├── data ├── trainA │ └── Readme.txt ├── trainB │ └── Readme.txt ├── valA │ └── Readme.txt ├── valB │ └── Readme.txt └── addnoise_clip_Gauss.py ├── Fig ├── image0.png ├── image1.png └── image2.png ├── test.py ├── SSIM.py ├── README.md ├── warmup_scheduler.py ├── saver.py ├── train.py ├── dataset.py ├── networks.py ├── options.py ├── model_singalG.py ├── models └── MWUNet.py └── utils.py /data/trainA/Readme.txt: -------------------------------------------------------------------------------- 1 | Stripe Domain Training Set -------------------------------------------------------------------------------- /data/trainB/Readme.txt: -------------------------------------------------------------------------------- 1 | Clean Domain Training Set -------------------------------------------------------------------------------- /data/valA/Readme.txt: -------------------------------------------------------------------------------- 1 | Stripe Domain Validation Set -------------------------------------------------------------------------------- /data/valB/Readme.txt: -------------------------------------------------------------------------------- 1 | Clean Domain Validation Set -------------------------------------------------------------------------------- /Fig/image0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/DestripeCycleGAN/HEAD/Fig/image0.png -------------------------------------------------------------------------------- /Fig/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/DestripeCycleGAN/HEAD/Fig/image1.png -------------------------------------------------------------------------------- /Fig/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xdFai/DestripeCycleGAN/HEAD/Fig/image2.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from options import TestOptions 3 | from dataset import dataset_single_test 4 | from model_singalG import DerainCycleGAN 5 | 6 | import os 7 | import time 8 | import torchvision 9 | import numpy as np 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 11 | 12 | def normalization(data): 13 | _range = np.max(data) - np.min(data) 14 | return (data - np.min(data)) / _range 15 | 16 | def main(): 17 | # parse options 18 | parser = TestOptions() 19 | opts = parser.parse() 20 | 21 | # data loader 22 | print('\n--- load dataset ---') 23 | dataset = dataset_single_test(opts, opts.input_dim_a) 24 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=opts.nThreads) 25 | 26 | # model 27 | print('\n--- load model ---') 28 | model = DerainCycleGAN(opts) 29 | model.setgpu(opts.gpu) 30 | model.resume(opts.resume, train=False) 31 | model.eval() 32 | 33 | # directory 34 | result_dir = os.path.join(opts.result_dir, opts.name) 35 | if not os.path.exists(result_dir): 36 | os.mkdir(result_dir) 37 | 38 | # test 39 | print('\n--- testing ---') 40 | time_test = 0 41 | count = 0 42 | for idx1, (img1, needcrop, img_names) in enumerate(loader): 43 | print('{}/{}'.format(idx1, len(loader))) 44 | img1 = img1.cuda() 45 | imgs = [] 46 | masks = [] 47 | names = [] 48 | start_time = time.time() 49 | # for idx2 in range(1): 50 | with torch.no_grad(): 51 | # img, mask = model.test_forward(img1, a2b=opts.a2b) 52 | img = model.test_forward(img1, a2b=opts.a2b) 53 | 54 | img = torch.clamp(img, -1., 1.) 55 | img = (img + 1) / 2 56 | end_time = time.time() 57 | dur_time = end_time - start_time 58 | time_test += dur_time 59 | print(img_names[idx1], ': ', dur_time) 60 | imgs.append(img) 61 | torchvision.utils.save_image(img, os.path.join(result_dir, (img_names[idx1][0] + '.png')), nrow=1) 62 | count += 1 63 | print('Avg. time:%.4f' % (time_test / count)) 64 | 65 | return 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /SSIM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DestripeCycleGAN: Stripe Simulation CycleGAN for Unsupervised Infrared Image Destriping 2 | 3 | Shiqi Yang, Hanlin Qin, Shuai Yuan, Xiang Yan, Hossein Rahmani, IEEE Transactions on Instrumentation and Measurement, 2024 [[Paper]](https://ieeexplore.ieee.org/document/10711892) [[Weight]](https://drive.google.com/file/d/1VhCR8dTqmqpQSaA_f4GokRzaCCXq5cG8/view?usp=sharing) 4 | 5 | 6 | # Chanlleges and inspiration 7 | ![Image text](https://github.com/xdFai/DestripeCycleGAN/blob/main/Fig/image0.png) 8 | 9 | 10 | # Structure 11 | ![Image text](https://github.com/xdFai/DestripeCycleGAN/blob/main/Fig/image1.png) 12 | 13 | ![Image text](https://github.com/xdFai/DestripeCycleGAN/blob/main/Fig/image2.png) 14 | 15 | 16 | # Introduction 17 | DestripeCycleGAN: Stripe Simulation CycleGAN for Unsupervised Infrared Image Destriping. Shiqi Yang, Hanlin Qin, Shuai Yuan, Xiang Yan 18 | 19 | 20 | The main contributions of this paper are as follows: 21 | 1. An efficient deep unsupervised DestripeCycleGAN is proposed for infrared image destriping. We incorporated a stripe generation model (SGM) into the framework, balancing the semantic information between the degraded and clean domains. 22 | 23 | 2. The Haar Wavelet Background Guidance Module (HBGM) is designed to mitigate the impact of vertical stripes and accurately assess the consistency of background details. As a plug-and-play image constraint module, it can offer a powerful unsupervised restriction for DestripeCycleGAN. 24 | 25 | 3. We design multi-level wavelet U-Net (MWUNet) that leverages Haar wavelet sampling to minimize feature loss. The network effectively integrates multi-scale features and strengthens long-range dependencies by using group fusion block (GFB) in skip connections. 26 | 27 | 28 | ## Citation 29 | 30 | If you find the code useful, please consider citing our paper using the following BibTeX entry. 31 | 32 | ``` 33 | @ARTICLE{10711892, 34 | author={Yang, Shiqi and Qin, Hanlin and Yuan, Shuai and Yan, Xiang and Rahmani, Hossein}, 35 | journal={IEEE Transactions on Instrumentation and Measurement}, 36 | title={DestripeCycleGAN: Stripe Simulation CycleGAN for Unsupervised Infrared Image Destriping}, 37 | year={2024}, 38 | volume={73}, 39 | number={}, 40 | pages={1-14}, 41 | keywords={Noise;Wavelet transforms;Generators;Semantics;Noise reduction;Image restoration;Computational modeling;Adaptation models;Wavelet domain;Image reconstruction;Convolutional neural network (CNN);CycleGAN;infrared image destriping;stripe prior modeling;unsupervised learning}, 42 | doi={10.1109/TIM.2024.3476560}} 43 | 44 | ``` 45 | 46 | ## Usage 47 | 48 | 49 | #### 1. Data 50 | * **Our project has the following structure:** 51 | 52 | ##### 2. Train. 53 | ```bash 54 | python train.py 55 | ``` 56 | 57 | #### 3. Test and demo. [[Weight]](https://drive.google.com/file/d/1VhCR8dTqmqpQSaA_f4GokRzaCCXq5cG8/view?usp=sharing) 58 | ```bash 59 | python test.py 60 | ``` 61 | 62 | ## Contact 63 | **Welcome to raise issues or email to [22191214967@stu.xidian.edu.cn](22191214967@stu.xidian.edu.cn) or [yuansy@stu.xidian.edu.cn](yuansy@stu.xidian.edu.cn) for any question regarding our DestripeCycleGAN.** 64 | -------------------------------------------------------------------------------- /warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | 在optimizer中会设置一个基础学习率base lr, 8 | 当multiplier>1时,预热机制会在total_epoch内把学习率从base lr逐渐增加到multiplier*base lr,再接着开始正常的scheduler 9 | 当multiplier==1.0时,预热机制会在total_epoch内把学习率从0逐渐增加到base lr,再接着开始正常的scheduler 10 | Args: 11 | optimizer (Optimizer): Wrapped optimizer. 12 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 13 | total_epoch: target learning rate is reached at total_epoch, gradually 14 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 15 | """ 16 | 17 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 18 | self.multiplier = multiplier 19 | if self.multiplier < 1.: 20 | raise ValueError('multiplier should be greater thant or equal to 1.') 21 | self.total_epoch = total_epoch 22 | self.after_scheduler = after_scheduler 23 | self.finished = False 24 | super(GradualWarmupScheduler, self).__init__(optimizer) 25 | 26 | def get_lr(self): 27 | if self.last_epoch > self.total_epoch: 28 | if self.after_scheduler and (not self.finished): 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | # !这是很关键的一个环节,需要直接返回新的base-lr 32 | return [base_lr for base_lr in self.after_scheduler.base_lrs] 33 | if self.multiplier == 1.0: 34 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 35 | else: 36 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 37 | 38 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 39 | if epoch is None: 40 | epoch = self.last_epoch + 1 41 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 42 | print('warmuping...') 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr=None 45 | if self.multiplier == 1.0: 46 | warmup_lr = [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 47 | else: 48 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 49 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 50 | param_group['lr'] = lr 51 | else: 52 | if epoch is None: 53 | self.after_scheduler.step(metrics, None) 54 | else: 55 | self.after_scheduler.step(metrics,epoch - self.total_epoch) 56 | 57 | def step(self, epoch=None, metrics=None): 58 | if type(self.after_scheduler) != ReduceLROnPlateau: 59 | if self.finished and self.after_scheduler: 60 | if epoch is None: 61 | self.after_scheduler.step(None) 62 | else: 63 | self.after_scheduler.step(epoch - self.total_epoch) 64 | self._last_lr = self.after_scheduler.get_last_lr() 65 | else: 66 | return super(GradualWarmupScheduler, self).step(epoch) 67 | else: 68 | self.step_ReduceLROnPlateau(metrics, epoch) 69 | -------------------------------------------------------------------------------- /data/addnoise_clip_Gauss.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import pywt 5 | import os 6 | import argparse 7 | 8 | # args 系列 9 | parser = argparse.ArgumentParser(description='add stripe noise') 10 | parser.add_argument('--cleanfilename', type=str, default=r'/path/to/your/clean/data', help="the path of clean image") 11 | parser.add_argument('--generatefilename', type=str, default=r'/path/to/your/clean/result', 12 | help="the path of noise image") 13 | opt = parser.parse_args() 14 | 15 | if not os.path.isdir(opt.generatefilename): 16 | os.makedirs(opt.generatefilename) 17 | 18 | clist = os.listdir(opt.cleanfilename) 19 | clist.sort() 20 | 21 | 22 | # 定义噪声大小 23 | 24 | # noiseB_S = [0.01, 0.05] 25 | # noiseB_S = [0.01, 0.1] 26 | noiseB_S = [0.1, 0.1] 27 | # noiseB_S = [0.10, 0.15] 28 | # noiseB_S = [0.0, 0.1] 29 | 30 | case = 3 31 | 32 | 33 | for i in clist: 34 | path = os.path.join(opt.cleanfilename, i) 35 | image = cv2.imread(path) 36 | img = image[:, :, 0] 37 | # img = np.expand_dims(a, axis=0) 38 | img = np.float32(img / 255.) 39 | img = torch.Tensor(img) 40 | # img_val = torch.unsqueeze(img, 0) 41 | noise_S = torch.zeros(img.size()) 42 | sizeN_S = noise_S.size() 43 | if case == 0: 44 | # add stride noise 45 | # 随机定一个 分布的最大值 46 | beta = np.random.uniform(noiseB_S[0], noiseB_S[1]) 47 | noise_col = np.random.uniform(-beta, beta, sizeN_S[1]) 48 | S_noise = np.tile(noise_col, (sizeN_S[0], 1)) 49 | S_noise = torch.from_numpy(S_noise) 50 | imgn_val = S_noise+img 51 | 52 | elif case == 1: 53 | beta1 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 54 | beta2 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 55 | 56 | A1 = np.random.uniform(-beta1, beta1, sizeN_S[1]) # 一行向量 57 | A2 = np.random.uniform(-beta2, beta2, sizeN_S[1]) # 一行向量 58 | 59 | A1 = np.tile(A1, (sizeN_S[0], 1)) 60 | A2 = np.tile(A2, (sizeN_S[0], 1)) 61 | # 62 | A1 = torch.from_numpy(A1) 63 | A2 = torch.from_numpy(A2) 64 | imgn_val = A1 + A2 * img+img 65 | 66 | elif case == 2: 67 | beta1 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 68 | beta2 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 69 | beta3 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 70 | 71 | A1 = np.random.uniform(-beta1, beta1, sizeN_S[1]) # 一行向量 72 | A2 = np.random.uniform(-beta2, beta2, sizeN_S[1]) # 一行向量 73 | A3 = np.random.uniform(-beta3, beta3, sizeN_S[1]) # 一行向量 74 | # 拉伸 75 | A1 = np.tile(A1, (sizeN_S[0], 1)) 76 | A2 = np.tile(A2, (sizeN_S[0], 1)) 77 | A3 = np.tile(A3, (sizeN_S[0], 1)) 78 | # 79 | A1 = torch.from_numpy(A1) 80 | A2 = torch.from_numpy(A2) 81 | A3 = torch.from_numpy(A3) 82 | imgn_val = A1 + A2 * img +A3 * A3 * img+img 83 | 84 | elif case == 3: 85 | beta1 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 86 | beta2 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 87 | beta3 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 88 | beta4 = np.random.uniform(noiseB_S[0], noiseB_S[1]) # 一个数字 89 | 90 | A1 = np.random.normal(-beta1, beta1, sizeN_S[1]) # 一行向量 91 | A2 = np.random.normal(-beta2, beta2, sizeN_S[1]) # 一行向量 92 | A3 = np.random.normal(-beta3, beta3, sizeN_S[1]) # 一行向量 93 | A4 = np.random.normal(-beta4, beta4, sizeN_S[1]) # 一行向量 94 | # 拉伸 95 | A1 = np.tile(A1, (sizeN_S[0], 1)) 96 | A2 = np.tile(A2, (sizeN_S[0], 1)) 97 | A3 = np.tile(A3, (sizeN_S[0], 1)) 98 | A4 = np.tile(A4, (sizeN_S[0], 1)) 99 | # 100 | A1 = torch.from_numpy(A1) 101 | A2 = torch.from_numpy(A2) 102 | A3 = torch.from_numpy(A3) 103 | A4 = torch.from_numpy(A4) 104 | imgn_val = A1 + A2 * img + A3 * A3 * img + A4 * A4 * A4 * img + img 105 | 106 | 107 | 108 | noise_img = imgn_val.numpy() 109 | 110 | noise_img_f = noise_img * 255 111 | noise_img_f = np.clip(noise_img_f, 0, 255) 112 | 113 | cv2.imwrite(os.path.join(opt.generatefilename, i), noise_img_f.astype("uint8")) 114 | 115 | 116 | print(i) -------------------------------------------------------------------------------- /saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | from tensorboardX import SummaryWriter 4 | import numpy as np 5 | from PIL import Image 6 | import torchvision.transforms.functional as F 7 | import pdb 8 | import cv2 9 | import torch 10 | 11 | # tensor to PIL Image 12 | def tensor2img(img): 13 | '''新增这句话,去除噪点''' 14 | img = torch.clamp(img, -1., 1.) 15 | '''''' 16 | img = img.squeeze(0) 17 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 18 | img = img.unsqueeze(0) 19 | img = img.unsqueeze(0) 20 | img = img[0].cpu().float().numpy() 21 | img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 22 | return img.astype(np.uint8) 23 | 24 | # save a set of images 25 | def save_imgs(imgs, names, path, needcrop): 26 | if not os.path.exists(path): 27 | os.mkdir(path) 28 | for img, name in zip(imgs, names): 29 | h, w = img.shape[2], img.shape[3] 30 | img = tensor2img(img) 31 | img = Image.fromarray(np.uint8(img)) 32 | img.save(os.path.join(path, name + '.png')) 33 | 34 | def save_masks(imgs, names, path, needcrop): 35 | if not os.path.exists(path): 36 | os.mkdir(path) 37 | for img, name in zip(imgs, names): 38 | h, w = img.shape[2], img.shape[3] 39 | img = tensor2img(img) 40 | img = Image.fromarray(np.uint8(img)) 41 | name += '_mask' 42 | if needcrop == 1: 43 | img = F.crop(img, 0, 0, h-3, w-3) 44 | # print(os.path.join(path, name + '.png')) 45 | img.save(os.path.join(path, name + '.png')) 46 | 47 | class Saver(): 48 | def __init__(self, opts): 49 | self.display_dir = os.path.join(opts.display_dir, opts.name) 50 | self.model_dir = os.path.join(opts.result_dir, opts.name) 51 | self.image_dir = os.path.join(self.model_dir, 'images') 52 | self.dict_dir = os.path.join(self.model_dir, 'dicts') 53 | self.display_freq = opts.display_freq 54 | self.img_save_freq = opts.img_save_freq 55 | self.model_save_freq = opts.model_save_freq 56 | 57 | # make directory 58 | if not os.path.exists(self.display_dir): 59 | os.makedirs(self.display_dir) 60 | if not os.path.exists(self.model_dir): 61 | os.makedirs(self.model_dir) 62 | if not os.path.exists(self.image_dir): 63 | os.makedirs(self.image_dir) 64 | if not os.path.exists(self.dict_dir): 65 | os.makedirs(self.dict_dir) 66 | 67 | # create tensorboard writer 68 | self.writer = SummaryWriter(log_dir=self.display_dir) 69 | 70 | # write losses and images to tensorboard 71 | def write_display(self, total_it, model): 72 | if (total_it + 1) % self.display_freq == 0: 73 | # write loss 74 | members = [attr for attr in dir(model) if not callable(getattr(model, attr)) and not attr.startswith("__") and 'loss' in attr] 75 | for m in members: 76 | self.writer.add_scalar(m, getattr(model, m), total_it) 77 | # write img 78 | # image_dis = torchvision.utils.make_grid(model.image_display, nrow=model.image_display.size(0)//2)/2 + 0.5 79 | # self.writer.add_image('Image', image_dis, total_it) 80 | 81 | # save result images 82 | def write_img(self, ep, model): 83 | if (ep + 1) % self.img_save_freq == 0: 84 | assembled_images1= model.assemble_outputs() 85 | img_filename = '%s/gen_%05d.jpg' % (self.image_dir, ep) 86 | torchvision.utils.save_image(assembled_images1 / 2 + 0.5, img_filename, nrow=1) 87 | elif ep == -1: 88 | assembled_images1, assembled_images2, assembled_images3 = model.assemble_outputs() 89 | img_filename = '%s/gen_%05d.jpg' % (self.image_dir, ep) 90 | torchvision.utils.save_image(assembled_images1 / 2 + 0.5, img_filename, nrow=1) 91 | img_filename = '%s/maska_%05d.jpg' % (self.image_dir, ep) 92 | torchvision.utils.save_image(assembled_images2 / 2 + 0.5, img_filename, nrow=1) 93 | img_filename = '%s/maskb_%05d.jpg' % (self.image_dir, ep) 94 | torchvision.utils.save_image(assembled_images3 / 2 + 0.5, img_filename, nrow=1) 95 | 96 | # save model 97 | def write_model(self, ep, total_it, model): 98 | if (ep + 1) % self.model_save_freq == 0: 99 | print('--- save the model @ ep %d ---' % (ep)) 100 | model.save('%s/%05d.pth' % (self.model_dir, ep), ep, total_it) 101 | elif ep == -1: 102 | model.save('%s/last.pth' % self.model_dir, ep, total_it) 103 | 104 | # save dict 105 | def write_dict(self, obj, ep, model): 106 | if (ep + 1) % self.model_save_freq == 0: 107 | dict_filename = '%s/%05d' % (self.dict_dir, ep) 108 | print('--- save the dict @ ep %d ---' % (ep)) 109 | model.save_dict(obj, dict_filename) 110 | elif ep == -1: 111 | dict_filename = '%s/last' % (self.dict_dir) 112 | model.save_dict(dict, dict_filename) 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from options import TrainOptions 3 | from dataset import dataset_unpair, dataset_unpair_val 4 | from model_singalG import DerainCycleGAN 5 | from saver import Saver 6 | import os 7 | 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 9 | from SSIM import * 10 | from utils import * 11 | import torch.backends.cudnn as cudnn 12 | 13 | cudnn.benchmark = True 14 | cudnn.fastest = True 15 | 16 | def main(): 17 | # parse options 18 | parser = TrainOptions() 19 | opts = parser.parse() 20 | 21 | # data loader 22 | print('\n--- load dataset ---') 23 | dataset = dataset_unpair(opts) 24 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, 25 | num_workers=opts.nThreads) 26 | dataset_val = dataset_unpair_val(opts) 27 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=opts.nThreads) 28 | criterion = SSIM() 29 | criterion.cuda(opts.gpu) 30 | if not os.path.exists(os.path.join(opts.result_dir, opts.name)): 31 | os.makedirs(os.path.join(opts.result_dir, opts.name)) 32 | trainLogger = open('%s/psnr&ssim.log' % os.path.join(opts.result_dir, opts.name), 'w') 33 | 34 | # model 35 | print('\n--- load model ---') 36 | model = DerainCycleGAN(opts) 37 | model.setgpu(opts.gpu) 38 | if opts.resume is None: 39 | model.initialize() 40 | ep0 = -1 41 | total_it = 0 42 | else: 43 | ep0, total_it = model.resume(opts.resume) 44 | 45 | model.set_scheduler(opts, last_ep=ep0) 46 | ep0 += 1 47 | print('start the training at epoch %d' % (ep0)) 48 | 49 | # saver for display and output 50 | saver = Saver(opts) 51 | 52 | # train 53 | print('\n--- train ---') 54 | for ep in range(ep0, opts.n_ep): 55 | 56 | ssim_sum = 0 57 | ssim_avg = 0 58 | psnr_sum = 0 59 | psnr_avg = 0 60 | 61 | # for it, (images_a, images_b, images_c) in enumerate(train_loader): 62 | for it, (images_a, images_b) in enumerate(train_loader): 63 | if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size: 64 | continue 65 | 66 | model.train() 67 | # input data 68 | images_a = images_a.cuda(opts.gpu).detach() 69 | images_b = images_b.cuda(opts.gpu).detach() 70 | 71 | # operation Structure outside the discriminator 72 | model.update_EG(images_a, images_b, ep, opts) 73 | 74 | # save to display file 75 | if not opts.no_display_img: 76 | saver.write_display(total_it, model) 77 | 78 | print( 79 | 'total_it: %d (ep %d, it %d), lr %08f, ganA %04f, WRGM %04f, recA %04f, recB %04f, identity %04f, percp %04f, total %04f' % ( 80 | total_it, ep, it, model.genA_opt.param_groups[0]['lr'], \ 81 | model.gan_loss_a, model.tvloss, \ 82 | model.l1_recon_A_loss, model.l1_recon_B_loss, \ 83 | model.identity_loss, model.perceptual_loss, model.G_loss)) 84 | 85 | total_it += 1 86 | 87 | # decay learning rate 88 | if opts.n_ep_decay > -1: 89 | model.update_lr() 90 | 91 | model.eval() 92 | 93 | # save result image 94 | saver.write_img(ep, model) 95 | 96 | # Save network weights 97 | saver.write_model(ep, total_it, model) 98 | 99 | print('\n--- valing ---') 100 | model.eval() 101 | with torch.no_grad(): 102 | for i, (input_val, target_val) in enumerate(loader_val, 0): 103 | input_val, target_val = input_val.cuda(opts.gpu), target_val.cuda(opts.gpu) 104 | out_val = model.test_forward(input_val, a2b=opts.a2b) 105 | out_val = torch.clamp(out_val, -1., 1.) 106 | out_val = (out_val + 1) / 2 107 | ssim_val = criterion(target_val, out_val) 108 | ssim_sum = ssim_sum + ssim_val.item() 109 | # out_val = torch.clamp(out_val, 0., 1.) 110 | psnr_val = batch_PSNR(out_val, target_val, 1.) 111 | psnr_sum = psnr_sum + psnr_val 112 | 113 | print("[epoch %d][%d/%d] ssim: %.4f, psnr: %.4f" % 114 | (ep + 1, i + 1, len(loader_val), ssim_val.item(), psnr_val)) 115 | 116 | ssim_avg = ssim_sum / len(loader_val) 117 | psnr_avg = psnr_sum / len(loader_val) 118 | 119 | trainLogger.write('%03d\t%04f\t%04f\r\n' % \ 120 | (ep, psnr_avg, ssim_avg)) 121 | trainLogger.flush() 122 | 123 | if ep == ep0: 124 | best_psnr = psnr_avg 125 | best_ssim = ssim_avg 126 | 127 | print("[epoch %d][%d/%d] ssim_avg: %.4f, psnr_avg: %.4f, best_ssim: %.4f, best_psnr: %.4f" % 128 | (ep + 1, i + 1, len(loader_val), ssim_avg, psnr_avg, best_ssim, best_psnr)) 129 | 130 | if psnr_avg >= best_psnr: 131 | best_psnr = psnr_avg 132 | best_ssim = ssim_avg 133 | print('--- save the model @ ep %d ---' % (ep)) 134 | model.save('%s/net_best_%05d.pth' % (os.path.join(opts.result_dir, opts.name), ep), ep, total_it) 135 | 136 | trainLogger.close() 137 | 138 | return 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import torch.utils.data as data 4 | from PIL import Image 5 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize, Pad, ToPILImage 6 | import random 7 | import cv2 8 | 9 | class dataset_single_test(data.Dataset): 10 | def __init__(self, opts, input_dim): 11 | self.test_path = opts.test_path 12 | images = os.listdir(os.path.join(self.test_path, '00042_set12_0.1')) 13 | 14 | images.sort() 15 | self.img = [os.path.join(self.test_path,'00042_set12_0.1', x) for x in images] 16 | 17 | self.size = len(self.img) 18 | self.input_dim = input_dim 19 | self.img_name = self.img 20 | transforms1 = [Pad((0, 0, 3, 3), padding_mode='edge'), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 21 | self.transforms1 = Compose(transforms1) 22 | transforms2 = [ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 23 | self.transforms2 = Compose(transforms2) 24 | return 25 | 26 | def __getitem__(self, index): 27 | data, needcrop, img_names = self.load_img(self.img[index], self.input_dim) 28 | return data, needcrop, img_names 29 | 30 | def load_img(self, img_name, input_dim): 31 | img = Image.open(img_name).convert('RGB') 32 | y = cv2.imread(img_name) 33 | h,w = y.shape[0], y.shape[1] 34 | needcrop = 0 35 | if h == 321 or w == 321: 36 | img = self.transforms1(img) 37 | needcrop = 1 38 | else: 39 | img = self.transforms2(img) 40 | if input_dim == 1: 41 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 42 | img = img.unsqueeze(0) 43 | img_names = [os.path.splitext(x.split('/')[-1])[0] for x in self.img_name] 44 | return img, needcrop, img_names 45 | 46 | def __len__(self): 47 | return self.size 48 | 49 | class dataset_unpair(data.Dataset): 50 | def __init__(self, opts): 51 | self.train_path = opts.train_path 52 | # A 53 | images_A = os.listdir(os.path.join(self.train_path, opts.phase + 'A')) 54 | self.A = [os.path.join(self.train_path, opts.phase + 'A', x) for x in images_A] 55 | # B 56 | images_B = os.listdir(os.path.join(self.train_path, opts.phase + 'B')) 57 | self.B = [os.path.join(self.train_path, opts.phase + 'B', x) for x in images_B] 58 | 59 | self.A_size = len(self.A) 60 | self.B_size = len(self.B) 61 | self.dataset_size = max(self.A_size, self.B_size) 62 | self.input_dim_A = opts.input_dim_a 63 | self.input_dim_B = opts.input_dim_b 64 | 65 | transforms = [ToTensor()] 66 | self.transforms = Compose(transforms) 67 | print('A: %d, B: %d images'%(self.A_size, self.B_size)) 68 | return 69 | 70 | def __getitem__(self, index): 71 | if self.dataset_size == self.A_size: 72 | data_A = self.load_img(self.A[index], self.input_dim_A) 73 | data_B = self.load_img(self.B[random.randint(0, self.B_size - 1)], self.input_dim_B) 74 | else: 75 | data_A = self.load_img(self.A[random.randint(0, self.A_size - 1)], self.input_dim_A) 76 | data_B = self.load_img(self.B[index], self.input_dim_B) 77 | return data_A, data_B 78 | 79 | def load_img(self, img_name, input_dim): 80 | img = Image.open(img_name).convert('RGB') 81 | img = self.transforms(img) 82 | if input_dim == 1: 83 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 84 | img = img.unsqueeze(0) 85 | return img 86 | 87 | def __len__(self): 88 | return self.dataset_size 89 | 90 | class dataset_unpair_val(data.Dataset): 91 | def __init__(self, opts): 92 | self.val_path = opts.val_path 93 | 94 | # A 95 | images_A = os.listdir(os.path.join(self.val_path, opts.valphase + 'A_1007')) 96 | self.A = [os.path.join(self.val_path, opts.valphase + 'A_1007', x) for x in images_A] 97 | # B 98 | images_B = os.listdir(os.path.join(self.val_path, opts.valphase + 'B_1007')) 99 | self.B = [os.path.join(self.val_path, opts.valphase + 'B_1007', x) for x in images_B] 100 | 101 | self.A_size = len(self.A) 102 | self.B_size = len(self.B) 103 | self.dataset_size = max(self.A_size, self.B_size) 104 | self.input_dim_A = opts.input_dim_a 105 | self.input_dim_B = opts.input_dim_b 106 | 107 | # setup image transformation 108 | # transforms = [ToTensor()] 109 | transforms = [ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 110 | transforms_clean = [ToTensor()] 111 | self.transforms = Compose(transforms) 112 | self.transforms_clean = Compose(transforms_clean) 113 | print('A: %d, B: %d images'%(self.A_size, self.B_size)) 114 | return 115 | 116 | def __getitem__(self, index): 117 | if self.dataset_size == self.A_size: 118 | data_A = self.load_img(self.A[index], self.input_dim_A) 119 | data_B = self.load_img_clean(self.B[index], self.input_dim_B) 120 | else: 121 | data_A = self.load_img(self.A[index], self.input_dim_A) 122 | data_B = self.load_img_clean(self.B[index], self.input_dim_B) 123 | return data_A, data_B 124 | 125 | def load_img(self, img_name, input_dim): 126 | img = Image.open(img_name).convert('RGB') 127 | img = self.transforms(img) 128 | if input_dim == 1: 129 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 130 | img = img.unsqueeze(0) 131 | return img 132 | 133 | def load_img_clean(self, img_name, input_dim): 134 | img = Image.open(img_name).convert('RGB') 135 | img = self.transforms_clean(img) 136 | if input_dim == 1: 137 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 138 | img = img.unsqueeze(0) 139 | return img 140 | 141 | def __len__(self): 142 | return self.dataset_size -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import lr_scheduler 4 | import torchvision.models as models 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | #################################################################### 8 | #------------------------- Discriminators -------------------------- 9 | #################################################################### 10 | class MultiScaleDis(nn.Module): 11 | def __init__(self, input_dim, n_scale=3, n_layer=4, norm='None', sn=False): 12 | super(MultiScaleDis, self).__init__() 13 | ch = 64 14 | self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) 15 | self.Diss = nn.ModuleList() 16 | for _ in range(n_scale): 17 | self.Diss.append(self._make_net(ch, input_dim, n_layer, norm, sn)) 18 | 19 | def _make_net(self, ch, input_dim, n_layer, norm, sn): 20 | model = [] 21 | model += [LeakyReLUConv2d(input_dim, ch, 4, 2, 1, norm, sn)] 22 | tch = ch 23 | for _ in range(1, n_layer): 24 | model += [LeakyReLUConv2d(tch, tch * 2, 4, 2, 1, norm, sn)] 25 | tch *= 2 26 | if sn: 27 | model += [spectral_norm(nn.Conv2d(tch, 1, 1, 1, 0))] 28 | else: 29 | model += [nn.Conv2d(tch, 1, 1, 1, 0)] 30 | # model += [nn.Sigmoid()] 31 | return nn.Sequential(*model) 32 | 33 | def forward(self, x): 34 | outs = [] 35 | for Dis in self.Diss: 36 | outs.append(Dis(x)) 37 | x = self.downsample(x) 38 | # for Dis in self.Diss: 39 | # print(Dis) 40 | 41 | return outs 42 | 43 | #################################################################### 44 | #--------------------------- Vgg16 ---------------------------- 45 | #################################################################### 46 | class Vgg16(nn.Module): 47 | def __init__(self): 48 | super(Vgg16, self).__init__() 49 | features = models.vgg16(pretrained=True).features 50 | self.relu1_1 = torch.nn.Sequential() 51 | self.relu1_2 = torch.nn.Sequential() 52 | 53 | self.relu2_1 = torch.nn.Sequential() 54 | self.relu2_2 = torch.nn.Sequential() 55 | 56 | self.relu3_1 = torch.nn.Sequential() 57 | self.relu3_2 = torch.nn.Sequential() 58 | self.relu3_3 = torch.nn.Sequential() 59 | self.max3 = torch.nn.Sequential() 60 | 61 | 62 | self.relu4_1 = torch.nn.Sequential() 63 | self.relu4_2 = torch.nn.Sequential() 64 | self.relu4_3 = torch.nn.Sequential() 65 | 66 | 67 | self.relu5_1 = torch.nn.Sequential() 68 | self.relu5_2 = torch.nn.Sequential() 69 | self.relu5_3 = torch.nn.Sequential() 70 | 71 | for x in range(2): 72 | self.relu1_1.add_module(str(x), features[x]) 73 | 74 | for x in range(2, 4): 75 | self.relu1_2.add_module(str(x), features[x]) 76 | 77 | for x in range(4, 7): 78 | self.relu2_1.add_module(str(x), features[x]) 79 | 80 | for x in range(7, 9): 81 | self.relu2_2.add_module(str(x), features[x]) 82 | 83 | for x in range(9, 12): 84 | self.relu3_1.add_module(str(x), features[x]) 85 | 86 | for x in range(12, 14): 87 | self.relu3_2.add_module(str(x), features[x]) 88 | 89 | for x in range(14, 16): 90 | self.relu3_3.add_module(str(x), features[x]) 91 | for x in range(16, 17): 92 | self.max3.add_module(str(x), features[x]) 93 | 94 | for x in range(17, 19): 95 | self.relu4_1.add_module(str(x), features[x]) 96 | 97 | for x in range(19, 21): 98 | self.relu4_2.add_module(str(x), features[x]) 99 | 100 | for x in range(21, 23): 101 | self.relu4_3.add_module(str(x), features[x]) 102 | 103 | for x in range(23, 26): 104 | self.relu5_1.add_module(str(x), features[x]) 105 | 106 | for x in range(26, 28): 107 | self.relu5_2.add_module(str(x), features[x]) 108 | 109 | for x in range(28, 30): 110 | self.relu5_3.add_module(str(x), features[x]) 111 | 112 | 113 | # don't need the gradients, just want the features 114 | for param in self.parameters(): 115 | param.requires_grad = False 116 | 117 | def forward(self, x): 118 | relu1_1 = self.relu1_1(x) 119 | relu1_2 = self.relu1_2(relu1_1) 120 | 121 | relu2_1 = self.relu2_1(relu1_2) 122 | relu2_2 = self.relu2_2(relu2_1) 123 | 124 | relu3_1 = self.relu3_1(relu2_2) 125 | relu3_2 = self.relu3_2(relu3_1) 126 | relu3_3 = self.relu3_3(relu3_2) 127 | max_3 = self.max3(relu3_3) 128 | 129 | 130 | relu4_1 = self.relu4_1(max_3) 131 | relu4_2 = self.relu4_2(relu4_1) 132 | relu4_3 = self.relu4_3(relu4_2) 133 | 134 | 135 | relu5_1 = self.relu5_1(relu4_3) 136 | relu5_2 = self.relu5_1(relu5_1) 137 | relu5_3 = self.relu5_1(relu5_2) 138 | out = { 139 | 'relu1_1': relu1_1, 140 | 'relu1_2': relu1_2, 141 | 142 | 'relu2_1': relu2_1, 143 | 'relu2_2': relu2_2, 144 | 145 | 'relu3_1': relu3_1, 146 | 'relu3_2': relu3_2, 147 | 'relu3_3': relu3_3, 148 | 'max_3':max_3, 149 | 150 | 151 | 'relu4_1': relu4_1, 152 | 'relu4_2': relu4_2, 153 | 'relu4_3': relu4_3, 154 | 155 | 156 | 'relu5_1': relu5_1, 157 | 'relu5_2': relu5_2, 158 | 'relu5_3': relu5_3, 159 | } 160 | return out['relu3_3'] 161 | 162 | #################################################################### 163 | #------------------------- Basic Functions ------------------------- 164 | #################################################################### 165 | def get_scheduler(optimizer, opts, cur_ep=-1): 166 | if opts.lr_policy == 'lambda': 167 | def lambda_rule(ep): 168 | lr_l = 1.0 - max(0, ep - opts.n_ep_decay) / float(opts.n_ep - opts.n_ep_decay + 1) 169 | return lr_l 170 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=cur_ep) 171 | # scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=lr_l, last_epoch=cur_ep) 172 | elif opts.lr_policy == 'step': 173 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opts.n_ep_decay, gamma=0.1, last_epoch=cur_ep) 174 | elif opts.lr_policy == 'warmup': 175 | warmup_epochs = 6 176 | scheduler_cosine = lr_scheduler.CosineAnnealingLR(optimizer, opts.n_ep - warmup_epochs, eta_min=1e-6) 177 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 178 | scheduler.step() 179 | else: 180 | return NotImplementedError('no such learn rate policy') 181 | return scheduler 182 | 183 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class TrainOptions(): 5 | def __init__(self): 6 | self.parser = argparse.ArgumentParser() 7 | 8 | # data loader related 9 | self.parser.add_argument('--phase', type=str, default='train', help='phase for dataloading') 10 | self.parser.add_argument('--valphase', type=str, default='val', help='phase for dataloading') 11 | self.parser.add_argument('--batch_size', type=int, default=64, help='batch size') 12 | self.parser.add_argument('--resize_size', type=int, default=64, help='resized image size for training') 13 | self.parser.add_argument('--crop_size', type=int, default=64, help='cropped image size for training') 14 | self.parser.add_argument('--input_dim_a', type=int, default=3, help='# of input channels for domain A') 15 | self.parser.add_argument('--input_dim_b', type=int, default=3, help='# of input channels for domain B') 16 | self.parser.add_argument('--nThreads', type=int, default=0, help='# of threads for data loader') 17 | self.parser.add_argument('--no_flip', action='store_true', default=False, help='specified if no flipping') 18 | 19 | # ouptput related 20 | self.parser.add_argument('--name', type=str, default='7_3_test', help='folder name to save outputs') 21 | self.parser.add_argument('--display_dir', type=str, default='../logs', help='path for saving display results') 22 | self.parser.add_argument('--result_dir', type=str, default='../results', 23 | help='path for saving result images and models') 24 | self.parser.add_argument('--display_freq', type=int, default=10, help='freq (iteration) of display') 25 | self.parser.add_argument('--img_save_freq', type=int, default=1, help='freq (epoch) of saving images') 26 | self.parser.add_argument('--model_save_freq', type=int, default=20, help='freq (epoch) of saving models') 27 | self.parser.add_argument('--no_display_img', action='store_true', help='specified if no dispaly') 28 | 29 | # training related 30 | self.parser.add_argument('--dis_scale', type=int, default=3, help='scale of discriminator') 31 | self.parser.add_argument('--dis_norm', type=str, default='None', 32 | help='normalization layer in discriminator [None, Instance]') 33 | self.parser.add_argument('--dis_spectral_norm', action='store_true', 34 | help='use spectral normalization in discriminator') 35 | self.parser.add_argument('--lr_policy', type=str, default='warmup', help='type of learn rate decay') 36 | self.parser.add_argument('--n_ep', type=int, default=150, help='number of epochs') 37 | self.parser.add_argument('--n_ep_decay', type=int, default=100, 38 | help='epoch start decay learning rate, set -1 if no decay') 39 | self.parser.add_argument('--resume', type=str, default=None, 40 | help='specified the dir of saved models for resume the training') 41 | self.parser.add_argument('--gpu', type=int, default=0, help='gpu') 42 | self.parser.add_argument('--train_path', type=str, default=r'/media/omnisky/c91e9985-5113-463d-83a6-6ec3405ef3a7/ysq/SCI02/datasets/new_jiaozheng', help='path of training data') 43 | self.parser.add_argument('--val_path', type=str, default=r'/media/omnisky/c91e9985-5113-463d-83a6-6ec3405ef3a7/ysq/SCI02/datasets/new_jiaozheng', help='path of testing data') 44 | self.parser.add_argument('--a2b', type=int, default=0, help='translation direction, 1 for a2b, 0 for b2a') 45 | self.parser.add_argument('--gan_mode', type=str, default='vanilla', 46 | help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 47 | self.parser.add_argument('--pool_size', type=int, default=50, 48 | help='the size of image buffer that stores previously generated images') 49 | 50 | def parse(self): 51 | self.opt = self.parser.parse_args() 52 | args = vars(self.opt) 53 | print('\n--- load options ---') 54 | for name, value in sorted(args.items()): 55 | print('%s: %s' % (str(name), str(value))) 56 | return self.opt 57 | 58 | 59 | class TestOptions(): 60 | def __init__(self): 61 | self.parser = argparse.ArgumentParser() 62 | 63 | # data loader related 64 | self.parser.add_argument('--phase', type=str, default='test', help='phase for dataloading') 65 | self.parser.add_argument('--resize_size', type=int, default=256, help='resized image size for training') 66 | self.parser.add_argument('--crop_size', type=int, default=256, help='cropped image size for training') 67 | self.parser.add_argument('--nThreads', type=int, default=4, help='for data loader') 68 | self.parser.add_argument('--input_dim_a', type=int, default=3, help='# of input channels for domain A') 69 | self.parser.add_argument('--input_dim_b', type=int, default=3, help='# of input channels for domain B') 70 | self.parser.add_argument('--a2b', type=int, default=0, help='translation direction, 1 for a2b, 0 for b2a') 71 | 72 | # ouptput related 73 | self.parser.add_argument('--num', type=int, default=5, help='number of outputs per image') 74 | self.parser.add_argument('--name', type=str, default=r'', help='folder name to save outputs') 75 | self.parser.add_argument('--result_dir', type=str, default='../outputs', help='path for saving result images and models') 76 | 77 | # model related 78 | self.parser.add_argument('--resume', type=str, default='', help='specified the dir of saved models for resume the training') 79 | self.parser.add_argument('--gpu', type=int, default=0, help='gpu') 80 | self.parser.add_argument('--test_path', type=str, default='', help='path of testing data') 81 | self.parser.add_argument('--val_path', type=str, default=r'', help='path of testing data') 82 | self.parser.add_argument('--valphase', type=str, default='val', help='phase for dataloading') 83 | self.parser.add_argument('--gan_mode', type=str, default='lsgan', 84 | help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 85 | self.parser.add_argument('--pool_size', type=int, default=50, 86 | help='the size of image buffer that stores previously generated images') 87 | 88 | def parse(self): 89 | self.opt = self.parser.parse_args() 90 | args = vars(self.opt) 91 | print('\n--- load options ---') 92 | for name, value in sorted(args.items()): 93 | print('%s: %s' % (str(name), str(value))) 94 | # set irrelevant options 95 | self.opt.dis_scale = 3 96 | self.opt.dis_norm = 'None' 97 | self.opt.dis_spectral_norm = False 98 | return self.opt 99 | -------------------------------------------------------------------------------- /model_singalG.py: -------------------------------------------------------------------------------- 1 | import networks 2 | import torch 3 | import torch.nn as nn 4 | import pickle 5 | from utils import * 6 | from models.MWUNet import MWUNet 7 | from SSIM import * 8 | 9 | 10 | class DerainCycleGAN(nn.Module): 11 | def __init__(self, opts): 12 | super(DerainCycleGAN, self).__init__() 13 | 14 | # parameters 15 | lr = 0.0001 16 | 17 | # discriminators 18 | self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, 19 | sn=opts.dis_spectral_norm) 20 | 21 | # generator 22 | self.genA = MWUNet(3, 3) 23 | # cubic noise 24 | case = 3 25 | noise = [0.02, 0.12] 26 | self.genB = add_noise(case, noise) 27 | self.myBatchNormlize = myBatchNormlize().cuda(opts.gpu) 28 | self.myUnnormlize = myUnormlize().cuda(opts.gpu) 29 | 30 | # vgg 31 | self.vgg = networks.Vgg16() 32 | 33 | # optimizers 34 | self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 35 | self.genA_opt = torch.optim.Adam(self.genA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 36 | 37 | # Setup the loss function for training 38 | self.criterionL1 = torch.nn.L1Loss() 39 | self.criterionL2 = torch.nn.MSELoss() 40 | self.criterionGAN = GANLoss(opts.gan_mode).cuda(opts.gpu) 41 | self.criterionRGM = GANLoss('lsgan').cuda(opts.gpu) 42 | self.TVloss = Drecloss_stripe().cuda(opts.gpu) 43 | self.ms_ssim_mix = MS_SSIM_L1_LOSS().cuda(opts.gpu) 44 | self.ssimloss = SSIM().cuda(opts.gpu) 45 | 46 | # create image buffer to store previously generated images 47 | self.fake_A_pool = ImagePool(opts.pool_size) 48 | self.fake_A1_pool = ImagePool(opts.pool_size) 49 | self.fake_B_pool = ImagePool(opts.pool_size) 50 | 51 | # 权重初始化 52 | def initialize(self): 53 | self.disA.apply(networks.gaussian_weights_init) 54 | self.genA.apply(networks.gaussian_weights_init) 55 | 56 | # 学习率衰减类型 57 | def set_scheduler(self, opts, last_ep=0): 58 | self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep) 59 | self.genA_sch = networks.get_scheduler(self.genA_opt, opts, last_ep) 60 | 61 | # 将数据放入GPU 62 | def setgpu(self, gpu): 63 | self.gpu = gpu 64 | self.disA.cuda(self.gpu) 65 | self.genA.cuda(self.gpu) 66 | self.genB.cuda(self.gpu) 67 | self.vgg.cuda(self.gpu) 68 | 69 | def get_z_random(self, batchSize, nz, random_type='gauss'): # 70 | z = torch.randn(batchSize, nz).cuda(self.gpu) 71 | return z 72 | 73 | def test_forward(self, image1, image2=None, a2b=None): 74 | if a2b: 75 | self.fake_A_encoded = self.genA.forward(image1) 76 | return self.fake_A_encoded 77 | 78 | def forward(self, ep, opts): 79 | '''self.real_A_encoded -> self.fake_A_encoded -> self.real_A_recon''' 80 | '''self.real_B_encoded -> self.fake_B_encoded -> self.real_B_recon''' 81 | # input images 82 | real_A = self.input_A 83 | real_B = self.input_B 84 | self.real_A_encoded = real_A 85 | self.real_B_encoded = real_B 86 | 87 | # get first cycle 88 | '''self.real_A_encoded -> self.fake_A_encoded''' 89 | '''self.real_B_encoded -> self.fake_B_encoded''' 90 | self.real_A_train = self.myBatchNormlize(self.real_A_encoded) # real_A_train:norm 91 | self.fake_A_encoded = self.genA.forward(self.real_A_train) # fake_A_encoded:norm 92 | self.fake_B_encoded = self.genB.forward(self.real_B_encoded) # fake_B_encoded:tensor 93 | 94 | # get perceptual loss 95 | self.perc_real_A = self.vgg(self.real_A_train).detach() 96 | self.perc_fake_A = self.vgg(self.fake_A_encoded).detach() 97 | 98 | # get second cycle 99 | '''self.fake_A_encoded -> self.real_A_recon''' 100 | '''self.fake_B_encoded -> self.real_B_recon''' 101 | self.fake_B_encoded = self.myBatchNormlize.forward(self.fake_B_encoded) # fake_B_encoded:norm 102 | self.fake_A_tensor = self.myUnnormlize.forward(self.fake_A_encoded) # fake_A_tensor:tensor 103 | self.real_B_recon = self.genA.forward(self.fake_B_encoded) # real_B_recon:norm 104 | self.real_A_recon = self.genB.forward(self.fake_A_tensor) # real_A_recon:tensor 105 | 106 | self.real_B_train = self.myBatchNormlize.forward(self.real_B_encoded) # real_B_train:norm 107 | self.fake_B_I = self.genA.forward(self.real_B_train) # fake_B_I:norm 108 | 109 | # self.image_display = torch.cat((self.real_A_encoded[0:1].detach().cpu(), self.fake_A_encoded[0:1].detach().cpu(), \ 110 | # self.real_A_recon[0:1].detach().cpu(), \ 111 | # self.real_B_encoded[0:1].detach().cpu(), self.fake_B_encoded[0:1].detach().cpu(), \ 112 | # self.real_B_recon[0:1].detach().cpu()), dim=0) 113 | 114 | def update_D(self, opts): 115 | self.fake_A_encoded = self.fake_A_pool.query(self.fake_A_encoded) # 50个队列的加载 给判别器使用 116 | # self.fake_A1 = self.fake_A1_pool.query(self.fake_A1) # 50个队列的加载 给判别器使用 117 | self.real_B_recon = self.fake_B_pool.query(self.real_B_recon) # 50个队列的加载 给判别器使用 118 | 119 | # update disA 判别器优化策略 120 | self.disA_opt.zero_grad() 121 | 122 | loss_D1_A = self.backward_D_basic(self.disA, self.real_B_train, self.fake_A_encoded) 123 | loss_D2_A = self.backward_D_basic(self.disA, self.real_B_train, self.real_B_recon) 124 | self.disA_loss = ((loss_D1_A + loss_D2_A) * 0.5).item() 125 | self.disA_opt.step() 126 | 127 | def backward_D_basic(self, netD, real, fake): 128 | # Real 129 | pred_real = netD(real) 130 | loss_D_real1 = self.criterionGAN(pred_real[0], True) 131 | loss_D_real2 = self.criterionGAN(pred_real[1], True) 132 | loss_D_real3 = self.criterionGAN(pred_real[2], True) 133 | loss_D_real = (loss_D_real1 + loss_D_real2 + loss_D_real3) / 3 134 | 135 | # Fake 136 | pred_fake = netD(fake.detach()) 137 | loss_D_fake1 = self.criterionGAN(pred_fake[0], False) 138 | loss_D_fake2 = self.criterionGAN(pred_fake[1], False) 139 | loss_D_fake3 = self.criterionGAN(pred_fake[2], False) 140 | loss_D_fake = (loss_D_fake1 + loss_D_fake2 + loss_D_fake3) / 3 141 | 142 | loss_D = (loss_D_real + loss_D_fake) * 0.5 143 | loss_D.backward() 144 | return loss_D 145 | 146 | def update_EG(self, image_a, image_b, ep, opts): 147 | self.input_A = image_a 148 | self.input_B = image_b 149 | # step——one 判别器以外结构前向传播 150 | self.forward(ep, opts) 151 | # step——two 判别器以外结构的优化器梯度置零 152 | self.genA_opt.zero_grad() 153 | # step——three 计算判别器以外结构loss 154 | # step——four 计算判别器以外结构梯度 155 | self.backward_EG(opts) 156 | # step——five 判别器以外结构反向优化 157 | self.genA_opt.step() 158 | 159 | def backward_EG(self, opts): 160 | # adversarial loss 161 | disA_out1 = self.disA(self.fake_A_encoded)[0] 162 | disA_out2 = self.disA(self.real_B_recon)[0] 163 | loss_G_GAN_A = (self.criterionGAN(disA_out1, True) + self.criterionGAN(disA_out2, True)) * 0.5 164 | 165 | # HBGM 166 | A = self.real_A_train.clone() 167 | B = self.fake_A_encoded.clone() 168 | WR_outA1, WR_outA2 = HBGM(A, B) 169 | loss_tv = self.ms_ssim_mix(WR_outA1,WR_outA2)*10 170 | 171 | # cross cycle consistency loss 172 | self.real_A_recon = self.myBatchNormlize(self.real_A_recon) 173 | loss_G_L1_A = self.TVloss(self.real_A_recon, self.real_A_train) * 100 174 | loss_G_L1_B = self.ms_ssim_mix(self.real_B_recon, self.real_B_train) * 10 175 | 176 | # perceptual loss 177 | loss_perceptual = self.criterionL2(self.perc_fake_A, self.perc_real_A) * 0.01 178 | 179 | # Identity loss 180 | loss_identity_B = self.ms_ssim_mix(self.real_B_train, self.fake_B_I) * 10 181 | loss_identity = loss_identity_B 182 | 183 | loss_G = loss_G_GAN_A + \ 184 | loss_G_L1_A + loss_G_L1_B + \ 185 | loss_identity + \ 186 | loss_perceptual 187 | 188 | # 计算梯度 189 | loss_G.backward(retain_graph=True) 190 | # 损失记录 191 | self.gan_loss_a = loss_G_GAN_A.item() # 生成判别 192 | self.l1_recon_A_loss = loss_G_L1_A.item() # 循环一致 193 | self.l1_recon_B_loss = loss_G_L1_B.item() # 循环一致 194 | self.perceptual_loss = loss_perceptual.item() # 感知损失 195 | self.identity_loss = loss_identity.item() 196 | self.tvloss = loss_tv.item() 197 | self.G_loss = loss_G.item() # 总体损失 198 | 199 | def update_lr(self): 200 | self.disA_sch.step() 201 | self.genA_sch.step() 202 | 203 | def resume(self, model_dir, train=True): 204 | checkpoint = torch.load(model_dir) 205 | # weight 206 | if train: 207 | self.disA.load_state_dict(checkpoint['disA']) 208 | self.genA.load_state_dict(checkpoint['genA']) 209 | 210 | # optimizer 211 | if train: 212 | self.disA_opt.load_state_dict(checkpoint['disA_opt']) 213 | self.genA_opt.load_state_dict(checkpoint['genA_opt']) 214 | return checkpoint['ep'], checkpoint['total_it'] 215 | 216 | def save(self, filename, ep, total_it): 217 | state = { 218 | 'disA': self.disA.state_dict(), 219 | 'genA': self.genA.state_dict(), 220 | 'disA_opt': self.disA_opt.state_dict(), 221 | 'genA_opt': self.genA_opt.state_dict(), 222 | 'ep': ep, 223 | 'total_it': total_it 224 | } 225 | torch.save(state, filename) 226 | return 227 | 228 | def save_dict(self, obj, name): 229 | with open(name + '.pkl', 'wb') as f: 230 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 231 | 232 | def load_dict(self, name): 233 | with open(name + '.pkl', 'rb') as f: 234 | return pickle.load(f) 235 | 236 | def assemble_outputs(self): 237 | images_a = self.normalize_image(self.real_A_encoded).detach() 238 | images_b = self.normalize_image(self.real_B_encoded).detach() 239 | images_a1 = self.normalize_image(self.fake_A_encoded).detach() 240 | images_a3 = self.normalize_image(self.real_A_recon).detach() 241 | images_b1 = self.normalize_image(self.fake_B_encoded).detach() 242 | images_b3 = self.normalize_image(self.real_B_recon).detach() 243 | 244 | row1 = torch.cat((images_a[0:1, ::], images_a1[0:1, ::], images_a3[0:1, ::]), 3) 245 | row2 = torch.cat((images_b[0:1, ::], images_b1[0:1, ::], images_b3[0:1, ::]), 3) 246 | return torch.cat((row1, row2), 2) 247 | 248 | def normalize_image(self, x): 249 | return x[:, 0:3, :, :] -------------------------------------------------------------------------------- /models/MWUNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from pytorch_wavelets import DWTForward, DWTInverse 4 | from torchvision import transforms 5 | import os 6 | import torch.nn.functional as F 7 | import matplotlib.pylab as plt 8 | import torchvision 9 | import os 10 | 11 | """ 12 | The Writeness of DWT has been changed 13 | """ 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 16 | 17 | class DWT(nn.Module): 18 | def __init__(self): 19 | super(DWT, self).__init__() 20 | self.requires_grad = False 21 | 22 | def forward(self, x): 23 | return dwt_init(x) 24 | 25 | 26 | class IWT(nn.Module): 27 | def __init__(self): 28 | super(IWT, self).__init__() 29 | self.requires_grad = False 30 | 31 | def forward(self, x): 32 | return iwt_init(x) 33 | 34 | class single_conv(nn.Module): 35 | def __init__(self, in_channels, out_channels): 36 | super(single_conv, self).__init__() 37 | self.s_conv = nn.Sequential( 38 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 39 | nn.LeakyReLU(inplace=True), 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.s_conv(x) 44 | return x 45 | 46 | class conv11(nn.Module): 47 | def __init__(self, in_channels, out_channels): 48 | super(conv11, self).__init__() 49 | self.s_conv = nn.Conv2d(in_channels, out_channels, 1) 50 | 51 | def forward(self, x): 52 | x = self.s_conv(x) 53 | return x 54 | 55 | 56 | class conv33(nn.Module): 57 | def __init__(self, in_channels, out_channels): 58 | super(conv33, self).__init__() 59 | self.s_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 60 | 61 | def forward(self, x): 62 | x = self.s_conv(x) 63 | return x 64 | 65 | class conv55(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(conv55, self).__init__() 68 | self.s_conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2) 69 | 70 | def forward(self, x): 71 | x = self.s_conv(x) 72 | return x 73 | 74 | class conv77(nn.Module): 75 | def __init__(self, in_channels, out_channels): 76 | super(conv77, self).__init__() 77 | self.s_conv = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3) 78 | 79 | def forward(self, x): 80 | x = self.s_conv(x) 81 | return x 82 | 83 | class _DCR_block(nn.Module): 84 | def __init__(self, channel_in): 85 | super(_DCR_block, self).__init__() 86 | self.conv_1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in / 2.), kernel_size=3, stride=1, 87 | padding=1) 88 | self.relu1 = nn.LeakyReLU() 89 | self.conv_2 = nn.Conv2d(in_channels=int(channel_in * 3 / 2.), out_channels=int(channel_in / 2.), kernel_size=3, 90 | stride=1, padding=1) 91 | self.relu2 = nn.LeakyReLU() 92 | self.conv_3 = nn.Conv2d(in_channels=channel_in * 2, out_channels=channel_in, kernel_size=3, stride=1, padding=1) 93 | self.relu3 = nn.LeakyReLU() 94 | 95 | def forward(self, x): 96 | residual = x 97 | out = self.relu1(self.conv_1(x)) 98 | conc = torch.cat([x, out], 1) 99 | out = self.relu2(self.conv_2(conc)) 100 | conc = torch.cat([conc, out], 1) 101 | out = self.relu3(self.conv_3(conc)) 102 | out = torch.add(out, residual) 103 | return out 104 | 105 | ########################################################################## 106 | class ChannelPool(nn.Module): 107 | def forward(self, x): 108 | # 将maxpooling 与 global average pooling 结果拼接在一起 109 | return torch.cat((torch.max(x, 1) [0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 110 | 111 | class SpatialGate(nn.Module): 112 | def __init__(self): 113 | super(SpatialGate, self).__init__() 114 | self.compress = ChannelPool() 115 | self.spatial = conv77(2,1) 116 | 117 | def forward(self, x): 118 | x_compress = self.compress(x) 119 | x_out = self.spatial(x_compress) 120 | scale = torch.sigmoid_(x_out) 121 | return x * scale 122 | 123 | class TripletAttention(nn.Module): 124 | def __init__(self, no_spatial=False): 125 | super(TripletAttention, self).__init__() 126 | 127 | self.ChannelGateH = SpatialGate() 128 | self.ChannelGateW = SpatialGate() 129 | self.no_spatial = no_spatial 130 | if not no_spatial: 131 | self.SpatialGate = SpatialGate() 132 | 133 | def forward(self, x): 134 | x_perm1 = x.permute(0, 2, 1, 3).contiguous()#H*C*W 135 | x_out1 = self.ChannelGateH(x_perm1) 136 | x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()#C*H*W 137 | x_perm2 = x.permute(0, 3, 2, 1).contiguous()#W*H*C 138 | x_out2 = self.ChannelGateW(x_perm2) 139 | x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()#C*H*W 140 | if not self.no_spatial: 141 | x_out = self.SpatialGate(x) 142 | x_out = (1 / 3) * (x_out + x_out11 + x_out21) 143 | else: 144 | x_out = (1 / 2) * (x_out11 + x_out21) 145 | return x_out 146 | 147 | 148 | class LayerNorm(nn.Module): 149 | r""" From ConvNeXt (https://arxiv.org/pdf/2201.03545.pdf) 150 | """ 151 | 152 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 153 | super().__init__() 154 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 155 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 156 | self.eps = eps 157 | self.data_format = data_format 158 | if self.data_format not in ["channels_last", "channels_first"]: 159 | raise NotImplementedError 160 | self.normalized_shape = (normalized_shape,) 161 | 162 | def forward(self, x): 163 | if self.data_format == "channels_last": 164 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 165 | elif self.data_format == "channels_first": 166 | u = x.mean(1, keepdim=True) 167 | s = (x - u).pow(2).mean(1, keepdim=True) 168 | x = (x - u) / torch.sqrt(s + self.eps) 169 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 170 | return x 171 | 172 | class group_aggregation_bridge(nn.Module): 173 | def __init__(self, dim_xh, dim_xl): 174 | super().__init__() 175 | self.pre_project = nn.Conv2d(dim_xh, dim_xl, 1) 176 | group_size = dim_xl // 2 177 | self.g0 = nn.Sequential( 178 | LayerNorm(normalized_shape=group_size, data_format='channels_first'), 179 | conv11(group_size,group_size), 180 | TripletAttention() 181 | ) 182 | self.g1 = nn.Sequential( 183 | LayerNorm(normalized_shape=group_size, data_format='channels_first'), 184 | conv33(group_size,group_size), 185 | TripletAttention() 186 | ) 187 | self.g2 = nn.Sequential( 188 | LayerNorm(normalized_shape=group_size, data_format='channels_first'), 189 | conv55(group_size,group_size), 190 | TripletAttention() 191 | ) 192 | self.g3 = nn.Sequential( 193 | LayerNorm(normalized_shape=group_size, data_format='channels_first'), 194 | conv77(group_size,group_size), 195 | TripletAttention() 196 | ) 197 | self.tail_conv = nn.Sequential( 198 | LayerNorm(normalized_shape=dim_xl * 2, data_format='channels_first'), 199 | nn.Conv2d(dim_xl * 2, dim_xl, 1) 200 | ) 201 | def forward(self, xh, xl): 202 | xh = self.pre_project(xh) 203 | xh = F.interpolate(xh, size=[xl.size(2), xl.size(3)], mode ='bilinear', align_corners=True) 204 | xh = torch.chunk(xh, 4, dim=1) 205 | xl = torch.chunk(xl, 4, dim=1) 206 | x0 = self.g0(torch.cat((xh[0], xl[0]), dim=1)) 207 | x1 = self.g1(torch.cat((xh[1], xl[1]), dim=1)) 208 | x2 = self.g2(torch.cat((xh[2], xl[2]), dim=1)) 209 | x3 = self.g3(torch.cat((xh[3], xl[3]), dim=1)) 210 | x = torch.cat((x0,x1,x2,x3), dim=1) 211 | x = self.tail_conv(x) 212 | return x 213 | 214 | class MWUNet(nn.Module): 215 | 216 | def __init__(self, in_ch, out_ch): 217 | super(MWUNet, self).__init__() 218 | self.features = [] 219 | 220 | # encoder***************************************************** 221 | self.head = single_conv(in_ch, 32) 222 | self.dconv_encode0 = nn.Sequential(single_conv(32, 32), _DCR_block(32)) # → har 223 | self.DWT = DWTForward(J=1, wave='haar').cuda() 224 | self.dconv_encode1 = nn.Sequential(single_conv(128, 64), _DCR_block(64)) # → har 225 | self.DWT = DWTForward(J=1, wave='haar').cuda() 226 | self.dconv_encode2 = nn.Sequential(single_conv(256, 128), _DCR_block(128)) # → pool 227 | self.maxpool = nn.MaxPool2d(2) 228 | self.mid = nn.Sequential(single_conv(512, 256), _DCR_block(256), 229 | single_conv(256, 512)) 230 | 231 | # upsample***************************************************** 232 | self.upsample2 = nn.Sequential( 233 | nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), 234 | nn.LeakyReLU(inplace=True) 235 | ) 236 | self.upsample1 = nn.Sequential( 237 | nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), 238 | nn.LeakyReLU(inplace=True) 239 | ) 240 | 241 | self.upsample0 = nn.Sequential( 242 | nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), 243 | nn.LeakyReLU(inplace=True) 244 | ) 245 | self.IDWT = DWTInverse(wave='haar').cuda() 246 | 247 | # skip***************************************************** 248 | self.GAB1 = group_aggregation_bridge(256, 64) 249 | self.GAB2 = group_aggregation_bridge(512, 128) 250 | 251 | # decoder***************************************************** 252 | self.dconv_decode2 = nn.Sequential(single_conv(128 + 128, 128), _DCR_block(128),single_conv(128, 256)) 253 | 254 | self.dconv_decode1 = nn.Sequential(single_conv(64 + 64, 64), _DCR_block(64),single_conv(64, 128)) 255 | 256 | self.dconv_decode0 = nn.Sequential(single_conv(64, 32), _DCR_block(32),single_conv(32, 32)) 257 | self.tail = nn.Sequential(nn.Conv2d(32, out_ch, 1), nn.Tanh()) 258 | 259 | def make_layer(self, block, channel_in): 260 | layers = [] 261 | layers.append(block(channel_in)) 262 | return nn.Sequential(*layers) 263 | 264 | def _transformer(self, DMT1_yl, DMT1_yh): 265 | list_tensor = [] 266 | for i in range(3): 267 | list_tensor.append(DMT1_yh[0][:, :, i, :, :]) 268 | list_tensor.append(DMT1_yl) 269 | return torch.cat(list_tensor, 1) 270 | 271 | def _Itransformer(self, out): 272 | yh = [] 273 | C = int(out.shape[1] / 4) 274 | yl = out[:, 0:C, :, :] 275 | y1 = out[:, C:2 * C, :, :].unsqueeze(2) 276 | y2 = out[:, 2 * C:3 * C, :, :].unsqueeze(2) 277 | y3 = out[:, 3 * C:4 * C, :, :].unsqueeze(2) 278 | final = torch.cat([y1, y2, y3], 2) 279 | yh.append(final) 280 | return yl, yh 281 | 282 | def forward(self, x): 283 | input = x 284 | # x = torch.cat((x, mask), 1) 285 | # ***************************************************************************** 286 | # head +encoder 287 | x0 = self.dconv_encode0(self.head(x)) 288 | res0 = x0 289 | # ***************************************************************************** 290 | # har 291 | DMT1_yl, DMT1_yh = self.DWT(x0) 292 | DMT1 = self._transformer(DMT1_yl, DMT1_yh) 293 | x1 = self.dconv_encode1(DMT1) 294 | res1 = x1 295 | # ***************************************************************************** 296 | # har 297 | DMT1_yl, DMT1_yh = self.DWT(x1) 298 | DMT2 = self._transformer(DMT1_yl, DMT1_yh) 299 | x2 = self.dconv_encode2(DMT2) 300 | res2 = x2 301 | # ***************************************************************************** 302 | # pool 303 | DMT1_yl, DMT1_yh = self.DWT(x2) 304 | DMT3 = self._transformer(DMT1_yl, DMT1_yh) 305 | x3 = self.mid(DMT3) 306 | # ***************************************************************************** 307 | 308 | x2 = self.GAB2(x3,x2) 309 | x = self._Itransformer(x3) 310 | 311 | x = self.IDWT(x) 312 | # ***************************************************************************** 313 | x = self.dconv_decode2(torch.cat([x, x2], dim=1)) 314 | # ***************************************************************************** 315 | x1 = self.GAB1(x, x1) 316 | x = self._Itransformer(x) 317 | x = self.IDWT(x) 318 | # ***************************************************************************** 319 | x = self.dconv_decode1(torch.cat([x, x1], dim=1)) 320 | # ***************************************************************************** 321 | x = self._Itransformer(x) 322 | x = self.IDWT(x) 323 | # ***************************************************************************** 324 | x = self.dconv_decode0(torch.cat([x, x0], dim=1)) 325 | x = self.tail(x) 326 | # ***************************************************************************** 327 | out = x + input 328 | 329 | return out 330 | 331 | if __name__ == '__main__': 332 | net = MWUNet(3, 3).cuda() 333 | input = torch.zeros((1, 3, 64, 64), dtype=torch.float32).cuda() 334 | output = net(input) 335 | # print(net.features) 336 | print(output.shape) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | from torchvision import transforms 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import re 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from skimage.metrics import peak_signal_noise_ratio 10 | import os 11 | import glob 12 | import random 13 | import math 14 | from pytorch_wavelets import DWTForward, DWTInverse 15 | from torchvision.transforms import Compose, Normalize 16 | from SSIM import * 17 | import PIL.Image as Image 18 | import matplotlib.pyplot as plt 19 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize, Pad, ToPILImage 20 | import torchvision 21 | from scipy import signal 22 | import cv2 23 | 24 | 25 | class myBatchNormlize(nn.Module): 26 | def __init__(self): 27 | super(myBatchNormlize, self).__init__() 28 | transforms = [Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 29 | self.transforms = Compose(transforms) 30 | 31 | def forward(self,x): 32 | for m in range(x.size()[0]): 33 | x[m,:,:,:] = self.transforms(x[m]) 34 | return x 35 | 36 | class myUnormlize(nn.Module): 37 | def __init__(self): 38 | super(myUnormlize, self).__init__() 39 | 40 | def forward(self, x): 41 | x = torch.clamp(x, -1., 1.) 42 | x = (x + 1) / 2 43 | return x 44 | 45 | def weights_init_kaiming(lyr): 46 | r"""Initializes weights of the model according to the "He" initialization 47 | method described in "Delving deep into rectifiers: Surpassing human-level 48 | performance on ImageNet classification" - He, K. et al. (2015), using a 49 | normal distribution. 50 | This function is to be called by the torch.nn.Module.apply() method, 51 | which applies weights_init_kaiming() to every layer of the model. 52 | """ 53 | classname = lyr.__class__.__name__ 54 | if classname.find('Conv') != -1: 55 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in') 56 | elif classname.find('Linear') != -1: 57 | nn.init.kaiming_normal(lyr.weight.data, a=0, mode='fan_in') 58 | elif classname.find('BatchNorm') != -1: 59 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)). \ 60 | clamp_(-0.025, 0.025) 61 | nn.init.constant(lyr.bias.data, 0.0) 62 | 63 | 64 | def findLastCheckpoint(save_dir): 65 | file_list = glob.glob(os.path.join(save_dir, '*epoch*.pth')) 66 | if file_list: 67 | epochs_exist = [] 68 | for file_ in file_list: 69 | result = re.findall(".*epoch(.*).pth.*", file_) 70 | epochs_exist.append(int(result[0])) 71 | initial_epoch = max(epochs_exist) 72 | else: 73 | initial_epoch = 0 74 | return initial_epoch 75 | 76 | 77 | def batch_PSNR(img, imclean, data_range): 78 | Img = img.data.cpu().numpy().astype(np.float32) 79 | Iclean = imclean.data.cpu().numpy().astype(np.float32) 80 | PSNR = 0 81 | for i in range(Img.shape[0]): 82 | PSNR += peak_signal_noise_ratio(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range) 83 | return (PSNR/Img.shape[0]) 84 | 85 | 86 | def normalize(data): 87 | return data / 255. 88 | 89 | 90 | def is_image(img_name): 91 | if img_name.endswith(".jpg") or img_name.endswith(".bmp") or img_name.endswith(".png"): 92 | return True 93 | else: 94 | return False 95 | 96 | 97 | def print_network(net): 98 | num_params = 0 99 | for param in net.parameters(): 100 | num_params += param.numel() 101 | print(net) 102 | print('Total number of parameters: %d' % num_params) 103 | 104 | class ImagePool(): 105 | """This class implements an image buffer that stores previously generated images. 106 | 107 | This buffer enables us to update discriminators using a history of generated images 108 | rather than the ones produced by the latest generators. 109 | """ 110 | 111 | def __init__(self, pool_size): 112 | """Initialize the ImagePool class 113 | 114 | Parameters: 115 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 116 | """ 117 | self.pool_size = pool_size 118 | if self.pool_size > 0: # create an empty pool 119 | self.num_imgs = 0 120 | self.images = [] 121 | 122 | def query(self, images): 123 | """Return an image from the pool. 124 | 125 | Parameters: 126 | images: the latest generated images from the generator 127 | 128 | Returns images from the buffer. 129 | 130 | By 50/100, the buffer will return input images. 131 | By 50/100, the buffer will return images previously stored in the buffer, 132 | and insert the current images to the buffer. 133 | """ 134 | if self.pool_size == 0: # if the buffer size is 0, do nothing 135 | return images 136 | return_images = [] 137 | for image in images: 138 | image = torch.unsqueeze(image.data, 0) 139 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 140 | self.num_imgs = self.num_imgs + 1 141 | self.images.append(image) 142 | return_images.append(image) 143 | else: 144 | p = random.uniform(0, 1) 145 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 146 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 147 | tmp = self.images[random_id].clone() 148 | self.images[random_id] = image 149 | return_images.append(tmp) 150 | else: # by another 50% chance, the buffer will return the current image 151 | return_images.append(image) 152 | return_images = torch.cat(return_images, 0) # collect all the images and return 153 | return return_images 154 | 155 | class GANLoss(nn.Module): 156 | """Define different GAN objectives. 157 | 158 | The GANLoss class abstracts away the need to create the target label tensor 159 | that has the same size as the input. 160 | """ 161 | 162 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 163 | """ Initialize the GANLoss class. 164 | 165 | Parameters: 166 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 167 | target_real_label (bool) - - label for a real image 168 | target_fake_label (bool) - - label of a fake image 169 | 170 | Note: Do not use sigmoid as the last layer of Discriminator. 171 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 172 | """ 173 | super(GANLoss, self).__init__() 174 | self.register_buffer('real_label', torch.tensor(target_real_label)) 175 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 176 | self.gan_mode = gan_mode 177 | if gan_mode == 'lsgan': 178 | self.loss = nn.MSELoss() 179 | elif gan_mode == 'vanilla': 180 | self.loss = nn.BCEWithLogitsLoss() 181 | elif gan_mode in ['wgangp']: 182 | self.loss = None 183 | else: 184 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 185 | 186 | def get_target_tensor(self, prediction, target_is_real): 187 | """Create label tensors with the same size as the input. 188 | 189 | Parameters: 190 | prediction (tensor) - - tpyically the prediction from a discriminator 191 | target_is_real (bool) - - if the ground truth label is for real images or fake images 192 | 193 | Returns: 194 | A label tensor filled with ground truth label, and with the size of the input 195 | """ 196 | 197 | if target_is_real: 198 | target_tensor = self.real_label 199 | else: 200 | target_tensor = self.fake_label 201 | return target_tensor.expand_as(prediction) 202 | 203 | def __call__(self, prediction, target_is_real): 204 | """Calculate loss given Discriminator's output and grount truth labels. 205 | 206 | Parameters: 207 | prediction (tensor) - - tpyically the prediction output from a discriminator 208 | target_is_real (bool) - - if the ground truth label is for real images or fake images 209 | 210 | Returns: 211 | the calculated loss. 212 | """ 213 | if self.gan_mode in ['lsgan', 'vanilla']: 214 | target_tensor = self.get_target_tensor(prediction, target_is_real) 215 | # pdb.set_trace() 216 | loss = self.loss(prediction, target_tensor) 217 | elif self.gan_mode == 'wgangp': 218 | if target_is_real: 219 | loss = -prediction.mean() 220 | else: 221 | loss = prediction.mean() 222 | return loss 223 | 224 | class Drecloss_stripe(nn.Module): 225 | def __init__(self, Drecloss_stripe_weight=1): 226 | super(Drecloss_stripe, self).__init__() 227 | self.Drecloss_stripe_weight = Drecloss_stripe_weight 228 | 229 | def forward(self, x, y): 230 | h_x = x.size()[2] 231 | h_y = y.size()[2] 232 | h_tv_x = (x[:, :, 1:, :] - x[:, :, :h_x - 1, :]) 233 | h_tv_y = (y[:, :, 1:, :] - y[:, :, :h_y - 1, :]) 234 | L1 = torch.nn.L1Loss() 235 | Drecloss_stripe = L1(h_tv_x, h_tv_y) 236 | return self.Drecloss_stripe_weight * Drecloss_stripe 237 | 238 | def _tensor_size(self, t): 239 | return t.size()[1] * t.size()[2] * t.size()[3] 240 | 241 | class TVloss(nn.Module): 242 | def __init__(self, TVloss_weight=1): 243 | super(TVloss, self).__init__() 244 | self.TVloss_weight = TVloss_weight 245 | 246 | def forward(self,x,y): 247 | h_x = x.size()[2] 248 | w_x = x.size()[3] 249 | w_tv_x = (x[:, :, :, 1:] - x[:, :, :, :w_x - 1]) 250 | w_tv_y = (y[:, :, :, 1:] - y[:, :, :, :w_x - 1]) 251 | h_tv_x = (x[:, :, 1:, :] - x[:, :, :h_x - 1, :]) 252 | h_tv_y = (y[:, :, 1:, :] - y[:, :, :h_x - 1, :]) 253 | MSE = torch.nn.MSELoss() 254 | TVloss = (MSE(h_tv_x, h_tv_y) + MSE(w_tv_x, w_tv_y))*0.5 255 | return self.TVloss_weight * TVloss 256 | 257 | def _tensor_size(self, t): 258 | return t.size()[1] * t.size()[2] * t.size()[3] 259 | 260 | def HBGM(A,B): 261 | DWT = DWTForward(J=3, wave='haar').cuda() 262 | IDWT = DWTInverse(wave='haar').cuda() 263 | DMT3_yl, DMT3_yh = DWT(A) 264 | DMT3_yl.zero_() 265 | for i, tensor in enumerate(DMT3_yh): 266 | DMT3_yh[i][:, :, 1, :, :].zero_() 267 | out1 = IDWT((DMT3_yl, DMT3_yh)) 268 | 269 | DMT3_yl, DMT3_yh = DWT(B) 270 | DMT3_yl.zero_() 271 | for i, tensor in enumerate(DMT3_yh): 272 | DMT3_yh[i][:, :, 1, :, :].zero_() 273 | out2 = IDWT((DMT3_yl, DMT3_yh)) 274 | return out1,out2 275 | 276 | class MS_SSIM_L1_LOSS(nn.Module): 277 | """ 278 | Have to use cuda, otherwise the speed is too slow. 279 | Both the group and shape of input image should be attention on. 280 | I set 255 and 1 for gray image as default. 281 | """ 282 | 283 | def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0], 284 | data_range=1.0, 285 | K=(0.01, 0.03), # c1,c2 286 | alpha=0.025, # weight of ssim and l1 loss 287 | compensation=1.0, # final factor for total loss 288 | cuda_dev=0, # cuda device choice 289 | channel=3): # RGB image should set to 3 and Gray image should be set to 1 290 | super(MS_SSIM_L1_LOSS, self).__init__() 291 | self.channel = channel 292 | self.DR = data_range 293 | self.C1 = (K[0] * data_range) ** 2 294 | self.C2 = (K[1] * data_range) ** 2 295 | self.pad = int(2 * gaussian_sigmas[-1]) 296 | self.alpha = alpha 297 | self.compensation = compensation 298 | filter_size = int(4 * gaussian_sigmas[-1] + 1) 299 | g_masks = torch.zeros( 300 | (self.channel * len(gaussian_sigmas), 1, filter_size, filter_size)) # 创建了(3*5, 1, 33, 33)个masks 301 | for idx, sigma in enumerate(gaussian_sigmas): 302 | if self.channel == 1: 303 | # only gray layer 304 | g_masks[idx, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) 305 | elif self.channel == 3: 306 | # r0,g0,b0,r1,g1,b1,...,rM,gM,bM 307 | g_masks[self.channel * idx + 0, 0, :, :] = self._fspecial_gauss_2d(filter_size, 308 | sigma) # 每层mask对应不同的sigma 309 | g_masks[self.channel * idx + 1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) 310 | g_masks[self.channel * idx + 2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) 311 | else: 312 | raise ValueError 313 | self.g_masks = g_masks.cuda(cuda_dev) # 转换为cuda数据类型 314 | 315 | def _fspecial_gauss_1d(self, size, sigma): 316 | """Create 1-D gauss kernel 317 | Args: 318 | size (int): the size of gauss kernel 319 | sigma (float): sigma of normal distribution 320 | 321 | Returns: 322 | torch.Tensor: 1D kernel (size) 323 | """ 324 | coords = torch.arange(size).to(dtype=torch.float) 325 | coords -= size // 2 326 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 327 | g /= g.sum() 328 | return g.reshape(-1) 329 | 330 | def _fspecial_gauss_2d(self, size, sigma): 331 | """Create 2-D gauss kernel 332 | Args: 333 | size (int): the size of gauss kernel 334 | sigma (float): sigma of normal distribution 335 | 336 | Returns: 337 | torch.Tensor: 2D kernel (size x size) 338 | """ 339 | gaussian_vec = self._fspecial_gauss_1d(size, sigma) 340 | return torch.outer(gaussian_vec, gaussian_vec) 341 | # Outer product of input and vec2. If input is a vector of size nn and vec2 is a vector of size mm, 342 | # then out must be a matrix of size (n \times m)(n×m). 343 | 344 | def forward(self, x, y): 345 | b, c, h, w = x.shape 346 | assert c == self.channel 347 | 348 | mux = F.conv2d(x, self.g_masks, groups=c, padding=self.pad) # 图像为96*96,和33*33卷积,出来的是64*64,加上pad=16,出来的是96*96 349 | muy = F.conv2d(y, self.g_masks, groups=c, padding=self.pad) # groups 是分组卷积,为了加快卷积的速度 350 | 351 | mux2 = mux * mux 352 | muy2 = muy * muy 353 | muxy = mux * muy 354 | 355 | sigmax2 = F.conv2d(x * x, self.g_masks, groups=c, padding=self.pad) - mux2 356 | sigmay2 = F.conv2d(y * y, self.g_masks, groups=c, padding=self.pad) - muy2 357 | sigmaxy = F.conv2d(x * y, self.g_masks, groups=c, padding=self.pad) - muxy 358 | 359 | # l(j), cs(j) in MS-SSIM 360 | l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W] 361 | cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2) 362 | if self.channel == 3: 363 | lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :] # 亮度对比因子 364 | PIcs = cs.prod(dim=1) 365 | elif self.channel == 1: 366 | lM = l[:, -1, :, :] 367 | PIcs = cs.prod(dim=1) 368 | 369 | loss_ms_ssim = 1 - lM * PIcs # [B, H, W] 370 | # print(loss_ms_ssim) 371 | 372 | loss_l1 = F.l1_loss(x, y, reduction='none') # [B, C, H, W] 373 | # average l1 loss in num channels 374 | gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-self.channel, length=self.channel), 375 | groups=c, padding=self.pad).mean(1) # [B, H, W] 376 | 377 | loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR 378 | loss_mix = self.compensation * loss_mix 379 | 380 | return loss_mix.mean() 381 | 382 | class add_noise(nn.Module): 383 | def __init__(self, case, noiseIntL): 384 | super(add_noise, self).__init__() 385 | self.case = case 386 | self.noiseIntL = noiseIntL 387 | 388 | def forward(self,img_train): 389 | noise_S = torch.zeros(img_train.size()) 390 | if self.case == 0: 391 | # 随机定一个 分布的最大值 392 | beta = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], 393 | size=noise_S.size()[0]) # generate a row tensor, the insity of noise 394 | for m in range(noise_S.size()[0]): 395 | sizeN_S = noise_S[0, 0, :, :].size() 396 | noise_col = np.random.normal(0, beta[m], sizeN_S[1]) # row tensor 397 | S_noise = np.tile(noise_col, (sizeN_S[0], 1)) # flatten 398 | S_noise = np.expand_dims(S_noise, 0) # add dim 399 | S_noise = torch.from_numpy(S_noise) # to tensor 400 | noise_S[m, :, :, :] = S_noise # become primary shape 401 | 402 | imgn_trainC = img_train + noise_S 403 | imgn_train = torch.clip(imgn_trainC, 0., 1.) 404 | 405 | 406 | elif self.case == 1: 407 | beta1 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 408 | beta2 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 409 | for m in range(noise_S.size()[0]): 410 | sizeN_S = noise_S[0, 0, :, :].size() 411 | A1 = np.random.normal(0, beta1[m], sizeN_S[1]) # 一行向量 412 | A2 = np.random.normal(0, beta2[m], sizeN_S[1]) # 一行向量 413 | # flatten 414 | A1 = np.tile(A1, (sizeN_S[0], 1)) 415 | A2 = np.tile(A2, (sizeN_S[0], 1)) 416 | # add dim 417 | A1 = np.expand_dims(A1, 0) 418 | A2 = np.expand_dims(A2, 0) 419 | # to tensor 420 | A1 = torch.from_numpy(A1) 421 | A2 = torch.from_numpy(A2) 422 | imgn_train_m = A1 + A2 * img_train[m] + img_train[m] 423 | imgn_train_m_c = torch.clip(imgn_train_m, 0., 1.) 424 | noise_S[m, :, :, :] = imgn_train_m_c 425 | imgn_train = noise_S 426 | 427 | elif self.case == 2: 428 | beta1 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 429 | beta2 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 430 | beta3 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 431 | for m in range(noise_S.size()[0]): 432 | sizeN_S = noise_S[0, 0, :, :].size() 433 | A1 = np.random.normal(0, beta1[m], sizeN_S[1]) # 一行向量 434 | A2 = np.random.normal(0, beta2[m], sizeN_S[1]) # 一行向量 435 | A3 = np.random.normal(0, beta3[m], sizeN_S[1]) # 一行向量 436 | # 拉伸 437 | A1 = np.tile(A1, (sizeN_S[0], 1)) 438 | A2 = np.tile(A2, (sizeN_S[0], 1)) 439 | A3 = np.tile(A3, (sizeN_S[0], 1)) 440 | # add dim 441 | A1 = np.expand_dims(A1, 0) 442 | A2 = np.expand_dims(A2, 0) 443 | A3 = np.expand_dims(A3, 0) 444 | # to tensor 445 | A1 = torch.from_numpy(A1) 446 | A2 = torch.from_numpy(A2) 447 | A3 = torch.from_numpy(A3) 448 | imgn_train_m = A1 + A2 * img_train[m] + A3 * A3 * img_train[m] + img_train[m] 449 | imgn_train_m_c = torch.clip(imgn_train_m, 0., 1.) 450 | noise_S[m, :, :, :] = imgn_train_m_c 451 | imgn_train = noise_S 452 | 453 | elif self.case == 3: 454 | beta1 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 455 | beta2 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 456 | beta3 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 457 | beta4 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 458 | 459 | for m in range(noise_S.size()[0]): 460 | sizeN_S = noise_S[0, 0, :, :].size() 461 | A1 = np.random.normal(0, beta1[m], sizeN_S[1]) # 一行向量 462 | A2 = np.random.normal(0, beta2[m], sizeN_S[1]) # 一行向量 463 | A3 = np.random.normal(0, beta3[m], sizeN_S[1]) # 一行向量 464 | A4 = np.random.normal(0, beta4[m], sizeN_S[1]) # 一行向量 465 | # 拉伸 466 | A1 = np.tile(A1, (sizeN_S[0], 1)) 467 | A2 = np.tile(A2, (sizeN_S[0], 1)) 468 | A3 = np.tile(A3, (sizeN_S[0], 1)) 469 | A4 = np.tile(A4, (sizeN_S[0], 1)) 470 | # add dim 471 | A1 = np.expand_dims(A1, 0) 472 | A2 = np.expand_dims(A2, 0) 473 | A3 = np.expand_dims(A3, 0) 474 | A4 = np.expand_dims(A4, 0) 475 | # to tensor 476 | A1 = torch.from_numpy(A1).cuda() 477 | A2 = torch.from_numpy(A2).cuda() 478 | A3 = torch.from_numpy(A3).cuda() 479 | A4 = torch.from_numpy(A4).cuda() 480 | imgn_train_m = A1 + A2 * img_train[m] + A3 * A3 * img_train[m] + A4 * A4 * A4 * img_train[m] + \ 481 | img_train[m] 482 | imgn_train_m_c = torch.clip(imgn_train_m, 0., 1.) 483 | noise_S[m, :, :, :] = imgn_train_m_c 484 | imgn_train = noise_S.cuda() 485 | elif self.case == 4: 486 | beta1 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 487 | beta2 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 488 | beta3 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 489 | beta4 = np.random.uniform(self.noiseIntL[0], self.noiseIntL[1], size=noise_S.size()[0]) 490 | 491 | for m in range(noise_S.size()[0]): 492 | sizeN_S = noise_S[0, 0, :, :].size() 493 | A1 = np.random.normal(0, beta1[m], sizeN_S[1]) # 一行向量 494 | A2 = np.random.normal(0, beta2[m], sizeN_S[1]) # 一行向量 495 | A3 = np.random.normal(0, beta3[m], sizeN_S[1]) # 一行向量 496 | A4 = np.random.normal(0, beta4[m], sizeN_S[1]) # 一行向量 497 | # 拉伸 498 | A1 = np.tile(A1, (sizeN_S[0], 1)) 499 | A2 = np.tile(A2, (sizeN_S[0], 1)) 500 | A3 = np.tile(A3, (sizeN_S[0], 1)) 501 | A4 = np.tile(A4, (sizeN_S[0], 1)) 502 | # add dim 503 | A1 = np.expand_dims(A1, 0) 504 | A2 = np.expand_dims(A2, 0) 505 | A3 = np.expand_dims(A3, 0) 506 | A4 = np.expand_dims(A4, 0) 507 | # to tensor 508 | A1 = torch.from_numpy(A1) 509 | A2 = torch.from_numpy(A2) 510 | A3 = torch.from_numpy(A3) 511 | A4 = torch.from_numpy(A4) 512 | imgn_train_m = A1 + A2 * img_train[m] + A3 * A3 * img_train[m] + A4 * A4 * A4 * img_train[m] + \ 513 | img_train[m] 514 | # imgn_train_m_c = torch.clip(imgn_train_m, 0., 1.) 515 | noise_S[m, :, :, :] = imgn_train_m 516 | imgn_train = noise_S 517 | return imgn_train 518 | --------------------------------------------------------------------------------