├── Fig
├── Building.png
├── Car.png
├── Fig0.png
├── Fig1.png
├── Fig2.png
├── Fig3.png
├── Fig4.png
├── Mars.png
└── Road.png
├── README.md
├── Real
├── 044.png
├── 045.png
├── 046.png
├── 047.png
├── 048.png
├── 049.png
├── 050.png
├── 051.png
├── 052.png
├── 053.png
├── 054.png
├── 055.png
├── 056.png
├── 057.png
├── 058.png
├── 059.png
└── 060.png
├── dataset.py
├── demo.py
├── model
├── ASCNet.py
└── cbam.py
├── prepare_patches.py
├── test.py
├── train.py
├── utils.py
└── warmup_scheduler.py
/Fig/Building.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Building.png
--------------------------------------------------------------------------------
/Fig/Car.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Car.png
--------------------------------------------------------------------------------
/Fig/Fig0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig0.png
--------------------------------------------------------------------------------
/Fig/Fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig1.png
--------------------------------------------------------------------------------
/Fig/Fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig2.png
--------------------------------------------------------------------------------
/Fig/Fig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig3.png
--------------------------------------------------------------------------------
/Fig/Fig4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig4.png
--------------------------------------------------------------------------------
/Fig/Mars.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Mars.png
--------------------------------------------------------------------------------
/Fig/Road.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Road.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # This is the code of paper "ASCNet: Asymmetric Sampling Correction Network for Infrared Image Destriping".[[Paper]](https://ieeexplore.ieee.org/document/10855453) [[Weight]](https://drive.google.com/file/d/1zbBsWUbRVBjNckPg5DiCgKIKOKWnQ2N8/view?usp=sharing)
2 | Shuai Yuan, Hanlin Qin, Xiang Yan, Shiqi Yang, Shuowen Yang, Naveed Akhtar, Huixin Zhou, IEEE Transactions on Geoscience and Remote Sensing 2025.
3 | # Real Destriping Examples
4 |
5 | [
](https://imgsli.com/MjkxNDU2) | [
](https://imgsli.com/MjkxNDU4)
6 | :-------------------------:|:-------------------------:
7 | Mars | Building
8 |
9 |
10 | [
](https://imgsli.com/MjkxNDU5) | [
](https://imgsli.com/MjkxNDYx)
11 | :-------------------------:|:-------------------------:
12 | Road | Car
13 |
14 |
15 | # Chanlleges and inspiration
16 | 
17 |
18 | # Structure
19 | 
20 |
21 | 
22 |
23 |
24 | ## Usage
25 |
26 | #### 1. Dataset
27 | Training dataset: [[Data]](https://drive.google.com/file/d/1o9BmWspPTJtFsBj66NN3FfM83cjp37IW/view?usp=sharing)
28 |
29 | Training dataset augmentation: [[Data_AUG]](https://drive.google.com/file/d/1Iv4CoQiInFORYn1kHjJCCCeuy6LKvnIc/view?usp=sharing)
30 |
31 | Validation dataset: [[clean]](https://drive.google.com/file/d/1WYYZCoEooOXDG49YJXJiNkCtVFgGdx2J/view?usp=sharing), [[dirty]](https://drive.google.com/file/d/1D1NAyMLbso_UL-YRqYfPduFR-Zs8g2sx/view?usp=sharing)
32 |
33 | ##### 2. Train.
34 | ```bash
35 | python train.py
36 | ```
37 |
38 | #### 3. Test and demo. [[Weight]](https://drive.google.com/file/d/1zbBsWUbRVBjNckPg5DiCgKIKOKWnQ2N8/view?usp=sharing)
39 | ```bash
40 | python test.py
41 | ```
42 | If the implementation of this repo is helpful to you, just star it!⭐⭐⭐
43 |
44 | If you find the code useful, please consider citing our paper using the following BibTeX entry.
45 |
46 | ```
47 | @ARTICLE{10855453,
48 | author={Yuan, Shuai and Qin, Hanlin and Yan, Xiang and Yang, Shiqi and Yang, Shuowen and Akhtar, Naveed and Zhou, Huixin},
49 | journal={IEEE Transactions on Geoscience and Remote Sensing},
50 | title={ASCNet: Asymmetric Sampling Correction Network for Infrared Image Destriping},
51 | year={2025},
52 | volume={63},
53 | number={},
54 | pages={1-15},
55 | keywords={Noise;Discrete wavelet transforms;Semantics;Image reconstruction;Feature extraction;Neural networks;Filters;Crosstalk;Aggregates;Geoscience and remote sensing;Asymmetric sampling (AS);column correction;deep learning;infrared (IR) image destriping;wavelet transform},
56 | doi={10.1109/TGRS.2025.3534838}}
57 | ```
58 |
59 | ## Contact
60 | **Welcome to raise issues or email to [yuansy@stu.xidian.edu.cn](yuansy@stu.xidian.edu.cn) for any question regarding our ASCNet.**
61 |
--------------------------------------------------------------------------------
/Real/044.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/044.png
--------------------------------------------------------------------------------
/Real/045.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/045.png
--------------------------------------------------------------------------------
/Real/046.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/046.png
--------------------------------------------------------------------------------
/Real/047.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/047.png
--------------------------------------------------------------------------------
/Real/048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/048.png
--------------------------------------------------------------------------------
/Real/049.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/049.png
--------------------------------------------------------------------------------
/Real/050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/050.png
--------------------------------------------------------------------------------
/Real/051.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/051.png
--------------------------------------------------------------------------------
/Real/052.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/052.png
--------------------------------------------------------------------------------
/Real/053.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/053.png
--------------------------------------------------------------------------------
/Real/054.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/054.png
--------------------------------------------------------------------------------
/Real/055.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/055.png
--------------------------------------------------------------------------------
/Real/056.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/056.png
--------------------------------------------------------------------------------
/Real/057.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/057.png
--------------------------------------------------------------------------------
/Real/058.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/058.png
--------------------------------------------------------------------------------
/Real/059.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/059.png
--------------------------------------------------------------------------------
/Real/060.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/060.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset related functions
3 |
4 | Copyright (C) 2018, Matias Tassano
5 |
6 | This program is free software: you can use, modify and/or
7 | redistribute it under the terms of the GNU General Public
8 | License as published by the Free Software Foundation, either
9 | version 3 of the License, or (at your option) any later
10 | version. You should have received a copy of this license along
11 | this program. If not, see .
12 | """
13 | import os
14 | import os.path
15 | import random
16 | import glob
17 | import numpy as np
18 | import cv2
19 | import h5py
20 | import torch
21 | import torch.utils.data as udata
22 | from PIL import Image
23 |
24 | from utils import data_augmentation, normalize
25 | class Dataset(udata.Dataset):
26 | r"""Implements torch.utils.data.Dataset
27 | """
28 |
29 | def __init__(self, train=True, gray_mode=False, shuffle=False):
30 | super(Dataset, self).__init__()
31 | self.train = train
32 | self.gray_mode = gray_mode
33 | if not self.gray_mode:
34 | self.traindbf = 'train_rgb.h5'
35 | self.valdbf = 'val_rgb.h5'
36 | self.valdirtydbf = 'val_dirty_rgb.h5'
37 | else:
38 | self.traindbf = 'train_gray.h5'
39 | self.valdbf = 'val_gray.h5'
40 | self.valdirtydbf = 'val_dirty_gray.h5'
41 |
42 | if self.train:
43 | h5f = h5py.File(self.traindbf, 'r')
44 | self.keys = list(h5f.keys())
45 | if shuffle:
46 | random.shuffle(self.keys)
47 | h5f.close()
48 | else:
49 | h5f = h5py.File(self.valdbf, 'r')
50 | h5f_dirty = h5py.File(self.valdirtydbf, 'r')
51 | self.keys = list(h5f.keys())
52 | if shuffle:
53 | random.shuffle(self.keys)
54 | h5f.close()
55 | h5f_dirty.close()
56 |
57 | def __len__(self):
58 | return len(self.keys)
59 |
60 | def __getitem__(self, index):
61 | # 从 计算机的具体路径下 读图片 转化为 pytroch框架可以认识的形式
62 | # pytroch: tensor张量
63 |
64 | if self.train:
65 | h5f = h5py.File(self.traindbf, 'r')
66 | key = self.keys[index]
67 | data = np.array(h5f[key])
68 | h5f.close()
69 | return torch.Tensor(data)
70 | else:
71 | h5f = h5py.File(self.valdbf, 'r')
72 | h5f_dirty = h5py.File(self.valdirtydbf, 'r')
73 | key = self.keys[index]
74 | data_clean = np.array(h5f[key])
75 | data_dirty = np.array(h5f_dirty[key])
76 | h5f.close()
77 | h5f_dirty.close()
78 | return torch.Tensor(data_clean), torch.Tensor(data_dirty)
79 |
80 |
81 | def img_to_patches(img, win, stride=1):
82 | r"""Converts an image to an array of patches.
83 |
84 | Args:
85 | img: a numpy array containing a CxHxW RGB (C=3) or grayscale (C=1)
86 | image
87 | win: size of the output patches
88 | stride: int. stride
89 | """
90 | k = 0
91 | endc = img.shape[0]
92 | endw = img.shape[1]
93 | endh = img.shape[2]
94 | patch = img[:, 0:endw - win + 0 + 1:stride, 0:endh - win + 0 + 1:stride]
95 | total_pat_num = patch.shape[1] * patch.shape[2]
96 | res = np.zeros([endc, win * win, total_pat_num], np.float32)
97 | for i in range(win):
98 | for j in range(win):
99 | patch = img[:, i:endw - win + i + 1:stride, j:endh - win + j + 1:stride]
100 | res[:, k, :] = np.array(patch[:]).reshape(endc, total_pat_num)
101 | k = k + 1
102 | return res.reshape([endc, win, win, total_pat_num])
103 |
104 | def prepare_data(data_path, \
105 | val_data_path, \
106 | val_data_dirty_path, \
107 | patch_size, \
108 | stride, \
109 | max_num_patches=None, \
110 | aug_times=1, \
111 | gray_mode=True):
112 | r"""Builds the training and validations datasets by scanning the
113 | corresponding directories for images and extracting patches from them.
114 |
115 | Args:
116 | data_path: path containing the training image dataset
117 | val_data_path: path containing the validation image dataset
118 | patch_size: size of the patches to extract from the images
119 | stride: size of stride to extract patches
120 | stride: size of stride to extract patches
121 | max_num_patches: maximum number of patches to extract
122 | aug_times: number of times to augment the available data minus one
123 | gray_mode: build the databases composed of grayscale patches
124 | """
125 | # training database
126 | print('> Training database')
127 | # scales = [1, 0.9, 0.8, 0.7]
128 | scales = [1, 0.8, 0.6, 0.4]
129 | # scales = [1]
130 | types = ('*.bmp', '*.png')
131 | files = []
132 | for tp in types:
133 | files.extend(glob.glob(os.path.join(data_path, tp)))
134 | files.sort()
135 |
136 | if gray_mode:
137 | traindbf = 'train_gray.h5'
138 | valdbf = 'val_gray.h5'
139 | valdirtydbf = 'val_dirty_gray.h5'
140 | else:
141 | traindbf = 'train_rgb.h5'
142 | valdbf = 'val_rgb.h5'
143 | valdirtydbf = 'val_dirty_rgb.h5'
144 |
145 | if max_num_patches is None:
146 | max_num_patches = 5000000
147 | print("\tMaximum number of patches not set")
148 | else:
149 | print("\tMaximum number of patches set to {}".format(max_num_patches))
150 | train_num = 0
151 | i = 0
152 | with h5py.File(traindbf, 'w') as h5f:
153 | while i < len(files) and train_num < max_num_patches:
154 | imgor = cv2.imread(files[i])
155 | # h, w, c = img.shape
156 | for sca in scales:
157 | img = cv2.resize(imgor, (0, 0), fx=sca, fy=sca, \
158 | interpolation=cv2.INTER_CUBIC)
159 | if not gray_mode:
160 | # CxHxW RGB image
161 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
162 | else:
163 | # CxHxW grayscale image (C=1)
164 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
165 | img = np.expand_dims(img, 0)
166 | img = normalize(img)
167 | # argument
168 | patches = img_to_patches(img, win=patch_size, stride=stride)
169 | # data = patches[:, :, :, 1]
170 | # data = data * 255
171 | # im = Image.fromarray(data.reshape(64, 64)) # 这是把numpy转化成了PIL
172 | # im.show()
173 |
174 | print("\tfile: %s scale %.1f # samples: %d" % \
175 | (files[i], sca, patches.shape[3] * aug_times))
176 | for nx in range(patches.shape[3]):
177 | data = data_augmentation(patches[:, :, :, nx].copy(), \
178 | np.random.randint(0, 7))
179 | h5f.create_dataset(str(train_num), data=data)
180 | train_num += 1
181 | for mx in range(aug_times - 1):
182 | data_aug = data_augmentation(data, np.random.randint(1, 4))
183 | h5f.create_dataset(str(train_num) + "_aug_%d" % (mx + 1), data=data_aug)
184 | train_num += 1
185 | i += 1
186 |
187 | # validation database
188 | print('\n> Validation database')
189 | files = []
190 | for tp in types:
191 | files.extend(glob.glob(os.path.join(val_data_path, tp)))
192 | files.sort()
193 | h5f = h5py.File(valdbf, 'w')
194 | val_num = 0
195 | for i, item in enumerate(files):
196 | print("\tfile: %s" % item)
197 | img = cv2.imread(item)
198 | if not gray_mode:
199 | # C. H. W, RGB image
200 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
201 | else:
202 | # C, H, W grayscale image (C=1)
203 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
204 | img = np.expand_dims(img, 0)
205 | img = normalize(img)
206 | h5f.create_dataset(str(val_num), data=img)
207 | val_num += 1
208 | h5f.close()
209 | print('\n> Validation dirty database')
210 | files_dirty = []
211 | for tp in types:
212 | files_dirty.extend(glob.glob(os.path.join(val_data_dirty_path, tp)))
213 | files_dirty.sort()
214 | h5f = h5py.File(valdirtydbf, 'w')
215 | val_num_dirty = 0
216 | for i, item in enumerate(files_dirty):
217 | print("\tfile: %s" % item)
218 | img = cv2.imread(item)
219 | if not gray_mode:
220 | # C. H. W, RGB image
221 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
222 | else:
223 | # C, H, W grayscale image (C=1)
224 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
225 | img = np.expand_dims(img, 0)
226 | img = normalize(img)
227 | h5f.create_dataset(str(val_num_dirty), data=img)
228 | val_num_dirty += 1
229 | h5f.close()
230 |
231 | print('\n> Total')
232 | print('\ttraining set, # samples %d' % train_num)
233 | print('\tvalidation set, # samples %d\n' % val_num)
234 | print('\tvalidation dirty set, # samples %d\n' % val_num_dirty)
235 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | # from SSIM import *
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import pywt
7 | import torch.nn as nn
8 | import argparse
9 | from model.ASCNet import ASCNet
10 | import time
11 |
12 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13 |
14 | parser = argparse.ArgumentParser(description="Demo")
15 | parser.add_argument("--log_path", type=str,
16 | default=r"XXXX.pth")
17 | parser.add_argument("--filename", type=str, default=r"XXXXX")
18 |
19 | parser.add_argument("--savepth", type=str, default=r"XXXXX",
20 | help='path of result image file')
21 |
22 | parser.add_argument("--mk", type=str, default=r"XXXXX/",
23 | help='path of result image file')
24 |
25 |
26 | opt = parser.parse_args()
27 |
28 | model = ASCNet(1, 1, feats=32)
29 | model = nn.DataParallel(model)
30 | model.load_state_dict(torch.load(opt.log_path,map_location='cpu'))
31 |
32 | namelist = os.listdir(opt.filename)
33 | namelist.sort()
34 |
35 | if os.path.exists(opt.mk):
36 | pass
37 | else:
38 | os.makedirs(opt.mk)
39 |
40 |
41 | # def normalization(data):
42 | # _range = np.max(data) - np.min(data)
43 | # return (data - np.min(data)) / _range
44 |
45 |
46 | model.eval()
47 | for name in namelist:
48 | image = cv2.imread(os.path.join(opt.filename, name))
49 | img_np = np.expand_dims(image[:, :, 0], 0)
50 | img_np = np.float32(img_np / 255.)
51 | img_tensor = torch.from_numpy(img_np)
52 | img_tensor = torch.unsqueeze(img_tensor, 0)
53 | # time_start = time.time()
54 | out = model(img_tensor)
55 | # time_end = time.time()
56 | out_np = out.data.cpu().numpy()
57 | # time_c = time_end - time_start # 运行所花时间
58 | # print('time cost', time_c, 's')
59 | out_val = out_np[0, :, :, :]
60 |
61 | out_val = np.transpose(out_val, (1, 2, 0))
62 |
63 |
64 | # Clamp
65 | out_val = out_val * 255
66 | out_valf = np.clip(out_val, 0, 255)
67 |
68 | # Normalization
69 | # final=normalization(out_val)
70 | # out_valf = final * 255
71 |
72 | cv2.imwrite(os.path.join(opt.savepth, name), out_valf.astype("uint8"))
73 |
--------------------------------------------------------------------------------
/model/ASCNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from pytorch_wavelets import DWTForward, DWTInverse
4 | from torchvision import transforms
5 | from model.cbam import *
6 | import cv2
7 | from utils import weights_init_kaiming
8 | import os
9 | from thop import profile
10 | from thop import clever_format
11 | # from torchvision import transforms
12 | # import matplotlib.pyplot as plt
13 | from torch.autograd import Variable
14 | from torchvision import models
15 | import numpy as np
16 |
17 |
18 |
19 | class DWT(nn.Module):
20 | def __init__(self):
21 | super(DWT, self).__init__()
22 | self.requires_grad = False
23 |
24 | def forward(self, x):
25 | return dwt_init(x)
26 |
27 |
28 | class IWT(nn.Module):
29 | def __init__(self):
30 | super(IWT, self).__init__()
31 | self.requires_grad = False
32 |
33 | def forward(self, x):
34 | return iwt_init(x)
35 |
36 |
37 | # double_conv model
38 | class double_conv(nn.Module):
39 | def __init__(self, in_channels, out_channels):
40 | super(double_conv, self).__init__()
41 | self.d_conv = nn.Sequential(
42 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
43 | nn.LeakyReLU(inplace=True),
44 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
45 | nn.LeakyReLU(inplace=True)
46 | )
47 |
48 | def forward(self, x):
49 | x = self.d_conv(x)
50 | return x
51 |
52 |
53 | class single_conv(nn.Module):
54 | def __init__(self, in_channels, out_channels):
55 | super(single_conv, self).__init__()
56 | self.s_conv = nn.Sequential(
57 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
58 | nn.LeakyReLU(inplace=True),
59 | )
60 |
61 | def forward(self, x):
62 | x = self.s_conv(x)
63 | return x
64 |
65 |
66 | class single_conv_res(nn.Module):
67 | def __init__(self, in_channels, out_channels):
68 | super(single_conv_res, self).__init__()
69 | self.s_conv = nn.Sequential(
70 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
71 | nn.LeakyReLU(inplace=True),
72 | )
73 |
74 | def forward(self, x):
75 | residual = x
76 | x = self.s_conv(x)
77 | out = torch.add(x, residual)
78 | return out
79 |
80 |
81 | class conv11(nn.Module):
82 | def __init__(self, in_channels, out_channels):
83 | super(conv11, self).__init__()
84 | self.s_conv = nn.Conv2d(in_channels, out_channels, 1)
85 |
86 | def forward(self, x):
87 | x = self.s_conv(x)
88 | return x
89 |
90 |
91 | class conv33(nn.Module):
92 | def __init__(self, in_channels, out_channels):
93 | super(conv33, self).__init__()
94 | self.s_conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
95 |
96 | def forward(self, x):
97 | x = self.s_conv(x)
98 | return x
99 |
100 |
101 | class ChannelPool(nn.Module):
102 | def forward(self, x):
103 | # 将maxpooling 与 global average pooling 结果拼接在一起
104 | return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
105 |
106 |
107 | class Basic(nn.Module):
108 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, relu=True, bn=True, bias=False):
109 | super(Basic, self).__init__()
110 | self.out_channels = out_planes
111 | self.conv = nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride,
112 | padding=padding, bias=bias)
113 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
114 | self.relu = nn.LeakyReLU() if relu else None
115 |
116 | def forward(self, x):
117 | x = self.conv(x)
118 | if self.bn is not None:
119 | x = self.bn(x)
120 | if self.relu is not None:
121 | x = self.relu(x)
122 | return x
123 |
124 |
125 | class CALayer(nn.Module):
126 | def __init__(self, channel, reduction=16):
127 | super(CALayer, self).__init__()
128 |
129 | self.avgPoolW = nn.AdaptiveAvgPool2d((1, None))
130 | self.maxPoolW = nn.AdaptiveMaxPool2d((1, None))
131 |
132 |
133 | self.conv_1x1 = nn.Conv2d(in_channels=2 * channel, out_channels=2 * channel, kernel_size=1, padding=0, stride=1,
134 | bias=False)
135 | self.bn = nn.BatchNorm2d(2 * channel, eps=1e-5, momentum=0.01, affine=True)
136 | self.Relu = nn.LeakyReLU()
137 |
138 | self.F_h = nn.Sequential( # 激发操作
139 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
140 | nn.BatchNorm2d(channel // reduction, eps=1e-5, momentum=0.01, affine=True),
141 | nn.ReLU(inplace=True),
142 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
143 | )
144 | self.F_w = nn.Sequential( # 激发操作
145 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
146 | nn.BatchNorm2d(channel // reduction, eps=1e-5, momentum=0.01, affine=True),
147 | nn.ReLU(inplace=True),
148 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
149 | )
150 | self.sigmoid = nn.Sigmoid()
151 |
152 | def forward(self, x):
153 | N, C, H, W = x.size()
154 | res = x
155 | x_cat = torch.cat([self.avgPoolW(x), self.maxPoolW(x)], 1)
156 | x = self.Relu(self.bn(self.conv_1x1(x_cat)))
157 | x_1, x_2 = x.split(C, 1)
158 |
159 | x_1 = self.F_h(x_1)
160 | x_2 = self.F_w(x_2)
161 | s_h = self.sigmoid(x_1)
162 | s_w = self.sigmoid(x_2)
163 |
164 | out = res * s_h.expand_as(res) * s_w.expand_as(res)
165 |
166 | return out
167 |
168 |
169 | class spatial_attn_layer(nn.Module):
170 | def __init__(self, kernel_size=3):
171 | super(spatial_attn_layer, self).__init__()
172 | self.compress = ChannelPool()
173 | self.spatial = Basic(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bn=False, relu=False)
174 |
175 | def forward(self, x):
176 | x_compress = self.compress(x)
177 | x_out = self.spatial(x_compress)
178 | scale = torch.sigmoid(x_out) # broadcasting
179 | return x * scale
180 |
181 |
182 | class Sep(nn.Module):
183 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=1, bias=True):
184 | super().__init__()
185 | self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size, stride, padding, groups=in_channel, bias=bias)
186 | self.conv2 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=bias)
187 |
188 | def forward(self, input):
189 | x = self.conv1(input)
190 | x = self.conv2(x)
191 | return x
192 |
193 |
194 | class RCSSC(nn.Module):
195 | def __init__(self, n_feat, reduction=16):
196 | super(RCSSC, self).__init__()
197 | pooling_r = 4
198 | self.head = nn.Sequential(
199 | nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1, stride=1, bias=True),
200 | nn.LeakyReLU(),
201 | )
202 | self.SC = nn.Sequential(
203 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
204 | nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1, stride=1, bias=True),
205 | nn.BatchNorm2d(n_feat)
206 | )
207 | self.SA = spatial_attn_layer() ## Spatial Attention
208 | self.CA = CALayer(n_feat, reduction) ## Channel Attention
209 |
210 | self.conv1x1 = nn.Sequential(
211 | nn.Conv2d(n_feat * 2, n_feat, kernel_size=1),
212 | nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1, stride=1, bias=True)
213 | )
214 | self.ReLU = nn.LeakyReLU()
215 | self.tail = nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1)
216 |
217 | def forward(self, x):
218 | res = x
219 | x = self.head(x)
220 | sa_branch = self.SA(x)
221 | ca_branch = self.CA(x)
222 | x1 = torch.cat([sa_branch, ca_branch], dim=1) # 拼接
223 | x1 = self.conv1x1(x1)
224 | x2 = torch.sigmoid(
225 | torch.add(x, F.interpolate(self.SC(x), x.size()[2:])))
226 | out = torch.mul(x1, x2)
227 | out = self.tail(out)
228 | out = out + res
229 | out = self.ReLU(out)
230 | return out
231 |
232 |
233 |
234 | class _DCR_block(nn.Module):
235 | def __init__(self, channel_in):
236 | super(_DCR_block, self).__init__()
237 | self.conv_1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in / 2.), kernel_size=3, stride=1,
238 | padding=1)
239 | self.relu1 = nn.LeakyReLU()
240 | self.conv_2 = nn.Conv2d(in_channels=int(channel_in * 3 / 2.), out_channels=int(channel_in / 2.), kernel_size=3,
241 | stride=1, padding=1)
242 | self.relu2 = nn.LeakyReLU()
243 | self.conv_3 = nn.Conv2d(in_channels=channel_in * 2, out_channels=channel_in, kernel_size=3, stride=1, padding=1)
244 | self.relu3 = nn.LeakyReLU()
245 |
246 | def forward(self, x):
247 | residual = x
248 | out = self.relu1(self.conv_1(x))
249 | conc = torch.cat([x, out], 1)
250 | out = self.relu2(self.conv_2(conc))
251 | conc = torch.cat([conc, out], 1)
252 | out = self.relu3(self.conv_3(conc))
253 | out = torch.add(out, residual)
254 | return out
255 |
256 |
257 | class New_block(nn.Module):
258 | def __init__(self, channel_in, reduction):
259 | super(New_block, self).__init__()
260 |
261 | # RCSSC
262 | self.unit_1 = RCSSC(int(channel_in / 2.), reduction)
263 | self.unit_2 = RCSSC(int(channel_in / 2.), reduction)
264 |
265 | self.conv1 = nn.Sequential(
266 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in / 2.), kernel_size=3, padding=1),
267 | nn.LeakyReLU()
268 | )
269 | self.conv2 = nn.Sequential(
270 | nn.Conv2d(in_channels=int(channel_in * 3 / 2.), out_channels=int(channel_in / 2.), kernel_size=3,
271 | padding=1),
272 | nn.LeakyReLU()
273 | )
274 | self.conv3 = nn.Sequential(
275 | nn.Conv2d(in_channels=channel_in * 2, out_channels=channel_in, kernel_size=1, padding=0,
276 | stride=1), # 做压缩
277 | nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=3, padding=1),
278 | nn.LeakyReLU()
279 | )
280 |
281 | def forward(self, x):
282 | residual = x
283 | c1 = self.unit_1(self.conv1(x))
284 | x = torch.cat([residual, c1], 1)
285 | c2 = self.unit_2(self.conv2(x))
286 | x = torch.cat([c2, x], 1)
287 | x = self.conv3(x)
288 | x = torch.add(x, residual)
289 | return x
290 |
291 |
292 | class ASCNet(nn.Module):
293 |
294 | def __init__(self, in_ch, out_ch, feats):
295 | super(ASCNet, self).__init__()
296 | self.features = []
297 |
298 | self.head = single_conv(in_ch, feats)
299 | self.dconv_encode0 = double_conv(feats, feats) # → har
300 |
301 | self.identety1 = nn.Conv2d(in_channels=feats, out_channels=2 * feats, kernel_size=3, stride=2, padding=1)
302 | self.DWT = DWTForward(J=1, wave='haar')
303 | self.dconv_encode1 = single_conv(4 * feats, 2 * feats)
304 |
305 | # CNCM
306 | self.enhance1 = New_block(2 * feats, reduction=16)
307 |
308 | self.identety2 = nn.Conv2d(in_channels=2 * feats, out_channels=4 * feats, kernel_size=3, stride=2, padding=1)
309 |
310 | self.dconv_encode2 = single_conv(8 * feats, 4 * feats)
311 |
312 | self.dconv_encode3 = single_conv(16 * feats, 4 * feats)
313 |
314 | self.enhance2 = New_block(4 * feats, reduction=16)
315 | self.identety3 = nn.Conv2d(in_channels=4 * feats, out_channels=4 * feats, kernel_size=3, stride=2, padding=1)
316 | self.maxpool = nn.MaxPool2d(2)
317 | self.enhance3 = New_block(4 * feats, reduction=16)
318 |
319 | self.mid1 = single_conv(8 * feats, 4 * feats)
320 | self.mid2 = single_conv(4 * feats, 4 * feats + 4 * feats)
321 |
322 | self.pixs = nn.PixelShuffle(2)
323 |
324 | # decoder*****************************************************
325 | self.upsample2 = nn.Sequential(
326 | nn.ConvTranspose2d(8 * feats, 4 * feats, kernel_size=2, stride=2),
327 | # nn.LeakyReLU(inplace=True)
328 | )
329 | self.upsample1 = nn.Sequential(
330 | nn.ConvTranspose2d(4 * feats, 2 * feats, kernel_size=2, stride=2),
331 | # nn.LeakyReLU(inplace=True)
332 | )
333 |
334 | self.upsample0 = nn.Sequential(
335 | nn.ConvTranspose2d(2 * feats, feats, kernel_size=2, stride=2),
336 | # nn.LeakyReLU(inplace=True)
337 | )
338 | self.IDWT = DWTInverse(wave='haar')
339 |
340 | # fair *******************************************************
341 | self.fair2 = nn.Conv2d(2 * feats, 4 * feats, kernel_size=3, padding=1)
342 | self.fair1 = nn.Conv2d(1 * feats, 2 * feats, kernel_size=3, padding=1)
343 | self.fair0 = nn.Conv2d(int(feats / 2), feats, kernel_size=3, padding=1)
344 |
345 | # decoder*****************************************************
346 | self.dconv_decode2 = nn.Sequential(conv11(4 * feats + 4 * feats, 4 * feats),
347 | New_block(4 * feats, reduction=16))
348 |
349 | self.dconv_decode1 = nn.Sequential(conv11(2 * feats + 2 * feats, 2 * feats),
350 | New_block(2 * feats, reduction=16))
351 |
352 | self.dconv_decode0 = double_conv(feats + feats, feats)
353 | self.tail = nn.Sequential(nn.Conv2d(feats, out_ch, 1), nn.Tanh())
354 |
355 | def make_layer(self, block, channel_in):
356 | layers = []
357 | layers.append(block(channel_in))
358 | return nn.Sequential(*layers)
359 |
360 | def _transformer(self, DMT1_yl, DMT1_yh):
361 | list_tensor = []
362 | a = DMT1_yh[0]
363 | list_tensor.append(DMT1_yl)
364 | for i in range(3):
365 | list_tensor.append(a[:, :, i, :, :])
366 | return torch.cat(list_tensor, 1)
367 |
368 | def _Itransformer(self, out):
369 | yh = []
370 | C = int(out.shape[1] / 4)
371 | yl = out[:, 0:C, :, :]
372 | y1 = out[:, C:2 * C, :, :].unsqueeze(2)
373 | y2 = out[:, 2 * C:3 * C, :, :].unsqueeze(2)
374 | y3 = out[:, 3 * C:4 * C, :, :].unsqueeze(2)
375 | final = torch.cat([y1, y2, y3], 2)
376 | yh.append(final)
377 | return yl, yh
378 |
379 | def forward(self, x):
380 | inputs = x
381 |
382 | x0 = self.dconv_encode0(self.head(x))
383 |
384 | DMT1_yl, DMT1_yh = self.DWT(x0)
385 | DMT1 = self._transformer(DMT1_yl, DMT1_yh)
386 | x = self.dconv_encode1(DMT1)
387 |
388 | res1 = self.identety1(x0)
389 | out = torch.add(x, res1)
390 |
391 | x1 = self.enhance1(out)
392 |
393 | DMT1_yl, DMT1_yh = self.DWT(x1)
394 | DMT2 = self._transformer(DMT1_yl, DMT1_yh)
395 | x = self.dconv_encode2(DMT2)
396 |
397 | res1 = self.identety2(x1)
398 | out2 = torch.add(x, res1)
399 |
400 | x2 = self.enhance2(out2)
401 |
402 | DMT1_yl, DMT1_yh = self.DWT(x2)
403 | DMT3 = self._transformer(DMT1_yl, DMT1_yh)
404 | x = self.dconv_encode3(DMT3)
405 |
406 | res1 = self.identety3(x2)
407 | out3 = torch.add(x, res1)
408 | # MI = self.mid1(out3)
409 | x3 = self.mid2(self.enhance3(out3))
410 |
411 | x = self.pixs(x3)
412 | x = self.fair2(x)
413 | x = self.dconv_decode2(torch.cat([x, x2], dim=1))
414 | x = self.pixs(x)
415 | x = self.fair1(x)
416 | x = self.dconv_decode1(torch.cat([x, x1], dim=1))
417 | x = self.pixs(x)
418 | x = self.fair0(x)
419 |
420 | x = self.dconv_decode0(torch.cat([x, x0], dim=1))
421 | x = self.tail(x)
422 | out = x + inputs
423 |
424 | return out
425 |
426 |
427 | if __name__ == '__main__':
428 | net = ASCNet(1, 1, feats=32)
429 | input = torch.zeros((1, 1, 256, 256), dtype=torch.float32)
430 | output = net(input)
431 |
432 | flops, params = profile(net, (input,))
433 | print("-" * 50)
434 | print('FLOPs = ' + str(flops / 1000 ** 3) + ' G')
435 | print('Params = ' + str(params / 1000 ** 2) + ' M')
436 | print(output.shape)
437 |
--------------------------------------------------------------------------------
/model/cbam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class BasicConv(nn.Module):
7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8 | super(BasicConv, self).__init__()
9 | self.out_channels = out_planes
10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
12 | self.relu = nn.ReLU() if relu else None
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.bn is not None:
17 | x = self.bn(x)
18 | if self.relu is not None:
19 | x = self.relu(x)
20 | return x
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | class ChannelGate(nn.Module):
27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 | super(ChannelGate, self).__init__()
29 | self.gate_channels = gate_channels
30 | self.mlp = nn.Sequential(
31 | Flatten(),
32 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
33 | nn.ReLU(),
34 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
35 | )
36 | self.pool_types = pool_types
37 | def forward(self, x):
38 | channel_att_sum = None
39 | for pool_type in self.pool_types:
40 | if pool_type=='avg':
41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp( avg_pool )
43 | elif pool_type=='max':
44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 | channel_att_raw = self.mlp( max_pool )
46 | elif pool_type=='lp':
47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp( lp_pool )
49 | elif pool_type=='lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp( lse_pool )
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 | return x * scale
61 |
62 | def logsumexp_2d(tensor):
63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 | return outputs
67 |
68 | class ChannelPool(nn.Module):
69 | def forward(self, x):
70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 |
72 | class SpatialGate(nn.Module):
73 | def __init__(self):
74 | super(SpatialGate, self).__init__()
75 | kernel_size = 7
76 | self.compress = ChannelPool()
77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78 | def forward(self, x):
79 | x_compress = self.compress(x)
80 | x_out = self.spatial(x_compress)
81 | scale = F.sigmoid(x_out) # broadcasting
82 | return x * scale
83 |
84 | class CBAM(nn.Module):
85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86 | super(CBAM, self).__init__()
87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88 | self.no_spatial=no_spatial
89 | if not no_spatial:
90 | self.SpatialGate = SpatialGate()
91 | def forward(self, x):
92 | x_out = self.ChannelGate(x)
93 | if not self.no_spatial:
94 | x_out = self.SpatialGate(x_out)
95 | return x_out
96 |
--------------------------------------------------------------------------------
/prepare_patches.py:
--------------------------------------------------------------------------------
1 | """
2 | 生成 .h5 数据文件
3 | """
4 | import argparse
5 | from dataset import prepare_data
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser(description="Building the training patch database")
9 | parser.add_argument("--gray", default=True, action='store_true', help='prepare grayscale database instead of RGB')
10 | # Preprocessing parameters
11 | parser.add_argument("--patch_size", "--p", type=int, default=64, help="Patch size")
12 | parser.add_argument("--stride", "--s", type=int, default=40, help="Size of stride")
13 | parser.add_argument("--max_number_patches", "--m", type=int, default=180,
14 | # parser.add_argument("--max_number_patches", "--m", type=int, default=18,
15 | help="Maximum number of patches")
16 | parser.add_argument("--aug_times", "--a", type=int, default=2,
17 | help="How many times to perform data augmentation")
18 | # Dirs
19 | parser.add_argument("--trainset_dir", type=str, default=r"D:\SCI\03-SCI\dataset\All", help='path of trainset')
20 | parser.add_argument("--valset_dir", type=str, default=r"D:\SCI\03-SCI\dataset\All_16\image",
21 | help='path of validation set')
22 | parser.add_argument("--valset_dirty_dir", type=str, default=r"D:\SCI\03-SCI\dataset\All_16\nosie",
23 | help='path of validation set')
24 | args = parser.parse_args()
25 |
26 |
27 | print("\n### Building databases ###")
28 | print("> Parameters:")
29 | for p, v in zip(args.__dict__.keys(), args.__dict__.values()):
30 | print('\t{}: {}'.format(p, v))
31 | print('\n')
32 |
33 | prepare_data(args.trainset_dir, args.valset_dir, args.valset_dirty_dir, args.patch_size, args.stride, args.max_number_patches,
34 | aug_times=args.aug_times, gray_mode=args.gray)
35 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch.nn as nn
5 | import argparse
6 | from model.ASCNet import ASCNet
7 | import time
8 | from utils import *
9 | import numpy as np
10 | import torch
11 | import pywt
12 | import torch.nn as nn
13 | import lpips
14 |
15 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16 |
17 | parser = argparse.ArgumentParser(description="Demo")
18 | parser.add_argument("--log_path", type=str,
19 | default=r"XXXXXX")
20 | parser.add_argument("--filename", type=str, default=r"XXXX")
21 | parser.add_argument("--save", type=bool, default=False)
22 | opt = parser.parse_args()
23 |
24 |
25 | def normalization(data):
26 | _range = np.max(data) - np.min(data)
27 | return (data - np.min(data)) / _range
28 |
29 |
30 | ssim = SSIM()
31 | loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
32 | loss_fn_vgg = lpips.LPIPS(net='vgg')
33 |
34 | cleanfilename = os.path.join(opt.filename, 'image')
35 | clclist = ['Gauss', 'Uniform', 'Cycle']
36 |
37 |
38 |
39 | model = ASCNet(1, 1, feats=32)
40 | # model = nn.DataParallel(model).cuda()
41 | model = nn.DataParallel(model)
42 | model.load_state_dict(torch.load(opt.log_path, map_location='cpu'))
43 |
44 |
45 | psnr_sum = 0
46 | ssim_sum = 0
47 | lpips_sum = 0
48 |
49 |
50 | for clc in clclist:
51 | savepth = os.path.join(opt.filename, 'ASCNet', clc)
52 | mk = savepth + '\\'
53 | noisepth = os.path.join(opt.filename, 'noise', clc)
54 | namelist = os.listdir(cleanfilename)
55 | namelist.sort()
56 | model.eval()
57 | with torch.no_grad():
58 | for name in namelist:
59 | # read noise image and process it
60 | image = cv2.imread(os.path.join(noisepth, name))
61 | img_np = np.expand_dims(image[:, :, 0], 0)
62 | img_np = np.float32(img_np / 255.)
63 | img_tensor = torch.from_numpy(img_np)
64 | img_tensor = torch.unsqueeze(img_tensor, 0)
65 | # out, outstripe = model(img_tensor)
66 | out = model(img_tensor)
67 | out_val = torch.clip(out, 0., 1.)
68 |
69 | # read clean image
70 | image2 = cv2.imread(os.path.join(cleanfilename, name))
71 | img_np2 = np.expand_dims(image2[:, :, 0], 0)
72 | img_np2562 = np.float32(img_np2 / 255.)
73 | img_clean = torch.from_numpy(img_np2562)
74 | # img_clean = torch.unsqueeze(img_clean, 0).cuda()
75 | img_clean = torch.unsqueeze(img_clean, 0)
76 |
77 | # calculate PSNR and SSIM
78 | psnr_val = batch_psnr(out_val, img_clean, 1.)
79 | ssim_val = ssim(img_clean, out_val)
80 | lpips_val = loss_fn_alex(out_val, img_clean)
81 |
82 | psnr_sum = psnr_sum + psnr_val
83 | ssim_sum = ssim_sum + ssim_val
84 | lpips_sum = lpips_sum + lpips_val
85 | # print(name)
86 | if opt.save:
87 | if os.path.exists(mk):
88 | pass
89 | else:
90 | os.makedirs(mk)
91 |
92 |
93 | out_np = out.data.cpu().numpy()
94 | # out_np = out.data.numpy()
95 | out_val = out_np[0, :, :, :]
96 | out_val = np.transpose(out_val, (1, 2, 0))
97 | out_val = out_val * 255
98 | out_valf = np.clip(out_val, 0, 255)
99 | savepth = os.path.join(opt.filename, 'ASCNet', clc)
100 | # final=normalization(out_val)
101 | # out_valf = final * 255
102 | cv2.imwrite(os.path.join(savepth, name), out_valf.astype("uint8"))
103 |
104 | psnr_val = psnr_sum / len(namelist)
105 | ssim_val = ssim_sum / len(namelist)
106 | lpips_val = lpips_sum / len(namelist)
107 |
108 | print("*" * 10 + clc + "*" * 10)
109 | print("PSNR_sum: %.4f" % psnr_val)
110 | print("SSIM_sum: %.4f" % ssim_val)
111 | print("LPIPS_sum: %.4f" % lpips_val)
112 |
113 | psnr_sum = 0
114 | ssim_sum = 0
115 | lpips_sum = 0
116 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | 修改网络
3 | GPU编号
4 | logs 修改名称并 远程新建文件 372
5 | pth名称 345
6 | 上传文件
7 |
8 | *************************************
9 |
10 | '''
11 |
12 | import warnings
13 | import os
14 | import argparse
15 | import cv2
16 | import numpy as np
17 | import torch
18 | import torch.nn as nn
19 | import torch.optim as optim
20 | from torch.autograd import Variable
21 | from torch.utils.data import DataLoader
22 | import torchvision.utils as utils
23 | from torch.utils.tensorboard import SummaryWriter
24 | # from models import FFDNet
25 | from dataset import Dataset
26 | from model.ASCNet import ASCNet
27 | from utils import *
28 | from warmup_scheduler import GradualWarmupScheduler
29 | from torchvision import transforms
30 | import matplotlib.pyplot as plt
31 | import random
32 |
33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
34 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
35 |
36 |
37 | # warnings.filterwarnings('ignore')
38 |
39 |
40 | def main(args):
41 | r"""Performs the main training loop
42 | """
43 | # Load dataset
44 | print('> Loading dataset ...') # 训练和验证都是读的h5文件
45 | dataset_train = Dataset(train=True, gray_mode=args.gray, shuffle=True)
46 | dataset_val = Dataset(train=False, gray_mode=args.gray, shuffle=False)
47 | # 训练数据走的DataLoder 验证数据没有走
48 | loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=args.batch_size, shuffle=True)
49 | print("\t# of training samples: %d\n" % int(len(dataset_train)))
50 |
51 | # Init loggers
52 | if not os.path.exists(args.log_dir):
53 | os.makedirs(args.log_dir)
54 | writer = SummaryWriter(args.log_dir)
55 | # **********************************************************************************************
56 | # build model
57 | # **********************************************************************************************
58 | net = ASCNet(1, 1, feats=16)
59 | # Define loss
60 | criterion = nn.MSELoss().cuda()
61 | ssim = SSIM().cuda()
62 |
63 | # Move to GPU
64 | device_ids = [0]
65 | model = nn.DataParallel(net, device_ids=device_ids).cuda()
66 |
67 | # Optimizer
68 | # optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.9999), eps=1e-8)
69 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.9999))
70 | warmup_epochs = 4
71 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs - warmup_epochs,
72 | eta_min=1e-6)
73 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs,
74 | after_scheduler=scheduler_cosine)
75 | scheduler.step()
76 | # noise case
77 |
78 | # case = 3
79 |
80 | start_epoch = 0
81 | training_params = {}
82 | training_params['step'] = 0
83 | training_params['no_orthog'] = args.no_orthog
84 |
85 | # Training
86 | for epoch in range(start_epoch, args.epochs):
87 | print("==============ASCNet==============", epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
88 | psnr_sum = 0
89 | psnr_val = 0
90 | ssim_sum = 0
91 | ssim_val = 0
92 |
93 | # train
94 | for i, data in enumerate(loader_train, 0):
95 | # case = random.randint(0, 3)
96 | case = 3
97 | # print(case)
98 | # Pre-training step
99 | model.train()
100 | model.zero_grad()
101 | optimizer.zero_grad()
102 | img_train = data
103 | # add noise
104 | imgn_train = add_noise(img_train, case, args.noiseIntL)
105 | # imgn_train = add_noise2(img_train, case, args.noiseIntL, args.noiseIntS)
106 | # Create input Variables
107 | img_train = Variable(img_train.cuda())
108 | imgn_train = Variable(imgn_train.cuda())
109 |
110 | # Evaluate model and optimize it
111 | out_train = model(imgn_train)
112 | # out_train = torch.clamp(model(imgn_train), 0., 1.)
113 | # torch.clamp(model(imgn_val), 0., 1.)
114 | # *************************************************************************************************************************
115 | # loss
116 | # *************************************************************************************************************************
117 | loss1 = criterion(out_train, img_train)
118 | # loss2 = l1(out_train, img_train)
119 | # loss3 = dre(out_train, img_train)
120 | # loss4 = tv(out_train - img_train)
121 | # loss5 = drestr(out_train, img_train)
122 | loss = loss1
123 | loss.backward()
124 | optimizer.step()
125 |
126 | if training_params['step'] % args.save_every == 0:
127 | # Apply regularization by orthogonalizing filters
128 | # Results
129 | model.eval()
130 | out_train = torch.clip(out_train, 0., 1.)
131 | psnr_train = batch_psnr(out_train, img_train, 1.)
132 | ssim_train = ssim(img_train, out_train)
133 | if not training_params['no_orthog']:
134 | model.apply(svd_orthogonalization)
135 |
136 | # Log the scalar values
137 | writer.add_scalar('loss', loss.item(), training_params['step'])
138 | writer.add_scalar('PSNR on training data', psnr_train, \
139 | training_params['step'])
140 | writer.add_scalar('SSIM on training data', ssim_train, \
141 | training_params['step'])
142 | print("[epoch %d][%d/%d] loss: %.6f PSNR_train: %.4f" % \
143 | (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train))
144 | training_params['step'] += 1
145 | scheduler.step()
146 | # The end of each epoch
147 |
148 | if epoch % 1 == 0:
149 | model.eval()
150 | with torch.no_grad():
151 | # Validation
152 | for dataclean, datadirty in dataset_val:
153 | datadirty_val = torch.unsqueeze(datadirty, 0)
154 | dataclean_val = torch.unsqueeze(dataclean, 0)
155 | datadirty_val, dataclean_val = Variable(datadirty_val.cuda()), Variable(dataclean_val.cuda())
156 | out_val = torch.clip(model(datadirty_val), 0., 1.)
157 | psnr_val = batch_psnr(out_val, dataclean_val, 1.)
158 | psnr_sum = psnr_sum + psnr_val
159 | ssim_val = ssim(dataclean_val, out_val)
160 | ssim_sum = ssim_sum + ssim_val.item()
161 | psnr_val = psnr_sum / len(dataset_val)
162 | ssim_val = ssim_sum / len(dataset_val)
163 | print("\n[epoch %d] PSNR_val: %.4f SSIM_val: %.6f" % (epoch + 1, psnr_val, ssim_val))
164 | writer.add_scalar('PSNR on validation data', psnr_val, training_params['step'])
165 | writer.add_scalar('SSIM on validation data', ssim_val, training_params['step'])
166 | writer.add_scalar('Learning rate', scheduler.get_lr()[0], training_params['step'])
167 |
168 | if epoch == 0:
169 | best_psnr = psnr_val
170 | best_ssim = ssim_val
171 |
172 | print("[epoch %d][%d/%d] psnr_avg: %.4f, ssim_avg: %.4f, best_psnr: %.4f, best_ssim: %.6f" %
173 | (epoch + 1, i + 1, len(dataset_val), psnr_val, ssim_val, best_psnr, best_ssim))
174 |
175 | if psnr_val >= best_psnr:
176 | best_psnr = psnr_val
177 | best_ssim = ssim_val
178 | print('--- save the model @ ep--{} PSNR--{} SSIM--{}'.format(epoch, best_psnr, best_ssim))
179 | best_psnr_s = format(best_psnr,'.4f')
180 | best_ssim_s = format(best_ssim,'.6f')
181 | s = "best_" + "ASCNet"+"_" + str(best_psnr_s) + "_" + str(best_ssim_s) + ".pth"
182 | torch.save(model.state_dict(), os.path.join(args.log_dir, s))
183 |
184 | training_params['start_epoch'] = epoch + 1
185 |
186 |
187 | if __name__ == "__main__":
188 |
189 | parser = argparse.ArgumentParser(description="ASCNet")
190 | # ********************************************************************************************************************************
191 | parser.add_argument("--log_dir", type=str, default="otherlogs/ASCNet", help='path of log files')
192 | parser.add_argument("--batch_size", type=int, default=128, help="Training batch size")
193 | parser.add_argument("--epochs", "--e", type=int, default=101, help="Number of total training epochs")
194 | parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
195 | parser.add_argument("--noiseIntL", nargs=2, type=int, default=[0.05, 0.15], help="Noise training interval")
196 | # parser.add_argument("--noiseIntS", nargs=2, type=int, default=[0, 0.25], help="Noise training interval")
197 | parser.add_argument("--seed", type=int, default=42, help="Threshold for test")
198 | parser.add_argument("--gray", default=True, action='store_true',
199 | help='train grayscale image denoising instead of RGB')
200 | parser.add_argument("--no_orthog", action='store_true', help="Don't perform orthogonalization as regularization")
201 | parser.add_argument("--save_every", type=int, default=100,
202 | help="Number of training steps to log psnr and perform orthogonalization")
203 | argspar = parser.parse_args()
204 |
205 | print("\n#########################################\n"
206 | " ASCNet "
207 | "\n#########################################\n")
208 | print("> Parameters:")
209 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
210 | print('\t{}: {}'.format(p, v))
211 | print('\n')
212 |
213 | seed_pytorch(argspar.seed)
214 |
215 | main(argspar)
216 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Different utilities such as orthogonalization of weights, initialization of
3 | loggers, etc
4 |
5 | Copyright (C) 2018, Matias Tassano
6 |
7 | This program is free software: you can use, modify and/or
8 | redistribute it under the terms of the GNU General Public
9 | License as published by the Free Software Foundation, either
10 | version 3 of the License, or (at your option) any later
11 | version. You should have received a copy of this license along
12 | this program. If not, see .
13 | """
14 | import subprocess
15 | import math
16 | import logging
17 | import numpy as np
18 | import cv2
19 | import torch
20 | import torch.nn as nn
21 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr
22 | from math import exp
23 | from torch.autograd import Variable
24 | from torch.nn import functional as F
25 | from PIL import Image
26 | import random
27 |
28 |
29 | from torchvision import transforms
30 | import matplotlib.pyplot as plt
31 |
32 | def seed_pytorch(seed=42):
33 | random.seed(seed)
34 | os.environ['PYTHONHASHSEED'] = str(seed)
35 | np.random.seed(seed)
36 | torch.manual_seed(seed)
37 | torch.cuda.manual_seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 |
40 |
41 | def weights_init_kaiming(lyr):
42 | r"""Initializes weights of the model according to the "He" initialization
43 | method described in "Delving deep into rectifiers: Surpassing human-level
44 | performance on ImageNet classification" - He, K. et al. (2015), using a
45 | normal distribution.
46 | This function is to be called by the torch.nn.Module.apply() method,
47 | which applies weights_init_kaiming() to every layer of the model.
48 | """
49 | classname = lyr.__class__.__name__
50 | if classname.find('Conv') != -1:
51 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in')
52 | elif classname.find('Linear') != -1:
53 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in')
54 | elif classname.find('BatchNorm') != -1:
55 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)). \
56 | clamp_(-0.025, 0.025)
57 | nn.init.constant(lyr.bias.data, 0.0)
58 |
59 |
60 | def ssim(img1, img2, window_size=11, size_average=True):
61 | (_, channel, _, _) = img1.size()
62 | window = create_window(window_size, channel)
63 |
64 | if img1.is_cuda:
65 | window = window.cuda(img1.get_device())
66 | window = window.type_as(img1)
67 |
68 | return _ssim(img1, img2, window, window_size, channel, size_average)
69 |
70 |
71 | def batch_psnr(img, imclean, data_range):
72 | r"""
73 | Computes the PSNR along the batch dimension (not pixel-wise)
74 |
75 | Args:
76 | img: a `torch.Tensor` containing the restored image
77 | imclean: a `torch.Tensor` containing the reference image
78 | data_range: The data range of the input image (distance between
79 | minimum and maximum possible values). By default, this is estimated
80 | from the image data-type.
81 | """
82 | img_cpu = img.data.cpu().numpy().astype(np.float32)
83 | imgclean = imclean.data.cpu().numpy().astype(np.float32)
84 | psnr = 0
85 | for i in range(img_cpu.shape[0]):
86 | psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :], \
87 | data_range=data_range)
88 | return psnr / img_cpu.shape[0]
89 |
90 |
91 | # def batch_ssim(img, imclean):
92 | #
93 | # img_cpu = img.data.cpu().numpy().astype(np.float32)
94 | # imgclean = imclean.data.cpu().numpy().astype(np.float32)
95 | # ssimall = 0
96 | # for i in range(img_cpu.shape[0]):
97 | # ssimall += ssim(img_cpu[i, :, :, :],imgclean[i, :, :, :])
98 | # return ssimall / img_cpu.shape[0]
99 |
100 | def data_augmentation(image, mode):
101 | r"""Performs dat augmentation of the input image
102 |
103 | Args:
104 | image: a cv2 (OpenCV) image
105 | mode: int. Choice of transformation to apply to the image
106 | 0 - no transformation
107 | 1 - flip up and down
108 | 2 - rotate counterwise 90 degree
109 | 3 - rotate 90 degree and flip up and down
110 | 4 - rotate 180 degree
111 | 5 - rotate 180 degree and flip
112 | 6 - rotate 270 degree
113 | 7 - rotate 270 degree and flip
114 | """
115 | out = np.transpose(image, (1, 2, 0))
116 | if mode == 0:
117 | # original
118 | out = out
119 | elif mode == 1:
120 | # flip up and down
121 | out = np.flipud(out)
122 | elif mode == 2:
123 | # rotate counterwise 90 degree
124 | out = np.rot90(out)
125 | elif mode == 3:
126 | # rotate 90 degree and flip up and down
127 | out = np.rot90(out)
128 | out = np.flipud(out)
129 | elif mode == 4:
130 | # rotate 180 degree
131 | out = np.rot90(out, k=2)
132 | elif mode == 5:
133 | # rotate 180 degree and flip
134 | out = np.rot90(out, k=2)
135 | out = np.flipud(out)
136 | elif mode == 6:
137 | # rotate 270 degree
138 | out = np.rot90(out, k=3)
139 | elif mode == 7:
140 | # rotate 270 degree and flip
141 | out = np.rot90(out, k=3)
142 | out = np.flipud(out)
143 | else:
144 | raise Exception('Invalid choice of image transformation')
145 | return np.transpose(out, (2, 0, 1))
146 |
147 |
148 | def variable_to_cv2_image(varim):
149 | r"""Converts a torch.autograd.Variable to an OpenCV image
150 |
151 | Args:
152 | varim: a torch.autograd.Variable
153 | """
154 | nchannels = varim.size()[1]
155 | if nchannels == 1:
156 | res = (varim.data.cpu().numpy()[0, 0, :] * 255.).clip(0, 255).astype(np.uint8)
157 | elif nchannels == 3:
158 | res = varim.data.cpu().numpy()[0]
159 | res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
160 | res = (res * 255.).clip(0, 255).astype(np.uint8)
161 | else:
162 | raise Exception('Number of color channels not supported')
163 | return res
164 |
165 |
166 | def get_git_revision_short_hash():
167 | r"""Returns the current Git commit.
168 | """
169 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip()
170 |
171 |
172 | def init_logger(argdict):
173 | r"""Initializes a logging.Logger to save all the running parameters to a
174 | log file
175 |
176 | Args:
177 | argdict: dictionary of parameters to be logged
178 | """
179 | from os.path import join
180 |
181 | logger = logging.getLogger(__name__)
182 | logger.setLevel(level=logging.INFO)
183 | fh = logging.FileHandler(join(argdict.log_dir, 'log.txt'), mode='a')
184 | formatter = logging.Formatter('%(asctime)s - %(message)s')
185 | fh.setFormatter(formatter)
186 | logger.addHandler(fh)
187 | try:
188 | logger.info("Commit: {}".format(get_git_revision_short_hash()))
189 | except Exception as e:
190 | logger.error("Couldn't get commit number: {}".format(e))
191 | logger.info("Arguments: ")
192 | for k in argdict.__dict__:
193 | logger.info("\t{}: {}".format(k, argdict.__dict__[k]))
194 |
195 | return logger
196 |
197 |
198 | def init_logger_ipol():
199 | r"""Initializes a logging.Logger in order to log the results after
200 | testing a model
201 |
202 | Args:
203 | result_dir: path to the folder with the denoising results
204 | """
205 | logger = logging.getLogger('testlog')
206 | logger.setLevel(level=logging.INFO)
207 | fh = logging.FileHandler('out.txt', mode='w')
208 | formatter = logging.Formatter('%(message)s')
209 | fh.setFormatter(formatter)
210 | logger.addHandler(fh)
211 |
212 | return logger
213 |
214 |
215 | def init_logger_test(result_dir):
216 | r"""Initializes a logging.Logger in order to log the results after testing
217 | a model
218 |
219 | Args:
220 | result_dir: path to the folder with the denoising results
221 | """
222 | from os.path import join
223 |
224 | logger = logging.getLogger('testlog')
225 | logger.setLevel(level=logging.INFO)
226 | fh = logging.FileHandler(join(result_dir, 'log.txt'), mode='a')
227 | formatter = logging.Formatter('%(asctime)s - %(message)s')
228 | fh.setFormatter(formatter)
229 | logger.addHandler(fh)
230 |
231 | return logger
232 |
233 |
234 | def normalize(data):
235 | r"""Normalizes a unit8 image to a float32 image in the range [0, 1]
236 |
237 | Args:
238 | data: a unint8 numpy array to normalize from [0, 255] to [0, 1]
239 | """
240 | return np.float32(data / 255.)
241 |
242 |
243 | def svd_orthogonalization(lyr):
244 | r"""Applies regularization to the training by performing the
245 | orthogonalization technique described in the paper "FFDNet: Toward a fast
246 | and flexible solution for CNN based image denoising." Zhang et al. (2017).
247 | For each Conv layer in the model, the method replaces the matrix whose columns
248 | are the filters of the layer by new filters which are orthogonal to each other.
249 | This is achieved by setting the singular values of a SVD decomposition to 1.
250 |
251 | This function is to be called by the torch.nn.Module.apply() method,
252 | which applies svd_orthogonalization() to every layer of the model.
253 | """
254 | classname = lyr.__class__.__name__
255 | if classname.find('Conv') != -1:
256 | weights = lyr.weight.data.clone()
257 | c_out, c_in, f1, f2 = weights.size()
258 | dtype = lyr.weight.data.type()
259 |
260 | # Reshape filters to columns
261 | # From (c_out, c_in, f1, f2) to (f1*f2*c_in, c_out)
262 | weights = weights.permute(2, 3, 1, 0).contiguous().view(f1 * f2 * c_in, c_out)
263 |
264 | # Convert filter matrix to numpy array
265 | weights = weights.cpu().numpy()
266 |
267 | # SVD decomposition and orthogonalization
268 | mat_u, _, mat_vh = np.linalg.svd(weights, full_matrices=False)
269 | weights = np.dot(mat_u, mat_vh)
270 |
271 | # As full_matrices=False we don't need to set s[:] = 1 and do mat_u*s
272 | lyr.weight.data = torch.Tensor(weights).view(f1, f2, c_in, c_out). \
273 | permute(3, 2, 0, 1).type(dtype)
274 | else:
275 | pass
276 |
277 |
278 | def remove_dataparallel_wrapper(state_dict):
279 | r"""Converts a DataParallel model to a normal one by removing the "module."
280 | wrapper in the module dictionary
281 |
282 | Args:
283 | state_dict: a torch.nn.DataParallel state dictionary
284 | """
285 | from collections import OrderedDict
286 |
287 | new_state_dict = OrderedDict()
288 | for k, vl in state_dict.items():
289 | name = k[7:] # remove 'module.' of DataParallel
290 | new_state_dict[name] = vl
291 |
292 | return new_state_dict
293 |
294 |
295 | def is_rgb(im_path):
296 | r""" Returns True if the image in im_path is an RGB image
297 | """
298 | from skimage.io import imread
299 | rgb = False
300 | im = imread(im_path)
301 | if (len(im.shape) == 3):
302 | if not (np.allclose(im[..., 0], im[..., 1]) and np.allclose(im[..., 2], im[..., 1])):
303 | rgb = True
304 | print("rgb: {}".format(rgb))
305 | print("im shape: {}".format(im.shape))
306 | return rgb
307 |
308 |
309 | def add_noise(img_train, case, noiseIntL):
310 | noise_S = torch.zeros(img_train.size())
311 | if case == 3:
312 | beta1 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0])
313 | beta2 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0])
314 | beta3 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0])
315 | beta4 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0])
316 |
317 | for m in range(noise_S.size()[0]):
318 | sizeN_S = noise_S[0, 0, :, :].size()
319 | A1 = np.random.normal(0, beta1[m], sizeN_S[1]) # 一行向量
320 | A2 = np.random.normal(0, beta2[m], sizeN_S[1]) # 一行向量
321 | A3 = np.random.normal(0, beta3[m], sizeN_S[1]) # 一行向量
322 | A4 = np.random.normal(0, beta4[m], sizeN_S[1]) # 一行向量
323 | # 拉伸
324 | A1 = np.tile(A1, (sizeN_S[0], 1))
325 | A2 = np.tile(A2, (sizeN_S[0], 1))
326 | A3 = np.tile(A3, (sizeN_S[0], 1))
327 | A4 = np.tile(A4, (sizeN_S[0], 1))
328 | # add dim
329 | A1 = np.expand_dims(A1, 0)
330 | A2 = np.expand_dims(A2, 0)
331 | A3 = np.expand_dims(A3, 0)
332 | A4 = np.expand_dims(A4, 0)
333 | # to tensor
334 | A1 = torch.from_numpy(A1)
335 | A2 = torch.from_numpy(A2)
336 | A3 = torch.from_numpy(A3)
337 | A4 = torch.from_numpy(A4)
338 | imgn_train_m = A1 + A2 * img_train[m] + A3 * A3 * img_train[m] + A4 * A4 * A4 * img_train[m] + \
339 | img_train[m]
340 | imgn_train_m_c = torch.clip(imgn_train_m, 0., 1.)
341 | noise_S[m, :, :, :] = imgn_train_m_c
342 | imgn_train = noise_S
343 | return imgn_train
344 |
345 |
346 |
347 |
348 | def gaussian(window_size, sigma):
349 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
350 | return gauss / torch.sum(gauss) # 归一化
351 |
352 |
353 | # x=gaussian(3,1.5)
354 | # # print(x)
355 | # x=x.unsqueeze(1)
356 | # print(x.shape) #torch.Size([3,1])
357 | # print(x.t().unsqueeze(0).unsqueeze(0).shape) # torch.Size([1,1,1, 3])
358 |
359 |
360 | def create_window(window_size, channel=1):
361 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) # window_size,1
362 | # mm:矩阵乘法 t:转置矩阵 ->1,1,window_size,_window_size
363 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
364 | # expand:扩大张量的尺寸,比如3,1->3,4则意味将输入张量的列复制四份,
365 | # 1,1,window_size,_window_size->channel,1,window_size,_window_size
366 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
367 | return window
368 |
369 |
370 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
371 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
372 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
373 |
374 | mu1_sq = mu1.pow(2)
375 | mu2_sq = mu2.pow(2)
376 | mu1_mu2 = mu1 * mu2
377 |
378 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
379 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
380 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
381 |
382 | C1 = 0.01 ** 2
383 | C2 = 0.03 ** 2
384 |
385 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
386 |
387 | if size_average:
388 | return ssim_map.mean()
389 | else:
390 | return ssim_map.mean(1).mean(1).mean(1)
391 |
392 |
393 | # 构造损失函数用于网络训练或者普通计算SSIM值
394 | class SSIM(torch.nn.Module):
395 | def __init__(self, window_size=11, size_average=True):
396 | super(SSIM, self).__init__()
397 | self.window_size = window_size
398 | self.size_average = size_average
399 | self.channel = 1
400 | self.window = create_window(window_size, self.channel)
401 |
402 | def forward(self, img1, img2):
403 | (_, channel, _, _) = img1.size()
404 |
405 | if channel == self.channel and self.window.data.type() == img1.data.type():
406 | window = self.window
407 | else:
408 | window = create_window(self.window_size, channel)
409 |
410 | if img1.is_cuda:
411 | window = window.cuda(img1.get_device())
412 | window = window.type_as(img1)
413 |
414 | self.window = window
415 | self.channel = channel
416 |
417 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
418 |
419 |
420 | def weights_init_kaiming(lyr):
421 | r"""Initializes weights of the model according to the "He" initialization
422 | method described in "Delving deep into rectifiers: Surpassing human-level
423 | performance on ImageNet classification" - He, K. et al. (2015), using a
424 | normal distribution.
425 | This function is to be called by the torch.nn.Module.apply() method,
426 | which applies weights_init_kaiming() to every layer of the model.
427 | """
428 | classname = lyr.__class__.__name__
429 | if classname.find('Conv') != -1:
430 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in')
431 | elif classname.find('Linear') != -1:
432 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in')
433 | elif classname.find('BatchNorm') != -1:
434 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)). \
435 | clamp_(-0.025, 0.025)
436 | nn.init.constant(lyr.bias.data, 0.0)
437 |
438 |
439 | def findLastCheckpoint(save_dir):
440 | file_list = glob.glob(os.path.join(save_dir, '*epoch*.pth'))
441 | if file_list:
442 | epochs_exist = []
443 | for file_ in file_list:
444 | result = re.findall(".*epoch(.*).pth.*", file_)
445 | epochs_exist.append(int(result[0]))
446 | initial_epoch = max(epochs_exist)
447 | else:
448 | initial_epoch = 0
449 | return initial_epoch
450 |
451 |
452 | def batch_PSNR(img, imclean, data_range):
453 | Img = img.data.cpu().numpy().astype(np.float32)
454 | Iclean = imclean.data.cpu().numpy().astype(np.float32)
455 | PSNR = 0
456 | for i in range(Img.shape[0]):
457 | PSNR += peak_signal_noise_ratio(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
458 | return (PSNR / Img.shape[0])
459 |
460 |
461 | def normalize(data):
462 | return data / 255.
463 |
464 |
465 | def is_image(img_name):
466 | if img_name.endswith(".jpg") or img_name.endswith(".bmp") or img_name.endswith(".png"):
467 | return True
468 | else:
469 | return False
470 |
471 |
472 | def print_network(net):
473 | num_params = 0
474 | for param in net.parameters():
475 | num_params += param.numel()
476 | print(net)
477 | print('Total number of parameters: %d' % num_params)
478 |
479 |
480 | class ImagePool():
481 | """This class implements an image buffer that stores previously generated images.
482 |
483 | This buffer enables us to update discriminators using a history of generated images
484 | rather than the ones produced by the latest generators.
485 | """
486 |
487 | def __init__(self, pool_size):
488 | """Initialize the ImagePool class
489 |
490 | Parameters:
491 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
492 | """
493 | self.pool_size = pool_size
494 | if self.pool_size > 0: # create an empty pool
495 | self.num_imgs = 0
496 | self.images = []
497 |
498 | def query(self, images):
499 | """Return an image from the pool.
500 |
501 | Parameters:
502 | images: the latest generated images from the generator
503 |
504 | Returns images from the buffer.
505 |
506 | By 50/100, the buffer will return input images.
507 | By 50/100, the buffer will return images previously stored in the buffer,
508 | and insert the current images to the buffer.
509 | """
510 | if self.pool_size == 0: # if the buffer size is 0, do nothing
511 | return images
512 | return_images = []
513 | for image in images:
514 | image = torch.unsqueeze(image.data, 0)
515 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
516 | self.num_imgs = self.num_imgs + 1
517 | self.images.append(image)
518 | return_images.append(image)
519 | else:
520 | p = random.uniform(0, 1)
521 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
522 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
523 | tmp = self.images[random_id].clone()
524 | self.images[random_id] = image
525 | return_images.append(tmp)
526 | else: # by another 50% chance, the buffer will return the current image
527 | return_images.append(image)
528 | return_images = torch.cat(return_images, 0) # collect all the images and return
529 | return return_images
530 |
531 |
532 | class GANLoss(nn.Module):
533 | """Define different GAN objectives.
534 |
535 | The GANLoss class abstracts away the need to create the target label tensor
536 | that has the same size as the input.
537 | """
538 |
539 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
540 | """ Initialize the GANLoss class.
541 |
542 | Parameters:
543 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
544 | target_real_label (bool) - - label for a real image
545 | target_fake_label (bool) - - label of a fake image
546 |
547 | Note: Do not use sigmoid as the last layer of Discriminator.
548 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
549 | """
550 | super(GANLoss, self).__init__()
551 | self.register_buffer('real_label', torch.tensor(target_real_label))
552 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
553 | self.gan_mode = gan_mode
554 | if gan_mode == 'lsgan':
555 | self.loss = nn.MSELoss()
556 | elif gan_mode == 'vanilla':
557 | self.loss = nn.BCEWithLogitsLoss()
558 | elif gan_mode in ['wgangp']:
559 | self.loss = None
560 | else:
561 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
562 |
563 | def get_target_tensor(self, prediction, target_is_real):
564 | """Create label tensors with the same size as the input.
565 |
566 | Parameters:
567 | prediction (tensor) - - tpyically the prediction from a discriminator
568 | target_is_real (bool) - - if the ground truth label is for real images or fake images
569 |
570 | Returns:
571 | A label tensor filled with ground truth label, and with the size of the input
572 | """
573 |
574 | if target_is_real:
575 | target_tensor = self.real_label
576 | else:
577 | target_tensor = self.fake_label
578 | return target_tensor.expand_as(prediction)
579 |
580 | def __call__(self, prediction, target_is_real):
581 | """Calculate loss given Discriminator's output and grount truth labels.
582 |
583 | Parameters:
584 | prediction (tensor) - - tpyically the prediction output from a discriminator
585 | target_is_real (bool) - - if the ground truth label is for real images or fake images
586 |
587 | Returns:
588 | the calculated loss.
589 | """
590 | if self.gan_mode in ['lsgan', 'vanilla']:
591 | target_tensor = self.get_target_tensor(prediction, target_is_real)
592 | # pdb.set_trace()
593 | loss = self.loss(prediction, target_tensor)
594 | elif self.gan_mode == 'wgangp':
595 | if target_is_real:
596 | loss = -prediction.mean()
597 | else:
598 | loss = prediction.mean()
599 | return loss
600 |
601 |
602 | class TVloss(nn.Module):
603 | def __init__(self, TVloss_weight=1):
604 | super(TVloss, self).__init__()
605 | self.TVloss_weight = TVloss_weight
606 | # self.x = x
607 | # self.y = y
608 |
609 | def forward(self, x, y):
610 | # x = self.x
611 | # y = self.y
612 | batch_size = x.size()[0]
613 | h_x = x.size()[2]
614 | w_x = x.size()[3]
615 | # count_h = self._tensor_size(x[:, :, 1:, :]) # 算出总共求了多少次差
616 | # count_w = self._tensor_size(x[:, :, :, 1:])
617 | # h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
618 |
619 | # x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片
620 | # 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个
621 | # 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相
622 | # 邻的下一个像素点的差。
623 | w_tv_x = (x[:, :, :, 1:] - x[:, :, :, :w_x - 1])
624 | w_tv_y = (y[:, :, :, 1:] - y[:, :, :, :w_x - 1])
625 | h_tv_x = (x[:, :, 1:, :] - x[:, :, :h_x - 1, :])
626 | h_tv_y = (y[:, :, 1:, :] - y[:, :, :h_x - 1, :])
627 | MSE = torch.nn.MSELoss()
628 | TVloss = (MSE(h_tv_x, h_tv_y) + MSE(w_tv_x, w_tv_y)) * 0.5
629 | # Drecloss_stripe = torch.pow((w_tv_y - w_tv_x), 2)
630 | # Drecloss_stripe = (w_tv_y - w_tv_x)**2
631 | return self.TVloss_weight * TVloss
632 |
633 | def _tensor_size(self, t):
634 | return t.size()[1] * t.size()[2] * t.size()[3]
635 |
636 |
637 | def WRRGM(A, B):
638 | DWT = DWTForward(J=3, wave='haar').cuda()
639 | IDWT = DWTInverse(wave='haar').cuda()
640 | DMT3_yl, DMT3_yh = DWT(A)
641 | for tensor in DMT3_yh:
642 | tensor.zero_()
643 | out1 = IDWT((DMT3_yl, DMT3_yh))
644 |
645 | DMT3_yl, DMT3_yh = DWT(B)
646 | for tensor in DMT3_yh:
647 | tensor.zero_()
648 | out2 = IDWT((DMT3_yl, DMT3_yh))
649 |
650 | return out1, out2
651 |
652 |
653 | class MS_SSIM_L1_LOSS(nn.Module):
654 | """
655 | Have to use cuda, otherwise the speed is too slow.
656 | Both the group and shape of input image should be attention on.
657 | I set 255 and 1 for gray image as default.
658 | """
659 |
660 | def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
661 | data_range=1.0,
662 | K=(0.01, 0.03), # c1,c2
663 | alpha=0.025, # weight of ssim and l1 loss
664 | compensation=1.0, # final factor for total loss
665 | cuda_dev=0, # cuda device choice
666 | channel=3): # RGB image should set to 3 and Gray image should be set to 1
667 | super(MS_SSIM_L1_LOSS, self).__init__()
668 | self.channel = channel
669 | self.DR = data_range
670 | self.C1 = (K[0] * data_range) ** 2
671 | self.C2 = (K[1] * data_range) ** 2
672 | self.pad = int(2 * gaussian_sigmas[-1])
673 | self.alpha = alpha
674 | self.compensation = compensation
675 | filter_size = int(4 * gaussian_sigmas[-1] + 1)
676 | g_masks = torch.zeros(
677 | (self.channel * len(gaussian_sigmas), 1, filter_size, filter_size)) # 创建了(3*5, 1, 33, 33)个masks
678 | for idx, sigma in enumerate(gaussian_sigmas):
679 | if self.channel == 1:
680 | # only gray layer
681 | g_masks[idx, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
682 | elif self.channel == 3:
683 | # r0,g0,b0,r1,g1,b1,...,rM,gM,bM
684 | g_masks[self.channel * idx + 0, 0, :, :] = self._fspecial_gauss_2d(filter_size,
685 | sigma) # 每层mask对应不同的sigma
686 | g_masks[self.channel * idx + 1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
687 | g_masks[self.channel * idx + 2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
688 | else:
689 | raise ValueError
690 | self.g_masks = g_masks.cuda(cuda_dev) # 转换为cuda数据类型
691 |
692 | def _fspecial_gauss_1d(self, size, sigma):
693 | """Create 1-D gauss kernel
694 | Args:
695 | size (int): the size of gauss kernel
696 | sigma (float): sigma of normal distribution
697 |
698 | Returns:
699 | torch.Tensor: 1D kernel (size)
700 | """
701 | coords = torch.arange(size).to(dtype=torch.float)
702 | coords -= size // 2
703 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
704 | g /= g.sum()
705 | return g.reshape(-1)
706 |
707 | def _fspecial_gauss_2d(self, size, sigma):
708 | """Create 2-D gauss kernel
709 | Args:
710 | size (int): the size of gauss kernel
711 | sigma (float): sigma of normal distribution
712 |
713 | Returns:
714 | torch.Tensor: 2D kernel (size x size)
715 | """
716 | gaussian_vec = self._fspecial_gauss_1d(size, sigma)
717 | return torch.outer(gaussian_vec, gaussian_vec)
718 | # Outer product of input and vec2. If input is a vector of size nn and vec2 is a vector of size mm,
719 | # then out must be a matrix of size (n \times m)(n×m).
720 |
721 | def forward(self, x, y):
722 | b, c, h, w = x.shape
723 | assert c == self.channel
724 |
725 | mux = F.conv2d(x, self.g_masks, groups=c, padding=self.pad) # 图像为96*96,和33*33卷积,出来的是64*64,加上pad=16,出来的是96*96
726 | muy = F.conv2d(y, self.g_masks, groups=c, padding=self.pad) # groups 是分组卷积,为了加快卷积的速度
727 |
728 | mux2 = mux * mux
729 | muy2 = muy * muy
730 | muxy = mux * muy
731 |
732 | sigmax2 = F.conv2d(x * x, self.g_masks, groups=c, padding=self.pad) - mux2
733 | sigmay2 = F.conv2d(y * y, self.g_masks, groups=c, padding=self.pad) - muy2
734 | sigmaxy = F.conv2d(x * y, self.g_masks, groups=c, padding=self.pad) - muxy
735 |
736 | # l(j), cs(j) in MS-SSIM
737 | l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W]
738 | cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)
739 | if self.channel == 3:
740 | lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :] # 亮度对比因子
741 | PIcs = cs.prod(dim=1)
742 | elif self.channel == 1:
743 | lM = l[:, -1, :, :]
744 | PIcs = cs.prod(dim=1)
745 |
746 | loss_ms_ssim = 1 - lM * PIcs # [B, H, W]
747 | # print(loss_ms_ssim)
748 |
749 | loss_l1 = F.l1_loss(x, y, reduction='none') # [B, C, H, W]
750 | # average l1 loss in num channels
751 | gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-self.channel, length=self.channel),
752 | groups=c, padding=self.pad).mean(1) # [B, H, W]
753 |
754 | loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR
755 | loss_mix = self.compensation * loss_mix
756 |
757 | return loss_mix.mean()
758 |
--------------------------------------------------------------------------------
/warmup_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 | class GradualWarmupScheduler(_LRScheduler):
5 | """ Gradually warm-up(increasing) learning rate in optimizer.
6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7 | 在optimizer中会设置一个基础学习率base lr,
8 | 当multiplier>1时,预热机制会在total_epoch内把学习率从base lr逐渐增加到multiplier*base lr,再接着开始正常的scheduler
9 | 当multiplier==1.0时,预热机制会在total_epoch内把学习率从0逐渐增加到base lr,再接着开始正常的scheduler
10 | Args:
11 | optimizer (Optimizer): Wrapped optimizer.
12 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
13 | total_epoch: target learning rate is reached at total_epoch, gradually
14 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
15 | """
16 |
17 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
18 | self.multiplier = multiplier
19 | if self.multiplier < 1.:
20 | raise ValueError('multiplier should be greater thant or equal to 1.')
21 | self.total_epoch = total_epoch
22 | self.after_scheduler = after_scheduler
23 | self.finished = False
24 | super(GradualWarmupScheduler, self).__init__(optimizer)
25 |
26 | def get_lr(self):
27 | if self.last_epoch > self.total_epoch:
28 | if self.after_scheduler and (not self.finished):
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | # !这是很关键的一个环节,需要直接返回新的base-lr
32 | return [base_lr for base_lr in self.after_scheduler.base_lrs]
33 | if self.multiplier == 1.0:
34 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
35 | else:
36 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
37 |
38 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
39 | if epoch is None:
40 | epoch = self.last_epoch + 1
41 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
42 | print('warmuping...')
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr=None
45 | if self.multiplier == 1.0:
46 | warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
47 | else:
48 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
49 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
50 | param_group['lr'] = lr
51 | else:
52 | if epoch is None:
53 | self.after_scheduler.step(metrics, None)
54 | else:
55 | self.after_scheduler.step(metrics,epoch - self.total_epoch)
56 |
57 | def step(self, epoch=None, metrics=None):
58 | if type(self.after_scheduler) != ReduceLROnPlateau:
59 | if self.finished and self.after_scheduler:
60 | if epoch is None:
61 | self.after_scheduler.step(None)
62 | else:
63 | self.after_scheduler.step(epoch - self.total_epoch)
64 | self._last_lr = self.after_scheduler.get_last_lr()
65 | else:
66 | return super(GradualWarmupScheduler, self).step(epoch)
67 | else:
68 | self.step_ReduceLROnPlateau(metrics, epoch)
69 |
--------------------------------------------------------------------------------