├── Deblur
├── README.md
├── cal.py
├── config.py
├── data_RGB.py
├── dataset_RGB.py
├── eval.py
├── losses.py
├── test.py
├── train.py
├── trmash.yml
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── arch_utils.cpython-38.pyc
│ ├── dataset_utils.cpython-38.pyc
│ ├── dir_utils.cpython-38.pyc
│ ├── dist_util.cpython-38.pyc
│ ├── image_utils.cpython-38.pyc
│ ├── logger.cpython-38.pyc
│ └── model_utils.cpython-38.pyc
│ ├── arch_utils.py
│ ├── dataset_utils.py
│ ├── dir_utils.py
│ ├── dist_util.py
│ ├── image_utils.py
│ ├── logger.py
│ └── model_utils.py
├── Derain
├── README.md
├── cal.py
├── config.py
├── data_RGB.py
├── dataset_RGB.py
├── eval.py
├── evaluate_PSNR_SSIM.py
├── losses.py
├── test.py
├── train.py
├── trmash.yml
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── arch_utils.cpython-38.pyc
│ ├── dataset_utils.cpython-38.pyc
│ ├── dir_utils.cpython-38.pyc
│ ├── dist_util.cpython-38.pyc
│ ├── image_utils.cpython-38.pyc
│ ├── logger.cpython-38.pyc
│ └── model_utils.cpython-38.pyc
│ ├── arch_utils.py
│ ├── dataset_utils.py
│ ├── dir_utils.py
│ ├── dist_util.py
│ ├── image_utils.py
│ ├── logger.py
│ └── model_utils.py
├── LICENSE.md
├── MHNet.py
├── README.md
├── fig
├── blur.jpg
├── dau.png
├── deblur.png
├── derain.png
├── fir_h.jpg
├── muti-net.png
├── network.jpg
├── network.png
├── rain.jpg
├── sec_h.jpg
└── three_con.png
└── pytorch-gradual-warmup-lr
├── build
└── lib
│ └── warmup_scheduler
│ ├── __init__.py
│ ├── run.py
│ └── scheduler.py
├── dist
└── warmup_scheduler-0.3-py3.8.egg
├── setup.py
├── warmup_scheduler.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
└── top_level.txt
└── warmup_scheduler
├── __init__.py
├── run.py
└── scheduler.py
/Deblur/README.md:
--------------------------------------------------------------------------------
1 | ## Training
2 | - Download datasets from the google drive links and place them in Dataset. Your directory tree should look like this
3 |
4 | `GoPro`
5 | `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing)
6 | `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing)
7 |
8 | `HIDE`
9 | `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing)
10 |
11 |
12 | - Train the model with default arguments by running
13 |
14 | ```
15 | python train.py
16 | ```
17 |
18 | ## Evaluation
19 |
20 | ### Download the [model](https://drive.google.com/drive/folders/1qBC3mUoLoCuMyuiseYoZWzvyvImG98TW?usp=drive_link) and place it in ./pre-trained/
21 |
22 | #### Testing on GoPro dataset
23 | - Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/`
24 | - Run
25 | ```
26 | python test.py --dataset GoPro
27 | ```
28 |
29 | #### Testing on HIDE dataset
30 | - Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/`
31 | - Run
32 | ```
33 | python test.py --dataset HIDE
34 | ```
35 |
36 |
37 |
38 |
39 | #### To reproduce PSNR,SSIM scores of the paper on GoPro and HIDE datasets, run
40 |
41 | ```
42 | python eval.py
43 | ```
44 |
--------------------------------------------------------------------------------
/Deblur/cal.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | import skimage.metrics
6 | import torch
7 | import math
8 |
9 | def calculate_psnr(img1, img2, crop_border, test_y_channel=True):
10 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
11 |
12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13 | Args:
14 | img1 (ndarray): Images with range [0, 255].
15 | img2 (ndarray): Images with range [0, 255].
16 | crop_border (int): Cropped pixels in each edge of an image. These
17 | pixels are not involved in the PSNR calculation.
18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
19 | Returns:
20 | float: psnr result.
21 | """
22 | assert img1.shape == img2.shape, (
23 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
24 | if type(img1) == torch.Tensor:
25 | if len(img1.shape) == 4:
26 | img1 = img1.squeeze(0)
27 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
28 | if type(img2) == torch.Tensor:
29 | if len(img2.shape) == 4:
30 | img2 = img2.squeeze(0)
31 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
32 | img1 = img1.astype(np.float64)
33 | img2 = img2.astype(np.float64)
34 |
35 | if crop_border != 0:
36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38 |
39 | if test_y_channel:
40 | img1 = to_y_channel(img1)
41 | img2 = to_y_channel(img2)
42 |
43 | imdff = np.float32(img1) - np.float32(img2)
44 | rmse = np.sqrt(np.mean(imdff**2))
45 | ps = 20*np.log10(255/rmse)
46 | return ps
47 |
48 |
49 | def _convert_input_type_range(img):
50 | """Convert the type and range of the input image.
51 |
52 | It converts the input image to np.float32 type and range of [0, 1].
53 | It is mainly used for pre-processing the input image in colorspace
54 | convertion functions such as rgb2ycbcr and ycbcr2rgb.
55 | Args:
56 | img (ndarray): The input image. It accepts:
57 | 1. np.uint8 type with range [0, 255];
58 | 2. np.float32 type with range [0, 1].
59 | Returns:
60 | (ndarray): The converted image with type of np.float32 and range of
61 | [0, 1].
62 | """
63 | img_type = img.dtype
64 | img = img.astype(np.float32)
65 | if img_type == np.float32:
66 | pass
67 | elif img_type == np.uint8:
68 | img /= 255.
69 | else:
70 | raise TypeError('The img type should be np.float32 or np.uint8, '
71 | f'but got {img_type}')
72 | return img
73 |
74 |
75 | def _convert_output_type_range(img, dst_type):
76 | """Convert the type and range of the image according to dst_type.
77 |
78 | It converts the image to desired type and range. If `dst_type` is np.uint8,
79 | images will be converted to np.uint8 type with range [0, 255]. If
80 | `dst_type` is np.float32, it converts the image to np.float32 type with
81 | range [0, 1].
82 | It is mainly used for post-processing images in colorspace convertion
83 | functions such as rgb2ycbcr and ycbcr2rgb.
84 | Args:
85 | img (ndarray): The image to be converted with np.float32 type and
86 | range [0, 255].
87 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
88 | converts the image to np.uint8 type with range [0, 255]. If
89 | dst_type is np.float32, it converts the image to np.float32 type
90 | with range [0, 1].
91 | Returns:
92 | (ndarray): The converted image with desired type and range.
93 | """
94 | if dst_type not in (np.uint8, np.float32):
95 | raise TypeError('The dst_type should be np.float32 or np.uint8, '
96 | f'but got {dst_type}')
97 | if dst_type == np.uint8:
98 | img = img.round()
99 | else:
100 | img /= 255.
101 |
102 | return img.astype(dst_type)
103 |
104 |
105 | def rgb2ycbcr(img, y_only=True):
106 | """Convert a RGB image to YCbCr image.
107 |
108 | This function produces the same results as Matlab's `rgb2ycbcr` function.
109 | It implements the ITU-R BT.601 conversion for standard-definition
110 | television. See more details in
111 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
112 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
113 | In OpenCV, it implements a JPEG conversion. See more details in
114 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
115 |
116 | Args:
117 | img (ndarray): The input image. It accepts:
118 | 1. np.uint8 type with range [0, 255];
119 | 2. np.float32 type with range [0, 1].
120 | y_only (bool): Whether to only return Y channel. Default: False.
121 | Returns:
122 | ndarray: The converted YCbCr image. The output image has the same type
123 | and range as input image.
124 | """
125 | img_type = img.dtype
126 | img = _convert_input_type_range(img)
127 | if y_only:
128 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
129 | else:
130 | out_img = np.matmul(img,
131 | [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
132 | [24.966, 112.0, -18.214]]) + [16, 128, 128]
133 | out_img = _convert_output_type_range(out_img, img_type)
134 | return out_img
135 |
136 |
137 | def to_y_channel(img):
138 | """Change to Y channel of YCbCr.
139 |
140 | Args:
141 | img (ndarray): Images with range [0, 255].
142 | Returns:
143 | (ndarray): Images with range [0, 255] (float type) without round.
144 | """
145 | img = img.astype(np.float32) / 255.
146 | if img.ndim == 3 and img.shape[2] == 3:
147 | img = rgb2ycbcr(img, y_only=True)
148 | img = img[..., None]
149 | return img * 255.
150 |
151 | def _ssim(img1, img2):
152 | """Calculate SSIM (structural similarity) for one channel images.
153 |
154 | It is called by func:`calculate_ssim`.
155 |
156 | Args:
157 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
158 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
159 |
160 | Returns:
161 | float: ssim result.
162 | """
163 |
164 | C1 = (0.01 * 255)**2
165 | C2 = (0.03 * 255)**2
166 |
167 | img1 = img1.astype(np.float64)
168 | img2 = img2.astype(np.float64)
169 | kernel = cv2.getGaussianKernel(11, 1.5)
170 | window = np.outer(kernel, kernel.transpose())
171 |
172 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
173 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
174 | mu1_sq = mu1**2
175 | mu2_sq = mu2**2
176 | mu1_mu2 = mu1 * mu2
177 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
178 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
179 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
180 |
181 | ssim_map = ((2 * mu1_mu2 + C1) *
182 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
183 | (sigma1_sq + sigma2_sq + C2))
184 | return ssim_map.mean()
185 |
186 | def prepare_for_ssim(img, k):
187 | import torch
188 | with torch.no_grad():
189 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
190 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
191 | conv.weight.requires_grad = False
192 | conv.weight[:, :, :, :] = 1. / (k * k)
193 |
194 | img = conv(img)
195 |
196 | img = img.squeeze(0).squeeze(0)
197 | img = img[0::k, 0::k]
198 | return img.detach().cpu().numpy()
199 |
200 | def prepare_for_ssim_rgb(img, k):
201 | import torch
202 | with torch.no_grad():
203 | img = torch.from_numpy(img).float() #HxWx3
204 |
205 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
206 | conv.weight.requires_grad = False
207 | conv.weight[:, :, :, :] = 1. / (k * k)
208 |
209 | new_img = []
210 |
211 | for i in range(3):
212 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
213 |
214 | return torch.stack(new_img, dim=2).detach().cpu().numpy()
215 |
216 | def _3d_gaussian_calculator(img, conv3d):
217 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
218 | return out
219 |
220 | def _generate_3d_gaussian_kernel():
221 | kernel = cv2.getGaussianKernel(11, 1.5)
222 | window = np.outer(kernel, kernel.transpose())
223 | kernel_3 = cv2.getGaussianKernel(11, 1.5)
224 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
225 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
226 | conv3d.weight.requires_grad = False
227 | conv3d.weight[0, 0, :, :, :] = kernel
228 | return conv3d
229 |
230 | def _ssim_3d(img1, img2, max_value):
231 | assert len(img1.shape) == 3 and len(img2.shape) == 3
232 | """Calculate SSIM (structural similarity) for one channel images.
233 |
234 | It is called by func:`calculate_ssim`.
235 |
236 | Args:
237 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
238 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
239 |
240 | Returns:
241 | float: ssim result.
242 | """
243 | C1 = (0.01 * max_value) ** 2
244 | C2 = (0.03 * max_value) ** 2
245 | img1 = img1.astype(np.float64)
246 | img2 = img2.astype(np.float64)
247 |
248 | kernel = _generate_3d_gaussian_kernel().cuda()
249 |
250 | img1 = torch.tensor(img1).float().cuda()
251 | img2 = torch.tensor(img2).float().cuda()
252 |
253 |
254 | mu1 = _3d_gaussian_calculator(img1, kernel)
255 | mu2 = _3d_gaussian_calculator(img2, kernel)
256 |
257 | mu1_sq = mu1 ** 2
258 | mu2_sq = mu2 ** 2
259 | mu1_mu2 = mu1 * mu2
260 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
261 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
262 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
263 |
264 | ssim_map = ((2 * mu1_mu2 + C1) *
265 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
266 | (sigma1_sq + sigma2_sq + C2))
267 | return float(ssim_map.mean())
268 |
269 | def _ssim_cly(img1, img2):
270 | assert len(img1.shape) == 2 and len(img2.shape) == 2
271 | """Calculate SSIM (structural similarity) for one channel images.
272 |
273 | It is called by func:`calculate_ssim`.
274 |
275 | Args:
276 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
277 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
278 |
279 | Returns:
280 | float: ssim result.
281 | """
282 |
283 | C1 = (0.01 * 255)**2
284 | C2 = (0.03 * 255)**2
285 | img1 = img1.astype(np.float64)
286 | img2 = img2.astype(np.float64)
287 |
288 | kernel = cv2.getGaussianKernel(11, 1.5)
289 | # print(kernel)
290 | window = np.outer(kernel, kernel.transpose())
291 |
292 | bt = cv2.BORDER_REPLICATE
293 |
294 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
295 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
296 |
297 | mu1_sq = mu1**2
298 | mu2_sq = mu2**2
299 | mu1_mu2 = mu1 * mu2
300 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
301 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
302 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
303 |
304 | ssim_map = ((2 * mu1_mu2 + C1) *
305 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
306 | (sigma1_sq + sigma2_sq + C2))
307 | return ssim_map.mean()
308 | def reorder_image(img, input_order='HWC'):
309 | """Reorder images to 'HWC' order.
310 |
311 | If the input_order is (h, w), return (h, w, 1);
312 | If the input_order is (c, h, w), return (h, w, c);
313 | If the input_order is (h, w, c), return as it is.
314 |
315 | Args:
316 | img (ndarray): Input image.
317 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
318 | If the input image shape is (h, w), input_order will not have
319 | effects. Default: 'HWC'.
320 |
321 | Returns:
322 | ndarray: reordered image.
323 | """
324 |
325 | if input_order not in ['HWC', 'CHW']:
326 | raise ValueError(
327 | f'Wrong input_order {input_order}. Supported input_orders are '
328 | "'HWC' and 'CHW'")
329 | if len(img.shape) == 2:
330 | img = img[..., None]
331 | if input_order == 'CHW':
332 | img = img.transpose(1, 2, 0)
333 | return img
334 |
335 |
336 | def calculate_ssim(img1,
337 | img2,
338 | crop_border,
339 | input_order='HWC',
340 | test_y_channel=True):
341 | """Calculate SSIM (structural similarity).
342 |
343 | Ref:
344 | Image quality assessment: From error visibility to structural similarity
345 |
346 | The results are the same as that of the official released MATLAB code in
347 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
348 |
349 | For three-channel images, SSIM is calculated for each channel and then
350 | averaged.
351 |
352 | Args:
353 | img1 (ndarray): Images with range [0, 255].
354 | img2 (ndarray): Images with range [0, 255].
355 | crop_border (int): Cropped pixels in each edge of an image. These
356 | pixels are not involved in the SSIM calculation.
357 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
358 | Default: 'HWC'.
359 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
360 |
361 | Returns:
362 | float: ssim result.
363 | """
364 |
365 | assert img1.shape == img2.shape, (
366 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
367 | if input_order not in ['HWC', 'CHW']:
368 | raise ValueError(
369 | f'Wrong input_order {input_order}. Supported input_orders are '
370 | '"HWC" and "CHW"')
371 |
372 | if type(img1) == torch.Tensor:
373 | if len(img1.shape) == 4:
374 | img1 = img1.squeeze(0)
375 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
376 | if type(img2) == torch.Tensor:
377 | if len(img2.shape) == 4:
378 | img2 = img2.squeeze(0)
379 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
380 |
381 | img1 = reorder_image(img1, input_order=input_order)
382 | img2 = reorder_image(img2, input_order=input_order)
383 |
384 | img1 = img1.astype(np.float64)
385 | img2 = img2.astype(np.float64)
386 |
387 | if crop_border != 0:
388 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
389 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
390 |
391 | if test_y_channel:
392 | img1 = to_y_channel(img1)
393 | img2 = to_y_channel(img2)
394 | return _ssim_cly(img1[..., 0], img2[..., 0])
395 |
396 |
397 | ssims = []
398 | # ssims_before = []
399 |
400 | # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
401 | # print('.._skimage',
402 | # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
403 | max_value = 1 if img1.max() <= 1 else 255
404 | with torch.no_grad():
405 | final_ssim = _ssim_3d(img1, img2, max_value)
406 | ssims.append(final_ssim)
407 |
408 | # for i in range(img1.shape[2]):
409 | # ssims_before.append(_ssim(img1, img2))
410 |
411 | # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
412 | # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
413 |
414 | return np.array(ssims).mean()
--------------------------------------------------------------------------------
/Deblur/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | r"""This module provides package-wide configuration management."""
6 | from typing import Any, List
7 |
8 | from yacs.config import CfgNode as CN
9 |
10 |
11 | class Config(object):
12 | r"""
13 | A collection of all the required configuration parameters. This class is a nested dict-like
14 | structure, with nested keys accessible as attributes. It contains sensible default values for
15 | all the parameters, which may be overriden by (first) through a YAML file and (second) through
16 | a list of attributes and values.
17 |
18 | Extended Summary
19 | ----------------
20 | This class definition contains default values corresponding to ``joint_training`` phase, as it
21 | is the final training phase and uses almost all the configuration parameters. Modification of
22 | any parameter after instantiating this class is not possible, so you must override required
23 | parameter values in either through ``config_yaml`` file or ``config_override`` list.
24 |
25 | Parameters
26 | ----------
27 | config_yaml: str
28 | Path to a YAML file containing configuration parameters to override.
29 | config_override: List[Any], optional (default= [])
30 | A list of sequential attributes and values of parameters to override. This happens after
31 | overriding from YAML file.
32 |
33 | Examples
34 | --------
35 | Let a YAML file named "config.yaml" specify these parameters to override::
36 |
37 | ALPHA: 1000.0
38 | BETA: 0.5
39 |
40 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7])
41 | >>> _C.ALPHA # default: 100.0
42 | 1000.0
43 | >>> _C.BATCH_SIZE # default: 256
44 | 2048
45 | >>> _C.BETA # default: 0.1
46 | 0.7
47 |
48 | Attributes
49 | ----------
50 | """
51 |
52 | def __init__(self, config_yaml: str, config_override: List[Any] = []):
53 |
54 | self._C = CN()
55 | self._C.GPU = [0]
56 | self._C.VERBOSE = False
57 |
58 | self._C.MODEL = CN()
59 | self._C.MODEL.MODE = 'global'
60 | self._C.MODEL.SESSION = 'ps128_bs1'
61 |
62 | self._C.OPTIM = CN()
63 | self._C.OPTIM.BATCH_SIZE = 1
64 | self._C.OPTIM.NUM_EPOCHS = 100
65 | self._C.OPTIM.NEPOCH_DECAY = [100]
66 | self._C.OPTIM.LR_INITIAL = 0.0002
67 | self._C.OPTIM.LR_MIN = 0.0002
68 | self._C.OPTIM.BETA1 = 0.5
69 |
70 | self._C.TRAINING = CN()
71 | self._C.TRAINING.VAL_AFTER_EVERY = 3
72 | self._C.TRAINING.RESUME = False
73 | self._C.TRAINING.SAVE_IMAGES = False
74 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train'
75 | self._C.TRAINING.VAL_DIR = 'images_dir/val'
76 | self._C.TRAINING.SAVE_DIR = 'checkpoints'
77 | self._C.TRAINING.TRAIN_PS = 64
78 | self._C.TRAINING.VAL_PS = 64
79 |
80 | # Override parameter values from YAML file first, then from override list.
81 | self._C.merge_from_file(config_yaml)
82 | self._C.merge_from_list(config_override)
83 |
84 | # Make an instantiated object of this class immutable.
85 | self._C.freeze()
86 |
87 | def dump(self, file_path: str):
88 | r"""Save config at the specified file path.
89 |
90 | Parameters
91 | ----------
92 | file_path: str
93 | (YAML) path to save config at.
94 | """
95 | self._C.dump(stream=open(file_path, "w"))
96 |
97 | def __getattr__(self, attr: str):
98 | return self._C.__getattr__(attr)
99 |
100 | def __repr__(self):
101 | return self._C.__repr__()
102 |
--------------------------------------------------------------------------------
/Deblur/data_RGB.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTest2
3 |
4 | def get_training_data(rgb_dir, img_options):
5 | assert os.path.exists(rgb_dir)
6 | return DataLoaderTrain(rgb_dir, img_options)
7 |
8 | def get_validation_data(rgb_dir, img_options):
9 | assert os.path.exists(rgb_dir)
10 | return DataLoaderVal(rgb_dir, img_options)
11 |
12 | def get_test_data(rgb_dir, img_options):
13 | assert os.path.exists(rgb_dir)
14 | return DataLoaderTest(rgb_dir, img_options)
15 |
16 | def get_test_data2(rgb_dir, img_options):
17 | assert os.path.exists(rgb_dir)
18 | return DataLoaderTest2(rgb_dir, img_options)
19 |
20 |
21 |
--------------------------------------------------------------------------------
/Deblur/dataset_RGB.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | from PIL import Image
6 | import torchvision.transforms.functional as TF
7 | from pdb import set_trace as stx
8 | import random
9 | import utils
10 |
11 |
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
14 |
15 |
16 | class DataLoaderTrain(Dataset):
17 | def __init__(self, rgb_dir, img_options=None):
18 | super(DataLoaderTrain, self).__init__()
19 |
20 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
21 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
22 |
23 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
24 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
25 |
26 | self.img_options = img_options
27 | self.sizex = len(self.tar_filenames) # get the size of target
28 |
29 | self.ps = self.img_options['patch_size']
30 |
31 | def __len__(self):
32 | return self.sizex
33 |
34 | def __getitem__(self, index):
35 | index_ = index % self.sizex
36 | ps = self.ps
37 |
38 | inp_path = self.inp_filenames[index_]
39 | tar_path = self.tar_filenames[index_]
40 |
41 | inp_img = Image.open(inp_path)
42 | tar_img = Image.open(tar_path)
43 |
44 | w, h = tar_img.size
45 | padw = ps - w if w < ps else 0
46 | padh = ps - h if h < ps else 0
47 |
48 | # Reflect Pad in case image is smaller than patch_size
49 | if padw != 0 or padh != 0:
50 | inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect')
51 | tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect')
52 |
53 |
54 | inp_img = TF.to_tensor(inp_img)
55 | tar_img = TF.to_tensor(tar_img)
56 |
57 | hh, ww = tar_img.shape[1], tar_img.shape[2]
58 |
59 | rr = random.randint(0, hh - ps)
60 | cc = random.randint(0, ww - ps)
61 | aug = random.randint(0, 8)
62 |
63 | # Crop patch
64 | inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
65 | tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]
66 |
67 | # Data Augmentations
68 | if aug == 1:
69 | inp_img = inp_img.flip(1)
70 | tar_img = tar_img.flip(1)
71 | elif aug == 2:
72 | inp_img = inp_img.flip(2)
73 | tar_img = tar_img.flip(2)
74 | elif aug == 3:
75 | inp_img = torch.rot90(inp_img, dims=(1, 2))
76 | tar_img = torch.rot90(tar_img, dims=(1, 2))
77 | elif aug == 4:
78 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=2)
79 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=2)
80 | elif aug == 5:
81 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=3)
82 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=3)
83 | elif aug == 6:
84 | inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2))
85 | tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2))
86 | elif aug == 7:
87 | inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2))
88 | tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2))
89 |
90 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
91 |
92 | return tar_img, inp_img, filename
93 |
94 |
95 | class DataLoaderVal(Dataset):
96 | def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
97 | super(DataLoaderVal, self).__init__()
98 |
99 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
100 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
101 |
102 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
103 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
104 |
105 | self.img_options = img_options
106 | self.sizex = len(self.tar_filenames) # get the size of target
107 |
108 | self.ps = self.img_options['patch_size']
109 |
110 | def __len__(self):
111 | return self.sizex
112 |
113 | def __getitem__(self, index):
114 | index_ = index % self.sizex
115 | ps = self.ps
116 |
117 | inp_path = self.inp_filenames[index_]
118 | tar_path = self.tar_filenames[index_]
119 |
120 | inp_img = Image.open(inp_path)
121 | tar_img = Image.open(tar_path)
122 |
123 | # Validate on center crop
124 | if self.ps is not None:
125 | inp_img = TF.center_crop(inp_img, (ps, ps))
126 | tar_img = TF.center_crop(tar_img, (ps, ps))
127 |
128 | inp_img = TF.to_tensor(inp_img)
129 | tar_img = TF.to_tensor(tar_img)
130 |
131 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
132 |
133 |
134 | return tar_img, inp_img, filename
135 |
136 |
137 | class DataLoaderTest(Dataset):
138 | def __init__(self, rgb_dir, img_options):
139 | super(DataLoaderTest, self).__init__()
140 |
141 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
142 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
143 |
144 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
145 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
146 |
147 | self.inp_size = len(self.inp_filenames)
148 | self.img_options = img_options
149 |
150 | def __len__(self):
151 | return self.inp_size
152 |
153 | def __getitem__(self, index):
154 | path_inp = self.inp_filenames[index]
155 | tar_path = self.tar_filenames[index]
156 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
157 | inp = Image.open(path_inp)
158 | tar_img = Image.open(tar_path)
159 |
160 | inp = TF.to_tensor(inp)
161 | tar_img = TF.to_tensor(tar_img)
162 | return inp, tar_img
163 |
164 |
--------------------------------------------------------------------------------
/Deblur/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from glob import glob
4 | import cv2
5 | from natsort import natsorted
6 |
7 | from skimage.metrics import structural_similarity,peak_signal_noise_ratio
8 | from cal import calculate_psnr,calculate_ssim
9 |
10 | def read_img(path):
11 | return cv2.imread(path)
12 |
13 |
14 |
15 |
16 | def main():
17 | datasets = {'GoPr', 'HIDE'}
18 | file_path = os.path.join('resultsmash_g/Raindata/test', 'Rain100H')
19 | gt_path = os.path.join('Dataset/Raindata/test/Rain100H', 'target')
20 | print(file_path)
21 | print(gt_path)
22 |
23 | path_fake = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
24 | path_real = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
25 | print(len(path_fake))
26 | list_psnr = []
27 | list_ssim = []
28 | list_mse = []
29 |
30 | for i in range(len(path_real)):
31 | t1 = read_img(path_real[i])
32 | t2 = read_img(path_fake[i])
33 | #result1 = np.zeros(t1.shape,dtype=np.float32)
34 | #result2 = np.zeros(t2.shape,dtype=np.float32)
35 | #cv2.normalize(t1,result1,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F)
36 | #cv2.normalize(t2,result2,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F)
37 |
38 |
39 |
40 | psnr_num = calculate_psnr(t1, t2,0)
41 | ssim_num = calculate_ssim(t1, t2,0)
42 |
43 | list_ssim.append(ssim_num)
44 | list_psnr.append(psnr_num)
45 |
46 |
47 |
48 | print("AverSSIM:", np.mean(list_ssim)) # ,list_ssim)
49 | print("AverPSNR:", np.mean(list_psnr)) # ,list_ssim)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/Deblur/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class CharbonnierLoss(nn.Module):
7 | """Charbonnier Loss (L1)"""
8 |
9 | def __init__(self, eps=1e-3):
10 | super(CharbonnierLoss, self).__init__()
11 | self.eps = eps
12 |
13 | def forward(self, x, y):
14 | diff = x - y
15 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
16 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
17 | return loss
18 |
19 | class EdgeLoss(nn.Module):
20 | def __init__(self):
21 | super(EdgeLoss, self).__init__()
22 | k = torch.Tensor([[.05, .25, .4, .25, .05]])
23 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
24 | if torch.cuda.is_available():
25 | self.kernel = self.kernel.cuda()
26 | self.loss = CharbonnierLoss()
27 |
28 | def conv_gauss(self, img):
29 | n_channels, _, kw, kh = self.kernel.shape
30 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
31 | return F.conv2d(img, self.kernel, groups=n_channels)
32 |
33 | def laplacian_kernel(self, current):
34 | filtered = self.conv_gauss(current) # filter
35 | down = filtered[:,:,::2,::2] # downsample
36 | new_filter = torch.zeros_like(filtered)
37 | new_filter[:,:,::2,::2] = down*4 # upsample
38 | filtered = self.conv_gauss(new_filter) # filter
39 | diff = current - filtered
40 | return diff
41 |
42 | def forward(self, x, y):
43 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
44 | return loss
45 |
46 |
47 | class PSNRLoss(nn.Module):
48 |
49 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
50 | super(PSNRLoss, self).__init__()
51 | assert reduction == 'mean'
52 | self.loss_weight = loss_weight
53 | self.scale = 10 / np.log(10)
54 | self.toY = toY
55 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
56 | self.first = True
57 |
58 | def forward(self, pred, target):
59 | assert len(pred.size()) == 4
60 | if self.toY:
61 | if self.first:
62 | self.coef = self.coef.to(pred.device)
63 | self.first = False
64 |
65 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
66 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
67 |
68 | pred, target = pred / 255., target / 255.
69 | pass
70 | assert len(pred.size()) == 4
71 |
72 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
73 |
--------------------------------------------------------------------------------
/Deblur/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch.nn as nn
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 | import utils
11 |
12 | from data_RGB import get_test_data
13 | from MHNet import MHNet
14 | from skimage import img_as_ubyte
15 | from pdb import set_trace as stx
16 |
17 | parser = argparse.ArgumentParser(description='Image Deraining using MPRNet')
18 |
19 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images')
20 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
21 | parser.add_argument('--weights', default='./pre-trained/model_best.pth', type=str, help='Path to weights')
22 | parser.add_argument('--gpus', default='2', type=str, help='CUDA_VISIBLE_DEVICES')
23 |
24 | args = parser.parse_args()
25 |
26 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
27 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
28 |
29 | model_restoration = MHNet()
30 |
31 | utils.load_checkpoint(model_restoration,args.weights)
32 | print("===>Testing using weights: ",args.weights)
33 | model_restoration.cuda()
34 | model_restoration = nn.DataParallel(model_restoration)
35 | model_restoration.eval()
36 |
37 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200']
38 | # datasets = ['Rain100L']
39 |
40 | for dataset in datasets:
41 | rgb_dir_test = os.path.join(args.input_dir, dataset, 'input')
42 | test_dataset = get_test_data(rgb_dir_test, img_options={})
43 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
44 |
45 | result_dir = os.path.join(args.result_dir, dataset)
46 | utils.mkdir(result_dir)
47 |
48 | with torch.no_grad():
49 | for ii, data_test in enumerate(tqdm(test_loader), 0):
50 | torch.cuda.ipc_collect()
51 | torch.cuda.empty_cache()
52 |
53 | input_ = data_test[0].cuda()
54 | filenames = data_test[1]
55 |
56 | restored = model_restoration(input_)
57 | restored = torch.clamp(restored[0],0,1)
58 |
59 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
60 |
61 | for batch in range(len(restored)):
62 | restored_img = img_as_ubyte(restored[batch])
63 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
64 |
--------------------------------------------------------------------------------
/Deblur/train.py:
--------------------------------------------------------------------------------
1 |
2 | #!/usr/bin/env python
3 | # coding=utf-8
4 |
5 | import os
6 | from config import Config
7 |
8 | opt = Config('trmash.yml')
9 |
10 | gpus = ','.join([str(i) for i in opt.GPU])
11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
12 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus
13 |
14 | import torch
15 |
16 | torch.backends.cudnn.benchmark = True
17 |
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import torch.optim as optim
21 | from torch.utils.data import DataLoader
22 | import wandb
23 |
24 | import random
25 | import time
26 | import numpy as np
27 | from pathlib import Path
28 |
29 | import utils
30 | from data_RGB import get_training_data, get_validation_data
31 | from MHNet import MHNet
32 | import losses
33 | from warmup_scheduler import GradualWarmupScheduler
34 | from tqdm import tqdm
35 | from pdb import set_trace as stx
36 |
37 |
38 | dir_checkpoint = Path('./mhnetmash/')
39 |
40 | def train():
41 |
42 | ######### Set Seeds ###########
43 | random.seed(1234)
44 | np.random.seed(1234)
45 | torch.manual_seed(42)
46 | torch.cuda.manual_seed_all(42)
47 |
48 | start_epoch = 1
49 | mode = opt.MODEL.MODE
50 | session = opt.MODEL.SESSION
51 |
52 | result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
53 | model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
54 |
55 | utils.mkdir(result_dir)
56 | utils.mkdir(model_dir)
57 |
58 | train_dir = opt.TRAINING.TRAIN_DIR
59 | val_dir = opt.TRAINING.VAL_DIR
60 |
61 | ######### Model ###########
62 | model_restoration = MHNet()
63 | print("Total number of param is ", sum(x.numel() for x in model_restoration.parameters()))
64 | model_restoration.cuda()
65 |
66 | device_ids = [i for i in range(torch.cuda.device_count())]
67 | if torch.cuda.device_count() > 1:
68 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
69 |
70 |
71 | new_lr = opt.OPTIM.LR_INITIAL
72 |
73 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)
74 |
75 |
76 | ######### Scheduler ###########
77 | warmup_epochs = 3
78 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
79 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
80 | scheduler.step()
81 |
82 | ######### Resume ###########
83 | if opt.TRAINING.RESUME:
84 | path_chk_rest = './mhnetmash/model_best.pth'
85 | utils.load_checkpoint(model_restoration,path_chk_rest)
86 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
87 | utils.load_optim(optimizer, path_chk_rest)
88 |
89 | for i in range(1, start_epoch):
90 | scheduler.step()
91 | new_lr = scheduler.get_lr()[0]
92 | print('------------------------------------------------------------------------------')
93 | print("==> Resuming Training with learning rate:", new_lr)
94 | print('------------------------------------------------------------------------------')
95 |
96 | if len(device_ids)>1:
97 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)
98 | print("duoka")
99 |
100 | ######### Loss ###########
101 | criterion_mse = losses.PSNRLoss()
102 | ######### DataLoaders ###########
103 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
104 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)
105 |
106 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
107 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
108 |
109 |
110 |
111 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
112 | print('===> Loading datasets')
113 |
114 | best_psnr = 0
115 | best_epoch = 0
116 | global_step = 0
117 |
118 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
119 | epoch_start_time = time.time()
120 | epoch_loss = 0
121 | psnr_train_rgb = []
122 | psnr_train_rgb1 = []
123 | psnr_tr = 0
124 | psnr_tr1 = 0
125 | model_restoration.train()
126 | for i, data in enumerate(tqdm(train_loader), 0):
127 |
128 | # zero_grad
129 | for param in model_restoration.parameters():
130 | param.grad = None
131 |
132 | target = data[0].cuda()
133 | input_ = data[1].cuda()
134 |
135 | restored = model_restoration(input_)
136 |
137 | loss = criterion_mse(restored[0],target)
138 | loss.backward()
139 | optimizer.step()
140 | epoch_loss += loss.item()
141 | global_step = global_step+1
142 |
143 | psnr_te = 0
144 | psnr_te_1 = 0
145 | ssim_te_1 = 0
146 | #### Evaluation ####
147 | if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0:
148 | model_restoration.eval()
149 | psnr_val_rgb = []
150 | psnr_val_rgb1 = []
151 | for ii, data_val in enumerate((val_loader), 0):
152 | target = data_val[0].cuda()
153 | input_ = data_val[1].cuda()
154 |
155 | with torch.no_grad():
156 | restored = model_restoration(input_)
157 | restore = restored[0]
158 |
159 | for res, tar in zip(restore, target):
160 | tssss = utils.torchPSNR(res, tar)
161 | psnr_te = psnr_te + tssss
162 | psnr_val_rgb.append(utils.torchPSNR(res, tar))
163 |
164 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
165 | print("te", psnr_te)
166 |
167 | if psnr_val_rgb > best_psnr:
168 | best_psnr = psnr_val_rgb
169 | best_epoch = epoch
170 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
171 | torch.save({'epoch': epoch,
172 | 'state_dict': model_restoration.state_dict(),
173 | 'optimizer': optimizer.state_dict()
174 | }, str(dir_checkpoint / "model_best.pth"))
175 |
176 |
177 | print("[epoch %d PSNR: %.4f best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
178 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
179 | torch.save({'epoch': epoch,
180 | 'state_dict': model_restoration.state_dict(),
181 | 'optimizer': optimizer.state_dict()
182 | }, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
183 |
184 | scheduler.step()
185 |
186 | print("------------------------------------------------------------------")
187 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time,
188 | epoch_loss, scheduler.get_lr()[0]))
189 | print("------------------------------------------------------------------")
190 |
191 |
192 | if __name__=='__main__':
193 | train()
194 |
195 |
--------------------------------------------------------------------------------
/Deblur/trmash.yml:
--------------------------------------------------------------------------------
1 | ###############
2 | ##
3 | ####
4 |
5 |
6 | GPU: [0,1,2,3]
7 |
8 | VERBOSE: True
9 |
10 | MODEL:
11 | MODE: 'Deblurring'
12 | SESSION: 'MHNet'
13 |
14 | # Optimization arguments.
15 | OPTIM:
16 | BATCH_SIZE: 32
17 | NUM_EPOCHS: 45000000000
18 | # NEPOCH_DECAY: [10]
19 | LR_INITIAL: 2e-4
20 | LR_MIN: 1e-6
21 | # BETA1: 0.9
22 |
23 | TRAINING:
24 | VAL_AFTER_EVERY: 15
25 | RESUME: True
26 | TRAIN_PS: 256
27 | VAL_PS: 256
28 | TRAIN_DIR: './Datasets/GoPro/train' # path to training data
29 | VAL_DIR: './Datasets/GoPro/test' # path to validation data
30 | SAVE_DIR: './checkpoints' # path to save models and images
31 | # SAVE_IMAGES: False
32 |
--------------------------------------------------------------------------------
/Deblur/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dir_utils import *
2 | from .image_utils import *
3 | from .model_utils import *
4 | from .dataset_utils import *
5 | from .logger import (MessageLogger, get_env_info, get_root_logger,
6 | init_tb_logger, init_wandb_logger)
7 |
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'scandir_SIDD',
31 | 'check_resume',
32 | 'sizeof_fmt',
33 | 'padding',
34 | 'create_lmdb_for_reds',
35 | 'create_lmdb_for_gopro',
36 | 'create_lmdb_for_rain13k',
37 | ]
38 |
39 |
40 |
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/arch_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/arch_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/dataset_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/dataset_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/dir_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/dir_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/dist_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/dist_util.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/image_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/image_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/__pycache__/model_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/model_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Deblur/utils/arch_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn as nn
4 | from torch.nn import functional as F
5 | from torch.nn import init as init
6 | from torch.nn.modules.batchnorm import _BatchNorm
7 |
8 | from utils import get_root_logger
9 |
10 | # try:
11 | # from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
12 | # modulated_deform_conv)
13 | # except ImportError:
14 | # # print('Cannot import dcn. Ignore this warning if dcn is not used. '
15 | # # 'Otherwise install BasicSR with compiling dcn.')
16 | #
17 |
18 | @torch.no_grad()
19 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
20 | """Initialize network weights.
21 | Args:
22 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23 | scale (float): Scale initialized weights, especially for residual
24 | blocks. Default: 1.
25 | bias_fill (float): The value to fill bias. Default: 0
26 | kwargs (dict): Other arguments for initialization function.
27 | """
28 | if not isinstance(module_list, list):
29 | module_list = [module_list]
30 | for module in module_list:
31 | for m in module.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.kaiming_normal_(m.weight, **kwargs)
34 | m.weight.data *= scale
35 | if m.bias is not None:
36 | m.bias.data.fill_(bias_fill)
37 | elif isinstance(m, nn.Linear):
38 | init.kaiming_normal_(m.weight, **kwargs)
39 | m.weight.data *= scale
40 | if m.bias is not None:
41 | m.bias.data.fill_(bias_fill)
42 | elif isinstance(m, _BatchNorm):
43 | init.constant_(m.weight, 1)
44 | if m.bias is not None:
45 | m.bias.data.fill_(bias_fill)
46 |
47 |
48 | def make_layer(basic_block, num_basic_block, **kwarg):
49 | """Make layers by stacking the same blocks.
50 | Args:
51 | basic_block (nn.module): nn.module class for basic block.
52 | num_basic_block (int): number of blocks.
53 | Returns:
54 | nn.Sequential: Stacked blocks in nn.Sequential.
55 | """
56 | layers = []
57 | for _ in range(num_basic_block):
58 | layers.append(basic_block(**kwarg))
59 | return nn.Sequential(*layers)
60 |
61 |
62 | class ResidualBlockNoBN(nn.Module):
63 | """Residual block without BN.
64 | It has a style of:
65 | ---Conv-ReLU-Conv-+-
66 | |________________|
67 | Args:
68 | num_feat (int): Channel number of intermediate features.
69 | Default: 64.
70 | res_scale (float): Residual scale. Default: 1.
71 | pytorch_init (bool): If set to True, use pytorch default init,
72 | otherwise, use default_init_weights. Default: False.
73 | """
74 |
75 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
76 | super(ResidualBlockNoBN, self).__init__()
77 | self.res_scale = res_scale
78 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
79 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
80 | self.relu = nn.ReLU(inplace=True)
81 |
82 | if not pytorch_init:
83 | default_init_weights([self.conv1, self.conv2], 0.1)
84 |
85 | def forward(self, x):
86 | identity = x
87 | out = self.conv2(self.relu(self.conv1(x)))
88 | return identity + out * self.res_scale
89 |
90 |
91 | class Upsample(nn.Sequential):
92 | """Upsample module.
93 | Args:
94 | scale (int): Scale factor. Supported scales: 2^n and 3.
95 | num_feat (int): Channel number of intermediate features.
96 | """
97 |
98 | def __init__(self, scale, num_feat):
99 | m = []
100 | if (scale & (scale - 1)) == 0: # scale = 2^n
101 | for _ in range(int(math.log(scale, 2))):
102 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
103 | m.append(nn.PixelShuffle(2))
104 | elif scale == 3:
105 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
106 | m.append(nn.PixelShuffle(3))
107 | else:
108 | raise ValueError(f'scale {scale} is not supported. '
109 | 'Supported scales: 2^n and 3.')
110 | super(Upsample, self).__init__(*m)
111 |
112 |
113 | def flow_warp(x,
114 | flow,
115 | interp_mode='bilinear',
116 | padding_mode='zeros',
117 | align_corners=True):
118 | """Warp an image or feature map with optical flow.
119 | Args:
120 | x (Tensor): Tensor with size (n, c, h, w).
121 | flow (Tensor): Tensor with size (n, h, w, 2), normal value.
122 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
123 | padding_mode (str): 'zeros' or 'border' or 'reflection'.
124 | Default: 'zeros'.
125 | align_corners (bool): Before pytorch 1.3, the default value is
126 | align_corners=True. After pytorch 1.3, the default value is
127 | align_corners=False. Here, we use the True as default.
128 | Returns:
129 | Tensor: Warped image or feature map.
130 | """
131 | assert x.size()[-2:] == flow.size()[1:3]
132 | _, _, h, w = x.size()
133 | # create mesh grid
134 | grid_y, grid_x = torch.meshgrid(
135 | torch.arange(0, h).type_as(x),
136 | torch.arange(0, w).type_as(x))
137 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138 | grid.requires_grad = False
139 |
140 | vgrid = grid + flow
141 | # scale grid to [-1,1]
142 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145 | output = F.grid_sample(
146 | x,
147 | vgrid_scaled,
148 | mode=interp_mode,
149 | padding_mode=padding_mode,
150 | align_corners=align_corners)
151 |
152 | # TODO, what if align_corners=False
153 | return output
154 |
155 |
156 | def resize_flow(flow,
157 | size_type,
158 | sizes,
159 | interp_mode='bilinear',
160 | align_corners=False):
161 | """Resize a flow according to ratio or shape.
162 | Args:
163 | flow (Tensor): Precomputed flow. shape [N, 2, H, W].
164 | size_type (str): 'ratio' or 'shape'.
165 | sizes (list[int | float]): the ratio for resizing or the final output
166 | shape.
167 | 1) The order of ratio should be [ratio_h, ratio_w]. For
168 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio
169 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
170 | ratio > 1.0).
171 | 2) The order of output_size should be [out_h, out_w].
172 | interp_mode (str): The mode of interpolation for resizing.
173 | Default: 'bilinear'.
174 | align_corners (bool): Whether align corners. Default: False.
175 | Returns:
176 | Tensor: Resized flow.
177 | """
178 | _, _, flow_h, flow_w = flow.size()
179 | if size_type == 'ratio':
180 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
181 | elif size_type == 'shape':
182 | output_h, output_w = sizes[0], sizes[1]
183 | else:
184 | raise ValueError(
185 | f'Size type should be ratio or shape, but got type {size_type}.')
186 |
187 | input_flow = flow.clone()
188 | ratio_h = output_h / flow_h
189 | ratio_w = output_w / flow_w
190 | input_flow[:, 0, :, :] *= ratio_w
191 | input_flow[:, 1, :, :] *= ratio_h
192 | resized_flow = F.interpolate(
193 | input=input_flow,
194 | size=(output_h, output_w),
195 | mode=interp_mode,
196 | align_corners=align_corners)
197 | return resized_flow
198 |
199 |
200 | # TODO: may write a cpp file
201 | def pixel_unshuffle(x, scale):
202 | """ Pixel unshuffle.
203 | Args:
204 | x (Tensor): Input feature with shape (b, c, hh, hw).
205 | scale (int): Downsample ratio.
206 | Returns:
207 | Tensor: the pixel unshuffled feature.
208 | """
209 | b, c, hh, hw = x.size()
210 | out_channel = c * (scale**2)
211 | assert hh % scale == 0 and hw % scale == 0
212 | h = hh // scale
213 | w = hw // scale
214 | x_view = x.view(b, c, h, scale, w, scale)
215 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
216 |
217 |
218 | # class DCNv2Pack(ModulatedDeformConvPack):
219 | # """Modulated deformable conv for deformable alignment.
220 | #
221 | # Different from the official DCNv2Pack, which generates offsets and masks
222 | # from the preceding features, this DCNv2Pack takes another different
223 | # features to generate offsets and masks.
224 | #
225 | # Ref:
226 | # Delving Deep into Deformable Alignment in Video Super-Resolution.
227 | # """
228 | #
229 | # def forward(self, x, feat):
230 | # out = self.conv_offset(feat)
231 | # o1, o2, mask = torch.chunk(out, 3, dim=1)
232 | # offset = torch.cat((o1, o2), dim=1)
233 | # mask = torch.sigmoid(mask)
234 | #
235 | # offset_absmean = torch.mean(torch.abs(offset))
236 | # if offset_absmean > 50:
237 | # logger = get_root_logger()
238 | # logger.warning(
239 | # f'Offset abs mean is {offset_absmean}, larger than 50.')
240 | #
241 | # return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
242 | # self.stride, self.padding, self.dilation,
243 | # self.groups, self.deformable_groups)
244 |
245 |
246 | class LayerNormFunction(torch.autograd.Function):
247 |
248 | @staticmethod
249 | def forward(ctx, x, weight, bias, eps):
250 | ctx.eps = eps
251 | N, C, H, W = x.size()
252 | mu = x.mean(1, keepdim=True)
253 | var = (x - mu).pow(2).mean(1, keepdim=True)
254 | y = (x - mu) / (var + eps).sqrt()
255 | ctx.save_for_backward(y, var, weight)
256 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
257 | return y
258 |
259 | @staticmethod
260 | def backward(ctx, grad_output):
261 | eps = ctx.eps
262 |
263 | N, C, H, W = grad_output.size()
264 | y, var, weight = ctx.saved_variables
265 | g = grad_output * weight.view(1, C, 1, 1)
266 | mean_g = g.mean(dim=1, keepdim=True)
267 |
268 | mean_gy = (g * y).mean(dim=1, keepdim=True)
269 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
270 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
271 | dim=0), None
272 |
273 | class LayerNorm2d(nn.Module):
274 |
275 | def __init__(self, channels, eps=1e-6):
276 | super(LayerNorm2d, self).__init__()
277 | self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
278 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
279 | self.eps = eps
280 |
281 | def forward(self, x):
282 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
283 |
284 | # handle multiple input
285 | class MySequential(nn.Sequential):
286 | def forward(self, *inputs):
287 | for module in self._modules.values():
288 | if type(inputs) == tuple:
289 | inputs = module(*inputs)
290 | else:
291 | inputs = module(inputs)
292 | return inputs
293 |
294 | import time
295 | def measure_inference_speed(model, data, max_iter=200, log_interval=50):
296 | model.eval()
297 |
298 | # the first several iterations may be very slow so skip them
299 | num_warmup = 5
300 | pure_inf_time = 0
301 | fps = 0
302 |
303 | # benchmark with 2000 image and take the average
304 | for i in range(max_iter):
305 |
306 | torch.cuda.synchronize()
307 | start_time = time.perf_counter()
308 |
309 | with torch.no_grad():
310 | model(*data)
311 |
312 | torch.cuda.synchronize()
313 | elapsed = time.perf_counter() - start_time
314 |
315 | if i >= num_warmup:
316 | pure_inf_time += elapsed
317 | if (i + 1) % log_interval == 0:
318 | fps = (i + 1 - num_warmup) / pure_inf_time
319 | print(
320 | f'Done image [{i + 1:<3}/ {max_iter}], '
321 | f'fps: {fps:.1f} img / s, '
322 | f'times per image: {1000 / fps:.1f} ms / img',
323 | flush=True)
324 |
325 | if (i + 1) == max_iter:
326 | fps = (i + 1 - num_warmup) / pure_inf_time
327 | print(
328 | f'Overall fps: {fps:.1f} img / s, '
329 | f'times per image: {1000 / fps:.1f} ms / img',
330 | flush=True)
331 | break
332 | return
--------------------------------------------------------------------------------
/Deblur/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class MixUp_AUG:
4 | def __init__(self):
5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
6 |
7 | def aug(self, rgb_gt, rgb_noisy):
8 | bs = rgb_gt.size(0)
9 | indices = torch.randperm(bs)
10 | rgb_gt2 = rgb_gt[indices]
11 | rgb_noisy2 = rgb_noisy[indices]
12 |
13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
14 |
15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
17 |
18 | return rgb_gt, rgb_noisy
--------------------------------------------------------------------------------
/Deblur/utils/dir_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from natsort import natsorted
3 | from glob import glob
4 |
5 | def mkdirs(paths):
6 | if isinstance(paths, list) and not isinstance(paths, str):
7 | for path in paths:
8 | mkdir(path)
9 | else:
10 | mkdir(paths)
11 |
12 | def mkdir(path):
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 |
16 | def get_last_path(path, session):
17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
18 | return x
--------------------------------------------------------------------------------
/Deblur/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | import subprocess
4 | import torch
5 | import torch.distributed as dist
6 | import torch.multiprocessing as mp
7 |
8 |
9 | def init_dist(launcher, backend='nccl', **kwargs):
10 | if mp.get_start_method(allow_none=True) is None:
11 | mp.set_start_method('spawn')
12 | if launcher == 'pytorch':
13 | _init_dist_pytorch(backend, **kwargs)
14 | elif launcher == 'slurm':
15 | _init_dist_slurm(backend, **kwargs)
16 | else:
17 | raise ValueError(f'Invalid launcher type: {launcher}')
18 |
19 |
20 | def _init_dist_pytorch(backend, **kwargs):
21 | rank = int(os.environ['RANK'])
22 | num_gpus = torch.cuda.device_count()
23 | torch.cuda.set_device(rank % num_gpus)
24 | dist.init_process_group(backend=backend, **kwargs)
25 |
26 |
27 | def _init_dist_slurm(backend, port=None):
28 | """Initialize slurm distributed training environment.
29 | If argument ``port`` is not specified, then the master port will be system
30 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
31 | environment variable, then a default port ``29500`` will be used.
32 | Args:
33 | backend (str): Backend of torch.distributed.
34 | port (int, optional): Master port. Defaults to None.
35 | """
36 | proc_id = int(os.environ['SLURM_PROCID'])
37 | ntasks = int(os.environ['SLURM_NTASKS'])
38 | node_list = os.environ['SLURM_NODELIST']
39 | num_gpus = torch.cuda.device_count()
40 | torch.cuda.set_device(proc_id % num_gpus)
41 | addr = subprocess.getoutput(
42 | f'scontrol show hostname {node_list} | head -n1')
43 | # specify master port
44 | if port is not None:
45 | os.environ['MASTER_PORT'] = str(port)
46 | elif 'MASTER_PORT' in os.environ:
47 | pass # use MASTER_PORT in the environment variable
48 | else:
49 | # 29500 is torch.distributed default port
50 | os.environ['MASTER_PORT'] = '29500'
51 | os.environ['MASTER_ADDR'] = addr
52 | os.environ['WORLD_SIZE'] = str(ntasks)
53 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
54 | os.environ['RANK'] = str(proc_id)
55 | dist.init_process_group(backend=backend)
56 |
57 |
58 | def get_dist_info():
59 | if dist.is_available():
60 | initialized = dist.is_initialized()
61 | else:
62 | initialized = False
63 | if initialized:
64 | rank = dist.get_rank()
65 | world_size = dist.get_world_size()
66 | else:
67 | rank = 0
68 | world_size = 1
69 | return rank, world_size
70 |
71 |
72 | def master_only(func):
73 |
74 | @functools.wraps(func)
75 | def wrapper(*args, **kwargs):
76 | rank, _ = get_dist_info()
77 | if rank == 0:
78 | return func(*args, **kwargs)
79 |
80 | return
--------------------------------------------------------------------------------
/Deblur/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 | def torchPSNR(tar_img, prd_img):
6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
7 | rmse = (imdff**2).mean().sqrt()
8 | ps = 20*torch.log10(1/rmse)
9 | return ps
10 |
11 | def save_img(filepath, img):
12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
13 |
14 | def numpyPSNR(tar_img, prd_img):
15 | imdff = np.float32(prd_img) - np.float32(tar_img)
16 | rmse = np.sqrt(np.mean(imdff**2))
17 | ps = 20*np.log10(255/rmse)
18 | return ps
19 |
--------------------------------------------------------------------------------
/Deblur/utils/logger.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import datetime
8 | import logging
9 | import time
10 |
11 | from .dist_util import get_dist_info, master_only
12 |
13 |
14 | class MessageLogger():
15 | """Message logger for printing.
16 | Args:
17 | opt (dict): Config. It contains the following keys:
18 | name (str): Exp name.
19 | logger (dict): Contains 'print_freq' (str) for logger interval.
20 | train (dict): Contains 'total_iter' (int) for total iters.
21 | use_tb_logger (bool): Use tensorboard logger.
22 | start_iter (int): Start iter. Default: 1.
23 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
24 | """
25 |
26 | def __init__(self, opt, start_iter=1, tb_logger=None):
27 | self.exp_name = opt['name']
28 | self.interval = opt['logger']['print_freq']
29 | self.start_iter = start_iter
30 | self.max_iters = opt['train']['total_iter']
31 | self.use_tb_logger = opt['logger']['use_tb_logger']
32 | self.tb_logger = tb_logger
33 | self.start_time = time.time()
34 | self.logger = get_root_logger()
35 |
36 | @master_only
37 | def __call__(self, log_vars):
38 | """Format logging message.
39 | Args:
40 | log_vars (dict): It contains the following keys:
41 | epoch (int): Epoch number.
42 | iter (int): Current iter.
43 | lrs (list): List for learning rates.
44 | time (float): Iter time.
45 | data_time (float): Data time for each iter.
46 | """
47 | # epoch, iter, learning rates
48 | epoch = log_vars.pop('epoch')
49 | current_iter = log_vars.pop('iter')
50 | total_iter = log_vars.pop('total_iter')
51 | lrs = log_vars.pop('lrs')
52 |
53 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
54 | f'iter:{current_iter:8,d}, lr:(')
55 | for v in lrs:
56 | message += f'{v:.3e},'
57 | message += ')] '
58 |
59 | # time and estimated time
60 | if 'time' in log_vars.keys():
61 | iter_time = log_vars.pop('time')
62 | data_time = log_vars.pop('data_time')
63 |
64 | total_time = time.time() - self.start_time
65 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
66 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
67 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
68 | message += f'[eta: {eta_str}, '
69 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
70 |
71 | # other items, especially losses
72 | for k, v in log_vars.items():
73 | message += f'{k}: {v:.4e} '
74 | # tensorboard logger
75 | if self.use_tb_logger and 'debug' not in self.exp_name:
76 | normed_step = 10000 * (current_iter / total_iter)
77 | normed_step = int(normed_step)
78 |
79 | if k.startswith('l_'):
80 | self.tb_logger.add_scalar(f'losses/{k}', v, normed_step)
81 | elif k.startswith('m_'):
82 | self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step)
83 | else:
84 | assert 1 == 0
85 | # else:
86 | # self.tb_logger.add_scalar(k, v, current_iter)
87 | self.logger.info(message)
88 |
89 |
90 | @master_only
91 | def init_tb_logger(log_dir):
92 | from torch.utils.tensorboard import SummaryWriter
93 | tb_logger = SummaryWriter(log_dir=log_dir)
94 | return tb_logger
95 |
96 |
97 | @master_only
98 | def init_wandb_logger(opt):
99 | """We now only use wandb to sync tensorboard log."""
100 | import wandb
101 | logger = logging.getLogger('basicsr')
102 |
103 | project = opt['logger']['wandb']['project']
104 | resume_id = opt['logger']['wandb'].get('resume_id')
105 | if resume_id:
106 | wandb_id = resume_id
107 | resume = 'allow'
108 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
109 | else:
110 | wandb_id = wandb.util.generate_id()
111 | resume = 'never'
112 |
113 | wandb.init(
114 | id=wandb_id,
115 | resume=resume,
116 | name=opt['name'],
117 | config=opt,
118 | project=project,
119 | sync_tensorboard=True)
120 |
121 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
122 |
123 |
124 | def get_root_logger(logger_name='basicsr',
125 | log_level=logging.INFO,
126 | log_file=None):
127 | """Get the root logger.
128 | The logger will be initialized if it has not been initialized. By default a
129 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
130 | also be added.
131 | Args:
132 | logger_name (str): root logger name. Default: 'basicsr'.
133 | log_file (str | None): The log filename. If specified, a FileHandler
134 | will be added to the root logger.
135 | log_level (int): The root logger level. Note that only the process of
136 | rank 0 is affected, while other processes will set the level to
137 | "Error" and be silent most of the time.
138 | Returns:
139 | logging.Logger: The root logger.
140 | """
141 | logger = logging.getLogger(logger_name)
142 | # if the logger has been initialized, just return it
143 | if logger.hasHandlers():
144 | return logger
145 |
146 | format_str = '%(asctime)s %(levelname)s: %(message)s'
147 | logging.basicConfig(format=format_str, level=log_level)
148 | rank, _ = get_dist_info()
149 | if rank != 0:
150 | logger.setLevel('ERROR')
151 | elif log_file is not None:
152 | file_handler = logging.FileHandler(log_file, 'w')
153 | file_handler.setFormatter(logging.Formatter(format_str))
154 | file_handler.setLevel(log_level)
155 | logger.addHandler(file_handler)
156 |
157 | return logger
158 |
159 |
160 | def get_env_info():
161 | """Get environment information.
162 | Currently, only log the software version.
163 | """
164 | import torch
165 | import torchvision
166 |
167 | from basicsr.version import __version__
168 | msg = r"""
169 | ____ _ _____ ____
170 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
171 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
172 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
173 | /_____/ \__,_//____//_/ \___//____//_/ |_|
174 | ______ __ __ __ __
175 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
176 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
177 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
178 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
179 | """
180 | msg += ('\nVersion Information: '
181 | f'\n\tBasicSR: {__version__}'
182 | f'\n\tPyTorch: {torch.__version__}'
183 | f'\n\tTorchVision: {torchvision.__version__}')
184 | return
--------------------------------------------------------------------------------
/Deblur/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from collections import OrderedDict
4 |
5 | def freeze(model):
6 | for p in model.parameters():
7 | p.requires_grad=False
8 |
9 | def unfreeze(model):
10 | for p in model.parameters():
11 | p.requires_grad=True
12 |
13 | def is_frozen(model):
14 | x = [p.requires_grad for p in model.parameters()]
15 | return not all(x)
16 |
17 | def save_checkpoint(model_dir, state, session):
18 | epoch = state['epoch']
19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
20 | torch.save(state, model_out_path)
21 |
22 | def load_checkpoint(model, weights):
23 | checkpoint = torch.load(weights)
24 | try:
25 | model.load_state_dict(checkpoint["state_dict"])
26 | except:
27 | state_dict = checkpoint["state_dict"]
28 | new_state_dict = OrderedDict()
29 | for k, v in state_dict.items():
30 | name = k[7:] # remove `module.`
31 | new_state_dict[name] = v
32 | model.load_state_dict(new_state_dict)
33 |
34 |
35 | def load_checkpoint_multigpu(model, weights):
36 | checkpoint = torch.load(weights)
37 | state_dict = checkpoint["state_dict"]
38 | new_state_dict = OrderedDict()
39 | for k, v in state_dict.items():
40 | name = k[7:] # remove `module.`
41 | new_state_dict[name] = v
42 | model.load_state_dict(new_state_dict)
43 |
44 | def load_start_epoch(weights):
45 | checkpoint = torch.load(weights)
46 | epoch = checkpoint["epoch"]
47 | return epoch
48 |
49 | def load_optim(optimizer, weights):
50 | checkpoint = torch.load(weights)
51 | optimizer.load_state_dict(checkpoint['optimizer'])
52 | # for p in optimizer.param_groups: lr = p['lr']
53 | # return lr
54 |
--------------------------------------------------------------------------------
/Derain/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Training
3 | - Download datasets from the google drive links and place them in this directory. Your directory structure should look something like this
4 |
5 | `Synthetic_Rain_Datasets`
6 | `├──`[train](https://drive.google.com/drive/folders/1Hnnlc5kI0v9_BtfMytC2LR5VpLAFZtVe?usp=sharing)
7 | `└──`[test](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing)
8 | `├──Test100`
9 | `├──Rain100H`
10 | `├──Rain100L`
11 | `├──Test1200`
12 |
13 |
14 |
15 | - Train the model with default arguments by running
16 |
17 | ```
18 | python train.py
19 | ```
20 |
21 |
22 | ## Evaluation
23 |
24 | 1. Download the [model](https://drive.google.com/drive/folders/1qBC3mUoLoCuMyuiseYoZWzvyvImG98TW?usp=drive_link) and place it in `./pretrained_models/`
25 |
26 | 2. Download test datasets (Test100, Rain100H, Rain100L, Test1200) from [here](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing) and place them in `./Datasets/Synthetic_Rain_Datasets/test/`
27 |
28 | 3. Run
29 | ```
30 | python test.py
31 | ```
32 |
33 | #### To reproduce PSNR/SSIM scores of the paper, run
34 | ```
35 | python eval.py
36 | ```
37 |
--------------------------------------------------------------------------------
/Derain/cal.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | import skimage.metrics
6 | import torch
7 | import math
8 |
9 | def calculate_psnr(img1, img2, crop_border, test_y_channel=True):
10 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
11 |
12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13 | Args:
14 | img1 (ndarray): Images with range [0, 255].
15 | img2 (ndarray): Images with range [0, 255].
16 | crop_border (int): Cropped pixels in each edge of an image. These
17 | pixels are not involved in the PSNR calculation.
18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
19 | Returns:
20 | float: psnr result.
21 | """
22 | assert img1.shape == img2.shape, (
23 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
24 | if type(img1) == torch.Tensor:
25 | if len(img1.shape) == 4:
26 | img1 = img1.squeeze(0)
27 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
28 | if type(img2) == torch.Tensor:
29 | if len(img2.shape) == 4:
30 | img2 = img2.squeeze(0)
31 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
32 | img1 = img1.astype(np.float64)
33 | img2 = img2.astype(np.float64)
34 |
35 | if crop_border != 0:
36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38 |
39 | if test_y_channel:
40 | img1 = to_y_channel(img1)
41 | img2 = to_y_channel(img2)
42 |
43 | imdff = np.float32(img1) - np.float32(img2)
44 | rmse = np.sqrt(np.mean(imdff**2))
45 | ps = 20*np.log10(255/rmse)
46 | return ps
47 |
48 |
49 | def _convert_input_type_range(img):
50 | """Convert the type and range of the input image.
51 |
52 | It converts the input image to np.float32 type and range of [0, 1].
53 | It is mainly used for pre-processing the input image in colorspace
54 | convertion functions such as rgb2ycbcr and ycbcr2rgb.
55 | Args:
56 | img (ndarray): The input image. It accepts:
57 | 1. np.uint8 type with range [0, 255];
58 | 2. np.float32 type with range [0, 1].
59 | Returns:
60 | (ndarray): The converted image with type of np.float32 and range of
61 | [0, 1].
62 | """
63 | img_type = img.dtype
64 | img = img.astype(np.float32)
65 | if img_type == np.float32:
66 | pass
67 | elif img_type == np.uint8:
68 | img /= 255.
69 | else:
70 | raise TypeError('The img type should be np.float32 or np.uint8, '
71 | f'but got {img_type}')
72 | return img
73 |
74 |
75 | def _convert_output_type_range(img, dst_type):
76 | """Convert the type and range of the image according to dst_type.
77 |
78 | It converts the image to desired type and range. If `dst_type` is np.uint8,
79 | images will be converted to np.uint8 type with range [0, 255]. If
80 | `dst_type` is np.float32, it converts the image to np.float32 type with
81 | range [0, 1].
82 | It is mainly used for post-processing images in colorspace convertion
83 | functions such as rgb2ycbcr and ycbcr2rgb.
84 | Args:
85 | img (ndarray): The image to be converted with np.float32 type and
86 | range [0, 255].
87 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
88 | converts the image to np.uint8 type with range [0, 255]. If
89 | dst_type is np.float32, it converts the image to np.float32 type
90 | with range [0, 1].
91 | Returns:
92 | (ndarray): The converted image with desired type and range.
93 | """
94 | if dst_type not in (np.uint8, np.float32):
95 | raise TypeError('The dst_type should be np.float32 or np.uint8, '
96 | f'but got {dst_type}')
97 | if dst_type == np.uint8:
98 | img = img.round()
99 | else:
100 | img /= 255.
101 |
102 | return img.astype(dst_type)
103 |
104 |
105 | def rgb2ycbcr(img, y_only=True):
106 | """Convert a RGB image to YCbCr image.
107 |
108 | This function produces the same results as Matlab's `rgb2ycbcr` function.
109 | It implements the ITU-R BT.601 conversion for standard-definition
110 | television. See more details in
111 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
112 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
113 | In OpenCV, it implements a JPEG conversion. See more details in
114 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
115 |
116 | Args:
117 | img (ndarray): The input image. It accepts:
118 | 1. np.uint8 type with range [0, 255];
119 | 2. np.float32 type with range [0, 1].
120 | y_only (bool): Whether to only return Y channel. Default: False.
121 | Returns:
122 | ndarray: The converted YCbCr image. The output image has the same type
123 | and range as input image.
124 | """
125 | img_type = img.dtype
126 | img = _convert_input_type_range(img)
127 | if y_only:
128 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
129 | else:
130 | out_img = np.matmul(img,
131 | [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
132 | [24.966, 112.0, -18.214]]) + [16, 128, 128]
133 | out_img = _convert_output_type_range(out_img, img_type)
134 | return out_img
135 |
136 |
137 | def to_y_channel(img):
138 | """Change to Y channel of YCbCr.
139 |
140 | Args:
141 | img (ndarray): Images with range [0, 255].
142 | Returns:
143 | (ndarray): Images with range [0, 255] (float type) without round.
144 | """
145 | img = img.astype(np.float32) / 255.
146 | if img.ndim == 3 and img.shape[2] == 3:
147 | img = rgb2ycbcr(img, y_only=True)
148 | img = img[..., None]
149 | return img * 255.
150 |
151 | def _ssim(img1, img2):
152 | """Calculate SSIM (structural similarity) for one channel images.
153 |
154 | It is called by func:`calculate_ssim`.
155 |
156 | Args:
157 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
158 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
159 |
160 | Returns:
161 | float: ssim result.
162 | """
163 |
164 | C1 = (0.01 * 255)**2
165 | C2 = (0.03 * 255)**2
166 |
167 | img1 = img1.astype(np.float64)
168 | img2 = img2.astype(np.float64)
169 | kernel = cv2.getGaussianKernel(11, 1.5)
170 | window = np.outer(kernel, kernel.transpose())
171 |
172 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
173 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
174 | mu1_sq = mu1**2
175 | mu2_sq = mu2**2
176 | mu1_mu2 = mu1 * mu2
177 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
178 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
179 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
180 |
181 | ssim_map = ((2 * mu1_mu2 + C1) *
182 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
183 | (sigma1_sq + sigma2_sq + C2))
184 | return ssim_map.mean()
185 |
186 | def prepare_for_ssim(img, k):
187 | import torch
188 | with torch.no_grad():
189 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
190 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
191 | conv.weight.requires_grad = False
192 | conv.weight[:, :, :, :] = 1. / (k * k)
193 |
194 | img = conv(img)
195 |
196 | img = img.squeeze(0).squeeze(0)
197 | img = img[0::k, 0::k]
198 | return img.detach().cpu().numpy()
199 |
200 | def prepare_for_ssim_rgb(img, k):
201 | import torch
202 | with torch.no_grad():
203 | img = torch.from_numpy(img).float() #HxWx3
204 |
205 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
206 | conv.weight.requires_grad = False
207 | conv.weight[:, :, :, :] = 1. / (k * k)
208 |
209 | new_img = []
210 |
211 | for i in range(3):
212 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
213 |
214 | return torch.stack(new_img, dim=2).detach().cpu().numpy()
215 |
216 | def _3d_gaussian_calculator(img, conv3d):
217 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
218 | return out
219 |
220 | def _generate_3d_gaussian_kernel():
221 | kernel = cv2.getGaussianKernel(11, 1.5)
222 | window = np.outer(kernel, kernel.transpose())
223 | kernel_3 = cv2.getGaussianKernel(11, 1.5)
224 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
225 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
226 | conv3d.weight.requires_grad = False
227 | conv3d.weight[0, 0, :, :, :] = kernel
228 | return conv3d
229 |
230 | def _ssim_3d(img1, img2, max_value):
231 | assert len(img1.shape) == 3 and len(img2.shape) == 3
232 | """Calculate SSIM (structural similarity) for one channel images.
233 |
234 | It is called by func:`calculate_ssim`.
235 |
236 | Args:
237 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
238 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
239 |
240 | Returns:
241 | float: ssim result.
242 | """
243 | C1 = (0.01 * max_value) ** 2
244 | C2 = (0.03 * max_value) ** 2
245 | img1 = img1.astype(np.float64)
246 | img2 = img2.astype(np.float64)
247 |
248 | kernel = _generate_3d_gaussian_kernel().cuda()
249 |
250 | img1 = torch.tensor(img1).float().cuda()
251 | img2 = torch.tensor(img2).float().cuda()
252 |
253 |
254 | mu1 = _3d_gaussian_calculator(img1, kernel)
255 | mu2 = _3d_gaussian_calculator(img2, kernel)
256 |
257 | mu1_sq = mu1 ** 2
258 | mu2_sq = mu2 ** 2
259 | mu1_mu2 = mu1 * mu2
260 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
261 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
262 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
263 |
264 | ssim_map = ((2 * mu1_mu2 + C1) *
265 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
266 | (sigma1_sq + sigma2_sq + C2))
267 | return float(ssim_map.mean())
268 |
269 | def _ssim_cly(img1, img2):
270 | assert len(img1.shape) == 2 and len(img2.shape) == 2
271 | """Calculate SSIM (structural similarity) for one channel images.
272 |
273 | It is called by func:`calculate_ssim`.
274 |
275 | Args:
276 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
277 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
278 |
279 | Returns:
280 | float: ssim result.
281 | """
282 |
283 | C1 = (0.01 * 255)**2
284 | C2 = (0.03 * 255)**2
285 | img1 = img1.astype(np.float64)
286 | img2 = img2.astype(np.float64)
287 |
288 | kernel = cv2.getGaussianKernel(11, 1.5)
289 | # print(kernel)
290 | window = np.outer(kernel, kernel.transpose())
291 |
292 | bt = cv2.BORDER_REPLICATE
293 |
294 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
295 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
296 |
297 | mu1_sq = mu1**2
298 | mu2_sq = mu2**2
299 | mu1_mu2 = mu1 * mu2
300 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
301 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
302 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
303 |
304 | ssim_map = ((2 * mu1_mu2 + C1) *
305 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
306 | (sigma1_sq + sigma2_sq + C2))
307 | return ssim_map.mean()
308 | def reorder_image(img, input_order='HWC'):
309 | """Reorder images to 'HWC' order.
310 |
311 | If the input_order is (h, w), return (h, w, 1);
312 | If the input_order is (c, h, w), return (h, w, c);
313 | If the input_order is (h, w, c), return as it is.
314 |
315 | Args:
316 | img (ndarray): Input image.
317 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
318 | If the input image shape is (h, w), input_order will not have
319 | effects. Default: 'HWC'.
320 |
321 | Returns:
322 | ndarray: reordered image.
323 | """
324 |
325 | if input_order not in ['HWC', 'CHW']:
326 | raise ValueError(
327 | f'Wrong input_order {input_order}. Supported input_orders are '
328 | "'HWC' and 'CHW'")
329 | if len(img.shape) == 2:
330 | img = img[..., None]
331 | if input_order == 'CHW':
332 | img = img.transpose(1, 2, 0)
333 | return img
334 |
335 |
336 | def calculate_ssim(img1,
337 | img2,
338 | crop_border,
339 | input_order='HWC',
340 | test_y_channel=True):
341 | """Calculate SSIM (structural similarity).
342 |
343 | Ref:
344 | Image quality assessment: From error visibility to structural similarity
345 |
346 | The results are the same as that of the official released MATLAB code in
347 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
348 |
349 | For three-channel images, SSIM is calculated for each channel and then
350 | averaged.
351 |
352 | Args:
353 | img1 (ndarray): Images with range [0, 255].
354 | img2 (ndarray): Images with range [0, 255].
355 | crop_border (int): Cropped pixels in each edge of an image. These
356 | pixels are not involved in the SSIM calculation.
357 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
358 | Default: 'HWC'.
359 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
360 |
361 | Returns:
362 | float: ssim result.
363 | """
364 |
365 | assert img1.shape == img2.shape, (
366 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
367 | if input_order not in ['HWC', 'CHW']:
368 | raise ValueError(
369 | f'Wrong input_order {input_order}. Supported input_orders are '
370 | '"HWC" and "CHW"')
371 |
372 | if type(img1) == torch.Tensor:
373 | if len(img1.shape) == 4:
374 | img1 = img1.squeeze(0)
375 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
376 | if type(img2) == torch.Tensor:
377 | if len(img2.shape) == 4:
378 | img2 = img2.squeeze(0)
379 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
380 |
381 | img1 = reorder_image(img1, input_order=input_order)
382 | img2 = reorder_image(img2, input_order=input_order)
383 |
384 | img1 = img1.astype(np.float64)
385 | img2 = img2.astype(np.float64)
386 |
387 | if crop_border != 0:
388 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
389 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
390 |
391 | if test_y_channel:
392 | img1 = to_y_channel(img1)
393 | img2 = to_y_channel(img2)
394 | return _ssim_cly(img1[..., 0], img2[..., 0])
395 |
396 |
397 | ssims = []
398 | # ssims_before = []
399 |
400 | # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
401 | # print('.._skimage',
402 | # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
403 | max_value = 1 if img1.max() <= 1 else 255
404 | with torch.no_grad():
405 | final_ssim = _ssim_3d(img1, img2, max_value)
406 | ssims.append(final_ssim)
407 |
408 | # for i in range(img1.shape[2]):
409 | # ssims_before.append(_ssim(img1, img2))
410 |
411 | # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
412 | # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
413 |
414 | return np.array(ssims).mean()
--------------------------------------------------------------------------------
/Derain/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | r"""This module provides package-wide configuration management."""
6 | from typing import Any, List
7 |
8 | from yacs.config import CfgNode as CN
9 |
10 |
11 | class Config(object):
12 | r"""
13 | A collection of all the required configuration parameters. This class is a nested dict-like
14 | structure, with nested keys accessible as attributes. It contains sensible default values for
15 | all the parameters, which may be overriden by (first) through a YAML file and (second) through
16 | a list of attributes and values.
17 |
18 | Extended Summary
19 | ----------------
20 | This class definition contains default values corresponding to ``joint_training`` phase, as it
21 | is the final training phase and uses almost all the configuration parameters. Modification of
22 | any parameter after instantiating this class is not possible, so you must override required
23 | parameter values in either through ``config_yaml`` file or ``config_override`` list.
24 |
25 | Parameters
26 | ----------
27 | config_yaml: str
28 | Path to a YAML file containing configuration parameters to override.
29 | config_override: List[Any], optional (default= [])
30 | A list of sequential attributes and values of parameters to override. This happens after
31 | overriding from YAML file.
32 |
33 | Examples
34 | --------
35 | Let a YAML file named "config.yaml" specify these parameters to override::
36 |
37 | ALPHA: 1000.0
38 | BETA: 0.5
39 |
40 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7])
41 | >>> _C.ALPHA # default: 100.0
42 | 1000.0
43 | >>> _C.BATCH_SIZE # default: 256
44 | 2048
45 | >>> _C.BETA # default: 0.1
46 | 0.7
47 |
48 | Attributes
49 | ----------
50 | """
51 |
52 | def __init__(self, config_yaml: str, config_override: List[Any] = []):
53 |
54 | self._C = CN()
55 | self._C.GPU = [0]
56 | self._C.VERBOSE = False
57 |
58 | self._C.MODEL = CN()
59 | self._C.MODEL.MODE = 'global'
60 | self._C.MODEL.SESSION = 'ps128_bs1'
61 |
62 | self._C.OPTIM = CN()
63 | self._C.OPTIM.BATCH_SIZE = 1
64 | self._C.OPTIM.NUM_EPOCHS = 100
65 | self._C.OPTIM.NEPOCH_DECAY = [100]
66 | self._C.OPTIM.LR_INITIAL = 0.0002
67 | self._C.OPTIM.LR_MIN = 0.0002
68 | self._C.OPTIM.BETA1 = 0.5
69 |
70 | self._C.TRAINING = CN()
71 | self._C.TRAINING.VAL_AFTER_EVERY = 3
72 | self._C.TRAINING.RESUME = False
73 | self._C.TRAINING.SAVE_IMAGES = False
74 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train'
75 | self._C.TRAINING.VAL_DIR = 'images_dir/val'
76 | self._C.TRAINING.SAVE_DIR = 'checkpoints'
77 | self._C.TRAINING.TRAIN_PS = 64
78 | self._C.TRAINING.VAL_PS = 64
79 |
80 | # Override parameter values from YAML file first, then from override list.
81 | self._C.merge_from_file(config_yaml)
82 | self._C.merge_from_list(config_override)
83 |
84 | # Make an instantiated object of this class immutable.
85 | self._C.freeze()
86 |
87 | def dump(self, file_path: str):
88 | r"""Save config at the specified file path.
89 |
90 | Parameters
91 | ----------
92 | file_path: str
93 | (YAML) path to save config at.
94 | """
95 | self._C.dump(stream=open(file_path, "w"))
96 |
97 | def __getattr__(self, attr: str):
98 | return self._C.__getattr__(attr)
99 |
100 | def __repr__(self):
101 | return self._C.__repr__()
102 |
--------------------------------------------------------------------------------
/Derain/data_RGB.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTest2
3 |
4 | def get_training_data(rgb_dir, img_options):
5 | assert os.path.exists(rgb_dir)
6 | return DataLoaderTrain(rgb_dir, img_options)
7 |
8 | def get_validation_data(rgb_dir, img_options):
9 | assert os.path.exists(rgb_dir)
10 | return DataLoaderVal(rgb_dir, img_options)
11 |
12 | def get_test_data(rgb_dir, img_options):
13 | assert os.path.exists(rgb_dir)
14 | return DataLoaderTest(rgb_dir, img_options)
15 |
16 | def get_test_data2(rgb_dir, img_options):
17 | assert os.path.exists(rgb_dir)
18 | return DataLoaderTest2(rgb_dir, img_options)
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/Derain/dataset_RGB.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | from PIL import Image
6 | import torchvision.transforms.functional as TF
7 | from pdb import set_trace as stx
8 | import random
9 | import utils
10 |
11 |
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
14 |
15 |
16 | class DataLoaderTrain(Dataset):
17 | def __init__(self, rgb_dir, img_options=None):
18 | super(DataLoaderTrain, self).__init__()
19 |
20 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
21 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
22 |
23 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
24 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
25 |
26 | self.img_options = img_options
27 | self.sizex = len(self.tar_filenames) # get the size of target
28 |
29 | self.ps = self.img_options['patch_size']
30 |
31 | def __len__(self):
32 | return self.sizex
33 |
34 | def __getitem__(self, index):
35 | index_ = index % self.sizex
36 | ps = self.ps
37 |
38 | inp_path = self.inp_filenames[index_]
39 | tar_path = self.tar_filenames[index_]
40 |
41 | inp_img = Image.open(inp_path)
42 | tar_img = Image.open(tar_path)
43 |
44 | w, h = tar_img.size
45 | padw = ps - w if w < ps else 0
46 | padh = ps - h if h < ps else 0
47 |
48 | # Reflect Pad in case image is smaller than patch_size
49 | if padw != 0 or padh != 0:
50 | inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect')
51 | tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect')
52 |
53 |
54 | inp_img = TF.to_tensor(inp_img)
55 | tar_img = TF.to_tensor(tar_img)
56 |
57 | hh, ww = tar_img.shape[1], tar_img.shape[2]
58 |
59 | rr = random.randint(0, hh - ps)
60 | cc = random.randint(0, ww - ps)
61 | aug = random.randint(0, 8)
62 |
63 | # Crop patch
64 | inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
65 | tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]
66 |
67 | # Data Augmentations
68 | if aug == 1:
69 | inp_img = inp_img.flip(1)
70 | tar_img = tar_img.flip(1)
71 | elif aug == 2:
72 | inp_img = inp_img.flip(2)
73 | tar_img = tar_img.flip(2)
74 | elif aug == 3:
75 | inp_img = torch.rot90(inp_img, dims=(1, 2))
76 | tar_img = torch.rot90(tar_img, dims=(1, 2))
77 | elif aug == 4:
78 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=2)
79 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=2)
80 | elif aug == 5:
81 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=3)
82 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=3)
83 | elif aug == 6:
84 | inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2))
85 | tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2))
86 | elif aug == 7:
87 | inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2))
88 | tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2))
89 |
90 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
91 |
92 | return tar_img, inp_img, filename
93 |
94 |
95 | class DataLoaderVal(Dataset):
96 | def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
97 | super(DataLoaderVal, self).__init__()
98 |
99 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
100 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
101 |
102 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
103 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
104 |
105 | self.img_options = img_options
106 | self.sizex = len(self.tar_filenames) # get the size of target
107 |
108 | self.ps = self.img_options['patch_size']
109 |
110 | def __len__(self):
111 | return self.sizex
112 |
113 | def __getitem__(self, index):
114 | index_ = index % self.sizex
115 | ps = self.ps
116 |
117 | inp_path = self.inp_filenames[index_]
118 | tar_path = self.tar_filenames[index_]
119 |
120 | inp_img = Image.open(inp_path)
121 | tar_img = Image.open(tar_path)
122 |
123 | # Validate on center crop
124 | if self.ps is not None:
125 | inp_img = TF.center_crop(inp_img, (ps, ps))
126 | tar_img = TF.center_crop(tar_img, (ps, ps))
127 |
128 | inp_img = TF.to_tensor(inp_img)
129 | tar_img = TF.to_tensor(tar_img)
130 |
131 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
132 |
133 |
134 | return tar_img, inp_img, filename
135 |
136 |
137 | class DataLoaderTest(Dataset):
138 | def __init__(self, rgb_dir, img_options):
139 | super(DataLoaderTest, self).__init__()
140 |
141 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
142 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
143 |
144 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
145 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
146 |
147 | self.inp_size = len(self.inp_filenames)
148 | self.img_options = img_options
149 |
150 | def __len__(self):
151 | return self.inp_size
152 |
153 | def __getitem__(self, index):
154 | path_inp = self.inp_filenames[index]
155 | tar_path = self.tar_filenames[index]
156 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
157 | inp = Image.open(path_inp)
158 | tar_img = Image.open(tar_path)
159 |
160 | inp = TF.to_tensor(inp)
161 | tar_img = TF.to_tensor(tar_img)
162 | return inp, tar_img
163 |
164 |
--------------------------------------------------------------------------------
/Derain/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from glob import glob
4 | import cv2
5 | from natsort import natsorted
6 |
7 | from skimage.metrics import structural_similarity,peak_signal_noise_ratio
8 | from cal import calculate_psnr,calculate_ssim
9 |
10 | def read_img(path):
11 | return cv2.imread(path)
12 |
13 |
14 |
15 |
16 | def main():
17 | datasets = {'GoPr', 'HIDE'}
18 | file_path = os.path.join('resultsmash_g/Raindata/test', 'Rain100H')
19 | gt_path = os.path.join('Dataset/Raindata/test/Rain100H', 'target')
20 | print(file_path)
21 | print(gt_path)
22 |
23 | path_fake = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
24 | path_real = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
25 | print(len(path_fake))
26 | list_psnr = []
27 | list_ssim = []
28 | list_mse = []
29 |
30 | for i in range(len(path_real)):
31 | t1 = read_img(path_real[i])
32 | t2 = read_img(path_fake[i])
33 | #result1 = np.zeros(t1.shape,dtype=np.float32)
34 | #result2 = np.zeros(t2.shape,dtype=np.float32)
35 | #cv2.normalize(t1,result1,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F)
36 | #cv2.normalize(t2,result2,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F)
37 |
38 |
39 |
40 | psnr_num = calculate_psnr(t1, t2,0)
41 | ssim_num = calculate_ssim(t1, t2,0)
42 |
43 | list_ssim.append(ssim_num)
44 | list_psnr.append(psnr_num)
45 |
46 |
47 |
48 | print("AverSSIM:", np.mean(list_ssim)) # ,list_ssim)
49 | print("AverPSNR:", np.mean(list_psnr)) # ,list_ssim)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
--------------------------------------------------------------------------------
/Derain/evaluate_PSNR_SSIM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from glob import glob
4 | from natsort import natsorted
5 | from skimage import io
6 | import cv2
7 | from skimage.metrics import structural_similarity
8 | from tqdm import tqdm
9 | import concurrent.futures
10 |
11 |
12 | def image_align(deblurred, gt):
13 | # this function is based on kohler evaluation code
14 | z = deblurred
15 | c = np.ones_like(z)
16 | x = gt
17 |
18 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
19 |
20 | warp_mode = cv2.MOTION_HOMOGRAPHY
21 | warp_matrix = np.eye(3, 3, dtype=np.float32)
22 |
23 | # Specify the number of iterations.
24 | number_of_iterations = 100
25 |
26 | termination_eps = 0
27 |
28 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
29 | number_of_iterations, termination_eps)
30 |
31 | # Run the ECC algorithm. The results are stored in warp_matrix.
32 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
33 | warp_matrix, warp_mode, criteria, inputMask=None)
34 |
35 | target_shape = x.shape
36 | shift = warp_matrix
37 |
38 | zr = cv2.warpPerspective(
39 | zs,
40 | warp_matrix,
41 | (target_shape[1], target_shape[0]),
42 | flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
43 | borderMode=cv2.BORDER_REFLECT)
44 |
45 | cr = cv2.warpPerspective(
46 | np.ones_like(zs, dtype='float32'),
47 | warp_matrix,
48 | (target_shape[1], target_shape[0]),
49 | flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
50 | borderMode=cv2.BORDER_CONSTANT,
51 | borderValue=0)
52 |
53 | zr = zr * cr
54 | xr = x * cr
55 |
56 | return zr, xr, cr, shift
57 |
58 |
59 | def compute_psnr(image_true, image_test, image_mask, data_range=None):
60 | # this function is based on skimage.metrics.peak_signal_noise_ratio
61 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
62 | return 10 * np.log10((data_range ** 2) / err)
63 |
64 |
65 | def compute_ssim(tar_img, prd_img, cr1):
66 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
67 | use_sample_covariance=False, data_range=1.0, full=True)
68 | ssim_map = ssim_map * cr1
69 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
70 | win_size = 2 * r + 1
71 | pad = (win_size - 1) // 2
72 | ssim = ssim_map[pad:-pad, pad:-pad, :]
73 | crop_cr1 = cr1[pad:-pad, pad:-pad, :]
74 | ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
75 | ssim = np.mean(ssim)
76 | return ssim
77 |
78 |
79 | def proc(filename):
80 | tar, prd = filename
81 | tar_img = io.imread(tar)
82 | prd_img = io.imread(prd)
83 |
84 | tar_img = tar_img.astype(np.float32) / 255.0
85 | prd_img = prd_img.astype(np.float32) / 255.0
86 |
87 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
88 |
89 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
90 | SSIM = compute_ssim(tar_img, prd_img, cr1)
91 | return (PSNR, SSIM)
92 |
93 |
94 | def te():
95 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200']
96 |
97 | for dataset in datasets:
98 |
99 | file_path = os.path.join('mashresults' , dataset)
100 | gt_path = os.path.join('Datasets','test', dataset, 'target')
101 |
102 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
103 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
104 |
105 | assert len(path_list) != 0, "Predicted files not found"
106 | assert len(gt_list) != 0, "Target files not found"
107 | index = 0
108 | psnr, ssim = [], []
109 |
110 |
111 | img_files =[(i, j) for i,j in zip(gt_list,path_list)]
112 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
113 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
114 | psnr.append(PSNR_SSIM[0])
115 | ssim.append(PSNR_SSIM[1])
116 |
117 |
118 |
119 | #img_files = [(i, j) for i, j in zip(gt_list, path_list)]
120 | #for i in range(len(img_files)):
121 | # res = proc(img_files[i])
122 | # psnr.append(res[0])
123 | # ssim.append(res[1])
124 | # with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
125 | # for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
126 | # index = index + 1
127 | # print(index)
128 | # psnr.append(PSNR_SSIM[0])
129 | # ssim.append(PSNR_SSIM[1])
130 |
131 | avg_psnr = sum(psnr) / len(psnr)
132 | avg_ssim = sum(ssim) / len(ssim)
133 |
134 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
135 |
136 | if __name__=='__main__':
137 | te()
--------------------------------------------------------------------------------
/Derain/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class CharbonnierLoss(nn.Module):
7 | """Charbonnier Loss (L1)"""
8 |
9 | def __init__(self, eps=1e-3):
10 | super(CharbonnierLoss, self).__init__()
11 | self.eps = eps
12 |
13 | def forward(self, x, y):
14 | diff = x - y
15 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
16 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
17 | return loss
18 |
19 | class EdgeLoss(nn.Module):
20 | def __init__(self):
21 | super(EdgeLoss, self).__init__()
22 | k = torch.Tensor([[.05, .25, .4, .25, .05]])
23 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
24 | if torch.cuda.is_available():
25 | self.kernel = self.kernel.cuda()
26 | self.loss = CharbonnierLoss()
27 |
28 | def conv_gauss(self, img):
29 | n_channels, _, kw, kh = self.kernel.shape
30 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
31 | return F.conv2d(img, self.kernel, groups=n_channels)
32 |
33 | def laplacian_kernel(self, current):
34 | filtered = self.conv_gauss(current) # filter
35 | down = filtered[:,:,::2,::2] # downsample
36 | new_filter = torch.zeros_like(filtered)
37 | new_filter[:,:,::2,::2] = down*4 # upsample
38 | filtered = self.conv_gauss(new_filter) # filter
39 | diff = current - filtered
40 | return diff
41 |
42 | def forward(self, x, y):
43 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
44 | return loss
45 |
46 |
47 | class PSNRLoss(nn.Module):
48 |
49 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
50 | super(PSNRLoss, self).__init__()
51 | assert reduction == 'mean'
52 | self.loss_weight = loss_weight
53 | self.scale = 10 / np.log(10)
54 | self.toY = toY
55 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
56 | self.first = True
57 |
58 | def forward(self, pred, target):
59 | assert len(pred.size()) == 4
60 | if self.toY:
61 | if self.first:
62 | self.coef = self.coef.to(pred.device)
63 | self.first = False
64 |
65 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
66 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
67 |
68 | pred, target = pred / 255., target / 255.
69 | pass
70 | assert len(pred.size()) == 4
71 |
72 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
73 |
--------------------------------------------------------------------------------
/Derain/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch.nn as nn
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 | import utils
11 |
12 | from data_RGB import get_test_data
13 | from MHNet import MHNet
14 | from skimage import img_as_ubyte
15 | from pdb import set_trace as stx
16 |
17 | parser = argparse.ArgumentParser(description='Image Deraining using MPRNet')
18 |
19 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images')
20 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
21 | parser.add_argument('--weights', default='./pre-trained/model_best.pth', type=str, help='Path to weights')
22 | parser.add_argument('--gpus', default='2', type=str, help='CUDA_VISIBLE_DEVICES')
23 |
24 | args = parser.parse_args()
25 |
26 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
27 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
28 |
29 | model_restoration = MHNet()
30 |
31 | utils.load_checkpoint(model_restoration,args.weights)
32 | print("===>Testing using weights: ",args.weights)
33 | model_restoration.cuda()
34 | model_restoration = nn.DataParallel(model_restoration)
35 | model_restoration.eval()
36 |
37 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200']
38 | # datasets = ['Rain100L']
39 |
40 | for dataset in datasets:
41 | rgb_dir_test = os.path.join(args.input_dir, dataset, 'input')
42 | test_dataset = get_test_data(rgb_dir_test, img_options={})
43 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
44 |
45 | result_dir = os.path.join(args.result_dir, dataset)
46 | utils.mkdir(result_dir)
47 |
48 | with torch.no_grad():
49 | for ii, data_test in enumerate(tqdm(test_loader), 0):
50 | torch.cuda.ipc_collect()
51 | torch.cuda.empty_cache()
52 |
53 | input_ = data_test[0].cuda()
54 | filenames = data_test[1]
55 |
56 | restored = model_restoration(input_)
57 | restored = torch.clamp(restored[0],0,1)
58 |
59 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
60 |
61 | for batch in range(len(restored)):
62 | restored_img = img_as_ubyte(restored[batch])
63 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
64 |
--------------------------------------------------------------------------------
/Derain/train.py:
--------------------------------------------------------------------------------
1 |
2 | #!/usr/bin/env python
3 | # coding=utf-8
4 |
5 | import os
6 | from config import Config
7 |
8 | opt = Config('trmash.yml')
9 |
10 | gpus = ','.join([str(i) for i in opt.GPU])
11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
12 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus
13 |
14 | import torch
15 |
16 | torch.backends.cudnn.benchmark = True
17 |
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import torch.optim as optim
21 | from torch.utils.data import DataLoader
22 | import wandb
23 |
24 | import random
25 | import time
26 | import numpy as np
27 | from pathlib import Path
28 |
29 | import utils
30 | from data_RGB import get_training_data, get_validation_data
31 | from MHNet import MHNet
32 | import losses
33 | from warmup_scheduler import GradualWarmupScheduler
34 | from tqdm import tqdm
35 | from pdb import set_trace as stx
36 |
37 |
38 | dir_checkpoint = Path('./mhnetmash/')
39 |
40 | def train():
41 |
42 | ######### Set Seeds ###########
43 | random.seed(1234)
44 | np.random.seed(1234)
45 | torch.manual_seed(42)
46 | torch.cuda.manual_seed_all(42)
47 |
48 | start_epoch = 1
49 | mode = opt.MODEL.MODE
50 | session = opt.MODEL.SESSION
51 |
52 | result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
53 | model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
54 |
55 | utils.mkdir(result_dir)
56 | utils.mkdir(model_dir)
57 |
58 | train_dir = opt.TRAINING.TRAIN_DIR
59 | val_dir = opt.TRAINING.VAL_DIR
60 |
61 | ######### Model ###########
62 | model_restoration = MHNet()
63 | print("Total number of param is ", sum(x.numel() for x in model_restoration.parameters()))
64 | model_restoration.cuda()
65 |
66 | device_ids = [i for i in range(torch.cuda.device_count())]
67 | if torch.cuda.device_count() > 1:
68 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
69 |
70 |
71 | new_lr = opt.OPTIM.LR_INITIAL
72 |
73 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)
74 |
75 |
76 | ######### Scheduler ###########
77 | warmup_epochs = 3
78 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
79 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
80 | scheduler.step()
81 |
82 | ######### Resume ###########
83 | if opt.TRAINING.RESUME:
84 | path_chk_rest = './mhnetmash/model_best.pth'
85 | utils.load_checkpoint(model_restoration,path_chk_rest)
86 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
87 | utils.load_optim(optimizer, path_chk_rest)
88 |
89 | for i in range(1, start_epoch):
90 | scheduler.step()
91 | new_lr = scheduler.get_lr()[0]
92 | print('------------------------------------------------------------------------------')
93 | print("==> Resuming Training with learning rate:", new_lr)
94 | print('------------------------------------------------------------------------------')
95 |
96 | if len(device_ids)>1:
97 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)
98 | print("duoka")
99 |
100 | ######### Loss ###########
101 | criterion_mse = losses.PSNRLoss()
102 |
103 | ######### DataLoaders ###########
104 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
105 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)
106 |
107 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
108 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
109 |
110 |
111 |
112 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
113 | print('===> Loading datasets')
114 |
115 | best_psnr = 0
116 | best_epoch = 0
117 | global_step = 0
118 |
119 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
120 | epoch_start_time = time.time()
121 | epoch_loss = 0
122 | psnr_train_rgb = []
123 | psnr_train_rgb1 = []
124 | psnr_tr = 0
125 | psnr_tr1 = 0
126 | model_restoration.train()
127 | for i, data in enumerate(tqdm(train_loader), 0):
128 |
129 | # zero_grad
130 | for param in model_restoration.parameters():
131 | param.grad = None
132 |
133 | target = data[0].cuda()
134 | input_ = data[1].cuda()
135 |
136 | restored = model_restoration(input_)
137 |
138 | loss = criterion_mse(restored[0],target)
139 | loss.backward()
140 | optimizer.step()
141 | epoch_loss += loss.item()
142 | global_step = global_step+1
143 |
144 | psnr_te = 0
145 | psnr_te_1 = 0
146 | ssim_te_1 = 0
147 | #### Evaluation ####
148 | if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0:
149 | model_restoration.eval()
150 | psnr_val_rgb = []
151 | psnr_val_rgb1 = []
152 | for ii, data_val in enumerate((val_loader), 0):
153 | target = data_val[0].cuda()
154 | input_ = data_val[1].cuda()
155 |
156 | with torch.no_grad():
157 | restored = model_restoration(input_)
158 | restore = restored[0]
159 |
160 | for res, tar in zip(restore, target):
161 | tssss = utils.torchPSNR(res, tar)
162 | psnr_te = psnr_te + tssss
163 | psnr_val_rgb.append(utils.torchPSNR(res, tar))
164 |
165 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
166 | print("te", psnr_te)
167 |
168 | if psnr_val_rgb > best_psnr:
169 | best_psnr = psnr_val_rgb
170 | best_epoch = epoch
171 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
172 | torch.save({'epoch': epoch,
173 | 'state_dict': model_restoration.state_dict(),
174 | 'optimizer': optimizer.state_dict()
175 | }, str(dir_checkpoint / "model_best.pth"))
176 |
177 |
178 | print("[epoch %d PSNR: %.4f best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
179 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
180 | torch.save({'epoch': epoch,
181 | 'state_dict': model_restoration.state_dict(),
182 | 'optimizer': optimizer.state_dict()
183 | }, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
184 |
185 | scheduler.step()
186 |
187 | print("------------------------------------------------------------------")
188 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time,
189 | epoch_loss, scheduler.get_lr()[0]))
190 | print("------------------------------------------------------------------")
191 |
192 |
193 | if __name__=='__main__':
194 | train()
195 |
196 |
--------------------------------------------------------------------------------
/Derain/trmash.yml:
--------------------------------------------------------------------------------
1 | ###############
2 | ##
3 | ####
4 |
5 | GPU: [2,3]
6 |
7 | VERBOSE: True
8 |
9 | MODEL:
10 | MODE: 'Deraining'
11 | SESSION: 'MHNet'
12 | # Optimization arguments.
13 | OPTIM:
14 | BATCH_SIZE: 16
15 | NUM_EPOCHS: 10000
16 | # NEPOCH_DECAY: [10]
17 | LR_INITIAL: 2e-4
18 | LR_MIN: 1e-6
19 | # BETA1: 0.9
20 |
21 | TRAINING:
22 |
23 | VAL_AFTER_EVERY: 10
24 | RESUME: True
25 | TRAIN_PS: 256
26 | VAL_PS: 128
27 | TRAIN_DIR: './Datasets/train' # path to training data
28 | VAL_DIR: './Datasets/test/Rain100L' # path to validation data
29 | SAVE_DIR: './checkpoints' # path to save models and images
30 | # SAVE_IMAGES: False
31 |
--------------------------------------------------------------------------------
/Derain/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dir_utils import *
2 | from .image_utils import *
3 | from .model_utils import *
4 | from .dataset_utils import *
5 | from .logger import (MessageLogger, get_env_info, get_root_logger,
6 | init_tb_logger, init_wandb_logger)
7 |
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'scandir_SIDD',
31 | 'check_resume',
32 | 'sizeof_fmt',
33 | 'padding',
34 | 'create_lmdb_for_reds',
35 | 'create_lmdb_for_gopro',
36 | 'create_lmdb_for_rain13k',
37 | ]
38 |
39 |
40 |
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/arch_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/arch_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/dataset_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/dataset_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/dir_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/dir_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/dist_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/dist_util.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/image_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/image_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/__pycache__/model_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/model_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Derain/utils/arch_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn as nn
4 | from torch.nn import functional as F
5 | from torch.nn import init as init
6 | from torch.nn.modules.batchnorm import _BatchNorm
7 |
8 | from utils import get_root_logger
9 |
10 | # try:
11 | # from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
12 | # modulated_deform_conv)
13 | # except ImportError:
14 | # # print('Cannot import dcn. Ignore this warning if dcn is not used. '
15 | # # 'Otherwise install BasicSR with compiling dcn.')
16 | #
17 |
18 | @torch.no_grad()
19 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
20 | """Initialize network weights.
21 | Args:
22 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23 | scale (float): Scale initialized weights, especially for residual
24 | blocks. Default: 1.
25 | bias_fill (float): The value to fill bias. Default: 0
26 | kwargs (dict): Other arguments for initialization function.
27 | """
28 | if not isinstance(module_list, list):
29 | module_list = [module_list]
30 | for module in module_list:
31 | for m in module.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.kaiming_normal_(m.weight, **kwargs)
34 | m.weight.data *= scale
35 | if m.bias is not None:
36 | m.bias.data.fill_(bias_fill)
37 | elif isinstance(m, nn.Linear):
38 | init.kaiming_normal_(m.weight, **kwargs)
39 | m.weight.data *= scale
40 | if m.bias is not None:
41 | m.bias.data.fill_(bias_fill)
42 | elif isinstance(m, _BatchNorm):
43 | init.constant_(m.weight, 1)
44 | if m.bias is not None:
45 | m.bias.data.fill_(bias_fill)
46 |
47 |
48 | def make_layer(basic_block, num_basic_block, **kwarg):
49 | """Make layers by stacking the same blocks.
50 | Args:
51 | basic_block (nn.module): nn.module class for basic block.
52 | num_basic_block (int): number of blocks.
53 | Returns:
54 | nn.Sequential: Stacked blocks in nn.Sequential.
55 | """
56 | layers = []
57 | for _ in range(num_basic_block):
58 | layers.append(basic_block(**kwarg))
59 | return nn.Sequential(*layers)
60 |
61 |
62 | class ResidualBlockNoBN(nn.Module):
63 | """Residual block without BN.
64 | It has a style of:
65 | ---Conv-ReLU-Conv-+-
66 | |________________|
67 | Args:
68 | num_feat (int): Channel number of intermediate features.
69 | Default: 64.
70 | res_scale (float): Residual scale. Default: 1.
71 | pytorch_init (bool): If set to True, use pytorch default init,
72 | otherwise, use default_init_weights. Default: False.
73 | """
74 |
75 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
76 | super(ResidualBlockNoBN, self).__init__()
77 | self.res_scale = res_scale
78 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
79 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
80 | self.relu = nn.ReLU(inplace=True)
81 |
82 | if not pytorch_init:
83 | default_init_weights([self.conv1, self.conv2], 0.1)
84 |
85 | def forward(self, x):
86 | identity = x
87 | out = self.conv2(self.relu(self.conv1(x)))
88 | return identity + out * self.res_scale
89 |
90 |
91 | class Upsample(nn.Sequential):
92 | """Upsample module.
93 | Args:
94 | scale (int): Scale factor. Supported scales: 2^n and 3.
95 | num_feat (int): Channel number of intermediate features.
96 | """
97 |
98 | def __init__(self, scale, num_feat):
99 | m = []
100 | if (scale & (scale - 1)) == 0: # scale = 2^n
101 | for _ in range(int(math.log(scale, 2))):
102 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
103 | m.append(nn.PixelShuffle(2))
104 | elif scale == 3:
105 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
106 | m.append(nn.PixelShuffle(3))
107 | else:
108 | raise ValueError(f'scale {scale} is not supported. '
109 | 'Supported scales: 2^n and 3.')
110 | super(Upsample, self).__init__(*m)
111 |
112 |
113 | def flow_warp(x,
114 | flow,
115 | interp_mode='bilinear',
116 | padding_mode='zeros',
117 | align_corners=True):
118 | """Warp an image or feature map with optical flow.
119 | Args:
120 | x (Tensor): Tensor with size (n, c, h, w).
121 | flow (Tensor): Tensor with size (n, h, w, 2), normal value.
122 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
123 | padding_mode (str): 'zeros' or 'border' or 'reflection'.
124 | Default: 'zeros'.
125 | align_corners (bool): Before pytorch 1.3, the default value is
126 | align_corners=True. After pytorch 1.3, the default value is
127 | align_corners=False. Here, we use the True as default.
128 | Returns:
129 | Tensor: Warped image or feature map.
130 | """
131 | assert x.size()[-2:] == flow.size()[1:3]
132 | _, _, h, w = x.size()
133 | # create mesh grid
134 | grid_y, grid_x = torch.meshgrid(
135 | torch.arange(0, h).type_as(x),
136 | torch.arange(0, w).type_as(x))
137 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138 | grid.requires_grad = False
139 |
140 | vgrid = grid + flow
141 | # scale grid to [-1,1]
142 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145 | output = F.grid_sample(
146 | x,
147 | vgrid_scaled,
148 | mode=interp_mode,
149 | padding_mode=padding_mode,
150 | align_corners=align_corners)
151 |
152 | # TODO, what if align_corners=False
153 | return output
154 |
155 |
156 | def resize_flow(flow,
157 | size_type,
158 | sizes,
159 | interp_mode='bilinear',
160 | align_corners=False):
161 | """Resize a flow according to ratio or shape.
162 | Args:
163 | flow (Tensor): Precomputed flow. shape [N, 2, H, W].
164 | size_type (str): 'ratio' or 'shape'.
165 | sizes (list[int | float]): the ratio for resizing or the final output
166 | shape.
167 | 1) The order of ratio should be [ratio_h, ratio_w]. For
168 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio
169 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
170 | ratio > 1.0).
171 | 2) The order of output_size should be [out_h, out_w].
172 | interp_mode (str): The mode of interpolation for resizing.
173 | Default: 'bilinear'.
174 | align_corners (bool): Whether align corners. Default: False.
175 | Returns:
176 | Tensor: Resized flow.
177 | """
178 | _, _, flow_h, flow_w = flow.size()
179 | if size_type == 'ratio':
180 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
181 | elif size_type == 'shape':
182 | output_h, output_w = sizes[0], sizes[1]
183 | else:
184 | raise ValueError(
185 | f'Size type should be ratio or shape, but got type {size_type}.')
186 |
187 | input_flow = flow.clone()
188 | ratio_h = output_h / flow_h
189 | ratio_w = output_w / flow_w
190 | input_flow[:, 0, :, :] *= ratio_w
191 | input_flow[:, 1, :, :] *= ratio_h
192 | resized_flow = F.interpolate(
193 | input=input_flow,
194 | size=(output_h, output_w),
195 | mode=interp_mode,
196 | align_corners=align_corners)
197 | return resized_flow
198 |
199 |
200 | # TODO: may write a cpp file
201 | def pixel_unshuffle(x, scale):
202 | """ Pixel unshuffle.
203 | Args:
204 | x (Tensor): Input feature with shape (b, c, hh, hw).
205 | scale (int): Downsample ratio.
206 | Returns:
207 | Tensor: the pixel unshuffled feature.
208 | """
209 | b, c, hh, hw = x.size()
210 | out_channel = c * (scale**2)
211 | assert hh % scale == 0 and hw % scale == 0
212 | h = hh // scale
213 | w = hw // scale
214 | x_view = x.view(b, c, h, scale, w, scale)
215 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
216 |
217 |
218 | # class DCNv2Pack(ModulatedDeformConvPack):
219 | # """Modulated deformable conv for deformable alignment.
220 | #
221 | # Different from the official DCNv2Pack, which generates offsets and masks
222 | # from the preceding features, this DCNv2Pack takes another different
223 | # features to generate offsets and masks.
224 | #
225 | # Ref:
226 | # Delving Deep into Deformable Alignment in Video Super-Resolution.
227 | # """
228 | #
229 | # def forward(self, x, feat):
230 | # out = self.conv_offset(feat)
231 | # o1, o2, mask = torch.chunk(out, 3, dim=1)
232 | # offset = torch.cat((o1, o2), dim=1)
233 | # mask = torch.sigmoid(mask)
234 | #
235 | # offset_absmean = torch.mean(torch.abs(offset))
236 | # if offset_absmean > 50:
237 | # logger = get_root_logger()
238 | # logger.warning(
239 | # f'Offset abs mean is {offset_absmean}, larger than 50.')
240 | #
241 | # return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
242 | # self.stride, self.padding, self.dilation,
243 | # self.groups, self.deformable_groups)
244 |
245 |
246 | class LayerNormFunction(torch.autograd.Function):
247 |
248 | @staticmethod
249 | def forward(ctx, x, weight, bias, eps):
250 | ctx.eps = eps
251 | N, C, H, W = x.size()
252 | mu = x.mean(1, keepdim=True)
253 | var = (x - mu).pow(2).mean(1, keepdim=True)
254 | y = (x - mu) / (var + eps).sqrt()
255 | ctx.save_for_backward(y, var, weight)
256 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
257 | return y
258 |
259 | @staticmethod
260 | def backward(ctx, grad_output):
261 | eps = ctx.eps
262 |
263 | N, C, H, W = grad_output.size()
264 | y, var, weight = ctx.saved_variables
265 | g = grad_output * weight.view(1, C, 1, 1)
266 | mean_g = g.mean(dim=1, keepdim=True)
267 |
268 | mean_gy = (g * y).mean(dim=1, keepdim=True)
269 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
270 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
271 | dim=0), None
272 |
273 | class LayerNorm2d(nn.Module):
274 |
275 | def __init__(self, channels, eps=1e-6):
276 | super(LayerNorm2d, self).__init__()
277 | self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
278 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
279 | self.eps = eps
280 |
281 | def forward(self, x):
282 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
283 |
284 | # handle multiple input
285 | class MySequential(nn.Sequential):
286 | def forward(self, *inputs):
287 | for module in self._modules.values():
288 | if type(inputs) == tuple:
289 | inputs = module(*inputs)
290 | else:
291 | inputs = module(inputs)
292 | return inputs
293 |
294 | import time
295 | def measure_inference_speed(model, data, max_iter=200, log_interval=50):
296 | model.eval()
297 |
298 | # the first several iterations may be very slow so skip them
299 | num_warmup = 5
300 | pure_inf_time = 0
301 | fps = 0
302 |
303 | # benchmark with 2000 image and take the average
304 | for i in range(max_iter):
305 |
306 | torch.cuda.synchronize()
307 | start_time = time.perf_counter()
308 |
309 | with torch.no_grad():
310 | model(*data)
311 |
312 | torch.cuda.synchronize()
313 | elapsed = time.perf_counter() - start_time
314 |
315 | if i >= num_warmup:
316 | pure_inf_time += elapsed
317 | if (i + 1) % log_interval == 0:
318 | fps = (i + 1 - num_warmup) / pure_inf_time
319 | print(
320 | f'Done image [{i + 1:<3}/ {max_iter}], '
321 | f'fps: {fps:.1f} img / s, '
322 | f'times per image: {1000 / fps:.1f} ms / img',
323 | flush=True)
324 |
325 | if (i + 1) == max_iter:
326 | fps = (i + 1 - num_warmup) / pure_inf_time
327 | print(
328 | f'Overall fps: {fps:.1f} img / s, '
329 | f'times per image: {1000 / fps:.1f} ms / img',
330 | flush=True)
331 | break
332 | return
--------------------------------------------------------------------------------
/Derain/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class MixUp_AUG:
4 | def __init__(self):
5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
6 |
7 | def aug(self, rgb_gt, rgb_noisy):
8 | bs = rgb_gt.size(0)
9 | indices = torch.randperm(bs)
10 | rgb_gt2 = rgb_gt[indices]
11 | rgb_noisy2 = rgb_noisy[indices]
12 |
13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
14 |
15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
17 |
18 | return rgb_gt, rgb_noisy
--------------------------------------------------------------------------------
/Derain/utils/dir_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from natsort import natsorted
3 | from glob import glob
4 |
5 | def mkdirs(paths):
6 | if isinstance(paths, list) and not isinstance(paths, str):
7 | for path in paths:
8 | mkdir(path)
9 | else:
10 | mkdir(paths)
11 |
12 | def mkdir(path):
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 |
16 | def get_last_path(path, session):
17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
18 | return x
--------------------------------------------------------------------------------
/Derain/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | import subprocess
4 | import torch
5 | import torch.distributed as dist
6 | import torch.multiprocessing as mp
7 |
8 |
9 | def init_dist(launcher, backend='nccl', **kwargs):
10 | if mp.get_start_method(allow_none=True) is None:
11 | mp.set_start_method('spawn')
12 | if launcher == 'pytorch':
13 | _init_dist_pytorch(backend, **kwargs)
14 | elif launcher == 'slurm':
15 | _init_dist_slurm(backend, **kwargs)
16 | else:
17 | raise ValueError(f'Invalid launcher type: {launcher}')
18 |
19 |
20 | def _init_dist_pytorch(backend, **kwargs):
21 | rank = int(os.environ['RANK'])
22 | num_gpus = torch.cuda.device_count()
23 | torch.cuda.set_device(rank % num_gpus)
24 | dist.init_process_group(backend=backend, **kwargs)
25 |
26 |
27 | def _init_dist_slurm(backend, port=None):
28 | """Initialize slurm distributed training environment.
29 | If argument ``port`` is not specified, then the master port will be system
30 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
31 | environment variable, then a default port ``29500`` will be used.
32 | Args:
33 | backend (str): Backend of torch.distributed.
34 | port (int, optional): Master port. Defaults to None.
35 | """
36 | proc_id = int(os.environ['SLURM_PROCID'])
37 | ntasks = int(os.environ['SLURM_NTASKS'])
38 | node_list = os.environ['SLURM_NODELIST']
39 | num_gpus = torch.cuda.device_count()
40 | torch.cuda.set_device(proc_id % num_gpus)
41 | addr = subprocess.getoutput(
42 | f'scontrol show hostname {node_list} | head -n1')
43 | # specify master port
44 | if port is not None:
45 | os.environ['MASTER_PORT'] = str(port)
46 | elif 'MASTER_PORT' in os.environ:
47 | pass # use MASTER_PORT in the environment variable
48 | else:
49 | # 29500 is torch.distributed default port
50 | os.environ['MASTER_PORT'] = '29500'
51 | os.environ['MASTER_ADDR'] = addr
52 | os.environ['WORLD_SIZE'] = str(ntasks)
53 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
54 | os.environ['RANK'] = str(proc_id)
55 | dist.init_process_group(backend=backend)
56 |
57 |
58 | def get_dist_info():
59 | if dist.is_available():
60 | initialized = dist.is_initialized()
61 | else:
62 | initialized = False
63 | if initialized:
64 | rank = dist.get_rank()
65 | world_size = dist.get_world_size()
66 | else:
67 | rank = 0
68 | world_size = 1
69 | return rank, world_size
70 |
71 |
72 | def master_only(func):
73 |
74 | @functools.wraps(func)
75 | def wrapper(*args, **kwargs):
76 | rank, _ = get_dist_info()
77 | if rank == 0:
78 | return func(*args, **kwargs)
79 |
80 | return
--------------------------------------------------------------------------------
/Derain/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 | def torchPSNR(tar_img, prd_img):
6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
7 | rmse = (imdff**2).mean().sqrt()
8 | ps = 20*torch.log10(1/rmse)
9 | return ps
10 |
11 | def save_img(filepath, img):
12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
13 |
14 | def numpyPSNR(tar_img, prd_img):
15 | imdff = np.float32(prd_img) - np.float32(tar_img)
16 | rmse = np.sqrt(np.mean(imdff**2))
17 | ps = 20*np.log10(255/rmse)
18 | return ps
19 |
--------------------------------------------------------------------------------
/Derain/utils/logger.py:
--------------------------------------------------------------------------------
1 |
2 | import datetime
3 | import logging
4 | import time
5 |
6 | from .dist_util import get_dist_info, master_only
7 |
8 |
9 | class MessageLogger():
10 | """Message logger for printing.
11 | Args:
12 | opt (dict): Config. It contains the following keys:
13 | name (str): Exp name.
14 | logger (dict): Contains 'print_freq' (str) for logger interval.
15 | train (dict): Contains 'total_iter' (int) for total iters.
16 | use_tb_logger (bool): Use tensorboard logger.
17 | start_iter (int): Start iter. Default: 1.
18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
19 | """
20 |
21 | def __init__(self, opt, start_iter=1, tb_logger=None):
22 | self.exp_name = opt['name']
23 | self.interval = opt['logger']['print_freq']
24 | self.start_iter = start_iter
25 | self.max_iters = opt['train']['total_iter']
26 | self.use_tb_logger = opt['logger']['use_tb_logger']
27 | self.tb_logger = tb_logger
28 | self.start_time = time.time()
29 | self.logger = get_root_logger()
30 |
31 | @master_only
32 | def __call__(self, log_vars):
33 | """Format logging message.
34 | Args:
35 | log_vars (dict): It contains the following keys:
36 | epoch (int): Epoch number.
37 | iter (int): Current iter.
38 | lrs (list): List for learning rates.
39 | time (float): Iter time.
40 | data_time (float): Data time for each iter.
41 | """
42 | # epoch, iter, learning rates
43 | epoch = log_vars.pop('epoch')
44 | current_iter = log_vars.pop('iter')
45 | total_iter = log_vars.pop('total_iter')
46 | lrs = log_vars.pop('lrs')
47 |
48 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
49 | f'iter:{current_iter:8,d}, lr:(')
50 | for v in lrs:
51 | message += f'{v:.3e},'
52 | message += ')] '
53 |
54 | # time and estimated time
55 | if 'time' in log_vars.keys():
56 | iter_time = log_vars.pop('time')
57 | data_time = log_vars.pop('data_time')
58 |
59 | total_time = time.time() - self.start_time
60 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
61 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
62 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
63 | message += f'[eta: {eta_str}, '
64 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
65 |
66 | # other items, especially losses
67 | for k, v in log_vars.items():
68 | message += f'{k}: {v:.4e} '
69 | # tensorboard logger
70 | if self.use_tb_logger and 'debug' not in self.exp_name:
71 | normed_step = 10000 * (current_iter / total_iter)
72 | normed_step = int(normed_step)
73 |
74 | if k.startswith('l_'):
75 | self.tb_logger.add_scalar(f'losses/{k}', v, normed_step)
76 | elif k.startswith('m_'):
77 | self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step)
78 | else:
79 | assert 1 == 0
80 | # else:
81 | # self.tb_logger.add_scalar(k, v, current_iter)
82 | self.logger.info(message)
83 |
84 |
85 | @master_only
86 | def init_tb_logger(log_dir):
87 | from torch.utils.tensorboard import SummaryWriter
88 | tb_logger = SummaryWriter(log_dir=log_dir)
89 | return tb_logger
90 |
91 |
92 | @master_only
93 | def init_wandb_logger(opt):
94 | """We now only use wandb to sync tensorboard log."""
95 | import wandb
96 | logger = logging.getLogger('basicsr')
97 |
98 | project = opt['logger']['wandb']['project']
99 | resume_id = opt['logger']['wandb'].get('resume_id')
100 | if resume_id:
101 | wandb_id = resume_id
102 | resume = 'allow'
103 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
104 | else:
105 | wandb_id = wandb.util.generate_id()
106 | resume = 'never'
107 |
108 | wandb.init(
109 | id=wandb_id,
110 | resume=resume,
111 | name=opt['name'],
112 | config=opt,
113 | project=project,
114 | sync_tensorboard=True)
115 |
116 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
117 |
118 |
119 | def get_root_logger(logger_name='basicsr',
120 | log_level=logging.INFO,
121 | log_file=None):
122 | """Get the root logger.
123 | The logger will be initialized if it has not been initialized. By default a
124 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
125 | also be added.
126 | Args:
127 | logger_name (str): root logger name. Default: 'basicsr'.
128 | log_file (str | None): The log filename. If specified, a FileHandler
129 | will be added to the root logger.
130 | log_level (int): The root logger level. Note that only the process of
131 | rank 0 is affected, while other processes will set the level to
132 | "Error" and be silent most of the time.
133 | Returns:
134 | logging.Logger: The root logger.
135 | """
136 | logger = logging.getLogger(logger_name)
137 | # if the logger has been initialized, just return it
138 | if logger.hasHandlers():
139 | return logger
140 |
141 | format_str = '%(asctime)s %(levelname)s: %(message)s'
142 | logging.basicConfig(format=format_str, level=log_level)
143 | rank, _ = get_dist_info()
144 | if rank != 0:
145 | logger.setLevel('ERROR')
146 | elif log_file is not None:
147 | file_handler = logging.FileHandler(log_file, 'w')
148 | file_handler.setFormatter(logging.Formatter(format_str))
149 | file_handler.setLevel(log_level)
150 | logger.addHandler(file_handler)
151 |
152 | return logger
153 |
154 |
155 | def get_env_info():
156 | """Get environment information.
157 | Currently, only log the software version.
158 | """
159 | import torch
160 | import torchvision
161 |
162 | from basicsr.version import __version__
163 | msg = r"""
164 | ____ _ _____ ____
165 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
166 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
167 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
168 | /_____/ \__,_//____//_/ \___//____//_/ |_|
169 | ______ __ __ __ __
170 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
171 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
172 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
173 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
174 | """
175 | msg += ('\nVersion Information: '
176 | f'\n\tBasicSR: {__version__}'
177 | f'\n\tPyTorch: {torch.__version__}'
178 | f'\n\tTorchVision: {torchvision.__version__}')
179 | return
180 |
--------------------------------------------------------------------------------
/Derain/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from collections import OrderedDict
4 |
5 | def freeze(model):
6 | for p in model.parameters():
7 | p.requires_grad=False
8 |
9 | def unfreeze(model):
10 | for p in model.parameters():
11 | p.requires_grad=True
12 |
13 | def is_frozen(model):
14 | x = [p.requires_grad for p in model.parameters()]
15 | return not all(x)
16 |
17 | def save_checkpoint(model_dir, state, session):
18 | epoch = state['epoch']
19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
20 | torch.save(state, model_out_path)
21 |
22 | def load_checkpoint(model, weights):
23 | checkpoint = torch.load(weights)
24 | try:
25 | model.load_state_dict(checkpoint["state_dict"])
26 | except:
27 | state_dict = checkpoint["state_dict"]
28 | new_state_dict = OrderedDict()
29 | for k, v in state_dict.items():
30 | name = k[7:] # remove `module.`
31 | new_state_dict[name] = v
32 | model.load_state_dict(new_state_dict)
33 |
34 |
35 | def load_checkpoint_multigpu(model, weights):
36 | checkpoint = torch.load(weights)
37 | state_dict = checkpoint["state_dict"]
38 | new_state_dict = OrderedDict()
39 | for k, v in state_dict.items():
40 | name = k[7:] # remove `module.`
41 | new_state_dict[name] = v
42 | model.load_state_dict(new_state_dict)
43 |
44 | def load_start_epoch(weights):
45 | checkpoint = torch.load(weights)
46 | epoch = checkpoint["epoch"]
47 | return epoch
48 |
49 | def load_optim(optimizer, weights):
50 | checkpoint = torch.load(weights)
51 | optimizer.load_state_dict(checkpoint['optimizer'])
52 | # for p in optimizer.param_groups: lr = p['lr']
53 | # return lr
54 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## ACADEMIC PUBLIC LICENSE
2 |
3 | ### Permissions
4 | :heavy_check_mark: Non-Commercial use
5 | :heavy_check_mark: Modification
6 | :heavy_check_mark: Distribution
7 | :heavy_check_mark: Private use
8 |
9 | ### Limitations
10 | :x: Commercial Use
11 | :x: Liability
12 | :x: Warranty
13 |
14 | ### Conditions
15 | :information_source: License and copyright notice
16 | :information_source: Same License
17 |
18 | MHNet is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations.
19 | You can use MHNet in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately.
20 |
21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software.
22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license.
23 | This license guarantees that you're safe when using MHNet in your work, for teaching or research.
24 | This license guarantees that MHNet will remain available free of charge for nonprofit use.
25 | You can modify MHNet to your purposes, and you can also share your modifications.
26 |
27 | If you would like to use MHNet in commercial settings, contact us so we can discuss options. Send an email to two_bits@163.com
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/MHNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.arch_utils import LayerNorm2d
5 | from einops import rearrange
6 |
7 |
8 | class UpDSample(nn.Module):
9 | def __init__(self, in_channels):
10 | super(UpDSample, self).__init__()
11 | self.up = nn.Sequential(
12 | nn.Conv2d(in_channels, in_channels * 2, 1, bias=False),
13 | nn.PixelShuffle(2)
14 | )
15 |
16 | def forward(self, x):
17 | x = self.up(x)
18 | return x
19 |
20 |
21 |
22 |
23 | class SimpleGate(nn.Module):
24 | def forward(self, x):
25 | x1, x2 = x.chunk(2, dim=1)
26 | return x1 * x2
27 |
28 | class Attention(nn.Module):
29 | def __init__(self, dim, num_heads, bias):
30 | super(Attention, self).__init__()
31 | self.dim = dim
32 | self.num_heads = num_heads
33 | self.bias = bias
34 |
35 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
36 |
37 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
38 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
39 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
40 | self.attn_drop = nn.Dropout(0.)
41 |
42 | self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
43 | self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
44 | self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
45 | self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
46 |
47 | def forward(self, x):
48 | b, c, h, w = x.shape
49 |
50 | qkv = self.qkv_dwconv(self.qkv(x))
51 | q, k, v = qkv.chunk(3, dim=1)
52 |
53 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
54 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
55 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
56 |
57 | q = torch.nn.functional.normalize(q, dim=-1)
58 | k = torch.nn.functional.normalize(k, dim=-1)
59 |
60 | _, _, C, _ = q.shape
61 |
62 | mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
63 | mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
64 | mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
65 | mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
66 |
67 | attn = (q @ k.transpose(-2, -1)) * self.temperature
68 |
69 | index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1]
70 | mask1.scatter_(-1, index, 1.)
71 | attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))
72 |
73 | index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1]
74 | mask2.scatter_(-1, index, 1.)
75 | attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf')))
76 |
77 | index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1]
78 | mask3.scatter_(-1, index, 1.)
79 | attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf')))
80 |
81 | index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1]
82 | mask4.scatter_(-1, index, 1.)
83 | attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf')))
84 |
85 | attn1 = attn1.softmax(dim=-1)
86 | attn2 = attn2.softmax(dim=-1)
87 | attn3 = attn3.softmax(dim=-1)
88 | attn4 = attn4.softmax(dim=-1)
89 |
90 | out1 = (attn1 @ v)
91 | out2 = (attn2 @ v)
92 | out3 = (attn3 @ v)
93 | out4 = (attn4 @ v)
94 |
95 | out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4
96 |
97 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
98 |
99 | out = self.project_out(out)
100 | return out
101 |
102 |
103 | class BotBlock(nn.Module):
104 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
105 | super().__init__()
106 | dw_channel = c * DW_Expand
107 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1,
108 | bias=True)
109 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1,
110 | groups=dw_channel,
111 | bias=True)
112 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
113 | groups=1, bias=True)
114 |
115 | # Simplified Channel Attention
116 | self.sca = nn.Sequential(
117 | nn.AdaptiveAvgPool2d(1),
118 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
119 | groups=1, bias=True),
120 | )
121 |
122 | # SimpleGate
123 | self.sg = SimpleGate()
124 |
125 | ffn_channel = FFN_Expand * c
126 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1,
127 | bias=True)
128 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
129 | groups=1, bias=True)
130 |
131 | self.norm1 = LayerNorm2d(c)
132 | self.norm2 = LayerNorm2d(c)
133 |
134 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
135 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
136 |
137 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
138 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
139 |
140 | def forward(self, inp):
141 | x = inp
142 |
143 | x = self.norm1(x)
144 |
145 | x = self.conv1(x)
146 | x = self.conv2(x)
147 | x = self.sg(x)
148 | x = x * self.sca(x)
149 | x = self.conv3(x)
150 |
151 | x = self.dropout1(x)
152 |
153 | y = inp + x * self.beta
154 |
155 | x = self.conv4(self.norm2(y))
156 | x = self.sg(x)
157 | x = self.conv5(x)
158 |
159 | x = self.dropout2(x)
160 |
161 | return y + x * self.gamma
162 |
163 |
164 | class Bottneck(nn.Module):
165 |
166 | def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1,1,1,28], dec_blk_nums=[1,1,1,1]):
167 | super().__init__()
168 |
169 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1,
170 | groups=1,
171 | bias=True)
172 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1,
173 | groups=1,
174 | bias=True)
175 |
176 | self.encoders = nn.ModuleList()
177 | self.decoders = nn.ModuleList()
178 | self.middle_blks = nn.ModuleList()
179 | self.ups = nn.ModuleList()
180 | self.downs = nn.ModuleList()
181 |
182 | chan = width
183 | for num in enc_blk_nums:
184 | self.encoders.append(
185 | nn.Sequential(
186 | *[BotBlock(chan) for _ in range(num)]
187 | )
188 | )
189 | self.downs.append(
190 | nn.Conv2d(chan, 2 * chan, 2, 2)
191 | )
192 | chan = chan * 2
193 |
194 | self.middle_blks = \
195 | nn.Sequential(
196 | *[Attention(chan, 8, False) for _ in range(middle_blk_num)]
197 | )
198 |
199 | for num in dec_blk_nums:
200 | self.ups.append(
201 | nn.Sequential(
202 | nn.Conv2d(chan, chan * 2, 1, bias=False),
203 | nn.PixelShuffle(2)
204 | )
205 | )
206 | chan = chan // 2
207 | self.decoders.append(
208 | nn.Sequential(
209 | *[BotBlock(chan) for _ in range(num)]
210 | )
211 | )
212 |
213 | self.padder_size = 2 ** len(self.encoders)
214 |
215 | def forward(self, inp):
216 | B, C, H, W = inp.shape
217 | inp = self.check_image_size(inp)
218 |
219 | x = self.intro(inp)
220 |
221 | encs = []
222 |
223 |
224 | for encoder, down in zip(self.encoders, self.downs):
225 | x = encoder(x)
226 | encs.append(x)
227 | x = down(x)
228 |
229 | x = self.middle_blks(x)
230 | decs = []
231 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
232 | x = up(x)
233 | x = x + enc_skip
234 | x = decoder(x)
235 | decs.append(x)
236 |
237 | x = self.ending(x)
238 |
239 | return encs, decs
240 |
241 | def check_image_size(self, x):
242 | _, _, h, w = x.size()
243 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
244 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
245 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
246 | return x
247 |
248 |
249 |
250 | class CABG(nn.Module):
251 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
252 | super().__init__()
253 | dw_channel = c * DW_Expand
254 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1,
255 | bias=True)
256 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1,
257 | groups=dw_channel,
258 | bias=True)
259 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
260 | groups=1, bias=True)
261 |
262 | # Simplified Channel Attention
263 | self.sca = nn.Sequential(
264 | nn.AdaptiveAvgPool2d(1),
265 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
266 | groups=1, bias=True),
267 | )
268 |
269 | # SimpleGate
270 | self.sg = SimpleGate()
271 |
272 | ffn_channel = FFN_Expand * c
273 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1,
274 | bias=True)
275 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
276 | groups=1, bias=True)
277 |
278 | self.norm1 = LayerNorm2d(c)
279 | self.norm2 = LayerNorm2d(c)
280 |
281 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
282 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
283 |
284 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
285 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
286 |
287 | def forward(self, inp):
288 | x = inp
289 |
290 | x = self.norm1(x)
291 |
292 | x = self.conv1(x)
293 | x = self.conv2(x)
294 | x = self.sg(x)
295 | x = x * self.sca(x)
296 | x = self.conv3(x)
297 |
298 | x = self.dropout1(x)
299 |
300 | y = inp + x * self.beta
301 |
302 | x = self.conv4(self.norm2(y))
303 | x = self.sg(x)
304 | x = self.conv5(x)
305 |
306 | x = self.dropout2(x)
307 |
308 | return y + x * self.gamma
309 |
310 | class AFFM(nn.Module):
311 | def __init__(self, in_channels, height=3,reduction=8, bias=False):
312 | super(SKFF, self).__init__()
313 |
314 | self.height = height
315 | d = max(int(in_channels/reduction),4)
316 |
317 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
318 | self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),SimpleGate())
319 |
320 | self.fcs = nn.ModuleList([])
321 | for i in range(self.height):
322 | self.fcs.append(nn.Conv2d(d//2, in_channels, kernel_size=1, stride=1,bias=bias))
323 |
324 | self.softmax = nn.Softmax(dim=1)
325 |
326 | def forward(self, f, f_e, f_d):
327 |
328 |
329 | feats_U = f + f_e + f_d
330 | feats_S = self.avg_pool(feats_U)
331 | feats_Z = self.conv_du(feats_S)
332 |
333 | a = self.softmax(self.fcs[0](feats_Z))
334 | a_e = self.softmax(self.fcs[1](feats_Z))
335 | a_d = self.softmax(self.fcs[2](feats_Z))
336 |
337 | return a*f + f_e*a_e + a_d*f_d
338 |
339 | ##########################################################################
340 | class FRSNet(nn.Module):
341 | def __init__(self, width, bias, num):
342 | super(CASNet, self).__init__()
343 |
344 |
345 |
346 | self.CABG1 = nn.Sequential(
347 | *[BotBlock(width) for _ in range(num)]
348 | )
349 | self.CABG2 = nn.Sequential(
350 | *[BotBlock(width) for _ in range(num)]
351 | )
352 | self.CABG3 = nn.Sequential(
353 | *[BotBlock(width) for _ in range(num)]
354 | )
355 | self.CABG4 = nn.Sequential(
356 | *[BotBlock(width) for _ in range(num)]
357 | )
358 |
359 | self.up_enc1 = UpDSample( width*2)
360 | self.up_dec1 = UpDSample(width*2)
361 |
362 | self.up_enc2 = nn.Sequential(UpDSample(width*4), UpDSample(width*2))
363 | self.up_dec2 = nn.Sequential(UpDSample(width*4), UpDSample(width*2))
364 |
365 | self.up_enc3 = nn.Sequential(UpDSample(width*8), UpDSample(width*4), UpDSample(width*2))
366 | self.up_dec3 = nn.Sequential(UpDSample(width*8), UpDSample(width*4), UpDSample(width*2))
367 |
368 | self.norm1 = LayerNorm2d(width)
369 | self.norm2 = LayerNorm2d(width)
370 | self.norm3 = LayerNorm2d(width)
371 | self.norm4 = LayerNorm2d(width)
372 |
373 | self.conv_enc1 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
374 | self.conv_enc2 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
375 | self.conv_enc3 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
376 | self.conv_enc4 = nn.Conv2d(width, width , kernel_size=1, bias=bias)
377 |
378 | self.conv_dnc1 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
379 | self.conv_dnc2 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
380 | self.conv_dnc3 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
381 | self.conv_dnc4 = nn.Conv2d(width, width, kernel_size=1, bias=bias)
382 |
383 | self.skff1 = AFFM(width)
384 | self.skff2 = AFFM(width)
385 | self.skff3 = AFFM(width)
386 | self.skff4 = AFFM(width)
387 |
388 | def forward(self, x, encoder_outs, decoder_outs):
389 | x = self.norm1(x)
390 | x = self.CABG1(x) + x
391 | x = self.skff1(x, self.conv_enc1(encoder_outs[0]), self.conv_dnc1(decoder_outs[3]))
392 |
393 | x = self.norm2(x)
394 | x = self.CABG2(x) + x
395 | x = self.skff2(x, self.conv_enc2(self.up_enc1(encoder_outs[1])), self.conv_dnc1(self.up_dec1(decoder_outs[2])))
396 |
397 | x = self.norm3(x)
398 | x = self.CABG3(x) + x
399 | x = self.skff3(x, self.conv_enc3(self.up_enc2(encoder_outs[2])), self.conv_dnc1(self.up_dec2(decoder_outs[1])))
400 |
401 | x = self.norm4(x)
402 | x = self.CABG4(x) + x
403 | x = self.skff4(x, self.conv_enc4(self.up_enc3(encoder_outs[3])), self.conv_dnc1(self.up_dec3(decoder_outs[0])))
404 |
405 | return x
406 |
407 |
408 |
409 | class MHNet(nn.Module):
410 | def __init__(self, in_c=3, out_c=3, width=64, num_cab=8,
411 | bias=False):
412 | super().__init__()
413 | act = SimpleGate()
414 | self.intro = nn.Conv2d(in_channels=in_c, out_channels=width, kernel_size=3, padding=1, stride=1,
415 | groups=1,
416 | bias=bias)
417 | self.intro2 = nn.Conv2d(in_channels=in_c, out_channels=width, kernel_size=3, padding=1, stride=1,
418 | groups=1,
419 | bias=bias)
420 | self.stage1 = Bottneck(in_c, width)
421 | self.stage2 = FRSNet(width, bias=bias, num=num_cab)
422 | self.concat12 = nn.Conv2d(width*2, width, kernel_size=1, stride=1, padding=0, bias=bias)
423 | self.ending = nn.Conv2d(in_channels=width, out_channels=out_c, kernel_size=3, padding=1, stride=1,
424 | groups=1,
425 | bias=bias)
426 |
427 |
428 | def forward(self, x3_img):
429 |
430 |
431 | x1 = self.intro(x3_img)
432 |
433 |
434 |
435 |
436 |
437 | x1_en, x1_dn = self.stage1(x3_img)
438 |
439 | x2 = self.intro2(x3_img)
440 | t = torch.cat([x2, x1_dn[3]], 1)
441 |
442 | x2_cat = self.concat12(t)
443 | x2_cat = self.stage2(x2_cat, x1_en, x1_dn)
444 |
445 | stage2_img = self.ending(x2_cat)
446 |
447 | return stage2_img + x3_img
448 |
449 | class AvgPool2d(nn.Module):
450 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
451 | super().__init__()
452 | self.kernel_size = kernel_size
453 | self.base_size = base_size
454 | self.auto_pad = auto_pad
455 |
456 | # only used for fast implementation
457 | self.fast_imp = fast_imp
458 | self.rs = [5, 4, 3, 2, 1]
459 | self.max_r1 = self.rs[0]
460 | self.max_r2 = self.rs[0]
461 | self.train_size = train_size
462 |
463 | def extra_repr(self) -> str:
464 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
465 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
466 | )
467 |
468 | def forward(self, x):
469 | if self.kernel_size is None and self.base_size:
470 | train_size = self.train_size
471 | if isinstance(self.base_size, int):
472 | self.base_size = (self.base_size, self.base_size)
473 | self.kernel_size = list(self.base_size)
474 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
475 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
476 |
477 | # only used for fast implementation
478 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
479 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
480 |
481 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
482 | return F.adaptive_avg_pool2d(x, 1)
483 |
484 | if self.fast_imp: # Non-equivalent implementation but faster
485 | h, w = x.shape[2:]
486 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
487 | out = F.adaptive_avg_pool2d(x, 1)
488 | else:
489 | r1 = [r for r in self.rs if h % r == 0][0]
490 | r2 = [r for r in self.rs if w % r == 0][0]
491 | # reduction_constraint
492 | r1 = min(self.max_r1, r1)
493 | r2 = min(self.max_r2, r2)
494 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
495 | n, c, h, w = s.shape
496 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
497 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
498 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
499 | else:
500 | n, c, h, w = x.shape
501 | s = x.cumsum(dim=-1).cumsum_(dim=-2)
502 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
503 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
504 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
505 | out = s4 + s1 - s2 - s3
506 | out = out / (k1 * k2)
507 |
508 | if self.auto_pad:
509 | n, c, h, w = x.shape
510 | _h, _w = out.shape[2:]
511 | # print(x.shape, self.kernel_size)
512 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
513 | out = torch.nn.functional.pad(out, pad2d, mode='replicate')
514 |
515 | return out
516 |
517 |
518 |
519 |
520 |
521 |
522 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
523 | for n, m in model.named_children():
524 | if len(list(m.children())) > 0:
525 | ## compound module, go inside it
526 | replace_layers(m, base_size, train_size, fast_imp, **kwargs)
527 |
528 | if isinstance(m, nn.AdaptiveAvgPool2d):
529 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
530 | assert m.output_size == 1
531 | setattr(model, n, pool)
532 |
533 |
534 |
535 |
536 |
537 | class Local_Base():
538 | def convert(self, *args, train_size, **kwargs):
539 | replace_layers(self, *args, train_size=train_size, **kwargs)
540 | imgs = torch.rand(train_size)
541 | with torch.no_grad():
542 | self.forward(imgs)
543 |
544 |
545 | class MHNetLocal(Local_Base, MHNet):
546 | def __init__(self, *args, train_size=(1, 3, 256, 256), base_size=None, fast_imp=False, **kwargs):
547 | Local_Base.__init__(self)
548 | MHNet.__init__(self, *args, **kwargs)
549 | N, C, H, W = train_size
550 | if base_size is None:
551 | base_size = (int(H * 1.5), int(W * 1.5))
552 |
553 | self.eval()
554 | with torch.no_grad():
555 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
556 |
557 |
558 | if __name__=='__main__':
559 | x = torch.randn([2, 3, 256, 256])
560 | model = MHNet()
561 | print("Total number of param is ", sum(i.numel() for i in model.parameters()))
562 | t = model(x)
563 | print(t.shape)
564 |
565 |
566 | from thop import profile
567 | x3 = torch.randn((1, 3, 256, 256))
568 | flops, params = profile(model, inputs=(x3, ))
569 | print('FLOPs = ' + str(flops/1000**3) + 'G')
570 | print('Params = ' + str(params/1000**2) + 'M')
571 | from ptflops import get_model_complexity_info
572 | FLOPS = 0
573 | inp_shape=(3,256,256)
574 | macs, params = get_model_complexity_info(model, inp_shape, verbose=False, print_per_layer_stat=True)
575 | #print(params)
576 | macs = float(macs[:-4]) + FLOPS / 10 ** 9
577 |
578 |
579 |
580 | print('mac', macs, params)
581 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Mixed Hierarchy Network for Image Restoration
4 |
5 | [](http://arxiv.org/abs/2302.09554)
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
![]() |
26 |
Overall Framework of MHNet |
29 |
![]() |
32 |
(a) Encoder-decoder subnetwork. (b) Selective multi-head attention mechanism (SMAM) (c) The architecture of nonlinear activation free block (NAFBlock). (d) Simplified Channel Attention (SCA). |
35 |
![]() |
38 |
(a) The architecture of nonlinear activation free block groups (NAFG). Each NAFG further contains multiple nonlinear activation free blocks (NAFBlocks). (b) Adaptive feature fusion mechanism (AFFM) between an encoder-decoder subnetwork and FRSNet. |
41 |