├── Dehazing ├── config │ └── config.yaml ├── data │ ├── __init__.py │ └── loader.py ├── loss.py ├── metrics.py ├── model │ ├── __init__.py │ ├── final_model.py │ ├── model_utils.py │ └── module.py ├── readme.md ├── test.py ├── train.py └── utils.py ├── Deraining ├── .DS_Store ├── VERSION ├── basicsr.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── not-zip-safe │ └── top_level.txt ├── basicsr │ ├── .DS_Store │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── data_util.py │ │ ├── ffhq_dataset.py │ │ ├── paired_image_dataset.py │ │ ├── prefetch_dataloader.py │ │ ├── reds_dataset.py │ │ ├── single_image_dataset.py │ │ └── transforms.py │ ├── demo.py │ ├── metrics │ │ ├── __init__.py │ │ ├── fid.py │ │ ├── metric_util.py │ │ ├── niqe.py │ │ ├── niqe_pris_params.npz │ │ └── psnr_ssim.py │ ├── models │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── archs │ │ │ ├── .DS_Store │ │ │ ├── __init__.py │ │ │ ├── arch_util.py │ │ │ ├── common.py │ │ │ ├── fourmer.py │ │ │ └── layers.py │ │ ├── base_model.py │ │ ├── image_restoration_model.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── loss_util.py │ │ │ └── losses.py │ │ ├── lr_scheduler.py │ │ └── prenet_model.py │ ├── test.py │ ├── train.py │ ├── train_rain.py │ ├── utils │ │ ├── __init__.py │ │ ├── create_lmdb.py │ │ ├── dist_util.py │ │ ├── download_util.py │ │ ├── face_util.py │ │ ├── file_client.py │ │ ├── flow_util.py │ │ ├── img_util.py │ │ ├── lmdb_util.py │ │ ├── logger.py │ │ ├── matlab_functions.py │ │ ├── misc.py │ │ ├── options.py │ │ └── show.py │ └── version.py ├── generate_lmdb.py ├── get_best.py ├── options │ ├── .DS_Store │ └── train │ │ ├── .DS_Store │ │ ├── RAIN200H │ │ └── fourmer.yml │ │ └── RAIN200L │ │ └── fourmer.yml ├── readme.md ├── requirements.txt ├── scripts │ ├── data_preparation │ │ ├── gopro.py │ │ ├── rain13k.py │ │ ├── reds.py │ │ └── sidd.py │ ├── download_gdrive.py │ ├── download_pretrained_models.py │ └── publish_models.py ├── setup.py ├── show.py └── test.py ├── LLIE ├── create_txt.py ├── data │ ├── SIEN_dataset.py │ ├── __init__.py │ ├── groups_test_Huawei.txt │ ├── groups_test_LOL.txt │ ├── groups_train_Huawei.txt │ ├── groups_train_LOL.txt │ ├── shuffle.py │ └── util.py ├── eval.py ├── eval_test.py ├── metrics │ ├── calculate_PSNR_SSIM.m │ └── calculate_PSNR_SSIM.py ├── models │ ├── SIEN_model.py │ ├── __init__.py │ ├── archs │ │ ├── EnhanceN_arch.py │ │ ├── EnhanceN_arch1.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── EDVR_arch.cpython-35.pyc │ │ │ ├── EDVR_arch.cpython-36.pyc │ │ │ ├── EnhanceN_arch.cpython-35.pyc │ │ │ ├── EnhanceN_arch.cpython-38.pyc │ │ │ ├── EnhanceN_arch1.cpython-38.pyc │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── arch_util.cpython-35.pyc │ │ │ ├── arch_util.cpython-36.pyc │ │ │ ├── arch_util.cpython-38.pyc │ │ │ ├── discriminator_vgg_arch.cpython-36.pyc │ │ │ └── discriminator_vgg_arch.cpython-38.pyc │ │ ├── arch_util.py │ │ ├── dehaze_arch.py │ │ └── discriminator_vgg_arch.py │ ├── base_model.py │ ├── loss.py │ ├── loss_new.py │ ├── lr_scheduler.py │ └── networks.py ├── options │ ├── __init__.py │ ├── options.py │ ├── test │ │ ├── test_ESRGAN.yml │ │ ├── test_SRGAN.yml │ │ └── test_SRResNet.yml │ └── train │ │ ├── train_Enhance.yml │ │ └── train_Enhance1.yml ├── pretrain │ └── LOL │ │ └── 0_bestavg.pth ├── readme.md ├── test.py ├── train.py ├── train1.py └── utils │ ├── __init__.py │ └── util.py ├── README.md └── pan-sharpening ├── .gitignore ├── configs └── config.yaml ├── datasets └── data.py ├── main.py ├── models ├── Network.py ├── __init__.py ├── fusion.py ├── grad_loss.py ├── hazer_cfm_adp.py ├── histgram_loss.py ├── model.py ├── model2.py └── pipeline.py ├── test.py └── utils ├── __init__.py ├── global_config.py └── util.py /Dehazing/config/config.yaml: -------------------------------------------------------------------------------- 1 | output_dir: 'experiment/upsample/aod_4' 2 | data: 3 | train_dir: data/ITS/ 4 | test_dir: data/SOTS/indoor 5 | 6 | model: 7 | in_channel: 3 8 | model_channel: 36 9 | 10 | train_loader: 11 | num_workers: 8 12 | batch_size: 4 13 | loader: crop 14 | img_size: (256, 256) 15 | shuffle: True 16 | 17 | test_loader: 18 | num_workers: 8 19 | batch_size: 1 20 | loader: default 21 | img_size: (600, 600) 22 | shuffle: False 23 | 24 | optimizer: 25 | type: step 26 | total_epoch: 45 27 | lr: 0.0002 28 | T_0: 100 29 | T_MULT: 1 30 | ETA_MIN: 0.000001 31 | step: 15 32 | gamma: 0.75 33 | 34 | hyper_params: 35 | x_lambda: 0.03 36 | 37 | resume: 38 | flag: False 39 | checkpoint: None 40 | 41 | evaluate_intervel: 5 42 | 43 | 44 | -------------------------------------------------------------------------------- /Dehazing/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import * -------------------------------------------------------------------------------- /Dehazing/data/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import Dataset 6 | from PIL import Image 7 | import torchvision.transforms.functional as TF 8 | import torchvision.transforms as tf 9 | from PIL import Image, ImageFile 10 | import random 11 | import math 12 | from model import * 13 | import torch 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | class base_dataset(Dataset): 19 | def __init__(self, data_dir, img_size, transforms=False, crop=False): 20 | imgs = sorted(os.listdir(data_dir + "/hazy")) 21 | gt_imgs = [i.split("_")[0] for i in imgs] 22 | self.input_imgs = [os.path.join(data_dir + "/hazy", name) for name in imgs] 23 | 24 | self.gt_imgs = [os.path.join(data_dir + "/gt", name + ".png") for name in gt_imgs] 25 | self.transforms = transforms 26 | self.crop = crop 27 | self.img_size = img_size 28 | 29 | def __getitem__(self, index): 30 | inp_img_path = self.input_imgs[index] 31 | gt_img_path = self.gt_imgs[index] 32 | inp_img = Image.open(inp_img_path).convert("RGB") 33 | gt_img = Image.open(gt_img_path).convert("RGB") 34 | if self.transforms: 35 | inp_img = self.transforms(inp_img) 36 | gt_img = self.transforms(gt_img) 37 | 38 | if self.crop: 39 | inp_img, gt_img = self.crop_image(inp_img, gt_img) 40 | 41 | return inp_img, gt_img, inp_img_path 42 | 43 | def __len__(self): 44 | return len(self.gt_imgs) 45 | 46 | def crop_image(self, inp_img, gt_img): 47 | crop_h, crop_w = self.img_size 48 | i, j, h, w = tf.RandomCrop.get_params( 49 | inp_img, output_size=((crop_h, crop_w))) 50 | inp_img = TF.crop(inp_img, i, j, h, w) 51 | gt_img = TF.crop(gt_img, i, j, h, w) 52 | inp_img = TF.to_tensor(inp_img) 53 | gt_img = TF.to_tensor(gt_img) 54 | 55 | return inp_img, gt_img 56 | 57 | 58 | class random_scale_dataset(Dataset): 59 | def __init__(self, data_dir, img_size, transforms=False, crop=False): 60 | imgs = sorted(os.listdir(data_dir + "/low")) 61 | self.input_imgs = [os.path.join(data_dir + "/low", name) for name in imgs] 62 | self.gt_imgs = [os.path.join(data_dir + "/high", name) for name in imgs] 63 | self.transforms = transforms 64 | self.crop = crop 65 | self.img_size = img_size 66 | 67 | def __getitem__(self, index): 68 | inp_img_path = self.input_imgs[index] 69 | gt_img_path = self.gt_imgs[index] 70 | inp_img = Image.open(inp_img_path).convert("RGB") 71 | gt_img = Image.open(gt_img_path).convert("RGB") 72 | 73 | random_scale_factor = random.randrange(self.img_size[0] * 0.25, self.img_size[0], 8) 74 | down_h = down_w = random_scale_factor 75 | 76 | if self.transforms: 77 | inp_img = self.transforms(inp_img) 78 | gt_img = self.transforms(gt_img) 79 | return inp_img, gt_img, down_h, down_w, inp_img_path 80 | 81 | if self.crop: 82 | inp_img, gt_img = self.crop_image(inp_img, gt_img) 83 | return inp_img, gt_img, down_h, down_w, inp_img_path 84 | 85 | def __len__(self): 86 | return len(self.gt_imgs) 87 | 88 | def crop_image(self, inp_img, gt_img): 89 | crop_h, crop_w = self.img_size 90 | i, j, h, w = tf.RandomCrop.get_params( 91 | inp_img, output_size=((crop_h, crop_w))) 92 | inp_img = TF.crop(inp_img, i, j, h, w) 93 | gt_img = TF.crop(gt_img, i, j, h, w) 94 | inp_img = TF.to_tensor(inp_img) 95 | gt_img = TF.to_tensor(gt_img) 96 | 97 | return inp_img, gt_img 98 | 99 | 100 | def get_loader(data_dir, img_size, transforms, crop_flag, batch_size, num_workers, shuffle, random_flag=False): 101 | if random_flag: 102 | dataset = random_scale_dataset(data_dir, img_size, transforms, crop_flag) 103 | dataloader = DataLoader(dataset, batch_size=batch_size, 104 | shuffle=shuffle, num_workers=num_workers, pin_memory=True) 105 | else: 106 | dataset = base_dataset(data_dir, img_size, transforms, crop_flag) 107 | dataloader = DataLoader(dataset, batch_size=batch_size, 108 | shuffle=shuffle, num_workers=num_workers, pin_memory=True) 109 | return dataloader 110 | -------------------------------------------------------------------------------- /Dehazing/metrics.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): 8 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 9 | 10 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 11 | 12 | Args: 13 | img (ndarray): Images with range [0, 255]. 14 | img2 (ndarray): Images with range [0, 255]. 15 | crop_border (int): Cropped pixels in each edge of an image. These 16 | pixels are not involved in the PSNR calculation. 17 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 18 | Default: 'HWC'. 19 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 20 | 21 | Returns: 22 | float: psnr result. 23 | """ 24 | 25 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 26 | if input_order not in ['HWC', 'CHW']: 27 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 28 | img = img.astype(np.float64) 29 | img2 = img2.astype(np.float64) 30 | 31 | if crop_border != 0: 32 | img = img[crop_border:-crop_border, crop_border:-crop_border, ...] 33 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 34 | 35 | if test_y_channel: 36 | img = to_y_channel(img) 37 | img2 = to_y_channel(img2) 38 | 39 | mse = np.mean((img - img2)**2) 40 | if mse == 0: 41 | return float('inf') 42 | return 20. * np.log10(255. / np.sqrt(mse)) 43 | 44 | 45 | def _ssim(img, img2): 46 | """Calculate SSIM (structural similarity) for one channel images. 47 | 48 | It is called by func:`calculate_ssim`. 49 | 50 | Args: 51 | img (ndarray): Images with range [0, 255] with order 'HWC'. 52 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 53 | 54 | Returns: 55 | float: ssim result. 56 | """ 57 | 58 | c1 = (0.01 * 255)**2 59 | c2 = (0.03 * 255)**2 60 | 61 | img = img.astype(np.float64) 62 | img2 = img2.astype(np.float64) 63 | kernel = cv2.getGaussianKernel(11, 1.5) 64 | window = np.outer(kernel, kernel.transpose()) 65 | 66 | mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] 67 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 68 | mu1_sq = mu1**2 69 | mu2_sq = mu2**2 70 | mu1_mu2 = mu1 * mu2 71 | sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq 72 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 73 | sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 74 | 75 | ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) 76 | return ssim_map.mean() 77 | 78 | def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): 79 | """Calculate SSIM (structural similarity). 80 | 81 | Ref: 82 | Image quality assessment: From error visibility to structural similarity 83 | 84 | The results are the same as that of the official released MATLAB code in 85 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 86 | 87 | For three-channel images, SSIM is calculated for each channel and then 88 | averaged. 89 | 90 | Args: 91 | img (ndarray): Images with range [0, 255]. 92 | img2 (ndarray): Images with range [0, 255]. 93 | crop_border (int): Cropped pixels in each edge of an image. These 94 | pixels are not involved in the SSIM calculation. 95 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 96 | Default: 'HWC'. 97 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 98 | 99 | Returns: 100 | float: ssim result. 101 | """ 102 | 103 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 104 | if input_order not in ['HWC', 'CHW']: 105 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 106 | # img = reorder_image(img, input_order=input_order) 107 | # img2 = reorder_image(img2, input_order=input_order) 108 | img = img.astype(np.float64) 109 | img2 = img2.astype(np.float64) 110 | 111 | if crop_border != 0: 112 | img = img[crop_border:-crop_border, crop_border:-crop_border, ...] 113 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 114 | 115 | if test_y_channel: 116 | img = to_y_channel(img) 117 | img2 = to_y_channel(img2) 118 | 119 | ssims = [] 120 | for i in range(img.shape[2]): 121 | ssims.append(_ssim(img[..., i], img2[..., i])) 122 | return np.array(ssims).mean() 123 | 124 | if __name__ == '__main__': 125 | 126 | # test_transforms = transforms.Compose([transforms.Resize((512, 512)),transforms.ToTensor()]) 127 | # inp_img = Image.open("/mnt/disk1/yuwei/data/4Kdehaze/train/clear/0_000002.jpg").convert("RGB") 128 | # img = test_transforms(inp_img) 129 | img = cv2.imread("/mnt/disk1/yuwei/data/4Kdehaze/train/clear/0_000002.jpg") 130 | psnr = calculate_psnr(img, img, 0) 131 | ssim = calculate_ssim(img, img, 0) 132 | print(psnr) 133 | print(ssim) -------------------------------------------------------------------------------- /Dehazing/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .final_model import * 3 | from .module import * -------------------------------------------------------------------------------- /Dehazing/model/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import cv2 5 | from skimage import feature 6 | from torchvision.transforms.functional import rgb_to_grayscale 7 | 8 | 9 | def gaussian_2d(size, fwhm=3, center=None): 10 | x = np.arange(0, size, 1, float) 11 | y = x[:, np.newaxis] 12 | 13 | if center is None: 14 | x0 = y0 = size // 2 15 | else: 16 | x0 = center[0] 17 | y0 = center[1] 18 | 19 | return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2) 20 | 21 | 22 | def coord_pos_norm(grid): 23 | # grid: [-1, 1] 24 | grid = (grid + 1) / 2 25 | grid = torch.clamp(grid, 0, 1) 26 | return grid 27 | 28 | 29 | def coord_neg_norm(grid): 30 | # grid: [0, 1] 31 | grid = grid * 2 - 1 32 | grid = torch.clamp(grid, -1, 1) 33 | return grid 34 | 35 | 36 | def grid_offset(coord, offset): 37 | # grid: b, 2, h, w 38 | x_cor = coord[:, 0, :, :] 39 | y_cor = coord[:, 1, :, :] 40 | x_cor += offset[:, :1, :, :] 41 | y_cor += offset[:, 1:, :, :] 42 | 43 | offseted_coord = torch.cat([x_cor, y_cor], dim=1) 44 | 45 | return offseted_coord 46 | 47 | 48 | def invert(grid): # h, w, 2 49 | I = np.zeros_like(grid) 50 | I[:, :, 1], I[:, :, 0] = np.indices((grid.shape[0], grid.shape[1])) 51 | P = np.copy(I) 52 | for i in range(5): 53 | P += (I - cv2.remap(grid, P, None, interpolation=cv2.INTER_LINEAR)) * 0.5 54 | return P 55 | -------------------------------------------------------------------------------- /Dehazing/readme.md: -------------------------------------------------------------------------------- 1 | ## Applications 2 | ### Image Dehazing 3 | #### Prepare data 4 | Download the training data and add the data path to the config file (config/config.yaml). Please refer to [RESIDE](https://github.com/BookerDeWitt/MSBDN-DFF) for data download. 5 | 6 | #### Training 7 | 8 | Set path 'output_dir', 'data' in config/config.yaml. 9 | Run: 10 | ``` 11 | python train.py 12 | ``` 13 | -------------------------------------------------------------------------------- /Dehazing/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import torchvision.transforms as transforms 4 | from utils import read_args, save_checkpoint, AverageMeter, calculate_metrics, CosineAnnealingWarmRestarts 5 | # import torchvision.transforms.InterpolationMode 6 | import time 7 | from tqdm import trange, tqdm 8 | from torchvision.utils import save_image 9 | import os 10 | 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 12 | import json 13 | import time 14 | import logging 15 | import torch 16 | from torch import nn, optim 17 | import numpy as np 18 | import torch.nn.functional as F 19 | 20 | from model import * 21 | from data import * 22 | from PIL import Image 23 | from torchvision.transforms import Resize 24 | import pyiqa 25 | from thop import profile 26 | from thop import clever_format 27 | 28 | psnr_calculator = pyiqa.create_metric('psnr').cuda() 29 | ssim_calculator = pyiqa.create_metric('ssimc', downsample=True).cuda() 30 | lpips_calculator = pyiqa.create_metric('lpips').cuda() 31 | niqe_calculator = pyiqa.create_metric('niqe').cuda() 32 | 33 | 34 | def test(load_path, data_loader, args): 35 | if not os.path.exists(args.output_dir + '/image_test'): 36 | os.mkdir(args.output_dir + '/image_test') 37 | 38 | save_path = args.output_dir + '/image_test' 39 | 40 | model = guide_net(args.model["model_channel"]) 41 | checkpoint = torch.load(load_path) 42 | model.load_state_dict(checkpoint["state_dict"]) 43 | model.cuda() 44 | model.eval() 45 | 46 | psnrs = AverageMeter() 47 | ssims = AverageMeter() 48 | lpipss = AverageMeter() 49 | niqes = AverageMeter() 50 | 51 | start_time = time.time() 52 | with torch.no_grad(): 53 | for i, batch in enumerate(tqdm(data_loader)): 54 | input_img, gt_img, inp_img_path = batch 55 | 56 | name = inp_img_path[0].split("/")[-1] 57 | input_img = input_img.cuda() 58 | batch_size = input_img.size(0) 59 | start_time = time.time() 60 | output, _ = model(input_img) 61 | 62 | # metrics 63 | clamped_out = torch.clamp(output, 0, 1) 64 | psnr_val, ssim_val = psnr_calculator(clamped_out, gt_img), ssim_calculator(clamped_out, gt_img) 65 | psnrs.update(torch.mean(psnr_val).item(), batch_size) 66 | ssims.update(torch.mean(ssim_val).item(), batch_size) 67 | 68 | save_image(clamped_out[0], os.path.join(save_path, name)) 69 | # lpips = lpips_calculator(clamped_out, gt_img) 70 | # lpipss.update(torch.mean(lpips).item(), batch_size) 71 | # niqe = niqe_calculator(clamped_out) 72 | # niqes.update(torch.mean(niqe).item(), batch_size) 73 | torch.cuda.empty_cache() 74 | 75 | if i % 20 == 0: 76 | logging.info( 77 | "PSNR {:.4f}, SSIM {:.4f}, LPIPS {:.4F}, NIQE {:.4F}, Elapse time {:.2f}\n".format(psnrs.avg, 78 | ssims.avg, 79 | lpipss.avg, 80 | niqes.avg, 81 | time.time() - start_time)) 82 | 83 | logging.info( 84 | "Finish test: avg PSNR: %.4f, avg SSIM: %.4F, avg LPIPS: %.4F, avg NIQE: %.4F, and takes %.2f seconds" % ( 85 | psnrs.avg, ssims.avg, lpipss.avg, niqes.avg, time.time() - start_time)) 86 | 87 | 88 | def main(args, load_path): 89 | if not os.path.exists(args.output_dir): 90 | os.mkdir(args.output_dir) 91 | test_transforms = transforms.Compose([transforms.ToTensor()]) 92 | 93 | log_format = "%(asctime)s %(levelname)-8s %(message)s" 94 | log_file = os.path.join(args.output_dir, "test_log") 95 | logging.basicConfig(filename=log_file, level=logging.INFO, format=log_format) 96 | logging.getLogger().addHandler(logging.StreamHandler()) 97 | 98 | logging.info("Building data loader") 99 | 100 | test_loader = get_loader(args.data["test_dir"], 101 | eval(args.test_loader["img_size"]), test_transforms, False, 102 | int(args.test_loader["batch_size"]), args.test_loader["num_workers"], 103 | args.test_loader["shuffle"], random_flag=False) 104 | test(load_path, test_loader, args) 105 | 106 | 107 | if __name__ == '__main__': 108 | parser = read_args("config.yaml") 109 | args = parser.parse_args() 110 | main(args, "model.pth") 111 | -------------------------------------------------------------------------------- /Deraining/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/.DS_Store -------------------------------------------------------------------------------- /Deraining/VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /Deraining/basicsr.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: basicsr 3 | Version: 1.2.0+unknown 4 | Summary: Open Source Image and Video Super-Resolution Toolbox 5 | Home-page: https://github.com/xinntao/BasicSR 6 | Author: Xintao Wang 7 | Author-email: xintao.wang@outlook.com 8 | License: Apache License 2.0 9 | Keywords: computer vision,restoration,super resolution 10 | Classifier: Development Status :: 4 - Beta 11 | Classifier: License :: OSI Approved :: Apache Software License 12 | Classifier: Operating System :: OS Independent 13 | Classifier: Programming Language :: Python :: 3 14 | Classifier: Programming Language :: Python :: 3.7 15 | Classifier: Programming Language :: Python :: 3.8 16 | -------------------------------------------------------------------------------- /Deraining/basicsr.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | basicsr/__init__.py 3 | basicsr/demo.py 4 | basicsr/test.py 5 | basicsr/train.py 6 | basicsr/train_rain.py 7 | basicsr/version.py 8 | basicsr.egg-info/PKG-INFO 9 | basicsr.egg-info/SOURCES.txt 10 | basicsr.egg-info/dependency_links.txt 11 | basicsr.egg-info/not-zip-safe 12 | basicsr.egg-info/top_level.txt 13 | basicsr/data/__init__.py 14 | basicsr/data/data_sampler.py 15 | basicsr/data/data_util.py 16 | basicsr/data/ffhq_dataset.py 17 | basicsr/data/paired_image_dataset.py 18 | basicsr/data/prefetch_dataloader.py 19 | basicsr/data/reds_dataset.py 20 | basicsr/data/single_image_dataset.py 21 | basicsr/data/transforms.py 22 | basicsr/metrics/__init__.py 23 | basicsr/metrics/fid.py 24 | basicsr/metrics/metric_util.py 25 | basicsr/metrics/niqe.py 26 | basicsr/metrics/psnr_ssim.py 27 | basicsr/models/__init__.py 28 | basicsr/models/base_model.py 29 | basicsr/models/image_restoration_model.py 30 | basicsr/models/lpnet_model.py 31 | basicsr/models/lr_scheduler.py 32 | basicsr/models/prenet_model.py 33 | basicsr/models/archs/LPNet_corner_arch.py 34 | basicsr/models/archs/LPNet_pad_arch.py 35 | basicsr/models/archs/LPNet_v1_arch.py 36 | basicsr/models/archs/LPNet_v2_arch.py 37 | basicsr/models/archs/__init__.py 38 | basicsr/models/archs/arch_util.py 39 | basicsr/models/archs/common.py 40 | basicsr/models/archs/layers.py 41 | basicsr/models/losses/__init__.py 42 | basicsr/models/losses/loss_util.py 43 | basicsr/models/losses/losses.py 44 | basicsr/utils/__init__.py 45 | basicsr/utils/create_lmdb.py 46 | basicsr/utils/dist_util.py 47 | basicsr/utils/download_util.py 48 | basicsr/utils/face_util.py 49 | basicsr/utils/file_client.py 50 | basicsr/utils/flow_util.py 51 | basicsr/utils/img_util.py 52 | basicsr/utils/lmdb_util.py 53 | basicsr/utils/logger.py 54 | basicsr/utils/matlab_functions.py 55 | basicsr/utils/misc.py 56 | basicsr/utils/options.py 57 | basicsr/utils/show.py -------------------------------------------------------------------------------- /Deraining/basicsr.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Deraining/basicsr.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Deraining/basicsr.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | basicsr 2 | -------------------------------------------------------------------------------- /Deraining/basicsr/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/basicsr/.DS_Store -------------------------------------------------------------------------------- /Deraining/basicsr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/basicsr/__init__.py -------------------------------------------------------------------------------- /Deraining/basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | import importlib 9 | import numpy as np 10 | import random 11 | import torch 12 | import torch.utils.data 13 | from functools import partial 14 | from os import path as osp 15 | 16 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 17 | from basicsr.utils import get_root_logger, scandir 18 | from basicsr.utils.dist_util import get_dist_info 19 | 20 | __all__ = ['create_dataset', 'create_dataloader'] 21 | 22 | # automatically scan and import dataset modules 23 | # scan all the files under the data folder with '_dataset' in file names 24 | data_folder = osp.dirname(osp.abspath(__file__)) 25 | dataset_filenames = [ 26 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 27 | if v.endswith('_dataset.py') 28 | ] 29 | # import all the dataset modules 30 | _dataset_modules = [ 31 | importlib.import_module(f'basicsr.data.{file_name}') 32 | for file_name in dataset_filenames 33 | ] 34 | 35 | 36 | def create_dataset(dataset_opt): 37 | """Create dataset. 38 | 39 | Args: 40 | dataset_opt (dict): Configuration for dataset. It constains: 41 | name (str): Dataset name. 42 | type (str): Dataset type. 43 | """ 44 | dataset_type = dataset_opt['type'] 45 | 46 | # dynamic instantiation 47 | for module in _dataset_modules: 48 | dataset_cls = getattr(module, dataset_type, None) 49 | if dataset_cls is not None: 50 | break 51 | if dataset_cls is None: 52 | raise ValueError(f'Dataset {dataset_type} is not found.') 53 | 54 | dataset = dataset_cls(dataset_opt) 55 | 56 | logger = get_root_logger() 57 | logger.info( 58 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 59 | 'is created.') 60 | return dataset 61 | 62 | 63 | def create_dataloader(dataset, 64 | dataset_opt, 65 | num_gpu=1, 66 | dist=False, 67 | sampler=None, 68 | seed=None): 69 | """Create dataloader. 70 | 71 | Args: 72 | dataset (torch.utils.data.Dataset): Dataset. 73 | dataset_opt (dict): Dataset options. It contains the following keys: 74 | phase (str): 'train' or 'val'. 75 | num_worker_per_gpu (int): Number of workers for each GPU. 76 | batch_size_per_gpu (int): Training batch size for each GPU. 77 | num_gpu (int): Number of GPUs. Used only in the train phase. 78 | Default: 1. 79 | dist (bool): Whether in distributed training. Used only in the train 80 | phase. Default: False. 81 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 82 | seed (int | None): Seed. Default: None 83 | """ 84 | phase = dataset_opt['phase'] 85 | rank, _ = get_dist_info() 86 | if phase == 'train': 87 | if dist: # distributed training 88 | batch_size = dataset_opt['batch_size_per_gpu'] 89 | num_workers = dataset_opt['num_worker_per_gpu'] 90 | else: # non-distributed training 91 | multiplier = 1 if num_gpu == 0 else num_gpu 92 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 93 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 94 | dataloader_args = dict( 95 | dataset=dataset, 96 | batch_size=batch_size, 97 | shuffle=False, 98 | num_workers=num_workers, 99 | sampler=sampler, 100 | drop_last=True) 101 | if sampler is None: 102 | dataloader_args['shuffle'] = True 103 | dataloader_args['worker_init_fn'] = partial( 104 | worker_init_fn, num_workers=num_workers, rank=rank, 105 | seed=seed) if seed is not None else None 106 | elif phase in ['val', 'test']: # validation 107 | dataloader_args = dict( 108 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 109 | else: 110 | raise ValueError(f'Wrong dataset phase: {phase}. ' 111 | "Supported ones are 'train', 'val' and 'test'.") 112 | 113 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 114 | 115 | prefetch_mode = dataset_opt.get('prefetch_mode') 116 | if prefetch_mode == 'cpu': # CPUPrefetcher 117 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 118 | logger = get_root_logger() 119 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 120 | f'num_prefetch_queue = {num_prefetch_queue}') 121 | return PrefetchDataLoader( 122 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 123 | else: 124 | # prefetch_mode=None: Normal dataloader 125 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 126 | return torch.utils.data.DataLoader(**dataloader_args) 127 | 128 | 129 | def worker_init_fn(worker_id, num_workers, rank, seed): 130 | # Set the worker seed to num_workers * rank + worker_id + seed 131 | worker_seed = num_workers * rank + worker_id + seed 132 | np.random.seed(worker_seed) 133 | random.seed(worker_seed) 134 | -------------------------------------------------------------------------------- /Deraining/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | import math 9 | import torch 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | class EnlargedSampler(Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset. 15 | 16 | Modified from torch.utils.data.distributed.DistributedSampler 17 | Support enlarging the dataset for iteration-based training, for saving 18 | time when restart the dataloader after each epoch 19 | 20 | Args: 21 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 22 | num_replicas (int | None): Number of processes participating in 23 | the training. It is usually the world_size. 24 | rank (int | None): Rank of the current process within num_replicas. 25 | ratio (int): Enlarging ratio. Default: 1. 26 | """ 27 | 28 | def __init__(self, dataset, num_replicas, rank, ratio=1): 29 | self.dataset = dataset 30 | self.num_replicas = num_replicas 31 | self.rank = rank 32 | self.epoch = 0 33 | self.num_samples = math.ceil( 34 | len(self.dataset) * ratio / self.num_replicas) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = torch.randperm(self.total_size, generator=g).tolist() 42 | 43 | dataset_size = len(self.dataset) 44 | indices = [v % dataset_size for v in indices] 45 | 46 | # subsample 47 | indices = indices[self.rank:self.total_size:self.num_replicas] 48 | assert len(indices) == self.num_samples 49 | 50 | return iter(indices) 51 | 52 | def __len__(self): 53 | return self.num_samples 54 | 55 | def set_epoch(self, epoch): 56 | self.epoch = epoch 57 | -------------------------------------------------------------------------------- /Deraining/basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from os import path as osp 8 | from torch.utils import data as data 9 | from torchvision.transforms.functional import normalize 10 | 11 | from basicsr.data.transforms import augment 12 | from basicsr.utils import FileClient, imfrombytes, img2tensor 13 | 14 | 15 | class FFHQDataset(data.Dataset): 16 | """FFHQ dataset for StyleGAN. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_gt (str): Data root path for gt. 21 | io_backend (dict): IO backend type and other kwarg. 22 | mean (list | tuple): Image mean. 23 | std (list | tuple): Image std. 24 | use_hflip (bool): Whether to horizontally flip. 25 | 26 | """ 27 | 28 | def __init__(self, opt): 29 | super(FFHQDataset, self).__init__() 30 | self.opt = opt 31 | # file client (io backend) 32 | self.file_client = None 33 | self.io_backend_opt = opt['io_backend'] 34 | 35 | self.gt_folder = opt['dataroot_gt'] 36 | self.mean = opt['mean'] 37 | self.std = opt['std'] 38 | 39 | if self.io_backend_opt['type'] == 'lmdb': 40 | self.io_backend_opt['db_paths'] = self.gt_folder 41 | if not self.gt_folder.endswith('.lmdb'): 42 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 43 | f'but received {self.gt_folder}') 44 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 45 | self.paths = [line.split('.')[0] for line in fin] 46 | else: 47 | # FFHQ has 70000 images in total 48 | self.paths = [ 49 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 50 | ] 51 | 52 | def __getitem__(self, index): 53 | if self.file_client is None: 54 | self.file_client = FileClient( 55 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 56 | 57 | # load gt image 58 | gt_path = self.paths[index] 59 | img_bytes = self.file_client.get(gt_path) 60 | img_gt = imfrombytes(img_bytes, float32=True) 61 | 62 | # random horizontal flip 63 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 64 | # BGR to RGB, HWC to CHW, numpy to tensor 65 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 66 | # normalize 67 | normalize(img_gt, self.mean, self.std, inplace=True) 68 | return {'gt': img_gt, 'gt_path': gt_path} 69 | 70 | def __len__(self): 71 | return len(self.paths) 72 | -------------------------------------------------------------------------------- /Deraining/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import queue as Queue 8 | import threading 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | class PrefetchGenerator(threading.Thread): 14 | """A general prefetch generator. 15 | 16 | Ref: 17 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 18 | 19 | Args: 20 | generator: Python generator. 21 | num_prefetch_queue (int): Number of prefetch queue. 22 | """ 23 | 24 | def __init__(self, generator, num_prefetch_queue): 25 | threading.Thread.__init__(self) 26 | self.queue = Queue.Queue(num_prefetch_queue) 27 | self.generator = generator 28 | self.daemon = True 29 | self.start() 30 | 31 | def run(self): 32 | for item in self.generator: 33 | self.queue.put(item) 34 | self.queue.put(None) 35 | 36 | def __next__(self): 37 | next_item = self.queue.get() 38 | if next_item is None: 39 | raise StopIteration 40 | return next_item 41 | 42 | def __iter__(self): 43 | return self 44 | 45 | 46 | class PrefetchDataLoader(DataLoader): 47 | """Prefetch version of dataloader. 48 | 49 | Ref: 50 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 51 | 52 | TODO: 53 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 54 | ddp. 55 | 56 | Args: 57 | num_prefetch_queue (int): Number of prefetch queue. 58 | kwargs (dict): Other arguments for dataloader. 59 | """ 60 | 61 | def __init__(self, num_prefetch_queue, **kwargs): 62 | self.num_prefetch_queue = num_prefetch_queue 63 | super(PrefetchDataLoader, self).__init__(**kwargs) 64 | 65 | def __iter__(self): 66 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 67 | 68 | 69 | class CPUPrefetcher(): 70 | """CPU prefetcher. 71 | 72 | Args: 73 | loader: Dataloader. 74 | """ 75 | 76 | def __init__(self, loader): 77 | self.ori_loader = loader 78 | self.loader = iter(loader) 79 | 80 | def next(self): 81 | try: 82 | return next(self.loader) 83 | except StopIteration: 84 | return None 85 | 86 | def reset(self): 87 | self.loader = iter(self.ori_loader) 88 | 89 | 90 | class CUDAPrefetcher(): 91 | """CUDA prefetcher. 92 | 93 | Ref: 94 | https://github.com/NVIDIA/apex/issues/304# 95 | 96 | It may consums more GPU memory. 97 | 98 | Args: 99 | loader: Dataloader. 100 | opt (dict): Options. 101 | """ 102 | 103 | def __init__(self, loader, opt): 104 | self.ori_loader = loader 105 | self.loader = iter(loader) 106 | self.opt = opt 107 | self.stream = torch.cuda.Stream() 108 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 109 | self.preload() 110 | 111 | def preload(self): 112 | try: 113 | self.batch = next(self.loader) # self.batch is a dict 114 | except StopIteration: 115 | self.batch = None 116 | return None 117 | # put tensors to gpu 118 | with torch.cuda.stream(self.stream): 119 | for k, v in self.batch.items(): 120 | if torch.is_tensor(v): 121 | self.batch[k] = self.batch[k].to( 122 | device=self.device, non_blocking=True) 123 | 124 | def next(self): 125 | torch.cuda.current_stream().wait_stream(self.stream) 126 | batch = self.batch 127 | self.preload() 128 | return batch 129 | 130 | def reset(self): 131 | self.loader = iter(self.ori_loader) 132 | self.preload() 133 | -------------------------------------------------------------------------------- /Deraining/basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from os import path as osp 8 | from torch.utils import data as data 9 | from torchvision.transforms.functional import normalize 10 | 11 | from basicsr.data.data_util import paths_from_lmdb 12 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 13 | 14 | 15 | class SingleImageDataset(data.Dataset): 16 | """Read only lq images in the test phase. 17 | 18 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 19 | 20 | There are two modes: 21 | 1. 'meta_info_file': Use meta information file to generate paths. 22 | 2. 'folder': Scan folders to generate paths. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | """ 30 | 31 | def __init__(self, opt): 32 | super(SingleImageDataset, self).__init__() 33 | self.opt = opt 34 | # file client (io backend) 35 | self.file_client = None 36 | self.io_backend_opt = opt['io_backend'] 37 | self.mean = opt['mean'] if 'mean' in opt else None 38 | self.std = opt['std'] if 'std' in opt else None 39 | self.lq_folder = opt['dataroot_lq'] 40 | 41 | if self.io_backend_opt['type'] == 'lmdb': 42 | self.io_backend_opt['db_paths'] = [self.lq_folder] 43 | self.io_backend_opt['client_keys'] = ['lq'] 44 | self.paths = paths_from_lmdb(self.lq_folder) 45 | elif 'meta_info_file' in self.opt: 46 | with open(self.opt['meta_info_file'], 'r') as fin: 47 | self.paths = [ 48 | osp.join(self.lq_folder, 49 | line.split(' ')[0]) for line in fin 50 | ] 51 | else: 52 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 53 | 54 | def __getitem__(self, index): 55 | if self.file_client is None: 56 | self.file_client = FileClient( 57 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 58 | 59 | # load lq image 60 | lq_path = self.paths[index] 61 | img_bytes = self.file_client.get(lq_path, 'lq') 62 | img_lq = imfrombytes(img_bytes, float32=True) 63 | 64 | # TODO: color space transform 65 | # BGR to RGB, HWC to CHW, numpy to tensor 66 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 67 | # normalize 68 | if self.mean is not None or self.std is not None: 69 | normalize(img_lq, self.mean, self.std, inplace=True) 70 | return {'lq': img_lq, 'lq_path': lq_path} 71 | 72 | def __len__(self): 73 | return len(self.paths) 74 | -------------------------------------------------------------------------------- /Deraining/basicsr/demo.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import torch 8 | 9 | # from basicsr.data import create_dataloader, create_dataset 10 | from basicsr.models import create_model 11 | from basicsr.train import parse_options 12 | from basicsr.utils import FileClient, imfrombytes, img2tensor, padding 13 | 14 | # from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 15 | # make_exp_dirs) 16 | # from basicsr.utils.options import dict2str 17 | 18 | def main(): 19 | # parse options, set distributed setting, set ramdom seed 20 | opt = parse_options(is_train=False) 21 | 22 | img_path = opt['img_path'].get('input_img') 23 | output_path = opt['img_path'].get('output_img') 24 | 25 | 26 | ## 1. read image 27 | file_client = FileClient('disk') 28 | 29 | img_bytes = file_client.get(img_path, None) 30 | try: 31 | img = imfrombytes(img_bytes, float32=True) 32 | except: 33 | raise Exception("path {} not working".format(img_path)) 34 | 35 | img = img2tensor(img, bgr2rgb=True, float32=True) 36 | 37 | 38 | 39 | ## 2. run inference 40 | model = create_model(opt) 41 | model.single_image_inference(img, output_path) 42 | 43 | print('inference {} .. finished.'.format(img_path)) 44 | 45 | if __name__ == '__main__': 46 | main() 47 | 48 | -------------------------------------------------------------------------------- /Deraining/basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .niqe import calculate_niqe 8 | from .psnr_ssim import calculate_psnr, calculate_ssim 9 | 10 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 11 | -------------------------------------------------------------------------------- /Deraining/basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from scipy import linalg 11 | from tqdm import tqdm 12 | 13 | from basicsr.models.archs.inception import InceptionV3 14 | 15 | 16 | def load_patched_inception_v3(device='cuda', 17 | resize_input=True, 18 | normalize_input=False): 19 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 20 | # does resize the input. 21 | inception = InceptionV3([3], 22 | resize_input=resize_input, 23 | normalize_input=normalize_input) 24 | inception = nn.DataParallel(inception).eval().to(device) 25 | return inception 26 | 27 | 28 | @torch.no_grad() 29 | def extract_inception_features(data_generator, 30 | inception, 31 | len_generator=None, 32 | device='cuda'): 33 | """Extract inception features. 34 | 35 | Args: 36 | data_generator (generator): A data generator. 37 | inception (nn.Module): Inception model. 38 | len_generator (int): Length of the data_generator to show the 39 | progressbar. Default: None. 40 | device (str): Device. Default: cuda. 41 | 42 | Returns: 43 | Tensor: Extracted features. 44 | """ 45 | if len_generator is not None: 46 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 47 | else: 48 | pbar = None 49 | features = [] 50 | 51 | for data in data_generator: 52 | if pbar: 53 | pbar.update(1) 54 | data = data.to(device) 55 | feature = inception(data)[0].view(data.shape[0], -1) 56 | features.append(feature.to('cpu')) 57 | if pbar: 58 | pbar.close() 59 | features = torch.cat(features, 0) 60 | return features 61 | 62 | 63 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 64 | """Numpy implementation of the Frechet Distance. 65 | 66 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 67 | and X_2 ~ N(mu_2, C_2) is 68 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 69 | Stable version by Dougal J. Sutherland. 70 | 71 | Args: 72 | mu1 (np.array): The sample mean over activations. 73 | sigma1 (np.array): The covariance matrix over activations for 74 | generated samples. 75 | mu2 (np.array): The sample mean over activations, precalculated on an 76 | representative data set. 77 | sigma2 (np.array): The covariance matrix over activations, 78 | precalculated on an representative data set. 79 | 80 | Returns: 81 | float: The Frechet Distance. 82 | """ 83 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 84 | assert sigma1.shape == sigma2.shape, ( 85 | 'Two covariances have different dimensions') 86 | 87 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 88 | 89 | # Product might be almost singular 90 | if not np.isfinite(cov_sqrt).all(): 91 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 92 | 'of cov estimates') 93 | offset = np.eye(sigma1.shape[0]) * eps 94 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 95 | 96 | # Numerical error might give slight imaginary component 97 | if np.iscomplexobj(cov_sqrt): 98 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 99 | m = np.max(np.abs(cov_sqrt.imag)) 100 | raise ValueError(f'Imaginary component {m}') 101 | cov_sqrt = cov_sqrt.real 102 | 103 | mean_diff = mu1 - mu2 104 | mean_norm = mean_diff @ mean_diff 105 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 106 | fid = mean_norm + trace 107 | 108 | return fid 109 | -------------------------------------------------------------------------------- /Deraining/basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import numpy as np 8 | 9 | from basicsr.utils.matlab_functions import bgr2ycbcr 10 | 11 | 12 | def reorder_image(img, input_order='HWC'): 13 | """Reorder images to 'HWC' order. 14 | 15 | If the input_order is (h, w), return (h, w, 1); 16 | If the input_order is (c, h, w), return (h, w, c); 17 | If the input_order is (h, w, c), return as it is. 18 | 19 | Args: 20 | img (ndarray): Input image. 21 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 22 | If the input image shape is (h, w), input_order will not have 23 | effects. Default: 'HWC'. 24 | 25 | Returns: 26 | ndarray: reordered image. 27 | """ 28 | 29 | if input_order not in ['HWC', 'CHW']: 30 | raise ValueError( 31 | f'Wrong input_order {input_order}. Supported input_orders are ' 32 | "'HWC' and 'CHW'") 33 | if len(img.shape) == 2: 34 | img = img[..., None] 35 | if input_order == 'CHW': 36 | img = img.transpose(1, 2, 0) 37 | return img 38 | 39 | 40 | def to_y_channel(img): 41 | """Change to Y channel of YCbCr. 42 | 43 | Args: 44 | img (ndarray): Images with range [0, 255]. 45 | 46 | Returns: 47 | (ndarray): Images with range [0, 255] (float type) without round. 48 | """ 49 | img = img.astype(np.float32) / 255. 50 | if img.ndim == 3 and img.shape[2] == 3: 51 | img = bgr2ycbcr(img, y_only=True) 52 | img = img[..., None] 53 | return img * 255. 54 | -------------------------------------------------------------------------------- /Deraining/basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /Deraining/basicsr/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/basicsr/models/.DS_Store -------------------------------------------------------------------------------- /Deraining/basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import importlib 8 | from os import path as osp 9 | 10 | from basicsr.utils import get_root_logger, scandir 11 | 12 | # automatically scan and import model modules 13 | # scan all the files under the 'models' folder and collect files ending with 14 | # '_model.py' 15 | model_folder = osp.dirname(osp.abspath(__file__)) 16 | model_filenames = [ 17 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 18 | if v.endswith('_model.py') 19 | ] 20 | # import all the model modules 21 | _model_modules = [ 22 | importlib.import_module(f'basicsr.models.{file_name}') 23 | for file_name in model_filenames 24 | ] 25 | 26 | 27 | def create_model(opt): 28 | """Create model. 29 | 30 | Args: 31 | opt (dict): Configuration. It constains: 32 | model_type (str): Model type. 33 | """ 34 | model_type = opt['model_type'] 35 | # dynamic instantiation 36 | for module in _model_modules: 37 | model_cls = getattr(module, model_type, None) 38 | if model_cls is not None: 39 | break 40 | if model_cls is None: 41 | raise ValueError(f'Model {model_type} is not found.') 42 | 43 | model = model_cls(opt) 44 | 45 | logger = get_root_logger() 46 | logger.info(f'Model [{model.__class__.__name__}] is created.') 47 | return model 48 | -------------------------------------------------------------------------------- /Deraining/basicsr/models/archs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/basicsr/models/archs/.DS_Store -------------------------------------------------------------------------------- /Deraining/basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import importlib 8 | from os import path as osp 9 | 10 | from basicsr.utils import scandir 11 | 12 | # automatically scan and import arch modules 13 | # scan all the files under the 'archs' folder and collect files ending with 14 | # '_arch.py' 15 | arch_folder = osp.dirname(osp.abspath(__file__)) 16 | arch_filenames = [ 17 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 18 | if v.endswith('_arch.py') 19 | ] 20 | # import all the arch modules 21 | _arch_modules = [ 22 | importlib.import_module(f'basicsr.models.archs.{file_name}') 23 | for file_name in arch_filenames 24 | ] 25 | 26 | 27 | def dynamic_instantiation(modules, cls_type, opt): 28 | """Dynamically instantiate class. 29 | 30 | Args: 31 | modules (list[importlib modules]): List of modules from importlib 32 | files. 33 | cls_type (str): Class type. 34 | opt (dict): Class initialization kwargs. 35 | 36 | Returns: 37 | class: Instantiated class. 38 | """ 39 | 40 | for module in modules: 41 | cls_ = getattr(module, cls_type, None) 42 | if cls_ is not None: 43 | break 44 | if cls_ is None: 45 | raise ValueError(f'{cls_type} is not found.') 46 | return cls_(**opt) 47 | 48 | 49 | def define_network(opt): 50 | network_type = opt.pop('type') 51 | net = dynamic_instantiation(_arch_modules, network_type, opt) 52 | return net 53 | -------------------------------------------------------------------------------- /Deraining/basicsr/models/archs/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2), bias=bias) 11 | 12 | class SFT_Layer(nn.Module): 13 | ''' SFT layer ''' 14 | def __init__(self, nf=64, para=10): 15 | super(SFT_Layer, self).__init__() 16 | self.mul_conv1 = nn.Conv2d(para + nf, 32, kernel_size=3, stride=1, padding=1) 17 | self.mul_leaky = nn.LeakyReLU(0.2) 18 | self.mul_conv2 = nn.Conv2d(32, nf, kernel_size=3, stride=1, padding=1) 19 | 20 | self.add_conv1 = nn.Conv2d(para + nf, 32, kernel_size=3, stride=1, padding=1) 21 | self.add_leaky = nn.LeakyReLU(0.2) 22 | self.add_conv2 = nn.Conv2d(32, nf, kernel_size=3, stride=1, padding=1) 23 | 24 | def forward(self, feature_maps, para_maps): 25 | cat_input = torch.cat((feature_maps, para_maps), dim=1) 26 | mul = torch.sigmoid(self.mul_conv2(self.mul_leaky(self.mul_conv1(cat_input)))) 27 | add = self.add_conv2(self.add_leaky(self.add_conv1(cat_input))) 28 | return feature_maps * mul + add 29 | 30 | class DA_conv(nn.Module): 31 | def __init__(self, channels_in, channels_out, kernel_size, reduction): 32 | super(DA_conv, self).__init__() 33 | self.channels_out = channels_out 34 | self.channels_in = channels_in 35 | self.kernel_size = kernel_size 36 | 37 | self.kernel = nn.Sequential( 38 | nn.Linear(256, self.channels_in, bias=False), 39 | nn.LeakyReLU(0.1, True), 40 | nn.Linear(self.channels_in, self.channels_in * self.kernel_size * self.kernel_size, bias=False) 41 | ) 42 | 43 | self.conv = default_conv(channels_in, channels_out, 1) 44 | self.ca = CA_layer(channels_in, channels_out, reduction) 45 | 46 | self.relu = nn.LeakyReLU(0.1, True) 47 | 48 | def forward(self, x): 49 | ''' 50 | :param x[0]: feature map: B * C * H * W 51 | :param x[1]: degradation representation: B * C 52 | ''' 53 | # print('channels_in', self.channels_in) 54 | # print('channels_out', self.channels_out) 55 | # print('x0:', x[0].shape) 56 | # print('x1:', x[1].shape) 57 | b, c, h, w = x[0].size() 58 | # branch 1 59 | # kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size) 60 | # out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2)) 61 | # out = self.conv(out.view(b, -1, h, w)) 62 | 63 | # branch 2 64 | # print('out', out.shape) 65 | # out = out + self.ca(x) 66 | out = self.ca(x) 67 | # print('out', out.shape) 68 | return out 69 | 70 | 71 | class CA_layer(nn.Module): 72 | def __init__(self,channels_in, channels_out, reduction): 73 | super(CA_layer, self).__init__() 74 | 75 | if channels_in//reduction == 0: 76 | reduction = channels_in 77 | 78 | self.conv_du = nn.Sequential( 79 | nn.Conv2d(256, channels_in//reduction, 1, 1, 0, bias=False), 80 | nn.LeakyReLU(0.1, True), 81 | nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False), 82 | nn.Sigmoid() 83 | ) 84 | 85 | def forward(self, x): 86 | ''' 87 | :param x[0]: feature map: B * C * H * W 88 | :param x[1]: degradation representation: B * C 89 | ''' 90 | att = self.conv_du(x[1][:, :, None, None]) 91 | return x[0] * att 92 | 93 | class MeanShift(nn.Conv2d): 94 | def __init__( 95 | self, rgb_range, 96 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 97 | 98 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 99 | std = torch.Tensor(rgb_std) 100 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 101 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 102 | for p in self.parameters(): 103 | p.requires_grad = False 104 | 105 | class BasicBlock(nn.Sequential): 106 | def __init__( 107 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 108 | bn=True, act=nn.ReLU(True)): 109 | 110 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 111 | if bn: 112 | m.append(nn.BatchNorm2d(out_channels)) 113 | if act is not None: 114 | m.append(act) 115 | 116 | super(BasicBlock, self).__init__(*m) 117 | 118 | class ResBlock(nn.Module): 119 | def __init__( 120 | self, conv, n_feats, kernel_size, 121 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 122 | 123 | super(ResBlock, self).__init__() 124 | m = [] 125 | for i in range(2): 126 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 127 | if bn: 128 | m.append(nn.BatchNorm2d(n_feats)) 129 | if i == 0: 130 | m.append(act) 131 | 132 | self.body = nn.Sequential(*m) 133 | self.res_scale = res_scale 134 | 135 | def forward(self, x): 136 | res = self.body(x).mul(self.res_scale) 137 | res += x 138 | 139 | return res 140 | 141 | class Upsampler(nn.Sequential): 142 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 143 | 144 | m = [] 145 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 146 | for _ in range(int(math.log(scale, 2))): 147 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 148 | m.append(nn.PixelShuffle(2)) 149 | if bn: 150 | m.append(nn.BatchNorm2d(n_feats)) 151 | if act == 'relu': 152 | m.append(nn.ReLU(True)) 153 | elif act == 'prelu': 154 | m.append(nn.PReLU(n_feats)) 155 | 156 | elif scale == 3: 157 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 158 | m.append(nn.PixelShuffle(3)) 159 | if bn: 160 | m.append(nn.BatchNorm2d(n_feats)) 161 | if act == 'relu': 162 | m.append(nn.ReLU(True)) 163 | elif act == 'prelu': 164 | m.append(nn.PReLU(n_feats)) 165 | else: 166 | raise NotImplementedError 167 | 168 | super(Upsampler, self).__init__(*m) 169 | 170 | -------------------------------------------------------------------------------- /Deraining/basicsr/models/archs/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .do_conv import DOConv2d 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.ReLU(inplace=True)) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.main(x) + x 39 | 40 | class BasicDOConv(nn.Module): 41 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 42 | super(BasicDOConv, self).__init__() 43 | if bias and norm: 44 | bias = False 45 | 46 | padding = kernel_size // 2 47 | layers = list() 48 | if transpose: 49 | padding = kernel_size // 2 -1 50 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 51 | else: 52 | layers.append( 53 | DOConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 54 | if norm: 55 | layers.append(nn.BatchNorm2d(out_channel)) 56 | if relu: 57 | layers.append(nn.ReLU(inplace=True)) 58 | self.main = nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | return self.main(x) 62 | 63 | class ResFFTBlock(nn.Module): 64 | def __init__(self, in_channel, out_channel): 65 | super(ResFFTBlock, self).__init__() 66 | self.main = nn.Sequential( 67 | BasicDOConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 68 | BasicDOConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 69 | ) 70 | self.conv_1 = BasicDOConv(out_channel, out_channel, kernel_size=1, stride=1, relu=True) 71 | self.conv_2 = BasicDOConv(out_channel, out_channel, kernel_size=1, stride=1, relu=False) 72 | def forward(self, x): 73 | x_fft =torch.fft.fft2(x, dim=(-2, -1)) 74 | x_real = x_fft.real 75 | 76 | x_real = self.conv_1(x_real) 77 | x_real = self.conv_2(x_real) 78 | x_fft.real = x_real 79 | x_fft_res = torch.fft.ifft2(x_fft, dim=(-2, -1)) 80 | 81 | return self.main(x) + x + x_fft_res.real -------------------------------------------------------------------------------- /Deraining/basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .losses import (L1Loss, MSELoss, PSNRLoss, SSIMLoss) 8 | 9 | __all__ = [ 10 | 'L1Loss', 'MSELoss', 'PSNRLoss','SSIMLoss', 11 | ] 12 | -------------------------------------------------------------------------------- /Deraining/basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import functools 8 | from torch.nn import functional as F 9 | 10 | 11 | def reduce_loss(loss, reduction): 12 | """Reduce loss as specified. 13 | 14 | Args: 15 | loss (Tensor): Elementwise loss tensor. 16 | reduction (str): Options are 'none', 'mean' and 'sum'. 17 | 18 | Returns: 19 | Tensor: Reduced loss tensor. 20 | """ 21 | reduction_enum = F._Reduction.get_enum(reduction) 22 | # none: 0, elementwise_mean:1, sum: 2 23 | if reduction_enum == 0: 24 | return loss 25 | elif reduction_enum == 1: 26 | return loss.mean() 27 | else: 28 | return loss.sum() 29 | 30 | 31 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 32 | """Apply element-wise weight and reduce loss. 33 | 34 | Args: 35 | loss (Tensor): Element-wise loss. 36 | weight (Tensor): Element-wise weights. Default: None. 37 | reduction (str): Same as built-in losses of PyTorch. Options are 38 | 'none', 'mean' and 'sum'. Default: 'mean'. 39 | 40 | Returns: 41 | Tensor: Loss values. 42 | """ 43 | # if weight is specified, apply element-wise weight 44 | if weight is not None: 45 | assert weight.dim() == loss.dim() 46 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 47 | loss = loss * weight 48 | 49 | # if weight is not specified or reduction is sum, just reduce the loss 50 | if weight is None or reduction == 'sum': 51 | loss = reduce_loss(loss, reduction) 52 | # if reduction is mean, then compute mean over weight region 53 | elif reduction == 'mean': 54 | if weight.size(1) > 1: 55 | weight = weight.sum() 56 | else: 57 | weight = weight.sum() * loss.size(1) 58 | loss = loss.sum() / weight 59 | 60 | return loss 61 | 62 | 63 | def weighted_loss(loss_func): 64 | """Create a weighted version of a given loss function. 65 | 66 | To use this decorator, the loss function must have the signature like 67 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 68 | element-wise loss without any reduction. This decorator will add weight 69 | and reduction arguments to the function. The decorated function will have 70 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 71 | **kwargs)`. 72 | 73 | :Example: 74 | 75 | >>> import torch 76 | >>> @weighted_loss 77 | >>> def l1_loss(pred, target): 78 | >>> return (pred - target).abs() 79 | 80 | >>> pred = torch.Tensor([0, 2, 3]) 81 | >>> target = torch.Tensor([1, 1, 1]) 82 | >>> weight = torch.Tensor([1, 0, 1]) 83 | 84 | >>> l1_loss(pred, target) 85 | tensor(1.3333) 86 | >>> l1_loss(pred, target, weight) 87 | tensor(1.5000) 88 | >>> l1_loss(pred, target, reduction='none') 89 | tensor([1., 1., 2.]) 90 | >>> l1_loss(pred, target, weight, reduction='sum') 91 | tensor(3.) 92 | """ 93 | 94 | @functools.wraps(loss_func) 95 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 96 | # get element-wise loss 97 | loss = loss_func(pred, target, **kwargs) 98 | loss = weight_reduce_loss(loss, weight, reduction) 99 | return loss 100 | 101 | return wrapper 102 | -------------------------------------------------------------------------------- /Deraining/basicsr/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import logging 8 | import torch 9 | from os import path as osp 10 | 11 | from basicsr.data import create_dataloader, create_dataset 12 | from basicsr.models import create_model 13 | from basicsr.train import parse_options 14 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 15 | make_exp_dirs) 16 | from basicsr.utils.options import dict2str 17 | 18 | 19 | def main(): 20 | # parse options, set distributed setting, set ramdom seed 21 | opt = parse_options(is_train=False) 22 | 23 | torch.backends.cudnn.benchmark = True 24 | # torch.backends.cudnn.deterministic = True 25 | 26 | # mkdir and initialize loggers 27 | make_exp_dirs(opt) 28 | log_file = osp.join(opt['path']['log'], 29 | f"test_{opt['name']}_{get_time_str()}.log") 30 | logger = get_root_logger( 31 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 32 | logger.info(get_env_info()) 33 | logger.info(dict2str(opt)) 34 | 35 | # create test dataset and dataloader 36 | test_loaders = [] 37 | for phase, dataset_opt in sorted(opt['datasets'].items()): 38 | test_set = create_dataset(dataset_opt) 39 | test_loader = create_dataloader( 40 | test_set, 41 | dataset_opt, 42 | num_gpu=opt['num_gpu'], 43 | dist=opt['dist'], 44 | sampler=None, 45 | seed=opt['manual_seed']) 46 | logger.info( 47 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 48 | test_loaders.append(test_loader) 49 | 50 | # create model 51 | model = create_model(opt) 52 | 53 | for test_loader in test_loaders: 54 | test_set_name = test_loader.dataset.opt['name'] 55 | logger.info(f'Testing {test_set_name}...') 56 | rgb2bgr = opt['val'].get('rgb2bgr', True) 57 | # wheather use uint8 image to compute metrics 58 | use_image = opt['val'].get('use_image', True) 59 | model.validation( 60 | test_loader, 61 | current_iter=opt['name'], 62 | tb_logger=None, 63 | save_img=opt['val']['save_img'], 64 | rgb2bgr=rgb2bgr, use_image=use_image) 65 | print(model.cost) 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /Deraining/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .file_client import FileClient 8 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding 9 | from .logger import (MessageLogger, get_env_info, get_root_logger, 10 | init_tb_logger, init_wandb_logger) 11 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 12 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 13 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 14 | 15 | __all__ = [ 16 | # file_client.py 17 | 'FileClient', 18 | # img_util.py 19 | 'img2tensor', 20 | 'tensor2img', 21 | 'imfrombytes', 22 | 'imwrite', 23 | 'crop_border', 24 | # logger.py 25 | 'MessageLogger', 26 | 'init_tb_logger', 27 | 'init_wandb_logger', 28 | 'get_root_logger', 29 | 'get_env_info', 30 | # misc.py 31 | 'set_random_seed', 32 | 'get_time_str', 33 | 'mkdir_and_rename', 34 | 'make_exp_dirs', 35 | 'scandir', 36 | 'check_resume', 37 | 'sizeof_fmt', 38 | 'padding', 39 | 'create_lmdb_for_reds', 40 | 'create_lmdb_for_gopro', 41 | 'create_lmdb_for_rain13k', 42 | ] 43 | -------------------------------------------------------------------------------- /Deraining/basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import argparse 8 | from os import path as osp 9 | 10 | from basicsr.utils import scandir 11 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 12 | 13 | def prepare_keys(folder_path, suffix='png'): 14 | """Prepare image path list and keys for DIV2K dataset. 15 | 16 | Args: 17 | folder_path (str): Folder path. 18 | 19 | Returns: 20 | list[str]: Image path list. 21 | list[str]: Key list. 22 | """ 23 | print('Reading image path list ...') 24 | img_path_list = sorted( 25 | list(scandir(folder_path, suffix=suffix, recursive=False))) 26 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 27 | 28 | return img_path_list, keys 29 | 30 | def create_lmdb_for_reds(): 31 | folder_path = './datasets/REDS/val/sharp_300' 32 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 33 | img_path_list, keys = prepare_keys(folder_path, 'png') 34 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 35 | # 36 | folder_path = './datasets/REDS/val/blur_300' 37 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 38 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 39 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 40 | 41 | folder_path = './datasets/REDS/train/train_sharp' 42 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 43 | img_path_list, keys = prepare_keys(folder_path, 'png') 44 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 45 | 46 | folder_path = './datasets/REDS/train/train_blur_jpeg' 47 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 48 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 49 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 50 | 51 | 52 | def create_lmdb_for_gopro(): 53 | folder_path = './datasets/GoPro/train/blur_crops' 54 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/train/sharp_crops' 60 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/target' 66 | lmdb_path = './datasets/GoPro/test/target.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | folder_path = './datasets/GoPro/test/input' 72 | lmdb_path = './datasets/GoPro/test/input.lmdb' 73 | 74 | img_path_list, keys = prepare_keys(folder_path, 'png') 75 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 76 | 77 | def create_lmdb_for_rain13k(): 78 | folder_path = '/research/datasets/RAIN_SYN/RAIN13K/train/rainy_image' 79 | lmdb_path = '/research/datasets/RAIN_SYN/RAIN13K/train/rainy_image.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | folder_path = '/research/datasets/RAIN_SYN/RAIN13K/train/ground_truth' 85 | lmdb_path = '/research/datasets/RAIN_SYN/RAIN13K/train/ground_truth.lmdb' 86 | 87 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 88 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 89 | 90 | def create_lmdb_for_SIDD(): 91 | folder_path = './datasets/SIDD/train/input_crops' 92 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | folder_path = './datasets/SIDD/train/gt_crops' 98 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 99 | 100 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 101 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 102 | 103 | #for val 104 | folder_path = './datasets/SIDD/val/input_crops' 105 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 106 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 107 | if not osp.exists(folder_path): 108 | os.makedirs(folder_path) 109 | assert osp.exists(mat_path) 110 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 111 | N, B, H ,W, C = data.shape 112 | data = data.reshape(N*B, H, W, C) 113 | for i in tqdm(range(N*B)): 114 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 115 | img_path_list, keys = prepare_keys(folder_path, 'png') 116 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 117 | 118 | folder_path = './datasets/SIDD/val/gt_crops' 119 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 120 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 121 | if not osp.exists(folder_path): 122 | os.makedirs(folder_path) 123 | assert osp.exists(mat_path) 124 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 125 | N, B, H ,W, C = data.shape 126 | data = data.reshape(N*B, H, W, C) 127 | for i in tqdm(range(N*B)): 128 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 129 | img_path_list, keys = prepare_keys(folder_path, 'png') 130 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 131 | -------------------------------------------------------------------------------- /Deraining/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | 8 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 9 | import functools 10 | import os 11 | import subprocess 12 | import torch 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | 16 | 17 | def init_dist(launcher, backend='nccl', **kwargs): 18 | if mp.get_start_method(allow_none=True) is None: 19 | mp.set_start_method('spawn') 20 | if launcher == 'pytorch': 21 | _init_dist_pytorch(backend, **kwargs) 22 | elif launcher == 'slurm': 23 | _init_dist_slurm(backend, **kwargs) 24 | else: 25 | raise ValueError(f'Invalid launcher type: {launcher}') 26 | 27 | 28 | def _init_dist_pytorch(backend, **kwargs): 29 | rank = int(os.environ['RANK']) 30 | num_gpus = torch.cuda.device_count() 31 | torch.cuda.set_device(rank % num_gpus) 32 | dist.init_process_group(backend=backend, **kwargs) 33 | 34 | 35 | def _init_dist_slurm(backend, port=None): 36 | """Initialize slurm distributed training environment. 37 | 38 | If argument ``port`` is not specified, then the master port will be system 39 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 40 | environment variable, then a default port ``29500`` will be used. 41 | 42 | Args: 43 | backend (str): Backend of torch.distributed. 44 | port (int, optional): Master port. Defaults to None. 45 | """ 46 | proc_id = int(os.environ['SLURM_PROCID']) 47 | ntasks = int(os.environ['SLURM_NTASKS']) 48 | node_list = os.environ['SLURM_NODELIST'] 49 | num_gpus = torch.cuda.device_count() 50 | torch.cuda.set_device(proc_id % num_gpus) 51 | addr = subprocess.getoutput( 52 | f'scontrol show hostname {node_list} | head -n1') 53 | # specify master port 54 | if port is not None: 55 | os.environ['MASTER_PORT'] = str(port) 56 | elif 'MASTER_PORT' in os.environ: 57 | pass # use MASTER_PORT in the environment variable 58 | else: 59 | # 29500 is torch.distributed default port 60 | os.environ['MASTER_PORT'] = '29500' 61 | os.environ['MASTER_ADDR'] = addr 62 | os.environ['WORLD_SIZE'] = str(ntasks) 63 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 64 | os.environ['RANK'] = str(proc_id) 65 | dist.init_process_group(backend=backend) 66 | 67 | 68 | def get_dist_info(): 69 | if dist.is_available(): 70 | initialized = dist.is_initialized() 71 | else: 72 | initialized = False 73 | if initialized: 74 | rank = dist.get_rank() 75 | world_size = dist.get_world_size() 76 | else: 77 | rank = 0 78 | world_size = 1 79 | return rank, world_size 80 | 81 | 82 | def master_only(func): 83 | 84 | @functools.wraps(func) 85 | def wrapper(*args, **kwargs): 86 | rank, _ = get_dist_info() 87 | if rank == 0: 88 | return func(*args, **kwargs) 89 | 90 | return wrapper 91 | -------------------------------------------------------------------------------- /Deraining/basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import math 8 | import requests 9 | from tqdm import tqdm 10 | 11 | from .misc import sizeof_fmt 12 | 13 | 14 | def download_file_from_google_drive(file_id, save_path): 15 | """Download files from google drive. 16 | 17 | Ref: 18 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 19 | 20 | Args: 21 | file_id (str): File id. 22 | save_path (str): Save path. 23 | """ 24 | 25 | session = requests.Session() 26 | URL = 'https://docs.google.com/uc?export=download' 27 | params = {'id': file_id} 28 | 29 | response = session.get(URL, params=params, stream=True) 30 | token = get_confirm_token(response) 31 | if token: 32 | params['confirm'] = token 33 | response = session.get(URL, params=params, stream=True) 34 | 35 | # get file size 36 | response_file_size = session.get( 37 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 38 | if 'Content-Range' in response_file_size.headers: 39 | file_size = int( 40 | response_file_size.headers['Content-Range'].split('/')[1]) 41 | else: 42 | file_size = None 43 | 44 | save_response_content(response, save_path, file_size) 45 | 46 | 47 | def get_confirm_token(response): 48 | for key, value in response.cookies.items(): 49 | if key.startswith('download_warning'): 50 | return value 51 | return None 52 | 53 | 54 | def save_response_content(response, 55 | destination, 56 | file_size=None, 57 | chunk_size=32768): 58 | if file_size is not None: 59 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 60 | 61 | readable_file_size = sizeof_fmt(file_size) 62 | else: 63 | pbar = None 64 | 65 | with open(destination, 'wb') as f: 66 | downloaded_size = 0 67 | for chunk in response.iter_content(chunk_size): 68 | downloaded_size += chunk_size 69 | if pbar is not None: 70 | pbar.update(1) 71 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 72 | f'/ {readable_file_size}') 73 | if chunk: # filter out keep-alive new chunks 74 | f.write(chunk) 75 | if pbar is not None: 76 | pbar.close() 77 | -------------------------------------------------------------------------------- /Deraining/basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import yaml 8 | from collections import OrderedDict 9 | from os import path as osp 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def parse(opt_path, is_train=True): 38 | """Parse option file. 39 | 40 | Args: 41 | opt_path (str): Option file path. 42 | is_train (str): Indicate whether in training or not. Default: True. 43 | 44 | Returns: 45 | (dict): Options. 46 | """ 47 | with open(opt_path, mode='r') as f: 48 | Loader, _ = ordered_yaml() 49 | opt = yaml.load(f, Loader=Loader) 50 | 51 | opt['is_train'] = is_train 52 | 53 | # datasets 54 | if 'datasets' in opt: 55 | for phase, dataset in opt['datasets'].items(): 56 | # for several datasets, e.g., test_1, test_2 57 | phase = phase.split('_')[0] 58 | dataset['phase'] = phase 59 | if 'scale' in opt: 60 | dataset['scale'] = opt['scale'] 61 | if dataset.get('dataroot_gt') is not None: 62 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 63 | if dataset.get('dataroot_lq') is not None: 64 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 65 | 66 | # paths 67 | for key, val in opt['path'].items(): 68 | if (val is not None) and ('resume_state' in key 69 | or 'pretrain_network' in key): 70 | opt['path'][key] = osp.expanduser(val) 71 | opt['path']['root'] = osp.abspath( 72 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 73 | if is_train: 74 | experiments_root = osp.join(opt['path']['root'], 'experiments', 75 | opt['name']) 76 | opt['path']['experiments_root'] = experiments_root 77 | opt['path']['models'] = osp.join(experiments_root, 'models') 78 | opt['path']['training_states'] = osp.join(experiments_root, 79 | 'training_states') 80 | opt['path']['log'] = experiments_root 81 | opt['path']['visualization'] = osp.join(experiments_root, 82 | 'visualization') 83 | 84 | # change some options for debug mode 85 | if 'debug' in opt['name']: 86 | if 'val' in opt: 87 | opt['val']['val_freq'] = 8 88 | opt['logger']['print_freq'] = 1 89 | opt['logger']['save_checkpoint_freq'] = 8 90 | else: # test 91 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 92 | opt['path']['results_root'] = results_root 93 | opt['path']['log'] = results_root 94 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 95 | 96 | return opt 97 | 98 | 99 | def dict2str(opt, indent_level=1): 100 | """dict to string for printing options. 101 | 102 | Args: 103 | opt (dict): Option dict. 104 | indent_level (int): Indent level. Default: 1. 105 | 106 | Return: 107 | (str): Option string for printing. 108 | """ 109 | msg = '\n' 110 | for k, v in opt.items(): 111 | if isinstance(v, dict): 112 | msg += ' ' * (indent_level * 2) + k + ':[' 113 | msg += dict2str(v, indent_level + 1) 114 | msg += ' ' * (indent_level * 2) + ']\n' 115 | else: 116 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 117 | return msg 118 | -------------------------------------------------------------------------------- /Deraining/basicsr/utils/show.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | # from PIL import Image, ImageOps 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import os 10 | 11 | from torchvision.utils import save_image 12 | import cv2 13 | import time 14 | 15 | def feature_show(x,name, iter, stage): 16 | x_out = x.detach().cpu() 17 | x_out = x_out[0] 18 | x_out = np.average(x_out, axis=0) 19 | plt.imshow(x_out) 20 | plt.axis('off') # plt.show() 之前,plt.imshow() 之后 21 | # plt.xticks([]) #plt.show() 之前,plt.imshow() 之后 22 | # plt.yticks([]) 23 | plt.savefig(os.path.join('./demo', '{}_iter_{}_stage_{}.jpg'.format(name,iter, stage))) 24 | -------------------------------------------------------------------------------- /Deraining/basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Tue Feb 27 15:50:54 2024 3 | __version__ = '1.2.0+unknown' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /Deraining/generate_lmdb.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.create_lmdb import create_lmdb_for_rain13k 2 | from basicsr.utils import scandir 3 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 4 | def prepare_keys(folder_path, suffix='png'): 5 | """Prepare image path list and keys for DIV2K dataset. 6 | 7 | Args: 8 | folder_path (str): Folder path. 9 | 10 | Returns: 11 | list[str]: Image path list. 12 | list[str]: Key list. 13 | """ 14 | print('Reading image path list ...') 15 | img_path_list = sorted( 16 | list(scandir(folder_path, suffix=suffix, recursive=False))) 17 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 18 | 19 | return img_path_list, keys 20 | 21 | def create_lmdb_for_rainds(): 22 | folder_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAINDS_REAL/train/rainy_image' 23 | lmdb_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAINDS_REAL/train/rainy_image.lmdb' 24 | 25 | img_path_list, keys = prepare_keys(folder_path, 'png') 26 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 27 | 28 | folder_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAINDS_REAL/train/ground_truth' 29 | lmdb_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAINDS_REAL/train/ground_truth.lmdb' 30 | 31 | img_path_list, keys = prepare_keys(folder_path, 'png') 32 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 33 | 34 | def create_lmdb_for_rain200(): 35 | folder_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200H/train/rainy_image' 36 | lmdb_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200H/train/rainy_image.lmdb' 37 | 38 | img_path_list, keys = prepare_keys(folder_path, 'png') 39 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 40 | 41 | folder_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200H/train/ground_truth' 42 | lmdb_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200H/train/ground_truth.lmdb' 43 | 44 | img_path_list, keys = prepare_keys(folder_path, 'png') 45 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 46 | 47 | folder_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200L/train/rainy_image' 48 | lmdb_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200L/train/rainy_image.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200L/train/ground_truth' 54 | lmdb_path = '/home/jieh/Dataset/DERAIN_DATASETS/RAIN200L/train/ground_truth.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | 60 | def create_lmdb_for_rain13k(): 61 | folder_path = '/home/jieh/Dataset/RAIN_SYN/RAIN13K/train/rainy_image' 62 | lmdb_path = '/home/jieh/Dataset/RAIN_SYN/RAIN13K/train/rainy_image.lmdb' 63 | 64 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 65 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 66 | 67 | folder_path = '/home/jieh/Dataset/RAIN_SYN/RAIN13K/train/ground_truth' 68 | lmdb_path = '/home/jieh/Dataset/RAIN_SYN/RAIN13K/train/ground_truth.lmdb' 69 | 70 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 71 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 72 | 73 | def create_lmdb_for_snow100k(): 74 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/SNOW100K/train/snow' 75 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/SNOW100K/train/snow.lmdb' 76 | 77 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 78 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 79 | 80 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/SNOW100K/train/gt' 81 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/SNOW100K/train/gt.lmdb' 82 | 83 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 84 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 85 | 86 | def create_lmdb_for_KITTI_DATASET(): 87 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/KITTI_DATASET/train/snow' 88 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/KITTI_DATASET/train/snow.lmdb' 89 | 90 | img_path_list, keys = prepare_keys(folder_path, 'png') 91 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 92 | 93 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/KITTI_DATASET/train/gt' 94 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/KITTI_DATASET/train/gt.lmdb' 95 | 96 | img_path_list, keys = prepare_keys(folder_path, 'png') 97 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 98 | 99 | def create_lmdb_for_CSD(): 100 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/CSD/train/snow' 101 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/CSD/train/snow.lmdb' 102 | 103 | img_path_list, keys = prepare_keys(folder_path, 'tif') 104 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 105 | 106 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/CSD/train/gt' 107 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/CSD/train/gt.lmdb' 108 | 109 | img_path_list, keys = prepare_keys(folder_path, 'tif') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | def create_lmdb_for_CITYSCAPE_DATASET(): 112 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/CITYSCAPE_DATASET/train/snow' 113 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/CITYSCAPE_DATASET/train/snow.lmdb' 114 | 115 | img_path_list, keys = prepare_keys(folder_path, 'png') 116 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 117 | 118 | folder_path = '/home/jieh/Dataset/DESNOW_DATASETS/CITYSCAPE_DATASET/train/gt' 119 | lmdb_path = '/home/jieh/Dataset/DESNOW_DATASETS/CITYSCAPE_DATASET/train/gt.lmdb' 120 | 121 | img_path_list, keys = prepare_keys(folder_path, 'png') 122 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 123 | # create_lmdb_for_snow100k() 124 | # create_lmdb_for_KITTI_DATASET() 125 | # create_lmdb_for_CSD() 126 | # create_lmdb_for_CITYSCAPE_DATASET() 127 | # create_lmdb_for_rainds() 128 | create_lmdb_for_rain200() -------------------------------------------------------------------------------- /Deraining/get_best.py: -------------------------------------------------------------------------------- 1 | log_path = './train_Rain13k-derain_nips_20211106_163626.log' 2 | count = 0 3 | line_num = 0 4 | best_line_num = 0 5 | best_psnr = 0 6 | psnr_dict = {1:0,2:0,3:0} 7 | best_psnr_dict = {1:0,2:0,3:0} 8 | with open(log_path, 'r') as f: 9 | lines = f.readlines() 10 | for line in lines: 11 | line_num += 1 12 | if line.find('Validation Rain13k, # psnr: ') > 0: 13 | index = line.find('Validation Rain13k, # psnr: ') 14 | off = len('Validation Rain13k, # psnr: ') 15 | psnr = line[index+off:] 16 | psnr = psnr[:-2] 17 | psnr = float(psnr) 18 | count += 1 19 | psnr_dict[count] = psnr 20 | if count % 3 == 0: 21 | count = 0 22 | if best_psnr < sum(psnr_dict.values())/3: 23 | best_psnr = sum(psnr_dict.values())/3 24 | best_psnr_dict[1] = psnr_dict[1] 25 | best_psnr_dict[2] = psnr_dict[2] 26 | best_psnr_dict[3] = psnr_dict[3] 27 | best_line_num = line_num 28 | print(best_psnr) 29 | print(best_psnr_dict) 30 | print(best_line_num) -------------------------------------------------------------------------------- /Deraining/options/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/options/.DS_Store -------------------------------------------------------------------------------- /Deraining/options/train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/Deraining/options/train/.DS_Store -------------------------------------------------------------------------------- /Deraining/options/train/RAIN200H/fourmer.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: RAIN200H-Fourmer 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | train: 17 | name: RAIN200H 18 | type: PairedImageDataset 19 | dataroot_gt: /RAIN200H/train/ground_truth 20 | dataroot_lq: /RAIN200H/train/rainy_image 21 | 22 | filename_tmpl: '{}' 23 | io_backend: 24 | type: disk 25 | 26 | gt_size: 128 27 | use_flip: true 28 | use_rot: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 0 33 | batch_size_per_gpu: 12 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: RAIN200H 39 | type: PairedImageDataset 40 | dataroot_gt: ~ 41 | dataroot_lq: ~ 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: fourmer 48 | 49 | 50 | 51 | # path 52 | path: 53 | pretrain_network_g: ~ 54 | strict_load_g: true 55 | resume_state: ~ 56 | 57 | # training settings 58 | train: 59 | optim_g: 60 | type: Adam 61 | lr: !!float 1e-3 62 | weight_decay: 0 63 | betas: [0.9, 0.99] 64 | 65 | scheduler: 66 | type: TrueCosineAnnealingLR 67 | T_max: 400000 68 | eta_min: !!float 1e-7 69 | 70 | total_iter: 400000 71 | warmup_iter: 0 # no warm up 10000 72 | 73 | # losses 74 | pixel_opt: 75 | type: L1Loss 76 | loss_weight: 0.5 77 | reduction: mean 78 | # toY: true 79 | 80 | # validation settings 81 | val: 82 | # val_freq: 10 83 | # val_freq: !!float 2.5e4 84 | val_freq: !!float 2e3 85 | save_img: false 86 | grids: false 87 | crop_size: 256 88 | max_minibatch: 8 89 | 90 | metrics: 91 | psnr: # metric name, can be arbitrary 92 | type: calculate_psnr 93 | crop_border: 0 94 | test_y_channel: true 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 2e4 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /Deraining/options/train/RAIN200L/fourmer.yml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # general settings 8 | name: RAIN200L-Fourmer 9 | model_type: ImageRestorationModel 10 | scale: 1 11 | num_gpu: 1 12 | manual_seed: 10 13 | 14 | # dataset and data loader settings 15 | datasets: 16 | train: 17 | name: RAIN200L 18 | type: PairedImageDataset 19 | dataroot_gt: /RAIN200L/train/ground_truth 20 | dataroot_lq: /RAIN200L/train/rainy_image 21 | 22 | filename_tmpl: '{}' 23 | io_backend: 24 | type: disk 25 | 26 | gt_size: 128 27 | use_flip: true 28 | use_rot: true 29 | 30 | # data loader 31 | use_shuffle: true 32 | num_worker_per_gpu: 0 33 | batch_size_per_gpu: 12 34 | dataset_enlarge_ratio: 1 35 | prefetch_mode: ~ 36 | 37 | val: 38 | name: RAIN200H 39 | type: PairedImageDataset 40 | dataroot_gt: ~ 41 | dataroot_lq: ~ 42 | io_backend: 43 | type: disk 44 | 45 | # network structures 46 | network_g: 47 | type: fourmer 48 | 49 | 50 | 51 | # path 52 | path: 53 | pretrain_network_g: ~ 54 | strict_load_g: true 55 | resume_state: ~ 56 | 57 | # training settings 58 | train: 59 | optim_g: 60 | type: Adam 61 | lr: !!float 1e-3 62 | weight_decay: 0 63 | betas: [0.9, 0.99] 64 | 65 | scheduler: 66 | type: TrueCosineAnnealingLR 67 | T_max: 400000 68 | eta_min: !!float 1e-7 69 | 70 | total_iter: 400000 71 | warmup_iter: 0 # no warm up 10000 72 | 73 | # losses 74 | pixel_opt: 75 | type: L1Loss 76 | loss_weight: 0.5 77 | reduction: mean 78 | # toY: true 79 | 80 | # validation settings 81 | val: 82 | # val_freq: 10 83 | # val_freq: !!float 2.5e4 84 | val_freq: !!float 2e3 85 | save_img: false 86 | grids: false 87 | crop_size: 256 88 | max_minibatch: 8 89 | 90 | metrics: 91 | psnr: # metric name, can be arbitrary 92 | type: calculate_psnr 93 | crop_border: 0 94 | test_y_channel: true 95 | 96 | # logging settings 97 | logger: 98 | print_freq: 200 99 | save_checkpoint_freq: !!float 2e4 100 | use_tb_logger: true 101 | wandb: 102 | project: ~ 103 | resume_id: ~ 104 | 105 | # dist training settings 106 | dist_params: 107 | backend: nccl 108 | port: 29500 109 | -------------------------------------------------------------------------------- /Deraining/readme.md: -------------------------------------------------------------------------------- 1 | ## Applications 2 | ### Image deraining 3 | #### Prepare data 4 | Download the training data and add the data path to the config file (/basicsr/option/train/RAIN200H(L)/*.yml). 5 | #### : Training 6 | ``` 7 | python /basicsr/train.py -opt options/train/RAIN200H/fourmer.yml 8 | python /basicsr/train.py -opt /LLIE/options/train/RAIN200L/fourmer.yml 9 | ``` 10 | 11 | -------------------------------------------------------------------------------- /Deraining/requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | tqdm 13 | yapf 14 | -------------------------------------------------------------------------------- /Deraining/scripts/data_preparation/gopro.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import numpy as np 9 | import os 10 | import sys 11 | from multiprocessing import Pool 12 | from os import path as osp 13 | from tqdm import tqdm 14 | 15 | from basicsr.utils import scandir 16 | from basicsr.utils.create_lmdb import create_lmdb_for_gopro 17 | 18 | def main(): 19 | opt = {} 20 | opt['n_thread'] = 20 21 | opt['compression_level'] = 3 22 | 23 | opt['input_folder'] = './datasets/GoPro/train/input' 24 | opt['save_folder'] = './datasets/GoPro/train/blur_crops' 25 | opt['crop_size'] = 512 26 | opt['step'] = 256 27 | opt['thresh_size'] = 0 28 | extract_subimages(opt) 29 | 30 | opt['input_folder'] = './datasets/GoPro/train/target' 31 | opt['save_folder'] = './datasets/GoPro/train/sharp_crops' 32 | opt['crop_size'] = 512 33 | opt['step'] = 256 34 | opt['thresh_size'] = 0 35 | extract_subimages(opt) 36 | 37 | create_lmdb_for_gopro() 38 | 39 | 40 | def extract_subimages(opt): 41 | """Crop images to subimages. 42 | 43 | Args: 44 | opt (dict): Configuration dict. It contains: 45 | input_folder (str): Path to the input folder. 46 | save_folder (str): Path to save folder. 47 | n_thread (int): Thread number. 48 | """ 49 | input_folder = opt['input_folder'] 50 | save_folder = opt['save_folder'] 51 | if not osp.exists(save_folder): 52 | os.makedirs(save_folder) 53 | print(f'mkdir {save_folder} ...') 54 | else: 55 | print(f'Folder {save_folder} already exists. Exit.') 56 | sys.exit(1) 57 | 58 | img_list = list(scandir(input_folder, full_path=True)) 59 | 60 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 61 | pool = Pool(opt['n_thread']) 62 | for path in img_list: 63 | pool.apply_async( 64 | worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 65 | pool.close() 66 | pool.join() 67 | pbar.close() 68 | print('All processes done.') 69 | 70 | 71 | def worker(path, opt): 72 | """Worker for each process. 73 | 74 | Args: 75 | path (str): Image path. 76 | opt (dict): Configuration dict. It contains: 77 | crop_size (int): Crop size. 78 | step (int): Step for overlapped sliding window. 79 | thresh_size (int): Threshold size. Patches whose size is lower 80 | than thresh_size will be dropped. 81 | save_folder (str): Path to save folder. 82 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 83 | 84 | Returns: 85 | process_info (str): Process information displayed in progress bar. 86 | """ 87 | crop_size = opt['crop_size'] 88 | step = opt['step'] 89 | thresh_size = opt['thresh_size'] 90 | img_name, extension = osp.splitext(osp.basename(path)) 91 | 92 | # remove the x2, x3, x4 and x8 in the filename for DIV2K 93 | img_name = img_name.replace('x2', 94 | '').replace('x3', 95 | '').replace('x4', 96 | '').replace('x8', '') 97 | 98 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 99 | 100 | if img.ndim == 2: 101 | h, w = img.shape 102 | elif img.ndim == 3: 103 | h, w, c = img.shape 104 | else: 105 | raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') 106 | 107 | h_space = np.arange(0, h - crop_size + 1, step) 108 | if h - (h_space[-1] + crop_size) > thresh_size: 109 | h_space = np.append(h_space, h - crop_size) 110 | w_space = np.arange(0, w - crop_size + 1, step) 111 | if w - (w_space[-1] + crop_size) > thresh_size: 112 | w_space = np.append(w_space, w - crop_size) 113 | 114 | index = 0 115 | for x in h_space: 116 | for y in w_space: 117 | index += 1 118 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 119 | cropped_img = np.ascontiguousarray(cropped_img) 120 | cv2.imwrite( 121 | osp.join(opt['save_folder'], 122 | f'{img_name}_s{index:03d}{extension}'), cropped_img, 123 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 124 | process_info = f'Processing {img_name} ...' 125 | return process_info 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /Deraining/scripts/data_preparation/rain13k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from basicsr.utils.create_lmdb import create_lmdb_for_rain13k 8 | 9 | 10 | if __name__ == '__main__': 11 | create_lmdb_for_rain13k() -------------------------------------------------------------------------------- /Deraining/scripts/data_preparation/reds.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | ''' 8 | for val set, extract the subset val-300 9 | 10 | ''' 11 | import os 12 | import time 13 | from basicsr.utils.create_lmdb import create_lmdb_for_reds 14 | 15 | def make_val_300(folder, dst): 16 | if not os.path.exists(dst): 17 | os.mkdir(dst) 18 | templates = '*9.*' 19 | cp_command = 'cp {} {}'.format(os.path.join(folder, templates), dst) 20 | os.system(cp_command) 21 | 22 | 23 | def flatten_folders(folder): 24 | for vid in range(300): 25 | vidfolder_path = '{:03}'.format(vid) 26 | 27 | if not os.path.exists(os.path.join(folder, vidfolder_path)): 28 | continue 29 | 30 | print('working on .. {} .. {}'.format(folder, vid)) 31 | for fid in range(100): 32 | src_filename = '{:08}'.format(fid) 33 | 34 | suffixes = ['.jpg', '.png'] 35 | suffix = None 36 | 37 | for suf in suffixes: 38 | # print(os.path.join(folder, vidfolder_path, src_filename+suf)) 39 | if os.path.exists(os.path.join(folder, vidfolder_path, src_filename+suf)): 40 | suffix = suf 41 | break 42 | assert suffix is not None 43 | 44 | 45 | src_filepath = os.path.join(folder, vidfolder_path, src_filename+suffix) 46 | dst_filepath = os.path.join(folder, '{}_{}{}'.format(vidfolder_path, src_filename, suffix)) 47 | os.system('mv {} {}'.format(src_filepath, dst_filepath)) 48 | time.sleep(0.001) 49 | os.system('rm -r {}'.format(os.path.join(folder, vidfolder_path))) 50 | 51 | 52 | if __name__ == '__main__': 53 | flatten_folders('./datasets/REDS/train/train_blur_jpeg') 54 | flatten_folders('./datasets/REDS/train/train_sharp') 55 | 56 | flatten_folders('./datasets/REDS/val/val_blur_jpeg') 57 | flatten_folders('./datasets/REDS/val/val_sharp') 58 | make_val_300('./datasets/REDS/val/val_blur_jpeg', './datasets/REDS/val/blur_300') 59 | make_val_300('./datasets/REDS/val/val_sharp', './datasets/REDS/val/sharp_300') 60 | 61 | create_lmdb_for_reds() 62 | 63 | 64 | -------------------------------------------------------------------------------- /Deraining/scripts/data_preparation/sidd.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import numpy as np 9 | import os 10 | import sys 11 | from multiprocessing import Pool 12 | from os import path as osp 13 | from tqdm import tqdm 14 | 15 | from basicsr.utils import scandir_SIDD 16 | from basicsr.utils.create_lmdb import create_lmdb_for_SIDD 17 | 18 | def main(): 19 | opt = {} 20 | opt['n_thread'] = 20 21 | opt['compression_level'] = 3 22 | 23 | opt['input_folder'] = './datasets/SIDD/Data' 24 | opt['save_folder'] = './datasets/SIDD/train/input_crops' 25 | opt['crop_size'] = 512 26 | opt['step'] = 384 27 | opt['thresh_size'] = 0 28 | opt['keywords'] = '_NOISY' 29 | extract_subimages(opt) 30 | 31 | 32 | opt['save_folder'] = './datasets/SIDD/train/gt_crops' 33 | opt['keywords'] = '_GT' 34 | extract_subimages(opt) 35 | 36 | create_lmdb_for_SIDD() 37 | 38 | def extract_subimages(opt): 39 | """Crop images to subimages. 40 | 41 | Args: 42 | opt (dict): Configuration dict. It contains: 43 | input_folder (str): Path to the input folder. 44 | save_folder (str): Path to save folder. 45 | n_thread (int): Thread number. 46 | """ 47 | input_folder = opt['input_folder'] 48 | save_folder = opt['save_folder'] 49 | if not osp.exists(save_folder): 50 | os.makedirs(save_folder) 51 | print(f'mkdir {save_folder} ...') 52 | else: 53 | print(f'Folder {save_folder} already exists. Exit.') 54 | #sys.exit(1) 55 | 56 | img_list = list(scandir_SIDD(input_folder, keywords=opt['keywords'], recursive=True, full_path=True)) 57 | 58 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 59 | pool = Pool(opt['n_thread']) 60 | for path in img_list: 61 | pool.apply_async( 62 | worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 63 | pool.close() 64 | pool.join() 65 | pbar.close() 66 | print('All processes done.') 67 | 68 | 69 | def worker(path, opt): 70 | """Worker for each process. 71 | 72 | Args: 73 | path (str): Image path. 74 | opt (dict): Configuration dict. It contains: 75 | crop_size (int): Crop size. 76 | step (int): Step for overlapped sliding window. 77 | thresh_size (int): Threshold size. Patches whose size is lower 78 | than thresh_size will be dropped. 79 | save_folder (str): Path to save folder. 80 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 81 | 82 | Returns: 83 | process_info (str): Process information displayed in progress bar. 84 | """ 85 | crop_size = opt['crop_size'] 86 | step = opt['step'] 87 | thresh_size = opt['thresh_size'] 88 | img_name, extension = osp.splitext(osp.basename(path)) 89 | 90 | 91 | img_name = img_name.replace(opt['keywords'], '') 92 | 93 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 94 | 95 | if img.ndim == 2: 96 | h, w = img.shape 97 | elif img.ndim == 3: 98 | h, w, c = img.shape 99 | else: 100 | raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') 101 | 102 | h_space = np.arange(0, h - crop_size + 1, step) 103 | if h - (h_space[-1] + crop_size) > thresh_size: 104 | h_space = np.append(h_space, h - crop_size) 105 | w_space = np.arange(0, w - crop_size + 1, step) 106 | if w - (w_space[-1] + crop_size) > thresh_size: 107 | w_space = np.append(w_space, w - crop_size) 108 | 109 | index = 0 110 | for x in h_space: 111 | for y in w_space: 112 | index += 1 113 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 114 | cropped_img = np.ascontiguousarray(cropped_img) 115 | cv2.imwrite( 116 | osp.join(opt['save_folder'], 117 | f'{img_name}_s{index:03d}{extension}'), cropped_img, 118 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 119 | process_info = f'Processing {img_name} ...' 120 | return process_info 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | #... make sidd to lmdb 126 | -------------------------------------------------------------------------------- /Deraining/scripts/download_gdrive.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import argparse 8 | 9 | from basicsr.utils.download_util import download_file_from_google_drive 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--id', type=str, help='File id') 15 | parser.add_argument('--output', type=str, help='Save path') 16 | args = parser.parse_args() 17 | 18 | download_file_from_google_drive(args.id, args.save_path) 19 | -------------------------------------------------------------------------------- /Deraining/scripts/download_pretrained_models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import argparse 8 | import os 9 | from os import path as osp 10 | 11 | from basicsr.utils.download_util import download_file_from_google_drive 12 | 13 | 14 | def download_pretrained_models(method, file_ids): 15 | save_path_root = f'./experiments/pretrained_models/{method}' 16 | os.makedirs(save_path_root, exist_ok=True) 17 | 18 | for file_name, file_id in file_ids.items(): 19 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 20 | if osp.exists(save_path): 21 | user_response = input( 22 | f'{file_name} already exist. Do you want to cover it? Y/N\n') 23 | if user_response.lower() == 'y': 24 | print(f'Covering {file_name} to {save_path}') 25 | download_file_from_google_drive(file_id, save_path) 26 | elif user_response.lower() == 'n': 27 | print(f'Skipping {file_name}') 28 | else: 29 | raise ValueError('Wrong input. Only accpets Y/N.') 30 | else: 31 | print(f'Downloading {file_name} to {save_path}') 32 | download_file_from_google_drive(file_id, save_path) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | 38 | parser.add_argument( 39 | 'method', 40 | type=str, 41 | help=( 42 | "Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', " 43 | "'dlib'. Set to 'all' if you want to download all the models.")) 44 | args = parser.parse_args() 45 | 46 | file_ids = { 47 | 'ESRGAN': { 48 | 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name 49 | '1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id 50 | 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': 51 | '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM' 52 | }, 53 | 'EDVR': { 54 | 'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': 55 | '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb', 56 | 'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth': 57 | '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy', 58 | 'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth': 59 | '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny', 60 | 'EDVR_M_x4_SR_REDS_official-32075921.pth': 61 | '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6', 62 | 'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': 63 | '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl', 64 | 'EDVR_L_deblur_REDS_official-ca46bd8c.pth': 65 | '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE', 66 | 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': 67 | '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW' 68 | }, 69 | 'StyleGAN': { 70 | 'stylegan2_ffhq_config_f_1024_official-b09c3668.pth': 71 | '163PfuVSYKh4vhkYkfEaufw84CiF4pvWG', 72 | 'stylegan2_ffhq_config_f_1024_discriminator_official-806ddc5e.pth': 73 | '1wyOdcJnMtAT_fEwXYJObee7hcLzI8usT', 74 | 'stylegan2_cat_config_f_256_official-b82c74e3.pth': 75 | '1dGUvw8FLch50FEDAgAa6st1AXGnjduc7', 76 | 'stylegan2_cat_config_f_256_discriminator_official-f6f5ed5c.pth': 77 | '19wuj7Ztg56QtwEs01-p_LjQeoz6G11kF', 78 | 'stylegan2_church_config_f_256_official-12725a53.pth': 79 | '1Rcpguh4t833wHlFrWz9UuqFcSYERyd2d', 80 | 'stylegan2_church_config_f_256_discriminator_official-feba65b0.pth': # noqa: E501 81 | '1ImOfFUOwKqDDKZCxxM4VUdPQCc-j85Z9', 82 | 'stylegan2_car_config_f_512_official-32c42d4e.pth': 83 | '1FviBGvzORv4T3w0c3m7BaIfLNeEd0dC8', 84 | 'stylegan2_car_config_f_512_discriminator_official-31f302ab.pth': 85 | '1hlZ7M2GrK6cDFd2FIYazPxOZXTUfudB3', 86 | 'stylegan2_horse_config_f_256_official-d3d97ebc.pth': 87 | '1LV4OR22tJN19HHfGk0e7dVqMhjD0APRm', 88 | 'stylegan2_horse_config_f_256_discriminator_official-efc5e50e.pth': 89 | '1T8xbI-Tz8EeSg3gCmQBNqGjLP5l3Mv84' 90 | }, 91 | 'EDSR': { 92 | 'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': 93 | '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV', 94 | 'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth': 95 | '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5', 96 | 'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': 97 | '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn', 98 | 'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': 99 | '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU', 100 | 'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': 101 | '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su', 102 | 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': 103 | '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl' 104 | }, 105 | 'DUF': { 106 | 'DUF_x2_16L_official-39537cb9.pth': 107 | '1e91cEZOlUUk35keK9EnuK0F54QegnUKo', 108 | 'DUF_x3_16L_official-34ce53ec.pth': 109 | '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76', 110 | 'DUF_x4_16L_official-bf8f0cfa.pth': 111 | '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J', 112 | 'DUF_x4_28L_official-cbada450.pth': 113 | '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4', 114 | 'DUF_x4_52L_official-483d2c78.pth': 115 | '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T' 116 | }, 117 | 'DFDNet': { 118 | 'DFDNet_dict_512-f79685f0.pth': 119 | '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79', 120 | 'DFDNet_official-d1fa5650.pth': 121 | '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe' 122 | }, 123 | 'dlib': { 124 | 'mmod_human_face_detector-4cb19393.dat': 125 | '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL', 126 | 'shape_predictor_5_face_landmarks-c4b1e980.dat': 127 | '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F', 128 | 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': 129 | '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni' 130 | } 131 | } 132 | 133 | if args.method == 'all': 134 | for method in file_ids.keys(): 135 | download_pretrained_models(method, file_ids[method]) 136 | else: 137 | download_pretrained_models(args.method, file_ids[args.method]) 138 | -------------------------------------------------------------------------------- /Deraining/scripts/publish_models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import glob 8 | import subprocess 9 | import torch 10 | from os import path as osp 11 | from torch.serialization import _is_zipfile, _open_file_like 12 | 13 | 14 | def update_sha(paths): 15 | print('# Update sha ...') 16 | for idx, path in enumerate(paths): 17 | print(f'{idx+1:03d}: Processing {path}') 18 | net = torch.load(path, map_location=torch.device('cpu')) 19 | basename = osp.basename(path) 20 | if 'params' not in net and 'params_ema' not in net: 21 | raise ValueError(f'Please check! Model {basename} does not ' 22 | f"have 'params'/'params_ema' key.") 23 | else: 24 | if '-' in basename: 25 | # check whether the sha is the latest 26 | old_sha = basename.split('-')[1].split('.')[0] 27 | new_sha = subprocess.check_output(['sha256sum', 28 | path]).decode()[:8] 29 | if old_sha != new_sha: 30 | final_file = path.split('-')[0] + f'-{new_sha}.pth' 31 | print(f'\tSave from {path} to {final_file}') 32 | subprocess.Popen(['mv', path, final_file]) 33 | else: 34 | sha = subprocess.check_output(['sha256sum', path]).decode()[:8] 35 | final_file = path.split('.pth')[0] + f'-{sha}.pth' 36 | print(f'\tSave from {path} to {final_file}') 37 | subprocess.Popen(['mv', path, final_file]) 38 | 39 | 40 | def convert_to_backward_compatible_models(paths): 41 | """Convert to backward compatible pth files. 42 | 43 | PyTorch 1.6 uses a updated version of torch.save. In order to be compatible 44 | with previous PyTorch version, save it with 45 | _use_new_zipfile_serialization=False. 46 | """ 47 | print('# Convert to backward compatible pth files ...') 48 | for idx, path in enumerate(paths): 49 | print(f'{idx+1:03d}: Processing {path}') 50 | flag_need_conversion = False 51 | with _open_file_like(path, 'rb') as opened_file: 52 | if _is_zipfile(opened_file): 53 | flag_need_conversion = True 54 | 55 | if flag_need_conversion: 56 | net = torch.load(path, map_location=torch.device('cpu')) 57 | print('\tConverting to compatible pth file...') 58 | torch.save(net, path, _use_new_zipfile_serialization=False) 59 | 60 | 61 | if __name__ == '__main__': 62 | paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob( 63 | 'experiments/pretrained_models/**/*.pth') 64 | convert_to_backward_compatible_models(paths) 65 | update_sha(paths) 66 | -------------------------------------------------------------------------------- /Deraining/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | #!/usr/bin/env python 8 | 9 | from setuptools import find_packages, setup 10 | 11 | import os 12 | import subprocess 13 | import sys 14 | import time 15 | import torch 16 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 17 | CUDAExtension) 18 | 19 | version_file = 'basicsr/version.py' 20 | 21 | 22 | def readme(): 23 | return '' 24 | # with open('README.md', encoding='utf-8') as f: 25 | # content = f.read() 26 | # return content 27 | 28 | 29 | def get_git_hash(): 30 | 31 | def _minimal_ext_cmd(cmd): 32 | # construct minimal environment 33 | env = {} 34 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 35 | v = os.environ.get(k) 36 | if v is not None: 37 | env[k] = v 38 | # LANGUAGE is used on win32 39 | env['LANGUAGE'] = 'C' 40 | env['LANG'] = 'C' 41 | env['LC_ALL'] = 'C' 42 | out = subprocess.Popen( 43 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 44 | return out 45 | 46 | try: 47 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 48 | sha = out.strip().decode('ascii') 49 | except OSError: 50 | sha = 'unknown' 51 | 52 | return sha 53 | 54 | 55 | def get_hash(): 56 | if os.path.exists('.git'): 57 | sha = get_git_hash()[:7] 58 | elif os.path.exists(version_file): 59 | try: 60 | from basicsr.version import __version__ 61 | sha = __version__.split('+')[-1] 62 | except ImportError: 63 | raise ImportError('Unable to get git version') 64 | else: 65 | sha = 'unknown' 66 | 67 | return sha 68 | 69 | 70 | def write_version_py(): 71 | content = """# GENERATED VERSION FILE 72 | # TIME: {} 73 | __version__ = '{}' 74 | short_version = '{}' 75 | version_info = ({}) 76 | """ 77 | sha = get_hash() 78 | with open('VERSION', 'r') as f: 79 | SHORT_VERSION = f.read().strip() 80 | VERSION_INFO = ', '.join( 81 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 82 | VERSION = SHORT_VERSION + '+' + sha 83 | 84 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 85 | VERSION_INFO) 86 | with open(version_file, 'w') as f: 87 | f.write(version_file_str) 88 | 89 | 90 | def get_version(): 91 | with open(version_file, 'r') as f: 92 | exec(compile(f.read(), version_file, 'exec')) 93 | return locals()['__version__'] 94 | 95 | 96 | def make_cuda_ext(name, module, sources, sources_cuda=None): 97 | if sources_cuda is None: 98 | sources_cuda = [] 99 | define_macros = [] 100 | extra_compile_args = {'cxx': []} 101 | 102 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 103 | define_macros += [('WITH_CUDA', None)] 104 | extension = CUDAExtension 105 | extra_compile_args['nvcc'] = [ 106 | '-D__CUDA_NO_HALF_OPERATORS__', 107 | '-D__CUDA_NO_HALF_CONVERSIONS__', 108 | '-D__CUDA_NO_HALF2_OPERATORS__', 109 | ] 110 | sources += sources_cuda 111 | else: 112 | print(f'Compiling {name} without CUDA') 113 | extension = CppExtension 114 | 115 | return extension( 116 | name=f'{module}.{name}', 117 | sources=[os.path.join(*module.split('.'), p) for p in sources], 118 | define_macros=define_macros, 119 | extra_compile_args=extra_compile_args) 120 | 121 | 122 | def get_requirements(filename='requirements.txt'): 123 | return [] 124 | here = os.path.dirname(os.path.realpath(__file__)) 125 | with open(os.path.join(here, filename), 'r') as f: 126 | requires = [line.replace('\n', '') for line in f.readlines()] 127 | return requires 128 | 129 | 130 | if __name__ == '__main__': 131 | if '--no_cuda_ext' in sys.argv: 132 | ext_modules = [] 133 | sys.argv.remove('--no_cuda_ext') 134 | else: 135 | ext_modules = [ 136 | make_cuda_ext( 137 | name='deform_conv_ext', 138 | module='basicsr.models.ops.dcn', 139 | sources=['src/deform_conv_ext.cpp'], 140 | sources_cuda=[ 141 | 'src/deform_conv_cuda.cpp', 142 | 'src/deform_conv_cuda_kernel.cu' 143 | ]), 144 | make_cuda_ext( 145 | name='fused_act_ext', 146 | module='basicsr.models.ops.fused_act', 147 | sources=['src/fused_bias_act.cpp'], 148 | sources_cuda=['src/fused_bias_act_kernel.cu']), 149 | make_cuda_ext( 150 | name='upfirdn2d_ext', 151 | module='basicsr.models.ops.upfirdn2d', 152 | sources=['src/upfirdn2d.cpp'], 153 | sources_cuda=['src/upfirdn2d_kernel.cu']), 154 | ] 155 | 156 | write_version_py() 157 | setup( 158 | name='basicsr', 159 | version=get_version(), 160 | description='Open Source Image and Video Super-Resolution Toolbox', 161 | long_description=readme(), 162 | author='Xintao Wang', 163 | author_email='xintao.wang@outlook.com', 164 | keywords='computer vision, restoration, super resolution', 165 | url='https://github.com/xinntao/BasicSR', 166 | packages=find_packages( 167 | exclude=('options', 'datasets', 'experiments', 'results', 168 | 'tb_logger', 'wandb')), 169 | classifiers=[ 170 | 'Development Status :: 4 - Beta', 171 | 'License :: OSI Approved :: Apache Software License', 172 | 'Operating System :: OS Independent', 173 | 'Programming Language :: Python :: 3', 174 | 'Programming Language :: Python :: 3.7', 175 | 'Programming Language :: Python :: 3.8', 176 | ], 177 | license='Apache License 2.0', 178 | setup_requires=['cython', 'numpy'], 179 | install_requires=get_requirements(), 180 | ext_modules=ext_modules, 181 | cmdclass={'build_ext': BuildExtension}, 182 | zip_safe=False) 183 | -------------------------------------------------------------------------------- /Deraining/show.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | # from PIL import Image, ImageOps 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import os 10 | 11 | from torchvision.utils import save_image 12 | from cv2 import cv2 13 | import time 14 | 15 | # 可视化特征图 16 | # print(AVW.shape) #[1, 12, 16384] 17 | AV_show = AVW.detach().cpu() # .transpose(1, 2).contiguous() 18 | # print(AV_show.shape) #[1, 12, 16384] 19 | show_AV = AV_show.view(b, c, h, -1) 20 | # print(show_AV.shape) #[1, 12, 128, 128] 21 | viz(show_AV) 22 | 23 | 24 | # save_features(show_AV) 25 | # save_features_pcolor(show_AV) 26 | # print(AVW.shape) 27 | 28 | 29 | def save_features(feature_map): 30 | for i in range(feature_map.size(1)): 31 | save_image(feature_map[0][i], os.path.join('./feature_maps', 'image_{}.jpg'.format(i)), nrow=1, padding=0) 32 | 33 | 34 | def save_features_pcolor(feature_map): 35 | print(feature_map.shape) 36 | length = feature_map.shape[1] 37 | for i in range(length): 38 | feature = np.asanyarray(feature_map[0][i] * 255, dtype=np.uint8) 39 | features_pcolor = cv2.applyColorMap(feature, cv2.COLORMAP_JET) 40 | cv2.imwrite(os.path.join('./feature_maps', 'image_{}.jpg'.format(i)), features_pcolor) 41 | 42 | 43 | def viz(input): 44 | x = input[0] 45 | print(x.shape) 46 | min_num = np.minimum(16, x.size()[0]) 47 | for i in range(min_num): 48 | # plt.subplot(2, 8, i+1) 49 | plt.imshow(x[i]) 50 | 51 | plt.axis('off') # plt.show() 之前,plt.imshow() 之后 52 | # plt.xticks([]) #plt.show() 之前,plt.imshow() 之后 53 | # plt.yticks([]) 54 | 55 | plt.savefig(os.path.join('./feature_maps', 'image_{}.jpg'.format(time.time()))) 56 | # plt.show() 57 | -------------------------------------------------------------------------------- /Deraining/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | Pha = torch.zeros((1,2,6,6)) 3 | Pha = torch.tile(Pha, (2, 2)) 4 | print(Pha.shape) -------------------------------------------------------------------------------- /LLIE/create_txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def mkdir(path): 6 | if not os.path.exists(path): 7 | os.mkdir(path) 8 | 9 | 10 | def main(): 11 | assert os.path.exists(inputdir), 'Input dir not found' 12 | assert os.path.exists(targetdir), 'target dir not found' 13 | mkdir(outputdir) 14 | imgs = os.listdir(inputdir) 15 | for img in imgs: 16 | groups = '' 17 | 18 | groups += os.path.join(inputdir, img) + '|' 19 | groups += os.path.join(targetdir,img) 20 | 21 | with open(os.path.join(outputdir, 'groups_test_lowReFive.txt'), 'a') as f: 22 | f.write(groups + '\n') 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--input', type=str, default='/gdata/huangjie/Continous/ExpFive/test/Low', metavar='PATH', help='root dir to save low resolution images') 27 | parser.add_argument('--target', type=str, default='/gdata/huangjie/Continous/ExpFive/test/Retouch', metavar='PATH', help='root dir to save high resolution images') 28 | parser.add_argument('--output', type=str, default='/ghome/huangjie/Continous/Baseline/', metavar='PATH', help='output dir to save group txt files') 29 | parser.add_argument('--ext', type=str, default='.png', help='Extension of files') 30 | args = parser.parse_args() 31 | 32 | inputdir = args.input 33 | targetdir = args.target 34 | outputdir = args.output 35 | ext = args.ext 36 | 37 | main() 38 | -------------------------------------------------------------------------------- /LLIE/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 11 | batch_size = dataset_opt['batch_size'] 12 | # shuffle = True 13 | shuffle = dataset_opt['use_shuffle'] 14 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 15 | num_workers=0, sampler=sampler, 16 | pin_memory=True) 17 | else: 18 | batch_size = dataset_opt['batch_size'] 19 | # shuffle = dataset_opt['use_shuffle'] 20 | shuffle = False 21 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, 22 | pin_memory=False) 23 | 24 | 25 | def create_dataset(opt,dataset_opt): 26 | mode = dataset_opt['mode'] 27 | # datasets for image restoration 28 | if mode == 'UEN_train': 29 | from data.SIEN_dataset import DatasetFromFolder as D 30 | 31 | dataset = D(upscale_factor=opt['scale'], data_augmentation=dataset_opt['augment'], 32 | group_file=dataset_opt['filelist'], 33 | patch_size=dataset_opt['IN_size'], black_edges_crop=False, hflip=True, rot=True) 34 | 35 | elif mode == 'UEN_val': 36 | from data.SIEN_dataset import DatasetFromFolder as D 37 | dataset = D(upscale_factor=opt['scale'], data_augmentation=False, 38 | group_file=dataset_opt['filelist'], 39 | patch_size=None, black_edges_crop=False, hflip=False, rot=False) 40 | 41 | else: 42 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 43 | 44 | 45 | logger = logging.getLogger('base') 46 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 47 | dataset_opt['name'])) 48 | return dataset 49 | -------------------------------------------------------------------------------- /LLIE/data/groups_test_LOL.txt: -------------------------------------------------------------------------------- 1 | /home/jieh/Dataset/Continous/LOL/test/low/1.png|/home/jieh/Dataset/Continous/LOL/test/high/1.png 2 | /home/jieh/Dataset/Continous/LOL/test/low/111.png|/home/jieh/Dataset/Continous/LOL/test/high/111.png 3 | /home/jieh/Dataset/Continous/LOL/test/low/146.png|/home/jieh/Dataset/Continous/LOL/test/high/146.png 4 | /home/jieh/Dataset/Continous/LOL/test/low/179.png|/home/jieh/Dataset/Continous/LOL/test/high/179.png 5 | /home/jieh/Dataset/Continous/LOL/test/low/22.png|/home/jieh/Dataset/Continous/LOL/test/high/22.png 6 | /home/jieh/Dataset/Continous/LOL/test/low/23.png|/home/jieh/Dataset/Continous/LOL/test/high/23.png 7 | /home/jieh/Dataset/Continous/LOL/test/low/493.png|/home/jieh/Dataset/Continous/LOL/test/high/493.png 8 | /home/jieh/Dataset/Continous/LOL/test/low/547.png|/home/jieh/Dataset/Continous/LOL/test/high/547.png 9 | /home/jieh/Dataset/Continous/LOL/test/low/55.png|/home/jieh/Dataset/Continous/LOL/test/high/55.png 10 | /home/jieh/Dataset/Continous/LOL/test/low/665.png|/home/jieh/Dataset/Continous/LOL/test/high/665.png 11 | /home/jieh/Dataset/Continous/LOL/test/low/669.png|/home/jieh/Dataset/Continous/LOL/test/high/669.png 12 | /home/jieh/Dataset/Continous/LOL/test/low/748.png|/home/jieh/Dataset/Continous/LOL/test/high/748.png 13 | /home/jieh/Dataset/Continous/LOL/test/low/778.png|/home/jieh/Dataset/Continous/LOL/test/high/778.png 14 | /home/jieh/Dataset/Continous/LOL/test/low/780.png|/home/jieh/Dataset/Continous/LOL/test/high/780.png 15 | /home/jieh/Dataset/Continous/LOL/test/low/79.png|/home/jieh/Dataset/Continous/LOL/test/high/79.png 16 | -------------------------------------------------------------------------------- /LLIE/data/shuffle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | out = open("/ghome/huangjie/Continous/Baseline/groups_train_mixReFive.txt",'w') 4 | lines=[] 5 | with open("/ghome/huangjie/Continous/Baseline/mix.txt", 'r') as infile: 6 | for line in infile: 7 | lines.append(line) 8 | random.shuffle(lines) 9 | for line in lines: 10 | out.write(line) 11 | 12 | infile.close() 13 | out.close() 14 | 15 | -------------------------------------------------------------------------------- /LLIE/eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets 3 | ''' 4 | import os 5 | import os.path as osp 6 | import glob 7 | import logging 8 | import numpy as np 9 | import cv2 10 | import torch 11 | import argparse 12 | import utils.util as util 13 | import data.util as data_util 14 | import models.archs.EnhanceN_arch as EnhanceN_arch 15 | 16 | def dataload_test(img_path): 17 | img_numpy = cv2.imread(img_path,cv2.IMREAD_UNCHANGED).astype(np.float32)/255.0 18 | # img_numpy = cv2.resize(img_numpy,(512,512)) 19 | img_numpy = img_numpy[:, :, [2, 1, 0]] 20 | img_numpy = torch.from_numpy(np.ascontiguousarray(np.transpose(img_numpy, (2, 0, 1)))).float() 21 | img_numpy = img_numpy.unsqueeze(0) 22 | return img_numpy 23 | 24 | 25 | def forward_eval(model,img): 26 | with torch.no_grad(): 27 | model_output = model(img,torch.cuda.FloatTensor().resize_(1).zero_()+1, 28 | torch.cuda.FloatTensor().resize_(1).zero_()+1) 29 | # if isinstance(model_output, list) or isinstance(model_output, tuple): 30 | output = model_output 31 | # else: 32 | # output = model_output 33 | output = output.data.float().cpu() 34 | 35 | return output 36 | 37 | 38 | def main(root, save_folder, imageLists, modelPath): 39 | ################# 40 | # configurations 41 | ################# 42 | device = torch.device('cuda') 43 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 44 | ############################################################################ 45 | #### model 46 | model_path = modelPath 47 | model = EnhanceN_arch.Net() 48 | 49 | #### dataset 50 | test_dataset_folder = root 51 | save_imgs = True 52 | util.mkdirs(save_folder) 53 | # util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) 54 | # logger = logging.getLogger('base') 55 | 56 | #### set up the models 57 | model.load_state_dict(torch.load(model_path), strict=True) 58 | print("load model successfully") 59 | model.eval() 60 | model = model.to(device) 61 | 62 | image_filenames = [line.rstrip() for line in open(os.path.join(imageLists))] 63 | 64 | # process each image 65 | j = 0 66 | for img_idx in range(len(image_filenames)): 67 | 68 | img_name = os.path.basename(image_filenames[img_idx]) 69 | img = dataload_test(image_filenames[img_idx]).to(device) 70 | folder = image_filenames[img_idx].split('/')[-2] 71 | util.mkdirs(os.path.join(save_folder,folder)) 72 | # if img_right.shape[3] < 1000: 73 | # continue 74 | # for tx in range(2): 75 | # img_left = img_left_ori[:, :, :, max(0, 736 * tx - 32):min(736 * (tx + 1) + 32, img_left_ori.shape[3])] 76 | # img_right = img_right_ori[:, :, :, max(0, 736 * tx - 32):min(736 * (tx + 1) + 32, img_right_ori.shape[3])] 77 | 78 | output = forward_eval(model,img) 79 | output = util.tensor2img(output.squeeze(0)) 80 | 81 | if save_imgs: 82 | 83 | cv2.imwrite(osp.join(save_folder,folder,img_name), output) 84 | 85 | 86 | j = j + 1 87 | print("process %d th image" % j) 88 | 89 | print('################ Finish Testing ################') 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--root', type=str, default='/data/1760921465/NTIRE2021/SR/test_input/', metavar='PATH', help='validation dataset root dir') 96 | parser.add_argument('--imageLists', type=str, default='/code/UEN/data/srtest_input.txt', metavar='FILE', help='record video ids') 97 | parser.add_argument('--save_folder', type=str, default='/data/1760921465/NTIRE2021/SR/GFN_test', metavar='PATH', help='save results') 98 | parser.add_argument('--modelPath', type=str, default='/model/1760921465/NTIRE2021/SR/GFN.pth', help='Model path') 99 | 100 | args = parser.parse_args() 101 | 102 | root = args.root 103 | imageLists = args.imageLists 104 | save_folder = args.save_folder 105 | modelPath = args.modelPath 106 | 107 | main(root, save_folder, imageLists, modelPath) 108 | -------------------------------------------------------------------------------- /LLIE/eval_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Test Vid4 (SR) and REDS4 (SR-clean, SR-blur, deblur-clean, deblur-compression) datasets 3 | ''' 4 | import os 5 | import os.path as osp 6 | import glob 7 | import logging 8 | import numpy as np 9 | import cv2 10 | import torch 11 | import argparse 12 | 13 | import utils.util as util 14 | import data.util as data_util 15 | import models.archs.EDVR_arch as EDVR_arch 16 | 17 | 18 | def main(root, output, videoLists, modelPath): 19 | ################# 20 | # configurations 21 | ################# 22 | device = torch.device('cuda') 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 24 | data_mode = 'Vid4' # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp 25 | # Vid4: SR 26 | # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur); 27 | # blur (deblur-clean), blur_comp (deblur-compression). 28 | flip_test = False 29 | ############################################################################ 30 | #### model 31 | 32 | model_path = modelPath 33 | # model_path = '/output/experiments/EDVR_NEW/models/40_G.pth' 34 | N_in = 5 35 | predeblur, HR_in = False, False 36 | back_RBs = 10 37 | 38 | model = EDVR_arch.EDVR(nf=32, nframes=N_in, groups=8, front_RBs=5, back_RBs=back_RBs, predeblur=predeblur, 39 | HR_in=HR_in) 40 | 41 | #### dataset 42 | test_dataset_folder = root 43 | 44 | #### evaluation 45 | # temporal padding mode 46 | 47 | padding = 'new_info' 48 | save_imgs = True 49 | 50 | save_folder = output 51 | util.mkdirs(save_folder) 52 | util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) 53 | logger = logging.getLogger('base') 54 | 55 | #### log info 56 | logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) 57 | logger.info('Padding mode: {}'.format(padding)) 58 | logger.info('Model path: {}'.format(model_path)) 59 | logger.info('Save images: {}'.format(save_imgs)) 60 | logger.info('Flip test: {}'.format(flip_test)) 61 | 62 | #### set up the models 63 | model.load_state_dict(torch.load(model_path), strict=True) 64 | model.eval() 65 | model = model.to(device) 66 | 67 | subfolder_name_l = [] 68 | 69 | with open(videoLists, 'r') as f: 70 | while True: 71 | line = f.readline().strip() 72 | if line == '': 73 | break 74 | subfolder_name_l.append(osp.join(root, line)) 75 | 76 | subfolder_l = sorted(subfolder_name_l) 77 | subfolder_name_l = [] 78 | 79 | print(subfolder_l) 80 | # temp = input() 81 | 82 | # subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*'))) 83 | # for each subfolder 84 | i = 0 85 | for subfolder in subfolder_l: 86 | subfolder_name = osp.basename(subfolder) 87 | subfolder_name_l.append(subfolder_name) 88 | save_subfolder = osp.join(save_folder, subfolder_name) 89 | 90 | img_path_l = sorted(glob.glob(osp.join(subfolder, '*'))) 91 | max_idx = len(img_path_l) 92 | if save_imgs: 93 | util.mkdirs(save_subfolder) 94 | 95 | #### read LQ and GT images 96 | imgs_LQ = data_util.read_img_seq(subfolder) 97 | i = i + 1 98 | print("process %d th subfolder" % i) 99 | 100 | # process each image 101 | j = 0 102 | for img_idx, img_path in enumerate(img_path_l): 103 | img_name = osp.splitext(osp.basename(img_path))[0] 104 | select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) 105 | imgs_in = (imgs_LQ.index_select(0, torch.LongTensor(select_idx))[2]).unsqueeze(0).to(device) 106 | 107 | if flip_test: 108 | output = util.flipx4_forward(model, imgs_in) 109 | else: 110 | output = util.single_forward(model, imgs_in) 111 | output = util.tensor2img(output.squeeze(0)) 112 | 113 | if save_imgs: 114 | cv2.imwrite(osp.join(save_subfolder, '{}.png'.format(img_name)), output) 115 | 116 | j = j + 1 117 | print("process %d th image" % j) 118 | 119 | logger.info('################ Finish Testing ################') 120 | # logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder)) 121 | # logger.info('Padding mode: {}'.format(padding)) 122 | # logger.info('Model path: {}'.format(model_path)) 123 | # logger.info('Save images: {}'.format(save_imgs)) 124 | # logger.info('Flip test: {}'.format(flip_test)) 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--root', type=str, default='/tmp/data/answer/stage1_test_input', metavar='PATH', help='validation dataset root dir') 130 | parser.add_argument('--videoLists', type=str, default='/code/code_enhance/data/videolist_test1.txt', metavar='FILE', help='record video ids') 131 | parser.add_argument('--output', type=str, default='/tmp/data/answer/stage2_test_input/', metavar='PATH', help='save results') 132 | parser.add_argument('--modelPath', type=str, default='/tmp/data/model/stage1_model/experiments/EDVR_NEW/models/140000_G.pth', help='Model path') 133 | 134 | args = parser.parse_args() 135 | 136 | root = args.root 137 | videoLists = args.videoLists 138 | output = args.output 139 | modelPath = args.modelPath 140 | 141 | main(root, output, videoLists, modelPath) 142 | -------------------------------------------------------------------------------- /LLIE/metrics/calculate_PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | import os 6 | import math 7 | import numpy as np 8 | import cv2 9 | import glob 10 | import torch 11 | 12 | 13 | def main(): 14 | # Configurations 15 | 16 | # GT - Ground-truth; 17 | # Gen: Generated / Restored / Recovered images 18 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' 19 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' 20 | 21 | crop_border = 4 22 | suffix = '' # suffix for Gen images 23 | test_Y = False # True: test Y channel only; False: test RGB channels 24 | 25 | PSNR_all = [] 26 | SSIM_all = [] 27 | img_list = sorted(glob.glob(folder_GT + '/*')) 28 | 29 | if test_Y: 30 | print('Testing Y channel.') 31 | else: 32 | print('Testing RGB channels.') 33 | 34 | for i, img_path in enumerate(img_list): 35 | base_name = os.path.splitext(os.path.basename(img_path))[0] 36 | im_GT = cv2.imread(img_path) / 255. 37 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. 38 | 39 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 40 | im_GT_in = bgr2ycbcr(im_GT) 41 | im_Gen_in = bgr2ycbcr(im_Gen) 42 | else: 43 | im_GT_in = im_GT 44 | im_Gen_in = im_Gen 45 | 46 | # crop borders 47 | if im_GT_in.ndim == 3: 48 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 49 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 50 | elif im_GT_in.ndim == 2: 51 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 52 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 53 | else: 54 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 55 | 56 | # calculate PSNR and SSIM 57 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 58 | 59 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 60 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( 61 | i + 1, base_name, PSNR, SSIM)) 62 | PSNR_all.append(PSNR) 63 | SSIM_all.append(SSIM) 64 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( 65 | sum(PSNR_all) / len(PSNR_all), 66 | sum(SSIM_all) / len(SSIM_all))) 67 | 68 | 69 | def calculate_psnr(img1, img2): 70 | # img1 and img2 have range [0, 255] 71 | img1 = img1.astype(np.float64) 72 | img2 = img2.astype(np.float64) 73 | mse = np.mean((img1 - img2)**2) 74 | if mse == 0: 75 | return float('inf') 76 | return 20 * math.log10(255.0 / math.sqrt(mse)) 77 | 78 | 79 | def ssim(img1, img2): 80 | C1 = (0.01 * 255)**2 81 | C2 = (0.03 * 255)**2 82 | 83 | img1 = img1.astype(np.float64) 84 | img2 = img2.astype(np.float64) 85 | kernel = cv2.getGaussianKernel(11, 1.5) 86 | window = np.outer(kernel, kernel.transpose()) 87 | 88 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 89 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 90 | mu1_sq = mu1**2 91 | mu2_sq = mu2**2 92 | mu1_mu2 = mu1 * mu2 93 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 94 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 95 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 96 | 97 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 98 | (sigma1_sq + sigma2_sq + C2)) 99 | return ssim_map.mean() 100 | 101 | 102 | def calculate_ssim(img1, img2): 103 | '''calculate SSIM 104 | the same outputs as MATLAB's 105 | img1, img2: [0, 255] 106 | ''' 107 | if not img1.shape == img2.shape: 108 | raise ValueError('Input images must have the same dimensions.') 109 | if img1.ndim == 2: 110 | return ssim(img1, img2) 111 | elif img1.ndim == 3: 112 | if img1.shape[2] == 3: 113 | ssims = [] 114 | for i in range(3): 115 | ssims.append(ssim(img1, img2)) 116 | return np.array(ssims).mean() 117 | elif img1.shape[2] == 1: 118 | return ssim(np.squeeze(img1), np.squeeze(img2)) 119 | else: 120 | raise ValueError('Wrong input image dimensions.') 121 | 122 | 123 | def bgr2ycbcr(img, only_y=True): 124 | '''same as matlab rgb2ycbcr 125 | only_y: only return Y channel 126 | Input: 127 | uint8, [0, 255] 128 | float, [0, 1] 129 | ''' 130 | in_img_type = img.dtype 131 | img.astype(np.float32) 132 | if in_img_type != np.uint8: 133 | img *= 255. 134 | # convert 135 | if only_y: 136 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 137 | else: 138 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 139 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 140 | if in_img_type == np.uint8: 141 | rlt = rlt.round() 142 | else: 143 | rlt /= 255. 144 | return rlt.astype(in_img_type) 145 | 146 | 147 | def psnr_np(enhanced, image_dslr): 148 | # target = np.array(image_dslr) 149 | # enhanced = np.array(enhanced) 150 | # enhanced = np.clip(enhanced, 0, 1) 151 | # 152 | # 153 | # squared_error = np.square(enhanced - target) 154 | # mse = np.mean(squared_error) 155 | # psnr = 10 * np.log10(1.0 / mse) 156 | squares = (enhanced-image_dslr).pow(2) 157 | squares = squares.view([squares.shape[0],-1]) 158 | psnr = torch.mean((-10/np.log(10))*torch.log(torch.mean(squares, dim=1))) 159 | 160 | return psnr 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /LLIE/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | # image restoration 7 | model = opt['model'] 8 | if model == 'sr': 9 | from .SIEN_model import SIEN_Model as M 10 | else: 11 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 12 | m = M(opt) 13 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 14 | return m 15 | 16 | -------------------------------------------------------------------------------- /LLIE/models/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__init__.py -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/EDVR_arch.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/EDVR_arch.cpython-35.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/EDVR_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/EDVR_arch.cpython-36.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/EnhanceN_arch.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/EnhanceN_arch.cpython-35.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/EnhanceN_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/EnhanceN_arch.cpython-38.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/EnhanceN_arch1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/EnhanceN_arch1.cpython-38.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/arch_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/arch_util.cpython-35.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/arch_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/arch_util.cpython-36.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/arch_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/arch_util.cpython-38.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/__pycache__/discriminator_vgg_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/models/archs/__pycache__/discriminator_vgg_arch.cpython-38.pyc -------------------------------------------------------------------------------- /LLIE/models/archs/arch_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## AAAI 2022 3 | """ 4 | 5 | # --- Imports --- # 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | 12 | class LayerNormFunction(torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, x, weight, bias, eps): 16 | ctx.eps = eps 17 | N, C, H, W = x.size() 18 | mu = x.mean(1, keepdim=True) 19 | var = (x - mu).pow(2).mean(1, keepdim=True) 20 | y = (x - mu) / (var + eps).sqrt() 21 | ctx.save_for_backward(y, var, weight) 22 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 23 | return y 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | eps = ctx.eps 28 | 29 | N, C, H, W = grad_output.size() 30 | y, var, weight = ctx.saved_variables 31 | g = grad_output * weight.view(1, C, 1, 1) 32 | mean_g = g.mean(dim=1, keepdim=True) 33 | 34 | mean_gy = (g * y).mean(dim=1, keepdim=True) 35 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 36 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 37 | dim=0), None 38 | 39 | class LayerNorm2d(nn.Module): 40 | def __init__(self,channels,eps=1e-6): 41 | super(LayerNorm2d, self).__init__() 42 | self.register_parameter('weight',nn.Parameter(torch.ones(channels))) 43 | self.register_parameter('bias',nn.Parameter(torch.zeros(channels))) 44 | self.eps = eps 45 | 46 | def forward(self,x): 47 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 48 | 49 | 50 | 51 | class SimpleGate(nn.Module): 52 | def forward(self, x): 53 | x1, x2 = x.chunk(2, dim=1) 54 | return x1 * x2 55 | 56 | 57 | class NAFBlock(nn.Module): 58 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 59 | super().__init__() 60 | dw_channel = c * DW_Expand 61 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 62 | bias=True) 63 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 64 | groups=dw_channel, 65 | bias=True) 66 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 67 | groups=1, bias=True) 68 | 69 | # Simplified Channel Attention 70 | self.sca = nn.Sequential( 71 | nn.AdaptiveAvgPool2d(1), 72 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 73 | groups=1, bias=True), 74 | ) 75 | 76 | # SimpleGate 77 | self.sg = SimpleGate() 78 | 79 | ffn_channel = FFN_Expand * c 80 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, 81 | bias=True) 82 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 83 | groups=1, bias=True) 84 | 85 | self.norm1 = LayerNorm2d(c) 86 | self.norm2 = LayerNorm2d(c) 87 | 88 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 89 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 90 | 91 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 92 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 93 | 94 | def forward(self, inp): 95 | x = inp 96 | 97 | x = self.norm1(x) 98 | 99 | x = self.conv1(x) 100 | x = self.conv2(x) 101 | x = self.sg(x) 102 | x = x * self.sca(x) 103 | x = self.conv3(x) 104 | 105 | x = self.dropout1(x) 106 | 107 | y = inp + x * self.beta 108 | 109 | x = self.conv4(self.norm2(y)) 110 | x = self.sg(x) 111 | x = self.conv5(x) 112 | 113 | x = self.dropout2(x) 114 | 115 | return y + x * self.gamma 116 | 117 | -------------------------------------------------------------------------------- /LLIE/models/archs/discriminator_vgg_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class Discriminator_VGG_128(nn.Module): 7 | def __init__(self, in_nc, nf): 8 | super(Discriminator_VGG_128, self).__init__() 9 | # [64, 128, 128] 10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 13 | # [64, 64, 64] 14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 18 | # [128, 32, 32] 19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 23 | # [256, 16, 16] 24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 28 | # [512, 8, 8] 29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 33 | 34 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 35 | self.linear2 = nn.Linear(100, 1) 36 | 37 | # activation function 38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 39 | 40 | def forward(self, x): 41 | fea = self.lrelu(self.conv0_0(x)) 42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 43 | 44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 46 | 47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 49 | 50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 52 | 53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 55 | 56 | fea = fea.view(fea.size(0), -1) 57 | fea = self.lrelu(self.linear1(fea)) 58 | out = self.linear2(fea) 59 | return out 60 | 61 | 62 | class VGGFeatureExtractor(nn.Module): 63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 64 | device=torch.device('cpu')): 65 | super(VGGFeatureExtractor, self).__init__() 66 | self.use_input_norm = use_input_norm 67 | if use_bn: 68 | model = torchvision.models.vgg19_bn(pretrained=True) 69 | else: 70 | model = torchvision.models.vgg19(pretrained=True) 71 | if self.use_input_norm: 72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 76 | self.register_buffer('mean', mean) 77 | self.register_buffer('std', std) 78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 79 | # No need to BP to variable 80 | for k, v in self.features.named_parameters(): 81 | v.requires_grad = False 82 | 83 | def forward(self, x): 84 | # Assume input range is [0, 1] 85 | if self.use_input_norm: 86 | x = (x - self.mean) / self.std 87 | output = self.features(x) 88 | return output 89 | -------------------------------------------------------------------------------- /LLIE/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | """Set learning rate for warmup 39 | lr_groups_l: list for lr_groups. each for a optimizer""" 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | """Get the initial lr, which is set by the scheduler""" 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | # set up warm-up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups] 67 | 68 | def get_network_description(self, network): 69 | """Get the string and total parameters of the network""" 70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 71 | network = network.module 72 | return str(network), sum(map(lambda x: x.numel(), network.parameters())) 73 | 74 | def save_network(self, network, network_label, iter_label): 75 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 76 | save_path = os.path.join(self.opt['path']['models'], save_filename) 77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 78 | network = network.module 79 | state_dict = network.state_dict() 80 | for key, param in state_dict.items(): 81 | state_dict[key] = param.cpu() 82 | torch.save(state_dict, save_path) 83 | 84 | def load_network(self, load_path, network, strict=True): 85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 86 | network = network.module 87 | if os.path.exists(load_path): 88 | load_net = torch.load(load_path) 89 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 90 | for k, v in load_net.items(): 91 | if k.startswith('module.'): 92 | load_net_clean[k[7:]] = v 93 | else: 94 | load_net_clean[k] = v 95 | network.load_state_dict(load_net_clean, strict=strict) 96 | print("Succcefully!!!!! pretrained model has loaded!!!!!!!!!!!!!!!") 97 | else: 98 | print("Wrong!!!!! pretrained path not exists") 99 | 100 | def save_training_state(self, epoch, iter_step): 101 | """Save training state during training, which will be used for resuming""" 102 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 103 | for s in self.schedulers: 104 | state['schedulers'].append(s.state_dict()) 105 | for o in self.optimizers: 106 | state['optimizers'].append(o.state_dict()) 107 | save_filename = '{}.state'.format(iter_step) 108 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 109 | torch.save(state, save_path) 110 | 111 | def resume_training(self, resume_state): 112 | """Resume the optimizers and schedulers for training""" 113 | resume_optimizers = resume_state['optimizers'] 114 | resume_schedulers = resume_state['schedulers'] 115 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 116 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 117 | for i, o in enumerate(resume_optimizers): 118 | self.optimizers[i].load_state_dict(o) 119 | for i, s in enumerate(resume_schedulers): 120 | self.schedulers[i].load_state_dict(s) 121 | -------------------------------------------------------------------------------- /LLIE/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | 17 | 18 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 19 | class GANLoss(nn.Module): 20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 21 | super(GANLoss, self).__init__() 22 | self.gan_type = gan_type.lower() 23 | self.real_label_val = real_label_val 24 | self.fake_label_val = fake_label_val 25 | 26 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 27 | self.loss = nn.BCEWithLogitsLoss() 28 | elif self.gan_type == 'lsgan': 29 | self.loss = nn.MSELoss() 30 | elif self.gan_type == 'wgan-gp': 31 | 32 | def wgan_loss(input, target): 33 | # target is boolean 34 | return -1 * input.mean() if target else input.mean() 35 | 36 | self.loss = wgan_loss 37 | else: 38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 39 | 40 | def get_target_label(self, input, target_is_real): 41 | if self.gan_type == 'wgan-gp': 42 | return target_is_real 43 | if target_is_real: 44 | return torch.empty_like(input).fill_(self.real_label_val) 45 | else: 46 | return torch.empty_like(input).fill_(self.fake_label_val) 47 | 48 | def forward(self, input, target_is_real): 49 | target_label = self.get_target_label(input, target_is_real) 50 | loss = self.loss(input, target_label) 51 | return loss 52 | 53 | 54 | class GradientPenaltyLoss(nn.Module): 55 | def __init__(self, device=torch.device('cpu')): 56 | super(GradientPenaltyLoss, self).__init__() 57 | self.register_buffer('grad_outputs', torch.Tensor()) 58 | self.grad_outputs = self.grad_outputs.to(device) 59 | 60 | def get_grad_outputs(self, input): 61 | if self.grad_outputs.size() != input.size(): 62 | self.grad_outputs.resize_(input.size()).fill_(1.0) 63 | return self.grad_outputs 64 | 65 | def forward(self, interp, interp_crit): 66 | grad_outputs = self.get_grad_outputs(interp_crit) 67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 68 | grad_outputs=grad_outputs, create_graph=True, 69 | retain_graph=True, only_inputs=True)[0] 70 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 71 | grad_interp_norm = grad_interp.norm(2, dim=1) 72 | 73 | loss = ((grad_interp_norm - 1)**2).mean() 74 | return loss 75 | 76 | 77 | class AMPLoss(nn.Module): 78 | def __init__(self): 79 | super(AMPLoss, self).__init__() 80 | self.cri = nn.L1Loss() 81 | 82 | def forward(self, x, y): 83 | x = torch.fft.rfft2(x, norm='backward') 84 | x_mag = torch.abs(x) 85 | y = torch.fft.rfft2(y, norm='backward') 86 | y_mag = torch.abs(y) 87 | 88 | return self.cri(x_mag,y_mag) 89 | 90 | 91 | class PhaLoss(nn.Module): 92 | def __init__(self): 93 | super(PhaLoss, self).__init__() 94 | self.cri = nn.L1Loss() 95 | 96 | def forward(self, x, y): 97 | x = torch.fft.rfft2(x, norm='backward') 98 | x_mag = torch.angle(x) 99 | y = torch.fft.rfft2(y, norm='backward') 100 | y_mag = torch.angle(y) 101 | 102 | return self.cri(x_mag, y_mag) -------------------------------------------------------------------------------- /LLIE/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restarts = [v + 1 for v in self.restarts] 16 | self.restart_weights = weights if weights else [1] 17 | assert len(self.restarts) == len( 18 | self.restart_weights), 'restarts and their weights do not match.' 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group['lr'] for group in self.optimizer.param_groups] 29 | return [ 30 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 31 | for group in self.optimizer.param_groups 32 | ] 33 | 34 | 35 | class CosineAnnealingLR_Restart(_LRScheduler): 36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 37 | self.T_period = T_period 38 | self.T_max = self.T_period[0] # current T period 39 | self.eta_min = eta_min 40 | self.restarts = restarts if restarts else [0] 41 | self.restarts = [v + 1 for v in self.restarts] 42 | self.restart_weights = weights if weights else [1] 43 | self.last_restart = 0 44 | assert len(self.restarts) == len( 45 | self.restart_weights), 'restarts and their weights do not match.' 46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | if self.last_epoch == 0: 50 | return self.base_lrs 51 | elif self.last_epoch in self.restarts: 52 | self.last_restart = self.last_epoch 53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 57 | return [ 58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 60 | ] 61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 63 | (group['lr'] - self.eta_min) + self.eta_min 64 | for group in self.optimizer.param_groups] 65 | 66 | 67 | if __name__ == "__main__": 68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 69 | betas=(0.9, 0.99)) 70 | ############################## 71 | # MultiStepLR_Restart 72 | ############################## 73 | ## Original 74 | lr_steps = [200000, 400000, 600000, 800000] 75 | restarts = None 76 | restart_weights = None 77 | 78 | ## two 79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 80 | restarts = [500000] 81 | restart_weights = [1] 82 | 83 | ## four 84 | lr_steps = [ 85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 87 | ] 88 | restarts = [250000, 500000, 750000] 89 | restart_weights = [1, 1, 1] 90 | 91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 92 | clear_state=False) 93 | 94 | ############################## 95 | # Cosine Annealing Restart 96 | ############################## 97 | ## two 98 | T_period = [500000, 500000] 99 | restarts = [500000] 100 | restart_weights = [1] 101 | 102 | ## four 103 | T_period = [250000, 250000, 250000, 250000] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1, 1, 1] 106 | 107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 108 | weights=restart_weights) 109 | 110 | ############################## 111 | # Draw figure 112 | ############################## 113 | N_iter = 1000000 114 | lr_l = list(range(N_iter)) 115 | for i in range(N_iter): 116 | scheduler.step() 117 | current_lr = optimizer.param_groups[0]['lr'] 118 | lr_l[i] = current_lr 119 | 120 | import matplotlib as mpl 121 | from matplotlib import pyplot as plt 122 | import matplotlib.ticker as mtick 123 | mpl.style.use('default') 124 | import seaborn 125 | seaborn.set(style='whitegrid') 126 | seaborn.set_context('paper') 127 | 128 | plt.figure(1) 129 | plt.subplot(111) 130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 131 | plt.title('Title', fontsize=16, color='k') 132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 133 | legend = plt.legend(loc='upper right', shadow=False) 134 | ax = plt.gca() 135 | labels = ax.get_xticks().tolist() 136 | for k, v in enumerate(labels): 137 | labels[k] = str(int(v / 1000)) + 'K' 138 | ax.set_xticklabels(labels) 139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 140 | 141 | ax.set_ylabel('Learning rate') 142 | ax.set_xlabel('Iteration') 143 | fig = plt.gcf() 144 | plt.show() 145 | -------------------------------------------------------------------------------- /LLIE/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models.archs.discriminator_vgg_arch as SRGAN_arch 3 | import models.archs.EnhanceN_arch as EnhanceN_arch 4 | import models.archs.EnhanceN_arch1 as EnhanceN_arch1 5 | 6 | 7 | # Generator 8 | def define_G(opt): 9 | opt_net = opt['network_G'] 10 | which_model = opt_net['which_model_G'] 11 | 12 | # video restoration 13 | if which_model == 'Net': 14 | netG = EnhanceN_arch.InteractNet(nc=opt_net['nc']) 15 | elif which_model == 'Net1': 16 | netG = EnhanceN_arch1.InteractNet(nc=opt_net['nc']) 17 | else: 18 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 19 | 20 | return netG 21 | 22 | 23 | # Discriminator 24 | def define_D(opt): 25 | opt_net = opt['network_D'] 26 | which_model = opt_net['which_model_D'] 27 | 28 | if which_model == 'discriminator_vgg_128': 29 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 30 | else: 31 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 32 | return netD 33 | 34 | 35 | # Define network used for perceptual loss 36 | def define_F(opt, use_bn=False): 37 | gpu_ids = opt['gpu_ids'] 38 | device = torch.device('cuda' if gpu_ids else 'cpu') 39 | # PyTorch pretrained VGG19-54, before ReLU. 40 | if use_bn: 41 | feature_layer = 49 42 | else: 43 | feature_layer = 34 44 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 45 | use_input_norm=True, device=device) 46 | netF.eval() # No need to train 47 | return netF 48 | -------------------------------------------------------------------------------- /LLIE/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/options/__init__.py -------------------------------------------------------------------------------- /LLIE/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | if dataset.get('dataroot_LQ', None) is not None: 33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 34 | if dataset['dataroot_LQ'].endswith('lmdb'): 35 | is_lmdb = True 36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 37 | if dataset['mode'].endswith('mc'): # for memcached 38 | dataset['data_type'] = 'mc' 39 | dataset['mode'] = dataset['mode'].replace('_mc', '') 40 | 41 | # path 42 | for key, path in opt['path'].items(): 43 | if path and key in opt['path'] and key != 'strict_load': 44 | opt['path'][key] = osp.expanduser(path) 45 | # opt['path']['root'] = '/home/zhanghc/IEEEyellow/code_enhance/checkpoints' 46 | if is_train: 47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 48 | opt['path']['experiments_root'] = experiments_root 49 | opt['path']['models'] = osp.join(experiments_root, 'models') 50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 51 | opt['path']['log'] = experiments_root 52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 53 | 54 | # change some options for debug mode 55 | if 'debug' in opt['name']: 56 | opt['train']['val_freq'] = 8 57 | opt['logger']['print_freq'] = 1 58 | opt['logger']['save_checkpoint_freq'] = 8 59 | else: # test 60 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 61 | opt['path']['results_root'] = results_root 62 | opt['path']['log'] = results_root 63 | 64 | # network 65 | if opt['distortion'] == 'sr': 66 | opt['network_G']['scale'] = scale 67 | 68 | return opt 69 | 70 | 71 | def dict2str(opt, indent_l=1): 72 | '''dict to string for logger''' 73 | msg = '' 74 | for k, v in opt.items(): 75 | if isinstance(v, dict): 76 | msg += ' ' * (indent_l * 2) + k + ':[\n' 77 | msg += dict2str(v, indent_l + 1) 78 | msg += ' ' * (indent_l * 2) + ']\n' 79 | else: 80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 81 | return msg 82 | 83 | 84 | class NoneDict(dict): 85 | def __missing__(self, key): 86 | return None 87 | 88 | 89 | # convert to NoneDict, which return None for missing key. 90 | def dict_to_nonedict(opt): 91 | if isinstance(opt, dict): 92 | new_opt = dict() 93 | for key, sub_opt in opt.items(): 94 | new_opt[key] = dict_to_nonedict(sub_opt) 95 | return NoneDict(**new_opt) 96 | elif isinstance(opt, list): 97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 98 | else: 99 | return opt 100 | 101 | 102 | def check_resume(opt, resume_iter): 103 | '''Check resume states and pretrain_model paths''' 104 | logger = logging.getLogger('base') 105 | if opt['path']['resume_state']: 106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 107 | 'pretrain_model_D', None) is not None: 108 | logger.warning('pretrain_model path will be ignored when resuming training.') 109 | 110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 111 | '{}_G.pth'.format(resume_iter)) 112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 113 | if 'gan' in opt['model']: 114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 115 | '{}_D.pth'.format(resume_iter)) 116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 117 | -------------------------------------------------------------------------------- /LLIE/options/test/test_ESRGAN.yml: -------------------------------------------------------------------------------- 1 | name: RRDB_ESRGAN_x4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set5 12 | mode: LQGT 13 | dataroot_GT: ../datasets/val_set5/Set5 14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 15 | test_2: # the 2st test dataset 16 | name: set14 17 | mode: LQGT 18 | dataroot_GT: ../datasets/val_set14/Set14 19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 20 | 21 | #### network structures 22 | network_G: 23 | which_model_G: RRDBNet 24 | in_nc: 3 25 | out_nc: 3 26 | nf: 64 27 | nb: 23 28 | upscale: 4 29 | 30 | #### path 31 | path: 32 | pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth 33 | -------------------------------------------------------------------------------- /LLIE/options/test/test_SRGAN.yml: -------------------------------------------------------------------------------- 1 | name: MSRGANx4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set5 12 | mode: LQGT 13 | dataroot_GT: ../datasets/val_set5/Set5 14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 15 | test_2: # the 2st test dataset 16 | name: set14 17 | mode: LQGT 18 | dataroot_GT: ../datasets/val_set14/Set14 19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 20 | 21 | #### network structures 22 | network_G: 23 | which_model_G: MSRResNet 24 | in_nc: 3 25 | out_nc: 3 26 | nf: 64 27 | nb: 16 28 | upscale: 4 29 | 30 | #### path 31 | path: 32 | pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth 33 | -------------------------------------------------------------------------------- /LLIE/options/test/test_SRResNet.yml: -------------------------------------------------------------------------------- 1 | name: MSRResNetx4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set5 12 | mode: LQGT 13 | dataroot_GT: ../datasets/val_set5/Set5 14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 15 | test_2: # the 2st test dataset 16 | name: set14 17 | mode: LQGT 18 | dataroot_GT: ../datasets/val_set14/Set14 19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 20 | test_3: 21 | name: bsd100 22 | mode: LQGT 23 | dataroot_GT: ../datasets/BSD/BSDS100 24 | dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4 25 | test_4: 26 | name: urban100 27 | mode: LQGT 28 | dataroot_GT: ../datasets/urban100 29 | dataroot_LQ: ../datasets/urban100_bicLRx4 30 | test_5: 31 | name: div2k100 32 | mode: LQGT 33 | dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR 34 | dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4 35 | 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: MSRResNet 40 | in_nc: 3 41 | out_nc: 3 42 | nf: 64 43 | nb: 16 44 | upscale: 4 45 | 46 | #### path 47 | path: 48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth 49 | -------------------------------------------------------------------------------- /LLIE/options/train/train_Enhance.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: STEN 3 | use_tb_logger: true 4 | model: sr 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: UEN 13 | mode: UEN_train 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | # dataroot: /data/1760921465/dped/iphone/test_data/patches 18 | cache_keys: ~ 19 | filelist: /home/jieh/Projects/FourRestore/MainNet/data/groups_train_Huawei.txt 20 | 21 | use_shuffle: true 22 | n_workers: 0 # per GPU 23 | batch_size: 4 24 | IN_size: 384 25 | augment: true 26 | color: RGB 27 | 28 | val: 29 | name: UEN 30 | mode: UEN_val 31 | # dataroot: /data/1760921465/dped/iphone/test_data/patches 32 | filelist: /home/jieh/Projects/FourRestore/MainNet/data/groups_test_Huawei.txt 33 | 34 | batch_size: 1 35 | use_shuffle: false 36 | 37 | 38 | #### network structures 39 | network_G: 40 | which_model_G: Net 41 | nc: 8 42 | groups: 8 43 | 44 | #### path 45 | path: 46 | root: /home/jieh/Projects/FourRestore/MainNet/output 47 | results_root: /home/jieh/Projects/FourRestore/MainNet/output 48 | pretrain: /home/jieh/Projects/FourRestore/MainNet/pretrain 49 | pretrain_model_G: /home/jieh/Projects/FourRestore/MainNet/output/experiments/SOTA_Huawei/models/0_bestavg.pth 50 | strict_load: false 51 | resume_state: ~ 52 | 53 | #### training settings: learning rate scheme, loss 54 | train: 55 | lr_G: !!float 8e-4 56 | lr_scheme: MultiStepLR 57 | beta1: 0.9 58 | beta2: 0.99 59 | niter: 80000 60 | fix_some_part: ~ 61 | warmup_iter: -1 # -1: no warm up 62 | is_training: False 63 | #### for cosine adjustment 64 | # T_period: [400000, 1000000, 1500000, 1500000, 1500000] 65 | # restarts: [400000, 1400000, 2700000, 4200000] 66 | # restart_weights: [1, 1, 1, 1] 67 | lr_scheme: MultiStepLR 68 | lr_steps: [10000, 30000, 50000, 60000] 69 | lr_gamma: 0.5 70 | 71 | eta_min: !!float 5e-6 72 | pixel_criterion: l1 73 | pixel_weight: 5000.0 74 | ssim_weight: 1000.0 75 | vgg_weight: 1000.0 76 | 77 | val_epoch: !!float 1 78 | manual_seed: 0 79 | 80 | #### logger 81 | logger: 82 | print_freq: 20 83 | save_checkpoint_epoch: !!float 100 84 | -------------------------------------------------------------------------------- /LLIE/options/train/train_Enhance1.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: STEN 3 | use_tb_logger: true 4 | model: sr 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [1] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: UEN 13 | mode: UEN_train 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | # dataroot: /data/1760921465/dped/iphone/test_data/patches 18 | cache_keys: ~ 19 | filelist: /home/jieh/Projects/FourRestore/MainNet/data/groups_train_LOL.txt 20 | 21 | use_shuffle: true 22 | n_workers: 0 # per GPU 23 | batch_size: 4 24 | IN_size: 384 25 | augment: true 26 | color: RGB 27 | 28 | val: 29 | name: UEN 30 | mode: UEN_val 31 | # dataroot: /data/1760921465/dped/iphone/test_data/patches 32 | filelist: /home/jieh/Projects/FourRestore/MainNet/data/groups_test_LOL.txt 33 | 34 | batch_size: 1 35 | use_shuffle: false 36 | 37 | 38 | #### network structures 39 | network_G: 40 | which_model_G: Net1 41 | nc: 8 42 | groups: 8 43 | 44 | #### path 45 | path: 46 | root: /home/jieh/Projects/FourRestore/MainNet/output1 47 | results_root: /home/jieh/Projects/FourRestore/MainNet/output1 48 | pretrain: /home/jieh/Projects/FourRestore/MainNet/output1 49 | pretrain_model_G: ~ 50 | strict_load: false 51 | resume_state: ~ 52 | 53 | #### training settings: learning rate scheme, loss 54 | train: 55 | lr_G: !!float 8e-4 56 | lr_scheme: MultiStepLR 57 | beta1: 0.9 58 | beta2: 0.99 59 | niter: 80000 60 | fix_some_part: ~ 61 | warmup_iter: -1 # -1: no warm up 62 | 63 | #### for cosine adjustment 64 | # T_period: [400000, 1000000, 1500000, 1500000, 1500000] 65 | # restarts: [400000, 1400000, 2700000, 4200000] 66 | # restart_weights: [1, 1, 1, 1] 67 | lr_scheme: MultiStepLR 68 | lr_steps: [10000, 30000, 50000, 60000] 69 | lr_gamma: 0.5 70 | 71 | eta_min: !!float 5e-6 72 | pixel_criterion: l1 73 | pixel_weight: 5000.0 74 | ssim_weight: 1000.0 75 | vgg_weight: 1000.0 76 | 77 | val_epoch: !!float 1 78 | manual_seed: 0 79 | 80 | #### logger 81 | logger: 82 | print_freq: 20 83 | save_checkpoint_epoch: !!float 100 84 | -------------------------------------------------------------------------------- /LLIE/pretrain/LOL/0_bestavg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/pretrain/LOL/0_bestavg.pth -------------------------------------------------------------------------------- /LLIE/readme.md: -------------------------------------------------------------------------------- 1 | ## Applications 2 | ### Low-Light Image Enhancement 3 | #### Prepare data 4 | Download the training data and add the data path to the config file (/basicsr/option/train/LLIE/*.yml). Please refer to [LOL](https://daooshee.github.io/BMVC2018website/) and [Huawei](https://drive.google.com/drive/folders/1rFUSdcw833haZfkGKODvu1hrv2pgxYT_?usp=drive_link) (it includes 2480 images, and we we randomly select 2200 images for training and the remaining 280 for testing) for data download. 5 | #### Training 6 | ``` 7 | python /LLIE/train.py -opt options/train/train_Enhance.yml 8 | ``` 9 | -------------------------------------------------------------------------------- /LLIE/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | 7 | import options.options as option 8 | import utils.util as util 9 | from data.util import bgr2ycbcr 10 | from data import create_dataset, create_dataloader 11 | from models import create_model 12 | 13 | #### options 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 16 | opt = option.parse(parser.parse_args().opt, is_train=False) 17 | opt = option.dict_to_nonedict(opt) 18 | 19 | util.mkdirs( 20 | (path for key, path in opt['path'].items() 21 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 22 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 23 | screen=True, tofile=True) 24 | logger = logging.getLogger('base') 25 | logger.info(option.dict2str(opt)) 26 | 27 | #### Create test dataset and dataloader 28 | test_loaders = [] 29 | for phase, dataset_opt in sorted(opt['datasets'].items()): 30 | test_set = create_dataset(dataset_opt) 31 | test_loader = create_dataloader(test_set, dataset_opt) 32 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 33 | test_loaders.append(test_loader) 34 | 35 | model = create_model(opt) 36 | for test_loader in test_loaders: 37 | test_set_name = test_loader.dataset.opt['name'] 38 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 39 | test_start_time = time.time() 40 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 41 | util.mkdir(dataset_dir) 42 | 43 | test_results = OrderedDict() 44 | test_results['psnr'] = [] 45 | test_results['ssim'] = [] 46 | test_results['psnr_y'] = [] 47 | test_results['ssim_y'] = [] 48 | 49 | for data in test_loader: 50 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True 51 | model.feed_data(data, need_GT=need_GT) 52 | img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] 53 | img_name = osp.splitext(osp.basename(img_path))[0] 54 | 55 | model.test() 56 | visuals = model.get_current_visuals(need_GT=need_GT) 57 | 58 | sr_img = util.tensor2img(visuals['rlt']) # uint8 59 | 60 | # save images 61 | suffix = opt['suffix'] 62 | if suffix: 63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') 64 | else: 65 | save_img_path = osp.join(dataset_dir, img_name + '.png') 66 | util.save_img(sr_img, save_img_path) 67 | 68 | # calculate PSNR and SSIM 69 | if need_GT: 70 | gt_img = util.tensor2img(visuals['GT']) 71 | sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) 72 | psnr = util.calculate_psnr(sr_img, gt_img) 73 | ssim = util.calculate_ssim(sr_img, gt_img) 74 | test_results['psnr'].append(psnr) 75 | test_results['ssim'].append(ssim) 76 | 77 | if gt_img.shape[2] == 3: # RGB image 78 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) 79 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) 80 | 81 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) 82 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) 83 | test_results['psnr_y'].append(psnr_y) 84 | test_results['ssim_y'].append(ssim_y) 85 | logger.info( 86 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. 87 | format(img_name, psnr, ssim, psnr_y, ssim_y)) 88 | else: 89 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) 90 | else: 91 | logger.info(img_name) 92 | 93 | if need_GT: # metrics 94 | # Average PSNR/SSIM results 95 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 96 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 97 | logger.info( 98 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( 99 | test_set_name, ave_psnr, ave_ssim)) 100 | if test_results['psnr_y'] and test_results['ssim_y']: 101 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) 102 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) 103 | logger.info( 104 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. 105 | format(ave_psnr_y, ave_ssim_y)) 106 | -------------------------------------------------------------------------------- /LLIE/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/LLIE/utils/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fourmer: An Efficient Global Modeling Paradigm for Image Restoration 2 | Welcome to the Fourmer repository! This repository contains implementations for the Fourmer model, an efficient global modeling paradigm for image restoration tasks including dehazing, deraining, and low-light image enhancement (LLIE). 3 | 4 | For specific details and implementations, please refer to the following folders: 5 | 6 | - dehazing: Contains implementations for image dehazing using the Fourmer model. 7 | - deraining: Contains implementations for image deraining using the Fourmer model. 8 | - llie: Contains implementations for low-light image enhancement (LLIE) using the Fourmer model. 9 | - pan-sharpening:Contains implementations for pan-sharpening using the Fourmer model. 10 | 11 | -------------------------------------------------------------------------------- /pan-sharpening/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | examples/* 3 | !examples/.gitkeep 4 | *.DS_Store 5 | .idea/ 6 | local/ 7 | 8 | data/* 9 | results/* 10 | weights/* 11 | model/* 12 | logs/* 13 | 14 | visualization/out* 15 | envs/*/*/*.txt 16 | envs/*/*/*.json 17 | envs/*/*.json 18 | wandb/* 19 | -------------------------------------------------------------------------------- /pan-sharpening/configs/config.yaml: -------------------------------------------------------------------------------- 1 | name: HGK # project_name 2 | dataset_name: 'WV3' # only support 'GF2', 'QB', 'WV3', 'WV2' now 3 | #model_name: 'HGK' # only support 'HGK', 'HGK2', 'Fourmer', 'GPPNN' and related ablation parts of HGK2 now 4 | #model_name: 'HGK2' 5 | #model_name: 'Fuser' 6 | # model_name: 'PANINN' 7 | #model_name: 'woCFI' 8 | #model_name: 'woCBIFA' 9 | #model_name: 'woSSM' 10 | #model_name: 'GPPNN' 11 | model_name: 'Fourmer' 12 | #with_hist_loss: True 13 | with_hist_loss: False 14 | #with_grad_loss: True 15 | with_grad_loss: False 16 | alpha: 0.1 17 | #with_hist_loss: False 18 | #with_hist_loss: False 19 | hidden_channel: 8 20 | #hidden_channel: 32 21 | 22 | epoch_num: 500 23 | #epoch_num: 300 24 | #epoch_num: 1000 25 | batch_size: 32 26 | #batch_size: 24 27 | #batch_size: 16 28 | base_lr: 5e-4 29 | #base_lr: 1e-4 30 | #base_lr: 3e-4 31 | # parameters for scheduler lr 32 | gamma: 0.1 33 | step_size: 250 34 | 35 | data_path: './data_files' 36 | # tensorboard file path 37 | log_dir: './logs' 38 | tb_log_path: './tb_logs' 39 | weights_path: './weights/' 40 | results_path: './results/' 41 | gpu_list: [0] 42 | workers: 0 43 | save_epoch: 25 44 | #resume: 45 | # path to the latest check point (default - none) 46 | #resume_log_id: 47 | # the id of resumed run if you want to continue log 48 | #start_epoch: 0 49 | 50 | test_mode: 'reduced' 51 | #test_mode: 'full' 52 | test_weight_path: 'weights/WV3/20240101-152247/CSNET500.pth' 53 | -------------------------------------------------------------------------------- /pan-sharpening/datasets/data.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import h5py 4 | import cv2 5 | import numpy as np 6 | from pathlib import Path 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def get_edge(data): # for training 11 | rs = np.zeros_like(data) 12 | N = data.shape[0] 13 | for i in range(N): 14 | if len(data.shape) == 3: 15 | rs[i, :, :] = data[i, :, :] - cv2.boxFilter(data[i, :, :], -1, (5, 5)) 16 | else: 17 | rs[i, :, :, :] = data[i, :, :, :] - cv2.boxFilter(data[i, :, :, :], -1, (5, 5)) 18 | return rs 19 | 20 | 21 | class Dataset_Pro(data.Dataset): 22 | def __init__(self, file_path): 23 | super(Dataset_Pro, self).__init__() 24 | data = h5py.File(file_path) # NxCxHxW = 0x1x2x3 25 | # max_value = 1023.0 if 'gf2' in file_path else 2047.0 26 | max_value = 2047.0 27 | 28 | # tensor type: 29 | gt1 = data["gt"][...] # convert to np tpye for CV2.filter 30 | gt1 = np.array(gt1, dtype=np.float32) / max_value 31 | self.gt = torch.from_numpy(gt1) # NxCxHxW: 32 | 33 | lms1 = data["lms"][...] # convert to np tpye for CV2.filter 34 | lms1 = np.array(lms1, dtype=np.float32) / max_value 35 | self.lms = torch.from_numpy(lms1) 36 | 37 | ms1 = data["ms"][...] # NxCxHxW 38 | ms1 = np.array(ms1, dtype=np.float32) / max_value 39 | self.ms = torch.from_numpy(ms1) 40 | # ms1 = np.array(ms1.transpose(0, 2, 3, 1), dtype=np.float32) / max_value # NxHxWxC 41 | # ms1_tmp = get_edge(ms1) # NxHxWxC 42 | # self.ms_hp = torch.from_numpy(ms1_tmp).permute(0, 3, 1, 2) # NxCxHxW: 43 | 44 | pan1 = data['pan'][...] # Nx1xHxW 45 | pan1 = np.array(pan1.transpose(0, 2, 3, 1), dtype=np.float32) / max_value # NxHxWx1 46 | pan1 = np.squeeze(pan1, axis=3) # NxHxW 47 | pan_hp_tmp = get_edge(pan1) # NxHxW 48 | pan_hp_tmp = np.expand_dims(pan_hp_tmp, axis=3) # NxHxWx1 49 | self.pan_hp = torch.from_numpy(pan_hp_tmp).permute(0, 3, 1, 2) # Nx1xHxW: 50 | 51 | pan1 = data['pan'][...] # Nx1xHxW 52 | pan1 = np.array(pan1, dtype=np.float32) / max_value # Nx1xHxW 53 | self.pan = torch.from_numpy(pan1) # Nx1xHxW: 54 | 55 | def __getitem__(self, index): 56 | return self.gt[index, :, :, :].float(), self.lms[index, :, :, :].float(), \ 57 | self.ms[index, :, :, :].float(), self.pan_hp[index, :, :, :].float(), \ 58 | self.pan[index, :, :, :].float() 59 | 60 | def __len__(self): 61 | return self.gt.shape[0] 62 | 63 | 64 | def create_loaders(config): 65 | assert config.dataset_name in ('GF2', 'QB', 'WV3', 'WV2') 66 | data_path = Path(config.data_path) 67 | batch_size = config.batch_size 68 | 69 | # if training: 70 | dataset_name = config.dataset_name.lower() 71 | train_data_path = str(data_path / f'training_{dataset_name}/train_{dataset_name}.h5') 72 | train_set = Dataset_Pro(train_data_path) 73 | training_data_loader = DataLoader(dataset=train_set, num_workers=config.workers, batch_size=batch_size, 74 | shuffle=True, pin_memory=True, drop_last=True) 75 | print('Train set ground truth shape', train_set.gt.shape) 76 | 77 | validate_data_path = str(data_path / f'training_{dataset_name}/valid_{dataset_name}.h5') 78 | validate_set = Dataset_Pro(validate_data_path) 79 | validate_data_loader = DataLoader(dataset=validate_set, num_workers=0, batch_size=batch_size, shuffle=False, 80 | pin_memory=True, drop_last=True) 81 | print('Validate set ground truth shape', validate_set.gt.shape) 82 | return training_data_loader, validate_data_loader 83 | # else: 84 | # dataset_name = config.dataset_name.lower() 85 | # test_mode = {'reduced': 'reduce_examples', 'full': 'full_examples'}[config.test_mode] 86 | # test_data_path = str(data_path / f'test_data/h5/{dataset_name}/{test_mode}/test_{dataset_name}_multiExm1.h5') 87 | # test_set = Dataset_Pro(test_data_path) 88 | # test_data_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=batch_size, shuffle=False, 89 | # pin_memory=True, drop_last=True) 90 | # return test_data_loader 91 | -------------------------------------------------------------------------------- /pan-sharpening/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utils.util as u 3 | from utils.global_config import parse_args, init_global_config 4 | from models.pipeline import Trainer 5 | import torch 6 | 7 | 8 | if __name__ == "__main__": 9 | args = parse_args() 10 | config = init_global_config(args) 11 | # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in config.gpu_list) 12 | logger = u.get_logger(config) 13 | u.setup_seed(10) 14 | u.log_args_and_parameters(logger, args, config) 15 | 16 | trainer = Trainer(config, logger) 17 | trainer.train_all() 18 | -------------------------------------------------------------------------------- /pan-sharpening/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import HGKNet 2 | from .model2 import HGKNet2 3 | from .model2 import woCFI, woCBIFA, woSSM 4 | from .Network import Fourmer 5 | from .GPPNN import GPPNN 6 | from .hazer_cfm_adp import Fuser 7 | from .pan_inn9 import pan_inn as PANINN 8 | -------------------------------------------------------------------------------- /pan-sharpening/models/grad_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GradLoss(nn.Module): 7 | def __init__(self): 8 | super(GradLoss, self).__init__() 9 | self.sobelconv = Sobelxy() 10 | 11 | def forward(self, x, y): 12 | x_grad = self.sobelconv(x) 13 | y_grad = self.sobelconv(y) 14 | # loss_grad = F.l1_loss(x_grad, y_grad) 15 | loss_grad = F.mse_loss(x_grad, y_grad) 16 | return loss_grad 17 | 18 | 19 | class Sobelxy(nn.Module): 20 | def __init__(self): 21 | super(Sobelxy, self).__init__() 22 | kernelx = [[-1, 0, 1], 23 | [-2, 0, 2], 24 | [-1, 0, 1]] 25 | kernely = [[1, 2, 1], 26 | [0, 0, 0], 27 | [-1, -2, -1]] 28 | kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0) 29 | kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0) 30 | self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda() 31 | self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda() 32 | 33 | def forward(self, x): 34 | b, c, h, w = x.shape 35 | x = x.reshape(b*c, 1, h, w) 36 | sobelx = F.conv2d(x, self.weightx, padding=1, groups=1) 37 | sobely = F.conv2d(x, self.weighty, padding=1, groups=1) 38 | out = torch.abs(sobelx) + torch.abs(sobely) 39 | out = out.reshape(b, c, h, w) 40 | return out 41 | -------------------------------------------------------------------------------- /pan-sharpening/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.global_config import parse_args, init_global_config 3 | from models.pipeline import Tester 4 | 5 | 6 | if __name__ == '__main__': 7 | args = parse_args() 8 | config = init_global_config(args) 9 | # os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in config.gpu_list) 10 | tester = Tester(config) 11 | # tester.test(analyse_fms=True) 12 | tester.test() 13 | -------------------------------------------------------------------------------- /pan-sharpening/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manman1995/Fourmer/02b66eb06ffafca13e6e2272adf1bc0992882634/pan-sharpening/utils/__init__.py -------------------------------------------------------------------------------- /pan-sharpening/utils/global_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import json 4 | from argparse import ArgumentParser 5 | from utils.util import yaml_read 6 | 7 | 8 | class Config: 9 | def __init__(self, **entries): 10 | for k, v in entries.items(): 11 | if isinstance(v, dict): 12 | entries[k] = Config(**v) 13 | self.__dict__.update(entries) 14 | 15 | def __str__(self): 16 | return '\n'.join(f"{key}: {value}" for key, value in self.__dict__.items()) 17 | # return json.dumps(self.__dict__) 18 | 19 | 20 | def str2bool(v): 21 | if isinstance(v, bool): 22 | return v 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise TypeError('Boolean value expected.') 29 | 30 | 31 | # Setup 32 | def parse_args(): 33 | # here only support partial arguments in config.yaml, plus debug and config_file arguments. 34 | # and extra arguments will append into run_cfg, full argument list can be found in config.yaml file. 35 | # default value is also specified by config.yaml file. 36 | parser = ArgumentParser(description='FourierAttention') 37 | parser.add_argument('--config_file', type=str, default='configs/config.yaml', help='The global config file') 38 | parser.add_argument('--debug', type=str2bool, const=True, nargs='?', 39 | default=False, help='When in the debug mode, it will not record logs') 40 | 41 | # train config 42 | parser.add_argument('--epoch_num', type=int, help='The number of epoch for training') 43 | parser.add_argument('--batch_size', type=int, help='Batch size for training') 44 | parser.add_argument('--dataset_name', type=str, choices=('GF2', 'QB', 'WV3', 'WV2'), 45 | help='The dataset name, support GF2, QB, WV3') 46 | parser.add_argument('--data_path', type=str, help='The path to the dataset file') 47 | parser.add_argument('--log_dir', type=str, help='Log dir') 48 | parser.add_argument('--gpu_list', type=int, nargs='+', help='The list of used gpu') 49 | parser.add_argument('--workers', type=int, help='Data loader workers') 50 | args = parser.parse_args() 51 | 52 | # test config 53 | parser.add_argument('--test_mode', type=str, choices=('full', 'reduced'), help='Choose the test mode') 54 | parser.add_argument('--test_weight_path', type=str, help='The model weight path for testing') 55 | return args 56 | 57 | 58 | # the argument in parse_args or specified from CLI will override values in config.yaml 59 | def init_global_config(args): 60 | cfg_file = args.config_file 61 | config = yaml_read(cfg_file) 62 | for k, v in vars(args).items(): 63 | if v is None: 64 | continue 65 | if k in config.keys(): 66 | config[k] = v 67 | g_c = Config(**config) 68 | g_c.debug = args.debug 69 | return g_c 70 | -------------------------------------------------------------------------------- /pan-sharpening/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import random 5 | import numpy as np 6 | from datetime import datetime 7 | import logging 8 | from sklearn.manifold import TSNE 9 | # from sklearn.decomposition import PCA 10 | import matplotlib.pyplot as plt 11 | from tqdm import tqdm 12 | import pickle 13 | import cv2 14 | 15 | 16 | def get_logger(config): 17 | if not os.path.exists(config.log_dir): 18 | os.makedirs(config.log_dir) 19 | logger = logging.getLogger('main') 20 | logger.setLevel(logging.DEBUG) 21 | 22 | file_name_time = f"{datetime.now().strftime('%Y%m%d-%H%M%S')}" 23 | file_name = f"{config.log_dir}/{file_name_time}" 24 | 25 | if not config.debug: 26 | fh = logging.FileHandler(file_name + '.log') 27 | fh.setLevel(logging.DEBUG) 28 | logger.addHandler(fh) 29 | sh = logging.StreamHandler() 30 | sh.setLevel(logging.INFO) 31 | logger.addHandler(sh) 32 | return logger 33 | 34 | 35 | class AverageMeter(object): 36 | """Computes and stores the average and current value""" 37 | 38 | def __init__(self): 39 | self.reset() 40 | 41 | def reset(self): 42 | self.val = 0 43 | self.avg = 0 44 | self.sum = 0 45 | self.count = 0 46 | 47 | def update(self, val, n=1): 48 | self.val = val 49 | self.sum += val * n 50 | self.count += n 51 | self.avg = self.sum / self.count 52 | 53 | 54 | def log_args_and_parameters(logger, args, config): 55 | logger.info("config_file: ") 56 | logger.info(args.config_file) 57 | logger.info("args: ") 58 | logger.info(args) 59 | logger.info("config: ") 60 | logger.info(config) 61 | 62 | 63 | def setup_seed(seed): 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed_all(seed) 66 | np.random.seed(seed) 67 | random.seed(seed) 68 | torch.backends.cudnn.deterministic = True 69 | 70 | 71 | def yaml_read(yaml_file): 72 | with open(yaml_file) as f: 73 | data = yaml.load(f, Loader=yaml.FullLoader) 74 | return data 75 | 76 | 77 | def calc_fourier_magnitude(data): 78 | dft = cv2.dft(data, flags=cv2.DFT_COMPLEX_OUTPUT) 79 | # Shift the zero-frequency component from the left-top corner to the center of the spectrum 80 | dft_shift = np.fft.fftshift(dft) 81 | # convert the magnitude of fourier complex into 0-255 82 | magnitude = 20 * np.log(cv2.magnitude(dft_shift[:, :, 0], dft_shift[:, :, 1])) 83 | return magnitude 84 | 85 | 86 | def plot_feature_maps(block_num, features, gts=None, rgb_idx=None, save_path=None): 87 | # for i in range(1): 88 | for i in tqdm(range(len(gts))): 89 | fig, axes = plt.subplots(1, block_num + 1 if gts is not None else block_num, figsize=(16, 10), sharey='row') 90 | fig.set_tight_layout(True) 91 | for j in range(block_num): 92 | cur_block_fm = features[f'rijab_{j}'][0] 93 | b, c, h, w = cur_block_fm.shape 94 | cur_block_fm = cur_block_fm[i].reshape(h*w, c) 95 | cur_path = save_path / f"tsne_data/blocks_b{i}_block{j}.pkl" 96 | if cur_path.exists(): 97 | with open(str(cur_path), 'rb') as file: 98 | cur_block_fm = pickle.load(file) 99 | else: 100 | cur_block_fm = TSNE(n_components=1, perplexity=30).fit_transform(cur_block_fm) 101 | with open(str(cur_path), 'wb') as file: 102 | strs = pickle.dumps(cur_block_fm) 103 | file.write(strs) 104 | 105 | # pca = PCA(n_components=1).fit(cur_block_fm) 106 | # print(pca.explained_variance_ratio_) 107 | # cur_block_fm = pca.transform(cur_block_fm) 108 | 109 | # cur_block_fm = cur_block_fm.sum(-1) 110 | cur_block_fm = cur_block_fm.reshape(h, w) 111 | cur_block_fm = calc_fourier_magnitude(cur_block_fm) 112 | 113 | # temporal_fm1 = features[f'rijab_{j}.sat1'] 114 | # temporal_fm3 = features[f'rijab_{j}.sat3'] 115 | # temporal_fm5 = features[f'rijab_{j}.sat5'] 116 | axes[j].set_title(f'rijab_{j}') 117 | # axes[j].imshow(cur_block_fm) 118 | # axes[j].axis('off') 119 | axes[j].hist(cur_block_fm) 120 | if gts is not None: 121 | gt_gray = gts[i][rgb_idx] * np.array([0.3, 0.59, 0.11])[:, None, None] 122 | gt_gray = gt_gray.sum(0) 123 | gt_gray = calc_fourier_magnitude(gt_gray) 124 | axes[-1].set_title('GT') 125 | # axes[-1].imshow(gt_gray) 126 | # axes[-1].axis('off') 127 | axes[-1].hist(gt_gray) 128 | if save_path is not None: 129 | # plt.savefig(str(save_path / f"blocks_b{i}.png")) 130 | # plt.savefig(str(save_path / f"fdomain_blocks_b{i}.png")) 131 | plt.savefig(str(save_path / f"fhist_blocks_b{i}.png")) 132 | # plt.show() 133 | plt.close() 134 | return 135 | --------------------------------------------------------------------------------