├── median
├── __init__.py
└── median_derain.py
├── ensemble
├── __init__.py
└── ensemble_derain.py
├── restormer_x
├── __init__.py
├── dataset
│ ├── __init__.py
│ └── gt_rain_dataset.py
├── model
│ ├── __init__.py
│ └── restormer.py
├── utils
│ ├── __init__.py
│ ├── log.py
│ ├── mixmethod.py
│ ├── loss.py
│ ├── data_augmentation.py
│ └── trainutil.py
├── test.py
└── train.py
├── post_process
├── __init__.py
├── post_process_derain.py
└── estimate_pixels.py
├── .idea
├── vcs.xml
├── misc.xml
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
├── Restormer-Plus.iml
└── deployment.xml
├── requirements.txt
├── repeat300.py
├── LICENSE
├── README.md
└── .gitignore
/median/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ensemble/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/restormer_x/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/post_process/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/restormer_x/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/restormer_x/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/restormer_x/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.3.0
2 | natsort==8.3.1
3 | numpy==1.21.5
4 | opencv_contrib_python==4.2.0.32
5 | Pillow==9.2.0
6 | piq==0.7.0
7 | skimage==0.0
8 | tabulate==0.8.10
9 | torch==1.12.1
10 | torchvision==0.13.1
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/Restormer-Plus.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/repeat300.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from glob import glob
4 |
5 | from natsort import natsorted
6 |
7 | root_dir = '/gt-rain/result/post_process'
8 | scene_names = []
9 | for sc in list(os.walk(root_dir))[0][1]:
10 | scene_names.append(sc)
11 |
12 | img_paths = {}
13 | for scene in scene_names:
14 | scene_path = os.path.join(root_dir, scene)
15 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*_r.png')))
16 | img_paths[scene] = scene_img_paths
17 |
18 | for scene_name, im_paths in img_paths.items():
19 | print(scene_name)
20 | origin_file = im_paths[0]
21 | for idx in range(2, 301):
22 | new_file = origin_file[:-7] + '{}_r.png'.format(idx)
23 | shutil.copyfile(origin_file, new_file)
24 |
--------------------------------------------------------------------------------
/restormer_x/utils/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def set_logger(log_dir, file_name):
6 | loglevel = logging.INFO
7 |
8 | log_path = os.path.join(log_dir, file_name)
9 |
10 | logger = logging.getLogger()
11 | logger.setLevel(loglevel)
12 |
13 | # Logging to a file
14 | file_handler = logging.FileHandler(log_path)
15 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
16 | logger.addHandler(file_handler)
17 |
18 | # Logging to console
19 | stream_handler = logging.StreamHandler()
20 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
21 | logger.addHandler(stream_handler)
22 |
23 | logging.info('writting logs to file {}'.format(log_path))
24 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ensemble/ensemble_derain.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | from PIL import Image
7 | from natsort import natsorted
8 |
9 | restormer_x_res_dir = '/gt-rain/result/restormer_x'
10 | median_res_dir = '/gt-rain/result/median'
11 | save_path = '/gt-rain/result'
12 |
13 |
14 | def get_img_paths(data_dir):
15 | scene_names = []
16 | for sc in list(os.walk(data_dir))[0][1]:
17 | scene_names.append(sc)
18 | img_paths = {}
19 | for scene in scene_names:
20 | img_paths[scene] = natsorted(glob(os.path.join(data_dir, scene, '*_r.png')))
21 | return img_paths
22 |
23 |
24 | restormer_x_res_paths = get_img_paths(restormer_x_res_dir)
25 | median_res_paths = get_img_paths(median_res_dir)
26 |
27 | wt = 0.9
28 | for scene in restormer_x_res_paths.keys():
29 | restormer_x_res = np.array(Image.open(restormer_x_res_paths[scene][0])) / 255.0
30 | median_res = np.array(Image.open(median_res_paths[scene][0])) / 255.0
31 |
32 | ensemble_res = wt * restormer_x_res + (1. - wt) * median_res
33 |
34 | ensemble_res = (ensemble_res * 255).astype(np.uint8)
35 |
36 | save_dir = f"{save_path}/ensemble/{scene}"
37 | Path(save_dir).mkdir(parents=True, exist_ok=True)
38 |
39 | filename = restormer_x_res_paths[scene][0].split('\\')[-1]
40 | Image.fromarray(ensemble_res).save(f"{save_dir}/{filename}")
41 |
--------------------------------------------------------------------------------
/restormer_x/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 | import torch
6 |
7 | from restormer_x.model.restormer import get_model
8 | from restormer_x.utils.log import set_logger
9 | from restormer_x.utils.trainutil import predict
10 |
11 | os.environ["CUDA_VISIBLE_DEVICES"] = '6'
12 |
13 | # CONFIG
14 | params = {
15 | # general
16 | 'save_dir': '/gt-rain/model', # Dir to save the model weights
17 | 'result_dir': '/gt-rain/result',
18 | 'method_name': 'restormer_x',
19 |
20 | # data
21 | 'val_dir_list': ['/gt-rain/GT-RAIN_val'], # Dir for the val data
22 | 'test_dir_list': ['/gt-rain/GT-RAIN_test'], # Dir for the val data
23 |
24 | # model
25 | 'model_version': 'base',
26 | 'resume_epoch': 11, # begin training using loaded checkpoint
27 | }
28 |
29 | # INIT
30 | save_path = os.path.join(params['save_dir'], params['method_name'])
31 | Path(save_path).mkdir(parents=True, exist_ok=True)
32 | set_logger(save_path, 'test.log')
33 | logging.info(str(params))
34 |
35 | # MODEL
36 |
37 | model = get_model(model_version=params['model_version'])
38 |
39 | resume_epoch = params['resume_epoch']
40 | resume_file = os.path.join(save_path, f'model_epoch_{resume_epoch}.pth')
41 | checkpoint = torch.load(resume_file)
42 | model.load_state_dict(checkpoint['state_dict'], strict=False)
43 |
44 | # EVALUATE OR TEST
45 |
46 | is_test = True
47 | psnr_res = predict(
48 | model,
49 | params['test_dir_list'][0] if is_test else params['val_dir_list'][0],
50 | is_test=is_test,
51 | save_path=params['result_dir'],
52 | method_name=params['method_name']
53 | )
54 | logging.info(psnr_res)
55 |
--------------------------------------------------------------------------------
/restormer_x/utils/mixmethod.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def rand_bbox(size, lam):
6 | H = size[2]
7 | W = size[3]
8 |
9 | cut_rat = np.sqrt(1. - lam)
10 | cut_w = np.int(W * cut_rat)
11 | cut_h = np.int(H * cut_rat)
12 |
13 | cx = np.random.randint(W)
14 | cy = np.random.randint(H)
15 |
16 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
17 | bby1 = np.clip(cy - cut_h // 2, 0, H)
18 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
19 | bby2 = np.clip(cy + cut_h // 2, 0, H)
20 |
21 | return bbx1, bby1, bbx2, bby2
22 |
23 |
24 | def mixup(input_image, target_image, alpha=1.0):
25 | """
26 |
27 | :param alpha:
28 | :param input_image: [bs, c, h, w]
29 | :param target_image:
30 | :return:
31 | """
32 | image_shape = input_image.shape
33 | rand_index = torch.randperm(image_shape[0]).to(input_image.device)
34 | lam = np.random.beta(alpha, alpha)
35 |
36 | input_image = lam * input_image + (1.0 - lam) * input_image[rand_index]
37 | target_image = lam * target_image + (1.0 - lam) * target_image[rand_index]
38 |
39 | return input_image, target_image
40 |
41 |
42 | def cutmix(input_image, target_image, alpha=1.0):
43 | image_shape = input_image.shape
44 | lam = np.random.beta(alpha, alpha)
45 | bbx1, bby1, bbx2, bby2 = rand_bbox(image_shape, lam)
46 |
47 | rand_index = torch.randperm(image_shape[0]).to(input_image.device)
48 |
49 | input_image[:, :, bby1: bby2, bbx1: bbx2] = input_image[rand_index][:, :, bby1: bby2, bbx1: bbx2]
50 | target_image[:, :, bby1: bby2, bbx1: bbx2] = target_image[rand_index][:, :, bby1: bby2, bbx1: bbx2]
51 | return input_image, target_image
52 |
--------------------------------------------------------------------------------
/median/median_derain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from glob import glob
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | from PIL import Image
8 | from natsort import natsorted
9 |
10 |
11 | is_train = True
12 | if is_train:
13 | data_dir = '/gt-rain/GT-RAIN_train'
14 | save_path = '/gt-rain/result'
15 | else:
16 | data_dir = '/gt-rain/GT-RAIN_test'
17 | save_path = '/gt-rain/result'
18 |
19 |
20 | def get_img_paths(data_dir, is_train=False):
21 | scene_names = []
22 | for sc in list(os.walk(data_dir))[0][1]:
23 | scene_names.append(sc)
24 | img_paths = {}
25 | clean_img_path = {} if is_train else None
26 | for scene in scene_names:
27 | if is_train:
28 | img_paths[scene] = natsorted(glob(os.path.join(data_dir, scene, '*-R-*.png')))
29 | clean_img_path[scene] = natsorted(glob(os.path.join(data_dir, scene, '*-C-*.png')))[0]
30 | else:
31 | img_paths[scene] = natsorted(glob(os.path.join(data_dir, scene, '*_r.png')))
32 | return img_paths, clean_img_path
33 |
34 |
35 | img_paths, clean_img_path = get_img_paths(data_dir, is_train)
36 |
37 | for scene, scene_img_paths in img_paths.items():
38 |
39 | img_list = []
40 | for img_path in scene_img_paths:
41 | img = Image.open(img_path)
42 | img = np.array(img) / 255.0
43 | img_list.append(img)
44 | median_res = np.median(np.stack(img_list, axis=-1), axis=-1)
45 | median_res = (median_res * 255).astype(np.uint8)
46 | if is_train:
47 | save_dir = f"{save_path}/train_median/{scene}"
48 | else:
49 | save_dir = f"{save_path}/test_median/{scene}"
50 | Path(save_dir).mkdir(parents=True, exist_ok=True)
51 |
52 | filename = scene_img_paths[0].split('\\')[-1]
53 | Image.fromarray(median_res).save(f"{save_dir}/{filename}")
54 | if is_train:
55 | filename = clean_img_path[scene].split('\\')[-1]
56 | shutil.copyfile(clean_img_path[scene], f"{save_dir}/{filename}")
57 |
--------------------------------------------------------------------------------
/post_process/post_process_derain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | from glob import glob
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | from PIL import Image
9 | from natsort import natsorted
10 |
11 | est_pixels_file = '/gt-rain/result/est_pixels.pkl'
12 | est_pixels = pickle.load(open(est_pixels_file, 'rb'))
13 | ensemble_res_dir = '/gt-rain/result/ensemble'
14 | save_path = '/gt-rain/result'
15 |
16 |
17 | def linear_regression(ensemble_res_dir, est_pixels, N=4, K=10, eps=1e-10):
18 | for scene, pixels_data in est_pixels.items():
19 | x_img_file = natsorted(glob(os.path.join(ensemble_res_dir, scene, '*_r.png')))[0]
20 | x_img = np.array(Image.open(x_img_file)) / 255.
21 |
22 | wt = np.zeros(shape=[N, 3], dtype=np.float32)
23 | bias = np.zeros(shape=[N, 3], dtype=np.float32)
24 |
25 | for i in range(N):
26 | sum_x = 0.
27 | sum_y = 0.
28 | sum_xy = 0.
29 | sum_x2 = 0.
30 |
31 | sub_pixels_data = random.sample(pixels_data, K)
32 | n = len(sub_pixels_data)
33 | for pdata in sub_pixels_data:
34 | h_idx, w_idx = pdata['pos']
35 | x = x_img[h_idx, w_idx, :].copy()
36 | y = np.array(pdata['rgb']).copy() / 255.
37 | sum_x += x
38 | sum_y += y
39 | sum_xy += x * y
40 | sum_x2 += x * x
41 | wt[i, :] = (sum_xy - sum_x * sum_y / (eps + n)) / (eps + sum_x2 - sum_x * sum_x / (eps + n))
42 | bias[i, :] = sum_y / (eps + n) - wt[i, :] * sum_x / (eps + n)
43 |
44 | mwt = np.reshape(np.mean(wt, axis=0), (1, 1, 3))
45 | mbias = np.reshape(np.mean(bias, axis=0), (1, 1, 3))
46 | post_process_res = mwt * x_img.copy() + mbias
47 | post_process_res = np.clip(post_process_res, 0., 1.)
48 | post_process_res = (post_process_res * 255).astype(np.uint8)
49 |
50 | save_dir = f"{save_path}/post_process/{scene}"
51 | Path(save_dir).mkdir(parents=True, exist_ok=True)
52 | filename = x_img_file.split('\\')[-1]
53 | Image.fromarray(post_process_res).save(f"{save_dir}/{filename}")
54 |
55 |
56 | linear_regression(ensemble_res_dir, est_pixels)
57 |
--------------------------------------------------------------------------------
/restormer_x/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from pathlib import Path
5 |
6 | import tabulate
7 | import torch
8 | import torch.nn as nn
9 |
10 | from restormer_x.dataset.gt_rain_dataset import get_datasets
11 | from restormer_x.model.restormer import get_model
12 | from restormer_x.utils.log import set_logger
13 | from restormer_x.utils.loss import ShiftMSSSIM
14 | from restormer_x.utils.trainutil import get_train_settings, train
15 |
16 | os.environ["CUDA_VISIBLE_DEVICES"] = '6'
17 |
18 | # CONFIG
19 | params = {
20 | # general
21 | 'method_name': 'restormer_x',
22 | # data
23 | 'train_dir_list': ['/gt-rain/GT-RAIN_train'], # Dir for the training data
24 | 'rain_mask_dir': '/gt-rain/Streaks_Garg06', # Dir for the rain masks
25 | 'img_size': 256, # the size of image input
26 | 'zoom_min': .06, # the minimum zoom for RainMix
27 | 'zoom_max': 1.8, # the maximum zoom for RainMix
28 | 'batch_size': 2, # batch size
29 |
30 | # model
31 | 'model_version': 'base',
32 | 'pretrained_model': '/pre-train-model/gt_rain/restormer_deraining.pth',
33 |
34 | # train
35 | 'ssim_kernel_size': 11, # img_size >= (kernel_size - 1) * 16 + 1
36 | 'initial_lr': 3e-4, # initial learning rate used by scheduler
37 | 'weight_decay': 1e-4,
38 | 'num_epochs': 20, # number of epochs to train
39 | 'warmup_epochs': 4, # number of epochs for warmup
40 | 'min_lr': 1e-6, # minimum learning rate used by scheduler
41 | 'mixmethod': 'mixup',
42 | 'mix_prob': 0.5,
43 | 'ssim_loss_weight': 0.0, # weight for the ssim loss
44 | 'acc_grad_step': 4,
45 | 'save_freq': 1,
46 | 'save_dir': '/gt-rain/model', # Dir to save the model weights
47 | }
48 |
49 | # INIT
50 |
51 | save_path = os.path.join(params['save_dir'], params['method_name'])
52 | Path(save_path).mkdir(parents=True, exist_ok=True)
53 | set_logger(save_path, 'train.log')
54 | logging.info(str(params))
55 |
56 | # DATA
57 |
58 | train_loader = get_datasets(params)
59 |
60 | # MODEL
61 |
62 | model = get_model(model_version=params['model_version'])
63 |
64 | if params['pretrained_model'] is not None:
65 | model.load_state_dict(torch.load(params['pretrained_model'])['params'], strict=False)
66 |
67 | # LOSS
68 |
69 | criterion_l1 = nn.L1Loss().cuda()
70 | criterion_ssim = ShiftMSSSIM(ssim_kernel_size=params['ssim_kernel_size']).cuda()
71 |
72 | # TRAIN
73 |
74 | optimizer, scheduler = get_train_settings(model, params)
75 |
76 | start_epoch = 0
77 |
78 | for epoch in range(start_epoch, params['num_epochs']):
79 | time_ep = time.time()
80 |
81 | train_res = train(model, train_loader, optimizer, scheduler, criterion_l1, criterion_ssim, params)
82 |
83 | if ((epoch + 1) % params['save_freq'] == 0) or ((epoch + 1) == params['num_epochs']):
84 | torch.save(
85 | {
86 | 'epoch': epoch,
87 | 'state_dict': model.state_dict(),
88 | 'optimizer': optimizer.state_dict()
89 | },
90 | os.path.join(save_path, f'model_epoch_{epoch}.pth')
91 | )
92 |
93 | time_ep = time.time() - time_ep
94 | columns = ["epoch", "learning_rate",
95 | "train_loss", "train_ssim_loss", "train_l1_loss",
96 | "cost_time"]
97 |
98 | values = [epoch + 1, optimizer.param_groups[0]['lr'],
99 | train_res["total_loss"], train_res["ssim_loss"], train_res["l1_loss"],
100 | time_ep]
101 |
102 | table = tabulate.tabulate([values], columns, tablefmt="simple", floatfmt="8.4f")
103 | if epoch % 50 == 0:
104 | table = table.split("\n")
105 | table = "\n".join([table[1]] + table)
106 | else:
107 | table = table.split("\n")[2]
108 |
109 | logging.info(table)
110 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Restormer-Plus for Real World Image Deraining: One State-of-the-Art Solution to the GT-RAIN Challenge (CVPR 2023 UG2+ Track 3)
2 | This is the Python code used to implement the Restormer-Plus method as described in the technical report:
3 |
4 | [**Restormer-Plus for Real World Image Deraining: One State-of-the-Art Solution to the GT-RAIN Challenge (CVPR 2023 UG2+ Track 3)**
5 | Chaochao Zheng, Luping Wang, Bin Liu](https://arxiv.org/abs/2305.05454)
6 |
7 | [//]: # (## Technical Report Link)
8 |
9 | [//]: # ([xx](xxx))
10 |
11 | ## Abstract
12 | This technical report presents our Restormer-Plus approach, which was submitted to the GT-RAIN Challenge (CVPR 2023 UG$^2$+ Track 3). Details regarding the challenge are available at http://cvpr2023.ug2challenge.org/track3.html. Our Restormer-Plus outperformed all other submitted solutions in terms of peak signal-to-noise ratio (PSNR). It consists mainly of four modules: the single image de-raining module, the median filtering module, the weighted averaging module, and the post-processing module. We named the single-image de-raining module Restormer-X, which is built on Restormer and performed on each rainy image. The median filtering module is employed as a median operator for the 300 rainy images associated with each scene. The weighted averaging module combines the median filtering results with that of Restormer-X to alleviate overfitting if we only use Restormer-X. Finally, the post-processing module is used to improve the brightness restoration. Together, these modules render Restormer-Plus to be one state-of-the-art solution to the GT-RAIN Challenge. Our code is available at https://github.com/ZJLAB-AMMI/Restormer-Plus.
13 |
14 | ## Dataset
15 | The dataset can be found [here](https://drive.google.com/drive/folders/1NSRl954QPcGIgoyJa_VjQwh_gEaHWPb8).
16 |
17 | ## Requirements
18 |
19 | - einops==0.3.0
20 | - natsort==8.3.1
21 | - numpy==1.21.5
22 | - opencv_contrib_python==4.2.0.32
23 | - Pillow==9.2.0
24 | - piq==0.7.0
25 | - skimage==0.0
26 | - tabulate==0.8.10
27 | - torch==1.12.1
28 | - torchvision==0.13.1
29 |
30 | ## Setup
31 | Download the dataset from the link above and change the parameters in the ```train.py``` and ```test.py``` code to point to the appropriate directories (e.g., ```./gt-rain/```).
32 |
33 | Download the pre-trained de-rain model from [link](https://drive.google.com/drive/folders/1ZEDDEVW0UgkpWi-N4Lj_JUoVChGXCu_u).
34 |
35 | Install all the required packages.
36 |
37 | ## Running
38 | **restormer-x:**
39 |
40 | - training restormer baseline: set ```model_version=base``` and execute ```python /restormer_x/train.py```.
41 |
42 | - training restormer+: set ```model_version=plus``` and execute ```python /restormer_x/train.py```.
43 |
44 | - evaluate and/or test: execute ```python /restormer_x/test.py```.
45 |
46 | **median:** execute ```python /median/median_derain.py```.
47 |
48 | **ensemble:** execute ```python /ensemble/ensemble_derain.py```.
49 |
50 | **post process:** execute ```python /post_process/post_process_derain.py```.
51 |
52 | **submit result:** execute ```python repeat300.py```.
53 |
54 | ## Citation
55 | If you find this code useful, please kindly cite
56 |
57 | @article{zheng2023RestormerPlus,
58 |
59 | title={Restormer-Plus for Real World Image Deraining: One State-of-the-Art Solution to the GT-RAIN Challenge (CVPR 2023 UG2+ Track 3)},
60 |
61 | author={Zheng, Chaochao, Wang, Luping and Liu, Bin},
62 |
63 | journal={arXiv preprint arXiv:2305.05454},
64 |
65 | year={2023}
66 |
67 | }
68 | ## Disclaimer
69 | Please only use the code and dataset for research purposes.
70 |
71 | ## Contact
72 | Chaochao Zheng
73 | Zhejiang Lab, Research Center for Applied Mathematics and Machine Intelligence
74 | zhengcc@zhejianglab.com
75 |
76 | Luping Wang
77 | Zhejiang Lab, Research Center for Applied Mathematics and Machine Intelligence
78 | wangluping@zhejianglab.com
79 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
52 |
53 |
54 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/restormer_x/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from piq import MultiScaleSSIMLoss
5 |
6 |
7 | class ShiftMSSSIM(torch.nn.Module):
8 | """Shifted SSIM Loss """
9 |
10 | def __init__(self, ssim_kernel_size=11):
11 | super(ShiftMSSSIM, self).__init__()
12 | self.ssim = MultiScaleSSIMLoss(kernel_size=ssim_kernel_size, data_range=1.)
13 |
14 | def forward(self, est, gt):
15 | # shift images back into range (0, 1)
16 | # est = est * 0.5 + 0.5
17 | # gt = gt * 0.5 + 0.5
18 | return self.ssim(est, gt)
19 |
20 |
21 | class RainRobustLoss(torch.nn.Module):
22 | """Rain Robust Loss"""
23 |
24 | def __init__(self, batch_size, n_views, device, temperature=0.07):
25 | super(RainRobustLoss, self).__init__()
26 | self.batch_size = batch_size
27 | self.n_views = n_views
28 | self.temperature = temperature
29 | self.device = device
30 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device)
31 |
32 | def forward(self, features):
33 | logits, labels = self.info_nce_loss(features)
34 | return self.criterion(logits, labels)
35 |
36 | def info_nce_loss(self, features):
37 | labels = torch.cat([torch.arange(self.batch_size) for i in range(self.n_views)], dim=0)
38 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
39 | labels = labels.to(self.device)
40 |
41 | features = F.normalize(features, dim=1)
42 |
43 | similarity_matrix = torch.matmul(features, features.T)
44 |
45 | # discard the main diagonal from both: labels and similarities matrix
46 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
47 | labels = labels[~mask].view(labels.shape[0], -1)
48 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
49 |
50 | # select and combine multiple positives
51 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
52 |
53 | # select only the negatives the negatives
54 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
55 |
56 | logits = torch.cat([positives, negatives], dim=1)
57 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
58 |
59 | logits = logits / self.temperature
60 | return logits, labels
61 |
62 |
63 | def rain_robust_loss(params):
64 | return RainRobustLoss(
65 | batch_size=params['batch_size'],
66 | n_views=2,
67 | device=torch.device("cuda"),
68 | temperature=params['temperature']
69 | ).cuda()
70 |
71 |
72 | class AverageMeter(object):
73 | """Computes and stores the average and current value"""
74 |
75 | def __init__(self):
76 | self.reset()
77 |
78 | def reset(self):
79 | self.val = 0
80 | self.avg = 0
81 | self.sum = 0
82 | self.count = 0
83 |
84 | def add(self, val, n=1):
85 | self.val = val
86 | self.sum += val * n
87 | self.count += n
88 |
89 | def value(self):
90 | return self.sum / self.count if self.count > 0 else 0.0
91 |
92 |
93 | class AverageAccMeter(object):
94 |
95 | def __init__(self):
96 | self.reset()
97 |
98 | def reset(self):
99 | self.val = 0
100 | self.avg = 0
101 | self.sum = 0
102 | self.count = 0
103 |
104 | def add(self, output, target):
105 | n = output.size(0)
106 | self.val = self.accuracy(output, target).item()
107 | self.sum += self.val * n
108 | self.count += n
109 |
110 | def value(self):
111 | if self.sum == 0:
112 | return 0
113 | else:
114 | return self.sum / self.count
115 |
116 | def accuracy(self, output, target, topk=(1,)):
117 | """Computes the precision@k for the specified values of k"""
118 | maxk = max(topk)
119 | batch_size = target.size(0)
120 |
121 | _, pred = output.topk(maxk, 1, True, True)
122 | pred = pred.t()
123 | correct = pred.eq(target.view(1, -1).expand_as(pred))
124 |
125 | res = []
126 | for k in topk:
127 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
128 | res.append(correct_k.mul_(100.0 / batch_size))
129 |
130 | return res[0]
131 |
--------------------------------------------------------------------------------
/post_process/estimate_pixels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from glob import glob
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | from PIL import Image
8 | from natsort import natsorted
9 |
10 | # ==========config
11 | f"""
12 | test_median_res_dir: the directory of the median result of test data, achieved by running median_derain.py
13 | train_median_res_dir: the directory of the median result of train data, achieved by running median_derain.py
14 | pixels_file: a .pkl file where contains the position info of the pixels whose values require to be estimated.
15 | Format: a dict, the key is the scene name, the value is a list of pixel-position.
16 | save_dir: where to save the similar patches.
17 | patch_size: the size of the patch.
18 | min_dis: the threshold used to select similar patches.
19 | """
20 | test_median_res_dir = '/gt-rain/result/test_median'
21 | train_median_res_dir = '/gt-rain/result/train_median'
22 | pixels_file = '/gt-rain/result/pixels.pkl'
23 | save_dir = '/gt-rain/result/similar_patch'
24 | patch_size = 8
25 | min_dis = 6
26 | # ==========
27 |
28 | pixels = pickle.load(open(pixels_file, 'rb'))
29 |
30 |
31 | def get_img_paths(data_dir):
32 | scene_names = []
33 | for sc in list(os.walk(data_dir))[0][1]:
34 | scene_names.append(sc)
35 | img_paths = []
36 | for scene in scene_names:
37 | img_paths.append(
38 | (
39 | natsorted(glob(os.path.join(data_dir, scene, '*-R-*.png')))[0],
40 | natsorted(glob(os.path.join(data_dir, scene, '*-C-*.png')))[0]
41 | )
42 | )
43 | return img_paths
44 |
45 |
46 | for scene_name, pixels_pos in pixels.items():
47 | test_median_res = np.asarray(Image.open(os.path.join(test_median_res_dir, scene_name, '1_r.png')))
48 | hts, wts, cts = test_median_res.shape
49 | train_img_paths = get_img_paths(train_median_res_dir)
50 |
51 | for pixel_pos in pixels_pos:
52 | save_path = f"{save_dir}/{scene_name}/{str(pixel_pos[0]) + '_' + str(pixel_pos[1])}"
53 | Path(save_path).mkdir(parents=True, exist_ok=True)
54 |
55 | # test patch
56 | hts1 = np.clip(pixel_pos[0] - patch_size // 2, 0, hts)
57 | hts2 = np.clip(pixel_pos[0] + patch_size // 2, 0, hts)
58 | wts1 = np.clip(pixel_pos[1] - patch_size // 2, 0, wts)
59 | wts2 = np.clip(pixel_pos[1] + patch_size // 2, 0, wts)
60 | test_patch = test_median_res[hts1: hts2, wts1: wts2, :]
61 | Image.fromarray(test_patch).save(f"{save_path}/test_patch.png")
62 |
63 | h_patch_size = hts2 - hts1
64 | w_patch_size = wts2 - wts1
65 |
66 | # search and save similar patch in train data
67 | for train_median_res_file, train_clean_file in train_img_paths:
68 | train_median_res = np.asarray(Image.open(train_median_res_file))
69 | train_clean = np.asarray(Image.open(train_clean_file))
70 | htr, wtr, ctr = train_median_res.shape
71 | h_gap = (htr - h_patch_size) // 30
72 | w_gap = (wtr - w_patch_size) // 30
73 | for h_idx in range(0, htr - h_patch_size, h_gap):
74 | for w_idx in range(0, wtr - w_patch_size, w_gap):
75 | train_median_patch = train_median_res[h_idx: (h_idx + h_patch_size), w_idx: (w_idx + w_patch_size),
76 | :]
77 | train_clean_patch = train_clean[h_idx: (h_idx + h_patch_size), w_idx: (w_idx + w_patch_size), :]
78 |
79 | distance = np.median(
80 | np.abs(test_patch.flatten() - train_median_patch.flatten())
81 | )
82 |
83 | if distance <= min_dis:
84 | pred_val = np.mean(train_clean_patch, axis=(0, 1))
85 | Image.fromarray(train_median_patch).save(
86 | os.path.join(save_path,
87 | 'train_median_patch_{}_{}_{}.png'.format(h_idx, w_idx, np.round(distance, 3))))
88 | Image.fromarray(train_clean_patch).save(os.path.join(save_path,
89 | 'train_clean_patch_{}_{}_{}.png'.format(
90 | h_idx,
91 | w_idx,
92 | np.round(pred_val, 3))))
93 |
--------------------------------------------------------------------------------
/restormer_x/utils/data_augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageChops, ImageOps, ImageEnhance
3 |
4 |
5 | def sample_level(n):
6 | return np.random.uniform(low=0.1, high=n)
7 |
8 |
9 | def int_parameter(level, maxval):
10 | """Helper function to scale `val` between 0 and maxval .
11 |
12 | Args:
13 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
14 | maxval: Maximum value that the operation can have. This will be scaled to
15 | level/PARAMETER_MAX.
16 |
17 | Returns:
18 | An int that results from scaling `maxval` according to `level`.
19 | """
20 | return int(level * maxval / 10)
21 |
22 |
23 | def float_parameter(level, maxval):
24 | """Helper function to scale `val` between 0 and maxval.
25 |
26 | Args:
27 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
28 | maxval: Maximum value that the operation can have. This will be scaled to
29 | level/PARAMETER_MAX.
30 |
31 | Returns:
32 | A float that results from scaling `maxval` according to `level`.
33 | """
34 | return float(level) * maxval / 10.
35 |
36 |
37 | def autocontrast(pil_img, _):
38 | return ImageOps.autocontrast(pil_img)
39 |
40 |
41 | def equalize(pil_img, _):
42 | return ImageOps.equalize(pil_img)
43 |
44 |
45 | def posterize(pil_img, level):
46 | level = int_parameter(sample_level(level), 4)
47 | return ImageOps.posterize(pil_img, 4 - level)
48 |
49 |
50 | def rotate(pil_img, level):
51 | degrees = int_parameter(sample_level(level), 30)
52 | if np.random.uniform() > 0.5:
53 | degrees = -degrees
54 | return pil_img.rotate(degrees, resample=Image.BILINEAR)
55 |
56 |
57 | def solarize(pil_img, level):
58 | level = int_parameter(sample_level(level), 256)
59 | return ImageOps.solarize(pil_img, 256 - level)
60 |
61 |
62 | def shear_x(pil_img, level):
63 | level = float_parameter(sample_level(level), 0.3)
64 | if np.random.uniform() > 0.5:
65 | level = -level
66 | return pil_img.transform(
67 | (pil_img.width, pil_img.height),
68 | Image.AFFINE, (1, level, 0, 0, 1, 0),
69 | resample=Image.BILINEAR)
70 |
71 |
72 | def shear_y(pil_img, level):
73 | level = float_parameter(sample_level(level), 0.3)
74 | if np.random.uniform() > 0.5:
75 | level = -level
76 | return pil_img.transform(
77 | (pil_img.width, pil_img.height),
78 | Image.AFFINE, (1, 0, 0, level, 1, 0),
79 | resample=Image.BILINEAR)
80 |
81 |
82 | def roll_x(pil_img, level):
83 | """Roll an image sideways."""
84 | delta = int_parameter(sample_level(level), pil_img.width / 3)
85 | if np.random.random() > 0.5:
86 | delta = -delta
87 | xsize, ysize = pil_img.size
88 | delta = delta % xsize
89 | if delta == 0: return pil_img
90 | part1 = pil_img.crop((0, 0, delta, ysize))
91 | part2 = pil_img.crop((delta, 0, xsize, ysize))
92 | pil_img.paste(part1, (xsize - delta, 0, xsize, ysize))
93 | pil_img.paste(part2, (0, 0, xsize - delta, ysize))
94 |
95 | return pil_img
96 |
97 |
98 | def roll_y(pil_img, level):
99 | """Roll an image sideways."""
100 | delta = int_parameter(sample_level(level), pil_img.width / 3)
101 | if np.random.random() > 0.5:
102 | delta = -delta
103 | xsize, ysize = pil_img.size
104 | delta = delta % ysize
105 | if delta == 0: return pil_img
106 | part1 = pil_img.crop((0, 0, xsize, delta))
107 | part2 = pil_img.crop((0, delta, xsize, ysize))
108 | pil_img.paste(part1, (0, ysize - delta, xsize, ysize))
109 | pil_img.paste(part2, (0, 0, xsize, ysize - delta))
110 |
111 | return pil_img
112 |
113 |
114 | # operation that overlaps with ImageNet-C's test set
115 | def color(pil_img, level):
116 | level = float_parameter(sample_level(level), 1.8) + 0.1
117 | return ImageEnhance.Color(pil_img).enhance(level)
118 |
119 |
120 | # operation that overlaps with ImageNet-C's test set
121 | def contrast(pil_img, level):
122 | level = float_parameter(sample_level(level), 1.8) + 0.1
123 | return ImageEnhance.Contrast(pil_img).enhance(level)
124 |
125 |
126 | # operation that overlaps with ImageNet-C's test set
127 | def brightness(pil_img, level):
128 | level = float_parameter(sample_level(level), 1.8) + 0.1
129 | return ImageEnhance.Brightness(pil_img).enhance(level)
130 |
131 |
132 | # operation that overlaps with ImageNet-C's test set
133 | def sharpness(pil_img, level):
134 | level = float_parameter(sample_level(level), 1.8) + 0.1
135 | return ImageEnhance.Sharpness(pil_img).enhance(level)
136 |
137 |
138 | def zoom_x(pil_img, level):
139 | # zoom from .02 to 2.5
140 | rate = level
141 | zoom_img = pil_img.transform(
142 | (pil_img.width, pil_img.height),
143 | Image.AFFINE, (rate, 0, 0, 0, 1, 0),
144 | resample=Image.BILINEAR)
145 | # need to do reflect padding
146 | if rate > 1.0:
147 | orig_x, orig_y = pil_img.size
148 | new_x = int(orig_x / rate)
149 | zoom_img = np.array(zoom_img)
150 | zoom_img = np.pad(zoom_img[:, :new_x, :], ((0, 0), (0, orig_x - new_x), (0, 0)), 'wrap')
151 | return zoom_img
152 |
153 |
154 | def zoom_y(pil_img, level):
155 | # zoom from .02 to 2.5
156 | rate = level
157 | zoom_img = pil_img.transform(
158 | (pil_img.width, pil_img.height),
159 | Image.AFFINE, (1, 0, 0, 0, rate, 0),
160 | resample=Image.BILINEAR)
161 | # need to do reflect padding
162 | if rate > 1.0:
163 | orig_x, orig_y = pil_img.size
164 | new_y = int(orig_y / rate)
165 | zoom_img = np.array(zoom_img)
166 | zoom_img = np.pad(zoom_img[:new_y, :, :], ((0, orig_y - new_y), (0, 0), (0, 0)), 'wrap')
167 | return zoom_img
168 |
169 |
170 | augmentations = [
171 | rotate, shear_x, shear_y,
172 | zoom_x, zoom_y, roll_x, roll_y
173 | ]
174 |
175 |
176 |
--------------------------------------------------------------------------------
/restormer_x/utils/trainutil.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | import torchvision.transforms.functional as TF
10 | from PIL import Image
11 | from natsort import natsorted
12 | from skimage.metrics import peak_signal_noise_ratio as psnr
13 | from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
14 |
15 | from restormer_x.utils.mixmethod import mixup
16 | from restormer_x.utils.loss import AverageMeter
17 |
18 |
19 | class GradualWarmupScheduler(_LRScheduler):
20 | """ Gradually warm-up(increasing) learning rate in optimizer.
21 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
22 | Args:
23 | optimizer (Optimizer): Wrapped optimizer.
24 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
25 | total_epoch: target learning rate is reached at total_epoch, gradually
26 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
27 | """
28 |
29 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
30 | self.multiplier = multiplier
31 | if self.multiplier < 1.:
32 | raise ValueError('multiplier should be greater thant or equal to 1.')
33 | self.total_epoch = total_epoch
34 | self.after_scheduler = after_scheduler
35 | self.finished = False
36 | super(GradualWarmupScheduler, self).__init__(optimizer)
37 |
38 | def get_lr(self):
39 | if self.last_epoch > self.total_epoch:
40 | if self.after_scheduler:
41 | if not self.finished:
42 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
43 | self.finished = True
44 | return self.after_scheduler.get_last_lr()
45 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
46 |
47 | if self.multiplier == 1.0:
48 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
49 | else:
50 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
51 | self.base_lrs]
52 |
53 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
54 | if epoch is None:
55 | epoch = self.last_epoch + 1
56 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
57 | if self.last_epoch <= self.total_epoch:
58 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
59 | self.base_lrs]
60 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
61 | param_group['lr'] = lr
62 | else:
63 | if epoch is None:
64 | self.after_scheduler.step(metrics, None)
65 | else:
66 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
67 |
68 | def step(self, epoch=None, metrics=None):
69 | if type(self.after_scheduler) != ReduceLROnPlateau:
70 | if self.finished and self.after_scheduler:
71 | if epoch is None:
72 | self.after_scheduler.step(None)
73 | else:
74 | self.after_scheduler.step(epoch - self.total_epoch)
75 | self._last_lr = self.after_scheduler.get_last_lr()
76 | else:
77 | return super(GradualWarmupScheduler, self).step(epoch)
78 | else:
79 | self.step_ReduceLROnPlateau(metrics, epoch)
80 |
81 |
82 | def get_train_settings(model, params):
83 | optimizer = optim.AdamW(
84 | model.parameters(),
85 | lr=params['initial_lr'],
86 | weight_decay=params['weight_decay']
87 | )
88 |
89 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(
90 | optimizer,
91 | params['num_epochs'] - params['warmup_epochs'],
92 | eta_min=params['min_lr'])
93 |
94 | scheduler = GradualWarmupScheduler(
95 | optimizer,
96 | multiplier=1.0,
97 | total_epoch=params['warmup_epochs'],
98 | after_scheduler=scheduler_cosine
99 | )
100 |
101 | optimizer.zero_grad()
102 | optimizer.step()
103 | scheduler.step() # To start warmup
104 |
105 | return optimizer, scheduler
106 |
107 |
108 | def train(model, train_loader, optimizer, scheduler, criterion_l1, criterion_ssim, params):
109 | model.train()
110 |
111 | total_losses = AverageMeter()
112 | l1_losses = AverageMeter()
113 | ssim_losses = AverageMeter()
114 | num_batchs = len(train_loader.dataset) // params['batch_size']
115 | for batch_idx, batch_data in enumerate(train_loader):
116 | input_img = batch_data['input_img'].cuda()
117 | target_img = batch_data['target_img'].cuda()
118 |
119 | if (params['mixmethod'] == 'mixup') and (np.random.rand(1) <= params['mix_prob']):
120 | input_img, target_img = mixup(input_img, target_img)
121 |
122 | output_img = model(input_img)
123 |
124 | l1_loss = criterion_l1(output_img, target_img)
125 | loss = l1_loss
126 | l1_losses.add(l1_loss.item(), input_img.size(0))
127 |
128 | if params['ssim_loss_weight'] > 0:
129 | ssim_loss = criterion_ssim(output_img.clip(0., 1.), target_img)
130 | loss += params['ssim_loss_weight'] * ssim_loss
131 | ssim_losses.add(ssim_loss.item(), input_img.size(0))
132 |
133 | total_losses.add(loss.item(), input_img.size(0))
134 |
135 | acc_grad_step = params['acc_grad_step']
136 | loss = loss / acc_grad_step
137 | loss.backward()
138 |
139 | if (((batch_idx + 1) % acc_grad_step) == 0) or ((batch_idx + 1) == num_batchs):
140 | optimizer.step()
141 | optimizer.zero_grad()
142 |
143 | scheduler.step()
144 |
145 | return {
146 | 'total_loss': total_losses.value(),
147 | 'ssim_loss': ssim_losses.value(),
148 | 'l1_loss': l1_losses.value()
149 | }
150 |
151 |
152 | def predict(model, root_dir, is_test=False, eta=8, save_path=None, method_name=None):
153 | model.eval()
154 | scene_names = []
155 | for sc in list(os.walk(root_dir))[0][1]:
156 | scene_names.append(sc)
157 |
158 | img_paths = {}
159 | for scene in scene_names:
160 | scene_path = os.path.join(root_dir, scene)
161 | if is_test:
162 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*_r.png')))
163 | else:
164 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*R-*.png')))
165 | img_paths[scene] = scene_img_paths
166 |
167 | mean_output = {}
168 | with torch.no_grad():
169 | for scene_name, im_paths in img_paths.items():
170 | print(scene_name)
171 | if scene_name not in mean_output:
172 | mean_output[scene_name] = {'sum_im': 0.0, 'num_im': 0}
173 | for im_path in im_paths:
174 | img = Image.open(im_path)
175 | img = np.array(img)
176 | img = TF.to_tensor(img) # [c, h, w]
177 | h, w = img.shape[1:]
178 | padw = eta - (w % eta) if (w % eta) != 0 else 0
179 | padh = eta - (h % eta) if (h % eta) != 0 else 0
180 | if padw != 0 or padh != 0:
181 | img = F.pad(img, (0, padw, 0, padh), mode='reflect')
182 |
183 | input = torch.unsqueeze(img, 0).cuda()
184 | output = model(input)
185 | output = output.squeeze().permute((1, 2, 0))
186 | output = output.detach().cpu().numpy()[:h, :w, :]
187 |
188 | mean_output[scene_name]['sum_im'] += output
189 | mean_output[scene_name]['num_im'] += 1
190 |
191 | psnr_res = {'scene_psnr': {}, 'psnr': [0.0]}
192 | for scene_name, res in mean_output.items():
193 | output = res['sum_im'] / res['num_im']
194 | output = np.clip(output, 0.0, 1.0)
195 | if not is_test:
196 | tmp = img_paths[scene_name][0]
197 | tar_path = tmp[:-9] + 'C-000.png'
198 | if 'Gurutto_1-2' in im_path:
199 | tar_path = tmp[:-9] + 'C' + tmp[-8:]
200 | tar_img = Image.open(tar_path)
201 | tar_img = np.array(tar_img, dtype=np.float32)
202 | tar_img = tar_img / 255 # [h, w, c]
203 |
204 | psnr_val = psnr(tar_img, output)
205 | psnr_res['scene_psnr'][scene_name] = psnr_val
206 | psnr_res['psnr'] += psnr_val
207 | else:
208 | save_dir = f"{save_path}/{method_name}/test/{scene_name}"
209 | Path(save_dir).mkdir(parents=True, exist_ok=True)
210 | output = (output * 255).astype(np.uint8)
211 | filename = img_paths[scene_name][0].split('/')[-1]
212 | Image.fromarray(output).save(f"{save_dir}/{filename}")
213 | psnr_res['psnr'][0] /= len(mean_output.keys())
214 | return psnr_res
215 |
--------------------------------------------------------------------------------
/restormer_x/model/restormer.py:
--------------------------------------------------------------------------------
1 | import numbers
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch import nn
7 |
8 |
9 | class OverlapPatchEmbed(nn.Module):
10 | def __init__(self, in_c=3, embed_dim=48, bias=False):
11 | super(OverlapPatchEmbed, self).__init__()
12 |
13 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
14 |
15 | def forward(self, x):
16 | x = self.proj(x)
17 |
18 | return x
19 |
20 |
21 | class BiasFree_LayerNorm(nn.Module):
22 | def __init__(self, normalized_shape):
23 | super(BiasFree_LayerNorm, self).__init__()
24 | if isinstance(normalized_shape, numbers.Integral):
25 | normalized_shape = (normalized_shape,)
26 | normalized_shape = torch.Size(normalized_shape)
27 |
28 | assert len(normalized_shape) == 1
29 |
30 | self.weight = nn.Parameter(torch.ones(normalized_shape))
31 | self.normalized_shape = normalized_shape
32 |
33 | def forward(self, x):
34 | sigma = x.var(-1, keepdim=True, unbiased=False)
35 | return x / torch.sqrt(sigma + 1e-5) * self.weight
36 |
37 |
38 | class WithBias_LayerNorm(nn.Module):
39 | def __init__(self, normalized_shape):
40 | super(WithBias_LayerNorm, self).__init__()
41 | if isinstance(normalized_shape, numbers.Integral):
42 | normalized_shape = (normalized_shape,)
43 | normalized_shape = torch.Size(normalized_shape)
44 |
45 | assert len(normalized_shape) == 1
46 |
47 | self.weight = nn.Parameter(torch.ones(normalized_shape))
48 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
49 | self.normalized_shape = normalized_shape
50 |
51 | def forward(self, x):
52 | mu = x.mean(-1, keepdim=True)
53 | sigma = x.var(-1, keepdim=True, unbiased=False)
54 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
55 |
56 |
57 | def to_3d(x):
58 | return rearrange(x, 'b c h w -> b (h w) c')
59 |
60 |
61 | def to_4d(x, h, w):
62 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
63 |
64 |
65 | class LayerNorm(nn.Module):
66 | def __init__(self, dim, LayerNorm_type):
67 | super(LayerNorm, self).__init__()
68 | if LayerNorm_type == 'BiasFree':
69 | self.body = BiasFree_LayerNorm(dim)
70 | else:
71 | self.body = WithBias_LayerNorm(dim)
72 |
73 | def forward(self, x):
74 | h, w = x.shape[-2:]
75 | return to_4d(self.body(to_3d(x)), h, w)
76 |
77 |
78 | class Attention(nn.Module):
79 | def __init__(self, dim, num_heads, bias):
80 | super(Attention, self).__init__()
81 | self.num_heads = num_heads
82 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
83 |
84 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
85 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
86 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 |
91 | qkv = self.qkv_dwconv(self.qkv(x))
92 | q, k, v = qkv.chunk(3, dim=1)
93 |
94 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
95 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
96 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
97 |
98 | q = torch.nn.functional.normalize(q, dim=-1)
99 | k = torch.nn.functional.normalize(k, dim=-1)
100 |
101 | attn = (q @ k.transpose(-2, -1)) * self.temperature
102 | attn = attn.softmax(dim=-1)
103 |
104 | out = (attn @ v)
105 |
106 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
107 |
108 | out = self.project_out(out)
109 | return out
110 |
111 |
112 | class FeedForward(nn.Module):
113 | def __init__(self, dim, ffn_expansion_factor, bias):
114 | super(FeedForward, self).__init__()
115 |
116 | hidden_features = int(dim * ffn_expansion_factor)
117 |
118 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
119 |
120 | self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
121 | groups=hidden_features * 2, bias=bias)
122 |
123 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
124 |
125 | def forward(self, x):
126 | x = self.project_in(x)
127 | x1, x2 = self.dwconv(x).chunk(2, dim=1)
128 | x = F.gelu(x1) * x2
129 | x = self.project_out(x)
130 | return x
131 |
132 |
133 | class TransformerBlock(nn.Module):
134 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
135 | super(TransformerBlock, self).__init__()
136 |
137 | self.norm1 = LayerNorm(dim, LayerNorm_type)
138 | self.attn = Attention(dim, num_heads, bias)
139 | self.norm2 = LayerNorm(dim, LayerNorm_type)
140 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
141 |
142 | def forward(self, x):
143 | x = x + self.attn(self.norm1(x))
144 | x = x + self.ffn(self.norm2(x))
145 |
146 | return x
147 |
148 |
149 | class Downsample(nn.Module):
150 | def __init__(self, n_feat):
151 | super(Downsample, self).__init__()
152 |
153 | self.body = nn.Sequential(
154 | nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
155 | nn.PixelUnshuffle(2)
156 | )
157 |
158 | def forward(self, x):
159 | return self.body(x)
160 |
161 |
162 | class Upsample(nn.Module):
163 | def __init__(self, n_feat):
164 | super(Upsample, self).__init__()
165 |
166 | self.body = nn.Sequential(
167 | nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
168 | nn.PixelShuffle(2)
169 | )
170 |
171 | def forward(self, x):
172 | return self.body(x)
173 |
174 |
175 | class Restormer(nn.Module):
176 | def __init__(
177 | self,
178 | inp_channels=3,
179 | out_channels=3,
180 | dim=48,
181 | num_blocks=[4, 6, 6, 8],
182 | num_refinement_blocks=4,
183 | heads=[1, 2, 4, 8],
184 | ffn_expansion_factor=2.66,
185 | bias=False,
186 | LayerNorm_type='WithBias',
187 | version='base' # base or plus
188 | ):
189 | super(Restormer, self).__init__()
190 |
191 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
192 |
193 | self.encoder_level1 = nn.Sequential(*[
194 | TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
195 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
196 |
197 | self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
198 | self.encoder_level2 = nn.Sequential(*[
199 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
200 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
201 |
202 | self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3
203 | self.encoder_level3 = nn.Sequential(*[
204 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
205 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
206 |
207 | self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4
208 | self.latent = nn.Sequential(*[
209 | TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor,
210 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
211 |
212 | self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3
213 | self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
214 | self.decoder_level3 = nn.Sequential(*[
215 | TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
216 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
217 |
218 | self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2
219 | self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
220 | self.decoder_level2 = nn.Sequential(*[
221 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
222 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
223 |
224 | self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
225 |
226 | self.decoder_level1 = nn.Sequential(*[
227 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
228 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
229 |
230 | self.refinement = nn.Sequential(*[
231 | TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
232 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
233 | self.output_wt = None
234 | if version == 'plus':
235 | self.output_wt = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
236 |
237 | self.output_bias = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
238 |
239 | self.version = version
240 |
241 | def forward(self, inp_img):
242 | inp_enc_level1 = self.patch_embed(inp_img)
243 | out_enc_level1 = self.encoder_level1(inp_enc_level1)
244 |
245 | inp_enc_level2 = self.down1_2(out_enc_level1)
246 | out_enc_level2 = self.encoder_level2(inp_enc_level2)
247 |
248 | inp_enc_level3 = self.down2_3(out_enc_level2)
249 | out_enc_level3 = self.encoder_level3(inp_enc_level3)
250 |
251 | inp_enc_level4 = self.down3_4(out_enc_level3)
252 | latent = self.latent(inp_enc_level4)
253 |
254 | inp_dec_level3 = self.up4_3(latent)
255 | inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
256 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
257 | out_dec_level3 = self.decoder_level3(inp_dec_level3)
258 |
259 | inp_dec_level2 = self.up3_2(out_dec_level3)
260 | inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
261 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
262 | out_dec_level2 = self.decoder_level2(inp_dec_level2)
263 |
264 | inp_dec_level1 = self.up2_1(out_dec_level2)
265 | inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
266 | out_dec_level1 = self.decoder_level1(inp_dec_level1)
267 |
268 | out_dec_level1 = self.refinement(out_dec_level1)
269 |
270 | if self.version == 'plus' and self.output_wt is not None:
271 | out_dec_level1 = self.output_wt(out_dec_level1) * inp_img + self.output_bias(out_dec_level1)
272 | else:
273 | out_dec_level1 = inp_img + self.output_bias(out_dec_level1)
274 |
275 | return out_dec_level1
276 |
277 |
278 | def get_model(model_version='base'):
279 | model = Restormer(version=model_version)
280 | model.cuda()
281 | return model
282 |
--------------------------------------------------------------------------------
/restormer_x/dataset/gt_rain_dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import os
4 | import random
5 | from glob import glob
6 |
7 | import cv2
8 | import numpy as np
9 | import torch.nn.functional as F
10 | import torchvision.transforms.functional as TF
11 | from PIL import Image
12 | from natsort import natsorted
13 | from torch.utils.data import Dataset, DataLoader
14 |
15 | from restormer_x.utils.data_augmentation import augmentations, zoom_x, zoom_y
16 |
17 |
18 | def getRainLayer2(rand_id1, rand_id2, rain_mask_dir):
19 | path_img_rainlayer_src = os.path.join(rain_mask_dir, f'{rand_id1}-{rand_id2}.png')
20 | rainlayer_rand = cv2.imread(path_img_rainlayer_src).astype(np.float32) / 255.0
21 | rainlayer_rand = cv2.cvtColor(rainlayer_rand, cv2.COLOR_BGR2RGB)
22 | return rainlayer_rand
23 |
24 |
25 | def getRandRainLayer2(rain_mask_dir):
26 | rand_id1 = random.randint(1, 165)
27 | rand_id2 = random.randint(4, 8)
28 | rainlayer_rand = getRainLayer2(rand_id1, rand_id2, rain_mask_dir)
29 | return rainlayer_rand
30 |
31 |
32 | def apply_op(image, op, severity):
33 | image = np.clip(image * 255., 0, 255).astype(np.uint8)
34 | pil_img = Image.fromarray(image) # Convert to PIL.Image
35 | pil_img = op(pil_img, severity)
36 | return np.asarray(pil_img) / 255.
37 |
38 |
39 | def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1., zoom_min=0.06, zoom_max=1.8):
40 | """Perform AugMix augmentations and compute mixture.
41 | Args:
42 | image: Raw input image as float32 np.ndarray of shape (h, w, c)
43 | severity: Severity of underlying augmentation operators (between 1 to 10).
44 | width: Width of augmentation chain
45 | depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
46 | from [1, 3]
47 | alpha: Probability coefficient for Beta and Dirichlet distributions.
48 | Returns:
49 | mixed: Augmented and mixed image.
50 | """
51 | ws = np.float32(
52 | np.random.dirichlet([alpha] * width))
53 | m = np.float32(np.random.beta(alpha, alpha))
54 |
55 | mix = np.zeros_like(image)
56 | for i in range(width):
57 | image_aug = image.copy()
58 | depth = depth if depth > 0 else np.random.randint(2, 4)
59 | for _ in range(depth):
60 | op = np.random.choice(augmentations)
61 | if (op == zoom_x or op == zoom_y):
62 | rate = np.random.uniform(low=zoom_min, high=zoom_max)
63 | image_aug = apply_op(image_aug, op, rate)
64 | else:
65 | image_aug = apply_op(image_aug, op, severity)
66 | # Preprocessing commutes since all coefficients are convex
67 | mix += ws[i] * image_aug
68 |
69 | max_ws = max(ws)
70 | rate = 1.0 / max_ws
71 |
72 | mixed = max((1 - m), 0.7) * image + max(m, rate * 0.5) * mix
73 | return mixed
74 |
75 |
76 | class RandomCrop(object):
77 | def __init__(self, image_size, crop_size):
78 | self.ch, self.cw = crop_size
79 | ih, iw = image_size
80 |
81 | self.h1 = random.randint(0, ih - self.ch)
82 | self.w1 = random.randint(0, iw - self.cw)
83 |
84 | self.h2 = self.h1 + self.ch
85 | self.w2 = self.w1 + self.cw
86 |
87 | def __call__(self, img):
88 | if len(img.shape) == 3:
89 | return img[self.h1: self.h2, self.w1: self.w2, :]
90 | else:
91 | return img[self.h1: self.h2, self.w1: self.w2]
92 |
93 |
94 | def rain_aug(img_rainy, img_gt, rain_mask_dir, zoom_min=0.06, zoom_max=1.8):
95 | img_rainy = (img_rainy.astype(np.float32)) / 255.0
96 | img_gt = (img_gt.astype(np.float32)) / 255.0
97 | img_rainy_ret = img_rainy
98 | img_gt_ret = img_gt
99 |
100 | rainlayer_rand2 = getRandRainLayer2(rain_mask_dir)
101 | rainlayer_aug2 = augment_and_mix(
102 | rainlayer_rand2,
103 | severity=3,
104 | width=3,
105 | depth=-1,
106 | zoom_min=zoom_min,
107 | zoom_max=zoom_max
108 | ) * 1
109 |
110 | height = min(img_rainy.shape[0], rainlayer_aug2.shape[0])
111 | width = min(img_rainy.shape[1], rainlayer_aug2.shape[1])
112 |
113 | cropper = RandomCrop(rainlayer_aug2.shape[:2], (height, width))
114 | rainlayer_aug2_crop = cropper(rainlayer_aug2)
115 | cropper = RandomCrop(img_rainy.shape[:2], (height, width))
116 | img_rainy_ret = cropper(img_rainy_ret)
117 | img_gt_ret = cropper(img_gt_ret)
118 | img_rainy_ret = img_rainy_ret + rainlayer_aug2_crop - img_rainy_ret * rainlayer_aug2_crop
119 | img_rainy_ret = np.clip(img_rainy_ret, 0.0, 1.0)
120 | img_rainy_ret = (img_rainy_ret * 255).astype(np.uint8)
121 | img_gt_ret = (img_gt_ret * 255).astype(np.uint8)
122 |
123 | return img_rainy_ret, img_gt_ret
124 |
125 |
126 | def get_translation_matrix_2d(dx, dy):
127 | """
128 | Returns a numpy affine transformation matrix for a 2D translation of
129 | (dx, dy)
130 | """
131 | return np.matrix([[1, 0, dx], [0, 1, dy], [0, 0, 1]])
132 |
133 |
134 | def rotate_image(image, angle):
135 | """
136 | Rotates the given image about it's centre
137 | """
138 |
139 | image_size = (image.shape[1], image.shape[0])
140 | image_center = tuple(np.array(image_size) / 2)
141 |
142 | rot_mat = np.vstack([cv2.getRotationMatrix2D(image_center, angle, 1.0), [0, 0, 1]])
143 | trans_mat = np.identity(3)
144 |
145 | w2 = image_size[0] * 0.5
146 | h2 = image_size[1] * 0.5
147 |
148 | rot_mat_notranslate = np.matrix(rot_mat[0:2, 0:2])
149 |
150 | tl = (np.array([-w2, h2]) * rot_mat_notranslate).A[0]
151 | tr = (np.array([w2, h2]) * rot_mat_notranslate).A[0]
152 | bl = (np.array([-w2, -h2]) * rot_mat_notranslate).A[0]
153 | br = (np.array([w2, -h2]) * rot_mat_notranslate).A[0]
154 |
155 | x_coords = [pt[0] for pt in [tl, tr, bl, br]]
156 | x_pos = [x for x in x_coords if x > 0]
157 | x_neg = [x for x in x_coords if x < 0]
158 |
159 | y_coords = [pt[1] for pt in [tl, tr, bl, br]]
160 | y_pos = [y for y in y_coords if y > 0]
161 | y_neg = [y for y in y_coords if y < 0]
162 |
163 | right_bound = max(x_pos)
164 | left_bound = min(x_neg)
165 | top_bound = max(y_pos)
166 | bot_bound = min(y_neg)
167 |
168 | new_w = int(abs(right_bound - left_bound))
169 | new_h = int(abs(top_bound - bot_bound))
170 | new_image_size = (new_w, new_h)
171 |
172 | new_midx = new_w * 0.5
173 | new_midy = new_h * 0.5
174 |
175 | dx = int(new_midx - w2)
176 | dy = int(new_midy - h2)
177 |
178 | trans_mat = get_translation_matrix_2d(dx, dy)
179 | affine_mat = (np.matrix(trans_mat) * np.matrix(rot_mat))[0:2, :]
180 | result = cv2.warpAffine(image, affine_mat, new_image_size, flags=cv2.INTER_LINEAR)
181 |
182 | return result
183 |
184 |
185 | def rotated_rect_with_max_area(w, h, angle):
186 | """
187 | Given a rectangle of size wxh that has been rotated by 'angle' (in
188 | radians), computes the width and height of the largest possible
189 | axis-aligned rectangle (maximal area) within the rotated rectangle.
190 | """
191 | if w <= 0 or h <= 0:
192 | return 0, 0
193 |
194 | width_is_longer = w >= h
195 | side_long, side_short = (w, h) if width_is_longer else (h, w)
196 |
197 | # since the solutions for angle, -angle and 180-angle are all the same,
198 | # if suffices to look at the first quadrant and the absolute values of sin,cos:
199 | sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
200 | if side_short <= 2. * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10:
201 | # half constrained case: two crop corners touch the longer side,
202 | # the other two corners are on the mid-line parallel to the longer line
203 | x = 0.5 * side_short
204 | wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a)
205 | else:
206 | # fully constrained case: crop touches all 4 sides
207 | cos_2a = cos_a * cos_a - sin_a * sin_a
208 | wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a
209 |
210 | return int(wr), int(hr)
211 |
212 |
213 | def gen_rotate_image(img, angle):
214 | dim = img.shape
215 | h = dim[0]
216 | w = dim[1]
217 |
218 | img = rotate_image(img, angle)
219 | dim_bb = img.shape
220 | h_bb = dim_bb[0]
221 | w_bb = dim_bb[1]
222 |
223 | w_r, h_r = rotated_rect_with_max_area(w, h, math.radians(angle))
224 |
225 | w_0 = (w_bb - w_r) // 2
226 | h_0 = (h_bb - h_r) // 2
227 | img = img[h_0:h_0 + h_r, w_0:w_0 + w_r, :]
228 |
229 | return img
230 |
231 |
232 | class GTRainDataset(Dataset):
233 | """
234 | The dataset class for weather net training and validation.
235 |
236 | Parameters:
237 | train_dir_list (list) -- list of dirs for the dataset.
238 | val_dir_list (list) -- list of dirs for the dataset.
239 | rain_mask_dir (string) -- location of rain masks for data augmentation.
240 | img_size (int) -- size of the images after cropping.
241 | is_train (bool) -- True for training set.
242 | val_list (list) -- list of validation scenes
243 | sigma (int) -- variance for random angle rotation data augmentation
244 | zoom_min (float) -- minimum zoom for RainMix data augmentation
245 | zoom_max (float) -- maximum zoom for RainMix data augmentation
246 | """
247 |
248 | def __init__(
249 | self,
250 | train_dir_list=None,
251 | rain_mask_dir=None,
252 | img_size=256,
253 | sigma=13,
254 | zoom_min=0.06,
255 | zoom_max=1.8
256 | ):
257 | super(GTRainDataset, self).__init__()
258 |
259 | self.rain_mask_dir = rain_mask_dir
260 | self.img_size = img_size
261 | self.sigma = sigma
262 | self.zoom_min = zoom_min
263 | self.zoom_max = zoom_max
264 |
265 | scene_paths = []
266 | for root_dir in train_dir_list:
267 | scene_paths += [os.path.join(root_dir, scene) for scene in list(os.walk(root_dir))[0][1]]
268 |
269 | last_index = 0
270 | self.img_paths = []
271 | self.scene_indices = []
272 |
273 | for scene_path in scene_paths:
274 | scene_img_paths = natsorted(glob(os.path.join(scene_path, '*R-*.png')))
275 | scene_length = len(scene_img_paths)
276 | self.scene_indices.append(list(range(last_index, last_index + scene_length)))
277 | last_index += scene_length
278 | self.img_paths += scene_img_paths
279 |
280 | # number of images in full dataset
281 | self.data_len = len(self.img_paths)
282 |
283 | def __len__(self):
284 | return self.data_len
285 |
286 | def get_scene_indices(self):
287 | return self.scene_indices
288 |
289 | def __getitem__(self, index):
290 | ts = self.img_size
291 |
292 | inp_path = self.img_paths[index]
293 | tar_path = self.img_paths[index][:-9] + 'C-000.png'
294 | if 'Gurutto_1-2' in inp_path:
295 | tar_path = self.img_paths[index][:-9] + 'C' + self.img_paths[index][-8:]
296 |
297 | inp_img = Image.open(inp_path)
298 | inp_img = np.array(inp_img)
299 |
300 | tar_img = Image.open(tar_path)
301 | tar_img = np.array(tar_img) # [height, width, channel]
302 |
303 | # rain aug
304 | if random.randint(1, 10) > 4:
305 | inp_img, tar_img = rain_aug(
306 | inp_img,
307 | tar_img,
308 | self.rain_mask_dir,
309 | zoom_min=self.zoom_min,
310 | zoom_max=self.zoom_max
311 | )
312 |
313 | # Random rotation
314 | angle = np.random.normal(0, self.sigma)
315 | inp_img_rot = gen_rotate_image(inp_img, angle)
316 | if inp_img_rot.shape[0] >= 256 and inp_img_rot.shape[1] >= 256:
317 | inp_img = inp_img_rot
318 | tar_img = gen_rotate_image(tar_img, angle)
319 |
320 | # reflect pad and random cropping to ensure the right image size for training
321 | h, w = inp_img.shape[:2]
322 |
323 | # To tensor
324 | inp_img = TF.to_tensor(inp_img) # [channel, height, width]
325 | tar_img = TF.to_tensor(tar_img)
326 |
327 | # reflect padding
328 | padw = ts - w if w < ts else 0
329 | padh = ts - h if h < ts else 0
330 |
331 | if padw != 0 or padh != 0:
332 | inp_img = F.pad(inp_img, (padw // 2, padw - padw // 2, padh // 2, padh - padh // 2), mode='reflect')
333 | tar_img = F.pad(tar_img, (padw // 2, padw - padw // 2, padh // 2, padh - padh // 2), mode='reflect')
334 |
335 | # random cropping
336 | hh, ww, = inp_img.shape[1], inp_img.shape[2]
337 | rr = random.randint(0, hh - ts)
338 | cc = random.randint(0, ww - ts)
339 | inp_img = inp_img[:, rr:rr + ts, cc:cc + ts]
340 | tar_img = tar_img[:, rr:rr + ts, cc:cc + ts]
341 |
342 | # Data augmentations: flip x, flip y
343 | aug = random.randint(0, 2)
344 | if aug == 1:
345 | inp_img = inp_img.flip(1)
346 | tar_img = tar_img.flip(1)
347 | elif aug == 2:
348 | inp_img = inp_img.flip(2)
349 | tar_img = tar_img.flip(2)
350 |
351 | # Get image name
352 | scene_name = inp_path.split('/')[-2]
353 | file_name = inp_path.split('/')[-1]
354 |
355 | # Dict for return
356 | # If using tanh as the last layer, the range should be [-1, 1]
357 |
358 | sample_dict = {
359 | 'input_img': inp_img,
360 | 'target_img': tar_img,
361 | 'file_name': file_name
362 | }
363 |
364 | return sample_dict
365 |
366 |
367 | class CustomBatchSampler():
368 | def __init__(self, scene_indices, batch_size=16):
369 | self.scene_indices = scene_indices
370 | self.batch_size = batch_size
371 | self.num_batches = int(scene_indices[-1][-1] / batch_size)
372 |
373 | def __len__(self):
374 | return self.num_batches
375 |
376 | def __iter__(self):
377 | scene_indices = copy.deepcopy(self.scene_indices)
378 | for scene_list in scene_indices:
379 | random.shuffle(scene_list)
380 | out_indices = []
381 | done = False
382 | while not done:
383 | out_batch_indices = []
384 | if (len(scene_indices) < self.batch_size):
385 | self.num_batches = len(out_indices)
386 | return iter(out_indices)
387 | chosen_scenes = np.random.choice(len(scene_indices), self.batch_size, replace=False)
388 | empty_indices = []
389 | for i in chosen_scenes:
390 | scene_list = scene_indices[i]
391 | out_batch_indices.append(scene_list.pop())
392 | if (len(scene_list) == 0):
393 | empty_indices.append(i)
394 | empty_indices.sort(reverse=True)
395 | for i in empty_indices:
396 | scene_indices.pop(i)
397 | out_indices.append(out_batch_indices)
398 | self.num_batches = len(out_indices)
399 | return iter(out_indices)
400 |
401 |
402 | def get_datasets(params):
403 | train_dataset = GTRainDataset(
404 | train_dir_list=params['train_dir_list'],
405 | rain_mask_dir=params['rain_mask_dir'],
406 | img_size=params['img_size'],
407 | zoom_min=params['zoom_min'],
408 | zoom_max=params['zoom_max']
409 | )
410 |
411 | train_loader = DataLoader(
412 | dataset=train_dataset,
413 | batch_sampler=CustomBatchSampler(
414 | train_dataset.get_scene_indices(),
415 | batch_size=params['batch_size']
416 | ),
417 | num_workers=2,
418 | pin_memory=True
419 | )
420 |
421 | return train_loader
422 |
--------------------------------------------------------------------------------