├── .DS_Store ├── Dehaze ├── Options │ └── RealDehazing_HINT.yml ├── evaluate_SOTS.py ├── test_SOTS_HINT.py └── utils.py ├── Denoising ├── Options │ └── GaussianColorDenoising_HINT.yml ├── evaluate_gaussian_color_denoising_HINT.py ├── test_gaussian_color_denoising_HINT.py └── utils.py ├── Deraining ├── Options │ └── Deraining_HINT_syn_rain100L.yml ├── evaluate_PSNR_SSIM.m ├── test_rain100L.py └── utils.py ├── Desnowing ├── Options │ └── Desnow_snow100k_HINT.yml ├── evaluate_Snow100k.py ├── test_snow100k.py └── utils.py ├── Enhancement ├── Options │ ├── HINT_LOL_v2_real.yml │ └── HINT_LOL_v2_synthetic.yml ├── test_from_dataset_LOLv2_Real.py ├── test_from_dataset_LOLv2_Syn.py └── utils.py ├── README.md ├── VERSION ├── basicsr ├── .DS_Store ├── __pycache__ │ ├── version.cpython-37.pyc │ └── version.cpython-38.pyc ├── data │ ├── SDSD_image_dataset.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── SDSD_image_dataset.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── data_sampler.cpython-37.pyc │ │ ├── data_util.cpython-37.pyc │ │ ├── ffhq_dataset.cpython-37.pyc │ │ ├── paired_image_dataset.cpython-37.pyc │ │ ├── prefetch_dataloader.cpython-37.pyc │ │ ├── reds_dataset.cpython-37.pyc │ │ ├── single_image_dataset.cpython-37.pyc │ │ ├── transforms.cpython-37.pyc │ │ ├── util.cpython-37.pyc │ │ ├── video_test_dataset.cpython-37.pyc │ │ └── vimeo90k_dataset.cpython-37.pyc │ ├── data_sampler.py │ ├── data_util.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── util.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── metrics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── metric_util.cpython-37.pyc │ │ ├── metric_util.cpython-38.pyc │ │ ├── niqe.cpython-37.pyc │ │ ├── niqe.cpython-38.pyc │ │ ├── psnr_ssim.cpython-37.pyc │ │ └── psnr_ssim.cpython-38.pyc │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base_model.cpython-37.pyc │ │ ├── base_model.cpython-38.pyc │ │ ├── image_restoration_model.cpython-37.pyc │ │ ├── image_restoration_model.cpython-38.pyc │ │ ├── lr_scheduler.cpython-37.pyc │ │ └── lr_scheduler.cpython-38.pyc │ ├── archs │ │ ├── HINT_arch.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── FPro_arch.cpython-37.pyc │ │ │ ├── HINT_arch.cpython-37.pyc │ │ │ ├── HINT_arch.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── arch_util.cpython-37.pyc │ │ │ ├── graph_layers.cpython-37.pyc │ │ │ ├── local_arch.cpython-37.pyc │ │ │ ├── restormer_arch.cpython-37.pyc │ │ │ ├── restormer_arch.py │ │ │ └── restormer_local_arch.cpython-37.pyc │ │ └── arch_util.py │ ├── base_model.py │ ├── image_restoration_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── loss_util.cpython-37.pyc │ │ │ ├── loss_util.cpython-38.pyc │ │ │ ├── losses.cpython-37.pyc │ │ │ └── losses.cpython-38.pyc │ │ ├── loss_util.py │ │ └── losses.py │ └── lr_scheduler.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── create_lmdb.cpython-37.pyc │ │ ├── create_lmdb.cpython-38.pyc │ │ ├── dist_util.cpython-37.pyc │ │ ├── dist_util.cpython-38.pyc │ │ ├── file_client.cpython-37.pyc │ │ ├── file_client.cpython-38.pyc │ │ ├── flow_util.cpython-37.pyc │ │ ├── img_util.cpython-37.pyc │ │ ├── img_util.cpython-38.pyc │ │ ├── lmdb_util.cpython-37.pyc │ │ ├── lmdb_util.cpython-38.pyc │ │ ├── logger.cpython-37.pyc │ │ ├── logger.cpython-38.pyc │ │ ├── matlab_functions.cpython-37.pyc │ │ ├── matlab_functions.cpython-38.pyc │ │ ├── misc.cpython-37.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── options.cpython-37.pyc │ │ └── options.cpython-38.pyc │ ├── bundle_submissions.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py └── version.py ├── environment.yml ├── setup.py ├── test.sh └── train.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/.DS_Store -------------------------------------------------------------------------------- /Dehaze/Options/RealDehazing_HINT.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Dehazing_HINT 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage_dehazeSOT 13 | dataroot_gt: ./dataset/haze 14 | dataroot_lq: ./dataset/haze 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ## ------- Training on single fixed-patch size 128x128--------- 27 | mini_batch_sizes: [6,1] 28 | iters: [200000,100000] 29 | gt_size: 256 30 | gt_sizes: [128,256] 31 | ## ------------------------------------------------------------ 32 | 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: ValSet 38 | type: Dataset_PairedImage_dehazeSOT 39 | dataroot_gt: ./dataset/haze 40 | dataroot_lq: ./dataset/haze 41 | gt_size: 256 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | 47 | # network structures 48 | network_g: 49 | type: HINT 50 | inp_channels: 3 51 | out_channels: 3 52 | dim: 48 53 | num_blocks: [4,6,6,8] 54 | num_refinement_blocks: 4 55 | heads: [8,8,8,8] 56 | ffn_expansion_factor: 2.66 57 | bias: False 58 | LayerNorm_type: WithBias 59 | dual_pixel_task: False 60 | 61 | 62 | # path 63 | path: 64 | pretrain_network_g: ~ 65 | strict_load_g: true 66 | resume_state: ~ 67 | 68 | # training settings 69 | train: 70 | total_iter: 300000 71 | warmup_iter: -1 # no warm up 72 | use_grad_clip: true 73 | 74 | # Split 300k iterations into two cycles. 75 | # 1st cycle: fixed 3e-4 LR for 92k iters. 76 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 77 | scheduler: 78 | type: CosineAnnealingRestartCyclicLR 79 | periods: [92000, 208000] 80 | restart_weights: [1,1] 81 | eta_mins: [0.0003,0.000001] 82 | 83 | mixing_augs: 84 | mixup: true 85 | mixup_beta: 1.2 86 | use_identity: true 87 | 88 | optim_g: 89 | type: AdamW 90 | lr: !!float 3e-4 91 | weight_decay: !!float 1e-4 92 | betas: [0.9, 0.999] 93 | 94 | # losses 95 | pixel_opt: 96 | type: L1Loss 97 | loss_weight: 1 98 | reduction: mean 99 | fft_loss_opt: 100 | type: FFTLoss 101 | loss_weight: 0.1 102 | reduction: mean 103 | 104 | # validation settings 105 | val: 106 | window_size: 8 107 | val_freq: !!float 4e3 108 | save_img: false 109 | rgb2bgr: true 110 | use_image: false 111 | max_minibatch: 8 112 | 113 | metrics: 114 | psnr: # metric name, can be arbitrary 115 | type: calculate_psnr 116 | crop_border: 0 117 | test_y_channel: false 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 1000 122 | save_checkpoint_freq: !!float 4e3 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /Dehaze/evaluate_SOTS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | from natsort import natsorted 5 | from skimage import io 6 | import cv2 7 | import argparse 8 | from skimage.metrics import structural_similarity 9 | from tqdm import tqdm 10 | import concurrent.futures 11 | import utils 12 | 13 | def proc(filename): 14 | tar,prd = filename 15 | prd_name = prd.split('/')[-1].split('_')[0]+'.png' 16 | tar_name = './dataset/haze/promptIR/outdoor/gt/' + prd_name 17 | tar_img = utils.load_img(tar_name) 18 | prd_img = utils.load_img(prd) 19 | 20 | PSNR = utils.calculate_psnr(tar_img, prd_img) 21 | SSIM = utils.calculate_ssim(tar_img, prd_img) 22 | return PSNR,SSIM 23 | 24 | parser = argparse.ArgumentParser(description='Dehazing using HINT') 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | datasets = ['outdoor'] 30 | 31 | for dataset in datasets: 32 | 33 | gt_path = os.path.join('./dataset/haze/promptIR/outdoor/gt') 34 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif'))) 35 | assert len(gt_list) != 0, "Target files not found" 36 | 37 | 38 | file_path = os.path.join('results', 'HINT', dataset) 39 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif'))) 40 | assert len(path_list) != 0, "Predicted files not found" 41 | 42 | psnr, ssim = [], [] 43 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 44 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 45 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 46 | psnr.append(PSNR_SSIM[0]) 47 | ssim.append(PSNR_SSIM[1]) 48 | 49 | avg_psnr = sum(psnr)/len(psnr) 50 | avg_ssim = sum(ssim)/len(ssim) 51 | 52 | # print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr)) 53 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 54 | -------------------------------------------------------------------------------- /Dehaze/test_SOTS_HINT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import utils 10 | 11 | from natsort import natsorted 12 | from glob import glob 13 | from basicsr.models.archs.HINT_arch import HINT 14 | from skimage import img_as_ubyte 15 | from pdb import set_trace as stx 16 | 17 | parser = argparse.ArgumentParser(description='Image Dehazning using HINT') 18 | 19 | parser.add_argument('--input_dir', default='./dataset/haze/promptIR/', type=str, help='Directory of validation images') 20 | parser.add_argument('--result_dir', default='./results/HINT/', type=str, help='Directory for results') 21 | parser.add_argument('--weights', default='./models/Dehazing.pth', type=str, help='Path to weights') 22 | 23 | args = parser.parse_args() 24 | 25 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 26 | _, C, H, W = imgtensor.shape 27 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 28 | while hstarts and hstarts[-1] + crop_size >= H: 29 | hstarts.pop() 30 | hstarts.append(H - crop_size) 31 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 32 | while wstarts and wstarts[-1] + crop_size >= W: 33 | wstarts.pop() 34 | wstarts.append(W - crop_size) 35 | starts = [] 36 | split_data = [] 37 | for hs in hstarts: 38 | for ws in wstarts: 39 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 40 | starts.append((hs, ws)) 41 | split_data.append(cimgdata) 42 | return split_data, starts 43 | 44 | def get_scoremap(H, W, C, B=1, is_mean=True): 45 | center_h = H / 2 46 | center_w = W / 2 47 | 48 | score = torch.ones((B, C, H, W)) 49 | if not is_mean: 50 | for h in range(H): 51 | for w in range(W): 52 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 53 | return score 54 | 55 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 56 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 57 | tot_score = torch.zeros((B, C, H, W)) 58 | merge_img = torch.zeros((B, C, H, W)) 59 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 60 | for simg, cstart in zip(split_data, starts): 61 | hs, ws = cstart 62 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 63 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 64 | merge_img = merge_img / tot_score 65 | return merge_img 66 | 67 | ####### Load yaml ####### 68 | yaml_file = 'Options/RealDehazing_HINT.yml' 69 | import yaml 70 | 71 | try: 72 | from yaml import CLoader as Loader 73 | except ImportError: 74 | from yaml import Loader 75 | 76 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 77 | 78 | s = x['network_g'].pop('type') 79 | ########################## 80 | 81 | model_restoration = HINT(**x['network_g']) 82 | 83 | checkpoint = torch.load(args.weights) 84 | model_restoration.load_state_dict(checkpoint['params']) 85 | print("===>Testing using weights: ",args.weights) 86 | model_restoration.cuda() 87 | model_restoration = nn.DataParallel(model_restoration) 88 | model_restoration.eval() 89 | 90 | 91 | factor = 8 92 | datasets = ['outdoor'] 93 | 94 | for dataset in datasets: 95 | result_dir = os.path.join(args.result_dir, dataset) 96 | os.makedirs(result_dir, exist_ok=True) 97 | 98 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain') 99 | inp_dir = os.path.join(args.input_dir, dataset, 'hazy/') 100 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 101 | with torch.no_grad(): 102 | for file_ in tqdm(files): 103 | torch.cuda.ipc_collect() 104 | torch.cuda.empty_cache() 105 | 106 | img = np.float32(utils.load_img(file_))/255. 107 | img = torch.from_numpy(img).permute(2,0,1) 108 | input_ = img.unsqueeze(0).cuda() 109 | 110 | # Padding in case images are not multiples of 8 111 | h,w = input_.shape[2], input_.shape[3] 112 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 113 | padh = H-h if h%factor!=0 else 0 114 | padw = W-w if w%factor!=0 else 0 115 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 116 | 117 | restored = model_restoration(input_) 118 | 119 | restored = restored[:,:,:h,:w] 120 | 121 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 122 | 123 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 124 | -------------------------------------------------------------------------------- /Dehaze/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianColorDenoising_HINT.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianColorDenoising_HINT 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: random 14 | sigma_range: [0,50] 15 | in_ch: 3 ## RGB image 16 | dataroot_gt: ./Denoising/Datasets/train/WB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | # -------------Progressive training-------------------------- 30 | mini_batch_sizes: [6,4,3,1,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 256 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | dataset_enlarge_ratio: 1 37 | prefetch_mode: ~ 38 | 39 | val: 40 | name: ValSet 41 | type: Dataset_GaussianDenoising 42 | sigma_test: 25 43 | in_ch: 3 ## RGB image 44 | dataroot_gt: ./Denoising/Datasets/test/CBSD68 45 | dataroot_lq: none 46 | gt_size: 256 47 | io_backend: 48 | type: disk 49 | 50 | # network structures 51 | network_g: 52 | type: HINT 53 | inp_channels: 3 54 | out_channels: 3 55 | dim: 48 56 | num_blocks: [4,6,6,8] 57 | num_refinement_blocks: 4 58 | heads: [8,8,8,8] 59 | ffn_expansion_factor: 2.66 60 | bias: False 61 | LayerNorm_type: WithBias 62 | dual_pixel_task: False 63 | # path 64 | path: 65 | pretrain_network_g: ~ 66 | strict_load_g: true 67 | resume_state: ~ 68 | 69 | # training settings 70 | train: 71 | total_iter: 300000 72 | warmup_iter: -1 # no warm up 73 | use_grad_clip: true 74 | 75 | # Split 300k iterations into two cycles. 76 | # 1st cycle: fixed 3e-4 LR for 92k iters. 77 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 78 | scheduler: 79 | type: CosineAnnealingRestartCyclicLR 80 | periods: [92000, 208000] 81 | restart_weights: [1,1] 82 | eta_mins: [0.0003,0.000001] 83 | 84 | mixing_augs: 85 | mixup: true 86 | mixup_beta: 1.2 87 | use_identity: true 88 | 89 | optim_g: 90 | type: AdamW 91 | lr: !!float 3e-4 92 | weight_decay: !!float 1e-4 93 | betas: [0.9, 0.999] 94 | 95 | # losses 96 | pixel_opt: 97 | type: L1Loss 98 | loss_weight: 1 99 | reduction: mean 100 | fft_loss_opt: 101 | type: FFTLoss 102 | loss_weight: 0.1 103 | reduction: mean 104 | # validation settings 105 | val: 106 | window_size: 8 107 | val_freq: !!float 4e3 108 | save_img: false 109 | rgb2bgr: true 110 | use_image: false 111 | max_minibatch: 8 112 | 113 | metrics: 114 | psnr: # metric name, can be arbitrary 115 | type: calculate_psnr 116 | crop_border: 0 117 | test_y_channel: false 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 1000 122 | save_checkpoint_freq: !!float 4e3 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /Denoising/evaluate_gaussian_color_denoising_HINT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | from natsort import natsorted 5 | from skimage import io 6 | import cv2 7 | import argparse 8 | from skimage.metrics import structural_similarity 9 | from tqdm import tqdm 10 | import concurrent.futures 11 | import utils 12 | 13 | def proc(filename): 14 | tar,prd = filename 15 | tar_img = utils.load_img(tar) 16 | prd_img = utils.load_img(prd) 17 | 18 | PSNR = utils.calculate_psnr(tar_img, prd_img) 19 | SSIM = utils.calculate_ssim(tar_img, prd_img) 20 | return PSNR,SSIM 21 | 22 | parser = argparse.ArgumentParser(description='Gasussian Color Denoising using HINT') 23 | 24 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.') 25 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values') 26 | 27 | args = parser.parse_args() 28 | 29 | sigmas = np.int_(args.sigmas.split(',')) 30 | 31 | datasets = ['CBSD68','Urban100'] 32 | 33 | for dataset in datasets: 34 | 35 | gt_path = os.path.join('./Denoising/Datasets','test', dataset) 36 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif'))) 37 | assert len(gt_list) != 0, "Target files not found" 38 | 39 | for sigma_test in sigmas: 40 | file_path = os.path.join('results', 'Gaussian_Color_Denoising', args.model_type, dataset, str(sigma_test)) 41 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif'))) 42 | assert len(path_list) != 0, "Predicted files not found" 43 | 44 | psnr, ssim = [], [] 45 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 46 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 47 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 48 | psnr.append(PSNR_SSIM[0]) 49 | ssim.append(PSNR_SSIM[1]) 50 | 51 | avg_psnr = sum(psnr)/len(psnr) 52 | avg_ssim = sum(ssim)/len(ssim) 53 | 54 | print('For {:s} dataset Noise Level {:d} PSNR: {:f}\n'.format(dataset, sigma_test, avg_psnr)) 55 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 56 | -------------------------------------------------------------------------------- /Denoising/test_gaussian_color_denoising_HINT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from basicsr.models.archs.HINT_arch import HINT 11 | from skimage import img_as_ubyte 12 | from natsort import natsorted 13 | from glob import glob 14 | import utils 15 | from pdb import set_trace as stx 16 | 17 | parser = argparse.ArgumentParser(description='Gaussian Color Denoising using HINT') 18 | 19 | parser.add_argument('--input_dir', default='./Denoising/Datasets/test/', type=str, help='Directory of validation images') 20 | parser.add_argument('--result_dir', default='./results/Gaussian_Color_Denoising/', type=str, help='Directory for results') 21 | parser.add_argument('--weights', default='./models/net_g_latest', type=str, help='Path to weights') 22 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.') 23 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values') 24 | 25 | args = parser.parse_args() 26 | 27 | ####### Load yaml ####### 28 | if args.model_type == 'blind': 29 | yaml_file = 'Options/GaussianColorDenoising_HINT.yml' 30 | else: 31 | yaml_file = f'Options/GaussianColorDenoising_RestormerSigma{args.sigmas}.yml' 32 | import yaml 33 | 34 | try: 35 | from yaml import CLoader as Loader 36 | except ImportError: 37 | from yaml import Loader 38 | 39 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 40 | 41 | s = x['network_g'].pop('type') 42 | ########################## 43 | 44 | sigmas = np.int_(args.sigmas.split(',')) 45 | 46 | factor = 8 47 | 48 | datasets = ['CBSD68','Urban100'] 49 | 50 | for sigma_test in sigmas: 51 | print("Compute results for noise level",sigma_test) 52 | model_restoration = HINT(**x['network_g']) 53 | if args.model_type == 'blind': 54 | weights = args.weights+'_blind.pth' 55 | else: 56 | weights = args.weights + '_sigma' + str(sigma_test) +'.pth' 57 | checkpoint = torch.load(weights) 58 | model_restoration.load_state_dict(checkpoint['params']) 59 | 60 | print("===>Testing using weights: ",weights) 61 | print("------------------------------------------------") 62 | model_restoration.cuda() 63 | model_restoration = nn.DataParallel(model_restoration) 64 | model_restoration.eval() 65 | 66 | for dataset in datasets: 67 | inp_dir = os.path.join(args.input_dir, dataset) 68 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif'))) 69 | result_dir_tmp = os.path.join(args.result_dir, args.model_type, dataset, str(sigma_test)) 70 | os.makedirs(result_dir_tmp, exist_ok=True) 71 | 72 | with torch.no_grad(): 73 | for file_ in tqdm(files): 74 | torch.cuda.ipc_collect() 75 | torch.cuda.empty_cache() 76 | img = np.float32(utils.load_img(file_))/255. 77 | 78 | np.random.seed(seed=0) # for reproducibility 79 | img += np.random.normal(0, sigma_test/255., img.shape) 80 | 81 | img = torch.from_numpy(img).permute(2,0,1) 82 | input_ = img.unsqueeze(0).cuda() 83 | 84 | # Padding in case images are not multiples of 8 85 | h,w = input_.shape[2], input_.shape[3] 86 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 87 | padh = H-h if h%factor!=0 else 0 88 | padw = W-w if w%factor!=0 else 0 89 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 90 | 91 | restored = model_restoration(input_) 92 | 93 | # Unpad images to original dimensions 94 | restored = restored[:,:,:h,:w] 95 | 96 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 97 | 98 | save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1]) 99 | utils.save_img(save_file, img_as_ubyte(restored)) 100 | -------------------------------------------------------------------------------- /Denoising/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Deraining/Options/Deraining_HINT_syn_rain100L.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Deraining_HINT_rain100L 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 4 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./dataset/Rain100L/train/clean 14 | dataroot_lq: ./dataset/Rain100L/train/rainy 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### -------------Progressive training-------------------------- 27 | mini_batch_sizes: [6,4,3,1] # Batch size per gpu 28 | iters: [92000,64000,48000,96000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128,160,192,256] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | ### ------- Training on single fixed-patch size 128x128--------- 34 | # mini_batch_sizes: [8] 35 | # iters: [300000] 36 | # gt_size: 128 37 | # gt_sizes: [128] 38 | ### ------------------------------------------------------------ 39 | 40 | dataset_enlarge_ratio: 1 41 | prefetch_mode: ~ 42 | 43 | val: 44 | name: ValSet 45 | type: Dataset_PairedImage 46 | dataroot_gt: ./dataset/Rain100L/test/clean 47 | dataroot_lq: ./dataset/Rain100L/test/rainy 48 | io_backend: 49 | type: disk 50 | 51 | # network structures 52 | network_g: 53 | type: HINT 54 | inp_channels: 3 55 | out_channels: 3 56 | dim: 48 57 | num_blocks: [4,6,6,8] 58 | num_refinement_blocks: 4 59 | heads: [8,8,8,8] 60 | ffn_expansion_factor: 2.66 61 | bias: False 62 | LayerNorm_type: WithBias 63 | dual_pixel_task: false 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: false 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | fft_loss_opt: 104 | type: FFTLoss 105 | loss_weight: 0.1 106 | reduction: mean 107 | 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: true 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: true 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Deraining/test_rain100L.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | import utils 16 | 17 | from natsort import natsorted 18 | from glob import glob 19 | from basicsr.models.archs.HINT_arch import HINT 20 | from skimage import img_as_ubyte 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Deraining using HINT') 24 | 25 | parser.add_argument('--input_dir', default='./dataset', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/Rain100L_HINT/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./models/Rain100L_HINT.pth', type=str, help='Path to weights') 28 | 29 | args = parser.parse_args() 30 | 31 | ####### Load yaml ####### 32 | yaml_file = 'Options/Deraining_HINT_syn_rain100L.yml' 33 | import yaml 34 | 35 | try: 36 | from yaml import CLoader as Loader 37 | except ImportError: 38 | from yaml import Loader 39 | 40 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 41 | 42 | s = x['network_g'].pop('type') 43 | ########################## 44 | 45 | model_restoration = HINT(**x['network_g']) 46 | 47 | checkpoint = torch.load(args.weights) 48 | model_restoration.load_state_dict(checkpoint['params']) 49 | print("===>Testing using weights: ",args.weights) 50 | model_restoration.cuda() 51 | model_restoration = nn.DataParallel(model_restoration) 52 | model_restoration.eval() 53 | 54 | 55 | factor = 8 56 | datasets = ['Rain100L'] 57 | 58 | for dataset in datasets: 59 | result_dir = os.path.join(args.result_dir, dataset) 60 | os.makedirs(result_dir, exist_ok=True) 61 | 62 | inp_dir = os.path.join(args.input_dir, dataset,'test','rainy') 63 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 64 | with torch.no_grad(): 65 | for file_ in tqdm(files): 66 | torch.cuda.ipc_collect() 67 | torch.cuda.empty_cache() 68 | 69 | img = np.float32(utils.load_img(file_))/255. 70 | img = torch.from_numpy(img).permute(2,0,1) 71 | input_ = img.unsqueeze(0).cuda() 72 | 73 | # Padding in case images are not multiples of 8 74 | h,w = input_.shape[2], input_.shape[3] 75 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 76 | padh = H-h if h%factor!=0 else 0 77 | padw = W-w if w%factor!=0 else 0 78 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 79 | 80 | restored = model_restoration(input_) 81 | 82 | # Unpad images to original dimensions 83 | restored = restored[:,:,:h,:w] 84 | 85 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 86 | 87 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 88 | -------------------------------------------------------------------------------- /Deraining/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Desnowing/Options/Desnow_snow100k_HINT.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Desnow_HINT 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./dataset/Snow100K/train2500/Gt 14 | dataroot_lq: ./dataset/Snow100K/train2500/Snow 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### ------- Training on single fixed-patch size 128x128--------- 27 | mini_batch_sizes: [6,5,2,1,1] 28 | iters: [50000,40000,30000,20000,10000] 29 | gt_size: 128 30 | gt_sizes: [128,192,256,320,384] 31 | ### ------------------------------------------------------------ 32 | 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: ValSet 38 | type: Dataset_PairedImage 39 | dataroot_gt: ./dataset/Snow100K/test2000/Gt 40 | dataroot_lq: ./dataset/Snow100K/test2000/Snow 41 | gt_size: 256 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: HINT 48 | inp_channels: 3 49 | out_channels: 3 50 | dim: 48 51 | num_blocks: [4,6,6,8] 52 | num_refinement_blocks: 4 53 | heads: [8,8,8,8] 54 | ffn_expansion_factor: 2.66 55 | bias: False 56 | LayerNorm_type: WithBias 57 | dual_pixel_task: False 58 | 59 | 60 | # path 61 | path: 62 | pretrain_network_g: ~ 63 | strict_load_g: true 64 | resume_state: ~ 65 | 66 | # training settings 67 | train: 68 | total_iter: 300000 69 | warmup_iter: -1 # no warm up 70 | use_grad_clip: true 71 | 72 | # Split 300k iterations into two cycles. 73 | # 1st cycle: fixed 3e-4 LR for 92k iters. 74 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 75 | scheduler: 76 | type: CosineAnnealingRestartCyclicLR 77 | periods: [92000, 208000] 78 | restart_weights: [1,1] 79 | eta_mins: [0.0003,0.000001] 80 | 81 | mixing_augs: 82 | mixup: true 83 | mixup_beta: 1.2 84 | use_identity: true 85 | 86 | optim_g: 87 | type: AdamW 88 | lr: !!float 3e-4 89 | weight_decay: !!float 1e-4 90 | betas: [0.9, 0.999] 91 | 92 | # losses 93 | pixel_opt: 94 | type: L1Loss 95 | loss_weight: 1 96 | reduction: mean 97 | fft_loss_opt: 98 | type: FFTLoss 99 | loss_weight: 0.1 100 | reduction: mean 101 | 102 | # validation settings 103 | val: 104 | window_size: 8 105 | val_freq: !!float 4e3 106 | save_img: false 107 | rgb2bgr: true 108 | use_image: false 109 | max_minibatch: 8 110 | 111 | metrics: 112 | psnr: # metric name, can be arbitrary 113 | type: calculate_psnr 114 | crop_border: 0 115 | test_y_channel: false 116 | 117 | # logging settings 118 | logger: 119 | print_freq: 1000 120 | save_checkpoint_freq: !!float 4e3 121 | use_tb_logger: true 122 | wandb: 123 | project: ~ 124 | resume_id: ~ 125 | 126 | # dist training settings 127 | dist_params: 128 | backend: nccl 129 | port: 29500 130 | -------------------------------------------------------------------------------- /Desnowing/evaluate_Snow100k.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | import argparse 12 | from skimage.metrics import structural_similarity 13 | from tqdm import tqdm 14 | import concurrent.futures 15 | import utils 16 | 17 | def proc(filename): 18 | tar,prd = filename 19 | prd_name = prd.split('/')[-1]+'.png' 20 | t_name = prd.split('/')[-1].split('.')[0]+'.jpg' 21 | tar_name = './dataset/Snow100K/test2000/Gt/' + t_name 22 | tar_img = utils.load_img(tar_name) 23 | prd_img = utils.load_img(prd) 24 | 25 | PSNR = utils.calculate_psnr(tar_img, prd_img) 26 | SSIM = utils.calculate_ssim(tar_img, prd_img) 27 | return PSNR,SSIM 28 | 29 | parser = argparse.ArgumentParser(description='Desnowing using HINT') 30 | 31 | args = parser.parse_args() 32 | 33 | 34 | datasets = ['test2000'] 35 | 36 | for dataset in datasets: 37 | 38 | gt_path = os.path.join('./dataset/Snow100K/test2000/Gt') 39 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 40 | assert len(gt_list) != 0, "Target files not found" 41 | 42 | 43 | file_path = os.path.join('results', 'HINT', dataset) 44 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 45 | assert len(path_list) != 0, "Predicted files not found" 46 | 47 | psnr, ssim = [], [] 48 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 49 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 50 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 51 | psnr.append(PSNR_SSIM[0]) 52 | ssim.append(PSNR_SSIM[1]) 53 | 54 | avg_psnr = sum(psnr)/len(psnr) 55 | avg_ssim = sum(ssim)/len(ssim) 56 | 57 | # print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr)) 58 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 59 | -------------------------------------------------------------------------------- /Desnowing/test_snow100k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import utils 10 | 11 | from natsort import natsorted 12 | from glob import glob 13 | from basicsr.models.archs.HINT_arch import HINT 14 | from skimage import img_as_ubyte 15 | from pdb import set_trace as stx 16 | 17 | parser = argparse.ArgumentParser(description='Image Desnowing using HINT') 18 | 19 | parser.add_argument('--input_dir', default='./dataset/Snow100K/', type=str, help='Directory of validation images') 20 | parser.add_argument('--result_dir', default='./results/HINT', type=str, help='Directory for results') 21 | parser.add_argument('--weights', default='./models/snow100k.pth', type=str, help='Path to weights') 22 | 23 | args = parser.parse_args() 24 | 25 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 26 | _, C, H, W = imgtensor.shape 27 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 28 | while hstarts and hstarts[-1] + crop_size >= H: 29 | hstarts.pop() 30 | hstarts.append(H - crop_size) 31 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 32 | while wstarts and wstarts[-1] + crop_size >= W: 33 | wstarts.pop() 34 | wstarts.append(W - crop_size) 35 | starts = [] 36 | split_data = [] 37 | for hs in hstarts: 38 | for ws in wstarts: 39 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 40 | starts.append((hs, ws)) 41 | split_data.append(cimgdata) 42 | return split_data, starts 43 | 44 | def get_scoremap(H, W, C, B=1, is_mean=True): 45 | center_h = H / 2 46 | center_w = W / 2 47 | 48 | score = torch.ones((B, C, H, W)) 49 | if not is_mean: 50 | for h in range(H): 51 | for w in range(W): 52 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 53 | return score 54 | 55 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 56 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 57 | tot_score = torch.zeros((B, C, H, W)) 58 | merge_img = torch.zeros((B, C, H, W)) 59 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 60 | for simg, cstart in zip(split_data, starts): 61 | hs, ws = cstart 62 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 63 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 64 | merge_img = merge_img / tot_score 65 | return merge_img 66 | 67 | ####### Load yaml ####### 68 | yaml_file = 'Options/Desnow_snow100k_HINT.yml' 69 | import yaml 70 | 71 | try: 72 | from yaml import CLoader as Loader 73 | except ImportError: 74 | from yaml import Loader 75 | 76 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 77 | 78 | s = x['network_g'].pop('type') 79 | ########################## 80 | 81 | model_restoration = HINT(**x['network_g']) 82 | 83 | checkpoint = torch.load(args.weights) 84 | model_restoration.load_state_dict(checkpoint['params']) 85 | print("===>Testing using weights: ",args.weights) 86 | model_restoration.cuda() 87 | model_restoration = nn.DataParallel(model_restoration) 88 | model_restoration.eval() 89 | 90 | 91 | factor = 8 92 | datasets = ['test2000'] 93 | 94 | for dataset in datasets: 95 | result_dir = os.path.join(args.result_dir, dataset) 96 | os.makedirs(result_dir, exist_ok=True) 97 | 98 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain') 99 | inp_dir = os.path.join(args.input_dir, dataset, 'Snow/') 100 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 101 | with torch.no_grad(): 102 | for file_ in tqdm(files): 103 | torch.cuda.ipc_collect() 104 | torch.cuda.empty_cache() 105 | 106 | img = np.float32(utils.load_img(file_))/255. 107 | img = torch.from_numpy(img).permute(2,0,1) 108 | input_ = img.unsqueeze(0).cuda() 109 | 110 | B, C, H, W = input_.shape 111 | corp_size_arg = 256 112 | overlap_size_arg = 128 113 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 114 | for i, data in enumerate(split_data): 115 | split_data[i] = model_restoration(data).cpu() 116 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 117 | 118 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 119 | 120 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 121 | -------------------------------------------------------------------------------- /Desnowing/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Enhancement/Options/HINT_LOL_v2_real.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Enhancement_HINT 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./dataset/LOLv2/Real_captured/Train/Normal 14 | dataroot_lq: ./dataset/LOLv2/Real_captured/Train/Low 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ## -------------Progressive training-------------------------- 27 | mini_batch_sizes: [6] # Batch size per gpu 28 | iters: [150000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: ValSet 39 | type: Dataset_PairedImage 40 | dataroot_gt: ./dataset/LOLv2/Real_captured/Test/Normal 41 | dataroot_lq: ./dataset/LOLv2/Real_captured/Test/Low 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: HINT 48 | inp_channels: 3 49 | out_channels: 3 50 | dim: 48 51 | num_blocks: [4,6,6,8] 52 | num_refinement_blocks: 4 53 | heads: [8,8,8,8] 54 | ffn_expansion_factor: 2.66 55 | bias: False 56 | LayerNorm_type: WithBias 57 | dual_pixel_task: False 58 | 59 | # path 60 | path: 61 | pretrain_network_g: ~ 62 | strict_load_g: true 63 | resume_state: ~ 64 | 65 | # training settings 66 | train: 67 | total_iter: 150000 68 | warmup_iter: -1 # no warm up 69 | use_grad_clip: true 70 | 71 | # Split 300k iterations into two cycles. 72 | # 1st cycle: fixed 3e-4 LR for 92k iters. 73 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 74 | scheduler: 75 | type: CosineAnnealingRestartCyclicLR 76 | periods: [46000, 104000] 77 | restart_weights: [1,1] 78 | eta_mins: [0.0003,0.000001] 79 | 80 | mixing_augs: 81 | mixup: true 82 | mixup_beta: 1.2 83 | use_identity: true 84 | 85 | optim_g: 86 | type: Adam 87 | lr: !!float 2e-4 88 | # weight_decay: !!float 1e-4 89 | betas: [0.9, 0.999] 90 | 91 | pixel_opt: 92 | type: L1Loss 93 | loss_weight: 1 94 | reduction: mean 95 | 96 | fft_loss_opt: 97 | type: FFTLoss 98 | loss_weight: 0.1 99 | reduction: mean 100 | 101 | 102 | # validation settings 103 | val: 104 | window_size: 4 105 | val_freq: !!float 1e3 106 | save_img: false 107 | rgb2bgr: true 108 | use_image: false 109 | max_minibatch: 8 110 | 111 | metrics: 112 | psnr: # metric name, can be arbitrary 113 | type: calculate_psnr 114 | crop_border: 0 115 | test_y_channel: false 116 | 117 | # logging settings 118 | logger: 119 | print_freq: 500 120 | save_checkpoint_freq: !!float 1e3 121 | use_tb_logger: true 122 | wandb: 123 | project: ~ 124 | resume_id: ~ 125 | 126 | # dist training settings 127 | dist_params: 128 | backend: nccl 129 | port: 29500 130 | -------------------------------------------------------------------------------- /Enhancement/Options/HINT_LOL_v2_synthetic.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Enhancement_HINT 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./dataset/LOLv2/Synthetic/Train/Normal 14 | dataroot_lq: ./dataset/LOLv2/Synthetic/Train/Low 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### ------- Training on single fixed-patch size 128x128--------- 27 | mini_batch_sizes: [6,5,2,1,1] 28 | iters: [50000,40000,30000,20000,10000] 29 | gt_size: 128 30 | gt_sizes: [128,192,256,320,384] 31 | ### ------------------------------------------------------------ 32 | 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: ValSet 38 | type: Dataset_PairedImage 39 | dataroot_gt: ./dataset/LOLv2/Synthetic/Test/Normal 40 | dataroot_lq: ./dataset/LOLv2/Synthetic/Test/Low 41 | io_backend: 42 | type: disk 43 | 44 | # network structures 45 | network_g: 46 | type: HINT 47 | inp_channels: 3 48 | out_channels: 3 49 | dim: 48 50 | num_blocks: [4,6,6,8] 51 | num_refinement_blocks: 4 52 | heads: [8,8,8,8] 53 | ffn_expansion_factor: 2.66 54 | bias: False 55 | LayerNorm_type: WithBias 56 | dual_pixel_task: False 57 | 58 | 59 | # path 60 | path: 61 | pretrain_network_g: ~ 62 | strict_load_g: true 63 | resume_state: ~ 64 | 65 | # training settings 66 | train: 67 | total_iter: 150000 68 | warmup_iter: -1 # no warm up 69 | use_grad_clip: true 70 | 71 | # Split 300k iterations into two cycles. 72 | # 1st cycle: fixed 3e-4 LR for 92k iters. 73 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 74 | scheduler: 75 | type: CosineAnnealingRestartCyclicLR 76 | periods: [46000, 104000] 77 | restart_weights: [1,1] 78 | eta_mins: [0.0003,0.000001] 79 | 80 | mixing_augs: 81 | mixup: true 82 | mixup_beta: 1.2 83 | use_identity: true 84 | 85 | optim_g: 86 | type: Adam 87 | lr: !!float 2e-4 88 | # weight_decay: !!float 1e-4 89 | betas: [0.9, 0.999] 90 | 91 | pixel_opt: 92 | type: L1Loss 93 | loss_weight: 1 94 | reduction: mean 95 | 96 | fft_loss_opt: 97 | type: FFTLoss 98 | loss_weight: 0.1 99 | reduction: mean 100 | 101 | # validation settings 102 | val: 103 | window_size: 4 104 | val_freq: !!float 1e3 105 | save_img: false 106 | rgb2bgr: true 107 | use_image: false 108 | max_minibatch: 8 109 | 110 | metrics: 111 | psnr: # metric name, can be arbitrary 112 | type: calculate_psnr 113 | crop_border: 0 114 | test_y_channel: false 115 | 116 | # logging settings 117 | logger: 118 | print_freq: 500 119 | save_checkpoint_freq: !!float 1e3 120 | use_tb_logger: true 121 | wandb: 122 | project: ~ 123 | resume_id: ~ 124 | 125 | # dist training settings 126 | dist_params: 127 | backend: nccl 128 | port: 29500 129 | -------------------------------------------------------------------------------- /Enhancement/utils.py: -------------------------------------------------------------------------------- 1 | # Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement 2 | # Yuanhao Cai, Hao Bian, Jing Lin, Haoqian Wang, Radu Timofte, Yulun Zhang 3 | # International Conference on Computer Vision (ICCV), 2023 4 | # https://arxiv.org/abs/2303.06705 5 | # https://github.com/caiyuanhao1998/Retinexformer 6 | 7 | import numpy as np 8 | import os 9 | import cv2 10 | import math 11 | from pdb import set_trace as stx 12 | 13 | 14 | def calculate_psnr(img1, img2, border=0): 15 | # img1 and img2 have range [0, 255] 16 | #img1 = img1.squeeze() 17 | #img2 = img2.squeeze() 18 | if not img1.shape == img2.shape: 19 | raise ValueError('Input images must have the same dimensions.') 20 | h, w = img1.shape[:2] 21 | img1 = img1[border:h - border, border:w - border] 22 | img2 = img2[border:h - border, border:w - border] 23 | 24 | img1 = img1.astype(np.float64) 25 | img2 = img2.astype(np.float64) 26 | mse = np.mean((img1 - img2)**2) 27 | if mse == 0: 28 | return float('inf') 29 | return 20 * math.log10(255.0 / math.sqrt(mse)) 30 | 31 | 32 | def PSNR(img1, img2): 33 | mse_ = np.mean((img1 - img2) ** 2) 34 | if mse_ == 0: 35 | return 100 36 | return 10 * math.log10(1 / mse_) 37 | 38 | 39 | # -------------------------------------------- 40 | # SSIM 41 | # -------------------------------------------- 42 | def calculate_ssim(img1, img2, border=0): 43 | '''calculate SSIM 44 | the same outputs as MATLAB's 45 | img1, img2: [0, 255] 46 | ''' 47 | #img1 = img1.squeeze() 48 | #img2 = img2.squeeze() 49 | if not img1.shape == img2.shape: 50 | raise ValueError('Input images must have the same dimensions.') 51 | h, w = img1.shape[:2] 52 | img1 = img1[border:h - border, border:w - border] 53 | img2 = img2[border:h - border, border:w - border] 54 | 55 | if img1.ndim == 2: 56 | return ssim(img1, img2) 57 | elif img1.ndim == 3: 58 | if img1.shape[2] == 3: 59 | ssims = [] 60 | for i in range(3): 61 | ssims.append(ssim(img1[:, :, i], img2[:, :, i])) 62 | return np.array(ssims).mean() 63 | elif img1.shape[2] == 1: 64 | return ssim(np.squeeze(img1), np.squeeze(img2)) 65 | else: 66 | raise ValueError('Wrong input image dimensions.') 67 | 68 | 69 | def ssim(img1, img2): 70 | C1 = (0.01 * 255)**2 71 | C2 = (0.03 * 255)**2 72 | 73 | img1 = img1.astype(np.float64) 74 | img2 = img2.astype(np.float64) 75 | kernel = cv2.getGaussianKernel(11, 1.5) 76 | window = np.outer(kernel, kernel.transpose()) 77 | 78 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 79 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 80 | mu1_sq = mu1**2 81 | mu2_sq = mu2**2 82 | mu1_mu2 = mu1 * mu2 83 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 84 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 85 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 86 | 87 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 88 | (sigma1_sq + sigma2_sq + C2)) 89 | return ssim_map.mean() 90 | 91 | 92 | def load_img(filepath): 93 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 94 | 95 | 96 | def save_img(filepath, img): 97 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 98 | 99 | 100 | def load_gray_img(filepath): 101 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 102 | 103 | 104 | def save_gray_img(filepath, img): 105 | cv2.imwrite(filepath, img) 106 | 107 | 108 | def visualization(feature, save_path, type='max', colormap=cv2.COLORMAP_JET): 109 | ''' 110 | :param feature: [C,H,W] 111 | :param save_path: saving path 112 | :param type: 'mean' or 'max' 113 | :param colormap: the type of the pseudocolor map 114 | ''' 115 | feature = feature.cpu().numpy() 116 | if type == 'mean': 117 | feature = np.mean(feature, axis=0) 118 | else: 119 | feature = np.max(feature, axis=0) 120 | normed_feat = (feature - feature.min()) / (feature.max() - feature.min()) 121 | normed_feat = (normed_feat * 255).astype('uint8') 122 | color_feat = cv2.applyColorMap(normed_feat, colormap) 123 | # stx() 124 | cv2.imwrite(save_path, color_feat) 125 | 126 | def my_summary(test_model, H = 256, W = 256, C = 3, N = 1): 127 | model = test_model.cuda() 128 | print(model) 129 | inputs = torch.randn((N, C, H, W)).cuda() 130 | flops = FlopCountAnalysis(model,inputs) 131 | n_param = sum([p.nelement() for p in model.parameters()]) 132 | print(f'GMac:{flops.total()/(1024*1024*1024)}') 133 | print(f'Params:{n_param}') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Devil is in the Uniformity: Exploring Diverse Learners within Transformer for Image Restoration 2 | 3 | [![Hugging Face Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demos-blue)](https://huggingface.co/spaces/yssszzzzzzzzy/HINT) 4 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=joshyZhou/HINT) 5 | [![GitHub Stars](https://img.shields.io/github/stars/joshyZhou/HINT?style=social)](https://github.com/joshyZhou/HINT)
6 | 7 | [Shihao Zhou](https://joshyzhou.github.io/), [Dayu Li](https://github.com/nkldy22), [Jinshan Pan](https://jspan.github.io/), [Juncheng Zhou](https://github.com/ZhouJunCheng99), [Jinglei Shi](https://jingleishi.github.io/) and [Jufeng Yang](https://cv.nankai.edu.cn/) 8 | 9 | #### News 10 | - **Jul 19, 2025:** [Hugging Face Demo](https://huggingface.co/spaces/yssszzzzzzzzy/HINT) is available now, thanks contribution of [Sen](https://github.com/yss730) 11 | - **Jun 26, 2025:** HINT has been accepted to ICCV 2025 :tada: 12 |
13 | 14 | ## Training 15 | ### Derain 16 | To train HINT on rain100L, you can run: 17 | ```sh 18 | ./train.sh Deraining/Options/Deraining_HINT_syn_rain100L.yml 19 | ``` 20 | ### Dehaze 21 | To train HINT on SOTS, you can run: 22 | ```sh 23 | ./train.sh Dehaze/Options/RealDehazing_HINT.yml 24 | ``` 25 | ### Denoising 26 | To train HINT on WB, you can run: 27 | ```sh 28 | ./train.sh Denoising/Options/GaussianColorDenoising_HINT.yml 29 | ``` 30 | ### Desnowing 31 | To train HINT on snow100k, you can run: 32 | ```sh 33 | ./train.sh Desnowing/Options/Desnow_snow100k_HINT.yml 34 | ``` 35 | ### Enhancement 36 | To train HINT on LOL_v2_real, you can run: 37 | ```sh 38 | ./train.sh Enhancement/Options/HINT_LOL_v2_real.yml 39 | ``` 40 | 41 | To train HINT on LOL_v2_synthetic, you can run: 42 | ```sh 43 | ./train.sh Enhancement/Options/HINT_LOL_v2_synthetic.yml 44 | ``` 45 | 46 | ## Evaluation 47 | To evaluate HINT, you can refer commands in 'test.sh' 48 | 49 | For evaluate on each dataset, you should uncomment corresponding line. 50 | 51 | 52 | ## Results 53 | Experiments are performed for different image processing tasks. 54 | Here is a summary table containing hyperlinks for easy navigation: 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 |
BenchmarkPretrained modelVisual Results
Rain100L(code:ngn8)(code:bdpg)
SOTS(code:64j8)(code:dypf)
Snow100K(code:q2cm)(code:s7xx)
LOL-v2-Real(code:6cux)(code:5bxm)
LOL-v2-Syn(code:7fi5)(code:y9uq)
WB(code:7fi5)(code:ss8c)
93 | 94 | 95 | ## Citation 96 | If you find this project useful, please consider citing: 97 | 98 | @inproceedings{zhou_ICCV25_HINT, 99 | title={Devil is in the Uniformity: Exploring Diverse Learners within Transformer for Image Restoration}, 100 | author={Zhou, Shihao and Li, Dayu and Pan, Jinshan and Zhou, Juncheng and Shi, Jinglei and Yang, Jufeng}, 101 | booktitle={ICCV}, 102 | year={2025} 103 | } 104 | 105 | ## Acknowledgement 106 | 107 | This code borrows heavily from [Restormer](https://github.com/swz30/Restormer). -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /basicsr/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/.DS_Store -------------------------------------------------------------------------------- /basicsr/__pycache__/version.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/__pycache__/version.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/__pycache__/version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/__pycache__/version.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/data/SDSD_image_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | import torch.utils.data as data 4 | import basicsr.data.util as util 5 | import torch.nn.functional as F 6 | import random 7 | import cv2 8 | import numpy as np 9 | import glob 10 | import os 11 | import functools 12 | 13 | 14 | class Dataset_SDSDImage(data.Dataset): 15 | def __init__(self, opt): 16 | super(Dataset_SDSDImage, self).__init__() 17 | self.opt = opt 18 | self.cache_data = opt['cache_data'] 19 | self.half_N_frames = opt['N_frames'] // 2 20 | self.GT_root, self.LQ_root = opt['dataroot_gt'], opt['dataroot_lq'] 21 | self.io_backend_opt = opt['io_backend'] 22 | self.data_type = self.io_backend_opt['type'] 23 | self.data_info = {'path_LQ': [], 'path_GT': [], 24 | 'folder': [], 'idx': [], 'border': []} 25 | if self.data_type == 'lmdb': 26 | raise ValueError('No need to use LMDB during validation/test.') 27 | # Generate data info and cache data 28 | self.imgs_LQ, self.imgs_GT = {}, {} 29 | 30 | if opt['testing_dir'] is not None: 31 | testing_dir = opt['testing_dir'] 32 | testing_dir = testing_dir.split(',') 33 | else: 34 | testing_dir = [] 35 | print('testing_dir', testing_dir) 36 | 37 | subfolders_LQ = util.glob_file_list(self.LQ_root) 38 | subfolders_GT = util.glob_file_list(self.GT_root) 39 | 40 | for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT): 41 | # for frames in each video: 42 | subfolder_name = osp.basename(subfolder_GT) 43 | 44 | if self.opt['phase'] == 'train': 45 | if (subfolder_name in testing_dir): 46 | continue 47 | 48 | if (subfolder_name.split('_2')[0] in testing_dir): 49 | continue 50 | else: # val test 51 | if not(subfolder_name in testing_dir) and not(subfolder_name.split('_2')[0] in testing_dir): 52 | continue 53 | 54 | img_paths_LQ = util.glob_file_list(subfolder_LQ) 55 | img_paths_GT = util.glob_file_list(subfolder_GT) 56 | 57 | max_idx = len(img_paths_LQ) 58 | assert max_idx == len( 59 | img_paths_GT), 'Different number of images in LQ and GT folders' 60 | self.data_info['path_LQ'].extend( 61 | img_paths_LQ) # list of path str of images 62 | self.data_info['path_GT'].extend(img_paths_GT) 63 | 64 | self.data_info['folder'].extend([subfolder_name] * max_idx) 65 | for i in range(max_idx): 66 | self.data_info['idx'].append('{}/{}'.format(i, max_idx)) 67 | 68 | border_l = [0] * max_idx 69 | for i in range(self.half_N_frames): 70 | border_l[i] = 1 71 | border_l[max_idx - i - 1] = 1 72 | self.data_info['border'].extend(border_l) 73 | 74 | if self.cache_data: 75 | self.imgs_LQ[subfolder_name] = img_paths_LQ 76 | self.imgs_GT[subfolder_name] = img_paths_GT 77 | 78 | def __getitem__(self, index): 79 | folder = self.data_info['folder'][index] 80 | idx, max_idx = self.data_info['idx'][index].split('/') 81 | idx, max_idx = int(idx), int(max_idx) 82 | border = self.data_info['border'][index] 83 | 84 | img_LQ_path = self.imgs_LQ[folder][idx:idx + 1] 85 | img_GT_path = self.imgs_GT[folder][idx:idx + 1] 86 | 87 | img_LQ = util.read_img_seq2(img_LQ_path, self.opt['train_size']) 88 | img_LQ = img_LQ[0] 89 | img_GT = util.read_img_seq2(img_GT_path, self.opt['train_size']) 90 | img_GT = img_GT[0] 91 | 92 | if self.opt['phase'] == 'train': 93 | 94 | # LQ_size = self.opt['LQ_size'] 95 | # GT_size = self.opt['GT_size'] 96 | 97 | # _, H, W = img_GT.shape # real img size 98 | 99 | # rnd_h = random.randint(0, max(0, H - GT_size)) 100 | # rnd_w = random.randint(0, max(0, W - GT_size)) 101 | # img_LQ = img_LQ[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size] 102 | # img_GT = img_GT[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size] 103 | 104 | img_LQ_l = [img_LQ] 105 | img_LQ_l.append(img_GT) 106 | rlt = util.augment_torch( 107 | img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) 108 | img_LQ = rlt[0] 109 | img_GT = rlt[1] 110 | 111 | # img_nf = img_LQ.clone().permute(1, 2, 0).numpy() * 255.0 112 | # img_nf = cv2.blur(img_nf, (5, 5)) 113 | # img_nf = img_nf * 1.0 / 255.0 114 | # img_nf = torch.Tensor(img_nf).float().permute(2, 0, 1) 115 | 116 | return { 117 | 'lq': img_LQ, 118 | 'gt': img_GT, 119 | # 'nf': img_nf, 120 | 'folder': folder, 121 | 'idx': self.data_info['idx'][index], 122 | 'border': border, 123 | 'lq_path': img_LQ_path[0], 124 | 'gt_path': img_GT_path[0] 125 | } 126 | 127 | def __len__(self): 128 | return len(self.data_info['path_LQ']) 129 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | 13 | __all__ = ['create_dataset', 'create_dataloader'] 14 | 15 | # automatically scan and import dataset modules 16 | # scan all the files under the data folder with '_dataset' in file names 17 | data_folder = osp.dirname(osp.abspath(__file__)) 18 | dataset_filenames = [ 19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 20 | if v.endswith('_dataset.py') 21 | ] 22 | # import all the dataset modules 23 | _dataset_modules = [ 24 | importlib.import_module(f'basicsr.data.{file_name}') 25 | for file_name in dataset_filenames 26 | ] 27 | 28 | 29 | def create_dataset(dataset_opt): 30 | """Create dataset. 31 | 32 | Args: 33 | dataset_opt (dict): Configuration for dataset. It constains: 34 | name (str): Dataset name. 35 | type (str): Dataset type. 36 | """ 37 | dataset_type = dataset_opt['type'] 38 | 39 | # dynamic instantiation 40 | for module in _dataset_modules: 41 | dataset_cls = getattr(module, dataset_type, None) 42 | if dataset_cls is not None: 43 | break 44 | if dataset_cls is None: 45 | raise ValueError(f'Dataset {dataset_type} is not found.') 46 | 47 | dataset = dataset_cls(dataset_opt) 48 | 49 | logger = get_root_logger() 50 | logger.info( 51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 52 | 'is created.') 53 | return dataset 54 | 55 | 56 | def create_dataloader(dataset, 57 | dataset_opt, 58 | num_gpu=1, 59 | dist=False, 60 | sampler=None, 61 | seed=None): 62 | """Create dataloader. 63 | 64 | Args: 65 | dataset (torch.utils.data.Dataset): Dataset. 66 | dataset_opt (dict): Dataset options. It contains the following keys: 67 | phase (str): 'train' or 'val'. 68 | num_worker_per_gpu (int): Number of workers for each GPU. 69 | batch_size_per_gpu (int): Training batch size for each GPU. 70 | num_gpu (int): Number of GPUs. Used only in the train phase. 71 | Default: 1. 72 | dist (bool): Whether in distributed training. Used only in the train 73 | phase. Default: False. 74 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 75 | seed (int | None): Seed. Default: None 76 | """ 77 | phase = dataset_opt['phase'] 78 | rank, _ = get_dist_info() 79 | if phase == 'train': 80 | if dist: # distributed training 81 | batch_size = dataset_opt['batch_size_per_gpu'] 82 | num_workers = dataset_opt['num_worker_per_gpu'] 83 | else: # non-distributed training 84 | multiplier = 1 if num_gpu == 0 else num_gpu 85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 87 | dataloader_args = dict( 88 | dataset=dataset, 89 | batch_size=batch_size, 90 | shuffle=False, 91 | num_workers=num_workers, 92 | sampler=sampler, 93 | drop_last=True) 94 | if sampler is None: 95 | dataloader_args['shuffle'] = True 96 | dataloader_args['worker_init_fn'] = partial( 97 | worker_init_fn, num_workers=num_workers, rank=rank, 98 | seed=seed) if seed is not None else None 99 | elif phase in ['val', 'test']: # validation 100 | dataloader_args = dict( 101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 102 | else: 103 | raise ValueError(f'Wrong dataset phase: {phase}. ' 104 | "Supported ones are 'train', 'val' and 'test'.") 105 | 106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 107 | 108 | prefetch_mode = dataset_opt.get('prefetch_mode') 109 | if prefetch_mode == 'cpu': # CPUPrefetcher 110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 111 | logger = get_root_logger() 112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 113 | f'num_prefetch_queue = {num_prefetch_queue}') 114 | return PrefetchDataLoader( 115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 116 | else: 117 | # prefetch_mode=None: Normal dataloader 118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 119 | return torch.utils.data.DataLoader(**dataloader_args) 120 | 121 | 122 | def worker_init_fn(worker_id, num_workers, rank, seed): 123 | # Set the worker seed to num_workers * rank + worker_id + seed 124 | worker_seed = num_workers * rank + worker_id + seed 125 | np.random.seed(worker_seed) 126 | random.seed(worker_seed) 127 | -------------------------------------------------------------------------------- /basicsr/data/__pycache__/SDSD_image_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/SDSD_image_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/data_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/data_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/data_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/data_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/reds_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/reds_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil( 27 | len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank:self.total_size:self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | 8 | 9 | class FFHQDataset(data.Dataset): 10 | """FFHQ dataset for StyleGAN. 11 | 12 | Args: 13 | opt (dict): Config for train datasets. It contains the following keys: 14 | dataroot_gt (str): Data root path for gt. 15 | io_backend (dict): IO backend type and other kwarg. 16 | mean (list | tuple): Image mean. 17 | std (list | tuple): Image std. 18 | use_hflip (bool): Whether to horizontally flip. 19 | 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(FFHQDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | 29 | self.gt_folder = opt['dataroot_gt'] 30 | self.mean = opt['mean'] 31 | self.std = opt['std'] 32 | 33 | if self.io_backend_opt['type'] == 'lmdb': 34 | self.io_backend_opt['db_paths'] = self.gt_folder 35 | if not self.gt_folder.endswith('.lmdb'): 36 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 37 | f'but received {self.gt_folder}') 38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 39 | self.paths = [line.split('.')[0] for line in fin] 40 | else: 41 | # FFHQ has 70000 images in total 42 | self.paths = [ 43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 44 | ] 45 | 46 | def __getitem__(self, index): 47 | if self.file_client is None: 48 | self.file_client = FileClient( 49 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | img_bytes = self.file_client.get(gt_path) 54 | img_gt = imfrombytes(img_bytes, float32=True) 55 | 56 | # random horizontal flip 57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 58 | # BGR to RGB, HWC to CHW, numpy to tensor 59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 60 | # normalize 61 | normalize(img_gt, self.mean, self.std, inplace=True) 62 | return {'gt': img_gt, 'gt_path': gt_path} 63 | 64 | def __len__(self): 65 | return len(self.paths) 66 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 001 100 (720,1280,3) 3 | 002 100 (720,1280,3) 4 | 003 100 (720,1280,3) 5 | 004 100 (720,1280,3) 6 | 005 100 (720,1280,3) 7 | 006 100 (720,1280,3) 8 | 007 100 (720,1280,3) 9 | 008 100 (720,1280,3) 10 | 009 100 (720,1280,3) 11 | 010 100 (720,1280,3) 12 | 011 100 (720,1280,3) 13 | 012 100 (720,1280,3) 14 | 013 100 (720,1280,3) 15 | 014 100 (720,1280,3) 16 | 015 100 (720,1280,3) 17 | 016 100 (720,1280,3) 18 | 017 100 (720,1280,3) 19 | 018 100 (720,1280,3) 20 | 019 100 (720,1280,3) 21 | 020 100 (720,1280,3) 22 | 021 100 (720,1280,3) 23 | 022 100 (720,1280,3) 24 | 023 100 (720,1280,3) 25 | 024 100 (720,1280,3) 26 | 025 100 (720,1280,3) 27 | 026 100 (720,1280,3) 28 | 027 100 (720,1280,3) 29 | 028 100 (720,1280,3) 30 | 029 100 (720,1280,3) 31 | 030 100 (720,1280,3) 32 | 031 100 (720,1280,3) 33 | 032 100 (720,1280,3) 34 | 033 100 (720,1280,3) 35 | 034 100 (720,1280,3) 36 | 035 100 (720,1280,3) 37 | 036 100 (720,1280,3) 38 | 037 100 (720,1280,3) 39 | 038 100 (720,1280,3) 40 | 039 100 (720,1280,3) 41 | 040 100 (720,1280,3) 42 | 041 100 (720,1280,3) 43 | 042 100 (720,1280,3) 44 | 043 100 (720,1280,3) 45 | 044 100 (720,1280,3) 46 | 045 100 (720,1280,3) 47 | 046 100 (720,1280,3) 48 | 047 100 (720,1280,3) 49 | 048 100 (720,1280,3) 50 | 049 100 (720,1280,3) 51 | 050 100 (720,1280,3) 52 | 051 100 (720,1280,3) 53 | 052 100 (720,1280,3) 54 | 053 100 (720,1280,3) 55 | 054 100 (720,1280,3) 56 | 055 100 (720,1280,3) 57 | 056 100 (720,1280,3) 58 | 057 100 (720,1280,3) 59 | 058 100 (720,1280,3) 60 | 059 100 (720,1280,3) 61 | 060 100 (720,1280,3) 62 | 061 100 (720,1280,3) 63 | 062 100 (720,1280,3) 64 | 063 100 (720,1280,3) 65 | 064 100 (720,1280,3) 66 | 065 100 (720,1280,3) 67 | 066 100 (720,1280,3) 68 | 067 100 (720,1280,3) 69 | 068 100 (720,1280,3) 70 | 069 100 (720,1280,3) 71 | 070 100 (720,1280,3) 72 | 071 100 (720,1280,3) 73 | 072 100 (720,1280,3) 74 | 073 100 (720,1280,3) 75 | 074 100 (720,1280,3) 76 | 075 100 (720,1280,3) 77 | 076 100 (720,1280,3) 78 | 077 100 (720,1280,3) 79 | 078 100 (720,1280,3) 80 | 079 100 (720,1280,3) 81 | 080 100 (720,1280,3) 82 | 081 100 (720,1280,3) 83 | 082 100 (720,1280,3) 84 | 083 100 (720,1280,3) 85 | 084 100 (720,1280,3) 86 | 085 100 (720,1280,3) 87 | 086 100 (720,1280,3) 88 | 087 100 (720,1280,3) 89 | 088 100 (720,1280,3) 90 | 089 100 (720,1280,3) 91 | 090 100 (720,1280,3) 92 | 091 100 (720,1280,3) 93 | 092 100 (720,1280,3) 94 | 093 100 (720,1280,3) 95 | 094 100 (720,1280,3) 96 | 095 100 (720,1280,3) 97 | 096 100 (720,1280,3) 98 | 097 100 (720,1280,3) 99 | 098 100 (720,1280,3) 100 | 099 100 (720,1280,3) 101 | 100 100 (720,1280,3) 102 | 101 100 (720,1280,3) 103 | 102 100 (720,1280,3) 104 | 103 100 (720,1280,3) 105 | 104 100 (720,1280,3) 106 | 105 100 (720,1280,3) 107 | 106 100 (720,1280,3) 108 | 107 100 (720,1280,3) 109 | 108 100 (720,1280,3) 110 | 109 100 (720,1280,3) 111 | 110 100 (720,1280,3) 112 | 111 100 (720,1280,3) 113 | 112 100 (720,1280,3) 114 | 113 100 (720,1280,3) 115 | 114 100 (720,1280,3) 116 | 115 100 (720,1280,3) 117 | 116 100 (720,1280,3) 118 | 117 100 (720,1280,3) 119 | 118 100 (720,1280,3) 120 | 119 100 (720,1280,3) 121 | 120 100 (720,1280,3) 122 | 121 100 (720,1280,3) 123 | 122 100 (720,1280,3) 124 | 123 100 (720,1280,3) 125 | 124 100 (720,1280,3) 126 | 125 100 (720,1280,3) 127 | 126 100 (720,1280,3) 128 | 127 100 (720,1280,3) 129 | 128 100 (720,1280,3) 130 | 129 100 (720,1280,3) 131 | 130 100 (720,1280,3) 132 | 131 100 (720,1280,3) 133 | 132 100 (720,1280,3) 134 | 133 100 (720,1280,3) 135 | 134 100 (720,1280,3) 136 | 135 100 (720,1280,3) 137 | 136 100 (720,1280,3) 138 | 137 100 (720,1280,3) 139 | 138 100 (720,1280,3) 140 | 139 100 (720,1280,3) 141 | 140 100 (720,1280,3) 142 | 141 100 (720,1280,3) 143 | 142 100 (720,1280,3) 144 | 143 100 (720,1280,3) 145 | 144 100 (720,1280,3) 146 | 145 100 (720,1280,3) 147 | 146 100 (720,1280,3) 148 | 147 100 (720,1280,3) 149 | 148 100 (720,1280,3) 150 | 149 100 (720,1280,3) 151 | 150 100 (720,1280,3) 152 | 151 100 (720,1280,3) 153 | 152 100 (720,1280,3) 154 | 153 100 (720,1280,3) 155 | 154 100 (720,1280,3) 156 | 155 100 (720,1280,3) 157 | 156 100 (720,1280,3) 158 | 157 100 (720,1280,3) 159 | 158 100 (720,1280,3) 160 | 159 100 (720,1280,3) 161 | 160 100 (720,1280,3) 162 | 161 100 (720,1280,3) 163 | 162 100 (720,1280,3) 164 | 163 100 (720,1280,3) 165 | 164 100 (720,1280,3) 166 | 165 100 (720,1280,3) 167 | 166 100 (720,1280,3) 168 | 167 100 (720,1280,3) 169 | 168 100 (720,1280,3) 170 | 169 100 (720,1280,3) 171 | 170 100 (720,1280,3) 172 | 171 100 (720,1280,3) 173 | 172 100 (720,1280,3) 174 | 173 100 (720,1280,3) 175 | 174 100 (720,1280,3) 176 | 175 100 (720,1280,3) 177 | 176 100 (720,1280,3) 178 | 177 100 (720,1280,3) 179 | 178 100 (720,1280,3) 180 | 179 100 (720,1280,3) 181 | 180 100 (720,1280,3) 182 | 181 100 (720,1280,3) 183 | 182 100 (720,1280,3) 184 | 183 100 (720,1280,3) 185 | 184 100 (720,1280,3) 186 | 185 100 (720,1280,3) 187 | 186 100 (720,1280,3) 188 | 187 100 (720,1280,3) 189 | 188 100 (720,1280,3) 190 | 189 100 (720,1280,3) 191 | 190 100 (720,1280,3) 192 | 191 100 (720,1280,3) 193 | 192 100 (720,1280,3) 194 | 193 100 (720,1280,3) 195 | 194 100 (720,1280,3) 196 | 195 100 (720,1280,3) 197 | 196 100 (720,1280,3) 198 | 197 100 (720,1280,3) 199 | 198 100 (720,1280,3) 200 | 199 100 (720,1280,3) 201 | 200 100 (720,1280,3) 202 | 201 100 (720,1280,3) 203 | 202 100 (720,1280,3) 204 | 203 100 (720,1280,3) 205 | 204 100 (720,1280,3) 206 | 205 100 (720,1280,3) 207 | 206 100 (720,1280,3) 208 | 207 100 (720,1280,3) 209 | 208 100 (720,1280,3) 210 | 209 100 (720,1280,3) 211 | 210 100 (720,1280,3) 212 | 211 100 (720,1280,3) 213 | 212 100 (720,1280,3) 214 | 213 100 (720,1280,3) 215 | 214 100 (720,1280,3) 216 | 215 100 (720,1280,3) 217 | 216 100 (720,1280,3) 218 | 217 100 (720,1280,3) 219 | 218 100 (720,1280,3) 220 | 219 100 (720,1280,3) 221 | 220 100 (720,1280,3) 222 | 221 100 (720,1280,3) 223 | 222 100 (720,1280,3) 224 | 223 100 (720,1280,3) 225 | 224 100 (720,1280,3) 226 | 225 100 (720,1280,3) 227 | 226 100 (720,1280,3) 228 | 227 100 (720,1280,3) 229 | 228 100 (720,1280,3) 230 | 229 100 (720,1280,3) 231 | 230 100 (720,1280,3) 232 | 231 100 (720,1280,3) 233 | 232 100 (720,1280,3) 234 | 233 100 (720,1280,3) 235 | 234 100 (720,1280,3) 236 | 235 100 (720,1280,3) 237 | 236 100 (720,1280,3) 238 | 237 100 (720,1280,3) 239 | 238 100 (720,1280,3) 240 | 239 100 (720,1280,3) 241 | 240 100 (720,1280,3) 242 | 241 100 (720,1280,3) 243 | 242 100 (720,1280,3) 244 | 243 100 (720,1280,3) 245 | 244 100 (720,1280,3) 246 | 245 100 (720,1280,3) 247 | 246 100 (720,1280,3) 248 | 247 100 (720,1280,3) 249 | 248 100 (720,1280,3) 250 | 249 100 (720,1280,3) 251 | 250 100 (720,1280,3) 252 | 251 100 (720,1280,3) 253 | 252 100 (720,1280,3) 254 | 253 100 (720,1280,3) 255 | 254 100 (720,1280,3) 256 | 255 100 (720,1280,3) 257 | 256 100 (720,1280,3) 258 | 257 100 (720,1280,3) 259 | 258 100 (720,1280,3) 260 | 259 100 (720,1280,3) 261 | 260 100 (720,1280,3) 262 | 261 100 (720,1280,3) 263 | 262 100 (720,1280,3) 264 | 263 100 (720,1280,3) 265 | 264 100 (720,1280,3) 266 | 265 100 (720,1280,3) 267 | 266 100 (720,1280,3) 268 | 267 100 (720,1280,3) 269 | 268 100 (720,1280,3) 270 | 269 100 (720,1280,3) 271 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 8 | 9 | 10 | class Vimeo90KDataset(data.Dataset): 11 | """Vimeo90K dataset for training. 12 | 13 | The keys are generated from a meta info txt file. 14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 15 | 16 | Each line contains: 17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space. 18 | Examples: 19 | 00001/0001 7 (256,448,3) 20 | 00001/0002 7 (256,448,3) 21 | 22 | Key examples: "00001/0001" 23 | GT (gt): Ground-Truth; 24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 25 | 26 | The neighboring frame list for different num_frame: 27 | num_frame | frame list 28 | 1 | 4 29 | 3 | 3,4,5 30 | 5 | 2,3,4,5,6 31 | 7 | 1,2,3,4,5,6,7 32 | 33 | Args: 34 | opt (dict): Config for train dataset. It contains the following keys: 35 | dataroot_gt (str): Data root path for gt. 36 | dataroot_lq (str): Data root path for lq. 37 | meta_info_file (str): Path for meta information file. 38 | io_backend (dict): IO backend type and other kwarg. 39 | 40 | num_frame (int): Window size for input frames. 41 | gt_size (int): Cropped patched size for gt patches. 42 | random_reverse (bool): Random reverse input frames. 43 | use_flip (bool): Use horizontal flips. 44 | use_rot (bool): Use rotation (use vertical flip and transposing h 45 | and w for implementation). 46 | 47 | scale (bool): Scale, which will be added automatically. 48 | """ 49 | 50 | def __init__(self, opt): 51 | super(Vimeo90KDataset, self).__init__() 52 | self.opt = opt 53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( 54 | opt['dataroot_lq']) 55 | 56 | with open(opt['meta_info_file'], 'r') as fin: 57 | self.keys = [line.split(' ')[0] for line in fin] 58 | 59 | # file client (io backend) 60 | self.file_client = None 61 | self.io_backend_opt = opt['io_backend'] 62 | self.is_lmdb = False 63 | if self.io_backend_opt['type'] == 'lmdb': 64 | self.is_lmdb = True 65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 66 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 67 | 68 | # indices of input images 69 | self.neighbor_list = [ 70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) 71 | ] 72 | 73 | # temporal augmentation configs 74 | self.random_reverse = opt['random_reverse'] 75 | logger = get_root_logger() 76 | logger.info(f'Random reverse is {self.random_reverse}.') 77 | 78 | def __getitem__(self, index): 79 | if self.file_client is None: 80 | self.file_client = FileClient( 81 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 82 | 83 | # random reverse 84 | if self.random_reverse and random.random() < 0.5: 85 | self.neighbor_list.reverse() 86 | 87 | scale = self.opt['scale'] 88 | gt_size = self.opt['gt_size'] 89 | key = self.keys[index] 90 | clip, seq = key.split('/') # key example: 00001/0001 91 | 92 | # get the GT frame (im4.png) 93 | if self.is_lmdb: 94 | img_gt_path = f'{key}/im4' 95 | else: 96 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 97 | img_bytes = self.file_client.get(img_gt_path, 'gt') 98 | img_gt = imfrombytes(img_bytes, float32=True) 99 | 100 | # get the neighboring LQ frames 101 | img_lqs = [] 102 | for neighbor in self.neighbor_list: 103 | if self.is_lmdb: 104 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 105 | else: 106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 107 | img_bytes = self.file_client.get(img_lq_path, 'lq') 108 | img_lq = imfrombytes(img_bytes, float32=True) 109 | img_lqs.append(img_lq) 110 | 111 | # randomly crop 112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, 113 | img_gt_path) 114 | 115 | # augmentation - flip, rotate 116 | img_lqs.append(img_gt) 117 | img_results = augment(img_lqs, self.opt['use_flip'], 118 | self.opt['use_rot']) 119 | 120 | img_results = img2tensor(img_results) 121 | img_lqs = torch.stack(img_results[0:-1], dim=0) 122 | img_gt = img_results[-1] 123 | 124 | # img_lqs: (t, c, h, w) 125 | # img_gt: (c, h, w) 126 | # key: str 127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 128 | 129 | def __len__(self): 130 | return len(self.keys) 131 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .niqe import calculate_niqe 2 | from .psnr_ssim import calculate_psnr, calculate_ssim 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 5 | -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/metric_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/metric_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/metric_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/niqe.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/niqe.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/niqe.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/niqe.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.models.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', 11 | resize_input=True, 12 | normalize_input=False): 13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 14 | # does resize the input. 15 | inception = InceptionV3([3], 16 | resize_input=resize_input, 17 | normalize_input=normalize_input) 18 | inception = nn.DataParallel(inception).eval().to(device) 19 | return inception 20 | 21 | 22 | @torch.no_grad() 23 | def extract_inception_features(data_generator, 24 | inception, 25 | len_generator=None, 26 | device='cuda'): 27 | """Extract inception features. 28 | 29 | Args: 30 | data_generator (generator): A data generator. 31 | inception (nn.Module): Inception model. 32 | len_generator (int): Length of the data_generator to show the 33 | progressbar. Default: None. 34 | device (str): Device. Default: cuda. 35 | 36 | Returns: 37 | Tensor: Extracted features. 38 | """ 39 | if len_generator is not None: 40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 41 | else: 42 | pbar = None 43 | features = [] 44 | 45 | for data in data_generator: 46 | if pbar: 47 | pbar.update(1) 48 | data = data.to(device) 49 | feature = inception(data)[0].view(data.shape[0], -1) 50 | features.append(feature.to('cpu')) 51 | if pbar: 52 | pbar.close() 53 | features = torch.cat(features, 0) 54 | return features 55 | 56 | 57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | 60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 61 | and X_2 ~ N(mu_2, C_2) is 62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 63 | Stable version by Dougal J. Sutherland. 64 | 65 | Args: 66 | mu1 (np.array): The sample mean over activations. 67 | sigma1 (np.array): The covariance matrix over activations for 68 | generated samples. 69 | mu2 (np.array): The sample mean over activations, precalculated on an 70 | representative data set. 71 | sigma2 (np.array): The covariance matrix over activations, 72 | precalculated on an representative data set. 73 | 74 | Returns: 75 | float: The Frechet Distance. 76 | """ 77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 78 | assert sigma1.shape == sigma2.shape, ( 79 | 'Two covariances have different dimensions') 80 | 81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 82 | 83 | # Product might be almost singular 84 | if not np.isfinite(cov_sqrt).all(): 85 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 86 | 'of cov estimates') 87 | offset = np.eye(sigma1.shape[0]) * eps 88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 89 | 90 | # Numerical error might give slight imaginary component 91 | if np.iscomplexobj(cov_sqrt): 92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 93 | m = np.max(np.abs(cov_sqrt.imag)) 94 | raise ValueError(f'Imaginary component {m}') 95 | cov_sqrt = cov_sqrt.real 96 | 97 | mean_diff = mu1 - mu2 98 | mean_norm = mean_diff @ mean_diff 99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 100 | fid = mean_norm + trace 101 | 102 | return fid 103 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/.DS_Store -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /basicsr/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/FPro_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/FPro_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/HINT_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/HINT_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/HINT_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/HINT_arch.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/restormer_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/restormer_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/restormer_local_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/restormer_local_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss,FFTLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss','FFTLoss', 5 | ] 6 | -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from basicsr.models.losses.loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | # @weighted_loss 22 | # def charbonnier_loss(pred, target, eps=1e-12): 23 | # return torch.sqrt((pred - target)**2 + eps) 24 | 25 | 26 | class L1Loss(nn.Module): 27 | """L1 (mean absolute error, MAE) loss. 28 | 29 | Args: 30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 31 | reduction (str): Specifies the reduction to apply to the output. 32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 33 | """ 34 | 35 | def __init__(self, loss_weight=1.0, reduction='mean'): 36 | super(L1Loss, self).__init__() 37 | if reduction not in ['none', 'mean', 'sum']: 38 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 39 | f'Supported ones are: {_reduction_modes}') 40 | 41 | self.loss_weight = loss_weight 42 | self.reduction = reduction 43 | 44 | def forward(self, pred, target, weight=None, **kwargs): 45 | """ 46 | Args: 47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 50 | weights. Default: None. 51 | """ 52 | return self.loss_weight * l1_loss( 53 | pred, target, weight, reduction=self.reduction) 54 | 55 | 56 | class FFTLoss(nn.Module): 57 | """L1 loss in frequency domain with FFT. 58 | 59 | Args: 60 | loss_weight (float): Loss weight for FFT loss. Default: 1.0. 61 | reduction (str): Specifies the reduction to apply to the output. 62 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 63 | """ 64 | 65 | def __init__(self, loss_weight=1.0, reduction='mean'): 66 | super(FFTLoss, self).__init__() 67 | if reduction not in ['none', 'mean', 'sum']: 68 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') 69 | 70 | self.loss_weight = loss_weight 71 | self.reduction = reduction 72 | 73 | def forward(self, pred, target, weight=None, **kwargs): 74 | """ 75 | Args: 76 | pred (Tensor): of shape (..., C, H, W). Predicted tensor. 77 | target (Tensor): of shape (..., C, H, W). Ground truth tensor. 78 | weight (Tensor, optional): of shape (..., C, H, W). Element-wise 79 | weights. Default: None. 80 | """ 81 | 82 | pred_fft = torch.fft.fft2(pred, dim=(-2, -1)) 83 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1) 84 | target_fft = torch.fft.fft2(target, dim=(-2, -1)) 85 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1) 86 | return self.loss_weight * l1_loss(pred_fft, target_fft, weight, reduction=self.reduction) 87 | 88 | 89 | class MSELoss(nn.Module): 90 | """MSE (L2) loss. 91 | 92 | Args: 93 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 94 | reduction (str): Specifies the reduction to apply to the output. 95 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 96 | """ 97 | 98 | def __init__(self, loss_weight=1.0, reduction='mean'): 99 | super(MSELoss, self).__init__() 100 | if reduction not in ['none', 'mean', 'sum']: 101 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 102 | f'Supported ones are: {_reduction_modes}') 103 | 104 | self.loss_weight = loss_weight 105 | self.reduction = reduction 106 | 107 | def forward(self, pred, target, weight=None, **kwargs): 108 | """ 109 | Args: 110 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 111 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 112 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 113 | weights. Default: None. 114 | """ 115 | return self.loss_weight * mse_loss( 116 | pred, target, weight, reduction=self.reduction) 117 | 118 | class PSNRLoss(nn.Module): 119 | 120 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 121 | super(PSNRLoss, self).__init__() 122 | assert reduction == 'mean' 123 | self.loss_weight = loss_weight 124 | self.scale = 10 / np.log(10) 125 | self.toY = toY 126 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 127 | self.first = True 128 | 129 | def forward(self, pred, target): 130 | assert len(pred.size()) == 4 131 | if self.toY: 132 | if self.first: 133 | self.coef = self.coef.to(pred.device) 134 | self.first = False 135 | 136 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 137 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 138 | 139 | pred, target = pred / 255., target / 255. 140 | pass 141 | assert len(pred.size()) == 4 142 | 143 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 144 | 145 | class CharbonnierLoss(nn.Module): 146 | """Charbonnier Loss (L1)""" 147 | 148 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): 149 | super(CharbonnierLoss, self).__init__() 150 | self.eps = eps 151 | 152 | def forward(self, x, y): 153 | diff = x - y 154 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 155 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 156 | return loss 157 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import create_dataloader, create_dataset 6 | from basicsr.models import create_model 7 | from basicsr.train import parse_options 8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 9 | make_exp_dirs) 10 | from basicsr.utils.options import dict2str 11 | 12 | 13 | def main(): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt = parse_options(is_train=False) 16 | 17 | torch.backends.cudnn.benchmark = True 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | # mkdir and initialize loggers 21 | make_exp_dirs(opt) 22 | log_file = osp.join(opt['path']['log'], 23 | f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger( 25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 26 | logger.info(get_env_info()) 27 | logger.info(dict2str(opt)) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for phase, dataset_opt in sorted(opt['datasets'].items()): 32 | test_set = create_dataset(dataset_opt) 33 | test_loader = create_dataloader( 34 | test_set, 35 | dataset_opt, 36 | num_gpu=opt['num_gpu'], 37 | dist=opt['dist'], 38 | sampler=None, 39 | seed=opt['manual_seed']) 40 | logger.info( 41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 42 | test_loaders.append(test_loader) 43 | 44 | # create model 45 | model = create_model(opt) 46 | 47 | for test_loader in test_loaders: 48 | test_set_name = test_loader.dataset.opt['name'] 49 | logger.info(f'Testing {test_set_name}...') 50 | rgb2bgr = opt['val'].get('rgb2bgr', True) 51 | # wheather use uint8 image to compute metrics 52 | use_image = opt['val'].get('use_image', True) 53 | model.validation( 54 | test_loader, 55 | current_iter=opt['name'], 56 | tb_logger=None, 57 | save_img=opt['val']['save_img'], 58 | rgb2bgr=rgb2bgr, use_image=use_image) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/dist_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/dist_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/file_client.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/file_client.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/file_client.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/file_client.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/flow_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/flow_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/img_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/img_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/img_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/img_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /basicsr/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError( 40 | 'Please install memcached to enable MemcachedBackend.') 41 | 42 | self.server_list_cfg = server_list_cfg 43 | self.client_cfg = client_cfg 44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, 45 | self.client_cfg) 46 | # mc.pyvector servers as a point which points to a memory cache 47 | self._mc_buffer = mc.pyvector() 48 | 49 | def get(self, filepath): 50 | filepath = str(filepath) 51 | import mc 52 | self._client.Get(filepath, self._mc_buffer) 53 | value_buf = mc.ConvertBuffer(self._mc_buffer) 54 | return value_buf 55 | 56 | def get_text(self, filepath): 57 | raise NotImplementedError 58 | 59 | 60 | class HardDiskBackend(BaseStorageBackend): 61 | """Raw hard disks storage backend.""" 62 | 63 | def get(self, filepath): 64 | filepath = str(filepath) 65 | with open(filepath, 'rb') as f: 66 | value_buf = f.read() 67 | return value_buf 68 | 69 | def get_text(self, filepath): 70 | filepath = str(filepath) 71 | with open(filepath, 'r') as f: 72 | value_buf = f.read() 73 | return value_buf 74 | 75 | 76 | class LmdbBackend(BaseStorageBackend): 77 | """Lmdb storage backend. 78 | 79 | Args: 80 | db_paths (str | list[str]): Lmdb database paths. 81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 82 | readonly (bool, optional): Lmdb environment parameter. If True, 83 | disallow any write operations. Default: True. 84 | lock (bool, optional): Lmdb environment parameter. If False, when 85 | concurrent access occurs, do not lock the database. Default: False. 86 | readahead (bool, optional): Lmdb environment parameter. If False, 87 | disable the OS filesystem readahead mechanism, which may improve 88 | random read performance when a database is larger than RAM. 89 | Default: False. 90 | 91 | Attributes: 92 | db_paths (list): Lmdb database path. 93 | _client (list): A list of several lmdb envs. 94 | """ 95 | 96 | def __init__(self, 97 | db_paths, 98 | client_keys='default', 99 | readonly=True, 100 | lock=False, 101 | readahead=False, 102 | **kwargs): 103 | try: 104 | import lmdb 105 | except ImportError: 106 | raise ImportError('Please install lmdb to enable LmdbBackend.') 107 | 108 | if isinstance(client_keys, str): 109 | client_keys = [client_keys] 110 | 111 | if isinstance(db_paths, list): 112 | self.db_paths = [str(v) for v in db_paths] 113 | elif isinstance(db_paths, str): 114 | self.db_paths = [str(db_paths)] 115 | assert len(client_keys) == len(self.db_paths), ( 116 | 'client_keys and db_paths should have the same length, ' 117 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 118 | 119 | self._client = {} 120 | 121 | for client, path in zip(client_keys, self.db_paths): 122 | self._client[client] = lmdb.open( 123 | path, 124 | readonly=readonly, 125 | lock=lock, 126 | readahead=readahead, 127 | map_size=8*1024*10485760, 128 | # max_readers=1, 129 | **kwargs) 130 | 131 | def get(self, filepath, client_key): 132 | """Get values according to the filepath from one lmdb named client_key. 133 | 134 | Args: 135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 136 | client_key (str): Used for distinguishing differnet lmdb envs. 137 | """ 138 | filepath = str(filepath) 139 | assert client_key in self._client, (f'client_key {client_key} is not ' 140 | 'in lmdb clients.') 141 | client = self._client[client_key] 142 | with client.begin(write=False) as txn: 143 | value_buf = txn.get(filepath.encode('ascii')) 144 | return value_buf 145 | 146 | def get_text(self, filepath): 147 | raise NotImplementedError 148 | 149 | 150 | class FileClient(object): 151 | """A general file client to access files in different backend. 152 | 153 | The client loads a file or text in a specified backend from its path 154 | and return it as a binary file. it can also register other backend 155 | accessor with a given name and backend class. 156 | 157 | Attributes: 158 | backend (str): The storage backend type. Options are "disk", 159 | "memcached" and "lmdb". 160 | client (:obj:`BaseStorageBackend`): The backend object. 161 | """ 162 | 163 | _backends = { 164 | 'disk': HardDiskBackend, 165 | 'memcached': MemcachedBackend, 166 | 'lmdb': LmdbBackend, 167 | } 168 | 169 | def __init__(self, backend='disk', **kwargs): 170 | if backend not in self._backends: 171 | raise ValueError( 172 | f'Backend {backend} is not supported. Currently supported ones' 173 | f' are {list(self._backends.keys())}') 174 | self.backend = backend 175 | self.client = self._backends[backend](**kwargs) 176 | 177 | def get(self, filepath, client_key='default'): 178 | # client_key is used only for lmdb, where different fileclients have 179 | # different lmdb environments. 180 | if self.backend == 'lmdb': 181 | return self.client.get(filepath, client_key) 182 | else: 183 | return self.client.get(filepath) 184 | 185 | def get_text(self, filepath): 186 | return self.client.get_text(filepath) 187 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, ' 25 | f'its dimension is {cat_flow.ndim}.') 26 | assert cat_flow.shape[concat_axis] % 2 == 0 27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 28 | flow = dequantize_flow(dx, dy, *args, **kwargs) 29 | else: 30 | with open(flow_path, 'rb') as f: 31 | try: 32 | header = f.read(4).decode('utf-8') 33 | except Exception: 34 | raise IOError(f'Invalid flow file: {flow_path}') 35 | else: 36 | if header != 'PIEH': 37 | raise IOError(f'Invalid flow file: {flow_path}, ' 38 | 'header does not contain PIEH') 39 | 40 | w = np.fromfile(f, np.int32, 1).squeeze() 41 | h = np.fromfile(f, np.int32, 1).squeeze() 42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 43 | 44 | return flow.astype(np.float32) 45 | 46 | 47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 48 | """Write optical flow to file. 49 | 50 | If the flow is not quantized, it will be saved as a .flo file losslessly, 51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 52 | will be concatenated horizontally into a single image if quantize is True.) 53 | 54 | Args: 55 | flow (ndarray): (h, w, 2) array of optical flow. 56 | filename (str): Output filepath. 57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 58 | images. If set to True, remaining args will be passed to 59 | :func:`quantize_flow`. 60 | concat_axis (int): The axis that dx and dy are concatenated, 61 | can be either 0 or 1. Ignored if quantize is False. 62 | """ 63 | if not quantize: 64 | with open(filename, 'wb') as f: 65 | f.write('PIEH'.encode('utf-8')) 66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 67 | flow = flow.astype(np.float32) 68 | flow.tofile(f) 69 | f.flush() 70 | else: 71 | assert concat_axis in [0, 1] 72 | dx, dy = quantize_flow(flow, *args, **kwargs) 73 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 74 | os.makedirs(filename, exist_ok=True) 75 | cv2.imwrite(dxdy, filename) 76 | 77 | 78 | def quantize_flow(flow, max_val=0.02, norm=True): 79 | """Quantize flow to [0, 255]. 80 | 81 | After this step, the size of flow will be much smaller, and can be 82 | dumped as jpeg images. 83 | 84 | Args: 85 | flow (ndarray): (h, w, 2) array of optical flow. 86 | max_val (float): Maximum value of flow, values beyond 87 | [-max_val, max_val] will be truncated. 88 | norm (bool): Whether to divide flow values by image width/height. 89 | 90 | Returns: 91 | tuple[ndarray]: Quantized dx and dy. 92 | """ 93 | h, w, _ = flow.shape 94 | dx = flow[..., 0] 95 | dy = flow[..., 1] 96 | if norm: 97 | dx = dx / w # avoid inplace operations 98 | dy = dy / h 99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 100 | flow_comps = [ 101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] 102 | ] 103 | return tuple(flow_comps) 104 | 105 | 106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 107 | """Recover from quantized flow. 108 | 109 | Args: 110 | dx (ndarray): Quantized dx. 111 | dy (ndarray): Quantized dy. 112 | max_val (float): Maximum value used when quantizing. 113 | denorm (bool): Whether to multiply flow values with width/height. 114 | 115 | Returns: 116 | ndarray: Dequantized flow. 117 | """ 118 | assert dx.shape == dy.shape 119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 120 | 121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 122 | 123 | if denorm: 124 | dx *= dx.shape[1] 125 | dy *= dx.shape[0] 126 | flow = np.dstack((dx, dy)) 127 | return flow 128 | 129 | 130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 131 | """Quantize an array of (-inf, inf) to [0, levels-1]. 132 | 133 | Args: 134 | arr (ndarray): Input array. 135 | min_val (scalar): Minimum value to be clipped. 136 | max_val (scalar): Maximum value to be clipped. 137 | levels (int): Quantization levels. 138 | dtype (np.type): The type of the quantized array. 139 | 140 | Returns: 141 | tuple: Quantized array. 142 | """ 143 | if not (isinstance(levels, int) and levels > 1): 144 | raise ValueError( 145 | f'levels must be a positive integer, but got {levels}') 146 | if min_val >= max_val: 147 | raise ValueError( 148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 149 | 150 | arr = np.clip(arr, min_val, max_val) - min_val 151 | quantized_arr = np.minimum( 152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 153 | 154 | return quantized_arr 155 | 156 | 157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 158 | """Dequantize an array. 159 | 160 | Args: 161 | arr (ndarray): Input array. 162 | min_val (scalar): Minimum value to be clipped. 163 | max_val (scalar): Maximum value to be clipped. 164 | levels (int): Quantization levels. 165 | dtype (np.type): The type of the dequantized array. 166 | 167 | Returns: 168 | tuple: Dequantized array. 169 | """ 170 | if not (isinstance(levels, int) and levels > 1): 171 | raise ValueError( 172 | f'levels must be a positive integer, but got {levels}') 173 | if min_val >= max_val: 174 | raise ValueError( 175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 176 | 177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - 178 | min_val) / levels + min_val 179 | 180 | return dequantized_arr 181 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 25 | img = torch.from_numpy(img.transpose(2, 0, 1)) 26 | if float32: 27 | img = img.float() 28 | return img 29 | 30 | if isinstance(imgs, list): 31 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 32 | else: 33 | return _totensor(imgs, bgr2rgb, float32) 34 | 35 | 36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 37 | """Convert torch Tensors into image numpy arrays. 38 | 39 | After clamping to [min, max], values will be normalized to [0, 1]. 40 | 41 | Args: 42 | tensor (Tensor or list[Tensor]): Accept shapes: 43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 44 | 2) 3D Tensor of shape (3/1 x H x W); 45 | 3) 2D Tensor of shape (H x W). 46 | Tensor channel should be in RGB order. 47 | rgb2bgr (bool): Whether to change rgb to bgr. 48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 49 | to uint8 type with range [0, 255]; otherwise, float type with 50 | range [0, 1]. Default: ``np.uint8``. 51 | min_max (tuple[int]): min and max values for clamp. 52 | 53 | Returns: 54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 55 | shape (H x W). The channel order is BGR. 56 | """ 57 | if not (torch.is_tensor(tensor) or 58 | (isinstance(tensor, list) 59 | and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError( 61 | f'tensor or list of tensors expected, got {type(tensor)}') 62 | 63 | if torch.is_tensor(tensor): 64 | tensor = [tensor] 65 | result = [] 66 | for _tensor in tensor: 67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 69 | 70 | n_dim = _tensor.dim() 71 | if n_dim == 4: 72 | img_np = make_grid( 73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 74 | normalize=False).numpy() 75 | img_np = img_np.transpose(1, 2, 0) 76 | if rgb2bgr: 77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 78 | elif n_dim == 3: 79 | img_np = _tensor.numpy() 80 | img_np = img_np.transpose(1, 2, 0) 81 | if img_np.shape[2] == 1: # gray image 82 | img_np = np.squeeze(img_np, axis=2) 83 | else: 84 | if rgb2bgr: 85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 86 | elif n_dim == 2: 87 | img_np = _tensor.numpy() 88 | else: 89 | raise TypeError('Only support 4D, 3D or 2D tensor. ' 90 | f'But received with dimension: {n_dim}') 91 | if out_type == np.uint8: 92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 93 | img_np = (img_np * 255.0).round() 94 | img_np = img_np.astype(out_type) 95 | result.append(img_np) 96 | if len(result) == 1: 97 | result = result[0] 98 | return result 99 | 100 | 101 | def imfrombytes(content, flag='color', float32=False): 102 | """Read an image from bytes. 103 | 104 | Args: 105 | content (bytes): Image bytes got from files or other streams. 106 | flag (str): Flags specifying the color type of a loaded image, 107 | candidates are `color`, `grayscale` and `unchanged`. 108 | float32 (bool): Whether to change to float32., If True, will also norm 109 | to [0, 1]. Default: False. 110 | 111 | Returns: 112 | ndarray: Loaded image array. 113 | """ 114 | img_np = np.frombuffer(content, np.uint8) 115 | imread_flags = { 116 | 'color': cv2.IMREAD_COLOR, 117 | 'grayscale': cv2.IMREAD_GRAYSCALE, 118 | 'unchanged': cv2.IMREAD_UNCHANGED 119 | } 120 | if img_np is None: 121 | raise Exception('None .. !!!') 122 | img = cv2.imdecode(img_np, imread_flags[flag]) 123 | if float32: 124 | img = img.astype(np.float32) / 255. 125 | return img 126 | 127 | def imfrombytesDP(content, flag='color', float32=False): 128 | """Read an image from bytes. 129 | 130 | Args: 131 | content (bytes): Image bytes got from files or other streams. 132 | flag (str): Flags specifying the color type of a loaded image, 133 | candidates are `color`, `grayscale` and `unchanged`. 134 | float32 (bool): Whether to change to float32., If True, will also norm 135 | to [0, 1]. Default: False. 136 | 137 | Returns: 138 | ndarray: Loaded image array. 139 | """ 140 | img_np = np.frombuffer(content, np.uint8) 141 | if img_np is None: 142 | raise Exception('None .. !!!') 143 | img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED) 144 | if float32: 145 | img = img.astype(np.float32) / 65535. 146 | return img 147 | 148 | def padding(img_lq, img_gt, gt_size): 149 | h, w, _ = img_lq.shape 150 | 151 | h_pad = max(0, gt_size - h) 152 | w_pad = max(0, gt_size - w) 153 | 154 | if h_pad == 0 and w_pad == 0: 155 | return img_lq, img_gt 156 | 157 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 158 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 159 | # print('img_lq', img_lq.shape, img_gt.shape) 160 | if img_lq.ndim == 2: 161 | img_lq = np.expand_dims(img_lq, axis=2) 162 | if img_gt.ndim == 2: 163 | img_gt = np.expand_dims(img_gt, axis=2) 164 | return img_lq, img_gt 165 | 166 | def padding_DP(img_lqL, img_lqR, img_gt, gt_size): 167 | h, w, _ = img_gt.shape 168 | 169 | h_pad = max(0, gt_size - h) 170 | w_pad = max(0, gt_size - w) 171 | 172 | if h_pad == 0 and w_pad == 0: 173 | return img_lqL, img_lqR, img_gt 174 | 175 | img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 176 | img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 177 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 178 | # print('img_lq', img_lq.shape, img_gt.shape) 179 | return img_lqL, img_lqR, img_gt 180 | 181 | def imwrite(img, file_path, params=None, auto_mkdir=True): 182 | """Write image to file. 183 | 184 | Args: 185 | img (ndarray): Image array to be written. 186 | file_path (str): Image file path. 187 | params (None or list): Same as opencv's :func:`imwrite` interface. 188 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 189 | whether to create it automatically. 190 | 191 | Returns: 192 | bool: Successful or not. 193 | """ 194 | if auto_mkdir: 195 | dir_name = os.path.abspath(os.path.dirname(file_path)) 196 | os.makedirs(dir_name, exist_ok=True) 197 | return cv2.imwrite(file_path, img, params) 198 | 199 | 200 | def crop_border(imgs, crop_border): 201 | """Crop borders of images. 202 | 203 | Args: 204 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 205 | crop_border (int): Crop border for each end of height and weight. 206 | 207 | Returns: 208 | list[ndarray]: Cropped images. 209 | """ 210 | if crop_border == 0: 211 | return imgs 212 | else: 213 | if isinstance(imgs, list): 214 | return [ 215 | v[crop_border:-crop_border, crop_border:-crop_border, ...] 216 | for v in imgs 217 | ] 218 | else: 219 | return imgs[crop_border:-crop_border, crop_border:-crop_border, 220 | ...] 221 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ( 62 | 'img_path_list and keys should have the same length, ' 63 | f'but got {len(img_path_list)} and {len(keys)}') 64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 65 | print(f'Totoal images: {len(img_path_list)}') 66 | if not lmdb_path.endswith('.lmdb'): 67 | raise ValueError("lmdb_path must end with '.lmdb'.") 68 | if osp.exists(lmdb_path): 69 | print(f'Folder {lmdb_path} already exists. Exit.') 70 | sys.exit(1) 71 | 72 | if multiprocessing_read: 73 | # read all the images to memory (multiprocessing) 74 | dataset = {} # use dict to keep the order for multiprocessing 75 | shapes = {} 76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 77 | pbar = tqdm(total=len(img_path_list), unit='image') 78 | 79 | def callback(arg): 80 | """get the image data and update pbar.""" 81 | key, dataset[key], shapes[key] = arg 82 | pbar.update(1) 83 | pbar.set_description(f'Read {key}') 84 | 85 | pool = Pool(n_thread) 86 | for path, key in zip(img_path_list, keys): 87 | pool.apply_async( 88 | read_img_worker, 89 | args=(osp.join(data_path, path), key, compress_level), 90 | callback=callback) 91 | pool.close() 92 | pool.join() 93 | pbar.close() 94 | print(f'Finish reading {len(img_path_list)} images.') 95 | 96 | # create lmdb environment 97 | if map_size is None: 98 | # obtain data size for one image 99 | img = cv2.imread( 100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 101 | _, img_byte = cv2.imencode( 102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 103 | data_size_per_img = img_byte.nbytes 104 | print('Data size per image is: ', data_size_per_img) 105 | data_size = data_size_per_img * len(img_path_list) 106 | map_size = data_size * 10 107 | 108 | env = lmdb.open(lmdb_path, map_size=map_size) 109 | 110 | # write data to lmdb 111 | pbar = tqdm(total=len(img_path_list), unit='chunk') 112 | txn = env.begin(write=True) 113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 115 | pbar.update(1) 116 | pbar.set_description(f'Write {key}') 117 | key_byte = key.encode('ascii') 118 | if multiprocessing_read: 119 | img_byte = dataset[key] 120 | h, w, c = shapes[key] 121 | else: 122 | _, img_byte, img_shape = read_img_worker( 123 | osp.join(data_path, path), key, compress_level) 124 | h, w, c = img_shape 125 | 126 | txn.put(key_byte, img_byte) 127 | # write meta information 128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 129 | if idx % batch == 0: 130 | txn.commit() 131 | txn = env.begin(write=True) 132 | pbar.close() 133 | txn.commit() 134 | env.close() 135 | txt_file.close() 136 | print('\nFinish writing lmdb.') 137 | 138 | 139 | def read_img_worker(path, key, compress_level): 140 | """Read image worker. 141 | 142 | Args: 143 | path (str): Image path. 144 | key (str): Image key. 145 | compress_level (int): Compress level when encoding images. 146 | 147 | Returns: 148 | str: Image key. 149 | byte: Image byte. 150 | tuple[int]: Image shape. 151 | """ 152 | 153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 154 | if img.ndim == 2: 155 | h, w = img.shape 156 | c = 1 157 | else: 158 | h, w, c = img.shape 159 | _, img_byte = cv2.imencode('.png', img, 160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 161 | return (key, img_byte, (h, w, c)) 162 | 163 | 164 | class LmdbMaker(): 165 | """LMDB Maker. 166 | 167 | Args: 168 | lmdb_path (str): Lmdb save path. 169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 170 | batch (int): After processing batch images, lmdb commits. 171 | Default: 5000. 172 | compress_level (int): Compress level when encoding images. Default: 1. 173 | """ 174 | 175 | def __init__(self, 176 | lmdb_path, 177 | map_size=1024**4, 178 | batch=5000, 179 | compress_level=1): 180 | if not lmdb_path.endswith('.lmdb'): 181 | raise ValueError("lmdb_path must end with '.lmdb'.") 182 | if osp.exists(lmdb_path): 183 | print(f'Folder {lmdb_path} already exists. Exit.') 184 | sys.exit(1) 185 | 186 | self.lmdb_path = lmdb_path 187 | self.batch = batch 188 | self.compress_level = compress_level 189 | self.env = lmdb.open(lmdb_path, map_size=map_size) 190 | self.txn = self.env.begin(write=True) 191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 192 | self.counter = 0 193 | 194 | def put(self, img_byte, key, img_shape): 195 | self.counter += 1 196 | key_byte = key.encode('ascii') 197 | self.txn.put(key_byte, img_byte) 198 | # write meta information 199 | h, w, c = img_shape 200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 201 | if self.counter % self.batch == 0: 202 | self.txn.commit() 203 | self.txn = self.env.begin(write=True) 204 | 205 | def close(self): 206 | self.txn.commit() 207 | self.env.close() 208 | self.txt_file.close() 209 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | 13 | Args: 14 | opt (dict): Config. It contains the following keys: 15 | name (str): Exp name. 16 | logger (dict): Contains 'print_freq' (str) for logger interval. 17 | train (dict): Contains 'total_iter' (int) for total iters. 18 | use_tb_logger (bool): Use tensorboard logger. 19 | start_iter (int): Start iter. Default: 1. 20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 21 | """ 22 | 23 | def __init__(self, opt, start_iter=1, tb_logger=None): 24 | self.exp_name = opt['name'] 25 | self.interval = opt['logger']['print_freq'] 26 | self.start_iter = start_iter 27 | self.max_iters = opt['train']['total_iter'] 28 | self.use_tb_logger = opt['logger']['use_tb_logger'] 29 | self.tb_logger = tb_logger 30 | self.start_time = time.time() 31 | self.logger = get_root_logger() 32 | 33 | @master_only 34 | def __call__(self, log_vars): 35 | """Format logging message. 36 | 37 | Args: 38 | log_vars (dict): It contains the following keys: 39 | epoch (int): Epoch number. 40 | iter (int): Current iter. 41 | lrs (list): List for learning rates. 42 | 43 | time (float): Iter time. 44 | data_time (float): Data time for each iter. 45 | """ 46 | # epoch, iter, learning rates 47 | epoch = log_vars.pop('epoch') 48 | current_iter = log_vars.pop('iter') 49 | lrs = log_vars.pop('lrs') 50 | 51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 52 | for v in lrs: 53 | message += f'{v:.3e},' 54 | message += ')] ' 55 | 56 | # time and estimated time 57 | if 'time' in log_vars.keys(): 58 | iter_time = log_vars.pop('time') 59 | data_time = log_vars.pop('data_time') 60 | 61 | total_time = time.time() - self.start_time 62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 65 | message += f'[eta: {eta_str}, ' 66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 67 | 68 | # other items, especially losses 69 | for k, v in log_vars.items(): 70 | message += f'{k}: {v:.4e} ' 71 | # tensorboard logger 72 | if self.use_tb_logger and 'debug' not in self.exp_name: 73 | if k.startswith('l_'): 74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 75 | else: 76 | self.tb_logger.add_scalar(k, v, current_iter) 77 | self.logger.info(message) 78 | 79 | 80 | @master_only 81 | def init_tb_logger(log_dir): 82 | from torch.utils.tensorboard import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir=log_dir) 84 | return tb_logger 85 | 86 | 87 | @master_only 88 | def init_wandb_logger(opt): 89 | """We now only use wandb to sync tensorboard log.""" 90 | import wandb 91 | logger = logging.getLogger('basicsr') 92 | 93 | project = opt['logger']['wandb']['project'] 94 | resume_id = opt['logger']['wandb'].get('resume_id') 95 | if resume_id: 96 | wandb_id = resume_id 97 | resume = 'allow' 98 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 99 | else: 100 | wandb_id = wandb.util.generate_id() 101 | resume = 'never' 102 | 103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 104 | 105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 106 | 107 | 108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 109 | """Get the root logger. 110 | 111 | The logger will be initialized if it has not been initialized. By default a 112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 113 | also be added. 114 | 115 | Args: 116 | logger_name (str): root logger name. Default: 'basicsr'. 117 | log_file (str | None): The log filename. If specified, a FileHandler 118 | will be added to the root logger. 119 | log_level (int): The root logger level. Note that only the process of 120 | rank 0 is affected, while other processes will set the level to 121 | "Error" and be silent most of the time. 122 | 123 | Returns: 124 | logging.Logger: The root logger. 125 | """ 126 | logger = logging.getLogger(logger_name) 127 | # if the logger has been initialized, just return it 128 | if logger_name in initialized_logger: 129 | return logger 130 | 131 | format_str = '%(asctime)s %(levelname)s: %(message)s' 132 | stream_handler = logging.StreamHandler() 133 | stream_handler.setFormatter(logging.Formatter(format_str)) 134 | logger.addHandler(stream_handler) 135 | logger.propagate = False 136 | rank, _ = get_dist_info() 137 | if rank != 0: 138 | logger.setLevel('ERROR') 139 | elif log_file is not None: 140 | logger.setLevel(log_level) 141 | # add file handler 142 | file_handler = logging.FileHandler(log_file, 'w') 143 | file_handler.setFormatter(logging.Formatter(format_str)) 144 | file_handler.setLevel(log_level) 145 | logger.addHandler(file_handler) 146 | initialized_logger[logger_name] = True 147 | return logger 148 | 149 | 150 | def get_env_info(): 151 | """Get environment information. 152 | 153 | Currently, only log the software version. 154 | """ 155 | import torch 156 | import torchvision 157 | 158 | from basicsr.version import __version__ 159 | msg = r""" 160 | ____ _ _____ ____ 161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 164 | /_____/ \__,_//____//_/ \___//____//_/ |_| 165 | ______ __ __ __ __ 166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 170 | """ 171 | msg += ('\nVersion Information: ' 172 | f'\n\tBasicSR: {__version__}' 173 | f'\n\tPyTorch: {torch.__version__}' 174 | f'\n\tTorchVision: {torchvision.__version__}') 175 | return msg -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | from .logger import get_root_logger 10 | 11 | 12 | def set_random_seed(seed): 13 | """Set random seeds.""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | @master_only 39 | def make_exp_dirs(opt): 40 | """Make dirs for experiments.""" 41 | path_opt = opt['path'].copy() 42 | if opt['is_train']: 43 | mkdir_and_rename(path_opt.pop('experiments_root')) 44 | else: 45 | mkdir_and_rename(path_opt.pop('results_root')) 46 | for key, path in path_opt.items(): 47 | if ('strict_load' not in key) and ('pretrain_network' 48 | not in key) and ('resume' 49 | not in key): 50 | os.makedirs(path, exist_ok=True) 51 | 52 | 53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 54 | """Scan a directory to find the interested files. 55 | 56 | Args: 57 | dir_path (str): Path of the directory. 58 | suffix (str | tuple(str), optional): File suffix that we are 59 | interested in. Default: None. 60 | recursive (bool, optional): If set to True, recursively scan the 61 | directory. Default: False. 62 | full_path (bool, optional): If set to True, include the dir_path. 63 | Default: False. 64 | 65 | Returns: 66 | A generator for all the interested files with relative pathes. 67 | """ 68 | 69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 70 | raise TypeError('"suffix" must be a string or tuple of strings') 71 | 72 | root = dir_path 73 | 74 | def _scandir(dir_path, suffix, recursive): 75 | for entry in os.scandir(dir_path): 76 | if not entry.name.startswith('.') and entry.is_file(): 77 | if full_path: 78 | return_path = entry.path 79 | else: 80 | return_path = osp.relpath(entry.path, root) 81 | 82 | if suffix is None: 83 | yield return_path 84 | elif return_path.endswith(suffix): 85 | yield return_path 86 | else: 87 | if recursive: 88 | yield from _scandir( 89 | entry.path, suffix=suffix, recursive=recursive) 90 | else: 91 | continue 92 | 93 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 94 | 95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 96 | """Scan a directory to find the interested files. 97 | 98 | Args: 99 | dir_path (str): Path of the directory. 100 | keywords (str | tuple(str), optional): File keywords that we are 101 | interested in. Default: None. 102 | recursive (bool, optional): If set to True, recursively scan the 103 | directory. Default: False. 104 | full_path (bool, optional): If set to True, include the dir_path. 105 | Default: False. 106 | 107 | Returns: 108 | A generator for all the interested files with relative pathes. 109 | """ 110 | 111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 112 | raise TypeError('"keywords" must be a string or tuple of strings') 113 | 114 | root = dir_path 115 | 116 | def _scandir(dir_path, keywords, recursive): 117 | for entry in os.scandir(dir_path): 118 | if not entry.name.startswith('.') and entry.is_file(): 119 | if full_path: 120 | return_path = entry.path 121 | else: 122 | return_path = osp.relpath(entry.path, root) 123 | 124 | if keywords is None: 125 | yield return_path 126 | elif return_path.find(keywords) > 0: 127 | yield return_path 128 | else: 129 | if recursive: 130 | yield from _scandir( 131 | entry.path, keywords=keywords, recursive=recursive) 132 | else: 133 | continue 134 | 135 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 136 | 137 | def check_resume(opt, resume_iter): 138 | """Check resume states and pretrain_network paths. 139 | 140 | Args: 141 | opt (dict): Options. 142 | resume_iter (int): Resume iteration. 143 | """ 144 | logger = get_root_logger() 145 | if opt['path']['resume_state']: 146 | # get all the networks 147 | networks = [key for key in opt.keys() if key.startswith('network_')] 148 | flag_pretrain = False 149 | for network in networks: 150 | if opt['path'].get(f'pretrain_{network}') is not None: 151 | flag_pretrain = True 152 | if flag_pretrain: 153 | logger.warning( 154 | 'pretrain_network path will be ignored during resuming.') 155 | # set pretrained model paths 156 | for network in networks: 157 | name = f'pretrain_{network}' 158 | basename = network.replace('network_', '') 159 | if opt['path'].get('ignore_resume_networks') is None or ( 160 | basename not in opt['path']['ignore_resume_networks']): 161 | opt['path'][name] = osp.join( 162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 163 | logger.info(f"Set {name} to {opt['path'][name]}") 164 | 165 | 166 | def sizeof_fmt(size, suffix='B'): 167 | """Get human readable file size. 168 | 169 | Args: 170 | size (int): File size. 171 | suffix (str): Suffix. Default: 'B'. 172 | 173 | Return: 174 | str: Formated file siz. 175 | """ 176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 177 | if abs(size) < 1024.0: 178 | return f'{size:3.1f} {unit}{suffix}' 179 | size /= 1024.0 180 | return f'{size:3.1f} Y{suffix}' 181 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Tue Jul 1 18:25:19 2025 3 | __version__ = '1.2.0+733ceb2' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: HINT_n 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 4 | - https://mirrors.ustc.edu.cn/anaconda/cloud/menpo/ 5 | - https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda/ 6 | - https://mirrors.ustc.edu.cn/anaconda/cloud/msys2/ 7 | - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ 8 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/free/ 9 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 13 | - https://repo.continuum.io/pkgs/main/win-64/ 14 | - https://repo.continuum.io/pkgs/free/win-64/ 15 | - defaults 16 | dependencies: 17 | - _libgcc_mutex=0.1=conda_forge 18 | - _openmp_mutex=4.5=2_gnu 19 | - bzip2=1.0.8=hd590300_5 20 | - ca-certificates=2023.11.17=hbcca054_0 21 | - ld_impl_linux-64=2.40=h41732ed_0 22 | - libffi=3.4.2=h7f98852_5 23 | - libgcc-ng=13.2.0=h807b86a_3 24 | - libgomp=13.2.0=h807b86a_3 25 | - libnsl=2.0.1=hd590300_0 26 | - libsqlite=3.44.2=h2797004_0 27 | - libuuid=2.38.1=h0b41bf4_0 28 | - libxcrypt=4.4.36=hd590300_1 29 | - libzlib=1.2.13=hd590300_5 30 | - ncurses=6.4=h59595ed_2 31 | - openssl=3.2.0=hd590300_1 32 | - pip=23.3.2=pyhd8ed1ab_0 33 | - python=3.8.18=hd12c33a_1_cpython 34 | - readline=8.2=h8228510_1 35 | - setuptools=68.2.2=pyhd8ed1ab_0 36 | - tk=8.6.13=noxft_h4845f30_101 37 | - wheel=0.42.0=pyhd8ed1ab_0 38 | - xz=5.2.6=h166bdaf_0 39 | - pip: 40 | - absl-py==2.2.2 41 | - cachetools==5.5.2 42 | - certifi==2023.11.17 43 | - charset-normalizer==3.3.2 44 | - einops==0.7.0 45 | - filelock==3.13.1 46 | - fsspec==2023.12.2 47 | - google-auth==2.40.1 48 | - google-auth-oauthlib==1.0.0 49 | - grpcio==1.70.0 50 | - huggingface-hub==0.20.1 51 | - idna==3.6 52 | - imageio==2.35.1 53 | - importlib-metadata==8.5.0 54 | - joblib==1.4.2 55 | - lazy-loader==0.4 56 | - lmdb==1.6.2 57 | - markdown==3.7 58 | - markupsafe==2.1.5 59 | - natsort==8.4.0 60 | - networkx==3.1 61 | - numpy==1.24.4 62 | - oauthlib==3.2.2 63 | - opencv-python==4.8.1.78 64 | - packaging==23.2 65 | - pillow==10.1.0 66 | - protobuf==5.29.4 67 | - pyasn1==0.6.1 68 | - pyasn1-modules==0.4.2 69 | - pywavelets==1.4.1 70 | - pyyaml==6.0.1 71 | - requests==2.31.0 72 | - requests-oauthlib==2.0.0 73 | - rsa==4.9.1 74 | - safetensors==0.4.1 75 | - scikit-image==0.21.0 76 | - scikit-learn==1.3.2 77 | - scipy==1.10.1 78 | - six==1.17.0 79 | - tensorboard==2.14.0 80 | - tensorboard-data-server==0.7.2 81 | - threadpoolctl==3.5.0 82 | - tifffile==2023.7.10 83 | - timm==0.9.12 84 | - torch==1.12.0+cu113 85 | - torchaudio==0.12.0+cu113 86 | - torchvision==0.13.0+cu113 87 | - tqdm==4.66.1 88 | - typing-extensions==4.9.0 89 | - urllib3==2.1.0 90 | - werkzeug==3.0.6 91 | - zipp==3.20.2 92 | prefix: /home/ubuntu13/anaconda3/envs/HINT_n 93 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 11 | CUDAExtension) 12 | 13 | version_file = 'basicsr/version.py' 14 | 15 | 16 | def readme(): 17 | return '' 18 | # with open('README.md', encoding='utf-8') as f: 19 | # content = f.read() 20 | # return content 21 | 22 | 23 | def get_git_hash(): 24 | 25 | def _minimal_ext_cmd(cmd): 26 | # construct minimal environment 27 | env = {} 28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 29 | v = os.environ.get(k) 30 | if v is not None: 31 | env[k] = v 32 | # LANGUAGE is used on win32 33 | env['LANGUAGE'] = 'C' 34 | env['LANG'] = 'C' 35 | env['LC_ALL'] = 'C' 36 | out = subprocess.Popen( 37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 38 | return out 39 | 40 | try: 41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 42 | sha = out.strip().decode('ascii') 43 | except OSError: 44 | sha = 'unknown' 45 | 46 | return sha 47 | 48 | 49 | def get_hash(): 50 | if os.path.exists('.git'): 51 | sha = get_git_hash()[:7] 52 | elif os.path.exists(version_file): 53 | try: 54 | from basicsr.version import __version__ 55 | sha = __version__.split('+')[-1] 56 | except ImportError: 57 | raise ImportError('Unable to get git version') 58 | else: 59 | sha = 'unknown' 60 | 61 | return sha 62 | 63 | 64 | def write_version_py(): 65 | content = """# GENERATED VERSION FILE 66 | # TIME: {} 67 | __version__ = '{}' 68 | short_version = '{}' 69 | version_info = ({}) 70 | """ 71 | sha = get_hash() 72 | with open('VERSION', 'r') as f: 73 | SHORT_VERSION = f.read().strip() 74 | VERSION_INFO = ', '.join( 75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 76 | VERSION = SHORT_VERSION + '+' + sha 77 | 78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 79 | VERSION_INFO) 80 | with open(version_file, 'w') as f: 81 | f.write(version_file_str) 82 | 83 | 84 | def get_version(): 85 | with open(version_file, 'r') as f: 86 | exec(compile(f.read(), version_file, 'exec')) 87 | return locals()['__version__'] 88 | 89 | 90 | def make_cuda_ext(name, module, sources, sources_cuda=None): 91 | if sources_cuda is None: 92 | sources_cuda = [] 93 | define_macros = [] 94 | extra_compile_args = {'cxx': []} 95 | 96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 97 | define_macros += [('WITH_CUDA', None)] 98 | extension = CUDAExtension 99 | extra_compile_args['nvcc'] = [ 100 | '-D__CUDA_NO_HALF_OPERATORS__', 101 | '-D__CUDA_NO_HALF_CONVERSIONS__', 102 | '-D__CUDA_NO_HALF2_OPERATORS__', 103 | ] 104 | sources += sources_cuda 105 | else: 106 | print(f'Compiling {name} without CUDA') 107 | extension = CppExtension 108 | 109 | return extension( 110 | name=f'{module}.{name}', 111 | sources=[os.path.join(*module.split('.'), p) for p in sources], 112 | define_macros=define_macros, 113 | extra_compile_args=extra_compile_args) 114 | 115 | 116 | def get_requirements(filename='requirements.txt'): 117 | return [] 118 | here = os.path.dirname(os.path.realpath(__file__)) 119 | with open(os.path.join(here, filename), 'r') as f: 120 | requires = [line.replace('\n', '') for line in f.readlines()] 121 | return requires 122 | 123 | 124 | if __name__ == '__main__': 125 | if '--no_cuda_ext' in sys.argv: 126 | ext_modules = [] 127 | sys.argv.remove('--no_cuda_ext') 128 | else: 129 | ext_modules = [ 130 | make_cuda_ext( 131 | name='deform_conv_ext', 132 | module='basicsr.models.ops.dcn', 133 | sources=['src/deform_conv_ext.cpp'], 134 | sources_cuda=[ 135 | 'src/deform_conv_cuda.cpp', 136 | 'src/deform_conv_cuda_kernel.cu' 137 | ]), 138 | make_cuda_ext( 139 | name='fused_act_ext', 140 | module='basicsr.models.ops.fused_act', 141 | sources=['src/fused_bias_act.cpp'], 142 | sources_cuda=['src/fused_bias_act_kernel.cu']), 143 | make_cuda_ext( 144 | name='upfirdn2d_ext', 145 | module='basicsr.models.ops.upfirdn2d', 146 | sources=['src/upfirdn2d.cpp'], 147 | sources_cuda=['src/upfirdn2d_kernel.cu']), 148 | ] 149 | 150 | write_version_py() 151 | setup( 152 | name='basicsr', 153 | version=get_version(), 154 | description='Open Source Image and Video Super-Resolution Toolbox', 155 | long_description=readme(), 156 | author='Xintao Wang', 157 | author_email='xintao.wang@outlook.com', 158 | keywords='computer vision, restoration, super resolution', 159 | url='https://github.com/xinntao/BasicSR', 160 | packages=find_packages( 161 | exclude=('options', 'datasets', 'experiments', 'results', 162 | 'tb_logger', 'wandb')), 163 | classifiers=[ 164 | 'Development Status :: 4 - Beta', 165 | 'License :: OSI Approved :: Apache Software License', 166 | 'Operating System :: OS Independent', 167 | 'Programming Language :: Python :: 3', 168 | 'Programming Language :: Python :: 3.7', 169 | 'Programming Language :: Python :: 3.8', 170 | ], 171 | license='Apache License 2.0', 172 | setup_requires=['cython', 'numpy'], 173 | install_requires=get_requirements(), 174 | ext_modules=ext_modules, 175 | cmdclass={'build_ext': BuildExtension}, 176 | zip_safe=False) 177 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | 2 | ### Dehaze 3 | python test_SOTS_HINT.py 4 | python evaluate_SOTS.py 5 | 6 | ### Derain 7 | python test_rain100L.py 8 | 9 | ### Denoising 10 | python test_gaussian_color_denoising_HINT.py --model_type blind 11 | python evaluate_gaussian_color_denoising_HINT.py --model_type blind 12 | 13 | ### Desnowing 14 | python test_snow100k.py 15 | python evaluate_Snow100k.py 16 | 17 | ### Enhancement 18 | python test_from_dataset_LOLv2_Real.py 19 | python test_from_dataset_LOLv2_Syn.py 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | 5 | python -m torch.distributed.launch --nproc_per_node=2 --master_port=4321 basicsr/train.py -opt $CONFIG --launcher pytorch 6 | --------------------------------------------------------------------------------