├── .DS_Store
├── Dehaze
├── Options
│ └── RealDehazing_HINT.yml
├── evaluate_SOTS.py
├── test_SOTS_HINT.py
└── utils.py
├── Denoising
├── Options
│ └── GaussianColorDenoising_HINT.yml
├── evaluate_gaussian_color_denoising_HINT.py
├── test_gaussian_color_denoising_HINT.py
└── utils.py
├── Deraining
├── Options
│ └── Deraining_HINT_syn_rain100L.yml
├── evaluate_PSNR_SSIM.m
├── test_rain100L.py
└── utils.py
├── Desnowing
├── Options
│ └── Desnow_snow100k_HINT.yml
├── evaluate_Snow100k.py
├── test_snow100k.py
└── utils.py
├── Enhancement
├── Options
│ ├── HINT_LOL_v2_real.yml
│ └── HINT_LOL_v2_synthetic.yml
├── test_from_dataset_LOLv2_Real.py
├── test_from_dataset_LOLv2_Syn.py
└── utils.py
├── README.md
├── VERSION
├── basicsr
├── .DS_Store
├── __pycache__
│ ├── version.cpython-37.pyc
│ └── version.cpython-38.pyc
├── data
│ ├── SDSD_image_dataset.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── SDSD_image_dataset.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── data_sampler.cpython-37.pyc
│ │ ├── data_util.cpython-37.pyc
│ │ ├── ffhq_dataset.cpython-37.pyc
│ │ ├── paired_image_dataset.cpython-37.pyc
│ │ ├── prefetch_dataloader.cpython-37.pyc
│ │ ├── reds_dataset.cpython-37.pyc
│ │ ├── single_image_dataset.cpython-37.pyc
│ │ ├── transforms.cpython-37.pyc
│ │ ├── util.cpython-37.pyc
│ │ ├── video_test_dataset.cpython-37.pyc
│ │ └── vimeo90k_dataset.cpython-37.pyc
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── ffhq_dataset.py
│ ├── meta_info
│ │ ├── meta_info_DIV2K800sub_GT.txt
│ │ ├── meta_info_REDS4_test_GT.txt
│ │ ├── meta_info_REDS_GT.txt
│ │ ├── meta_info_REDSofficial4_test_GT.txt
│ │ ├── meta_info_REDSval_official_test_GT.txt
│ │ ├── meta_info_Vimeo90K_test_GT.txt
│ │ ├── meta_info_Vimeo90K_test_fast_GT.txt
│ │ ├── meta_info_Vimeo90K_test_medium_GT.txt
│ │ ├── meta_info_Vimeo90K_test_slow_GT.txt
│ │ └── meta_info_Vimeo90K_train_GT.txt
│ ├── paired_image_dataset.py
│ ├── prefetch_dataloader.py
│ ├── reds_dataset.py
│ ├── single_image_dataset.py
│ ├── transforms.py
│ ├── util.py
│ ├── video_test_dataset.py
│ └── vimeo90k_dataset.py
├── metrics
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── metric_util.cpython-37.pyc
│ │ ├── metric_util.cpython-38.pyc
│ │ ├── niqe.cpython-37.pyc
│ │ ├── niqe.cpython-38.pyc
│ │ ├── psnr_ssim.cpython-37.pyc
│ │ └── psnr_ssim.cpython-38.pyc
│ ├── fid.py
│ ├── metric_util.py
│ ├── niqe.py
│ ├── niqe_pris_params.npz
│ └── psnr_ssim.py
├── models
│ ├── .DS_Store
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── base_model.cpython-37.pyc
│ │ ├── base_model.cpython-38.pyc
│ │ ├── image_restoration_model.cpython-37.pyc
│ │ ├── image_restoration_model.cpython-38.pyc
│ │ ├── lr_scheduler.cpython-37.pyc
│ │ └── lr_scheduler.cpython-38.pyc
│ ├── archs
│ │ ├── HINT_arch.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── FPro_arch.cpython-37.pyc
│ │ │ ├── HINT_arch.cpython-37.pyc
│ │ │ ├── HINT_arch.cpython-38.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── arch_util.cpython-37.pyc
│ │ │ ├── graph_layers.cpython-37.pyc
│ │ │ ├── local_arch.cpython-37.pyc
│ │ │ ├── restormer_arch.cpython-37.pyc
│ │ │ ├── restormer_arch.py
│ │ │ └── restormer_local_arch.cpython-37.pyc
│ │ └── arch_util.py
│ ├── base_model.py
│ ├── image_restoration_model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── loss_util.cpython-37.pyc
│ │ │ ├── loss_util.cpython-38.pyc
│ │ │ ├── losses.cpython-37.pyc
│ │ │ └── losses.cpython-38.pyc
│ │ ├── loss_util.py
│ │ └── losses.py
│ └── lr_scheduler.py
├── test.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── create_lmdb.cpython-37.pyc
│ │ ├── create_lmdb.cpython-38.pyc
│ │ ├── dist_util.cpython-37.pyc
│ │ ├── dist_util.cpython-38.pyc
│ │ ├── file_client.cpython-37.pyc
│ │ ├── file_client.cpython-38.pyc
│ │ ├── flow_util.cpython-37.pyc
│ │ ├── img_util.cpython-37.pyc
│ │ ├── img_util.cpython-38.pyc
│ │ ├── lmdb_util.cpython-37.pyc
│ │ ├── lmdb_util.cpython-38.pyc
│ │ ├── logger.cpython-37.pyc
│ │ ├── logger.cpython-38.pyc
│ │ ├── matlab_functions.cpython-37.pyc
│ │ ├── matlab_functions.cpython-38.pyc
│ │ ├── misc.cpython-37.pyc
│ │ ├── misc.cpython-38.pyc
│ │ ├── options.cpython-37.pyc
│ │ └── options.cpython-38.pyc
│ ├── bundle_submissions.py
│ ├── create_lmdb.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── face_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ └── options.py
└── version.py
├── environment.yml
├── setup.py
├── test.sh
└── train.sh
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/.DS_Store
--------------------------------------------------------------------------------
/Dehaze/Options/RealDehazing_HINT.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Dehazing_HINT
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage_dehazeSOT
13 | dataroot_gt: ./dataset/haze
14 | dataroot_lq: ./dataset/haze
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ## ------- Training on single fixed-patch size 128x128---------
27 | mini_batch_sizes: [6,1]
28 | iters: [200000,100000]
29 | gt_size: 256
30 | gt_sizes: [128,256]
31 | ## ------------------------------------------------------------
32 |
33 | dataset_enlarge_ratio: 1
34 | prefetch_mode: ~
35 |
36 | val:
37 | name: ValSet
38 | type: Dataset_PairedImage_dehazeSOT
39 | dataroot_gt: ./dataset/haze
40 | dataroot_lq: ./dataset/haze
41 | gt_size: 256
42 | io_backend:
43 | type: disk
44 |
45 | # network structures
46 |
47 | # network structures
48 | network_g:
49 | type: HINT
50 | inp_channels: 3
51 | out_channels: 3
52 | dim: 48
53 | num_blocks: [4,6,6,8]
54 | num_refinement_blocks: 4
55 | heads: [8,8,8,8]
56 | ffn_expansion_factor: 2.66
57 | bias: False
58 | LayerNorm_type: WithBias
59 | dual_pixel_task: False
60 |
61 |
62 | # path
63 | path:
64 | pretrain_network_g: ~
65 | strict_load_g: true
66 | resume_state: ~
67 |
68 | # training settings
69 | train:
70 | total_iter: 300000
71 | warmup_iter: -1 # no warm up
72 | use_grad_clip: true
73 |
74 | # Split 300k iterations into two cycles.
75 | # 1st cycle: fixed 3e-4 LR for 92k iters.
76 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
77 | scheduler:
78 | type: CosineAnnealingRestartCyclicLR
79 | periods: [92000, 208000]
80 | restart_weights: [1,1]
81 | eta_mins: [0.0003,0.000001]
82 |
83 | mixing_augs:
84 | mixup: true
85 | mixup_beta: 1.2
86 | use_identity: true
87 |
88 | optim_g:
89 | type: AdamW
90 | lr: !!float 3e-4
91 | weight_decay: !!float 1e-4
92 | betas: [0.9, 0.999]
93 |
94 | # losses
95 | pixel_opt:
96 | type: L1Loss
97 | loss_weight: 1
98 | reduction: mean
99 | fft_loss_opt:
100 | type: FFTLoss
101 | loss_weight: 0.1
102 | reduction: mean
103 |
104 | # validation settings
105 | val:
106 | window_size: 8
107 | val_freq: !!float 4e3
108 | save_img: false
109 | rgb2bgr: true
110 | use_image: false
111 | max_minibatch: 8
112 |
113 | metrics:
114 | psnr: # metric name, can be arbitrary
115 | type: calculate_psnr
116 | crop_border: 0
117 | test_y_channel: false
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 1000
122 | save_checkpoint_freq: !!float 4e3
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/Dehaze/evaluate_SOTS.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from glob import glob
4 | from natsort import natsorted
5 | from skimage import io
6 | import cv2
7 | import argparse
8 | from skimage.metrics import structural_similarity
9 | from tqdm import tqdm
10 | import concurrent.futures
11 | import utils
12 |
13 | def proc(filename):
14 | tar,prd = filename
15 | prd_name = prd.split('/')[-1].split('_')[0]+'.png'
16 | tar_name = './dataset/haze/promptIR/outdoor/gt/' + prd_name
17 | tar_img = utils.load_img(tar_name)
18 | prd_img = utils.load_img(prd)
19 |
20 | PSNR = utils.calculate_psnr(tar_img, prd_img)
21 | SSIM = utils.calculate_ssim(tar_img, prd_img)
22 | return PSNR,SSIM
23 |
24 | parser = argparse.ArgumentParser(description='Dehazing using HINT')
25 |
26 | args = parser.parse_args()
27 |
28 |
29 | datasets = ['outdoor']
30 |
31 | for dataset in datasets:
32 |
33 | gt_path = os.path.join('./dataset/haze/promptIR/outdoor/gt')
34 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif')))
35 | assert len(gt_list) != 0, "Target files not found"
36 |
37 |
38 | file_path = os.path.join('results', 'HINT', dataset)
39 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif')))
40 | assert len(path_list) != 0, "Predicted files not found"
41 |
42 | psnr, ssim = [], []
43 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
44 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
45 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
46 | psnr.append(PSNR_SSIM[0])
47 | ssim.append(PSNR_SSIM[1])
48 |
49 | avg_psnr = sum(psnr)/len(psnr)
50 | avg_ssim = sum(ssim)/len(ssim)
51 |
52 | # print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr))
53 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
54 |
--------------------------------------------------------------------------------
/Dehaze/test_SOTS_HINT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch.nn as nn
7 | import torch
8 | import torch.nn.functional as F
9 | import utils
10 |
11 | from natsort import natsorted
12 | from glob import glob
13 | from basicsr.models.archs.HINT_arch import HINT
14 | from skimage import img_as_ubyte
15 | from pdb import set_trace as stx
16 |
17 | parser = argparse.ArgumentParser(description='Image Dehazning using HINT')
18 |
19 | parser.add_argument('--input_dir', default='./dataset/haze/promptIR/', type=str, help='Directory of validation images')
20 | parser.add_argument('--result_dir', default='./results/HINT/', type=str, help='Directory for results')
21 | parser.add_argument('--weights', default='./models/Dehazing.pth', type=str, help='Path to weights')
22 |
23 | args = parser.parse_args()
24 |
25 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
26 | _, C, H, W = imgtensor.shape
27 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
28 | while hstarts and hstarts[-1] + crop_size >= H:
29 | hstarts.pop()
30 | hstarts.append(H - crop_size)
31 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
32 | while wstarts and wstarts[-1] + crop_size >= W:
33 | wstarts.pop()
34 | wstarts.append(W - crop_size)
35 | starts = []
36 | split_data = []
37 | for hs in hstarts:
38 | for ws in wstarts:
39 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
40 | starts.append((hs, ws))
41 | split_data.append(cimgdata)
42 | return split_data, starts
43 |
44 | def get_scoremap(H, W, C, B=1, is_mean=True):
45 | center_h = H / 2
46 | center_w = W / 2
47 |
48 | score = torch.ones((B, C, H, W))
49 | if not is_mean:
50 | for h in range(H):
51 | for w in range(W):
52 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
53 | return score
54 |
55 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
56 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
57 | tot_score = torch.zeros((B, C, H, W))
58 | merge_img = torch.zeros((B, C, H, W))
59 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
60 | for simg, cstart in zip(split_data, starts):
61 | hs, ws = cstart
62 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
63 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
64 | merge_img = merge_img / tot_score
65 | return merge_img
66 |
67 | ####### Load yaml #######
68 | yaml_file = 'Options/RealDehazing_HINT.yml'
69 | import yaml
70 |
71 | try:
72 | from yaml import CLoader as Loader
73 | except ImportError:
74 | from yaml import Loader
75 |
76 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
77 |
78 | s = x['network_g'].pop('type')
79 | ##########################
80 |
81 | model_restoration = HINT(**x['network_g'])
82 |
83 | checkpoint = torch.load(args.weights)
84 | model_restoration.load_state_dict(checkpoint['params'])
85 | print("===>Testing using weights: ",args.weights)
86 | model_restoration.cuda()
87 | model_restoration = nn.DataParallel(model_restoration)
88 | model_restoration.eval()
89 |
90 |
91 | factor = 8
92 | datasets = ['outdoor']
93 |
94 | for dataset in datasets:
95 | result_dir = os.path.join(args.result_dir, dataset)
96 | os.makedirs(result_dir, exist_ok=True)
97 |
98 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain')
99 | inp_dir = os.path.join(args.input_dir, dataset, 'hazy/')
100 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
101 | with torch.no_grad():
102 | for file_ in tqdm(files):
103 | torch.cuda.ipc_collect()
104 | torch.cuda.empty_cache()
105 |
106 | img = np.float32(utils.load_img(file_))/255.
107 | img = torch.from_numpy(img).permute(2,0,1)
108 | input_ = img.unsqueeze(0).cuda()
109 |
110 | # Padding in case images are not multiples of 8
111 | h,w = input_.shape[2], input_.shape[3]
112 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
113 | padh = H-h if h%factor!=0 else 0
114 | padw = W-w if w%factor!=0 else 0
115 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
116 |
117 | restored = model_restoration(input_)
118 |
119 | restored = restored[:,:,:h,:w]
120 |
121 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
122 |
123 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
124 |
--------------------------------------------------------------------------------
/Dehaze/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Denoising/Options/GaussianColorDenoising_HINT.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: GaussianColorDenoising_HINT
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_GaussianDenoising
13 | sigma_type: random
14 | sigma_range: [0,50]
15 | in_ch: 3 ## RGB image
16 | dataroot_gt: ./Denoising/Datasets/train/WB
17 | dataroot_lq: none
18 | geometric_augs: true
19 |
20 | filename_tmpl: '{}'
21 | io_backend:
22 | type: disk
23 |
24 | # data loader
25 | use_shuffle: true
26 | num_worker_per_gpu: 8
27 | batch_size_per_gpu: 8
28 |
29 | # -------------Progressive training--------------------------
30 | mini_batch_sizes: [6,4,3,1,1,1] # Batch size per gpu
31 | iters: [92000,64000,48000,36000,36000,24000]
32 | gt_size: 256 # Max patch size for progressive training
33 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
34 | ### ------------------------------------------------------------
35 |
36 | dataset_enlarge_ratio: 1
37 | prefetch_mode: ~
38 |
39 | val:
40 | name: ValSet
41 | type: Dataset_GaussianDenoising
42 | sigma_test: 25
43 | in_ch: 3 ## RGB image
44 | dataroot_gt: ./Denoising/Datasets/test/CBSD68
45 | dataroot_lq: none
46 | gt_size: 256
47 | io_backend:
48 | type: disk
49 |
50 | # network structures
51 | network_g:
52 | type: HINT
53 | inp_channels: 3
54 | out_channels: 3
55 | dim: 48
56 | num_blocks: [4,6,6,8]
57 | num_refinement_blocks: 4
58 | heads: [8,8,8,8]
59 | ffn_expansion_factor: 2.66
60 | bias: False
61 | LayerNorm_type: WithBias
62 | dual_pixel_task: False
63 | # path
64 | path:
65 | pretrain_network_g: ~
66 | strict_load_g: true
67 | resume_state: ~
68 |
69 | # training settings
70 | train:
71 | total_iter: 300000
72 | warmup_iter: -1 # no warm up
73 | use_grad_clip: true
74 |
75 | # Split 300k iterations into two cycles.
76 | # 1st cycle: fixed 3e-4 LR for 92k iters.
77 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
78 | scheduler:
79 | type: CosineAnnealingRestartCyclicLR
80 | periods: [92000, 208000]
81 | restart_weights: [1,1]
82 | eta_mins: [0.0003,0.000001]
83 |
84 | mixing_augs:
85 | mixup: true
86 | mixup_beta: 1.2
87 | use_identity: true
88 |
89 | optim_g:
90 | type: AdamW
91 | lr: !!float 3e-4
92 | weight_decay: !!float 1e-4
93 | betas: [0.9, 0.999]
94 |
95 | # losses
96 | pixel_opt:
97 | type: L1Loss
98 | loss_weight: 1
99 | reduction: mean
100 | fft_loss_opt:
101 | type: FFTLoss
102 | loss_weight: 0.1
103 | reduction: mean
104 | # validation settings
105 | val:
106 | window_size: 8
107 | val_freq: !!float 4e3
108 | save_img: false
109 | rgb2bgr: true
110 | use_image: false
111 | max_minibatch: 8
112 |
113 | metrics:
114 | psnr: # metric name, can be arbitrary
115 | type: calculate_psnr
116 | crop_border: 0
117 | test_y_channel: false
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 1000
122 | save_checkpoint_freq: !!float 4e3
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/Denoising/evaluate_gaussian_color_denoising_HINT.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from glob import glob
4 | from natsort import natsorted
5 | from skimage import io
6 | import cv2
7 | import argparse
8 | from skimage.metrics import structural_similarity
9 | from tqdm import tqdm
10 | import concurrent.futures
11 | import utils
12 |
13 | def proc(filename):
14 | tar,prd = filename
15 | tar_img = utils.load_img(tar)
16 | prd_img = utils.load_img(prd)
17 |
18 | PSNR = utils.calculate_psnr(tar_img, prd_img)
19 | SSIM = utils.calculate_ssim(tar_img, prd_img)
20 | return PSNR,SSIM
21 |
22 | parser = argparse.ArgumentParser(description='Gasussian Color Denoising using HINT')
23 |
24 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
25 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
26 |
27 | args = parser.parse_args()
28 |
29 | sigmas = np.int_(args.sigmas.split(','))
30 |
31 | datasets = ['CBSD68','Urban100']
32 |
33 | for dataset in datasets:
34 |
35 | gt_path = os.path.join('./Denoising/Datasets','test', dataset)
36 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.tif')))
37 | assert len(gt_list) != 0, "Target files not found"
38 |
39 | for sigma_test in sigmas:
40 | file_path = os.path.join('results', 'Gaussian_Color_Denoising', args.model_type, dataset, str(sigma_test))
41 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.tif')))
42 | assert len(path_list) != 0, "Predicted files not found"
43 |
44 | psnr, ssim = [], []
45 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
46 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
47 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
48 | psnr.append(PSNR_SSIM[0])
49 | ssim.append(PSNR_SSIM[1])
50 |
51 | avg_psnr = sum(psnr)/len(psnr)
52 | avg_ssim = sum(ssim)/len(ssim)
53 |
54 | print('For {:s} dataset Noise Level {:d} PSNR: {:f}\n'.format(dataset, sigma_test, avg_psnr))
55 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
56 |
--------------------------------------------------------------------------------
/Denoising/test_gaussian_color_denoising_HINT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch.nn as nn
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from basicsr.models.archs.HINT_arch import HINT
11 | from skimage import img_as_ubyte
12 | from natsort import natsorted
13 | from glob import glob
14 | import utils
15 | from pdb import set_trace as stx
16 |
17 | parser = argparse.ArgumentParser(description='Gaussian Color Denoising using HINT')
18 |
19 | parser.add_argument('--input_dir', default='./Denoising/Datasets/test/', type=str, help='Directory of validation images')
20 | parser.add_argument('--result_dir', default='./results/Gaussian_Color_Denoising/', type=str, help='Directory for results')
21 | parser.add_argument('--weights', default='./models/net_g_latest', type=str, help='Path to weights')
22 | parser.add_argument('--model_type', required=True, choices=['non_blind','blind'], type=str, help='blind: single model to handle various noise levels. non_blind: separate model for each noise level.')
23 | parser.add_argument('--sigmas', default='15,25,50', type=str, help='Sigma values')
24 |
25 | args = parser.parse_args()
26 |
27 | ####### Load yaml #######
28 | if args.model_type == 'blind':
29 | yaml_file = 'Options/GaussianColorDenoising_HINT.yml'
30 | else:
31 | yaml_file = f'Options/GaussianColorDenoising_RestormerSigma{args.sigmas}.yml'
32 | import yaml
33 |
34 | try:
35 | from yaml import CLoader as Loader
36 | except ImportError:
37 | from yaml import Loader
38 |
39 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
40 |
41 | s = x['network_g'].pop('type')
42 | ##########################
43 |
44 | sigmas = np.int_(args.sigmas.split(','))
45 |
46 | factor = 8
47 |
48 | datasets = ['CBSD68','Urban100']
49 |
50 | for sigma_test in sigmas:
51 | print("Compute results for noise level",sigma_test)
52 | model_restoration = HINT(**x['network_g'])
53 | if args.model_type == 'blind':
54 | weights = args.weights+'_blind.pth'
55 | else:
56 | weights = args.weights + '_sigma' + str(sigma_test) +'.pth'
57 | checkpoint = torch.load(weights)
58 | model_restoration.load_state_dict(checkpoint['params'])
59 |
60 | print("===>Testing using weights: ",weights)
61 | print("------------------------------------------------")
62 | model_restoration.cuda()
63 | model_restoration = nn.DataParallel(model_restoration)
64 | model_restoration.eval()
65 |
66 | for dataset in datasets:
67 | inp_dir = os.path.join(args.input_dir, dataset)
68 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.tif')))
69 | result_dir_tmp = os.path.join(args.result_dir, args.model_type, dataset, str(sigma_test))
70 | os.makedirs(result_dir_tmp, exist_ok=True)
71 |
72 | with torch.no_grad():
73 | for file_ in tqdm(files):
74 | torch.cuda.ipc_collect()
75 | torch.cuda.empty_cache()
76 | img = np.float32(utils.load_img(file_))/255.
77 |
78 | np.random.seed(seed=0) # for reproducibility
79 | img += np.random.normal(0, sigma_test/255., img.shape)
80 |
81 | img = torch.from_numpy(img).permute(2,0,1)
82 | input_ = img.unsqueeze(0).cuda()
83 |
84 | # Padding in case images are not multiples of 8
85 | h,w = input_.shape[2], input_.shape[3]
86 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
87 | padh = H-h if h%factor!=0 else 0
88 | padw = W-w if w%factor!=0 else 0
89 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
90 |
91 | restored = model_restoration(input_)
92 |
93 | # Unpad images to original dimensions
94 | restored = restored[:,:,:h,:w]
95 |
96 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
97 |
98 | save_file = os.path.join(result_dir_tmp, os.path.split(file_)[-1])
99 | utils.save_img(save_file, img_as_ubyte(restored))
100 |
--------------------------------------------------------------------------------
/Denoising/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Deraining/Options/Deraining_HINT_syn_rain100L.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Deraining_HINT_rain100L
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 4 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./dataset/Rain100L/train/clean
14 | dataroot_lq: ./dataset/Rain100L/train/rainy
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ### -------------Progressive training--------------------------
27 | mini_batch_sizes: [6,4,3,1] # Batch size per gpu
28 | iters: [92000,64000,48000,96000]
29 | gt_size: 384 # Max patch size for progressive training
30 | gt_sizes: [128,160,192,256] # Patch sizes for progressive training.
31 | ### ------------------------------------------------------------
32 |
33 | ### ------- Training on single fixed-patch size 128x128---------
34 | # mini_batch_sizes: [8]
35 | # iters: [300000]
36 | # gt_size: 128
37 | # gt_sizes: [128]
38 | ### ------------------------------------------------------------
39 |
40 | dataset_enlarge_ratio: 1
41 | prefetch_mode: ~
42 |
43 | val:
44 | name: ValSet
45 | type: Dataset_PairedImage
46 | dataroot_gt: ./dataset/Rain100L/test/clean
47 | dataroot_lq: ./dataset/Rain100L/test/rainy
48 | io_backend:
49 | type: disk
50 |
51 | # network structures
52 | network_g:
53 | type: HINT
54 | inp_channels: 3
55 | out_channels: 3
56 | dim: 48
57 | num_blocks: [4,6,6,8]
58 | num_refinement_blocks: 4
59 | heads: [8,8,8,8]
60 | ffn_expansion_factor: 2.66
61 | bias: False
62 | LayerNorm_type: WithBias
63 | dual_pixel_task: false
64 |
65 |
66 | # path
67 | path:
68 | pretrain_network_g: ~
69 | strict_load_g: true
70 | resume_state: ~
71 |
72 | # training settings
73 | train:
74 | total_iter: 300000
75 | warmup_iter: -1 # no warm up
76 | use_grad_clip: true
77 |
78 | # Split 300k iterations into two cycles.
79 | # 1st cycle: fixed 3e-4 LR for 92k iters.
80 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
81 | scheduler:
82 | type: CosineAnnealingRestartCyclicLR
83 | periods: [92000, 208000]
84 | restart_weights: [1,1]
85 | eta_mins: [0.0003,0.000001]
86 |
87 | mixing_augs:
88 | mixup: false
89 | mixup_beta: 1.2
90 | use_identity: true
91 |
92 | optim_g:
93 | type: AdamW
94 | lr: !!float 3e-4
95 | weight_decay: !!float 1e-4
96 | betas: [0.9, 0.999]
97 |
98 | # losses
99 | pixel_opt:
100 | type: L1Loss
101 | loss_weight: 1
102 | reduction: mean
103 | fft_loss_opt:
104 | type: FFTLoss
105 | loss_weight: 0.1
106 | reduction: mean
107 |
108 |
109 | # validation settings
110 | val:
111 | window_size: 8
112 | val_freq: !!float 4e3
113 | save_img: false
114 | rgb2bgr: true
115 | use_image: true
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: true
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 1000
127 | save_checkpoint_freq: !!float 4e3
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
137 |
--------------------------------------------------------------------------------
/Deraining/test_rain100L.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 |
6 |
7 | import numpy as np
8 | import os
9 | import argparse
10 | from tqdm import tqdm
11 |
12 | import torch.nn as nn
13 | import torch
14 | import torch.nn.functional as F
15 | import utils
16 |
17 | from natsort import natsorted
18 | from glob import glob
19 | from basicsr.models.archs.HINT_arch import HINT
20 | from skimage import img_as_ubyte
21 | from pdb import set_trace as stx
22 |
23 | parser = argparse.ArgumentParser(description='Image Deraining using HINT')
24 |
25 | parser.add_argument('--input_dir', default='./dataset', type=str, help='Directory of validation images')
26 | parser.add_argument('--result_dir', default='./results/Rain100L_HINT/', type=str, help='Directory for results')
27 | parser.add_argument('--weights', default='./models/Rain100L_HINT.pth', type=str, help='Path to weights')
28 |
29 | args = parser.parse_args()
30 |
31 | ####### Load yaml #######
32 | yaml_file = 'Options/Deraining_HINT_syn_rain100L.yml'
33 | import yaml
34 |
35 | try:
36 | from yaml import CLoader as Loader
37 | except ImportError:
38 | from yaml import Loader
39 |
40 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
41 |
42 | s = x['network_g'].pop('type')
43 | ##########################
44 |
45 | model_restoration = HINT(**x['network_g'])
46 |
47 | checkpoint = torch.load(args.weights)
48 | model_restoration.load_state_dict(checkpoint['params'])
49 | print("===>Testing using weights: ",args.weights)
50 | model_restoration.cuda()
51 | model_restoration = nn.DataParallel(model_restoration)
52 | model_restoration.eval()
53 |
54 |
55 | factor = 8
56 | datasets = ['Rain100L']
57 |
58 | for dataset in datasets:
59 | result_dir = os.path.join(args.result_dir, dataset)
60 | os.makedirs(result_dir, exist_ok=True)
61 |
62 | inp_dir = os.path.join(args.input_dir, dataset,'test','rainy')
63 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
64 | with torch.no_grad():
65 | for file_ in tqdm(files):
66 | torch.cuda.ipc_collect()
67 | torch.cuda.empty_cache()
68 |
69 | img = np.float32(utils.load_img(file_))/255.
70 | img = torch.from_numpy(img).permute(2,0,1)
71 | input_ = img.unsqueeze(0).cuda()
72 |
73 | # Padding in case images are not multiples of 8
74 | h,w = input_.shape[2], input_.shape[3]
75 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
76 | padh = H-h if h%factor!=0 else 0
77 | padw = W-w if w%factor!=0 else 0
78 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
79 |
80 | restored = model_restoration(input_)
81 |
82 | # Unpad images to original dimensions
83 | restored = restored[:,:,:h,:w]
84 |
85 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
86 |
87 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
88 |
--------------------------------------------------------------------------------
/Deraining/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Desnowing/Options/Desnow_snow100k_HINT.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Desnow_HINT
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 8 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./dataset/Snow100K/train2500/Gt
14 | dataroot_lq: ./dataset/Snow100K/train2500/Snow
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ### ------- Training on single fixed-patch size 128x128---------
27 | mini_batch_sizes: [6,5,2,1,1]
28 | iters: [50000,40000,30000,20000,10000]
29 | gt_size: 128
30 | gt_sizes: [128,192,256,320,384]
31 | ### ------------------------------------------------------------
32 |
33 | dataset_enlarge_ratio: 1
34 | prefetch_mode: ~
35 |
36 | val:
37 | name: ValSet
38 | type: Dataset_PairedImage
39 | dataroot_gt: ./dataset/Snow100K/test2000/Gt
40 | dataroot_lq: ./dataset/Snow100K/test2000/Snow
41 | gt_size: 256
42 | io_backend:
43 | type: disk
44 |
45 | # network structures
46 | network_g:
47 | type: HINT
48 | inp_channels: 3
49 | out_channels: 3
50 | dim: 48
51 | num_blocks: [4,6,6,8]
52 | num_refinement_blocks: 4
53 | heads: [8,8,8,8]
54 | ffn_expansion_factor: 2.66
55 | bias: False
56 | LayerNorm_type: WithBias
57 | dual_pixel_task: False
58 |
59 |
60 | # path
61 | path:
62 | pretrain_network_g: ~
63 | strict_load_g: true
64 | resume_state: ~
65 |
66 | # training settings
67 | train:
68 | total_iter: 300000
69 | warmup_iter: -1 # no warm up
70 | use_grad_clip: true
71 |
72 | # Split 300k iterations into two cycles.
73 | # 1st cycle: fixed 3e-4 LR for 92k iters.
74 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
75 | scheduler:
76 | type: CosineAnnealingRestartCyclicLR
77 | periods: [92000, 208000]
78 | restart_weights: [1,1]
79 | eta_mins: [0.0003,0.000001]
80 |
81 | mixing_augs:
82 | mixup: true
83 | mixup_beta: 1.2
84 | use_identity: true
85 |
86 | optim_g:
87 | type: AdamW
88 | lr: !!float 3e-4
89 | weight_decay: !!float 1e-4
90 | betas: [0.9, 0.999]
91 |
92 | # losses
93 | pixel_opt:
94 | type: L1Loss
95 | loss_weight: 1
96 | reduction: mean
97 | fft_loss_opt:
98 | type: FFTLoss
99 | loss_weight: 0.1
100 | reduction: mean
101 |
102 | # validation settings
103 | val:
104 | window_size: 8
105 | val_freq: !!float 4e3
106 | save_img: false
107 | rgb2bgr: true
108 | use_image: false
109 | max_minibatch: 8
110 |
111 | metrics:
112 | psnr: # metric name, can be arbitrary
113 | type: calculate_psnr
114 | crop_border: 0
115 | test_y_channel: false
116 |
117 | # logging settings
118 | logger:
119 | print_freq: 1000
120 | save_checkpoint_freq: !!float 4e3
121 | use_tb_logger: true
122 | wandb:
123 | project: ~
124 | resume_id: ~
125 |
126 | # dist training settings
127 | dist_params:
128 | backend: nccl
129 | port: 29500
130 |
--------------------------------------------------------------------------------
/Desnowing/evaluate_Snow100k.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from natsort import natsorted
9 | from skimage import io
10 | import cv2
11 | import argparse
12 | from skimage.metrics import structural_similarity
13 | from tqdm import tqdm
14 | import concurrent.futures
15 | import utils
16 |
17 | def proc(filename):
18 | tar,prd = filename
19 | prd_name = prd.split('/')[-1]+'.png'
20 | t_name = prd.split('/')[-1].split('.')[0]+'.jpg'
21 | tar_name = './dataset/Snow100K/test2000/Gt/' + t_name
22 | tar_img = utils.load_img(tar_name)
23 | prd_img = utils.load_img(prd)
24 |
25 | PSNR = utils.calculate_psnr(tar_img, prd_img)
26 | SSIM = utils.calculate_ssim(tar_img, prd_img)
27 | return PSNR,SSIM
28 |
29 | parser = argparse.ArgumentParser(description='Desnowing using HINT')
30 |
31 | args = parser.parse_args()
32 |
33 |
34 | datasets = ['test2000']
35 |
36 | for dataset in datasets:
37 |
38 | gt_path = os.path.join('./dataset/Snow100K/test2000/Gt')
39 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
40 | assert len(gt_list) != 0, "Target files not found"
41 |
42 |
43 | file_path = os.path.join('results', 'HINT', dataset)
44 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
45 | assert len(path_list) != 0, "Predicted files not found"
46 |
47 | psnr, ssim = [], []
48 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
49 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
50 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
51 | psnr.append(PSNR_SSIM[0])
52 | ssim.append(PSNR_SSIM[1])
53 |
54 | avg_psnr = sum(psnr)/len(psnr)
55 | avg_ssim = sum(ssim)/len(ssim)
56 |
57 | # print('For {:s} dataset PSNR: {:f}\n'.format(dataset, avg_psnr))
58 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
59 |
--------------------------------------------------------------------------------
/Desnowing/test_snow100k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch.nn as nn
7 | import torch
8 | import torch.nn.functional as F
9 | import utils
10 |
11 | from natsort import natsorted
12 | from glob import glob
13 | from basicsr.models.archs.HINT_arch import HINT
14 | from skimage import img_as_ubyte
15 | from pdb import set_trace as stx
16 |
17 | parser = argparse.ArgumentParser(description='Image Desnowing using HINT')
18 |
19 | parser.add_argument('--input_dir', default='./dataset/Snow100K/', type=str, help='Directory of validation images')
20 | parser.add_argument('--result_dir', default='./results/HINT', type=str, help='Directory for results')
21 | parser.add_argument('--weights', default='./models/snow100k.pth', type=str, help='Path to weights')
22 |
23 | args = parser.parse_args()
24 |
25 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
26 | _, C, H, W = imgtensor.shape
27 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
28 | while hstarts and hstarts[-1] + crop_size >= H:
29 | hstarts.pop()
30 | hstarts.append(H - crop_size)
31 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
32 | while wstarts and wstarts[-1] + crop_size >= W:
33 | wstarts.pop()
34 | wstarts.append(W - crop_size)
35 | starts = []
36 | split_data = []
37 | for hs in hstarts:
38 | for ws in wstarts:
39 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
40 | starts.append((hs, ws))
41 | split_data.append(cimgdata)
42 | return split_data, starts
43 |
44 | def get_scoremap(H, W, C, B=1, is_mean=True):
45 | center_h = H / 2
46 | center_w = W / 2
47 |
48 | score = torch.ones((B, C, H, W))
49 | if not is_mean:
50 | for h in range(H):
51 | for w in range(W):
52 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
53 | return score
54 |
55 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
56 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
57 | tot_score = torch.zeros((B, C, H, W))
58 | merge_img = torch.zeros((B, C, H, W))
59 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
60 | for simg, cstart in zip(split_data, starts):
61 | hs, ws = cstart
62 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
63 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
64 | merge_img = merge_img / tot_score
65 | return merge_img
66 |
67 | ####### Load yaml #######
68 | yaml_file = 'Options/Desnow_snow100k_HINT.yml'
69 | import yaml
70 |
71 | try:
72 | from yaml import CLoader as Loader
73 | except ImportError:
74 | from yaml import Loader
75 |
76 | x = yaml.load(open(yaml_file, mode='r'), Loader=Loader)
77 |
78 | s = x['network_g'].pop('type')
79 | ##########################
80 |
81 | model_restoration = HINT(**x['network_g'])
82 |
83 | checkpoint = torch.load(args.weights)
84 | model_restoration.load_state_dict(checkpoint['params'])
85 | print("===>Testing using weights: ",args.weights)
86 | model_restoration.cuda()
87 | model_restoration = nn.DataParallel(model_restoration)
88 | model_restoration.eval()
89 |
90 |
91 | factor = 8
92 | datasets = ['test2000']
93 |
94 | for dataset in datasets:
95 | result_dir = os.path.join(args.result_dir, dataset)
96 | os.makedirs(result_dir, exist_ok=True)
97 |
98 | # inp_dir = os.path.join(args.input_dir, 'test', dataset, 'rain')
99 | inp_dir = os.path.join(args.input_dir, dataset, 'Snow/')
100 | files = natsorted(glob(os.path.join(inp_dir, '*.png')) + glob(os.path.join(inp_dir, '*.jpg')))
101 | with torch.no_grad():
102 | for file_ in tqdm(files):
103 | torch.cuda.ipc_collect()
104 | torch.cuda.empty_cache()
105 |
106 | img = np.float32(utils.load_img(file_))/255.
107 | img = torch.from_numpy(img).permute(2,0,1)
108 | input_ = img.unsqueeze(0).cuda()
109 |
110 | B, C, H, W = input_.shape
111 | corp_size_arg = 256
112 | overlap_size_arg = 128
113 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
114 | for i, data in enumerate(split_data):
115 | split_data[i] = model_restoration(data).cpu()
116 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
117 |
118 | restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
119 |
120 | utils.save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
121 |
--------------------------------------------------------------------------------
/Desnowing/utils.py:
--------------------------------------------------------------------------------
1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3 | ## https://arxiv.org/abs/2111.09881
4 |
5 | import numpy as np
6 | import os
7 | import cv2
8 | import math
9 |
10 | def calculate_psnr(img1, img2, border=0):
11 | # img1 and img2 have range [0, 255]
12 | #img1 = img1.squeeze()
13 | #img2 = img2.squeeze()
14 | if not img1.shape == img2.shape:
15 | raise ValueError('Input images must have the same dimensions.')
16 | h, w = img1.shape[:2]
17 | img1 = img1[border:h-border, border:w-border]
18 | img2 = img2[border:h-border, border:w-border]
19 |
20 | img1 = img1.astype(np.float64)
21 | img2 = img2.astype(np.float64)
22 | mse = np.mean((img1 - img2)**2)
23 | if mse == 0:
24 | return float('inf')
25 | return 20 * math.log10(255.0 / math.sqrt(mse))
26 |
27 |
28 | # --------------------------------------------
29 | # SSIM
30 | # --------------------------------------------
31 | def calculate_ssim(img1, img2, border=0):
32 | '''calculate SSIM
33 | the same outputs as MATLAB's
34 | img1, img2: [0, 255]
35 | '''
36 | #img1 = img1.squeeze()
37 | #img2 = img2.squeeze()
38 | if not img1.shape == img2.shape:
39 | raise ValueError('Input images must have the same dimensions.')
40 | h, w = img1.shape[:2]
41 | img1 = img1[border:h-border, border:w-border]
42 | img2 = img2[border:h-border, border:w-border]
43 |
44 | if img1.ndim == 2:
45 | return ssim(img1, img2)
46 | elif img1.ndim == 3:
47 | if img1.shape[2] == 3:
48 | ssims = []
49 | for i in range(3):
50 | ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
51 | return np.array(ssims).mean()
52 | elif img1.shape[2] == 1:
53 | return ssim(np.squeeze(img1), np.squeeze(img2))
54 | else:
55 | raise ValueError('Wrong input image dimensions.')
56 |
57 |
58 | def ssim(img1, img2):
59 | C1 = (0.01 * 255)**2
60 | C2 = (0.03 * 255)**2
61 |
62 | img1 = img1.astype(np.float64)
63 | img2 = img2.astype(np.float64)
64 | kernel = cv2.getGaussianKernel(11, 1.5)
65 | window = np.outer(kernel, kernel.transpose())
66 |
67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69 | mu1_sq = mu1**2
70 | mu2_sq = mu2**2
71 | mu1_mu2 = mu1 * mu2
72 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
73 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75 |
76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
77 | (sigma1_sq + sigma2_sq + C2))
78 | return ssim_map.mean()
79 |
80 | def load_img(filepath):
81 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
82 |
83 | def save_img(filepath, img):
84 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
85 |
86 | def load_gray_img(filepath):
87 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
88 |
89 | def save_gray_img(filepath, img):
90 | cv2.imwrite(filepath, img)
91 |
--------------------------------------------------------------------------------
/Enhancement/Options/HINT_LOL_v2_real.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Enhancement_HINT
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./dataset/LOLv2/Real_captured/Train/Normal
14 | dataroot_lq: ./dataset/LOLv2/Real_captured/Train/Low
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ## -------------Progressive training--------------------------
27 | mini_batch_sizes: [6] # Batch size per gpu
28 | iters: [150000]
29 | gt_size: 384 # Max patch size for progressive training
30 | gt_sizes: [128] # Patch sizes for progressive training.
31 | ### ------------------------------------------------------------
32 |
33 |
34 | dataset_enlarge_ratio: 1
35 | prefetch_mode: ~
36 |
37 | val:
38 | name: ValSet
39 | type: Dataset_PairedImage
40 | dataroot_gt: ./dataset/LOLv2/Real_captured/Test/Normal
41 | dataroot_lq: ./dataset/LOLv2/Real_captured/Test/Low
42 | io_backend:
43 | type: disk
44 |
45 | # network structures
46 | network_g:
47 | type: HINT
48 | inp_channels: 3
49 | out_channels: 3
50 | dim: 48
51 | num_blocks: [4,6,6,8]
52 | num_refinement_blocks: 4
53 | heads: [8,8,8,8]
54 | ffn_expansion_factor: 2.66
55 | bias: False
56 | LayerNorm_type: WithBias
57 | dual_pixel_task: False
58 |
59 | # path
60 | path:
61 | pretrain_network_g: ~
62 | strict_load_g: true
63 | resume_state: ~
64 |
65 | # training settings
66 | train:
67 | total_iter: 150000
68 | warmup_iter: -1 # no warm up
69 | use_grad_clip: true
70 |
71 | # Split 300k iterations into two cycles.
72 | # 1st cycle: fixed 3e-4 LR for 92k iters.
73 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
74 | scheduler:
75 | type: CosineAnnealingRestartCyclicLR
76 | periods: [46000, 104000]
77 | restart_weights: [1,1]
78 | eta_mins: [0.0003,0.000001]
79 |
80 | mixing_augs:
81 | mixup: true
82 | mixup_beta: 1.2
83 | use_identity: true
84 |
85 | optim_g:
86 | type: Adam
87 | lr: !!float 2e-4
88 | # weight_decay: !!float 1e-4
89 | betas: [0.9, 0.999]
90 |
91 | pixel_opt:
92 | type: L1Loss
93 | loss_weight: 1
94 | reduction: mean
95 |
96 | fft_loss_opt:
97 | type: FFTLoss
98 | loss_weight: 0.1
99 | reduction: mean
100 |
101 |
102 | # validation settings
103 | val:
104 | window_size: 4
105 | val_freq: !!float 1e3
106 | save_img: false
107 | rgb2bgr: true
108 | use_image: false
109 | max_minibatch: 8
110 |
111 | metrics:
112 | psnr: # metric name, can be arbitrary
113 | type: calculate_psnr
114 | crop_border: 0
115 | test_y_channel: false
116 |
117 | # logging settings
118 | logger:
119 | print_freq: 500
120 | save_checkpoint_freq: !!float 1e3
121 | use_tb_logger: true
122 | wandb:
123 | project: ~
124 | resume_id: ~
125 |
126 | # dist training settings
127 | dist_params:
128 | backend: nccl
129 | port: 29500
130 |
--------------------------------------------------------------------------------
/Enhancement/Options/HINT_LOL_v2_synthetic.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: Enhancement_HINT
3 | model_type: ImageCleanModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: Dataset_PairedImage
13 | dataroot_gt: ./dataset/LOLv2/Synthetic/Train/Normal
14 | dataroot_lq: ./dataset/LOLv2/Synthetic/Train/Low
15 | geometric_augs: true
16 |
17 | filename_tmpl: '{}'
18 | io_backend:
19 | type: disk
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 8
24 | batch_size_per_gpu: 8
25 |
26 | ### ------- Training on single fixed-patch size 128x128---------
27 | mini_batch_sizes: [6,5,2,1,1]
28 | iters: [50000,40000,30000,20000,10000]
29 | gt_size: 128
30 | gt_sizes: [128,192,256,320,384]
31 | ### ------------------------------------------------------------
32 |
33 | dataset_enlarge_ratio: 1
34 | prefetch_mode: ~
35 |
36 | val:
37 | name: ValSet
38 | type: Dataset_PairedImage
39 | dataroot_gt: ./dataset/LOLv2/Synthetic/Test/Normal
40 | dataroot_lq: ./dataset/LOLv2/Synthetic/Test/Low
41 | io_backend:
42 | type: disk
43 |
44 | # network structures
45 | network_g:
46 | type: HINT
47 | inp_channels: 3
48 | out_channels: 3
49 | dim: 48
50 | num_blocks: [4,6,6,8]
51 | num_refinement_blocks: 4
52 | heads: [8,8,8,8]
53 | ffn_expansion_factor: 2.66
54 | bias: False
55 | LayerNorm_type: WithBias
56 | dual_pixel_task: False
57 |
58 |
59 | # path
60 | path:
61 | pretrain_network_g: ~
62 | strict_load_g: true
63 | resume_state: ~
64 |
65 | # training settings
66 | train:
67 | total_iter: 150000
68 | warmup_iter: -1 # no warm up
69 | use_grad_clip: true
70 |
71 | # Split 300k iterations into two cycles.
72 | # 1st cycle: fixed 3e-4 LR for 92k iters.
73 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
74 | scheduler:
75 | type: CosineAnnealingRestartCyclicLR
76 | periods: [46000, 104000]
77 | restart_weights: [1,1]
78 | eta_mins: [0.0003,0.000001]
79 |
80 | mixing_augs:
81 | mixup: true
82 | mixup_beta: 1.2
83 | use_identity: true
84 |
85 | optim_g:
86 | type: Adam
87 | lr: !!float 2e-4
88 | # weight_decay: !!float 1e-4
89 | betas: [0.9, 0.999]
90 |
91 | pixel_opt:
92 | type: L1Loss
93 | loss_weight: 1
94 | reduction: mean
95 |
96 | fft_loss_opt:
97 | type: FFTLoss
98 | loss_weight: 0.1
99 | reduction: mean
100 |
101 | # validation settings
102 | val:
103 | window_size: 4
104 | val_freq: !!float 1e3
105 | save_img: false
106 | rgb2bgr: true
107 | use_image: false
108 | max_minibatch: 8
109 |
110 | metrics:
111 | psnr: # metric name, can be arbitrary
112 | type: calculate_psnr
113 | crop_border: 0
114 | test_y_channel: false
115 |
116 | # logging settings
117 | logger:
118 | print_freq: 500
119 | save_checkpoint_freq: !!float 1e3
120 | use_tb_logger: true
121 | wandb:
122 | project: ~
123 | resume_id: ~
124 |
125 | # dist training settings
126 | dist_params:
127 | backend: nccl
128 | port: 29500
129 |
--------------------------------------------------------------------------------
/Enhancement/utils.py:
--------------------------------------------------------------------------------
1 | # Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement
2 | # Yuanhao Cai, Hao Bian, Jing Lin, Haoqian Wang, Radu Timofte, Yulun Zhang
3 | # International Conference on Computer Vision (ICCV), 2023
4 | # https://arxiv.org/abs/2303.06705
5 | # https://github.com/caiyuanhao1998/Retinexformer
6 |
7 | import numpy as np
8 | import os
9 | import cv2
10 | import math
11 | from pdb import set_trace as stx
12 |
13 |
14 | def calculate_psnr(img1, img2, border=0):
15 | # img1 and img2 have range [0, 255]
16 | #img1 = img1.squeeze()
17 | #img2 = img2.squeeze()
18 | if not img1.shape == img2.shape:
19 | raise ValueError('Input images must have the same dimensions.')
20 | h, w = img1.shape[:2]
21 | img1 = img1[border:h - border, border:w - border]
22 | img2 = img2[border:h - border, border:w - border]
23 |
24 | img1 = img1.astype(np.float64)
25 | img2 = img2.astype(np.float64)
26 | mse = np.mean((img1 - img2)**2)
27 | if mse == 0:
28 | return float('inf')
29 | return 20 * math.log10(255.0 / math.sqrt(mse))
30 |
31 |
32 | def PSNR(img1, img2):
33 | mse_ = np.mean((img1 - img2) ** 2)
34 | if mse_ == 0:
35 | return 100
36 | return 10 * math.log10(1 / mse_)
37 |
38 |
39 | # --------------------------------------------
40 | # SSIM
41 | # --------------------------------------------
42 | def calculate_ssim(img1, img2, border=0):
43 | '''calculate SSIM
44 | the same outputs as MATLAB's
45 | img1, img2: [0, 255]
46 | '''
47 | #img1 = img1.squeeze()
48 | #img2 = img2.squeeze()
49 | if not img1.shape == img2.shape:
50 | raise ValueError('Input images must have the same dimensions.')
51 | h, w = img1.shape[:2]
52 | img1 = img1[border:h - border, border:w - border]
53 | img2 = img2[border:h - border, border:w - border]
54 |
55 | if img1.ndim == 2:
56 | return ssim(img1, img2)
57 | elif img1.ndim == 3:
58 | if img1.shape[2] == 3:
59 | ssims = []
60 | for i in range(3):
61 | ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
62 | return np.array(ssims).mean()
63 | elif img1.shape[2] == 1:
64 | return ssim(np.squeeze(img1), np.squeeze(img2))
65 | else:
66 | raise ValueError('Wrong input image dimensions.')
67 |
68 |
69 | def ssim(img1, img2):
70 | C1 = (0.01 * 255)**2
71 | C2 = (0.03 * 255)**2
72 |
73 | img1 = img1.astype(np.float64)
74 | img2 = img2.astype(np.float64)
75 | kernel = cv2.getGaussianKernel(11, 1.5)
76 | window = np.outer(kernel, kernel.transpose())
77 |
78 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
79 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
80 | mu1_sq = mu1**2
81 | mu2_sq = mu2**2
82 | mu1_mu2 = mu1 * mu2
83 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
84 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
85 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
86 |
87 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
88 | (sigma1_sq + sigma2_sq + C2))
89 | return ssim_map.mean()
90 |
91 |
92 | def load_img(filepath):
93 | return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
94 |
95 |
96 | def save_img(filepath, img):
97 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
98 |
99 |
100 | def load_gray_img(filepath):
101 | return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
102 |
103 |
104 | def save_gray_img(filepath, img):
105 | cv2.imwrite(filepath, img)
106 |
107 |
108 | def visualization(feature, save_path, type='max', colormap=cv2.COLORMAP_JET):
109 | '''
110 | :param feature: [C,H,W]
111 | :param save_path: saving path
112 | :param type: 'mean' or 'max'
113 | :param colormap: the type of the pseudocolor map
114 | '''
115 | feature = feature.cpu().numpy()
116 | if type == 'mean':
117 | feature = np.mean(feature, axis=0)
118 | else:
119 | feature = np.max(feature, axis=0)
120 | normed_feat = (feature - feature.min()) / (feature.max() - feature.min())
121 | normed_feat = (normed_feat * 255).astype('uint8')
122 | color_feat = cv2.applyColorMap(normed_feat, colormap)
123 | # stx()
124 | cv2.imwrite(save_path, color_feat)
125 |
126 | def my_summary(test_model, H = 256, W = 256, C = 3, N = 1):
127 | model = test_model.cuda()
128 | print(model)
129 | inputs = torch.randn((N, C, H, W)).cuda()
130 | flops = FlopCountAnalysis(model,inputs)
131 | n_param = sum([p.nelement() for p in model.parameters()])
132 | print(f'GMac:{flops.total()/(1024*1024*1024)}')
133 | print(f'Params:{n_param}')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Devil is in the Uniformity: Exploring Diverse Learners within Transformer for Image Restoration
2 |
3 | [](https://huggingface.co/spaces/yssszzzzzzzzy/HINT)
4 | 
5 | [](https://github.com/joshyZhou/HINT)
6 |
7 | [Shihao Zhou](https://joshyzhou.github.io/), [Dayu Li](https://github.com/nkldy22), [Jinshan Pan](https://jspan.github.io/), [Juncheng Zhou](https://github.com/ZhouJunCheng99), [Jinglei Shi](https://jingleishi.github.io/) and [Jufeng Yang](https://cv.nankai.edu.cn/)
8 |
9 | #### News
10 | - **Jul 19, 2025:** [Hugging Face Demo](https://huggingface.co/spaces/yssszzzzzzzzy/HINT) is available now, thanks contribution of [Sen](https://github.com/yss730)
11 | - **Jun 26, 2025:** HINT has been accepted to ICCV 2025 :tada:
12 |
13 |
14 | ## Training
15 | ### Derain
16 | To train HINT on rain100L, you can run:
17 | ```sh
18 | ./train.sh Deraining/Options/Deraining_HINT_syn_rain100L.yml
19 | ```
20 | ### Dehaze
21 | To train HINT on SOTS, you can run:
22 | ```sh
23 | ./train.sh Dehaze/Options/RealDehazing_HINT.yml
24 | ```
25 | ### Denoising
26 | To train HINT on WB, you can run:
27 | ```sh
28 | ./train.sh Denoising/Options/GaussianColorDenoising_HINT.yml
29 | ```
30 | ### Desnowing
31 | To train HINT on snow100k, you can run:
32 | ```sh
33 | ./train.sh Desnowing/Options/Desnow_snow100k_HINT.yml
34 | ```
35 | ### Enhancement
36 | To train HINT on LOL_v2_real, you can run:
37 | ```sh
38 | ./train.sh Enhancement/Options/HINT_LOL_v2_real.yml
39 | ```
40 |
41 | To train HINT on LOL_v2_synthetic, you can run:
42 | ```sh
43 | ./train.sh Enhancement/Options/HINT_LOL_v2_synthetic.yml
44 | ```
45 |
46 | ## Evaluation
47 | To evaluate HINT, you can refer commands in 'test.sh'
48 |
49 | For evaluate on each dataset, you should uncomment corresponding line.
50 |
51 |
52 | ## Results
53 | Experiments are performed for different image processing tasks.
54 | Here is a summary table containing hyperlinks for easy navigation:
55 |
93 |
94 |
95 | ## Citation
96 | If you find this project useful, please consider citing:
97 |
98 | @inproceedings{zhou_ICCV25_HINT,
99 | title={Devil is in the Uniformity: Exploring Diverse Learners within Transformer for Image Restoration},
100 | author={Zhou, Shihao and Li, Dayu and Pan, Jinshan and Zhou, Juncheng and Shi, Jinglei and Yang, Jufeng},
101 | booktitle={ICCV},
102 | year={2025}
103 | }
104 |
105 | ## Acknowledgement
106 |
107 | This code borrows heavily from [Restormer](https://github.com/swz30/Restormer).
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 1.2.0
2 |
--------------------------------------------------------------------------------
/basicsr/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/.DS_Store
--------------------------------------------------------------------------------
/basicsr/__pycache__/version.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/__pycache__/version.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/__pycache__/version.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/__pycache__/version.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/data/SDSD_image_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import torch
3 | import torch.utils.data as data
4 | import basicsr.data.util as util
5 | import torch.nn.functional as F
6 | import random
7 | import cv2
8 | import numpy as np
9 | import glob
10 | import os
11 | import functools
12 |
13 |
14 | class Dataset_SDSDImage(data.Dataset):
15 | def __init__(self, opt):
16 | super(Dataset_SDSDImage, self).__init__()
17 | self.opt = opt
18 | self.cache_data = opt['cache_data']
19 | self.half_N_frames = opt['N_frames'] // 2
20 | self.GT_root, self.LQ_root = opt['dataroot_gt'], opt['dataroot_lq']
21 | self.io_backend_opt = opt['io_backend']
22 | self.data_type = self.io_backend_opt['type']
23 | self.data_info = {'path_LQ': [], 'path_GT': [],
24 | 'folder': [], 'idx': [], 'border': []}
25 | if self.data_type == 'lmdb':
26 | raise ValueError('No need to use LMDB during validation/test.')
27 | # Generate data info and cache data
28 | self.imgs_LQ, self.imgs_GT = {}, {}
29 |
30 | if opt['testing_dir'] is not None:
31 | testing_dir = opt['testing_dir']
32 | testing_dir = testing_dir.split(',')
33 | else:
34 | testing_dir = []
35 | print('testing_dir', testing_dir)
36 |
37 | subfolders_LQ = util.glob_file_list(self.LQ_root)
38 | subfolders_GT = util.glob_file_list(self.GT_root)
39 |
40 | for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
41 | # for frames in each video:
42 | subfolder_name = osp.basename(subfolder_GT)
43 |
44 | if self.opt['phase'] == 'train':
45 | if (subfolder_name in testing_dir):
46 | continue
47 |
48 | if (subfolder_name.split('_2')[0] in testing_dir):
49 | continue
50 | else: # val test
51 | if not(subfolder_name in testing_dir) and not(subfolder_name.split('_2')[0] in testing_dir):
52 | continue
53 |
54 | img_paths_LQ = util.glob_file_list(subfolder_LQ)
55 | img_paths_GT = util.glob_file_list(subfolder_GT)
56 |
57 | max_idx = len(img_paths_LQ)
58 | assert max_idx == len(
59 | img_paths_GT), 'Different number of images in LQ and GT folders'
60 | self.data_info['path_LQ'].extend(
61 | img_paths_LQ) # list of path str of images
62 | self.data_info['path_GT'].extend(img_paths_GT)
63 |
64 | self.data_info['folder'].extend([subfolder_name] * max_idx)
65 | for i in range(max_idx):
66 | self.data_info['idx'].append('{}/{}'.format(i, max_idx))
67 |
68 | border_l = [0] * max_idx
69 | for i in range(self.half_N_frames):
70 | border_l[i] = 1
71 | border_l[max_idx - i - 1] = 1
72 | self.data_info['border'].extend(border_l)
73 |
74 | if self.cache_data:
75 | self.imgs_LQ[subfolder_name] = img_paths_LQ
76 | self.imgs_GT[subfolder_name] = img_paths_GT
77 |
78 | def __getitem__(self, index):
79 | folder = self.data_info['folder'][index]
80 | idx, max_idx = self.data_info['idx'][index].split('/')
81 | idx, max_idx = int(idx), int(max_idx)
82 | border = self.data_info['border'][index]
83 |
84 | img_LQ_path = self.imgs_LQ[folder][idx:idx + 1]
85 | img_GT_path = self.imgs_GT[folder][idx:idx + 1]
86 |
87 | img_LQ = util.read_img_seq2(img_LQ_path, self.opt['train_size'])
88 | img_LQ = img_LQ[0]
89 | img_GT = util.read_img_seq2(img_GT_path, self.opt['train_size'])
90 | img_GT = img_GT[0]
91 |
92 | if self.opt['phase'] == 'train':
93 |
94 | # LQ_size = self.opt['LQ_size']
95 | # GT_size = self.opt['GT_size']
96 |
97 | # _, H, W = img_GT.shape # real img size
98 |
99 | # rnd_h = random.randint(0, max(0, H - GT_size))
100 | # rnd_w = random.randint(0, max(0, W - GT_size))
101 | # img_LQ = img_LQ[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size]
102 | # img_GT = img_GT[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size]
103 |
104 | img_LQ_l = [img_LQ]
105 | img_LQ_l.append(img_GT)
106 | rlt = util.augment_torch(
107 | img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
108 | img_LQ = rlt[0]
109 | img_GT = rlt[1]
110 |
111 | # img_nf = img_LQ.clone().permute(1, 2, 0).numpy() * 255.0
112 | # img_nf = cv2.blur(img_nf, (5, 5))
113 | # img_nf = img_nf * 1.0 / 255.0
114 | # img_nf = torch.Tensor(img_nf).float().permute(2, 0, 1)
115 |
116 | return {
117 | 'lq': img_LQ,
118 | 'gt': img_GT,
119 | # 'nf': img_nf,
120 | 'folder': folder,
121 | 'idx': self.data_info['idx'][index],
122 | 'border': border,
123 | 'lq_path': img_LQ_path[0],
124 | 'gt_path': img_GT_path[0]
125 | }
126 |
127 | def __len__(self):
128 | return len(self.data_info['path_LQ'])
129 |
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from functools import partial
7 | from os import path as osp
8 |
9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10 | from basicsr.utils import get_root_logger, scandir
11 | from basicsr.utils.dist_util import get_dist_info
12 |
13 | __all__ = ['create_dataset', 'create_dataloader']
14 |
15 | # automatically scan and import dataset modules
16 | # scan all the files under the data folder with '_dataset' in file names
17 | data_folder = osp.dirname(osp.abspath(__file__))
18 | dataset_filenames = [
19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
20 | if v.endswith('_dataset.py')
21 | ]
22 | # import all the dataset modules
23 | _dataset_modules = [
24 | importlib.import_module(f'basicsr.data.{file_name}')
25 | for file_name in dataset_filenames
26 | ]
27 |
28 |
29 | def create_dataset(dataset_opt):
30 | """Create dataset.
31 |
32 | Args:
33 | dataset_opt (dict): Configuration for dataset. It constains:
34 | name (str): Dataset name.
35 | type (str): Dataset type.
36 | """
37 | dataset_type = dataset_opt['type']
38 |
39 | # dynamic instantiation
40 | for module in _dataset_modules:
41 | dataset_cls = getattr(module, dataset_type, None)
42 | if dataset_cls is not None:
43 | break
44 | if dataset_cls is None:
45 | raise ValueError(f'Dataset {dataset_type} is not found.')
46 |
47 | dataset = dataset_cls(dataset_opt)
48 |
49 | logger = get_root_logger()
50 | logger.info(
51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
52 | 'is created.')
53 | return dataset
54 |
55 |
56 | def create_dataloader(dataset,
57 | dataset_opt,
58 | num_gpu=1,
59 | dist=False,
60 | sampler=None,
61 | seed=None):
62 | """Create dataloader.
63 |
64 | Args:
65 | dataset (torch.utils.data.Dataset): Dataset.
66 | dataset_opt (dict): Dataset options. It contains the following keys:
67 | phase (str): 'train' or 'val'.
68 | num_worker_per_gpu (int): Number of workers for each GPU.
69 | batch_size_per_gpu (int): Training batch size for each GPU.
70 | num_gpu (int): Number of GPUs. Used only in the train phase.
71 | Default: 1.
72 | dist (bool): Whether in distributed training. Used only in the train
73 | phase. Default: False.
74 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
75 | seed (int | None): Seed. Default: None
76 | """
77 | phase = dataset_opt['phase']
78 | rank, _ = get_dist_info()
79 | if phase == 'train':
80 | if dist: # distributed training
81 | batch_size = dataset_opt['batch_size_per_gpu']
82 | num_workers = dataset_opt['num_worker_per_gpu']
83 | else: # non-distributed training
84 | multiplier = 1 if num_gpu == 0 else num_gpu
85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
87 | dataloader_args = dict(
88 | dataset=dataset,
89 | batch_size=batch_size,
90 | shuffle=False,
91 | num_workers=num_workers,
92 | sampler=sampler,
93 | drop_last=True)
94 | if sampler is None:
95 | dataloader_args['shuffle'] = True
96 | dataloader_args['worker_init_fn'] = partial(
97 | worker_init_fn, num_workers=num_workers, rank=rank,
98 | seed=seed) if seed is not None else None
99 | elif phase in ['val', 'test']: # validation
100 | dataloader_args = dict(
101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
102 | else:
103 | raise ValueError(f'Wrong dataset phase: {phase}. '
104 | "Supported ones are 'train', 'val' and 'test'.")
105 |
106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
107 |
108 | prefetch_mode = dataset_opt.get('prefetch_mode')
109 | if prefetch_mode == 'cpu': # CPUPrefetcher
110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
111 | logger = get_root_logger()
112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
113 | f'num_prefetch_queue = {num_prefetch_queue}')
114 | return PrefetchDataLoader(
115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
116 | else:
117 | # prefetch_mode=None: Normal dataloader
118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
119 | return torch.utils.data.DataLoader(**dataloader_args)
120 |
121 |
122 | def worker_init_fn(worker_id, num_workers, rank, seed):
123 | # Set the worker seed to num_workers * rank + worker_id + seed
124 | worker_seed = num_workers * rank + worker_id + seed
125 | np.random.seed(worker_seed)
126 | random.seed(worker_seed)
127 |
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/SDSD_image_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/SDSD_image_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/data_sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/data_sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/data_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/data_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/ffhq_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/paired_image_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/prefetch_dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/reds_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/reds_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/single_image_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/video_test_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/data/__pycache__/vimeo90k_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when restart the dataloader after each epoch
12 |
13 | Args:
14 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
15 | num_replicas (int | None): Number of processes participating in
16 | the training. It is usually the world_size.
17 | rank (int | None): Rank of the current process within num_replicas.
18 | ratio (int): Enlarging ratio. Default: 1.
19 | """
20 |
21 | def __init__(self, dataset, num_replicas, rank, ratio=1):
22 | self.dataset = dataset
23 | self.num_replicas = num_replicas
24 | self.rank = rank
25 | self.epoch = 0
26 | self.num_samples = math.ceil(
27 | len(self.dataset) * ratio / self.num_replicas)
28 | self.total_size = self.num_samples * self.num_replicas
29 |
30 | def __iter__(self):
31 | # deterministically shuffle based on epoch
32 | g = torch.Generator()
33 | g.manual_seed(self.epoch)
34 | indices = torch.randperm(self.total_size, generator=g).tolist()
35 |
36 | dataset_size = len(self.dataset)
37 | indices = [v % dataset_size for v in indices]
38 |
39 | # subsample
40 | indices = indices[self.rank:self.total_size:self.num_replicas]
41 | assert len(indices) == self.num_samples
42 |
43 | return iter(indices)
44 |
45 | def __len__(self):
46 | return self.num_samples
47 |
48 | def set_epoch(self, epoch):
49 | self.epoch = epoch
50 |
--------------------------------------------------------------------------------
/basicsr/data/ffhq_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.transforms import augment
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor
7 |
8 |
9 | class FFHQDataset(data.Dataset):
10 | """FFHQ dataset for StyleGAN.
11 |
12 | Args:
13 | opt (dict): Config for train datasets. It contains the following keys:
14 | dataroot_gt (str): Data root path for gt.
15 | io_backend (dict): IO backend type and other kwarg.
16 | mean (list | tuple): Image mean.
17 | std (list | tuple): Image std.
18 | use_hflip (bool): Whether to horizontally flip.
19 |
20 | """
21 |
22 | def __init__(self, opt):
23 | super(FFHQDataset, self).__init__()
24 | self.opt = opt
25 | # file client (io backend)
26 | self.file_client = None
27 | self.io_backend_opt = opt['io_backend']
28 |
29 | self.gt_folder = opt['dataroot_gt']
30 | self.mean = opt['mean']
31 | self.std = opt['std']
32 |
33 | if self.io_backend_opt['type'] == 'lmdb':
34 | self.io_backend_opt['db_paths'] = self.gt_folder
35 | if not self.gt_folder.endswith('.lmdb'):
36 | raise ValueError("'dataroot_gt' should end with '.lmdb', "
37 | f'but received {self.gt_folder}')
38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
39 | self.paths = [line.split('.')[0] for line in fin]
40 | else:
41 | # FFHQ has 70000 images in total
42 | self.paths = [
43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)
44 | ]
45 |
46 | def __getitem__(self, index):
47 | if self.file_client is None:
48 | self.file_client = FileClient(
49 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | # load gt image
52 | gt_path = self.paths[index]
53 | img_bytes = self.file_client.get(gt_path)
54 | img_gt = imfrombytes(img_bytes, float32=True)
55 |
56 | # random horizontal flip
57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
58 | # BGR to RGB, HWC to CHW, numpy to tensor
59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
60 | # normalize
61 | normalize(img_gt, self.mean, self.std, inplace=True)
62 | return {'gt': img_gt, 'gt_path': gt_path}
63 |
64 | def __len__(self):
65 | return len(self.paths)
66 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDS4_test_GT.txt:
--------------------------------------------------------------------------------
1 | 000 100 (720,1280,3)
2 | 011 100 (720,1280,3)
3 | 015 100 (720,1280,3)
4 | 020 100 (720,1280,3)
5 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDS_GT.txt:
--------------------------------------------------------------------------------
1 | 000 100 (720,1280,3)
2 | 001 100 (720,1280,3)
3 | 002 100 (720,1280,3)
4 | 003 100 (720,1280,3)
5 | 004 100 (720,1280,3)
6 | 005 100 (720,1280,3)
7 | 006 100 (720,1280,3)
8 | 007 100 (720,1280,3)
9 | 008 100 (720,1280,3)
10 | 009 100 (720,1280,3)
11 | 010 100 (720,1280,3)
12 | 011 100 (720,1280,3)
13 | 012 100 (720,1280,3)
14 | 013 100 (720,1280,3)
15 | 014 100 (720,1280,3)
16 | 015 100 (720,1280,3)
17 | 016 100 (720,1280,3)
18 | 017 100 (720,1280,3)
19 | 018 100 (720,1280,3)
20 | 019 100 (720,1280,3)
21 | 020 100 (720,1280,3)
22 | 021 100 (720,1280,3)
23 | 022 100 (720,1280,3)
24 | 023 100 (720,1280,3)
25 | 024 100 (720,1280,3)
26 | 025 100 (720,1280,3)
27 | 026 100 (720,1280,3)
28 | 027 100 (720,1280,3)
29 | 028 100 (720,1280,3)
30 | 029 100 (720,1280,3)
31 | 030 100 (720,1280,3)
32 | 031 100 (720,1280,3)
33 | 032 100 (720,1280,3)
34 | 033 100 (720,1280,3)
35 | 034 100 (720,1280,3)
36 | 035 100 (720,1280,3)
37 | 036 100 (720,1280,3)
38 | 037 100 (720,1280,3)
39 | 038 100 (720,1280,3)
40 | 039 100 (720,1280,3)
41 | 040 100 (720,1280,3)
42 | 041 100 (720,1280,3)
43 | 042 100 (720,1280,3)
44 | 043 100 (720,1280,3)
45 | 044 100 (720,1280,3)
46 | 045 100 (720,1280,3)
47 | 046 100 (720,1280,3)
48 | 047 100 (720,1280,3)
49 | 048 100 (720,1280,3)
50 | 049 100 (720,1280,3)
51 | 050 100 (720,1280,3)
52 | 051 100 (720,1280,3)
53 | 052 100 (720,1280,3)
54 | 053 100 (720,1280,3)
55 | 054 100 (720,1280,3)
56 | 055 100 (720,1280,3)
57 | 056 100 (720,1280,3)
58 | 057 100 (720,1280,3)
59 | 058 100 (720,1280,3)
60 | 059 100 (720,1280,3)
61 | 060 100 (720,1280,3)
62 | 061 100 (720,1280,3)
63 | 062 100 (720,1280,3)
64 | 063 100 (720,1280,3)
65 | 064 100 (720,1280,3)
66 | 065 100 (720,1280,3)
67 | 066 100 (720,1280,3)
68 | 067 100 (720,1280,3)
69 | 068 100 (720,1280,3)
70 | 069 100 (720,1280,3)
71 | 070 100 (720,1280,3)
72 | 071 100 (720,1280,3)
73 | 072 100 (720,1280,3)
74 | 073 100 (720,1280,3)
75 | 074 100 (720,1280,3)
76 | 075 100 (720,1280,3)
77 | 076 100 (720,1280,3)
78 | 077 100 (720,1280,3)
79 | 078 100 (720,1280,3)
80 | 079 100 (720,1280,3)
81 | 080 100 (720,1280,3)
82 | 081 100 (720,1280,3)
83 | 082 100 (720,1280,3)
84 | 083 100 (720,1280,3)
85 | 084 100 (720,1280,3)
86 | 085 100 (720,1280,3)
87 | 086 100 (720,1280,3)
88 | 087 100 (720,1280,3)
89 | 088 100 (720,1280,3)
90 | 089 100 (720,1280,3)
91 | 090 100 (720,1280,3)
92 | 091 100 (720,1280,3)
93 | 092 100 (720,1280,3)
94 | 093 100 (720,1280,3)
95 | 094 100 (720,1280,3)
96 | 095 100 (720,1280,3)
97 | 096 100 (720,1280,3)
98 | 097 100 (720,1280,3)
99 | 098 100 (720,1280,3)
100 | 099 100 (720,1280,3)
101 | 100 100 (720,1280,3)
102 | 101 100 (720,1280,3)
103 | 102 100 (720,1280,3)
104 | 103 100 (720,1280,3)
105 | 104 100 (720,1280,3)
106 | 105 100 (720,1280,3)
107 | 106 100 (720,1280,3)
108 | 107 100 (720,1280,3)
109 | 108 100 (720,1280,3)
110 | 109 100 (720,1280,3)
111 | 110 100 (720,1280,3)
112 | 111 100 (720,1280,3)
113 | 112 100 (720,1280,3)
114 | 113 100 (720,1280,3)
115 | 114 100 (720,1280,3)
116 | 115 100 (720,1280,3)
117 | 116 100 (720,1280,3)
118 | 117 100 (720,1280,3)
119 | 118 100 (720,1280,3)
120 | 119 100 (720,1280,3)
121 | 120 100 (720,1280,3)
122 | 121 100 (720,1280,3)
123 | 122 100 (720,1280,3)
124 | 123 100 (720,1280,3)
125 | 124 100 (720,1280,3)
126 | 125 100 (720,1280,3)
127 | 126 100 (720,1280,3)
128 | 127 100 (720,1280,3)
129 | 128 100 (720,1280,3)
130 | 129 100 (720,1280,3)
131 | 130 100 (720,1280,3)
132 | 131 100 (720,1280,3)
133 | 132 100 (720,1280,3)
134 | 133 100 (720,1280,3)
135 | 134 100 (720,1280,3)
136 | 135 100 (720,1280,3)
137 | 136 100 (720,1280,3)
138 | 137 100 (720,1280,3)
139 | 138 100 (720,1280,3)
140 | 139 100 (720,1280,3)
141 | 140 100 (720,1280,3)
142 | 141 100 (720,1280,3)
143 | 142 100 (720,1280,3)
144 | 143 100 (720,1280,3)
145 | 144 100 (720,1280,3)
146 | 145 100 (720,1280,3)
147 | 146 100 (720,1280,3)
148 | 147 100 (720,1280,3)
149 | 148 100 (720,1280,3)
150 | 149 100 (720,1280,3)
151 | 150 100 (720,1280,3)
152 | 151 100 (720,1280,3)
153 | 152 100 (720,1280,3)
154 | 153 100 (720,1280,3)
155 | 154 100 (720,1280,3)
156 | 155 100 (720,1280,3)
157 | 156 100 (720,1280,3)
158 | 157 100 (720,1280,3)
159 | 158 100 (720,1280,3)
160 | 159 100 (720,1280,3)
161 | 160 100 (720,1280,3)
162 | 161 100 (720,1280,3)
163 | 162 100 (720,1280,3)
164 | 163 100 (720,1280,3)
165 | 164 100 (720,1280,3)
166 | 165 100 (720,1280,3)
167 | 166 100 (720,1280,3)
168 | 167 100 (720,1280,3)
169 | 168 100 (720,1280,3)
170 | 169 100 (720,1280,3)
171 | 170 100 (720,1280,3)
172 | 171 100 (720,1280,3)
173 | 172 100 (720,1280,3)
174 | 173 100 (720,1280,3)
175 | 174 100 (720,1280,3)
176 | 175 100 (720,1280,3)
177 | 176 100 (720,1280,3)
178 | 177 100 (720,1280,3)
179 | 178 100 (720,1280,3)
180 | 179 100 (720,1280,3)
181 | 180 100 (720,1280,3)
182 | 181 100 (720,1280,3)
183 | 182 100 (720,1280,3)
184 | 183 100 (720,1280,3)
185 | 184 100 (720,1280,3)
186 | 185 100 (720,1280,3)
187 | 186 100 (720,1280,3)
188 | 187 100 (720,1280,3)
189 | 188 100 (720,1280,3)
190 | 189 100 (720,1280,3)
191 | 190 100 (720,1280,3)
192 | 191 100 (720,1280,3)
193 | 192 100 (720,1280,3)
194 | 193 100 (720,1280,3)
195 | 194 100 (720,1280,3)
196 | 195 100 (720,1280,3)
197 | 196 100 (720,1280,3)
198 | 197 100 (720,1280,3)
199 | 198 100 (720,1280,3)
200 | 199 100 (720,1280,3)
201 | 200 100 (720,1280,3)
202 | 201 100 (720,1280,3)
203 | 202 100 (720,1280,3)
204 | 203 100 (720,1280,3)
205 | 204 100 (720,1280,3)
206 | 205 100 (720,1280,3)
207 | 206 100 (720,1280,3)
208 | 207 100 (720,1280,3)
209 | 208 100 (720,1280,3)
210 | 209 100 (720,1280,3)
211 | 210 100 (720,1280,3)
212 | 211 100 (720,1280,3)
213 | 212 100 (720,1280,3)
214 | 213 100 (720,1280,3)
215 | 214 100 (720,1280,3)
216 | 215 100 (720,1280,3)
217 | 216 100 (720,1280,3)
218 | 217 100 (720,1280,3)
219 | 218 100 (720,1280,3)
220 | 219 100 (720,1280,3)
221 | 220 100 (720,1280,3)
222 | 221 100 (720,1280,3)
223 | 222 100 (720,1280,3)
224 | 223 100 (720,1280,3)
225 | 224 100 (720,1280,3)
226 | 225 100 (720,1280,3)
227 | 226 100 (720,1280,3)
228 | 227 100 (720,1280,3)
229 | 228 100 (720,1280,3)
230 | 229 100 (720,1280,3)
231 | 230 100 (720,1280,3)
232 | 231 100 (720,1280,3)
233 | 232 100 (720,1280,3)
234 | 233 100 (720,1280,3)
235 | 234 100 (720,1280,3)
236 | 235 100 (720,1280,3)
237 | 236 100 (720,1280,3)
238 | 237 100 (720,1280,3)
239 | 238 100 (720,1280,3)
240 | 239 100 (720,1280,3)
241 | 240 100 (720,1280,3)
242 | 241 100 (720,1280,3)
243 | 242 100 (720,1280,3)
244 | 243 100 (720,1280,3)
245 | 244 100 (720,1280,3)
246 | 245 100 (720,1280,3)
247 | 246 100 (720,1280,3)
248 | 247 100 (720,1280,3)
249 | 248 100 (720,1280,3)
250 | 249 100 (720,1280,3)
251 | 250 100 (720,1280,3)
252 | 251 100 (720,1280,3)
253 | 252 100 (720,1280,3)
254 | 253 100 (720,1280,3)
255 | 254 100 (720,1280,3)
256 | 255 100 (720,1280,3)
257 | 256 100 (720,1280,3)
258 | 257 100 (720,1280,3)
259 | 258 100 (720,1280,3)
260 | 259 100 (720,1280,3)
261 | 260 100 (720,1280,3)
262 | 261 100 (720,1280,3)
263 | 262 100 (720,1280,3)
264 | 263 100 (720,1280,3)
265 | 264 100 (720,1280,3)
266 | 265 100 (720,1280,3)
267 | 266 100 (720,1280,3)
268 | 267 100 (720,1280,3)
269 | 268 100 (720,1280,3)
270 | 269 100 (720,1280,3)
271 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt:
--------------------------------------------------------------------------------
1 | 240 100 (720,1280,3)
2 | 241 100 (720,1280,3)
3 | 246 100 (720,1280,3)
4 | 257 100 (720,1280,3)
5 |
--------------------------------------------------------------------------------
/basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt:
--------------------------------------------------------------------------------
1 | 240 100 (720,1280,3)
2 | 241 100 (720,1280,3)
3 | 242 100 (720,1280,3)
4 | 243 100 (720,1280,3)
5 | 244 100 (720,1280,3)
6 | 245 100 (720,1280,3)
7 | 246 100 (720,1280,3)
8 | 247 100 (720,1280,3)
9 | 248 100 (720,1280,3)
10 | 249 100 (720,1280,3)
11 | 250 100 (720,1280,3)
12 | 251 100 (720,1280,3)
13 | 252 100 (720,1280,3)
14 | 253 100 (720,1280,3)
15 | 254 100 (720,1280,3)
16 | 255 100 (720,1280,3)
17 | 256 100 (720,1280,3)
18 | 257 100 (720,1280,3)
19 | 258 100 (720,1280,3)
20 | 259 100 (720,1280,3)
21 | 260 100 (720,1280,3)
22 | 261 100 (720,1280,3)
23 | 262 100 (720,1280,3)
24 | 263 100 (720,1280,3)
25 | 264 100 (720,1280,3)
26 | 265 100 (720,1280,3)
27 | 266 100 (720,1280,3)
28 | 267 100 (720,1280,3)
29 | 268 100 (720,1280,3)
30 | 269 100 (720,1280,3)
31 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Ref:
11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 |
13 | Args:
14 | generator: Python generator.
15 | num_prefetch_queue (int): Number of prefetch queue.
16 | """
17 |
18 | def __init__(self, generator, num_prefetch_queue):
19 | threading.Thread.__init__(self)
20 | self.queue = Queue.Queue(num_prefetch_queue)
21 | self.generator = generator
22 | self.daemon = True
23 | self.start()
24 |
25 | def run(self):
26 | for item in self.generator:
27 | self.queue.put(item)
28 | self.queue.put(None)
29 |
30 | def __next__(self):
31 | next_item = self.queue.get()
32 | if next_item is None:
33 | raise StopIteration
34 | return next_item
35 |
36 | def __iter__(self):
37 | return self
38 |
39 |
40 | class PrefetchDataLoader(DataLoader):
41 | """Prefetch version of dataloader.
42 |
43 | Ref:
44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45 |
46 | TODO:
47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48 | ddp.
49 |
50 | Args:
51 | num_prefetch_queue (int): Number of prefetch queue.
52 | kwargs (dict): Other arguments for dataloader.
53 | """
54 |
55 | def __init__(self, num_prefetch_queue, **kwargs):
56 | self.num_prefetch_queue = num_prefetch_queue
57 | super(PrefetchDataLoader, self).__init__(**kwargs)
58 |
59 | def __iter__(self):
60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61 |
62 |
63 | class CPUPrefetcher():
64 | """CPU prefetcher.
65 |
66 | Args:
67 | loader: Dataloader.
68 | """
69 |
70 | def __init__(self, loader):
71 | self.ori_loader = loader
72 | self.loader = iter(loader)
73 |
74 | def next(self):
75 | try:
76 | return next(self.loader)
77 | except StopIteration:
78 | return None
79 |
80 | def reset(self):
81 | self.loader = iter(self.ori_loader)
82 |
83 |
84 | class CUDAPrefetcher():
85 | """CUDA prefetcher.
86 |
87 | Ref:
88 | https://github.com/NVIDIA/apex/issues/304#
89 |
90 | It may consums more GPU memory.
91 |
92 | Args:
93 | loader: Dataloader.
94 | opt (dict): Options.
95 | """
96 |
97 | def __init__(self, loader, opt):
98 | self.ori_loader = loader
99 | self.loader = iter(loader)
100 | self.opt = opt
101 | self.stream = torch.cuda.Stream()
102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103 | self.preload()
104 |
105 | def preload(self):
106 | try:
107 | self.batch = next(self.loader) # self.batch is a dict
108 | except StopIteration:
109 | self.batch = None
110 | return None
111 | # put tensors to gpu
112 | with torch.cuda.stream(self.stream):
113 | for k, v in self.batch.items():
114 | if torch.is_tensor(v):
115 | self.batch[k] = self.batch[k].to(
116 | device=self.device, non_blocking=True)
117 |
118 | def next(self):
119 | torch.cuda.current_stream().wait_stream(self.stream)
120 | batch = self.batch
121 | self.preload()
122 | return batch
123 |
124 | def reset(self):
125 | self.loader = iter(self.ori_loader)
126 | self.preload()
127 |
--------------------------------------------------------------------------------
/basicsr/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7 |
8 |
9 | class SingleImageDataset(data.Dataset):
10 | """Read only lq images in the test phase.
11 |
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13 |
14 | There are two modes:
15 | 1. 'meta_info_file': Use meta information file to generate paths.
16 | 2. 'folder': Scan folders to generate paths.
17 |
18 | Args:
19 | opt (dict): Config for train datasets. It contains the following keys:
20 | dataroot_lq (str): Data root path for lq.
21 | meta_info_file (str): Path for meta information file.
22 | io_backend (dict): IO backend type and other kwarg.
23 | """
24 |
25 | def __init__(self, opt):
26 | super(SingleImageDataset, self).__init__()
27 | self.opt = opt
28 | # file client (io backend)
29 | self.file_client = None
30 | self.io_backend_opt = opt['io_backend']
31 | self.mean = opt['mean'] if 'mean' in opt else None
32 | self.std = opt['std'] if 'std' in opt else None
33 | self.lq_folder = opt['dataroot_lq']
34 |
35 | if self.io_backend_opt['type'] == 'lmdb':
36 | self.io_backend_opt['db_paths'] = [self.lq_folder]
37 | self.io_backend_opt['client_keys'] = ['lq']
38 | self.paths = paths_from_lmdb(self.lq_folder)
39 | elif 'meta_info_file' in self.opt:
40 | with open(self.opt['meta_info_file'], 'r') as fin:
41 | self.paths = [
42 | osp.join(self.lq_folder,
43 | line.split(' ')[0]) for line in fin
44 | ]
45 | else:
46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 |
48 | def __getitem__(self, index):
49 | if self.file_client is None:
50 | self.file_client = FileClient(
51 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
52 |
53 | # load lq image
54 | lq_path = self.paths[index]
55 | img_bytes = self.file_client.get(lq_path, 'lq')
56 | img_lq = imfrombytes(img_bytes, float32=True)
57 |
58 | # TODO: color space transform
59 | # BGR to RGB, HWC to CHW, numpy to tensor
60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61 | # normalize
62 | if self.mean is not None or self.std is not None:
63 | normalize(img_lq, self.mean, self.std, inplace=True)
64 | return {'lq': img_lq, 'lq_path': lq_path}
65 |
66 | def __len__(self):
67 | return len(self.paths)
68 |
--------------------------------------------------------------------------------
/basicsr/data/vimeo90k_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from pathlib import Path
4 | from torch.utils import data as data
5 |
6 | from basicsr.data.transforms import augment, paired_random_crop
7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8 |
9 |
10 | class Vimeo90KDataset(data.Dataset):
11 | """Vimeo90K dataset for training.
12 |
13 | The keys are generated from a meta info txt file.
14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
15 |
16 | Each line contains:
17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space.
18 | Examples:
19 | 00001/0001 7 (256,448,3)
20 | 00001/0002 7 (256,448,3)
21 |
22 | Key examples: "00001/0001"
23 | GT (gt): Ground-Truth;
24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
25 |
26 | The neighboring frame list for different num_frame:
27 | num_frame | frame list
28 | 1 | 4
29 | 3 | 3,4,5
30 | 5 | 2,3,4,5,6
31 | 7 | 1,2,3,4,5,6,7
32 |
33 | Args:
34 | opt (dict): Config for train dataset. It contains the following keys:
35 | dataroot_gt (str): Data root path for gt.
36 | dataroot_lq (str): Data root path for lq.
37 | meta_info_file (str): Path for meta information file.
38 | io_backend (dict): IO backend type and other kwarg.
39 |
40 | num_frame (int): Window size for input frames.
41 | gt_size (int): Cropped patched size for gt patches.
42 | random_reverse (bool): Random reverse input frames.
43 | use_flip (bool): Use horizontal flips.
44 | use_rot (bool): Use rotation (use vertical flip and transposing h
45 | and w for implementation).
46 |
47 | scale (bool): Scale, which will be added automatically.
48 | """
49 |
50 | def __init__(self, opt):
51 | super(Vimeo90KDataset, self).__init__()
52 | self.opt = opt
53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(
54 | opt['dataroot_lq'])
55 |
56 | with open(opt['meta_info_file'], 'r') as fin:
57 | self.keys = [line.split(' ')[0] for line in fin]
58 |
59 | # file client (io backend)
60 | self.file_client = None
61 | self.io_backend_opt = opt['io_backend']
62 | self.is_lmdb = False
63 | if self.io_backend_opt['type'] == 'lmdb':
64 | self.is_lmdb = True
65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
66 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
67 |
68 | # indices of input images
69 | self.neighbor_list = [
70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])
71 | ]
72 |
73 | # temporal augmentation configs
74 | self.random_reverse = opt['random_reverse']
75 | logger = get_root_logger()
76 | logger.info(f'Random reverse is {self.random_reverse}.')
77 |
78 | def __getitem__(self, index):
79 | if self.file_client is None:
80 | self.file_client = FileClient(
81 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
82 |
83 | # random reverse
84 | if self.random_reverse and random.random() < 0.5:
85 | self.neighbor_list.reverse()
86 |
87 | scale = self.opt['scale']
88 | gt_size = self.opt['gt_size']
89 | key = self.keys[index]
90 | clip, seq = key.split('/') # key example: 00001/0001
91 |
92 | # get the GT frame (im4.png)
93 | if self.is_lmdb:
94 | img_gt_path = f'{key}/im4'
95 | else:
96 | img_gt_path = self.gt_root / clip / seq / 'im4.png'
97 | img_bytes = self.file_client.get(img_gt_path, 'gt')
98 | img_gt = imfrombytes(img_bytes, float32=True)
99 |
100 | # get the neighboring LQ frames
101 | img_lqs = []
102 | for neighbor in self.neighbor_list:
103 | if self.is_lmdb:
104 | img_lq_path = f'{clip}/{seq}/im{neighbor}'
105 | else:
106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
107 | img_bytes = self.file_client.get(img_lq_path, 'lq')
108 | img_lq = imfrombytes(img_bytes, float32=True)
109 | img_lqs.append(img_lq)
110 |
111 | # randomly crop
112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale,
113 | img_gt_path)
114 |
115 | # augmentation - flip, rotate
116 | img_lqs.append(img_gt)
117 | img_results = augment(img_lqs, self.opt['use_flip'],
118 | self.opt['use_rot'])
119 |
120 | img_results = img2tensor(img_results)
121 | img_lqs = torch.stack(img_results[0:-1], dim=0)
122 | img_gt = img_results[-1]
123 |
124 | # img_lqs: (t, c, h, w)
125 | # img_gt: (c, h, w)
126 | # key: str
127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key}
128 |
129 | def __len__(self):
130 | return len(self.keys)
131 |
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .niqe import calculate_niqe
2 | from .psnr_ssim import calculate_psnr, calculate_ssim
3 |
4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
5 |
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/metric_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/metric_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/metric_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/niqe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/niqe.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/niqe.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/niqe.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/psnr_ssim.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/__pycache__/psnr_ssim.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/metrics/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy import linalg
5 | from tqdm import tqdm
6 |
7 | from basicsr.models.archs.inception import InceptionV3
8 |
9 |
10 | def load_patched_inception_v3(device='cuda',
11 | resize_input=True,
12 | normalize_input=False):
13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
14 | # does resize the input.
15 | inception = InceptionV3([3],
16 | resize_input=resize_input,
17 | normalize_input=normalize_input)
18 | inception = nn.DataParallel(inception).eval().to(device)
19 | return inception
20 |
21 |
22 | @torch.no_grad()
23 | def extract_inception_features(data_generator,
24 | inception,
25 | len_generator=None,
26 | device='cuda'):
27 | """Extract inception features.
28 |
29 | Args:
30 | data_generator (generator): A data generator.
31 | inception (nn.Module): Inception model.
32 | len_generator (int): Length of the data_generator to show the
33 | progressbar. Default: None.
34 | device (str): Device. Default: cuda.
35 |
36 | Returns:
37 | Tensor: Extracted features.
38 | """
39 | if len_generator is not None:
40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
41 | else:
42 | pbar = None
43 | features = []
44 |
45 | for data in data_generator:
46 | if pbar:
47 | pbar.update(1)
48 | data = data.to(device)
49 | feature = inception(data)[0].view(data.shape[0], -1)
50 | features.append(feature.to('cpu'))
51 | if pbar:
52 | pbar.close()
53 | features = torch.cat(features, 0)
54 | return features
55 |
56 |
57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
58 | """Numpy implementation of the Frechet Distance.
59 |
60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
61 | and X_2 ~ N(mu_2, C_2) is
62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
63 | Stable version by Dougal J. Sutherland.
64 |
65 | Args:
66 | mu1 (np.array): The sample mean over activations.
67 | sigma1 (np.array): The covariance matrix over activations for
68 | generated samples.
69 | mu2 (np.array): The sample mean over activations, precalculated on an
70 | representative data set.
71 | sigma2 (np.array): The covariance matrix over activations,
72 | precalculated on an representative data set.
73 |
74 | Returns:
75 | float: The Frechet Distance.
76 | """
77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
78 | assert sigma1.shape == sigma2.shape, (
79 | 'Two covariances have different dimensions')
80 |
81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
82 |
83 | # Product might be almost singular
84 | if not np.isfinite(cov_sqrt).all():
85 | print('Product of cov matrices is singular. Adding {eps} to diagonal '
86 | 'of cov estimates')
87 | offset = np.eye(sigma1.shape[0]) * eps
88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
89 |
90 | # Numerical error might give slight imaginary component
91 | if np.iscomplexobj(cov_sqrt):
92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
93 | m = np.max(np.abs(cov_sqrt.imag))
94 | raise ValueError(f'Imaginary component {m}')
95 | cov_sqrt = cov_sqrt.real
96 |
97 | mean_diff = mu1 - mu2
98 | mean_norm = mean_diff @ mean_diff
99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
100 | fid = mean_norm + trace
101 |
102 | return fid
103 |
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from basicsr.utils.matlab_functions import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order='HWC'):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ['HWC', 'CHW']:
24 | raise ValueError(
25 | f'Wrong input_order {input_order}. Supported input_orders are '
26 | "'HWC' and 'CHW'")
27 | if len(img.shape) == 2:
28 | img = img[..., None]
29 | if input_order == 'CHW':
30 | img = img.transpose(1, 2, 0)
31 | return img
32 |
33 |
34 | def to_y_channel(img):
35 | """Change to Y channel of YCbCr.
36 |
37 | Args:
38 | img (ndarray): Images with range [0, 255].
39 |
40 | Returns:
41 | (ndarray): Images with range [0, 255] (float type) without round.
42 | """
43 | img = img.astype(np.float32) / 255.
44 | if img.ndim == 3 and img.shape[2] == 3:
45 | img = bgr2ycbcr(img, y_only=True)
46 | img = img[..., None]
47 | return img * 255.
48 |
--------------------------------------------------------------------------------
/basicsr/metrics/niqe_pris_params.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/metrics/niqe_pris_params.npz
--------------------------------------------------------------------------------
/basicsr/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/.DS_Store
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import get_root_logger, scandir
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12 | if v.endswith('_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'basicsr.models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamic instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = get_root_logger()
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/image_restoration_model.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/image_restoration_model.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/__pycache__/lr_scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import arch modules
7 | # scan all the files under the 'archs' folder and collect files ending with
8 | # '_arch.py'
9 | arch_folder = osp.dirname(osp.abspath(__file__))
10 | arch_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12 | if v.endswith('_arch.py')
13 | ]
14 | # import all the arch modules
15 | _arch_modules = [
16 | importlib.import_module(f'basicsr.models.archs.{file_name}')
17 | for file_name in arch_filenames
18 | ]
19 |
20 |
21 | def dynamic_instantiation(modules, cls_type, opt):
22 | """Dynamically instantiate class.
23 |
24 | Args:
25 | modules (list[importlib modules]): List of modules from importlib
26 | files.
27 | cls_type (str): Class type.
28 | opt (dict): Class initialization kwargs.
29 |
30 | Returns:
31 | class: Instantiated class.
32 | """
33 |
34 | for module in modules:
35 | cls_ = getattr(module, cls_type, None)
36 | if cls_ is not None:
37 | break
38 | if cls_ is None:
39 | raise ValueError(f'{cls_type} is not found.')
40 | return cls_(**opt)
41 |
42 |
43 | def define_network(opt):
44 | network_type = opt.pop('type')
45 | net = dynamic_instantiation(_arch_modules, network_type, opt)
46 | return net
47 |
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/FPro_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/FPro_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/HINT_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/HINT_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/HINT_arch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/HINT_arch.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/arch_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/graph_layers.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/local_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/restormer_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/restormer_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/archs/__pycache__/restormer_local_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/archs/__pycache__/restormer_local_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss,FFTLoss)
2 |
3 | __all__ = [
4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss','FFTLoss',
5 | ]
6 |
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/loss_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/loss_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/models/losses/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import functional as F
3 |
4 |
5 | def reduce_loss(loss, reduction):
6 | """Reduce loss as specified.
7 |
8 | Args:
9 | loss (Tensor): Elementwise loss tensor.
10 | reduction (str): Options are 'none', 'mean' and 'sum'.
11 |
12 | Returns:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | else:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
26 | """Apply element-wise weight and reduce loss.
27 |
28 | Args:
29 | loss (Tensor): Element-wise loss.
30 | weight (Tensor): Element-wise weights. Default: None.
31 | reduction (str): Same as built-in losses of PyTorch. Options are
32 | 'none', 'mean' and 'sum'. Default: 'mean'.
33 |
34 | Returns:
35 | Tensor: Loss values.
36 | """
37 | # if weight is specified, apply element-wise weight
38 | if weight is not None:
39 | assert weight.dim() == loss.dim()
40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41 | loss = loss * weight
42 |
43 | # if weight is not specified or reduction is sum, just reduce the loss
44 | if weight is None or reduction == 'sum':
45 | loss = reduce_loss(loss, reduction)
46 | # if reduction is mean, then compute mean over weight region
47 | elif reduction == 'mean':
48 | if weight.size(1) > 1:
49 | weight = weight.sum()
50 | else:
51 | weight = weight.sum() * loss.size(1)
52 | loss = loss.sum() / weight
53 |
54 | return loss
55 |
56 |
57 | def weighted_loss(loss_func):
58 | """Create a weighted version of a given loss function.
59 |
60 | To use this decorator, the loss function must have the signature like
61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
62 | element-wise loss without any reduction. This decorator will add weight
63 | and reduction arguments to the function. The decorated function will have
64 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
65 | **kwargs)`.
66 |
67 | :Example:
68 |
69 | >>> import torch
70 | >>> @weighted_loss
71 | >>> def l1_loss(pred, target):
72 | >>> return (pred - target).abs()
73 |
74 | >>> pred = torch.Tensor([0, 2, 3])
75 | >>> target = torch.Tensor([1, 1, 1])
76 | >>> weight = torch.Tensor([1, 0, 1])
77 |
78 | >>> l1_loss(pred, target)
79 | tensor(1.3333)
80 | >>> l1_loss(pred, target, weight)
81 | tensor(1.5000)
82 | >>> l1_loss(pred, target, reduction='none')
83 | tensor([1., 1., 2.])
84 | >>> l1_loss(pred, target, weight, reduction='sum')
85 | tensor(3.)
86 | """
87 |
88 | @functools.wraps(loss_func)
89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90 | # get element-wise loss
91 | loss = loss_func(pred, target, **kwargs)
92 | loss = weight_reduce_loss(loss, weight, reduction)
93 | return loss
94 |
95 | return wrapper
96 |
--------------------------------------------------------------------------------
/basicsr/models/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from basicsr.models.losses.loss_util import weighted_loss
7 |
8 | _reduction_modes = ['none', 'mean', 'sum']
9 |
10 |
11 | @weighted_loss
12 | def l1_loss(pred, target):
13 | return F.l1_loss(pred, target, reduction='none')
14 |
15 |
16 | @weighted_loss
17 | def mse_loss(pred, target):
18 | return F.mse_loss(pred, target, reduction='none')
19 |
20 |
21 | # @weighted_loss
22 | # def charbonnier_loss(pred, target, eps=1e-12):
23 | # return torch.sqrt((pred - target)**2 + eps)
24 |
25 |
26 | class L1Loss(nn.Module):
27 | """L1 (mean absolute error, MAE) loss.
28 |
29 | Args:
30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
31 | reduction (str): Specifies the reduction to apply to the output.
32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
33 | """
34 |
35 | def __init__(self, loss_weight=1.0, reduction='mean'):
36 | super(L1Loss, self).__init__()
37 | if reduction not in ['none', 'mean', 'sum']:
38 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
39 | f'Supported ones are: {_reduction_modes}')
40 |
41 | self.loss_weight = loss_weight
42 | self.reduction = reduction
43 |
44 | def forward(self, pred, target, weight=None, **kwargs):
45 | """
46 | Args:
47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
50 | weights. Default: None.
51 | """
52 | return self.loss_weight * l1_loss(
53 | pred, target, weight, reduction=self.reduction)
54 |
55 |
56 | class FFTLoss(nn.Module):
57 | """L1 loss in frequency domain with FFT.
58 |
59 | Args:
60 | loss_weight (float): Loss weight for FFT loss. Default: 1.0.
61 | reduction (str): Specifies the reduction to apply to the output.
62 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
63 | """
64 |
65 | def __init__(self, loss_weight=1.0, reduction='mean'):
66 | super(FFTLoss, self).__init__()
67 | if reduction not in ['none', 'mean', 'sum']:
68 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
69 |
70 | self.loss_weight = loss_weight
71 | self.reduction = reduction
72 |
73 | def forward(self, pred, target, weight=None, **kwargs):
74 | """
75 | Args:
76 | pred (Tensor): of shape (..., C, H, W). Predicted tensor.
77 | target (Tensor): of shape (..., C, H, W). Ground truth tensor.
78 | weight (Tensor, optional): of shape (..., C, H, W). Element-wise
79 | weights. Default: None.
80 | """
81 |
82 | pred_fft = torch.fft.fft2(pred, dim=(-2, -1))
83 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1)
84 | target_fft = torch.fft.fft2(target, dim=(-2, -1))
85 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1)
86 | return self.loss_weight * l1_loss(pred_fft, target_fft, weight, reduction=self.reduction)
87 |
88 |
89 | class MSELoss(nn.Module):
90 | """MSE (L2) loss.
91 |
92 | Args:
93 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
94 | reduction (str): Specifies the reduction to apply to the output.
95 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
96 | """
97 |
98 | def __init__(self, loss_weight=1.0, reduction='mean'):
99 | super(MSELoss, self).__init__()
100 | if reduction not in ['none', 'mean', 'sum']:
101 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
102 | f'Supported ones are: {_reduction_modes}')
103 |
104 | self.loss_weight = loss_weight
105 | self.reduction = reduction
106 |
107 | def forward(self, pred, target, weight=None, **kwargs):
108 | """
109 | Args:
110 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
111 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
112 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
113 | weights. Default: None.
114 | """
115 | return self.loss_weight * mse_loss(
116 | pred, target, weight, reduction=self.reduction)
117 |
118 | class PSNRLoss(nn.Module):
119 |
120 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
121 | super(PSNRLoss, self).__init__()
122 | assert reduction == 'mean'
123 | self.loss_weight = loss_weight
124 | self.scale = 10 / np.log(10)
125 | self.toY = toY
126 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
127 | self.first = True
128 |
129 | def forward(self, pred, target):
130 | assert len(pred.size()) == 4
131 | if self.toY:
132 | if self.first:
133 | self.coef = self.coef.to(pred.device)
134 | self.first = False
135 |
136 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
137 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
138 |
139 | pred, target = pred / 255., target / 255.
140 | pass
141 | assert len(pred.size()) == 4
142 |
143 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
144 |
145 | class CharbonnierLoss(nn.Module):
146 | """Charbonnier Loss (L1)"""
147 |
148 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3):
149 | super(CharbonnierLoss, self).__init__()
150 | self.eps = eps
151 |
152 | def forward(self, x, y):
153 | diff = x - y
154 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
155 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
156 | return loss
157 |
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import create_dataloader, create_dataset
6 | from basicsr.models import create_model
7 | from basicsr.train import parse_options
8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
9 | make_exp_dirs)
10 | from basicsr.utils.options import dict2str
11 |
12 |
13 | def main():
14 | # parse options, set distributed setting, set ramdom seed
15 | opt = parse_options(is_train=False)
16 |
17 | torch.backends.cudnn.benchmark = True
18 | # torch.backends.cudnn.deterministic = True
19 |
20 | # mkdir and initialize loggers
21 | make_exp_dirs(opt)
22 | log_file = osp.join(opt['path']['log'],
23 | f"test_{opt['name']}_{get_time_str()}.log")
24 | logger = get_root_logger(
25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
26 | logger.info(get_env_info())
27 | logger.info(dict2str(opt))
28 |
29 | # create test dataset and dataloader
30 | test_loaders = []
31 | for phase, dataset_opt in sorted(opt['datasets'].items()):
32 | test_set = create_dataset(dataset_opt)
33 | test_loader = create_dataloader(
34 | test_set,
35 | dataset_opt,
36 | num_gpu=opt['num_gpu'],
37 | dist=opt['dist'],
38 | sampler=None,
39 | seed=opt['manual_seed'])
40 | logger.info(
41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
42 | test_loaders.append(test_loader)
43 |
44 | # create model
45 | model = create_model(opt)
46 |
47 | for test_loader in test_loaders:
48 | test_set_name = test_loader.dataset.opt['name']
49 | logger.info(f'Testing {test_set_name}...')
50 | rgb2bgr = opt['val'].get('rgb2bgr', True)
51 | # wheather use uint8 image to compute metrics
52 | use_image = opt['val'].get('use_image', True)
53 | model.validation(
54 | test_loader,
55 | current_iter=opt['name'],
56 | tb_logger=None,
57 | save_img=opt['val']['save_img'],
58 | rgb2bgr=rgb2bgr, use_image=use_image)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_client import FileClient
2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP
3 | from .logger import (MessageLogger, get_env_info, get_root_logger,
4 | init_tb_logger, init_wandb_logger)
5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'check_resume',
31 | 'sizeof_fmt',
32 | 'padding',
33 | 'padding_DP',
34 | 'imfrombytesDP',
35 | 'create_lmdb_for_reds',
36 | 'create_lmdb_for_gopro',
37 | 'create_lmdb_for_rain13k',
38 | ]
39 |
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/create_lmdb.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/create_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/dist_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/dist_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/dist_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/dist_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/file_client.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/file_client.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/file_client.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/file_client.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/flow_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/flow_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/img_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/img_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/img_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/img_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/lmdb_util.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/lmdb_util.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/matlab_functions.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/matlab_functions.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/options.cpython-37.pyc
--------------------------------------------------------------------------------
/basicsr/utils/__pycache__/options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/HINT/acb541738b59a4ec6660f6e8f6becd4fabd59eda/basicsr/utils/__pycache__/options.cpython-38.pyc
--------------------------------------------------------------------------------
/basicsr/utils/bundle_submissions.py:
--------------------------------------------------------------------------------
1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de)
2 |
3 | # This file is part of the implementation as described in the CVPR 2017 paper:
4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
5 | # Please see the file LICENSE.txt for the license governing this code.
6 |
7 |
8 | import numpy as np
9 | import scipy.io as sio
10 | import os
11 | import h5py
12 |
13 | def bundle_submissions_raw(submission_folder,session):
14 | '''
15 | Bundles submission data for raw denoising
16 |
17 | submission_folder Folder where denoised images reside
18 |
19 | Output is written to /bundled/. Please submit
20 | the content of this folder.
21 | '''
22 |
23 | out_folder = os.path.join(submission_folder, session)
24 | # out_folder = os.path.join(submission_folder, "bundled/")
25 | try:
26 | os.mkdir(out_folder)
27 | except:pass
28 |
29 | israw = True
30 | eval_version="1.0"
31 |
32 | for i in range(50):
33 | Idenoised = np.zeros((20,), dtype=np.object)
34 | for bb in range(20):
35 | filename = '%04d_%02d.mat'%(i+1,bb+1)
36 | s = sio.loadmat(os.path.join(submission_folder,filename))
37 | Idenoised_crop = s["Idenoised_crop"]
38 | Idenoised[bb] = Idenoised_crop
39 | filename = '%04d.mat'%(i+1)
40 | sio.savemat(os.path.join(out_folder, filename),
41 | {"Idenoised": Idenoised,
42 | "israw": israw,
43 | "eval_version": eval_version},
44 | )
45 |
46 | def bundle_submissions_srgb(submission_folder,session):
47 | '''
48 | Bundles submission data for sRGB denoising
49 |
50 | submission_folder Folder where denoised images reside
51 |
52 | Output is written to /bundled/. Please submit
53 | the content of this folder.
54 | '''
55 | out_folder = os.path.join(submission_folder, session)
56 | # out_folder = os.path.join(submission_folder, "bundled/")
57 | try:
58 | os.mkdir(out_folder)
59 | except:pass
60 | israw = False
61 | eval_version="1.0"
62 |
63 | for i in range(50):
64 | Idenoised = np.zeros((20,), dtype=np.object)
65 | for bb in range(20):
66 | filename = '%04d_%02d.mat'%(i+1,bb+1)
67 | s = sio.loadmat(os.path.join(submission_folder,filename))
68 | Idenoised_crop = s["Idenoised_crop"]
69 | Idenoised[bb] = Idenoised_crop
70 | filename = '%04d.mat'%(i+1)
71 | sio.savemat(os.path.join(out_folder, filename),
72 | {"Idenoised": Idenoised,
73 | "israw": israw,
74 | "eval_version": eval_version},
75 | )
76 |
77 |
78 |
79 | def bundle_submissions_srgb_v1(submission_folder,session):
80 | '''
81 | Bundles submission data for sRGB denoising
82 |
83 | submission_folder Folder where denoised images reside
84 |
85 | Output is written to /bundled/. Please submit
86 | the content of this folder.
87 | '''
88 | out_folder = os.path.join(submission_folder, session)
89 | # out_folder = os.path.join(submission_folder, "bundled/")
90 | try:
91 | os.mkdir(out_folder)
92 | except:pass
93 | israw = False
94 | eval_version="1.0"
95 |
96 | for i in range(50):
97 | Idenoised = np.zeros((20,), dtype=np.object)
98 | for bb in range(20):
99 | filename = '%04d_%d.mat'%(i+1,bb+1)
100 | s = sio.loadmat(os.path.join(submission_folder,filename))
101 | Idenoised_crop = s["Idenoised_crop"]
102 | Idenoised[bb] = Idenoised_crop
103 | filename = '%04d.mat'%(i+1)
104 | sio.savemat(os.path.join(out_folder, filename),
105 | {"Idenoised": Idenoised,
106 | "israw": israw,
107 | "eval_version": eval_version},
108 | )
--------------------------------------------------------------------------------
/basicsr/utils/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6 |
7 | def prepare_keys(folder_path, suffix='png'):
8 | """Prepare image path list and keys for DIV2K dataset.
9 |
10 | Args:
11 | folder_path (str): Folder path.
12 |
13 | Returns:
14 | list[str]: Image path list.
15 | list[str]: Key list.
16 | """
17 | print('Reading image path list ...')
18 | img_path_list = sorted(
19 | list(scandir(folder_path, suffix=suffix, recursive=False)))
20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
21 |
22 | return img_path_list, keys
23 |
24 | def create_lmdb_for_reds():
25 | folder_path = './datasets/REDS/val/sharp_300'
26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
27 | img_path_list, keys = prepare_keys(folder_path, 'png')
28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
29 | #
30 | folder_path = './datasets/REDS/val/blur_300'
31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb'
32 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
34 |
35 | folder_path = './datasets/REDS/train/train_sharp'
36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
37 | img_path_list, keys = prepare_keys(folder_path, 'png')
38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
39 |
40 | folder_path = './datasets/REDS/train/train_blur_jpeg'
41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
42 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
44 |
45 |
46 | def create_lmdb_for_gopro():
47 | folder_path = './datasets/GoPro/train/blur_crops'
48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
49 |
50 | img_path_list, keys = prepare_keys(folder_path, 'png')
51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
52 |
53 | folder_path = './datasets/GoPro/train/sharp_crops'
54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
55 |
56 | img_path_list, keys = prepare_keys(folder_path, 'png')
57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
58 |
59 | folder_path = './datasets/GoPro/test/target'
60 | lmdb_path = './datasets/GoPro/test/target.lmdb'
61 |
62 | img_path_list, keys = prepare_keys(folder_path, 'png')
63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
64 |
65 | folder_path = './datasets/GoPro/test/input'
66 | lmdb_path = './datasets/GoPro/test/input.lmdb'
67 |
68 | img_path_list, keys = prepare_keys(folder_path, 'png')
69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
70 |
71 | def create_lmdb_for_rain13k():
72 | folder_path = './datasets/Rain13k/train/input'
73 | lmdb_path = './datasets/Rain13k/train/input.lmdb'
74 |
75 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
77 |
78 | folder_path = './datasets/Rain13k/train/target'
79 | lmdb_path = './datasets/Rain13k/train/target.lmdb'
80 |
81 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
83 |
84 | def create_lmdb_for_SIDD():
85 | folder_path = './datasets/SIDD/train/input_crops'
86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
87 |
88 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
90 |
91 | folder_path = './datasets/SIDD/train/gt_crops'
92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
93 |
94 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
96 |
97 | #for val
98 | folder_path = './datasets/SIDD/val/input_crops'
99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
101 | if not osp.exists(folder_path):
102 | os.makedirs(folder_path)
103 | assert osp.exists(mat_path)
104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
105 | N, B, H ,W, C = data.shape
106 | data = data.reshape(N*B, H, W, C)
107 | for i in tqdm(range(N*B)):
108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
109 | img_path_list, keys = prepare_keys(folder_path, 'png')
110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
111 |
112 | folder_path = './datasets/SIDD/val/gt_crops'
113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
115 | if not osp.exists(folder_path):
116 | os.makedirs(folder_path)
117 | assert osp.exists(mat_path)
118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
119 | N, B, H ,W, C = data.shape
120 | data = data.reshape(N*B, H, W, C)
121 | for i in tqdm(range(N*B)):
122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
123 | img_path_list, keys = prepare_keys(folder_path, 'png')
124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
125 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(
45 | f'scontrol show hostname {node_list} | head -n1')
46 | # specify master port
47 | if port is not None:
48 | os.environ['MASTER_PORT'] = str(port)
49 | elif 'MASTER_PORT' in os.environ:
50 | pass # use MASTER_PORT in the environment variable
51 | else:
52 | # 29500 is torch.distributed default port
53 | os.environ['MASTER_PORT'] = '29500'
54 | os.environ['MASTER_ADDR'] = addr
55 | os.environ['WORLD_SIZE'] = str(ntasks)
56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57 | os.environ['RANK'] = str(proc_id)
58 | dist.init_process_group(backend=backend)
59 |
60 |
61 | def get_dist_info():
62 | if dist.is_available():
63 | initialized = dist.is_initialized()
64 | else:
65 | initialized = False
66 | if initialized:
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | else:
70 | rank = 0
71 | world_size = 1
72 | return rank, world_size
73 |
74 |
75 | def master_only(func):
76 |
77 | @functools.wraps(func)
78 | def wrapper(*args, **kwargs):
79 | rank, _ = get_dist_info()
80 | if rank == 0:
81 | return func(*args, **kwargs)
82 |
83 | return wrapper
84 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import requests
3 | from tqdm import tqdm
4 |
5 | from .misc import sizeof_fmt
6 |
7 |
8 | def download_file_from_google_drive(file_id, save_path):
9 | """Download files from google drive.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13 |
14 | Args:
15 | file_id (str): File id.
16 | save_path (str): Save path.
17 | """
18 |
19 | session = requests.Session()
20 | URL = 'https://docs.google.com/uc?export=download'
21 | params = {'id': file_id}
22 |
23 | response = session.get(URL, params=params, stream=True)
24 | token = get_confirm_token(response)
25 | if token:
26 | params['confirm'] = token
27 | response = session.get(URL, params=params, stream=True)
28 |
29 | # get file size
30 | response_file_size = session.get(
31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
32 | if 'Content-Range' in response_file_size.headers:
33 | file_size = int(
34 | response_file_size.headers['Content-Range'].split('/')[1])
35 | else:
36 | file_size = None
37 |
38 | save_response_content(response, save_path, file_size)
39 |
40 |
41 | def get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 |
48 | def save_response_content(response,
49 | destination,
50 | file_size=None,
51 | chunk_size=32768):
52 | if file_size is not None:
53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
54 |
55 | readable_file_size = sizeof_fmt(file_size)
56 | else:
57 | pbar = None
58 |
59 | with open(destination, 'wb') as f:
60 | downloaded_size = 0
61 | for chunk in response.iter_content(chunk_size):
62 | downloaded_size += chunk_size
63 | if pbar is not None:
64 | pbar.update(1)
65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
66 | f'/ {readable_file_size}')
67 | if chunk: # filter out keep-alive new chunks
68 | f.write(chunk)
69 | if pbar is not None:
70 | pbar.close()
71 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 | sys.path.append(sys_path)
36 | try:
37 | import mc
38 | except ImportError:
39 | raise ImportError(
40 | 'Please install memcached to enable MemcachedBackend.')
41 |
42 | self.server_list_cfg = server_list_cfg
43 | self.client_cfg = client_cfg
44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
45 | self.client_cfg)
46 | # mc.pyvector servers as a point which points to a memory cache
47 | self._mc_buffer = mc.pyvector()
48 |
49 | def get(self, filepath):
50 | filepath = str(filepath)
51 | import mc
52 | self._client.Get(filepath, self._mc_buffer)
53 | value_buf = mc.ConvertBuffer(self._mc_buffer)
54 | return value_buf
55 |
56 | def get_text(self, filepath):
57 | raise NotImplementedError
58 |
59 |
60 | class HardDiskBackend(BaseStorageBackend):
61 | """Raw hard disks storage backend."""
62 |
63 | def get(self, filepath):
64 | filepath = str(filepath)
65 | with open(filepath, 'rb') as f:
66 | value_buf = f.read()
67 | return value_buf
68 |
69 | def get_text(self, filepath):
70 | filepath = str(filepath)
71 | with open(filepath, 'r') as f:
72 | value_buf = f.read()
73 | return value_buf
74 |
75 |
76 | class LmdbBackend(BaseStorageBackend):
77 | """Lmdb storage backend.
78 |
79 | Args:
80 | db_paths (str | list[str]): Lmdb database paths.
81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
82 | readonly (bool, optional): Lmdb environment parameter. If True,
83 | disallow any write operations. Default: True.
84 | lock (bool, optional): Lmdb environment parameter. If False, when
85 | concurrent access occurs, do not lock the database. Default: False.
86 | readahead (bool, optional): Lmdb environment parameter. If False,
87 | disable the OS filesystem readahead mechanism, which may improve
88 | random read performance when a database is larger than RAM.
89 | Default: False.
90 |
91 | Attributes:
92 | db_paths (list): Lmdb database path.
93 | _client (list): A list of several lmdb envs.
94 | """
95 |
96 | def __init__(self,
97 | db_paths,
98 | client_keys='default',
99 | readonly=True,
100 | lock=False,
101 | readahead=False,
102 | **kwargs):
103 | try:
104 | import lmdb
105 | except ImportError:
106 | raise ImportError('Please install lmdb to enable LmdbBackend.')
107 |
108 | if isinstance(client_keys, str):
109 | client_keys = [client_keys]
110 |
111 | if isinstance(db_paths, list):
112 | self.db_paths = [str(v) for v in db_paths]
113 | elif isinstance(db_paths, str):
114 | self.db_paths = [str(db_paths)]
115 | assert len(client_keys) == len(self.db_paths), (
116 | 'client_keys and db_paths should have the same length, '
117 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
118 |
119 | self._client = {}
120 |
121 | for client, path in zip(client_keys, self.db_paths):
122 | self._client[client] = lmdb.open(
123 | path,
124 | readonly=readonly,
125 | lock=lock,
126 | readahead=readahead,
127 | map_size=8*1024*10485760,
128 | # max_readers=1,
129 | **kwargs)
130 |
131 | def get(self, filepath, client_key):
132 | """Get values according to the filepath from one lmdb named client_key.
133 |
134 | Args:
135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
136 | client_key (str): Used for distinguishing differnet lmdb envs.
137 | """
138 | filepath = str(filepath)
139 | assert client_key in self._client, (f'client_key {client_key} is not '
140 | 'in lmdb clients.')
141 | client = self._client[client_key]
142 | with client.begin(write=False) as txn:
143 | value_buf = txn.get(filepath.encode('ascii'))
144 | return value_buf
145 |
146 | def get_text(self, filepath):
147 | raise NotImplementedError
148 |
149 |
150 | class FileClient(object):
151 | """A general file client to access files in different backend.
152 |
153 | The client loads a file or text in a specified backend from its path
154 | and return it as a binary file. it can also register other backend
155 | accessor with a given name and backend class.
156 |
157 | Attributes:
158 | backend (str): The storage backend type. Options are "disk",
159 | "memcached" and "lmdb".
160 | client (:obj:`BaseStorageBackend`): The backend object.
161 | """
162 |
163 | _backends = {
164 | 'disk': HardDiskBackend,
165 | 'memcached': MemcachedBackend,
166 | 'lmdb': LmdbBackend,
167 | }
168 |
169 | def __init__(self, backend='disk', **kwargs):
170 | if backend not in self._backends:
171 | raise ValueError(
172 | f'Backend {backend} is not supported. Currently supported ones'
173 | f' are {list(self._backends.keys())}')
174 | self.backend = backend
175 | self.client = self._backends[backend](**kwargs)
176 |
177 | def get(self, filepath, client_key='default'):
178 | # client_key is used only for lmdb, where different fileclients have
179 | # different lmdb environments.
180 | if self.backend == 'lmdb':
181 | return self.client.get(filepath, client_key)
182 | else:
183 | return self.client.get(filepath)
184 |
185 | def get_text(self, filepath):
186 | return self.client.get_text(filepath)
187 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2 | import cv2
3 | import numpy as np
4 | import os
5 |
6 |
7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8 | """Read an optical flow map.
9 |
10 | Args:
11 | flow_path (ndarray or str): Flow path.
12 | quantize (bool): whether to read quantized pair, if set to True,
13 | remaining args will be passed to :func:`dequantize_flow`.
14 | concat_axis (int): The axis that dx and dy are concatenated,
15 | can be either 0 or 1. Ignored if quantize is False.
16 |
17 | Returns:
18 | ndarray: Optical flow represented as a (h, w, 2) numpy array
19 | """
20 | if quantize:
21 | assert concat_axis in [0, 1]
22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23 | if cat_flow.ndim != 2:
24 | raise IOError(f'{flow_path} is not a valid quantized flow file, '
25 | f'its dimension is {cat_flow.ndim}.')
26 | assert cat_flow.shape[concat_axis] % 2 == 0
27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
28 | flow = dequantize_flow(dx, dy, *args, **kwargs)
29 | else:
30 | with open(flow_path, 'rb') as f:
31 | try:
32 | header = f.read(4).decode('utf-8')
33 | except Exception:
34 | raise IOError(f'Invalid flow file: {flow_path}')
35 | else:
36 | if header != 'PIEH':
37 | raise IOError(f'Invalid flow file: {flow_path}, '
38 | 'header does not contain PIEH')
39 |
40 | w = np.fromfile(f, np.int32, 1).squeeze()
41 | h = np.fromfile(f, np.int32, 1).squeeze()
42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
43 |
44 | return flow.astype(np.float32)
45 |
46 |
47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
48 | """Write optical flow to file.
49 |
50 | If the flow is not quantized, it will be saved as a .flo file losslessly,
51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
52 | will be concatenated horizontally into a single image if quantize is True.)
53 |
54 | Args:
55 | flow (ndarray): (h, w, 2) array of optical flow.
56 | filename (str): Output filepath.
57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
58 | images. If set to True, remaining args will be passed to
59 | :func:`quantize_flow`.
60 | concat_axis (int): The axis that dx and dy are concatenated,
61 | can be either 0 or 1. Ignored if quantize is False.
62 | """
63 | if not quantize:
64 | with open(filename, 'wb') as f:
65 | f.write('PIEH'.encode('utf-8'))
66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
67 | flow = flow.astype(np.float32)
68 | flow.tofile(f)
69 | f.flush()
70 | else:
71 | assert concat_axis in [0, 1]
72 | dx, dy = quantize_flow(flow, *args, **kwargs)
73 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
74 | os.makedirs(filename, exist_ok=True)
75 | cv2.imwrite(dxdy, filename)
76 |
77 |
78 | def quantize_flow(flow, max_val=0.02, norm=True):
79 | """Quantize flow to [0, 255].
80 |
81 | After this step, the size of flow will be much smaller, and can be
82 | dumped as jpeg images.
83 |
84 | Args:
85 | flow (ndarray): (h, w, 2) array of optical flow.
86 | max_val (float): Maximum value of flow, values beyond
87 | [-max_val, max_val] will be truncated.
88 | norm (bool): Whether to divide flow values by image width/height.
89 |
90 | Returns:
91 | tuple[ndarray]: Quantized dx and dy.
92 | """
93 | h, w, _ = flow.shape
94 | dx = flow[..., 0]
95 | dy = flow[..., 1]
96 | if norm:
97 | dx = dx / w # avoid inplace operations
98 | dy = dy / h
99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
100 | flow_comps = [
101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
102 | ]
103 | return tuple(flow_comps)
104 |
105 |
106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
107 | """Recover from quantized flow.
108 |
109 | Args:
110 | dx (ndarray): Quantized dx.
111 | dy (ndarray): Quantized dy.
112 | max_val (float): Maximum value used when quantizing.
113 | denorm (bool): Whether to multiply flow values with width/height.
114 |
115 | Returns:
116 | ndarray: Dequantized flow.
117 | """
118 | assert dx.shape == dy.shape
119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
120 |
121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
122 |
123 | if denorm:
124 | dx *= dx.shape[1]
125 | dy *= dx.shape[0]
126 | flow = np.dstack((dx, dy))
127 | return flow
128 |
129 |
130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
131 | """Quantize an array of (-inf, inf) to [0, levels-1].
132 |
133 | Args:
134 | arr (ndarray): Input array.
135 | min_val (scalar): Minimum value to be clipped.
136 | max_val (scalar): Maximum value to be clipped.
137 | levels (int): Quantization levels.
138 | dtype (np.type): The type of the quantized array.
139 |
140 | Returns:
141 | tuple: Quantized array.
142 | """
143 | if not (isinstance(levels, int) and levels > 1):
144 | raise ValueError(
145 | f'levels must be a positive integer, but got {levels}')
146 | if min_val >= max_val:
147 | raise ValueError(
148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
149 |
150 | arr = np.clip(arr, min_val, max_val) - min_val
151 | quantized_arr = np.minimum(
152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
153 |
154 | return quantized_arr
155 |
156 |
157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
158 | """Dequantize an array.
159 |
160 | Args:
161 | arr (ndarray): Input array.
162 | min_val (scalar): Minimum value to be clipped.
163 | max_val (scalar): Maximum value to be clipped.
164 | levels (int): Quantization levels.
165 | dtype (np.type): The type of the dequantized array.
166 |
167 | Returns:
168 | tuple: Dequantized array.
169 | """
170 | if not (isinstance(levels, int) and levels > 1):
171 | raise ValueError(
172 | f'levels must be a positive integer, but got {levels}')
173 | if min_val >= max_val:
174 | raise ValueError(
175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
176 |
177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
178 | min_val) / levels + min_val
179 |
180 | return dequantized_arr
181 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def img2tensor(imgs, bgr2rgb=True, float32=True):
10 | """Numpy array to tensor.
11 |
12 | Args:
13 | imgs (list[ndarray] | ndarray): Input images.
14 | bgr2rgb (bool): Whether to change bgr to rgb.
15 | float32 (bool): Whether to change to float32.
16 |
17 | Returns:
18 | list[tensor] | tensor: Tensor images. If returned results only have
19 | one element, just return tensor.
20 | """
21 |
22 | def _totensor(img, bgr2rgb, float32):
23 | if img.shape[2] == 3 and bgr2rgb:
24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
25 | img = torch.from_numpy(img.transpose(2, 0, 1))
26 | if float32:
27 | img = img.float()
28 | return img
29 |
30 | if isinstance(imgs, list):
31 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
32 | else:
33 | return _totensor(imgs, bgr2rgb, float32)
34 |
35 |
36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
37 | """Convert torch Tensors into image numpy arrays.
38 |
39 | After clamping to [min, max], values will be normalized to [0, 1].
40 |
41 | Args:
42 | tensor (Tensor or list[Tensor]): Accept shapes:
43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
44 | 2) 3D Tensor of shape (3/1 x H x W);
45 | 3) 2D Tensor of shape (H x W).
46 | Tensor channel should be in RGB order.
47 | rgb2bgr (bool): Whether to change rgb to bgr.
48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
49 | to uint8 type with range [0, 255]; otherwise, float type with
50 | range [0, 1]. Default: ``np.uint8``.
51 | min_max (tuple[int]): min and max values for clamp.
52 |
53 | Returns:
54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
55 | shape (H x W). The channel order is BGR.
56 | """
57 | if not (torch.is_tensor(tensor) or
58 | (isinstance(tensor, list)
59 | and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(
61 | f'tensor or list of tensors expected, got {type(tensor)}')
62 |
63 | if torch.is_tensor(tensor):
64 | tensor = [tensor]
65 | result = []
66 | for _tensor in tensor:
67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
69 |
70 | n_dim = _tensor.dim()
71 | if n_dim == 4:
72 | img_np = make_grid(
73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))),
74 | normalize=False).numpy()
75 | img_np = img_np.transpose(1, 2, 0)
76 | if rgb2bgr:
77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
78 | elif n_dim == 3:
79 | img_np = _tensor.numpy()
80 | img_np = img_np.transpose(1, 2, 0)
81 | if img_np.shape[2] == 1: # gray image
82 | img_np = np.squeeze(img_np, axis=2)
83 | else:
84 | if rgb2bgr:
85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
86 | elif n_dim == 2:
87 | img_np = _tensor.numpy()
88 | else:
89 | raise TypeError('Only support 4D, 3D or 2D tensor. '
90 | f'But received with dimension: {n_dim}')
91 | if out_type == np.uint8:
92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
93 | img_np = (img_np * 255.0).round()
94 | img_np = img_np.astype(out_type)
95 | result.append(img_np)
96 | if len(result) == 1:
97 | result = result[0]
98 | return result
99 |
100 |
101 | def imfrombytes(content, flag='color', float32=False):
102 | """Read an image from bytes.
103 |
104 | Args:
105 | content (bytes): Image bytes got from files or other streams.
106 | flag (str): Flags specifying the color type of a loaded image,
107 | candidates are `color`, `grayscale` and `unchanged`.
108 | float32 (bool): Whether to change to float32., If True, will also norm
109 | to [0, 1]. Default: False.
110 |
111 | Returns:
112 | ndarray: Loaded image array.
113 | """
114 | img_np = np.frombuffer(content, np.uint8)
115 | imread_flags = {
116 | 'color': cv2.IMREAD_COLOR,
117 | 'grayscale': cv2.IMREAD_GRAYSCALE,
118 | 'unchanged': cv2.IMREAD_UNCHANGED
119 | }
120 | if img_np is None:
121 | raise Exception('None .. !!!')
122 | img = cv2.imdecode(img_np, imread_flags[flag])
123 | if float32:
124 | img = img.astype(np.float32) / 255.
125 | return img
126 |
127 | def imfrombytesDP(content, flag='color', float32=False):
128 | """Read an image from bytes.
129 |
130 | Args:
131 | content (bytes): Image bytes got from files or other streams.
132 | flag (str): Flags specifying the color type of a loaded image,
133 | candidates are `color`, `grayscale` and `unchanged`.
134 | float32 (bool): Whether to change to float32., If True, will also norm
135 | to [0, 1]. Default: False.
136 |
137 | Returns:
138 | ndarray: Loaded image array.
139 | """
140 | img_np = np.frombuffer(content, np.uint8)
141 | if img_np is None:
142 | raise Exception('None .. !!!')
143 | img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED)
144 | if float32:
145 | img = img.astype(np.float32) / 65535.
146 | return img
147 |
148 | def padding(img_lq, img_gt, gt_size):
149 | h, w, _ = img_lq.shape
150 |
151 | h_pad = max(0, gt_size - h)
152 | w_pad = max(0, gt_size - w)
153 |
154 | if h_pad == 0 and w_pad == 0:
155 | return img_lq, img_gt
156 |
157 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
158 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
159 | # print('img_lq', img_lq.shape, img_gt.shape)
160 | if img_lq.ndim == 2:
161 | img_lq = np.expand_dims(img_lq, axis=2)
162 | if img_gt.ndim == 2:
163 | img_gt = np.expand_dims(img_gt, axis=2)
164 | return img_lq, img_gt
165 |
166 | def padding_DP(img_lqL, img_lqR, img_gt, gt_size):
167 | h, w, _ = img_gt.shape
168 |
169 | h_pad = max(0, gt_size - h)
170 | w_pad = max(0, gt_size - w)
171 |
172 | if h_pad == 0 and w_pad == 0:
173 | return img_lqL, img_lqR, img_gt
174 |
175 | img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
176 | img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
177 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
178 | # print('img_lq', img_lq.shape, img_gt.shape)
179 | return img_lqL, img_lqR, img_gt
180 |
181 | def imwrite(img, file_path, params=None, auto_mkdir=True):
182 | """Write image to file.
183 |
184 | Args:
185 | img (ndarray): Image array to be written.
186 | file_path (str): Image file path.
187 | params (None or list): Same as opencv's :func:`imwrite` interface.
188 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
189 | whether to create it automatically.
190 |
191 | Returns:
192 | bool: Successful or not.
193 | """
194 | if auto_mkdir:
195 | dir_name = os.path.abspath(os.path.dirname(file_path))
196 | os.makedirs(dir_name, exist_ok=True)
197 | return cv2.imwrite(file_path, img, params)
198 |
199 |
200 | def crop_border(imgs, crop_border):
201 | """Crop borders of images.
202 |
203 | Args:
204 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
205 | crop_border (int): Crop border for each end of height and weight.
206 |
207 | Returns:
208 | list[ndarray]: Cropped images.
209 | """
210 | if crop_border == 0:
211 | return imgs
212 | else:
213 | if isinstance(imgs, list):
214 | return [
215 | v[crop_border:-crop_border, crop_border:-crop_border, ...]
216 | for v in imgs
217 | ]
218 | else:
219 | return imgs[crop_border:-crop_border, crop_border:-crop_border,
220 | ...]
221 |
--------------------------------------------------------------------------------
/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import lmdb
3 | import sys
4 | from multiprocessing import Pool
5 | from os import path as osp
6 | from tqdm import tqdm
7 |
8 |
9 | def make_lmdb_from_imgs(data_path,
10 | lmdb_path,
11 | img_path_list,
12 | keys,
13 | batch=5000,
14 | compress_level=1,
15 | multiprocessing_read=False,
16 | n_thread=40,
17 | map_size=None):
18 | """Make lmdb from images.
19 |
20 | Contents of lmdb. The file structure is:
21 | example.lmdb
22 | ├── data.mdb
23 | ├── lock.mdb
24 | ├── meta_info.txt
25 |
26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
27 | https://lmdb.readthedocs.io/en/release/ for more details.
28 |
29 | The meta_info.txt is a specified txt file to record the meta information
30 | of our datasets. It will be automatically created when preparing
31 | datasets by our provided dataset tools.
32 | Each line in the txt file records 1)image name (with extension),
33 | 2)image shape, and 3)compression level, separated by a white space.
34 |
35 | For example, the meta information could be:
36 | `000_00000000.png (720,1280,3) 1`, which means:
37 | 1) image name (with extension): 000_00000000.png;
38 | 2) image shape: (720,1280,3);
39 | 3) compression level: 1
40 |
41 | We use the image name without extension as the lmdb key.
42 |
43 | If `multiprocessing_read` is True, it will read all the images to memory
44 | using multiprocessing. Thus, your server needs to have enough memory.
45 |
46 | Args:
47 | data_path (str): Data path for reading images.
48 | lmdb_path (str): Lmdb save path.
49 | img_path_list (str): Image path list.
50 | keys (str): Used for lmdb keys.
51 | batch (int): After processing batch images, lmdb commits.
52 | Default: 5000.
53 | compress_level (int): Compress level when encoding images. Default: 1.
54 | multiprocessing_read (bool): Whether use multiprocessing to read all
55 | the images to memory. Default: False.
56 | n_thread (int): For multiprocessing.
57 | map_size (int | None): Map size for lmdb env. If None, use the
58 | estimated size from images. Default: None
59 | """
60 |
61 | assert len(img_path_list) == len(keys), (
62 | 'img_path_list and keys should have the same length, '
63 | f'but got {len(img_path_list)} and {len(keys)}')
64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
65 | print(f'Totoal images: {len(img_path_list)}')
66 | if not lmdb_path.endswith('.lmdb'):
67 | raise ValueError("lmdb_path must end with '.lmdb'.")
68 | if osp.exists(lmdb_path):
69 | print(f'Folder {lmdb_path} already exists. Exit.')
70 | sys.exit(1)
71 |
72 | if multiprocessing_read:
73 | # read all the images to memory (multiprocessing)
74 | dataset = {} # use dict to keep the order for multiprocessing
75 | shapes = {}
76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...')
77 | pbar = tqdm(total=len(img_path_list), unit='image')
78 |
79 | def callback(arg):
80 | """get the image data and update pbar."""
81 | key, dataset[key], shapes[key] = arg
82 | pbar.update(1)
83 | pbar.set_description(f'Read {key}')
84 |
85 | pool = Pool(n_thread)
86 | for path, key in zip(img_path_list, keys):
87 | pool.apply_async(
88 | read_img_worker,
89 | args=(osp.join(data_path, path), key, compress_level),
90 | callback=callback)
91 | pool.close()
92 | pool.join()
93 | pbar.close()
94 | print(f'Finish reading {len(img_path_list)} images.')
95 |
96 | # create lmdb environment
97 | if map_size is None:
98 | # obtain data size for one image
99 | img = cv2.imread(
100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
101 | _, img_byte = cv2.imencode(
102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
103 | data_size_per_img = img_byte.nbytes
104 | print('Data size per image is: ', data_size_per_img)
105 | data_size = data_size_per_img * len(img_path_list)
106 | map_size = data_size * 10
107 |
108 | env = lmdb.open(lmdb_path, map_size=map_size)
109 |
110 | # write data to lmdb
111 | pbar = tqdm(total=len(img_path_list), unit='chunk')
112 | txn = env.begin(write=True)
113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
115 | pbar.update(1)
116 | pbar.set_description(f'Write {key}')
117 | key_byte = key.encode('ascii')
118 | if multiprocessing_read:
119 | img_byte = dataset[key]
120 | h, w, c = shapes[key]
121 | else:
122 | _, img_byte, img_shape = read_img_worker(
123 | osp.join(data_path, path), key, compress_level)
124 | h, w, c = img_shape
125 |
126 | txn.put(key_byte, img_byte)
127 | # write meta information
128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
129 | if idx % batch == 0:
130 | txn.commit()
131 | txn = env.begin(write=True)
132 | pbar.close()
133 | txn.commit()
134 | env.close()
135 | txt_file.close()
136 | print('\nFinish writing lmdb.')
137 |
138 |
139 | def read_img_worker(path, key, compress_level):
140 | """Read image worker.
141 |
142 | Args:
143 | path (str): Image path.
144 | key (str): Image key.
145 | compress_level (int): Compress level when encoding images.
146 |
147 | Returns:
148 | str: Image key.
149 | byte: Image byte.
150 | tuple[int]: Image shape.
151 | """
152 |
153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
154 | if img.ndim == 2:
155 | h, w = img.shape
156 | c = 1
157 | else:
158 | h, w, c = img.shape
159 | _, img_byte = cv2.imencode('.png', img,
160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
161 | return (key, img_byte, (h, w, c))
162 |
163 |
164 | class LmdbMaker():
165 | """LMDB Maker.
166 |
167 | Args:
168 | lmdb_path (str): Lmdb save path.
169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
170 | batch (int): After processing batch images, lmdb commits.
171 | Default: 5000.
172 | compress_level (int): Compress level when encoding images. Default: 1.
173 | """
174 |
175 | def __init__(self,
176 | lmdb_path,
177 | map_size=1024**4,
178 | batch=5000,
179 | compress_level=1):
180 | if not lmdb_path.endswith('.lmdb'):
181 | raise ValueError("lmdb_path must end with '.lmdb'.")
182 | if osp.exists(lmdb_path):
183 | print(f'Folder {lmdb_path} already exists. Exit.')
184 | sys.exit(1)
185 |
186 | self.lmdb_path = lmdb_path
187 | self.batch = batch
188 | self.compress_level = compress_level
189 | self.env = lmdb.open(lmdb_path, map_size=map_size)
190 | self.txn = self.env.begin(write=True)
191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
192 | self.counter = 0
193 |
194 | def put(self, img_byte, key, img_shape):
195 | self.counter += 1
196 | key_byte = key.encode('ascii')
197 | self.txn.put(key_byte, img_byte)
198 | # write meta information
199 | h, w, c = img_shape
200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
201 | if self.counter % self.batch == 0:
202 | self.txn.commit()
203 | self.txn = self.env.begin(write=True)
204 |
205 | def close(self):
206 | self.txn.commit()
207 | self.env.close()
208 | self.txt_file.close()
209 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 | initialized_logger = {}
8 |
9 |
10 | class MessageLogger():
11 | """Message logger for printing.
12 |
13 | Args:
14 | opt (dict): Config. It contains the following keys:
15 | name (str): Exp name.
16 | logger (dict): Contains 'print_freq' (str) for logger interval.
17 | train (dict): Contains 'total_iter' (int) for total iters.
18 | use_tb_logger (bool): Use tensorboard logger.
19 | start_iter (int): Start iter. Default: 1.
20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
21 | """
22 |
23 | def __init__(self, opt, start_iter=1, tb_logger=None):
24 | self.exp_name = opt['name']
25 | self.interval = opt['logger']['print_freq']
26 | self.start_iter = start_iter
27 | self.max_iters = opt['train']['total_iter']
28 | self.use_tb_logger = opt['logger']['use_tb_logger']
29 | self.tb_logger = tb_logger
30 | self.start_time = time.time()
31 | self.logger = get_root_logger()
32 |
33 | @master_only
34 | def __call__(self, log_vars):
35 | """Format logging message.
36 |
37 | Args:
38 | log_vars (dict): It contains the following keys:
39 | epoch (int): Epoch number.
40 | iter (int): Current iter.
41 | lrs (list): List for learning rates.
42 |
43 | time (float): Iter time.
44 | data_time (float): Data time for each iter.
45 | """
46 | # epoch, iter, learning rates
47 | epoch = log_vars.pop('epoch')
48 | current_iter = log_vars.pop('iter')
49 | lrs = log_vars.pop('lrs')
50 |
51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
52 | for v in lrs:
53 | message += f'{v:.3e},'
54 | message += ')] '
55 |
56 | # time and estimated time
57 | if 'time' in log_vars.keys():
58 | iter_time = log_vars.pop('time')
59 | data_time = log_vars.pop('data_time')
60 |
61 | total_time = time.time() - self.start_time
62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
65 | message += f'[eta: {eta_str}, '
66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
67 |
68 | # other items, especially losses
69 | for k, v in log_vars.items():
70 | message += f'{k}: {v:.4e} '
71 | # tensorboard logger
72 | if self.use_tb_logger and 'debug' not in self.exp_name:
73 | if k.startswith('l_'):
74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
75 | else:
76 | self.tb_logger.add_scalar(k, v, current_iter)
77 | self.logger.info(message)
78 |
79 |
80 | @master_only
81 | def init_tb_logger(log_dir):
82 | from torch.utils.tensorboard import SummaryWriter
83 | tb_logger = SummaryWriter(log_dir=log_dir)
84 | return tb_logger
85 |
86 |
87 | @master_only
88 | def init_wandb_logger(opt):
89 | """We now only use wandb to sync tensorboard log."""
90 | import wandb
91 | logger = logging.getLogger('basicsr')
92 |
93 | project = opt['logger']['wandb']['project']
94 | resume_id = opt['logger']['wandb'].get('resume_id')
95 | if resume_id:
96 | wandb_id = resume_id
97 | resume = 'allow'
98 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
99 | else:
100 | wandb_id = wandb.util.generate_id()
101 | resume = 'never'
102 |
103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
104 |
105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
106 |
107 |
108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
109 | """Get the root logger.
110 |
111 | The logger will be initialized if it has not been initialized. By default a
112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
113 | also be added.
114 |
115 | Args:
116 | logger_name (str): root logger name. Default: 'basicsr'.
117 | log_file (str | None): The log filename. If specified, a FileHandler
118 | will be added to the root logger.
119 | log_level (int): The root logger level. Note that only the process of
120 | rank 0 is affected, while other processes will set the level to
121 | "Error" and be silent most of the time.
122 |
123 | Returns:
124 | logging.Logger: The root logger.
125 | """
126 | logger = logging.getLogger(logger_name)
127 | # if the logger has been initialized, just return it
128 | if logger_name in initialized_logger:
129 | return logger
130 |
131 | format_str = '%(asctime)s %(levelname)s: %(message)s'
132 | stream_handler = logging.StreamHandler()
133 | stream_handler.setFormatter(logging.Formatter(format_str))
134 | logger.addHandler(stream_handler)
135 | logger.propagate = False
136 | rank, _ = get_dist_info()
137 | if rank != 0:
138 | logger.setLevel('ERROR')
139 | elif log_file is not None:
140 | logger.setLevel(log_level)
141 | # add file handler
142 | file_handler = logging.FileHandler(log_file, 'w')
143 | file_handler.setFormatter(logging.Formatter(format_str))
144 | file_handler.setLevel(log_level)
145 | logger.addHandler(file_handler)
146 | initialized_logger[logger_name] = True
147 | return logger
148 |
149 |
150 | def get_env_info():
151 | """Get environment information.
152 |
153 | Currently, only log the software version.
154 | """
155 | import torch
156 | import torchvision
157 |
158 | from basicsr.version import __version__
159 | msg = r"""
160 | ____ _ _____ ____
161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
164 | /_____/ \__,_//____//_/ \___//____//_/ |_|
165 | ______ __ __ __ __
166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
170 | """
171 | msg += ('\nVersion Information: '
172 | f'\n\tBasicSR: {__version__}'
173 | f'\n\tPyTorch: {torch.__version__}'
174 | f'\n\tTorchVision: {torchvision.__version__}')
175 | return msg
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 | from .logger import get_root_logger
10 |
11 |
12 | def set_random_seed(seed):
13 | """Set random seeds."""
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 |
20 |
21 | def get_time_str():
22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23 |
24 |
25 | def mkdir_and_rename(path):
26 | """mkdirs. If path exists, rename it with timestamp and create a new one.
27 |
28 | Args:
29 | path (str): Folder path.
30 | """
31 | if osp.exists(path):
32 | new_name = path + '_archived_' + get_time_str()
33 | print(f'Path already exists. Rename it to {new_name}', flush=True)
34 | os.rename(path, new_name)
35 | os.makedirs(path, exist_ok=True)
36 |
37 |
38 | @master_only
39 | def make_exp_dirs(opt):
40 | """Make dirs for experiments."""
41 | path_opt = opt['path'].copy()
42 | if opt['is_train']:
43 | mkdir_and_rename(path_opt.pop('experiments_root'))
44 | else:
45 | mkdir_and_rename(path_opt.pop('results_root'))
46 | for key, path in path_opt.items():
47 | if ('strict_load' not in key) and ('pretrain_network'
48 | not in key) and ('resume'
49 | not in key):
50 | os.makedirs(path, exist_ok=True)
51 |
52 |
53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
54 | """Scan a directory to find the interested files.
55 |
56 | Args:
57 | dir_path (str): Path of the directory.
58 | suffix (str | tuple(str), optional): File suffix that we are
59 | interested in. Default: None.
60 | recursive (bool, optional): If set to True, recursively scan the
61 | directory. Default: False.
62 | full_path (bool, optional): If set to True, include the dir_path.
63 | Default: False.
64 |
65 | Returns:
66 | A generator for all the interested files with relative pathes.
67 | """
68 |
69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
70 | raise TypeError('"suffix" must be a string or tuple of strings')
71 |
72 | root = dir_path
73 |
74 | def _scandir(dir_path, suffix, recursive):
75 | for entry in os.scandir(dir_path):
76 | if not entry.name.startswith('.') and entry.is_file():
77 | if full_path:
78 | return_path = entry.path
79 | else:
80 | return_path = osp.relpath(entry.path, root)
81 |
82 | if suffix is None:
83 | yield return_path
84 | elif return_path.endswith(suffix):
85 | yield return_path
86 | else:
87 | if recursive:
88 | yield from _scandir(
89 | entry.path, suffix=suffix, recursive=recursive)
90 | else:
91 | continue
92 |
93 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
94 |
95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
96 | """Scan a directory to find the interested files.
97 |
98 | Args:
99 | dir_path (str): Path of the directory.
100 | keywords (str | tuple(str), optional): File keywords that we are
101 | interested in. Default: None.
102 | recursive (bool, optional): If set to True, recursively scan the
103 | directory. Default: False.
104 | full_path (bool, optional): If set to True, include the dir_path.
105 | Default: False.
106 |
107 | Returns:
108 | A generator for all the interested files with relative pathes.
109 | """
110 |
111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)):
112 | raise TypeError('"keywords" must be a string or tuple of strings')
113 |
114 | root = dir_path
115 |
116 | def _scandir(dir_path, keywords, recursive):
117 | for entry in os.scandir(dir_path):
118 | if not entry.name.startswith('.') and entry.is_file():
119 | if full_path:
120 | return_path = entry.path
121 | else:
122 | return_path = osp.relpath(entry.path, root)
123 |
124 | if keywords is None:
125 | yield return_path
126 | elif return_path.find(keywords) > 0:
127 | yield return_path
128 | else:
129 | if recursive:
130 | yield from _scandir(
131 | entry.path, keywords=keywords, recursive=recursive)
132 | else:
133 | continue
134 |
135 | return _scandir(dir_path, keywords=keywords, recursive=recursive)
136 |
137 | def check_resume(opt, resume_iter):
138 | """Check resume states and pretrain_network paths.
139 |
140 | Args:
141 | opt (dict): Options.
142 | resume_iter (int): Resume iteration.
143 | """
144 | logger = get_root_logger()
145 | if opt['path']['resume_state']:
146 | # get all the networks
147 | networks = [key for key in opt.keys() if key.startswith('network_')]
148 | flag_pretrain = False
149 | for network in networks:
150 | if opt['path'].get(f'pretrain_{network}') is not None:
151 | flag_pretrain = True
152 | if flag_pretrain:
153 | logger.warning(
154 | 'pretrain_network path will be ignored during resuming.')
155 | # set pretrained model paths
156 | for network in networks:
157 | name = f'pretrain_{network}'
158 | basename = network.replace('network_', '')
159 | if opt['path'].get('ignore_resume_networks') is None or (
160 | basename not in opt['path']['ignore_resume_networks']):
161 | opt['path'][name] = osp.join(
162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
163 | logger.info(f"Set {name} to {opt['path'][name]}")
164 |
165 |
166 | def sizeof_fmt(size, suffix='B'):
167 | """Get human readable file size.
168 |
169 | Args:
170 | size (int): File size.
171 | suffix (str): Suffix. Default: 'B'.
172 |
173 | Return:
174 | str: Formated file siz.
175 | """
176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
177 | if abs(size) < 1024.0:
178 | return f'{size:3.1f} {unit}{suffix}'
179 | size /= 1024.0
180 | return f'{size:3.1f} Y{suffix}'
181 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from collections import OrderedDict
3 | from os import path as osp
4 |
5 |
6 | def ordered_yaml():
7 | """Support OrderedDict for yaml.
8 |
9 | Returns:
10 | yaml Loader and Dumper.
11 | """
12 | try:
13 | from yaml import CDumper as Dumper
14 | from yaml import CLoader as Loader
15 | except ImportError:
16 | from yaml import Dumper, Loader
17 |
18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
19 |
20 | def dict_representer(dumper, data):
21 | return dumper.represent_dict(data.items())
22 |
23 | def dict_constructor(loader, node):
24 | return OrderedDict(loader.construct_pairs(node))
25 |
26 | Dumper.add_representer(OrderedDict, dict_representer)
27 | Loader.add_constructor(_mapping_tag, dict_constructor)
28 | return Loader, Dumper
29 |
30 |
31 | def parse(opt_path, is_train=True):
32 | """Parse option file.
33 |
34 | Args:
35 | opt_path (str): Option file path.
36 | is_train (str): Indicate whether in training or not. Default: True.
37 |
38 | Returns:
39 | (dict): Options.
40 | """
41 | with open(opt_path, mode='r') as f:
42 | Loader, _ = ordered_yaml()
43 | opt = yaml.load(f, Loader=Loader)
44 |
45 | opt['is_train'] = is_train
46 |
47 | # datasets
48 | for phase, dataset in opt['datasets'].items():
49 | # for several datasets, e.g., test_1, test_2
50 | phase = phase.split('_')[0]
51 | dataset['phase'] = phase
52 | if 'scale' in opt:
53 | dataset['scale'] = opt['scale']
54 | if dataset.get('dataroot_gt') is not None:
55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
56 | if dataset.get('dataroot_lq') is not None:
57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
58 |
59 | # paths
60 | for key, val in opt['path'].items():
61 | if (val is not None) and ('resume_state' in key
62 | or 'pretrain_network' in key):
63 | opt['path'][key] = osp.expanduser(val)
64 | opt['path']['root'] = osp.abspath(
65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
66 | if is_train:
67 | experiments_root = osp.join(opt['path']['root'], 'experiments',
68 | opt['name'])
69 | opt['path']['experiments_root'] = experiments_root
70 | opt['path']['models'] = osp.join(experiments_root, 'models')
71 | opt['path']['training_states'] = osp.join(experiments_root,
72 | 'training_states')
73 | opt['path']['log'] = experiments_root
74 | opt['path']['visualization'] = osp.join(experiments_root,
75 | 'visualization')
76 |
77 | # change some options for debug mode
78 | if 'debug' in opt['name']:
79 | if 'val' in opt:
80 | opt['val']['val_freq'] = 8
81 | opt['logger']['print_freq'] = 1
82 | opt['logger']['save_checkpoint_freq'] = 8
83 | else: # test
84 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
85 | opt['path']['results_root'] = results_root
86 | opt['path']['log'] = results_root
87 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
88 |
89 | return opt
90 |
91 |
92 | def dict2str(opt, indent_level=1):
93 | """dict to string for printing options.
94 |
95 | Args:
96 | opt (dict): Option dict.
97 | indent_level (int): Indent level. Default: 1.
98 |
99 | Return:
100 | (str): Option string for printing.
101 | """
102 | msg = '\n'
103 | for k, v in opt.items():
104 | if isinstance(v, dict):
105 | msg += ' ' * (indent_level * 2) + k + ':['
106 | msg += dict2str(v, indent_level + 1)
107 | msg += ' ' * (indent_level * 2) + ']\n'
108 | else:
109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
110 | return msg
111 |
--------------------------------------------------------------------------------
/basicsr/version.py:
--------------------------------------------------------------------------------
1 | # GENERATED VERSION FILE
2 | # TIME: Tue Jul 1 18:25:19 2025
3 | __version__ = '1.2.0+733ceb2'
4 | short_version = '1.2.0'
5 | version_info = (1, 2, 0)
6 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: HINT_n
2 | channels:
3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
4 | - https://mirrors.ustc.edu.cn/anaconda/cloud/menpo/
5 | - https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda/
6 | - https://mirrors.ustc.edu.cn/anaconda/cloud/msys2/
7 | - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/
8 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/free/
9 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/
10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
13 | - https://repo.continuum.io/pkgs/main/win-64/
14 | - https://repo.continuum.io/pkgs/free/win-64/
15 | - defaults
16 | dependencies:
17 | - _libgcc_mutex=0.1=conda_forge
18 | - _openmp_mutex=4.5=2_gnu
19 | - bzip2=1.0.8=hd590300_5
20 | - ca-certificates=2023.11.17=hbcca054_0
21 | - ld_impl_linux-64=2.40=h41732ed_0
22 | - libffi=3.4.2=h7f98852_5
23 | - libgcc-ng=13.2.0=h807b86a_3
24 | - libgomp=13.2.0=h807b86a_3
25 | - libnsl=2.0.1=hd590300_0
26 | - libsqlite=3.44.2=h2797004_0
27 | - libuuid=2.38.1=h0b41bf4_0
28 | - libxcrypt=4.4.36=hd590300_1
29 | - libzlib=1.2.13=hd590300_5
30 | - ncurses=6.4=h59595ed_2
31 | - openssl=3.2.0=hd590300_1
32 | - pip=23.3.2=pyhd8ed1ab_0
33 | - python=3.8.18=hd12c33a_1_cpython
34 | - readline=8.2=h8228510_1
35 | - setuptools=68.2.2=pyhd8ed1ab_0
36 | - tk=8.6.13=noxft_h4845f30_101
37 | - wheel=0.42.0=pyhd8ed1ab_0
38 | - xz=5.2.6=h166bdaf_0
39 | - pip:
40 | - absl-py==2.2.2
41 | - cachetools==5.5.2
42 | - certifi==2023.11.17
43 | - charset-normalizer==3.3.2
44 | - einops==0.7.0
45 | - filelock==3.13.1
46 | - fsspec==2023.12.2
47 | - google-auth==2.40.1
48 | - google-auth-oauthlib==1.0.0
49 | - grpcio==1.70.0
50 | - huggingface-hub==0.20.1
51 | - idna==3.6
52 | - imageio==2.35.1
53 | - importlib-metadata==8.5.0
54 | - joblib==1.4.2
55 | - lazy-loader==0.4
56 | - lmdb==1.6.2
57 | - markdown==3.7
58 | - markupsafe==2.1.5
59 | - natsort==8.4.0
60 | - networkx==3.1
61 | - numpy==1.24.4
62 | - oauthlib==3.2.2
63 | - opencv-python==4.8.1.78
64 | - packaging==23.2
65 | - pillow==10.1.0
66 | - protobuf==5.29.4
67 | - pyasn1==0.6.1
68 | - pyasn1-modules==0.4.2
69 | - pywavelets==1.4.1
70 | - pyyaml==6.0.1
71 | - requests==2.31.0
72 | - requests-oauthlib==2.0.0
73 | - rsa==4.9.1
74 | - safetensors==0.4.1
75 | - scikit-image==0.21.0
76 | - scikit-learn==1.3.2
77 | - scipy==1.10.1
78 | - six==1.17.0
79 | - tensorboard==2.14.0
80 | - tensorboard-data-server==0.7.2
81 | - threadpoolctl==3.5.0
82 | - tifffile==2023.7.10
83 | - timm==0.9.12
84 | - torch==1.12.0+cu113
85 | - torchaudio==0.12.0+cu113
86 | - torchvision==0.13.0+cu113
87 | - tqdm==4.66.1
88 | - typing-extensions==4.9.0
89 | - urllib3==2.1.0
90 | - werkzeug==3.0.6
91 | - zipp==3.20.2
92 | prefix: /home/ubuntu13/anaconda3/envs/HINT_n
93 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import sys
8 | import time
9 | import torch
10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension,
11 | CUDAExtension)
12 |
13 | version_file = 'basicsr/version.py'
14 |
15 |
16 | def readme():
17 | return ''
18 | # with open('README.md', encoding='utf-8') as f:
19 | # content = f.read()
20 | # return content
21 |
22 |
23 | def get_git_hash():
24 |
25 | def _minimal_ext_cmd(cmd):
26 | # construct minimal environment
27 | env = {}
28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
29 | v = os.environ.get(k)
30 | if v is not None:
31 | env[k] = v
32 | # LANGUAGE is used on win32
33 | env['LANGUAGE'] = 'C'
34 | env['LANG'] = 'C'
35 | env['LC_ALL'] = 'C'
36 | out = subprocess.Popen(
37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
38 | return out
39 |
40 | try:
41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
42 | sha = out.strip().decode('ascii')
43 | except OSError:
44 | sha = 'unknown'
45 |
46 | return sha
47 |
48 |
49 | def get_hash():
50 | if os.path.exists('.git'):
51 | sha = get_git_hash()[:7]
52 | elif os.path.exists(version_file):
53 | try:
54 | from basicsr.version import __version__
55 | sha = __version__.split('+')[-1]
56 | except ImportError:
57 | raise ImportError('Unable to get git version')
58 | else:
59 | sha = 'unknown'
60 |
61 | return sha
62 |
63 |
64 | def write_version_py():
65 | content = """# GENERATED VERSION FILE
66 | # TIME: {}
67 | __version__ = '{}'
68 | short_version = '{}'
69 | version_info = ({})
70 | """
71 | sha = get_hash()
72 | with open('VERSION', 'r') as f:
73 | SHORT_VERSION = f.read().strip()
74 | VERSION_INFO = ', '.join(
75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
76 | VERSION = SHORT_VERSION + '+' + sha
77 |
78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
79 | VERSION_INFO)
80 | with open(version_file, 'w') as f:
81 | f.write(version_file_str)
82 |
83 |
84 | def get_version():
85 | with open(version_file, 'r') as f:
86 | exec(compile(f.read(), version_file, 'exec'))
87 | return locals()['__version__']
88 |
89 |
90 | def make_cuda_ext(name, module, sources, sources_cuda=None):
91 | if sources_cuda is None:
92 | sources_cuda = []
93 | define_macros = []
94 | extra_compile_args = {'cxx': []}
95 |
96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
97 | define_macros += [('WITH_CUDA', None)]
98 | extension = CUDAExtension
99 | extra_compile_args['nvcc'] = [
100 | '-D__CUDA_NO_HALF_OPERATORS__',
101 | '-D__CUDA_NO_HALF_CONVERSIONS__',
102 | '-D__CUDA_NO_HALF2_OPERATORS__',
103 | ]
104 | sources += sources_cuda
105 | else:
106 | print(f'Compiling {name} without CUDA')
107 | extension = CppExtension
108 |
109 | return extension(
110 | name=f'{module}.{name}',
111 | sources=[os.path.join(*module.split('.'), p) for p in sources],
112 | define_macros=define_macros,
113 | extra_compile_args=extra_compile_args)
114 |
115 |
116 | def get_requirements(filename='requirements.txt'):
117 | return []
118 | here = os.path.dirname(os.path.realpath(__file__))
119 | with open(os.path.join(here, filename), 'r') as f:
120 | requires = [line.replace('\n', '') for line in f.readlines()]
121 | return requires
122 |
123 |
124 | if __name__ == '__main__':
125 | if '--no_cuda_ext' in sys.argv:
126 | ext_modules = []
127 | sys.argv.remove('--no_cuda_ext')
128 | else:
129 | ext_modules = [
130 | make_cuda_ext(
131 | name='deform_conv_ext',
132 | module='basicsr.models.ops.dcn',
133 | sources=['src/deform_conv_ext.cpp'],
134 | sources_cuda=[
135 | 'src/deform_conv_cuda.cpp',
136 | 'src/deform_conv_cuda_kernel.cu'
137 | ]),
138 | make_cuda_ext(
139 | name='fused_act_ext',
140 | module='basicsr.models.ops.fused_act',
141 | sources=['src/fused_bias_act.cpp'],
142 | sources_cuda=['src/fused_bias_act_kernel.cu']),
143 | make_cuda_ext(
144 | name='upfirdn2d_ext',
145 | module='basicsr.models.ops.upfirdn2d',
146 | sources=['src/upfirdn2d.cpp'],
147 | sources_cuda=['src/upfirdn2d_kernel.cu']),
148 | ]
149 |
150 | write_version_py()
151 | setup(
152 | name='basicsr',
153 | version=get_version(),
154 | description='Open Source Image and Video Super-Resolution Toolbox',
155 | long_description=readme(),
156 | author='Xintao Wang',
157 | author_email='xintao.wang@outlook.com',
158 | keywords='computer vision, restoration, super resolution',
159 | url='https://github.com/xinntao/BasicSR',
160 | packages=find_packages(
161 | exclude=('options', 'datasets', 'experiments', 'results',
162 | 'tb_logger', 'wandb')),
163 | classifiers=[
164 | 'Development Status :: 4 - Beta',
165 | 'License :: OSI Approved :: Apache Software License',
166 | 'Operating System :: OS Independent',
167 | 'Programming Language :: Python :: 3',
168 | 'Programming Language :: Python :: 3.7',
169 | 'Programming Language :: Python :: 3.8',
170 | ],
171 | license='Apache License 2.0',
172 | setup_requires=['cython', 'numpy'],
173 | install_requires=get_requirements(),
174 | ext_modules=ext_modules,
175 | cmdclass={'build_ext': BuildExtension},
176 | zip_safe=False)
177 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 |
2 | ### Dehaze
3 | python test_SOTS_HINT.py
4 | python evaluate_SOTS.py
5 |
6 | ### Derain
7 | python test_rain100L.py
8 |
9 | ### Denoising
10 | python test_gaussian_color_denoising_HINT.py --model_type blind
11 | python evaluate_gaussian_color_denoising_HINT.py --model_type blind
12 |
13 | ### Desnowing
14 | python test_snow100k.py
15 | python evaluate_Snow100k.py
16 |
17 | ### Enhancement
18 | python test_from_dataset_LOLv2_Real.py
19 | python test_from_dataset_LOLv2_Syn.py
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 |
5 | python -m torch.distributed.launch --nproc_per_node=2 --master_port=4321 basicsr/train.py -opt $CONFIG --launcher pytorch
6 |
--------------------------------------------------------------------------------