├── Dehaze ├── Options │ └── RealDehazing_FPro.yml ├── evaluate_SOTS.py ├── test_SOTS.py └── utils.py ├── Demoiring ├── Options │ └── RealDemoiring_FPro.yml ├── dataset_demoire.py ├── evaluate_demoire.py ├── test_moire.py └── utils.py ├── Deraining ├── Options │ ├── Deraining_FPro_spad.yml │ └── RealDeraindrop_FPro.yml ├── evaluate_PSNR_SSIM.m ├── evaluate_raindrop.py ├── test_AGAN.py ├── test_spad.py └── utils.py ├── INSTALL.md ├── Motion_Deblurring ├── Options │ └── Deblurring_FPro.yml ├── evaluate_gopro_hide.m ├── generate_patches_gopro.py ├── test_FPro.py └── utils.py ├── README.md ├── basicsr ├── .DS_Store ├── __pycache__ │ └── version.cpython-37.pyc ├── data │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── data_sampler.cpython-37.pyc │ │ ├── data_util.cpython-37.pyc │ │ ├── ffhq_dataset.cpython-37.pyc │ │ ├── paired_image_dataset.cpython-37.pyc │ │ ├── prefetch_dataloader.cpython-37.pyc │ │ ├── reds_dataset.cpython-37.pyc │ │ ├── single_image_dataset.cpython-37.pyc │ │ ├── transforms.cpython-37.pyc │ │ ├── video_test_dataset.cpython-37.pyc │ │ └── vimeo90k_dataset.cpython-37.pyc │ ├── data_sampler.py │ ├── data_util.py │ ├── ffhq_dataset.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── metrics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── metric_util.cpython-37.pyc │ │ ├── niqe.cpython-37.pyc │ │ └── psnr_ssim.cpython-37.pyc │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── base_model.cpython-37.pyc │ │ ├── image_restoration_model.cpython-37.pyc │ │ └── lr_scheduler.cpython-37.pyc │ ├── archs │ │ ├── FPro_arch.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── arch_util.cpython-37.pyc │ │ │ ├── graph_layers.cpython-37.pyc │ │ │ └── local_arch.cpython-37.pyc │ │ └── arch_util.py │ ├── base_model.py │ ├── image_restoration_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── loss_util.cpython-37.pyc │ │ │ └── losses.cpython-37.pyc │ │ ├── loss_util.py │ │ └── losses.py │ └── lr_scheduler.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── create_lmdb.cpython-37.pyc │ │ ├── dist_util.cpython-37.pyc │ │ ├── file_client.cpython-37.pyc │ │ ├── flow_util.cpython-37.pyc │ │ ├── img_util.cpython-37.pyc │ │ ├── lmdb_util.cpython-37.pyc │ │ ├── logger.cpython-37.pyc │ │ ├── matlab_functions.cpython-37.pyc │ │ ├── misc.cpython-37.pyc │ │ └── options.cpython-37.pyc │ ├── bundle_submissions.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py └── version.py ├── setup.py ├── test.sh └── train.sh /Dehaze/Options/RealDehazing_FPro.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Dehazing_FPro 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage_dehazeSOT 13 | dataroot_gt: /mnt/sda/zsh/dataset/haze 14 | dataroot_lq: /mnt/sda/zsh/dataset/haze 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ## ------- Training on single fixed-patch size 128x128--------- 27 | mini_batch_sizes: [2] 28 | iters: [300000] 29 | gt_size: 256 30 | gt_sizes: [256] 31 | ## ------------------------------------------------------------ 32 | 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: ValSet 38 | type: Dataset_PairedImage_dehazeSOT 39 | dataroot_gt: /mnt/sda/zsh/dataset/haze 40 | dataroot_lq: /mnt/sda/zsh/dataset/haze 41 | gt_size: 256 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | 47 | network_g: 48 | type: FPro 49 | inp_channels: 3 50 | out_channels: 3 51 | # input_res: 128 52 | dim: 48 53 | # num_blocks: [4,6,6,8] 54 | num_blocks: [2,3,6] 55 | # num_refinement_blocks: 4 56 | num_refinement_blocks: 2 57 | # heads: [1,2,4,8] 58 | heads: [2,4,8] 59 | # ffn_expansion_factor: 2.66 60 | ffn_expansion_factor: 3 61 | bias: False 62 | LayerNorm_type: WithBias 63 | dual_pixel_task: False 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: true 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | fft_loss_opt: 104 | type: FFTLoss 105 | loss_weight: 0.1 106 | reduction: mean 107 | 108 | # validation settings 109 | val: 110 | window_size: 8 111 | val_freq: !!float 4e3 112 | save_img: false 113 | rgb2bgr: true 114 | use_image: false 115 | max_minibatch: 8 116 | 117 | metrics: 118 | psnr: # metric name, can be arbitrary 119 | type: calculate_psnr 120 | crop_border: 0 121 | test_y_channel: false 122 | 123 | # logging settings 124 | logger: 125 | print_freq: 1000 126 | save_checkpoint_freq: !!float 4e3 127 | use_tb_logger: true 128 | wandb: 129 | project: ~ 130 | resume_id: ~ 131 | 132 | # dist training settings 133 | dist_params: 134 | backend: nccl 135 | port: 29500 136 | -------------------------------------------------------------------------------- /Dehaze/evaluate_SOTS.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | import argparse 12 | from skimage.metrics import structural_similarity 13 | from tqdm import tqdm 14 | import concurrent.futures 15 | import utils 16 | 17 | def proc(filename): 18 | tar,prd = filename 19 | prd_name = prd.split('/')[-1].split('_')[0]+'.png' 20 | tar_name = '/mnt/sda/zsh/dataset/haze/promptIR/outdoor/gt/' + prd_name 21 | # print('tar',tar) 22 | # print('prd',prd) 23 | tar_img = utils.load_img(tar_name) 24 | prd_img = utils.load_img(prd) 25 | 26 | PSNR = utils.calculate_psnr(tar_img, prd_img) 27 | SSIM = utils.calculate_ssim(tar_img, prd_img) 28 | return PSNR,SSIM 29 | 30 | parser = argparse.ArgumentParser(description='Dehazing using FPro') 31 | 32 | args = parser.parse_args() 33 | 34 | 35 | datasets = ['outdoor'] 36 | 37 | for dataset in datasets: 38 | 39 | gt_path = os.path.join('/mnt/sda/zsh/dataset/haze/promptIR/outdoor/gt') 40 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif'))) 41 | assert len(gt_list) != 0, "Target files not found" 42 | 43 | 44 | file_path = os.path.join('results', 'FPro', dataset) 45 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif'))) 46 | assert len(path_list) != 0, "Predicted files not found" 47 | 48 | psnr, ssim = [], [] 49 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 50 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 51 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 52 | psnr.append(PSNR_SSIM[0]) 53 | ssim.append(PSNR_SSIM[1]) 54 | 55 | avg_psnr = sum(psnr)/len(psnr) 56 | avg_ssim = sum(ssim)/len(ssim) 57 | 58 | # print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr)) 59 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 60 | -------------------------------------------------------------------------------- /Dehaze/test_SOTS.py: -------------------------------------------------------------------------------- 1 | ## Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration 2 | 3 | import numpy as np 4 | import os 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | import torch.nn as nn 9 | import torch 10 | import torch.nn.functional as F 11 | import utils 12 | 13 | from natsort import natsorted 14 | from glob import glob 15 | from basicsr.models.archs.FPro_arch import FPro 16 | from skimage import img_as_ubyte 17 | from pdb import set_trace as stx 18 | 19 | parser = argparse.ArgumentParser(description='Image Dehazning using FPro') 20 | 21 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/dataset/haze/promptIR/', type=str, help='Directory of validation images') 22 | parser.add_argument('--result_dir', default='./results/FPro/', type=str, help='Directory for results') 23 | parser.add_argument('--weights', default='/mnt/sda/zsh/FPro/Dehaze/models/synDehaze.pth', type=str, help='Path to weights') 24 | 25 | args = parser.parse_args() 26 | 27 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 28 | _, C, H, W = imgtensor.shape 29 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 30 | while hstarts and hstarts[-1] + crop_size >= H: 31 | hstarts.pop() 32 | hstarts.append(H - crop_size) 33 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 34 | while wstarts and wstarts[-1] + crop_size >= W: 35 | wstarts.pop() 36 | wstarts.append(W - crop_size) 37 | starts = [] 38 | split_data = [] 39 | for hs in hstarts: 40 | for ws in wstarts: 41 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 42 | starts.append((hs, ws)) 43 | split_data.append(cimgdata) 44 | return split_data, starts 45 | 46 | def get_scoremap(H, W, C, B=1, is_mean=True): 47 | center_h = H / 2 48 | center_w = W / 2 49 | 50 | score = torch.ones((B, C, H, W)) 51 | if not is_mean: 52 | for h in range(H): 53 | for w in range(W): 54 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 55 | return score 56 | 57 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 58 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 59 | tot_score = torch.zeros((B, C, H, W)) 60 | merge_img = torch.zeros((B, C, H, W)) 61 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 62 | for simg, cstart in zip(split_data, starts): 63 | hs, ws = cstart 64 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 65 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 66 | merge_img = merge_img / tot_score 67 | return merge_img 68 | 69 | ####### Load yaml ####### 70 | yaml_file = 'Options/RealDehazing_FPro.yml' 71 | import yaml 72 | 73 | try: 74 | from yaml import CLoader as Loader 75 | except ImportError: 76 | from yaml import Loader 77 | 78 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 79 | 80 | s = x['network_g'].pop('type') 81 | ########################## 82 | 83 | model_restoration = FPro(**x['network_g']) 84 | 85 | checkpoint = torch.load(args.weights) 86 | model_restoration.load_state_dict(checkpoint['params']) 87 | print("===>Testing using weights: ",args.weights) 88 | model_restoration.cuda() 89 | model_restoration = nn.DataParallel(model_restoration) 90 | model_restoration.eval() 91 | 92 | 93 | factor = 8 94 | datasets = ['outdoor'] 95 | 96 | for dataset in datasets: 97 | result_dir = os.path.join(args.result_dir, dataset) 98 | os.makedirs(result_dir, exist_ok=True) 99 | 100 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain') 101 | inp_dir = os.path.join(args.input_dir, dataset, 'hazy/') 102 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 103 | with torch.no_grad(): 104 | for file_ in tqdm(files): 105 | torch.cuda.ipc_collect() 106 | torch.cuda.empty_cache() 107 | 108 | img = np.float32(utils.load_img(file_))/255. 109 | img = torch.from_numpy(img).permute(2,0,1) 110 | input_ = img.unsqueeze(0).cuda() 111 | 112 | # Padding in case images are not multiples of 8 113 | h,w = input_.shape[2], input_.shape[3] 114 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 115 | padh = H-h if h%factor!=0 else 0 116 | padw = W-w if w%factor!=0 else 0 117 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 118 | 119 | B, C, H, W = input_.shape 120 | corp_size_arg = 256 121 | overlap_size_arg = 158 122 | # corp_size_arg = 512 123 | # overlap_size_arg = 204 124 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 125 | for i, data in enumerate(split_data): 126 | split_data[i] = model_restoration(data).cpu() 127 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 128 | # rgb_restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy() 129 | 130 | # restored = rgb_restored 131 | # restored = model_restoration(input_) 132 | 133 | # Unpad images to original dimensions 134 | restored = restored[:,:,:h,:w] 135 | 136 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 137 | 138 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 139 | -------------------------------------------------------------------------------- /Dehaze/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Demoiring/Options/RealDemoiring_FPro.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: RealDemoiring_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage_denseHaze 13 | dataroot_gt: /home/ubuntu/zsh/datasets/TIP18/process/train/thin_target 14 | dataroot_lq: /home/ubuntu/zsh/datasets/TIP18/process/train/thin_source 15 | geometric_augs: False 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ## ------- Training on single fixed-patch size 128x128--------- 27 | mini_batch_sizes: [2] 28 | iters: [300000] 29 | gt_size: 256 30 | gt_sizes: [256] 31 | ## ------------------------------------------------------------ 32 | 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: ValSet 38 | type: Dataset_PairedImage_denseHaze 39 | dataroot_gt: /home/ubuntu/zsh/datasets/TIP18/process/val/thin_target 40 | dataroot_lq: /home/ubuntu/zsh/datasets/TIP18/process/val/thin_source 41 | gt_size: 256 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | 47 | network_g: 48 | type: Restormer 49 | inp_channels: 3 50 | out_channels: 3 51 | # input_res: 128 52 | dim: 48 53 | # num_blocks: [4,6,6,8] 54 | num_blocks: [2,3,6] 55 | # num_refinement_blocks: 4 56 | num_refinement_blocks: 2 57 | # heads: [1,2,4,8] 58 | heads: [2,4,8] 59 | # ffn_expansion_factor: 2.66 60 | ffn_expansion_factor: 3 61 | bias: False 62 | LayerNorm_type: WithBias 63 | dual_pixel_task: False 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: true 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | fft_loss_opt: 104 | type: FFTLoss 105 | loss_weight: 0.1 106 | reduction: mean 107 | 108 | # validation settings 109 | val: 110 | window_size: 8 111 | val_freq: !!float 4e3 112 | save_img: false 113 | rgb2bgr: true 114 | use_image: false 115 | max_minibatch: 8 116 | 117 | metrics: 118 | psnr: # metric name, can be arbitrary 119 | type: calculate_psnr 120 | crop_border: 0 121 | test_y_channel: false 122 | 123 | # logging settings 124 | logger: 125 | print_freq: 1000 126 | save_checkpoint_freq: !!float 4e3 127 | use_tb_logger: true 128 | wandb: 129 | project: ~ 130 | resume_id: ~ 131 | 132 | # dist training settings 133 | dist_params: 134 | backend: nccl 135 | port: 29500 136 | -------------------------------------------------------------------------------- /Demoiring/evaluate_demoire.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | import argparse 12 | from skimage.metrics import structural_similarity 13 | from tqdm import tqdm 14 | import concurrent.futures 15 | import utils 16 | 17 | def proc(filename): 18 | tar,prd = filename 19 | tar_img = utils.load_img(tar) 20 | prd_img = utils.load_img(prd) 21 | 22 | PSNR = utils.calculate_psnr(tar_img, prd_img) 23 | SSIM = utils.calculate_ssim(tar_img, prd_img) 24 | return PSNR,SSIM 25 | 26 | parser = argparse.ArgumentParser(description='Demoireing using FPro') 27 | 28 | args = parser.parse_args() 29 | 30 | 31 | datasets = ['TIP18'] 32 | 33 | for dataset in datasets: 34 | #/home/ubuntu/zsh/datasets/TIP18/process/test_resize286_crop256/thin_target 35 | #/home/ubuntu/zsh/datasets/TIP18/process/test_256/thin_target 36 | gt_path = os.path.join('/mnt/sda/zsh/FPro/Demoiring/test_resize286_crop256/thin_target') 37 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif'))) 38 | assert len(gt_list) != 0, "Target files not found" 39 | 40 | 41 | file_path = os.path.join('results/', 'FPro_test/', dataset) 42 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif'))) 43 | assert len(path_list) != 0, "Predicted files not found" 44 | 45 | psnr, ssim = [], [] 46 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 47 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 48 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 49 | psnr.append(PSNR_SSIM[0]) 50 | ssim.append(PSNR_SSIM[1]) 51 | 52 | avg_psnr = sum(psnr)/len(psnr) 53 | avg_ssim = sum(ssim)/len(ssim) 54 | 55 | print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr)) 56 | print('For {:s} dataset SSIM: {:f}\n'.format(dataset, avg_ssim)) 57 | # print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 58 | -------------------------------------------------------------------------------- /Demoiring/test_moire.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | import torch.nn as nn 8 | import torch 9 | import torch.nn.functional as F 10 | import utils 11 | 12 | from natsort import natsorted 13 | from glob import glob 14 | from basicsr.models.archs.FPro_arch import FPro 15 | from skimage import img_as_ubyte 16 | from pdb import set_trace as stx 17 | 18 | parser = argparse.ArgumentParser(description='Image Demoireing using FPro') 19 | #test_resize286_crop256 test_256 20 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/FPro/Demoiring/test_resize286_crop256/thin_source', type=str, help='Directory of validation images') 21 | parser.add_argument('--result_dir', default='./results/FPro_test/', type=str, help='Directory for results') 22 | parser.add_argument('--weights', default='./models/demoire_noAug.pth', type=str, help='Path to weights') 23 | 24 | args = parser.parse_args() 25 | 26 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 27 | _, C, H, W = imgtensor.shape 28 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 29 | while hstarts and hstarts[-1] + crop_size >= H: 30 | hstarts.pop() 31 | hstarts.append(H - crop_size) 32 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 33 | while wstarts and wstarts[-1] + crop_size >= W: 34 | wstarts.pop() 35 | wstarts.append(W - crop_size) 36 | starts = [] 37 | split_data = [] 38 | for hs in hstarts: 39 | for ws in wstarts: 40 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 41 | starts.append((hs, ws)) 42 | split_data.append(cimgdata) 43 | return split_data, starts 44 | 45 | def get_scoremap(H, W, C, B=1, is_mean=True): 46 | center_h = H / 2 47 | center_w = W / 2 48 | 49 | score = torch.ones((B, C, H, W)) 50 | if not is_mean: 51 | for h in range(H): 52 | for w in range(W): 53 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 54 | return score 55 | 56 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 57 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 58 | tot_score = torch.zeros((B, C, H, W)) 59 | merge_img = torch.zeros((B, C, H, W)) 60 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 61 | for simg, cstart in zip(split_data, starts): 62 | hs, ws = cstart 63 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 64 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 65 | merge_img = merge_img / tot_score 66 | return merge_img 67 | 68 | ####### Load yaml ####### 69 | yaml_file = 'Options/RealDemoiring_FPro.yml' 70 | import yaml 71 | 72 | try: 73 | from yaml import CLoader as Loader 74 | except ImportError: 75 | from yaml import Loader 76 | 77 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 78 | 79 | s = x['network_g'].pop('type') 80 | ########################## 81 | 82 | model_restoration = FPro(**x['network_g']) 83 | 84 | checkpoint = torch.load(args.weights) 85 | model_restoration.load_state_dict(checkpoint['params']) 86 | print("===>Testing using weights: ",args.weights) 87 | model_restoration.cuda() 88 | model_restoration = nn.DataParallel(model_restoration) 89 | model_restoration.eval() 90 | 91 | 92 | factor = 8 93 | datasets = ['TIP18'] 94 | 95 | for dataset in datasets: 96 | result_dir = os.path.join(args.result_dir, dataset) 97 | os.makedirs(result_dir, exist_ok=True) 98 | 99 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain') 100 | inp_dir = os.path.join(args.input_dir) 101 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 102 | with torch.no_grad(): 103 | for file_ in tqdm(files): 104 | torch.cuda.ipc_collect() 105 | torch.cuda.empty_cache() 106 | 107 | img = np.float32(utils.load_img(file_))/255. 108 | img = torch.from_numpy(img).permute(2,0,1) 109 | input_ = img.unsqueeze(0).cuda() 110 | 111 | # Padding in case images are not multiples of 8 112 | h,w = input_.shape[2], input_.shape[3] 113 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 114 | padh = H-h if h%factor!=0 else 0 115 | padw = W-w if w%factor!=0 else 0 116 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 117 | 118 | restored = model_restoration(input_) 119 | restored = restored[:,:,:h,:w] 120 | 121 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 122 | 123 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 124 | -------------------------------------------------------------------------------- /Demoiring/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Deraining/Options/Deraining_FPro_spad.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Deraining_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 2 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage_derainSpad 13 | dataroot_gt: /home/ubuntu/zsh/datasets/derain/real_world_gt 14 | dataroot_lq: /home/ubuntu/zsh/datasets/derain/real_world 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{/home/ubuntu/zsh/datasets/derain}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | # ### -------------Progressive training-------------------------- 27 | # # mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | # mini_batch_sizes: [6,4,3,1] # Batch size per gpu 29 | # # mini_batch_sizes: [20,16,12,8] # Batch size per gpu 30 | # # iters: [92000,64000,48000,36000,36000,24000] 31 | # iters: [152000,74000,48000,26000] 32 | # # gt_size: 384 # Max patch size for progressive training 33 | # gt_size: 256 # Max patch size for progressive training 34 | # gt_sizes: [128,160,192,256] # Patch sizes for progressive training. 35 | # ### ------------------------------------------------------------ 36 | 37 | ### ------- Training on single fixed-patch size 128x128--------- 38 | # mini_batch_sizes: [8] 39 | # iters: [300000] 40 | # gt_size: 128 41 | # gt_sizes: [128] 42 | ## ------------------------------------------------------------ 43 | ## ------- Training on single fixed-patch size 128x128--------- 44 | mini_batch_sizes: [2] 45 | iters: [300000] 46 | gt_size: 256 47 | gt_sizes: [256] 48 | ## ------------------------------------------------------------ 49 | 50 | dataset_enlarge_ratio: 1 51 | prefetch_mode: ~ 52 | 53 | val: 54 | name: ValSet 55 | type: Dataset_PairedImage_derainSpad 56 | dataroot_gt: /home/ubuntu/zsh/datasets/derain/real_test_1000/gt 57 | dataroot_lq: /home/ubuntu/zsh/datasets/derain/real_test_1000/rain 58 | gt_size: 256 59 | io_backend: 60 | type: disk 61 | 62 | # network structures 63 | network_g: 64 | type: Restormer 65 | inp_channels: 3 66 | out_channels: 3 67 | # input_res: 128 68 | dim: 48 69 | # num_blocks: [4,6,6,8] 70 | num_blocks: [2,3,6] 71 | # num_refinement_blocks: 4 72 | num_refinement_blocks: 2 73 | # heads: [1,2,4,8] 74 | heads: [2,4,8] 75 | # ffn_expansion_factor: 2.66 76 | ffn_expansion_factor: 3 77 | bias: False 78 | LayerNorm_type: WithBias 79 | dual_pixel_task: False 80 | # type: Restormer 81 | # inp_channels: 3 82 | # out_channels: 3 83 | # # input_res: 128 84 | # dim: 48 85 | # num_blocks: [4,6,6,8] 86 | # # num_blocks: [1,3,6] 87 | # num_refinement_blocks: 4 88 | # # num_refinement_blocks: 2 89 | # heads: [1,2,4,8] 90 | # ffn_expansion_factor: 2.66 91 | # # ffn_expansion_factor: 3 92 | # bias: False 93 | # LayerNorm_type: WithBias 94 | # dual_pixel_task: False 95 | 96 | 97 | # path 98 | path: 99 | pretrain_network_g: ~ 100 | strict_load_g: true 101 | resume_state: ~ 102 | 103 | # training settings 104 | train: 105 | # total_iter: 300000 106 | total_iter: 300000 107 | warmup_iter: -1 # no warm up 108 | use_grad_clip: true 109 | 110 | # Split 300k iterations into two cycles. 111 | # 1st cycle: fixed 3e-4 LR for 92k iters. 112 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 113 | scheduler: 114 | type: CosineAnnealingRestartCyclicLR 115 | periods: [92000, 208000] 116 | # periods: [480000, 720000] 117 | restart_weights: [1,1] 118 | eta_mins: [0.0003,0.000001] 119 | 120 | mixing_augs: 121 | mixup: false 122 | mixup_beta: 1.2 123 | use_identity: true 124 | 125 | optim_g: 126 | type: AdamW 127 | lr: !!float 3e-4 128 | weight_decay: !!float 1e-4 129 | betas: [0.9, 0.999] 130 | 131 | # losses 132 | pixel_opt: 133 | type: L1Loss 134 | loss_weight: 1 135 | reduction: mean 136 | 137 | fft_loss_opt: 138 | type: FFTLoss 139 | loss_weight: 0.1 140 | reduction: mean 141 | 142 | # validation settings 143 | val: 144 | window_size: 8 145 | val_freq: !!float 4e3 146 | # val_freq: !!float 300e3 147 | save_img: true 148 | rgb2bgr: true 149 | use_image: true 150 | max_minibatch: 8 151 | 152 | metrics: 153 | psnr: # metric name, can be arbitrary 154 | type: calculate_psnr 155 | crop_border: 0 156 | test_y_channel: true 157 | 158 | # logging settings 159 | logger: 160 | print_freq: 1000 161 | save_checkpoint_freq: !!float 4e3 162 | use_tb_logger: true 163 | wandb: 164 | project: ~ 165 | resume_id: ~ 166 | 167 | # dist training settings 168 | dist_params: 169 | backend: nccl 170 | port: 29500 171 | -------------------------------------------------------------------------------- /Deraining/Options/RealDeraindrop_FPro.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: RealDeraindrop_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage_denseHaze 13 | dataroot_gt: /mnt/sda/dataset/raindrop/train/data 14 | dataroot_lq: /mnt/sda/dataset/raindrop/train/gt 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ## ------- Training on single fixed-patch size 128x128--------- 27 | mini_batch_sizes: [2] 28 | iters: [300000] 29 | gt_size: 256 30 | gt_sizes: [256] 31 | ## ------------------------------------------------------------ 32 | 33 | dataset_enlarge_ratio: 1 34 | prefetch_mode: ~ 35 | 36 | val: 37 | name: ValSet 38 | type: Dataset_PairedImage_denseHaze 39 | dataroot_gt: /mnt/sda/dataset/raindrop/test_a/data 40 | dataroot_lq: /mnt/sda/dataset/raindrop/test_a/gt 41 | gt_size: 256 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | 47 | network_g: 48 | type: Restormer 49 | inp_channels: 3 50 | out_channels: 3 51 | # input_res: 128 52 | dim: 48 53 | # num_blocks: [4,6,6,8] 54 | num_blocks: [2,3,6] 55 | # num_refinement_blocks: 4 56 | num_refinement_blocks: 2 57 | # heads: [1,2,4,8] 58 | heads: [2,4,8] 59 | # ffn_expansion_factor: 2.66 60 | ffn_expansion_factor: 3 61 | bias: False 62 | LayerNorm_type: WithBias 63 | dual_pixel_task: False 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: true 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | fft_loss_opt: 104 | type: FFTLoss 105 | loss_weight: 0.1 106 | reduction: mean 107 | 108 | # validation settings 109 | val: 110 | window_size: 8 111 | val_freq: !!float 4e3 112 | save_img: false 113 | rgb2bgr: true 114 | use_image: false 115 | max_minibatch: 8 116 | 117 | metrics: 118 | psnr: # metric name, can be arbitrary 119 | type: calculate_psnr 120 | crop_border: 0 121 | test_y_channel: false 122 | 123 | # logging settings 124 | logger: 125 | print_freq: 1000 126 | save_checkpoint_freq: !!float 4e3 127 | use_tb_logger: true 128 | wandb: 129 | project: ~ 130 | resume_id: ~ 131 | 132 | # dist training settings 133 | dist_params: 134 | backend: nccl 135 | port: 29500 136 | -------------------------------------------------------------------------------- /Deraining/evaluate_raindrop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from glob import glob 5 | from tqdm import tqdm 6 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 7 | from skimage.metrics import structural_similarity as compare_ssim 8 | 9 | 10 | def calc_psnr(im1, im2): 11 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 12 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 13 | return compare_psnr(im1_y, im2_y) 14 | 15 | 16 | def calc_ssim(im1, im2): 17 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 18 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0] 19 | return compare_ssim(im1_y, im2_y) 20 | 21 | 22 | def align_to_four(img): 23 | a_row = int(img.shape[0]/4)*4 24 | a_col = int(img.shape[1]/4)*4 25 | img = img[0:a_row, 0:a_col, :] 26 | return img 27 | 28 | 29 | def evaluate_raindrop(in_dir, gt_dir): 30 | inputs = sorted(glob(os.path.join(in_dir, '*.png')) + glob(os.path.join(in_dir, '*.jpg'))) 31 | gts = sorted(glob(os.path.join(gt_dir, '*.png')) + glob(os.path.join(gt_dir, '*.jpg'))) 32 | psnrs = [] 33 | ssims = [] 34 | for input, gt in tqdm(zip(inputs, gts)): 35 | inputdata = cv2.imread(input) 36 | gtdata = cv2.imread(gt) 37 | inputdata = align_to_four(inputdata) 38 | gtdata = align_to_four(gtdata) 39 | psnrs.append(calc_psnr(inputdata, gtdata)) 40 | ssims.append(calc_ssim(inputdata, gtdata)) 41 | 42 | ave_psnr = np.array(psnrs).mean() 43 | ave_ssim = np.array(ssims).mean() 44 | return ave_psnr, ave_ssim 45 | 46 | 47 | if __name__ == '__main__': 48 | ave_psnr, ave_ssim = evaluate_raindrop('/mnt/sda/zsh/FPro/Deraining/results/FPro_AGAN/test_a', '/mnt/sda/zsh/dataset/test_a/gt') 49 | print('PSNR: ', ave_psnr) 50 | print('SSIM: ', ave_ssim) 51 | -------------------------------------------------------------------------------- /Deraining/test_AGAN.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | import utils 16 | 17 | from natsort import natsorted 18 | from glob import glob 19 | from basicsr.models.archs.FPro_arch import FPro 20 | from skimage import img_as_ubyte 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Deraindrop using FPro') 24 | 25 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/dataset/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/FPro_AGAN/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./models/deraindrop_FPro.pth', type=str, help='Path to weights') 28 | 29 | args = parser.parse_args() 30 | 31 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 32 | _, C, H, W = imgtensor.shape 33 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 34 | while hstarts and hstarts[-1] + crop_size >= H: 35 | hstarts.pop() 36 | hstarts.append(H - crop_size) 37 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 38 | while wstarts and wstarts[-1] + crop_size >= W: 39 | wstarts.pop() 40 | wstarts.append(W - crop_size) 41 | starts = [] 42 | split_data = [] 43 | for hs in hstarts: 44 | for ws in wstarts: 45 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 46 | starts.append((hs, ws)) 47 | split_data.append(cimgdata) 48 | return split_data, starts 49 | 50 | def get_scoremap(H, W, C, B=1, is_mean=True): 51 | center_h = H / 2 52 | center_w = W / 2 53 | 54 | score = torch.ones((B, C, H, W)) 55 | if not is_mean: 56 | for h in range(H): 57 | for w in range(W): 58 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 59 | return score 60 | 61 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 62 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 63 | tot_score = torch.zeros((B, C, H, W)) 64 | merge_img = torch.zeros((B, C, H, W)) 65 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 66 | for simg, cstart in zip(split_data, starts): 67 | hs, ws = cstart 68 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 69 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 70 | merge_img = merge_img / tot_score 71 | return merge_img 72 | 73 | ####### Load yaml ####### 74 | yaml_file = 'Options/RealDeraindrop_FPro.yml' 75 | import yaml 76 | 77 | try: 78 | from yaml import CLoader as Loader 79 | except ImportError: 80 | from yaml import Loader 81 | 82 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 83 | 84 | s = x['network_g'].pop('type') 85 | ########################## 86 | 87 | model_restoration = FPro(**x['network_g']) 88 | 89 | checkpoint = torch.load(args.weights) 90 | model_restoration.load_state_dict(checkpoint['params']) 91 | print("===>Testing using weights: ",args.weights) 92 | model_restoration.cuda() 93 | model_restoration = nn.DataParallel(model_restoration) 94 | model_restoration.eval() 95 | 96 | 97 | factor = 8 98 | datasets = ['test_a'] 99 | 100 | for dataset in datasets: 101 | result_dir = os.path.join(args.result_dir, dataset) 102 | os.makedirs(result_dir, exist_ok=True) 103 | 104 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain') 105 | inp_dir = os.path.join(args.input_dir, dataset, 'data') 106 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 107 | with torch.no_grad(): 108 | for file_ in tqdm(files): 109 | torch.cuda.ipc_collect() 110 | torch.cuda.empty_cache() 111 | 112 | img = np.float32(utils.load_img(file_))/255. 113 | img = torch.from_numpy(img).permute(2,0,1) 114 | input_ = img.unsqueeze(0).cuda() 115 | 116 | # Padding in case images are not multiples of 8 117 | h,w = input_.shape[2], input_.shape[3] 118 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 119 | padh = H-h if h%factor!=0 else 0 120 | padw = W-w if w%factor!=0 else 0 121 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 122 | 123 | B, C, H, W = input_.shape 124 | corp_size_arg = 256 125 | overlap_size_arg = 200 126 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 127 | for i, data in enumerate(split_data): 128 | split_data[i] = model_restoration(data).cpu() 129 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 130 | # restored = model_restoration(input_) 131 | 132 | restored = restored[:,:,:h,:w] 133 | 134 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 135 | 136 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 137 | -------------------------------------------------------------------------------- /Deraining/test_spad.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | import utils 16 | 17 | from natsort import natsorted 18 | from glob import glob 19 | from basicsr.models.archs.FPro_arch import FPro 20 | from skimage import img_as_ubyte 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Deraining using Restormer') 24 | 25 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/derain/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/FPro/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./models/derain_spad.pth', type=str, help='Path to weights') 28 | 29 | args = parser.parse_args() 30 | 31 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 32 | _, C, H, W = imgtensor.shape 33 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 34 | while hstarts and hstarts[-1] + crop_size >= H: 35 | hstarts.pop() 36 | hstarts.append(H - crop_size) 37 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 38 | while wstarts and wstarts[-1] + crop_size >= W: 39 | wstarts.pop() 40 | wstarts.append(W - crop_size) 41 | starts = [] 42 | split_data = [] 43 | for hs in hstarts: 44 | for ws in wstarts: 45 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 46 | starts.append((hs, ws)) 47 | split_data.append(cimgdata) 48 | return split_data, starts 49 | 50 | def get_scoremap(H, W, C, B=1, is_mean=True): 51 | center_h = H / 2 52 | center_w = W / 2 53 | 54 | score = torch.ones((B, C, H, W)) 55 | if not is_mean: 56 | for h in range(H): 57 | for w in range(W): 58 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 59 | return score 60 | 61 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 62 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 63 | tot_score = torch.zeros((B, C, H, W)) 64 | merge_img = torch.zeros((B, C, H, W)) 65 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 66 | for simg, cstart in zip(split_data, starts): 67 | hs, ws = cstart 68 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 69 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 70 | merge_img = merge_img / tot_score 71 | return merge_img 72 | 73 | ####### Load yaml ####### 74 | yaml_file = 'Options/Deraining_FPro_spad.yml' 75 | import yaml 76 | 77 | try: 78 | from yaml import CLoader as Loader 79 | except ImportError: 80 | from yaml import Loader 81 | 82 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 83 | 84 | s = x['network_g'].pop('type') 85 | ########################## 86 | 87 | model_restoration = FPro(**x['network_g']) 88 | 89 | checkpoint = torch.load(args.weights) 90 | model_restoration.load_state_dict(checkpoint['params']) 91 | print("===>Testing using weights: ",args.weights) 92 | model_restoration.cuda() 93 | model_restoration = nn.DataParallel(model_restoration) 94 | model_restoration.eval() 95 | 96 | 97 | factor = 8 98 | datasets = ['real_test_1000'] 99 | 100 | for dataset in datasets: 101 | result_dir = os.path.join(args.result_dir, dataset) 102 | os.makedirs(result_dir, exist_ok=True) 103 | 104 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain') 105 | inp_dir = os.path.join(args.input_dir, dataset, 'rain') 106 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 107 | with torch.no_grad(): 108 | for file_ in tqdm(files): 109 | torch.cuda.ipc_collect() 110 | torch.cuda.empty_cache() 111 | 112 | img = np.float32(utils.load_img(file_))/255. 113 | img = torch.from_numpy(img).permute(2,0,1) 114 | input_ = img.unsqueeze(0).cuda() 115 | 116 | # Padding in case images are not multiples of 8 117 | h,w = input_.shape[2], input_.shape[3] 118 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 119 | padh = H-h if h%factor!=0 else 0 120 | padw = W-w if w%factor!=0 else 0 121 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 122 | 123 | B, C, H, W = input_.shape 124 | corp_size_arg = 256 125 | overlap_size_arg = 200 126 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 127 | for i, data in enumerate(split_data): 128 | split_data[i] = model_restoration(data).cpu() 129 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 130 | 131 | restored = restored[:,:,:h,:w] 132 | 133 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 134 | 135 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 136 | -------------------------------------------------------------------------------- /Deraining/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This repository is built in PyTorch 1.8.1 and tested on Ubuntu 16.04 environment (Python3.7, CUDA10.2, cuDNN7.6). 4 | Follow these intructions 5 | 6 | 1. Clone our repository 7 | ``` 8 | git clone https://github.com/swz30/Restormer.git 9 | cd Restormer 10 | ``` 11 | 12 | 2. Make conda environment 13 | ``` 14 | conda create -n pytorch181 python=3.7 15 | conda activate pytorch181 16 | ``` 17 | 18 | 3. Install dependencies 19 | ``` 20 | conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch 21 | pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm 22 | pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips 23 | ``` 24 | 25 | 4. Install basicsr 26 | ``` 27 | python setup.py develop --no_cuda_ext 28 | ``` 29 | 30 | ### Download datasets from Google Drive 31 | 32 | To be able to download datasets automatically you would need `go` and `gdrive` installed. 33 | 34 | 1. You can install `go` with the following 35 | ``` 36 | curl -O https://storage.googleapis.com/golang/go1.11.1.linux-amd64.tar.gz 37 | mkdir -p ~/installed 38 | tar -C ~/installed -xzf go1.11.1.linux-amd64.tar.gz 39 | mkdir -p ~/go 40 | ``` 41 | 42 | 2. Add the lines in `~/.bashrc` 43 | ``` 44 | export GOPATH=$HOME/go 45 | export PATH=$PATH:$HOME/go/bin:$HOME/installed/go/bin 46 | ``` 47 | 48 | 3. Install `gdrive` using 49 | ``` 50 | go get github.com/prasmussen/gdrive 51 | ``` 52 | 53 | 4. Close current terminal and open a new terminal. 54 | -------------------------------------------------------------------------------- /Motion_Deblurring/Options/Deblurring_FPro.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Deblurring_FPro 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage_denseHaze 13 | dataroot_gt: ./Motion_Deblurring/Datasets/train/GoPro/target_crops 14 | dataroot_lq: ./Motion_Deblurring/Datasets/train/GoPro/input_crops 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | # ### -------------Progressive training-------------------------- 27 | # mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | # iters: [92000,64000,48000,36000,36000,24000] 29 | # gt_size: 384 # Max patch size for progressive training 30 | # gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 31 | mini_batch_sizes: [2] 32 | iters: [600000] 33 | gt_size: 256 34 | gt_sizes: [256] 35 | ### ------------------------------------------------------------ 36 | 37 | ### ------- Training on single fixed-patch size 128x128--------- 38 | # mini_batch_sizes: [8] 39 | # iters: [300000] 40 | # gt_size: 128 41 | # gt_sizes: [128] 42 | ### ------------------------------------------------------------ 43 | 44 | dataset_enlarge_ratio: 1 45 | prefetch_mode: ~ 46 | 47 | val: 48 | name: ValSet 49 | type: Dataset_PairedImage_denseHaze 50 | dataroot_gt: ./Motion_Deblurring/Datasets/val/GoPro/target_crops 51 | dataroot_lq: ./Motion_Deblurring/Datasets/val/GoPro/input_crops 52 | gt_size: 256 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 3 60 | out_channels: 3 61 | # input_res: 128 62 | dim: 48 63 | # num_blocks: [4,6,6,8] 64 | num_blocks: [2,3,6] 65 | # num_refinement_blocks: 4 66 | num_refinement_blocks: 2 67 | # heads: [1,2,4,8] 68 | heads: [2,4,8] 69 | # ffn_expansion_factor: 2.66 70 | ffn_expansion_factor: 3 71 | bias: False 72 | LayerNorm_type: WithBias 73 | dual_pixel_task: False 74 | # network_g: 75 | # type: Restormer 76 | # inp_channels: 3 77 | # out_channels: 3 78 | # dim: 48 79 | # num_blocks: [4,6,6,8] 80 | # num_refinement_blocks: 4 81 | # heads: [1,2,4,8] 82 | # ffn_expansion_factor: 2.66 83 | # bias: False 84 | # LayerNorm_type: WithBias 85 | # dual_pixel_task: False 86 | 87 | 88 | # path 89 | path: 90 | pretrain_network_g: ~ 91 | strict_load_g: true 92 | resume_state: ~ 93 | 94 | # training settings 95 | train: 96 | total_iter: 600000 97 | warmup_iter: -1 # no warm up 98 | use_grad_clip: true 99 | 100 | # Split 300k iterations into two cycles. 101 | # 1st cycle: fixed 3e-4 LR for 92k iters. 102 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 103 | scheduler: 104 | type: CosineAnnealingRestartCyclicLR 105 | periods: [184000, 416000] 106 | restart_weights: [1,1] 107 | eta_mins: [0.0003,0.000001] 108 | 109 | mixing_augs: 110 | mixup: false 111 | mixup_beta: 1.2 112 | use_identity: true 113 | 114 | optim_g: 115 | type: AdamW 116 | lr: !!float 3e-4 117 | weight_decay: !!float 1e-4 118 | betas: [0.9, 0.999] 119 | 120 | # losses 121 | pixel_opt: 122 | type: L1Loss 123 | loss_weight: 1 124 | reduction: mean 125 | fft_loss_opt: 126 | type: FFTLoss 127 | loss_weight: 0.1 128 | reduction: mean 129 | 130 | # validation settings 131 | val: 132 | window_size: 8 133 | val_freq: !!float 4e3 134 | save_img: false 135 | rgb2bgr: true 136 | use_image: true 137 | max_minibatch: 8 138 | 139 | metrics: 140 | psnr: # metric name, can be arbitrary 141 | type: calculate_psnr 142 | crop_border: 0 143 | test_y_channel: false 144 | 145 | # logging settings 146 | logger: 147 | print_freq: 1000 148 | save_checkpoint_freq: !!float 4e3 149 | use_tb_logger: true 150 | wandb: 151 | project: ~ 152 | resume_id: ~ 153 | 154 | # dist training settings 155 | dist_params: 156 | backend: nccl 157 | port: 29500 158 | -------------------------------------------------------------------------------- /Motion_Deblurring/evaluate_gopro_hide.m: -------------------------------------------------------------------------------- 1 | %% Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | %% Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | %% https://arxiv.org/abs/2111.09881 4 | 5 | close all;clear all; 6 | 7 | % datasets = {'GoPro'}; 8 | datasets = {'GoPro', 'HIDE'}; 9 | num_set = length(datasets); 10 | 11 | tic 12 | delete(gcp('nocreate')) 13 | parpool('local',20); 14 | 15 | for idx_set = 1:num_set 16 | file_path = strcat('./results/', datasets{idx_set}, '/'); 17 | gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/'); 18 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 19 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 20 | img_num = length(path_list); 21 | 22 | total_psnr = 0; 23 | total_ssim = 0; 24 | if img_num > 0 25 | parfor j = 1:img_num 26 | image_name = path_list(j).name; 27 | gt_name = gt_list(j).name; 28 | input = imread(strcat(file_path,image_name)); 29 | gt = imread(strcat(gt_path, gt_name)); 30 | ssim_val = ssim(input, gt); 31 | psnr_val = psnr(input, gt); 32 | total_ssim = total_ssim + ssim_val; 33 | total_psnr = total_psnr + psnr_val; 34 | end 35 | end 36 | qm_psnr = total_psnr / img_num; 37 | qm_ssim = total_ssim / img_num; 38 | 39 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 40 | 41 | end 42 | delete(gcp('nocreate')) 43 | toc 44 | -------------------------------------------------------------------------------- /Motion_Deblurring/generate_patches_gopro.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | ##### Data preparation file for training Restormer on the GoPro Dataset ######## 6 | 7 | import cv2 8 | import numpy as np 9 | from glob import glob 10 | from natsort import natsorted 11 | import os 12 | from tqdm import tqdm 13 | from pdb import set_trace as stx 14 | from joblib import Parallel, delayed 15 | import multiprocessing 16 | 17 | def train_files(file_): 18 | lr_file, hr_file = file_ 19 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0] 20 | lr_img = cv2.imread(lr_file) 21 | hr_img = cv2.imread(hr_file) 22 | num_patch = 0 23 | w, h = lr_img.shape[:2] 24 | if w > p_max and h > p_max: 25 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int)) 26 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int)) 27 | w1.append(w-patch_size) 28 | h1.append(h-patch_size) 29 | for i in w1: 30 | for j in h1: 31 | num_patch += 1 32 | 33 | lr_patch = lr_img[i:i+patch_size, j:j+patch_size,:] 34 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:] 35 | 36 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.png') 37 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.png') 38 | 39 | cv2.imwrite(lr_savename, lr_patch) 40 | cv2.imwrite(hr_savename, hr_patch) 41 | 42 | else: 43 | lr_savename = os.path.join(lr_tar, filename + '.png') 44 | hr_savename = os.path.join(hr_tar, filename + '.png') 45 | 46 | cv2.imwrite(lr_savename, lr_img) 47 | cv2.imwrite(hr_savename, hr_img) 48 | 49 | def val_files(file_): 50 | lr_file, hr_file = file_ 51 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0] 52 | lr_img = cv2.imread(lr_file) 53 | hr_img = cv2.imread(hr_file) 54 | 55 | lr_savename = os.path.join(lr_tar, filename + '.png') 56 | hr_savename = os.path.join(hr_tar, filename + '.png') 57 | 58 | w, h = lr_img.shape[:2] 59 | 60 | i = (w-val_patch_size)//2 61 | j = (h-val_patch_size)//2 62 | 63 | lr_patch = lr_img[i:i+val_patch_size, j:j+val_patch_size,:] 64 | hr_patch = hr_img[i:i+val_patch_size, j:j+val_patch_size,:] 65 | 66 | cv2.imwrite(lr_savename, lr_patch) 67 | cv2.imwrite(hr_savename, hr_patch) 68 | 69 | ############ Prepare Training data #################### 70 | num_cores = 10 71 | patch_size = 512 72 | overlap = 256 73 | p_max = 0 74 | 75 | src = '/home/ubuntu/test/datasets/deblurring/GoPro/train' 76 | tar = 'Datasets/train/GoPro' 77 | 78 | lr_tar = os.path.join(tar, 'input_crops') 79 | hr_tar = os.path.join(tar, 'target_crops') 80 | 81 | os.makedirs(lr_tar, exist_ok=True) 82 | os.makedirs(hr_tar, exist_ok=True) 83 | 84 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg'))) 85 | hr_files = natsorted(glob(os.path.join(src, 'groundtruth', '*.png')) + glob(os.path.join(src, 'groundtruth', '*.jpg'))) 86 | 87 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 88 | 89 | Parallel(n_jobs=num_cores)(delayed(train_files)(file_) for file_ in tqdm(files)) 90 | 91 | 92 | ############ Prepare validation data #################### 93 | val_patch_size = 256 94 | src = '/home/ubuntu/test/datasets/deblurring/GoPro/test' 95 | tar = 'Datasets/val/GoPro' 96 | 97 | lr_tar = os.path.join(tar, 'input_crops') 98 | hr_tar = os.path.join(tar, 'target_crops') 99 | 100 | os.makedirs(lr_tar, exist_ok=True) 101 | os.makedirs(hr_tar, exist_ok=True) 102 | 103 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg'))) 104 | hr_files = natsorted(glob(os.path.join(src, 'groundtruth', '*.png')) + glob(os.path.join(src, 'groundtruth', '*.jpg'))) 105 | 106 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 107 | 108 | Parallel(n_jobs=num_cores)(delayed(val_files)(file_) for file_ in tqdm(files)) 109 | -------------------------------------------------------------------------------- /Motion_Deblurring/test_FPro.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | import numpy as np 7 | import os 8 | import argparse 9 | from tqdm import tqdm 10 | 11 | import torch.nn as nn 12 | import torch 13 | import torch.nn.functional as F 14 | import utils 15 | 16 | from natsort import natsorted 17 | from glob import glob 18 | from basicsr.models.archs.FPro_arch import FPro 19 | from skimage import img_as_ubyte 20 | from pdb import set_trace as stx 21 | 22 | parser = argparse.ArgumentParser(description='Single Image Motion Deblurring using Restormer') 23 | 24 | parser.add_argument('--input_dir', default='/home/ubuntu13/zsh/dataset/Uformer/deblurring/', type=str, help='Directory of validation images') 25 | parser.add_argument('--result_dir', default='./results/FPro/', type=str, help='Directory for results') 26 | parser.add_argument('--weights', default='./models/deblur.pth', type=str, help='Path to weights') 27 | parser.add_argument('--dataset', default='GoPro', type=str, help='Test Dataset') # ['GoPro', 'hide', 'RealBlur_J', 'RealBlur_R'] 28 | 29 | args = parser.parse_args() 30 | 31 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 32 | _, C, H, W = imgtensor.shape 33 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 34 | while hstarts and hstarts[-1] + crop_size >= H: 35 | hstarts.pop() 36 | hstarts.append(H - crop_size) 37 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 38 | while wstarts and wstarts[-1] + crop_size >= W: 39 | wstarts.pop() 40 | wstarts.append(W - crop_size) 41 | starts = [] 42 | split_data = [] 43 | for hs in hstarts: 44 | for ws in wstarts: 45 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 46 | starts.append((hs, ws)) 47 | split_data.append(cimgdata) 48 | return split_data, starts 49 | 50 | def get_scoremap(H, W, C, B=1, is_mean=True): 51 | center_h = H / 2 52 | center_w = W / 2 53 | 54 | score = torch.ones((B, C, H, W)) 55 | if not is_mean: 56 | for h in range(H): 57 | for w in range(W): 58 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 59 | return score 60 | 61 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 62 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 63 | tot_score = torch.zeros((B, C, H, W)) 64 | merge_img = torch.zeros((B, C, H, W)) 65 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 66 | for simg, cstart in zip(split_data, starts): 67 | hs, ws = cstart 68 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 69 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 70 | merge_img = merge_img / tot_score 71 | return merge_img 72 | 73 | ####### Load yaml ####### 74 | yaml_file = 'Options/Deblurring_FPro.yml' 75 | import yaml 76 | 77 | try: 78 | from yaml import CLoader as Loader 79 | except ImportError: 80 | from yaml import Loader 81 | 82 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 83 | 84 | s = x['network_g'].pop('type') 85 | ########################## 86 | 87 | model_restoration = FPro(**x['network_g']) 88 | 89 | checkpoint = torch.load(args.weights) 90 | model_restoration.load_state_dict(checkpoint['params']) 91 | print("===>Testing using weights: ",args.weights) 92 | model_restoration.cuda() 93 | model_restoration = nn.DataParallel(model_restoration) 94 | model_restoration.eval() 95 | 96 | 97 | factor = 8 98 | dataset = args.dataset 99 | result_dir = os.path.join(args.result_dir, dataset) 100 | os.makedirs(result_dir, exist_ok=True) 101 | 102 | inp_dir = os.path.join(args.input_dir, dataset,'test', 'blur') 103 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 104 | with torch.no_grad(): 105 | for file_ in tqdm(files): 106 | torch.cuda.ipc_collect() 107 | torch.cuda.empty_cache() 108 | 109 | img = np.float32(utils.load_img(file_))/255. 110 | img = torch.from_numpy(img).permute(2,0,1) 111 | input_ = img.unsqueeze(0).cuda() 112 | 113 | B, C, H, W = input_.shape 114 | corp_size_arg = 256 115 | overlap_size_arg = 200 116 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 117 | for i, data in enumerate(split_data): 118 | split_data[i] = model_restoration(data).cpu() 119 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 120 | 121 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 122 | 123 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 124 | -------------------------------------------------------------------------------- /Motion_Deblurring/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration (ECCV 2024) 2 | 3 | [Shihao Zhou](https://joshyzhou.github.io/), [Jinshan Pan](https://jspan.github.io/), [Jinglei Shi](https://jingleishi.github.io/), [Duosheng Chen](https://github.com/Calvin11311), [Lishen Qu](https://github.com/qulishen) and [Jufeng Yang](https://cv.nankai.edu.cn/) 4 | 5 | #### News 6 | - **Jul 02, 2024:** FPro has been accepted to ECCV 2024 :tada: 7 |
8 | 9 | 10 | ## Training 11 | ### Derain 12 | To train FPro on SPAD, you can run: 13 | ```sh 14 | ./train.sh Deraining/Options/Deraining_FPro_spad.yml 15 | ``` 16 | ### Dehaze 17 | To train FPro on SOTS, you can run: 18 | ```sh 19 | ./train.sh Dehaze/Options/RealDehazing_FPro.yml 20 | ``` 21 | ### Deblur 22 | To train FPro on GoPro, you can run: 23 | ```sh 24 | ./train.sh Motion_Deblurring/Options/Deblurring_FPro.yml 25 | ``` 26 | ### Deraindrop 27 | To train FPro on AGAN, you can run: 28 | ```sh 29 | ./train.sh Deraining/Options/RealDeraindrop_FPro.yml 30 | ``` 31 | ### Demoire 32 | To train FPro on TIP18, you can run: 33 | ```sh 34 | ./train.sh Demoiring/Options/RealDemoiring_FPro.yml 35 | ``` 36 | 37 | ## Evaluation 38 | To evaluate FPro, you can refer commands in 'test.sh' 39 | 40 | For evaluate on each dataset, you should uncomment corresponding line. 41 | 42 | 43 | ## Results 44 | Experiments are performed for different image processing tasks including, rain streak removal, raindrop removal, haze removal, motion blur removal, and moire pattern removal. 45 | Here is a summary table containing hyperlinks for easy navigation: 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
BenchmarkPretrained modelVisual Results
SPAD(code:gd8j)(code:ntgp)
AGAN(code:dqml)(code:ul55)
SOTS(code:aagq)(code:9ssj)
GoPro(code:lhds)(code:764e)
TIP18(code:l13v)(code:9und)
79 | 80 | 81 | ## Citation 82 | If you find this project useful, please consider citing: 83 | 84 | @inproceedings{zhou_ECCV2024_FPro, 85 | title={Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration}, 86 | author={Zhou, Shihao and Pan, Jinshan and Shi, Jinglei and Chen, Duosheng and Qu, Lishen and Yang, Jufeng}, 87 | booktitle={ECCV}, 88 | year={2024} 89 | } 90 | 91 | ## Acknowledgement 92 | 93 | This code borrows heavily from [Restormer](https://github.com/swz30/Restormer). -------------------------------------------------------------------------------- /basicsr/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/.DS_Store -------------------------------------------------------------------------------- /basicsr/__pycache__/version.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/__pycache__/version.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/.DS_Store -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | 13 | __all__ = ['create_dataset', 'create_dataloader'] 14 | 15 | # automatically scan and import dataset modules 16 | # scan all the files under the data folder with '_dataset' in file names 17 | data_folder = osp.dirname(osp.abspath(__file__)) 18 | dataset_filenames = [ 19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 20 | if v.endswith('_dataset.py') 21 | ] 22 | # import all the dataset modules 23 | _dataset_modules = [ 24 | importlib.import_module(f'basicsr.data.{file_name}') 25 | for file_name in dataset_filenames 26 | ] 27 | 28 | 29 | def create_dataset(dataset_opt): 30 | """Create dataset. 31 | 32 | Args: 33 | dataset_opt (dict): Configuration for dataset. It constains: 34 | name (str): Dataset name. 35 | type (str): Dataset type. 36 | """ 37 | dataset_type = dataset_opt['type'] 38 | 39 | # dynamic instantiation 40 | for module in _dataset_modules: 41 | dataset_cls = getattr(module, dataset_type, None) 42 | if dataset_cls is not None: 43 | break 44 | if dataset_cls is None: 45 | raise ValueError(f'Dataset {dataset_type} is not found.') 46 | 47 | dataset = dataset_cls(dataset_opt) 48 | 49 | logger = get_root_logger() 50 | logger.info( 51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 52 | 'is created.') 53 | return dataset 54 | 55 | 56 | def create_dataloader(dataset, 57 | dataset_opt, 58 | num_gpu=1, 59 | dist=False, 60 | sampler=None, 61 | seed=None): 62 | """Create dataloader. 63 | 64 | Args: 65 | dataset (torch.utils.data.Dataset): Dataset. 66 | dataset_opt (dict): Dataset options. It contains the following keys: 67 | phase (str): 'train' or 'val'. 68 | num_worker_per_gpu (int): Number of workers for each GPU. 69 | batch_size_per_gpu (int): Training batch size for each GPU. 70 | num_gpu (int): Number of GPUs. Used only in the train phase. 71 | Default: 1. 72 | dist (bool): Whether in distributed training. Used only in the train 73 | phase. Default: False. 74 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 75 | seed (int | None): Seed. Default: None 76 | """ 77 | phase = dataset_opt['phase'] 78 | rank, _ = get_dist_info() 79 | if phase == 'train': 80 | if dist: # distributed training 81 | batch_size = dataset_opt['batch_size_per_gpu'] 82 | num_workers = dataset_opt['num_worker_per_gpu'] 83 | else: # non-distributed training 84 | multiplier = 1 if num_gpu == 0 else num_gpu 85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 87 | dataloader_args = dict( 88 | dataset=dataset, 89 | batch_size=batch_size, 90 | shuffle=False, 91 | num_workers=num_workers, 92 | sampler=sampler, 93 | drop_last=True) 94 | if sampler is None: 95 | dataloader_args['shuffle'] = True 96 | dataloader_args['worker_init_fn'] = partial( 97 | worker_init_fn, num_workers=num_workers, rank=rank, 98 | seed=seed) if seed is not None else None 99 | elif phase in ['val', 'test']: # validation 100 | dataloader_args = dict( 101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 102 | else: 103 | raise ValueError(f'Wrong dataset phase: {phase}. ' 104 | "Supported ones are 'train', 'val' and 'test'.") 105 | 106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 107 | 108 | prefetch_mode = dataset_opt.get('prefetch_mode') 109 | if prefetch_mode == 'cpu': # CPUPrefetcher 110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 111 | logger = get_root_logger() 112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 113 | f'num_prefetch_queue = {num_prefetch_queue}') 114 | return PrefetchDataLoader( 115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 116 | else: 117 | # prefetch_mode=None: Normal dataloader 118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 119 | return torch.utils.data.DataLoader(**dataloader_args) 120 | 121 | 122 | def worker_init_fn(worker_id, num_workers, rank, seed): 123 | # Set the worker seed to num_workers * rank + worker_id + seed 124 | worker_seed = num_workers * rank + worker_id + seed 125 | np.random.seed(worker_seed) 126 | random.seed(worker_seed) 127 | -------------------------------------------------------------------------------- /basicsr/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/data_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/data_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/data_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/data_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/reds_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/reds_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil( 27 | len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank:self.total_size:self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | 8 | 9 | class FFHQDataset(data.Dataset): 10 | """FFHQ dataset for StyleGAN. 11 | 12 | Args: 13 | opt (dict): Config for train datasets. It contains the following keys: 14 | dataroot_gt (str): Data root path for gt. 15 | io_backend (dict): IO backend type and other kwarg. 16 | mean (list | tuple): Image mean. 17 | std (list | tuple): Image std. 18 | use_hflip (bool): Whether to horizontally flip. 19 | 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(FFHQDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | 29 | self.gt_folder = opt['dataroot_gt'] 30 | self.mean = opt['mean'] 31 | self.std = opt['std'] 32 | 33 | if self.io_backend_opt['type'] == 'lmdb': 34 | self.io_backend_opt['db_paths'] = self.gt_folder 35 | if not self.gt_folder.endswith('.lmdb'): 36 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 37 | f'but received {self.gt_folder}') 38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 39 | self.paths = [line.split('.')[0] for line in fin] 40 | else: 41 | # FFHQ has 70000 images in total 42 | self.paths = [ 43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 44 | ] 45 | 46 | def __getitem__(self, index): 47 | if self.file_client is None: 48 | self.file_client = FileClient( 49 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | img_bytes = self.file_client.get(gt_path) 54 | img_gt = imfrombytes(img_bytes, float32=True) 55 | 56 | # random horizontal flip 57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 58 | # BGR to RGB, HWC to CHW, numpy to tensor 59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 60 | # normalize 61 | normalize(img_gt, self.mean, self.std, inplace=True) 62 | return {'gt': img_gt, 'gt_path': gt_path} 63 | 64 | def __len__(self): 65 | return len(self.paths) 66 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 8 | 9 | 10 | class Vimeo90KDataset(data.Dataset): 11 | """Vimeo90K dataset for training. 12 | 13 | The keys are generated from a meta info txt file. 14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 15 | 16 | Each line contains: 17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space. 18 | Examples: 19 | 00001/0001 7 (256,448,3) 20 | 00001/0002 7 (256,448,3) 21 | 22 | Key examples: "00001/0001" 23 | GT (gt): Ground-Truth; 24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 25 | 26 | The neighboring frame list for different num_frame: 27 | num_frame | frame list 28 | 1 | 4 29 | 3 | 3,4,5 30 | 5 | 2,3,4,5,6 31 | 7 | 1,2,3,4,5,6,7 32 | 33 | Args: 34 | opt (dict): Config for train dataset. It contains the following keys: 35 | dataroot_gt (str): Data root path for gt. 36 | dataroot_lq (str): Data root path for lq. 37 | meta_info_file (str): Path for meta information file. 38 | io_backend (dict): IO backend type and other kwarg. 39 | 40 | num_frame (int): Window size for input frames. 41 | gt_size (int): Cropped patched size for gt patches. 42 | random_reverse (bool): Random reverse input frames. 43 | use_flip (bool): Use horizontal flips. 44 | use_rot (bool): Use rotation (use vertical flip and transposing h 45 | and w for implementation). 46 | 47 | scale (bool): Scale, which will be added automatically. 48 | """ 49 | 50 | def __init__(self, opt): 51 | super(Vimeo90KDataset, self).__init__() 52 | self.opt = opt 53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( 54 | opt['dataroot_lq']) 55 | 56 | with open(opt['meta_info_file'], 'r') as fin: 57 | self.keys = [line.split(' ')[0] for line in fin] 58 | 59 | # file client (io backend) 60 | self.file_client = None 61 | self.io_backend_opt = opt['io_backend'] 62 | self.is_lmdb = False 63 | if self.io_backend_opt['type'] == 'lmdb': 64 | self.is_lmdb = True 65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 66 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 67 | 68 | # indices of input images 69 | self.neighbor_list = [ 70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) 71 | ] 72 | 73 | # temporal augmentation configs 74 | self.random_reverse = opt['random_reverse'] 75 | logger = get_root_logger() 76 | logger.info(f'Random reverse is {self.random_reverse}.') 77 | 78 | def __getitem__(self, index): 79 | if self.file_client is None: 80 | self.file_client = FileClient( 81 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 82 | 83 | # random reverse 84 | if self.random_reverse and random.random() < 0.5: 85 | self.neighbor_list.reverse() 86 | 87 | scale = self.opt['scale'] 88 | gt_size = self.opt['gt_size'] 89 | key = self.keys[index] 90 | clip, seq = key.split('/') # key example: 00001/0001 91 | 92 | # get the GT frame (im4.png) 93 | if self.is_lmdb: 94 | img_gt_path = f'{key}/im4' 95 | else: 96 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 97 | img_bytes = self.file_client.get(img_gt_path, 'gt') 98 | img_gt = imfrombytes(img_bytes, float32=True) 99 | 100 | # get the neighboring LQ frames 101 | img_lqs = [] 102 | for neighbor in self.neighbor_list: 103 | if self.is_lmdb: 104 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 105 | else: 106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 107 | img_bytes = self.file_client.get(img_lq_path, 'lq') 108 | img_lq = imfrombytes(img_bytes, float32=True) 109 | img_lqs.append(img_lq) 110 | 111 | # randomly crop 112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, 113 | img_gt_path) 114 | 115 | # augmentation - flip, rotate 116 | img_lqs.append(img_gt) 117 | img_results = augment(img_lqs, self.opt['use_flip'], 118 | self.opt['use_rot']) 119 | 120 | img_results = img2tensor(img_results) 121 | img_lqs = torch.stack(img_results[0:-1], dim=0) 122 | img_gt = img_results[-1] 123 | 124 | # img_lqs: (t, c, h, w) 125 | # img_gt: (c, h, w) 126 | # key: str 127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 128 | 129 | def __len__(self): 130 | return len(self.keys) 131 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .niqe import calculate_niqe 2 | from .psnr_ssim import calculate_psnr, calculate_ssim 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 5 | -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/metric_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/niqe.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/niqe.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.models.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', 11 | resize_input=True, 12 | normalize_input=False): 13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 14 | # does resize the input. 15 | inception = InceptionV3([3], 16 | resize_input=resize_input, 17 | normalize_input=normalize_input) 18 | inception = nn.DataParallel(inception).eval().to(device) 19 | return inception 20 | 21 | 22 | @torch.no_grad() 23 | def extract_inception_features(data_generator, 24 | inception, 25 | len_generator=None, 26 | device='cuda'): 27 | """Extract inception features. 28 | 29 | Args: 30 | data_generator (generator): A data generator. 31 | inception (nn.Module): Inception model. 32 | len_generator (int): Length of the data_generator to show the 33 | progressbar. Default: None. 34 | device (str): Device. Default: cuda. 35 | 36 | Returns: 37 | Tensor: Extracted features. 38 | """ 39 | if len_generator is not None: 40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 41 | else: 42 | pbar = None 43 | features = [] 44 | 45 | for data in data_generator: 46 | if pbar: 47 | pbar.update(1) 48 | data = data.to(device) 49 | feature = inception(data)[0].view(data.shape[0], -1) 50 | features.append(feature.to('cpu')) 51 | if pbar: 52 | pbar.close() 53 | features = torch.cat(features, 0) 54 | return features 55 | 56 | 57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | 60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 61 | and X_2 ~ N(mu_2, C_2) is 62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 63 | Stable version by Dougal J. Sutherland. 64 | 65 | Args: 66 | mu1 (np.array): The sample mean over activations. 67 | sigma1 (np.array): The covariance matrix over activations for 68 | generated samples. 69 | mu2 (np.array): The sample mean over activations, precalculated on an 70 | representative data set. 71 | sigma2 (np.array): The covariance matrix over activations, 72 | precalculated on an representative data set. 73 | 74 | Returns: 75 | float: The Frechet Distance. 76 | """ 77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 78 | assert sigma1.shape == sigma2.shape, ( 79 | 'Two covariances have different dimensions') 80 | 81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 82 | 83 | # Product might be almost singular 84 | if not np.isfinite(cov_sqrt).all(): 85 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 86 | 'of cov estimates') 87 | offset = np.eye(sigma1.shape[0]) * eps 88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 89 | 90 | # Numerical error might give slight imaginary component 91 | if np.iscomplexobj(cov_sqrt): 92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 93 | m = np.max(np.abs(cov_sqrt.imag)) 94 | raise ValueError(f'Imaginary component {m}') 95 | cov_sqrt = cov_sqrt.real 96 | 97 | mean_diff = mu1 - mu2 98 | mean_norm = mean_diff @ mean_diff 99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 100 | fid = mean_norm + trace 101 | 102 | return fid 103 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | from scipy.ndimage.filters import convolve 5 | from scipy.special import gamma 6 | 7 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 8 | 9 | 10 | def estimate_aggd_param(block): 11 | """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters. 12 | 13 | Args: 14 | block (ndarray): 2D Image block. 15 | 16 | Returns: 17 | tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD 18 | distribution (Estimating the parames in Equation 7 in the paper). 19 | """ 20 | block = block.flatten() 21 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801 22 | gam_reciprocal = np.reciprocal(gam) 23 | r_gam = np.square(gamma(gam_reciprocal * 2)) / ( 24 | gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) 25 | 26 | left_std = np.sqrt(np.mean(block[block < 0]**2)) 27 | right_std = np.sqrt(np.mean(block[block > 0]**2)) 28 | gammahat = left_std / right_std 29 | rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) 30 | rhatnorm = (rhat * (gammahat**3 + 1) * 31 | (gammahat + 1)) / ((gammahat**2 + 1)**2) 32 | array_position = np.argmin((r_gam - rhatnorm)**2) 33 | 34 | alpha = gam[array_position] 35 | beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) 36 | beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) 37 | return (alpha, beta_l, beta_r) 38 | 39 | 40 | def compute_feature(block): 41 | """Compute features. 42 | 43 | Args: 44 | block (ndarray): 2D Image block. 45 | 46 | Returns: 47 | list: Features with length of 18. 48 | """ 49 | feat = [] 50 | alpha, beta_l, beta_r = estimate_aggd_param(block) 51 | feat.extend([alpha, (beta_l + beta_r) / 2]) 52 | 53 | # distortions disturb the fairly regular structure of natural images. 54 | # This deviation can be captured by analyzing the sample distribution of 55 | # the products of pairs of adjacent coefficients computed along 56 | # horizontal, vertical and diagonal orientations. 57 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] 58 | for i in range(len(shifts)): 59 | shifted_block = np.roll(block, shifts[i], axis=(0, 1)) 60 | alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) 61 | # Eq. 8 62 | mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) 63 | feat.extend([alpha, mean, beta_l, beta_r]) 64 | return feat 65 | 66 | 67 | def niqe(img, 68 | mu_pris_param, 69 | cov_pris_param, 70 | gaussian_window, 71 | block_size_h=96, 72 | block_size_w=96): 73 | """Calculate NIQE (Natural Image Quality Evaluator) metric. 74 | 75 | Ref: Making a "Completely Blind" Image Quality Analyzer. 76 | This implementation could produce almost the same results as the official 77 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip 78 | 79 | Note that we do not include block overlap height and width, since they are 80 | always 0 in the official implementation. 81 | 82 | For good performance, it is advisable by the official implemtation to 83 | divide the distorted image in to the same size patched as used for the 84 | construction of multivariate Gaussian model. 85 | 86 | Args: 87 | img (ndarray): Input image whose quality needs to be computed. The 88 | image must be a gray or Y (of YCbCr) image with shape (h, w). 89 | Range [0, 255] with float type. 90 | mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian 91 | model calculated on the pristine dataset. 92 | cov_pris_param (ndarray): Covariance of a pre-defined multivariate 93 | Gaussian model calculated on the pristine dataset. 94 | gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the 95 | image. 96 | block_size_h (int): Height of the blocks in to which image is divided. 97 | Default: 96 (the official recommended value). 98 | block_size_w (int): Width of the blocks in to which image is divided. 99 | Default: 96 (the official recommended value). 100 | """ 101 | assert img.ndim == 2, ( 102 | 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).') 103 | # crop image 104 | h, w = img.shape 105 | num_block_h = math.floor(h / block_size_h) 106 | num_block_w = math.floor(w / block_size_w) 107 | img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] 108 | 109 | distparam = [] # dist param is actually the multiscale features 110 | for scale in (1, 2): # perform on two scales (1, 2) 111 | mu = convolve(img, gaussian_window, mode='nearest') 112 | sigma = np.sqrt( 113 | np.abs( 114 | convolve(np.square(img), gaussian_window, mode='nearest') - 115 | np.square(mu))) 116 | # normalize, as in Eq. 1 in the paper 117 | img_nomalized = (img - mu) / (sigma + 1) 118 | 119 | feat = [] 120 | for idx_w in range(num_block_w): 121 | for idx_h in range(num_block_h): 122 | # process ecah block 123 | block = img_nomalized[idx_h * block_size_h // 124 | scale:(idx_h + 1) * block_size_h // 125 | scale, idx_w * block_size_w // 126 | scale:(idx_w + 1) * block_size_w // 127 | scale] 128 | feat.append(compute_feature(block)) 129 | 130 | distparam.append(np.array(feat)) 131 | # TODO: matlab bicubic downsample with anti-aliasing 132 | # for simplicity, now we use opencv instead, which will result in 133 | # a slight difference. 134 | if scale == 1: 135 | h, w = img.shape 136 | img = cv2.resize( 137 | img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) 138 | img = img * 255. 139 | 140 | distparam = np.concatenate(distparam, axis=1) 141 | 142 | # fit a MVG (multivariate Gaussian) model to distorted patch features 143 | mu_distparam = np.nanmean(distparam, axis=0) 144 | # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html 145 | distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] 146 | cov_distparam = np.cov(distparam_no_nan, rowvar=False) 147 | 148 | # compute niqe quality, Eq. 10 in the paper 149 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) 150 | quality = np.matmul( 151 | np.matmul((mu_pris_param - mu_distparam), invcov_param), 152 | np.transpose((mu_pris_param - mu_distparam))) 153 | quality = np.sqrt(quality) 154 | 155 | return quality 156 | 157 | 158 | def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): 159 | """Calculate NIQE (Natural Image Quality Evaluator) metric. 160 | 161 | Ref: Making a "Completely Blind" Image Quality Analyzer. 162 | This implementation could produce almost the same results as the official 163 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip 164 | 165 | We use the official params estimated from the pristine dataset. 166 | We use the recommended block size (96, 96) without overlaps. 167 | 168 | Args: 169 | img (ndarray): Input image whose quality needs to be computed. 170 | The input image must be in range [0, 255] with float/int type. 171 | The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) 172 | If the input order is 'HWC' or 'CHW', it will be converted to gray 173 | or Y (of YCbCr) image according to the ``convert_to`` argument. 174 | crop_border (int): Cropped pixels in each edge of an image. These 175 | pixels are not involved in the metric calculation. 176 | input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. 177 | Default: 'HWC'. 178 | convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'. 179 | Default: 'y'. 180 | 181 | Returns: 182 | float: NIQE result. 183 | """ 184 | 185 | # we use the official params estimated from the pristine dataset. 186 | niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz') 187 | mu_pris_param = niqe_pris_params['mu_pris_param'] 188 | cov_pris_param = niqe_pris_params['cov_pris_param'] 189 | gaussian_window = niqe_pris_params['gaussian_window'] 190 | 191 | img = img.astype(np.float32) 192 | if input_order != 'HW': 193 | img = reorder_image(img, input_order=input_order) 194 | if convert_to == 'y': 195 | img = to_y_channel(img) 196 | elif convert_to == 'gray': 197 | img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. 198 | img = np.squeeze(img) 199 | 200 | if crop_border != 0: 201 | img = img[crop_border:-crop_border, crop_border:-crop_border] 202 | 203 | niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) 204 | 205 | return niqe_result 206 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/.DS_Store -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /basicsr/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss,FFTLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss','FFTLoss', 5 | ] 6 | -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/losses/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from basicsr.models.losses.loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | # @weighted_loss 22 | # def charbonnier_loss(pred, target, eps=1e-12): 23 | # return torch.sqrt((pred - target)**2 + eps) 24 | 25 | 26 | class L1Loss(nn.Module): 27 | """L1 (mean absolute error, MAE) loss. 28 | 29 | Args: 30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 31 | reduction (str): Specifies the reduction to apply to the output. 32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 33 | """ 34 | 35 | def __init__(self, loss_weight=1.0, reduction='mean'): 36 | super(L1Loss, self).__init__() 37 | if reduction not in ['none', 'mean', 'sum']: 38 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 39 | f'Supported ones are: {_reduction_modes}') 40 | 41 | self.loss_weight = loss_weight 42 | self.reduction = reduction 43 | 44 | def forward(self, pred, target, weight=None, **kwargs): 45 | """ 46 | Args: 47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 50 | weights. Default: None. 51 | """ 52 | return self.loss_weight * l1_loss( 53 | pred, target, weight, reduction=self.reduction) 54 | 55 | 56 | class FFTLoss(nn.Module): 57 | """L1 loss in frequency domain with FFT. 58 | 59 | Args: 60 | loss_weight (float): Loss weight for FFT loss. Default: 1.0. 61 | reduction (str): Specifies the reduction to apply to the output. 62 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 63 | """ 64 | 65 | def __init__(self, loss_weight=1.0, reduction='mean'): 66 | super(FFTLoss, self).__init__() 67 | if reduction not in ['none', 'mean', 'sum']: 68 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') 69 | 70 | self.loss_weight = loss_weight 71 | self.reduction = reduction 72 | 73 | def forward(self, pred, target, weight=None, **kwargs): 74 | """ 75 | Args: 76 | pred (Tensor): of shape (..., C, H, W). Predicted tensor. 77 | target (Tensor): of shape (..., C, H, W). Ground truth tensor. 78 | weight (Tensor, optional): of shape (..., C, H, W). Element-wise 79 | weights. Default: None. 80 | """ 81 | 82 | pred_fft = torch.fft.fft2(pred, dim=(-2, -1)) 83 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1) 84 | target_fft = torch.fft.fft2(target, dim=(-2, -1)) 85 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1) 86 | return self.loss_weight * l1_loss(pred_fft, target_fft, weight, reduction=self.reduction) 87 | 88 | 89 | class MSELoss(nn.Module): 90 | """MSE (L2) loss. 91 | 92 | Args: 93 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 94 | reduction (str): Specifies the reduction to apply to the output. 95 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 96 | """ 97 | 98 | def __init__(self, loss_weight=1.0, reduction='mean'): 99 | super(MSELoss, self).__init__() 100 | if reduction not in ['none', 'mean', 'sum']: 101 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 102 | f'Supported ones are: {_reduction_modes}') 103 | 104 | self.loss_weight = loss_weight 105 | self.reduction = reduction 106 | 107 | def forward(self, pred, target, weight=None, **kwargs): 108 | """ 109 | Args: 110 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 111 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 112 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 113 | weights. Default: None. 114 | """ 115 | return self.loss_weight * mse_loss( 116 | pred, target, weight, reduction=self.reduction) 117 | 118 | class PSNRLoss(nn.Module): 119 | 120 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 121 | super(PSNRLoss, self).__init__() 122 | assert reduction == 'mean' 123 | self.loss_weight = loss_weight 124 | self.scale = 10 / np.log(10) 125 | self.toY = toY 126 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 127 | self.first = True 128 | 129 | def forward(self, pred, target): 130 | assert len(pred.size()) == 4 131 | if self.toY: 132 | if self.first: 133 | self.coef = self.coef.to(pred.device) 134 | self.first = False 135 | 136 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 137 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 138 | 139 | pred, target = pred / 255., target / 255. 140 | pass 141 | assert len(pred.size()) == 4 142 | 143 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 144 | 145 | class CharbonnierLoss(nn.Module): 146 | """Charbonnier Loss (L1)""" 147 | 148 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): 149 | super(CharbonnierLoss, self).__init__() 150 | self.eps = eps 151 | 152 | def forward(self, x, y): 153 | diff = x - y 154 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 155 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 156 | return loss 157 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | import torch 5 | 6 | 7 | class MultiStepRestartLR(_LRScheduler): 8 | """ MultiStep with restarts learning rate scheme. 9 | 10 | Args: 11 | optimizer (torch.nn.optimizer): Torch optimizer. 12 | milestones (list): Iterations that will decrease learning rate. 13 | gamma (float): Decrease ratio. Default: 0.1. 14 | restarts (list): Restart iterations. Default: [0]. 15 | restart_weights (list): Restart weights at each restart iteration. 16 | Default: [1]. 17 | last_epoch (int): Used in _LRScheduler. Default: -1. 18 | """ 19 | 20 | def __init__(self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | restarts=(0, ), 25 | restart_weights=(1, ), 26 | last_epoch=-1): 27 | self.milestones = Counter(milestones) 28 | self.gamma = gamma 29 | self.restarts = restarts 30 | self.restart_weights = restart_weights 31 | assert len(self.restarts) == len( 32 | self.restart_weights), 'restarts and their weights do not match.' 33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch in self.restarts: 37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 38 | return [ 39 | group['initial_lr'] * weight 40 | for group in self.optimizer.param_groups 41 | ] 42 | if self.last_epoch not in self.milestones: 43 | return [group['lr'] for group in self.optimizer.param_groups] 44 | return [ 45 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 46 | for group in self.optimizer.param_groups 47 | ] 48 | 49 | class LinearLR(_LRScheduler): 50 | """ 51 | 52 | Args: 53 | optimizer (torch.nn.optimizer): Torch optimizer. 54 | milestones (list): Iterations that will decrease learning rate. 55 | gamma (float): Decrease ratio. Default: 0.1. 56 | last_epoch (int): Used in _LRScheduler. Default: -1. 57 | """ 58 | 59 | def __init__(self, 60 | optimizer, 61 | total_iter, 62 | last_epoch=-1): 63 | self.total_iter = total_iter 64 | super(LinearLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | process = self.last_epoch / self.total_iter 68 | weight = (1 - process) 69 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 70 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 71 | 72 | class VibrateLR(_LRScheduler): 73 | """ 74 | 75 | Args: 76 | optimizer (torch.nn.optimizer): Torch optimizer. 77 | milestones (list): Iterations that will decrease learning rate. 78 | gamma (float): Decrease ratio. Default: 0.1. 79 | last_epoch (int): Used in _LRScheduler. Default: -1. 80 | """ 81 | 82 | def __init__(self, 83 | optimizer, 84 | total_iter, 85 | last_epoch=-1): 86 | self.total_iter = total_iter 87 | super(VibrateLR, self).__init__(optimizer, last_epoch) 88 | 89 | def get_lr(self): 90 | process = self.last_epoch / self.total_iter 91 | 92 | f = 0.1 93 | if process < 3 / 8: 94 | f = 1 - process * 8 / 3 95 | elif process < 5 / 8: 96 | f = 0.2 97 | 98 | T = self.total_iter // 80 99 | Th = T // 2 100 | 101 | t = self.last_epoch % T 102 | 103 | f2 = t / Th 104 | if t >= Th: 105 | f2 = 2 - f2 106 | 107 | weight = f * f2 108 | 109 | if self.last_epoch < Th: 110 | weight = max(0.1, weight) 111 | 112 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 113 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 114 | 115 | def get_position_from_periods(iteration, cumulative_period): 116 | """Get the position from a period list. 117 | 118 | It will return the index of the right-closest number in the period list. 119 | For example, the cumulative_period = [100, 200, 300, 400], 120 | if iteration == 50, return 0; 121 | if iteration == 210, return 2; 122 | if iteration == 300, return 2. 123 | 124 | Args: 125 | iteration (int): Current iteration. 126 | cumulative_period (list[int]): Cumulative period list. 127 | 128 | Returns: 129 | int: The position of the right-closest number in the period list. 130 | """ 131 | for i, period in enumerate(cumulative_period): 132 | if iteration <= period: 133 | return i 134 | 135 | 136 | class CosineAnnealingRestartLR(_LRScheduler): 137 | """ Cosine annealing with restarts learning rate scheme. 138 | 139 | An example of config: 140 | periods = [10, 10, 10, 10] 141 | restart_weights = [1, 0.5, 0.5, 0.5] 142 | eta_min=1e-7 143 | 144 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 145 | scheduler will restart with the weights in restart_weights. 146 | 147 | Args: 148 | optimizer (torch.nn.optimizer): Torch optimizer. 149 | periods (list): Period for each cosine anneling cycle. 150 | restart_weights (list): Restart weights at each restart iteration. 151 | Default: [1]. 152 | eta_min (float): The mimimum lr. Default: 0. 153 | last_epoch (int): Used in _LRScheduler. Default: -1. 154 | """ 155 | 156 | def __init__(self, 157 | optimizer, 158 | periods, 159 | restart_weights=(1, ), 160 | eta_min=0, 161 | last_epoch=-1): 162 | self.periods = periods 163 | self.restart_weights = restart_weights 164 | self.eta_min = eta_min 165 | assert (len(self.periods) == len(self.restart_weights) 166 | ), 'periods and restart_weights should have the same length.' 167 | self.cumulative_period = [ 168 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 169 | ] 170 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 171 | 172 | def get_lr(self): 173 | idx = get_position_from_periods(self.last_epoch, 174 | self.cumulative_period) 175 | current_weight = self.restart_weights[idx] 176 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 177 | current_period = self.periods[idx] 178 | 179 | return [ 180 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 181 | (1 + math.cos(math.pi * ( 182 | (self.last_epoch - nearest_restart) / current_period))) 183 | for base_lr in self.base_lrs 184 | ] 185 | 186 | class CosineAnnealingRestartCyclicLR(_LRScheduler): 187 | """ Cosine annealing with restarts learning rate scheme. 188 | An example of config: 189 | periods = [10, 10, 10, 10] 190 | restart_weights = [1, 0.5, 0.5, 0.5] 191 | eta_min=1e-7 192 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 193 | scheduler will restart with the weights in restart_weights. 194 | Args: 195 | optimizer (torch.nn.optimizer): Torch optimizer. 196 | periods (list): Period for each cosine anneling cycle. 197 | restart_weights (list): Restart weights at each restart iteration. 198 | Default: [1]. 199 | eta_min (float): The mimimum lr. Default: 0. 200 | last_epoch (int): Used in _LRScheduler. Default: -1. 201 | """ 202 | 203 | def __init__(self, 204 | optimizer, 205 | periods, 206 | restart_weights=(1, ), 207 | eta_mins=(0, ), 208 | last_epoch=-1): 209 | self.periods = periods 210 | self.restart_weights = restart_weights 211 | self.eta_mins = eta_mins 212 | assert (len(self.periods) == len(self.restart_weights) 213 | ), 'periods and restart_weights should have the same length.' 214 | self.cumulative_period = [ 215 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 216 | ] 217 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch) 218 | 219 | def get_lr(self): 220 | idx = get_position_from_periods(self.last_epoch, 221 | self.cumulative_period) 222 | current_weight = self.restart_weights[idx] 223 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 224 | current_period = self.periods[idx] 225 | eta_min = self.eta_mins[idx] 226 | 227 | return [ 228 | eta_min + current_weight * 0.5 * (base_lr - eta_min) * 229 | (1 + math.cos(math.pi * ( 230 | (self.last_epoch - nearest_restart) / current_period))) 231 | for base_lr in self.base_lrs 232 | ] 233 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import create_dataloader, create_dataset 6 | from basicsr.models import create_model 7 | from basicsr.train import parse_options 8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 9 | make_exp_dirs) 10 | from basicsr.utils.options import dict2str 11 | 12 | 13 | def main(): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt = parse_options(is_train=False) 16 | 17 | torch.backends.cudnn.benchmark = True 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | # mkdir and initialize loggers 21 | make_exp_dirs(opt) 22 | log_file = osp.join(opt['path']['log'], 23 | f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger( 25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 26 | logger.info(get_env_info()) 27 | logger.info(dict2str(opt)) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for phase, dataset_opt in sorted(opt['datasets'].items()): 32 | test_set = create_dataset(dataset_opt) 33 | test_loader = create_dataloader( 34 | test_set, 35 | dataset_opt, 36 | num_gpu=opt['num_gpu'], 37 | dist=opt['dist'], 38 | sampler=None, 39 | seed=opt['manual_seed']) 40 | logger.info( 41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 42 | test_loaders.append(test_loader) 43 | 44 | # create model 45 | model = create_model(opt) 46 | 47 | for test_loader in test_loaders: 48 | test_set_name = test_loader.dataset.opt['name'] 49 | logger.info(f'Testing {test_set_name}...') 50 | rgb2bgr = opt['val'].get('rgb2bgr', True) 51 | # wheather use uint8 image to compute metrics 52 | use_image = opt['val'].get('use_image', True) 53 | model.validation( 54 | test_loader, 55 | current_iter=opt['name'], 56 | tb_logger=None, 57 | save_img=opt['val']['save_img'], 58 | rgb2bgr=rgb2bgr, use_image=use_image) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/dist_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/dist_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/file_client.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/file_client.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/flow_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/flow_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/img_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/img_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /basicsr/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError( 40 | 'Please install memcached to enable MemcachedBackend.') 41 | 42 | self.server_list_cfg = server_list_cfg 43 | self.client_cfg = client_cfg 44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, 45 | self.client_cfg) 46 | # mc.pyvector servers as a point which points to a memory cache 47 | self._mc_buffer = mc.pyvector() 48 | 49 | def get(self, filepath): 50 | filepath = str(filepath) 51 | import mc 52 | self._client.Get(filepath, self._mc_buffer) 53 | value_buf = mc.ConvertBuffer(self._mc_buffer) 54 | return value_buf 55 | 56 | def get_text(self, filepath): 57 | raise NotImplementedError 58 | 59 | 60 | class HardDiskBackend(BaseStorageBackend): 61 | """Raw hard disks storage backend.""" 62 | 63 | def get(self, filepath): 64 | filepath = str(filepath) 65 | with open(filepath, 'rb') as f: 66 | value_buf = f.read() 67 | return value_buf 68 | 69 | def get_text(self, filepath): 70 | filepath = str(filepath) 71 | with open(filepath, 'r') as f: 72 | value_buf = f.read() 73 | return value_buf 74 | 75 | 76 | class LmdbBackend(BaseStorageBackend): 77 | """Lmdb storage backend. 78 | 79 | Args: 80 | db_paths (str | list[str]): Lmdb database paths. 81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 82 | readonly (bool, optional): Lmdb environment parameter. If True, 83 | disallow any write operations. Default: True. 84 | lock (bool, optional): Lmdb environment parameter. If False, when 85 | concurrent access occurs, do not lock the database. Default: False. 86 | readahead (bool, optional): Lmdb environment parameter. If False, 87 | disable the OS filesystem readahead mechanism, which may improve 88 | random read performance when a database is larger than RAM. 89 | Default: False. 90 | 91 | Attributes: 92 | db_paths (list): Lmdb database path. 93 | _client (list): A list of several lmdb envs. 94 | """ 95 | 96 | def __init__(self, 97 | db_paths, 98 | client_keys='default', 99 | readonly=True, 100 | lock=False, 101 | readahead=False, 102 | **kwargs): 103 | try: 104 | import lmdb 105 | except ImportError: 106 | raise ImportError('Please install lmdb to enable LmdbBackend.') 107 | 108 | if isinstance(client_keys, str): 109 | client_keys = [client_keys] 110 | 111 | if isinstance(db_paths, list): 112 | self.db_paths = [str(v) for v in db_paths] 113 | elif isinstance(db_paths, str): 114 | self.db_paths = [str(db_paths)] 115 | assert len(client_keys) == len(self.db_paths), ( 116 | 'client_keys and db_paths should have the same length, ' 117 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 118 | 119 | self._client = {} 120 | 121 | for client, path in zip(client_keys, self.db_paths): 122 | self._client[client] = lmdb.open( 123 | path, 124 | readonly=readonly, 125 | lock=lock, 126 | readahead=readahead, 127 | map_size=8*1024*10485760, 128 | # max_readers=1, 129 | **kwargs) 130 | 131 | def get(self, filepath, client_key): 132 | """Get values according to the filepath from one lmdb named client_key. 133 | 134 | Args: 135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 136 | client_key (str): Used for distinguishing differnet lmdb envs. 137 | """ 138 | filepath = str(filepath) 139 | assert client_key in self._client, (f'client_key {client_key} is not ' 140 | 'in lmdb clients.') 141 | client = self._client[client_key] 142 | with client.begin(write=False) as txn: 143 | value_buf = txn.get(filepath.encode('ascii')) 144 | return value_buf 145 | 146 | def get_text(self, filepath): 147 | raise NotImplementedError 148 | 149 | 150 | class FileClient(object): 151 | """A general file client to access files in different backend. 152 | 153 | The client loads a file or text in a specified backend from its path 154 | and return it as a binary file. it can also register other backend 155 | accessor with a given name and backend class. 156 | 157 | Attributes: 158 | backend (str): The storage backend type. Options are "disk", 159 | "memcached" and "lmdb". 160 | client (:obj:`BaseStorageBackend`): The backend object. 161 | """ 162 | 163 | _backends = { 164 | 'disk': HardDiskBackend, 165 | 'memcached': MemcachedBackend, 166 | 'lmdb': LmdbBackend, 167 | } 168 | 169 | def __init__(self, backend='disk', **kwargs): 170 | if backend not in self._backends: 171 | raise ValueError( 172 | f'Backend {backend} is not supported. Currently supported ones' 173 | f' are {list(self._backends.keys())}') 174 | self.backend = backend 175 | self.client = self._backends[backend](**kwargs) 176 | 177 | def get(self, filepath, client_key='default'): 178 | # client_key is used only for lmdb, where different fileclients have 179 | # different lmdb environments. 180 | if self.backend == 'lmdb': 181 | return self.client.get(filepath, client_key) 182 | else: 183 | return self.client.get(filepath) 184 | 185 | def get_text(self, filepath): 186 | return self.client.get_text(filepath) 187 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, ' 25 | f'its dimension is {cat_flow.ndim}.') 26 | assert cat_flow.shape[concat_axis] % 2 == 0 27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 28 | flow = dequantize_flow(dx, dy, *args, **kwargs) 29 | else: 30 | with open(flow_path, 'rb') as f: 31 | try: 32 | header = f.read(4).decode('utf-8') 33 | except Exception: 34 | raise IOError(f'Invalid flow file: {flow_path}') 35 | else: 36 | if header != 'PIEH': 37 | raise IOError(f'Invalid flow file: {flow_path}, ' 38 | 'header does not contain PIEH') 39 | 40 | w = np.fromfile(f, np.int32, 1).squeeze() 41 | h = np.fromfile(f, np.int32, 1).squeeze() 42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 43 | 44 | return flow.astype(np.float32) 45 | 46 | 47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 48 | """Write optical flow to file. 49 | 50 | If the flow is not quantized, it will be saved as a .flo file losslessly, 51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 52 | will be concatenated horizontally into a single image if quantize is True.) 53 | 54 | Args: 55 | flow (ndarray): (h, w, 2) array of optical flow. 56 | filename (str): Output filepath. 57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 58 | images. If set to True, remaining args will be passed to 59 | :func:`quantize_flow`. 60 | concat_axis (int): The axis that dx and dy are concatenated, 61 | can be either 0 or 1. Ignored if quantize is False. 62 | """ 63 | if not quantize: 64 | with open(filename, 'wb') as f: 65 | f.write('PIEH'.encode('utf-8')) 66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 67 | flow = flow.astype(np.float32) 68 | flow.tofile(f) 69 | f.flush() 70 | else: 71 | assert concat_axis in [0, 1] 72 | dx, dy = quantize_flow(flow, *args, **kwargs) 73 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 74 | os.makedirs(filename, exist_ok=True) 75 | cv2.imwrite(dxdy, filename) 76 | 77 | 78 | def quantize_flow(flow, max_val=0.02, norm=True): 79 | """Quantize flow to [0, 255]. 80 | 81 | After this step, the size of flow will be much smaller, and can be 82 | dumped as jpeg images. 83 | 84 | Args: 85 | flow (ndarray): (h, w, 2) array of optical flow. 86 | max_val (float): Maximum value of flow, values beyond 87 | [-max_val, max_val] will be truncated. 88 | norm (bool): Whether to divide flow values by image width/height. 89 | 90 | Returns: 91 | tuple[ndarray]: Quantized dx and dy. 92 | """ 93 | h, w, _ = flow.shape 94 | dx = flow[..., 0] 95 | dy = flow[..., 1] 96 | if norm: 97 | dx = dx / w # avoid inplace operations 98 | dy = dy / h 99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 100 | flow_comps = [ 101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] 102 | ] 103 | return tuple(flow_comps) 104 | 105 | 106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 107 | """Recover from quantized flow. 108 | 109 | Args: 110 | dx (ndarray): Quantized dx. 111 | dy (ndarray): Quantized dy. 112 | max_val (float): Maximum value used when quantizing. 113 | denorm (bool): Whether to multiply flow values with width/height. 114 | 115 | Returns: 116 | ndarray: Dequantized flow. 117 | """ 118 | assert dx.shape == dy.shape 119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 120 | 121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 122 | 123 | if denorm: 124 | dx *= dx.shape[1] 125 | dy *= dx.shape[0] 126 | flow = np.dstack((dx, dy)) 127 | return flow 128 | 129 | 130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 131 | """Quantize an array of (-inf, inf) to [0, levels-1]. 132 | 133 | Args: 134 | arr (ndarray): Input array. 135 | min_val (scalar): Minimum value to be clipped. 136 | max_val (scalar): Maximum value to be clipped. 137 | levels (int): Quantization levels. 138 | dtype (np.type): The type of the quantized array. 139 | 140 | Returns: 141 | tuple: Quantized array. 142 | """ 143 | if not (isinstance(levels, int) and levels > 1): 144 | raise ValueError( 145 | f'levels must be a positive integer, but got {levels}') 146 | if min_val >= max_val: 147 | raise ValueError( 148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 149 | 150 | arr = np.clip(arr, min_val, max_val) - min_val 151 | quantized_arr = np.minimum( 152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 153 | 154 | return quantized_arr 155 | 156 | 157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 158 | """Dequantize an array. 159 | 160 | Args: 161 | arr (ndarray): Input array. 162 | min_val (scalar): Minimum value to be clipped. 163 | max_val (scalar): Maximum value to be clipped. 164 | levels (int): Quantization levels. 165 | dtype (np.type): The type of the dequantized array. 166 | 167 | Returns: 168 | tuple: Dequantized array. 169 | """ 170 | if not (isinstance(levels, int) and levels > 1): 171 | raise ValueError( 172 | f'levels must be a positive integer, but got {levels}') 173 | if min_val >= max_val: 174 | raise ValueError( 175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 176 | 177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - 178 | min_val) / levels + min_val 179 | 180 | return dequantized_arr 181 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 25 | img = torch.from_numpy(img.transpose(2, 0, 1)) 26 | if float32: 27 | img = img.float() 28 | return img 29 | 30 | if isinstance(imgs, list): 31 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 32 | else: 33 | return _totensor(imgs, bgr2rgb, float32) 34 | 35 | 36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 37 | """Convert torch Tensors into image numpy arrays. 38 | 39 | After clamping to [min, max], values will be normalized to [0, 1]. 40 | 41 | Args: 42 | tensor (Tensor or list[Tensor]): Accept shapes: 43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 44 | 2) 3D Tensor of shape (3/1 x H x W); 45 | 3) 2D Tensor of shape (H x W). 46 | Tensor channel should be in RGB order. 47 | rgb2bgr (bool): Whether to change rgb to bgr. 48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 49 | to uint8 type with range [0, 255]; otherwise, float type with 50 | range [0, 1]. Default: ``np.uint8``. 51 | min_max (tuple[int]): min and max values for clamp. 52 | 53 | Returns: 54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 55 | shape (H x W). The channel order is BGR. 56 | """ 57 | if not (torch.is_tensor(tensor) or 58 | (isinstance(tensor, list) 59 | and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError( 61 | f'tensor or list of tensors expected, got {type(tensor)}') 62 | 63 | if torch.is_tensor(tensor): 64 | tensor = [tensor] 65 | result = [] 66 | for _tensor in tensor: 67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 69 | 70 | n_dim = _tensor.dim() 71 | if n_dim == 4: 72 | img_np = make_grid( 73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 74 | normalize=False).numpy() 75 | img_np = img_np.transpose(1, 2, 0) 76 | if rgb2bgr: 77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 78 | elif n_dim == 3: 79 | img_np = _tensor.numpy() 80 | img_np = img_np.transpose(1, 2, 0) 81 | if img_np.shape[2] == 1: # gray image 82 | img_np = np.squeeze(img_np, axis=2) 83 | else: 84 | if rgb2bgr: 85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 86 | elif n_dim == 2: 87 | img_np = _tensor.numpy() 88 | else: 89 | raise TypeError('Only support 4D, 3D or 2D tensor. ' 90 | f'But received with dimension: {n_dim}') 91 | if out_type == np.uint8: 92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 93 | img_np = (img_np * 255.0).round() 94 | img_np = img_np.astype(out_type) 95 | result.append(img_np) 96 | if len(result) == 1: 97 | result = result[0] 98 | return result 99 | 100 | 101 | def imfrombytes(content, flag='color', float32=False): 102 | """Read an image from bytes. 103 | 104 | Args: 105 | content (bytes): Image bytes got from files or other streams. 106 | flag (str): Flags specifying the color type of a loaded image, 107 | candidates are `color`, `grayscale` and `unchanged`. 108 | float32 (bool): Whether to change to float32., If True, will also norm 109 | to [0, 1]. Default: False. 110 | 111 | Returns: 112 | ndarray: Loaded image array. 113 | """ 114 | img_np = np.frombuffer(content, np.uint8) 115 | imread_flags = { 116 | 'color': cv2.IMREAD_COLOR, 117 | 'grayscale': cv2.IMREAD_GRAYSCALE, 118 | 'unchanged': cv2.IMREAD_UNCHANGED 119 | } 120 | if img_np is None: 121 | raise Exception('None .. !!!') 122 | img = cv2.imdecode(img_np, imread_flags[flag]) 123 | if float32: 124 | img = img.astype(np.float32) / 255. 125 | return img 126 | 127 | def imfrombytesDP(content, flag='color', float32=False): 128 | """Read an image from bytes. 129 | 130 | Args: 131 | content (bytes): Image bytes got from files or other streams. 132 | flag (str): Flags specifying the color type of a loaded image, 133 | candidates are `color`, `grayscale` and `unchanged`. 134 | float32 (bool): Whether to change to float32., If True, will also norm 135 | to [0, 1]. Default: False. 136 | 137 | Returns: 138 | ndarray: Loaded image array. 139 | """ 140 | img_np = np.frombuffer(content, np.uint8) 141 | if img_np is None: 142 | raise Exception('None .. !!!') 143 | img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED) 144 | if float32: 145 | img = img.astype(np.float32) / 65535. 146 | return img 147 | 148 | def padding(img_lq, img_gt, gt_size): 149 | h, w, _ = img_lq.shape 150 | 151 | h_pad = max(0, gt_size - h) 152 | w_pad = max(0, gt_size - w) 153 | 154 | if h_pad == 0 and w_pad == 0: 155 | return img_lq, img_gt 156 | 157 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 158 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 159 | # print('img_lq', img_lq.shape, img_gt.shape) 160 | if img_lq.ndim == 2: 161 | img_lq = np.expand_dims(img_lq, axis=2) 162 | if img_gt.ndim == 2: 163 | img_gt = np.expand_dims(img_gt, axis=2) 164 | return img_lq, img_gt 165 | 166 | def padding_DP(img_lqL, img_lqR, img_gt, gt_size): 167 | h, w, _ = img_gt.shape 168 | 169 | h_pad = max(0, gt_size - h) 170 | w_pad = max(0, gt_size - w) 171 | 172 | if h_pad == 0 and w_pad == 0: 173 | return img_lqL, img_lqR, img_gt 174 | 175 | img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 176 | img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 177 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 178 | # print('img_lq', img_lq.shape, img_gt.shape) 179 | return img_lqL, img_lqR, img_gt 180 | 181 | def imwrite(img, file_path, params=None, auto_mkdir=True): 182 | """Write image to file. 183 | 184 | Args: 185 | img (ndarray): Image array to be written. 186 | file_path (str): Image file path. 187 | params (None or list): Same as opencv's :func:`imwrite` interface. 188 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 189 | whether to create it automatically. 190 | 191 | Returns: 192 | bool: Successful or not. 193 | """ 194 | if auto_mkdir: 195 | dir_name = os.path.abspath(os.path.dirname(file_path)) 196 | os.makedirs(dir_name, exist_ok=True) 197 | return cv2.imwrite(file_path, img, params) 198 | 199 | 200 | def crop_border(imgs, crop_border): 201 | """Crop borders of images. 202 | 203 | Args: 204 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 205 | crop_border (int): Crop border for each end of height and weight. 206 | 207 | Returns: 208 | list[ndarray]: Cropped images. 209 | """ 210 | if crop_border == 0: 211 | return imgs 212 | else: 213 | if isinstance(imgs, list): 214 | return [ 215 | v[crop_border:-crop_border, crop_border:-crop_border, ...] 216 | for v in imgs 217 | ] 218 | else: 219 | return imgs[crop_border:-crop_border, crop_border:-crop_border, 220 | ...] 221 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ( 62 | 'img_path_list and keys should have the same length, ' 63 | f'but got {len(img_path_list)} and {len(keys)}') 64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 65 | print(f'Totoal images: {len(img_path_list)}') 66 | if not lmdb_path.endswith('.lmdb'): 67 | raise ValueError("lmdb_path must end with '.lmdb'.") 68 | if osp.exists(lmdb_path): 69 | print(f'Folder {lmdb_path} already exists. Exit.') 70 | sys.exit(1) 71 | 72 | if multiprocessing_read: 73 | # read all the images to memory (multiprocessing) 74 | dataset = {} # use dict to keep the order for multiprocessing 75 | shapes = {} 76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 77 | pbar = tqdm(total=len(img_path_list), unit='image') 78 | 79 | def callback(arg): 80 | """get the image data and update pbar.""" 81 | key, dataset[key], shapes[key] = arg 82 | pbar.update(1) 83 | pbar.set_description(f'Read {key}') 84 | 85 | pool = Pool(n_thread) 86 | for path, key in zip(img_path_list, keys): 87 | pool.apply_async( 88 | read_img_worker, 89 | args=(osp.join(data_path, path), key, compress_level), 90 | callback=callback) 91 | pool.close() 92 | pool.join() 93 | pbar.close() 94 | print(f'Finish reading {len(img_path_list)} images.') 95 | 96 | # create lmdb environment 97 | if map_size is None: 98 | # obtain data size for one image 99 | img = cv2.imread( 100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 101 | _, img_byte = cv2.imencode( 102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 103 | data_size_per_img = img_byte.nbytes 104 | print('Data size per image is: ', data_size_per_img) 105 | data_size = data_size_per_img * len(img_path_list) 106 | map_size = data_size * 10 107 | 108 | env = lmdb.open(lmdb_path, map_size=map_size) 109 | 110 | # write data to lmdb 111 | pbar = tqdm(total=len(img_path_list), unit='chunk') 112 | txn = env.begin(write=True) 113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 115 | pbar.update(1) 116 | pbar.set_description(f'Write {key}') 117 | key_byte = key.encode('ascii') 118 | if multiprocessing_read: 119 | img_byte = dataset[key] 120 | h, w, c = shapes[key] 121 | else: 122 | _, img_byte, img_shape = read_img_worker( 123 | osp.join(data_path, path), key, compress_level) 124 | h, w, c = img_shape 125 | 126 | txn.put(key_byte, img_byte) 127 | # write meta information 128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 129 | if idx % batch == 0: 130 | txn.commit() 131 | txn = env.begin(write=True) 132 | pbar.close() 133 | txn.commit() 134 | env.close() 135 | txt_file.close() 136 | print('\nFinish writing lmdb.') 137 | 138 | 139 | def read_img_worker(path, key, compress_level): 140 | """Read image worker. 141 | 142 | Args: 143 | path (str): Image path. 144 | key (str): Image key. 145 | compress_level (int): Compress level when encoding images. 146 | 147 | Returns: 148 | str: Image key. 149 | byte: Image byte. 150 | tuple[int]: Image shape. 151 | """ 152 | 153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 154 | if img.ndim == 2: 155 | h, w = img.shape 156 | c = 1 157 | else: 158 | h, w, c = img.shape 159 | _, img_byte = cv2.imencode('.png', img, 160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 161 | return (key, img_byte, (h, w, c)) 162 | 163 | 164 | class LmdbMaker(): 165 | """LMDB Maker. 166 | 167 | Args: 168 | lmdb_path (str): Lmdb save path. 169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 170 | batch (int): After processing batch images, lmdb commits. 171 | Default: 5000. 172 | compress_level (int): Compress level when encoding images. Default: 1. 173 | """ 174 | 175 | def __init__(self, 176 | lmdb_path, 177 | map_size=1024**4, 178 | batch=5000, 179 | compress_level=1): 180 | if not lmdb_path.endswith('.lmdb'): 181 | raise ValueError("lmdb_path must end with '.lmdb'.") 182 | if osp.exists(lmdb_path): 183 | print(f'Folder {lmdb_path} already exists. Exit.') 184 | sys.exit(1) 185 | 186 | self.lmdb_path = lmdb_path 187 | self.batch = batch 188 | self.compress_level = compress_level 189 | self.env = lmdb.open(lmdb_path, map_size=map_size) 190 | self.txn = self.env.begin(write=True) 191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 192 | self.counter = 0 193 | 194 | def put(self, img_byte, key, img_shape): 195 | self.counter += 1 196 | key_byte = key.encode('ascii') 197 | self.txn.put(key_byte, img_byte) 198 | # write meta information 199 | h, w, c = img_shape 200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 201 | if self.counter % self.batch == 0: 202 | self.txn.commit() 203 | self.txn = self.env.begin(write=True) 204 | 205 | def close(self): 206 | self.txn.commit() 207 | self.env.close() 208 | self.txt_file.close() 209 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | 13 | Args: 14 | opt (dict): Config. It contains the following keys: 15 | name (str): Exp name. 16 | logger (dict): Contains 'print_freq' (str) for logger interval. 17 | train (dict): Contains 'total_iter' (int) for total iters. 18 | use_tb_logger (bool): Use tensorboard logger. 19 | start_iter (int): Start iter. Default: 1. 20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 21 | """ 22 | 23 | def __init__(self, opt, start_iter=1, tb_logger=None): 24 | self.exp_name = opt['name'] 25 | self.interval = opt['logger']['print_freq'] 26 | self.start_iter = start_iter 27 | self.max_iters = opt['train']['total_iter'] 28 | self.use_tb_logger = opt['logger']['use_tb_logger'] 29 | self.tb_logger = tb_logger 30 | self.start_time = time.time() 31 | self.logger = get_root_logger() 32 | 33 | @master_only 34 | def __call__(self, log_vars): 35 | """Format logging message. 36 | 37 | Args: 38 | log_vars (dict): It contains the following keys: 39 | epoch (int): Epoch number. 40 | iter (int): Current iter. 41 | lrs (list): List for learning rates. 42 | 43 | time (float): Iter time. 44 | data_time (float): Data time for each iter. 45 | """ 46 | # epoch, iter, learning rates 47 | epoch = log_vars.pop('epoch') 48 | current_iter = log_vars.pop('iter') 49 | lrs = log_vars.pop('lrs') 50 | 51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 52 | for v in lrs: 53 | message += f'{v:.3e},' 54 | message += ')] ' 55 | 56 | # time and estimated time 57 | if 'time' in log_vars.keys(): 58 | iter_time = log_vars.pop('time') 59 | data_time = log_vars.pop('data_time') 60 | 61 | total_time = time.time() - self.start_time 62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 65 | message += f'[eta: {eta_str}, ' 66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 67 | 68 | # other items, especially losses 69 | for k, v in log_vars.items(): 70 | message += f'{k}: {v:.4e} ' 71 | # tensorboard logger 72 | if self.use_tb_logger and 'debug' not in self.exp_name: 73 | if k.startswith('l_'): 74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 75 | else: 76 | self.tb_logger.add_scalar(k, v, current_iter) 77 | self.logger.info(message) 78 | 79 | 80 | @master_only 81 | def init_tb_logger(log_dir): 82 | from torch.utils.tensorboard import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir=log_dir) 84 | return tb_logger 85 | 86 | 87 | @master_only 88 | def init_wandb_logger(opt): 89 | """We now only use wandb to sync tensorboard log.""" 90 | import wandb 91 | logger = logging.getLogger('basicsr') 92 | 93 | project = opt['logger']['wandb']['project'] 94 | resume_id = opt['logger']['wandb'].get('resume_id') 95 | if resume_id: 96 | wandb_id = resume_id 97 | resume = 'allow' 98 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 99 | else: 100 | wandb_id = wandb.util.generate_id() 101 | resume = 'never' 102 | 103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 104 | 105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 106 | 107 | 108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 109 | """Get the root logger. 110 | 111 | The logger will be initialized if it has not been initialized. By default a 112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 113 | also be added. 114 | 115 | Args: 116 | logger_name (str): root logger name. Default: 'basicsr'. 117 | log_file (str | None): The log filename. If specified, a FileHandler 118 | will be added to the root logger. 119 | log_level (int): The root logger level. Note that only the process of 120 | rank 0 is affected, while other processes will set the level to 121 | "Error" and be silent most of the time. 122 | 123 | Returns: 124 | logging.Logger: The root logger. 125 | """ 126 | logger = logging.getLogger(logger_name) 127 | # if the logger has been initialized, just return it 128 | if logger_name in initialized_logger: 129 | return logger 130 | 131 | format_str = '%(asctime)s %(levelname)s: %(message)s' 132 | stream_handler = logging.StreamHandler() 133 | stream_handler.setFormatter(logging.Formatter(format_str)) 134 | logger.addHandler(stream_handler) 135 | logger.propagate = False 136 | rank, _ = get_dist_info() 137 | if rank != 0: 138 | logger.setLevel('ERROR') 139 | elif log_file is not None: 140 | logger.setLevel(log_level) 141 | # add file handler 142 | file_handler = logging.FileHandler(log_file, 'w') 143 | file_handler.setFormatter(logging.Formatter(format_str)) 144 | file_handler.setLevel(log_level) 145 | logger.addHandler(file_handler) 146 | initialized_logger[logger_name] = True 147 | return logger 148 | 149 | 150 | def get_env_info(): 151 | """Get environment information. 152 | 153 | Currently, only log the software version. 154 | """ 155 | import torch 156 | import torchvision 157 | 158 | from basicsr.version import __version__ 159 | msg = r""" 160 | ____ _ _____ ____ 161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 164 | /_____/ \__,_//____//_/ \___//____//_/ |_| 165 | ______ __ __ __ __ 166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 170 | """ 171 | msg += ('\nVersion Information: ' 172 | f'\n\tBasicSR: {__version__}' 173 | f'\n\tPyTorch: {torch.__version__}' 174 | f'\n\tTorchVision: {torchvision.__version__}') 175 | return msg -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | from .logger import get_root_logger 10 | 11 | 12 | def set_random_seed(seed): 13 | """Set random seeds.""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | @master_only 39 | def make_exp_dirs(opt): 40 | """Make dirs for experiments.""" 41 | path_opt = opt['path'].copy() 42 | if opt['is_train']: 43 | mkdir_and_rename(path_opt.pop('experiments_root')) 44 | else: 45 | mkdir_and_rename(path_opt.pop('results_root')) 46 | for key, path in path_opt.items(): 47 | if ('strict_load' not in key) and ('pretrain_network' 48 | not in key) and ('resume' 49 | not in key): 50 | os.makedirs(path, exist_ok=True) 51 | 52 | 53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 54 | """Scan a directory to find the interested files. 55 | 56 | Args: 57 | dir_path (str): Path of the directory. 58 | suffix (str | tuple(str), optional): File suffix that we are 59 | interested in. Default: None. 60 | recursive (bool, optional): If set to True, recursively scan the 61 | directory. Default: False. 62 | full_path (bool, optional): If set to True, include the dir_path. 63 | Default: False. 64 | 65 | Returns: 66 | A generator for all the interested files with relative pathes. 67 | """ 68 | 69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 70 | raise TypeError('"suffix" must be a string or tuple of strings') 71 | 72 | root = dir_path 73 | 74 | def _scandir(dir_path, suffix, recursive): 75 | for entry in os.scandir(dir_path): 76 | if not entry.name.startswith('.') and entry.is_file(): 77 | if full_path: 78 | return_path = entry.path 79 | else: 80 | return_path = osp.relpath(entry.path, root) 81 | 82 | if suffix is None: 83 | yield return_path 84 | elif return_path.endswith(suffix): 85 | yield return_path 86 | else: 87 | if recursive: 88 | yield from _scandir( 89 | entry.path, suffix=suffix, recursive=recursive) 90 | else: 91 | continue 92 | 93 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 94 | 95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 96 | """Scan a directory to find the interested files. 97 | 98 | Args: 99 | dir_path (str): Path of the directory. 100 | keywords (str | tuple(str), optional): File keywords that we are 101 | interested in. Default: None. 102 | recursive (bool, optional): If set to True, recursively scan the 103 | directory. Default: False. 104 | full_path (bool, optional): If set to True, include the dir_path. 105 | Default: False. 106 | 107 | Returns: 108 | A generator for all the interested files with relative pathes. 109 | """ 110 | 111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 112 | raise TypeError('"keywords" must be a string or tuple of strings') 113 | 114 | root = dir_path 115 | 116 | def _scandir(dir_path, keywords, recursive): 117 | for entry in os.scandir(dir_path): 118 | if not entry.name.startswith('.') and entry.is_file(): 119 | if full_path: 120 | return_path = entry.path 121 | else: 122 | return_path = osp.relpath(entry.path, root) 123 | 124 | if keywords is None: 125 | yield return_path 126 | elif return_path.find(keywords) > 0: 127 | yield return_path 128 | else: 129 | if recursive: 130 | yield from _scandir( 131 | entry.path, keywords=keywords, recursive=recursive) 132 | else: 133 | continue 134 | 135 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 136 | 137 | def check_resume(opt, resume_iter): 138 | """Check resume states and pretrain_network paths. 139 | 140 | Args: 141 | opt (dict): Options. 142 | resume_iter (int): Resume iteration. 143 | """ 144 | logger = get_root_logger() 145 | if opt['path']['resume_state']: 146 | # get all the networks 147 | networks = [key for key in opt.keys() if key.startswith('network_')] 148 | flag_pretrain = False 149 | for network in networks: 150 | if opt['path'].get(f'pretrain_{network}') is not None: 151 | flag_pretrain = True 152 | if flag_pretrain: 153 | logger.warning( 154 | 'pretrain_network path will be ignored during resuming.') 155 | # set pretrained model paths 156 | for network in networks: 157 | name = f'pretrain_{network}' 158 | basename = network.replace('network_', '') 159 | if opt['path'].get('ignore_resume_networks') is None or ( 160 | basename not in opt['path']['ignore_resume_networks']): 161 | opt['path'][name] = osp.join( 162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 163 | logger.info(f"Set {name} to {opt['path'][name]}") 164 | 165 | 166 | def sizeof_fmt(size, suffix='B'): 167 | """Get human readable file size. 168 | 169 | Args: 170 | size (int): File size. 171 | suffix (str): Suffix. Default: 'B'. 172 | 173 | Return: 174 | str: Formated file siz. 175 | """ 176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 177 | if abs(size) < 1024.0: 178 | return f'{size:3.1f} {unit}{suffix}' 179 | size /= 1024.0 180 | return f'{size:3.1f} Y{suffix}' 181 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Sun Jan 28 22:05:08 2024 3 | __version__ = '1.2.0+733ceb2' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 11 | CUDAExtension) 12 | 13 | version_file = 'basicsr/version.py' 14 | 15 | 16 | def readme(): 17 | return '' 18 | # with open('README.md', encoding='utf-8') as f: 19 | # content = f.read() 20 | # return content 21 | 22 | 23 | def get_git_hash(): 24 | 25 | def _minimal_ext_cmd(cmd): 26 | # construct minimal environment 27 | env = {} 28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 29 | v = os.environ.get(k) 30 | if v is not None: 31 | env[k] = v 32 | # LANGUAGE is used on win32 33 | env['LANGUAGE'] = 'C' 34 | env['LANG'] = 'C' 35 | env['LC_ALL'] = 'C' 36 | out = subprocess.Popen( 37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 38 | return out 39 | 40 | try: 41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 42 | sha = out.strip().decode('ascii') 43 | except OSError: 44 | sha = 'unknown' 45 | 46 | return sha 47 | 48 | 49 | def get_hash(): 50 | if os.path.exists('.git'): 51 | sha = get_git_hash()[:7] 52 | elif os.path.exists(version_file): 53 | try: 54 | from basicsr.version import __version__ 55 | sha = __version__.split('+')[-1] 56 | except ImportError: 57 | raise ImportError('Unable to get git version') 58 | else: 59 | sha = 'unknown' 60 | 61 | return sha 62 | 63 | 64 | def write_version_py(): 65 | content = """# GENERATED VERSION FILE 66 | # TIME: {} 67 | __version__ = '{}' 68 | short_version = '{}' 69 | version_info = ({}) 70 | """ 71 | sha = get_hash() 72 | with open('VERSION', 'r') as f: 73 | SHORT_VERSION = f.read().strip() 74 | VERSION_INFO = ', '.join( 75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 76 | VERSION = SHORT_VERSION + '+' + sha 77 | 78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 79 | VERSION_INFO) 80 | with open(version_file, 'w') as f: 81 | f.write(version_file_str) 82 | 83 | 84 | def get_version(): 85 | with open(version_file, 'r') as f: 86 | exec(compile(f.read(), version_file, 'exec')) 87 | return locals()['__version__'] 88 | 89 | 90 | def make_cuda_ext(name, module, sources, sources_cuda=None): 91 | if sources_cuda is None: 92 | sources_cuda = [] 93 | define_macros = [] 94 | extra_compile_args = {'cxx': []} 95 | 96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 97 | define_macros += [('WITH_CUDA', None)] 98 | extension = CUDAExtension 99 | extra_compile_args['nvcc'] = [ 100 | '-D__CUDA_NO_HALF_OPERATORS__', 101 | '-D__CUDA_NO_HALF_CONVERSIONS__', 102 | '-D__CUDA_NO_HALF2_OPERATORS__', 103 | ] 104 | sources += sources_cuda 105 | else: 106 | print(f'Compiling {name} without CUDA') 107 | extension = CppExtension 108 | 109 | return extension( 110 | name=f'{module}.{name}', 111 | sources=[os.path.join(*module.split('.'), p) for p in sources], 112 | define_macros=define_macros, 113 | extra_compile_args=extra_compile_args) 114 | 115 | 116 | def get_requirements(filename='requirements.txt'): 117 | return [] 118 | here = os.path.dirname(os.path.realpath(__file__)) 119 | with open(os.path.join(here, filename), 'r') as f: 120 | requires = [line.replace('\n', '') for line in f.readlines()] 121 | return requires 122 | 123 | 124 | if __name__ == '__main__': 125 | if '--no_cuda_ext' in sys.argv: 126 | ext_modules = [] 127 | sys.argv.remove('--no_cuda_ext') 128 | else: 129 | ext_modules = [ 130 | make_cuda_ext( 131 | name='deform_conv_ext', 132 | module='basicsr.models.ops.dcn', 133 | sources=['src/deform_conv_ext.cpp'], 134 | sources_cuda=[ 135 | 'src/deform_conv_cuda.cpp', 136 | 'src/deform_conv_cuda_kernel.cu' 137 | ]), 138 | make_cuda_ext( 139 | name='fused_act_ext', 140 | module='basicsr.models.ops.fused_act', 141 | sources=['src/fused_bias_act.cpp'], 142 | sources_cuda=['src/fused_bias_act_kernel.cu']), 143 | make_cuda_ext( 144 | name='upfirdn2d_ext', 145 | module='basicsr.models.ops.upfirdn2d', 146 | sources=['src/upfirdn2d.cpp'], 147 | sources_cuda=['src/upfirdn2d_kernel.cu']), 148 | ] 149 | 150 | write_version_py() 151 | setup( 152 | name='basicsr', 153 | version=get_version(), 154 | description='Open Source Image and Video Super-Resolution Toolbox', 155 | long_description=readme(), 156 | author='Xintao Wang', 157 | author_email='xintao.wang@outlook.com', 158 | keywords='computer vision, restoration, super resolution', 159 | url='https://github.com/xinntao/BasicSR', 160 | packages=find_packages( 161 | exclude=('options', 'datasets', 'experiments', 'results', 162 | 'tb_logger', 'wandb')), 163 | classifiers=[ 164 | 'Development Status :: 4 - Beta', 165 | 'License :: OSI Approved :: Apache Software License', 166 | 'Operating System :: OS Independent', 167 | 'Programming Language :: Python :: 3', 168 | 'Programming Language :: Python :: 3.7', 169 | 'Programming Language :: Python :: 3.8', 170 | ], 171 | license='Apache License 2.0', 172 | setup_requires=['cython', 'numpy'], 173 | install_requires=get_requirements(), 174 | ext_modules=ext_modules, 175 | cmdclass={'build_ext': BuildExtension}, 176 | zip_safe=False) 177 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | 2 | #dehaze 3 | # python test_SOTS.py 4 | 5 | #derain 6 | # python test_spad.py 7 | 8 | #deraindrop 9 | # python test_AGAN.py 10 | 11 | #deblur 12 | # python test_FPro.py 13 | 14 | #demoire 15 | # python test_moire.py -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | 5 | python -m torch.distributed.launch --nproc_per_node=2 --master_port=4321 basicsr/train.py -opt $CONFIG --launcher pytorch 6 | --------------------------------------------------------------------------------