├── .gitignore ├── Deblurring ├── Datasets │ └── README.md ├── Options │ └── Deblurring_Restormer.yml ├── README.md ├── download_data.py ├── evaluate_gopro_hide.m ├── evaluate_realblur.py ├── generate_patches_gopro.py ├── pretrained_models │ └── README.md ├── test.py └── utils.py ├── Denoising ├── Datasets │ └── README.md ├── Options │ ├── GaussianColorDenoising_Restormer.yml │ ├── GaussianColorDenoising_RestormerSigma15.yml │ ├── GaussianColorDenoising_RestormerSigma25.yml │ ├── GaussianColorDenoising_RestormerSigma50.yml │ ├── GaussianGrayDenoising_Restormer.yml │ ├── GaussianGrayDenoising_RestormerSigma15.yml │ ├── GaussianGrayDenoising_RestormerSigma25.yml │ ├── GaussianGrayDenoising_RestormerSigma50.yml │ └── RealDenoising_Restormer.yml ├── README.md ├── download_data.py ├── evaluate_gaussian_color_denoising.py ├── evaluate_gaussian_gray_denoising.py ├── evaluate_sidd.m ├── generate_patches_dfwb.py ├── generate_patches_sidd.py ├── pretrained_models │ └── README.md ├── test_gaussian_color_denoising.py ├── test_gaussian_gray_denoising.py ├── test_real_denoising_dnd.py ├── test_real_denoising_sidd.py └── utils.py ├── Deraining ├── Datasets │ └── README.md ├── Options │ └── Deraining_Restormer.yml ├── README.md ├── download_data.py ├── evaluate_PSNR_SSIM.m ├── pretrained_models │ └── README.md ├── test.py └── utils.py ├── INSTALL.md ├── LICENSE ├── README.md ├── basicsr ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── metrics │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── archs │ │ ├── Maxim_arch.py │ │ ├── __init__.py │ │ ├── arch_util.py │ │ ├── maxim.py │ │ └── restormer_arch.py │ ├── base_model.py │ ├── image_restoration_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ └── lr_scheduler.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── bundle_submissions.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py └── version.py ├── demo.py ├── images ├── Deblurring │ └── input │ │ ├── 109fromGOPR1096.MP4.png │ │ ├── 110fromGOPR1087.MP4.png │ │ ├── 1fromGOPR0950.png │ │ └── 1fromGOPR1096.MP4.png ├── Dehazing │ └── input │ │ ├── 0003_0.8_0.2.png │ │ ├── 0010_0.95_0.16.png │ │ ├── 0014_0.8_0.12.png │ │ ├── 0048_0.9_0.2.png │ │ ├── 1440_10.png │ │ └── 1444_10.png ├── Denoising │ └── input │ │ ├── 0003_30.png │ │ ├── 0011_23.png │ │ ├── 0013_19.png │ │ └── 0039_04.png ├── Deraining │ └── input │ │ ├── 0.jpg │ │ ├── 1.png │ │ ├── 15.png │ │ └── 55.png ├── Enhancement │ └── input │ │ ├── 1.png │ │ ├── 111.png │ │ ├── 748.png │ │ └── a4541-DSC_0040-2.png ├── Results │ ├── 0.jpg │ ├── 1.png │ ├── 15.png │ └── 55.png └── overview.png ├── maxim_pytorch ├── README.md ├── jax2torch.py └── maxim_torch.py ├── setup.cfg ├── setup.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | *.pyc 3 | *.png 4 | *.tif 5 | *.jpg 6 | *.pth 7 | *.mat 8 | *.npy 9 | .DS_Store 10 | 11 | -------------------------------------------------------------------------------- /Deblurring/Datasets/README.md: -------------------------------------------------------------------------------- 1 | For training and testing, your directory structure should look like this 2 | 3 | `Datasets`
4 |  `├──train`
5 |      `└──GoPro`
6 |           `├──input_crops`
7 |           `└──target_crops`
8 |  `├──val`
9 |      `└──GoPro`
10 |           `├──input_crops`
11 |           `└──target_crops`
12 |  `└──test`
13 |      `├──GoPro`
14 |           `├──input`
15 |           `└──target`
16 |      `├──HIDE`
17 |           `├──input`
18 |           `└──target`
19 |      `├──RealBlur_J`
20 |           `├──input`
21 |           `└──target`
22 |      `└──RealBlur_R`
23 |           `├──input`
24 |           `└──target` 25 | 26 | -------------------------------------------------------------------------------- /Deblurring/Options/Deblurring_Restormer.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Deblurring_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./Motion_Deblurring/Datasets/train/GoPro/target_crops 14 | dataroot_lq: ./Motion_Deblurring/Datasets/train/GoPro/input_crops 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### -------------Progressive training-------------------------- 27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | iters: [92000,64000,48000,36000,36000,24000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | ### ------- Training on single fixed-patch size 128x128--------- 34 | # mini_batch_sizes: [8] 35 | # iters: [300000] 36 | # gt_size: 128 37 | # gt_sizes: [128] 38 | ### ------------------------------------------------------------ 39 | 40 | dataset_enlarge_ratio: 1 41 | prefetch_mode: ~ 42 | 43 | val: 44 | name: ValSet 45 | type: Dataset_PairedImage 46 | dataroot_gt: ./Motion_Deblurring/Datasets/val/GoPro/target_crops 47 | dataroot_lq: ./Motion_Deblurring/Datasets/val/GoPro/input_crops 48 | io_backend: 49 | type: disk 50 | 51 | # network structures 52 | network_g: 53 | type: Restormer 54 | inp_channels: 3 55 | out_channels: 3 56 | dim: 48 57 | num_blocks: [4,6,6,8] 58 | num_refinement_blocks: 4 59 | heads: [1,2,4,8] 60 | ffn_expansion_factor: 2.66 61 | bias: False 62 | LayerNorm_type: WithBias 63 | dual_pixel_task: False 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: false 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | 104 | # validation settings 105 | val: 106 | window_size: 8 107 | val_freq: !!float 4e3 108 | save_img: false 109 | rgb2bgr: true 110 | use_image: true 111 | max_minibatch: 8 112 | 113 | metrics: 114 | psnr: # metric name, can be arbitrary 115 | type: calculate_psnr 116 | crop_border: 0 117 | test_y_channel: false 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 1000 122 | save_checkpoint_freq: !!float 4e3 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /Deblurring/README.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | 3 | 1. To download GoPro training and testing data, run 4 | ``` 5 | python download_data.py --data train-test 6 | ``` 7 | 8 | 2. Generate image patches from full-resolution training images of GoPro dataset 9 | ``` 10 | python generate_patches_gopro.py 11 | ``` 12 | 13 | 3. To train Restormer, run 14 | ``` 15 | cd Restormer 16 | ./train.sh Motion_Deblurring/Options/Deblurring_Restormer.yml 17 | ``` 18 | 19 | **Note:** The above training script uses 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and [Motion_Deblurring/Options/Deblurring_Restormer.yml](Options/Deblurring_Restormer.yml) 20 | 21 | ## Evaluation 22 | 23 | Download the pre-trained [model](https://drive.google.com/drive/folders/1czMyfRTQDX3j3ErByYeZ1PM4GVLbJeGK?usp=sharing) and place it in `./pretrained_models/` 24 | 25 | #### Testing on GoPro dataset 26 | 27 | - Download GoPro testset, run 28 | ``` 29 | python download_data.py --data test --dataset GoPro 30 | ``` 31 | 32 | - Testing 33 | ``` 34 | python test.py --dataset GoPro 35 | ``` 36 | 37 | #### Testing on HIDE dataset 38 | 39 | - Download HIDE testset, run 40 | ``` 41 | python download_data.py --data test --dataset HIDE 42 | ``` 43 | 44 | - Testing 45 | ``` 46 | python test.py --dataset HIDE 47 | ``` 48 | 49 | #### Testing on RealBlur-J dataset 50 | 51 | - Download RealBlur-J testset, run 52 | ``` 53 | python download_data.py --data test --dataset RealBlur_J 54 | ``` 55 | 56 | - Testing 57 | ``` 58 | python test.py --dataset RealBlur_J 59 | ``` 60 | 61 | #### Testing on RealBlur-R dataset 62 | 63 | - Download RealBlur-R testset, run 64 | ``` 65 | python download_data.py --data test --dataset RealBlur_R 66 | ``` 67 | 68 | - Testing 69 | ``` 70 | python test.py --dataset RealBlur_R 71 | ``` 72 | 73 | #### To reproduce PSNR/SSIM scores of the paper (Table 2) on GoPro and HIDE datasets, run this MATLAB script 74 | 75 | ``` 76 | evaluate_gopro_hide.m 77 | ``` 78 | 79 | #### To reproduce PSNR/SSIM scores of the paper (Table 2) on RealBlur dataset, run 80 | 81 | ``` 82 | evaluate_realblur.py 83 | ``` 84 | -------------------------------------------------------------------------------- /Deblurring/download_data.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | ## Download training and testing data for single-image motion deblurring task 6 | import os 7 | # import gdown 8 | import shutil 9 | 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--data', type=str, required=True, help='train, test or train-test') 14 | parser.add_argument('--dataset', type=str, default='GoPro', help='all, GoPro, HIDE, RealBlur_R, RealBlur_J') 15 | args = parser.parse_args() 16 | 17 | ### Google drive IDs ###### 18 | GoPro_train = '1zgALzrLCC_tcXKu_iHQTHukKUVT1aodI' ## https://drive.google.com/file/d/1zgALzrLCC_tcXKu_iHQTHukKUVT1aodI/view?usp=sharing 19 | GoPro_test = '1k6DTSHu4saUgrGTYkkZXTptILyG9RRll' ## https://drive.google.com/file/d/1k6DTSHu4saUgrGTYkkZXTptILyG9RRll/view?usp=sharing 20 | HIDE_test = '1XRomKYJF1H92g1EuD06pCQe4o6HlwB7A' ## https://drive.google.com/file/d/1XRomKYJF1H92g1EuD06pCQe4o6HlwB7A/view?usp=sharing 21 | RealBlurR_test = '1glgeWXCy7Y0qWDc0MXBTUlZYJf8984hS' ## https://drive.google.com/file/d/1glgeWXCy7Y0qWDc0MXBTUlZYJf8984hS/view?usp=sharing 22 | RealBlurJ_test = '1Rb1DhhXmX7IXfilQ-zL9aGjQfAAvQTrW' ## https://drive.google.com/file/d/1Rb1DhhXmX7IXfilQ-zL9aGjQfAAvQTrW/view?usp=sharing 23 | 24 | dataset = args.dataset 25 | 26 | for data in args.data.split('-'): 27 | if data == 'train': 28 | print('GoPro Training Data!') 29 | os.makedirs(os.path.join('Datasets', 'Downloads'), exist_ok=True) 30 | # gdown.download(id=GoPro_train, output='Datasets/Downloads/train.zip', quiet=False) 31 | os.system(f'gdrive download {GoPro_train} --path Datasets/Downloads/') 32 | print('Extracting GoPro data...') 33 | shutil.unpack_archive('Datasets/Downloads/train.zip', 'Datasets/Downloads') 34 | os.rename(os.path.join('Datasets', 'Downloads', 'train'), os.path.join('Datasets', 'Downloads', 'GoPro')) 35 | os.remove('Datasets/Downloads/train.zip') 36 | 37 | if data == 'test': 38 | if dataset == 'all' or dataset == 'GoPro': 39 | print('GoPro Testing Data!') 40 | # gdown.download(id=GoPro_test, output='Datasets/test.zip', quiet=False) 41 | os.system(f'gdrive download {GoPro_test} --path Datasets/') 42 | print('Extracting GoPro Data...') 43 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 44 | os.remove('Datasets/test.zip') 45 | 46 | if dataset == 'all' or dataset == 'HIDE': 47 | print('HIDE Testing Data!') 48 | # gdown.download(id=HIDE_test, output='Datasets/test.zip', quiet=False) 49 | os.system(f'gdrive download {HIDE_test} --path Datasets/') 50 | print('Extracting HIDE Data...') 51 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 52 | os.remove('Datasets/test.zip') 53 | 54 | if dataset == 'all' or dataset == 'RealBlur_R': 55 | print('RealBlur_R Testing Data!') 56 | # gdown.download(id=RealBlurR_test, output='Datasets/test.zip', quiet=False) 57 | os.system(f'gdrive download {RealBlurR_test} --path Datasets/') 58 | print('Extracting RealBlur_R Data...') 59 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 60 | os.remove('Datasets/test.zip') 61 | 62 | if dataset == 'all' or dataset == 'RealBlur_J': 63 | print('RealBlur_J testing Data!') 64 | # gdown.download(id=RealBlurJ_test, output='Datasets/test.zip', quiet=False) 65 | os.system(f'gdrive download {RealBlurJ_test} --path Datasets/') 66 | print('Extracting RealBlur_J Data...') 67 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 68 | os.remove('Datasets/test.zip') 69 | 70 | 71 | # print('Download completed successfully!') 72 | -------------------------------------------------------------------------------- /Deblurring/evaluate_gopro_hide.m: -------------------------------------------------------------------------------- 1 | %% Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | %% Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | %% https://arxiv.org/abs/2111.09881 4 | 5 | close all;clear all; 6 | 7 | % datasets = {'GoPro'}; 8 | datasets = {'GoPro', 'HIDE'}; 9 | num_set = length(datasets); 10 | 11 | tic 12 | delete(gcp('nocreate')) 13 | parpool('local',20); 14 | 15 | for idx_set = 1:num_set 16 | file_path = strcat('./results/', datasets{idx_set}, '/'); 17 | gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/'); 18 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 19 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 20 | img_num = length(path_list); 21 | 22 | total_psnr = 0; 23 | total_ssim = 0; 24 | if img_num > 0 25 | parfor j = 1:img_num 26 | image_name = path_list(j).name; 27 | gt_name = gt_list(j).name; 28 | input = imread(strcat(file_path,image_name)); 29 | gt = imread(strcat(gt_path, gt_name)); 30 | ssim_val = ssim(input, gt); 31 | psnr_val = psnr(input, gt); 32 | total_ssim = total_ssim + ssim_val; 33 | total_psnr = total_psnr + psnr_val; 34 | end 35 | end 36 | qm_psnr = total_psnr / img_num; 37 | qm_ssim = total_ssim / img_num; 38 | 39 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 40 | 41 | end 42 | delete(gcp('nocreate')) 43 | toc 44 | -------------------------------------------------------------------------------- /Deblurring/evaluate_realblur.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | from skimage.metrics import structural_similarity 12 | from tqdm import tqdm 13 | import concurrent.futures 14 | 15 | def image_align(deblurred, gt): 16 | # this function is based on kohler evaluation code 17 | z = deblurred 18 | c = np.ones_like(z) 19 | x = gt 20 | 21 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching 22 | 23 | warp_mode = cv2.MOTION_HOMOGRAPHY 24 | warp_matrix = np.eye(3, 3, dtype=np.float32) 25 | 26 | # Specify the number of iterations. 27 | number_of_iterations = 100 28 | 29 | termination_eps = 0 30 | 31 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 32 | number_of_iterations, termination_eps) 33 | 34 | # Run the ECC algorithm. The results are stored in warp_matrix. 35 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5) 36 | 37 | target_shape = x.shape 38 | shift = warp_matrix 39 | 40 | zr = cv2.warpPerspective( 41 | zs, 42 | warp_matrix, 43 | (target_shape[1], target_shape[0]), 44 | flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP, 45 | borderMode=cv2.BORDER_REFLECT) 46 | 47 | cr = cv2.warpPerspective( 48 | np.ones_like(zs, dtype='float32'), 49 | warp_matrix, 50 | (target_shape[1], target_shape[0]), 51 | flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP, 52 | borderMode=cv2.BORDER_CONSTANT, 53 | borderValue=0) 54 | 55 | zr = zr * cr 56 | xr = x * cr 57 | 58 | return zr, xr, cr, shift 59 | 60 | def compute_psnr(image_true, image_test, image_mask, data_range=None): 61 | # this function is based on skimage.metrics.peak_signal_noise_ratio 62 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask) 63 | return 10 * np.log10((data_range ** 2) / err) 64 | 65 | 66 | def compute_ssim(tar_img, prd_img, cr1): 67 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True) 68 | ssim_map = ssim_map * cr1 69 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage 70 | win_size = 2 * r + 1 71 | pad = (win_size - 1) // 2 72 | ssim = ssim_map[pad:-pad,pad:-pad,:] 73 | crop_cr1 = cr1[pad:-pad,pad:-pad,:] 74 | ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0) 75 | ssim = np.mean(ssim) 76 | return ssim 77 | 78 | def proc(filename): 79 | tar,prd = filename 80 | tar_img = io.imread(tar) 81 | prd_img = io.imread(prd) 82 | 83 | tar_img = tar_img.astype(np.float32)/255.0 84 | prd_img = prd_img.astype(np.float32)/255.0 85 | 86 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img) 87 | 88 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1) 89 | SSIM = compute_ssim(tar_img, prd_img, cr1) 90 | return (PSNR,SSIM) 91 | 92 | datasets = ['RealBlur_J', 'RealBlur_R'] 93 | 94 | for dataset in datasets: 95 | 96 | file_path = os.path.join('results' , dataset) 97 | gt_path = os.path.join('Datasets', 'test', dataset, 'target') 98 | 99 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 100 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 101 | 102 | assert len(path_list) != 0, "Predicted files not found" 103 | assert len(gt_list) != 0, "Target files not found" 104 | 105 | psnr, ssim = [], [] 106 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 107 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 108 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 109 | psnr.append(PSNR_SSIM[0]) 110 | ssim.append(PSNR_SSIM[1]) 111 | 112 | avg_psnr = sum(psnr)/len(psnr) 113 | avg_ssim = sum(ssim)/len(ssim) 114 | 115 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 116 | -------------------------------------------------------------------------------- /Deblurring/generate_patches_gopro.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | ##### Data preparation file for training Restormer on the GoPro Dataset ######## 6 | 7 | import cv2 8 | import numpy as np 9 | from glob import glob 10 | from natsort import natsorted 11 | import os 12 | from tqdm import tqdm 13 | from pdb import set_trace as stx 14 | from joblib import Parallel, delayed 15 | import multiprocessing 16 | 17 | def train_files(file_): 18 | lr_file, hr_file = file_ 19 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0] 20 | lr_img = cv2.imread(lr_file) 21 | hr_img = cv2.imread(hr_file) 22 | num_patch = 0 23 | w, h = lr_img.shape[:2] 24 | if w > p_max and h > p_max: 25 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int)) 26 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int)) 27 | w1.append(w-patch_size) 28 | h1.append(h-patch_size) 29 | for i in w1: 30 | for j in h1: 31 | num_patch += 1 32 | 33 | lr_patch = lr_img[i:i+patch_size, j:j+patch_size,:] 34 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:] 35 | 36 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.png') 37 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.png') 38 | 39 | cv2.imwrite(lr_savename, lr_patch) 40 | cv2.imwrite(hr_savename, hr_patch) 41 | 42 | else: 43 | lr_savename = os.path.join(lr_tar, filename + '.png') 44 | hr_savename = os.path.join(hr_tar, filename + '.png') 45 | 46 | cv2.imwrite(lr_savename, lr_img) 47 | cv2.imwrite(hr_savename, hr_img) 48 | 49 | def val_files(file_): 50 | lr_file, hr_file = file_ 51 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0] 52 | lr_img = cv2.imread(lr_file) 53 | hr_img = cv2.imread(hr_file) 54 | 55 | lr_savename = os.path.join(lr_tar, filename + '.png') 56 | hr_savename = os.path.join(hr_tar, filename + '.png') 57 | 58 | w, h = lr_img.shape[:2] 59 | 60 | i = (w-val_patch_size)//2 61 | j = (h-val_patch_size)//2 62 | 63 | lr_patch = lr_img[i:i+val_patch_size, j:j+val_patch_size,:] 64 | hr_patch = hr_img[i:i+val_patch_size, j:j+val_patch_size,:] 65 | 66 | cv2.imwrite(lr_savename, lr_patch) 67 | cv2.imwrite(hr_savename, hr_patch) 68 | 69 | ############ Prepare Training data #################### 70 | num_cores = 10 71 | patch_size = 512 72 | overlap = 256 73 | p_max = 0 74 | 75 | src = 'Datasets/Downloads/GoPro' 76 | tar = 'Datasets/train/GoPro' 77 | 78 | lr_tar = os.path.join(tar, 'input_crops') 79 | hr_tar = os.path.join(tar, 'target_crops') 80 | 81 | os.makedirs(lr_tar, exist_ok=True) 82 | os.makedirs(hr_tar, exist_ok=True) 83 | 84 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg'))) 85 | hr_files = natsorted(glob(os.path.join(src, 'target', '*.png')) + glob(os.path.join(src, 'target', '*.jpg'))) 86 | 87 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 88 | 89 | Parallel(n_jobs=num_cores)(delayed(train_files)(file_) for file_ in tqdm(files)) 90 | 91 | 92 | ############ Prepare validation data #################### 93 | val_patch_size = 256 94 | src = 'Datasets/test/GoPro' 95 | tar = 'Datasets/val/GoPro' 96 | 97 | lr_tar = os.path.join(tar, 'input_crops') 98 | hr_tar = os.path.join(tar, 'target_crops') 99 | 100 | os.makedirs(lr_tar, exist_ok=True) 101 | os.makedirs(hr_tar, exist_ok=True) 102 | 103 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg'))) 104 | hr_files = natsorted(glob(os.path.join(src, 'target', '*.png')) + glob(os.path.join(src, 'target', '*.jpg'))) 105 | 106 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 107 | 108 | Parallel(n_jobs=num_cores)(delayed(val_files)(file_) for file_ in tqdm(files)) 109 | -------------------------------------------------------------------------------- /Deblurring/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | pre-trained deblurring model is available [here](https://drive.google.com/drive/folders/1czMyfRTQDX3j3ErByYeZ1PM4GVLbJeGK?usp=sharing) -------------------------------------------------------------------------------- /Deblurring/test.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | import numpy as np 7 | import os 8 | import argparse 9 | from tqdm import tqdm 10 | 11 | import torch.nn as nn 12 | import torch 13 | import torch.nn.functional as F 14 | import utils 15 | 16 | from natsort import natsorted 17 | from glob import glob 18 | from basicsr.models.archs.restormer_arch import Restormer 19 | from skimage import img_as_ubyte 20 | from pdb import set_trace as stx 21 | 22 | parser = argparse.ArgumentParser(description='Single Image Motion Deblurring using Restormer') 23 | 24 | parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images') 25 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results') 26 | parser.add_argument('--weights', default='./pretrained_models/motion_deblurring.pth', type=str, help='Path to weights') 27 | parser.add_argument('--dataset', default='GoPro', type=str, help='Test Dataset') # ['GoPro', 'HIDE', 'RealBlur_J', 'RealBlur_R'] 28 | 29 | args = parser.parse_args() 30 | 31 | ####### Load yaml ####### 32 | yaml_file = 'Options/Deblurring_Restormer.yml' 33 | import yaml 34 | 35 | try: 36 | from yaml import CLoader as Loader 37 | except ImportError: 38 | from yaml import Loader 39 | 40 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 41 | 42 | s = x['network_g'].pop('type') 43 | ########################## 44 | 45 | model_restoration = Restormer(**x['network_g']) 46 | 47 | checkpoint = torch.load(args.weights) 48 | model_restoration.load_state_dict(checkpoint['params']) 49 | print("===>Testing using weights: ",args.weights) 50 | model_restoration.cuda() 51 | model_restoration = nn.DataParallel(model_restoration) 52 | model_restoration.eval() 53 | 54 | 55 | factor = 8 56 | dataset = args.dataset 57 | result_dir = os.path.join(args.result_dir, dataset) 58 | os.makedirs(result_dir, exist_ok=True) 59 | 60 | inp_dir = os.path.join(args.input_dir, 'test', dataset, 'input') 61 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 62 | with torch.no_grad(): 63 | for file_ in tqdm(files): 64 | torch.cuda.ipc_collect() 65 | torch.cuda.empty_cache() 66 | 67 | img = np.float32(utils.load_img(file_))/255. 68 | img = torch.from_numpy(img).permute(2,0,1) 69 | input_ = img.unsqueeze(0).cuda() 70 | 71 | # Padding in case images are not multiples of 8 72 | h,w = input_.shape[2], input_.shape[3] 73 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 74 | padh = H-h if h%factor!=0 else 0 75 | padw = W-w if w%factor!=0 else 0 76 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 77 | 78 | restored = model_restoration(input_) 79 | 80 | # Unpad images to original dimensions 81 | restored = restored[:,:,:h,:w] 82 | 83 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 84 | 85 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 86 | -------------------------------------------------------------------------------- /Deblurring/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Denoising/Datasets/README.md: -------------------------------------------------------------------------------- 1 | For training and testing, your directory structure should look like this 2 | 3 | 4 | `Datasets`
5 |  `├──train`
6 |      `├──DFWB`
7 |      `└──SIDD`
8 |           `├──input_crops`
9 |           `└──target_crops`
10 |  `├──val`
11 |      `└──SIDD`
12 |           `├──input_crops`
13 |           `└──target_crops`
14 |  `└──test`
15 |      `├──BSD68`
16 |      `├──CBSD68`
17 |      `├──Kodak`
18 |      `├──McMaster`
19 |      `├──Set12`
20 |      `├──Urban100`
21 |      `├──SIDD`
22 |           `├──ValidationNoisyBlocksSrgb.mat`
23 |           `└──ValidationGtBlocksSrgb.mat`
24 |      `├──DND`
25 |           `├──info.mat`
26 |           `└──images_srgb`
27 |                `├──0001.mat`
28 |                `├──0002.mat`
29 |                `├── ... `
30 |                `└──0050.mat` 31 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianColorDenoising_Restormer.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianColorDenoising_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: random 14 | sigma_range: [0,50] 15 | in_ch: 3 ## RGB image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 25 50 | in_ch: 3 ## RGB image 51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 3 60 | out_channels: 3 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | # path 71 | path: 72 | pretrain_network_g: ~ 73 | strict_load_g: true 74 | resume_state: ~ 75 | 76 | # training settings 77 | train: 78 | total_iter: 300000 79 | warmup_iter: -1 # no warm up 80 | use_grad_clip: true 81 | 82 | # Split 300k iterations into two cycles. 83 | # 1st cycle: fixed 3e-4 LR for 92k iters. 84 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 85 | scheduler: 86 | type: CosineAnnealingRestartCyclicLR 87 | periods: [92000, 208000] 88 | restart_weights: [1,1] 89 | eta_mins: [0.0003,0.000001] 90 | 91 | mixing_augs: 92 | mixup: true 93 | mixup_beta: 1.2 94 | use_identity: true 95 | 96 | optim_g: 97 | type: AdamW 98 | lr: !!float 3e-4 99 | weight_decay: !!float 1e-4 100 | betas: [0.9, 0.999] 101 | 102 | # losses 103 | pixel_opt: 104 | type: L1Loss 105 | loss_weight: 1 106 | reduction: mean 107 | 108 | # validation settings 109 | val: 110 | window_size: 8 111 | val_freq: !!float 4e3 112 | save_img: false 113 | rgb2bgr: true 114 | use_image: false 115 | max_minibatch: 8 116 | 117 | metrics: 118 | psnr: # metric name, can be arbitrary 119 | type: calculate_psnr 120 | crop_border: 0 121 | test_y_channel: false 122 | 123 | # logging settings 124 | logger: 125 | print_freq: 1000 126 | save_checkpoint_freq: !!float 4e3 127 | use_tb_logger: true 128 | wandb: 129 | project: ~ 130 | resume_id: ~ 131 | 132 | # dist training settings 133 | dist_params: 134 | backend: nccl 135 | port: 29500 136 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianColorDenoising_RestormerSigma15.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianColorDenoising_RestormerSigma15 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: constant 14 | sigma_range: 15 15 | in_ch: 3 ## RGB image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 15 50 | in_ch: 3 ## RGB image 51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 3 60 | out_channels: 3 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianColorDenoising_RestormerSigma25.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianColorDenoising_RestormerSigma25 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: constant 14 | sigma_range: 25 15 | in_ch: 3 ## RGB image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 25 50 | in_ch: 3 ## RGB image 51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 3 60 | out_channels: 3 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianColorDenoising_RestormerSigma50.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianColorDenoising_RestormerSigma50 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: constant 14 | sigma_range: 50 15 | in_ch: 3 ## RGB image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 50 50 | in_ch: 3 ## RGB image 51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 3 60 | out_channels: 3 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianGrayDenoising_Restormer.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianGrayDenoising_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: random 14 | sigma_range: [0,50] 15 | in_ch: 1 ## Grayscale image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 25 50 | in_ch: 1 ## Grayscale image 51 | dataroot_gt: ./Denoising/Datasets/test/BSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 1 60 | out_channels: 1 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianGrayDenoising_RestormerSigma15.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianGrayDenoising_RestormerSigma15 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: constant 14 | sigma_range: 15 15 | in_ch: 1 ## Grayscale image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 15 50 | in_ch: 1 ## Grayscale image 51 | dataroot_gt: ./Denoising/Datasets/test/BSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 1 60 | out_channels: 1 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianGrayDenoising_RestormerSigma25.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianGrayDenoising_RestormerSigma25 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: constant 14 | sigma_range: 25 15 | in_ch: 1 ## Grayscale image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 25 50 | in_ch: 1 ## Grayscale image 51 | dataroot_gt: ./Denoising/Datasets/test/BSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 1 60 | out_channels: 1 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/GaussianGrayDenoising_RestormerSigma50.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: GaussianGrayDenoising_RestormerSigma50 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_GaussianDenoising 13 | sigma_type: constant 14 | sigma_range: 50 15 | in_ch: 1 ## Grayscale image 16 | dataroot_gt: ./Denoising/Datasets/train/DFWB 17 | dataroot_lq: none 18 | geometric_augs: true 19 | 20 | filename_tmpl: '{}' 21 | io_backend: 22 | type: disk 23 | 24 | # data loader 25 | use_shuffle: true 26 | num_worker_per_gpu: 8 27 | batch_size_per_gpu: 8 28 | 29 | ### -------------Progressive training-------------------------- 30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 31 | iters: [92000,64000,48000,36000,36000,24000] 32 | gt_size: 384 # Max patch size for progressive training 33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 34 | ### ------------------------------------------------------------ 35 | 36 | ### ------- Training on single fixed-patch size 128x128--------- 37 | # mini_batch_sizes: [8] 38 | # iters: [300000] 39 | # gt_size: 128 40 | # gt_sizes: [128] 41 | ### ------------------------------------------------------------ 42 | 43 | dataset_enlarge_ratio: 1 44 | prefetch_mode: ~ 45 | 46 | val: 47 | name: ValSet 48 | type: Dataset_GaussianDenoising 49 | sigma_test: 50 50 | in_ch: 1 ## Grayscale image 51 | dataroot_gt: ./Denoising/Datasets/test/BSD68 52 | dataroot_lq: none 53 | io_backend: 54 | type: disk 55 | 56 | # network structures 57 | network_g: 58 | type: Restormer 59 | inp_channels: 1 60 | out_channels: 1 61 | dim: 48 62 | num_blocks: [4,6,6,8] 63 | num_refinement_blocks: 4 64 | heads: [1,2,4,8] 65 | ffn_expansion_factor: 2.66 66 | bias: False 67 | LayerNorm_type: BiasFree 68 | dual_pixel_task: False 69 | 70 | 71 | # path 72 | path: 73 | pretrain_network_g: ~ 74 | strict_load_g: true 75 | resume_state: ~ 76 | 77 | # training settings 78 | train: 79 | total_iter: 300000 80 | warmup_iter: -1 # no warm up 81 | use_grad_clip: true 82 | 83 | # Split 300k iterations into two cycles. 84 | # 1st cycle: fixed 3e-4 LR for 92k iters. 85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 86 | scheduler: 87 | type: CosineAnnealingRestartCyclicLR 88 | periods: [92000, 208000] 89 | restart_weights: [1,1] 90 | eta_mins: [0.0003,0.000001] 91 | 92 | mixing_augs: 93 | mixup: true 94 | mixup_beta: 1.2 95 | use_identity: true 96 | 97 | optim_g: 98 | type: AdamW 99 | lr: !!float 3e-4 100 | weight_decay: !!float 1e-4 101 | betas: [0.9, 0.999] 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | window_size: 8 112 | val_freq: !!float 4e3 113 | save_img: false 114 | rgb2bgr: true 115 | use_image: false 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 1000 127 | save_checkpoint_freq: !!float 4e3 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 137 | -------------------------------------------------------------------------------- /Denoising/Options/RealDenoising_Restormer.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: RealDenoising_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./Denoising/Datasets/train/SIDD/target_crops 14 | dataroot_lq: ./Denoising/Datasets/train/SIDD/input_crops 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### -------------Progressive training-------------------------- 27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | iters: [92000,64000,48000,36000,36000,24000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | ### ------- Training on single fixed-patch size 128x128--------- 34 | # mini_batch_sizes: [8] 35 | # iters: [300000] 36 | # gt_size: 128 37 | # gt_sizes: [128] 38 | ### ------------------------------------------------------------ 39 | 40 | dataset_enlarge_ratio: 1 41 | prefetch_mode: ~ 42 | 43 | val: 44 | name: ValSet 45 | type: Dataset_PairedImage 46 | dataroot_gt: ./Denoising/Datasets/val/SIDD/target_crops 47 | dataroot_lq: ./Denoising/Datasets/val/SIDD/input_crops 48 | io_backend: 49 | type: disk 50 | 51 | # network structures 52 | network_g: 53 | type: Restormer 54 | inp_channels: 3 55 | out_channels: 3 56 | dim: 48 57 | num_blocks: [4,6,6,8] 58 | num_refinement_blocks: 4 59 | heads: [1,2,4,8] 60 | ffn_expansion_factor: 2.66 61 | bias: False 62 | LayerNorm_type: BiasFree 63 | dual_pixel_task: False 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: true 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | 104 | # validation settings 105 | val: 106 | window_size: 8 107 | val_freq: !!float 4e3 108 | save_img: false 109 | rgb2bgr: true 110 | use_image: false 111 | max_minibatch: 8 112 | 113 | metrics: 114 | psnr: # metric name, can be arbitrary 115 | type: calculate_psnr 116 | crop_border: 0 117 | test_y_channel: false 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 1000 122 | save_checkpoint_freq: !!float 4e3 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /Denoising/README.md: -------------------------------------------------------------------------------- 1 | # Image Denoising 2 | - [Gaussian Image Denoising](#gaussian-image-denoising) 3 | * [Training](#training) 4 | * [Evaluation](#evaluation) 5 | - [Grayscale blind image denoising testing](#grayscale-blind-image-denoising-testing) 6 | - [Grayscale non-blind image denoising testing](#grayscale-non-blind-image-denoising-testing) 7 | - [Color blind image denoising testing](#color-blind-image-denoising-testing) 8 | - [Color non-blind image denoising testing](#color-non-blind-image-denoising-testing) 9 | - [Real Image Denoising](#real-image-denoising) 10 | * [Training](#training-1) 11 | * [Evaluation](#evaluation-1) 12 | - [Testing on SIDD dataset](#testing-on-sidd-dataset) 13 | - [Testing on DND dataset](#testing-on-dnd-dataset) 14 | 15 | # Gaussian Image Denoising 16 | 17 | - **Blind Denoising:** One model to handle various noise levels 18 | - **Non-Blind Denoising:** Separate models for each noise level 19 | 20 | ## Training 21 | 22 | - Download training (DIV2K, Flickr2K, WED, BSD) and testing datasets, run 23 | ``` 24 | python download_data.py --data train-test --noise gaussian 25 | ``` 26 | 27 | - Generate image patches from full-resolution training images, run 28 | ``` 29 | python generate_patches_dfwb.py 30 | ``` 31 | 32 | - Train Restormer for **grayscale blind** image denoising, run 33 | ``` 34 | cd Restormer 35 | ./train.sh Denoising/Options/GaussianGrayDenoising_Restormer.yml 36 | ``` 37 | 38 | - Train Restormer for **grayscale non-blind** image denoising, run 39 | ``` 40 | cd Restormer 41 | ./train.sh Denoising/Options/GaussianGrayDenoising_RestormerSigma15.yml 42 | ./train.sh Denoising/Options/GaussianGrayDenoising_RestormerSigma25.yml 43 | ./train.sh Denoising/Options/GaussianGrayDenoising_RestormerSigma50.yml 44 | ``` 45 | 46 | - Train Restormer for **color blind** image denoising, run 47 | ``` 48 | cd Restormer 49 | ./train.sh Denoising/Options/GaussianColorDenoising_Restormer.yml 50 | ``` 51 | 52 | - Train Restormer for **color non-blind** image denoising, run 53 | ``` 54 | cd Restormer 55 | ./train.sh Denoising/Options/GaussianColorDenoising_RestormerSigma15.yml 56 | ./train.sh Denoising/Options/GaussianColorDenoising_RestormerSigma25.yml 57 | ./train.sh Denoising/Options/GaussianColorDenoising_RestormerSigma50.yml 58 | ``` 59 | 60 | **Note:** The above training scripts use 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and the yaml file corresponding to each task (e.g., [Denoising/Options/GaussianGrayDenoising_Restormer.yml](Options/GaussianGrayDenoising_Restormer.yml)) 61 | 62 | ## Evaluation 63 | 64 | - Download the pre-trained [models](https://drive.google.com/drive/folders/1Qwsjyny54RZWa7zC4Apg7exixLBo4uF0?usp=sharing) and place them in `./pretrained_models/` 65 | 66 | - Download testsets (Set12, BSD68, CBSD68, Kodak, McMaster, Urban100), run 67 | ``` 68 | python download_data.py --data test --noise gaussian 69 | ``` 70 | 71 | #### Grayscale blind image denoising testing 72 | 73 | - To obtain denoised predictions, run 74 | ``` 75 | python test_gaussian_gray_denoising.py --model_type blind --sigmas 15,25,50 76 | ``` 77 | 78 | - To reproduce PSNR Table 4 (top super-row), run 79 | ``` 80 | python evaluate_gaussian_gray_denoising.py --model_type blind --sigmas 15,25,50 81 | ``` 82 | 83 | #### Grayscale non-blind image denoising testing 84 | 85 | - To obtain denoised predictions, run 86 | ``` 87 | python test_gaussian_gray_denoising.py --model_type non_blind --sigmas 15,25,50 88 | ``` 89 | 90 | - To reproduce PSNR Table 4 (bottom super-row), run 91 | ``` 92 | python evaluate_gaussian_gray_denoising.py --model_type non_blind --sigmas 15,25,50 93 | ``` 94 | 95 | #### Color blind image denoising testing 96 | 97 | - To obtain denoised predictions, run 98 | ``` 99 | python test_gaussian_color_denoising.py --model_type blind --sigmas 15,25,50 100 | ``` 101 | 102 | - To reproduce PSNR Table 5 (top super-row), run 103 | ``` 104 | python evaluate_gaussian_color_denoising.py --model_type blind --sigmas 15,25,50 105 | ``` 106 | 107 | #### Color non-blind image denoising testing 108 | 109 | - To obtain denoised predictions, run 110 | ``` 111 | python test_gaussian_color_denoising.py --model_type non_blind --sigmas 15,25,50 112 | ``` 113 | 114 | - To reproduce PSNR Table 5 (bottom super-row), run 115 | ``` 116 | python evaluate_gaussian_color_denoising.py --model_type non_blind --sigmas 15,25,50 117 | ``` 118 | 119 |
120 | 121 | # Real Image Denoising 122 | 123 | ## Training 124 | 125 | - Download SIDD training data, run 126 | ``` 127 | python download_data.py --data train --noise real 128 | ``` 129 | 130 | - Generate image patches from full-resolution training images, run 131 | ``` 132 | python generate_patches_sidd.py 133 | ``` 134 | 135 | - Train Restormer 136 | ``` 137 | cd Restormer 138 | ./train.sh Denoising/Options/RealDenoising_Restormer.yml 139 | ``` 140 | 141 | **Note:** This training script uses 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and [Denoising/Options/RealDenoising_Restormer.yml](Options/RealDenoising_Restormer.yml) 142 | 143 | ## Evaluation 144 | 145 | - Download the pre-trained [model](https://drive.google.com/file/d/1FF_4NTboTWQ7sHCq4xhyLZsSl0U0JfjH/view?usp=sharing) and place it in `./pretrained_models/` 146 | 147 | #### Testing on SIDD dataset 148 | 149 | - Download SIDD validation data, run 150 | ``` 151 | python download_data.py --noise real --data test --dataset SIDD 152 | ``` 153 | 154 | - To obtain denoised results, run 155 | ``` 156 | python test_real_denoising_sidd.py --save_images 157 | ``` 158 | 159 | - To reproduce PSNR/SSIM scores on SIDD data (Table 6), run 160 | ``` 161 | evaluate_sidd.m 162 | ``` 163 | 164 | #### Testing on DND dataset 165 | 166 | - Download the DND benchmark data, run 167 | ``` 168 | python download_data.py --noise real --data test --dataset DND 169 | ``` 170 | 171 | - To obtain denoised results, run 172 | ``` 173 | python test_real_denoising_dnd.py --save_images 174 | ``` 175 | 176 | - To reproduce PSNR/SSIM scores (Table 6), upload the results to the DND benchmark website. 177 | -------------------------------------------------------------------------------- /Denoising/download_data.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | ## Download training and testing data for Image Denoising task 6 | 7 | 8 | import os 9 | # import gdown 10 | import shutil 11 | 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data', type=str, required=True, help='train, test or train-test') 16 | parser.add_argument('--dataset', type=str, default='SIDD', help='all or SIDD or DND') 17 | parser.add_argument('--noise', type=str, required=True, help='real or gaussian') 18 | args = parser.parse_args() 19 | 20 | ### Google drive IDs ###### 21 | SIDD_train = '1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw' ## https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing 22 | SIDD_val = '1Fw6Ey1R-nCHN9WEpxv0MnMqxij-ECQYJ' ## https://drive.google.com/file/d/1Fw6Ey1R-nCHN9WEpxv0MnMqxij-ECQYJ/view?usp=sharing 23 | SIDD_test = '11vfqV-lqousZTuAit1Qkqghiv_taY0KZ' ## https://drive.google.com/file/d/11vfqV-lqousZTuAit1Qkqghiv_taY0KZ/view?usp=sharing 24 | DND_test = '1CYCDhaVxYYcXhSfEVDUwkvJDtGxeQ10G' ## https://drive.google.com/file/d/1CYCDhaVxYYcXhSfEVDUwkvJDtGxeQ10G/view?usp=sharing 25 | 26 | BSD400 = '1idKFDkAHJGAFDn1OyXZxsTbOSBx9GS8N' ## https://drive.google.com/file/d/1idKFDkAHJGAFDn1OyXZxsTbOSBx9GS8N/view?usp=sharing 27 | DIV2K = '13wLWWXvFkuYYVZMMAYiMVdSA7iVEf2fM' ## https://drive.google.com/file/d/13wLWWXvFkuYYVZMMAYiMVdSA7iVEf2fM/view?usp=sharing 28 | Flickr2K = '1J8xjFCrVzeYccD-LF08H7HiIsmi8l2Wn' ## https://drive.google.com/file/d/1J8xjFCrVzeYccD-LF08H7HiIsmi8l2Wn/view?usp=sharing 29 | WaterlooED = '19_mCE_GXfmE5yYsm-HEzuZQqmwMjPpJr' ## https://drive.google.com/file/d/19_mCE_GXfmE5yYsm-HEzuZQqmwMjPpJr/view?usp=sharing 30 | gaussian_test = '1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0' ## https://drive.google.com/file/d/1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0/view?usp=sharing 31 | 32 | 33 | noise = args.noise 34 | 35 | for data in args.data.split('-'): 36 | if noise == 'real': 37 | if data == 'train': 38 | print('SIDD Training Data!') 39 | os.makedirs(os.path.join('Datasets', 'Downloads'), exist_ok=True) 40 | # gdown.download(id=SIDD_train, output='Datasets/Downloads/train.zip', quiet=False) 41 | os.system(f'gdrive download {SIDD_train} --path Datasets/Downloads/') 42 | print('Extracting SIDD Data...') 43 | shutil.unpack_archive('Datasets/Downloads/train.zip', 'Datasets/Downloads') 44 | os.rename(os.path.join('Datasets', 'Downloads', 'train'), os.path.join('Datasets', 'Downloads', 'SIDD')) 45 | os.remove('Datasets/Downloads/train.zip') 46 | 47 | print('SIDD Validation Data!') 48 | # gdown.download(id=SIDD_val, output='Datasets/val.zip', quiet=False) 49 | os.system(f'gdrive download {SIDD_val} --path Datasets/') 50 | print('Extracting SIDD Data...') 51 | shutil.unpack_archive('Datasets/val.zip', 'Datasets') 52 | os.remove('Datasets/val.zip') 53 | 54 | if data == 'test': 55 | if args.dataset == 'all' or args.dataset == 'SIDD': 56 | print('SIDD Testing Data!') 57 | # gdown.download(id=SIDD_test, output='Datasets/test.zip', quiet=False) 58 | os.system(f'gdrive download {SIDD_test} --path Datasets/') 59 | print('Extracting SIDD Data...') 60 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 61 | os.remove('Datasets/test.zip') 62 | 63 | if args.dataset == 'all' or args.dataset == 'DND': 64 | print('DND Testing Data!') 65 | # gdown.download(id=DND_test, output='Datasets/test.zip', quiet=False) 66 | os.system(f'gdrive download {DND_test} --path Datasets/') 67 | print('Extracting DND data...') 68 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 69 | os.remove('Datasets/test.zip') 70 | 71 | if noise == 'gaussian': 72 | if data == 'train': 73 | os.makedirs(os.path.join('Datasets', 'Downloads'), exist_ok=True) 74 | print('WaterlooED Training Data!') 75 | # gdown.download(id=WaterlooED, output='Datasets/Downloads/WaterlooED.zip', quiet=False) 76 | os.system(f'gdrive download {WaterlooED} --path Datasets/Downloads/') 77 | print('Extracting WaterlooED Data...') 78 | shutil.unpack_archive('Datasets/Downloads/WaterlooED.zip', 'Datasets/Downloads') 79 | os.remove('Datasets/Downloads/WaterlooED.zip') 80 | 81 | print('DIV2K Training Data!') 82 | # gdown.download(id=DIV2K, output='Datasets/Downloads/DIV2K.zip', quiet=False) 83 | os.system(f'gdrive download {DIV2K} --path Datasets/Downloads/') 84 | print('Extracting DIV2K Data...') 85 | shutil.unpack_archive('Datasets/Downloads/DIV2K.zip', 'Datasets/Downloads') 86 | os.remove('Datasets/Downloads/DIV2K.zip') 87 | 88 | 89 | print('BSD400 Training Data!') 90 | # gdown.download(id=BSD400, output='Datasets/Downloads/BSD400.zip', quiet=False) 91 | os.system(f'gdrive download {BSD400} --path Datasets/Downloads/') 92 | print('Extracting BSD400 data...') 93 | shutil.unpack_archive('Datasets/Downloads/BSD400.zip', 'Datasets/Downloads') 94 | os.remove('Datasets/Downloads/BSD400.zip') 95 | 96 | print('Flickr2K Training Data!') 97 | # gdown.download(id=Flickr2K, output='Datasets/Downloads/Flickr2K.zip', quiet=False) 98 | os.system(f'gdrive download {Flickr2K} --path Datasets/Downloads/') 99 | print('Extracting Flickr2K data...') 100 | shutil.unpack_archive('Datasets/Downloads/Flickr2K.zip', 'Datasets/Downloads') 101 | os.remove('Datasets/Downloads/Flickr2K.zip') 102 | 103 | if data == 'test': 104 | print('Gaussian Denoising Testing Data!') 105 | # gdown.download(id=gaussian_test, output='Datasets/test.zip', quiet=False) 106 | os.system(f'gdrive download {gaussian_test} --path Datasets/') 107 | print('Extracting Data...') 108 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 109 | os.remove('Datasets/test.zip') 110 | 111 | # print('Download completed successfully!') 112 | -------------------------------------------------------------------------------- /Denoising/evaluate_gaussian_color_denoising.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | import argparse 12 | from skimage.metrics import structural_similarity 13 | from tqdm import tqdm 14 | import concurrent.futures 15 | import utils 16 | 17 | def proc(filename): 18 | tar,prd = filename 19 | tar_img = utils.load_img(tar) 20 | prd_img = utils.load_img(prd) 21 | 22 | PSNR = utils.calculate_psnr(tar_img, prd_img) 23 | # SSIM = utils.calculate_ssim(tar_img, prd_img) 24 | return PSNR 25 | 26 | parser = argparse.ArgumentParser(description='Gasussian Color Denoising using Restormer') 27 | 28 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.') 29 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values') 30 | 31 | args = parser.parse_args() 32 | 33 | sigmas = np.int_(args.sigmas.split(',')) 34 | 35 | datasets = ['CBSD68', 'Kodak', 'McMaster','Urban100'] 36 | 37 | for dataset in datasets: 38 | 39 | gt_path = os.path.join('Datasets','test', dataset) 40 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif'))) 41 | assert len(gt_list) != 0, "Target files not found" 42 | 43 | for sigma_test in sigmas: 44 | file_path = os.path.join('results', 'Gaussian_Color_Denoising', args.model_type, dataset, str(sigma_test)) 45 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif'))) 46 | assert len(path_list) != 0, "Predicted files not found" 47 | 48 | psnr, ssim = [], [] 49 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 50 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 51 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 52 | psnr.append(PSNR_SSIM) 53 | # ssim.append(PSNR_SSIM[1]) 54 | 55 | avg_psnr = sum(psnr)/len(psnr) 56 | # avg_ssim = sum(ssim)/len(ssim) 57 | 58 | print('For {:s} dataset Noise Level {:d} PSNR: {:f}\n'.format(dataset, sigma_test, avg_psnr)) 59 | # print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 60 | -------------------------------------------------------------------------------- /Denoising/evaluate_gaussian_gray_denoising.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | import argparse 12 | from skimage.metrics import structural_similarity 13 | from tqdm import tqdm 14 | import concurrent.futures 15 | import utils 16 | 17 | def proc(filename): 18 | tar,prd = filename 19 | tar_img = utils.load_gray_img(tar) 20 | prd_img = utils.load_gray_img(prd) 21 | 22 | PSNR = utils.calculate_psnr(tar_img, prd_img) 23 | # SSIM = utils.calculate_ssim(tar_img, prd_img) 24 | return PSNR 25 | 26 | parser = argparse.ArgumentParser(description='Gasussian Grayscale Denoising using Restormer') 27 | 28 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.') 29 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values') 30 | 31 | args = parser.parse_args() 32 | 33 | sigmas = np.int_(args.sigmas.split(',')) 34 | 35 | datasets = ['Set12', 'BSD68', 'Urban100'] 36 | 37 | for dataset in datasets: 38 | 39 | gt_path = os.path.join('Datasets','test', dataset) 40 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif'))) 41 | assert len(gt_list) != 0, "Target files not found" 42 | 43 | for sigma_test in sigmas: 44 | file_path = os.path.join('results', 'Gaussian_Gray_Denoising', args.model_type, dataset, str(sigma_test)) 45 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif'))) 46 | assert len(path_list) != 0, "Predicted files not found" 47 | 48 | psnr, ssim = [], [] 49 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 50 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 51 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 52 | psnr.append(PSNR_SSIM) 53 | # ssim.append(PSNR_SSIM[1]) 54 | 55 | avg_psnr = sum(psnr)/len(psnr) 56 | # avg_ssim = sum(ssim)/len(ssim) 57 | 58 | print('For {:s} dataset Noise Level {:d} PSNR: {:f}\n'.format(dataset, sigma_test, avg_psnr)) 59 | # print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 60 | -------------------------------------------------------------------------------- /Denoising/evaluate_sidd.m: -------------------------------------------------------------------------------- 1 | close all;clear all; 2 | 3 | denoised = load('./results/Real_Denoising/SIDD/mat/Idenoised.mat'); 4 | gt = load('./Datasets/test/SIDD/ValidationGtBlocksSrgb.mat'); 5 | 6 | denoised = denoised.Idenoised; 7 | gt = gt.ValidationGtBlocksSrgb; 8 | gt = im2single(gt); 9 | 10 | total_psnr = 0; 11 | total_ssim = 0; 12 | for i = 1:40 13 | for k = 1:32 14 | denoised_patch = squeeze(denoised(i,k,:,:,:)); 15 | gt_patch = squeeze(gt(i,k,:,:,:)); 16 | ssim_val = ssim(denoised_patch, gt_patch); 17 | psnr_val = psnr(denoised_patch, gt_patch); 18 | total_ssim = total_ssim + ssim_val; 19 | total_psnr = total_psnr + psnr_val; 20 | end 21 | end 22 | qm_psnr = total_psnr / (40*32); 23 | qm_ssim = total_ssim / (40*32); 24 | 25 | fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim); 26 | 27 | -------------------------------------------------------------------------------- /Denoising/generate_patches_dfwb.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from natsort import natsorted 6 | import os 7 | from tqdm import tqdm 8 | from pdb import set_trace as stx 9 | 10 | src = 'Datasets/Downloads' 11 | tar = 'Datasets/train/DFWB' 12 | os.makedirs(tar, exist_ok=True) 13 | 14 | patch_size = 512 15 | overlap = 96 16 | p_max = 800 17 | 18 | 19 | def save_files(file_): 20 | path_contents = file_.split(os.sep) 21 | foldname = path_contents[-2] 22 | filename = os.path.splitext(path_contents[-1])[0] 23 | img = cv2.imread(file_) 24 | num_patch = 0 25 | w, h = img.shape[:2] 26 | if w > p_max and h > p_max: 27 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int)) 28 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int)) 29 | w1.append(w-patch_size) 30 | h1.append(h-patch_size) 31 | for i in w1: 32 | for j in h1: 33 | num_patch += 1 34 | patch = img[i:i+patch_size, j:j+patch_size,:] 35 | savename = os.path.join(tar, foldname + '-' + filename + '-' + str(num_patch) + '.png') 36 | cv2.imwrite(savename, patch) 37 | 38 | else: 39 | savename = os.path.join(tar, foldname + '-' + filename + '.png') 40 | cv2.imwrite(savename, img) 41 | 42 | 43 | files = [] 44 | for dataset in ['DIV2K', 'Flickr2K', 'WaterlooED', 'BSD400']: 45 | df = natsorted(glob(os.path.join(src, dataset, '*.png')) + glob(os.path.join(src, dataset, '*.jpg')) + glob(os.path.join(src, dataset, '*.bmp'))) 46 | files.extend(df) 47 | 48 | from joblib import Parallel, delayed 49 | import multiprocessing 50 | num_cores = 10 51 | Parallel(n_jobs=num_cores)(delayed(save_files)(file_) for file_ in tqdm(files)) 52 | -------------------------------------------------------------------------------- /Denoising/generate_patches_sidd.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from natsort import natsorted 6 | import os 7 | from tqdm import tqdm 8 | from pdb import set_trace as stx 9 | 10 | 11 | src = 'Datasets/Downloads/SIDD' 12 | tar = 'Datasets/train/SIDD' 13 | 14 | lr_tar = os.path.join(tar, 'input_crops') 15 | hr_tar = os.path.join(tar, 'target_crops') 16 | 17 | os.makedirs(lr_tar, exist_ok=True) 18 | os.makedirs(hr_tar, exist_ok=True) 19 | 20 | files = natsorted(glob(os.path.join(src, '*', '*.PNG'))) 21 | 22 | lr_files, hr_files = [], [] 23 | for file_ in files: 24 | filename = os.path.split(file_)[-1] 25 | if 'GT' in filename: 26 | hr_files.append(file_) 27 | if 'NOISY' in filename: 28 | lr_files.append(file_) 29 | 30 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 31 | 32 | patch_size = 512 33 | overlap = 128 34 | p_max = 0 35 | 36 | def save_files(file_): 37 | lr_file, hr_file = file_ 38 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0] 39 | lr_img = cv2.imread(lr_file) 40 | hr_img = cv2.imread(hr_file) 41 | num_patch = 0 42 | w, h = lr_img.shape[:2] 43 | if w > p_max and h > p_max: 44 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int)) 45 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int)) 46 | w1.append(w-patch_size) 47 | h1.append(h-patch_size) 48 | for i in w1: 49 | for j in h1: 50 | num_patch += 1 51 | 52 | lr_patch = lr_img[i:i+patch_size, j:j+patch_size,:] 53 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:] 54 | 55 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.png') 56 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.png') 57 | 58 | cv2.imwrite(lr_savename, lr_patch) 59 | cv2.imwrite(hr_savename, hr_patch) 60 | 61 | else: 62 | lr_savename = os.path.join(lr_tar, filename + '.png') 63 | hr_savename = os.path.join(hr_tar, filename + '.png') 64 | 65 | cv2.imwrite(lr_savename, lr_img) 66 | cv2.imwrite(hr_savename, hr_img) 67 | 68 | from joblib import Parallel, delayed 69 | import multiprocessing 70 | num_cores = 10 71 | Parallel(n_jobs=num_cores)(delayed(save_files)(file_) for file_ in tqdm(files)) 72 | -------------------------------------------------------------------------------- /Denoising/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | pre-trained denoising model is available [here](https://drive.google.com/drive/folders/1Qwsjyny54RZWa7zC4Apg7exixLBo4uF0?usp=sharing) -------------------------------------------------------------------------------- /Denoising/test_gaussian_color_denoising.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import torch.nn as nn 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | from basicsr.models.archs.restormer_arch import Restormer 15 | from skimage import img_as_ubyte 16 | from natsort import natsorted 17 | from glob import glob 18 | import utils 19 | from pdb import set_trace as stx 20 | 21 | parser = argparse.ArgumentParser(description='Gaussian Color Denoising using Restormer') 22 | 23 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images') 24 | parser.add_argument('--result_dir', default='./results/Gaussian_Color_Denoising/', type=str, help='Directory for results') 25 | parser.add_argument('--weights', default='./pretrained_models/gaussian_color_denoising', type=str, help='Path to weights') 26 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.') 27 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values') 28 | 29 | args = parser.parse_args() 30 | 31 | ####### Load yaml ####### 32 | if args.model_type == 'blind': 33 | yaml_file = 'Options/GaussianColorDenoising_Restormer.yml' 34 | else: 35 | yaml_file = f'Options/GaussianColorDenoising_RestormerSigma{args.sigmas}.yml' 36 | import yaml 37 | 38 | try: 39 | from yaml import CLoader as Loader 40 | except ImportError: 41 | from yaml import Loader 42 | 43 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 44 | 45 | s = x['network_g'].pop('type') 46 | ########################## 47 | 48 | sigmas = np.int_(args.sigmas.split(',')) 49 | 50 | factor = 8 51 | 52 | datasets = ['CBSD68', 'Kodak', 'McMaster','Urban100'] 53 | 54 | for sigma_test in sigmas: 55 | print("Compute results for noise level",sigma_test) 56 | model_restoration = Restormer(**x['network_g']) 57 | if args.model_type == 'blind': 58 | weights = args.weights+'_blind.pth' 59 | else: 60 | weights = args.weights + '_sigma' + str(sigma_test) +'.pth' 61 | checkpoint = torch.load(weights) 62 | model_restoration.load_state_dict(checkpoint['params']) 63 | 64 | print("===>Testing using weights: ",weights) 65 | print("------------------------------------------------") 66 | model_restoration.cuda() 67 | model_restoration = nn.DataParallel(model_restoration) 68 | model_restoration.eval() 69 | 70 | for dataset in datasets: 71 | inp_dir = os.path.join(args.input_dir, dataset) 72 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif'))) 73 | result_dir_tmp = os.path.join(args.result_dir, args.model_type, dataset, str(sigma_test)) 74 | os.makedirs(result_dir_tmp, exist_ok=True) 75 | 76 | with torch.no_grad(): 77 | for file_ in tqdm(files): 78 | torch.cuda.ipc_collect() 79 | torch.cuda.empty_cache() 80 | img = np.float32(utils.load_img(file_))/255. 81 | 82 | np.random.seed(seed=0) # for reproducibility 83 | img += np.random.normal(0, sigma_test/255., img.shape) 84 | 85 | img = torch.from_numpy(img).permute(2,0,1) 86 | input_ = img.unsqueeze(0).cuda() 87 | 88 | # Padding in case images are not multiples of 8 89 | h,w = input_.shape[2], input_.shape[3] 90 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 91 | padh = H-h if h%factor!=0 else 0 92 | padw = W-w if w%factor!=0 else 0 93 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 94 | 95 | restored = model_restoration(input_) 96 | 97 | # Unpad images to original dimensions 98 | restored = restored[:,:,:h,:w] 99 | 100 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 101 | 102 | save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1]) 103 | utils.save_img(save_file, img_as_ubyte(restored)) 104 | -------------------------------------------------------------------------------- /Denoising/test_gaussian_gray_denoising.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import torch.nn as nn 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | from basicsr.models.archs.restormer_arch import Restormer 15 | from skimage import img_as_ubyte 16 | from natsort import natsorted 17 | from glob import glob 18 | import utils 19 | from pdb import set_trace as stx 20 | 21 | parser = argparse.ArgumentParser(description='Gasussian Grayscale Denoising using Restormer') 22 | 23 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images') 24 | parser.add_argument('--result_dir', default='./results/Gaussian_Gray_Denoising/', type=str, help='Directory for results') 25 | parser.add_argument('--weights', default='./pretrained_models/gaussian_gray_denoising', type=str, help='Path to weights') 26 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.') 27 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values') 28 | 29 | args = parser.parse_args() 30 | 31 | ####### Load yaml ####### 32 | if args.model_type == 'blind': 33 | yaml_file = 'Options/GaussianGrayDenoising_Restormer.yml' 34 | else: 35 | yaml_file = f'Options/GaussianGrayDenoising_RestormerSigma{args.sigmas}.yml' 36 | import yaml 37 | 38 | try: 39 | from yaml import CLoader as Loader 40 | except ImportError: 41 | from yaml import Loader 42 | 43 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 44 | 45 | s = x['network_g'].pop('type') 46 | ########################## 47 | 48 | sigmas = np.int_(args.sigmas.split(',')) 49 | 50 | factor = 8 51 | 52 | datasets = ['Set12', 'BSD68', 'Urban100'] 53 | 54 | for sigma_test in sigmas: 55 | print("Compute results for noise level",sigma_test) 56 | model_restoration = Restormer(**x['network_g']) 57 | if args.model_type == 'blind': 58 | weights = args.weights+'_blind.pth' 59 | else: 60 | weights = args.weights + '_sigma' + str(sigma_test) +'.pth' 61 | checkpoint = torch.load(weights) 62 | model_restoration.load_state_dict(checkpoint['params']) 63 | 64 | print("===>Testing using weights: ",weights) 65 | print("------------------------------------------------") 66 | model_restoration.cuda() 67 | model_restoration = nn.DataParallel(model_restoration) 68 | model_restoration.eval() 69 | 70 | for dataset in datasets: 71 | inp_dir = os.path.join(args.input_dir, dataset) 72 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif'))) 73 | result_dir_tmp = os.path.join(args.result_dir, args.model_type, dataset, str(sigma_test)) 74 | os.makedirs(result_dir_tmp, exist_ok=True) 75 | 76 | with torch.no_grad(): 77 | for file_ in tqdm(files): 78 | torch.cuda.ipc_collect() 79 | torch.cuda.empty_cache() 80 | img = np.float32(utils.load_gray_img(file_))/255. 81 | 82 | np.random.seed(seed=0) # for reproducibility 83 | img += np.random.normal(0, sigma_test/255., img.shape) 84 | 85 | img = torch.from_numpy(img).permute(2,0,1) 86 | input_ = img.unsqueeze(0).cuda() 87 | 88 | # Padding in case images are not multiples of 8 89 | h,w = input_.shape[2], input_.shape[3] 90 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 91 | padh = H-h if h%factor!=0 else 0 92 | padw = W-w if w%factor!=0 else 0 93 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 94 | 95 | restored = model_restoration(input_) 96 | 97 | # Unpad images to original dimensions 98 | restored = restored[:,:,:h,:w] 99 | 100 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 101 | 102 | save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1]) 103 | utils.save_gray_img(save_file, img_as_ubyte(restored)) 104 | -------------------------------------------------------------------------------- /Denoising/test_real_denoising_dnd.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import utils 14 | 15 | from basicsr.models.archs.restormer_arch import Restormer 16 | from skimage import img_as_ubyte 17 | import h5py 18 | import scipy.io as sio 19 | from pdb import set_trace as stx 20 | 21 | parser = argparse.ArgumentParser(description='Real Image Denoising using Restormer') 22 | 23 | parser.add_argument('--input_dir', default='./Datasets/test/DND/', type=str, help='Directory of validation images') 24 | parser.add_argument('--result_dir', default='./results/Real_Denoising/DND/', type=str, help='Directory for results') 25 | parser.add_argument('--weights', default='./pretrained_models/real_denoising.pth', type=str, help='Path to weights') 26 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 27 | 28 | args = parser.parse_args() 29 | 30 | ####### Load yaml ####### 31 | yaml_file = 'Options/RealDenoising_Restormer.yml' 32 | import yaml 33 | 34 | try: 35 | from yaml import CLoader as Loader 36 | except ImportError: 37 | from yaml import Loader 38 | 39 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 40 | 41 | s = x['network_g'].pop('type') 42 | ########################## 43 | 44 | result_dir_mat = os.path.join(args.result_dir, 'mat') 45 | os.makedirs(result_dir_mat, exist_ok=True) 46 | 47 | if args.save_images: 48 | result_dir_png = os.path.join(args.result_dir, 'png') 49 | os.makedirs(result_dir_png, exist_ok=True) 50 | 51 | model_restoration = Restormer(**x['network_g']) 52 | 53 | checkpoint = torch.load(args.weights) 54 | model_restoration.load_state_dict(checkpoint['params']) 55 | print("===>Testing using weights: ",args.weights) 56 | model_restoration.cuda() 57 | model_restoration = nn.DataParallel(model_restoration) 58 | model_restoration.eval() 59 | 60 | israw = False 61 | eval_version="1.0" 62 | 63 | # Load info 64 | infos = h5py.File(os.path.join(args.input_dir, 'info.mat'), 'r') 65 | info = infos['info'] 66 | bb = info['boundingboxes'] 67 | 68 | # Process data 69 | with torch.no_grad(): 70 | for i in tqdm(range(50)): 71 | Idenoised = np.zeros((20,), dtype=np.object) 72 | filename = '%04d.mat'%(i+1) 73 | filepath = os.path.join(args.input_dir, 'images_srgb', filename) 74 | img = h5py.File(filepath, 'r') 75 | Inoisy = np.float32(np.array(img['InoisySRGB']).T) 76 | 77 | # bounding box 78 | ref = bb[0][i] 79 | boxes = np.array(info[ref]).T 80 | 81 | for k in range(20): 82 | idx = [int(boxes[k,0]-1),int(boxes[k,2]),int(boxes[k,1]-1),int(boxes[k,3])] 83 | noisy_patch = torch.from_numpy(Inoisy[idx[0]:idx[1],idx[2]:idx[3],:]).unsqueeze(0).permute(0,3,1,2).cuda() 84 | restored_patch = model_restoration(noisy_patch) 85 | restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 86 | Idenoised[k] = restored_patch 87 | 88 | if args.save_images: 89 | save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1)) 90 | denoised_img = img_as_ubyte(restored_patch) 91 | utils.save_img(save_file, denoised_img) 92 | 93 | # save denoised data 94 | sio.savemat(os.path.join(result_dir_mat, filename), 95 | {"Idenoised": Idenoised, 96 | "israw": israw, 97 | "eval_version": eval_version}, 98 | ) 99 | -------------------------------------------------------------------------------- /Denoising/test_real_denoising_sidd.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import utils 14 | 15 | from basicsr.models.archs.restormer_arch import Restormer 16 | from skimage import img_as_ubyte 17 | import h5py 18 | import scipy.io as sio 19 | from pdb import set_trace as stx 20 | 21 | parser = argparse.ArgumentParser(description='Real Image Denoising using Restormer') 22 | 23 | parser.add_argument('--input_dir', default='./Datasets/test/SIDD/', type=str, help='Directory of validation images') 24 | parser.add_argument('--result_dir', default='./results/Real_Denoising/SIDD/', type=str, help='Directory for results') 25 | parser.add_argument('--weights', default='./pretrained_models/real_denoising.pth', type=str, help='Path to weights') 26 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 27 | 28 | args = parser.parse_args() 29 | 30 | ####### Load yaml ####### 31 | yaml_file = 'Options/RealDenoising_Restormer.yml' 32 | import yaml 33 | 34 | try: 35 | from yaml import CLoader as Loader 36 | except ImportError: 37 | from yaml import Loader 38 | 39 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 40 | 41 | s = x['network_g'].pop('type') 42 | ########################## 43 | 44 | result_dir_mat = os.path.join(args.result_dir, 'mat') 45 | os.makedirs(result_dir_mat, exist_ok=True) 46 | 47 | if args.save_images: 48 | result_dir_png = os.path.join(args.result_dir, 'png') 49 | os.makedirs(result_dir_png, exist_ok=True) 50 | 51 | model_restoration = Restormer(**x['network_g']) 52 | 53 | checkpoint = torch.load(args.weights) 54 | model_restoration.load_state_dict(checkpoint['params']) 55 | print("===>Testing using weights: ",args.weights) 56 | model_restoration.cuda() 57 | model_restoration = nn.DataParallel(model_restoration) 58 | model_restoration.eval() 59 | 60 | # Process data 61 | filepath = os.path.join(args.input_dir, 'ValidationNoisyBlocksSrgb.mat') 62 | img = sio.loadmat(filepath) 63 | Inoisy = np.float32(np.array(img['ValidationNoisyBlocksSrgb'])) 64 | Inoisy /=255. 65 | restored = np.zeros_like(Inoisy) 66 | with torch.no_grad(): 67 | for i in tqdm(range(40)): 68 | for k in range(32): 69 | noisy_patch = torch.from_numpy(Inoisy[i,k,:,:,:]).unsqueeze(0).permute(0,3,1,2).cuda() 70 | restored_patch = model_restoration(noisy_patch) 71 | restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0) 72 | restored[i,k,:,:,:] = restored_patch 73 | 74 | if args.save_images: 75 | save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1)) 76 | utils.save_img(save_file, img_as_ubyte(restored_patch)) 77 | 78 | # save denoised data 79 | sio.savemat(os.path.join(result_dir_mat, 'Idenoised.mat'), {"Idenoised": restored,}) 80 | -------------------------------------------------------------------------------- /Denoising/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /Deraining/Datasets/README.md: -------------------------------------------------------------------------------- 1 | For training and testing, your directory structure should look like this 2 | 3 | `Datasets`
4 |  `├──train`
5 |      `└──Rain13K`
6 |           `├──input`
7 |           `└──target`
8 |  `└──test`
9 |      `├──Test100`
10 |           `├──input`
11 |           `└──target`
12 |      `├──Rain100H`
13 |           `├──input`
14 |           `└──target`
15 |      `├──Rain100L`
16 |           `├──input`
17 |           `└──target`
18 |      `├──Test1200`
19 |           `├──input`
20 |           `└──target`
21 |      `└──Test2800`
22 |           `├──input`
23 |           `└──target` 24 | -------------------------------------------------------------------------------- /Deraining/Options/Deraining_Restormer.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: Deraining_Restormer 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 8 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: ./Deraining/Datasets/train/Rain13K/target 14 | dataroot_lq: ./Deraining/Datasets/train/Rain13K/input 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### -------------Progressive training-------------------------- 27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | iters: [92000,64000,48000,36000,36000,24000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | ### ------- Training on single fixed-patch size 128x128--------- 34 | # mini_batch_sizes: [8] 35 | # iters: [300000] 36 | # gt_size: 128 37 | # gt_sizes: [128] 38 | ### ------------------------------------------------------------ 39 | 40 | dataset_enlarge_ratio: 1 41 | prefetch_mode: ~ 42 | 43 | val: 44 | name: ValSet 45 | type: Dataset_PairedImage 46 | dataroot_gt: ./Deraining/Datasets/test/Rain100L/target 47 | dataroot_lq: ./Deraining/Datasets/test/Rain100L/input 48 | io_backend: 49 | type: disk 50 | 51 | # network structures 52 | network_g: 53 | type: Restormer 54 | inp_channels: 3 55 | out_channels: 3 56 | dim: 48 57 | num_blocks: [4,6,6,8] 58 | num_refinement_blocks: 4 59 | heads: [1,2,4,8] 60 | ffn_expansion_factor: 2.66 61 | bias: False 62 | LayerNorm_type: WithBias 63 | dual_pixel_task: False 64 | 65 | 66 | # path 67 | path: 68 | pretrain_network_g: ~ 69 | strict_load_g: true 70 | resume_state: ~ 71 | 72 | # training settings 73 | train: 74 | total_iter: 300000 75 | warmup_iter: -1 # no warm up 76 | use_grad_clip: true 77 | 78 | # Split 300k iterations into two cycles. 79 | # 1st cycle: fixed 3e-4 LR for 92k iters. 80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 81 | scheduler: 82 | type: CosineAnnealingRestartCyclicLR 83 | periods: [92000, 208000] 84 | restart_weights: [1,1] 85 | eta_mins: [0.0003,0.000001] 86 | 87 | mixing_augs: 88 | mixup: false 89 | mixup_beta: 1.2 90 | use_identity: true 91 | 92 | optim_g: 93 | type: AdamW 94 | lr: !!float 3e-4 95 | weight_decay: !!float 1e-4 96 | betas: [0.9, 0.999] 97 | 98 | # losses 99 | pixel_opt: 100 | type: L1Loss 101 | loss_weight: 1 102 | reduction: mean 103 | 104 | # validation settings 105 | val: 106 | window_size: 8 107 | val_freq: !!float 4e3 108 | save_img: false 109 | rgb2bgr: true 110 | use_image: true 111 | max_minibatch: 8 112 | 113 | metrics: 114 | psnr: # metric name, can be arbitrary 115 | type: calculate_psnr 116 | crop_border: 0 117 | test_y_channel: true 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 1000 122 | save_checkpoint_freq: !!float 4e3 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /Deraining/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Training 3 | 4 | 1. To download Rain13K training and testing data, run 5 | ``` 6 | python download_data.py --data train-test 7 | ``` 8 | 9 | 2. To train Restormer with default settings, run 10 | ``` 11 | cd Restormer 12 | ./train.sh Deraining/Options/Deraining_Restormer.yml 13 | ``` 14 | 15 | **Note:** The above training script uses 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and [Deraining/Options/Deraining_Restormer.yml](Options/Deraining_Restormer.yml) 16 | 17 | ## Evaluation 18 | 19 | 1. Download the pre-trained [model](https://drive.google.com/drive/folders/1ZEDDEVW0UgkpWi-N4Lj_JUoVChGXCu_u?usp=sharing) and place it in `./pretrained_models/` 20 | 21 | 2. Download test datasets (Test100, Rain100H, Rain100L, Test1200, Test2800), run 22 | ``` 23 | python download_data.py --data test 24 | ``` 25 | 26 | 3. Testing 27 | ``` 28 | python test.py 29 | ``` 30 | 31 | #### To reproduce PSNR/SSIM scores of Table 1, run 32 | 33 | ``` 34 | evaluate_PSNR_SSIM.m 35 | ``` 36 | -------------------------------------------------------------------------------- /Deraining/download_data.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | ## Download training and testing data for image deraining task 6 | import os 7 | # import gdown 8 | import shutil 9 | 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--data', type=str, required=True, help='train, test or train-test') 14 | args = parser.parse_args() 15 | 16 | ### Google drive IDs ###### 17 | rain13k_train = '14BidJeG4nSNuFNFDf99K-7eErCq4i47t' ## https://drive.google.com/file/d/14BidJeG4nSNuFNFDf99K-7eErCq4i47t/view?usp=sharing 18 | rain13k_test = '1P_-RAvltEoEhfT-9GrWRdpEi6NSswTs8' ## https://drive.google.com/file/d/1P_-RAvltEoEhfT-9GrWRdpEi6NSswTs8/view?usp=sharing 19 | 20 | for data in args.data.split('-'): 21 | if data == 'train': 22 | print('Rain13K Training Data!') 23 | # gdown.download(id=rain13k_train, output='Datasets/train.zip', quiet=False) 24 | os.system(f'gdrive download {rain13k_train} --path Datasets/') 25 | print('Extracting Rain13K data...') 26 | shutil.unpack_archive('Datasets/train.zip', 'Datasets') 27 | os.remove('Datasets/train.zip') 28 | 29 | if data == 'test': 30 | print('Download Deraining Testing Data') 31 | # gdown.download(id=rain13k_test, output='Datasets/test.zip', quiet=False) 32 | os.system(f'gdrive download {rain13k_test} --path Datasets/') 33 | print('Extracting test data...') 34 | shutil.unpack_archive('Datasets/test.zip', 'Datasets') 35 | os.remove('Datasets/test.zip') 36 | 37 | 38 | # print('Download completed successfully!') 39 | 40 | 41 | -------------------------------------------------------------------------------- /Deraining/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | pre-trained deraining model is available [here](https://drive.google.com/drive/folders/1ZEDDEVW0UgkpWi-N4Lj_JUoVChGXCu_u?usp=sharing) -------------------------------------------------------------------------------- /Deraining/test.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | import utils 16 | 17 | from natsort import natsorted 18 | from glob import glob 19 | from basicsr.models.archs.restormer_arch import Restormer 20 | from skimage import img_as_ubyte 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Deraining using Restormer') 24 | 25 | parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./pretrained_models/deraining.pth', type=str, help='Path to weights') 28 | 29 | args = parser.parse_args() 30 | 31 | ####### Load yaml ####### 32 | yaml_file = 'Options/Deraining_Restormer.yml' 33 | import yaml 34 | 35 | try: 36 | from yaml import CLoader as Loader 37 | except ImportError: 38 | from yaml import Loader 39 | 40 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader) 41 | 42 | s = x['network_g'].pop('type') 43 | ########################## 44 | 45 | model_restoration = Restormer(**x['network_g']) 46 | 47 | checkpoint = torch.load(args.weights) 48 | model_restoration.load_state_dict(checkpoint['params']) 49 | print("===>Testing using weights: ",args.weights) 50 | model_restoration.cuda() 51 | model_restoration = nn.DataParallel(model_restoration) 52 | model_restoration.eval() 53 | 54 | 55 | factor = 8 56 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] 57 | 58 | for dataset in datasets: 59 | result_dir = os.path.join(args.result_dir, dataset) 60 | os.makedirs(result_dir, exist_ok=True) 61 | 62 | inp_dir = os.path.join(args.input_dir, 'test', dataset, 'input') 63 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg'))) 64 | with torch.no_grad(): 65 | for file_ in tqdm(files): 66 | torch.cuda.ipc_collect() 67 | torch.cuda.empty_cache() 68 | 69 | img = np.float32(utils.load_img(file_))/255. 70 | img = torch.from_numpy(img).permute(2,0,1) 71 | input_ = img.unsqueeze(0).cuda() 72 | 73 | # Padding in case images are not multiples of 8 74 | h,w = input_.shape[2], input_.shape[3] 75 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 76 | padh = H-h if h%factor!=0 else 0 77 | padw = W-w if w%factor!=0 else 0 78 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 79 | 80 | restored = model_restoration(input_) 81 | 82 | # Unpad images to original dimensions 83 | restored = restored[:,:,:h,:w] 84 | 85 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 86 | 87 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored)) 88 | -------------------------------------------------------------------------------- /Deraining/utils.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | import numpy as np 6 | import os 7 | import cv2 8 | import math 9 | 10 | def calculate_psnr(img1, img2, border=0): 11 | # img1 and img2 have range [0, 255] 12 | #img1 = img1.squeeze() 13 | #img2 = img2.squeeze() 14 | if not img1.shape == img2.shape: 15 | raise ValueError('Input images must have the same dimensions.') 16 | h, w = img1.shape[:2] 17 | img1 = img1[border:h-border, border:w-border] 18 | img2 = img2[border:h-border, border:w-border] 19 | 20 | img1 = img1.astype(np.float64) 21 | img2 = img2.astype(np.float64) 22 | mse = np.mean((img1 - img2)**2) 23 | if mse == 0: 24 | return float('inf') 25 | return 20 * math.log10(255.0 / math.sqrt(mse)) 26 | 27 | 28 | # -------------------------------------------- 29 | # SSIM 30 | # -------------------------------------------- 31 | def calculate_ssim(img1, img2, border=0): 32 | '''calculate SSIM 33 | the same outputs as MATLAB's 34 | img1, img2: [0, 255] 35 | ''' 36 | #img1 = img1.squeeze() 37 | #img2 = img2.squeeze() 38 | if not img1.shape == img2.shape: 39 | raise ValueError('Input images must have the same dimensions.') 40 | h, w = img1.shape[:2] 41 | img1 = img1[border:h-border, border:w-border] 42 | img2 = img2[border:h-border, border:w-border] 43 | 44 | if img1.ndim == 2: 45 | return ssim(img1, img2) 46 | elif img1.ndim == 3: 47 | if img1.shape[2] == 3: 48 | ssims = [] 49 | for i in range(3): 50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 51 | return np.array(ssims).mean() 52 | elif img1.shape[2] == 1: 53 | return ssim(np.squeeze(img1), np.squeeze(img2)) 54 | else: 55 | raise ValueError('Wrong input image dimensions.') 56 | 57 | 58 | def ssim(img1, img2): 59 | C1 = (0.01 * 255)**2 60 | C2 = (0.03 * 255)**2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1**2 70 | mu2_sq = mu2**2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 77 | (sigma1_sq + sigma2_sq + C2)) 78 | return ssim_map.mean() 79 | 80 | def load_img(filepath): 81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 82 | 83 | def save_img(filepath, img): 84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 85 | 86 | def load_gray_img(filepath): 87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2) 88 | 89 | def save_gray_img(filepath, img): 90 | cv2.imwrite(filepath, img) 91 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This repository is built in PyTorch 1.8.1 and tested on Ubuntu 16.04 environment (Python3.7, CUDA10.2, cuDNN7.6). 4 | Follow these intructions 5 | 6 | 1. Clone our repository 7 | ``` 8 | git clone https://github.com/swz30/Restormer.git 9 | cd Restormer 10 | ``` 11 | 12 | 2. Make conda environment 13 | ``` 14 | conda create -n pytorch181 python=3.7 15 | conda activate pytorch181 16 | ``` 17 | 18 | 3. Install dependencies 19 | ``` 20 | conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch 21 | pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm 22 | pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips 23 | ``` 24 | 25 | 4. Install basicsr 26 | ``` 27 | python setup.py develop --no_cuda_ext 28 | ``` 29 | 30 | ### Download datasets from Google Drive 31 | 32 | To be able to download datasets automatically you would need `go` and `gdrive` installed. 33 | 34 | 1. You can install `go` with the following 35 | ``` 36 | curl -O https://storage.googleapis.com/golang/go1.11.1.linux-amd64.tar.gz 37 | mkdir -p ~/installed 38 | tar -C ~/installed -xzf go1.11.1.linux-amd64.tar.gz 39 | mkdir -p ~/go 40 | ``` 41 | 42 | 2. Add the lines in `~/.bashrc` 43 | ``` 44 | export GOPATH=$HOME/go 45 | export PATH=$PATH:$HOME/go/bin:$HOME/installed/go/bin 46 | ``` 47 | 48 | 3. Install `gdrive` using 49 | ``` 50 | go get github.com/prasmussen/gdrive 51 | ``` 52 | 53 | 4. Close current terminal and open a new terminal. 54 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | 13 | __all__ = ['create_dataset', 'create_dataloader'] 14 | 15 | # automatically scan and import dataset modules 16 | # scan all the files under the data folder with '_dataset' in file names 17 | data_folder = osp.dirname(osp.abspath(__file__)) 18 | dataset_filenames = [ 19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 20 | if v.endswith('_dataset.py') 21 | ] 22 | # import all the dataset modules 23 | _dataset_modules = [ 24 | importlib.import_module(f'basicsr.data.{file_name}') 25 | for file_name in dataset_filenames 26 | ] 27 | 28 | 29 | def create_dataset(dataset_opt): 30 | """Create dataset. 31 | 32 | Args: 33 | dataset_opt (dict): Configuration for dataset. It constains: 34 | name (str): Dataset name. 35 | type (str): Dataset type. 36 | """ 37 | dataset_type = dataset_opt['type'] 38 | 39 | # dynamic instantiation 40 | for module in _dataset_modules: 41 | dataset_cls = getattr(module, dataset_type, None) 42 | if dataset_cls is not None: 43 | break 44 | if dataset_cls is None: 45 | raise ValueError(f'Dataset {dataset_type} is not found.') 46 | 47 | dataset = dataset_cls(dataset_opt) 48 | 49 | logger = get_root_logger() 50 | logger.info( 51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 52 | 'is created.') 53 | return dataset 54 | 55 | 56 | def create_dataloader(dataset, 57 | dataset_opt, 58 | num_gpu=1, 59 | dist=False, 60 | sampler=None, 61 | seed=None): 62 | """Create dataloader. 63 | 64 | Args: 65 | dataset (torch.utils.data.Dataset): Dataset. 66 | dataset_opt (dict): Dataset options. It contains the following keys: 67 | phase (str): 'train' or 'val'. 68 | num_worker_per_gpu (int): Number of workers for each GPU. 69 | batch_size_per_gpu (int): Training batch size for each GPU. 70 | num_gpu (int): Number of GPUs. Used only in the train phase. 71 | Default: 1. 72 | dist (bool): Whether in distributed training. Used only in the train 73 | phase. Default: False. 74 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 75 | seed (int | None): Seed. Default: None 76 | """ 77 | phase = dataset_opt['phase'] 78 | rank, _ = get_dist_info() 79 | if phase == 'train': 80 | if dist: # distributed training 81 | batch_size = dataset_opt['batch_size_per_gpu'] 82 | num_workers = dataset_opt['num_worker_per_gpu'] 83 | else: # non-distributed training 84 | multiplier = 1 if num_gpu == 0 else num_gpu 85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 87 | dataloader_args = dict( 88 | dataset=dataset, 89 | batch_size=batch_size, 90 | shuffle=False, 91 | num_workers=num_workers, 92 | sampler=sampler, 93 | drop_last=True) 94 | if sampler is None: 95 | dataloader_args['shuffle'] = True 96 | dataloader_args['worker_init_fn'] = partial( 97 | worker_init_fn, num_workers=num_workers, rank=rank, 98 | seed=seed) if seed is not None else None 99 | elif phase in ['val', 'test']: # validation 100 | dataloader_args = dict( 101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 102 | else: 103 | raise ValueError(f'Wrong dataset phase: {phase}. ' 104 | "Supported ones are 'train', 'val' and 'test'.") 105 | 106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 107 | 108 | prefetch_mode = dataset_opt.get('prefetch_mode') 109 | if prefetch_mode == 'cpu': # CPUPrefetcher 110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 111 | logger = get_root_logger() 112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 113 | f'num_prefetch_queue = {num_prefetch_queue}') 114 | return PrefetchDataLoader( 115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 116 | else: 117 | # prefetch_mode=None: Normal dataloader 118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 119 | return torch.utils.data.DataLoader(**dataloader_args) 120 | 121 | 122 | def worker_init_fn(worker_id, num_workers, rank, seed): 123 | # Set the worker seed to num_workers * rank + worker_id + seed 124 | worker_seed = num_workers * rank + worker_id + seed 125 | np.random.seed(worker_seed) 126 | random.seed(worker_seed) 127 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil( 27 | len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank:self.total_size:self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | 8 | 9 | class FFHQDataset(data.Dataset): 10 | """FFHQ dataset for StyleGAN. 11 | 12 | Args: 13 | opt (dict): Config for train datasets. It contains the following keys: 14 | dataroot_gt (str): Data root path for gt. 15 | io_backend (dict): IO backend type and other kwarg. 16 | mean (list | tuple): Image mean. 17 | std (list | tuple): Image std. 18 | use_hflip (bool): Whether to horizontally flip. 19 | 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(FFHQDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | 29 | self.gt_folder = opt['dataroot_gt'] 30 | self.mean = opt['mean'] 31 | self.std = opt['std'] 32 | 33 | if self.io_backend_opt['type'] == 'lmdb': 34 | self.io_backend_opt['db_paths'] = self.gt_folder 35 | if not self.gt_folder.endswith('.lmdb'): 36 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 37 | f'but received {self.gt_folder}') 38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 39 | self.paths = [line.split('.')[0] for line in fin] 40 | else: 41 | # FFHQ has 70000 images in total 42 | self.paths = [ 43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 44 | ] 45 | 46 | def __getitem__(self, index): 47 | if self.file_client is None: 48 | self.file_client = FileClient( 49 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | img_bytes = self.file_client.get(gt_path) 54 | img_gt = imfrombytes(img_bytes, float32=True) 55 | 56 | # random horizontal flip 57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 58 | # BGR to RGB, HWC to CHW, numpy to tensor 59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 60 | # normalize 61 | normalize(img_gt, self.mean, self.std, inplace=True) 62 | return {'gt': img_gt, 'gt_path': gt_path} 63 | 64 | def __len__(self): 65 | return len(self.paths) 66 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 001 100 (720,1280,3) 3 | 002 100 (720,1280,3) 4 | 003 100 (720,1280,3) 5 | 004 100 (720,1280,3) 6 | 005 100 (720,1280,3) 7 | 006 100 (720,1280,3) 8 | 007 100 (720,1280,3) 9 | 008 100 (720,1280,3) 10 | 009 100 (720,1280,3) 11 | 010 100 (720,1280,3) 12 | 011 100 (720,1280,3) 13 | 012 100 (720,1280,3) 14 | 013 100 (720,1280,3) 15 | 014 100 (720,1280,3) 16 | 015 100 (720,1280,3) 17 | 016 100 (720,1280,3) 18 | 017 100 (720,1280,3) 19 | 018 100 (720,1280,3) 20 | 019 100 (720,1280,3) 21 | 020 100 (720,1280,3) 22 | 021 100 (720,1280,3) 23 | 022 100 (720,1280,3) 24 | 023 100 (720,1280,3) 25 | 024 100 (720,1280,3) 26 | 025 100 (720,1280,3) 27 | 026 100 (720,1280,3) 28 | 027 100 (720,1280,3) 29 | 028 100 (720,1280,3) 30 | 029 100 (720,1280,3) 31 | 030 100 (720,1280,3) 32 | 031 100 (720,1280,3) 33 | 032 100 (720,1280,3) 34 | 033 100 (720,1280,3) 35 | 034 100 (720,1280,3) 36 | 035 100 (720,1280,3) 37 | 036 100 (720,1280,3) 38 | 037 100 (720,1280,3) 39 | 038 100 (720,1280,3) 40 | 039 100 (720,1280,3) 41 | 040 100 (720,1280,3) 42 | 041 100 (720,1280,3) 43 | 042 100 (720,1280,3) 44 | 043 100 (720,1280,3) 45 | 044 100 (720,1280,3) 46 | 045 100 (720,1280,3) 47 | 046 100 (720,1280,3) 48 | 047 100 (720,1280,3) 49 | 048 100 (720,1280,3) 50 | 049 100 (720,1280,3) 51 | 050 100 (720,1280,3) 52 | 051 100 (720,1280,3) 53 | 052 100 (720,1280,3) 54 | 053 100 (720,1280,3) 55 | 054 100 (720,1280,3) 56 | 055 100 (720,1280,3) 57 | 056 100 (720,1280,3) 58 | 057 100 (720,1280,3) 59 | 058 100 (720,1280,3) 60 | 059 100 (720,1280,3) 61 | 060 100 (720,1280,3) 62 | 061 100 (720,1280,3) 63 | 062 100 (720,1280,3) 64 | 063 100 (720,1280,3) 65 | 064 100 (720,1280,3) 66 | 065 100 (720,1280,3) 67 | 066 100 (720,1280,3) 68 | 067 100 (720,1280,3) 69 | 068 100 (720,1280,3) 70 | 069 100 (720,1280,3) 71 | 070 100 (720,1280,3) 72 | 071 100 (720,1280,3) 73 | 072 100 (720,1280,3) 74 | 073 100 (720,1280,3) 75 | 074 100 (720,1280,3) 76 | 075 100 (720,1280,3) 77 | 076 100 (720,1280,3) 78 | 077 100 (720,1280,3) 79 | 078 100 (720,1280,3) 80 | 079 100 (720,1280,3) 81 | 080 100 (720,1280,3) 82 | 081 100 (720,1280,3) 83 | 082 100 (720,1280,3) 84 | 083 100 (720,1280,3) 85 | 084 100 (720,1280,3) 86 | 085 100 (720,1280,3) 87 | 086 100 (720,1280,3) 88 | 087 100 (720,1280,3) 89 | 088 100 (720,1280,3) 90 | 089 100 (720,1280,3) 91 | 090 100 (720,1280,3) 92 | 091 100 (720,1280,3) 93 | 092 100 (720,1280,3) 94 | 093 100 (720,1280,3) 95 | 094 100 (720,1280,3) 96 | 095 100 (720,1280,3) 97 | 096 100 (720,1280,3) 98 | 097 100 (720,1280,3) 99 | 098 100 (720,1280,3) 100 | 099 100 (720,1280,3) 101 | 100 100 (720,1280,3) 102 | 101 100 (720,1280,3) 103 | 102 100 (720,1280,3) 104 | 103 100 (720,1280,3) 105 | 104 100 (720,1280,3) 106 | 105 100 (720,1280,3) 107 | 106 100 (720,1280,3) 108 | 107 100 (720,1280,3) 109 | 108 100 (720,1280,3) 110 | 109 100 (720,1280,3) 111 | 110 100 (720,1280,3) 112 | 111 100 (720,1280,3) 113 | 112 100 (720,1280,3) 114 | 113 100 (720,1280,3) 115 | 114 100 (720,1280,3) 116 | 115 100 (720,1280,3) 117 | 116 100 (720,1280,3) 118 | 117 100 (720,1280,3) 119 | 118 100 (720,1280,3) 120 | 119 100 (720,1280,3) 121 | 120 100 (720,1280,3) 122 | 121 100 (720,1280,3) 123 | 122 100 (720,1280,3) 124 | 123 100 (720,1280,3) 125 | 124 100 (720,1280,3) 126 | 125 100 (720,1280,3) 127 | 126 100 (720,1280,3) 128 | 127 100 (720,1280,3) 129 | 128 100 (720,1280,3) 130 | 129 100 (720,1280,3) 131 | 130 100 (720,1280,3) 132 | 131 100 (720,1280,3) 133 | 132 100 (720,1280,3) 134 | 133 100 (720,1280,3) 135 | 134 100 (720,1280,3) 136 | 135 100 (720,1280,3) 137 | 136 100 (720,1280,3) 138 | 137 100 (720,1280,3) 139 | 138 100 (720,1280,3) 140 | 139 100 (720,1280,3) 141 | 140 100 (720,1280,3) 142 | 141 100 (720,1280,3) 143 | 142 100 (720,1280,3) 144 | 143 100 (720,1280,3) 145 | 144 100 (720,1280,3) 146 | 145 100 (720,1280,3) 147 | 146 100 (720,1280,3) 148 | 147 100 (720,1280,3) 149 | 148 100 (720,1280,3) 150 | 149 100 (720,1280,3) 151 | 150 100 (720,1280,3) 152 | 151 100 (720,1280,3) 153 | 152 100 (720,1280,3) 154 | 153 100 (720,1280,3) 155 | 154 100 (720,1280,3) 156 | 155 100 (720,1280,3) 157 | 156 100 (720,1280,3) 158 | 157 100 (720,1280,3) 159 | 158 100 (720,1280,3) 160 | 159 100 (720,1280,3) 161 | 160 100 (720,1280,3) 162 | 161 100 (720,1280,3) 163 | 162 100 (720,1280,3) 164 | 163 100 (720,1280,3) 165 | 164 100 (720,1280,3) 166 | 165 100 (720,1280,3) 167 | 166 100 (720,1280,3) 168 | 167 100 (720,1280,3) 169 | 168 100 (720,1280,3) 170 | 169 100 (720,1280,3) 171 | 170 100 (720,1280,3) 172 | 171 100 (720,1280,3) 173 | 172 100 (720,1280,3) 174 | 173 100 (720,1280,3) 175 | 174 100 (720,1280,3) 176 | 175 100 (720,1280,3) 177 | 176 100 (720,1280,3) 178 | 177 100 (720,1280,3) 179 | 178 100 (720,1280,3) 180 | 179 100 (720,1280,3) 181 | 180 100 (720,1280,3) 182 | 181 100 (720,1280,3) 183 | 182 100 (720,1280,3) 184 | 183 100 (720,1280,3) 185 | 184 100 (720,1280,3) 186 | 185 100 (720,1280,3) 187 | 186 100 (720,1280,3) 188 | 187 100 (720,1280,3) 189 | 188 100 (720,1280,3) 190 | 189 100 (720,1280,3) 191 | 190 100 (720,1280,3) 192 | 191 100 (720,1280,3) 193 | 192 100 (720,1280,3) 194 | 193 100 (720,1280,3) 195 | 194 100 (720,1280,3) 196 | 195 100 (720,1280,3) 197 | 196 100 (720,1280,3) 198 | 197 100 (720,1280,3) 199 | 198 100 (720,1280,3) 200 | 199 100 (720,1280,3) 201 | 200 100 (720,1280,3) 202 | 201 100 (720,1280,3) 203 | 202 100 (720,1280,3) 204 | 203 100 (720,1280,3) 205 | 204 100 (720,1280,3) 206 | 205 100 (720,1280,3) 207 | 206 100 (720,1280,3) 208 | 207 100 (720,1280,3) 209 | 208 100 (720,1280,3) 210 | 209 100 (720,1280,3) 211 | 210 100 (720,1280,3) 212 | 211 100 (720,1280,3) 213 | 212 100 (720,1280,3) 214 | 213 100 (720,1280,3) 215 | 214 100 (720,1280,3) 216 | 215 100 (720,1280,3) 217 | 216 100 (720,1280,3) 218 | 217 100 (720,1280,3) 219 | 218 100 (720,1280,3) 220 | 219 100 (720,1280,3) 221 | 220 100 (720,1280,3) 222 | 221 100 (720,1280,3) 223 | 222 100 (720,1280,3) 224 | 223 100 (720,1280,3) 225 | 224 100 (720,1280,3) 226 | 225 100 (720,1280,3) 227 | 226 100 (720,1280,3) 228 | 227 100 (720,1280,3) 229 | 228 100 (720,1280,3) 230 | 229 100 (720,1280,3) 231 | 230 100 (720,1280,3) 232 | 231 100 (720,1280,3) 233 | 232 100 (720,1280,3) 234 | 233 100 (720,1280,3) 235 | 234 100 (720,1280,3) 236 | 235 100 (720,1280,3) 237 | 236 100 (720,1280,3) 238 | 237 100 (720,1280,3) 239 | 238 100 (720,1280,3) 240 | 239 100 (720,1280,3) 241 | 240 100 (720,1280,3) 242 | 241 100 (720,1280,3) 243 | 242 100 (720,1280,3) 244 | 243 100 (720,1280,3) 245 | 244 100 (720,1280,3) 246 | 245 100 (720,1280,3) 247 | 246 100 (720,1280,3) 248 | 247 100 (720,1280,3) 249 | 248 100 (720,1280,3) 250 | 249 100 (720,1280,3) 251 | 250 100 (720,1280,3) 252 | 251 100 (720,1280,3) 253 | 252 100 (720,1280,3) 254 | 253 100 (720,1280,3) 255 | 254 100 (720,1280,3) 256 | 255 100 (720,1280,3) 257 | 256 100 (720,1280,3) 258 | 257 100 (720,1280,3) 259 | 258 100 (720,1280,3) 260 | 259 100 (720,1280,3) 261 | 260 100 (720,1280,3) 262 | 261 100 (720,1280,3) 263 | 262 100 (720,1280,3) 264 | 263 100 (720,1280,3) 265 | 264 100 (720,1280,3) 266 | 265 100 (720,1280,3) 267 | 266 100 (720,1280,3) 268 | 267 100 (720,1280,3) 269 | 268 100 (720,1280,3) 270 | 269 100 (720,1280,3) 271 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 8 | 9 | 10 | class Vimeo90KDataset(data.Dataset): 11 | """Vimeo90K dataset for training. 12 | 13 | The keys are generated from a meta info txt file. 14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 15 | 16 | Each line contains: 17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space. 18 | Examples: 19 | 00001/0001 7 (256,448,3) 20 | 00001/0002 7 (256,448,3) 21 | 22 | Key examples: "00001/0001" 23 | GT (gt): Ground-Truth; 24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 25 | 26 | The neighboring frame list for different num_frame: 27 | num_frame | frame list 28 | 1 | 4 29 | 3 | 3,4,5 30 | 5 | 2,3,4,5,6 31 | 7 | 1,2,3,4,5,6,7 32 | 33 | Args: 34 | opt (dict): Config for train dataset. It contains the following keys: 35 | dataroot_gt (str): Data root path for gt. 36 | dataroot_lq (str): Data root path for lq. 37 | meta_info_file (str): Path for meta information file. 38 | io_backend (dict): IO backend type and other kwarg. 39 | 40 | num_frame (int): Window size for input frames. 41 | gt_size (int): Cropped patched size for gt patches. 42 | random_reverse (bool): Random reverse input frames. 43 | use_flip (bool): Use horizontal flips. 44 | use_rot (bool): Use rotation (use vertical flip and transposing h 45 | and w for implementation). 46 | 47 | scale (bool): Scale, which will be added automatically. 48 | """ 49 | 50 | def __init__(self, opt): 51 | super(Vimeo90KDataset, self).__init__() 52 | self.opt = opt 53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( 54 | opt['dataroot_lq']) 55 | 56 | with open(opt['meta_info_file'], 'r') as fin: 57 | self.keys = [line.split(' ')[0] for line in fin] 58 | 59 | # file client (io backend) 60 | self.file_client = None 61 | self.io_backend_opt = opt['io_backend'] 62 | self.is_lmdb = False 63 | if self.io_backend_opt['type'] == 'lmdb': 64 | self.is_lmdb = True 65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 66 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 67 | 68 | # indices of input images 69 | self.neighbor_list = [ 70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) 71 | ] 72 | 73 | # temporal augmentation configs 74 | self.random_reverse = opt['random_reverse'] 75 | logger = get_root_logger() 76 | logger.info(f'Random reverse is {self.random_reverse}.') 77 | 78 | def __getitem__(self, index): 79 | if self.file_client is None: 80 | self.file_client = FileClient( 81 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 82 | 83 | # random reverse 84 | if self.random_reverse and random.random() < 0.5: 85 | self.neighbor_list.reverse() 86 | 87 | scale = self.opt['scale'] 88 | gt_size = self.opt['gt_size'] 89 | key = self.keys[index] 90 | clip, seq = key.split('/') # key example: 00001/0001 91 | 92 | # get the GT frame (im4.png) 93 | if self.is_lmdb: 94 | img_gt_path = f'{key}/im4' 95 | else: 96 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 97 | img_bytes = self.file_client.get(img_gt_path, 'gt') 98 | img_gt = imfrombytes(img_bytes, float32=True) 99 | 100 | # get the neighboring LQ frames 101 | img_lqs = [] 102 | for neighbor in self.neighbor_list: 103 | if self.is_lmdb: 104 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 105 | else: 106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 107 | img_bytes = self.file_client.get(img_lq_path, 'lq') 108 | img_lq = imfrombytes(img_bytes, float32=True) 109 | img_lqs.append(img_lq) 110 | 111 | # randomly crop 112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, 113 | img_gt_path) 114 | 115 | # augmentation - flip, rotate 116 | img_lqs.append(img_gt) 117 | img_results = augment(img_lqs, self.opt['use_flip'], 118 | self.opt['use_rot']) 119 | 120 | img_results = img2tensor(img_results) 121 | img_lqs = torch.stack(img_results[0:-1], dim=0) 122 | img_gt = img_results[-1] 123 | 124 | # img_lqs: (t, c, h, w) 125 | # img_gt: (c, h, w) 126 | # key: str 127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 128 | 129 | def __len__(self): 130 | return len(self.keys) 131 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .niqe import calculate_niqe 2 | from .psnr_ssim import calculate_psnr, calculate_ssim 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 5 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.models.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', 11 | resize_input=True, 12 | normalize_input=False): 13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 14 | # does resize the input. 15 | inception = InceptionV3([3], 16 | resize_input=resize_input, 17 | normalize_input=normalize_input) 18 | inception = nn.DataParallel(inception).eval().to(device) 19 | return inception 20 | 21 | 22 | @torch.no_grad() 23 | def extract_inception_features(data_generator, 24 | inception, 25 | len_generator=None, 26 | device='cuda'): 27 | """Extract inception features. 28 | 29 | Args: 30 | data_generator (generator): A data generator. 31 | inception (nn.Module): Inception model. 32 | len_generator (int): Length of the data_generator to show the 33 | progressbar. Default: None. 34 | device (str): Device. Default: cuda. 35 | 36 | Returns: 37 | Tensor: Extracted features. 38 | """ 39 | if len_generator is not None: 40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 41 | else: 42 | pbar = None 43 | features = [] 44 | 45 | for data in data_generator: 46 | if pbar: 47 | pbar.update(1) 48 | data = data.to(device) 49 | feature = inception(data)[0].view(data.shape[0], -1) 50 | features.append(feature.to('cpu')) 51 | if pbar: 52 | pbar.close() 53 | features = torch.cat(features, 0) 54 | return features 55 | 56 | 57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | 60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 61 | and X_2 ~ N(mu_2, C_2) is 62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 63 | Stable version by Dougal J. Sutherland. 64 | 65 | Args: 66 | mu1 (np.array): The sample mean over activations. 67 | sigma1 (np.array): The covariance matrix over activations for 68 | generated samples. 69 | mu2 (np.array): The sample mean over activations, precalculated on an 70 | representative data set. 71 | sigma2 (np.array): The covariance matrix over activations, 72 | precalculated on an representative data set. 73 | 74 | Returns: 75 | float: The Frechet Distance. 76 | """ 77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 78 | assert sigma1.shape == sigma2.shape, ( 79 | 'Two covariances have different dimensions') 80 | 81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 82 | 83 | # Product might be almost singular 84 | if not np.isfinite(cov_sqrt).all(): 85 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 86 | 'of cov estimates') 87 | offset = np.eye(sigma1.shape[0]) * eps 88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 89 | 90 | # Numerical error might give slight imaginary component 91 | if np.iscomplexobj(cov_sqrt): 92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 93 | m = np.max(np.abs(cov_sqrt.imag)) 94 | raise ValueError(f'Imaginary component {m}') 95 | cov_sqrt = cov_sqrt.real 96 | 97 | mean_diff = mu1 - mu2 98 | mean_norm = mean_diff @ mean_diff 99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 100 | fid = mean_norm + trace 101 | 102 | return fid 103 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss', 5 | ] 6 | -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from basicsr.models.losses.loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | # @weighted_loss 22 | # def charbonnier_loss(pred, target, eps=1e-12): 23 | # return torch.sqrt((pred - target)**2 + eps) 24 | 25 | 26 | class L1Loss(nn.Module): 27 | """L1 (mean absolute error, MAE) loss. 28 | 29 | Args: 30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 31 | reduction (str): Specifies the reduction to apply to the output. 32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 33 | """ 34 | 35 | def __init__(self, loss_weight=1.0, reduction='mean'): 36 | super(L1Loss, self).__init__() 37 | if reduction not in ['none', 'mean', 'sum']: 38 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 39 | f'Supported ones are: {_reduction_modes}') 40 | 41 | self.loss_weight = loss_weight 42 | self.reduction = reduction 43 | 44 | def forward(self, pred, target, weight=None, **kwargs): 45 | """ 46 | Args: 47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 50 | weights. Default: None. 51 | """ 52 | return self.loss_weight * l1_loss( 53 | pred, target, weight, reduction=self.reduction) 54 | 55 | class MSELoss(nn.Module): 56 | """MSE (L2) loss. 57 | 58 | Args: 59 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 60 | reduction (str): Specifies the reduction to apply to the output. 61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 62 | """ 63 | 64 | def __init__(self, loss_weight=1.0, reduction='mean'): 65 | super(MSELoss, self).__init__() 66 | if reduction not in ['none', 'mean', 'sum']: 67 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 68 | f'Supported ones are: {_reduction_modes}') 69 | 70 | self.loss_weight = loss_weight 71 | self.reduction = reduction 72 | 73 | def forward(self, pred, target, weight=None, **kwargs): 74 | """ 75 | Args: 76 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 77 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 78 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 79 | weights. Default: None. 80 | """ 81 | return self.loss_weight * mse_loss( 82 | pred, target, weight, reduction=self.reduction) 83 | 84 | class PSNRLoss(nn.Module): 85 | 86 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 87 | super(PSNRLoss, self).__init__() 88 | assert reduction == 'mean' 89 | self.loss_weight = loss_weight 90 | self.scale = 10 / np.log(10) 91 | self.toY = toY 92 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 93 | self.first = True 94 | 95 | def forward(self, pred, target): 96 | assert len(pred.size()) == 4 97 | if self.toY: 98 | if self.first: 99 | self.coef = self.coef.to(pred.device) 100 | self.first = False 101 | 102 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 103 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 104 | 105 | pred, target = pred / 255., target / 255. 106 | pass 107 | assert len(pred.size()) == 4 108 | 109 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 110 | 111 | class CharbonnierLoss(nn.Module): 112 | """Charbonnier Loss (L1)""" 113 | 114 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): 115 | super(CharbonnierLoss, self).__init__() 116 | self.eps = eps 117 | 118 | def forward(self, x, y): 119 | diff = x - y 120 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 121 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 122 | return loss 123 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import create_dataloader, create_dataset 6 | from basicsr.models import create_model 7 | from basicsr.train import parse_options 8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 9 | make_exp_dirs) 10 | from basicsr.utils.options import dict2str 11 | 12 | 13 | def main(): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt = parse_options(is_train=False) 16 | 17 | torch.backends.cudnn.benchmark = True 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | # mkdir and initialize loggers 21 | make_exp_dirs(opt) 22 | log_file = osp.join(opt['path']['log'], 23 | f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger( 25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 26 | logger.info(get_env_info()) 27 | logger.info(dict2str(opt)) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for phase, dataset_opt in sorted(opt['datasets'].items()): 32 | test_set = create_dataset(dataset_opt) 33 | test_loader = create_dataloader( 34 | test_set, 35 | dataset_opt, 36 | num_gpu=opt['num_gpu'], 37 | dist=opt['dist'], 38 | sampler=None, 39 | seed=opt['manual_seed']) 40 | logger.info( 41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 42 | test_loaders.append(test_loader) 43 | 44 | # create model 45 | model = create_model(opt) 46 | 47 | for test_loader in test_loaders: 48 | test_set_name = test_loader.dataset.opt['name'] 49 | logger.info(f'Testing {test_set_name}...') 50 | rgb2bgr = opt['val'].get('rgb2bgr', True) 51 | # wheather use uint8 image to compute metrics 52 | use_image = opt['val'].get('use_image', True) 53 | model.validation( 54 | test_loader, 55 | current_iter=opt['name'], 56 | tb_logger=None, 57 | save_img=opt['val']['save_img'], 58 | rgb2bgr=rgb2bgr, use_image=use_image) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /basicsr/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | 13 | Args: 14 | opt (dict): Config. It contains the following keys: 15 | name (str): Exp name. 16 | logger (dict): Contains 'print_freq' (str) for logger interval. 17 | train (dict): Contains 'total_iter' (int) for total iters. 18 | use_tb_logger (bool): Use tensorboard logger. 19 | start_iter (int): Start iter. Default: 1. 20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 21 | """ 22 | 23 | def __init__(self, opt, start_iter=1, tb_logger=None): 24 | self.exp_name = opt['name'] 25 | self.interval = opt['logger']['print_freq'] 26 | self.start_iter = start_iter 27 | self.max_iters = opt['train']['total_iter'] 28 | self.use_tb_logger = opt['logger']['use_tb_logger'] 29 | self.tb_logger = tb_logger 30 | self.start_time = time.time() 31 | self.logger = get_root_logger() 32 | 33 | @master_only 34 | def __call__(self, log_vars): 35 | """Format logging message. 36 | 37 | Args: 38 | log_vars (dict): It contains the following keys: 39 | epoch (int): Epoch number. 40 | iter (int): Current iter. 41 | lrs (list): List for learning rates. 42 | 43 | time (float): Iter time. 44 | data_time (float): Data time for each iter. 45 | """ 46 | # epoch, iter, learning rates 47 | epoch = log_vars.pop('epoch') 48 | current_iter = log_vars.pop('iter') 49 | lrs = log_vars.pop('lrs') 50 | 51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 52 | for v in lrs: 53 | message += f'{v:.3e},' 54 | message += ')] ' 55 | 56 | # time and estimated time 57 | if 'time' in log_vars.keys(): 58 | iter_time = log_vars.pop('time') 59 | data_time = log_vars.pop('data_time') 60 | 61 | total_time = time.time() - self.start_time 62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 65 | message += f'[eta: {eta_str}, ' 66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 67 | 68 | # other items, especially losses 69 | for k, v in log_vars.items(): 70 | message += f'{k}: {v:.4e} ' 71 | # tensorboard logger 72 | if self.use_tb_logger and 'debug' not in self.exp_name: 73 | if k.startswith('l_'): 74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 75 | else: 76 | self.tb_logger.add_scalar(k, v, current_iter) 77 | self.logger.info(message) 78 | 79 | 80 | @master_only 81 | def init_tb_logger(log_dir): 82 | from torch.utils.tensorboard import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir=log_dir) 84 | return tb_logger 85 | 86 | 87 | @master_only 88 | def init_wandb_logger(opt): 89 | """We now only use wandb to sync tensorboard log.""" 90 | import wandb 91 | logger = logging.getLogger('basicsr') 92 | 93 | project = opt['logger']['wandb']['project'] 94 | resume_id = opt['logger']['wandb'].get('resume_id') 95 | if resume_id: 96 | wandb_id = resume_id 97 | resume = 'allow' 98 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 99 | else: 100 | wandb_id = wandb.util.generate_id() 101 | resume = 'never' 102 | 103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 104 | 105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 106 | 107 | 108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 109 | """Get the root logger. 110 | 111 | The logger will be initialized if it has not been initialized. By default a 112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 113 | also be added. 114 | 115 | Args: 116 | logger_name (str): root logger name. Default: 'basicsr'. 117 | log_file (str | None): The log filename. If specified, a FileHandler 118 | will be added to the root logger. 119 | log_level (int): The root logger level. Note that only the process of 120 | rank 0 is affected, while other processes will set the level to 121 | "Error" and be silent most of the time. 122 | 123 | Returns: 124 | logging.Logger: The root logger. 125 | """ 126 | logger = logging.getLogger(logger_name) 127 | # if the logger has been initialized, just return it 128 | if logger_name in initialized_logger: 129 | return logger 130 | 131 | format_str = '%(asctime)s %(levelname)s: %(message)s' 132 | stream_handler = logging.StreamHandler() 133 | stream_handler.setFormatter(logging.Formatter(format_str)) 134 | logger.addHandler(stream_handler) 135 | logger.propagate = False 136 | rank, _ = get_dist_info() 137 | if rank != 0: 138 | logger.setLevel('ERROR') 139 | elif log_file is not None: 140 | logger.setLevel(log_level) 141 | # add file handler 142 | file_handler = logging.FileHandler(log_file, 'w') 143 | file_handler.setFormatter(logging.Formatter(format_str)) 144 | file_handler.setLevel(log_level) 145 | logger.addHandler(file_handler) 146 | initialized_logger[logger_name] = True 147 | return logger 148 | 149 | 150 | def get_env_info(): 151 | """Get environment information. 152 | 153 | Currently, only log the software version. 154 | """ 155 | import torch 156 | import torchvision 157 | 158 | from basicsr.version import __version__ 159 | msg = r""" 160 | ____ _ _____ ____ 161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 164 | /_____/ \__,_//____//_/ \___//____//_/ |_| 165 | ______ __ __ __ __ 166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 170 | """ 171 | msg += ('\nVersion Information: ' 172 | f'\n\tBasicSR: {__version__}' 173 | f'\n\tPyTorch: {torch.__version__}' 174 | f'\n\tTorchVision: {torchvision.__version__}') 175 | return msg -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | from .logger import get_root_logger 10 | 11 | 12 | def set_random_seed(seed): 13 | """Set random seeds.""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | @master_only 39 | def make_exp_dirs(opt): 40 | """Make dirs for experiments.""" 41 | path_opt = opt['path'].copy() 42 | if opt['is_train']: 43 | mkdir_and_rename(path_opt.pop('experiments_root')) 44 | else: 45 | mkdir_and_rename(path_opt.pop('results_root')) 46 | for key, path in path_opt.items(): 47 | if ('strict_load' not in key) and ('pretrain_network' 48 | not in key) and ('resume' 49 | not in key): 50 | os.makedirs(path, exist_ok=True) 51 | 52 | 53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 54 | """Scan a directory to find the interested files. 55 | 56 | Args: 57 | dir_path (str): Path of the directory. 58 | suffix (str | tuple(str), optional): File suffix that we are 59 | interested in. Default: None. 60 | recursive (bool, optional): If set to True, recursively scan the 61 | directory. Default: False. 62 | full_path (bool, optional): If set to True, include the dir_path. 63 | Default: False. 64 | 65 | Returns: 66 | A generator for all the interested files with relative pathes. 67 | """ 68 | 69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 70 | raise TypeError('"suffix" must be a string or tuple of strings') 71 | 72 | root = dir_path 73 | 74 | def _scandir(dir_path, suffix, recursive): 75 | for entry in os.scandir(dir_path): 76 | if not entry.name.startswith('.') and entry.is_file(): 77 | if full_path: 78 | return_path = entry.path 79 | else: 80 | return_path = osp.relpath(entry.path, root) 81 | 82 | if suffix is None: 83 | yield return_path 84 | elif return_path.endswith(suffix): 85 | yield return_path 86 | else: 87 | if recursive: 88 | yield from _scandir( 89 | entry.path, suffix=suffix, recursive=recursive) 90 | else: 91 | continue 92 | 93 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 94 | 95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 96 | """Scan a directory to find the interested files. 97 | 98 | Args: 99 | dir_path (str): Path of the directory. 100 | keywords (str | tuple(str), optional): File keywords that we are 101 | interested in. Default: None. 102 | recursive (bool, optional): If set to True, recursively scan the 103 | directory. Default: False. 104 | full_path (bool, optional): If set to True, include the dir_path. 105 | Default: False. 106 | 107 | Returns: 108 | A generator for all the interested files with relative pathes. 109 | """ 110 | 111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 112 | raise TypeError('"keywords" must be a string or tuple of strings') 113 | 114 | root = dir_path 115 | 116 | def _scandir(dir_path, keywords, recursive): 117 | for entry in os.scandir(dir_path): 118 | if not entry.name.startswith('.') and entry.is_file(): 119 | if full_path: 120 | return_path = entry.path 121 | else: 122 | return_path = osp.relpath(entry.path, root) 123 | 124 | if keywords is None: 125 | yield return_path 126 | elif return_path.find(keywords) > 0: 127 | yield return_path 128 | else: 129 | if recursive: 130 | yield from _scandir( 131 | entry.path, keywords=keywords, recursive=recursive) 132 | else: 133 | continue 134 | 135 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 136 | 137 | def check_resume(opt, resume_iter): 138 | """Check resume states and pretrain_network paths. 139 | 140 | Args: 141 | opt (dict): Options. 142 | resume_iter (int): Resume iteration. 143 | """ 144 | logger = get_root_logger() 145 | if opt['path']['resume_state']: 146 | # get all the networks 147 | networks = [key for key in opt.keys() if key.startswith('network_')] 148 | flag_pretrain = False 149 | for network in networks: 150 | if opt['path'].get(f'pretrain_{network}') is not None: 151 | flag_pretrain = True 152 | if flag_pretrain: 153 | logger.warning( 154 | 'pretrain_network path will be ignored during resuming.') 155 | # set pretrained model paths 156 | for network in networks: 157 | name = f'pretrain_{network}' 158 | basename = network.replace('network_', '') 159 | if opt['path'].get('ignore_resume_networks') is None or ( 160 | basename not in opt['path']['ignore_resume_networks']): 161 | opt['path'][name] = osp.join( 162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 163 | logger.info(f"Set {name} to {opt['path'][name]}") 164 | 165 | 166 | def sizeof_fmt(size, suffix='B'): 167 | """Get human readable file size. 168 | 169 | Args: 170 | size (int): File size. 171 | suffix (str): Suffix. Default: 'B'. 172 | 173 | Return: 174 | str: Formated file siz. 175 | """ 176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 177 | if abs(size) < 1024.0: 178 | return f'{size:3.1f} {unit}{suffix}' 179 | size /= 1024.0 180 | return f'{size:3.1f} Y{suffix}' 181 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Wed Mar 9 22:05:30 2022 3 | __version__ = '1.2.0+10018c6' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /images/Deblurring/input/109fromGOPR1096.MP4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/109fromGOPR1096.MP4.png -------------------------------------------------------------------------------- /images/Deblurring/input/110fromGOPR1087.MP4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/110fromGOPR1087.MP4.png -------------------------------------------------------------------------------- /images/Deblurring/input/1fromGOPR0950.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/1fromGOPR0950.png -------------------------------------------------------------------------------- /images/Deblurring/input/1fromGOPR1096.MP4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/1fromGOPR1096.MP4.png -------------------------------------------------------------------------------- /images/Dehazing/input/0003_0.8_0.2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0003_0.8_0.2.png -------------------------------------------------------------------------------- /images/Dehazing/input/0010_0.95_0.16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0010_0.95_0.16.png -------------------------------------------------------------------------------- /images/Dehazing/input/0014_0.8_0.12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0014_0.8_0.12.png -------------------------------------------------------------------------------- /images/Dehazing/input/0048_0.9_0.2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0048_0.9_0.2.png -------------------------------------------------------------------------------- /images/Dehazing/input/1440_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/1440_10.png -------------------------------------------------------------------------------- /images/Dehazing/input/1444_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/1444_10.png -------------------------------------------------------------------------------- /images/Denoising/input/0003_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0003_30.png -------------------------------------------------------------------------------- /images/Denoising/input/0011_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0011_23.png -------------------------------------------------------------------------------- /images/Denoising/input/0013_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0013_19.png -------------------------------------------------------------------------------- /images/Denoising/input/0039_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0039_04.png -------------------------------------------------------------------------------- /images/Deraining/input/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/0.jpg -------------------------------------------------------------------------------- /images/Deraining/input/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/1.png -------------------------------------------------------------------------------- /images/Deraining/input/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/15.png -------------------------------------------------------------------------------- /images/Deraining/input/55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/55.png -------------------------------------------------------------------------------- /images/Enhancement/input/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/1.png -------------------------------------------------------------------------------- /images/Enhancement/input/111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/111.png -------------------------------------------------------------------------------- /images/Enhancement/input/748.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/748.png -------------------------------------------------------------------------------- /images/Enhancement/input/a4541-DSC_0040-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/a4541-DSC_0040-2.png -------------------------------------------------------------------------------- /images/Results/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/0.jpg -------------------------------------------------------------------------------- /images/Results/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/1.png -------------------------------------------------------------------------------- /images/Results/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/15.png -------------------------------------------------------------------------------- /images/Results/55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/55.png -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/overview.png -------------------------------------------------------------------------------- /maxim_pytorch/README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch re-implementation of MAXIM. 2 | `maxim_torch.py` is the PyTorch re-implementation of 3-stage MAXIM architecture for image denoising. 3 | 4 | 5 | `jax2torch.py` is leveraged to convert JAX weights (from a pretrained checkpoint) to PyTorch, and then save it as a dictionary which can be loaded directly to a PyTorch-implemented MAXIM model. To use this script, you should first download the pretrained JAX model from the official directory. 6 | 7 | It should be noted that due to the incompatibility between `flax.linen.ConvTranspose` and `torch.nn.ConvTranspose2d`, even if you load exactly the same pretrained parameters, the outputs of JAX model and PyTorch model are not exactly the same, though the difference is small. -------------------------------------------------------------------------------- /maxim_pytorch/jax2torch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #convert pretrained Jax params of MAXIM to Pytorch 4 | import argparse 5 | import collections 6 | import io 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import tensorflow as tf 12 | from maxim_torch import MAXIM_dns_3s 13 | 14 | 15 | def recover_tree(keys, values): 16 | """Recovers a tree as a nested dict from flat names and values. 17 | This function is useful to analyze checkpoints that are saved by our programs 18 | without need to access the exact source code of the experiment. In particular, 19 | it can be used to extract an reuse various subtrees of the scheckpoint, e.g. 20 | subtree of parameters. 21 | Args: 22 | keys: a list of keys, where '/' is used as separator between nodes. 23 | values: a list of leaf values. 24 | Returns: 25 | A nested tree-like dict. 26 | """ 27 | tree = {} 28 | sub_trees = collections.defaultdict(list) 29 | for k, v in zip(keys, values): 30 | if "/" not in k: 31 | tree[k] = v 32 | else: 33 | k_left, k_right = k.split("/", 1) 34 | sub_trees[k_left].append((k_right, v)) 35 | for k, kv_pairs in sub_trees.items(): 36 | k_subtree, v_subtree = zip(*kv_pairs) 37 | tree[k] = recover_tree(k_subtree, v_subtree) 38 | return tree 39 | 40 | def get_params(ckpt_path): 41 | """Get params checkpoint.""" 42 | with tf.io.gfile.GFile(ckpt_path, "rb") as f: 43 | data = f.read() 44 | values = np.load(io.BytesIO(data)) 45 | params = recover_tree(*zip(*values.items())) 46 | params = params["opt"]["target"] 47 | return params 48 | 49 | def modify_jax_params(flat_jax_dict): 50 | modified_dict = {} 51 | for key, value in flat_jax_dict.items(): 52 | key_split = key.split("/") 53 | modified_value = torch.tensor(value, dtype=torch.float) 54 | 55 | 56 | #modify values 57 | num_dim = len(modified_value.shape) 58 | if num_dim == 1: 59 | modified_value = modified_value.squeeze() 60 | elif num_dim == 2 and key_split[-1] == 'kernel': 61 | # for normal weight, transpose it 62 | modified_value = modified_value.T 63 | elif num_dim == 4 and key_split[-1] == 'kernel': 64 | modified_value = modified_value.permute(3, 2, 0, 1) 65 | if num_dim ==4 and key_split[-2] == 'ConvTranspose_0' and key_split[-1] == 'kernel': 66 | modified_value = modified_value.permute(1, 0, 2, 3) 67 | 68 | 69 | #modify keys 70 | modified_key = (".".join(key_split[:])) 71 | if "kernel" in modified_key: 72 | modified_key = modified_key.replace("kernel", "weight") 73 | if "LayerNorm" in modified_key: 74 | modified_key = modified_key.replace("scale", "gamma") 75 | modified_key = modified_key.replace("bias", "beta") 76 | if "layernorm" in modified_key: 77 | modified_key = modified_key.replace("scale", "gamma") 78 | modified_key = modified_key.replace("bias", "beta") 79 | 80 | modified_dict[modified_key] = modified_value 81 | 82 | return modified_dict 83 | 84 | 85 | def main(args): 86 | jax_params = get_params(args.ckpt_path) 87 | [flat_jax_dict] = pd.json_normalize(jax_params, sep="/").to_dict(orient="records") #set separation sign 88 | 89 | # Amend the JAX variables to match the names of the torch variables. 90 | modified_jax_params = modify_jax_params(flat_jax_dict) 91 | 92 | # update and save 93 | model = MAXIM_dns_3s() 94 | maxim_dict = model.state_dict() 95 | maxim_dict.update(modified_jax_params) 96 | torch.save(maxim_dict, args.output_file) 97 | 98 | def parse_args(): 99 | parser = argparse.ArgumentParser( 100 | description="Conversion of the JAX pre-trained MAXIM weights to Pytorch." 101 | ) 102 | parser.add_argument( 103 | "-c", 104 | "--ckpt_path", 105 | default="maxim_ckpt_Denoising_SIDD_checkpoint.npz", 106 | type=str, 107 | help="Checkpoint to port.", 108 | ) 109 | parser.add_argument( 110 | "-o", 111 | "--output_file", 112 | default="torch_weight.pth", 113 | type=str, 114 | help="Output.", 115 | ) 116 | return parser.parse_args() 117 | 118 | 119 | if __name__ == "__main__": 120 | args = parse_args() 121 | main(args) 122 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 11 | CUDAExtension) 12 | 13 | version_file = 'basicsr/version.py' 14 | 15 | 16 | def readme(): 17 | return '' 18 | # with open('README.md', encoding='utf-8') as f: 19 | # content = f.read() 20 | # return content 21 | 22 | 23 | def get_git_hash(): 24 | 25 | def _minimal_ext_cmd(cmd): 26 | # construct minimal environment 27 | env = {} 28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 29 | v = os.environ.get(k) 30 | if v is not None: 31 | env[k] = v 32 | # LANGUAGE is used on win32 33 | env['LANGUAGE'] = 'C' 34 | env['LANG'] = 'C' 35 | env['LC_ALL'] = 'C' 36 | out = subprocess.Popen( 37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 38 | return out 39 | 40 | try: 41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 42 | sha = out.strip().decode('ascii') 43 | except OSError: 44 | sha = 'unknown' 45 | 46 | return sha 47 | 48 | 49 | def get_hash(): 50 | if os.path.exists('.git'): 51 | sha = get_git_hash()[:7] 52 | elif os.path.exists(version_file): 53 | try: 54 | from basicsr.version import __version__ 55 | sha = __version__.split('+')[-1] 56 | except ImportError: 57 | raise ImportError('Unable to get git version') 58 | else: 59 | sha = 'unknown' 60 | 61 | return sha 62 | 63 | 64 | def write_version_py(): 65 | content = """# GENERATED VERSION FILE 66 | # TIME: {} 67 | __version__ = '{}' 68 | short_version = '{}' 69 | version_info = ({}) 70 | """ 71 | sha = get_hash() 72 | with open('VERSION', 'r') as f: 73 | SHORT_VERSION = f.read().strip() 74 | VERSION_INFO = ', '.join( 75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 76 | VERSION = SHORT_VERSION + '+' + sha 77 | 78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 79 | VERSION_INFO) 80 | with open(version_file, 'w') as f: 81 | f.write(version_file_str) 82 | 83 | 84 | def get_version(): 85 | with open(version_file, 'r') as f: 86 | exec(compile(f.read(), version_file, 'exec')) 87 | return locals()['__version__'] 88 | 89 | 90 | def make_cuda_ext(name, module, sources, sources_cuda=None): 91 | if sources_cuda is None: 92 | sources_cuda = [] 93 | define_macros = [] 94 | extra_compile_args = {'cxx': []} 95 | 96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 97 | define_macros += [('WITH_CUDA', None)] 98 | extension = CUDAExtension 99 | extra_compile_args['nvcc'] = [ 100 | '-D__CUDA_NO_HALF_OPERATORS__', 101 | '-D__CUDA_NO_HALF_CONVERSIONS__', 102 | '-D__CUDA_NO_HALF2_OPERATORS__', 103 | ] 104 | sources += sources_cuda 105 | else: 106 | print(f'Compiling {name} without CUDA') 107 | extension = CppExtension 108 | 109 | return extension( 110 | name=f'{module}.{name}', 111 | sources=[os.path.join(*module.split('.'), p) for p in sources], 112 | define_macros=define_macros, 113 | extra_compile_args=extra_compile_args) 114 | 115 | 116 | def get_requirements(filename='requirements.txt'): 117 | return [] 118 | here = os.path.dirname(os.path.realpath(__file__)) 119 | with open(os.path.join(here, filename), 'r') as f: 120 | requires = [line.replace('\n', '') for line in f.readlines()] 121 | return requires 122 | 123 | 124 | if __name__ == '__main__': 125 | if '--no_cuda_ext' in sys.argv: 126 | ext_modules = [] 127 | sys.argv.remove('--no_cuda_ext') 128 | else: 129 | ext_modules = [ 130 | make_cuda_ext( 131 | name='deform_conv_ext', 132 | module='basicsr.models.ops.dcn', 133 | sources=['src/deform_conv_ext.cpp'], 134 | sources_cuda=[ 135 | 'src/deform_conv_cuda.cpp', 136 | 'src/deform_conv_cuda_kernel.cu' 137 | ]), 138 | make_cuda_ext( 139 | name='fused_act_ext', 140 | module='basicsr.models.ops.fused_act', 141 | sources=['src/fused_bias_act.cpp'], 142 | sources_cuda=['src/fused_bias_act_kernel.cu']), 143 | make_cuda_ext( 144 | name='upfirdn2d_ext', 145 | module='basicsr.models.ops.upfirdn2d', 146 | sources=['src/upfirdn2d.cpp'], 147 | sources_cuda=['src/upfirdn2d_kernel.cu']), 148 | ] 149 | 150 | write_version_py() 151 | setup( 152 | name='basicsr', 153 | version=get_version(), 154 | description='Open Source Image and Video Super-Resolution Toolbox', 155 | long_description=readme(), 156 | author='Xintao Wang', 157 | author_email='xintao.wang@outlook.com', 158 | keywords='computer vision, restoration, super resolution', 159 | url='https://github.com/xinntao/BasicSR', 160 | packages=find_packages( 161 | exclude=('options', 'datasets', 'experiments', 'results', 162 | 'tb_logger', 'wandb')), 163 | classifiers=[ 164 | 'Development Status :: 4 - Beta', 165 | 'License :: OSI Approved :: Apache Software License', 166 | 'Operating System :: OS Independent', 167 | 'Programming Language :: Python :: 3', 168 | 'Programming Language :: Python :: 3.7', 169 | 'Programming Language :: Python :: 3.8', 170 | ], 171 | license='Apache License 2.0', 172 | setup_requires=['cython', 'numpy'], 173 | install_requires=get_requirements(), 174 | ext_modules=ext_modules, 175 | cmdclass={'build_ext': BuildExtension}, 176 | zip_safe=False) 177 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | 5 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt $CONFIG --launcher pytorch --------------------------------------------------------------------------------