├── Dehaze
├── Options
│ └── RealDehazing_FPro.yml
├── evaluate_SOTS.py
├── test_SOTS.py
└── utils.py
├── Demoiring
├── Options
│ └── RealDemoiring_FPro.yml
├── dataset_demoire.py
├── evaluate_demoire.py
├── test_moire.py
└── utils.py
├── Deraining
├── Options
│ ├── Deraining_FPro_spad.yml
│ └── RealDeraindrop_FPro.yml
├── evaluate_PSNR_SSIM.m
├── evaluate_raindrop.py
├── test_AGAN.py
├── test_spad.py
└── utils.py
├── INSTALL.md
├── Motion_Deblurring
├── Options
│ └── Deblurring_FPro.yml
├── evaluate_gopro_hide.m
├── generate_patches_gopro.py
├── test_FPro.py
└── utils.py
├── README.md
├── basicsr
├── .DS_Store
├── __pycache__
│ └── version.cpython-37.pyc
├── data
│ ├── .DS_Store
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── data_sampler.cpython-37.pyc
│ │ ├── data_util.cpython-37.pyc
│ │ ├── ffhq_dataset.cpython-37.pyc
│ │ ├── paired_image_dataset.cpython-37.pyc
│ │ ├── prefetch_dataloader.cpython-37.pyc
│ │ ├── reds_dataset.cpython-37.pyc
│ │ ├── single_image_dataset.cpython-37.pyc
│ │ ├── transforms.cpython-37.pyc
│ │ ├── video_test_dataset.cpython-37.pyc
│ │ └── vimeo90k_dataset.cpython-37.pyc
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── ffhq_dataset.py
│ ├── paired_image_dataset.py
│ ├── prefetch_dataloader.py
│ ├── reds_dataset.py
│ ├── single_image_dataset.py
│ ├── transforms.py
│ ├── video_test_dataset.py
│ └── vimeo90k_dataset.py
├── metrics
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── metric_util.cpython-37.pyc
│ │ ├── niqe.cpython-37.pyc
│ │ └── psnr_ssim.cpython-37.pyc
│ ├── fid.py
│ ├── metric_util.py
│ ├── niqe.py
│ ├── niqe_pris_params.npz
│ └── psnr_ssim.py
├── models
│ ├── .DS_Store
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── base_model.cpython-37.pyc
│ │ ├── image_restoration_model.cpython-37.pyc
│ │ └── lr_scheduler.cpython-37.pyc
│ ├── archs
│ │ ├── FPro_arch.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── arch_util.cpython-37.pyc
│ │ │ ├── graph_layers.cpython-37.pyc
│ │ │ └── local_arch.cpython-37.pyc
│ │ └── arch_util.py
│ ├── base_model.py
│ ├── image_restoration_model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── loss_util.cpython-37.pyc
│ │ │ └── losses.cpython-37.pyc
│ │ ├── loss_util.py
│ │ └── losses.py
│ └── lr_scheduler.py
├── test.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── create_lmdb.cpython-37.pyc
│ │ ├── dist_util.cpython-37.pyc
│ │ ├── file_client.cpython-37.pyc
│ │ ├── flow_util.cpython-37.pyc
│ │ ├── img_util.cpython-37.pyc
│ │ ├── lmdb_util.cpython-37.pyc
│ │ ├── logger.cpython-37.pyc
│ │ ├── matlab_functions.cpython-37.pyc
│ │ ├── misc.cpython-37.pyc
│ │ └── options.cpython-37.pyc
│ ├── bundle_submissions.py
│ ├── create_lmdb.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── face_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ └── options.py
└── version.py
├── setup.py
├── test.sh
└── train.sh
/Dehaze/Options/RealDehazing_FPro.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Dehazing_FPro
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage_dehazeSOT
13 | dataroot_gt: /mnt/sda/zsh/dataset/haze
14 | dataroot_lq: /mnt/sda/zsh/dataset/haze
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ## ------- Training on single fixed-patch size 128x128---------
27 | mini_batch_sizes: [2]
28 | iters: [300000]
29 | gt_size: 256
30 | gt_sizes: [256]
31 | ## ------------------------------------------------------------
32 |
33 | dataset_enlarge_ratio: 1
34 | prefetch_mode: ~
35 |
36 | val:
37 | name: ValSet
38 | type: Dataset_PairedImage_dehazeSOT
39 | dataroot_gt: /mnt/sda/zsh/dataset/haze
40 | dataroot_lq: /mnt/sda/zsh/dataset/haze
41 | gt_size: 256
42 | io_backend:
43 | type: disk
44 |
45 | # network structures
46 |
47 | network_g:
48 | type: FPro
49 | inp_channels: 3
50 | out_channels: 3
51 | # input_res: 128
52 | dim: 48
53 | # num_blocks: [4,6,6,8]
54 | num_blocks: [2,3,6]
55 | # num_refinement_blocks: 4
56 | num_refinement_blocks: 2
57 | # heads: [1,2,4,8]
58 | heads: [2,4,8]
59 | # ffn_expansion_factor: 2.66
60 | ffn_expansion_factor: 3
61 | bias: False
62 | LayerNorm_type: WithBias
63 | dual_pixel_task: False
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: true
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 | fft_loss_opt:
104 | type: FFTLoss
105 | loss_weight: 0.1
106 | reduction: mean
107 |
108 | # validation settings
109 | val:
110 | window_size: 8
111 | val_freq: !!float 4e3
112 | save_img: false
113 | rgb2bgr: true
114 | use_image: false
115 | max_minibatch: 8
116 |
117 | metrics:
118 | psnr: # metric name, can be arbitrary
119 | type: calculate_psnr
120 | crop_border: 0
121 | test_y_channel: false
122 |
123 | # logging settings
124 | logger:
125 | print_freq: 1000
126 | save_checkpoint_freq: !!float 4e3
127 | use_tb_logger: true
128 | wandb:
129 | project: ~
130 | resume_id: ~
131 |
132 | # dist training settings
133 | dist_params:
134 | backend: nccl
135 | port: 29500
136 |
--------------------------------------------------------------------------------
/Dehaze/evaluate_SOTS.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from natsort import natsorted
9 | from skimage import io
10 | import cv2
11 | import argparse
12 | from skimage.metrics import structural_similarity
13 | from tqdm import tqdm
14 | import concurrent.futures
15 | import utils
16 |
17 | def proc(filename):
18 | tar,prd = filename
19 | prd_name = prd.split('/')[-1].split('_')[0]+'.png'
20 | tar_name = '/mnt/sda/zsh/dataset/haze/promptIR/outdoor/gt/' + prd_name
21 | # print('tar',tar)
22 | # print('prd',prd)
23 | tar_img = utils.load_img(tar_name)
24 | prd_img = utils.load_img(prd)
25 |
26 | PSNR = utils.calculate_psnr(tar_img, prd_img)
27 | SSIM = utils.calculate_ssim(tar_img, prd_img)
28 | return PSNR,SSIM
29 |
30 | parser = argparse.ArgumentParser(description='Dehazing using FPro')
31 |
32 | args = parser.parse_args()
33 |
34 |
35 | datasets = ['outdoor']
36 |
37 | for dataset in datasets:
38 |
39 | gt_path = os.path.join('/mnt/sda/zsh/dataset/haze/promptIR/outdoor/gt')
40 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif')))
41 | assert len(gt_list) != 0, "Target files not found"
42 |
43 |
44 | file_path = os.path.join('results', 'FPro', dataset)
45 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif')))
46 | assert len(path_list) != 0, "Predicted files not found"
47 |
48 | psnr, ssim = [], []
49 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
50 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
51 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
52 | psnr.append(PSNR_SSIM[0])
53 | ssim.append(PSNR_SSIM[1])
54 |
55 | avg_psnr = sum(psnr)/len(psnr)
56 | avg_ssim = sum(ssim)/len(ssim)
57 |
58 | # print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr))
59 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
60 |
--------------------------------------------------------------------------------
/Dehaze/test_SOTS.py:
--------------------------------------------------------------------------------
1 | ## Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration
2 |
3 | import numpy as np
4 | import os
5 | import argparse
6 | from tqdm import tqdm
7 |
8 | import torch.nn as nn
9 | import torch
10 | import torch.nn.functional as F
11 | import utils
12 |
13 | from natsort import natsorted
14 | from glob import glob
15 | from basicsr.models.archs.FPro_arch import FPro
16 | from skimage import img_as_ubyte
17 | from pdb import set_trace as stx
18 |
19 | parser = argparse.ArgumentParser(description='Image Dehazning using FPro')
20 |
21 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/dataset/haze/promptIR/', type=str, help='Directory of validation images')
22 | parser.add_argument('--result_dir', default='./results/FPro/', type=str, help='Directory for results')
23 | parser.add_argument('--weights', default='/mnt/sda/zsh/FPro/Dehaze/models/synDehaze.pth', type=str, help='Path to weights')
24 |
25 | args = parser.parse_args()
26 |
27 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
28 | _, C, H, W = imgtensor.shape
29 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
30 | while hstarts and hstarts[-1] + crop_size >= H:
31 | hstarts.pop()
32 | hstarts.append(H - crop_size)
33 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
34 | while wstarts and wstarts[-1] + crop_size >= W:
35 | wstarts.pop()
36 | wstarts.append(W - crop_size)
37 | starts = []
38 | split_data = []
39 | for hs in hstarts:
40 | for ws in wstarts:
41 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
42 | starts.append((hs, ws))
43 | split_data.append(cimgdata)
44 | return split_data, starts
45 |
46 | def get_scoremap(H, W, C, B=1, is_mean=True):
47 | center_h = H / 2
48 | center_w = W / 2
49 |
50 | score = torch.ones((B, C, H, W))
51 | if not is_mean:
52 | for h in range(H):
53 | for w in range(W):
54 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
55 | return score
56 |
57 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
58 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
59 | tot_score = torch.zeros((B, C, H, W))
60 | merge_img = torch.zeros((B, C, H, W))
61 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
62 | for simg, cstart in zip(split_data, starts):
63 | hs, ws = cstart
64 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
65 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
66 | merge_img = merge_img / tot_score
67 | return merge_img
68 |
69 | ####### Load yaml #######
70 | yaml_file = 'Options/RealDehazing_FPro.yml'
71 | import yaml
72 |
73 | try:
74 | from yaml import CLoader as Loader
75 | except ImportError:
76 | from yaml import Loader
77 |
78 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
79 |
80 | s = x['network_g'].pop('type')
81 | ##########################
82 |
83 | model_restoration = FPro(**x['network_g'])
84 |
85 | checkpoint = torch.load(args.weights)
86 | model_restoration.load_state_dict(checkpoint['params'])
87 | print("===>Testing using weights: ",args.weights)
88 | model_restoration.cuda()
89 | model_restoration = nn.DataParallel(model_restoration)
90 | model_restoration.eval()
91 |
92 |
93 | factor = 8
94 | datasets = ['outdoor']
95 |
96 | for dataset in datasets:
97 | result_dir = os.path.join(args.result_dir, dataset)
98 | os.makedirs(result_dir, exist_ok=True)
99 |
100 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain')
101 | inp_dir = os.path.join(args.input_dir, dataset, 'hazy/')
102 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
103 | with torch.no_grad():
104 | for file_ in tqdm(files):
105 | torch.cuda.ipc_collect()
106 | torch.cuda.empty_cache()
107 |
108 | img = np.float32(utils.load_img(file_))/255.
109 | img = torch.from_numpy(img).permute(2,0,1)
110 | input_ = img.unsqueeze(0).cuda()
111 |
112 | # Padding in case images are not multiples of 8
113 | h,w = input_.shape[2], input_.shape[3]
114 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
115 | padh = H-h if h%factor!=0 else 0
116 | padw = W-w if w%factor!=0 else 0
117 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
118 |
119 | B, C, H, W = input_.shape
120 | corp_size_arg = 256
121 | overlap_size_arg = 158
122 | # corp_size_arg = 512
123 | # overlap_size_arg = 204
124 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
125 | for i, data in enumerate(split_data):
126 | split_data[i] = model_restoration(data).cpu()
127 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
128 | # rgb_restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy()
129 |
130 | # restored = rgb_restored
131 | # restored = model_restoration(input_)
132 |
133 | # Unpad images to original dimensions
134 | restored = restored[:,:,:h,:w]
135 |
136 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
137 |
138 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
139 |
--------------------------------------------------------------------------------
/Dehaze/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Demoiring/Options/RealDemoiring_FPro.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: RealDemoiring_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage_denseHaze
13 | dataroot_gt: /home/ubuntu/zsh/datasets/TIP18/process/train/thin_target
14 | dataroot_lq: /home/ubuntu/zsh/datasets/TIP18/process/train/thin_source
15 | geometric_augs: False
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ## ------- Training on single fixed-patch size 128x128---------
27 | mini_batch_sizes: [2]
28 | iters: [300000]
29 | gt_size: 256
30 | gt_sizes: [256]
31 | ## ------------------------------------------------------------
32 |
33 | dataset_enlarge_ratio: 1
34 | prefetch_mode: ~
35 |
36 | val:
37 | name: ValSet
38 | type: Dataset_PairedImage_denseHaze
39 | dataroot_gt: /home/ubuntu/zsh/datasets/TIP18/process/val/thin_target
40 | dataroot_lq: /home/ubuntu/zsh/datasets/TIP18/process/val/thin_source
41 | gt_size: 256
42 | io_backend:
43 | type: disk
44 |
45 | # network structures
46 |
47 | network_g:
48 | type: Restormer
49 | inp_channels: 3
50 | out_channels: 3
51 | # input_res: 128
52 | dim: 48
53 | # num_blocks: [4,6,6,8]
54 | num_blocks: [2,3,6]
55 | # num_refinement_blocks: 4
56 | num_refinement_blocks: 2
57 | # heads: [1,2,4,8]
58 | heads: [2,4,8]
59 | # ffn_expansion_factor: 2.66
60 | ffn_expansion_factor: 3
61 | bias: False
62 | LayerNorm_type: WithBias
63 | dual_pixel_task: False
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: true
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 | fft_loss_opt:
104 | type: FFTLoss
105 | loss_weight: 0.1
106 | reduction: mean
107 |
108 | # validation settings
109 | val:
110 | window_size: 8
111 | val_freq: !!float 4e3
112 | save_img: false
113 | rgb2bgr: true
114 | use_image: false
115 | max_minibatch: 8
116 |
117 | metrics:
118 | psnr: # metric name, can be arbitrary
119 | type: calculate_psnr
120 | crop_border: 0
121 | test_y_channel: false
122 |
123 | # logging settings
124 | logger:
125 | print_freq: 1000
126 | save_checkpoint_freq: !!float 4e3
127 | use_tb_logger: true
128 | wandb:
129 | project: ~
130 | resume_id: ~
131 |
132 | # dist training settings
133 | dist_params:
134 | backend: nccl
135 | port: 29500
136 |
--------------------------------------------------------------------------------
/Demoiring/evaluate_demoire.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from natsort import natsorted
9 | from skimage import io
10 | import cv2
11 | import argparse
12 | from skimage.metrics import structural_similarity
13 | from tqdm import tqdm
14 | import concurrent.futures
15 | import utils
16 |
17 | def proc(filename):
18 | tar,prd = filename
19 | tar_img = utils.load_img(tar)
20 | prd_img = utils.load_img(prd)
21 |
22 | PSNR = utils.calculate_psnr(tar_img, prd_img)
23 | SSIM = utils.calculate_ssim(tar_img, prd_img)
24 | return PSNR,SSIM
25 |
26 | parser = argparse.ArgumentParser(description='Demoireing using FPro')
27 |
28 | args = parser.parse_args()
29 |
30 |
31 | datasets = ['TIP18']
32 |
33 | for dataset in datasets:
34 | #/home/ubuntu/zsh/datasets/TIP18/process/test_resize286_crop256/thin_target
35 | #/home/ubuntu/zsh/datasets/TIP18/process/test_256/thin_target
36 | gt_path = os.path.join('/mnt/sda/zsh/FPro/Demoiring/test_resize286_crop256/thin_target')
37 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif')))
38 | assert len(gt_list) != 0, "Target files not found"
39 |
40 |
41 | file_path = os.path.join('results/', 'FPro_test/', dataset)
42 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif')))
43 | assert len(path_list) != 0, "Predicted files not found"
44 |
45 | psnr, ssim = [], []
46 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
47 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
48 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
49 | psnr.append(PSNR_SSIM[0])
50 | ssim.append(PSNR_SSIM[1])
51 |
52 | avg_psnr = sum(psnr)/len(psnr)
53 | avg_ssim = sum(ssim)/len(ssim)
54 |
55 | print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr))
56 | print('For {:s} dataset SSIM: {:f}\n'.format(dataset, avg_ssim))
57 | # print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
58 |
--------------------------------------------------------------------------------
/Demoiring/test_moire.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import os
4 | import argparse
5 | from tqdm import tqdm
6 |
7 | import torch.nn as nn
8 | import torch
9 | import torch.nn.functional as F
10 | import utils
11 |
12 | from natsort import natsorted
13 | from glob import glob
14 | from basicsr.models.archs.FPro_arch import FPro
15 | from skimage import img_as_ubyte
16 | from pdb import set_trace as stx
17 |
18 | parser = argparse.ArgumentParser(description='Image Demoireing using FPro')
19 | #test_resize286_crop256 test_256
20 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/FPro/Demoiring/test_resize286_crop256/thin_source', type=str, help='Directory of validation images')
21 | parser.add_argument('--result_dir', default='./results/FPro_test/', type=str, help='Directory for results')
22 | parser.add_argument('--weights', default='./models/demoire_noAug.pth', type=str, help='Path to weights')
23 |
24 | args = parser.parse_args()
25 |
26 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
27 | _, C, H, W = imgtensor.shape
28 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
29 | while hstarts and hstarts[-1] + crop_size >= H:
30 | hstarts.pop()
31 | hstarts.append(H - crop_size)
32 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
33 | while wstarts and wstarts[-1] + crop_size >= W:
34 | wstarts.pop()
35 | wstarts.append(W - crop_size)
36 | starts = []
37 | split_data = []
38 | for hs in hstarts:
39 | for ws in wstarts:
40 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
41 | starts.append((hs, ws))
42 | split_data.append(cimgdata)
43 | return split_data, starts
44 |
45 | def get_scoremap(H, W, C, B=1, is_mean=True):
46 | center_h = H / 2
47 | center_w = W / 2
48 |
49 | score = torch.ones((B, C, H, W))
50 | if not is_mean:
51 | for h in range(H):
52 | for w in range(W):
53 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
54 | return score
55 |
56 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
57 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
58 | tot_score = torch.zeros((B, C, H, W))
59 | merge_img = torch.zeros((B, C, H, W))
60 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
61 | for simg, cstart in zip(split_data, starts):
62 | hs, ws = cstart
63 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
64 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
65 | merge_img = merge_img / tot_score
66 | return merge_img
67 |
68 | ####### Load yaml #######
69 | yaml_file = 'Options/RealDemoiring_FPro.yml'
70 | import yaml
71 |
72 | try:
73 | from yaml import CLoader as Loader
74 | except ImportError:
75 | from yaml import Loader
76 |
77 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
78 |
79 | s = x['network_g'].pop('type')
80 | ##########################
81 |
82 | model_restoration = FPro(**x['network_g'])
83 |
84 | checkpoint = torch.load(args.weights)
85 | model_restoration.load_state_dict(checkpoint['params'])
86 | print("===>Testing using weights: ",args.weights)
87 | model_restoration.cuda()
88 | model_restoration = nn.DataParallel(model_restoration)
89 | model_restoration.eval()
90 |
91 |
92 | factor = 8
93 | datasets = ['TIP18']
94 |
95 | for dataset in datasets:
96 | result_dir = os.path.join(args.result_dir, dataset)
97 | os.makedirs(result_dir, exist_ok=True)
98 |
99 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain')
100 | inp_dir = os.path.join(args.input_dir)
101 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
102 | with torch.no_grad():
103 | for file_ in tqdm(files):
104 | torch.cuda.ipc_collect()
105 | torch.cuda.empty_cache()
106 |
107 | img = np.float32(utils.load_img(file_))/255.
108 | img = torch.from_numpy(img).permute(2,0,1)
109 | input_ = img.unsqueeze(0).cuda()
110 |
111 | # Padding in case images are not multiples of 8
112 | h,w = input_.shape[2], input_.shape[3]
113 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
114 | padh = H-h if h%factor!=0 else 0
115 | padw = W-w if w%factor!=0 else 0
116 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
117 |
118 | restored = model_restoration(input_)
119 | restored = restored[:,:,:h,:w]
120 |
121 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
122 |
123 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
124 |
--------------------------------------------------------------------------------
/Demoiring/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Deraining/Options/Deraining_FPro_spad.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Deraining_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 2 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage_derainSpad
13 | dataroot_gt: /home/ubuntu/zsh/datasets/derain/real_world_gt
14 | dataroot_lq: /home/ubuntu/zsh/datasets/derain/real_world
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{/home/ubuntu/zsh/datasets/derain}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | # ### -------------Progressive training--------------------------
27 | # # mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
28 | # mini_batch_sizes: [6,4,3,1] # Batch size per gpu
29 | # # mini_batch_sizes: [20,16,12,8] # Batch size per gpu
30 | # # iters: [92000,64000,48000,36000,36000,24000]
31 | # iters: [152000,74000,48000,26000]
32 | # # gt_size: 384 # Max patch size for progressive training
33 | # gt_size: 256 # Max patch size for progressive training
34 | # gt_sizes: [128,160,192,256] # Patch sizes for progressive training.
35 | # ### ------------------------------------------------------------
36 |
37 | ### ------- Training on single fixed-patch size 128x128---------
38 | # mini_batch_sizes: [8]
39 | # iters: [300000]
40 | # gt_size: 128
41 | # gt_sizes: [128]
42 | ## ------------------------------------------------------------
43 | ## ------- Training on single fixed-patch size 128x128---------
44 | mini_batch_sizes: [2]
45 | iters: [300000]
46 | gt_size: 256
47 | gt_sizes: [256]
48 | ## ------------------------------------------------------------
49 |
50 | dataset_enlarge_ratio: 1
51 | prefetch_mode: ~
52 |
53 | val:
54 | name: ValSet
55 | type: Dataset_PairedImage_derainSpad
56 | dataroot_gt: /home/ubuntu/zsh/datasets/derain/real_test_1000/gt
57 | dataroot_lq: /home/ubuntu/zsh/datasets/derain/real_test_1000/rain
58 | gt_size: 256
59 | io_backend:
60 | type: disk
61 |
62 | # network structures
63 | network_g:
64 | type: Restormer
65 | inp_channels: 3
66 | out_channels: 3
67 | # input_res: 128
68 | dim: 48
69 | # num_blocks: [4,6,6,8]
70 | num_blocks: [2,3,6]
71 | # num_refinement_blocks: 4
72 | num_refinement_blocks: 2
73 | # heads: [1,2,4,8]
74 | heads: [2,4,8]
75 | # ffn_expansion_factor: 2.66
76 | ffn_expansion_factor: 3
77 | bias: False
78 | LayerNorm_type: WithBias
79 | dual_pixel_task: False
80 | # type: Restormer
81 | # inp_channels: 3
82 | # out_channels: 3
83 | # # input_res: 128
84 | # dim: 48
85 | # num_blocks: [4,6,6,8]
86 | # # num_blocks: [1,3,6]
87 | # num_refinement_blocks: 4
88 | # # num_refinement_blocks: 2
89 | # heads: [1,2,4,8]
90 | # ffn_expansion_factor: 2.66
91 | # # ffn_expansion_factor: 3
92 | # bias: False
93 | # LayerNorm_type: WithBias
94 | # dual_pixel_task: False
95 |
96 |
97 | # path
98 | path:
99 | pretrain_network_g: ~
100 | strict_load_g: true
101 | resume_state: ~
102 |
103 | # training settings
104 | train:
105 | # total_iter: 300000
106 | total_iter: 300000
107 | warmup_iter: -1 # no warm up
108 | use_grad_clip: true
109 |
110 | # Split 300k iterations into two cycles.
111 | # 1st cycle: fixed 3e-4 LR for 92k iters.
112 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
113 | scheduler:
114 | type: CosineAnnealingRestartCyclicLR
115 | periods: [92000, 208000]
116 | # periods: [480000, 720000]
117 | restart_weights: [1,1]
118 | eta_mins: [0.0003,0.000001]
119 |
120 | mixing_augs:
121 | mixup: false
122 | mixup_beta: 1.2
123 | use_identity: true
124 |
125 | optim_g:
126 | type: AdamW
127 | lr: !!float 3e-4
128 | weight_decay: !!float 1e-4
129 | betas: [0.9, 0.999]
130 |
131 | # losses
132 | pixel_opt:
133 | type: L1Loss
134 | loss_weight: 1
135 | reduction: mean
136 |
137 | fft_loss_opt:
138 | type: FFTLoss
139 | loss_weight: 0.1
140 | reduction: mean
141 |
142 | # validation settings
143 | val:
144 | window_size: 8
145 | val_freq: !!float 4e3
146 | # val_freq: !!float 300e3
147 | save_img: true
148 | rgb2bgr: true
149 | use_image: true
150 | max_minibatch: 8
151 |
152 | metrics:
153 | psnr: # metric name, can be arbitrary
154 | type: calculate_psnr
155 | crop_border: 0
156 | test_y_channel: true
157 |
158 | # logging settings
159 | logger:
160 | print_freq: 1000
161 | save_checkpoint_freq: !!float 4e3
162 | use_tb_logger: true
163 | wandb:
164 | project: ~
165 | resume_id: ~
166 |
167 | # dist training settings
168 | dist_params:
169 | backend: nccl
170 | port: 29500
171 |
--------------------------------------------------------------------------------
/Deraining/Options/RealDeraindrop_FPro.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: RealDeraindrop_Restormer
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage_denseHaze
13 | dataroot_gt: /mnt/sda/dataset/raindrop/train/data
14 | dataroot_lq: /mnt/sda/dataset/raindrop/train/gt
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ## ------- Training on single fixed-patch size 128x128---------
27 | mini_batch_sizes: [2]
28 | iters: [300000]
29 | gt_size: 256
30 | gt_sizes: [256]
31 | ## ------------------------------------------------------------
32 |
33 | dataset_enlarge_ratio: 1
34 | prefetch_mode: ~
35 |
36 | val:
37 | name: ValSet
38 | type: Dataset_PairedImage_denseHaze
39 | dataroot_gt: /mnt/sda/dataset/raindrop/test_a/data
40 | dataroot_lq: /mnt/sda/dataset/raindrop/test_a/gt
41 | gt_size: 256
42 | io_backend:
43 | type: disk
44 |
45 | # network structures
46 |
47 | network_g:
48 | type: Restormer
49 | inp_channels: 3
50 | out_channels: 3
51 | # input_res: 128
52 | dim: 48
53 | # num_blocks: [4,6,6,8]
54 | num_blocks: [2,3,6]
55 | # num_refinement_blocks: 4
56 | num_refinement_blocks: 2
57 | # heads: [1,2,4,8]
58 | heads: [2,4,8]
59 | # ffn_expansion_factor: 2.66
60 | ffn_expansion_factor: 3
61 | bias: False
62 | LayerNorm_type: WithBias
63 | dual_pixel_task: False
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: true
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 | fft_loss_opt:
104 | type: FFTLoss
105 | loss_weight: 0.1
106 | reduction: mean
107 |
108 | # validation settings
109 | val:
110 | window_size: 8
111 | val_freq: !!float 4e3
112 | save_img: false
113 | rgb2bgr: true
114 | use_image: false
115 | max_minibatch: 8
116 |
117 | metrics:
118 | psnr: # metric name, can be arbitrary
119 | type: calculate_psnr
120 | crop_border: 0
121 | test_y_channel: false
122 |
123 | # logging settings
124 | logger:
125 | print_freq: 1000
126 | save_checkpoint_freq: !!float 4e3
127 | use_tb_logger: true
128 | wandb:
129 | project: ~
130 | resume_id: ~
131 |
132 | # dist training settings
133 | dist_params:
134 | backend: nccl
135 | port: 29500
136 |
--------------------------------------------------------------------------------
/Deraining/evaluate_raindrop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | from glob import glob
5 | from tqdm import tqdm
6 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr
7 | from skimage.metrics import structural_similarity as compare_ssim
8 |
9 |
10 | def calc_psnr(im1, im2):
11 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
12 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
13 | return compare_psnr(im1_y, im2_y)
14 |
15 |
16 | def calc_ssim(im1, im2):
17 | im1_y = cv2.cvtColor(im1, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
18 | im2_y = cv2.cvtColor(im2, cv2.COLOR_BGR2YCR_CB)[:, :, 0]
19 | return compare_ssim(im1_y, im2_y)
20 |
21 |
22 | def align_to_four(img):
23 | a_row = int(img.shape[0]/4)*4
24 | a_col = int(img.shape[1]/4)*4
25 | img = img[0:a_row, 0:a_col, :]
26 | return img
27 |
28 |
29 | def evaluate_raindrop(in_dir, gt_dir):
30 | inputs = sorted(glob(os.path.join(in_dir, '*.png')) + glob(os.path.join(in_dir, '*.jpg')))
31 | gts = sorted(glob(os.path.join(gt_dir, '*.png')) + glob(os.path.join(gt_dir, '*.jpg')))
32 | psnrs = []
33 | ssims = []
34 | for input, gt in tqdm(zip(inputs, gts)):
35 | inputdata = cv2.imread(input)
36 | gtdata = cv2.imread(gt)
37 | inputdata = align_to_four(inputdata)
38 | gtdata = align_to_four(gtdata)
39 | psnrs.append(calc_psnr(inputdata, gtdata))
40 | ssims.append(calc_ssim(inputdata, gtdata))
41 |
42 | ave_psnr = np.array(psnrs).mean()
43 | ave_ssim = np.array(ssims).mean()
44 | return ave_psnr, ave_ssim
45 |
46 |
47 | if __name__ == '__main__':
48 | ave_psnr, ave_ssim = evaluate_raindrop('/mnt/sda/zsh/FPro/Deraining/results/FPro_AGAN/test_a', '/mnt/sda/zsh/dataset/test_a/gt')
49 | print('PSNR: ', ave_psnr)
50 | print('SSIM: ', ave_ssim)
51 |
--------------------------------------------------------------------------------
/Deraining/test_AGAN.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 |
6 |
7 | import numpy as np
8 | import os
9 | import argparse
10 | from tqdm import tqdm
11 |
12 | import torch.nn as nn
13 | import torch
14 | import torch.nn.functional as F
15 | import utils
16 |
17 | from natsort import natsorted
18 | from glob import glob
19 | from basicsr.models.archs.FPro_arch import FPro
20 | from skimage import img_as_ubyte
21 | from pdb import set_trace as stx
22 |
23 | parser = argparse.ArgumentParser(description='Image Deraindrop using FPro')
24 |
25 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/dataset/', type=str, help='Directory of validation images')
26 | parser.add_argument('--result_dir', default='./results/FPro_AGAN/', type=str, help='Directory for results')
27 | parser.add_argument('--weights', default='./models/deraindrop_FPro.pth', type=str, help='Path to weights')
28 |
29 | args = parser.parse_args()
30 |
31 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
32 | _, C, H, W = imgtensor.shape
33 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
34 | while hstarts and hstarts[-1] + crop_size >= H:
35 | hstarts.pop()
36 | hstarts.append(H - crop_size)
37 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
38 | while wstarts and wstarts[-1] + crop_size >= W:
39 | wstarts.pop()
40 | wstarts.append(W - crop_size)
41 | starts = []
42 | split_data = []
43 | for hs in hstarts:
44 | for ws in wstarts:
45 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
46 | starts.append((hs, ws))
47 | split_data.append(cimgdata)
48 | return split_data, starts
49 |
50 | def get_scoremap(H, W, C, B=1, is_mean=True):
51 | center_h = H / 2
52 | center_w = W / 2
53 |
54 | score = torch.ones((B, C, H, W))
55 | if not is_mean:
56 | for h in range(H):
57 | for w in range(W):
58 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
59 | return score
60 |
61 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
62 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
63 | tot_score = torch.zeros((B, C, H, W))
64 | merge_img = torch.zeros((B, C, H, W))
65 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
66 | for simg, cstart in zip(split_data, starts):
67 | hs, ws = cstart
68 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
69 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
70 | merge_img = merge_img / tot_score
71 | return merge_img
72 |
73 | ####### Load yaml #######
74 | yaml_file = 'Options/RealDeraindrop_FPro.yml'
75 | import yaml
76 |
77 | try:
78 | from yaml import CLoader as Loader
79 | except ImportError:
80 | from yaml import Loader
81 |
82 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
83 |
84 | s = x['network_g'].pop('type')
85 | ##########################
86 |
87 | model_restoration = FPro(**x['network_g'])
88 |
89 | checkpoint = torch.load(args.weights)
90 | model_restoration.load_state_dict(checkpoint['params'])
91 | print("===>Testing using weights: ",args.weights)
92 | model_restoration.cuda()
93 | model_restoration = nn.DataParallel(model_restoration)
94 | model_restoration.eval()
95 |
96 |
97 | factor = 8
98 | datasets = ['test_a']
99 |
100 | for dataset in datasets:
101 | result_dir = os.path.join(args.result_dir, dataset)
102 | os.makedirs(result_dir, exist_ok=True)
103 |
104 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain')
105 | inp_dir = os.path.join(args.input_dir, dataset, 'data')
106 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
107 | with torch.no_grad():
108 | for file_ in tqdm(files):
109 | torch.cuda.ipc_collect()
110 | torch.cuda.empty_cache()
111 |
112 | img = np.float32(utils.load_img(file_))/255.
113 | img = torch.from_numpy(img).permute(2,0,1)
114 | input_ = img.unsqueeze(0).cuda()
115 |
116 | # Padding in case images are not multiples of 8
117 | h,w = input_.shape[2], input_.shape[3]
118 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
119 | padh = H-h if h%factor!=0 else 0
120 | padw = W-w if w%factor!=0 else 0
121 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
122 |
123 | B, C, H, W = input_.shape
124 | corp_size_arg = 256
125 | overlap_size_arg = 200
126 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
127 | for i, data in enumerate(split_data):
128 | split_data[i] = model_restoration(data).cpu()
129 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
130 | # restored = model_restoration(input_)
131 |
132 | restored = restored[:,:,:h,:w]
133 |
134 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
135 |
136 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
137 |
--------------------------------------------------------------------------------
/Deraining/test_spad.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 |
6 |
7 | import numpy as np
8 | import os
9 | import argparse
10 | from tqdm import tqdm
11 |
12 | import torch.nn as nn
13 | import torch
14 | import torch.nn.functional as F
15 | import utils
16 |
17 | from natsort import natsorted
18 | from glob import glob
19 | from basicsr.models.archs.FPro_arch import FPro
20 | from skimage import img_as_ubyte
21 | from pdb import set_trace as stx
22 |
23 | parser = argparse.ArgumentParser(description='Image Deraining using Restormer')
24 |
25 | parser.add_argument('--input_dir', default='/mnt/sda/zsh/derain/', type=str, help='Directory of validation images')
26 | parser.add_argument('--result_dir', default='./results/FPro/', type=str, help='Directory for results')
27 | parser.add_argument('--weights', default='./models/derain_spad.pth', type=str, help='Path to weights')
28 |
29 | args = parser.parse_args()
30 |
31 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
32 | _, C, H, W = imgtensor.shape
33 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
34 | while hstarts and hstarts[-1] + crop_size >= H:
35 | hstarts.pop()
36 | hstarts.append(H - crop_size)
37 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
38 | while wstarts and wstarts[-1] + crop_size >= W:
39 | wstarts.pop()
40 | wstarts.append(W - crop_size)
41 | starts = []
42 | split_data = []
43 | for hs in hstarts:
44 | for ws in wstarts:
45 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
46 | starts.append((hs, ws))
47 | split_data.append(cimgdata)
48 | return split_data, starts
49 |
50 | def get_scoremap(H, W, C, B=1, is_mean=True):
51 | center_h = H / 2
52 | center_w = W / 2
53 |
54 | score = torch.ones((B, C, H, W))
55 | if not is_mean:
56 | for h in range(H):
57 | for w in range(W):
58 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
59 | return score
60 |
61 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
62 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
63 | tot_score = torch.zeros((B, C, H, W))
64 | merge_img = torch.zeros((B, C, H, W))
65 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
66 | for simg, cstart in zip(split_data, starts):
67 | hs, ws = cstart
68 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
69 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
70 | merge_img = merge_img / tot_score
71 | return merge_img
72 |
73 | ####### Load yaml #######
74 | yaml_file = 'Options/Deraining_FPro_spad.yml'
75 | import yaml
76 |
77 | try:
78 | from yaml import CLoader as Loader
79 | except ImportError:
80 | from yaml import Loader
81 |
82 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
83 |
84 | s = x['network_g'].pop('type')
85 | ##########################
86 |
87 | model_restoration = FPro(**x['network_g'])
88 |
89 | checkpoint = torch.load(args.weights)
90 | model_restoration.load_state_dict(checkpoint['params'])
91 | print("===>Testing using weights: ",args.weights)
92 | model_restoration.cuda()
93 | model_restoration = nn.DataParallel(model_restoration)
94 | model_restoration.eval()
95 |
96 |
97 | factor = 8
98 | datasets = ['real_test_1000']
99 |
100 | for dataset in datasets:
101 | result_dir = os.path.join(args.result_dir, dataset)
102 | os.makedirs(result_dir, exist_ok=True)
103 |
104 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain')
105 | inp_dir = os.path.join(args.input_dir, dataset, 'rain')
106 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
107 | with torch.no_grad():
108 | for file_ in tqdm(files):
109 | torch.cuda.ipc_collect()
110 | torch.cuda.empty_cache()
111 |
112 | img = np.float32(utils.load_img(file_))/255.
113 | img = torch.from_numpy(img).permute(2,0,1)
114 | input_ = img.unsqueeze(0).cuda()
115 |
116 | # Padding in case images are not multiples of 8
117 | h,w = input_.shape[2], input_.shape[3]
118 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
119 | padh = H-h if h%factor!=0 else 0
120 | padw = W-w if w%factor!=0 else 0
121 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
122 |
123 | B, C, H, W = input_.shape
124 | corp_size_arg = 256
125 | overlap_size_arg = 200
126 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
127 | for i, data in enumerate(split_data):
128 | split_data[i] = model_restoration(data).cpu()
129 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
130 |
131 | restored = restored[:,:,:h,:w]
132 |
133 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
134 |
135 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
136 |
--------------------------------------------------------------------------------
/Deraining/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | This repository is built in PyTorch 1.8.1 and tested on Ubuntu 16.04 environment (Python3.7, CUDA10.2, cuDNN7.6).
4 | Follow these intructions
5 |
6 | 1. Clone our repository
7 | ```
8 | git clone https://github.com/swz30/Restormer.git
9 | cd Restormer
10 | ```
11 |
12 | 2. Make conda environment
13 | ```
14 | conda create -n pytorch181 python=3.7
15 | conda activate pytorch181
16 | ```
17 |
18 | 3. Install dependencies
19 | ```
20 | conda install pytorch=1.8 torchvision cudatoolkit=10.2 -c pytorch
21 | pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm
22 | pip install einops gdown addict future lmdb numpy pyyaml requests scipy tb-nightly yapf lpips
23 | ```
24 |
25 | 4. Install basicsr
26 | ```
27 | python setup.py develop --no_cuda_ext
28 | ```
29 |
30 | ### Download datasets from Google Drive
31 |
32 | To be able to download datasets automatically you would need `go` and `gdrive` installed.
33 |
34 | 1. You can install `go` with the following
35 | ```
36 | curl -O https://storage.googleapis.com/golang/go1.11.1.linux-amd64.tar.gz
37 | mkdir -p ~/installed
38 | tar -C ~/installed -xzf go1.11.1.linux-amd64.tar.gz
39 | mkdir -p ~/go
40 | ```
41 |
42 | 2. Add the lines in `~/.bashrc`
43 | ```
44 | export GOPATH=$HOME/go
45 | export PATH=$PATH:$HOME/go/bin:$HOME/installed/go/bin
46 | ```
47 |
48 | 3. Install `gdrive` using
49 | ```
50 | go get github.com/prasmussen/gdrive
51 | ```
52 |
53 | 4. Close current terminal and open a new terminal.
54 |
--------------------------------------------------------------------------------
/Motion_Deblurring/Options/Deblurring_FPro.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Deblurring_FPro
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage_denseHaze
13 | dataroot_gt: ./Motion_Deblurring/Datasets/train/GoPro/target_crops
14 | dataroot_lq: ./Motion_Deblurring/Datasets/train/GoPro/input_crops
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | # ### -------------Progressive training--------------------------
27 | # mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
28 | # iters: [92000,64000,48000,36000,36000,24000]
29 | # gt_size: 384 # Max patch size for progressive training
30 | # gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
31 | mini_batch_sizes: [2]
32 | iters: [600000]
33 | gt_size: 256
34 | gt_sizes: [256]
35 | ### ------------------------------------------------------------
36 |
37 | ### ------- Training on single fixed-patch size 128x128---------
38 | # mini_batch_sizes: [8]
39 | # iters: [300000]
40 | # gt_size: 128
41 | # gt_sizes: [128]
42 | ### ------------------------------------------------------------
43 |
44 | dataset_enlarge_ratio: 1
45 | prefetch_mode: ~
46 |
47 | val:
48 | name: ValSet
49 | type: Dataset_PairedImage_denseHaze
50 | dataroot_gt: ./Motion_Deblurring/Datasets/val/GoPro/target_crops
51 | dataroot_lq: ./Motion_Deblurring/Datasets/val/GoPro/input_crops
52 | gt_size: 256
53 | io_backend:
54 | type: disk
55 |
56 | # network structures
57 | network_g:
58 | type: Restormer
59 | inp_channels: 3
60 | out_channels: 3
61 | # input_res: 128
62 | dim: 48
63 | # num_blocks: [4,6,6,8]
64 | num_blocks: [2,3,6]
65 | # num_refinement_blocks: 4
66 | num_refinement_blocks: 2
67 | # heads: [1,2,4,8]
68 | heads: [2,4,8]
69 | # ffn_expansion_factor: 2.66
70 | ffn_expansion_factor: 3
71 | bias: False
72 | LayerNorm_type: WithBias
73 | dual_pixel_task: False
74 | # network_g:
75 | # type: Restormer
76 | # inp_channels: 3
77 | # out_channels: 3
78 | # dim: 48
79 | # num_blocks: [4,6,6,8]
80 | # num_refinement_blocks: 4
81 | # heads: [1,2,4,8]
82 | # ffn_expansion_factor: 2.66
83 | # bias: False
84 | # LayerNorm_type: WithBias
85 | # dual_pixel_task: False
86 |
87 |
88 | # path
89 | path:
90 | pretrain_network_g: ~
91 | strict_load_g: true
92 | resume_state: ~
93 |
94 | # training settings
95 | train:
96 | total_iter: 600000
97 | warmup_iter: -1 # no warm up
98 | use_grad_clip: true
99 |
100 | # Split 300k iterations into two cycles.
101 | # 1st cycle: fixed 3e-4 LR for 92k iters.
102 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
103 | scheduler:
104 | type: CosineAnnealingRestartCyclicLR
105 | periods: [184000, 416000]
106 | restart_weights: [1,1]
107 | eta_mins: [0.0003,0.000001]
108 |
109 | mixing_augs:
110 | mixup: false
111 | mixup_beta: 1.2
112 | use_identity: true
113 |
114 | optim_g:
115 | type: AdamW
116 | lr: !!float 3e-4
117 | weight_decay: !!float 1e-4
118 | betas: [0.9, 0.999]
119 |
120 | # losses
121 | pixel_opt:
122 | type: L1Loss
123 | loss_weight: 1
124 | reduction: mean
125 | fft_loss_opt:
126 | type: FFTLoss
127 | loss_weight: 0.1
128 | reduction: mean
129 |
130 | # validation settings
131 | val:
132 | window_size: 8
133 | val_freq: !!float 4e3
134 | save_img: false
135 | rgb2bgr: true
136 | use_image: true
137 | max_minibatch: 8
138 |
139 | metrics:
140 | psnr: # metric name, can be arbitrary
141 | type: calculate_psnr
142 | crop_border: 0
143 | test_y_channel: false
144 |
145 | # logging settings
146 | logger:
147 | print_freq: 1000
148 | save_checkpoint_freq: !!float 4e3
149 | use_tb_logger: true
150 | wandb:
151 | project: ~
152 | resume_id: ~
153 |
154 | # dist training settings
155 | dist_params:
156 | backend: nccl
157 | port: 29500
158 |
--------------------------------------------------------------------------------
/Motion_Deblurring/evaluate_gopro_hide.m:
--------------------------------------------------------------------------------
1 | %% Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | %% Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | %% https://arxiv.org/abs/2111.09881
4 |
5 | close all;clear all;
6 |
7 | % datasets = {'GoPro'};
8 | datasets = {'GoPro', 'HIDE'};
9 | num_set = length(datasets);
10 |
11 | tic
12 | delete(gcp('nocreate'))
13 | parpool('local',20);
14 |
15 | for idx_set = 1:num_set
16 | file_path = strcat('./results/', datasets{idx_set}, '/');
17 | gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/');
18 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
19 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
20 | img_num = length(path_list);
21 |
22 | total_psnr = 0;
23 | total_ssim = 0;
24 | if img_num > 0
25 | parfor j = 1:img_num
26 | image_name = path_list(j).name;
27 | gt_name = gt_list(j).name;
28 | input = imread(strcat(file_path,image_name));
29 | gt = imread(strcat(gt_path, gt_name));
30 | ssim_val = ssim(input, gt);
31 | psnr_val = psnr(input, gt);
32 | total_ssim = total_ssim + ssim_val;
33 | total_psnr = total_psnr + psnr_val;
34 | end
35 | end
36 | qm_psnr = total_psnr / img_num;
37 | qm_ssim = total_ssim / img_num;
38 |
39 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
40 |
41 | end
42 | delete(gcp('nocreate'))
43 | toc
44 |
--------------------------------------------------------------------------------
/Motion_Deblurring/generate_patches_gopro.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | ##### Data preparation file for training Restormer on the GoPro Dataset ########
6 |
7 | import cv2
8 | import numpy as np
9 | from glob import glob
10 | from natsort import natsorted
11 | import os
12 | from tqdm import tqdm
13 | from pdb import set_trace as stx
14 | from joblib import Parallel, delayed
15 | import multiprocessing
16 |
17 | def train_files(file_):
18 | lr_file, hr_file = file_
19 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0]
20 | lr_img = cv2.imread(lr_file)
21 | hr_img = cv2.imread(hr_file)
22 | num_patch = 0
23 | w, h = lr_img.shape[:2]
24 | if w > p_max and h > p_max:
25 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int))
26 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int))
27 | w1.append(w-patch_size)
28 | h1.append(h-patch_size)
29 | for i in w1:
30 | for j in h1:
31 | num_patch += 1
32 |
33 | lr_patch = lr_img[i:i+patch_size, j:j+patch_size,:]
34 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:]
35 |
36 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.png')
37 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.png')
38 |
39 | cv2.imwrite(lr_savename, lr_patch)
40 | cv2.imwrite(hr_savename, hr_patch)
41 |
42 | else:
43 | lr_savename = os.path.join(lr_tar, filename + '.png')
44 | hr_savename = os.path.join(hr_tar, filename + '.png')
45 |
46 | cv2.imwrite(lr_savename, lr_img)
47 | cv2.imwrite(hr_savename, hr_img)
48 |
49 | def val_files(file_):
50 | lr_file, hr_file = file_
51 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0]
52 | lr_img = cv2.imread(lr_file)
53 | hr_img = cv2.imread(hr_file)
54 |
55 | lr_savename = os.path.join(lr_tar, filename + '.png')
56 | hr_savename = os.path.join(hr_tar, filename + '.png')
57 |
58 | w, h = lr_img.shape[:2]
59 |
60 | i = (w-val_patch_size)//2
61 | j = (h-val_patch_size)//2
62 |
63 | lr_patch = lr_img[i:i+val_patch_size, j:j+val_patch_size,:]
64 | hr_patch = hr_img[i:i+val_patch_size, j:j+val_patch_size,:]
65 |
66 | cv2.imwrite(lr_savename, lr_patch)
67 | cv2.imwrite(hr_savename, hr_patch)
68 |
69 | ############ Prepare Training data ####################
70 | num_cores = 10
71 | patch_size = 512
72 | overlap = 256
73 | p_max = 0
74 |
75 | src = '/home/ubuntu/test/datasets/deblurring/GoPro/train'
76 | tar = 'Datasets/train/GoPro'
77 |
78 | lr_tar = os.path.join(tar, 'input_crops')
79 | hr_tar = os.path.join(tar, 'target_crops')
80 |
81 | os.makedirs(lr_tar, exist_ok=True)
82 | os.makedirs(hr_tar, exist_ok=True)
83 |
84 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg')))
85 | hr_files = natsorted(glob(os.path.join(src, 'groundtruth', '*.png')) + glob(os.path.join(src, 'groundtruth', '*.jpg')))
86 |
87 | files = [(i, j) for i, j in zip(lr_files, hr_files)]
88 |
89 | Parallel(n_jobs=num_cores)(delayed(train_files)(file_) for file_ in tqdm(files))
90 |
91 |
92 | ############ Prepare validation data ####################
93 | val_patch_size = 256
94 | src = '/home/ubuntu/test/datasets/deblurring/GoPro/test'
95 | tar = 'Datasets/val/GoPro'
96 |
97 | lr_tar = os.path.join(tar, 'input_crops')
98 | hr_tar = os.path.join(tar, 'target_crops')
99 |
100 | os.makedirs(lr_tar, exist_ok=True)
101 | os.makedirs(hr_tar, exist_ok=True)
102 |
103 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.png')) + glob(os.path.join(src, 'input', '*.jpg')))
104 | hr_files = natsorted(glob(os.path.join(src, 'groundtruth', '*.png')) + glob(os.path.join(src, 'groundtruth', '*.jpg')))
105 |
106 | files = [(i, j) for i, j in zip(lr_files, hr_files)]
107 |
108 | Parallel(n_jobs=num_cores)(delayed(val_files)(file_) for file_ in tqdm(files))
109 |
--------------------------------------------------------------------------------
/Motion_Deblurring/test_FPro.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 |
6 | import numpy as np
7 | import os
8 | import argparse
9 | from tqdm import tqdm
10 |
11 | import torch.nn as nn
12 | import torch
13 | import torch.nn.functional as F
14 | import utils
15 |
16 | from natsort import natsorted
17 | from glob import glob
18 | from basicsr.models.archs.FPro_arch import FPro
19 | from skimage import img_as_ubyte
20 | from pdb import set_trace as stx
21 |
22 | parser = argparse.ArgumentParser(description='Single Image Motion Deblurring using Restormer')
23 |
24 | parser.add_argument('--input_dir', default='/home/ubuntu13/zsh/dataset/Uformer/deblurring/', type=str, help='Directory of validation images')
25 | parser.add_argument('--result_dir', default='./results/FPro/', type=str, help='Directory for results')
26 | parser.add_argument('--weights', default='./models/deblur.pth', type=str, help='Path to weights')
27 | parser.add_argument('--dataset', default='GoPro', type=str, help='Test Dataset') # ['GoPro', 'hide', 'RealBlur_J', 'RealBlur_R']
28 |
29 | args = parser.parse_args()
30 |
31 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
32 | _, C, H, W = imgtensor.shape
33 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
34 | while hstarts and hstarts[-1] + crop_size >= H:
35 | hstarts.pop()
36 | hstarts.append(H - crop_size)
37 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
38 | while wstarts and wstarts[-1] + crop_size >= W:
39 | wstarts.pop()
40 | wstarts.append(W - crop_size)
41 | starts = []
42 | split_data = []
43 | for hs in hstarts:
44 | for ws in wstarts:
45 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
46 | starts.append((hs, ws))
47 | split_data.append(cimgdata)
48 | return split_data, starts
49 |
50 | def get_scoremap(H, W, C, B=1, is_mean=True):
51 | center_h = H / 2
52 | center_w = W / 2
53 |
54 | score = torch.ones((B, C, H, W))
55 | if not is_mean:
56 | for h in range(H):
57 | for w in range(W):
58 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
59 | return score
60 |
61 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
62 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
63 | tot_score = torch.zeros((B, C, H, W))
64 | merge_img = torch.zeros((B, C, H, W))
65 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
66 | for simg, cstart in zip(split_data, starts):
67 | hs, ws = cstart
68 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
69 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
70 | merge_img = merge_img / tot_score
71 | return merge_img
72 |
73 | ####### Load yaml #######
74 | yaml_file = 'Options/Deblurring_FPro.yml'
75 | import yaml
76 |
77 | try:
78 | from yaml import CLoader as Loader
79 | except ImportError:
80 | from yaml import Loader
81 |
82 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
83 |
84 | s = x['network_g'].pop('type')
85 | ##########################
86 |
87 | model_restoration = FPro(**x['network_g'])
88 |
89 | checkpoint = torch.load(args.weights)
90 | model_restoration.load_state_dict(checkpoint['params'])
91 | print("===>Testing using weights: ",args.weights)
92 | model_restoration.cuda()
93 | model_restoration = nn.DataParallel(model_restoration)
94 | model_restoration.eval()
95 |
96 |
97 | factor = 8
98 | dataset = args.dataset
99 | result_dir = os.path.join(args.result_dir, dataset)
100 | os.makedirs(result_dir, exist_ok=True)
101 |
102 | inp_dir = os.path.join(args.input_dir, dataset,'test', 'blur')
103 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
104 | with torch.no_grad():
105 | for file_ in tqdm(files):
106 | torch.cuda.ipc_collect()
107 | torch.cuda.empty_cache()
108 |
109 | img = np.float32(utils.load_img(file_))/255.
110 | img = torch.from_numpy(img).permute(2,0,1)
111 | input_ = img.unsqueeze(0).cuda()
112 |
113 | B, C, H, W = input_.shape
114 | corp_size_arg = 256
115 | overlap_size_arg = 200
116 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
117 | for i, data in enumerate(split_data):
118 | split_data[i] = model_restoration(data).cpu()
119 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
120 |
121 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
122 |
123 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
124 |
--------------------------------------------------------------------------------
/Motion_Deblurring/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration (ECCV 2024)
2 |
3 | [Shihao Zhou](https://joshyzhou.github.io/), [Jinshan Pan](https://jspan.github.io/), [Jinglei Shi](https://jingleishi.github.io/), [Duosheng Chen](https://github.com/Calvin11311), [Lishen Qu](https://github.com/qulishen) and [Jufeng Yang](https://cv.nankai.edu.cn/)
4 |
5 | #### News
6 | - **Jul 02, 2024:** FPro has been accepted to ECCV 2024 :tada:
7 |
8 |
9 |
10 | ## Training
11 | ### Derain
12 | To train FPro on SPAD, you can run:
13 | ```sh
14 | ./train.sh Deraining/Options/Deraining_FPro_spad.yml
15 | ```
16 | ### Dehaze
17 | To train FPro on SOTS, you can run:
18 | ```sh
19 | ./train.sh Dehaze/Options/RealDehazing_FPro.yml
20 | ```
21 | ### Deblur
22 | To train FPro on GoPro, you can run:
23 | ```sh
24 | ./train.sh Motion_Deblurring/Options/Deblurring_FPro.yml
25 | ```
26 | ### Deraindrop
27 | To train FPro on AGAN, you can run:
28 | ```sh
29 | ./train.sh Deraining/Options/RealDeraindrop_FPro.yml
30 | ```
31 | ### Demoire
32 | To train FPro on TIP18, you can run:
33 | ```sh
34 | ./train.sh Demoiring/Options/RealDemoiring_FPro.yml
35 | ```
36 |
37 | ## Evaluation
38 | To evaluate FPro, you can refer commands in 'test.sh'
39 |
40 | For evaluate on each dataset, you should uncomment corresponding line.
41 |
42 |
43 | ## Results
44 | Experiments are performed for different image processing tasks including, rain streak removal, raindrop removal, haze removal, motion blur removal, and moire pattern removal.
45 | Here is a summary table containing hyperlinks for easy navigation:
46 |
79 |
80 |
81 | ## Citation
82 | If you find this project useful, please consider citing:
83 |
84 | @inproceedings{zhou_ECCV2024_FPro,
85 | title={Seeing the Unseen: A Frequency Prompt Guided Transformer for Image Restoration},
86 | author={Zhou, Shihao and Pan, Jinshan and Shi, Jinglei and Chen, Duosheng and Qu, Lishen and Yang, Jufeng},
87 | booktitle={ECCV},
88 | year={2024}
89 | }
90 |
91 | ## Acknowledgement
92 |
93 | This code borrows heavily from [Restormer](https://github.com/swz30/Restormer).
--------------------------------------------------------------------------------
/basicsr/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/.DS_Store
--------------------------------------------------------------------------------
/basicsr/__pycache__/version.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/__pycache__/version.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/.DS_Store
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from functools import partial
7 | from os import path as osp
8 |
9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10 | from basicsr.utils import get_root_logger, scandir
11 | from basicsr.utils.dist_util import get_dist_info
12 |
13 | __all__ = ['create_dataset', 'create_dataloader']
14 |
15 | # automatically scan and import dataset modules
16 | # scan all the files under the data folder with '_dataset' in file names
17 | data_folder = osp.dirname(osp.abspath(__file__))
18 | dataset_filenames = [
19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
20 | if v.endswith('_dataset.py')
21 | ]
22 | # import all the dataset modules
23 | _dataset_modules = [
24 | importlib.import_module(f'basicsr.data.{file_name}')
25 | for file_name in dataset_filenames
26 | ]
27 |
28 |
29 | def create_dataset(dataset_opt):
30 | """Create dataset.
31 |
32 | Args:
33 | dataset_opt (dict): Configuration for dataset. It constains:
34 | name (str): Dataset name.
35 | type (str): Dataset type.
36 | """
37 | dataset_type = dataset_opt['type']
38 |
39 | # dynamic instantiation
40 | for module in _dataset_modules:
41 | dataset_cls = getattr(module, dataset_type, None)
42 | if dataset_cls is not None:
43 | break
44 | if dataset_cls is None:
45 | raise ValueError(f'Dataset {dataset_type} is not found.')
46 |
47 | dataset = dataset_cls(dataset_opt)
48 |
49 | logger = get_root_logger()
50 | logger.info(
51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
52 | 'is created.')
53 | return dataset
54 |
55 |
56 | def create_dataloader(dataset,
57 | dataset_opt,
58 | num_gpu=1,
59 | dist=False,
60 | sampler=None,
61 | seed=None):
62 | """Create dataloader.
63 |
64 | Args:
65 | dataset (torch.utils.data.Dataset): Dataset.
66 | dataset_opt (dict): Dataset options. It contains the following keys:
67 | phase (str): 'train' or 'val'.
68 | num_worker_per_gpu (int): Number of workers for each GPU.
69 | batch_size_per_gpu (int): Training batch size for each GPU.
70 | num_gpu (int): Number of GPUs. Used only in the train phase.
71 | Default: 1.
72 | dist (bool): Whether in distributed training. Used only in the train
73 | phase. Default: False.
74 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
75 | seed (int | None): Seed. Default: None
76 | """
77 | phase = dataset_opt['phase']
78 | rank, _ = get_dist_info()
79 | if phase == 'train':
80 | if dist: # distributed training
81 | batch_size = dataset_opt['batch_size_per_gpu']
82 | num_workers = dataset_opt['num_worker_per_gpu']
83 | else: # non-distributed training
84 | multiplier = 1 if num_gpu == 0 else num_gpu
85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
87 | dataloader_args = dict(
88 | dataset=dataset,
89 | batch_size=batch_size,
90 | shuffle=False,
91 | num_workers=num_workers,
92 | sampler=sampler,
93 | drop_last=True)
94 | if sampler is None:
95 | dataloader_args['shuffle'] = True
96 | dataloader_args['worker_init_fn'] = partial(
97 | worker_init_fn, num_workers=num_workers, rank=rank,
98 | seed=seed) if seed is not None else None
99 | elif phase in ['val', 'test']: # validation
100 | dataloader_args = dict(
101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
102 | else:
103 | raise ValueError(f'Wrong dataset phase: {phase}. '
104 | "Supported ones are 'train', 'val' and 'test'.")
105 |
106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
107 |
108 | prefetch_mode = dataset_opt.get('prefetch_mode')
109 | if prefetch_mode == 'cpu': # CPUPrefetcher
110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
111 | logger = get_root_logger()
112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
113 | f'num_prefetch_queue = {num_prefetch_queue}')
114 | return PrefetchDataLoader(
115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
116 | else:
117 | # prefetch_mode=None: Normal dataloader
118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
119 | return torch.utils.data.DataLoader(**dataloader_args)
120 |
121 |
122 | def worker_init_fn(worker_id, num_workers, rank, seed):
123 | # Set the worker seed to num_workers * rank + worker_id + seed
124 | worker_seed = num_workers * rank + worker_id + seed
125 | np.random.seed(worker_seed)
126 | random.seed(worker_seed)
127 |
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/data_sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/data_sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/data_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/data_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/reds_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/reds_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when restart the dataloader after each epoch
12 |
13 | Args:
14 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
15 | num_replicas (int | None): Number of processes participating in
16 | the training. It is usually the world_size.
17 | rank (int | None): Rank of the current process within num_replicas.
18 | ratio (int): Enlarging ratio. Default: 1.
19 | """
20 |
21 | def __init__(self, dataset, num_replicas, rank, ratio=1):
22 | self.dataset = dataset
23 | self.num_replicas = num_replicas
24 | self.rank = rank
25 | self.epoch = 0
26 | self.num_samples = math.ceil(
27 | len(self.dataset) * ratio / self.num_replicas)
28 | self.total_size = self.num_samples * self.num_replicas
29 |
30 | def __iter__(self):
31 | # deterministically shuffle based on epoch
32 | g = torch.Generator()
33 | g.manual_seed(self.epoch)
34 | indices = torch.randperm(self.total_size, generator=g).tolist()
35 |
36 | dataset_size = len(self.dataset)
37 | indices = [v % dataset_size for v in indices]
38 |
39 | # subsample
40 | indices = indices[self.rank:self.total_size:self.num_replicas]
41 | assert len(indices) == self.num_samples
42 |
43 | return iter(indices)
44 |
45 | def __len__(self):
46 | return self.num_samples
47 |
48 | def set_epoch(self, epoch):
49 | self.epoch = epoch
50 |
--------------------------------------------------------------------------------
/basicsr/data/ffhq_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.transforms import augment
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor
7 |
8 |
9 | class FFHQDataset(data.Dataset):
10 | """FFHQ dataset for StyleGAN.
11 |
12 | Args:
13 | opt (dict): Config for train datasets. It contains the following keys:
14 | dataroot_gt (str): Data root path for gt.
15 | io_backend (dict): IO backend type and other kwarg.
16 | mean (list | tuple): Image mean.
17 | std (list | tuple): Image std.
18 | use_hflip (bool): Whether to horizontally flip.
19 |
20 | """
21 |
22 | def __init__(self, opt):
23 | super(FFHQDataset, self).__init__()
24 | self.opt = opt
25 | # file client (io backend)
26 | self.file_client = None
27 | self.io_backend_opt = opt['io_backend']
28 |
29 | self.gt_folder = opt['dataroot_gt']
30 | self.mean = opt['mean']
31 | self.std = opt['std']
32 |
33 | if self.io_backend_opt['type'] == 'lmdb':
34 | self.io_backend_opt['db_paths'] = self.gt_folder
35 | if not self.gt_folder.endswith('.lmdb'):
36 | raise ValueError("'dataroot_gt' should end with '.lmdb', "
37 | f'but received {self.gt_folder}')
38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
39 | self.paths = [line.split('.')[0] for line in fin]
40 | else:
41 | # FFHQ has 70000 images in total
42 | self.paths = [
43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
44 | ]
45 |
46 | def __getitem__(self, index):
47 | if self.file_client is None:
48 | self.file_client = FileClient(
49 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | # load gt image
52 | gt_path = self.paths[index]
53 | img_bytes = self.file_client.get(gt_path)
54 | img_gt = imfrombytes(img_bytes, float32=True)
55 |
56 | # random horizontal flip
57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
58 | # BGR to RGB, HWC to CHW, numpy to tensor
59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
60 | # normalize
61 | normalize(img_gt, self.mean, self.std, inplace=True)
62 | return {'gt': img_gt, 'gt_path': gt_path}
63 |
64 | def __len__(self):
65 | return len(self.paths)
66 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Ref:
11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 |
13 | Args:
14 | generator: Python generator.
15 | num_prefetch_queue (int): Number of prefetch queue.
16 | """
17 |
18 | def __init__(self, generator, num_prefetch_queue):
19 | threading.Thread.__init__(self)
20 | self.queue = Queue.Queue(num_prefetch_queue)
21 | self.generator = generator
22 | self.daemon = True
23 | self.start()
24 |
25 | def run(self):
26 | for item in self.generator:
27 | self.queue.put(item)
28 | self.queue.put(None)
29 |
30 | def __next__(self):
31 | next_item = self.queue.get()
32 | if next_item is None:
33 | raise StopIteration
34 | return next_item
35 |
36 | def __iter__(self):
37 | return self
38 |
39 |
40 | class PrefetchDataLoader(DataLoader):
41 | """Prefetch version of dataloader.
42 |
43 | Ref:
44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45 |
46 | TODO:
47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48 | ddp.
49 |
50 | Args:
51 | num_prefetch_queue (int): Number of prefetch queue.
52 | kwargs (dict): Other arguments for dataloader.
53 | """
54 |
55 | def __init__(self, num_prefetch_queue, **kwargs):
56 | self.num_prefetch_queue = num_prefetch_queue
57 | super(PrefetchDataLoader, self).__init__(**kwargs)
58 |
59 | def __iter__(self):
60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61 |
62 |
63 | class CPUPrefetcher():
64 | """CPU prefetcher.
65 |
66 | Args:
67 | loader: Dataloader.
68 | """
69 |
70 | def __init__(self, loader):
71 | self.ori_loader = loader
72 | self.loader = iter(loader)
73 |
74 | def next(self):
75 | try:
76 | return next(self.loader)
77 | except StopIteration:
78 | return None
79 |
80 | def reset(self):
81 | self.loader = iter(self.ori_loader)
82 |
83 |
84 | class CUDAPrefetcher():
85 | """CUDA prefetcher.
86 |
87 | Ref:
88 | https://github.com/NVIDIA/apex/issues/304#
89 |
90 | It may consums more GPU memory.
91 |
92 | Args:
93 | loader: Dataloader.
94 | opt (dict): Options.
95 | """
96 |
97 | def __init__(self, loader, opt):
98 | self.ori_loader = loader
99 | self.loader = iter(loader)
100 | self.opt = opt
101 | self.stream = torch.cuda.Stream()
102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103 | self.preload()
104 |
105 | def preload(self):
106 | try:
107 | self.batch = next(self.loader) # self.batch is a dict
108 | except StopIteration:
109 | self.batch = None
110 | return None
111 | # put tensors to gpu
112 | with torch.cuda.stream(self.stream):
113 | for k, v in self.batch.items():
114 | if torch.is_tensor(v):
115 | self.batch[k] = self.batch[k].to(
116 | device=self.device, non_blocking=True)
117 |
118 | def next(self):
119 | torch.cuda.current_stream().wait_stream(self.stream)
120 | batch = self.batch
121 | self.preload()
122 | return batch
123 |
124 | def reset(self):
125 | self.loader = iter(self.ori_loader)
126 | self.preload()
127 |
--------------------------------------------------------------------------------
/basicsr/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7 |
8 |
9 | class SingleImageDataset(data.Dataset):
10 | """Read only lq images in the test phase.
11 |
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13 |
14 | There are two modes:
15 | 1. 'meta_info_file': Use meta information file to generate paths.
16 | 2. 'folder': Scan folders to generate paths.
17 |
18 | Args:
19 | opt (dict): Config for train datasets. It contains the following keys:
20 | dataroot_lq (str): Data root path for lq.
21 | meta_info_file (str): Path for meta information file.
22 | io_backend (dict): IO backend type and other kwarg.
23 | """
24 |
25 | def __init__(self, opt):
26 | super(SingleImageDataset, self).__init__()
27 | self.opt = opt
28 | # file client (io backend)
29 | self.file_client = None
30 | self.io_backend_opt = opt['io_backend']
31 | self.mean = opt['mean'] if 'mean' in opt else None
32 | self.std = opt['std'] if 'std' in opt else None
33 | self.lq_folder = opt['dataroot_lq']
34 |
35 | if self.io_backend_opt['type'] == 'lmdb':
36 | self.io_backend_opt['db_paths'] = [self.lq_folder]
37 | self.io_backend_opt['client_keys'] = ['lq']
38 | self.paths = paths_from_lmdb(self.lq_folder)
39 | elif 'meta_info_file' in self.opt:
40 | with open(self.opt['meta_info_file'], 'r') as fin:
41 | self.paths = [
42 | osp.join(self.lq_folder,
43 | line.split(' ')[0]) for line in fin
44 | ]
45 | else:
46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 |
48 | def __getitem__(self, index):
49 | if self.file_client is None:
50 | self.file_client = FileClient(
51 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
52 |
53 | # load lq image
54 | lq_path = self.paths[index]
55 | img_bytes = self.file_client.get(lq_path, 'lq')
56 | img_lq = imfrombytes(img_bytes, float32=True)
57 |
58 | # TODO: color space transform
59 | # BGR to RGB, HWC to CHW, numpy to tensor
60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61 | # normalize
62 | if self.mean is not None or self.std is not None:
63 | normalize(img_lq, self.mean, self.std, inplace=True)
64 | return {'lq': img_lq, 'lq_path': lq_path}
65 |
66 | def __len__(self):
67 | return len(self.paths)
68 |
--------------------------------------------------------------------------------
/basicsr/data/vimeo90k_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from pathlib import Path
4 | from torch.utils import data as data
5 |
6 | from basicsr.data.transforms import augment, paired_random_crop
7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8 |
9 |
10 | class Vimeo90KDataset(data.Dataset):
11 | """Vimeo90K dataset for training.
12 |
13 | The keys are generated from a meta info txt file.
14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
15 |
16 | Each line contains:
17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space.
18 | Examples:
19 | 00001/0001 7 (256,448,3)
20 | 00001/0002 7 (256,448,3)
21 |
22 | Key examples: "00001/0001"
23 | GT (gt): Ground-Truth;
24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
25 |
26 | The neighboring frame list for different num_frame:
27 | num_frame | frame list
28 | 1 | 4
29 | 3 | 3,4,5
30 | 5 | 2,3,4,5,6
31 | 7 | 1,2,3,4,5,6,7
32 |
33 | Args:
34 | opt (dict): Config for train dataset. It contains the following keys:
35 | dataroot_gt (str): Data root path for gt.
36 | dataroot_lq (str): Data root path for lq.
37 | meta_info_file (str): Path for meta information file.
38 | io_backend (dict): IO backend type and other kwarg.
39 |
40 | num_frame (int): Window size for input frames.
41 | gt_size (int): Cropped patched size for gt patches.
42 | random_reverse (bool): Random reverse input frames.
43 | use_flip (bool): Use horizontal flips.
44 | use_rot (bool): Use rotation (use vertical flip and transposing h
45 | and w for implementation).
46 |
47 | scale (bool): Scale, which will be added automatically.
48 | """
49 |
50 | def __init__(self, opt):
51 | super(Vimeo90KDataset, self).__init__()
52 | self.opt = opt
53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
54 | opt['dataroot_lq'])
55 |
56 | with open(opt['meta_info_file'], 'r') as fin:
57 | self.keys = [line.split(' ')[0] for line in fin]
58 |
59 | # file client (io backend)
60 | self.file_client = None
61 | self.io_backend_opt = opt['io_backend']
62 | self.is_lmdb = False
63 | if self.io_backend_opt['type'] == 'lmdb':
64 | self.is_lmdb = True
65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
66 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
67 |
68 | # indices of input images
69 | self.neighbor_list = [
70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
71 | ]
72 |
73 | # temporal augmentation configs
74 | self.random_reverse = opt['random_reverse']
75 | logger = get_root_logger()
76 | logger.info(f'Random reverse is {self.random_reverse}.')
77 |
78 | def __getitem__(self, index):
79 | if self.file_client is None:
80 | self.file_client = FileClient(
81 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
82 |
83 | # random reverse
84 | if self.random_reverse and random.random() < 0.5:
85 | self.neighbor_list.reverse()
86 |
87 | scale = self.opt['scale']
88 | gt_size = self.opt['gt_size']
89 | key = self.keys[index]
90 | clip, seq = key.split('/') # key example: 00001/0001
91 |
92 | # get the GT frame (im4.png)
93 | if self.is_lmdb:
94 | img_gt_path = f'{key}/im4'
95 | else:
96 | img_gt_path = self.gt_root / clip / seq / 'im4.png'
97 | img_bytes = self.file_client.get(img_gt_path, 'gt')
98 | img_gt = imfrombytes(img_bytes, float32=True)
99 |
100 | # get the neighboring LQ frames
101 | img_lqs = []
102 | for neighbor in self.neighbor_list:
103 | if self.is_lmdb:
104 | img_lq_path = f'{clip}/{seq}/im{neighbor}'
105 | else:
106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
107 | img_bytes = self.file_client.get(img_lq_path, 'lq')
108 | img_lq = imfrombytes(img_bytes, float32=True)
109 | img_lqs.append(img_lq)
110 |
111 | # randomly crop
112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
113 | img_gt_path)
114 |
115 | # augmentation - flip, rotate
116 | img_lqs.append(img_gt)
117 | img_results = augment(img_lqs, self.opt['use_flip'],
118 | self.opt['use_rot'])
119 |
120 | img_results = img2tensor(img_results)
121 | img_lqs = torch.stack(img_results[0:-1], dim=0)
122 | img_gt = img_results[-1]
123 |
124 | # img_lqs: (t, c, h, w)
125 | # img_gt: (c, h, w)
126 | # key: str
127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key}
128 |
129 | def __len__(self):
130 | return len(self.keys)
131 |
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .niqe import calculate_niqe
2 | from .psnr_ssim import calculate_psnr, calculate_ssim
3 |
4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
5 |
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/niqe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/niqe.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy import linalg
5 | from tqdm import tqdm
6 |
7 | from basicsr.models.archs.inception import InceptionV3
8 |
9 |
10 | def load_patched_inception_v3(device='cuda',
11 | resize_input=True,
12 | normalize_input=False):
13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
14 | # does resize the input.
15 | inception = InceptionV3([3],
16 | resize_input=resize_input,
17 | normalize_input=normalize_input)
18 | inception = nn.DataParallel(inception).eval().to(device)
19 | return inception
20 |
21 |
22 | @torch.no_grad()
23 | def extract_inception_features(data_generator,
24 | inception,
25 | len_generator=None,
26 | device='cuda'):
27 | """Extract inception features.
28 |
29 | Args:
30 | data_generator (generator): A data generator.
31 | inception (nn.Module): Inception model.
32 | len_generator (int): Length of the data_generator to show the
33 | progressbar. Default: None.
34 | device (str): Device. Default: cuda.
35 |
36 | Returns:
37 | Tensor: Extracted features.
38 | """
39 | if len_generator is not None:
40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
41 | else:
42 | pbar = None
43 | features = []
44 |
45 | for data in data_generator:
46 | if pbar:
47 | pbar.update(1)
48 | data = data.to(device)
49 | feature = inception(data)[0].view(data.shape[0], -1)
50 | features.append(feature.to('cpu'))
51 | if pbar:
52 | pbar.close()
53 | features = torch.cat(features, 0)
54 | return features
55 |
56 |
57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
58 | """Numpy implementation of the Frechet Distance.
59 |
60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
61 | and X_2 ~ N(mu_2, C_2) is
62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
63 | Stable version by Dougal J. Sutherland.
64 |
65 | Args:
66 | mu1 (np.array): The sample mean over activations.
67 | sigma1 (np.array): The covariance matrix over activations for
68 | generated samples.
69 | mu2 (np.array): The sample mean over activations, precalculated on an
70 | representative data set.
71 | sigma2 (np.array): The covariance matrix over activations,
72 | precalculated on an representative data set.
73 |
74 | Returns:
75 | float: The Frechet Distance.
76 | """
77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
78 | assert sigma1.shape == sigma2.shape, (
79 | 'Two covariances have different dimensions')
80 |
81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
82 |
83 | # Product might be almost singular
84 | if not np.isfinite(cov_sqrt).all():
85 | print('Product of cov matrices is singular. Adding {eps} to diagonal '
86 | 'of cov estimates')
87 | offset = np.eye(sigma1.shape[0]) * eps
88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
89 |
90 | # Numerical error might give slight imaginary component
91 | if np.iscomplexobj(cov_sqrt):
92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
93 | m = np.max(np.abs(cov_sqrt.imag))
94 | raise ValueError(f'Imaginary component {m}')
95 | cov_sqrt = cov_sqrt.real
96 |
97 | mean_diff = mu1 - mu2
98 | mean_norm = mean_diff @ mean_diff
99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
100 | fid = mean_norm + trace
101 |
102 | return fid
103 |
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from basicsr.utils.matlab_functions import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order='HWC'):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ['HWC', 'CHW']:
24 | raise ValueError(
25 | f'Wrong input_order {input_order}. Supported input_orders are '
26 | "'HWC' and 'CHW'")
27 | if len(img.shape) == 2:
28 | img = img[..., None]
29 | if input_order == 'CHW':
30 | img = img.transpose(1, 2, 0)
31 | return img
32 |
33 |
34 | def to_y_channel(img):
35 | """Change to Y channel of YCbCr.
36 |
37 | Args:
38 | img (ndarray): Images with range [0, 255].
39 |
40 | Returns:
41 | (ndarray): Images with range [0, 255] (float type) without round.
42 | """
43 | img = img.astype(np.float32) / 255.
44 | if img.ndim == 3 and img.shape[2] == 3:
45 | img = bgr2ycbcr(img, y_only=True)
46 | img = img[..., None]
47 | return img * 255.
48 |
--------------------------------------------------------------------------------
/basicsr/metrics/niqe.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | from scipy.ndimage.filters import convolve
5 | from scipy.special import gamma
6 |
7 | from basicsr.metrics.metric_util import reorder_image, to_y_channel
8 |
9 |
10 | def estimate_aggd_param(block):
11 | """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
12 |
13 | Args:
14 | block (ndarray): 2D Image block.
15 |
16 | Returns:
17 | tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
18 | distribution (Estimating the parames in Equation 7 in the paper).
19 | """
20 | block = block.flatten()
21 | gam = np.arange(0.2, 10.001, 0.001) # len = 9801
22 | gam_reciprocal = np.reciprocal(gam)
23 | r_gam = np.square(gamma(gam_reciprocal * 2)) / (
24 | gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
25 |
26 | left_std = np.sqrt(np.mean(block[block < 0]**2))
27 | right_std = np.sqrt(np.mean(block[block > 0]**2))
28 | gammahat = left_std / right_std
29 | rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
30 | rhatnorm = (rhat * (gammahat**3 + 1) *
31 | (gammahat + 1)) / ((gammahat**2 + 1)**2)
32 | array_position = np.argmin((r_gam - rhatnorm)**2)
33 |
34 | alpha = gam[array_position]
35 | beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
36 | beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
37 | return (alpha, beta_l, beta_r)
38 |
39 |
40 | def compute_feature(block):
41 | """Compute features.
42 |
43 | Args:
44 | block (ndarray): 2D Image block.
45 |
46 | Returns:
47 | list: Features with length of 18.
48 | """
49 | feat = []
50 | alpha, beta_l, beta_r = estimate_aggd_param(block)
51 | feat.extend([alpha, (beta_l + beta_r) / 2])
52 |
53 | # distortions disturb the fairly regular structure of natural images.
54 | # This deviation can be captured by analyzing the sample distribution of
55 | # the products of pairs of adjacent coefficients computed along
56 | # horizontal, vertical and diagonal orientations.
57 | shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
58 | for i in range(len(shifts)):
59 | shifted_block = np.roll(block, shifts[i], axis=(0, 1))
60 | alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
61 | # Eq. 8
62 | mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
63 | feat.extend([alpha, mean, beta_l, beta_r])
64 | return feat
65 |
66 |
67 | def niqe(img,
68 | mu_pris_param,
69 | cov_pris_param,
70 | gaussian_window,
71 | block_size_h=96,
72 | block_size_w=96):
73 | """Calculate NIQE (Natural Image Quality Evaluator) metric.
74 |
75 | Ref: Making a "Completely Blind" Image Quality Analyzer.
76 | This implementation could produce almost the same results as the official
77 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
78 |
79 | Note that we do not include block overlap height and width, since they are
80 | always 0 in the official implementation.
81 |
82 | For good performance, it is advisable by the official implemtation to
83 | divide the distorted image in to the same size patched as used for the
84 | construction of multivariate Gaussian model.
85 |
86 | Args:
87 | img (ndarray): Input image whose quality needs to be computed. The
88 | image must be a gray or Y (of YCbCr) image with shape (h, w).
89 | Range [0, 255] with float type.
90 | mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
91 | model calculated on the pristine dataset.
92 | cov_pris_param (ndarray): Covariance of a pre-defined multivariate
93 | Gaussian model calculated on the pristine dataset.
94 | gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
95 | image.
96 | block_size_h (int): Height of the blocks in to which image is divided.
97 | Default: 96 (the official recommended value).
98 | block_size_w (int): Width of the blocks in to which image is divided.
99 | Default: 96 (the official recommended value).
100 | """
101 | assert img.ndim == 2, (
102 | 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
103 | # crop image
104 | h, w = img.shape
105 | num_block_h = math.floor(h / block_size_h)
106 | num_block_w = math.floor(w / block_size_w)
107 | img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
108 |
109 | distparam = [] # dist param is actually the multiscale features
110 | for scale in (1, 2): # perform on two scales (1, 2)
111 | mu = convolve(img, gaussian_window, mode='nearest')
112 | sigma = np.sqrt(
113 | np.abs(
114 | convolve(np.square(img), gaussian_window, mode='nearest') -
115 | np.square(mu)))
116 | # normalize, as in Eq. 1 in the paper
117 | img_nomalized = (img - mu) / (sigma + 1)
118 |
119 | feat = []
120 | for idx_w in range(num_block_w):
121 | for idx_h in range(num_block_h):
122 | # process ecah block
123 | block = img_nomalized[idx_h * block_size_h //
124 | scale:(idx_h + 1) * block_size_h //
125 | scale, idx_w * block_size_w //
126 | scale:(idx_w + 1) * block_size_w //
127 | scale]
128 | feat.append(compute_feature(block))
129 |
130 | distparam.append(np.array(feat))
131 | # TODO: matlab bicubic downsample with anti-aliasing
132 | # for simplicity, now we use opencv instead, which will result in
133 | # a slight difference.
134 | if scale == 1:
135 | h, w = img.shape
136 | img = cv2.resize(
137 | img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
138 | img = img * 255.
139 |
140 | distparam = np.concatenate(distparam, axis=1)
141 |
142 | # fit a MVG (multivariate Gaussian) model to distorted patch features
143 | mu_distparam = np.nanmean(distparam, axis=0)
144 | # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
145 | distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
146 | cov_distparam = np.cov(distparam_no_nan, rowvar=False)
147 |
148 | # compute niqe quality, Eq. 10 in the paper
149 | invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
150 | quality = np.matmul(
151 | np.matmul((mu_pris_param - mu_distparam), invcov_param),
152 | np.transpose((mu_pris_param - mu_distparam)))
153 | quality = np.sqrt(quality)
154 |
155 | return quality
156 |
157 |
158 | def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
159 | """Calculate NIQE (Natural Image Quality Evaluator) metric.
160 |
161 | Ref: Making a "Completely Blind" Image Quality Analyzer.
162 | This implementation could produce almost the same results as the official
163 | MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
164 |
165 | We use the official params estimated from the pristine dataset.
166 | We use the recommended block size (96, 96) without overlaps.
167 |
168 | Args:
169 | img (ndarray): Input image whose quality needs to be computed.
170 | The input image must be in range [0, 255] with float/int type.
171 | The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
172 | If the input order is 'HWC' or 'CHW', it will be converted to gray
173 | or Y (of YCbCr) image according to the ``convert_to`` argument.
174 | crop_border (int): Cropped pixels in each edge of an image. These
175 | pixels are not involved in the metric calculation.
176 | input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
177 | Default: 'HWC'.
178 | convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
179 | Default: 'y'.
180 |
181 | Returns:
182 | float: NIQE result.
183 | """
184 |
185 | # we use the official params estimated from the pristine dataset.
186 | niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
187 | mu_pris_param = niqe_pris_params['mu_pris_param']
188 | cov_pris_param = niqe_pris_params['cov_pris_param']
189 | gaussian_window = niqe_pris_params['gaussian_window']
190 |
191 | img = img.astype(np.float32)
192 | if input_order != 'HW':
193 | img = reorder_image(img, input_order=input_order)
194 | if convert_to == 'y':
195 | img = to_y_channel(img)
196 | elif convert_to == 'gray':
197 | img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
198 | img = np.squeeze(img)
199 |
200 | if crop_border != 0:
201 | img = img[crop_border:-crop_border, crop_border:-crop_border]
202 |
203 | niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
204 |
205 | return niqe_result
206 |
--------------------------------------------------------------------------------
/basicsr/metrics/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/metrics/niqe_pris_params.npz
--------------------------------------------------------------------------------
/basicsr/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/.DS_Store
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import get_root_logger, scandir
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12 | if v.endswith('_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'basicsr.models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamic instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = get_root_logger()
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import arch modules
7 | # scan all the files under the 'archs' folder and collect files ending with
8 | # '_arch.py'
9 | arch_folder = osp.dirname(osp.abspath(__file__))
10 | arch_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12 | if v.endswith('_arch.py')
13 | ]
14 | # import all the arch modules
15 | _arch_modules = [
16 | importlib.import_module(f'basicsr.models.archs.{file_name}')
17 | for file_name in arch_filenames
18 | ]
19 |
20 |
21 | def dynamic_instantiation(modules, cls_type, opt):
22 | """Dynamically instantiate class.
23 |
24 | Args:
25 | modules (list[importlib modules]): List of modules from importlib
26 | files.
27 | cls_type (str): Class type.
28 | opt (dict): Class initialization kwargs.
29 |
30 | Returns:
31 | class: Instantiated class.
32 | """
33 |
34 | for module in modules:
35 | cls_ = getattr(module, cls_type, None)
36 | if cls_ is not None:
37 | break
38 | if cls_ is None:
39 | raise ValueError(f'{cls_type} is not found.')
40 | return cls_(**opt)
41 |
42 |
43 | def define_network(opt):
44 | network_type = opt.pop('type')
45 | net = dynamic_instantiation(_arch_modules, network_type, opt)
46 | return net
47 |
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss,FFTLoss)
2 |
3 | __all__ = [
4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss','FFTLoss',
5 | ]
6 |
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/losses/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/models/losses/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import functional as F
3 |
4 |
5 | def reduce_loss(loss, reduction):
6 | """Reduce loss as specified.
7 |
8 | Args:
9 | loss (Tensor): Elementwise loss tensor.
10 | reduction (str): Options are 'none', 'mean' and 'sum'.
11 |
12 | Returns:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | else:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
26 | """Apply element-wise weight and reduce loss.
27 |
28 | Args:
29 | loss (Tensor): Element-wise loss.
30 | weight (Tensor): Element-wise weights. Default: None.
31 | reduction (str): Same as built-in losses of PyTorch. Options are
32 | 'none', 'mean' and 'sum'. Default: 'mean'.
33 |
34 | Returns:
35 | Tensor: Loss values.
36 | """
37 | # if weight is specified, apply element-wise weight
38 | if weight is not None:
39 | assert weight.dim() == loss.dim()
40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41 | loss = loss * weight
42 |
43 | # if weight is not specified or reduction is sum, just reduce the loss
44 | if weight is None or reduction == 'sum':
45 | loss = reduce_loss(loss, reduction)
46 | # if reduction is mean, then compute mean over weight region
47 | elif reduction == 'mean':
48 | if weight.size(1) > 1:
49 | weight = weight.sum()
50 | else:
51 | weight = weight.sum() * loss.size(1)
52 | loss = loss.sum() / weight
53 |
54 | return loss
55 |
56 |
57 | def weighted_loss(loss_func):
58 | """Create a weighted version of a given loss function.
59 |
60 | To use this decorator, the loss function must have the signature like
61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
62 | element-wise loss without any reduction. This decorator will add weight
63 | and reduction arguments to the function. The decorated function will have
64 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
65 | **kwargs)`.
66 |
67 | :Example:
68 |
69 | >>> import torch
70 | >>> @weighted_loss
71 | >>> def l1_loss(pred, target):
72 | >>> return (pred - target).abs()
73 |
74 | >>> pred = torch.Tensor([0, 2, 3])
75 | >>> target = torch.Tensor([1, 1, 1])
76 | >>> weight = torch.Tensor([1, 0, 1])
77 |
78 | >>> l1_loss(pred, target)
79 | tensor(1.3333)
80 | >>> l1_loss(pred, target, weight)
81 | tensor(1.5000)
82 | >>> l1_loss(pred, target, reduction='none')
83 | tensor([1., 1., 2.])
84 | >>> l1_loss(pred, target, weight, reduction='sum')
85 | tensor(3.)
86 | """
87 |
88 | @functools.wraps(loss_func)
89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90 | # get element-wise loss
91 | loss = loss_func(pred, target, **kwargs)
92 | loss = weight_reduce_loss(loss, weight, reduction)
93 | return loss
94 |
95 | return wrapper
96 |
--------------------------------------------------------------------------------
/basicsr/models/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from basicsr.models.losses.loss_util import weighted_loss
7 |
8 | _reduction_modes = ['none', 'mean', 'sum']
9 |
10 |
11 | @weighted_loss
12 | def l1_loss(pred, target):
13 | return F.l1_loss(pred, target, reduction='none')
14 |
15 |
16 | @weighted_loss
17 | def mse_loss(pred, target):
18 | return F.mse_loss(pred, target, reduction='none')
19 |
20 |
21 | # @weighted_loss
22 | # def charbonnier_loss(pred, target, eps=1e-12):
23 | # return torch.sqrt((pred - target)**2 + eps)
24 |
25 |
26 | class L1Loss(nn.Module):
27 | """L1 (mean absolute error, MAE) loss.
28 |
29 | Args:
30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
31 | reduction (str): Specifies the reduction to apply to the output.
32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
33 | """
34 |
35 | def __init__(self, loss_weight=1.0, reduction='mean'):
36 | super(L1Loss, self).__init__()
37 | if reduction not in ['none', 'mean', 'sum']:
38 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
39 | f'Supported ones are: {_reduction_modes}')
40 |
41 | self.loss_weight = loss_weight
42 | self.reduction = reduction
43 |
44 | def forward(self, pred, target, weight=None, **kwargs):
45 | """
46 | Args:
47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
50 | weights. Default: None.
51 | """
52 | return self.loss_weight * l1_loss(
53 | pred, target, weight, reduction=self.reduction)
54 |
55 |
56 | class FFTLoss(nn.Module):
57 | """L1 loss in frequency domain with FFT.
58 |
59 | Args:
60 | loss_weight (float): Loss weight for FFT loss. Default: 1.0.
61 | reduction (str): Specifies the reduction to apply to the output.
62 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
63 | """
64 |
65 | def __init__(self, loss_weight=1.0, reduction='mean'):
66 | super(FFTLoss, self).__init__()
67 | if reduction not in ['none', 'mean', 'sum']:
68 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
69 |
70 | self.loss_weight = loss_weight
71 | self.reduction = reduction
72 |
73 | def forward(self, pred, target, weight=None, **kwargs):
74 | """
75 | Args:
76 | pred (Tensor): of shape (..., C, H, W). Predicted tensor.
77 | target (Tensor): of shape (..., C, H, W). Ground truth tensor.
78 | weight (Tensor, optional): of shape (..., C, H, W). Element-wise
79 | weights. Default: None.
80 | """
81 |
82 | pred_fft = torch.fft.fft2(pred, dim=(-2, -1))
83 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1)
84 | target_fft = torch.fft.fft2(target, dim=(-2, -1))
85 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1)
86 | return self.loss_weight * l1_loss(pred_fft, target_fft, weight, reduction=self.reduction)
87 |
88 |
89 | class MSELoss(nn.Module):
90 | """MSE (L2) loss.
91 |
92 | Args:
93 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
94 | reduction (str): Specifies the reduction to apply to the output.
95 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
96 | """
97 |
98 | def __init__(self, loss_weight=1.0, reduction='mean'):
99 | super(MSELoss, self).__init__()
100 | if reduction not in ['none', 'mean', 'sum']:
101 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
102 | f'Supported ones are: {_reduction_modes}')
103 |
104 | self.loss_weight = loss_weight
105 | self.reduction = reduction
106 |
107 | def forward(self, pred, target, weight=None, **kwargs):
108 | """
109 | Args:
110 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
111 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
112 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
113 | weights. Default: None.
114 | """
115 | return self.loss_weight * mse_loss(
116 | pred, target, weight, reduction=self.reduction)
117 |
118 | class PSNRLoss(nn.Module):
119 |
120 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
121 | super(PSNRLoss, self).__init__()
122 | assert reduction == 'mean'
123 | self.loss_weight = loss_weight
124 | self.scale = 10 / np.log(10)
125 | self.toY = toY
126 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
127 | self.first = True
128 |
129 | def forward(self, pred, target):
130 | assert len(pred.size()) == 4
131 | if self.toY:
132 | if self.first:
133 | self.coef = self.coef.to(pred.device)
134 | self.first = False
135 |
136 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
137 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
138 |
139 | pred, target = pred / 255., target / 255.
140 | pass
141 | assert len(pred.size()) == 4
142 |
143 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
144 |
145 | class CharbonnierLoss(nn.Module):
146 | """Charbonnier Loss (L1)"""
147 |
148 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3):
149 | super(CharbonnierLoss, self).__init__()
150 | self.eps = eps
151 |
152 | def forward(self, x, y):
153 | diff = x - y
154 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
155 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
156 | return loss
157 |
--------------------------------------------------------------------------------
/basicsr/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 | import torch
5 |
6 |
7 | class MultiStepRestartLR(_LRScheduler):
8 | """ MultiStep with restarts learning rate scheme.
9 |
10 | Args:
11 | optimizer (torch.nn.optimizer): Torch optimizer.
12 | milestones (list): Iterations that will decrease learning rate.
13 | gamma (float): Decrease ratio. Default: 0.1.
14 | restarts (list): Restart iterations. Default: [0].
15 | restart_weights (list): Restart weights at each restart iteration.
16 | Default: [1].
17 | last_epoch (int): Used in _LRScheduler. Default: -1.
18 | """
19 |
20 | def __init__(self,
21 | optimizer,
22 | milestones,
23 | gamma=0.1,
24 | restarts=(0, ),
25 | restart_weights=(1, ),
26 | last_epoch=-1):
27 | self.milestones = Counter(milestones)
28 | self.gamma = gamma
29 | self.restarts = restarts
30 | self.restart_weights = restart_weights
31 | assert len(self.restarts) == len(
32 | self.restart_weights), 'restarts and their weights do not match.'
33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self):
36 | if self.last_epoch in self.restarts:
37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
38 | return [
39 | group['initial_lr'] * weight
40 | for group in self.optimizer.param_groups
41 | ]
42 | if self.last_epoch not in self.milestones:
43 | return [group['lr'] for group in self.optimizer.param_groups]
44 | return [
45 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
46 | for group in self.optimizer.param_groups
47 | ]
48 |
49 | class LinearLR(_LRScheduler):
50 | """
51 |
52 | Args:
53 | optimizer (torch.nn.optimizer): Torch optimizer.
54 | milestones (list): Iterations that will decrease learning rate.
55 | gamma (float): Decrease ratio. Default: 0.1.
56 | last_epoch (int): Used in _LRScheduler. Default: -1.
57 | """
58 |
59 | def __init__(self,
60 | optimizer,
61 | total_iter,
62 | last_epoch=-1):
63 | self.total_iter = total_iter
64 | super(LinearLR, self).__init__(optimizer, last_epoch)
65 |
66 | def get_lr(self):
67 | process = self.last_epoch / self.total_iter
68 | weight = (1 - process)
69 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
70 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
71 |
72 | class VibrateLR(_LRScheduler):
73 | """
74 |
75 | Args:
76 | optimizer (torch.nn.optimizer): Torch optimizer.
77 | milestones (list): Iterations that will decrease learning rate.
78 | gamma (float): Decrease ratio. Default: 0.1.
79 | last_epoch (int): Used in _LRScheduler. Default: -1.
80 | """
81 |
82 | def __init__(self,
83 | optimizer,
84 | total_iter,
85 | last_epoch=-1):
86 | self.total_iter = total_iter
87 | super(VibrateLR, self).__init__(optimizer, last_epoch)
88 |
89 | def get_lr(self):
90 | process = self.last_epoch / self.total_iter
91 |
92 | f = 0.1
93 | if process < 3 / 8:
94 | f = 1 - process * 8 / 3
95 | elif process < 5 / 8:
96 | f = 0.2
97 |
98 | T = self.total_iter // 80
99 | Th = T // 2
100 |
101 | t = self.last_epoch % T
102 |
103 | f2 = t / Th
104 | if t >= Th:
105 | f2 = 2 - f2
106 |
107 | weight = f * f2
108 |
109 | if self.last_epoch < Th:
110 | weight = max(0.1, weight)
111 |
112 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
113 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
114 |
115 | def get_position_from_periods(iteration, cumulative_period):
116 | """Get the position from a period list.
117 |
118 | It will return the index of the right-closest number in the period list.
119 | For example, the cumulative_period = [100, 200, 300, 400],
120 | if iteration == 50, return 0;
121 | if iteration == 210, return 2;
122 | if iteration == 300, return 2.
123 |
124 | Args:
125 | iteration (int): Current iteration.
126 | cumulative_period (list[int]): Cumulative period list.
127 |
128 | Returns:
129 | int: The position of the right-closest number in the period list.
130 | """
131 | for i, period in enumerate(cumulative_period):
132 | if iteration <= period:
133 | return i
134 |
135 |
136 | class CosineAnnealingRestartLR(_LRScheduler):
137 | """ Cosine annealing with restarts learning rate scheme.
138 |
139 | An example of config:
140 | periods = [10, 10, 10, 10]
141 | restart_weights = [1, 0.5, 0.5, 0.5]
142 | eta_min=1e-7
143 |
144 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
145 | scheduler will restart with the weights in restart_weights.
146 |
147 | Args:
148 | optimizer (torch.nn.optimizer): Torch optimizer.
149 | periods (list): Period for each cosine anneling cycle.
150 | restart_weights (list): Restart weights at each restart iteration.
151 | Default: [1].
152 | eta_min (float): The mimimum lr. Default: 0.
153 | last_epoch (int): Used in _LRScheduler. Default: -1.
154 | """
155 |
156 | def __init__(self,
157 | optimizer,
158 | periods,
159 | restart_weights=(1, ),
160 | eta_min=0,
161 | last_epoch=-1):
162 | self.periods = periods
163 | self.restart_weights = restart_weights
164 | self.eta_min = eta_min
165 | assert (len(self.periods) == len(self.restart_weights)
166 | ), 'periods and restart_weights should have the same length.'
167 | self.cumulative_period = [
168 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
169 | ]
170 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
171 |
172 | def get_lr(self):
173 | idx = get_position_from_periods(self.last_epoch,
174 | self.cumulative_period)
175 | current_weight = self.restart_weights[idx]
176 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
177 | current_period = self.periods[idx]
178 |
179 | return [
180 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
181 | (1 + math.cos(math.pi * (
182 | (self.last_epoch - nearest_restart) / current_period)))
183 | for base_lr in self.base_lrs
184 | ]
185 |
186 | class CosineAnnealingRestartCyclicLR(_LRScheduler):
187 | """ Cosine annealing with restarts learning rate scheme.
188 | An example of config:
189 | periods = [10, 10, 10, 10]
190 | restart_weights = [1, 0.5, 0.5, 0.5]
191 | eta_min=1e-7
192 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
193 | scheduler will restart with the weights in restart_weights.
194 | Args:
195 | optimizer (torch.nn.optimizer): Torch optimizer.
196 | periods (list): Period for each cosine anneling cycle.
197 | restart_weights (list): Restart weights at each restart iteration.
198 | Default: [1].
199 | eta_min (float): The mimimum lr. Default: 0.
200 | last_epoch (int): Used in _LRScheduler. Default: -1.
201 | """
202 |
203 | def __init__(self,
204 | optimizer,
205 | periods,
206 | restart_weights=(1, ),
207 | eta_mins=(0, ),
208 | last_epoch=-1):
209 | self.periods = periods
210 | self.restart_weights = restart_weights
211 | self.eta_mins = eta_mins
212 | assert (len(self.periods) == len(self.restart_weights)
213 | ), 'periods and restart_weights should have the same length.'
214 | self.cumulative_period = [
215 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
216 | ]
217 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
218 |
219 | def get_lr(self):
220 | idx = get_position_from_periods(self.last_epoch,
221 | self.cumulative_period)
222 | current_weight = self.restart_weights[idx]
223 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
224 | current_period = self.periods[idx]
225 | eta_min = self.eta_mins[idx]
226 |
227 | return [
228 | eta_min + current_weight * 0.5 * (base_lr - eta_min) *
229 | (1 + math.cos(math.pi * (
230 | (self.last_epoch - nearest_restart) / current_period)))
231 | for base_lr in self.base_lrs
232 | ]
233 |
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import create_dataloader, create_dataset
6 | from basicsr.models import create_model
7 | from basicsr.train import parse_options
8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
9 | make_exp_dirs)
10 | from basicsr.utils.options import dict2str
11 |
12 |
13 | def main():
14 | # parse options, set distributed setting, set ramdom seed
15 | opt = parse_options(is_train=False)
16 |
17 | torch.backends.cudnn.benchmark = True
18 | # torch.backends.cudnn.deterministic = True
19 |
20 | # mkdir and initialize loggers
21 | make_exp_dirs(opt)
22 | log_file = osp.join(opt['path']['log'],
23 | f"test_{opt['name']}_{get_time_str()}.log")
24 | logger = get_root_logger(
25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
26 | logger.info(get_env_info())
27 | logger.info(dict2str(opt))
28 |
29 | # create test dataset and dataloader
30 | test_loaders = []
31 | for phase, dataset_opt in sorted(opt['datasets'].items()):
32 | test_set = create_dataset(dataset_opt)
33 | test_loader = create_dataloader(
34 | test_set,
35 | dataset_opt,
36 | num_gpu=opt['num_gpu'],
37 | dist=opt['dist'],
38 | sampler=None,
39 | seed=opt['manual_seed'])
40 | logger.info(
41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
42 | test_loaders.append(test_loader)
43 |
44 | # create model
45 | model = create_model(opt)
46 |
47 | for test_loader in test_loaders:
48 | test_set_name = test_loader.dataset.opt['name']
49 | logger.info(f'Testing {test_set_name}...')
50 | rgb2bgr = opt['val'].get('rgb2bgr', True)
51 | # wheather use uint8 image to compute metrics
52 | use_image = opt['val'].get('use_image', True)
53 | model.validation(
54 | test_loader,
55 | current_iter=opt['name'],
56 | tb_logger=None,
57 | save_img=opt['val']['save_img'],
58 | rgb2bgr=rgb2bgr, use_image=use_image)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_client import FileClient
2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP
3 | from .logger import (MessageLogger, get_env_info, get_root_logger,
4 | init_tb_logger, init_wandb_logger)
5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'check_resume',
31 | 'sizeof_fmt',
32 | 'padding',
33 | 'padding_DP',
34 | 'imfrombytesDP',
35 | 'create_lmdb_for_reds',
36 | 'create_lmdb_for_gopro',
37 | 'create_lmdb_for_rain13k',
38 | ]
39 |
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/dist_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/dist_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/file_client.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/file_client.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/flow_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/flow_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/img_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/img_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/FPro/8acde68b73514f7837a28240af45d461df35758c/basicsr/utils/__pycache__/options.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/bundle_submissions.py:
--------------------------------------------------------------------------------
1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de)
2 |
3 | # This file is part of the implementation as described in the CVPR 2017 paper:
4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
5 | # Please see the file LICENSE.txt for the license governing this code.
6 |
7 |
8 | import numpy as np
9 | import scipy.io as sio
10 | import os
11 | import h5py
12 |
13 | def bundle_submissions_raw(submission_folder,session):
14 | '''
15 | Bundles submission data for raw denoising
16 |
17 | submission_folder Folder where denoised images reside
18 |
19 | Output is written to /bundled/. Please submit
20 | the content of this folder.
21 | '''
22 |
23 | out_folder = os.path.join(submission_folder, session)
24 | # out_folder = os.path.join(submission_folder, "bundled/")
25 | try:
26 | os.mkdir(out_folder)
27 | except:pass
28 |
29 | israw = True
30 | eval_version="1.0"
31 |
32 | for i in range(50):
33 | Idenoised = np.zeros((20,), dtype=np.object)
34 | for bb in range(20):
35 | filename = '%04d_%02d.mat'%(i+1,bb+1)
36 | s = sio.loadmat(os.path.join(submission_folder,filename))
37 | Idenoised_crop = s["Idenoised_crop"]
38 | Idenoised[bb] = Idenoised_crop
39 | filename = '%04d.mat'%(i+1)
40 | sio.savemat(os.path.join(out_folder, filename),
41 | {"Idenoised": Idenoised,
42 | "israw": israw,
43 | "eval_version": eval_version},
44 | )
45 |
46 | def bundle_submissions_srgb(submission_folder,session):
47 | '''
48 | Bundles submission data for sRGB denoising
49 |
50 | submission_folder Folder where denoised images reside
51 |
52 | Output is written to /bundled/. Please submit
53 | the content of this folder.
54 | '''
55 | out_folder = os.path.join(submission_folder, session)
56 | # out_folder = os.path.join(submission_folder, "bundled/")
57 | try:
58 | os.mkdir(out_folder)
59 | except:pass
60 | israw = False
61 | eval_version="1.0"
62 |
63 | for i in range(50):
64 | Idenoised = np.zeros((20,), dtype=np.object)
65 | for bb in range(20):
66 | filename = '%04d_%02d.mat'%(i+1,bb+1)
67 | s = sio.loadmat(os.path.join(submission_folder,filename))
68 | Idenoised_crop = s["Idenoised_crop"]
69 | Idenoised[bb] = Idenoised_crop
70 | filename = '%04d.mat'%(i+1)
71 | sio.savemat(os.path.join(out_folder, filename),
72 | {"Idenoised": Idenoised,
73 | "israw": israw,
74 | "eval_version": eval_version},
75 | )
76 |
77 |
78 |
79 | def bundle_submissions_srgb_v1(submission_folder,session):
80 | '''
81 | Bundles submission data for sRGB denoising
82 |
83 | submission_folder Folder where denoised images reside
84 |
85 | Output is written to /bundled/. Please submit
86 | the content of this folder.
87 | '''
88 | out_folder = os.path.join(submission_folder, session)
89 | # out_folder = os.path.join(submission_folder, "bundled/")
90 | try:
91 | os.mkdir(out_folder)
92 | except:pass
93 | israw = False
94 | eval_version="1.0"
95 |
96 | for i in range(50):
97 | Idenoised = np.zeros((20,), dtype=np.object)
98 | for bb in range(20):
99 | filename = '%04d_%d.mat'%(i+1,bb+1)
100 | s = sio.loadmat(os.path.join(submission_folder,filename))
101 | Idenoised_crop = s["Idenoised_crop"]
102 | Idenoised[bb] = Idenoised_crop
103 | filename = '%04d.mat'%(i+1)
104 | sio.savemat(os.path.join(out_folder, filename),
105 | {"Idenoised": Idenoised,
106 | "israw": israw,
107 | "eval_version": eval_version},
108 | )
--------------------------------------------------------------------------------
/basicsr/utils/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6 |
7 | def prepare_keys(folder_path, suffix='png'):
8 | """Prepare image path list and keys for DIV2K dataset.
9 |
10 | Args:
11 | folder_path (str): Folder path.
12 |
13 | Returns:
14 | list[str]: Image path list.
15 | list[str]: Key list.
16 | """
17 | print('Reading image path list ...')
18 | img_path_list = sorted(
19 | list(scandir(folder_path, suffix=suffix, recursive=False)))
20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
21 |
22 | return img_path_list, keys
23 |
24 | def create_lmdb_for_reds():
25 | folder_path = './datasets/REDS/val/sharp_300'
26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
27 | img_path_list, keys = prepare_keys(folder_path, 'png')
28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
29 | #
30 | folder_path = './datasets/REDS/val/blur_300'
31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb'
32 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
34 |
35 | folder_path = './datasets/REDS/train/train_sharp'
36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
37 | img_path_list, keys = prepare_keys(folder_path, 'png')
38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
39 |
40 | folder_path = './datasets/REDS/train/train_blur_jpeg'
41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
42 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
44 |
45 |
46 | def create_lmdb_for_gopro():
47 | folder_path = './datasets/GoPro/train/blur_crops'
48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
49 |
50 | img_path_list, keys = prepare_keys(folder_path, 'png')
51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
52 |
53 | folder_path = './datasets/GoPro/train/sharp_crops'
54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
55 |
56 | img_path_list, keys = prepare_keys(folder_path, 'png')
57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
58 |
59 | folder_path = './datasets/GoPro/test/target'
60 | lmdb_path = './datasets/GoPro/test/target.lmdb'
61 |
62 | img_path_list, keys = prepare_keys(folder_path, 'png')
63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
64 |
65 | folder_path = './datasets/GoPro/test/input'
66 | lmdb_path = './datasets/GoPro/test/input.lmdb'
67 |
68 | img_path_list, keys = prepare_keys(folder_path, 'png')
69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
70 |
71 | def create_lmdb_for_rain13k():
72 | folder_path = './datasets/Rain13k/train/input'
73 | lmdb_path = './datasets/Rain13k/train/input.lmdb'
74 |
75 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
77 |
78 | folder_path = './datasets/Rain13k/train/target'
79 | lmdb_path = './datasets/Rain13k/train/target.lmdb'
80 |
81 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
83 |
84 | def create_lmdb_for_SIDD():
85 | folder_path = './datasets/SIDD/train/input_crops'
86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
87 |
88 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
90 |
91 | folder_path = './datasets/SIDD/train/gt_crops'
92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
93 |
94 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
96 |
97 | #for val
98 | folder_path = './datasets/SIDD/val/input_crops'
99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
101 | if not osp.exists(folder_path):
102 | os.makedirs(folder_path)
103 | assert osp.exists(mat_path)
104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
105 | N, B, H ,W, C = data.shape
106 | data = data.reshape(N*B, H, W, C)
107 | for i in tqdm(range(N*B)):
108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
109 | img_path_list, keys = prepare_keys(folder_path, 'png')
110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
111 |
112 | folder_path = './datasets/SIDD/val/gt_crops'
113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
115 | if not osp.exists(folder_path):
116 | os.makedirs(folder_path)
117 | assert osp.exists(mat_path)
118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
119 | N, B, H ,W, C = data.shape
120 | data = data.reshape(N*B, H, W, C)
121 | for i in tqdm(range(N*B)):
122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
123 | img_path_list, keys = prepare_keys(folder_path, 'png')
124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
125 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(
45 | f'scontrol show hostname {node_list} | head -n1')
46 | # specify master port
47 | if port is not None:
48 | os.environ['MASTER_PORT'] = str(port)
49 | elif 'MASTER_PORT' in os.environ:
50 | pass # use MASTER_PORT in the environment variable
51 | else:
52 | # 29500 is torch.distributed default port
53 | os.environ['MASTER_PORT'] = '29500'
54 | os.environ['MASTER_ADDR'] = addr
55 | os.environ['WORLD_SIZE'] = str(ntasks)
56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57 | os.environ['RANK'] = str(proc_id)
58 | dist.init_process_group(backend=backend)
59 |
60 |
61 | def get_dist_info():
62 | if dist.is_available():
63 | initialized = dist.is_initialized()
64 | else:
65 | initialized = False
66 | if initialized:
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | else:
70 | rank = 0
71 | world_size = 1
72 | return rank, world_size
73 |
74 |
75 | def master_only(func):
76 |
77 | @functools.wraps(func)
78 | def wrapper(*args, **kwargs):
79 | rank, _ = get_dist_info()
80 | if rank == 0:
81 | return func(*args, **kwargs)
82 |
83 | return wrapper
84 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import requests
3 | from tqdm import tqdm
4 |
5 | from .misc import sizeof_fmt
6 |
7 |
8 | def download_file_from_google_drive(file_id, save_path):
9 | """Download files from google drive.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13 |
14 | Args:
15 | file_id (str): File id.
16 | save_path (str): Save path.
17 | """
18 |
19 | session = requests.Session()
20 | URL = 'https://docs.google.com/uc?export=download'
21 | params = {'id': file_id}
22 |
23 | response = session.get(URL, params=params, stream=True)
24 | token = get_confirm_token(response)
25 | if token:
26 | params['confirm'] = token
27 | response = session.get(URL, params=params, stream=True)
28 |
29 | # get file size
30 | response_file_size = session.get(
31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
32 | if 'Content-Range' in response_file_size.headers:
33 | file_size = int(
34 | response_file_size.headers['Content-Range'].split('/')[1])
35 | else:
36 | file_size = None
37 |
38 | save_response_content(response, save_path, file_size)
39 |
40 |
41 | def get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 |
48 | def save_response_content(response,
49 | destination,
50 | file_size=None,
51 | chunk_size=32768):
52 | if file_size is not None:
53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
54 |
55 | readable_file_size = sizeof_fmt(file_size)
56 | else:
57 | pbar = None
58 |
59 | with open(destination, 'wb') as f:
60 | downloaded_size = 0
61 | for chunk in response.iter_content(chunk_size):
62 | downloaded_size += chunk_size
63 | if pbar is not None:
64 | pbar.update(1)
65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
66 | f'/ {readable_file_size}')
67 | if chunk: # filter out keep-alive new chunks
68 | f.write(chunk)
69 | if pbar is not None:
70 | pbar.close()
71 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 | sys.path.append(sys_path)
36 | try:
37 | import mc
38 | except ImportError:
39 | raise ImportError(
40 | 'Please install memcached to enable MemcachedBackend.')
41 |
42 | self.server_list_cfg = server_list_cfg
43 | self.client_cfg = client_cfg
44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
45 | self.client_cfg)
46 | # mc.pyvector servers as a point which points to a memory cache
47 | self._mc_buffer = mc.pyvector()
48 |
49 | def get(self, filepath):
50 | filepath = str(filepath)
51 | import mc
52 | self._client.Get(filepath, self._mc_buffer)
53 | value_buf = mc.ConvertBuffer(self._mc_buffer)
54 | return value_buf
55 |
56 | def get_text(self, filepath):
57 | raise NotImplementedError
58 |
59 |
60 | class HardDiskBackend(BaseStorageBackend):
61 | """Raw hard disks storage backend."""
62 |
63 | def get(self, filepath):
64 | filepath = str(filepath)
65 | with open(filepath, 'rb') as f:
66 | value_buf = f.read()
67 | return value_buf
68 |
69 | def get_text(self, filepath):
70 | filepath = str(filepath)
71 | with open(filepath, 'r') as f:
72 | value_buf = f.read()
73 | return value_buf
74 |
75 |
76 | class LmdbBackend(BaseStorageBackend):
77 | """Lmdb storage backend.
78 |
79 | Args:
80 | db_paths (str | list[str]): Lmdb database paths.
81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
82 | readonly (bool, optional): Lmdb environment parameter. If True,
83 | disallow any write operations. Default: True.
84 | lock (bool, optional): Lmdb environment parameter. If False, when
85 | concurrent access occurs, do not lock the database. Default: False.
86 | readahead (bool, optional): Lmdb environment parameter. If False,
87 | disable the OS filesystem readahead mechanism, which may improve
88 | random read performance when a database is larger than RAM.
89 | Default: False.
90 |
91 | Attributes:
92 | db_paths (list): Lmdb database path.
93 | _client (list): A list of several lmdb envs.
94 | """
95 |
96 | def __init__(self,
97 | db_paths,
98 | client_keys='default',
99 | readonly=True,
100 | lock=False,
101 | readahead=False,
102 | **kwargs):
103 | try:
104 | import lmdb
105 | except ImportError:
106 | raise ImportError('Please install lmdb to enable LmdbBackend.')
107 |
108 | if isinstance(client_keys, str):
109 | client_keys = [client_keys]
110 |
111 | if isinstance(db_paths, list):
112 | self.db_paths = [str(v) for v in db_paths]
113 | elif isinstance(db_paths, str):
114 | self.db_paths = [str(db_paths)]
115 | assert len(client_keys) == len(self.db_paths), (
116 | 'client_keys and db_paths should have the same length, '
117 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
118 |
119 | self._client = {}
120 |
121 | for client, path in zip(client_keys, self.db_paths):
122 | self._client[client] = lmdb.open(
123 | path,
124 | readonly=readonly,
125 | lock=lock,
126 | readahead=readahead,
127 | map_size=8*1024*10485760,
128 | # max_readers=1,
129 | **kwargs)
130 |
131 | def get(self, filepath, client_key):
132 | """Get values according to the filepath from one lmdb named client_key.
133 |
134 | Args:
135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
136 | client_key (str): Used for distinguishing differnet lmdb envs.
137 | """
138 | filepath = str(filepath)
139 | assert client_key in self._client, (f'client_key {client_key} is not '
140 | 'in lmdb clients.')
141 | client = self._client[client_key]
142 | with client.begin(write=False) as txn:
143 | value_buf = txn.get(filepath.encode('ascii'))
144 | return value_buf
145 |
146 | def get_text(self, filepath):
147 | raise NotImplementedError
148 |
149 |
150 | class FileClient(object):
151 | """A general file client to access files in different backend.
152 |
153 | The client loads a file or text in a specified backend from its path
154 | and return it as a binary file. it can also register other backend
155 | accessor with a given name and backend class.
156 |
157 | Attributes:
158 | backend (str): The storage backend type. Options are "disk",
159 | "memcached" and "lmdb".
160 | client (:obj:`BaseStorageBackend`): The backend object.
161 | """
162 |
163 | _backends = {
164 | 'disk': HardDiskBackend,
165 | 'memcached': MemcachedBackend,
166 | 'lmdb': LmdbBackend,
167 | }
168 |
169 | def __init__(self, backend='disk', **kwargs):
170 | if backend not in self._backends:
171 | raise ValueError(
172 | f'Backend {backend} is not supported. Currently supported ones'
173 | f' are {list(self._backends.keys())}')
174 | self.backend = backend
175 | self.client = self._backends[backend](**kwargs)
176 |
177 | def get(self, filepath, client_key='default'):
178 | # client_key is used only for lmdb, where different fileclients have
179 | # different lmdb environments.
180 | if self.backend == 'lmdb':
181 | return self.client.get(filepath, client_key)
182 | else:
183 | return self.client.get(filepath)
184 |
185 | def get_text(self, filepath):
186 | return self.client.get_text(filepath)
187 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2 | import cv2
3 | import numpy as np
4 | import os
5 |
6 |
7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8 | """Read an optical flow map.
9 |
10 | Args:
11 | flow_path (ndarray or str): Flow path.
12 | quantize (bool): whether to read quantized pair, if set to True,
13 | remaining args will be passed to :func:`dequantize_flow`.
14 | concat_axis (int): The axis that dx and dy are concatenated,
15 | can be either 0 or 1. Ignored if quantize is False.
16 |
17 | Returns:
18 | ndarray: Optical flow represented as a (h, w, 2) numpy array
19 | """
20 | if quantize:
21 | assert concat_axis in [0, 1]
22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23 | if cat_flow.ndim != 2:
24 | raise IOError(f'{flow_path} is not a valid quantized flow file, '
25 | f'its dimension is {cat_flow.ndim}.')
26 | assert cat_flow.shape[concat_axis] % 2 == 0
27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
28 | flow = dequantize_flow(dx, dy, *args, **kwargs)
29 | else:
30 | with open(flow_path, 'rb') as f:
31 | try:
32 | header = f.read(4).decode('utf-8')
33 | except Exception:
34 | raise IOError(f'Invalid flow file: {flow_path}')
35 | else:
36 | if header != 'PIEH':
37 | raise IOError(f'Invalid flow file: {flow_path}, '
38 | 'header does not contain PIEH')
39 |
40 | w = np.fromfile(f, np.int32, 1).squeeze()
41 | h = np.fromfile(f, np.int32, 1).squeeze()
42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
43 |
44 | return flow.astype(np.float32)
45 |
46 |
47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
48 | """Write optical flow to file.
49 |
50 | If the flow is not quantized, it will be saved as a .flo file losslessly,
51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
52 | will be concatenated horizontally into a single image if quantize is True.)
53 |
54 | Args:
55 | flow (ndarray): (h, w, 2) array of optical flow.
56 | filename (str): Output filepath.
57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
58 | images. If set to True, remaining args will be passed to
59 | :func:`quantize_flow`.
60 | concat_axis (int): The axis that dx and dy are concatenated,
61 | can be either 0 or 1. Ignored if quantize is False.
62 | """
63 | if not quantize:
64 | with open(filename, 'wb') as f:
65 | f.write('PIEH'.encode('utf-8'))
66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
67 | flow = flow.astype(np.float32)
68 | flow.tofile(f)
69 | f.flush()
70 | else:
71 | assert concat_axis in [0, 1]
72 | dx, dy = quantize_flow(flow, *args, **kwargs)
73 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
74 | os.makedirs(filename, exist_ok=True)
75 | cv2.imwrite(dxdy, filename)
76 |
77 |
78 | def quantize_flow(flow, max_val=0.02, norm=True):
79 | """Quantize flow to [0, 255].
80 |
81 | After this step, the size of flow will be much smaller, and can be
82 | dumped as jpeg images.
83 |
84 | Args:
85 | flow (ndarray): (h, w, 2) array of optical flow.
86 | max_val (float): Maximum value of flow, values beyond
87 | [-max_val, max_val] will be truncated.
88 | norm (bool): Whether to divide flow values by image width/height.
89 |
90 | Returns:
91 | tuple[ndarray]: Quantized dx and dy.
92 | """
93 | h, w, _ = flow.shape
94 | dx = flow[..., 0]
95 | dy = flow[..., 1]
96 | if norm:
97 | dx = dx / w # avoid inplace operations
98 | dy = dy / h
99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
100 | flow_comps = [
101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
102 | ]
103 | return tuple(flow_comps)
104 |
105 |
106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
107 | """Recover from quantized flow.
108 |
109 | Args:
110 | dx (ndarray): Quantized dx.
111 | dy (ndarray): Quantized dy.
112 | max_val (float): Maximum value used when quantizing.
113 | denorm (bool): Whether to multiply flow values with width/height.
114 |
115 | Returns:
116 | ndarray: Dequantized flow.
117 | """
118 | assert dx.shape == dy.shape
119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
120 |
121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
122 |
123 | if denorm:
124 | dx *= dx.shape[1]
125 | dy *= dx.shape[0]
126 | flow = np.dstack((dx, dy))
127 | return flow
128 |
129 |
130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
131 | """Quantize an array of (-inf, inf) to [0, levels-1].
132 |
133 | Args:
134 | arr (ndarray): Input array.
135 | min_val (scalar): Minimum value to be clipped.
136 | max_val (scalar): Maximum value to be clipped.
137 | levels (int): Quantization levels.
138 | dtype (np.type): The type of the quantized array.
139 |
140 | Returns:
141 | tuple: Quantized array.
142 | """
143 | if not (isinstance(levels, int) and levels > 1):
144 | raise ValueError(
145 | f'levels must be a positive integer, but got {levels}')
146 | if min_val >= max_val:
147 | raise ValueError(
148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
149 |
150 | arr = np.clip(arr, min_val, max_val) - min_val
151 | quantized_arr = np.minimum(
152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
153 |
154 | return quantized_arr
155 |
156 |
157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
158 | """Dequantize an array.
159 |
160 | Args:
161 | arr (ndarray): Input array.
162 | min_val (scalar): Minimum value to be clipped.
163 | max_val (scalar): Maximum value to be clipped.
164 | levels (int): Quantization levels.
165 | dtype (np.type): The type of the dequantized array.
166 |
167 | Returns:
168 | tuple: Dequantized array.
169 | """
170 | if not (isinstance(levels, int) and levels > 1):
171 | raise ValueError(
172 | f'levels must be a positive integer, but got {levels}')
173 | if min_val >= max_val:
174 | raise ValueError(
175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
176 |
177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
178 | min_val) / levels + min_val
179 |
180 | return dequantized_arr
181 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def img2tensor(imgs, bgr2rgb=True, float32=True):
10 | """Numpy array to tensor.
11 |
12 | Args:
13 | imgs (list[ndarray] | ndarray): Input images.
14 | bgr2rgb (bool): Whether to change bgr to rgb.
15 | float32 (bool): Whether to change to float32.
16 |
17 | Returns:
18 | list[tensor] | tensor: Tensor images. If returned results only have
19 | one element, just return tensor.
20 | """
21 |
22 | def _totensor(img, bgr2rgb, float32):
23 | if img.shape[2] == 3 and bgr2rgb:
24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
25 | img = torch.from_numpy(img.transpose(2, 0, 1))
26 | if float32:
27 | img = img.float()
28 | return img
29 |
30 | if isinstance(imgs, list):
31 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
32 | else:
33 | return _totensor(imgs, bgr2rgb, float32)
34 |
35 |
36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
37 | """Convert torch Tensors into image numpy arrays.
38 |
39 | After clamping to [min, max], values will be normalized to [0, 1].
40 |
41 | Args:
42 | tensor (Tensor or list[Tensor]): Accept shapes:
43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
44 | 2) 3D Tensor of shape (3/1 x H x W);
45 | 3) 2D Tensor of shape (H x W).
46 | Tensor channel should be in RGB order.
47 | rgb2bgr (bool): Whether to change rgb to bgr.
48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
49 | to uint8 type with range [0, 255]; otherwise, float type with
50 | range [0, 1]. Default: ``np.uint8``.
51 | min_max (tuple[int]): min and max values for clamp.
52 |
53 | Returns:
54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
55 | shape (H x W). The channel order is BGR.
56 | """
57 | if not (torch.is_tensor(tensor) or
58 | (isinstance(tensor, list)
59 | and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(
61 | f'tensor or list of tensors expected, got {type(tensor)}')
62 |
63 | if torch.is_tensor(tensor):
64 | tensor = [tensor]
65 | result = []
66 | for _tensor in tensor:
67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
69 |
70 | n_dim = _tensor.dim()
71 | if n_dim == 4:
72 | img_np = make_grid(
73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))),
74 | normalize=False).numpy()
75 | img_np = img_np.transpose(1, 2, 0)
76 | if rgb2bgr:
77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
78 | elif n_dim == 3:
79 | img_np = _tensor.numpy()
80 | img_np = img_np.transpose(1, 2, 0)
81 | if img_np.shape[2] == 1: # gray image
82 | img_np = np.squeeze(img_np, axis=2)
83 | else:
84 | if rgb2bgr:
85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
86 | elif n_dim == 2:
87 | img_np = _tensor.numpy()
88 | else:
89 | raise TypeError('Only support 4D, 3D or 2D tensor. '
90 | f'But received with dimension: {n_dim}')
91 | if out_type == np.uint8:
92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
93 | img_np = (img_np * 255.0).round()
94 | img_np = img_np.astype(out_type)
95 | result.append(img_np)
96 | if len(result) == 1:
97 | result = result[0]
98 | return result
99 |
100 |
101 | def imfrombytes(content, flag='color', float32=False):
102 | """Read an image from bytes.
103 |
104 | Args:
105 | content (bytes): Image bytes got from files or other streams.
106 | flag (str): Flags specifying the color type of a loaded image,
107 | candidates are `color`, `grayscale` and `unchanged`.
108 | float32 (bool): Whether to change to float32., If True, will also norm
109 | to [0, 1]. Default: False.
110 |
111 | Returns:
112 | ndarray: Loaded image array.
113 | """
114 | img_np = np.frombuffer(content, np.uint8)
115 | imread_flags = {
116 | 'color': cv2.IMREAD_COLOR,
117 | 'grayscale': cv2.IMREAD_GRAYSCALE,
118 | 'unchanged': cv2.IMREAD_UNCHANGED
119 | }
120 | if img_np is None:
121 | raise Exception('None .. !!!')
122 | img = cv2.imdecode(img_np, imread_flags[flag])
123 | if float32:
124 | img = img.astype(np.float32) / 255.
125 | return img
126 |
127 | def imfrombytesDP(content, flag='color', float32=False):
128 | """Read an image from bytes.
129 |
130 | Args:
131 | content (bytes): Image bytes got from files or other streams.
132 | flag (str): Flags specifying the color type of a loaded image,
133 | candidates are `color`, `grayscale` and `unchanged`.
134 | float32 (bool): Whether to change to float32., If True, will also norm
135 | to [0, 1]. Default: False.
136 |
137 | Returns:
138 | ndarray: Loaded image array.
139 | """
140 | img_np = np.frombuffer(content, np.uint8)
141 | if img_np is None:
142 | raise Exception('None .. !!!')
143 | img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED)
144 | if float32:
145 | img = img.astype(np.float32) / 65535.
146 | return img
147 |
148 | def padding(img_lq, img_gt, gt_size):
149 | h, w, _ = img_lq.shape
150 |
151 | h_pad = max(0, gt_size - h)
152 | w_pad = max(0, gt_size - w)
153 |
154 | if h_pad == 0 and w_pad == 0:
155 | return img_lq, img_gt
156 |
157 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
158 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
159 | # print('img_lq', img_lq.shape, img_gt.shape)
160 | if img_lq.ndim == 2:
161 | img_lq = np.expand_dims(img_lq, axis=2)
162 | if img_gt.ndim == 2:
163 | img_gt = np.expand_dims(img_gt, axis=2)
164 | return img_lq, img_gt
165 |
166 | def padding_DP(img_lqL, img_lqR, img_gt, gt_size):
167 | h, w, _ = img_gt.shape
168 |
169 | h_pad = max(0, gt_size - h)
170 | w_pad = max(0, gt_size - w)
171 |
172 | if h_pad == 0 and w_pad == 0:
173 | return img_lqL, img_lqR, img_gt
174 |
175 | img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
176 | img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
177 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
178 | # print('img_lq', img_lq.shape, img_gt.shape)
179 | return img_lqL, img_lqR, img_gt
180 |
181 | def imwrite(img, file_path, params=None, auto_mkdir=True):
182 | """Write image to file.
183 |
184 | Args:
185 | img (ndarray): Image array to be written.
186 | file_path (str): Image file path.
187 | params (None or list): Same as opencv's :func:`imwrite` interface.
188 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
189 | whether to create it automatically.
190 |
191 | Returns:
192 | bool: Successful or not.
193 | """
194 | if auto_mkdir:
195 | dir_name = os.path.abspath(os.path.dirname(file_path))
196 | os.makedirs(dir_name, exist_ok=True)
197 | return cv2.imwrite(file_path, img, params)
198 |
199 |
200 | def crop_border(imgs, crop_border):
201 | """Crop borders of images.
202 |
203 | Args:
204 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
205 | crop_border (int): Crop border for each end of height and weight.
206 |
207 | Returns:
208 | list[ndarray]: Cropped images.
209 | """
210 | if crop_border == 0:
211 | return imgs
212 | else:
213 | if isinstance(imgs, list):
214 | return [
215 | v[crop_border:-crop_border, crop_border:-crop_border, ...]
216 | for v in imgs
217 | ]
218 | else:
219 | return imgs[crop_border:-crop_border, crop_border:-crop_border,
220 | ...]
221 |
--------------------------------------------------------------------------------
/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import lmdb
3 | import sys
4 | from multiprocessing import Pool
5 | from os import path as osp
6 | from tqdm import tqdm
7 |
8 |
9 | def make_lmdb_from_imgs(data_path,
10 | lmdb_path,
11 | img_path_list,
12 | keys,
13 | batch=5000,
14 | compress_level=1,
15 | multiprocessing_read=False,
16 | n_thread=40,
17 | map_size=None):
18 | """Make lmdb from images.
19 |
20 | Contents of lmdb. The file structure is:
21 | example.lmdb
22 | ├── data.mdb
23 | ├── lock.mdb
24 | ├── meta_info.txt
25 |
26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
27 | https://lmdb.readthedocs.io/en/release/ for more details.
28 |
29 | The meta_info.txt is a specified txt file to record the meta information
30 | of our datasets. It will be automatically created when preparing
31 | datasets by our provided dataset tools.
32 | Each line in the txt file records 1)image name (with extension),
33 | 2)image shape, and 3)compression level, separated by a white space.
34 |
35 | For example, the meta information could be:
36 | `000_00000000.png (720,1280,3) 1`, which means:
37 | 1) image name (with extension): 000_00000000.png;
38 | 2) image shape: (720,1280,3);
39 | 3) compression level: 1
40 |
41 | We use the image name without extension as the lmdb key.
42 |
43 | If `multiprocessing_read` is True, it will read all the images to memory
44 | using multiprocessing. Thus, your server needs to have enough memory.
45 |
46 | Args:
47 | data_path (str): Data path for reading images.
48 | lmdb_path (str): Lmdb save path.
49 | img_path_list (str): Image path list.
50 | keys (str): Used for lmdb keys.
51 | batch (int): After processing batch images, lmdb commits.
52 | Default: 5000.
53 | compress_level (int): Compress level when encoding images. Default: 1.
54 | multiprocessing_read (bool): Whether use multiprocessing to read all
55 | the images to memory. Default: False.
56 | n_thread (int): For multiprocessing.
57 | map_size (int | None): Map size for lmdb env. If None, use the
58 | estimated size from images. Default: None
59 | """
60 |
61 | assert len(img_path_list) == len(keys), (
62 | 'img_path_list and keys should have the same length, '
63 | f'but got {len(img_path_list)} and {len(keys)}')
64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
65 | print(f'Totoal images: {len(img_path_list)}')
66 | if not lmdb_path.endswith('.lmdb'):
67 | raise ValueError("lmdb_path must end with '.lmdb'.")
68 | if osp.exists(lmdb_path):
69 | print(f'Folder {lmdb_path} already exists. Exit.')
70 | sys.exit(1)
71 |
72 | if multiprocessing_read:
73 | # read all the images to memory (multiprocessing)
74 | dataset = {} # use dict to keep the order for multiprocessing
75 | shapes = {}
76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...')
77 | pbar = tqdm(total=len(img_path_list), unit='image')
78 |
79 | def callback(arg):
80 | """get the image data and update pbar."""
81 | key, dataset[key], shapes[key] = arg
82 | pbar.update(1)
83 | pbar.set_description(f'Read {key}')
84 |
85 | pool = Pool(n_thread)
86 | for path, key in zip(img_path_list, keys):
87 | pool.apply_async(
88 | read_img_worker,
89 | args=(osp.join(data_path, path), key, compress_level),
90 | callback=callback)
91 | pool.close()
92 | pool.join()
93 | pbar.close()
94 | print(f'Finish reading {len(img_path_list)} images.')
95 |
96 | # create lmdb environment
97 | if map_size is None:
98 | # obtain data size for one image
99 | img = cv2.imread(
100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
101 | _, img_byte = cv2.imencode(
102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
103 | data_size_per_img = img_byte.nbytes
104 | print('Data size per image is: ', data_size_per_img)
105 | data_size = data_size_per_img * len(img_path_list)
106 | map_size = data_size * 10
107 |
108 | env = lmdb.open(lmdb_path, map_size=map_size)
109 |
110 | # write data to lmdb
111 | pbar = tqdm(total=len(img_path_list), unit='chunk')
112 | txn = env.begin(write=True)
113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
115 | pbar.update(1)
116 | pbar.set_description(f'Write {key}')
117 | key_byte = key.encode('ascii')
118 | if multiprocessing_read:
119 | img_byte = dataset[key]
120 | h, w, c = shapes[key]
121 | else:
122 | _, img_byte, img_shape = read_img_worker(
123 | osp.join(data_path, path), key, compress_level)
124 | h, w, c = img_shape
125 |
126 | txn.put(key_byte, img_byte)
127 | # write meta information
128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
129 | if idx % batch == 0:
130 | txn.commit()
131 | txn = env.begin(write=True)
132 | pbar.close()
133 | txn.commit()
134 | env.close()
135 | txt_file.close()
136 | print('\nFinish writing lmdb.')
137 |
138 |
139 | def read_img_worker(path, key, compress_level):
140 | """Read image worker.
141 |
142 | Args:
143 | path (str): Image path.
144 | key (str): Image key.
145 | compress_level (int): Compress level when encoding images.
146 |
147 | Returns:
148 | str: Image key.
149 | byte: Image byte.
150 | tuple[int]: Image shape.
151 | """
152 |
153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
154 | if img.ndim == 2:
155 | h, w = img.shape
156 | c = 1
157 | else:
158 | h, w, c = img.shape
159 | _, img_byte = cv2.imencode('.png', img,
160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
161 | return (key, img_byte, (h, w, c))
162 |
163 |
164 | class LmdbMaker():
165 | """LMDB Maker.
166 |
167 | Args:
168 | lmdb_path (str): Lmdb save path.
169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
170 | batch (int): After processing batch images, lmdb commits.
171 | Default: 5000.
172 | compress_level (int): Compress level when encoding images. Default: 1.
173 | """
174 |
175 | def __init__(self,
176 | lmdb_path,
177 | map_size=1024**4,
178 | batch=5000,
179 | compress_level=1):
180 | if not lmdb_path.endswith('.lmdb'):
181 | raise ValueError("lmdb_path must end with '.lmdb'.")
182 | if osp.exists(lmdb_path):
183 | print(f'Folder {lmdb_path} already exists. Exit.')
184 | sys.exit(1)
185 |
186 | self.lmdb_path = lmdb_path
187 | self.batch = batch
188 | self.compress_level = compress_level
189 | self.env = lmdb.open(lmdb_path, map_size=map_size)
190 | self.txn = self.env.begin(write=True)
191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
192 | self.counter = 0
193 |
194 | def put(self, img_byte, key, img_shape):
195 | self.counter += 1
196 | key_byte = key.encode('ascii')
197 | self.txn.put(key_byte, img_byte)
198 | # write meta information
199 | h, w, c = img_shape
200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
201 | if self.counter % self.batch == 0:
202 | self.txn.commit()
203 | self.txn = self.env.begin(write=True)
204 |
205 | def close(self):
206 | self.txn.commit()
207 | self.env.close()
208 | self.txt_file.close()
209 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 | initialized_logger = {}
8 |
9 |
10 | class MessageLogger():
11 | """Message logger for printing.
12 |
13 | Args:
14 | opt (dict): Config. It contains the following keys:
15 | name (str): Exp name.
16 | logger (dict): Contains 'print_freq' (str) for logger interval.
17 | train (dict): Contains 'total_iter' (int) for total iters.
18 | use_tb_logger (bool): Use tensorboard logger.
19 | start_iter (int): Start iter. Default: 1.
20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
21 | """
22 |
23 | def __init__(self, opt, start_iter=1, tb_logger=None):
24 | self.exp_name = opt['name']
25 | self.interval = opt['logger']['print_freq']
26 | self.start_iter = start_iter
27 | self.max_iters = opt['train']['total_iter']
28 | self.use_tb_logger = opt['logger']['use_tb_logger']
29 | self.tb_logger = tb_logger
30 | self.start_time = time.time()
31 | self.logger = get_root_logger()
32 |
33 | @master_only
34 | def __call__(self, log_vars):
35 | """Format logging message.
36 |
37 | Args:
38 | log_vars (dict): It contains the following keys:
39 | epoch (int): Epoch number.
40 | iter (int): Current iter.
41 | lrs (list): List for learning rates.
42 |
43 | time (float): Iter time.
44 | data_time (float): Data time for each iter.
45 | """
46 | # epoch, iter, learning rates
47 | epoch = log_vars.pop('epoch')
48 | current_iter = log_vars.pop('iter')
49 | lrs = log_vars.pop('lrs')
50 |
51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
52 | for v in lrs:
53 | message += f'{v:.3e},'
54 | message += ')] '
55 |
56 | # time and estimated time
57 | if 'time' in log_vars.keys():
58 | iter_time = log_vars.pop('time')
59 | data_time = log_vars.pop('data_time')
60 |
61 | total_time = time.time() - self.start_time
62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
65 | message += f'[eta: {eta_str}, '
66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
67 |
68 | # other items, especially losses
69 | for k, v in log_vars.items():
70 | message += f'{k}: {v:.4e} '
71 | # tensorboard logger
72 | if self.use_tb_logger and 'debug' not in self.exp_name:
73 | if k.startswith('l_'):
74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
75 | else:
76 | self.tb_logger.add_scalar(k, v, current_iter)
77 | self.logger.info(message)
78 |
79 |
80 | @master_only
81 | def init_tb_logger(log_dir):
82 | from torch.utils.tensorboard import SummaryWriter
83 | tb_logger = SummaryWriter(log_dir=log_dir)
84 | return tb_logger
85 |
86 |
87 | @master_only
88 | def init_wandb_logger(opt):
89 | """We now only use wandb to sync tensorboard log."""
90 | import wandb
91 | logger = logging.getLogger('basicsr')
92 |
93 | project = opt['logger']['wandb']['project']
94 | resume_id = opt['logger']['wandb'].get('resume_id')
95 | if resume_id:
96 | wandb_id = resume_id
97 | resume = 'allow'
98 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
99 | else:
100 | wandb_id = wandb.util.generate_id()
101 | resume = 'never'
102 |
103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
104 |
105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
106 |
107 |
108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
109 | """Get the root logger.
110 |
111 | The logger will be initialized if it has not been initialized. By default a
112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
113 | also be added.
114 |
115 | Args:
116 | logger_name (str): root logger name. Default: 'basicsr'.
117 | log_file (str | None): The log filename. If specified, a FileHandler
118 | will be added to the root logger.
119 | log_level (int): The root logger level. Note that only the process of
120 | rank 0 is affected, while other processes will set the level to
121 | "Error" and be silent most of the time.
122 |
123 | Returns:
124 | logging.Logger: The root logger.
125 | """
126 | logger = logging.getLogger(logger_name)
127 | # if the logger has been initialized, just return it
128 | if logger_name in initialized_logger:
129 | return logger
130 |
131 | format_str = '%(asctime)s %(levelname)s: %(message)s'
132 | stream_handler = logging.StreamHandler()
133 | stream_handler.setFormatter(logging.Formatter(format_str))
134 | logger.addHandler(stream_handler)
135 | logger.propagate = False
136 | rank, _ = get_dist_info()
137 | if rank != 0:
138 | logger.setLevel('ERROR')
139 | elif log_file is not None:
140 | logger.setLevel(log_level)
141 | # add file handler
142 | file_handler = logging.FileHandler(log_file, 'w')
143 | file_handler.setFormatter(logging.Formatter(format_str))
144 | file_handler.setLevel(log_level)
145 | logger.addHandler(file_handler)
146 | initialized_logger[logger_name] = True
147 | return logger
148 |
149 |
150 | def get_env_info():
151 | """Get environment information.
152 |
153 | Currently, only log the software version.
154 | """
155 | import torch
156 | import torchvision
157 |
158 | from basicsr.version import __version__
159 | msg = r"""
160 | ____ _ _____ ____
161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
164 | /_____/ \__,_//____//_/ \___//____//_/ |_|
165 | ______ __ __ __ __
166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
170 | """
171 | msg += ('\nVersion Information: '
172 | f'\n\tBasicSR: {__version__}'
173 | f'\n\tPyTorch: {torch.__version__}'
174 | f'\n\tTorchVision: {torchvision.__version__}')
175 | return msg
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 | from .logger import get_root_logger
10 |
11 |
12 | def set_random_seed(seed):
13 | """Set random seeds."""
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 |
20 |
21 | def get_time_str():
22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23 |
24 |
25 | def mkdir_and_rename(path):
26 | """mkdirs. If path exists, rename it with timestamp and create a new one.
27 |
28 | Args:
29 | path (str): Folder path.
30 | """
31 | if osp.exists(path):
32 | new_name = path + '_archived_' + get_time_str()
33 | print(f'Path already exists. Rename it to {new_name}', flush=True)
34 | os.rename(path, new_name)
35 | os.makedirs(path, exist_ok=True)
36 |
37 |
38 | @master_only
39 | def make_exp_dirs(opt):
40 | """Make dirs for experiments."""
41 | path_opt = opt['path'].copy()
42 | if opt['is_train']:
43 | mkdir_and_rename(path_opt.pop('experiments_root'))
44 | else:
45 | mkdir_and_rename(path_opt.pop('results_root'))
46 | for key, path in path_opt.items():
47 | if ('strict_load' not in key) and ('pretrain_network'
48 | not in key) and ('resume'
49 | not in key):
50 | os.makedirs(path, exist_ok=True)
51 |
52 |
53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
54 | """Scan a directory to find the interested files.
55 |
56 | Args:
57 | dir_path (str): Path of the directory.
58 | suffix (str | tuple(str), optional): File suffix that we are
59 | interested in. Default: None.
60 | recursive (bool, optional): If set to True, recursively scan the
61 | directory. Default: False.
62 | full_path (bool, optional): If set to True, include the dir_path.
63 | Default: False.
64 |
65 | Returns:
66 | A generator for all the interested files with relative pathes.
67 | """
68 |
69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
70 | raise TypeError('"suffix" must be a string or tuple of strings')
71 |
72 | root = dir_path
73 |
74 | def _scandir(dir_path, suffix, recursive):
75 | for entry in os.scandir(dir_path):
76 | if not entry.name.startswith('.') and entry.is_file():
77 | if full_path:
78 | return_path = entry.path
79 | else:
80 | return_path = osp.relpath(entry.path, root)
81 |
82 | if suffix is None:
83 | yield return_path
84 | elif return_path.endswith(suffix):
85 | yield return_path
86 | else:
87 | if recursive:
88 | yield from _scandir(
89 | entry.path, suffix=suffix, recursive=recursive)
90 | else:
91 | continue
92 |
93 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
94 |
95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
96 | """Scan a directory to find the interested files.
97 |
98 | Args:
99 | dir_path (str): Path of the directory.
100 | keywords (str | tuple(str), optional): File keywords that we are
101 | interested in. Default: None.
102 | recursive (bool, optional): If set to True, recursively scan the
103 | directory. Default: False.
104 | full_path (bool, optional): If set to True, include the dir_path.
105 | Default: False.
106 |
107 | Returns:
108 | A generator for all the interested files with relative pathes.
109 | """
110 |
111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)):
112 | raise TypeError('"keywords" must be a string or tuple of strings')
113 |
114 | root = dir_path
115 |
116 | def _scandir(dir_path, keywords, recursive):
117 | for entry in os.scandir(dir_path):
118 | if not entry.name.startswith('.') and entry.is_file():
119 | if full_path:
120 | return_path = entry.path
121 | else:
122 | return_path = osp.relpath(entry.path, root)
123 |
124 | if keywords is None:
125 | yield return_path
126 | elif return_path.find(keywords) > 0:
127 | yield return_path
128 | else:
129 | if recursive:
130 | yield from _scandir(
131 | entry.path, keywords=keywords, recursive=recursive)
132 | else:
133 | continue
134 |
135 | return _scandir(dir_path, keywords=keywords, recursive=recursive)
136 |
137 | def check_resume(opt, resume_iter):
138 | """Check resume states and pretrain_network paths.
139 |
140 | Args:
141 | opt (dict): Options.
142 | resume_iter (int): Resume iteration.
143 | """
144 | logger = get_root_logger()
145 | if opt['path']['resume_state']:
146 | # get all the networks
147 | networks = [key for key in opt.keys() if key.startswith('network_')]
148 | flag_pretrain = False
149 | for network in networks:
150 | if opt['path'].get(f'pretrain_{network}') is not None:
151 | flag_pretrain = True
152 | if flag_pretrain:
153 | logger.warning(
154 | 'pretrain_network path will be ignored during resuming.')
155 | # set pretrained model paths
156 | for network in networks:
157 | name = f'pretrain_{network}'
158 | basename = network.replace('network_', '')
159 | if opt['path'].get('ignore_resume_networks') is None or (
160 | basename not in opt['path']['ignore_resume_networks']):
161 | opt['path'][name] = osp.join(
162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
163 | logger.info(f"Set {name} to {opt['path'][name]}")
164 |
165 |
166 | def sizeof_fmt(size, suffix='B'):
167 | """Get human readable file size.
168 |
169 | Args:
170 | size (int): File size.
171 | suffix (str): Suffix. Default: 'B'.
172 |
173 | Return:
174 | str: Formated file siz.
175 | """
176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
177 | if abs(size) < 1024.0:
178 | return f'{size:3.1f} {unit}{suffix}'
179 | size /= 1024.0
180 | return f'{size:3.1f} Y{suffix}'
181 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from collections import OrderedDict
3 | from os import path as osp
4 |
5 |
6 | def ordered_yaml():
7 | """Support OrderedDict for yaml.
8 |
9 | Returns:
10 | yaml Loader and Dumper.
11 | """
12 | try:
13 | from yaml import CDumper as Dumper
14 | from yaml import CLoader as Loader
15 | except ImportError:
16 | from yaml import Dumper, Loader
17 |
18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
19 |
20 | def dict_representer(dumper, data):
21 | return dumper.represent_dict(data.items())
22 |
23 | def dict_constructor(loader, node):
24 | return OrderedDict(loader.construct_pairs(node))
25 |
26 | Dumper.add_representer(OrderedDict, dict_representer)
27 | Loader.add_constructor(_mapping_tag, dict_constructor)
28 | return Loader, Dumper
29 |
30 |
31 | def parse(opt_path, is_train=True):
32 | """Parse option file.
33 |
34 | Args:
35 | opt_path (str): Option file path.
36 | is_train (str): Indicate whether in training or not. Default: True.
37 |
38 | Returns:
39 | (dict): Options.
40 | """
41 | with open(opt_path, mode='r') as f:
42 | Loader, _ = ordered_yaml()
43 | opt = yaml.load(f, Loader=Loader)
44 |
45 | opt['is_train'] = is_train
46 |
47 | # datasets
48 | for phase, dataset in opt['datasets'].items():
49 | # for several datasets, e.g., test_1, test_2
50 | phase = phase.split('_')[0]
51 | dataset['phase'] = phase
52 | if 'scale' in opt:
53 | dataset['scale'] = opt['scale']
54 | if dataset.get('dataroot_gt') is not None:
55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
56 | if dataset.get('dataroot_lq') is not None:
57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
58 |
59 | # paths
60 | for key, val in opt['path'].items():
61 | if (val is not None) and ('resume_state' in key
62 | or 'pretrain_network' in key):
63 | opt['path'][key] = osp.expanduser(val)
64 | opt['path']['root'] = osp.abspath(
65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
66 | if is_train:
67 | experiments_root = osp.join(opt['path']['root'], 'experiments',
68 | opt['name'])
69 | opt['path']['experiments_root'] = experiments_root
70 | opt['path']['models'] = osp.join(experiments_root, 'models')
71 | opt['path']['training_states'] = osp.join(experiments_root,
72 | 'training_states')
73 | opt['path']['log'] = experiments_root
74 | opt['path']['visualization'] = osp.join(experiments_root,
75 | 'visualization')
76 |
77 | # change some options for debug mode
78 | if 'debug' in opt['name']:
79 | if 'val' in opt:
80 | opt['val']['val_freq'] = 8
81 | opt['logger']['print_freq'] = 1
82 | opt['logger']['save_checkpoint_freq'] = 8
83 | else: # test
84 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
85 | opt['path']['results_root'] = results_root
86 | opt['path']['log'] = results_root
87 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
88 |
89 | return opt
90 |
91 |
92 | def dict2str(opt, indent_level=1):
93 | """dict to string for printing options.
94 |
95 | Args:
96 | opt (dict): Option dict.
97 | indent_level (int): Indent level. Default: 1.
98 |
99 | Return:
100 | (str): Option string for printing.
101 | """
102 | msg = '\n'
103 | for k, v in opt.items():
104 | if isinstance(v, dict):
105 | msg += ' ' * (indent_level * 2) + k + ':['
106 | msg += dict2str(v, indent_level + 1)
107 | msg += ' ' * (indent_level * 2) + ']\n'
108 | else:
109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
110 | return msg
111 |
--------------------------------------------------------------------------------
/basicsr/version.py:
--------------------------------------------------------------------------------
1 | # GENERATED VERSION FILE
2 | # TIME: Sun Jan 28 22:05:08 2024
3 | __version__ = '1.2.0+733ceb2'
4 | short_version = '1.2.0'
5 | version_info = (1, 2, 0)
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import sys
8 | import time
9 | import torch
10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension,
11 | CUDAExtension)
12 |
13 | version_file = 'basicsr/version.py'
14 |
15 |
16 | def readme():
17 | return ''
18 | # with open('README.md', encoding='utf-8') as f:
19 | # content = f.read()
20 | # return content
21 |
22 |
23 | def get_git_hash():
24 |
25 | def _minimal_ext_cmd(cmd):
26 | # construct minimal environment
27 | env = {}
28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
29 | v = os.environ.get(k)
30 | if v is not None:
31 | env[k] = v
32 | # LANGUAGE is used on win32
33 | env['LANGUAGE'] = 'C'
34 | env['LANG'] = 'C'
35 | env['LC_ALL'] = 'C'
36 | out = subprocess.Popen(
37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
38 | return out
39 |
40 | try:
41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
42 | sha = out.strip().decode('ascii')
43 | except OSError:
44 | sha = 'unknown'
45 |
46 | return sha
47 |
48 |
49 | def get_hash():
50 | if os.path.exists('.git'):
51 | sha = get_git_hash()[:7]
52 | elif os.path.exists(version_file):
53 | try:
54 | from basicsr.version import __version__
55 | sha = __version__.split('+')[-1]
56 | except ImportError:
57 | raise ImportError('Unable to get git version')
58 | else:
59 | sha = 'unknown'
60 |
61 | return sha
62 |
63 |
64 | def write_version_py():
65 | content = """# GENERATED VERSION FILE
66 | # TIME: {}
67 | __version__ = '{}'
68 | short_version = '{}'
69 | version_info = ({})
70 | """
71 | sha = get_hash()
72 | with open('VERSION', 'r') as f:
73 | SHORT_VERSION = f.read().strip()
74 | VERSION_INFO = ', '.join(
75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
76 | VERSION = SHORT_VERSION + '+' + sha
77 |
78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
79 | VERSION_INFO)
80 | with open(version_file, 'w') as f:
81 | f.write(version_file_str)
82 |
83 |
84 | def get_version():
85 | with open(version_file, 'r') as f:
86 | exec(compile(f.read(), version_file, 'exec'))
87 | return locals()['__version__']
88 |
89 |
90 | def make_cuda_ext(name, module, sources, sources_cuda=None):
91 | if sources_cuda is None:
92 | sources_cuda = []
93 | define_macros = []
94 | extra_compile_args = {'cxx': []}
95 |
96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
97 | define_macros += [('WITH_CUDA', None)]
98 | extension = CUDAExtension
99 | extra_compile_args['nvcc'] = [
100 | '-D__CUDA_NO_HALF_OPERATORS__',
101 | '-D__CUDA_NO_HALF_CONVERSIONS__',
102 | '-D__CUDA_NO_HALF2_OPERATORS__',
103 | ]
104 | sources += sources_cuda
105 | else:
106 | print(f'Compiling {name} without CUDA')
107 | extension = CppExtension
108 |
109 | return extension(
110 | name=f'{module}.{name}',
111 | sources=[os.path.join(*module.split('.'), p) for p in sources],
112 | define_macros=define_macros,
113 | extra_compile_args=extra_compile_args)
114 |
115 |
116 | def get_requirements(filename='requirements.txt'):
117 | return []
118 | here = os.path.dirname(os.path.realpath(__file__))
119 | with open(os.path.join(here, filename), 'r') as f:
120 | requires = [line.replace('\n', '') for line in f.readlines()]
121 | return requires
122 |
123 |
124 | if __name__ == '__main__':
125 | if '--no_cuda_ext' in sys.argv:
126 | ext_modules = []
127 | sys.argv.remove('--no_cuda_ext')
128 | else:
129 | ext_modules = [
130 | make_cuda_ext(
131 | name='deform_conv_ext',
132 | module='basicsr.models.ops.dcn',
133 | sources=['src/deform_conv_ext.cpp'],
134 | sources_cuda=[
135 | 'src/deform_conv_cuda.cpp',
136 | 'src/deform_conv_cuda_kernel.cu'
137 | ]),
138 | make_cuda_ext(
139 | name='fused_act_ext',
140 | module='basicsr.models.ops.fused_act',
141 | sources=['src/fused_bias_act.cpp'],
142 | sources_cuda=['src/fused_bias_act_kernel.cu']),
143 | make_cuda_ext(
144 | name='upfirdn2d_ext',
145 | module='basicsr.models.ops.upfirdn2d',
146 | sources=['src/upfirdn2d.cpp'],
147 | sources_cuda=['src/upfirdn2d_kernel.cu']),
148 | ]
149 |
150 | write_version_py()
151 | setup(
152 | name='basicsr',
153 | version=get_version(),
154 | description='Open Source Image and Video Super-Resolution Toolbox',
155 | long_description=readme(),
156 | author='Xintao Wang',
157 | author_email='xintao.wang@outlook.com',
158 | keywords='computer vision, restoration, super resolution',
159 | url='https://github.com/xinntao/BasicSR',
160 | packages=find_packages(
161 | exclude=('options', 'datasets', 'experiments', 'results',
162 | 'tb_logger', 'wandb')),
163 | classifiers=[
164 | 'Development Status :: 4 - Beta',
165 | 'License :: OSI Approved :: Apache Software License',
166 | 'Operating System :: OS Independent',
167 | 'Programming Language :: Python :: 3',
168 | 'Programming Language :: Python :: 3.7',
169 | 'Programming Language :: Python :: 3.8',
170 | ],
171 | license='Apache License 2.0',
172 | setup_requires=['cython', 'numpy'],
173 | install_requires=get_requirements(),
174 | ext_modules=ext_modules,
175 | cmdclass={'build_ext': BuildExtension},
176 | zip_safe=False)
177 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 |
2 | #dehaze
3 | # python test_SOTS.py
4 |
5 | #derain
6 | # python test_spad.py
7 |
8 | #deraindrop
9 | # python test_AGAN.py
10 |
11 | #deblur
12 | # python test_FPro.py
13 |
14 | #demoire
15 | # python test_moire.py
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 |
5 | python -m torch.distributed.launch --nproc_per_node=2 --master_port=4321 basicsr/train.py -opt $CONFIG --launcher pytorch
6 |
--------------------------------------------------------------------------------