├── median ├── __init__.py └── median_derain.py ├── ensemble ├── __init__.py └── ensemble_derain.py ├── restormer_x ├── __init__.py ├── dataset │ ├── __init__.py │ └── gt_rain_dataset.py ├── model │ ├── __init__.py │ └── restormer.py ├── utils │ ├── __init__.py │ ├── log.py │ ├── mixmethod.py │ ├── loss.py │ ├── data_augmentation.py │ └── trainutil.py ├── test.py └── train.py ├── post_process ├── __init__.py ├── post_process_derain.py └── estimate_pixels.py ├── .idea ├── vcs.xml ├── misc.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml ├── Restormer-Plus.iml └── deployment.xml ├── requirements.txt ├── repeat300.py ├── LICENSE ├── README.md └── .gitignore /median/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /restormer_x/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /post_process/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /restormer_x/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /restormer_x/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /restormer_x/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.3.0 2 | natsort==8.3.1 3 | numpy==1.21.5 4 | opencv_contrib_python==4.2.0.32 5 | Pillow==9.2.0 6 | piq==0.7.0 7 | skimage==0.0 8 | tabulate==0.8.10 9 | torch==1.12.1 10 | torchvision==0.13.1 -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/Restormer-Plus.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /repeat300.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from glob import glob 4 | 5 | from natsort import natsorted 6 | 7 | root_dir = '/gt-rain/result/post_process' 8 | scene_names = [] 9 | for sc in list(os.walk(root_dir))[0][1]: 10 | scene_names.append(sc) 11 | 12 | img_paths = {} 13 | for scene in scene_names: 14 | scene_path = os.path.join(root_dir, scene) 15 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*_r.png'))) 16 | img_paths[scene] = scene_img_paths 17 | 18 | for scene_name, im_paths in img_paths.items(): 19 | print(scene_name) 20 | origin_file = im_paths[0] 21 | for idx in range(2, 301): 22 | new_file = origin_file[:-7] + '{}_r.png'.format(idx) 23 | shutil.copyfile(origin_file, new_file) 24 | -------------------------------------------------------------------------------- /restormer_x/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def set_logger(log_dir, file_name): 6 | loglevel = logging.INFO 7 | 8 | log_path = os.path.join(log_dir, file_name) 9 | 10 | logger = logging.getLogger() 11 | logger.setLevel(loglevel) 12 | 13 | # Logging to a file 14 | file_handler = logging.FileHandler(log_path) 15 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 16 | logger.addHandler(file_handler) 17 | 18 | # Logging to console 19 | stream_handler = logging.StreamHandler() 20 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 21 | logger.addHandler(stream_handler) 22 | 23 | logging.info('writting logs to file {}'.format(log_path)) 24 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ensemble/ensemble_derain.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from natsort import natsorted 8 | 9 | restormer_x_res_dir = '/gt-rain/result/restormer_x' 10 | median_res_dir = '/gt-rain/result/median' 11 | save_path = '/gt-rain/result' 12 | 13 | 14 | def get_img_paths(data_dir): 15 | scene_names = [] 16 | for sc in list(os.walk(data_dir))[0][1]: 17 | scene_names.append(sc) 18 | img_paths = {} 19 | for scene in scene_names: 20 | img_paths[scene] = natsorted(glob(os.path.join(data_dir, scene, '*_r.png'))) 21 | return img_paths 22 | 23 | 24 | restormer_x_res_paths = get_img_paths(restormer_x_res_dir) 25 | median_res_paths = get_img_paths(median_res_dir) 26 | 27 | wt = 0.9 28 | for scene in restormer_x_res_paths.keys(): 29 | restormer_x_res = np.array(Image.open(restormer_x_res_paths[scene][0])) / 255.0 30 | median_res = np.array(Image.open(median_res_paths[scene][0])) / 255.0 31 | 32 | ensemble_res = wt * restormer_x_res + (1. - wt) * median_res 33 | 34 | ensemble_res = (ensemble_res * 255).astype(np.uint8) 35 | 36 | save_dir = f"{save_path}/ensemble/{scene}" 37 | Path(save_dir).mkdir(parents=True, exist_ok=True) 38 | 39 | filename = restormer_x_res_paths[scene][0].split('\\')[-1] 40 | Image.fromarray(ensemble_res).save(f"{save_dir}/{filename}") 41 | -------------------------------------------------------------------------------- /restormer_x/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from restormer_x.model.restormer import get_model 8 | from restormer_x.utils.log import set_logger 9 | from restormer_x.utils.trainutil import predict 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = '6' 12 | 13 | # CONFIG 14 | params = { 15 | # general 16 | 'save_dir': '/gt-rain/model', # Dir to save the model weights 17 | 'result_dir': '/gt-rain/result', 18 | 'method_name': 'restormer_x', 19 | 20 | # data 21 | 'val_dir_list': ['/gt-rain/GT-RAIN_val'], # Dir for the val data 22 | 'test_dir_list': ['/gt-rain/GT-RAIN_test'], # Dir for the val data 23 | 24 | # model 25 | 'model_version': 'base', 26 | 'resume_epoch': 11, # begin training using loaded checkpoint 27 | } 28 | 29 | # INIT 30 | save_path = os.path.join(params['save_dir'], params['method_name']) 31 | Path(save_path).mkdir(parents=True, exist_ok=True) 32 | set_logger(save_path, 'test.log') 33 | logging.info(str(params)) 34 | 35 | # MODEL 36 | 37 | model = get_model(model_version=params['model_version']) 38 | 39 | resume_epoch = params['resume_epoch'] 40 | resume_file = os.path.join(save_path, f'model_epoch_{resume_epoch}.pth') 41 | checkpoint = torch.load(resume_file) 42 | model.load_state_dict(checkpoint['state_dict'], strict=False) 43 | 44 | # EVALUATE OR TEST 45 | 46 | is_test = True 47 | psnr_res = predict( 48 | model, 49 | params['test_dir_list'][0] if is_test else params['val_dir_list'][0], 50 | is_test=is_test, 51 | save_path=params['result_dir'], 52 | method_name=params['method_name'] 53 | ) 54 | logging.info(psnr_res) 55 | -------------------------------------------------------------------------------- /restormer_x/utils/mixmethod.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def rand_bbox(size, lam): 6 | H = size[2] 7 | W = size[3] 8 | 9 | cut_rat = np.sqrt(1. - lam) 10 | cut_w = np.int(W * cut_rat) 11 | cut_h = np.int(H * cut_rat) 12 | 13 | cx = np.random.randint(W) 14 | cy = np.random.randint(H) 15 | 16 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 17 | bby1 = np.clip(cy - cut_h // 2, 0, H) 18 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 19 | bby2 = np.clip(cy + cut_h // 2, 0, H) 20 | 21 | return bbx1, bby1, bbx2, bby2 22 | 23 | 24 | def mixup(input_image, target_image, alpha=1.0): 25 | """ 26 | 27 | :param alpha: 28 | :param input_image: [bs, c, h, w] 29 | :param target_image: 30 | :return: 31 | """ 32 | image_shape = input_image.shape 33 | rand_index = torch.randperm(image_shape[0]).to(input_image.device) 34 | lam = np.random.beta(alpha, alpha) 35 | 36 | input_image = lam * input_image + (1.0 - lam) * input_image[rand_index] 37 | target_image = lam * target_image + (1.0 - lam) * target_image[rand_index] 38 | 39 | return input_image, target_image 40 | 41 | 42 | def cutmix(input_image, target_image, alpha=1.0): 43 | image_shape = input_image.shape 44 | lam = np.random.beta(alpha, alpha) 45 | bbx1, bby1, bbx2, bby2 = rand_bbox(image_shape, lam) 46 | 47 | rand_index = torch.randperm(image_shape[0]).to(input_image.device) 48 | 49 | input_image[:, :, bby1: bby2, bbx1: bbx2] = input_image[rand_index][:, :, bby1: bby2, bbx1: bbx2] 50 | target_image[:, :, bby1: bby2, bbx1: bbx2] = target_image[rand_index][:, :, bby1: bby2, bbx1: bbx2] 51 | return input_image, target_image 52 | -------------------------------------------------------------------------------- /median/median_derain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from glob import glob 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from natsort import natsorted 9 | 10 | 11 | is_train = True 12 | if is_train: 13 | data_dir = '/gt-rain/GT-RAIN_train' 14 | save_path = '/gt-rain/result' 15 | else: 16 | data_dir = '/gt-rain/GT-RAIN_test' 17 | save_path = '/gt-rain/result' 18 | 19 | 20 | def get_img_paths(data_dir, is_train=False): 21 | scene_names = [] 22 | for sc in list(os.walk(data_dir))[0][1]: 23 | scene_names.append(sc) 24 | img_paths = {} 25 | clean_img_path = {} if is_train else None 26 | for scene in scene_names: 27 | if is_train: 28 | img_paths[scene] = natsorted(glob(os.path.join(data_dir, scene, '*-R-*.png'))) 29 | clean_img_path[scene] = natsorted(glob(os.path.join(data_dir, scene, '*-C-*.png')))[0] 30 | else: 31 | img_paths[scene] = natsorted(glob(os.path.join(data_dir, scene, '*_r.png'))) 32 | return img_paths, clean_img_path 33 | 34 | 35 | img_paths, clean_img_path = get_img_paths(data_dir, is_train) 36 | 37 | for scene, scene_img_paths in img_paths.items(): 38 | 39 | img_list = [] 40 | for img_path in scene_img_paths: 41 | img = Image.open(img_path) 42 | img = np.array(img) / 255.0 43 | img_list.append(img) 44 | median_res = np.median(np.stack(img_list, axis=-1), axis=-1) 45 | median_res = (median_res * 255).astype(np.uint8) 46 | if is_train: 47 | save_dir = f"{save_path}/train_median/{scene}" 48 | else: 49 | save_dir = f"{save_path}/test_median/{scene}" 50 | Path(save_dir).mkdir(parents=True, exist_ok=True) 51 | 52 | filename = scene_img_paths[0].split('\\')[-1] 53 | Image.fromarray(median_res).save(f"{save_dir}/{filename}") 54 | if is_train: 55 | filename = clean_img_path[scene].split('\\')[-1] 56 | shutil.copyfile(clean_img_path[scene], f"{save_dir}/{filename}") 57 | -------------------------------------------------------------------------------- /post_process/post_process_derain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from glob import glob 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from natsort import natsorted 10 | 11 | est_pixels_file = '/gt-rain/result/est_pixels.pkl' 12 | est_pixels = pickle.load(open(est_pixels_file, 'rb')) 13 | ensemble_res_dir = '/gt-rain/result/ensemble' 14 | save_path = '/gt-rain/result' 15 | 16 | 17 | def linear_regression(ensemble_res_dir, est_pixels, N=4, K=10, eps=1e-10): 18 | for scene, pixels_data in est_pixels.items(): 19 | x_img_file = natsorted(glob(os.path.join(ensemble_res_dir, scene, '*_r.png')))[0] 20 | x_img = np.array(Image.open(x_img_file)) / 255. 21 | 22 | wt = np.zeros(shape=[N, 3], dtype=np.float32) 23 | bias = np.zeros(shape=[N, 3], dtype=np.float32) 24 | 25 | for i in range(N): 26 | sum_x = 0. 27 | sum_y = 0. 28 | sum_xy = 0. 29 | sum_x2 = 0. 30 | 31 | sub_pixels_data = random.sample(pixels_data, K) 32 | n = len(sub_pixels_data) 33 | for pdata in sub_pixels_data: 34 | h_idx, w_idx = pdata['pos'] 35 | x = x_img[h_idx, w_idx, :].copy() 36 | y = np.array(pdata['rgb']).copy() / 255. 37 | sum_x += x 38 | sum_y += y 39 | sum_xy += x * y 40 | sum_x2 += x * x 41 | wt[i, :] = (sum_xy - sum_x * sum_y / (eps + n)) / (eps + sum_x2 - sum_x * sum_x / (eps + n)) 42 | bias[i, :] = sum_y / (eps + n) - wt[i, :] * sum_x / (eps + n) 43 | 44 | mwt = np.reshape(np.mean(wt, axis=0), (1, 1, 3)) 45 | mbias = np.reshape(np.mean(bias, axis=0), (1, 1, 3)) 46 | post_process_res = mwt * x_img.copy() + mbias 47 | post_process_res = np.clip(post_process_res, 0., 1.) 48 | post_process_res = (post_process_res * 255).astype(np.uint8) 49 | 50 | save_dir = f"{save_path}/post_process/{scene}" 51 | Path(save_dir).mkdir(parents=True, exist_ok=True) 52 | filename = x_img_file.split('\\')[-1] 53 | Image.fromarray(post_process_res).save(f"{save_dir}/{filename}") 54 | 55 | 56 | linear_regression(ensemble_res_dir, est_pixels) 57 | -------------------------------------------------------------------------------- /restormer_x/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from pathlib import Path 5 | 6 | import tabulate 7 | import torch 8 | import torch.nn as nn 9 | 10 | from restormer_x.dataset.gt_rain_dataset import get_datasets 11 | from restormer_x.model.restormer import get_model 12 | from restormer_x.utils.log import set_logger 13 | from restormer_x.utils.loss import ShiftMSSSIM 14 | from restormer_x.utils.trainutil import get_train_settings, train 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = '6' 17 | 18 | # CONFIG 19 | params = { 20 | # general 21 | 'method_name': 'restormer_x', 22 | # data 23 | 'train_dir_list': ['/gt-rain/GT-RAIN_train'], # Dir for the training data 24 | 'rain_mask_dir': '/gt-rain/Streaks_Garg06', # Dir for the rain masks 25 | 'img_size': 256, # the size of image input 26 | 'zoom_min': .06, # the minimum zoom for RainMix 27 | 'zoom_max': 1.8, # the maximum zoom for RainMix 28 | 'batch_size': 2, # batch size 29 | 30 | # model 31 | 'model_version': 'base', 32 | 'pretrained_model': '/pre-train-model/gt_rain/restormer_deraining.pth', 33 | 34 | # train 35 | 'ssim_kernel_size': 11, # img_size >= (kernel_size - 1) * 16 + 1 36 | 'initial_lr': 3e-4, # initial learning rate used by scheduler 37 | 'weight_decay': 1e-4, 38 | 'num_epochs': 20, # number of epochs to train 39 | 'warmup_epochs': 4, # number of epochs for warmup 40 | 'min_lr': 1e-6, # minimum learning rate used by scheduler 41 | 'mixmethod': 'mixup', 42 | 'mix_prob': 0.5, 43 | 'ssim_loss_weight': 0.0, # weight for the ssim loss 44 | 'acc_grad_step': 4, 45 | 'save_freq': 1, 46 | 'save_dir': '/gt-rain/model', # Dir to save the model weights 47 | } 48 | 49 | # INIT 50 | 51 | save_path = os.path.join(params['save_dir'], params['method_name']) 52 | Path(save_path).mkdir(parents=True, exist_ok=True) 53 | set_logger(save_path, 'train.log') 54 | logging.info(str(params)) 55 | 56 | # DATA 57 | 58 | train_loader = get_datasets(params) 59 | 60 | # MODEL 61 | 62 | model = get_model(model_version=params['model_version']) 63 | 64 | if params['pretrained_model'] is not None: 65 | model.load_state_dict(torch.load(params['pretrained_model'])['params'], strict=False) 66 | 67 | # LOSS 68 | 69 | criterion_l1 = nn.L1Loss().cuda() 70 | criterion_ssim = ShiftMSSSIM(ssim_kernel_size=params['ssim_kernel_size']).cuda() 71 | 72 | # TRAIN 73 | 74 | optimizer, scheduler = get_train_settings(model, params) 75 | 76 | start_epoch = 0 77 | 78 | for epoch in range(start_epoch, params['num_epochs']): 79 | time_ep = time.time() 80 | 81 | train_res = train(model, train_loader, optimizer, scheduler, criterion_l1, criterion_ssim, params) 82 | 83 | if ((epoch + 1) % params['save_freq'] == 0) or ((epoch + 1) == params['num_epochs']): 84 | torch.save( 85 | { 86 | 'epoch': epoch, 87 | 'state_dict': model.state_dict(), 88 | 'optimizer': optimizer.state_dict() 89 | }, 90 | os.path.join(save_path, f'model_epoch_{epoch}.pth') 91 | ) 92 | 93 | time_ep = time.time() - time_ep 94 | columns = ["epoch", "learning_rate", 95 | "train_loss", "train_ssim_loss", "train_l1_loss", 96 | "cost_time"] 97 | 98 | values = [epoch + 1, optimizer.param_groups[0]['lr'], 99 | train_res["total_loss"], train_res["ssim_loss"], train_res["l1_loss"], 100 | time_ep] 101 | 102 | table = tabulate.tabulate([values], columns, tablefmt="simple", floatfmt="8.4f") 103 | if epoch % 50 == 0: 104 | table = table.split("\n") 105 | table = "\n".join([table[1]] + table) 106 | else: 107 | table = table.split("\n")[2] 108 | 109 | logging.info(table) 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Restormer-Plus for Real World Image Deraining: One State-of-the-Art Solution to the GT-RAIN Challenge (CVPR 2023 UG2+ Track 3) 2 | This is the Python code used to implement the Restormer-Plus method as described in the technical report: 3 | 4 | [**Restormer-Plus for Real World Image Deraining: One State-of-the-Art Solution to the GT-RAIN Challenge (CVPR 2023 UG2+ Track 3)** 5 | Chaochao Zheng, Luping Wang, Bin Liu](https://arxiv.org/abs/2305.05454) 6 | 7 | [//]: # (## Technical Report Link) 8 | 9 | [//]: # ([xx](xxx)) 10 | 11 | ## Abstract 12 | This technical report presents our Restormer-Plus approach, which was submitted to the GT-RAIN Challenge (CVPR 2023 UG$^2$+ Track 3). Details regarding the challenge are available at http://cvpr2023.ug2challenge.org/track3.html. Our Restormer-Plus outperformed all other submitted solutions in terms of peak signal-to-noise ratio (PSNR). It consists mainly of four modules: the single image de-raining module, the median filtering module, the weighted averaging module, and the post-processing module. We named the single-image de-raining module Restormer-X, which is built on Restormer and performed on each rainy image. The median filtering module is employed as a median operator for the 300 rainy images associated with each scene. The weighted averaging module combines the median filtering results with that of Restormer-X to alleviate overfitting if we only use Restormer-X. Finally, the post-processing module is used to improve the brightness restoration. Together, these modules render Restormer-Plus to be one state-of-the-art solution to the GT-RAIN Challenge. Our code is available at https://github.com/ZJLAB-AMMI/Restormer-Plus. 13 | 14 | ## Dataset 15 | The dataset can be found [here](https://drive.google.com/drive/folders/1NSRl954QPcGIgoyJa_VjQwh_gEaHWPb8). 16 | 17 | ## Requirements 18 | 19 | - einops==0.3.0 20 | - natsort==8.3.1 21 | - numpy==1.21.5 22 | - opencv_contrib_python==4.2.0.32 23 | - Pillow==9.2.0 24 | - piq==0.7.0 25 | - skimage==0.0 26 | - tabulate==0.8.10 27 | - torch==1.12.1 28 | - torchvision==0.13.1 29 | 30 | ## Setup 31 | Download the dataset from the link above and change the parameters in the ```train.py``` and ```test.py``` code to point to the appropriate directories (e.g., ```./gt-rain/```). 32 | 33 | Download the pre-trained de-rain model from [link](https://drive.google.com/drive/folders/1ZEDDEVW0UgkpWi-N4Lj_JUoVChGXCu_u). 34 | 35 | Install all the required packages. 36 | 37 | ## Running 38 | **restormer-x:** 39 | 40 | - training restormer baseline: set ```model_version=base``` and execute ```python /restormer_x/train.py```. 41 | 42 | - training restormer+: set ```model_version=plus``` and execute ```python /restormer_x/train.py```. 43 | 44 | - evaluate and/or test: execute ```python /restormer_x/test.py```. 45 | 46 | **median:** execute ```python /median/median_derain.py```. 47 | 48 | **ensemble:** execute ```python /ensemble/ensemble_derain.py```. 49 | 50 | **post process:** execute ```python /post_process/post_process_derain.py```. 51 | 52 | **submit result:** execute ```python repeat300.py```. 53 | 54 | ## Citation 55 | If you find this code useful, please kindly cite 56 | 57 | @article{zheng2023RestormerPlus, 58 | 59 | title={Restormer-Plus for Real World Image Deraining: One State-of-the-Art Solution to the GT-RAIN Challenge (CVPR 2023 UG2+ Track 3)}, 60 | 61 | author={Zheng, Chaochao, Wang, Luping and Liu, Bin}, 62 | 63 | journal={arXiv preprint arXiv:2305.05454}, 64 | 65 | year={2023} 66 | 67 | } 68 | ## Disclaimer 69 | Please only use the code and dataset for research purposes. 70 | 71 | ## Contact 72 | Chaochao Zheng
73 | Zhejiang Lab, Research Center for Applied Mathematics and Machine Intelligence
74 | zhengcc@zhejianglab.com 75 | 76 | Luping Wang
77 | Zhejiang Lab, Research Center for Applied Mathematics and Machine Intelligence
78 | wangluping@zhejianglab.com 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 68 | -------------------------------------------------------------------------------- /restormer_x/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from piq import MultiScaleSSIMLoss 5 | 6 | 7 | class ShiftMSSSIM(torch.nn.Module): 8 | """Shifted SSIM Loss """ 9 | 10 | def __init__(self, ssim_kernel_size=11): 11 | super(ShiftMSSSIM, self).__init__() 12 | self.ssim = MultiScaleSSIMLoss(kernel_size=ssim_kernel_size, data_range=1.) 13 | 14 | def forward(self, est, gt): 15 | # shift images back into range (0, 1) 16 | # est = est * 0.5 + 0.5 17 | # gt = gt * 0.5 + 0.5 18 | return self.ssim(est, gt) 19 | 20 | 21 | class RainRobustLoss(torch.nn.Module): 22 | """Rain Robust Loss""" 23 | 24 | def __init__(self, batch_size, n_views, device, temperature=0.07): 25 | super(RainRobustLoss, self).__init__() 26 | self.batch_size = batch_size 27 | self.n_views = n_views 28 | self.temperature = temperature 29 | self.device = device 30 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device) 31 | 32 | def forward(self, features): 33 | logits, labels = self.info_nce_loss(features) 34 | return self.criterion(logits, labels) 35 | 36 | def info_nce_loss(self, features): 37 | labels = torch.cat([torch.arange(self.batch_size) for i in range(self.n_views)], dim=0) 38 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() 39 | labels = labels.to(self.device) 40 | 41 | features = F.normalize(features, dim=1) 42 | 43 | similarity_matrix = torch.matmul(features, features.T) 44 | 45 | # discard the main diagonal from both: labels and similarities matrix 46 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device) 47 | labels = labels[~mask].view(labels.shape[0], -1) 48 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 49 | 50 | # select and combine multiple positives 51 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) 52 | 53 | # select only the negatives the negatives 54 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 55 | 56 | logits = torch.cat([positives, negatives], dim=1) 57 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device) 58 | 59 | logits = logits / self.temperature 60 | return logits, labels 61 | 62 | 63 | def rain_robust_loss(params): 64 | return RainRobustLoss( 65 | batch_size=params['batch_size'], 66 | n_views=2, 67 | device=torch.device("cuda"), 68 | temperature=params['temperature'] 69 | ).cuda() 70 | 71 | 72 | class AverageMeter(object): 73 | """Computes and stores the average and current value""" 74 | 75 | def __init__(self): 76 | self.reset() 77 | 78 | def reset(self): 79 | self.val = 0 80 | self.avg = 0 81 | self.sum = 0 82 | self.count = 0 83 | 84 | def add(self, val, n=1): 85 | self.val = val 86 | self.sum += val * n 87 | self.count += n 88 | 89 | def value(self): 90 | return self.sum / self.count if self.count > 0 else 0.0 91 | 92 | 93 | class AverageAccMeter(object): 94 | 95 | def __init__(self): 96 | self.reset() 97 | 98 | def reset(self): 99 | self.val = 0 100 | self.avg = 0 101 | self.sum = 0 102 | self.count = 0 103 | 104 | def add(self, output, target): 105 | n = output.size(0) 106 | self.val = self.accuracy(output, target).item() 107 | self.sum += self.val * n 108 | self.count += n 109 | 110 | def value(self): 111 | if self.sum == 0: 112 | return 0 113 | else: 114 | return self.sum / self.count 115 | 116 | def accuracy(self, output, target, topk=(1,)): 117 | """Computes the precision@k for the specified values of k""" 118 | maxk = max(topk) 119 | batch_size = target.size(0) 120 | 121 | _, pred = output.topk(maxk, 1, True, True) 122 | pred = pred.t() 123 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 124 | 125 | res = [] 126 | for k in topk: 127 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 128 | res.append(correct_k.mul_(100.0 / batch_size)) 129 | 130 | return res[0] 131 | -------------------------------------------------------------------------------- /post_process/estimate_pixels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from glob import glob 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from natsort import natsorted 9 | 10 | # ==========config 11 | f""" 12 | test_median_res_dir: the directory of the median result of test data, achieved by running median_derain.py 13 | train_median_res_dir: the directory of the median result of train data, achieved by running median_derain.py 14 | pixels_file: a .pkl file where contains the position info of the pixels whose values require to be estimated. 15 | Format: a dict, the key is the scene name, the value is a list of pixel-position. 16 | save_dir: where to save the similar patches. 17 | patch_size: the size of the patch. 18 | min_dis: the threshold used to select similar patches. 19 | """ 20 | test_median_res_dir = '/gt-rain/result/test_median' 21 | train_median_res_dir = '/gt-rain/result/train_median' 22 | pixels_file = '/gt-rain/result/pixels.pkl' 23 | save_dir = '/gt-rain/result/similar_patch' 24 | patch_size = 8 25 | min_dis = 6 26 | # ========== 27 | 28 | pixels = pickle.load(open(pixels_file, 'rb')) 29 | 30 | 31 | def get_img_paths(data_dir): 32 | scene_names = [] 33 | for sc in list(os.walk(data_dir))[0][1]: 34 | scene_names.append(sc) 35 | img_paths = [] 36 | for scene in scene_names: 37 | img_paths.append( 38 | ( 39 | natsorted(glob(os.path.join(data_dir, scene, '*-R-*.png')))[0], 40 | natsorted(glob(os.path.join(data_dir, scene, '*-C-*.png')))[0] 41 | ) 42 | ) 43 | return img_paths 44 | 45 | 46 | for scene_name, pixels_pos in pixels.items(): 47 | test_median_res = np.asarray(Image.open(os.path.join(test_median_res_dir, scene_name, '1_r.png'))) 48 | hts, wts, cts = test_median_res.shape 49 | train_img_paths = get_img_paths(train_median_res_dir) 50 | 51 | for pixel_pos in pixels_pos: 52 | save_path = f"{save_dir}/{scene_name}/{str(pixel_pos[0]) + '_' + str(pixel_pos[1])}" 53 | Path(save_path).mkdir(parents=True, exist_ok=True) 54 | 55 | # test patch 56 | hts1 = np.clip(pixel_pos[0] - patch_size // 2, 0, hts) 57 | hts2 = np.clip(pixel_pos[0] + patch_size // 2, 0, hts) 58 | wts1 = np.clip(pixel_pos[1] - patch_size // 2, 0, wts) 59 | wts2 = np.clip(pixel_pos[1] + patch_size // 2, 0, wts) 60 | test_patch = test_median_res[hts1: hts2, wts1: wts2, :] 61 | Image.fromarray(test_patch).save(f"{save_path}/test_patch.png") 62 | 63 | h_patch_size = hts2 - hts1 64 | w_patch_size = wts2 - wts1 65 | 66 | # search and save similar patch in train data 67 | for train_median_res_file, train_clean_file in train_img_paths: 68 | train_median_res = np.asarray(Image.open(train_median_res_file)) 69 | train_clean = np.asarray(Image.open(train_clean_file)) 70 | htr, wtr, ctr = train_median_res.shape 71 | h_gap = (htr - h_patch_size) // 30 72 | w_gap = (wtr - w_patch_size) // 30 73 | for h_idx in range(0, htr - h_patch_size, h_gap): 74 | for w_idx in range(0, wtr - w_patch_size, w_gap): 75 | train_median_patch = train_median_res[h_idx: (h_idx + h_patch_size), w_idx: (w_idx + w_patch_size), 76 | :] 77 | train_clean_patch = train_clean[h_idx: (h_idx + h_patch_size), w_idx: (w_idx + w_patch_size), :] 78 | 79 | distance = np.median( 80 | np.abs(test_patch.flatten() - train_median_patch.flatten()) 81 | ) 82 | 83 | if distance <= min_dis: 84 | pred_val = np.mean(train_clean_patch, axis=(0, 1)) 85 | Image.fromarray(train_median_patch).save( 86 | os.path.join(save_path, 87 | 'train_median_patch_{}_{}_{}.png'.format(h_idx, w_idx, np.round(distance, 3)))) 88 | Image.fromarray(train_clean_patch).save(os.path.join(save_path, 89 | 'train_clean_patch_{}_{}_{}.png'.format( 90 | h_idx, 91 | w_idx, 92 | np.round(pred_val, 3)))) 93 | -------------------------------------------------------------------------------- /restormer_x/utils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageChops, ImageOps, ImageEnhance 3 | 4 | 5 | def sample_level(n): 6 | return np.random.uniform(low=0.1, high=n) 7 | 8 | 9 | def int_parameter(level, maxval): 10 | """Helper function to scale `val` between 0 and maxval . 11 | 12 | Args: 13 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 14 | maxval: Maximum value that the operation can have. This will be scaled to 15 | level/PARAMETER_MAX. 16 | 17 | Returns: 18 | An int that results from scaling `maxval` according to `level`. 19 | """ 20 | return int(level * maxval / 10) 21 | 22 | 23 | def float_parameter(level, maxval): 24 | """Helper function to scale `val` between 0 and maxval. 25 | 26 | Args: 27 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 28 | maxval: Maximum value that the operation can have. This will be scaled to 29 | level/PARAMETER_MAX. 30 | 31 | Returns: 32 | A float that results from scaling `maxval` according to `level`. 33 | """ 34 | return float(level) * maxval / 10. 35 | 36 | 37 | def autocontrast(pil_img, _): 38 | return ImageOps.autocontrast(pil_img) 39 | 40 | 41 | def equalize(pil_img, _): 42 | return ImageOps.equalize(pil_img) 43 | 44 | 45 | def posterize(pil_img, level): 46 | level = int_parameter(sample_level(level), 4) 47 | return ImageOps.posterize(pil_img, 4 - level) 48 | 49 | 50 | def rotate(pil_img, level): 51 | degrees = int_parameter(sample_level(level), 30) 52 | if np.random.uniform() > 0.5: 53 | degrees = -degrees 54 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 55 | 56 | 57 | def solarize(pil_img, level): 58 | level = int_parameter(sample_level(level), 256) 59 | return ImageOps.solarize(pil_img, 256 - level) 60 | 61 | 62 | def shear_x(pil_img, level): 63 | level = float_parameter(sample_level(level), 0.3) 64 | if np.random.uniform() > 0.5: 65 | level = -level 66 | return pil_img.transform( 67 | (pil_img.width, pil_img.height), 68 | Image.AFFINE, (1, level, 0, 0, 1, 0), 69 | resample=Image.BILINEAR) 70 | 71 | 72 | def shear_y(pil_img, level): 73 | level = float_parameter(sample_level(level), 0.3) 74 | if np.random.uniform() > 0.5: 75 | level = -level 76 | return pil_img.transform( 77 | (pil_img.width, pil_img.height), 78 | Image.AFFINE, (1, 0, 0, level, 1, 0), 79 | resample=Image.BILINEAR) 80 | 81 | 82 | def roll_x(pil_img, level): 83 | """Roll an image sideways.""" 84 | delta = int_parameter(sample_level(level), pil_img.width / 3) 85 | if np.random.random() > 0.5: 86 | delta = -delta 87 | xsize, ysize = pil_img.size 88 | delta = delta % xsize 89 | if delta == 0: return pil_img 90 | part1 = pil_img.crop((0, 0, delta, ysize)) 91 | part2 = pil_img.crop((delta, 0, xsize, ysize)) 92 | pil_img.paste(part1, (xsize - delta, 0, xsize, ysize)) 93 | pil_img.paste(part2, (0, 0, xsize - delta, ysize)) 94 | 95 | return pil_img 96 | 97 | 98 | def roll_y(pil_img, level): 99 | """Roll an image sideways.""" 100 | delta = int_parameter(sample_level(level), pil_img.width / 3) 101 | if np.random.random() > 0.5: 102 | delta = -delta 103 | xsize, ysize = pil_img.size 104 | delta = delta % ysize 105 | if delta == 0: return pil_img 106 | part1 = pil_img.crop((0, 0, xsize, delta)) 107 | part2 = pil_img.crop((0, delta, xsize, ysize)) 108 | pil_img.paste(part1, (0, ysize - delta, xsize, ysize)) 109 | pil_img.paste(part2, (0, 0, xsize, ysize - delta)) 110 | 111 | return pil_img 112 | 113 | 114 | # operation that overlaps with ImageNet-C's test set 115 | def color(pil_img, level): 116 | level = float_parameter(sample_level(level), 1.8) + 0.1 117 | return ImageEnhance.Color(pil_img).enhance(level) 118 | 119 | 120 | # operation that overlaps with ImageNet-C's test set 121 | def contrast(pil_img, level): 122 | level = float_parameter(sample_level(level), 1.8) + 0.1 123 | return ImageEnhance.Contrast(pil_img).enhance(level) 124 | 125 | 126 | # operation that overlaps with ImageNet-C's test set 127 | def brightness(pil_img, level): 128 | level = float_parameter(sample_level(level), 1.8) + 0.1 129 | return ImageEnhance.Brightness(pil_img).enhance(level) 130 | 131 | 132 | # operation that overlaps with ImageNet-C's test set 133 | def sharpness(pil_img, level): 134 | level = float_parameter(sample_level(level), 1.8) + 0.1 135 | return ImageEnhance.Sharpness(pil_img).enhance(level) 136 | 137 | 138 | def zoom_x(pil_img, level): 139 | # zoom from .02 to 2.5 140 | rate = level 141 | zoom_img = pil_img.transform( 142 | (pil_img.width, pil_img.height), 143 | Image.AFFINE, (rate, 0, 0, 0, 1, 0), 144 | resample=Image.BILINEAR) 145 | # need to do reflect padding 146 | if rate > 1.0: 147 | orig_x, orig_y = pil_img.size 148 | new_x = int(orig_x / rate) 149 | zoom_img = np.array(zoom_img) 150 | zoom_img = np.pad(zoom_img[:, :new_x, :], ((0, 0), (0, orig_x - new_x), (0, 0)), 'wrap') 151 | return zoom_img 152 | 153 | 154 | def zoom_y(pil_img, level): 155 | # zoom from .02 to 2.5 156 | rate = level 157 | zoom_img = pil_img.transform( 158 | (pil_img.width, pil_img.height), 159 | Image.AFFINE, (1, 0, 0, 0, rate, 0), 160 | resample=Image.BILINEAR) 161 | # need to do reflect padding 162 | if rate > 1.0: 163 | orig_x, orig_y = pil_img.size 164 | new_y = int(orig_y / rate) 165 | zoom_img = np.array(zoom_img) 166 | zoom_img = np.pad(zoom_img[:new_y, :, :], ((0, orig_y - new_y), (0, 0), (0, 0)), 'wrap') 167 | return zoom_img 168 | 169 | 170 | augmentations = [ 171 | rotate, shear_x, shear_y, 172 | zoom_x, zoom_y, roll_x, roll_y 173 | ] 174 | 175 | 176 | -------------------------------------------------------------------------------- /restormer_x/utils/trainutil.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torchvision.transforms.functional as TF 10 | from PIL import Image 11 | from natsort import natsorted 12 | from skimage.metrics import peak_signal_noise_ratio as psnr 13 | from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau 14 | 15 | from restormer_x.utils.mixmethod import mixup 16 | from restormer_x.utils.loss import AverageMeter 17 | 18 | 19 | class GradualWarmupScheduler(_LRScheduler): 20 | """ Gradually warm-up(increasing) learning rate in optimizer. 21 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 22 | Args: 23 | optimizer (Optimizer): Wrapped optimizer. 24 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 25 | total_epoch: target learning rate is reached at total_epoch, gradually 26 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 27 | """ 28 | 29 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 30 | self.multiplier = multiplier 31 | if self.multiplier < 1.: 32 | raise ValueError('multiplier should be greater thant or equal to 1.') 33 | self.total_epoch = total_epoch 34 | self.after_scheduler = after_scheduler 35 | self.finished = False 36 | super(GradualWarmupScheduler, self).__init__(optimizer) 37 | 38 | def get_lr(self): 39 | if self.last_epoch > self.total_epoch: 40 | if self.after_scheduler: 41 | if not self.finished: 42 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 43 | self.finished = True 44 | return self.after_scheduler.get_last_lr() 45 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 46 | 47 | if self.multiplier == 1.0: 48 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 49 | else: 50 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 51 | self.base_lrs] 52 | 53 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 54 | if epoch is None: 55 | epoch = self.last_epoch + 1 56 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 57 | if self.last_epoch <= self.total_epoch: 58 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 59 | self.base_lrs] 60 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 61 | param_group['lr'] = lr 62 | else: 63 | if epoch is None: 64 | self.after_scheduler.step(metrics, None) 65 | else: 66 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 67 | 68 | def step(self, epoch=None, metrics=None): 69 | if type(self.after_scheduler) != ReduceLROnPlateau: 70 | if self.finished and self.after_scheduler: 71 | if epoch is None: 72 | self.after_scheduler.step(None) 73 | else: 74 | self.after_scheduler.step(epoch - self.total_epoch) 75 | self._last_lr = self.after_scheduler.get_last_lr() 76 | else: 77 | return super(GradualWarmupScheduler, self).step(epoch) 78 | else: 79 | self.step_ReduceLROnPlateau(metrics, epoch) 80 | 81 | 82 | def get_train_settings(model, params): 83 | optimizer = optim.AdamW( 84 | model.parameters(), 85 | lr=params['initial_lr'], 86 | weight_decay=params['weight_decay'] 87 | ) 88 | 89 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR( 90 | optimizer, 91 | params['num_epochs'] - params['warmup_epochs'], 92 | eta_min=params['min_lr']) 93 | 94 | scheduler = GradualWarmupScheduler( 95 | optimizer, 96 | multiplier=1.0, 97 | total_epoch=params['warmup_epochs'], 98 | after_scheduler=scheduler_cosine 99 | ) 100 | 101 | optimizer.zero_grad() 102 | optimizer.step() 103 | scheduler.step() # To start warmup 104 | 105 | return optimizer, scheduler 106 | 107 | 108 | def train(model, train_loader, optimizer, scheduler, criterion_l1, criterion_ssim, params): 109 | model.train() 110 | 111 | total_losses = AverageMeter() 112 | l1_losses = AverageMeter() 113 | ssim_losses = AverageMeter() 114 | num_batchs = len(train_loader.dataset) // params['batch_size'] 115 | for batch_idx, batch_data in enumerate(train_loader): 116 | input_img = batch_data['input_img'].cuda() 117 | target_img = batch_data['target_img'].cuda() 118 | 119 | if (params['mixmethod'] == 'mixup') and (np.random.rand(1) <= params['mix_prob']): 120 | input_img, target_img = mixup(input_img, target_img) 121 | 122 | output_img = model(input_img) 123 | 124 | l1_loss = criterion_l1(output_img, target_img) 125 | loss = l1_loss 126 | l1_losses.add(l1_loss.item(), input_img.size(0)) 127 | 128 | if params['ssim_loss_weight'] > 0: 129 | ssim_loss = criterion_ssim(output_img.clip(0., 1.), target_img) 130 | loss += params['ssim_loss_weight'] * ssim_loss 131 | ssim_losses.add(ssim_loss.item(), input_img.size(0)) 132 | 133 | total_losses.add(loss.item(), input_img.size(0)) 134 | 135 | acc_grad_step = params['acc_grad_step'] 136 | loss = loss / acc_grad_step 137 | loss.backward() 138 | 139 | if (((batch_idx + 1) % acc_grad_step) == 0) or ((batch_idx + 1) == num_batchs): 140 | optimizer.step() 141 | optimizer.zero_grad() 142 | 143 | scheduler.step() 144 | 145 | return { 146 | 'total_loss': total_losses.value(), 147 | 'ssim_loss': ssim_losses.value(), 148 | 'l1_loss': l1_losses.value() 149 | } 150 | 151 | 152 | def predict(model, root_dir, is_test=False, eta=8, save_path=None, method_name=None): 153 | model.eval() 154 | scene_names = [] 155 | for sc in list(os.walk(root_dir))[0][1]: 156 | scene_names.append(sc) 157 | 158 | img_paths = {} 159 | for scene in scene_names: 160 | scene_path = os.path.join(root_dir, scene) 161 | if is_test: 162 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*_r.png'))) 163 | else: 164 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*R-*.png'))) 165 | img_paths[scene] = scene_img_paths 166 | 167 | mean_output = {} 168 | with torch.no_grad(): 169 | for scene_name, im_paths in img_paths.items(): 170 | print(scene_name) 171 | if scene_name not in mean_output: 172 | mean_output[scene_name] = {'sum_im': 0.0, 'num_im': 0} 173 | for im_path in im_paths: 174 | img = Image.open(im_path) 175 | img = np.array(img) 176 | img = TF.to_tensor(img) # [c, h, w] 177 | h, w = img.shape[1:] 178 | padw = eta - (w % eta) if (w % eta) != 0 else 0 179 | padh = eta - (h % eta) if (h % eta) != 0 else 0 180 | if padw != 0 or padh != 0: 181 | img = F.pad(img, (0, padw, 0, padh), mode='reflect') 182 | 183 | input = torch.unsqueeze(img, 0).cuda() 184 | output = model(input) 185 | output = output.squeeze().permute((1, 2, 0)) 186 | output = output.detach().cpu().numpy()[:h, :w, :] 187 | 188 | mean_output[scene_name]['sum_im'] += output 189 | mean_output[scene_name]['num_im'] += 1 190 | 191 | psnr_res = {'scene_psnr': {}, 'psnr': [0.0]} 192 | for scene_name, res in mean_output.items(): 193 | output = res['sum_im'] / res['num_im'] 194 | output = np.clip(output, 0.0, 1.0) 195 | if not is_test: 196 | tmp = img_paths[scene_name][0] 197 | tar_path = tmp[:-9] + 'C-000.png' 198 | if 'Gurutto_1-2' in im_path: 199 | tar_path = tmp[:-9] + 'C' + tmp[-8:] 200 | tar_img = Image.open(tar_path) 201 | tar_img = np.array(tar_img, dtype=np.float32) 202 | tar_img = tar_img / 255 # [h, w, c] 203 | 204 | psnr_val = psnr(tar_img, output) 205 | psnr_res['scene_psnr'][scene_name] = psnr_val 206 | psnr_res['psnr'] += psnr_val 207 | else: 208 | save_dir = f"{save_path}/{method_name}/test/{scene_name}" 209 | Path(save_dir).mkdir(parents=True, exist_ok=True) 210 | output = (output * 255).astype(np.uint8) 211 | filename = img_paths[scene_name][0].split('/')[-1] 212 | Image.fromarray(output).save(f"{save_dir}/{filename}") 213 | psnr_res['psnr'][0] /= len(mean_output.keys()) 214 | return psnr_res 215 | -------------------------------------------------------------------------------- /restormer_x/model/restormer.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import nn 7 | 8 | 9 | class OverlapPatchEmbed(nn.Module): 10 | def __init__(self, in_c=3, embed_dim=48, bias=False): 11 | super(OverlapPatchEmbed, self).__init__() 12 | 13 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) 14 | 15 | def forward(self, x): 16 | x = self.proj(x) 17 | 18 | return x 19 | 20 | 21 | class BiasFree_LayerNorm(nn.Module): 22 | def __init__(self, normalized_shape): 23 | super(BiasFree_LayerNorm, self).__init__() 24 | if isinstance(normalized_shape, numbers.Integral): 25 | normalized_shape = (normalized_shape,) 26 | normalized_shape = torch.Size(normalized_shape) 27 | 28 | assert len(normalized_shape) == 1 29 | 30 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 31 | self.normalized_shape = normalized_shape 32 | 33 | def forward(self, x): 34 | sigma = x.var(-1, keepdim=True, unbiased=False) 35 | return x / torch.sqrt(sigma + 1e-5) * self.weight 36 | 37 | 38 | class WithBias_LayerNorm(nn.Module): 39 | def __init__(self, normalized_shape): 40 | super(WithBias_LayerNorm, self).__init__() 41 | if isinstance(normalized_shape, numbers.Integral): 42 | normalized_shape = (normalized_shape,) 43 | normalized_shape = torch.Size(normalized_shape) 44 | 45 | assert len(normalized_shape) == 1 46 | 47 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 48 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 49 | self.normalized_shape = normalized_shape 50 | 51 | def forward(self, x): 52 | mu = x.mean(-1, keepdim=True) 53 | sigma = x.var(-1, keepdim=True, unbiased=False) 54 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 55 | 56 | 57 | def to_3d(x): 58 | return rearrange(x, 'b c h w -> b (h w) c') 59 | 60 | 61 | def to_4d(x, h, w): 62 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 63 | 64 | 65 | class LayerNorm(nn.Module): 66 | def __init__(self, dim, LayerNorm_type): 67 | super(LayerNorm, self).__init__() 68 | if LayerNorm_type == 'BiasFree': 69 | self.body = BiasFree_LayerNorm(dim) 70 | else: 71 | self.body = WithBias_LayerNorm(dim) 72 | 73 | def forward(self, x): 74 | h, w = x.shape[-2:] 75 | return to_4d(self.body(to_3d(x)), h, w) 76 | 77 | 78 | class Attention(nn.Module): 79 | def __init__(self, dim, num_heads, bias): 80 | super(Attention, self).__init__() 81 | self.num_heads = num_heads 82 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 83 | 84 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) 85 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) 86 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | 91 | qkv = self.qkv_dwconv(self.qkv(x)) 92 | q, k, v = qkv.chunk(3, dim=1) 93 | 94 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 95 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 96 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 97 | 98 | q = torch.nn.functional.normalize(q, dim=-1) 99 | k = torch.nn.functional.normalize(k, dim=-1) 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.temperature 102 | attn = attn.softmax(dim=-1) 103 | 104 | out = (attn @ v) 105 | 106 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 107 | 108 | out = self.project_out(out) 109 | return out 110 | 111 | 112 | class FeedForward(nn.Module): 113 | def __init__(self, dim, ffn_expansion_factor, bias): 114 | super(FeedForward, self).__init__() 115 | 116 | hidden_features = int(dim * ffn_expansion_factor) 117 | 118 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 119 | 120 | self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, 121 | groups=hidden_features * 2, bias=bias) 122 | 123 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 124 | 125 | def forward(self, x): 126 | x = self.project_in(x) 127 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 128 | x = F.gelu(x1) * x2 129 | x = self.project_out(x) 130 | return x 131 | 132 | 133 | class TransformerBlock(nn.Module): 134 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 135 | super(TransformerBlock, self).__init__() 136 | 137 | self.norm1 = LayerNorm(dim, LayerNorm_type) 138 | self.attn = Attention(dim, num_heads, bias) 139 | self.norm2 = LayerNorm(dim, LayerNorm_type) 140 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 141 | 142 | def forward(self, x): 143 | x = x + self.attn(self.norm1(x)) 144 | x = x + self.ffn(self.norm2(x)) 145 | 146 | return x 147 | 148 | 149 | class Downsample(nn.Module): 150 | def __init__(self, n_feat): 151 | super(Downsample, self).__init__() 152 | 153 | self.body = nn.Sequential( 154 | nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), 155 | nn.PixelUnshuffle(2) 156 | ) 157 | 158 | def forward(self, x): 159 | return self.body(x) 160 | 161 | 162 | class Upsample(nn.Module): 163 | def __init__(self, n_feat): 164 | super(Upsample, self).__init__() 165 | 166 | self.body = nn.Sequential( 167 | nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), 168 | nn.PixelShuffle(2) 169 | ) 170 | 171 | def forward(self, x): 172 | return self.body(x) 173 | 174 | 175 | class Restormer(nn.Module): 176 | def __init__( 177 | self, 178 | inp_channels=3, 179 | out_channels=3, 180 | dim=48, 181 | num_blocks=[4, 6, 6, 8], 182 | num_refinement_blocks=4, 183 | heads=[1, 2, 4, 8], 184 | ffn_expansion_factor=2.66, 185 | bias=False, 186 | LayerNorm_type='WithBias', 187 | version='base' # base or plus 188 | ): 189 | super(Restormer, self).__init__() 190 | 191 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) 192 | 193 | self.encoder_level1 = nn.Sequential(*[ 194 | TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, 195 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) 196 | 197 | self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 198 | self.encoder_level2 = nn.Sequential(*[ 199 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, 200 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) 201 | 202 | self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 203 | self.encoder_level3 = nn.Sequential(*[ 204 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, 205 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) 206 | 207 | self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 208 | self.latent = nn.Sequential(*[ 209 | TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, 210 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) 211 | 212 | self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 213 | self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) 214 | self.decoder_level3 = nn.Sequential(*[ 215 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, 216 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) 217 | 218 | self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 219 | self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) 220 | self.decoder_level2 = nn.Sequential(*[ 221 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, 222 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) 223 | 224 | self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) 225 | 226 | self.decoder_level1 = nn.Sequential(*[ 227 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, 228 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) 229 | 230 | self.refinement = nn.Sequential(*[ 231 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, 232 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) 233 | self.output_wt = None 234 | if version == 'plus': 235 | self.output_wt = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 236 | 237 | self.output_bias = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 238 | 239 | self.version = version 240 | 241 | def forward(self, inp_img): 242 | inp_enc_level1 = self.patch_embed(inp_img) 243 | out_enc_level1 = self.encoder_level1(inp_enc_level1) 244 | 245 | inp_enc_level2 = self.down1_2(out_enc_level1) 246 | out_enc_level2 = self.encoder_level2(inp_enc_level2) 247 | 248 | inp_enc_level3 = self.down2_3(out_enc_level2) 249 | out_enc_level3 = self.encoder_level3(inp_enc_level3) 250 | 251 | inp_enc_level4 = self.down3_4(out_enc_level3) 252 | latent = self.latent(inp_enc_level4) 253 | 254 | inp_dec_level3 = self.up4_3(latent) 255 | inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) 256 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) 257 | out_dec_level3 = self.decoder_level3(inp_dec_level3) 258 | 259 | inp_dec_level2 = self.up3_2(out_dec_level3) 260 | inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) 261 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) 262 | out_dec_level2 = self.decoder_level2(inp_dec_level2) 263 | 264 | inp_dec_level1 = self.up2_1(out_dec_level2) 265 | inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) 266 | out_dec_level1 = self.decoder_level1(inp_dec_level1) 267 | 268 | out_dec_level1 = self.refinement(out_dec_level1) 269 | 270 | if self.version == 'plus' and self.output_wt is not None: 271 | out_dec_level1 = self.output_wt(out_dec_level1) * inp_img + self.output_bias(out_dec_level1) 272 | else: 273 | out_dec_level1 = inp_img + self.output_bias(out_dec_level1) 274 | 275 | return out_dec_level1 276 | 277 | 278 | def get_model(model_version='base'): 279 | model = Restormer(version=model_version) 280 | model.cuda() 281 | return model 282 | -------------------------------------------------------------------------------- /restormer_x/dataset/gt_rain_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import random 5 | from glob import glob 6 | 7 | import cv2 8 | import numpy as np 9 | import torch.nn.functional as F 10 | import torchvision.transforms.functional as TF 11 | from PIL import Image 12 | from natsort import natsorted 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from restormer_x.utils.data_augmentation import augmentations, zoom_x, zoom_y 16 | 17 | 18 | def getRainLayer2(rand_id1, rand_id2, rain_mask_dir): 19 | path_img_rainlayer_src = os.path.join(rain_mask_dir, f'{rand_id1}-{rand_id2}.png') 20 | rainlayer_rand = cv2.imread(path_img_rainlayer_src).astype(np.float32) / 255.0 21 | rainlayer_rand = cv2.cvtColor(rainlayer_rand, cv2.COLOR_BGR2RGB) 22 | return rainlayer_rand 23 | 24 | 25 | def getRandRainLayer2(rain_mask_dir): 26 | rand_id1 = random.randint(1, 165) 27 | rand_id2 = random.randint(4, 8) 28 | rainlayer_rand = getRainLayer2(rand_id1, rand_id2, rain_mask_dir) 29 | return rainlayer_rand 30 | 31 | 32 | def apply_op(image, op, severity): 33 | image = np.clip(image * 255., 0, 255).astype(np.uint8) 34 | pil_img = Image.fromarray(image) # Convert to PIL.Image 35 | pil_img = op(pil_img, severity) 36 | return np.asarray(pil_img) / 255. 37 | 38 | 39 | def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1., zoom_min=0.06, zoom_max=1.8): 40 | """Perform AugMix augmentations and compute mixture. 41 | Args: 42 | image: Raw input image as float32 np.ndarray of shape (h, w, c) 43 | severity: Severity of underlying augmentation operators (between 1 to 10). 44 | width: Width of augmentation chain 45 | depth: Depth of augmentation chain. -1 enables stochastic depth uniformly 46 | from [1, 3] 47 | alpha: Probability coefficient for Beta and Dirichlet distributions. 48 | Returns: 49 | mixed: Augmented and mixed image. 50 | """ 51 | ws = np.float32( 52 | np.random.dirichlet([alpha] * width)) 53 | m = np.float32(np.random.beta(alpha, alpha)) 54 | 55 | mix = np.zeros_like(image) 56 | for i in range(width): 57 | image_aug = image.copy() 58 | depth = depth if depth > 0 else np.random.randint(2, 4) 59 | for _ in range(depth): 60 | op = np.random.choice(augmentations) 61 | if (op == zoom_x or op == zoom_y): 62 | rate = np.random.uniform(low=zoom_min, high=zoom_max) 63 | image_aug = apply_op(image_aug, op, rate) 64 | else: 65 | image_aug = apply_op(image_aug, op, severity) 66 | # Preprocessing commutes since all coefficients are convex 67 | mix += ws[i] * image_aug 68 | 69 | max_ws = max(ws) 70 | rate = 1.0 / max_ws 71 | 72 | mixed = max((1 - m), 0.7) * image + max(m, rate * 0.5) * mix 73 | return mixed 74 | 75 | 76 | class RandomCrop(object): 77 | def __init__(self, image_size, crop_size): 78 | self.ch, self.cw = crop_size 79 | ih, iw = image_size 80 | 81 | self.h1 = random.randint(0, ih - self.ch) 82 | self.w1 = random.randint(0, iw - self.cw) 83 | 84 | self.h2 = self.h1 + self.ch 85 | self.w2 = self.w1 + self.cw 86 | 87 | def __call__(self, img): 88 | if len(img.shape) == 3: 89 | return img[self.h1: self.h2, self.w1: self.w2, :] 90 | else: 91 | return img[self.h1: self.h2, self.w1: self.w2] 92 | 93 | 94 | def rain_aug(img_rainy, img_gt, rain_mask_dir, zoom_min=0.06, zoom_max=1.8): 95 | img_rainy = (img_rainy.astype(np.float32)) / 255.0 96 | img_gt = (img_gt.astype(np.float32)) / 255.0 97 | img_rainy_ret = img_rainy 98 | img_gt_ret = img_gt 99 | 100 | rainlayer_rand2 = getRandRainLayer2(rain_mask_dir) 101 | rainlayer_aug2 = augment_and_mix( 102 | rainlayer_rand2, 103 | severity=3, 104 | width=3, 105 | depth=-1, 106 | zoom_min=zoom_min, 107 | zoom_max=zoom_max 108 | ) * 1 109 | 110 | height = min(img_rainy.shape[0], rainlayer_aug2.shape[0]) 111 | width = min(img_rainy.shape[1], rainlayer_aug2.shape[1]) 112 | 113 | cropper = RandomCrop(rainlayer_aug2.shape[:2], (height, width)) 114 | rainlayer_aug2_crop = cropper(rainlayer_aug2) 115 | cropper = RandomCrop(img_rainy.shape[:2], (height, width)) 116 | img_rainy_ret = cropper(img_rainy_ret) 117 | img_gt_ret = cropper(img_gt_ret) 118 | img_rainy_ret = img_rainy_ret + rainlayer_aug2_crop - img_rainy_ret * rainlayer_aug2_crop 119 | img_rainy_ret = np.clip(img_rainy_ret, 0.0, 1.0) 120 | img_rainy_ret = (img_rainy_ret * 255).astype(np.uint8) 121 | img_gt_ret = (img_gt_ret * 255).astype(np.uint8) 122 | 123 | return img_rainy_ret, img_gt_ret 124 | 125 | 126 | def get_translation_matrix_2d(dx, dy): 127 | """ 128 | Returns a numpy affine transformation matrix for a 2D translation of 129 | (dx, dy) 130 | """ 131 | return np.matrix([[1, 0, dx], [0, 1, dy], [0, 0, 1]]) 132 | 133 | 134 | def rotate_image(image, angle): 135 | """ 136 | Rotates the given image about it's centre 137 | """ 138 | 139 | image_size = (image.shape[1], image.shape[0]) 140 | image_center = tuple(np.array(image_size) / 2) 141 | 142 | rot_mat = np.vstack([cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]]) 143 | trans_mat = np.identity(3) 144 | 145 | w2 = image_size[0] * 0.5 146 | h2 = image_size[1] * 0.5 147 | 148 | rot_mat_notranslate = np.matrix(rot_mat[0:2, 0:2]) 149 | 150 | tl = (np.array([-w2, h2]) * rot_mat_notranslate).A[0] 151 | tr = (np.array([w2, h2]) * rot_mat_notranslate).A[0] 152 | bl = (np.array([-w2, -h2]) * rot_mat_notranslate).A[0] 153 | br = (np.array([w2, -h2]) * rot_mat_notranslate).A[0] 154 | 155 | x_coords = [pt[0] for pt in [tl, tr, bl, br]] 156 | x_pos = [x for x in x_coords if x > 0] 157 | x_neg = [x for x in x_coords if x < 0] 158 | 159 | y_coords = [pt[1] for pt in [tl, tr, bl, br]] 160 | y_pos = [y for y in y_coords if y > 0] 161 | y_neg = [y for y in y_coords if y < 0] 162 | 163 | right_bound = max(x_pos) 164 | left_bound = min(x_neg) 165 | top_bound = max(y_pos) 166 | bot_bound = min(y_neg) 167 | 168 | new_w = int(abs(right_bound - left_bound)) 169 | new_h = int(abs(top_bound - bot_bound)) 170 | new_image_size = (new_w, new_h) 171 | 172 | new_midx = new_w * 0.5 173 | new_midy = new_h * 0.5 174 | 175 | dx = int(new_midx - w2) 176 | dy = int(new_midy - h2) 177 | 178 | trans_mat = get_translation_matrix_2d(dx, dy) 179 | affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :] 180 | result = cv2.warpAffine(image, affine_mat, new_image_size, flags=cv2.INTER_LINEAR) 181 | 182 | return result 183 | 184 | 185 | def rotated_rect_with_max_area(w, h, angle): 186 | """ 187 | Given a rectangle of size wxh that has been rotated by 'angle' (in 188 | radians), computes the width and height of the largest possible 189 | axis-aligned rectangle (maximal area) within the rotated rectangle. 190 | """ 191 | if w <= 0 or h <= 0: 192 | return 0, 0 193 | 194 | width_is_longer = w >= h 195 | side_long, side_short = (w, h) if width_is_longer else (h, w) 196 | 197 | # since the solutions for angle, -angle and 180-angle are all the same, 198 | # if suffices to look at the first quadrant and the absolute values of sin,cos: 199 | sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) 200 | if side_short <= 2. * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: 201 | # half constrained case: two crop corners touch the longer side, 202 | # the other two corners are on the mid-line parallel to the longer line 203 | x = 0.5 * side_short 204 | wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a) 205 | else: 206 | # fully constrained case: crop touches all 4 sides 207 | cos_2a = cos_a * cos_a - sin_a * sin_a 208 | wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a 209 | 210 | return int(wr), int(hr) 211 | 212 | 213 | def gen_rotate_image(img, angle): 214 | dim = img.shape 215 | h = dim[0] 216 | w = dim[1] 217 | 218 | img = rotate_image(img, angle) 219 | dim_bb = img.shape 220 | h_bb = dim_bb[0] 221 | w_bb = dim_bb[1] 222 | 223 | w_r, h_r = rotated_rect_with_max_area(w, h, math.radians(angle)) 224 | 225 | w_0 = (w_bb - w_r) // 2 226 | h_0 = (h_bb - h_r) // 2 227 | img = img[h_0:h_0 + h_r, w_0:w_0 + w_r, :] 228 | 229 | return img 230 | 231 | 232 | class GTRainDataset(Dataset): 233 | """ 234 | The dataset class for weather net training and validation. 235 | 236 | Parameters: 237 | train_dir_list (list) -- list of dirs for the dataset. 238 | val_dir_list (list) -- list of dirs for the dataset. 239 | rain_mask_dir (string) -- location of rain masks for data augmentation. 240 | img_size (int) -- size of the images after cropping. 241 | is_train (bool) -- True for training set. 242 | val_list (list) -- list of validation scenes 243 | sigma (int) -- variance for random angle rotation data augmentation 244 | zoom_min (float) -- minimum zoom for RainMix data augmentation 245 | zoom_max (float) -- maximum zoom for RainMix data augmentation 246 | """ 247 | 248 | def __init__( 249 | self, 250 | train_dir_list=None, 251 | rain_mask_dir=None, 252 | img_size=256, 253 | sigma=13, 254 | zoom_min=0.06, 255 | zoom_max=1.8 256 | ): 257 | super(GTRainDataset, self).__init__() 258 | 259 | self.rain_mask_dir = rain_mask_dir 260 | self.img_size = img_size 261 | self.sigma = sigma 262 | self.zoom_min = zoom_min 263 | self.zoom_max = zoom_max 264 | 265 | scene_paths = [] 266 | for root_dir in train_dir_list: 267 | scene_paths += [os.path.join(root_dir, scene) for scene in list(os.walk(root_dir))[0][1]] 268 | 269 | last_index = 0 270 | self.img_paths = [] 271 | self.scene_indices = [] 272 | 273 | for scene_path in scene_paths: 274 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*R-*.png'))) 275 | scene_length = len(scene_img_paths) 276 | self.scene_indices.append(list(range(last_index, last_index + scene_length))) 277 | last_index += scene_length 278 | self.img_paths += scene_img_paths 279 | 280 | # number of images in full dataset 281 | self.data_len = len(self.img_paths) 282 | 283 | def __len__(self): 284 | return self.data_len 285 | 286 | def get_scene_indices(self): 287 | return self.scene_indices 288 | 289 | def __getitem__(self, index): 290 | ts = self.img_size 291 | 292 | inp_path = self.img_paths[index] 293 | tar_path = self.img_paths[index][:-9] + 'C-000.png' 294 | if 'Gurutto_1-2' in inp_path: 295 | tar_path = self.img_paths[index][:-9] + 'C' + self.img_paths[index][-8:] 296 | 297 | inp_img = Image.open(inp_path) 298 | inp_img = np.array(inp_img) 299 | 300 | tar_img = Image.open(tar_path) 301 | tar_img = np.array(tar_img) # [height, width, channel] 302 | 303 | # rain aug 304 | if random.randint(1, 10) > 4: 305 | inp_img, tar_img = rain_aug( 306 | inp_img, 307 | tar_img, 308 | self.rain_mask_dir, 309 | zoom_min=self.zoom_min, 310 | zoom_max=self.zoom_max 311 | ) 312 | 313 | # Random rotation 314 | angle = np.random.normal(0, self.sigma) 315 | inp_img_rot = gen_rotate_image(inp_img, angle) 316 | if inp_img_rot.shape[0] >= 256 and inp_img_rot.shape[1] >= 256: 317 | inp_img = inp_img_rot 318 | tar_img = gen_rotate_image(tar_img, angle) 319 | 320 | # reflect pad and random cropping to ensure the right image size for training 321 | h, w = inp_img.shape[:2] 322 | 323 | # To tensor 324 | inp_img = TF.to_tensor(inp_img) # [channel, height, width] 325 | tar_img = TF.to_tensor(tar_img) 326 | 327 | # reflect padding 328 | padw = ts - w if w < ts else 0 329 | padh = ts - h if h < ts else 0 330 | 331 | if padw != 0 or padh != 0: 332 | inp_img = F.pad(inp_img, (padw // 2, padw - padw // 2, padh // 2, padh - padh // 2), mode='reflect') 333 | tar_img = F.pad(tar_img, (padw // 2, padw - padw // 2, padh // 2, padh - padh // 2), mode='reflect') 334 | 335 | # random cropping 336 | hh, ww, = inp_img.shape[1], inp_img.shape[2] 337 | rr = random.randint(0, hh - ts) 338 | cc = random.randint(0, ww - ts) 339 | inp_img = inp_img[:, rr:rr + ts, cc:cc + ts] 340 | tar_img = tar_img[:, rr:rr + ts, cc:cc + ts] 341 | 342 | # Data augmentations: flip x, flip y 343 | aug = random.randint(0, 2) 344 | if aug == 1: 345 | inp_img = inp_img.flip(1) 346 | tar_img = tar_img.flip(1) 347 | elif aug == 2: 348 | inp_img = inp_img.flip(2) 349 | tar_img = tar_img.flip(2) 350 | 351 | # Get image name 352 | scene_name = inp_path.split('/')[-2] 353 | file_name = inp_path.split('/')[-1] 354 | 355 | # Dict for return 356 | # If using tanh as the last layer, the range should be [-1, 1] 357 | 358 | sample_dict = { 359 | 'input_img': inp_img, 360 | 'target_img': tar_img, 361 | 'file_name': file_name 362 | } 363 | 364 | return sample_dict 365 | 366 | 367 | class CustomBatchSampler(): 368 | def __init__(self, scene_indices, batch_size=16): 369 | self.scene_indices = scene_indices 370 | self.batch_size = batch_size 371 | self.num_batches = int(scene_indices[-1][-1] / batch_size) 372 | 373 | def __len__(self): 374 | return self.num_batches 375 | 376 | def __iter__(self): 377 | scene_indices = copy.deepcopy(self.scene_indices) 378 | for scene_list in scene_indices: 379 | random.shuffle(scene_list) 380 | out_indices = [] 381 | done = False 382 | while not done: 383 | out_batch_indices = [] 384 | if (len(scene_indices) < self.batch_size): 385 | self.num_batches = len(out_indices) 386 | return iter(out_indices) 387 | chosen_scenes = np.random.choice(len(scene_indices), self.batch_size, replace=False) 388 | empty_indices = [] 389 | for i in chosen_scenes: 390 | scene_list = scene_indices[i] 391 | out_batch_indices.append(scene_list.pop()) 392 | if (len(scene_list) == 0): 393 | empty_indices.append(i) 394 | empty_indices.sort(reverse=True) 395 | for i in empty_indices: 396 | scene_indices.pop(i) 397 | out_indices.append(out_batch_indices) 398 | self.num_batches = len(out_indices) 399 | return iter(out_indices) 400 | 401 | 402 | def get_datasets(params): 403 | train_dataset = GTRainDataset( 404 | train_dir_list=params['train_dir_list'], 405 | rain_mask_dir=params['rain_mask_dir'], 406 | img_size=params['img_size'], 407 | zoom_min=params['zoom_min'], 408 | zoom_max=params['zoom_max'] 409 | ) 410 | 411 | train_loader = DataLoader( 412 | dataset=train_dataset, 413 | batch_sampler=CustomBatchSampler( 414 | train_dataset.get_scene_indices(), 415 | batch_size=params['batch_size'] 416 | ), 417 | num_workers=2, 418 | pin_memory=True 419 | ) 420 | 421 | return train_loader 422 | --------------------------------------------------------------------------------