├── Fig ├── Building.png ├── Car.png ├── Fig0.png ├── Fig1.png ├── Fig2.png ├── Fig3.png ├── Fig4.png ├── Mars.png └── Road.png ├── README.md ├── Real ├── 044.png ├── 045.png ├── 046.png ├── 047.png ├── 048.png ├── 049.png ├── 050.png ├── 051.png ├── 052.png ├── 053.png ├── 054.png ├── 055.png ├── 056.png ├── 057.png ├── 058.png ├── 059.png └── 060.png ├── dataset.py ├── demo.py ├── model ├── ASCNet.py └── cbam.py ├── prepare_patches.py ├── test.py ├── train.py ├── utils.py └── warmup_scheduler.py /Fig/Building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Building.png -------------------------------------------------------------------------------- /Fig/Car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Car.png -------------------------------------------------------------------------------- /Fig/Fig0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig0.png -------------------------------------------------------------------------------- /Fig/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig1.png -------------------------------------------------------------------------------- /Fig/Fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig2.png -------------------------------------------------------------------------------- /Fig/Fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig3.png -------------------------------------------------------------------------------- /Fig/Fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Fig4.png -------------------------------------------------------------------------------- /Fig/Mars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Mars.png -------------------------------------------------------------------------------- /Fig/Road.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Fig/Road.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This is the code of paper "ASCNet: Asymmetric Sampling Correction Network for Infrared Image Destriping".[[Paper]](https://ieeexplore.ieee.org/document/10855453) [[Weight]](https://drive.google.com/file/d/1zbBsWUbRVBjNckPg5DiCgKIKOKWnQ2N8/view?usp=sharing) 2 | Shuai Yuan, Hanlin Qin, Xiang Yan, Shiqi Yang, Shuowen Yang, Naveed Akhtar, Huixin Zhou, IEEE Transactions on Geoscience and Remote Sensing 2025. 3 | # Real Destriping Examples 4 | 5 | [](https://imgsli.com/MjkxNDU2) | [](https://imgsli.com/MjkxNDU4) 6 | :-------------------------:|:-------------------------: 7 | Mars | Building 8 | 9 | 10 | [](https://imgsli.com/MjkxNDU5) | [](https://imgsli.com/MjkxNDYx) 11 | :-------------------------:|:-------------------------: 12 | Road | Car 13 | 14 | 15 | # Chanlleges and inspiration 16 | ![Image text](https://github.com/xdFai/ASCNet/blob/main/Fig/Fig0.png) 17 | 18 | # Structure 19 | ![Image text](https://github.com/xdFai/ASCNet/blob/main/Fig/Fig2.png) 20 | 21 | ![Image text](https://github.com/xdFai/ASCNet/blob/main/Fig/Fig3.png) 22 | 23 | 24 | ## Usage 25 | 26 | #### 1. Dataset 27 | Training dataset: [[Data]](https://drive.google.com/file/d/1o9BmWspPTJtFsBj66NN3FfM83cjp37IW/view?usp=sharing) 28 | 29 | Training dataset augmentation: [[Data_AUG]](https://drive.google.com/file/d/1Iv4CoQiInFORYn1kHjJCCCeuy6LKvnIc/view?usp=sharing) 30 | 31 | Validation dataset: [[clean]](https://drive.google.com/file/d/1WYYZCoEooOXDG49YJXJiNkCtVFgGdx2J/view?usp=sharing), [[dirty]](https://drive.google.com/file/d/1D1NAyMLbso_UL-YRqYfPduFR-Zs8g2sx/view?usp=sharing) 32 | 33 | ##### 2. Train. 34 | ```bash 35 | python train.py 36 | ``` 37 | 38 | #### 3. Test and demo. [[Weight]](https://drive.google.com/file/d/1zbBsWUbRVBjNckPg5DiCgKIKOKWnQ2N8/view?usp=sharing) 39 | ```bash 40 | python test.py 41 | ``` 42 | If the implementation of this repo is helpful to you, just star it!⭐⭐⭐ 43 | 44 | If you find the code useful, please consider citing our paper using the following BibTeX entry. 45 | 46 | ``` 47 | @ARTICLE{10855453, 48 | author={Yuan, Shuai and Qin, Hanlin and Yan, Xiang and Yang, Shiqi and Yang, Shuowen and Akhtar, Naveed and Zhou, Huixin}, 49 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 50 | title={ASCNet: Asymmetric Sampling Correction Network for Infrared Image Destriping}, 51 | year={2025}, 52 | volume={63}, 53 | number={}, 54 | pages={1-15}, 55 | keywords={Noise;Discrete wavelet transforms;Semantics;Image reconstruction;Feature extraction;Neural networks;Filters;Crosstalk;Aggregates;Geoscience and remote sensing;Asymmetric sampling (AS);column correction;deep learning;infrared (IR) image destriping;wavelet transform}, 56 | doi={10.1109/TGRS.2025.3534838}} 57 | ``` 58 | 59 | ## Contact 60 | **Welcome to raise issues or email to [yuansy@stu.xidian.edu.cn](yuansy@stu.xidian.edu.cn) for any question regarding our ASCNet.** 61 | -------------------------------------------------------------------------------- /Real/044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/044.png -------------------------------------------------------------------------------- /Real/045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/045.png -------------------------------------------------------------------------------- /Real/046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/046.png -------------------------------------------------------------------------------- /Real/047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/047.png -------------------------------------------------------------------------------- /Real/048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/048.png -------------------------------------------------------------------------------- /Real/049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/049.png -------------------------------------------------------------------------------- /Real/050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/050.png -------------------------------------------------------------------------------- /Real/051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/051.png -------------------------------------------------------------------------------- /Real/052.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/052.png -------------------------------------------------------------------------------- /Real/053.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/053.png -------------------------------------------------------------------------------- /Real/054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/054.png -------------------------------------------------------------------------------- /Real/055.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/055.png -------------------------------------------------------------------------------- /Real/056.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/056.png -------------------------------------------------------------------------------- /Real/057.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/057.png -------------------------------------------------------------------------------- /Real/058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/058.png -------------------------------------------------------------------------------- /Real/059.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/059.png -------------------------------------------------------------------------------- /Real/060.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/ASCNet/4ea618c18366505e8d9a15f47fdb0d0c14941568/Real/060.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset related functions 3 | 4 | Copyright (C) 2018, Matias Tassano 5 | 6 | This program is free software: you can use, modify and/or 7 | redistribute it under the terms of the GNU General Public 8 | License as published by the Free Software Foundation, either 9 | version 3 of the License, or (at your option) any later 10 | version. You should have received a copy of this license along 11 | this program. If not, see . 12 | """ 13 | import os 14 | import os.path 15 | import random 16 | import glob 17 | import numpy as np 18 | import cv2 19 | import h5py 20 | import torch 21 | import torch.utils.data as udata 22 | from PIL import Image 23 | 24 | from utils import data_augmentation, normalize 25 | class Dataset(udata.Dataset): 26 | r"""Implements torch.utils.data.Dataset 27 | """ 28 | 29 | def __init__(self, train=True, gray_mode=False, shuffle=False): 30 | super(Dataset, self).__init__() 31 | self.train = train 32 | self.gray_mode = gray_mode 33 | if not self.gray_mode: 34 | self.traindbf = 'train_rgb.h5' 35 | self.valdbf = 'val_rgb.h5' 36 | self.valdirtydbf = 'val_dirty_rgb.h5' 37 | else: 38 | self.traindbf = 'train_gray.h5' 39 | self.valdbf = 'val_gray.h5' 40 | self.valdirtydbf = 'val_dirty_gray.h5' 41 | 42 | if self.train: 43 | h5f = h5py.File(self.traindbf, 'r') 44 | self.keys = list(h5f.keys()) 45 | if shuffle: 46 | random.shuffle(self.keys) 47 | h5f.close() 48 | else: 49 | h5f = h5py.File(self.valdbf, 'r') 50 | h5f_dirty = h5py.File(self.valdirtydbf, 'r') 51 | self.keys = list(h5f.keys()) 52 | if shuffle: 53 | random.shuffle(self.keys) 54 | h5f.close() 55 | h5f_dirty.close() 56 | 57 | def __len__(self): 58 | return len(self.keys) 59 | 60 | def __getitem__(self, index): 61 | # 从 计算机的具体路径下 读图片 转化为 pytroch框架可以认识的形式 62 | # pytroch: tensor张量 63 | 64 | if self.train: 65 | h5f = h5py.File(self.traindbf, 'r') 66 | key = self.keys[index] 67 | data = np.array(h5f[key]) 68 | h5f.close() 69 | return torch.Tensor(data) 70 | else: 71 | h5f = h5py.File(self.valdbf, 'r') 72 | h5f_dirty = h5py.File(self.valdirtydbf, 'r') 73 | key = self.keys[index] 74 | data_clean = np.array(h5f[key]) 75 | data_dirty = np.array(h5f_dirty[key]) 76 | h5f.close() 77 | h5f_dirty.close() 78 | return torch.Tensor(data_clean), torch.Tensor(data_dirty) 79 | 80 | 81 | def img_to_patches(img, win, stride=1): 82 | r"""Converts an image to an array of patches. 83 | 84 | Args: 85 | img: a numpy array containing a CxHxW RGB (C=3) or grayscale (C=1) 86 | image 87 | win: size of the output patches 88 | stride: int. stride 89 | """ 90 | k = 0 91 | endc = img.shape[0] 92 | endw = img.shape[1] 93 | endh = img.shape[2] 94 | patch = img[:, 0:endw - win + 0 + 1:stride, 0:endh - win + 0 + 1:stride] 95 | total_pat_num = patch.shape[1] * patch.shape[2] 96 | res = np.zeros([endc, win * win, total_pat_num], np.float32) 97 | for i in range(win): 98 | for j in range(win): 99 | patch = img[:, i:endw - win + i + 1:stride, j:endh - win + j + 1:stride] 100 | res[:, k, :] = np.array(patch[:]).reshape(endc, total_pat_num) 101 | k = k + 1 102 | return res.reshape([endc, win, win, total_pat_num]) 103 | 104 | def prepare_data(data_path, \ 105 | val_data_path, \ 106 | val_data_dirty_path, \ 107 | patch_size, \ 108 | stride, \ 109 | max_num_patches=None, \ 110 | aug_times=1, \ 111 | gray_mode=True): 112 | r"""Builds the training and validations datasets by scanning the 113 | corresponding directories for images and extracting patches from them. 114 | 115 | Args: 116 | data_path: path containing the training image dataset 117 | val_data_path: path containing the validation image dataset 118 | patch_size: size of the patches to extract from the images 119 | stride: size of stride to extract patches 120 | stride: size of stride to extract patches 121 | max_num_patches: maximum number of patches to extract 122 | aug_times: number of times to augment the available data minus one 123 | gray_mode: build the databases composed of grayscale patches 124 | """ 125 | # training database 126 | print('> Training database') 127 | # scales = [1, 0.9, 0.8, 0.7] 128 | scales = [1, 0.8, 0.6, 0.4] 129 | # scales = [1] 130 | types = ('*.bmp', '*.png') 131 | files = [] 132 | for tp in types: 133 | files.extend(glob.glob(os.path.join(data_path, tp))) 134 | files.sort() 135 | 136 | if gray_mode: 137 | traindbf = 'train_gray.h5' 138 | valdbf = 'val_gray.h5' 139 | valdirtydbf = 'val_dirty_gray.h5' 140 | else: 141 | traindbf = 'train_rgb.h5' 142 | valdbf = 'val_rgb.h5' 143 | valdirtydbf = 'val_dirty_rgb.h5' 144 | 145 | if max_num_patches is None: 146 | max_num_patches = 5000000 147 | print("\tMaximum number of patches not set") 148 | else: 149 | print("\tMaximum number of patches set to {}".format(max_num_patches)) 150 | train_num = 0 151 | i = 0 152 | with h5py.File(traindbf, 'w') as h5f: 153 | while i < len(files) and train_num < max_num_patches: 154 | imgor = cv2.imread(files[i]) 155 | # h, w, c = img.shape 156 | for sca in scales: 157 | img = cv2.resize(imgor, (0, 0), fx=sca, fy=sca, \ 158 | interpolation=cv2.INTER_CUBIC) 159 | if not gray_mode: 160 | # CxHxW RGB image 161 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) 162 | else: 163 | # CxHxW grayscale image (C=1) 164 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 165 | img = np.expand_dims(img, 0) 166 | img = normalize(img) 167 | # argument 168 | patches = img_to_patches(img, win=patch_size, stride=stride) 169 | # data = patches[:, :, :, 1] 170 | # data = data * 255 171 | # im = Image.fromarray(data.reshape(64, 64)) # 这是把numpy转化成了PIL 172 | # im.show() 173 | 174 | print("\tfile: %s scale %.1f # samples: %d" % \ 175 | (files[i], sca, patches.shape[3] * aug_times)) 176 | for nx in range(patches.shape[3]): 177 | data = data_augmentation(patches[:, :, :, nx].copy(), \ 178 | np.random.randint(0, 7)) 179 | h5f.create_dataset(str(train_num), data=data) 180 | train_num += 1 181 | for mx in range(aug_times - 1): 182 | data_aug = data_augmentation(data, np.random.randint(1, 4)) 183 | h5f.create_dataset(str(train_num) + "_aug_%d" % (mx + 1), data=data_aug) 184 | train_num += 1 185 | i += 1 186 | 187 | # validation database 188 | print('\n> Validation database') 189 | files = [] 190 | for tp in types: 191 | files.extend(glob.glob(os.path.join(val_data_path, tp))) 192 | files.sort() 193 | h5f = h5py.File(valdbf, 'w') 194 | val_num = 0 195 | for i, item in enumerate(files): 196 | print("\tfile: %s" % item) 197 | img = cv2.imread(item) 198 | if not gray_mode: 199 | # C. H. W, RGB image 200 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) 201 | else: 202 | # C, H, W grayscale image (C=1) 203 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 204 | img = np.expand_dims(img, 0) 205 | img = normalize(img) 206 | h5f.create_dataset(str(val_num), data=img) 207 | val_num += 1 208 | h5f.close() 209 | print('\n> Validation dirty database') 210 | files_dirty = [] 211 | for tp in types: 212 | files_dirty.extend(glob.glob(os.path.join(val_data_dirty_path, tp))) 213 | files_dirty.sort() 214 | h5f = h5py.File(valdirtydbf, 'w') 215 | val_num_dirty = 0 216 | for i, item in enumerate(files_dirty): 217 | print("\tfile: %s" % item) 218 | img = cv2.imread(item) 219 | if not gray_mode: 220 | # C. H. W, RGB image 221 | img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) 222 | else: 223 | # C, H, W grayscale image (C=1) 224 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 225 | img = np.expand_dims(img, 0) 226 | img = normalize(img) 227 | h5f.create_dataset(str(val_num_dirty), data=img) 228 | val_num_dirty += 1 229 | h5f.close() 230 | 231 | print('\n> Total') 232 | print('\ttraining set, # samples %d' % train_num) 233 | print('\tvalidation set, # samples %d\n' % val_num) 234 | print('\tvalidation dirty set, # samples %d\n' % val_num_dirty) 235 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | # from SSIM import * 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import pywt 7 | import torch.nn as nn 8 | import argparse 9 | from model.ASCNet import ASCNet 10 | import time 11 | 12 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | 14 | parser = argparse.ArgumentParser(description="Demo") 15 | parser.add_argument("--log_path", type=str, 16 | default=r"XXXX.pth") 17 | parser.add_argument("--filename", type=str, default=r"XXXXX") 18 | 19 | parser.add_argument("--savepth", type=str, default=r"XXXXX", 20 | help='path of result image file') 21 | 22 | parser.add_argument("--mk", type=str, default=r"XXXXX/", 23 | help='path of result image file') 24 | 25 | 26 | opt = parser.parse_args() 27 | 28 | model = ASCNet(1, 1, feats=32) 29 | model = nn.DataParallel(model) 30 | model.load_state_dict(torch.load(opt.log_path,map_location='cpu')) 31 | 32 | namelist = os.listdir(opt.filename) 33 | namelist.sort() 34 | 35 | if os.path.exists(opt.mk): 36 | pass 37 | else: 38 | os.makedirs(opt.mk) 39 | 40 | 41 | # def normalization(data): 42 | # _range = np.max(data) - np.min(data) 43 | # return (data - np.min(data)) / _range 44 | 45 | 46 | model.eval() 47 | for name in namelist: 48 | image = cv2.imread(os.path.join(opt.filename, name)) 49 | img_np = np.expand_dims(image[:, :, 0], 0) 50 | img_np = np.float32(img_np / 255.) 51 | img_tensor = torch.from_numpy(img_np) 52 | img_tensor = torch.unsqueeze(img_tensor, 0) 53 | # time_start = time.time() 54 | out = model(img_tensor) 55 | # time_end = time.time() 56 | out_np = out.data.cpu().numpy() 57 | # time_c = time_end - time_start # 运行所花时间 58 | # print('time cost', time_c, 's') 59 | out_val = out_np[0, :, :, :] 60 | 61 | out_val = np.transpose(out_val, (1, 2, 0)) 62 | 63 | 64 | # Clamp 65 | out_val = out_val * 255 66 | out_valf = np.clip(out_val, 0, 255) 67 | 68 | # Normalization 69 | # final=normalization(out_val) 70 | # out_valf = final * 255 71 | 72 | cv2.imwrite(os.path.join(opt.savepth, name), out_valf.astype("uint8")) 73 | -------------------------------------------------------------------------------- /model/ASCNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from pytorch_wavelets import DWTForward, DWTInverse 4 | from torchvision import transforms 5 | from model.cbam import * 6 | import cv2 7 | from utils import weights_init_kaiming 8 | import os 9 | from thop import profile 10 | from thop import clever_format 11 | # from torchvision import transforms 12 | # import matplotlib.pyplot as plt 13 | from torch.autograd import Variable 14 | from torchvision import models 15 | import numpy as np 16 | 17 | 18 | 19 | class DWT(nn.Module): 20 | def __init__(self): 21 | super(DWT, self).__init__() 22 | self.requires_grad = False 23 | 24 | def forward(self, x): 25 | return dwt_init(x) 26 | 27 | 28 | class IWT(nn.Module): 29 | def __init__(self): 30 | super(IWT, self).__init__() 31 | self.requires_grad = False 32 | 33 | def forward(self, x): 34 | return iwt_init(x) 35 | 36 | 37 | # double_conv model 38 | class double_conv(nn.Module): 39 | def __init__(self, in_channels, out_channels): 40 | super(double_conv, self).__init__() 41 | self.d_conv = nn.Sequential( 42 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 43 | nn.LeakyReLU(inplace=True), 44 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 45 | nn.LeakyReLU(inplace=True) 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.d_conv(x) 50 | return x 51 | 52 | 53 | class single_conv(nn.Module): 54 | def __init__(self, in_channels, out_channels): 55 | super(single_conv, self).__init__() 56 | self.s_conv = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 58 | nn.LeakyReLU(inplace=True), 59 | ) 60 | 61 | def forward(self, x): 62 | x = self.s_conv(x) 63 | return x 64 | 65 | 66 | class single_conv_res(nn.Module): 67 | def __init__(self, in_channels, out_channels): 68 | super(single_conv_res, self).__init__() 69 | self.s_conv = nn.Sequential( 70 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 71 | nn.LeakyReLU(inplace=True), 72 | ) 73 | 74 | def forward(self, x): 75 | residual = x 76 | x = self.s_conv(x) 77 | out = torch.add(x, residual) 78 | return out 79 | 80 | 81 | class conv11(nn.Module): 82 | def __init__(self, in_channels, out_channels): 83 | super(conv11, self).__init__() 84 | self.s_conv = nn.Conv2d(in_channels, out_channels, 1) 85 | 86 | def forward(self, x): 87 | x = self.s_conv(x) 88 | return x 89 | 90 | 91 | class conv33(nn.Module): 92 | def __init__(self, in_channels, out_channels): 93 | super(conv33, self).__init__() 94 | self.s_conv = nn.Conv2d(in_channels, out_channels, 3, padding=1) 95 | 96 | def forward(self, x): 97 | x = self.s_conv(x) 98 | return x 99 | 100 | 101 | class ChannelPool(nn.Module): 102 | def forward(self, x): 103 | # 将maxpooling 与 global average pooling 结果拼接在一起 104 | return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 105 | 106 | 107 | class Basic(nn.Module): 108 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, relu=True, bn=True, bias=False): 109 | super(Basic, self).__init__() 110 | self.out_channels = out_planes 111 | self.conv = nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, stride=stride, 112 | padding=padding, bias=bias) 113 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 114 | self.relu = nn.LeakyReLU() if relu else None 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | if self.bn is not None: 119 | x = self.bn(x) 120 | if self.relu is not None: 121 | x = self.relu(x) 122 | return x 123 | 124 | 125 | class CALayer(nn.Module): 126 | def __init__(self, channel, reduction=16): 127 | super(CALayer, self).__init__() 128 | 129 | self.avgPoolW = nn.AdaptiveAvgPool2d((1, None)) 130 | self.maxPoolW = nn.AdaptiveMaxPool2d((1, None)) 131 | 132 | 133 | self.conv_1x1 = nn.Conv2d(in_channels=2 * channel, out_channels=2 * channel, kernel_size=1, padding=0, stride=1, 134 | bias=False) 135 | self.bn = nn.BatchNorm2d(2 * channel, eps=1e-5, momentum=0.01, affine=True) 136 | self.Relu = nn.LeakyReLU() 137 | 138 | self.F_h = nn.Sequential( # 激发操作 139 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 140 | nn.BatchNorm2d(channel // reduction, eps=1e-5, momentum=0.01, affine=True), 141 | nn.ReLU(inplace=True), 142 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 143 | ) 144 | self.F_w = nn.Sequential( # 激发操作 145 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 146 | nn.BatchNorm2d(channel // reduction, eps=1e-5, momentum=0.01, affine=True), 147 | nn.ReLU(inplace=True), 148 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 149 | ) 150 | self.sigmoid = nn.Sigmoid() 151 | 152 | def forward(self, x): 153 | N, C, H, W = x.size() 154 | res = x 155 | x_cat = torch.cat([self.avgPoolW(x), self.maxPoolW(x)], 1) 156 | x = self.Relu(self.bn(self.conv_1x1(x_cat))) 157 | x_1, x_2 = x.split(C, 1) 158 | 159 | x_1 = self.F_h(x_1) 160 | x_2 = self.F_w(x_2) 161 | s_h = self.sigmoid(x_1) 162 | s_w = self.sigmoid(x_2) 163 | 164 | out = res * s_h.expand_as(res) * s_w.expand_as(res) 165 | 166 | return out 167 | 168 | 169 | class spatial_attn_layer(nn.Module): 170 | def __init__(self, kernel_size=3): 171 | super(spatial_attn_layer, self).__init__() 172 | self.compress = ChannelPool() 173 | self.spatial = Basic(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bn=False, relu=False) 174 | 175 | def forward(self, x): 176 | x_compress = self.compress(x) 177 | x_out = self.spatial(x_compress) 178 | scale = torch.sigmoid(x_out) # broadcasting 179 | return x * scale 180 | 181 | 182 | class Sep(nn.Module): 183 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=1, bias=True): 184 | super().__init__() 185 | self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size, stride, padding, groups=in_channel, bias=bias) 186 | self.conv2 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=bias) 187 | 188 | def forward(self, input): 189 | x = self.conv1(input) 190 | x = self.conv2(x) 191 | return x 192 | 193 | 194 | class RCSSC(nn.Module): 195 | def __init__(self, n_feat, reduction=16): 196 | super(RCSSC, self).__init__() 197 | pooling_r = 4 198 | self.head = nn.Sequential( 199 | nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1, stride=1, bias=True), 200 | nn.LeakyReLU(), 201 | ) 202 | self.SC = nn.Sequential( 203 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 204 | nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1, stride=1, bias=True), 205 | nn.BatchNorm2d(n_feat) 206 | ) 207 | self.SA = spatial_attn_layer() ## Spatial Attention 208 | self.CA = CALayer(n_feat, reduction) ## Channel Attention 209 | 210 | self.conv1x1 = nn.Sequential( 211 | nn.Conv2d(n_feat * 2, n_feat, kernel_size=1), 212 | nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1, stride=1, bias=True) 213 | ) 214 | self.ReLU = nn.LeakyReLU() 215 | self.tail = nn.Conv2d(in_channels=n_feat, out_channels=n_feat, kernel_size=3, padding=1) 216 | 217 | def forward(self, x): 218 | res = x 219 | x = self.head(x) 220 | sa_branch = self.SA(x) 221 | ca_branch = self.CA(x) 222 | x1 = torch.cat([sa_branch, ca_branch], dim=1) # 拼接 223 | x1 = self.conv1x1(x1) 224 | x2 = torch.sigmoid( 225 | torch.add(x, F.interpolate(self.SC(x), x.size()[2:]))) 226 | out = torch.mul(x1, x2) 227 | out = self.tail(out) 228 | out = out + res 229 | out = self.ReLU(out) 230 | return out 231 | 232 | 233 | 234 | class _DCR_block(nn.Module): 235 | def __init__(self, channel_in): 236 | super(_DCR_block, self).__init__() 237 | self.conv_1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in / 2.), kernel_size=3, stride=1, 238 | padding=1) 239 | self.relu1 = nn.LeakyReLU() 240 | self.conv_2 = nn.Conv2d(in_channels=int(channel_in * 3 / 2.), out_channels=int(channel_in / 2.), kernel_size=3, 241 | stride=1, padding=1) 242 | self.relu2 = nn.LeakyReLU() 243 | self.conv_3 = nn.Conv2d(in_channels=channel_in * 2, out_channels=channel_in, kernel_size=3, stride=1, padding=1) 244 | self.relu3 = nn.LeakyReLU() 245 | 246 | def forward(self, x): 247 | residual = x 248 | out = self.relu1(self.conv_1(x)) 249 | conc = torch.cat([x, out], 1) 250 | out = self.relu2(self.conv_2(conc)) 251 | conc = torch.cat([conc, out], 1) 252 | out = self.relu3(self.conv_3(conc)) 253 | out = torch.add(out, residual) 254 | return out 255 | 256 | 257 | class New_block(nn.Module): 258 | def __init__(self, channel_in, reduction): 259 | super(New_block, self).__init__() 260 | 261 | # RCSSC 262 | self.unit_1 = RCSSC(int(channel_in / 2.), reduction) 263 | self.unit_2 = RCSSC(int(channel_in / 2.), reduction) 264 | 265 | self.conv1 = nn.Sequential( 266 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in / 2.), kernel_size=3, padding=1), 267 | nn.LeakyReLU() 268 | ) 269 | self.conv2 = nn.Sequential( 270 | nn.Conv2d(in_channels=int(channel_in * 3 / 2.), out_channels=int(channel_in / 2.), kernel_size=3, 271 | padding=1), 272 | nn.LeakyReLU() 273 | ) 274 | self.conv3 = nn.Sequential( 275 | nn.Conv2d(in_channels=channel_in * 2, out_channels=channel_in, kernel_size=1, padding=0, 276 | stride=1), # 做压缩 277 | nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=3, padding=1), 278 | nn.LeakyReLU() 279 | ) 280 | 281 | def forward(self, x): 282 | residual = x 283 | c1 = self.unit_1(self.conv1(x)) 284 | x = torch.cat([residual, c1], 1) 285 | c2 = self.unit_2(self.conv2(x)) 286 | x = torch.cat([c2, x], 1) 287 | x = self.conv3(x) 288 | x = torch.add(x, residual) 289 | return x 290 | 291 | 292 | class ASCNet(nn.Module): 293 | 294 | def __init__(self, in_ch, out_ch, feats): 295 | super(ASCNet, self).__init__() 296 | self.features = [] 297 | 298 | self.head = single_conv(in_ch, feats) 299 | self.dconv_encode0 = double_conv(feats, feats) # → har 300 | 301 | self.identety1 = nn.Conv2d(in_channels=feats, out_channels=2 * feats, kernel_size=3, stride=2, padding=1) 302 | self.DWT = DWTForward(J=1, wave='haar') 303 | self.dconv_encode1 = single_conv(4 * feats, 2 * feats) 304 | 305 | # CNCM 306 | self.enhance1 = New_block(2 * feats, reduction=16) 307 | 308 | self.identety2 = nn.Conv2d(in_channels=2 * feats, out_channels=4 * feats, kernel_size=3, stride=2, padding=1) 309 | 310 | self.dconv_encode2 = single_conv(8 * feats, 4 * feats) 311 | 312 | self.dconv_encode3 = single_conv(16 * feats, 4 * feats) 313 | 314 | self.enhance2 = New_block(4 * feats, reduction=16) 315 | self.identety3 = nn.Conv2d(in_channels=4 * feats, out_channels=4 * feats, kernel_size=3, stride=2, padding=1) 316 | self.maxpool = nn.MaxPool2d(2) 317 | self.enhance3 = New_block(4 * feats, reduction=16) 318 | 319 | self.mid1 = single_conv(8 * feats, 4 * feats) 320 | self.mid2 = single_conv(4 * feats, 4 * feats + 4 * feats) 321 | 322 | self.pixs = nn.PixelShuffle(2) 323 | 324 | # decoder***************************************************** 325 | self.upsample2 = nn.Sequential( 326 | nn.ConvTranspose2d(8 * feats, 4 * feats, kernel_size=2, stride=2), 327 | # nn.LeakyReLU(inplace=True) 328 | ) 329 | self.upsample1 = nn.Sequential( 330 | nn.ConvTranspose2d(4 * feats, 2 * feats, kernel_size=2, stride=2), 331 | # nn.LeakyReLU(inplace=True) 332 | ) 333 | 334 | self.upsample0 = nn.Sequential( 335 | nn.ConvTranspose2d(2 * feats, feats, kernel_size=2, stride=2), 336 | # nn.LeakyReLU(inplace=True) 337 | ) 338 | self.IDWT = DWTInverse(wave='haar') 339 | 340 | # fair ******************************************************* 341 | self.fair2 = nn.Conv2d(2 * feats, 4 * feats, kernel_size=3, padding=1) 342 | self.fair1 = nn.Conv2d(1 * feats, 2 * feats, kernel_size=3, padding=1) 343 | self.fair0 = nn.Conv2d(int(feats / 2), feats, kernel_size=3, padding=1) 344 | 345 | # decoder***************************************************** 346 | self.dconv_decode2 = nn.Sequential(conv11(4 * feats + 4 * feats, 4 * feats), 347 | New_block(4 * feats, reduction=16)) 348 | 349 | self.dconv_decode1 = nn.Sequential(conv11(2 * feats + 2 * feats, 2 * feats), 350 | New_block(2 * feats, reduction=16)) 351 | 352 | self.dconv_decode0 = double_conv(feats + feats, feats) 353 | self.tail = nn.Sequential(nn.Conv2d(feats, out_ch, 1), nn.Tanh()) 354 | 355 | def make_layer(self, block, channel_in): 356 | layers = [] 357 | layers.append(block(channel_in)) 358 | return nn.Sequential(*layers) 359 | 360 | def _transformer(self, DMT1_yl, DMT1_yh): 361 | list_tensor = [] 362 | a = DMT1_yh[0] 363 | list_tensor.append(DMT1_yl) 364 | for i in range(3): 365 | list_tensor.append(a[:, :, i, :, :]) 366 | return torch.cat(list_tensor, 1) 367 | 368 | def _Itransformer(self, out): 369 | yh = [] 370 | C = int(out.shape[1] / 4) 371 | yl = out[:, 0:C, :, :] 372 | y1 = out[:, C:2 * C, :, :].unsqueeze(2) 373 | y2 = out[:, 2 * C:3 * C, :, :].unsqueeze(2) 374 | y3 = out[:, 3 * C:4 * C, :, :].unsqueeze(2) 375 | final = torch.cat([y1, y2, y3], 2) 376 | yh.append(final) 377 | return yl, yh 378 | 379 | def forward(self, x): 380 | inputs = x 381 | 382 | x0 = self.dconv_encode0(self.head(x)) 383 | 384 | DMT1_yl, DMT1_yh = self.DWT(x0) 385 | DMT1 = self._transformer(DMT1_yl, DMT1_yh) 386 | x = self.dconv_encode1(DMT1) 387 | 388 | res1 = self.identety1(x0) 389 | out = torch.add(x, res1) 390 | 391 | x1 = self.enhance1(out) 392 | 393 | DMT1_yl, DMT1_yh = self.DWT(x1) 394 | DMT2 = self._transformer(DMT1_yl, DMT1_yh) 395 | x = self.dconv_encode2(DMT2) 396 | 397 | res1 = self.identety2(x1) 398 | out2 = torch.add(x, res1) 399 | 400 | x2 = self.enhance2(out2) 401 | 402 | DMT1_yl, DMT1_yh = self.DWT(x2) 403 | DMT3 = self._transformer(DMT1_yl, DMT1_yh) 404 | x = self.dconv_encode3(DMT3) 405 | 406 | res1 = self.identety3(x2) 407 | out3 = torch.add(x, res1) 408 | # MI = self.mid1(out3) 409 | x3 = self.mid2(self.enhance3(out3)) 410 | 411 | x = self.pixs(x3) 412 | x = self.fair2(x) 413 | x = self.dconv_decode2(torch.cat([x, x2], dim=1)) 414 | x = self.pixs(x) 415 | x = self.fair1(x) 416 | x = self.dconv_decode1(torch.cat([x, x1], dim=1)) 417 | x = self.pixs(x) 418 | x = self.fair0(x) 419 | 420 | x = self.dconv_decode0(torch.cat([x, x0], dim=1)) 421 | x = self.tail(x) 422 | out = x + inputs 423 | 424 | return out 425 | 426 | 427 | if __name__ == '__main__': 428 | net = ASCNet(1, 1, feats=32) 429 | input = torch.zeros((1, 1, 256, 256), dtype=torch.float32) 430 | output = net(input) 431 | 432 | flops, params = profile(net, (input,)) 433 | print("-" * 50) 434 | print('FLOPs = ' + str(flops / 1000 ** 3) + ' G') 435 | print('Params = ' + str(params / 1000 ** 2) + ' M') 436 | print(output.shape) 437 | -------------------------------------------------------------------------------- /model/cbam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicConv(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 8 | super(BasicConv, self).__init__() 9 | self.out_channels = out_planes 10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 11 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 12 | self.relu = nn.ReLU() if relu else None 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | if self.bn is not None: 17 | x = self.bn(x) 18 | if self.relu is not None: 19 | x = self.relu(x) 20 | return x 21 | 22 | class Flatten(nn.Module): 23 | def forward(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | class ChannelGate(nn.Module): 27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 28 | super(ChannelGate, self).__init__() 29 | self.gate_channels = gate_channels 30 | self.mlp = nn.Sequential( 31 | Flatten(), 32 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 33 | nn.ReLU(), 34 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 35 | ) 36 | self.pool_types = pool_types 37 | def forward(self, x): 38 | channel_att_sum = None 39 | for pool_type in self.pool_types: 40 | if pool_type=='avg': 41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 42 | channel_att_raw = self.mlp( avg_pool ) 43 | elif pool_type=='max': 44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 45 | channel_att_raw = self.mlp( max_pool ) 46 | elif pool_type=='lp': 47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 48 | channel_att_raw = self.mlp( lp_pool ) 49 | elif pool_type=='lse': 50 | # LSE pool only 51 | lse_pool = logsumexp_2d(x) 52 | channel_att_raw = self.mlp( lse_pool ) 53 | 54 | if channel_att_sum is None: 55 | channel_att_sum = channel_att_raw 56 | else: 57 | channel_att_sum = channel_att_sum + channel_att_raw 58 | 59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 60 | return x * scale 61 | 62 | def logsumexp_2d(tensor): 63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 66 | return outputs 67 | 68 | class ChannelPool(nn.Module): 69 | def forward(self, x): 70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 71 | 72 | class SpatialGate(nn.Module): 73 | def __init__(self): 74 | super(SpatialGate, self).__init__() 75 | kernel_size = 7 76 | self.compress = ChannelPool() 77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 78 | def forward(self, x): 79 | x_compress = self.compress(x) 80 | x_out = self.spatial(x_compress) 81 | scale = F.sigmoid(x_out) # broadcasting 82 | return x * scale 83 | 84 | class CBAM(nn.Module): 85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 86 | super(CBAM, self).__init__() 87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 88 | self.no_spatial=no_spatial 89 | if not no_spatial: 90 | self.SpatialGate = SpatialGate() 91 | def forward(self, x): 92 | x_out = self.ChannelGate(x) 93 | if not self.no_spatial: 94 | x_out = self.SpatialGate(x_out) 95 | return x_out 96 | -------------------------------------------------------------------------------- /prepare_patches.py: -------------------------------------------------------------------------------- 1 | """ 2 | 生成 .h5 数据文件 3 | """ 4 | import argparse 5 | from dataset import prepare_data 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser(description="Building the training patch database") 9 | parser.add_argument("--gray", default=True, action='store_true', help='prepare grayscale database instead of RGB') 10 | # Preprocessing parameters 11 | parser.add_argument("--patch_size", "--p", type=int, default=64, help="Patch size") 12 | parser.add_argument("--stride", "--s", type=int, default=40, help="Size of stride") 13 | parser.add_argument("--max_number_patches", "--m", type=int, default=180, 14 | # parser.add_argument("--max_number_patches", "--m", type=int, default=18, 15 | help="Maximum number of patches") 16 | parser.add_argument("--aug_times", "--a", type=int, default=2, 17 | help="How many times to perform data augmentation") 18 | # Dirs 19 | parser.add_argument("--trainset_dir", type=str, default=r"D:\SCI\03-SCI\dataset\All", help='path of trainset') 20 | parser.add_argument("--valset_dir", type=str, default=r"D:\SCI\03-SCI\dataset\All_16\image", 21 | help='path of validation set') 22 | parser.add_argument("--valset_dirty_dir", type=str, default=r"D:\SCI\03-SCI\dataset\All_16\nosie", 23 | help='path of validation set') 24 | args = parser.parse_args() 25 | 26 | 27 | print("\n### Building databases ###") 28 | print("> Parameters:") 29 | for p, v in zip(args.__dict__.keys(), args.__dict__.values()): 30 | print('\t{}: {}'.format(p, v)) 31 | print('\n') 32 | 33 | prepare_data(args.trainset_dir, args.valset_dir, args.valset_dirty_dir, args.patch_size, args.stride, args.max_number_patches, 34 | aug_times=args.aug_times, gray_mode=args.gray) 35 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch.nn as nn 5 | import argparse 6 | from model.ASCNet import ASCNet 7 | import time 8 | from utils import * 9 | import numpy as np 10 | import torch 11 | import pywt 12 | import torch.nn as nn 13 | import lpips 14 | 15 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | 17 | parser = argparse.ArgumentParser(description="Demo") 18 | parser.add_argument("--log_path", type=str, 19 | default=r"XXXXXX") 20 | parser.add_argument("--filename", type=str, default=r"XXXX") 21 | parser.add_argument("--save", type=bool, default=False) 22 | opt = parser.parse_args() 23 | 24 | 25 | def normalization(data): 26 | _range = np.max(data) - np.min(data) 27 | return (data - np.min(data)) / _range 28 | 29 | 30 | ssim = SSIM() 31 | loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores 32 | loss_fn_vgg = lpips.LPIPS(net='vgg') 33 | 34 | cleanfilename = os.path.join(opt.filename, 'image') 35 | clclist = ['Gauss', 'Uniform', 'Cycle'] 36 | 37 | 38 | 39 | model = ASCNet(1, 1, feats=32) 40 | # model = nn.DataParallel(model).cuda() 41 | model = nn.DataParallel(model) 42 | model.load_state_dict(torch.load(opt.log_path, map_location='cpu')) 43 | 44 | 45 | psnr_sum = 0 46 | ssim_sum = 0 47 | lpips_sum = 0 48 | 49 | 50 | for clc in clclist: 51 | savepth = os.path.join(opt.filename, 'ASCNet', clc) 52 | mk = savepth + '\\' 53 | noisepth = os.path.join(opt.filename, 'noise', clc) 54 | namelist = os.listdir(cleanfilename) 55 | namelist.sort() 56 | model.eval() 57 | with torch.no_grad(): 58 | for name in namelist: 59 | # read noise image and process it 60 | image = cv2.imread(os.path.join(noisepth, name)) 61 | img_np = np.expand_dims(image[:, :, 0], 0) 62 | img_np = np.float32(img_np / 255.) 63 | img_tensor = torch.from_numpy(img_np) 64 | img_tensor = torch.unsqueeze(img_tensor, 0) 65 | # out, outstripe = model(img_tensor) 66 | out = model(img_tensor) 67 | out_val = torch.clip(out, 0., 1.) 68 | 69 | # read clean image 70 | image2 = cv2.imread(os.path.join(cleanfilename, name)) 71 | img_np2 = np.expand_dims(image2[:, :, 0], 0) 72 | img_np2562 = np.float32(img_np2 / 255.) 73 | img_clean = torch.from_numpy(img_np2562) 74 | # img_clean = torch.unsqueeze(img_clean, 0).cuda() 75 | img_clean = torch.unsqueeze(img_clean, 0) 76 | 77 | # calculate PSNR and SSIM 78 | psnr_val = batch_psnr(out_val, img_clean, 1.) 79 | ssim_val = ssim(img_clean, out_val) 80 | lpips_val = loss_fn_alex(out_val, img_clean) 81 | 82 | psnr_sum = psnr_sum + psnr_val 83 | ssim_sum = ssim_sum + ssim_val 84 | lpips_sum = lpips_sum + lpips_val 85 | # print(name) 86 | if opt.save: 87 | if os.path.exists(mk): 88 | pass 89 | else: 90 | os.makedirs(mk) 91 | 92 | 93 | out_np = out.data.cpu().numpy() 94 | # out_np = out.data.numpy() 95 | out_val = out_np[0, :, :, :] 96 | out_val = np.transpose(out_val, (1, 2, 0)) 97 | out_val = out_val * 255 98 | out_valf = np.clip(out_val, 0, 255) 99 | savepth = os.path.join(opt.filename, 'ASCNet', clc) 100 | # final=normalization(out_val) 101 | # out_valf = final * 255 102 | cv2.imwrite(os.path.join(savepth, name), out_valf.astype("uint8")) 103 | 104 | psnr_val = psnr_sum / len(namelist) 105 | ssim_val = ssim_sum / len(namelist) 106 | lpips_val = lpips_sum / len(namelist) 107 | 108 | print("*" * 10 + clc + "*" * 10) 109 | print("PSNR_sum: %.4f" % psnr_val) 110 | print("SSIM_sum: %.4f" % ssim_val) 111 | print("LPIPS_sum: %.4f" % lpips_val) 112 | 113 | psnr_sum = 0 114 | ssim_sum = 0 115 | lpips_sum = 0 116 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 修改网络 3 | GPU编号 4 | logs 修改名称并 远程新建文件 372 5 | pth名称 345 6 | 上传文件 7 | 8 | ************************************* 9 | 10 | ''' 11 | 12 | import warnings 13 | import os 14 | import argparse 15 | import cv2 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | from torch.autograd import Variable 21 | from torch.utils.data import DataLoader 22 | import torchvision.utils as utils 23 | from torch.utils.tensorboard import SummaryWriter 24 | # from models import FFDNet 25 | from dataset import Dataset 26 | from model.ASCNet import ASCNet 27 | from utils import * 28 | from warmup_scheduler import GradualWarmupScheduler 29 | from torchvision import transforms 30 | import matplotlib.pyplot as plt 31 | import random 32 | 33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 34 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 35 | 36 | 37 | # warnings.filterwarnings('ignore') 38 | 39 | 40 | def main(args): 41 | r"""Performs the main training loop 42 | """ 43 | # Load dataset 44 | print('> Loading dataset ...') # 训练和验证都是读的h5文件 45 | dataset_train = Dataset(train=True, gray_mode=args.gray, shuffle=True) 46 | dataset_val = Dataset(train=False, gray_mode=args.gray, shuffle=False) 47 | # 训练数据走的DataLoder 验证数据没有走 48 | loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=args.batch_size, shuffle=True) 49 | print("\t# of training samples: %d\n" % int(len(dataset_train))) 50 | 51 | # Init loggers 52 | if not os.path.exists(args.log_dir): 53 | os.makedirs(args.log_dir) 54 | writer = SummaryWriter(args.log_dir) 55 | # ********************************************************************************************** 56 | # build model 57 | # ********************************************************************************************** 58 | net = ASCNet(1, 1, feats=16) 59 | # Define loss 60 | criterion = nn.MSELoss().cuda() 61 | ssim = SSIM().cuda() 62 | 63 | # Move to GPU 64 | device_ids = [0] 65 | model = nn.DataParallel(net, device_ids=device_ids).cuda() 66 | 67 | # Optimizer 68 | # optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.9999), eps=1e-8) 69 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.9999)) 70 | warmup_epochs = 4 71 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs - warmup_epochs, 72 | eta_min=1e-6) 73 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, 74 | after_scheduler=scheduler_cosine) 75 | scheduler.step() 76 | # noise case 77 | 78 | # case = 3 79 | 80 | start_epoch = 0 81 | training_params = {} 82 | training_params['step'] = 0 83 | training_params['no_orthog'] = args.no_orthog 84 | 85 | # Training 86 | for epoch in range(start_epoch, args.epochs): 87 | print("==============ASCNet==============", epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0])) 88 | psnr_sum = 0 89 | psnr_val = 0 90 | ssim_sum = 0 91 | ssim_val = 0 92 | 93 | # train 94 | for i, data in enumerate(loader_train, 0): 95 | # case = random.randint(0, 3) 96 | case = 3 97 | # print(case) 98 | # Pre-training step 99 | model.train() 100 | model.zero_grad() 101 | optimizer.zero_grad() 102 | img_train = data 103 | # add noise 104 | imgn_train = add_noise(img_train, case, args.noiseIntL) 105 | # imgn_train = add_noise2(img_train, case, args.noiseIntL, args.noiseIntS) 106 | # Create input Variables 107 | img_train = Variable(img_train.cuda()) 108 | imgn_train = Variable(imgn_train.cuda()) 109 | 110 | # Evaluate model and optimize it 111 | out_train = model(imgn_train) 112 | # out_train = torch.clamp(model(imgn_train), 0., 1.) 113 | # torch.clamp(model(imgn_val), 0., 1.) 114 | # ************************************************************************************************************************* 115 | # loss 116 | # ************************************************************************************************************************* 117 | loss1 = criterion(out_train, img_train) 118 | # loss2 = l1(out_train, img_train) 119 | # loss3 = dre(out_train, img_train) 120 | # loss4 = tv(out_train - img_train) 121 | # loss5 = drestr(out_train, img_train) 122 | loss = loss1 123 | loss.backward() 124 | optimizer.step() 125 | 126 | if training_params['step'] % args.save_every == 0: 127 | # Apply regularization by orthogonalizing filters 128 | # Results 129 | model.eval() 130 | out_train = torch.clip(out_train, 0., 1.) 131 | psnr_train = batch_psnr(out_train, img_train, 1.) 132 | ssim_train = ssim(img_train, out_train) 133 | if not training_params['no_orthog']: 134 | model.apply(svd_orthogonalization) 135 | 136 | # Log the scalar values 137 | writer.add_scalar('loss', loss.item(), training_params['step']) 138 | writer.add_scalar('PSNR on training data', psnr_train, \ 139 | training_params['step']) 140 | writer.add_scalar('SSIM on training data', ssim_train, \ 141 | training_params['step']) 142 | print("[epoch %d][%d/%d] loss: %.6f PSNR_train: %.4f" % \ 143 | (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train)) 144 | training_params['step'] += 1 145 | scheduler.step() 146 | # The end of each epoch 147 | 148 | if epoch % 1 == 0: 149 | model.eval() 150 | with torch.no_grad(): 151 | # Validation 152 | for dataclean, datadirty in dataset_val: 153 | datadirty_val = torch.unsqueeze(datadirty, 0) 154 | dataclean_val = torch.unsqueeze(dataclean, 0) 155 | datadirty_val, dataclean_val = Variable(datadirty_val.cuda()), Variable(dataclean_val.cuda()) 156 | out_val = torch.clip(model(datadirty_val), 0., 1.) 157 | psnr_val = batch_psnr(out_val, dataclean_val, 1.) 158 | psnr_sum = psnr_sum + psnr_val 159 | ssim_val = ssim(dataclean_val, out_val) 160 | ssim_sum = ssim_sum + ssim_val.item() 161 | psnr_val = psnr_sum / len(dataset_val) 162 | ssim_val = ssim_sum / len(dataset_val) 163 | print("\n[epoch %d] PSNR_val: %.4f SSIM_val: %.6f" % (epoch + 1, psnr_val, ssim_val)) 164 | writer.add_scalar('PSNR on validation data', psnr_val, training_params['step']) 165 | writer.add_scalar('SSIM on validation data', ssim_val, training_params['step']) 166 | writer.add_scalar('Learning rate', scheduler.get_lr()[0], training_params['step']) 167 | 168 | if epoch == 0: 169 | best_psnr = psnr_val 170 | best_ssim = ssim_val 171 | 172 | print("[epoch %d][%d/%d] psnr_avg: %.4f, ssim_avg: %.4f, best_psnr: %.4f, best_ssim: %.6f" % 173 | (epoch + 1, i + 1, len(dataset_val), psnr_val, ssim_val, best_psnr, best_ssim)) 174 | 175 | if psnr_val >= best_psnr: 176 | best_psnr = psnr_val 177 | best_ssim = ssim_val 178 | print('--- save the model @ ep--{} PSNR--{} SSIM--{}'.format(epoch, best_psnr, best_ssim)) 179 | best_psnr_s = format(best_psnr,'.4f') 180 | best_ssim_s = format(best_ssim,'.6f') 181 | s = "best_" + "ASCNet"+"_" + str(best_psnr_s) + "_" + str(best_ssim_s) + ".pth" 182 | torch.save(model.state_dict(), os.path.join(args.log_dir, s)) 183 | 184 | training_params['start_epoch'] = epoch + 1 185 | 186 | 187 | if __name__ == "__main__": 188 | 189 | parser = argparse.ArgumentParser(description="ASCNet") 190 | # ******************************************************************************************************************************** 191 | parser.add_argument("--log_dir", type=str, default="otherlogs/ASCNet", help='path of log files') 192 | parser.add_argument("--batch_size", type=int, default=128, help="Training batch size") 193 | parser.add_argument("--epochs", "--e", type=int, default=101, help="Number of total training epochs") 194 | parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate") 195 | parser.add_argument("--noiseIntL", nargs=2, type=int, default=[0.05, 0.15], help="Noise training interval") 196 | # parser.add_argument("--noiseIntS", nargs=2, type=int, default=[0, 0.25], help="Noise training interval") 197 | parser.add_argument("--seed", type=int, default=42, help="Threshold for test") 198 | parser.add_argument("--gray", default=True, action='store_true', 199 | help='train grayscale image denoising instead of RGB') 200 | parser.add_argument("--no_orthog", action='store_true', help="Don't perform orthogonalization as regularization") 201 | parser.add_argument("--save_every", type=int, default=100, 202 | help="Number of training steps to log psnr and perform orthogonalization") 203 | argspar = parser.parse_args() 204 | 205 | print("\n#########################################\n" 206 | " ASCNet " 207 | "\n#########################################\n") 208 | print("> Parameters:") 209 | for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): 210 | print('\t{}: {}'.format(p, v)) 211 | print('\n') 212 | 213 | seed_pytorch(argspar.seed) 214 | 215 | main(argspar) 216 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Different utilities such as orthogonalization of weights, initialization of 3 | loggers, etc 4 | 5 | Copyright (C) 2018, Matias Tassano 6 | 7 | This program is free software: you can use, modify and/or 8 | redistribute it under the terms of the GNU General Public 9 | License as published by the Free Software Foundation, either 10 | version 3 of the License, or (at your option) any later 11 | version. You should have received a copy of this license along 12 | this program. If not, see . 13 | """ 14 | import subprocess 15 | import math 16 | import logging 17 | import numpy as np 18 | import cv2 19 | import torch 20 | import torch.nn as nn 21 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 22 | from math import exp 23 | from torch.autograd import Variable 24 | from torch.nn import functional as F 25 | from PIL import Image 26 | import random 27 | 28 | 29 | from torchvision import transforms 30 | import matplotlib.pyplot as plt 31 | 32 | def seed_pytorch(seed=42): 33 | random.seed(seed) 34 | os.environ['PYTHONHASHSEED'] = str(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | 40 | 41 | def weights_init_kaiming(lyr): 42 | r"""Initializes weights of the model according to the "He" initialization 43 | method described in "Delving deep into rectifiers: Surpassing human-level 44 | performance on ImageNet classification" - He, K. et al. (2015), using a 45 | normal distribution. 46 | This function is to be called by the torch.nn.Module.apply() method, 47 | which applies weights_init_kaiming() to every layer of the model. 48 | """ 49 | classname = lyr.__class__.__name__ 50 | if classname.find('Conv') != -1: 51 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in') 52 | elif classname.find('Linear') != -1: 53 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in') 54 | elif classname.find('BatchNorm') != -1: 55 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)). \ 56 | clamp_(-0.025, 0.025) 57 | nn.init.constant(lyr.bias.data, 0.0) 58 | 59 | 60 | def ssim(img1, img2, window_size=11, size_average=True): 61 | (_, channel, _, _) = img1.size() 62 | window = create_window(window_size, channel) 63 | 64 | if img1.is_cuda: 65 | window = window.cuda(img1.get_device()) 66 | window = window.type_as(img1) 67 | 68 | return _ssim(img1, img2, window, window_size, channel, size_average) 69 | 70 | 71 | def batch_psnr(img, imclean, data_range): 72 | r""" 73 | Computes the PSNR along the batch dimension (not pixel-wise) 74 | 75 | Args: 76 | img: a `torch.Tensor` containing the restored image 77 | imclean: a `torch.Tensor` containing the reference image 78 | data_range: The data range of the input image (distance between 79 | minimum and maximum possible values). By default, this is estimated 80 | from the image data-type. 81 | """ 82 | img_cpu = img.data.cpu().numpy().astype(np.float32) 83 | imgclean = imclean.data.cpu().numpy().astype(np.float32) 84 | psnr = 0 85 | for i in range(img_cpu.shape[0]): 86 | psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :], \ 87 | data_range=data_range) 88 | return psnr / img_cpu.shape[0] 89 | 90 | 91 | # def batch_ssim(img, imclean): 92 | # 93 | # img_cpu = img.data.cpu().numpy().astype(np.float32) 94 | # imgclean = imclean.data.cpu().numpy().astype(np.float32) 95 | # ssimall = 0 96 | # for i in range(img_cpu.shape[0]): 97 | # ssimall += ssim(img_cpu[i, :, :, :],imgclean[i, :, :, :]) 98 | # return ssimall / img_cpu.shape[0] 99 | 100 | def data_augmentation(image, mode): 101 | r"""Performs dat augmentation of the input image 102 | 103 | Args: 104 | image: a cv2 (OpenCV) image 105 | mode: int. Choice of transformation to apply to the image 106 | 0 - no transformation 107 | 1 - flip up and down 108 | 2 - rotate counterwise 90 degree 109 | 3 - rotate 90 degree and flip up and down 110 | 4 - rotate 180 degree 111 | 5 - rotate 180 degree and flip 112 | 6 - rotate 270 degree 113 | 7 - rotate 270 degree and flip 114 | """ 115 | out = np.transpose(image, (1, 2, 0)) 116 | if mode == 0: 117 | # original 118 | out = out 119 | elif mode == 1: 120 | # flip up and down 121 | out = np.flipud(out) 122 | elif mode == 2: 123 | # rotate counterwise 90 degree 124 | out = np.rot90(out) 125 | elif mode == 3: 126 | # rotate 90 degree and flip up and down 127 | out = np.rot90(out) 128 | out = np.flipud(out) 129 | elif mode == 4: 130 | # rotate 180 degree 131 | out = np.rot90(out, k=2) 132 | elif mode == 5: 133 | # rotate 180 degree and flip 134 | out = np.rot90(out, k=2) 135 | out = np.flipud(out) 136 | elif mode == 6: 137 | # rotate 270 degree 138 | out = np.rot90(out, k=3) 139 | elif mode == 7: 140 | # rotate 270 degree and flip 141 | out = np.rot90(out, k=3) 142 | out = np.flipud(out) 143 | else: 144 | raise Exception('Invalid choice of image transformation') 145 | return np.transpose(out, (2, 0, 1)) 146 | 147 | 148 | def variable_to_cv2_image(varim): 149 | r"""Converts a torch.autograd.Variable to an OpenCV image 150 | 151 | Args: 152 | varim: a torch.autograd.Variable 153 | """ 154 | nchannels = varim.size()[1] 155 | if nchannels == 1: 156 | res = (varim.data.cpu().numpy()[0, 0, :] * 255.).clip(0, 255).astype(np.uint8) 157 | elif nchannels == 3: 158 | res = varim.data.cpu().numpy()[0] 159 | res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR) 160 | res = (res * 255.).clip(0, 255).astype(np.uint8) 161 | else: 162 | raise Exception('Number of color channels not supported') 163 | return res 164 | 165 | 166 | def get_git_revision_short_hash(): 167 | r"""Returns the current Git commit. 168 | """ 169 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip() 170 | 171 | 172 | def init_logger(argdict): 173 | r"""Initializes a logging.Logger to save all the running parameters to a 174 | log file 175 | 176 | Args: 177 | argdict: dictionary of parameters to be logged 178 | """ 179 | from os.path import join 180 | 181 | logger = logging.getLogger(__name__) 182 | logger.setLevel(level=logging.INFO) 183 | fh = logging.FileHandler(join(argdict.log_dir, 'log.txt'), mode='a') 184 | formatter = logging.Formatter('%(asctime)s - %(message)s') 185 | fh.setFormatter(formatter) 186 | logger.addHandler(fh) 187 | try: 188 | logger.info("Commit: {}".format(get_git_revision_short_hash())) 189 | except Exception as e: 190 | logger.error("Couldn't get commit number: {}".format(e)) 191 | logger.info("Arguments: ") 192 | for k in argdict.__dict__: 193 | logger.info("\t{}: {}".format(k, argdict.__dict__[k])) 194 | 195 | return logger 196 | 197 | 198 | def init_logger_ipol(): 199 | r"""Initializes a logging.Logger in order to log the results after 200 | testing a model 201 | 202 | Args: 203 | result_dir: path to the folder with the denoising results 204 | """ 205 | logger = logging.getLogger('testlog') 206 | logger.setLevel(level=logging.INFO) 207 | fh = logging.FileHandler('out.txt', mode='w') 208 | formatter = logging.Formatter('%(message)s') 209 | fh.setFormatter(formatter) 210 | logger.addHandler(fh) 211 | 212 | return logger 213 | 214 | 215 | def init_logger_test(result_dir): 216 | r"""Initializes a logging.Logger in order to log the results after testing 217 | a model 218 | 219 | Args: 220 | result_dir: path to the folder with the denoising results 221 | """ 222 | from os.path import join 223 | 224 | logger = logging.getLogger('testlog') 225 | logger.setLevel(level=logging.INFO) 226 | fh = logging.FileHandler(join(result_dir, 'log.txt'), mode='a') 227 | formatter = logging.Formatter('%(asctime)s - %(message)s') 228 | fh.setFormatter(formatter) 229 | logger.addHandler(fh) 230 | 231 | return logger 232 | 233 | 234 | def normalize(data): 235 | r"""Normalizes a unit8 image to a float32 image in the range [0, 1] 236 | 237 | Args: 238 | data: a unint8 numpy array to normalize from [0, 255] to [0, 1] 239 | """ 240 | return np.float32(data / 255.) 241 | 242 | 243 | def svd_orthogonalization(lyr): 244 | r"""Applies regularization to the training by performing the 245 | orthogonalization technique described in the paper "FFDNet: Toward a fast 246 | and flexible solution for CNN based image denoising." Zhang et al. (2017). 247 | For each Conv layer in the model, the method replaces the matrix whose columns 248 | are the filters of the layer by new filters which are orthogonal to each other. 249 | This is achieved by setting the singular values of a SVD decomposition to 1. 250 | 251 | This function is to be called by the torch.nn.Module.apply() method, 252 | which applies svd_orthogonalization() to every layer of the model. 253 | """ 254 | classname = lyr.__class__.__name__ 255 | if classname.find('Conv') != -1: 256 | weights = lyr.weight.data.clone() 257 | c_out, c_in, f1, f2 = weights.size() 258 | dtype = lyr.weight.data.type() 259 | 260 | # Reshape filters to columns 261 | # From (c_out, c_in, f1, f2) to (f1*f2*c_in, c_out) 262 | weights = weights.permute(2, 3, 1, 0).contiguous().view(f1 * f2 * c_in, c_out) 263 | 264 | # Convert filter matrix to numpy array 265 | weights = weights.cpu().numpy() 266 | 267 | # SVD decomposition and orthogonalization 268 | mat_u, _, mat_vh = np.linalg.svd(weights, full_matrices=False) 269 | weights = np.dot(mat_u, mat_vh) 270 | 271 | # As full_matrices=False we don't need to set s[:] = 1 and do mat_u*s 272 | lyr.weight.data = torch.Tensor(weights).view(f1, f2, c_in, c_out). \ 273 | permute(3, 2, 0, 1).type(dtype) 274 | else: 275 | pass 276 | 277 | 278 | def remove_dataparallel_wrapper(state_dict): 279 | r"""Converts a DataParallel model to a normal one by removing the "module." 280 | wrapper in the module dictionary 281 | 282 | Args: 283 | state_dict: a torch.nn.DataParallel state dictionary 284 | """ 285 | from collections import OrderedDict 286 | 287 | new_state_dict = OrderedDict() 288 | for k, vl in state_dict.items(): 289 | name = k[7:] # remove 'module.' of DataParallel 290 | new_state_dict[name] = vl 291 | 292 | return new_state_dict 293 | 294 | 295 | def is_rgb(im_path): 296 | r""" Returns True if the image in im_path is an RGB image 297 | """ 298 | from skimage.io import imread 299 | rgb = False 300 | im = imread(im_path) 301 | if (len(im.shape) == 3): 302 | if not (np.allclose(im[..., 0], im[..., 1]) and np.allclose(im[..., 2], im[..., 1])): 303 | rgb = True 304 | print("rgb: {}".format(rgb)) 305 | print("im shape: {}".format(im.shape)) 306 | return rgb 307 | 308 | 309 | def add_noise(img_train, case, noiseIntL): 310 | noise_S = torch.zeros(img_train.size()) 311 | if case == 3: 312 | beta1 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0]) 313 | beta2 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0]) 314 | beta3 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0]) 315 | beta4 = np.random.uniform(noiseIntL[0], noiseIntL[1], size=noise_S.size()[0]) 316 | 317 | for m in range(noise_S.size()[0]): 318 | sizeN_S = noise_S[0, 0, :, :].size() 319 | A1 = np.random.normal(0, beta1[m], sizeN_S[1]) # 一行向量 320 | A2 = np.random.normal(0, beta2[m], sizeN_S[1]) # 一行向量 321 | A3 = np.random.normal(0, beta3[m], sizeN_S[1]) # 一行向量 322 | A4 = np.random.normal(0, beta4[m], sizeN_S[1]) # 一行向量 323 | # 拉伸 324 | A1 = np.tile(A1, (sizeN_S[0], 1)) 325 | A2 = np.tile(A2, (sizeN_S[0], 1)) 326 | A3 = np.tile(A3, (sizeN_S[0], 1)) 327 | A4 = np.tile(A4, (sizeN_S[0], 1)) 328 | # add dim 329 | A1 = np.expand_dims(A1, 0) 330 | A2 = np.expand_dims(A2, 0) 331 | A3 = np.expand_dims(A3, 0) 332 | A4 = np.expand_dims(A4, 0) 333 | # to tensor 334 | A1 = torch.from_numpy(A1) 335 | A2 = torch.from_numpy(A2) 336 | A3 = torch.from_numpy(A3) 337 | A4 = torch.from_numpy(A4) 338 | imgn_train_m = A1 + A2 * img_train[m] + A3 * A3 * img_train[m] + A4 * A4 * A4 * img_train[m] + \ 339 | img_train[m] 340 | imgn_train_m_c = torch.clip(imgn_train_m, 0., 1.) 341 | noise_S[m, :, :, :] = imgn_train_m_c 342 | imgn_train = noise_S 343 | return imgn_train 344 | 345 | 346 | 347 | 348 | def gaussian(window_size, sigma): 349 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 350 | return gauss / torch.sum(gauss) # 归一化 351 | 352 | 353 | # x=gaussian(3,1.5) 354 | # # print(x) 355 | # x=x.unsqueeze(1) 356 | # print(x.shape) #torch.Size([3,1]) 357 | # print(x.t().unsqueeze(0).unsqueeze(0).shape) # torch.Size([1,1,1, 3]) 358 | 359 | 360 | def create_window(window_size, channel=1): 361 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) # window_size,1 362 | # mm:矩阵乘法 t:转置矩阵 ->1,1,window_size,_window_size 363 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 364 | # expand:扩大张量的尺寸,比如3,1->3,4则意味将输入张量的列复制四份, 365 | # 1,1,window_size,_window_size->channel,1,window_size,_window_size 366 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 367 | return window 368 | 369 | 370 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 371 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 372 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 373 | 374 | mu1_sq = mu1.pow(2) 375 | mu2_sq = mu2.pow(2) 376 | mu1_mu2 = mu1 * mu2 377 | 378 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 379 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 380 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 381 | 382 | C1 = 0.01 ** 2 383 | C2 = 0.03 ** 2 384 | 385 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 386 | 387 | if size_average: 388 | return ssim_map.mean() 389 | else: 390 | return ssim_map.mean(1).mean(1).mean(1) 391 | 392 | 393 | # 构造损失函数用于网络训练或者普通计算SSIM值 394 | class SSIM(torch.nn.Module): 395 | def __init__(self, window_size=11, size_average=True): 396 | super(SSIM, self).__init__() 397 | self.window_size = window_size 398 | self.size_average = size_average 399 | self.channel = 1 400 | self.window = create_window(window_size, self.channel) 401 | 402 | def forward(self, img1, img2): 403 | (_, channel, _, _) = img1.size() 404 | 405 | if channel == self.channel and self.window.data.type() == img1.data.type(): 406 | window = self.window 407 | else: 408 | window = create_window(self.window_size, channel) 409 | 410 | if img1.is_cuda: 411 | window = window.cuda(img1.get_device()) 412 | window = window.type_as(img1) 413 | 414 | self.window = window 415 | self.channel = channel 416 | 417 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 418 | 419 | 420 | def weights_init_kaiming(lyr): 421 | r"""Initializes weights of the model according to the "He" initialization 422 | method described in "Delving deep into rectifiers: Surpassing human-level 423 | performance on ImageNet classification" - He, K. et al. (2015), using a 424 | normal distribution. 425 | This function is to be called by the torch.nn.Module.apply() method, 426 | which applies weights_init_kaiming() to every layer of the model. 427 | """ 428 | classname = lyr.__class__.__name__ 429 | if classname.find('Conv') != -1: 430 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in') 431 | elif classname.find('Linear') != -1: 432 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in') 433 | elif classname.find('BatchNorm') != -1: 434 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)). \ 435 | clamp_(-0.025, 0.025) 436 | nn.init.constant(lyr.bias.data, 0.0) 437 | 438 | 439 | def findLastCheckpoint(save_dir): 440 | file_list = glob.glob(os.path.join(save_dir, '*epoch*.pth')) 441 | if file_list: 442 | epochs_exist = [] 443 | for file_ in file_list: 444 | result = re.findall(".*epoch(.*).pth.*", file_) 445 | epochs_exist.append(int(result[0])) 446 | initial_epoch = max(epochs_exist) 447 | else: 448 | initial_epoch = 0 449 | return initial_epoch 450 | 451 | 452 | def batch_PSNR(img, imclean, data_range): 453 | Img = img.data.cpu().numpy().astype(np.float32) 454 | Iclean = imclean.data.cpu().numpy().astype(np.float32) 455 | PSNR = 0 456 | for i in range(Img.shape[0]): 457 | PSNR += peak_signal_noise_ratio(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range) 458 | return (PSNR / Img.shape[0]) 459 | 460 | 461 | def normalize(data): 462 | return data / 255. 463 | 464 | 465 | def is_image(img_name): 466 | if img_name.endswith(".jpg") or img_name.endswith(".bmp") or img_name.endswith(".png"): 467 | return True 468 | else: 469 | return False 470 | 471 | 472 | def print_network(net): 473 | num_params = 0 474 | for param in net.parameters(): 475 | num_params += param.numel() 476 | print(net) 477 | print('Total number of parameters: %d' % num_params) 478 | 479 | 480 | class ImagePool(): 481 | """This class implements an image buffer that stores previously generated images. 482 | 483 | This buffer enables us to update discriminators using a history of generated images 484 | rather than the ones produced by the latest generators. 485 | """ 486 | 487 | def __init__(self, pool_size): 488 | """Initialize the ImagePool class 489 | 490 | Parameters: 491 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 492 | """ 493 | self.pool_size = pool_size 494 | if self.pool_size > 0: # create an empty pool 495 | self.num_imgs = 0 496 | self.images = [] 497 | 498 | def query(self, images): 499 | """Return an image from the pool. 500 | 501 | Parameters: 502 | images: the latest generated images from the generator 503 | 504 | Returns images from the buffer. 505 | 506 | By 50/100, the buffer will return input images. 507 | By 50/100, the buffer will return images previously stored in the buffer, 508 | and insert the current images to the buffer. 509 | """ 510 | if self.pool_size == 0: # if the buffer size is 0, do nothing 511 | return images 512 | return_images = [] 513 | for image in images: 514 | image = torch.unsqueeze(image.data, 0) 515 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 516 | self.num_imgs = self.num_imgs + 1 517 | self.images.append(image) 518 | return_images.append(image) 519 | else: 520 | p = random.uniform(0, 1) 521 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 522 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 523 | tmp = self.images[random_id].clone() 524 | self.images[random_id] = image 525 | return_images.append(tmp) 526 | else: # by another 50% chance, the buffer will return the current image 527 | return_images.append(image) 528 | return_images = torch.cat(return_images, 0) # collect all the images and return 529 | return return_images 530 | 531 | 532 | class GANLoss(nn.Module): 533 | """Define different GAN objectives. 534 | 535 | The GANLoss class abstracts away the need to create the target label tensor 536 | that has the same size as the input. 537 | """ 538 | 539 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 540 | """ Initialize the GANLoss class. 541 | 542 | Parameters: 543 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 544 | target_real_label (bool) - - label for a real image 545 | target_fake_label (bool) - - label of a fake image 546 | 547 | Note: Do not use sigmoid as the last layer of Discriminator. 548 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 549 | """ 550 | super(GANLoss, self).__init__() 551 | self.register_buffer('real_label', torch.tensor(target_real_label)) 552 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 553 | self.gan_mode = gan_mode 554 | if gan_mode == 'lsgan': 555 | self.loss = nn.MSELoss() 556 | elif gan_mode == 'vanilla': 557 | self.loss = nn.BCEWithLogitsLoss() 558 | elif gan_mode in ['wgangp']: 559 | self.loss = None 560 | else: 561 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 562 | 563 | def get_target_tensor(self, prediction, target_is_real): 564 | """Create label tensors with the same size as the input. 565 | 566 | Parameters: 567 | prediction (tensor) - - tpyically the prediction from a discriminator 568 | target_is_real (bool) - - if the ground truth label is for real images or fake images 569 | 570 | Returns: 571 | A label tensor filled with ground truth label, and with the size of the input 572 | """ 573 | 574 | if target_is_real: 575 | target_tensor = self.real_label 576 | else: 577 | target_tensor = self.fake_label 578 | return target_tensor.expand_as(prediction) 579 | 580 | def __call__(self, prediction, target_is_real): 581 | """Calculate loss given Discriminator's output and grount truth labels. 582 | 583 | Parameters: 584 | prediction (tensor) - - tpyically the prediction output from a discriminator 585 | target_is_real (bool) - - if the ground truth label is for real images or fake images 586 | 587 | Returns: 588 | the calculated loss. 589 | """ 590 | if self.gan_mode in ['lsgan', 'vanilla']: 591 | target_tensor = self.get_target_tensor(prediction, target_is_real) 592 | # pdb.set_trace() 593 | loss = self.loss(prediction, target_tensor) 594 | elif self.gan_mode == 'wgangp': 595 | if target_is_real: 596 | loss = -prediction.mean() 597 | else: 598 | loss = prediction.mean() 599 | return loss 600 | 601 | 602 | class TVloss(nn.Module): 603 | def __init__(self, TVloss_weight=1): 604 | super(TVloss, self).__init__() 605 | self.TVloss_weight = TVloss_weight 606 | # self.x = x 607 | # self.y = y 608 | 609 | def forward(self, x, y): 610 | # x = self.x 611 | # y = self.y 612 | batch_size = x.size()[0] 613 | h_x = x.size()[2] 614 | w_x = x.size()[3] 615 | # count_h = self._tensor_size(x[:, :, 1:, :]) # 算出总共求了多少次差 616 | # count_w = self._tensor_size(x[:, :, :, 1:]) 617 | # h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 618 | 619 | # x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片 620 | # 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个 621 | # 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相 622 | # 邻的下一个像素点的差。 623 | w_tv_x = (x[:, :, :, 1:] - x[:, :, :, :w_x - 1]) 624 | w_tv_y = (y[:, :, :, 1:] - y[:, :, :, :w_x - 1]) 625 | h_tv_x = (x[:, :, 1:, :] - x[:, :, :h_x - 1, :]) 626 | h_tv_y = (y[:, :, 1:, :] - y[:, :, :h_x - 1, :]) 627 | MSE = torch.nn.MSELoss() 628 | TVloss = (MSE(h_tv_x, h_tv_y) + MSE(w_tv_x, w_tv_y)) * 0.5 629 | # Drecloss_stripe = torch.pow((w_tv_y - w_tv_x), 2) 630 | # Drecloss_stripe = (w_tv_y - w_tv_x)**2 631 | return self.TVloss_weight * TVloss 632 | 633 | def _tensor_size(self, t): 634 | return t.size()[1] * t.size()[2] * t.size()[3] 635 | 636 | 637 | def WRRGM(A, B): 638 | DWT = DWTForward(J=3, wave='haar').cuda() 639 | IDWT = DWTInverse(wave='haar').cuda() 640 | DMT3_yl, DMT3_yh = DWT(A) 641 | for tensor in DMT3_yh: 642 | tensor.zero_() 643 | out1 = IDWT((DMT3_yl, DMT3_yh)) 644 | 645 | DMT3_yl, DMT3_yh = DWT(B) 646 | for tensor in DMT3_yh: 647 | tensor.zero_() 648 | out2 = IDWT((DMT3_yl, DMT3_yh)) 649 | 650 | return out1, out2 651 | 652 | 653 | class MS_SSIM_L1_LOSS(nn.Module): 654 | """ 655 | Have to use cuda, otherwise the speed is too slow. 656 | Both the group and shape of input image should be attention on. 657 | I set 255 and 1 for gray image as default. 658 | """ 659 | 660 | def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0], 661 | data_range=1.0, 662 | K=(0.01, 0.03), # c1,c2 663 | alpha=0.025, # weight of ssim and l1 loss 664 | compensation=1.0, # final factor for total loss 665 | cuda_dev=0, # cuda device choice 666 | channel=3): # RGB image should set to 3 and Gray image should be set to 1 667 | super(MS_SSIM_L1_LOSS, self).__init__() 668 | self.channel = channel 669 | self.DR = data_range 670 | self.C1 = (K[0] * data_range) ** 2 671 | self.C2 = (K[1] * data_range) ** 2 672 | self.pad = int(2 * gaussian_sigmas[-1]) 673 | self.alpha = alpha 674 | self.compensation = compensation 675 | filter_size = int(4 * gaussian_sigmas[-1] + 1) 676 | g_masks = torch.zeros( 677 | (self.channel * len(gaussian_sigmas), 1, filter_size, filter_size)) # 创建了(3*5, 1, 33, 33)个masks 678 | for idx, sigma in enumerate(gaussian_sigmas): 679 | if self.channel == 1: 680 | # only gray layer 681 | g_masks[idx, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) 682 | elif self.channel == 3: 683 | # r0,g0,b0,r1,g1,b1,...,rM,gM,bM 684 | g_masks[self.channel * idx + 0, 0, :, :] = self._fspecial_gauss_2d(filter_size, 685 | sigma) # 每层mask对应不同的sigma 686 | g_masks[self.channel * idx + 1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) 687 | g_masks[self.channel * idx + 2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) 688 | else: 689 | raise ValueError 690 | self.g_masks = g_masks.cuda(cuda_dev) # 转换为cuda数据类型 691 | 692 | def _fspecial_gauss_1d(self, size, sigma): 693 | """Create 1-D gauss kernel 694 | Args: 695 | size (int): the size of gauss kernel 696 | sigma (float): sigma of normal distribution 697 | 698 | Returns: 699 | torch.Tensor: 1D kernel (size) 700 | """ 701 | coords = torch.arange(size).to(dtype=torch.float) 702 | coords -= size // 2 703 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 704 | g /= g.sum() 705 | return g.reshape(-1) 706 | 707 | def _fspecial_gauss_2d(self, size, sigma): 708 | """Create 2-D gauss kernel 709 | Args: 710 | size (int): the size of gauss kernel 711 | sigma (float): sigma of normal distribution 712 | 713 | Returns: 714 | torch.Tensor: 2D kernel (size x size) 715 | """ 716 | gaussian_vec = self._fspecial_gauss_1d(size, sigma) 717 | return torch.outer(gaussian_vec, gaussian_vec) 718 | # Outer product of input and vec2. If input is a vector of size nn and vec2 is a vector of size mm, 719 | # then out must be a matrix of size (n \times m)(n×m). 720 | 721 | def forward(self, x, y): 722 | b, c, h, w = x.shape 723 | assert c == self.channel 724 | 725 | mux = F.conv2d(x, self.g_masks, groups=c, padding=self.pad) # 图像为96*96,和33*33卷积,出来的是64*64,加上pad=16,出来的是96*96 726 | muy = F.conv2d(y, self.g_masks, groups=c, padding=self.pad) # groups 是分组卷积,为了加快卷积的速度 727 | 728 | mux2 = mux * mux 729 | muy2 = muy * muy 730 | muxy = mux * muy 731 | 732 | sigmax2 = F.conv2d(x * x, self.g_masks, groups=c, padding=self.pad) - mux2 733 | sigmay2 = F.conv2d(y * y, self.g_masks, groups=c, padding=self.pad) - muy2 734 | sigmaxy = F.conv2d(x * y, self.g_masks, groups=c, padding=self.pad) - muxy 735 | 736 | # l(j), cs(j) in MS-SSIM 737 | l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W] 738 | cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2) 739 | if self.channel == 3: 740 | lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :] # 亮度对比因子 741 | PIcs = cs.prod(dim=1) 742 | elif self.channel == 1: 743 | lM = l[:, -1, :, :] 744 | PIcs = cs.prod(dim=1) 745 | 746 | loss_ms_ssim = 1 - lM * PIcs # [B, H, W] 747 | # print(loss_ms_ssim) 748 | 749 | loss_l1 = F.l1_loss(x, y, reduction='none') # [B, C, H, W] 750 | # average l1 loss in num channels 751 | gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-self.channel, length=self.channel), 752 | groups=c, padding=self.pad).mean(1) # [B, H, W] 753 | 754 | loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR 755 | loss_mix = self.compensation * loss_mix 756 | 757 | return loss_mix.mean() 758 | -------------------------------------------------------------------------------- /warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | 在optimizer中会设置一个基础学习率base lr, 8 | 当multiplier>1时,预热机制会在total_epoch内把学习率从base lr逐渐增加到multiplier*base lr,再接着开始正常的scheduler 9 | 当multiplier==1.0时,预热机制会在total_epoch内把学习率从0逐渐增加到base lr,再接着开始正常的scheduler 10 | Args: 11 | optimizer (Optimizer): Wrapped optimizer. 12 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 13 | total_epoch: target learning rate is reached at total_epoch, gradually 14 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 15 | """ 16 | 17 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 18 | self.multiplier = multiplier 19 | if self.multiplier < 1.: 20 | raise ValueError('multiplier should be greater thant or equal to 1.') 21 | self.total_epoch = total_epoch 22 | self.after_scheduler = after_scheduler 23 | self.finished = False 24 | super(GradualWarmupScheduler, self).__init__(optimizer) 25 | 26 | def get_lr(self): 27 | if self.last_epoch > self.total_epoch: 28 | if self.after_scheduler and (not self.finished): 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | # !这是很关键的一个环节,需要直接返回新的base-lr 32 | return [base_lr for base_lr in self.after_scheduler.base_lrs] 33 | if self.multiplier == 1.0: 34 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 35 | else: 36 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 37 | 38 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 39 | if epoch is None: 40 | epoch = self.last_epoch + 1 41 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 42 | print('warmuping...') 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr=None 45 | if self.multiplier == 1.0: 46 | warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 47 | else: 48 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 49 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 50 | param_group['lr'] = lr 51 | else: 52 | if epoch is None: 53 | self.after_scheduler.step(metrics, None) 54 | else: 55 | self.after_scheduler.step(metrics,epoch - self.total_epoch) 56 | 57 | def step(self, epoch=None, metrics=None): 58 | if type(self.after_scheduler) != ReduceLROnPlateau: 59 | if self.finished and self.after_scheduler: 60 | if epoch is None: 61 | self.after_scheduler.step(None) 62 | else: 63 | self.after_scheduler.step(epoch - self.total_epoch) 64 | self._last_lr = self.after_scheduler.get_last_lr() 65 | else: 66 | return super(GradualWarmupScheduler, self).step(epoch) 67 | else: 68 | self.step_ReduceLROnPlateau(metrics, epoch) 69 | --------------------------------------------------------------------------------