├── .gitignore
├── Deblurring
├── Datasets
│ └── README.md
├── Options
│ └── Deblurring_Restormer.yml
├── README.md
├── download_data.py
├── evaluate_gopro_hide.m
├── evaluate_realblur.py
├── generate_patches_gopro.py
├── pretrained_models
│ └── README.md
├── test.py
└── utils.py
├── Denoising
├── Datasets
│ └── README.md
├── Options
│ ├── GaussianColorDenoising_Restormer.yml
│ ├── GaussianColorDenoising_RestormerSigma15.yml
│ ├── GaussianColorDenoising_RestormerSigma25.yml
│ ├── GaussianColorDenoising_RestormerSigma50.yml
│ ├── GaussianGrayDenoising_Restormer.yml
│ ├── GaussianGrayDenoising_RestormerSigma15.yml
│ ├── GaussianGrayDenoising_RestormerSigma25.yml
│ ├── GaussianGrayDenoising_RestormerSigma50.yml
│ └── RealDenoising_Restormer.yml
├── README.md
├── download_data.py
├── evaluate_gaussian_color_denoising.py
├── evaluate_gaussian_gray_denoising.py
├── evaluate_sidd.m
├── generate_patches_dfwb.py
├── generate_patches_sidd.py
├── pretrained_models
│ └── README.md
├── test_gaussian_color_denoising.py
├── test_gaussian_gray_denoising.py
├── test_real_denoising_dnd.py
├── test_real_denoising_sidd.py
└── utils.py
├── Deraining
├── Datasets
│ └── README.md
├── Options
│ └── Deraining_Restormer.yml
├── README.md
├── download_data.py
├── evaluate_PSNR_SSIM.m
├── pretrained_models
│ └── README.md
├── test.py
└── utils.py
├── INSTALL.md
├── LICENSE
├── README.md
├── basicsr
├── data
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── ffhq_dataset.py
│ ├── meta_info
│ │ ├── meta_info_DIV2K800sub_GT.txt
│ │ ├── meta_info_REDS4_test_GT.txt
│ │ ├── meta_info_REDS_GT.txt
│ │ ├── meta_info_REDSofficial4_test_GT.txt
│ │ ├── meta_info_REDSval_official_test_GT.txt
│ │ ├── meta_info_Vimeo90K_test_GT.txt
│ │ ├── meta_info_Vimeo90K_test_fast_GT.txt
│ │ ├── meta_info_Vimeo90K_test_medium_GT.txt
│ │ ├── meta_info_Vimeo90K_test_slow_GT.txt
│ │ └── meta_info_Vimeo90K_train_GT.txt
│ ├── paired_image_dataset.py
│ ├── prefetch_dataloader.py
│ ├── reds_dataset.py
│ ├── single_image_dataset.py
│ ├── transforms.py
│ ├── video_test_dataset.py
│ └── vimeo90k_dataset.py
├── metrics
│ ├── __init__.py
│ ├── fid.py
│ ├── metric_util.py
│ ├── niqe.py
│ ├── niqe_pris_params.npz
│ └── psnr_ssim.py
├── models
│ ├── __init__.py
│ ├── archs
│ │ ├── Maxim_arch.py
│ │ ├── __init__.py
│ │ ├── arch_util.py
│ │ ├── maxim.py
│ │ └── restormer_arch.py
│ ├── base_model.py
│ ├── image_restoration_model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── loss_util.py
│ │ └── losses.py
│ └── lr_scheduler.py
├── test.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── bundle_submissions.py
│ ├── create_lmdb.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── face_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ └── options.py
└── version.py
├── demo.py
├── images
├── Deblurring
│ └── input
│ │ ├── 109fromGOPR1096.MP4.png
│ │ ├── 110fromGOPR1087.MP4.png
│ │ ├── 1fromGOPR0950.png
│ │ └── 1fromGOPR1096.MP4.png
├── Dehazing
│ └── input
│ │ ├── 0003_0.8_0.2.png
│ │ ├── 0010_0.95_0.16.png
│ │ ├── 0014_0.8_0.12.png
│ │ ├── 0048_0.9_0.2.png
│ │ ├── 1440_10.png
│ │ └── 1444_10.png
├── Denoising
│ └── input
│ │ ├── 0003_30.png
│ │ ├── 0011_23.png
│ │ ├── 0013_19.png
│ │ └── 0039_04.png
├── Deraining
│ └── input
│ │ ├── 0.jpg
│ │ ├── 1.png
│ │ ├── 15.png
│ │ └── 55.png
├── Enhancement
│ └── input
│ │ ├── 1.png
│ │ ├── 111.png
│ │ ├── 748.png
│ │ └── a4541-DSC_0040-2.png
├── Results
│ ├── 0.jpg
│ ├── 1.png
│ ├── 15.png
│ └── 55.png
└── overview.png
├── maxim_pytorch
├── README.md
├── jax2torch.py
└── maxim_torch.py
├── setup.cfg
├── setup.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | *.pyc
3 | *.png
4 | *.tif
5 | *.jpg
6 | *.pth
7 | *.mat
8 | *.npy
9 | .DS_Store
10 |
11 |
--------------------------------------------------------------------------------
/Deblurring/Datasets/README.md:
--------------------------------------------------------------------------------
1 | For training and testing, your directory structure should look like this
2 |
3 | `Datasets`
4 | `├──train`
5 | `└──GoPro`
6 | `├──input_crops`
7 | `└──target_crops`
8 | `├──val`
9 | `└──GoPro`
10 | `├──input_crops`
11 | `└──target_crops`
12 | `└──test`
13 | `├──GoPro`
14 | `├──input`
15 | `└──target`
16 | `├──HIDE`
17 | `├──input`
18 | `└──target`
19 | `├──RealBlur_J`
20 | `├──input`
21 | `└──target`
22 | `└──RealBlur_R`
23 | `├──input`
24 | `└──target`
25 |
26 |
--------------------------------------------------------------------------------
/Deblurring/Options/Deblurring_Restormer.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Deblurring_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./Motion_Deblurring/Datasets/train/GoPro/target_crops
14 | dataroot_lq: ./Motion_Deblurring/Datasets/train/GoPro/input_crops
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ### -------------Progressive training--------------------------
27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
28 | iters: [92000,64000,48000,36000,36000,24000]
29 | gt_size: 384 # Max patch size for progressive training
30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
31 | ### ------------------------------------------------------------
32 |
33 | ### ------- Training on single fixed-patch size 128x128---------
34 | # mini_batch_sizes: [8]
35 | # iters: [300000]
36 | # gt_size: 128
37 | # gt_sizes: [128]
38 | ### ------------------------------------------------------------
39 |
40 | dataset_enlarge_ratio: 1
41 | prefetch_mode: ~
42 |
43 | val:
44 | name: ValSet
45 | type: Dataset_PairedImage
46 | dataroot_gt: ./Motion_Deblurring/Datasets/val/GoPro/target_crops
47 | dataroot_lq: ./Motion_Deblurring/Datasets/val/GoPro/input_crops
48 | io_backend:
49 | type: disk
50 |
51 | # network structures
52 | network_g:
53 | type: Restormer
54 | inp_channels: 3
55 | out_channels: 3
56 | dim: 48
57 | num_blocks: [4,6,6,8]
58 | num_refinement_blocks: 4
59 | heads: [1,2,4,8]
60 | ffn_expansion_factor: 2.66
61 | bias: False
62 | LayerNorm_type: WithBias
63 | dual_pixel_task: False
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: false
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 |
104 | # validation settings
105 | val:
106 | window_size: 8
107 | val_freq: !!float 4e3
108 | save_img: false
109 | rgb2bgr: true
110 | use_image: true
111 | max_minibatch: 8
112 |
113 | metrics:
114 | psnr: # metric name, can be arbitrary
115 | type: calculate_psnr
116 | crop_border: 0
117 | test_y_channel: false
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 1000
122 | save_checkpoint_freq: !!float 4e3
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/Deblurring/README.md:
--------------------------------------------------------------------------------
1 | ## Training
2 |
3 | 1. To download GoPro training and testing data, run
4 | ```
5 | python download_data.py --data train-test
6 | ```
7 |
8 | 2. Generate image patches from full-resolution training images of GoPro dataset
9 | ```
10 | python generate_patches_gopro.py
11 | ```
12 |
13 | 3. To train Restormer, run
14 | ```
15 | cd Restormer
16 | ./train.sh Motion_Deblurring/Options/Deblurring_Restormer.yml
17 | ```
18 |
19 | **Note:** The above training script uses 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and [Motion_Deblurring/Options/Deblurring_Restormer.yml](Options/Deblurring_Restormer.yml)
20 |
21 | ## Evaluation
22 |
23 | Download the pre-trained [model](https://drive.google.com/drive/folders/1czMyfRTQDX3j3ErByYeZ1PM4GVLbJeGK?usp=sharing) and place it in `./pretrained_models/`
24 |
25 | #### Testing on GoPro dataset
26 |
27 | - Download GoPro testset, run
28 | ```
29 | python download_data.py --data test --dataset GoPro
30 | ```
31 |
32 | - Testing
33 | ```
34 | python test.py --dataset GoPro
35 | ```
36 |
37 | #### Testing on HIDE dataset
38 |
39 | - Download HIDE testset, run
40 | ```
41 | python download_data.py --data test --dataset HIDE
42 | ```
43 |
44 | - Testing
45 | ```
46 | python test.py --dataset HIDE
47 | ```
48 |
49 | #### Testing on RealBlur-J dataset
50 |
51 | - Download RealBlur-J testset, run
52 | ```
53 | python download_data.py --data test --dataset RealBlur_J
54 | ```
55 |
56 | - Testing
57 | ```
58 | python test.py --dataset RealBlur_J
59 | ```
60 |
61 | #### Testing on RealBlur-R dataset
62 |
63 | - Download RealBlur-R testset, run
64 | ```
65 | python download_data.py --data test --dataset RealBlur_R
66 | ```
67 |
68 | - Testing
69 | ```
70 | python test.py --dataset RealBlur_R
71 | ```
72 |
73 | #### To reproduce PSNR/SSIM scores of the paper (Table 2) on GoPro and HIDE datasets, run this MATLAB script
74 |
75 | ```
76 | evaluate_gopro_hide.m
77 | ```
78 |
79 | #### To reproduce PSNR/SSIM scores of the paper (Table 2) on RealBlur dataset, run
80 |
81 | ```
82 | evaluate_realblur.py
83 | ```
84 |
--------------------------------------------------------------------------------
/Deblurring/download_data.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | ## Download training and testing data for single-image motion deblurring task
6 | import os
7 | # import gdown
8 | import shutil
9 |
10 | import argparse
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--data', type=str, required=True, help='train, test or train-test')
14 | parser.add_argument('--dataset', type=str, default='GoPro', help='all, GoPro, HIDE, RealBlur_R, RealBlur_J')
15 | args = parser.parse_args()
16 |
17 | ### Google drive IDs ######
18 | GoPro_train = '1zgALzrLCC_tcXKu_iHQTHukKUVT1aodI' ## https://drive.google.com/file/d/1zgALzrLCC_tcXKu_iHQTHukKUVT1aodI/view?usp=sharing
19 | GoPro_test = '1k6DTSHu4saUgrGTYkkZXTptILyG9RRll' ## https://drive.google.com/file/d/1k6DTSHu4saUgrGTYkkZXTptILyG9RRll/view?usp=sharing
20 | HIDE_test = '1XRomKYJF1H92g1EuD06pCQe4o6HlwB7A' ## https://drive.google.com/file/d/1XRomKYJF1H92g1EuD06pCQe4o6HlwB7A/view?usp=sharing
21 | RealBlurR_test = '1glgeWXCy7Y0qWDc0MXBTUlZYJf8984hS' ## https://drive.google.com/file/d/1glgeWXCy7Y0qWDc0MXBTUlZYJf8984hS/view?usp=sharing
22 | RealBlurJ_test = '1Rb1DhhXmX7IXfilQ-zL9aGjQfAAvQTrW' ## https://drive.google.com/file/d/1Rb1DhhXmX7IXfilQ-zL9aGjQfAAvQTrW/view?usp=sharing
23 |
24 | dataset = args.dataset
25 |
26 | for data in args.data.split('-'):
27 | if data == 'train':
28 | print('GoPro Training Data!')
29 | os.makedirs(os.path.join('Datasets', 'Downloads'), exist_ok=True)
30 | # gdown.download(id=GoPro_train, output='Datasets/Downloads/train.zip', quiet=False)
31 | os.system(f'gdrive download {GoPro_train} --path Datasets/Downloads/')
32 | print('Extracting GoPro data...')
33 | shutil.unpack_archive('Datasets/Downloads/train.zip', 'Datasets/Downloads')
34 | os.rename(os.path.join('Datasets', 'Downloads', 'train'), os.path.join('Datasets', 'Downloads', 'GoPro'))
35 | os.remove('Datasets/Downloads/train.zip')
36 |
37 | if data == 'test':
38 | if dataset == 'all' or dataset == 'GoPro':
39 | print('GoPro Testing Data!')
40 | # gdown.download(id=GoPro_test, output='Datasets/test.zip', quiet=False)
41 | os.system(f'gdrive download {GoPro_test} --path Datasets/')
42 | print('Extracting GoPro Data...')
43 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
44 | os.remove('Datasets/test.zip')
45 |
46 | if dataset == 'all' or dataset == 'HIDE':
47 | print('HIDE Testing Data!')
48 | # gdown.download(id=HIDE_test, output='Datasets/test.zip', quiet=False)
49 | os.system(f'gdrive download {HIDE_test} --path Datasets/')
50 | print('Extracting HIDE Data...')
51 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
52 | os.remove('Datasets/test.zip')
53 |
54 | if dataset == 'all' or dataset == 'RealBlur_R':
55 | print('RealBlur_R Testing Data!')
56 | # gdown.download(id=RealBlurR_test, output='Datasets/test.zip', quiet=False)
57 | os.system(f'gdrive download {RealBlurR_test} --path Datasets/')
58 | print('Extracting RealBlur_R Data...')
59 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
60 | os.remove('Datasets/test.zip')
61 |
62 | if dataset == 'all' or dataset == 'RealBlur_J':
63 | print('RealBlur_J testing Data!')
64 | # gdown.download(id=RealBlurJ_test, output='Datasets/test.zip', quiet=False)
65 | os.system(f'gdrive download {RealBlurJ_test} --path Datasets/')
66 | print('Extracting RealBlur_J Data...')
67 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
68 | os.remove('Datasets/test.zip')
69 |
70 |
71 | # print('Download completed successfully!')
72 |
--------------------------------------------------------------------------------
/Deblurring/evaluate_gopro_hide.m:
--------------------------------------------------------------------------------
1 | %% Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | %% Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | %% https://arxiv.org/abs/2111.09881
4 |
5 | close all;clear all;
6 |
7 | % datasets = {'GoPro'};
8 | datasets = {'GoPro', 'HIDE'};
9 | num_set = length(datasets);
10 |
11 | tic
12 | delete(gcp('nocreate'))
13 | parpool('local',20);
14 |
15 | for idx_set = 1:num_set
16 | file_path = strcat('./results/', datasets{idx_set}, '/');
17 | gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/');
18 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
19 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
20 | img_num = length(path_list);
21 |
22 | total_psnr = 0;
23 | total_ssim = 0;
24 | if img_num > 0
25 | parfor j = 1:img_num
26 | image_name = path_list(j).name;
27 | gt_name = gt_list(j).name;
28 | input = imread(strcat(file_path,image_name));
29 | gt = imread(strcat(gt_path, gt_name));
30 | ssim_val = ssim(input, gt);
31 | psnr_val = psnr(input, gt);
32 | total_ssim = total_ssim + ssim_val;
33 | total_psnr = total_psnr + psnr_val;
34 | end
35 | end
36 | qm_psnr = total_psnr / img_num;
37 | qm_ssim = total_ssim / img_num;
38 |
39 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
40 |
41 | end
42 | delete(gcp('nocreate'))
43 | toc
44 |
--------------------------------------------------------------------------------
/Deblurring/evaluate_realblur.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from natsort import natsorted
9 | from skimage import io
10 | import cv2
11 | from skimage.metrics import structural_similarity
12 | from tqdm import tqdm
13 | import concurrent.futures
14 |
15 | def image_align(deblurred, gt):
16 | # this function is based on kohler evaluation code
17 | z = deblurred
18 | c = np.ones_like(z)
19 | x = gt
20 |
21 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
22 |
23 | warp_mode = cv2.MOTION_HOMOGRAPHY
24 | warp_matrix = np.eye(3, 3, dtype=np.float32)
25 |
26 | # Specify the number of iterations.
27 | number_of_iterations = 100
28 |
29 | termination_eps = 0
30 |
31 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
32 | number_of_iterations, termination_eps)
33 |
34 | # Run the ECC algorithm. The results are stored in warp_matrix.
35 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
36 |
37 | target_shape = x.shape
38 | shift = warp_matrix
39 |
40 | zr = cv2.warpPerspective(
41 | zs,
42 | warp_matrix,
43 | (target_shape[1], target_shape[0]),
44 | flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP,
45 | borderMode=cv2.BORDER_REFLECT)
46 |
47 | cr = cv2.warpPerspective(
48 | np.ones_like(zs, dtype='float32'),
49 | warp_matrix,
50 | (target_shape[1], target_shape[0]),
51 | flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP,
52 | borderMode=cv2.BORDER_CONSTANT,
53 | borderValue=0)
54 |
55 | zr = zr * cr
56 | xr = x * cr
57 |
58 | return zr, xr, cr, shift
59 |
60 | def compute_psnr(image_true, image_test, image_mask, data_range=None):
61 | # this function is based on skimage.metrics.peak_signal_noise_ratio
62 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
63 | return 10 * np.log10((data_range ** 2) / err)
64 |
65 |
66 | def compute_ssim(tar_img, prd_img, cr1):
67 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True)
68 | ssim_map = ssim_map * cr1
69 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
70 | win_size = 2 * r + 1
71 | pad = (win_size - 1) // 2
72 | ssim = ssim_map[pad:-pad,pad:-pad,:]
73 | crop_cr1 = cr1[pad:-pad,pad:-pad,:]
74 | ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0)
75 | ssim = np.mean(ssim)
76 | return ssim
77 |
78 | def proc(filename):
79 | tar,prd = filename
80 | tar_img = io.imread(tar)
81 | prd_img = io.imread(prd)
82 |
83 | tar_img = tar_img.astype(np.float32)/255.0
84 | prd_img = prd_img.astype(np.float32)/255.0
85 |
86 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
87 |
88 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
89 | SSIM = compute_ssim(tar_img, prd_img, cr1)
90 | return (PSNR,SSIM)
91 |
92 | datasets = ['RealBlur_J', 'RealBlur_R']
93 |
94 | for dataset in datasets:
95 |
96 | file_path = os.path.join('results' , dataset)
97 | gt_path = os.path.join('Datasets', 'test', dataset, 'target')
98 |
99 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
100 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
101 |
102 | assert len(path_list) != 0, "Predicted files not found"
103 | assert len(gt_list) != 0, "Target files not found"
104 |
105 | psnr, ssim = [], []
106 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
107 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
108 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
109 | psnr.append(PSNR_SSIM[0])
110 | ssim.append(PSNR_SSIM[1])
111 |
112 | avg_psnr = sum(psnr)/len(psnr)
113 | avg_ssim = sum(ssim)/len(ssim)
114 |
115 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
116 |
--------------------------------------------------------------------------------
/Deblurring/generate_patches_gopro.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | ##### Data preparation file for training Restormer on the GoPro Dataset ########
6 |
7 | import cv2
8 | import numpy as np
9 | from glob import glob
10 | from natsort import natsorted
11 | import os
12 | from tqdm import tqdm
13 | from pdb import set_trace as stx
14 | from joblib import Parallel, delayed
15 | import multiprocessing
16 |
17 | def train_files(file_):
18 | lr_file, hr_file = file_
19 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0]
20 | lr_img = cv2.imread(lr_file)
21 | hr_img = cv2.imread(hr_file)
22 | num_patch = 0
23 | w, h = lr_img.shape[:2]
24 | if w > p_max and h > p_max:
25 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int))
26 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int))
27 | w1.append(w-patch_size)
28 | h1.append(h-patch_size)
29 | for i in w1:
30 | for j in h1:
31 | num_patch += 1
32 |
33 | lr_patch = lr_img[i:i+patch_size, j:j+patch_size,:]
34 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:]
35 |
36 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.png')
37 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.png')
38 |
39 | cv2.imwrite(lr_savename, lr_patch)
40 | cv2.imwrite(hr_savename, hr_patch)
41 |
42 | else:
43 | lr_savename = os.path.join(lr_tar, filename + '.png')
44 | hr_savename = os.path.join(hr_tar, filename + '.png')
45 |
46 | cv2.imwrite(lr_savename, lr_img)
47 | cv2.imwrite(hr_savename, hr_img)
48 |
49 | def val_files(file_):
50 | lr_file, hr_file = file_
51 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0]
52 | lr_img = cv2.imread(lr_file)
53 | hr_img = cv2.imread(hr_file)
54 |
55 | lr_savename = os.path.join(lr_tar, filename + '.png')
56 | hr_savename = os.path.join(hr_tar, filename + '.png')
57 |
58 | w, h = lr_img.shape[:2]
59 |
60 | i = (w-val_patch_size)//2
61 | j = (h-val_patch_size)//2
62 |
63 | lr_patch = lr_img[i:i+val_patch_size, j:j+val_patch_size,:]
64 | hr_patch = hr_img[i:i+val_patch_size, j:j+val_patch_size,:]
65 |
66 | cv2.imwrite(lr_savename, lr_patch)
67 | cv2.imwrite(hr_savename, hr_patch)
68 |
69 | ############ Prepare Training data ####################
70 | num_cores = 10
71 | patch_size = 512
72 | overlap = 256
73 | p_max = 0
74 |
75 | src = 'Datasets/Downloads/GoPro'
76 | tar = 'Datasets/train/GoPro'
77 |
78 | lr_tar = os.path.join(tar, 'input_crops')
79 | hr_tar = os.path.join(tar, 'target_crops')
80 |
81 | os.makedirs(lr_tar, exist_ok=True)
82 | os.makedirs(hr_tar, exist_ok=True)
83 |
84 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg')))
85 | hr_files = natsorted(glob(os.path.join(src, 'target', '*.png')) + glob(os.path.join(src, 'target', '*.jpg')))
86 |
87 | files = [(i, j) for i, j in zip(lr_files, hr_files)]
88 |
89 | Parallel(n_jobs=num_cores)(delayed(train_files)(file_) for file_ in tqdm(files))
90 |
91 |
92 | ############ Prepare validation data ####################
93 | val_patch_size = 256
94 | src = 'Datasets/test/GoPro'
95 | tar = 'Datasets/val/GoPro'
96 |
97 | lr_tar = os.path.join(tar, 'input_crops')
98 | hr_tar = os.path.join(tar, 'target_crops')
99 |
100 | os.makedirs(lr_tar, exist_ok=True)
101 | os.makedirs(hr_tar, exist_ok=True)
102 |
103 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg')))
104 | hr_files = natsorted(glob(os.path.join(src, 'target', '*.png')) + glob(os.path.join(src, 'target', '*.jpg')))
105 |
106 | files = [(i, j) for i, j in zip(lr_files, hr_files)]
107 |
108 | Parallel(n_jobs=num_cores)(delayed(val_files)(file_) for file_ in tqdm(files))
109 |
--------------------------------------------------------------------------------
/Deblurring/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | pre-trained deblurring model is available [here](https://drive.google.com/drive/folders/1czMyfRTQDX3j3ErByYeZ1PM4GVLbJeGK?usp=sharing)
--------------------------------------------------------------------------------
/Deblurring/test.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 |
6 | import numpy as np
7 | import os
8 | import argparse
9 | from tqdm import tqdm
10 |
11 | import torch.nn as nn
12 | import torch
13 | import torch.nn.functional as F
14 | import utils
15 |
16 | from natsort import natsorted
17 | from glob import glob
18 | from basicsr.models.archs.restormer_arch import Restormer
19 | from skimage import img_as_ubyte
20 | from pdb import set_trace as stx
21 |
22 | parser = argparse.ArgumentParser(description='Single Image Motion Deblurring using Restormer')
23 |
24 | parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images')
25 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
26 | parser.add_argument('--weights', default='./pretrained_models/motion_deblurring.pth', type=str, help='Path to weights')
27 | parser.add_argument('--dataset', default='GoPro', type=str, help='Test Dataset') # ['GoPro', 'HIDE', 'RealBlur_J', 'RealBlur_R']
28 |
29 | args = parser.parse_args()
30 |
31 | ####### Load yaml #######
32 | yaml_file = 'Options/Deblurring_Restormer.yml'
33 | import yaml
34 |
35 | try:
36 | from yaml import CLoader as Loader
37 | except ImportError:
38 | from yaml import Loader
39 |
40 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
41 |
42 | s = x['network_g'].pop('type')
43 | ##########################
44 |
45 | model_restoration = Restormer(**x['network_g'])
46 |
47 | checkpoint = torch.load(args.weights)
48 | model_restoration.load_state_dict(checkpoint['params'])
49 | print("===>Testing using weights: ",args.weights)
50 | model_restoration.cuda()
51 | model_restoration = nn.DataParallel(model_restoration)
52 | model_restoration.eval()
53 |
54 |
55 | factor = 8
56 | dataset = args.dataset
57 | result_dir = os.path.join(args.result_dir, dataset)
58 | os.makedirs(result_dir, exist_ok=True)
59 |
60 | inp_dir = os.path.join(args.input_dir, 'test', dataset, 'input')
61 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
62 | with torch.no_grad():
63 | for file_ in tqdm(files):
64 | torch.cuda.ipc_collect()
65 | torch.cuda.empty_cache()
66 |
67 | img = np.float32(utils.load_img(file_))/255.
68 | img = torch.from_numpy(img).permute(2,0,1)
69 | input_ = img.unsqueeze(0).cuda()
70 |
71 | # Padding in case images are not multiples of 8
72 | h,w = input_.shape[2], input_.shape[3]
73 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
74 | padh = H-h if h%factor!=0 else 0
75 | padw = W-w if w%factor!=0 else 0
76 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
77 |
78 | restored = model_restoration(input_)
79 |
80 | # Unpad images to original dimensions
81 | restored = restored[:,:,:h,:w]
82 |
83 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
84 |
85 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
86 |
--------------------------------------------------------------------------------
/Deblurring/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Denoising/Datasets/README.md:
--------------------------------------------------------------------------------
1 | For training and testing, your directory structure should look like this
2 |
3 |
4 | `Datasets`
5 | `├──train`
6 | `├──DFWB`
7 | `└──SIDD`
8 | `├──input_crops`
9 | `└──target_crops`
10 | `├──val`
11 | `└──SIDD`
12 | `├──input_crops`
13 | `└──target_crops`
14 | `└──test`
15 | `├──BSD68`
16 | `├──CBSD68`
17 | `├──Kodak`
18 | `├──McMaster`
19 | `├──Set12`
20 | `├──Urban100`
21 | `├──SIDD`
22 | `├──ValidationNoisyBlocksSrgb.mat`
23 | `└──ValidationGtBlocksSrgb.mat`
24 | `├──DND`
25 | `├──info.mat`
26 | `└──images_srgb`
27 | `├──0001.mat`
28 | `├──0002.mat`
29 | `├── ... `
30 | `└──0050.mat`
31 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianColorDenoising_Restormer.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianColorDenoising_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: random
14 | sigma_range: [0,50]
15 | in_ch: 3 ## RGB image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 25
50 | in_ch: 3 ## RGB image
51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 3
60 | out_channels: 3
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 | # path
71 | path:
72 | pretrain_network_g: ~
73 | strict_load_g: true
74 | resume_state: ~
75 |
76 | # training settings
77 | train:
78 | total_iter: 300000
79 | warmup_iter: -1 # no warm up
80 | use_grad_clip: true
81 |
82 | # Split 300k iterations into two cycles.
83 | # 1st cycle: fixed 3e-4 LR for 92k iters.
84 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
85 | scheduler:
86 | type: CosineAnnealingRestartCyclicLR
87 | periods: [92000, 208000]
88 | restart_weights: [1,1]
89 | eta_mins: [0.0003,0.000001]
90 |
91 | mixing_augs:
92 | mixup: true
93 | mixup_beta: 1.2
94 | use_identity: true
95 |
96 | optim_g:
97 | type: AdamW
98 | lr: !!float 3e-4
99 | weight_decay: !!float 1e-4
100 | betas: [0.9, 0.999]
101 |
102 | # losses
103 | pixel_opt:
104 | type: L1Loss
105 | loss_weight: 1
106 | reduction: mean
107 |
108 | # validation settings
109 | val:
110 | window_size: 8
111 | val_freq: !!float 4e3
112 | save_img: false
113 | rgb2bgr: true
114 | use_image: false
115 | max_minibatch: 8
116 |
117 | metrics:
118 | psnr: # metric name, can be arbitrary
119 | type: calculate_psnr
120 | crop_border: 0
121 | test_y_channel: false
122 |
123 | # logging settings
124 | logger:
125 | print_freq: 1000
126 | save_checkpoint_freq: !!float 4e3
127 | use_tb_logger: true
128 | wandb:
129 | project: ~
130 | resume_id: ~
131 |
132 | # dist training settings
133 | dist_params:
134 | backend: nccl
135 | port: 29500
136 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianColorDenoising_RestormerSigma15.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianColorDenoising_RestormerSigma15
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: constant
14 | sigma_range: 15
15 | in_ch: 3 ## RGB image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 15
50 | in_ch: 3 ## RGB image
51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 3
60 | out_channels: 3
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianColorDenoising_RestormerSigma25.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianColorDenoising_RestormerSigma25
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: constant
14 | sigma_range: 25
15 | in_ch: 3 ## RGB image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 25
50 | in_ch: 3 ## RGB image
51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 3
60 | out_channels: 3
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianColorDenoising_RestormerSigma50.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianColorDenoising_RestormerSigma50
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: constant
14 | sigma_range: 50
15 | in_ch: 3 ## RGB image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 50
50 | in_ch: 3 ## RGB image
51 | dataroot_gt: ./Denoising/Datasets/test/CBSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 3
60 | out_channels: 3
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianGrayDenoising_Restormer.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianGrayDenoising_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: random
14 | sigma_range: [0,50]
15 | in_ch: 1 ## Grayscale image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 25
50 | in_ch: 1 ## Grayscale image
51 | dataroot_gt: ./Denoising/Datasets/test/BSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 1
60 | out_channels: 1
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianGrayDenoising_RestormerSigma15.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianGrayDenoising_RestormerSigma15
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: constant
14 | sigma_range: 15
15 | in_ch: 1 ## Grayscale image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 15
50 | in_ch: 1 ## Grayscale image
51 | dataroot_gt: ./Denoising/Datasets/test/BSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 1
60 | out_channels: 1
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianGrayDenoising_RestormerSigma25.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianGrayDenoising_RestormerSigma25
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: constant
14 | sigma_range: 25
15 | in_ch: 1 ## Grayscale image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 25
50 | in_ch: 1 ## Grayscale image
51 | dataroot_gt: ./Denoising/Datasets/test/BSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 1
60 | out_channels: 1
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianGrayDenoising_RestormerSigma50.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianGrayDenoising_RestormerSigma50
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: constant
14 | sigma_range: 50
15 | in_ch: 1 ## Grayscale image
16 | dataroot_gt: ./Denoising/Datasets/train/DFWB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | ### -------------Progressive training--------------------------
30 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 384 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | ### ------- Training on single fixed-patch size 128x128---------
37 | # mini_batch_sizes: [8]
38 | # iters: [300000]
39 | # gt_size: 128
40 | # gt_sizes: [128]
41 | ### ------------------------------------------------------------
42 |
43 | dataset_enlarge_ratio: 1
44 | prefetch_mode: ~
45 |
46 | val:
47 | name: ValSet
48 | type: Dataset_GaussianDenoising
49 | sigma_test: 50
50 | in_ch: 1 ## Grayscale image
51 | dataroot_gt: ./Denoising/Datasets/test/BSD68
52 | dataroot_lq: none
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 1
60 | out_channels: 1
61 | dim: 48
62 | num_blocks: [4,6,6,8]
63 | num_refinement_blocks: 4
64 | heads: [1,2,4,8]
65 | ffn_expansion_factor: 2.66
66 | bias: False
67 | LayerNorm_type: BiasFree
68 | dual_pixel_task: False
69 |
70 |
71 | # path
72 | path:
73 | pretrain_network_g: ~
74 | strict_load_g: true
75 | resume_state: ~
76 |
77 | # training settings
78 | train:
79 | total_iter: 300000
80 | warmup_iter: -1 # no warm up
81 | use_grad_clip: true
82 |
83 | # Split 300k iterations into two cycles.
84 | # 1st cycle: fixed 3e-4 LR for 92k iters.
85 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
86 | scheduler:
87 | type: CosineAnnealingRestartCyclicLR
88 | periods: [92000, 208000]
89 | restart_weights: [1,1]
90 | eta_mins: [0.0003,0.000001]
91 |
92 | mixing_augs:
93 | mixup: true
94 | mixup_beta: 1.2
95 | use_identity: true
96 |
97 | optim_g:
98 | type: AdamW
99 | lr: !!float 3e-4
100 | weight_decay: !!float 1e-4
101 | betas: [0.9, 0.999]
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: false
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Denoising/Options/RealDenoising_Restormer.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: RealDenoising_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./Denoising/Datasets/train/SIDD/target_crops
14 | dataroot_lq: ./Denoising/Datasets/train/SIDD/input_crops
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ### -------------Progressive training--------------------------
27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
28 | iters: [92000,64000,48000,36000,36000,24000]
29 | gt_size: 384 # Max patch size for progressive training
30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
31 | ### ------------------------------------------------------------
32 |
33 | ### ------- Training on single fixed-patch size 128x128---------
34 | # mini_batch_sizes: [8]
35 | # iters: [300000]
36 | # gt_size: 128
37 | # gt_sizes: [128]
38 | ### ------------------------------------------------------------
39 |
40 | dataset_enlarge_ratio: 1
41 | prefetch_mode: ~
42 |
43 | val:
44 | name: ValSet
45 | type: Dataset_PairedImage
46 | dataroot_gt: ./Denoising/Datasets/val/SIDD/target_crops
47 | dataroot_lq: ./Denoising/Datasets/val/SIDD/input_crops
48 | io_backend:
49 | type: disk
50 |
51 | # network structures
52 | network_g:
53 | type: Restormer
54 | inp_channels: 3
55 | out_channels: 3
56 | dim: 48
57 | num_blocks: [4,6,6,8]
58 | num_refinement_blocks: 4
59 | heads: [1,2,4,8]
60 | ffn_expansion_factor: 2.66
61 | bias: False
62 | LayerNorm_type: BiasFree
63 | dual_pixel_task: False
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: true
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 |
104 | # validation settings
105 | val:
106 | window_size: 8
107 | val_freq: !!float 4e3
108 | save_img: false
109 | rgb2bgr: true
110 | use_image: false
111 | max_minibatch: 8
112 |
113 | metrics:
114 | psnr: # metric name, can be arbitrary
115 | type: calculate_psnr
116 | crop_border: 0
117 | test_y_channel: false
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 1000
122 | save_checkpoint_freq: !!float 4e3
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/Denoising/README.md:
--------------------------------------------------------------------------------
1 | # Image Denoising
2 | - [Gaussian Image Denoising](#gaussian-image-denoising)
3 | * [Training](#training)
4 | * [Evaluation](#evaluation)
5 | - [Grayscale blind image denoising testing](#grayscale-blind-image-denoising-testing)
6 | - [Grayscale non-blind image denoising testing](#grayscale-non-blind-image-denoising-testing)
7 | - [Color blind image denoising testing](#color-blind-image-denoising-testing)
8 | - [Color non-blind image denoising testing](#color-non-blind-image-denoising-testing)
9 | - [Real Image Denoising](#real-image-denoising)
10 | * [Training](#training-1)
11 | * [Evaluation](#evaluation-1)
12 | - [Testing on SIDD dataset](#testing-on-sidd-dataset)
13 | - [Testing on DND dataset](#testing-on-dnd-dataset)
14 |
15 | # Gaussian Image Denoising
16 |
17 | - **Blind Denoising:** One model to handle various noise levels
18 | - **Non-Blind Denoising:** Separate models for each noise level
19 |
20 | ## Training
21 |
22 | - Download training (DIV2K, Flickr2K, WED, BSD) and testing datasets, run
23 | ```
24 | python download_data.py --data train-test --noise gaussian
25 | ```
26 |
27 | - Generate image patches from full-resolution training images, run
28 | ```
29 | python generate_patches_dfwb.py
30 | ```
31 |
32 | - Train Restormer for **grayscale blind** image denoising, run
33 | ```
34 | cd Restormer
35 | ./train.sh Denoising/Options/GaussianGrayDenoising_Restormer.yml
36 | ```
37 |
38 | - Train Restormer for **grayscale non-blind** image denoising, run
39 | ```
40 | cd Restormer
41 | ./train.sh Denoising/Options/GaussianGrayDenoising_RestormerSigma15.yml
42 | ./train.sh Denoising/Options/GaussianGrayDenoising_RestormerSigma25.yml
43 | ./train.sh Denoising/Options/GaussianGrayDenoising_RestormerSigma50.yml
44 | ```
45 |
46 | - Train Restormer for **color blind** image denoising, run
47 | ```
48 | cd Restormer
49 | ./train.sh Denoising/Options/GaussianColorDenoising_Restormer.yml
50 | ```
51 |
52 | - Train Restormer for **color non-blind** image denoising, run
53 | ```
54 | cd Restormer
55 | ./train.sh Denoising/Options/GaussianColorDenoising_RestormerSigma15.yml
56 | ./train.sh Denoising/Options/GaussianColorDenoising_RestormerSigma25.yml
57 | ./train.sh Denoising/Options/GaussianColorDenoising_RestormerSigma50.yml
58 | ```
59 |
60 | **Note:** The above training scripts use 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and the yaml file corresponding to each task (e.g., [Denoising/Options/GaussianGrayDenoising_Restormer.yml](Options/GaussianGrayDenoising_Restormer.yml))
61 |
62 | ## Evaluation
63 |
64 | - Download the pre-trained [models](https://drive.google.com/drive/folders/1Qwsjyny54RZWa7zC4Apg7exixLBo4uF0?usp=sharing) and place them in `./pretrained_models/`
65 |
66 | - Download testsets (Set12, BSD68, CBSD68, Kodak, McMaster, Urban100), run
67 | ```
68 | python download_data.py --data test --noise gaussian
69 | ```
70 |
71 | #### Grayscale blind image denoising testing
72 |
73 | - To obtain denoised predictions, run
74 | ```
75 | python test_gaussian_gray_denoising.py --model_type blind --sigmas 15,25,50
76 | ```
77 |
78 | - To reproduce PSNR Table 4 (top super-row), run
79 | ```
80 | python evaluate_gaussian_gray_denoising.py --model_type blind --sigmas 15,25,50
81 | ```
82 |
83 | #### Grayscale non-blind image denoising testing
84 |
85 | - To obtain denoised predictions, run
86 | ```
87 | python test_gaussian_gray_denoising.py --model_type non_blind --sigmas 15,25,50
88 | ```
89 |
90 | - To reproduce PSNR Table 4 (bottom super-row), run
91 | ```
92 | python evaluate_gaussian_gray_denoising.py --model_type non_blind --sigmas 15,25,50
93 | ```
94 |
95 | #### Color blind image denoising testing
96 |
97 | - To obtain denoised predictions, run
98 | ```
99 | python test_gaussian_color_denoising.py --model_type blind --sigmas 15,25,50
100 | ```
101 |
102 | - To reproduce PSNR Table 5 (top super-row), run
103 | ```
104 | python evaluate_gaussian_color_denoising.py --model_type blind --sigmas 15,25,50
105 | ```
106 |
107 | #### Color non-blind image denoising testing
108 |
109 | - To obtain denoised predictions, run
110 | ```
111 | python test_gaussian_color_denoising.py --model_type non_blind --sigmas 15,25,50
112 | ```
113 |
114 | - To reproduce PSNR Table 5 (bottom super-row), run
115 | ```
116 | python evaluate_gaussian_color_denoising.py --model_type non_blind --sigmas 15,25,50
117 | ```
118 |
119 |
120 |
121 | # Real Image Denoising
122 |
123 | ## Training
124 |
125 | - Download SIDD training data, run
126 | ```
127 | python download_data.py --data train --noise real
128 | ```
129 |
130 | - Generate image patches from full-resolution training images, run
131 | ```
132 | python generate_patches_sidd.py
133 | ```
134 |
135 | - Train Restormer
136 | ```
137 | cd Restormer
138 | ./train.sh Denoising/Options/RealDenoising_Restormer.yml
139 | ```
140 |
141 | **Note:** This training script uses 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and [Denoising/Options/RealDenoising_Restormer.yml](Options/RealDenoising_Restormer.yml)
142 |
143 | ## Evaluation
144 |
145 | - Download the pre-trained [model](https://drive.google.com/file/d/1FF_4NTboTWQ7sHCq4xhyLZsSl0U0JfjH/view?usp=sharing) and place it in `./pretrained_models/`
146 |
147 | #### Testing on SIDD dataset
148 |
149 | - Download SIDD validation data, run
150 | ```
151 | python download_data.py --noise real --data test --dataset SIDD
152 | ```
153 |
154 | - To obtain denoised results, run
155 | ```
156 | python test_real_denoising_sidd.py --save_images
157 | ```
158 |
159 | - To reproduce PSNR/SSIM scores on SIDD data (Table 6), run
160 | ```
161 | evaluate_sidd.m
162 | ```
163 |
164 | #### Testing on DND dataset
165 |
166 | - Download the DND benchmark data, run
167 | ```
168 | python download_data.py --noise real --data test --dataset DND
169 | ```
170 |
171 | - To obtain denoised results, run
172 | ```
173 | python test_real_denoising_dnd.py --save_images
174 | ```
175 |
176 | - To reproduce PSNR/SSIM scores (Table 6), upload the results to the DND benchmark website.
177 |
--------------------------------------------------------------------------------
/Denoising/download_data.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | ## Download training and testing data for Image Denoising task
6 |
7 |
8 | import os
9 | # import gdown
10 | import shutil
11 |
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data', type=str, required=True, help='train, test or train-test')
16 | parser.add_argument('--dataset', type=str, default='SIDD', help='all or SIDD or DND')
17 | parser.add_argument('--noise', type=str, required=True, help='real or gaussian')
18 | args = parser.parse_args()
19 |
20 | ### Google drive IDs ######
21 | SIDD_train = '1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw' ## https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing
22 | SIDD_val = '1Fw6Ey1R-nCHN9WEpxv0MnMqxij-ECQYJ' ## https://drive.google.com/file/d/1Fw6Ey1R-nCHN9WEpxv0MnMqxij-ECQYJ/view?usp=sharing
23 | SIDD_test = '11vfqV-lqousZTuAit1Qkqghiv_taY0KZ' ## https://drive.google.com/file/d/11vfqV-lqousZTuAit1Qkqghiv_taY0KZ/view?usp=sharing
24 | DND_test = '1CYCDhaVxYYcXhSfEVDUwkvJDtGxeQ10G' ## https://drive.google.com/file/d/1CYCDhaVxYYcXhSfEVDUwkvJDtGxeQ10G/view?usp=sharing
25 |
26 | BSD400 = '1idKFDkAHJGAFDn1OyXZxsTbOSBx9GS8N' ## https://drive.google.com/file/d/1idKFDkAHJGAFDn1OyXZxsTbOSBx9GS8N/view?usp=sharing
27 | DIV2K = '13wLWWXvFkuYYVZMMAYiMVdSA7iVEf2fM' ## https://drive.google.com/file/d/13wLWWXvFkuYYVZMMAYiMVdSA7iVEf2fM/view?usp=sharing
28 | Flickr2K = '1J8xjFCrVzeYccD-LF08H7HiIsmi8l2Wn' ## https://drive.google.com/file/d/1J8xjFCrVzeYccD-LF08H7HiIsmi8l2Wn/view?usp=sharing
29 | WaterlooED = '19_mCE_GXfmE5yYsm-HEzuZQqmwMjPpJr' ## https://drive.google.com/file/d/19_mCE_GXfmE5yYsm-HEzuZQqmwMjPpJr/view?usp=sharing
30 | gaussian_test = '1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0' ## https://drive.google.com/file/d/1mwMLt-niNqcQpfN_ZduG9j4k6P_ZkOl0/view?usp=sharing
31 |
32 |
33 | noise = args.noise
34 |
35 | for data in args.data.split('-'):
36 | if noise == 'real':
37 | if data == 'train':
38 | print('SIDD Training Data!')
39 | os.makedirs(os.path.join('Datasets', 'Downloads'), exist_ok=True)
40 | # gdown.download(id=SIDD_train, output='Datasets/Downloads/train.zip', quiet=False)
41 | os.system(f'gdrive download {SIDD_train} --path Datasets/Downloads/')
42 | print('Extracting SIDD Data...')
43 | shutil.unpack_archive('Datasets/Downloads/train.zip', 'Datasets/Downloads')
44 | os.rename(os.path.join('Datasets', 'Downloads', 'train'), os.path.join('Datasets', 'Downloads', 'SIDD'))
45 | os.remove('Datasets/Downloads/train.zip')
46 |
47 | print('SIDD Validation Data!')
48 | # gdown.download(id=SIDD_val, output='Datasets/val.zip', quiet=False)
49 | os.system(f'gdrive download {SIDD_val} --path Datasets/')
50 | print('Extracting SIDD Data...')
51 | shutil.unpack_archive('Datasets/val.zip', 'Datasets')
52 | os.remove('Datasets/val.zip')
53 |
54 | if data == 'test':
55 | if args.dataset == 'all' or args.dataset == 'SIDD':
56 | print('SIDD Testing Data!')
57 | # gdown.download(id=SIDD_test, output='Datasets/test.zip', quiet=False)
58 | os.system(f'gdrive download {SIDD_test} --path Datasets/')
59 | print('Extracting SIDD Data...')
60 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
61 | os.remove('Datasets/test.zip')
62 |
63 | if args.dataset == 'all' or args.dataset == 'DND':
64 | print('DND Testing Data!')
65 | # gdown.download(id=DND_test, output='Datasets/test.zip', quiet=False)
66 | os.system(f'gdrive download {DND_test} --path Datasets/')
67 | print('Extracting DND data...')
68 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
69 | os.remove('Datasets/test.zip')
70 |
71 | if noise == 'gaussian':
72 | if data == 'train':
73 | os.makedirs(os.path.join('Datasets', 'Downloads'), exist_ok=True)
74 | print('WaterlooED Training Data!')
75 | # gdown.download(id=WaterlooED, output='Datasets/Downloads/WaterlooED.zip', quiet=False)
76 | os.system(f'gdrive download {WaterlooED} --path Datasets/Downloads/')
77 | print('Extracting WaterlooED Data...')
78 | shutil.unpack_archive('Datasets/Downloads/WaterlooED.zip', 'Datasets/Downloads')
79 | os.remove('Datasets/Downloads/WaterlooED.zip')
80 |
81 | print('DIV2K Training Data!')
82 | # gdown.download(id=DIV2K, output='Datasets/Downloads/DIV2K.zip', quiet=False)
83 | os.system(f'gdrive download {DIV2K} --path Datasets/Downloads/')
84 | print('Extracting DIV2K Data...')
85 | shutil.unpack_archive('Datasets/Downloads/DIV2K.zip', 'Datasets/Downloads')
86 | os.remove('Datasets/Downloads/DIV2K.zip')
87 |
88 |
89 | print('BSD400 Training Data!')
90 | # gdown.download(id=BSD400, output='Datasets/Downloads/BSD400.zip', quiet=False)
91 | os.system(f'gdrive download {BSD400} --path Datasets/Downloads/')
92 | print('Extracting BSD400 data...')
93 | shutil.unpack_archive('Datasets/Downloads/BSD400.zip', 'Datasets/Downloads')
94 | os.remove('Datasets/Downloads/BSD400.zip')
95 |
96 | print('Flickr2K Training Data!')
97 | # gdown.download(id=Flickr2K, output='Datasets/Downloads/Flickr2K.zip', quiet=False)
98 | os.system(f'gdrive download {Flickr2K} --path Datasets/Downloads/')
99 | print('Extracting Flickr2K data...')
100 | shutil.unpack_archive('Datasets/Downloads/Flickr2K.zip', 'Datasets/Downloads')
101 | os.remove('Datasets/Downloads/Flickr2K.zip')
102 |
103 | if data == 'test':
104 | print('Gaussian Denoising Testing Data!')
105 | # gdown.download(id=gaussian_test, output='Datasets/test.zip', quiet=False)
106 | os.system(f'gdrive download {gaussian_test} --path Datasets/')
107 | print('Extracting Data...')
108 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
109 | os.remove('Datasets/test.zip')
110 |
111 | # print('Download completed successfully!')
112 |
--------------------------------------------------------------------------------
/Denoising/evaluate_gaussian_color_denoising.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from natsort import natsorted
9 | from skimage import io
10 | import cv2
11 | import argparse
12 | from skimage.metrics import structural_similarity
13 | from tqdm import tqdm
14 | import concurrent.futures
15 | import utils
16 |
17 | def proc(filename):
18 | tar,prd = filename
19 | tar_img = utils.load_img(tar)
20 | prd_img = utils.load_img(prd)
21 |
22 | PSNR = utils.calculate_psnr(tar_img, prd_img)
23 | # SSIM = utils.calculate_ssim(tar_img, prd_img)
24 | return PSNR
25 |
26 | parser = argparse.ArgumentParser(description='Gasussian Color Denoising using Restormer')
27 |
28 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
29 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
30 |
31 | args = parser.parse_args()
32 |
33 | sigmas = np.int_(args.sigmas.split(','))
34 |
35 | datasets = ['CBSD68', 'Kodak', 'McMaster','Urban100']
36 |
37 | for dataset in datasets:
38 |
39 | gt_path = os.path.join('Datasets','test', dataset)
40 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif')))
41 | assert len(gt_list) != 0, "Target files not found"
42 |
43 | for sigma_test in sigmas:
44 | file_path = os.path.join('results', 'Gaussian_Color_Denoising', args.model_type, dataset, str(sigma_test))
45 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif')))
46 | assert len(path_list) != 0, "Predicted files not found"
47 |
48 | psnr, ssim = [], []
49 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
50 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
51 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
52 | psnr.append(PSNR_SSIM)
53 | # ssim.append(PSNR_SSIM[1])
54 |
55 | avg_psnr = sum(psnr)/len(psnr)
56 | # avg_ssim = sum(ssim)/len(ssim)
57 |
58 | print('For {:s} dataset Noise Level {:d} PSNR: {:f}\n'.format(dataset, sigma_test, avg_psnr))
59 | # print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
60 |
--------------------------------------------------------------------------------
/Denoising/evaluate_gaussian_gray_denoising.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from natsort import natsorted
9 | from skimage import io
10 | import cv2
11 | import argparse
12 | from skimage.metrics import structural_similarity
13 | from tqdm import tqdm
14 | import concurrent.futures
15 | import utils
16 |
17 | def proc(filename):
18 | tar,prd = filename
19 | tar_img = utils.load_gray_img(tar)
20 | prd_img = utils.load_gray_img(prd)
21 |
22 | PSNR = utils.calculate_psnr(tar_img, prd_img)
23 | # SSIM = utils.calculate_ssim(tar_img, prd_img)
24 | return PSNR
25 |
26 | parser = argparse.ArgumentParser(description='Gasussian Grayscale Denoising using Restormer')
27 |
28 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
29 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
30 |
31 | args = parser.parse_args()
32 |
33 | sigmas = np.int_(args.sigmas.split(','))
34 |
35 | datasets = ['Set12', 'BSD68', 'Urban100']
36 |
37 | for dataset in datasets:
38 |
39 | gt_path = os.path.join('Datasets','test', dataset)
40 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif')))
41 | assert len(gt_list) != 0, "Target files not found"
42 |
43 | for sigma_test in sigmas:
44 | file_path = os.path.join('results', 'Gaussian_Gray_Denoising', args.model_type, dataset, str(sigma_test))
45 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif')))
46 | assert len(path_list) != 0, "Predicted files not found"
47 |
48 | psnr, ssim = [], []
49 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
50 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
51 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
52 | psnr.append(PSNR_SSIM)
53 | # ssim.append(PSNR_SSIM[1])
54 |
55 | avg_psnr = sum(psnr)/len(psnr)
56 | # avg_ssim = sum(ssim)/len(ssim)
57 |
58 | print('For {:s} dataset Noise Level {:d} PSNR: {:f}\n'.format(dataset, sigma_test, avg_psnr))
59 | # print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
60 |
--------------------------------------------------------------------------------
/Denoising/evaluate_sidd.m:
--------------------------------------------------------------------------------
1 | close all;clear all;
2 |
3 | denoised = load('./results/Real_Denoising/SIDD/mat/Idenoised.mat');
4 | gt = load('./Datasets/test/SIDD/ValidationGtBlocksSrgb.mat');
5 |
6 | denoised = denoised.Idenoised;
7 | gt = gt.ValidationGtBlocksSrgb;
8 | gt = im2single(gt);
9 |
10 | total_psnr = 0;
11 | total_ssim = 0;
12 | for i = 1:40
13 | for k = 1:32
14 | denoised_patch = squeeze(denoised(i,k,:,:,:));
15 | gt_patch = squeeze(gt(i,k,:,:,:));
16 | ssim_val = ssim(denoised_patch, gt_patch);
17 | psnr_val = psnr(denoised_patch, gt_patch);
18 | total_ssim = total_ssim + ssim_val;
19 | total_psnr = total_psnr + psnr_val;
20 | end
21 | end
22 | qm_psnr = total_psnr / (40*32);
23 | qm_ssim = total_ssim / (40*32);
24 |
25 | fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim);
26 |
27 |
--------------------------------------------------------------------------------
/Denoising/generate_patches_dfwb.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from glob import glob
5 | from natsort import natsorted
6 | import os
7 | from tqdm import tqdm
8 | from pdb import set_trace as stx
9 |
10 | src = 'Datasets/Downloads'
11 | tar = 'Datasets/train/DFWB'
12 | os.makedirs(tar, exist_ok=True)
13 |
14 | patch_size = 512
15 | overlap = 96
16 | p_max = 800
17 |
18 |
19 | def save_files(file_):
20 | path_contents = file_.split(os.sep)
21 | foldname = path_contents[-2]
22 | filename = os.path.splitext(path_contents[-1])[0]
23 | img = cv2.imread(file_)
24 | num_patch = 0
25 | w, h = img.shape[:2]
26 | if w > p_max and h > p_max:
27 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int))
28 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int))
29 | w1.append(w-patch_size)
30 | h1.append(h-patch_size)
31 | for i in w1:
32 | for j in h1:
33 | num_patch += 1
34 | patch = img[i:i+patch_size, j:j+patch_size,:]
35 | savename = os.path.join(tar, foldname + '-' + filename + '-' + str(num_patch) + '.png')
36 | cv2.imwrite(savename, patch)
37 |
38 | else:
39 | savename = os.path.join(tar, foldname + '-' + filename + '.png')
40 | cv2.imwrite(savename, img)
41 |
42 |
43 | files = []
44 | for dataset in ['DIV2K', 'Flickr2K', 'WaterlooED', 'BSD400']:
45 | df = natsorted(glob(os.path.join(src, dataset, '*.png')) + glob(os.path.join(src, dataset, '*.jpg')) + glob(os.path.join(src, dataset, '*.bmp')))
46 | files.extend(df)
47 |
48 | from joblib import Parallel, delayed
49 | import multiprocessing
50 | num_cores = 10
51 | Parallel(n_jobs=num_cores)(delayed(save_files)(file_) for file_ in tqdm(files))
52 |
--------------------------------------------------------------------------------
/Denoising/generate_patches_sidd.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from glob import glob
5 | from natsort import natsorted
6 | import os
7 | from tqdm import tqdm
8 | from pdb import set_trace as stx
9 |
10 |
11 | src = 'Datasets/Downloads/SIDD'
12 | tar = 'Datasets/train/SIDD'
13 |
14 | lr_tar = os.path.join(tar, 'input_crops')
15 | hr_tar = os.path.join(tar, 'target_crops')
16 |
17 | os.makedirs(lr_tar, exist_ok=True)
18 | os.makedirs(hr_tar, exist_ok=True)
19 |
20 | files = natsorted(glob(os.path.join(src, '*', '*.PNG')))
21 |
22 | lr_files, hr_files = [], []
23 | for file_ in files:
24 | filename = os.path.split(file_)[-1]
25 | if 'GT' in filename:
26 | hr_files.append(file_)
27 | if 'NOISY' in filename:
28 | lr_files.append(file_)
29 |
30 | files = [(i, j) for i, j in zip(lr_files, hr_files)]
31 |
32 | patch_size = 512
33 | overlap = 128
34 | p_max = 0
35 |
36 | def save_files(file_):
37 | lr_file, hr_file = file_
38 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0]
39 | lr_img = cv2.imread(lr_file)
40 | hr_img = cv2.imread(hr_file)
41 | num_patch = 0
42 | w, h = lr_img.shape[:2]
43 | if w > p_max and h > p_max:
44 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int))
45 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int))
46 | w1.append(w-patch_size)
47 | h1.append(h-patch_size)
48 | for i in w1:
49 | for j in h1:
50 | num_patch += 1
51 |
52 | lr_patch = lr_img[i:i+patch_size, j:j+patch_size,:]
53 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:]
54 |
55 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.png')
56 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.png')
57 |
58 | cv2.imwrite(lr_savename, lr_patch)
59 | cv2.imwrite(hr_savename, hr_patch)
60 |
61 | else:
62 | lr_savename = os.path.join(lr_tar, filename + '.png')
63 | hr_savename = os.path.join(hr_tar, filename + '.png')
64 |
65 | cv2.imwrite(lr_savename, lr_img)
66 | cv2.imwrite(hr_savename, hr_img)
67 |
68 | from joblib import Parallel, delayed
69 | import multiprocessing
70 | num_cores = 10
71 | Parallel(n_jobs=num_cores)(delayed(save_files)(file_) for file_ in tqdm(files))
72 |
--------------------------------------------------------------------------------
/Denoising/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | pre-trained denoising model is available [here](https://drive.google.com/drive/folders/1Qwsjyny54RZWa7zC4Apg7exixLBo4uF0?usp=sharing)
--------------------------------------------------------------------------------
/Denoising/test_gaussian_color_denoising.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import argparse
8 | from tqdm import tqdm
9 |
10 | import torch.nn as nn
11 | import torch
12 | import torch.nn.functional as F
13 |
14 | from basicsr.models.archs.restormer_arch import Restormer
15 | from skimage import img_as_ubyte
16 | from natsort import natsorted
17 | from glob import glob
18 | import utils
19 | from pdb import set_trace as stx
20 |
21 | parser = argparse.ArgumentParser(description='Gaussian Color Denoising using Restormer')
22 |
23 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images')
24 | parser.add_argument('--result_dir', default='./results/Gaussian_Color_Denoising/', type=str, help='Directory for results')
25 | parser.add_argument('--weights', default='./pretrained_models/gaussian_color_denoising', type=str, help='Path to weights')
26 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
27 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
28 |
29 | args = parser.parse_args()
30 |
31 | ####### Load yaml #######
32 | if args.model_type == 'blind':
33 | yaml_file = 'Options/GaussianColorDenoising_Restormer.yml'
34 | else:
35 | yaml_file = f'Options/GaussianColorDenoising_RestormerSigma{args.sigmas}.yml'
36 | import yaml
37 |
38 | try:
39 | from yaml import CLoader as Loader
40 | except ImportError:
41 | from yaml import Loader
42 |
43 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
44 |
45 | s = x['network_g'].pop('type')
46 | ##########################
47 |
48 | sigmas = np.int_(args.sigmas.split(','))
49 |
50 | factor = 8
51 |
52 | datasets = ['CBSD68', 'Kodak', 'McMaster','Urban100']
53 |
54 | for sigma_test in sigmas:
55 | print("Compute results for noise level",sigma_test)
56 | model_restoration = Restormer(**x['network_g'])
57 | if args.model_type == 'blind':
58 | weights = args.weights+'_blind.pth'
59 | else:
60 | weights = args.weights + '_sigma' + str(sigma_test) +'.pth'
61 | checkpoint = torch.load(weights)
62 | model_restoration.load_state_dict(checkpoint['params'])
63 |
64 | print("===>Testing using weights: ",weights)
65 | print("------------------------------------------------")
66 | model_restoration.cuda()
67 | model_restoration = nn.DataParallel(model_restoration)
68 | model_restoration.eval()
69 |
70 | for dataset in datasets:
71 | inp_dir = os.path.join(args.input_dir, dataset)
72 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif')))
73 | result_dir_tmp = os.path.join(args.result_dir, args.model_type, dataset, str(sigma_test))
74 | os.makedirs(result_dir_tmp, exist_ok=True)
75 |
76 | with torch.no_grad():
77 | for file_ in tqdm(files):
78 | torch.cuda.ipc_collect()
79 | torch.cuda.empty_cache()
80 | img = np.float32(utils.load_img(file_))/255.
81 |
82 | np.random.seed(seed=0) # for reproducibility
83 | img += np.random.normal(0, sigma_test/255., img.shape)
84 |
85 | img = torch.from_numpy(img).permute(2,0,1)
86 | input_ = img.unsqueeze(0).cuda()
87 |
88 | # Padding in case images are not multiples of 8
89 | h,w = input_.shape[2], input_.shape[3]
90 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
91 | padh = H-h if h%factor!=0 else 0
92 | padw = W-w if w%factor!=0 else 0
93 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
94 |
95 | restored = model_restoration(input_)
96 |
97 | # Unpad images to original dimensions
98 | restored = restored[:,:,:h,:w]
99 |
100 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
101 |
102 | save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1])
103 | utils.save_img(save_file, img_as_ubyte(restored))
104 |
--------------------------------------------------------------------------------
/Denoising/test_gaussian_gray_denoising.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import argparse
8 | from tqdm import tqdm
9 |
10 | import torch.nn as nn
11 | import torch
12 | import torch.nn.functional as F
13 |
14 | from basicsr.models.archs.restormer_arch import Restormer
15 | from skimage import img_as_ubyte
16 | from natsort import natsorted
17 | from glob import glob
18 | import utils
19 | from pdb import set_trace as stx
20 |
21 | parser = argparse.ArgumentParser(description='Gasussian Grayscale Denoising using Restormer')
22 |
23 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images')
24 | parser.add_argument('--result_dir', default='./results/Gaussian_Gray_Denoising/', type=str, help='Directory for results')
25 | parser.add_argument('--weights', default='./pretrained_models/gaussian_gray_denoising', type=str, help='Path to weights')
26 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
27 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
28 |
29 | args = parser.parse_args()
30 |
31 | ####### Load yaml #######
32 | if args.model_type == 'blind':
33 | yaml_file = 'Options/GaussianGrayDenoising_Restormer.yml'
34 | else:
35 | yaml_file = f'Options/GaussianGrayDenoising_RestormerSigma{args.sigmas}.yml'
36 | import yaml
37 |
38 | try:
39 | from yaml import CLoader as Loader
40 | except ImportError:
41 | from yaml import Loader
42 |
43 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
44 |
45 | s = x['network_g'].pop('type')
46 | ##########################
47 |
48 | sigmas = np.int_(args.sigmas.split(','))
49 |
50 | factor = 8
51 |
52 | datasets = ['Set12', 'BSD68', 'Urban100']
53 |
54 | for sigma_test in sigmas:
55 | print("Compute results for noise level",sigma_test)
56 | model_restoration = Restormer(**x['network_g'])
57 | if args.model_type == 'blind':
58 | weights = args.weights+'_blind.pth'
59 | else:
60 | weights = args.weights + '_sigma' + str(sigma_test) +'.pth'
61 | checkpoint = torch.load(weights)
62 | model_restoration.load_state_dict(checkpoint['params'])
63 |
64 | print("===>Testing using weights: ",weights)
65 | print("------------------------------------------------")
66 | model_restoration.cuda()
67 | model_restoration = nn.DataParallel(model_restoration)
68 | model_restoration.eval()
69 |
70 | for dataset in datasets:
71 | inp_dir = os.path.join(args.input_dir, dataset)
72 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif')))
73 | result_dir_tmp = os.path.join(args.result_dir, args.model_type, dataset, str(sigma_test))
74 | os.makedirs(result_dir_tmp, exist_ok=True)
75 |
76 | with torch.no_grad():
77 | for file_ in tqdm(files):
78 | torch.cuda.ipc_collect()
79 | torch.cuda.empty_cache()
80 | img = np.float32(utils.load_gray_img(file_))/255.
81 |
82 | np.random.seed(seed=0) # for reproducibility
83 | img += np.random.normal(0, sigma_test/255., img.shape)
84 |
85 | img = torch.from_numpy(img).permute(2,0,1)
86 | input_ = img.unsqueeze(0).cuda()
87 |
88 | # Padding in case images are not multiples of 8
89 | h,w = input_.shape[2], input_.shape[3]
90 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
91 | padh = H-h if h%factor!=0 else 0
92 | padw = W-w if w%factor!=0 else 0
93 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
94 |
95 | restored = model_restoration(input_)
96 |
97 | # Unpad images to original dimensions
98 | restored = restored[:,:,:h,:w]
99 |
100 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
101 |
102 | save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1])
103 | utils.save_gray_img(save_file, img_as_ubyte(restored))
104 |
--------------------------------------------------------------------------------
/Denoising/test_real_denoising_dnd.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import argparse
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import utils
14 |
15 | from basicsr.models.archs.restormer_arch import Restormer
16 | from skimage import img_as_ubyte
17 | import h5py
18 | import scipy.io as sio
19 | from pdb import set_trace as stx
20 |
21 | parser = argparse.ArgumentParser(description='Real Image Denoising using Restormer')
22 |
23 | parser.add_argument('--input_dir', default='./Datasets/test/DND/', type=str, help='Directory of validation images')
24 | parser.add_argument('--result_dir', default='./results/Real_Denoising/DND/', type=str, help='Directory for results')
25 | parser.add_argument('--weights', default='./pretrained_models/real_denoising.pth', type=str, help='Path to weights')
26 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
27 |
28 | args = parser.parse_args()
29 |
30 | ####### Load yaml #######
31 | yaml_file = 'Options/RealDenoising_Restormer.yml'
32 | import yaml
33 |
34 | try:
35 | from yaml import CLoader as Loader
36 | except ImportError:
37 | from yaml import Loader
38 |
39 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
40 |
41 | s = x['network_g'].pop('type')
42 | ##########################
43 |
44 | result_dir_mat = os.path.join(args.result_dir, 'mat')
45 | os.makedirs(result_dir_mat, exist_ok=True)
46 |
47 | if args.save_images:
48 | result_dir_png = os.path.join(args.result_dir, 'png')
49 | os.makedirs(result_dir_png, exist_ok=True)
50 |
51 | model_restoration = Restormer(**x['network_g'])
52 |
53 | checkpoint = torch.load(args.weights)
54 | model_restoration.load_state_dict(checkpoint['params'])
55 | print("===>Testing using weights: ",args.weights)
56 | model_restoration.cuda()
57 | model_restoration = nn.DataParallel(model_restoration)
58 | model_restoration.eval()
59 |
60 | israw = False
61 | eval_version="1.0"
62 |
63 | # Load info
64 | infos = h5py.File(os.path.join(args.input_dir, 'info.mat'), 'r')
65 | info = infos['info']
66 | bb = info['boundingboxes']
67 |
68 | # Process data
69 | with torch.no_grad():
70 | for i in tqdm(range(50)):
71 | Idenoised = np.zeros((20,), dtype=np.object)
72 | filename = '%04d.mat'%(i+1)
73 | filepath = os.path.join(args.input_dir, 'images_srgb', filename)
74 | img = h5py.File(filepath, 'r')
75 | Inoisy = np.float32(np.array(img['InoisySRGB']).T)
76 |
77 | # bounding box
78 | ref = bb[0][i]
79 | boxes = np.array(info[ref]).T
80 |
81 | for k in range(20):
82 | idx = [int(boxes[k,0]-1),int(boxes[k,2]),int(boxes[k,1]-1),int(boxes[k,3])]
83 | noisy_patch = torch.from_numpy(Inoisy[idx[0]:idx[1],idx[2]:idx[3],:]).unsqueeze(0).permute(0,3,1,2).cuda()
84 | restored_patch = model_restoration(noisy_patch)
85 | restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
86 | Idenoised[k] = restored_patch
87 |
88 | if args.save_images:
89 | save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1))
90 | denoised_img = img_as_ubyte(restored_patch)
91 | utils.save_img(save_file, denoised_img)
92 |
93 | # save denoised data
94 | sio.savemat(os.path.join(result_dir_mat, filename),
95 | {"Idenoised": Idenoised,
96 | "israw": israw,
97 | "eval_version": eval_version},
98 | )
99 |
--------------------------------------------------------------------------------
/Denoising/test_real_denoising_sidd.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import argparse
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import utils
14 |
15 | from basicsr.models.archs.restormer_arch import Restormer
16 | from skimage import img_as_ubyte
17 | import h5py
18 | import scipy.io as sio
19 | from pdb import set_trace as stx
20 |
21 | parser = argparse.ArgumentParser(description='Real Image Denoising using Restormer')
22 |
23 | parser.add_argument('--input_dir', default='./Datasets/test/SIDD/', type=str, help='Directory of validation images')
24 | parser.add_argument('--result_dir', default='./results/Real_Denoising/SIDD/', type=str, help='Directory for results')
25 | parser.add_argument('--weights', default='./pretrained_models/real_denoising.pth', type=str, help='Path to weights')
26 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
27 |
28 | args = parser.parse_args()
29 |
30 | ####### Load yaml #######
31 | yaml_file = 'Options/RealDenoising_Restormer.yml'
32 | import yaml
33 |
34 | try:
35 | from yaml import CLoader as Loader
36 | except ImportError:
37 | from yaml import Loader
38 |
39 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
40 |
41 | s = x['network_g'].pop('type')
42 | ##########################
43 |
44 | result_dir_mat = os.path.join(args.result_dir, 'mat')
45 | os.makedirs(result_dir_mat, exist_ok=True)
46 |
47 | if args.save_images:
48 | result_dir_png = os.path.join(args.result_dir, 'png')
49 | os.makedirs(result_dir_png, exist_ok=True)
50 |
51 | model_restoration = Restormer(**x['network_g'])
52 |
53 | checkpoint = torch.load(args.weights)
54 | model_restoration.load_state_dict(checkpoint['params'])
55 | print("===>Testing using weights: ",args.weights)
56 | model_restoration.cuda()
57 | model_restoration = nn.DataParallel(model_restoration)
58 | model_restoration.eval()
59 |
60 | # Process data
61 | filepath = os.path.join(args.input_dir, 'ValidationNoisyBlocksSrgb.mat')
62 | img = sio.loadmat(filepath)
63 | Inoisy = np.float32(np.array(img['ValidationNoisyBlocksSrgb']))
64 | Inoisy /=255.
65 | restored = np.zeros_like(Inoisy)
66 | with torch.no_grad():
67 | for i in tqdm(range(40)):
68 | for k in range(32):
69 | noisy_patch = torch.from_numpy(Inoisy[i,k,:,:,:]).unsqueeze(0).permute(0,3,1,2).cuda()
70 | restored_patch = model_restoration(noisy_patch)
71 | restored_patch = torch.clamp(restored_patch,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0)
72 | restored[i,k,:,:,:] = restored_patch
73 |
74 | if args.save_images:
75 | save_file = os.path.join(result_dir_png, '%04d_%02d.png'%(i+1,k+1))
76 | utils.save_img(save_file, img_as_ubyte(restored_patch))
77 |
78 | # save denoised data
79 | sio.savemat(os.path.join(result_dir_mat, 'Idenoised.mat'), {"Idenoised": restored,})
80 |
--------------------------------------------------------------------------------
/Denoising/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Deraining/Datasets/README.md:
--------------------------------------------------------------------------------
1 | For training and testing, your directory structure should look like this
2 |
3 | `Datasets`
4 | `├──train`
5 | `└──Rain13K`
6 | `├──input`
7 | `└──target`
8 | `└──test`
9 | `├──Test100`
10 | `├──input`
11 | `└──target`
12 | `├──Rain100H`
13 | `├──input`
14 | `└──target`
15 | `├──Rain100L`
16 | `├──input`
17 | `└──target`
18 | `├──Test1200`
19 | `├──input`
20 | `└──target`
21 | `└──Test2800`
22 | `├──input`
23 | `└──target`
24 |
--------------------------------------------------------------------------------
/Deraining/Options/Deraining_Restormer.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Deraining_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./Deraining/Datasets/train/Rain13K/target
14 | dataroot_lq: ./Deraining/Datasets/train/Rain13K/input
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ### -------------Progressive training--------------------------
27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
28 | iters: [92000,64000,48000,36000,36000,24000]
29 | gt_size: 384 # Max patch size for progressive training
30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
31 | ### ------------------------------------------------------------
32 |
33 | ### ------- Training on single fixed-patch size 128x128---------
34 | # mini_batch_sizes: [8]
35 | # iters: [300000]
36 | # gt_size: 128
37 | # gt_sizes: [128]
38 | ### ------------------------------------------------------------
39 |
40 | dataset_enlarge_ratio: 1
41 | prefetch_mode: ~
42 |
43 | val:
44 | name: ValSet
45 | type: Dataset_PairedImage
46 | dataroot_gt: ./Deraining/Datasets/test/Rain100L/target
47 | dataroot_lq: ./Deraining/Datasets/test/Rain100L/input
48 | io_backend:
49 | type: disk
50 |
51 | # network structures
52 | network_g:
53 | type: Restormer
54 | inp_channels: 3
55 | out_channels: 3
56 | dim: 48
57 | num_blocks: [4,6,6,8]
58 | num_refinement_blocks: 4
59 | heads: [1,2,4,8]
60 | ffn_expansion_factor: 2.66
61 | bias: False
62 | LayerNorm_type: WithBias
63 | dual_pixel_task: False
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: false
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 |
104 | # validation settings
105 | val:
106 | window_size: 8
107 | val_freq: !!float 4e3
108 | save_img: false
109 | rgb2bgr: true
110 | use_image: true
111 | max_minibatch: 8
112 |
113 | metrics:
114 | psnr: # metric name, can be arbitrary
115 | type: calculate_psnr
116 | crop_border: 0
117 | test_y_channel: true
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 1000
122 | save_checkpoint_freq: !!float 4e3
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/Deraining/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Training
3 |
4 | 1. To download Rain13K training and testing data, run
5 | ```
6 | python download_data.py --data train-test
7 | ```
8 |
9 | 2. To train Restormer with default settings, run
10 | ```
11 | cd Restormer
12 | ./train.sh Deraining/Options/Deraining_Restormer.yml
13 | ```
14 |
15 | **Note:** The above training script uses 8 GPUs by default. To use any other number of GPUs, modify [Restormer/train.sh](../train.sh) and [Deraining/Options/Deraining_Restormer.yml](Options/Deraining_Restormer.yml)
16 |
17 | ## Evaluation
18 |
19 | 1. Download the pre-trained [model](https://drive.google.com/drive/folders/1ZEDDEVW0UgkpWi-N4Lj_JUoVChGXCu_u?usp=sharing) and place it in `./pretrained_models/`
20 |
21 | 2. Download test datasets (Test100, Rain100H, Rain100L, Test1200, Test2800), run
22 | ```
23 | python download_data.py --data test
24 | ```
25 |
26 | 3. Testing
27 | ```
28 | python test.py
29 | ```
30 |
31 | #### To reproduce PSNR/SSIM scores of Table 1, run
32 |
33 | ```
34 | evaluate_PSNR_SSIM.m
35 | ```
36 |
--------------------------------------------------------------------------------
/Deraining/download_data.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | ## Download training and testing data for image deraining task
6 | import os
7 | # import gdown
8 | import shutil
9 |
10 | import argparse
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--data', type=str, required=True, help='train, test or train-test')
14 | args = parser.parse_args()
15 |
16 | ### Google drive IDs ######
17 | rain13k_train = '14BidJeG4nSNuFNFDf99K-7eErCq4i47t' ## https://drive.google.com/file/d/14BidJeG4nSNuFNFDf99K-7eErCq4i47t/view?usp=sharing
18 | rain13k_test = '1P_-RAvltEoEhfT-9GrWRdpEi6NSswTs8' ## https://drive.google.com/file/d/1P_-RAvltEoEhfT-9GrWRdpEi6NSswTs8/view?usp=sharing
19 |
20 | for data in args.data.split('-'):
21 | if data == 'train':
22 | print('Rain13K Training Data!')
23 | # gdown.download(id=rain13k_train, output='Datasets/train.zip', quiet=False)
24 | os.system(f'gdrive download {rain13k_train} --path Datasets/')
25 | print('Extracting Rain13K data...')
26 | shutil.unpack_archive('Datasets/train.zip', 'Datasets')
27 | os.remove('Datasets/train.zip')
28 |
29 | if data == 'test':
30 | print('Download Deraining Testing Data')
31 | # gdown.download(id=rain13k_test, output='Datasets/test.zip', quiet=False)
32 | os.system(f'gdrive download {rain13k_test} --path Datasets/')
33 | print('Extracting test data...')
34 | shutil.unpack_archive('Datasets/test.zip', 'Datasets')
35 | os.remove('Datasets/test.zip')
36 |
37 |
38 | # print('Download completed successfully!')
39 |
40 |
41 |
--------------------------------------------------------------------------------
/Deraining/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | pre-trained deraining model is available [here](https://drive.google.com/drive/folders/1ZEDDEVW0UgkpWi-N4Lj_JUoVChGXCu_u?usp=sharing)
--------------------------------------------------------------------------------
/Deraining/test.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 |
6 |
7 | import numpy as np
8 | import os
9 | import argparse
10 | from tqdm import tqdm
11 |
12 | import torch.nn as nn
13 | import torch
14 | import torch.nn.functional as F
15 | import utils
16 |
17 | from natsort import natsorted
18 | from glob import glob
19 | from basicsr.models.archs.restormer_arch import Restormer
20 | from skimage import img_as_ubyte
21 | from pdb import set_trace as stx
22 |
23 | parser = argparse.ArgumentParser(description='Image Deraining using Restormer')
24 |
25 | parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images')
26 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
27 | parser.add_argument('--weights', default='./pretrained_models/deraining.pth', type=str, help='Path to weights')
28 |
29 | args = parser.parse_args()
30 |
31 | ####### Load yaml #######
32 | yaml_file = 'Options/Deraining_Restormer.yml'
33 | import yaml
34 |
35 | try:
36 | from yaml import CLoader as Loader
37 | except ImportError:
38 | from yaml import Loader
39 |
40 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
41 |
42 | s = x['network_g'].pop('type')
43 | ##########################
44 |
45 | model_restoration = Restormer(**x['network_g'])
46 |
47 | checkpoint = torch.load(args.weights)
48 | model_restoration.load_state_dict(checkpoint['params'])
49 | print("===>Testing using weights: ",args.weights)
50 | model_restoration.cuda()
51 | model_restoration = nn.DataParallel(model_restoration)
52 | model_restoration.eval()
53 |
54 |
55 | factor = 8
56 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800']
57 |
58 | for dataset in datasets:
59 | result_dir = os.path.join(args.result_dir, dataset)
60 | os.makedirs(result_dir, exist_ok=True)
61 |
62 | inp_dir = os.path.join(args.input_dir, 'test', dataset, 'input')
63 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
64 | with torch.no_grad():
65 | for file_ in tqdm(files):
66 | torch.cuda.ipc_collect()
67 | torch.cuda.empty_cache()
68 |
69 | img = np.float32(utils.load_img(file_))/255.
70 | img = torch.from_numpy(img).permute(2,0,1)
71 | input_ = img.unsqueeze(0).cuda()
72 |
73 | # Padding in case images are not multiples of 8
74 | h,w = input_.shape[2], input_.shape[3]
75 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
76 | padh = H-h if h%factor!=0 else 0
77 | padw = W-w if w%factor!=0 else 0
78 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
79 |
80 | restored = model_restoration(input_)
81 |
82 | # Unpad images to original dimensions
83 | restored = restored[:,:,:h,:w]
84 |
85 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
86 |
87 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
88 |
--------------------------------------------------------------------------------
/Deraining/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | This repository is built in PyTorch 1.8.1 and tested on Ubuntu 16.04 environment (Python3.7, CUDA10.2, cuDNN7.6).
4 | Follow these intructions
5 |
6 | 1. Clone our repository
7 | ```
8 | git clone https://github.com/swz30/Restormer.git
9 | cd Restormer
10 | ```
11 |
12 | 2. Make conda environment
13 | ```
14 | conda create -n pytorch181 python=3.7
15 | conda activate pytorch181
16 | ```
17 |
18 | 3. Install dependencies
19 | ```
20 | conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch
21 | pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm
22 | pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips
23 | ```
24 |
25 | 4. Install basicsr
26 | ```
27 | python setup.py develop --no_cuda_ext
28 | ```
29 |
30 | ### Download datasets from Google Drive
31 |
32 | To be able to download datasets automatically you would need `go` and `gdrive` installed.
33 |
34 | 1. You can install `go` with the following
35 | ```
36 | curl -O https://storage.googleapis.com/golang/go1.11.1.linux-amd64.tar.gz
37 | mkdir -p ~/installed
38 | tar -C ~/installed -xzf go1.11.1.linux-amd64.tar.gz
39 | mkdir -p ~/go
40 | ```
41 |
42 | 2. Add the lines in `~/.bashrc`
43 | ```
44 | export GOPATH=$HOME/go
45 | export PATH=$PATH:$HOME/go/bin:$HOME/installed/go/bin
46 | ```
47 |
48 | 3. Install `gdrive` using
49 | ```
50 | go get github.com/prasmussen/gdrive
51 | ```
52 |
53 | 4. Close current terminal and open a new terminal.
54 |
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from functools import partial
7 | from os import path as osp
8 |
9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10 | from basicsr.utils import get_root_logger, scandir
11 | from basicsr.utils.dist_util import get_dist_info
12 |
13 | __all__ = ['create_dataset', 'create_dataloader']
14 |
15 | # automatically scan and import dataset modules
16 | # scan all the files under the data folder with '_dataset' in file names
17 | data_folder = osp.dirname(osp.abspath(__file__))
18 | dataset_filenames = [
19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
20 | if v.endswith('_dataset.py')
21 | ]
22 | # import all the dataset modules
23 | _dataset_modules = [
24 | importlib.import_module(f'basicsr.data.{file_name}')
25 | for file_name in dataset_filenames
26 | ]
27 |
28 |
29 | def create_dataset(dataset_opt):
30 | """Create dataset.
31 |
32 | Args:
33 | dataset_opt (dict): Configuration for dataset. It constains:
34 | name (str): Dataset name.
35 | type (str): Dataset type.
36 | """
37 | dataset_type = dataset_opt['type']
38 |
39 | # dynamic instantiation
40 | for module in _dataset_modules:
41 | dataset_cls = getattr(module, dataset_type, None)
42 | if dataset_cls is not None:
43 | break
44 | if dataset_cls is None:
45 | raise ValueError(f'Dataset {dataset_type} is not found.')
46 |
47 | dataset = dataset_cls(dataset_opt)
48 |
49 | logger = get_root_logger()
50 | logger.info(
51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
52 | 'is created.')
53 | return dataset
54 |
55 |
56 | def create_dataloader(dataset,
57 | dataset_opt,
58 | num_gpu=1,
59 | dist=False,
60 | sampler=None,
61 | seed=None):
62 | """Create dataloader.
63 |
64 | Args:
65 | dataset (torch.utils.data.Dataset): Dataset.
66 | dataset_opt (dict): Dataset options. It contains the following keys:
67 | phase (str): 'train' or 'val'.
68 | num_worker_per_gpu (int): Number of workers for each GPU.
69 | batch_size_per_gpu (int): Training batch size for each GPU.
70 | num_gpu (int): Number of GPUs. Used only in the train phase.
71 | Default: 1.
72 | dist (bool): Whether in distributed training. Used only in the train
73 | phase. Default: False.
74 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
75 | seed (int | None): Seed. Default: None
76 | """
77 | phase = dataset_opt['phase']
78 | rank, _ = get_dist_info()
79 | if phase == 'train':
80 | if dist: # distributed training
81 | batch_size = dataset_opt['batch_size_per_gpu']
82 | num_workers = dataset_opt['num_worker_per_gpu']
83 | else: # non-distributed training
84 | multiplier = 1 if num_gpu == 0 else num_gpu
85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
87 | dataloader_args = dict(
88 | dataset=dataset,
89 | batch_size=batch_size,
90 | shuffle=False,
91 | num_workers=num_workers,
92 | sampler=sampler,
93 | drop_last=True)
94 | if sampler is None:
95 | dataloader_args['shuffle'] = True
96 | dataloader_args['worker_init_fn'] = partial(
97 | worker_init_fn, num_workers=num_workers, rank=rank,
98 | seed=seed) if seed is not None else None
99 | elif phase in ['val', 'test']: # validation
100 | dataloader_args = dict(
101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
102 | else:
103 | raise ValueError(f'Wrong dataset phase: {phase}. '
104 | "Supported ones are 'train', 'val' and 'test'.")
105 |
106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
107 |
108 | prefetch_mode = dataset_opt.get('prefetch_mode')
109 | if prefetch_mode == 'cpu': # CPUPrefetcher
110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
111 | logger = get_root_logger()
112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
113 | f'num_prefetch_queue = {num_prefetch_queue}')
114 | return PrefetchDataLoader(
115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
116 | else:
117 | # prefetch_mode=None: Normal dataloader
118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
119 | return torch.utils.data.DataLoader(**dataloader_args)
120 |
121 |
122 | def worker_init_fn(worker_id, num_workers, rank, seed):
123 | # Set the worker seed to num_workers * rank + worker_id + seed
124 | worker_seed = num_workers * rank + worker_id + seed
125 | np.random.seed(worker_seed)
126 | random.seed(worker_seed)
127 |
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when restart the dataloader after each epoch
12 |
13 | Args:
14 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
15 | num_replicas (int | None): Number of processes participating in
16 | the training. It is usually the world_size.
17 | rank (int | None): Rank of the current process within num_replicas.
18 | ratio (int): Enlarging ratio. Default: 1.
19 | """
20 |
21 | def __init__(self, dataset, num_replicas, rank, ratio=1):
22 | self.dataset = dataset
23 | self.num_replicas = num_replicas
24 | self.rank = rank
25 | self.epoch = 0
26 | self.num_samples = math.ceil(
27 | len(self.dataset) * ratio / self.num_replicas)
28 | self.total_size = self.num_samples * self.num_replicas
29 |
30 | def __iter__(self):
31 | # deterministically shuffle based on epoch
32 | g = torch.Generator()
33 | g.manual_seed(self.epoch)
34 | indices = torch.randperm(self.total_size, generator=g).tolist()
35 |
36 | dataset_size = len(self.dataset)
37 | indices = [v % dataset_size for v in indices]
38 |
39 | # subsample
40 | indices = indices[self.rank:self.total_size:self.num_replicas]
41 | assert len(indices) == self.num_samples
42 |
43 | return iter(indices)
44 |
45 | def __len__(self):
46 | return self.num_samples
47 |
48 | def set_epoch(self, epoch):
49 | self.epoch = epoch
50 |
--------------------------------------------------------------------------------
/basicsr/data/ffhq_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.transforms import augment
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor
7 |
8 |
9 | class FFHQDataset(data.Dataset):
10 | """FFHQ dataset for StyleGAN.
11 |
12 | Args:
13 | opt (dict): Config for train datasets. It contains the following keys:
14 | dataroot_gt (str): Data root path for gt.
15 | io_backend (dict): IO backend type and other kwarg.
16 | mean (list | tuple): Image mean.
17 | std (list | tuple): Image std.
18 | use_hflip (bool): Whether to horizontally flip.
19 |
20 | """
21 |
22 | def __init__(self, opt):
23 | super(FFHQDataset, self).__init__()
24 | self.opt = opt
25 | # file client (io backend)
26 | self.file_client = None
27 | self.io_backend_opt = opt['io_backend']
28 |
29 | self.gt_folder = opt['dataroot_gt']
30 | self.mean = opt['mean']
31 | self.std = opt['std']
32 |
33 | if self.io_backend_opt['type'] == 'lmdb':
34 | self.io_backend_opt['db_paths'] = self.gt_folder
35 | if not self.gt_folder.endswith('.lmdb'):
36 | raise ValueError("'dataroot_gt' should end with '.lmdb', "
37 | f'but received {self.gt_folder}')
38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
39 | self.paths = [line.split('.')[0] for line in fin]
40 | else:
41 | # FFHQ has 70000 images in total
42 | self.paths = [
43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
44 | ]
45 |
46 | def __getitem__(self, index):
47 | if self.file_client is None:
48 | self.file_client = FileClient(
49 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | # load gt image
52 | gt_path = self.paths[index]
53 | img_bytes = self.file_client.get(gt_path)
54 | img_gt = imfrombytes(img_bytes, float32=True)
55 |
56 | # random horizontal flip
57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
58 | # BGR to RGB, HWC to CHW, numpy to tensor
59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
60 | # normalize
61 | normalize(img_gt, self.mean, self.std, inplace=True)
62 | return {'gt': img_gt, 'gt_path': gt_path}
63 |
64 | def __len__(self):
65 | return len(self.paths)
66 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDS4_test_GT.txt:
--------------------------------------------------------------------------------
1 | 000 100 (720,1280,3)
2 | 011 100 (720,1280,3)
3 | 015 100 (720,1280,3)
4 | 020 100 (720,1280,3)
5 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDS_GT.txt:
--------------------------------------------------------------------------------
1 | 000 100 (720,1280,3)
2 | 001 100 (720,1280,3)
3 | 002 100 (720,1280,3)
4 | 003 100 (720,1280,3)
5 | 004 100 (720,1280,3)
6 | 005 100 (720,1280,3)
7 | 006 100 (720,1280,3)
8 | 007 100 (720,1280,3)
9 | 008 100 (720,1280,3)
10 | 009 100 (720,1280,3)
11 | 010 100 (720,1280,3)
12 | 011 100 (720,1280,3)
13 | 012 100 (720,1280,3)
14 | 013 100 (720,1280,3)
15 | 014 100 (720,1280,3)
16 | 015 100 (720,1280,3)
17 | 016 100 (720,1280,3)
18 | 017 100 (720,1280,3)
19 | 018 100 (720,1280,3)
20 | 019 100 (720,1280,3)
21 | 020 100 (720,1280,3)
22 | 021 100 (720,1280,3)
23 | 022 100 (720,1280,3)
24 | 023 100 (720,1280,3)
25 | 024 100 (720,1280,3)
26 | 025 100 (720,1280,3)
27 | 026 100 (720,1280,3)
28 | 027 100 (720,1280,3)
29 | 028 100 (720,1280,3)
30 | 029 100 (720,1280,3)
31 | 030 100 (720,1280,3)
32 | 031 100 (720,1280,3)
33 | 032 100 (720,1280,3)
34 | 033 100 (720,1280,3)
35 | 034 100 (720,1280,3)
36 | 035 100 (720,1280,3)
37 | 036 100 (720,1280,3)
38 | 037 100 (720,1280,3)
39 | 038 100 (720,1280,3)
40 | 039 100 (720,1280,3)
41 | 040 100 (720,1280,3)
42 | 041 100 (720,1280,3)
43 | 042 100 (720,1280,3)
44 | 043 100 (720,1280,3)
45 | 044 100 (720,1280,3)
46 | 045 100 (720,1280,3)
47 | 046 100 (720,1280,3)
48 | 047 100 (720,1280,3)
49 | 048 100 (720,1280,3)
50 | 049 100 (720,1280,3)
51 | 050 100 (720,1280,3)
52 | 051 100 (720,1280,3)
53 | 052 100 (720,1280,3)
54 | 053 100 (720,1280,3)
55 | 054 100 (720,1280,3)
56 | 055 100 (720,1280,3)
57 | 056 100 (720,1280,3)
58 | 057 100 (720,1280,3)
59 | 058 100 (720,1280,3)
60 | 059 100 (720,1280,3)
61 | 060 100 (720,1280,3)
62 | 061 100 (720,1280,3)
63 | 062 100 (720,1280,3)
64 | 063 100 (720,1280,3)
65 | 064 100 (720,1280,3)
66 | 065 100 (720,1280,3)
67 | 066 100 (720,1280,3)
68 | 067 100 (720,1280,3)
69 | 068 100 (720,1280,3)
70 | 069 100 (720,1280,3)
71 | 070 100 (720,1280,3)
72 | 071 100 (720,1280,3)
73 | 072 100 (720,1280,3)
74 | 073 100 (720,1280,3)
75 | 074 100 (720,1280,3)
76 | 075 100 (720,1280,3)
77 | 076 100 (720,1280,3)
78 | 077 100 (720,1280,3)
79 | 078 100 (720,1280,3)
80 | 079 100 (720,1280,3)
81 | 080 100 (720,1280,3)
82 | 081 100 (720,1280,3)
83 | 082 100 (720,1280,3)
84 | 083 100 (720,1280,3)
85 | 084 100 (720,1280,3)
86 | 085 100 (720,1280,3)
87 | 086 100 (720,1280,3)
88 | 087 100 (720,1280,3)
89 | 088 100 (720,1280,3)
90 | 089 100 (720,1280,3)
91 | 090 100 (720,1280,3)
92 | 091 100 (720,1280,3)
93 | 092 100 (720,1280,3)
94 | 093 100 (720,1280,3)
95 | 094 100 (720,1280,3)
96 | 095 100 (720,1280,3)
97 | 096 100 (720,1280,3)
98 | 097 100 (720,1280,3)
99 | 098 100 (720,1280,3)
100 | 099 100 (720,1280,3)
101 | 100 100 (720,1280,3)
102 | 101 100 (720,1280,3)
103 | 102 100 (720,1280,3)
104 | 103 100 (720,1280,3)
105 | 104 100 (720,1280,3)
106 | 105 100 (720,1280,3)
107 | 106 100 (720,1280,3)
108 | 107 100 (720,1280,3)
109 | 108 100 (720,1280,3)
110 | 109 100 (720,1280,3)
111 | 110 100 (720,1280,3)
112 | 111 100 (720,1280,3)
113 | 112 100 (720,1280,3)
114 | 113 100 (720,1280,3)
115 | 114 100 (720,1280,3)
116 | 115 100 (720,1280,3)
117 | 116 100 (720,1280,3)
118 | 117 100 (720,1280,3)
119 | 118 100 (720,1280,3)
120 | 119 100 (720,1280,3)
121 | 120 100 (720,1280,3)
122 | 121 100 (720,1280,3)
123 | 122 100 (720,1280,3)
124 | 123 100 (720,1280,3)
125 | 124 100 (720,1280,3)
126 | 125 100 (720,1280,3)
127 | 126 100 (720,1280,3)
128 | 127 100 (720,1280,3)
129 | 128 100 (720,1280,3)
130 | 129 100 (720,1280,3)
131 | 130 100 (720,1280,3)
132 | 131 100 (720,1280,3)
133 | 132 100 (720,1280,3)
134 | 133 100 (720,1280,3)
135 | 134 100 (720,1280,3)
136 | 135 100 (720,1280,3)
137 | 136 100 (720,1280,3)
138 | 137 100 (720,1280,3)
139 | 138 100 (720,1280,3)
140 | 139 100 (720,1280,3)
141 | 140 100 (720,1280,3)
142 | 141 100 (720,1280,3)
143 | 142 100 (720,1280,3)
144 | 143 100 (720,1280,3)
145 | 144 100 (720,1280,3)
146 | 145 100 (720,1280,3)
147 | 146 100 (720,1280,3)
148 | 147 100 (720,1280,3)
149 | 148 100 (720,1280,3)
150 | 149 100 (720,1280,3)
151 | 150 100 (720,1280,3)
152 | 151 100 (720,1280,3)
153 | 152 100 (720,1280,3)
154 | 153 100 (720,1280,3)
155 | 154 100 (720,1280,3)
156 | 155 100 (720,1280,3)
157 | 156 100 (720,1280,3)
158 | 157 100 (720,1280,3)
159 | 158 100 (720,1280,3)
160 | 159 100 (720,1280,3)
161 | 160 100 (720,1280,3)
162 | 161 100 (720,1280,3)
163 | 162 100 (720,1280,3)
164 | 163 100 (720,1280,3)
165 | 164 100 (720,1280,3)
166 | 165 100 (720,1280,3)
167 | 166 100 (720,1280,3)
168 | 167 100 (720,1280,3)
169 | 168 100 (720,1280,3)
170 | 169 100 (720,1280,3)
171 | 170 100 (720,1280,3)
172 | 171 100 (720,1280,3)
173 | 172 100 (720,1280,3)
174 | 173 100 (720,1280,3)
175 | 174 100 (720,1280,3)
176 | 175 100 (720,1280,3)
177 | 176 100 (720,1280,3)
178 | 177 100 (720,1280,3)
179 | 178 100 (720,1280,3)
180 | 179 100 (720,1280,3)
181 | 180 100 (720,1280,3)
182 | 181 100 (720,1280,3)
183 | 182 100 (720,1280,3)
184 | 183 100 (720,1280,3)
185 | 184 100 (720,1280,3)
186 | 185 100 (720,1280,3)
187 | 186 100 (720,1280,3)
188 | 187 100 (720,1280,3)
189 | 188 100 (720,1280,3)
190 | 189 100 (720,1280,3)
191 | 190 100 (720,1280,3)
192 | 191 100 (720,1280,3)
193 | 192 100 (720,1280,3)
194 | 193 100 (720,1280,3)
195 | 194 100 (720,1280,3)
196 | 195 100 (720,1280,3)
197 | 196 100 (720,1280,3)
198 | 197 100 (720,1280,3)
199 | 198 100 (720,1280,3)
200 | 199 100 (720,1280,3)
201 | 200 100 (720,1280,3)
202 | 201 100 (720,1280,3)
203 | 202 100 (720,1280,3)
204 | 203 100 (720,1280,3)
205 | 204 100 (720,1280,3)
206 | 205 100 (720,1280,3)
207 | 206 100 (720,1280,3)
208 | 207 100 (720,1280,3)
209 | 208 100 (720,1280,3)
210 | 209 100 (720,1280,3)
211 | 210 100 (720,1280,3)
212 | 211 100 (720,1280,3)
213 | 212 100 (720,1280,3)
214 | 213 100 (720,1280,3)
215 | 214 100 (720,1280,3)
216 | 215 100 (720,1280,3)
217 | 216 100 (720,1280,3)
218 | 217 100 (720,1280,3)
219 | 218 100 (720,1280,3)
220 | 219 100 (720,1280,3)
221 | 220 100 (720,1280,3)
222 | 221 100 (720,1280,3)
223 | 222 100 (720,1280,3)
224 | 223 100 (720,1280,3)
225 | 224 100 (720,1280,3)
226 | 225 100 (720,1280,3)
227 | 226 100 (720,1280,3)
228 | 227 100 (720,1280,3)
229 | 228 100 (720,1280,3)
230 | 229 100 (720,1280,3)
231 | 230 100 (720,1280,3)
232 | 231 100 (720,1280,3)
233 | 232 100 (720,1280,3)
234 | 233 100 (720,1280,3)
235 | 234 100 (720,1280,3)
236 | 235 100 (720,1280,3)
237 | 236 100 (720,1280,3)
238 | 237 100 (720,1280,3)
239 | 238 100 (720,1280,3)
240 | 239 100 (720,1280,3)
241 | 240 100 (720,1280,3)
242 | 241 100 (720,1280,3)
243 | 242 100 (720,1280,3)
244 | 243 100 (720,1280,3)
245 | 244 100 (720,1280,3)
246 | 245 100 (720,1280,3)
247 | 246 100 (720,1280,3)
248 | 247 100 (720,1280,3)
249 | 248 100 (720,1280,3)
250 | 249 100 (720,1280,3)
251 | 250 100 (720,1280,3)
252 | 251 100 (720,1280,3)
253 | 252 100 (720,1280,3)
254 | 253 100 (720,1280,3)
255 | 254 100 (720,1280,3)
256 | 255 100 (720,1280,3)
257 | 256 100 (720,1280,3)
258 | 257 100 (720,1280,3)
259 | 258 100 (720,1280,3)
260 | 259 100 (720,1280,3)
261 | 260 100 (720,1280,3)
262 | 261 100 (720,1280,3)
263 | 262 100 (720,1280,3)
264 | 263 100 (720,1280,3)
265 | 264 100 (720,1280,3)
266 | 265 100 (720,1280,3)
267 | 266 100 (720,1280,3)
268 | 267 100 (720,1280,3)
269 | 268 100 (720,1280,3)
270 | 269 100 (720,1280,3)
271 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt:
--------------------------------------------------------------------------------
1 | 240 100 (720,1280,3)
2 | 241 100 (720,1280,3)
3 | 246 100 (720,1280,3)
4 | 257 100 (720,1280,3)
5 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt:
--------------------------------------------------------------------------------
1 | 240 100 (720,1280,3)
2 | 241 100 (720,1280,3)
3 | 242 100 (720,1280,3)
4 | 243 100 (720,1280,3)
5 | 244 100 (720,1280,3)
6 | 245 100 (720,1280,3)
7 | 246 100 (720,1280,3)
8 | 247 100 (720,1280,3)
9 | 248 100 (720,1280,3)
10 | 249 100 (720,1280,3)
11 | 250 100 (720,1280,3)
12 | 251 100 (720,1280,3)
13 | 252 100 (720,1280,3)
14 | 253 100 (720,1280,3)
15 | 254 100 (720,1280,3)
16 | 255 100 (720,1280,3)
17 | 256 100 (720,1280,3)
18 | 257 100 (720,1280,3)
19 | 258 100 (720,1280,3)
20 | 259 100 (720,1280,3)
21 | 260 100 (720,1280,3)
22 | 261 100 (720,1280,3)
23 | 262 100 (720,1280,3)
24 | 263 100 (720,1280,3)
25 | 264 100 (720,1280,3)
26 | 265 100 (720,1280,3)
27 | 266 100 (720,1280,3)
28 | 267 100 (720,1280,3)
29 | 268 100 (720,1280,3)
30 | 269 100 (720,1280,3)
31 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Ref:
11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 |
13 | Args:
14 | generator: Python generator.
15 | num_prefetch_queue (int): Number of prefetch queue.
16 | """
17 |
18 | def __init__(self, generator, num_prefetch_queue):
19 | threading.Thread.__init__(self)
20 | self.queue = Queue.Queue(num_prefetch_queue)
21 | self.generator = generator
22 | self.daemon = True
23 | self.start()
24 |
25 | def run(self):
26 | for item in self.generator:
27 | self.queue.put(item)
28 | self.queue.put(None)
29 |
30 | def __next__(self):
31 | next_item = self.queue.get()
32 | if next_item is None:
33 | raise StopIteration
34 | return next_item
35 |
36 | def __iter__(self):
37 | return self
38 |
39 |
40 | class PrefetchDataLoader(DataLoader):
41 | """Prefetch version of dataloader.
42 |
43 | Ref:
44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45 |
46 | TODO:
47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48 | ddp.
49 |
50 | Args:
51 | num_prefetch_queue (int): Number of prefetch queue.
52 | kwargs (dict): Other arguments for dataloader.
53 | """
54 |
55 | def __init__(self, num_prefetch_queue, **kwargs):
56 | self.num_prefetch_queue = num_prefetch_queue
57 | super(PrefetchDataLoader, self).__init__(**kwargs)
58 |
59 | def __iter__(self):
60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61 |
62 |
63 | class CPUPrefetcher():
64 | """CPU prefetcher.
65 |
66 | Args:
67 | loader: Dataloader.
68 | """
69 |
70 | def __init__(self, loader):
71 | self.ori_loader = loader
72 | self.loader = iter(loader)
73 |
74 | def next(self):
75 | try:
76 | return next(self.loader)
77 | except StopIteration:
78 | return None
79 |
80 | def reset(self):
81 | self.loader = iter(self.ori_loader)
82 |
83 |
84 | class CUDAPrefetcher():
85 | """CUDA prefetcher.
86 |
87 | Ref:
88 | https://github.com/NVIDIA/apex/issues/304#
89 |
90 | It may consums more GPU memory.
91 |
92 | Args:
93 | loader: Dataloader.
94 | opt (dict): Options.
95 | """
96 |
97 | def __init__(self, loader, opt):
98 | self.ori_loader = loader
99 | self.loader = iter(loader)
100 | self.opt = opt
101 | self.stream = torch.cuda.Stream()
102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103 | self.preload()
104 |
105 | def preload(self):
106 | try:
107 | self.batch = next(self.loader) # self.batch is a dict
108 | except StopIteration:
109 | self.batch = None
110 | return None
111 | # put tensors to gpu
112 | with torch.cuda.stream(self.stream):
113 | for k, v in self.batch.items():
114 | if torch.is_tensor(v):
115 | self.batch[k] = self.batch[k].to(
116 | device=self.device, non_blocking=True)
117 |
118 | def next(self):
119 | torch.cuda.current_stream().wait_stream(self.stream)
120 | batch = self.batch
121 | self.preload()
122 | return batch
123 |
124 | def reset(self):
125 | self.loader = iter(self.ori_loader)
126 | self.preload()
127 |
--------------------------------------------------------------------------------
/basicsr/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7 |
8 |
9 | class SingleImageDataset(data.Dataset):
10 | """Read only lq images in the test phase.
11 |
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13 |
14 | There are two modes:
15 | 1. 'meta_info_file': Use meta information file to generate paths.
16 | 2. 'folder': Scan folders to generate paths.
17 |
18 | Args:
19 | opt (dict): Config for train datasets. It contains the following keys:
20 | dataroot_lq (str): Data root path for lq.
21 | meta_info_file (str): Path for meta information file.
22 | io_backend (dict): IO backend type and other kwarg.
23 | """
24 |
25 | def __init__(self, opt):
26 | super(SingleImageDataset, self).__init__()
27 | self.opt = opt
28 | # file client (io backend)
29 | self.file_client = None
30 | self.io_backend_opt = opt['io_backend']
31 | self.mean = opt['mean'] if 'mean' in opt else None
32 | self.std = opt['std'] if 'std' in opt else None
33 | self.lq_folder = opt['dataroot_lq']
34 |
35 | if self.io_backend_opt['type'] == 'lmdb':
36 | self.io_backend_opt['db_paths'] = [self.lq_folder]
37 | self.io_backend_opt['client_keys'] = ['lq']
38 | self.paths = paths_from_lmdb(self.lq_folder)
39 | elif 'meta_info_file' in self.opt:
40 | with open(self.opt['meta_info_file'], 'r') as fin:
41 | self.paths = [
42 | osp.join(self.lq_folder,
43 | line.split(' ')[0]) for line in fin
44 | ]
45 | else:
46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 |
48 | def __getitem__(self, index):
49 | if self.file_client is None:
50 | self.file_client = FileClient(
51 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
52 |
53 | # load lq image
54 | lq_path = self.paths[index]
55 | img_bytes = self.file_client.get(lq_path, 'lq')
56 | img_lq = imfrombytes(img_bytes, float32=True)
57 |
58 | # TODO: color space transform
59 | # BGR to RGB, HWC to CHW, numpy to tensor
60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61 | # normalize
62 | if self.mean is not None or self.std is not None:
63 | normalize(img_lq, self.mean, self.std, inplace=True)
64 | return {'lq': img_lq, 'lq_path': lq_path}
65 |
66 | def __len__(self):
67 | return len(self.paths)
68 |
--------------------------------------------------------------------------------
/basicsr/data/vimeo90k_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from pathlib import Path
4 | from torch.utils import data as data
5 |
6 | from basicsr.data.transforms import augment, paired_random_crop
7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8 |
9 |
10 | class Vimeo90KDataset(data.Dataset):
11 | """Vimeo90K dataset for training.
12 |
13 | The keys are generated from a meta info txt file.
14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
15 |
16 | Each line contains:
17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space.
18 | Examples:
19 | 00001/0001 7 (256,448,3)
20 | 00001/0002 7 (256,448,3)
21 |
22 | Key examples: "00001/0001"
23 | GT (gt): Ground-Truth;
24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
25 |
26 | The neighboring frame list for different num_frame:
27 | num_frame | frame list
28 | 1 | 4
29 | 3 | 3,4,5
30 | 5 | 2,3,4,5,6
31 | 7 | 1,2,3,4,5,6,7
32 |
33 | Args:
34 | opt (dict): Config for train dataset. It contains the following keys:
35 | dataroot_gt (str): Data root path for gt.
36 | dataroot_lq (str): Data root path for lq.
37 | meta_info_file (str): Path for meta information file.
38 | io_backend (dict): IO backend type and other kwarg.
39 |
40 | num_frame (int): Window size for input frames.
41 | gt_size (int): Cropped patched size for gt patches.
42 | random_reverse (bool): Random reverse input frames.
43 | use_flip (bool): Use horizontal flips.
44 | use_rot (bool): Use rotation (use vertical flip and transposing h
45 | and w for implementation).
46 |
47 | scale (bool): Scale, which will be added automatically.
48 | """
49 |
50 | def __init__(self, opt):
51 | super(Vimeo90KDataset, self).__init__()
52 | self.opt = opt
53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
54 | opt['dataroot_lq'])
55 |
56 | with open(opt['meta_info_file'], 'r') as fin:
57 | self.keys = [line.split(' ')[0] for line in fin]
58 |
59 | # file client (io backend)
60 | self.file_client = None
61 | self.io_backend_opt = opt['io_backend']
62 | self.is_lmdb = False
63 | if self.io_backend_opt['type'] == 'lmdb':
64 | self.is_lmdb = True
65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
66 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
67 |
68 | # indices of input images
69 | self.neighbor_list = [
70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
71 | ]
72 |
73 | # temporal augmentation configs
74 | self.random_reverse = opt['random_reverse']
75 | logger = get_root_logger()
76 | logger.info(f'Random reverse is {self.random_reverse}.')
77 |
78 | def __getitem__(self, index):
79 | if self.file_client is None:
80 | self.file_client = FileClient(
81 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
82 |
83 | # random reverse
84 | if self.random_reverse and random.random() < 0.5:
85 | self.neighbor_list.reverse()
86 |
87 | scale = self.opt['scale']
88 | gt_size = self.opt['gt_size']
89 | key = self.keys[index]
90 | clip, seq = key.split('/') # key example: 00001/0001
91 |
92 | # get the GT frame (im4.png)
93 | if self.is_lmdb:
94 | img_gt_path = f'{key}/im4'
95 | else:
96 | img_gt_path = self.gt_root / clip / seq / 'im4.png'
97 | img_bytes = self.file_client.get(img_gt_path, 'gt')
98 | img_gt = imfrombytes(img_bytes, float32=True)
99 |
100 | # get the neighboring LQ frames
101 | img_lqs = []
102 | for neighbor in self.neighbor_list:
103 | if self.is_lmdb:
104 | img_lq_path = f'{clip}/{seq}/im{neighbor}'
105 | else:
106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
107 | img_bytes = self.file_client.get(img_lq_path, 'lq')
108 | img_lq = imfrombytes(img_bytes, float32=True)
109 | img_lqs.append(img_lq)
110 |
111 | # randomly crop
112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
113 | img_gt_path)
114 |
115 | # augmentation - flip, rotate
116 | img_lqs.append(img_gt)
117 | img_results = augment(img_lqs, self.opt['use_flip'],
118 | self.opt['use_rot'])
119 |
120 | img_results = img2tensor(img_results)
121 | img_lqs = torch.stack(img_results[0:-1], dim=0)
122 | img_gt = img_results[-1]
123 |
124 | # img_lqs: (t, c, h, w)
125 | # img_gt: (c, h, w)
126 | # key: str
127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key}
128 |
129 | def __len__(self):
130 | return len(self.keys)
131 |
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .niqe import calculate_niqe
2 | from .psnr_ssim import calculate_psnr, calculate_ssim
3 |
4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
5 |
--------------------------------------------------------------------------------
/basicsr/metrics/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy import linalg
5 | from tqdm import tqdm
6 |
7 | from basicsr.models.archs.inception import InceptionV3
8 |
9 |
10 | def load_patched_inception_v3(device='cuda',
11 | resize_input=True,
12 | normalize_input=False):
13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
14 | # does resize the input.
15 | inception = InceptionV3([3],
16 | resize_input=resize_input,
17 | normalize_input=normalize_input)
18 | inception = nn.DataParallel(inception).eval().to(device)
19 | return inception
20 |
21 |
22 | @torch.no_grad()
23 | def extract_inception_features(data_generator,
24 | inception,
25 | len_generator=None,
26 | device='cuda'):
27 | """Extract inception features.
28 |
29 | Args:
30 | data_generator (generator): A data generator.
31 | inception (nn.Module): Inception model.
32 | len_generator (int): Length of the data_generator to show the
33 | progressbar. Default: None.
34 | device (str): Device. Default: cuda.
35 |
36 | Returns:
37 | Tensor: Extracted features.
38 | """
39 | if len_generator is not None:
40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
41 | else:
42 | pbar = None
43 | features = []
44 |
45 | for data in data_generator:
46 | if pbar:
47 | pbar.update(1)
48 | data = data.to(device)
49 | feature = inception(data)[0].view(data.shape[0], -1)
50 | features.append(feature.to('cpu'))
51 | if pbar:
52 | pbar.close()
53 | features = torch.cat(features, 0)
54 | return features
55 |
56 |
57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
58 | """Numpy implementation of the Frechet Distance.
59 |
60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
61 | and X_2 ~ N(mu_2, C_2) is
62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
63 | Stable version by Dougal J. Sutherland.
64 |
65 | Args:
66 | mu1 (np.array): The sample mean over activations.
67 | sigma1 (np.array): The covariance matrix over activations for
68 | generated samples.
69 | mu2 (np.array): The sample mean over activations, precalculated on an
70 | representative data set.
71 | sigma2 (np.array): The covariance matrix over activations,
72 | precalculated on an representative data set.
73 |
74 | Returns:
75 | float: The Frechet Distance.
76 | """
77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
78 | assert sigma1.shape == sigma2.shape, (
79 | 'Two covariances have different dimensions')
80 |
81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
82 |
83 | # Product might be almost singular
84 | if not np.isfinite(cov_sqrt).all():
85 | print('Product of cov matrices is singular. Adding {eps} to diagonal '
86 | 'of cov estimates')
87 | offset = np.eye(sigma1.shape[0]) * eps
88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
89 |
90 | # Numerical error might give slight imaginary component
91 | if np.iscomplexobj(cov_sqrt):
92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
93 | m = np.max(np.abs(cov_sqrt.imag))
94 | raise ValueError(f'Imaginary component {m}')
95 | cov_sqrt = cov_sqrt.real
96 |
97 | mean_diff = mu1 - mu2
98 | mean_norm = mean_diff @ mean_diff
99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
100 | fid = mean_norm + trace
101 |
102 | return fid
103 |
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from basicsr.utils.matlab_functions import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order='HWC'):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ['HWC', 'CHW']:
24 | raise ValueError(
25 | f'Wrong input_order {input_order}. Supported input_orders are '
26 | "'HWC' and 'CHW'")
27 | if len(img.shape) == 2:
28 | img = img[..., None]
29 | if input_order == 'CHW':
30 | img = img.transpose(1, 2, 0)
31 | return img
32 |
33 |
34 | def to_y_channel(img):
35 | """Change to Y channel of YCbCr.
36 |
37 | Args:
38 | img (ndarray): Images with range [0, 255].
39 |
40 | Returns:
41 | (ndarray): Images with range [0, 255] (float type) without round.
42 | """
43 | img = img.astype(np.float32) / 255.
44 | if img.ndim == 3 and img.shape[2] == 3:
45 | img = bgr2ycbcr(img, y_only=True)
46 | img = img[..., None]
47 | return img * 255.
48 |
--------------------------------------------------------------------------------
/basicsr/metrics/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/basicsr/metrics/niqe_pris_params.npz
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import get_root_logger, scandir
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12 | if v.endswith('_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'basicsr.models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamic instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = get_root_logger()
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/basicsr/models/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import arch modules
7 | # scan all the files under the 'archs' folder and collect files ending with
8 | # '_arch.py'
9 | arch_folder = osp.dirname(osp.abspath(__file__))
10 | arch_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12 | if v.endswith('_arch.py')
13 | ]
14 | # import all the arch modules
15 | _arch_modules = [
16 | importlib.import_module(f'basicsr.models.archs.{file_name}')
17 | for file_name in arch_filenames
18 | ]
19 |
20 |
21 | def dynamic_instantiation(modules, cls_type, opt):
22 | """Dynamically instantiate class.
23 |
24 | Args:
25 | modules (list[importlib modules]): List of modules from importlib
26 | files.
27 | cls_type (str): Class type.
28 | opt (dict): Class initialization kwargs.
29 |
30 | Returns:
31 | class: Instantiated class.
32 | """
33 |
34 | for module in modules:
35 | cls_ = getattr(module, cls_type, None)
36 | if cls_ is not None:
37 | break
38 | if cls_ is None:
39 | raise ValueError(f'{cls_type} is not found.')
40 | return cls_(**opt)
41 |
42 |
43 | def define_network(opt):
44 | network_type = opt.pop('type')
45 | net = dynamic_instantiation(_arch_modules, network_type, opt)
46 | return net
47 |
--------------------------------------------------------------------------------
/basicsr/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss)
2 |
3 | __all__ = [
4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss',
5 | ]
6 |
--------------------------------------------------------------------------------
/basicsr/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import functional as F
3 |
4 |
5 | def reduce_loss(loss, reduction):
6 | """Reduce loss as specified.
7 |
8 | Args:
9 | loss (Tensor): Elementwise loss tensor.
10 | reduction (str): Options are 'none', 'mean' and 'sum'.
11 |
12 | Returns:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | else:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
26 | """Apply element-wise weight and reduce loss.
27 |
28 | Args:
29 | loss (Tensor): Element-wise loss.
30 | weight (Tensor): Element-wise weights. Default: None.
31 | reduction (str): Same as built-in losses of PyTorch. Options are
32 | 'none', 'mean' and 'sum'. Default: 'mean'.
33 |
34 | Returns:
35 | Tensor: Loss values.
36 | """
37 | # if weight is specified, apply element-wise weight
38 | if weight is not None:
39 | assert weight.dim() == loss.dim()
40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41 | loss = loss * weight
42 |
43 | # if weight is not specified or reduction is sum, just reduce the loss
44 | if weight is None or reduction == 'sum':
45 | loss = reduce_loss(loss, reduction)
46 | # if reduction is mean, then compute mean over weight region
47 | elif reduction == 'mean':
48 | if weight.size(1) > 1:
49 | weight = weight.sum()
50 | else:
51 | weight = weight.sum() * loss.size(1)
52 | loss = loss.sum() / weight
53 |
54 | return loss
55 |
56 |
57 | def weighted_loss(loss_func):
58 | """Create a weighted version of a given loss function.
59 |
60 | To use this decorator, the loss function must have the signature like
61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
62 | element-wise loss without any reduction. This decorator will add weight
63 | and reduction arguments to the function. The decorated function will have
64 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
65 | **kwargs)`.
66 |
67 | :Example:
68 |
69 | >>> import torch
70 | >>> @weighted_loss
71 | >>> def l1_loss(pred, target):
72 | >>> return (pred - target).abs()
73 |
74 | >>> pred = torch.Tensor([0, 2, 3])
75 | >>> target = torch.Tensor([1, 1, 1])
76 | >>> weight = torch.Tensor([1, 0, 1])
77 |
78 | >>> l1_loss(pred, target)
79 | tensor(1.3333)
80 | >>> l1_loss(pred, target, weight)
81 | tensor(1.5000)
82 | >>> l1_loss(pred, target, reduction='none')
83 | tensor([1., 1., 2.])
84 | >>> l1_loss(pred, target, weight, reduction='sum')
85 | tensor(3.)
86 | """
87 |
88 | @functools.wraps(loss_func)
89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90 | # get element-wise loss
91 | loss = loss_func(pred, target, **kwargs)
92 | loss = weight_reduce_loss(loss, weight, reduction)
93 | return loss
94 |
95 | return wrapper
96 |
--------------------------------------------------------------------------------
/basicsr/models/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from basicsr.models.losses.loss_util import weighted_loss
7 |
8 | _reduction_modes = ['none', 'mean', 'sum']
9 |
10 |
11 | @weighted_loss
12 | def l1_loss(pred, target):
13 | return F.l1_loss(pred, target, reduction='none')
14 |
15 |
16 | @weighted_loss
17 | def mse_loss(pred, target):
18 | return F.mse_loss(pred, target, reduction='none')
19 |
20 |
21 | # @weighted_loss
22 | # def charbonnier_loss(pred, target, eps=1e-12):
23 | # return torch.sqrt((pred - target)**2 + eps)
24 |
25 |
26 | class L1Loss(nn.Module):
27 | """L1 (mean absolute error, MAE) loss.
28 |
29 | Args:
30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
31 | reduction (str): Specifies the reduction to apply to the output.
32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
33 | """
34 |
35 | def __init__(self, loss_weight=1.0, reduction='mean'):
36 | super(L1Loss, self).__init__()
37 | if reduction not in ['none', 'mean', 'sum']:
38 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
39 | f'Supported ones are: {_reduction_modes}')
40 |
41 | self.loss_weight = loss_weight
42 | self.reduction = reduction
43 |
44 | def forward(self, pred, target, weight=None, **kwargs):
45 | """
46 | Args:
47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
50 | weights. Default: None.
51 | """
52 | return self.loss_weight * l1_loss(
53 | pred, target, weight, reduction=self.reduction)
54 |
55 | class MSELoss(nn.Module):
56 | """MSE (L2) loss.
57 |
58 | Args:
59 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
60 | reduction (str): Specifies the reduction to apply to the output.
61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
62 | """
63 |
64 | def __init__(self, loss_weight=1.0, reduction='mean'):
65 | super(MSELoss, self).__init__()
66 | if reduction not in ['none', 'mean', 'sum']:
67 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
68 | f'Supported ones are: {_reduction_modes}')
69 |
70 | self.loss_weight = loss_weight
71 | self.reduction = reduction
72 |
73 | def forward(self, pred, target, weight=None, **kwargs):
74 | """
75 | Args:
76 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
77 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
78 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
79 | weights. Default: None.
80 | """
81 | return self.loss_weight * mse_loss(
82 | pred, target, weight, reduction=self.reduction)
83 |
84 | class PSNRLoss(nn.Module):
85 |
86 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
87 | super(PSNRLoss, self).__init__()
88 | assert reduction == 'mean'
89 | self.loss_weight = loss_weight
90 | self.scale = 10 / np.log(10)
91 | self.toY = toY
92 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
93 | self.first = True
94 |
95 | def forward(self, pred, target):
96 | assert len(pred.size()) == 4
97 | if self.toY:
98 | if self.first:
99 | self.coef = self.coef.to(pred.device)
100 | self.first = False
101 |
102 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
103 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
104 |
105 | pred, target = pred / 255., target / 255.
106 | pass
107 | assert len(pred.size()) == 4
108 |
109 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
110 |
111 | class CharbonnierLoss(nn.Module):
112 | """Charbonnier Loss (L1)"""
113 |
114 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3):
115 | super(CharbonnierLoss, self).__init__()
116 | self.eps = eps
117 |
118 | def forward(self, x, y):
119 | diff = x - y
120 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
121 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
122 | return loss
123 |
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import create_dataloader, create_dataset
6 | from basicsr.models import create_model
7 | from basicsr.train import parse_options
8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
9 | make_exp_dirs)
10 | from basicsr.utils.options import dict2str
11 |
12 |
13 | def main():
14 | # parse options, set distributed setting, set ramdom seed
15 | opt = parse_options(is_train=False)
16 |
17 | torch.backends.cudnn.benchmark = True
18 | # torch.backends.cudnn.deterministic = True
19 |
20 | # mkdir and initialize loggers
21 | make_exp_dirs(opt)
22 | log_file = osp.join(opt['path']['log'],
23 | f"test_{opt['name']}_{get_time_str()}.log")
24 | logger = get_root_logger(
25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
26 | logger.info(get_env_info())
27 | logger.info(dict2str(opt))
28 |
29 | # create test dataset and dataloader
30 | test_loaders = []
31 | for phase, dataset_opt in sorted(opt['datasets'].items()):
32 | test_set = create_dataset(dataset_opt)
33 | test_loader = create_dataloader(
34 | test_set,
35 | dataset_opt,
36 | num_gpu=opt['num_gpu'],
37 | dist=opt['dist'],
38 | sampler=None,
39 | seed=opt['manual_seed'])
40 | logger.info(
41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
42 | test_loaders.append(test_loader)
43 |
44 | # create model
45 | model = create_model(opt)
46 |
47 | for test_loader in test_loaders:
48 | test_set_name = test_loader.dataset.opt['name']
49 | logger.info(f'Testing {test_set_name}...')
50 | rgb2bgr = opt['val'].get('rgb2bgr', True)
51 | # wheather use uint8 image to compute metrics
52 | use_image = opt['val'].get('use_image', True)
53 | model.validation(
54 | test_loader,
55 | current_iter=opt['name'],
56 | tb_logger=None,
57 | save_img=opt['val']['save_img'],
58 | rgb2bgr=rgb2bgr, use_image=use_image)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_client import FileClient
2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP
3 | from .logger import (MessageLogger, get_env_info, get_root_logger,
4 | init_tb_logger, init_wandb_logger)
5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'check_resume',
31 | 'sizeof_fmt',
32 | 'padding',
33 | 'padding_DP',
34 | 'imfrombytesDP',
35 | 'create_lmdb_for_reds',
36 | 'create_lmdb_for_gopro',
37 | 'create_lmdb_for_rain13k',
38 | ]
39 |
--------------------------------------------------------------------------------
/basicsr/utils/bundle_submissions.py:
--------------------------------------------------------------------------------
1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de)
2 |
3 | # This file is part of the implementation as described in the CVPR 2017 paper:
4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
5 | # Please see the file LICENSE.txt for the license governing this code.
6 |
7 |
8 | import numpy as np
9 | import scipy.io as sio
10 | import os
11 | import h5py
12 |
13 | def bundle_submissions_raw(submission_folder,session):
14 | '''
15 | Bundles submission data for raw denoising
16 |
17 | submission_folder Folder where denoised images reside
18 |
19 | Output is written to /bundled/. Please submit
20 | the content of this folder.
21 | '''
22 |
23 | out_folder = os.path.join(submission_folder, session)
24 | # out_folder = os.path.join(submission_folder, "bundled/")
25 | try:
26 | os.mkdir(out_folder)
27 | except:pass
28 |
29 | israw = True
30 | eval_version="1.0"
31 |
32 | for i in range(50):
33 | Idenoised = np.zeros((20,), dtype=np.object)
34 | for bb in range(20):
35 | filename = '%04d_%02d.mat'%(i+1,bb+1)
36 | s = sio.loadmat(os.path.join(submission_folder,filename))
37 | Idenoised_crop = s["Idenoised_crop"]
38 | Idenoised[bb] = Idenoised_crop
39 | filename = '%04d.mat'%(i+1)
40 | sio.savemat(os.path.join(out_folder, filename),
41 | {"Idenoised": Idenoised,
42 | "israw": israw,
43 | "eval_version": eval_version},
44 | )
45 |
46 | def bundle_submissions_srgb(submission_folder,session):
47 | '''
48 | Bundles submission data for sRGB denoising
49 |
50 | submission_folder Folder where denoised images reside
51 |
52 | Output is written to /bundled/. Please submit
53 | the content of this folder.
54 | '''
55 | out_folder = os.path.join(submission_folder, session)
56 | # out_folder = os.path.join(submission_folder, "bundled/")
57 | try:
58 | os.mkdir(out_folder)
59 | except:pass
60 | israw = False
61 | eval_version="1.0"
62 |
63 | for i in range(50):
64 | Idenoised = np.zeros((20,), dtype=np.object)
65 | for bb in range(20):
66 | filename = '%04d_%02d.mat'%(i+1,bb+1)
67 | s = sio.loadmat(os.path.join(submission_folder,filename))
68 | Idenoised_crop = s["Idenoised_crop"]
69 | Idenoised[bb] = Idenoised_crop
70 | filename = '%04d.mat'%(i+1)
71 | sio.savemat(os.path.join(out_folder, filename),
72 | {"Idenoised": Idenoised,
73 | "israw": israw,
74 | "eval_version": eval_version},
75 | )
76 |
77 |
78 |
79 | def bundle_submissions_srgb_v1(submission_folder,session):
80 | '''
81 | Bundles submission data for sRGB denoising
82 |
83 | submission_folder Folder where denoised images reside
84 |
85 | Output is written to /bundled/. Please submit
86 | the content of this folder.
87 | '''
88 | out_folder = os.path.join(submission_folder, session)
89 | # out_folder = os.path.join(submission_folder, "bundled/")
90 | try:
91 | os.mkdir(out_folder)
92 | except:pass
93 | israw = False
94 | eval_version="1.0"
95 |
96 | for i in range(50):
97 | Idenoised = np.zeros((20,), dtype=np.object)
98 | for bb in range(20):
99 | filename = '%04d_%d.mat'%(i+1,bb+1)
100 | s = sio.loadmat(os.path.join(submission_folder,filename))
101 | Idenoised_crop = s["Idenoised_crop"]
102 | Idenoised[bb] = Idenoised_crop
103 | filename = '%04d.mat'%(i+1)
104 | sio.savemat(os.path.join(out_folder, filename),
105 | {"Idenoised": Idenoised,
106 | "israw": israw,
107 | "eval_version": eval_version},
108 | )
--------------------------------------------------------------------------------
/basicsr/utils/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6 |
7 | def prepare_keys(folder_path, suffix='png'):
8 | """Prepare image path list and keys for DIV2K dataset.
9 |
10 | Args:
11 | folder_path (str): Folder path.
12 |
13 | Returns:
14 | list[str]: Image path list.
15 | list[str]: Key list.
16 | """
17 | print('Reading image path list ...')
18 | img_path_list = sorted(
19 | list(scandir(folder_path, suffix=suffix, recursive=False)))
20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
21 |
22 | return img_path_list, keys
23 |
24 | def create_lmdb_for_reds():
25 | folder_path = './datasets/REDS/val/sharp_300'
26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
27 | img_path_list, keys = prepare_keys(folder_path, 'png')
28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
29 | #
30 | folder_path = './datasets/REDS/val/blur_300'
31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb'
32 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
34 |
35 | folder_path = './datasets/REDS/train/train_sharp'
36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
37 | img_path_list, keys = prepare_keys(folder_path, 'png')
38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
39 |
40 | folder_path = './datasets/REDS/train/train_blur_jpeg'
41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
42 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
44 |
45 |
46 | def create_lmdb_for_gopro():
47 | folder_path = './datasets/GoPro/train/blur_crops'
48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
49 |
50 | img_path_list, keys = prepare_keys(folder_path, 'png')
51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
52 |
53 | folder_path = './datasets/GoPro/train/sharp_crops'
54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
55 |
56 | img_path_list, keys = prepare_keys(folder_path, 'png')
57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
58 |
59 | folder_path = './datasets/GoPro/test/target'
60 | lmdb_path = './datasets/GoPro/test/target.lmdb'
61 |
62 | img_path_list, keys = prepare_keys(folder_path, 'png')
63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
64 |
65 | folder_path = './datasets/GoPro/test/input'
66 | lmdb_path = './datasets/GoPro/test/input.lmdb'
67 |
68 | img_path_list, keys = prepare_keys(folder_path, 'png')
69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
70 |
71 | def create_lmdb_for_rain13k():
72 | folder_path = './datasets/Rain13k/train/input'
73 | lmdb_path = './datasets/Rain13k/train/input.lmdb'
74 |
75 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
77 |
78 | folder_path = './datasets/Rain13k/train/target'
79 | lmdb_path = './datasets/Rain13k/train/target.lmdb'
80 |
81 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
83 |
84 | def create_lmdb_for_SIDD():
85 | folder_path = './datasets/SIDD/train/input_crops'
86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
87 |
88 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
90 |
91 | folder_path = './datasets/SIDD/train/gt_crops'
92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
93 |
94 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
96 |
97 | #for val
98 | folder_path = './datasets/SIDD/val/input_crops'
99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
101 | if not osp.exists(folder_path):
102 | os.makedirs(folder_path)
103 | assert osp.exists(mat_path)
104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
105 | N, B, H ,W, C = data.shape
106 | data = data.reshape(N*B, H, W, C)
107 | for i in tqdm(range(N*B)):
108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
109 | img_path_list, keys = prepare_keys(folder_path, 'png')
110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
111 |
112 | folder_path = './datasets/SIDD/val/gt_crops'
113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
115 | if not osp.exists(folder_path):
116 | os.makedirs(folder_path)
117 | assert osp.exists(mat_path)
118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
119 | N, B, H ,W, C = data.shape
120 | data = data.reshape(N*B, H, W, C)
121 | for i in tqdm(range(N*B)):
122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
123 | img_path_list, keys = prepare_keys(folder_path, 'png')
124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
125 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(
45 | f'scontrol show hostname {node_list} | head -n1')
46 | # specify master port
47 | if port is not None:
48 | os.environ['MASTER_PORT'] = str(port)
49 | elif 'MASTER_PORT' in os.environ:
50 | pass # use MASTER_PORT in the environment variable
51 | else:
52 | # 29500 is torch.distributed default port
53 | os.environ['MASTER_PORT'] = '29500'
54 | os.environ['MASTER_ADDR'] = addr
55 | os.environ['WORLD_SIZE'] = str(ntasks)
56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57 | os.environ['RANK'] = str(proc_id)
58 | dist.init_process_group(backend=backend)
59 |
60 |
61 | def get_dist_info():
62 | if dist.is_available():
63 | initialized = dist.is_initialized()
64 | else:
65 | initialized = False
66 | if initialized:
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | else:
70 | rank = 0
71 | world_size = 1
72 | return rank, world_size
73 |
74 |
75 | def master_only(func):
76 |
77 | @functools.wraps(func)
78 | def wrapper(*args, **kwargs):
79 | rank, _ = get_dist_info()
80 | if rank == 0:
81 | return func(*args, **kwargs)
82 |
83 | return wrapper
84 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import requests
3 | from tqdm import tqdm
4 |
5 | from .misc import sizeof_fmt
6 |
7 |
8 | def download_file_from_google_drive(file_id, save_path):
9 | """Download files from google drive.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13 |
14 | Args:
15 | file_id (str): File id.
16 | save_path (str): Save path.
17 | """
18 |
19 | session = requests.Session()
20 | URL = 'https://docs.google.com/uc?export=download'
21 | params = {'id': file_id}
22 |
23 | response = session.get(URL, params=params, stream=True)
24 | token = get_confirm_token(response)
25 | if token:
26 | params['confirm'] = token
27 | response = session.get(URL, params=params, stream=True)
28 |
29 | # get file size
30 | response_file_size = session.get(
31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
32 | if 'Content-Range' in response_file_size.headers:
33 | file_size = int(
34 | response_file_size.headers['Content-Range'].split('/')[1])
35 | else:
36 | file_size = None
37 |
38 | save_response_content(response, save_path, file_size)
39 |
40 |
41 | def get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 |
48 | def save_response_content(response,
49 | destination,
50 | file_size=None,
51 | chunk_size=32768):
52 | if file_size is not None:
53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
54 |
55 | readable_file_size = sizeof_fmt(file_size)
56 | else:
57 | pbar = None
58 |
59 | with open(destination, 'wb') as f:
60 | downloaded_size = 0
61 | for chunk in response.iter_content(chunk_size):
62 | downloaded_size += chunk_size
63 | if pbar is not None:
64 | pbar.update(1)
65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
66 | f'/ {readable_file_size}')
67 | if chunk: # filter out keep-alive new chunks
68 | f.write(chunk)
69 | if pbar is not None:
70 | pbar.close()
71 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 | initialized_logger = {}
8 |
9 |
10 | class MessageLogger():
11 | """Message logger for printing.
12 |
13 | Args:
14 | opt (dict): Config. It contains the following keys:
15 | name (str): Exp name.
16 | logger (dict): Contains 'print_freq' (str) for logger interval.
17 | train (dict): Contains 'total_iter' (int) for total iters.
18 | use_tb_logger (bool): Use tensorboard logger.
19 | start_iter (int): Start iter. Default: 1.
20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
21 | """
22 |
23 | def __init__(self, opt, start_iter=1, tb_logger=None):
24 | self.exp_name = opt['name']
25 | self.interval = opt['logger']['print_freq']
26 | self.start_iter = start_iter
27 | self.max_iters = opt['train']['total_iter']
28 | self.use_tb_logger = opt['logger']['use_tb_logger']
29 | self.tb_logger = tb_logger
30 | self.start_time = time.time()
31 | self.logger = get_root_logger()
32 |
33 | @master_only
34 | def __call__(self, log_vars):
35 | """Format logging message.
36 |
37 | Args:
38 | log_vars (dict): It contains the following keys:
39 | epoch (int): Epoch number.
40 | iter (int): Current iter.
41 | lrs (list): List for learning rates.
42 |
43 | time (float): Iter time.
44 | data_time (float): Data time for each iter.
45 | """
46 | # epoch, iter, learning rates
47 | epoch = log_vars.pop('epoch')
48 | current_iter = log_vars.pop('iter')
49 | lrs = log_vars.pop('lrs')
50 |
51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
52 | for v in lrs:
53 | message += f'{v:.3e},'
54 | message += ')] '
55 |
56 | # time and estimated time
57 | if 'time' in log_vars.keys():
58 | iter_time = log_vars.pop('time')
59 | data_time = log_vars.pop('data_time')
60 |
61 | total_time = time.time() - self.start_time
62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
65 | message += f'[eta: {eta_str}, '
66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
67 |
68 | # other items, especially losses
69 | for k, v in log_vars.items():
70 | message += f'{k}: {v:.4e} '
71 | # tensorboard logger
72 | if self.use_tb_logger and 'debug' not in self.exp_name:
73 | if k.startswith('l_'):
74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
75 | else:
76 | self.tb_logger.add_scalar(k, v, current_iter)
77 | self.logger.info(message)
78 |
79 |
80 | @master_only
81 | def init_tb_logger(log_dir):
82 | from torch.utils.tensorboard import SummaryWriter
83 | tb_logger = SummaryWriter(log_dir=log_dir)
84 | return tb_logger
85 |
86 |
87 | @master_only
88 | def init_wandb_logger(opt):
89 | """We now only use wandb to sync tensorboard log."""
90 | import wandb
91 | logger = logging.getLogger('basicsr')
92 |
93 | project = opt['logger']['wandb']['project']
94 | resume_id = opt['logger']['wandb'].get('resume_id')
95 | if resume_id:
96 | wandb_id = resume_id
97 | resume = 'allow'
98 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
99 | else:
100 | wandb_id = wandb.util.generate_id()
101 | resume = 'never'
102 |
103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
104 |
105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
106 |
107 |
108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
109 | """Get the root logger.
110 |
111 | The logger will be initialized if it has not been initialized. By default a
112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
113 | also be added.
114 |
115 | Args:
116 | logger_name (str): root logger name. Default: 'basicsr'.
117 | log_file (str | None): The log filename. If specified, a FileHandler
118 | will be added to the root logger.
119 | log_level (int): The root logger level. Note that only the process of
120 | rank 0 is affected, while other processes will set the level to
121 | "Error" and be silent most of the time.
122 |
123 | Returns:
124 | logging.Logger: The root logger.
125 | """
126 | logger = logging.getLogger(logger_name)
127 | # if the logger has been initialized, just return it
128 | if logger_name in initialized_logger:
129 | return logger
130 |
131 | format_str = '%(asctime)s %(levelname)s: %(message)s'
132 | stream_handler = logging.StreamHandler()
133 | stream_handler.setFormatter(logging.Formatter(format_str))
134 | logger.addHandler(stream_handler)
135 | logger.propagate = False
136 | rank, _ = get_dist_info()
137 | if rank != 0:
138 | logger.setLevel('ERROR')
139 | elif log_file is not None:
140 | logger.setLevel(log_level)
141 | # add file handler
142 | file_handler = logging.FileHandler(log_file, 'w')
143 | file_handler.setFormatter(logging.Formatter(format_str))
144 | file_handler.setLevel(log_level)
145 | logger.addHandler(file_handler)
146 | initialized_logger[logger_name] = True
147 | return logger
148 |
149 |
150 | def get_env_info():
151 | """Get environment information.
152 |
153 | Currently, only log the software version.
154 | """
155 | import torch
156 | import torchvision
157 |
158 | from basicsr.version import __version__
159 | msg = r"""
160 | ____ _ _____ ____
161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
164 | /_____/ \__,_//____//_/ \___//____//_/ |_|
165 | ______ __ __ __ __
166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
170 | """
171 | msg += ('\nVersion Information: '
172 | f'\n\tBasicSR: {__version__}'
173 | f'\n\tPyTorch: {torch.__version__}'
174 | f'\n\tTorchVision: {torchvision.__version__}')
175 | return msg
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 | from .logger import get_root_logger
10 |
11 |
12 | def set_random_seed(seed):
13 | """Set random seeds."""
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 |
20 |
21 | def get_time_str():
22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23 |
24 |
25 | def mkdir_and_rename(path):
26 | """mkdirs. If path exists, rename it with timestamp and create a new one.
27 |
28 | Args:
29 | path (str): Folder path.
30 | """
31 | if osp.exists(path):
32 | new_name = path + '_archived_' + get_time_str()
33 | print(f'Path already exists. Rename it to {new_name}', flush=True)
34 | os.rename(path, new_name)
35 | os.makedirs(path, exist_ok=True)
36 |
37 |
38 | @master_only
39 | def make_exp_dirs(opt):
40 | """Make dirs for experiments."""
41 | path_opt = opt['path'].copy()
42 | if opt['is_train']:
43 | mkdir_and_rename(path_opt.pop('experiments_root'))
44 | else:
45 | mkdir_and_rename(path_opt.pop('results_root'))
46 | for key, path in path_opt.items():
47 | if ('strict_load' not in key) and ('pretrain_network'
48 | not in key) and ('resume'
49 | not in key):
50 | os.makedirs(path, exist_ok=True)
51 |
52 |
53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
54 | """Scan a directory to find the interested files.
55 |
56 | Args:
57 | dir_path (str): Path of the directory.
58 | suffix (str | tuple(str), optional): File suffix that we are
59 | interested in. Default: None.
60 | recursive (bool, optional): If set to True, recursively scan the
61 | directory. Default: False.
62 | full_path (bool, optional): If set to True, include the dir_path.
63 | Default: False.
64 |
65 | Returns:
66 | A generator for all the interested files with relative pathes.
67 | """
68 |
69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
70 | raise TypeError('"suffix" must be a string or tuple of strings')
71 |
72 | root = dir_path
73 |
74 | def _scandir(dir_path, suffix, recursive):
75 | for entry in os.scandir(dir_path):
76 | if not entry.name.startswith('.') and entry.is_file():
77 | if full_path:
78 | return_path = entry.path
79 | else:
80 | return_path = osp.relpath(entry.path, root)
81 |
82 | if suffix is None:
83 | yield return_path
84 | elif return_path.endswith(suffix):
85 | yield return_path
86 | else:
87 | if recursive:
88 | yield from _scandir(
89 | entry.path, suffix=suffix, recursive=recursive)
90 | else:
91 | continue
92 |
93 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
94 |
95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
96 | """Scan a directory to find the interested files.
97 |
98 | Args:
99 | dir_path (str): Path of the directory.
100 | keywords (str | tuple(str), optional): File keywords that we are
101 | interested in. Default: None.
102 | recursive (bool, optional): If set to True, recursively scan the
103 | directory. Default: False.
104 | full_path (bool, optional): If set to True, include the dir_path.
105 | Default: False.
106 |
107 | Returns:
108 | A generator for all the interested files with relative pathes.
109 | """
110 |
111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)):
112 | raise TypeError('"keywords" must be a string or tuple of strings')
113 |
114 | root = dir_path
115 |
116 | def _scandir(dir_path, keywords, recursive):
117 | for entry in os.scandir(dir_path):
118 | if not entry.name.startswith('.') and entry.is_file():
119 | if full_path:
120 | return_path = entry.path
121 | else:
122 | return_path = osp.relpath(entry.path, root)
123 |
124 | if keywords is None:
125 | yield return_path
126 | elif return_path.find(keywords) > 0:
127 | yield return_path
128 | else:
129 | if recursive:
130 | yield from _scandir(
131 | entry.path, keywords=keywords, recursive=recursive)
132 | else:
133 | continue
134 |
135 | return _scandir(dir_path, keywords=keywords, recursive=recursive)
136 |
137 | def check_resume(opt, resume_iter):
138 | """Check resume states and pretrain_network paths.
139 |
140 | Args:
141 | opt (dict): Options.
142 | resume_iter (int): Resume iteration.
143 | """
144 | logger = get_root_logger()
145 | if opt['path']['resume_state']:
146 | # get all the networks
147 | networks = [key for key in opt.keys() if key.startswith('network_')]
148 | flag_pretrain = False
149 | for network in networks:
150 | if opt['path'].get(f'pretrain_{network}') is not None:
151 | flag_pretrain = True
152 | if flag_pretrain:
153 | logger.warning(
154 | 'pretrain_network path will be ignored during resuming.')
155 | # set pretrained model paths
156 | for network in networks:
157 | name = f'pretrain_{network}'
158 | basename = network.replace('network_', '')
159 | if opt['path'].get('ignore_resume_networks') is None or (
160 | basename not in opt['path']['ignore_resume_networks']):
161 | opt['path'][name] = osp.join(
162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
163 | logger.info(f"Set {name} to {opt['path'][name]}")
164 |
165 |
166 | def sizeof_fmt(size, suffix='B'):
167 | """Get human readable file size.
168 |
169 | Args:
170 | size (int): File size.
171 | suffix (str): Suffix. Default: 'B'.
172 |
173 | Return:
174 | str: Formated file siz.
175 | """
176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
177 | if abs(size) < 1024.0:
178 | return f'{size:3.1f} {unit}{suffix}'
179 | size /= 1024.0
180 | return f'{size:3.1f} Y{suffix}'
181 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from collections import OrderedDict
3 | from os import path as osp
4 |
5 |
6 | def ordered_yaml():
7 | """Support OrderedDict for yaml.
8 |
9 | Returns:
10 | yaml Loader and Dumper.
11 | """
12 | try:
13 | from yaml import CDumper as Dumper
14 | from yaml import CLoader as Loader
15 | except ImportError:
16 | from yaml import Dumper, Loader
17 |
18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
19 |
20 | def dict_representer(dumper, data):
21 | return dumper.represent_dict(data.items())
22 |
23 | def dict_constructor(loader, node):
24 | return OrderedDict(loader.construct_pairs(node))
25 |
26 | Dumper.add_representer(OrderedDict, dict_representer)
27 | Loader.add_constructor(_mapping_tag, dict_constructor)
28 | return Loader, Dumper
29 |
30 |
31 | def parse(opt_path, is_train=True):
32 | """Parse option file.
33 |
34 | Args:
35 | opt_path (str): Option file path.
36 | is_train (str): Indicate whether in training or not. Default: True.
37 |
38 | Returns:
39 | (dict): Options.
40 | """
41 | with open(opt_path, mode='r') as f:
42 | Loader, _ = ordered_yaml()
43 | opt = yaml.load(f, Loader=Loader)
44 |
45 | opt['is_train'] = is_train
46 |
47 | # datasets
48 | for phase, dataset in opt['datasets'].items():
49 | # for several datasets, e.g., test_1, test_2
50 | phase = phase.split('_')[0]
51 | dataset['phase'] = phase
52 | if 'scale' in opt:
53 | dataset['scale'] = opt['scale']
54 | if dataset.get('dataroot_gt') is not None:
55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
56 | if dataset.get('dataroot_lq') is not None:
57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
58 |
59 | # paths
60 | for key, val in opt['path'].items():
61 | if (val is not None) and ('resume_state' in key
62 | or 'pretrain_network' in key):
63 | opt['path'][key] = osp.expanduser(val)
64 | opt['path']['root'] = osp.abspath(
65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
66 | if is_train:
67 | experiments_root = osp.join(opt['path']['root'], 'experiments',
68 | opt['name'])
69 | opt['path']['experiments_root'] = experiments_root
70 | opt['path']['models'] = osp.join(experiments_root, 'models')
71 | opt['path']['training_states'] = osp.join(experiments_root,
72 | 'training_states')
73 | opt['path']['log'] = experiments_root
74 | opt['path']['visualization'] = osp.join(experiments_root,
75 | 'visualization')
76 |
77 | # change some options for debug mode
78 | if 'debug' in opt['name']:
79 | if 'val' in opt:
80 | opt['val']['val_freq'] = 8
81 | opt['logger']['print_freq'] = 1
82 | opt['logger']['save_checkpoint_freq'] = 8
83 | else: # test
84 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
85 | opt['path']['results_root'] = results_root
86 | opt['path']['log'] = results_root
87 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
88 |
89 | return opt
90 |
91 |
92 | def dict2str(opt, indent_level=1):
93 | """dict to string for printing options.
94 |
95 | Args:
96 | opt (dict): Option dict.
97 | indent_level (int): Indent level. Default: 1.
98 |
99 | Return:
100 | (str): Option string for printing.
101 | """
102 | msg = '\n'
103 | for k, v in opt.items():
104 | if isinstance(v, dict):
105 | msg += ' ' * (indent_level * 2) + k + ':['
106 | msg += dict2str(v, indent_level + 1)
107 | msg += ' ' * (indent_level * 2) + ']\n'
108 | else:
109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
110 | return msg
111 |
--------------------------------------------------------------------------------
/basicsr/version.py:
--------------------------------------------------------------------------------
1 | # GENERATED VERSION FILE
2 | # TIME: Wed Mar 9 22:05:30 2022
3 | __version__ = '1.2.0+10018c6'
4 | short_version = '1.2.0'
5 | version_info = (1, 2, 0)
6 |
--------------------------------------------------------------------------------
/images/Deblurring/input/109fromGOPR1096.MP4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/109fromGOPR1096.MP4.png
--------------------------------------------------------------------------------
/images/Deblurring/input/110fromGOPR1087.MP4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/110fromGOPR1087.MP4.png
--------------------------------------------------------------------------------
/images/Deblurring/input/1fromGOPR0950.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/1fromGOPR0950.png
--------------------------------------------------------------------------------
/images/Deblurring/input/1fromGOPR1096.MP4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deblurring/input/1fromGOPR1096.MP4.png
--------------------------------------------------------------------------------
/images/Dehazing/input/0003_0.8_0.2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0003_0.8_0.2.png
--------------------------------------------------------------------------------
/images/Dehazing/input/0010_0.95_0.16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0010_0.95_0.16.png
--------------------------------------------------------------------------------
/images/Dehazing/input/0014_0.8_0.12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0014_0.8_0.12.png
--------------------------------------------------------------------------------
/images/Dehazing/input/0048_0.9_0.2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/0048_0.9_0.2.png
--------------------------------------------------------------------------------
/images/Dehazing/input/1440_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/1440_10.png
--------------------------------------------------------------------------------
/images/Dehazing/input/1444_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Dehazing/input/1444_10.png
--------------------------------------------------------------------------------
/images/Denoising/input/0003_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0003_30.png
--------------------------------------------------------------------------------
/images/Denoising/input/0011_23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0011_23.png
--------------------------------------------------------------------------------
/images/Denoising/input/0013_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0013_19.png
--------------------------------------------------------------------------------
/images/Denoising/input/0039_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Denoising/input/0039_04.png
--------------------------------------------------------------------------------
/images/Deraining/input/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/0.jpg
--------------------------------------------------------------------------------
/images/Deraining/input/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/1.png
--------------------------------------------------------------------------------
/images/Deraining/input/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/15.png
--------------------------------------------------------------------------------
/images/Deraining/input/55.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Deraining/input/55.png
--------------------------------------------------------------------------------
/images/Enhancement/input/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/1.png
--------------------------------------------------------------------------------
/images/Enhancement/input/111.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/111.png
--------------------------------------------------------------------------------
/images/Enhancement/input/748.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/748.png
--------------------------------------------------------------------------------
/images/Enhancement/input/a4541-DSC_0040-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Enhancement/input/a4541-DSC_0040-2.png
--------------------------------------------------------------------------------
/images/Results/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/0.jpg
--------------------------------------------------------------------------------
/images/Results/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/1.png
--------------------------------------------------------------------------------
/images/Results/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/15.png
--------------------------------------------------------------------------------
/images/Results/55.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/Results/55.png
--------------------------------------------------------------------------------
/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vztu/maxim-pytorch/3a6e901d483ac6d9bf47c2e21e30ce2c189f6470/images/overview.png
--------------------------------------------------------------------------------
/maxim_pytorch/README.md:
--------------------------------------------------------------------------------
1 | ## PyTorch re-implementation of MAXIM.
2 | `maxim_torch.py` is the PyTorch re-implementation of 3-stage MAXIM architecture for image denoising.
3 |
4 |
5 | `jax2torch.py` is leveraged to convert JAX weights (from a pretrained checkpoint) to PyTorch, and then save it as a dictionary which can be loaded directly to a PyTorch-implemented MAXIM model. To use this script, you should first download the pretrained JAX model from the official directory.
6 |
7 | It should be noted that due to the incompatibility between `flax.linen.ConvTranspose` and `torch.nn.ConvTranspose2d`, even if you load exactly the same pretrained parameters, the outputs of JAX model and PyTorch model are not exactly the same, though the difference is small.
--------------------------------------------------------------------------------
/maxim_pytorch/jax2torch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | #convert pretrained Jax params of MAXIM to Pytorch
4 | import argparse
5 | import collections
6 | import io
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | import tensorflow as tf
12 | from maxim_torch import MAXIM_dns_3s
13 |
14 |
15 | def recover_tree(keys, values):
16 | """Recovers a tree as a nested dict from flat names and values.
17 | This function is useful to analyze checkpoints that are saved by our programs
18 | without need to access the exact source code of the experiment. In particular,
19 | it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
20 | subtree of parameters.
21 | Args:
22 | keys: a list of keys, where '/' is used as separator between nodes.
23 | values: a list of leaf values.
24 | Returns:
25 | A nested tree-like dict.
26 | """
27 | tree = {}
28 | sub_trees = collections.defaultdict(list)
29 | for k, v in zip(keys, values):
30 | if "/" not in k:
31 | tree[k] = v
32 | else:
33 | k_left, k_right = k.split("/", 1)
34 | sub_trees[k_left].append((k_right, v))
35 | for k, kv_pairs in sub_trees.items():
36 | k_subtree, v_subtree = zip(*kv_pairs)
37 | tree[k] = recover_tree(k_subtree, v_subtree)
38 | return tree
39 |
40 | def get_params(ckpt_path):
41 | """Get params checkpoint."""
42 | with tf.io.gfile.GFile(ckpt_path, "rb") as f:
43 | data = f.read()
44 | values = np.load(io.BytesIO(data))
45 | params = recover_tree(*zip(*values.items()))
46 | params = params["opt"]["target"]
47 | return params
48 |
49 | def modify_jax_params(flat_jax_dict):
50 | modified_dict = {}
51 | for key, value in flat_jax_dict.items():
52 | key_split = key.split("/")
53 | modified_value = torch.tensor(value, dtype=torch.float)
54 |
55 |
56 | #modify values
57 | num_dim = len(modified_value.shape)
58 | if num_dim == 1:
59 | modified_value = modified_value.squeeze()
60 | elif num_dim == 2 and key_split[-1] == 'kernel':
61 | # for normal weight, transpose it
62 | modified_value = modified_value.T
63 | elif num_dim == 4 and key_split[-1] == 'kernel':
64 | modified_value = modified_value.permute(3, 2, 0, 1)
65 | if num_dim ==4 and key_split[-2] == 'ConvTranspose_0' and key_split[-1] == 'kernel':
66 | modified_value = modified_value.permute(1, 0, 2, 3)
67 |
68 |
69 | #modify keys
70 | modified_key = (".".join(key_split[:]))
71 | if "kernel" in modified_key:
72 | modified_key = modified_key.replace("kernel", "weight")
73 | if "LayerNorm" in modified_key:
74 | modified_key = modified_key.replace("scale", "gamma")
75 | modified_key = modified_key.replace("bias", "beta")
76 | if "layernorm" in modified_key:
77 | modified_key = modified_key.replace("scale", "gamma")
78 | modified_key = modified_key.replace("bias", "beta")
79 |
80 | modified_dict[modified_key] = modified_value
81 |
82 | return modified_dict
83 |
84 |
85 | def main(args):
86 | jax_params = get_params(args.ckpt_path)
87 | [flat_jax_dict] = pd.json_normalize(jax_params, sep="/").to_dict(orient="records") #set separation sign
88 |
89 | # Amend the JAX variables to match the names of the torch variables.
90 | modified_jax_params = modify_jax_params(flat_jax_dict)
91 |
92 | # update and save
93 | model = MAXIM_dns_3s()
94 | maxim_dict = model.state_dict()
95 | maxim_dict.update(modified_jax_params)
96 | torch.save(maxim_dict, args.output_file)
97 |
98 | def parse_args():
99 | parser = argparse.ArgumentParser(
100 | description="Conversion of the JAX pre-trained MAXIM weights to Pytorch."
101 | )
102 | parser.add_argument(
103 | "-c",
104 | "--ckpt_path",
105 | default="maxim_ckpt_Denoising_SIDD_checkpoint.npz",
106 | type=str,
107 | help="Checkpoint to port.",
108 | )
109 | parser.add_argument(
110 | "-o",
111 | "--output_file",
112 | default="torch_weight.pth",
113 | type=str,
114 | help="Output.",
115 | )
116 | return parser.parse_args()
117 |
118 |
119 | if __name__ == "__main__":
120 | args = parse_args()
121 | main(args)
122 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # line break before binary operator (W503)
4 | W503,
5 | # line break after binary operator (W504)
6 | W504,
7 | max-line-length=79
8 |
9 | [yapf]
10 | based_on_style = pep8
11 | blank_line_before_nested_class_or_def = true
12 | split_before_expression_after_opening_paren = true
13 |
14 | [isort]
15 | line_length = 79
16 | multi_line_output = 0
17 | known_standard_library = pkg_resources,setuptools
18 | known_first_party = basicsr
19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml
20 | no_lines_before = STDLIB,LOCALFOLDER
21 | default_section = THIRDPARTY
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import sys
8 | import time
9 | import torch
10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension,
11 | CUDAExtension)
12 |
13 | version_file = 'basicsr/version.py'
14 |
15 |
16 | def readme():
17 | return ''
18 | # with open('README.md', encoding='utf-8') as f:
19 | # content = f.read()
20 | # return content
21 |
22 |
23 | def get_git_hash():
24 |
25 | def _minimal_ext_cmd(cmd):
26 | # construct minimal environment
27 | env = {}
28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
29 | v = os.environ.get(k)
30 | if v is not None:
31 | env[k] = v
32 | # LANGUAGE is used on win32
33 | env['LANGUAGE'] = 'C'
34 | env['LANG'] = 'C'
35 | env['LC_ALL'] = 'C'
36 | out = subprocess.Popen(
37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
38 | return out
39 |
40 | try:
41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
42 | sha = out.strip().decode('ascii')
43 | except OSError:
44 | sha = 'unknown'
45 |
46 | return sha
47 |
48 |
49 | def get_hash():
50 | if os.path.exists('.git'):
51 | sha = get_git_hash()[:7]
52 | elif os.path.exists(version_file):
53 | try:
54 | from basicsr.version import __version__
55 | sha = __version__.split('+')[-1]
56 | except ImportError:
57 | raise ImportError('Unable to get git version')
58 | else:
59 | sha = 'unknown'
60 |
61 | return sha
62 |
63 |
64 | def write_version_py():
65 | content = """# GENERATED VERSION FILE
66 | # TIME: {}
67 | __version__ = '{}'
68 | short_version = '{}'
69 | version_info = ({})
70 | """
71 | sha = get_hash()
72 | with open('VERSION', 'r') as f:
73 | SHORT_VERSION = f.read().strip()
74 | VERSION_INFO = ', '.join(
75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
76 | VERSION = SHORT_VERSION + '+' + sha
77 |
78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
79 | VERSION_INFO)
80 | with open(version_file, 'w') as f:
81 | f.write(version_file_str)
82 |
83 |
84 | def get_version():
85 | with open(version_file, 'r') as f:
86 | exec(compile(f.read(), version_file, 'exec'))
87 | return locals()['__version__']
88 |
89 |
90 | def make_cuda_ext(name, module, sources, sources_cuda=None):
91 | if sources_cuda is None:
92 | sources_cuda = []
93 | define_macros = []
94 | extra_compile_args = {'cxx': []}
95 |
96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
97 | define_macros += [('WITH_CUDA', None)]
98 | extension = CUDAExtension
99 | extra_compile_args['nvcc'] = [
100 | '-D__CUDA_NO_HALF_OPERATORS__',
101 | '-D__CUDA_NO_HALF_CONVERSIONS__',
102 | '-D__CUDA_NO_HALF2_OPERATORS__',
103 | ]
104 | sources += sources_cuda
105 | else:
106 | print(f'Compiling {name} without CUDA')
107 | extension = CppExtension
108 |
109 | return extension(
110 | name=f'{module}.{name}',
111 | sources=[os.path.join(*module.split('.'), p) for p in sources],
112 | define_macros=define_macros,
113 | extra_compile_args=extra_compile_args)
114 |
115 |
116 | def get_requirements(filename='requirements.txt'):
117 | return []
118 | here = os.path.dirname(os.path.realpath(__file__))
119 | with open(os.path.join(here, filename), 'r') as f:
120 | requires = [line.replace('\n', '') for line in f.readlines()]
121 | return requires
122 |
123 |
124 | if __name__ == '__main__':
125 | if '--no_cuda_ext' in sys.argv:
126 | ext_modules = []
127 | sys.argv.remove('--no_cuda_ext')
128 | else:
129 | ext_modules = [
130 | make_cuda_ext(
131 | name='deform_conv_ext',
132 | module='basicsr.models.ops.dcn',
133 | sources=['src/deform_conv_ext.cpp'],
134 | sources_cuda=[
135 | 'src/deform_conv_cuda.cpp',
136 | 'src/deform_conv_cuda_kernel.cu'
137 | ]),
138 | make_cuda_ext(
139 | name='fused_act_ext',
140 | module='basicsr.models.ops.fused_act',
141 | sources=['src/fused_bias_act.cpp'],
142 | sources_cuda=['src/fused_bias_act_kernel.cu']),
143 | make_cuda_ext(
144 | name='upfirdn2d_ext',
145 | module='basicsr.models.ops.upfirdn2d',
146 | sources=['src/upfirdn2d.cpp'],
147 | sources_cuda=['src/upfirdn2d_kernel.cu']),
148 | ]
149 |
150 | write_version_py()
151 | setup(
152 | name='basicsr',
153 | version=get_version(),
154 | description='Open Source Image and Video Super-Resolution Toolbox',
155 | long_description=readme(),
156 | author='Xintao Wang',
157 | author_email='xintao.wang@outlook.com',
158 | keywords='computer vision, restoration, super resolution',
159 | url='https://github.com/xinntao/BasicSR',
160 | packages=find_packages(
161 | exclude=('options', 'datasets', 'experiments', 'results',
162 | 'tb_logger', 'wandb')),
163 | classifiers=[
164 | 'Development Status :: 4 - Beta',
165 | 'License :: OSI Approved :: Apache Software License',
166 | 'Operating System :: OS Independent',
167 | 'Programming Language :: Python :: 3',
168 | 'Programming Language :: Python :: 3.7',
169 | 'Programming Language :: Python :: 3.8',
170 | ],
171 | license='Apache License 2.0',
172 | setup_requires=['cython', 'numpy'],
173 | install_requires=get_requirements(),
174 | ext_modules=ext_modules,
175 | cmdclass={'build_ext': BuildExtension},
176 | zip_safe=False)
177 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 |
5 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt $CONFIG --launcher pytorch
--------------------------------------------------------------------------------