├── model └── wavelet.mat ├── data ├── rain12_test │ ├── 1.png │ ├── 10.png │ ├── 11.png │ ├── 12.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ └── 9.png └── rain12_test.list ├── README.md ├── dataset.py ├── test.py ├── main.py └── networks.py /model/wavelet.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/model/wavelet.mat -------------------------------------------------------------------------------- /data/rain12_test/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/1.png -------------------------------------------------------------------------------- /data/rain12_test/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/10.png -------------------------------------------------------------------------------- /data/rain12_test/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/11.png -------------------------------------------------------------------------------- /data/rain12_test/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/12.png -------------------------------------------------------------------------------- /data/rain12_test/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/2.png -------------------------------------------------------------------------------- /data/rain12_test/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/3.png -------------------------------------------------------------------------------- /data/rain12_test/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/4.png -------------------------------------------------------------------------------- /data/rain12_test/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/5.png -------------------------------------------------------------------------------- /data/rain12_test/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/6.png -------------------------------------------------------------------------------- /data/rain12_test/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/7.png -------------------------------------------------------------------------------- /data/rain12_test/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/8.png -------------------------------------------------------------------------------- /data/rain12_test/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hhb072/SWAL/HEAD/data/rain12_test/9.png -------------------------------------------------------------------------------- /data/rain12_test.list: -------------------------------------------------------------------------------- 1 | 1.png 2 | 2.png 3 | 3.png 4 | 4.png 5 | 5.png 6 | 6.png 7 | 7.png 8 | 8.png 9 | 9.png 10 | 10.png 11 | 11.png 12 | 12.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SWAL 2 | Code for Paper ["Selective Wavelet Attention Learning for Single Image Deraining"](https://link.springer.com/article/10.1007/s11263-020-01421-z) 3 | 4 | ## Prerequisites 5 | * Python 3 6 | * PyTorch 7 | 8 | ## Models 9 | 10 | We provide the models trained on DDN, DID, Rain100H, Rain100L, and AGAN datasets in the following links: 11 | 12 | * [Google Driver](https://drive.google.com/drive/folders/1rOuxUmOEHf_6t7-ZhNfrvbwRj-Se_oFA?usp=sharing) 13 | * [Jianguo Yun](https://www.jianguoyun.com/p/DbB0gXUQiaCuBxi37v0D) 14 | 15 | Download them into the *model* folder before testing. 16 | 17 | ## Dataset 18 | 19 | 1. Download the rain datasets. 20 | 2. Arrange the images and generate a list file, just like the rain12 set in the *data* folder. 21 | 22 | You can also modify the data_loader code in your manner. 23 | 24 | ## Run 25 | 26 | Train SWAL on a single GPU: 27 | 28 | CUDA_VISIBLE_DEVICES=0 python main.py --ngf=16 --ndf=64 --output_height=320 --trainroot=YOURPATH --trainfiles='YOUR_FILELIST' --save_iter=1 --batchSize=8 --nrow=8 --lr_d=1e-4 --lr_g=1e-4 --cuda --nEpochs=500 29 | 30 | Train SWAL on multiple GPUs: 31 | 32 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --ngf=16 --ndf=64 --output_height=320 --trainroot=YOURPATH --trainfiles='YOUR_FILELIST' --save_iter=1 --batchSize=32 --nrow=8 --lr_d=1e-4 --lr_g=1e-4 --cuda --nEpochs=500 33 | 34 | Test SWAL: 35 | 36 | CUDA_VISIBLE_DEVICES=0 python test.py --ngf=16 --outf='test' --testroot='data/rain12_test' --testfiles='data/rain12_test.list' --pretrained='model/rain100l_best.pth' --cuda 37 | 38 | Adjust the parameters according to your own settings. 39 | 40 | ## Citation 41 | 42 | If you use our codes, please cite the following paper: 43 | 44 | @article{huang2021selective, 45 | title={Selective Wavelet Attention Learning for Single Image Deraining}, 46 | author={Huang, Huaibo and Yu, Aijing and Chai, Zhenhua and He, Ran and Tan, Tieniu}, 47 | journal={International Journal of Computer Vision}, 48 | volume={129}, 49 | number={4}, 50 | pages={1282--1300}, 51 | year={2021}, 52 | } 53 | 54 | **The released codes are only allowed for non-commercial use.** 55 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | from os import listdir 5 | from os.path import join 6 | from PIL import Image, ImageOps 7 | import random 8 | import torchvision.transforms as transforms 9 | import torchvision.transforms.functional as TF 10 | 11 | 12 | 13 | def load_image(file_path, input_height=None, input_width=None, output_height=None, output_width=None, 14 | crop_height=None, crop_width=None, is_random_crop=True, is_mirror=True, is_gray=False, align=32): 15 | 16 | if input_width is None: 17 | input_width = input_height 18 | if output_width is None: 19 | output_width = output_height 20 | if crop_width is None: 21 | crop_width = crop_height 22 | 23 | img = Image.open(file_path) 24 | if is_gray is False and img.mode != 'RGB': 25 | img = img.convert('RGB') 26 | if is_gray and img.mode != 'L': 27 | img = img.convert('L') 28 | 29 | 30 | [w, h] = img.size 31 | imgR = ImageOps.crop(img, (0, 0, w//2, 0)) 32 | imgL = ImageOps.crop(img, (w//2, 0, 0, 0)) 33 | 34 | if is_mirror and random.randint(0,1) == 0: 35 | imgR = ImageOps.mirror(imgR) 36 | imgL = ImageOps.mirror(imgL) 37 | 38 | if input_height is not None: 39 | imgR = imgR.resize((input_width, input_height),Image.BICUBIC) 40 | imgL = imgL.resize((input_width, input_height),Image.BICUBIC) 41 | 42 | [w, h] = imgR.size 43 | if crop_height is not None: 44 | if is_random_crop: 45 | #print([w,cropSize]) 46 | cx1 = random.randint(0, w-crop_width) if crop_width < w else 0 47 | cx2 = w - crop_width - cx1 48 | cy1 = random.randint(0, h-crop_height) if crop_height < h else 0 49 | cy2 = h - crop_height - cy1 50 | else: 51 | cx2 = cx1 = int(round((w-crop_width)/2.)) 52 | cy2 = cy1 = int(round((h-crop_height)/2.)) 53 | imgR = ImageOps.crop(imgR, (cx1, cy1, cx2, cy2)) 54 | imgL = ImageOps.crop(imgL, (cx1, cy1, cx2, cy2)) 55 | if output_height is not None: 56 | imgR = imgR.resize((output_width, output_height),Image.BICUBIC) 57 | imgL = imgL.resize((output_width, output_height),Image.BICUBIC) 58 | 59 | [w, h] = imgR.size 60 | h1 = h // align * align 61 | w1 = w // align * align 62 | if h1 != h or w1 != w: 63 | imgR = imgR.resize((w1, h1),Image.BILINEAR) 64 | imgL = imgL.resize((w1, h1),Image.BILINEAR) 65 | return imgR, imgL 66 | 67 | class ImageDatasetFromFile(data.Dataset): 68 | def __init__(self, image_list, root_path, 69 | input_height=None, input_width=None, output_height=None, output_width=None, 70 | crop_height=None, crop_width=None, is_random_crop=False, is_mirror=True, is_gray=False, normalize=None): 71 | super(ImageDatasetFromFile, self).__init__() 72 | 73 | self.image_filenames = image_list 74 | self.is_random_crop = is_random_crop 75 | self.is_mirror = is_mirror 76 | self.input_height = input_height 77 | self.input_width = input_width 78 | self.output_height = output_height 79 | self.output_width = output_width 80 | self.root_path = root_path 81 | self.crop_height = crop_height 82 | self.crop_width = crop_width 83 | self.is_gray = is_gray 84 | 85 | if normalize is None: 86 | self.transform = transforms.Compose([transforms.ToTensor()]) 87 | else: 88 | self.transform = transforms.Compose([transforms.ToTensor(), normalize]) 89 | 90 | def __getitem__(self, index): 91 | 92 | imgR, imgL = load_image(join(self.root_path, self.image_filenames[index]), 93 | self.input_height, self.input_width, self.output_height, self.output_width, 94 | self.crop_height, self.crop_width, self.is_random_crop, self.is_mirror, self.is_gray) 95 | 96 | imgR = self.transform(imgR) 97 | imgL = self.transform(imgL) 98 | 99 | return imgR, imgL 100 | 101 | def __len__(self): 102 | return len(self.image_filenames) 103 | 104 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.datasets as dset 12 | from torch.utils.data import DataLoader 13 | from dataset import * 14 | import time 15 | import numpy as np 16 | import torchvision.utils as vutils 17 | from torch.autograd import Variable 18 | from networks import * 19 | import math 20 | from math import log10 21 | import torchvision 22 | import torch.nn.functional as F 23 | from torchvision.utils import make_grid, save_image 24 | import torchvision.transforms as transforms 25 | 26 | import skimage 27 | from skimage import io,transform 28 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 29 | from skimage.metrics import structural_similarity as compare_ssim 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--ngf", type=int, default=64, help="dim of the latent code, Default=512") 33 | parser.add_argument("--ndf", type=int, default=64, help="dim of the latent code, Default=512") 34 | parser.add_argument('--nrow', type=int, help='the number of images in each row', default=8) 35 | parser.add_argument('--trainfiles', default="celeba_hq_attr.list", type=str, help='the list of training files') 36 | parser.add_argument('--trainroot', default="/home/huaibo.huang/data/celeba-hq/celeba-hq-wx-256", type=str, help='path to dataset') 37 | parser.add_argument('--testfiles', default="celeba_hq_attr.list", type=str, help='the list of training files') 38 | parser.add_argument('--testroot', default="/home/huaibo.huang/data/celeba-hq/celeba-hq-wx-256", type=str, help='path to dataset') 39 | parser.add_argument('--trainsize', type=int, help='number of training data', default=28000) 40 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=12) 41 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 42 | parser.add_argument('--input_height', type=int, default=128, help='the height of the input image to network') 43 | parser.add_argument('--input_width', type=int, default=None, help='the width of the input image to network') 44 | parser.add_argument('--output_height', type=int, default=128, help='the height of the output image to network') 45 | parser.add_argument('--output_width', type=int, default=None, help='the width of the output image to network') 46 | parser.add_argument('--crop_height', type=int, default=None, help='the width of the output image to network') 47 | parser.add_argument('--crop_width', type=int, default=None, help='the width of the output image to network') 48 | parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for") 49 | parser.add_argument("--start_epoch", default=0, type=int, help="Manual epoch number (useful on restarts)") 50 | parser.add_argument("--cur_iter", default=0, type=int, help="Manual epoch number (useful on restarts)") 51 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 52 | parser.add_argument('--outf', default='results/', help='folder to output images and model checkpoints') 53 | parser.add_argument('--manualSeed', type=int, help='manual seed') 54 | parser.add_argument('--tensorboard', action='store_true', help='enables tensorboard') 55 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 56 | 57 | str_to_list = lambda x: [int(xi) for xi in x.split(',')] 58 | 59 | def readlinesFromFile(path): 60 | print("Load from file %s" % path) 61 | f=open(path) 62 | data = [] 63 | for line in f: 64 | content = line.split() 65 | data.append(content[0]) 66 | 67 | f.close() 68 | return data 69 | 70 | class AverageMeter(object): 71 | """Computes and stores the average and current value""" 72 | def __init__(self): 73 | self.reset() 74 | 75 | def reset(self): 76 | self.val = 0 77 | self.avg = 0 78 | self.sum = 0 79 | self.count = 0 80 | 81 | def update(self, val, n=1): 82 | self.val = val 83 | self.sum += val * n 84 | self.count += n 85 | self.avg = self.sum / self.count 86 | 87 | def time2str(t): 88 | t = int(t) 89 | day = t // 86400 90 | hour = t % 86400 // 3600 91 | minute = t % 3600 // 60 92 | second = t % 60 93 | return "{:02d}/{:02d}/{:02d}/{:02d}".format(day, hour, minute, second) 94 | 95 | def compute_metrics(img, gt): 96 | img = img.numpy().transpose((0, 2, 3, 1)) 97 | gt = gt.numpy().transpose((0, 2, 3, 1)) 98 | img = img[0,:,:,:] * 255. 99 | gt = gt[0,:,:,:] * 255. 100 | img = np.array(img, dtype = 'uint8') 101 | gt = np.array(gt, dtype = 'uint8') 102 | gt = skimage.color.rgb2ycbcr(gt)[:,:,0] 103 | img = skimage.color.rgb2ycbcr(img)[:,:,0] 104 | cur_psnr = compare_psnr(img, gt, data_range=255) 105 | cur_ssim = compare_ssim(img, gt, data_range=255) 106 | return cur_psnr, cur_ssim 107 | 108 | def main(): 109 | 110 | global opt, model 111 | opt = parser.parse_args() 112 | print(opt) 113 | 114 | if opt.manualSeed is None: 115 | opt.manualSeed = random.randint(1, 10000) 116 | print("Random Seed: ", opt.manualSeed) 117 | random.seed(opt.manualSeed) 118 | torch.manual_seed(opt.manualSeed) 119 | if opt.cuda: 120 | torch.cuda.manual_seed_all(opt.manualSeed) 121 | 122 | cudnn.benchmark = True 123 | 124 | if torch.cuda.is_available() and not opt.cuda: 125 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 126 | 127 | 128 | is_scale_back = False 129 | 130 | #--------------build models ------------------------- 131 | model = WaveletModel(ngf=opt.ngf, ndf=opt.ndf).cuda() 132 | if opt.pretrained: 133 | load_model(model, opt.pretrained) 134 | print(model) 135 | 136 | try: 137 | os.makedirs(opt.outf) 138 | except OSError: 139 | pass 140 | 141 | #-----------------load dataset-------------------------- 142 | test_list = readlinesFromFile(opt.testfiles) 143 | assert len(test_list) > 0 144 | 145 | test_set = ImageDatasetFromFile(test_list, opt.testroot, crop_height=None, output_height=None, is_random_crop=False, is_mirror=False, normalize=None) 146 | test_data_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1) 147 | 148 | start_time = time.time() 149 | 150 | def test(): 151 | psnrs = AverageMeter() 152 | ssims = AverageMeter() 153 | 154 | model.eval() 155 | with torch.no_grad(): 156 | for iteration, batch in enumerate(test_data_loader, 0): 157 | 158 | data, label = batch 159 | if len(data.size()) == 3: 160 | data, label = data.unsqueeze(0), label.unsqueeze(0) 161 | 162 | data, label = Variable(data).cuda(), Variable(label).cuda() 163 | 164 | fake, _, _ = model.generate(data) 165 | 166 | cur_psnr, cur_ssim = compute_metrics(fake.data.cpu(), label.data.cpu()) 167 | psnrs.update(cur_psnr) 168 | ssims.update(cur_ssim) 169 | 170 | 171 | info = "\n====> Test: time: {}: Iter: {}, ".format(time2str(time.time()-start_time), iteration) 172 | info += 'PSNR: {:.5f}({:.5f}), SSIM: {:.5f}({:.5f}), '.format(psnrs.val, psnrs.avg, ssims.val, ssims.avg) 173 | print(info, flush=True) 174 | 175 | vutils.save_image(fake.data.cpu(), '{}/fake-{:03d}.png'.format(opt.outf, iteration)) 176 | 177 | 178 | print(' * Total PSNR: {:.5f}, SSIM: {:.5f}, ' 179 | .format(psnrs.avg, ssims.avg)) 180 | 181 | test() 182 | 183 | def load_model(model, pretrained, strict=False): 184 | state = torch.load(pretrained) 185 | model.G.load_state_dict(state, strict=strict) 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.datasets as dset 12 | from torch.utils.data import DataLoader 13 | from dataset import * 14 | import time 15 | import numpy as np 16 | import torchvision.utils as vutils 17 | from torch.autograd import Variable 18 | from networks import * 19 | from math import log10 20 | import torchvision 21 | import torch.nn.functional as F 22 | from torchvision.utils import make_grid, save_image 23 | import torchvision.transforms as transforms 24 | import skimage 25 | from skimage import io,transform 26 | # from skimage.measure import compare_psnr, compare_ssim 27 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 28 | from skimage.metrics import structural_similarity as compare_ssim 29 | # from tensorboardX import SummaryWriter 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--lr_d', type=float, default=0.0001, help='learning rate of the encoder, default=0.0002') 33 | parser.add_argument('--lr_g', type=float, default=0.0001, help='learning rate of the generator, default=0.0002') 34 | parser.add_argument("--weight_align", type=float, default=1e-5, help="Default=1.0") 35 | parser.add_argument("--weight_aug", type=float, default=1e-2, help="Default=1.0") 36 | parser.add_argument("--weight_adv", type=float, default=1e-3, help="Default=1.0") 37 | parser.add_argument("--ngf", type=int, default=16, help="dim of the latent code, Default=512") 38 | parser.add_argument("--ndf", type=int, default=64, help="dim of the latent code, Default=512") 39 | parser.add_argument("--save_iter", type=int, default=1, help="Default=1") 40 | parser.add_argument("--test_iter", type=int, default=1, help="Default=1") 41 | parser.add_argument('--num_mse', type=int, help='the number of images in each row', default=100000) 42 | parser.add_argument('--nrow', type=int, help='the number of images in each row', default=8) 43 | parser.add_argument('--trainfiles', default="train.list", type=str, help='the list of training files') 44 | parser.add_argument('--trainroot', default="trainroot", type=str, help='path to dataset') 45 | parser.add_argument('--testfiles', default="test.list", type=str, help='the list of training files') 46 | parser.add_argument('--testroot', default="testroot", type=str, help='path to dataset') 47 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=12) 48 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 49 | parser.add_argument('--input_height', type=int, default=128, help='the height of the input image to network') 50 | parser.add_argument('--input_width', type=int, default=None, help='the width of the input image to network') 51 | parser.add_argument('--output_height', type=int, default=128, help='the height of the output image to network') 52 | parser.add_argument('--output_width', type=int, default=None, help='the width of the output image to network') 53 | parser.add_argument('--crop_height', type=int, default=None, help='the width of the output image to network') 54 | parser.add_argument('--crop_width', type=int, default=None, help='the width of the output image to network') 55 | parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for") 56 | parser.add_argument("--start_epoch", default=0, type=int, help="Manual epoch number (useful on restarts)") 57 | parser.add_argument("--cur_iter", default=0, type=int, help="Manual epoch number (useful on restarts)") 58 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 59 | parser.add_argument('--outf', default='results/', help='folder to output images and model checkpoints') 60 | parser.add_argument('--manualSeed', type=int, help='manual seed') 61 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 62 | 63 | str_to_list = lambda x: [int(xi) for xi in x.split(',')] 64 | 65 | def readlinesFromFile(path): 66 | print("Load from file %s" % path) 67 | f=open(path) 68 | data = [] 69 | for line in f: 70 | content = line.split() 71 | data.append(content[0]) 72 | 73 | f.close() 74 | return data 75 | 76 | class AverageMeter(object): 77 | """Computes and stores the average and current value""" 78 | def __init__(self): 79 | self.reset() 80 | 81 | def reset(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.sum += val * n 90 | self.count += n 91 | self.avg = self.sum / self.count 92 | 93 | def time2str(t): 94 | t = int(t) 95 | day = t // 86400 96 | hour = t % 86400 // 3600 97 | minute = t % 3600 // 60 98 | second = t % 60 99 | return "{:02d}/{:02d}/{:02d}/{:02d}".format(day, hour, minute, second) 100 | 101 | def main(): 102 | 103 | global opt, model 104 | opt = parser.parse_args() 105 | print(opt) 106 | 107 | try: 108 | os.makedirs(opt.outf) 109 | except OSError: 110 | pass 111 | 112 | if opt.manualSeed is None: 113 | opt.manualSeed = random.randint(1, 10000) 114 | print("Random Seed: ", opt.manualSeed) 115 | random.seed(opt.manualSeed) 116 | torch.manual_seed(opt.manualSeed) 117 | if opt.cuda: 118 | torch.cuda.manual_seed_all(opt.manualSeed) 119 | 120 | cudnn.benchmark = True 121 | 122 | if torch.cuda.is_available() and not opt.cuda: 123 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 124 | 125 | 126 | opt.cur_iter = 0 127 | 128 | #--------------build models ------------------------- 129 | model = WaveletModel(ngf=opt.ngf, ndf=opt.ndf).cuda() 130 | print(model) 131 | if opt.pretrained: 132 | opt.start_epoch, opt.cur_iter = load_model(model, opt.pretrained) 133 | 134 | optimizerD = optim.Adam(filter(lambda p: p.requires_grad, model.D.parameters()), lr=opt.lr_d) 135 | optimizerG = optim.Adam(filter(lambda p: p.requires_grad, model.G.parameters()), lr=opt.lr_g) 136 | 137 | #-----------------load dataset-------------------------- 138 | train_list = readlinesFromFile(opt.trainfiles) 139 | assert len(train_list) > 0 140 | 141 | train_set = ImageDatasetFromFile(train_list, opt.trainroot, crop_height=opt.output_height, output_height=opt.output_height, is_random_crop=True, is_mirror=True, normalize=None) 142 | train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) 143 | 144 | start_time = time.time() 145 | 146 | def train_mse(epoch): 147 | model.train() 148 | lossesR = AverageMeter() 149 | 150 | 151 | for iteration, batch in enumerate(train_data_loader, 0): 152 | data, label = batch 153 | 154 | batch_size = data.shape[0] 155 | 156 | data, label = Variable(data).cuda(), Variable(label).cuda() 157 | 158 | #=========== Update batchSize ================== 159 | fake, att, res = model.generate(data) 160 | 161 | lossR = model.get_l1_loss(fake, label) 162 | 163 | 164 | lossG = lossR 165 | 166 | #=========== Backward && Update ================= 167 | lossesR.update(lossR.data, batch_size) 168 | 169 | optimizerG.zero_grad() 170 | lossG.backward() 171 | optimizerG.step() 172 | 173 | #========= Print loss ================= 174 | if iteration % 10 == 0: 175 | info = "\n====> Cur_iter: [{}]: Epoch[{}]({}/{}): time: {}: ".format(opt.cur_iter, epoch, iteration, len(train_data_loader), time2str(time.time()-start_time)) 176 | info += 'L1: {:.5f}({:.5f}), '.format(lossesR.val, lossesR.avg) 177 | print(info, flush=True) 178 | 179 | opt.cur_iter += 1 180 | 181 | def train_gan(epoch): 182 | model.train() 183 | lossesR = AverageMeter() 184 | lossesR_ = AverageMeter() 185 | lossesF = AverageMeter() 186 | 187 | for iteration, batch in enumerate(train_data_loader, 0): 188 | data, label = batch 189 | 190 | batch_size = data.shape[0] 191 | 192 | data, label = Variable(data).cuda(), Variable(label).cuda() 193 | 194 | fake, att, res = model.generate(data) 195 | 196 | alpha = random.random() * 0.5 197 | rec, _, res_ = model.generate(label*alpha+data*(1-alpha)) 198 | 199 | lossR = model.get_l1_loss(fake, label) 200 | lossR_ = model.get_l1_loss(rec, label) 201 | 202 | lossF = 0 203 | for r0, r1 in zip(res, res_): 204 | lossF += model.get_l1_loss(r0,r1.detach()) 205 | 206 | lossG = lossR + lossR_ * opt.weight_aug + lossF * opt.weight_align # 207 | 208 | #=========== Update D ================== 209 | realD = model.discriminate(torch.cat((data, label), dim=1)) 210 | fakeD = model.discriminate(torch.cat((data, fake.detach()), dim=1)) 211 | 212 | lossD_real = model.get_ls_loss(realD, 1.0) 213 | lossD_fake = model.get_ls_loss(fakeD, 0.0) 214 | 215 | lossD = lossD_real + lossD_fake 216 | 217 | optimizerD.zero_grad() 218 | lossD.backward() 219 | optimizerD.step() 220 | 221 | #========== Update G =================== 222 | fakeG = model.discriminate(torch.cat((data, fake), dim=1)) 223 | 224 | lossG_fake = model.get_ls_loss(fakeG, 1.0) 225 | 226 | lossG += lossG_fake * opt.weight_adv 227 | 228 | optimizerG.zero_grad() 229 | lossG.backward() 230 | optimizerG.step() 231 | 232 | lossesR.update(lossR.data, batch_size) 233 | lossesR_.update(lossR_.data, batch_size) 234 | lossesF.update(lossF.data, batch_size) 235 | 236 | #========= Print loss ================= 237 | if iteration % 10 == 0: 238 | info = "\n====> Cur_iter: [{}]: Epoch[{}]({}/{}): time: {}: ".format(opt.cur_iter, epoch, iteration, len(train_data_loader), time2str(time.time()-start_time)) 239 | info += 'L1: {:.5f}({:.5f}), '.format(lossesR.val, lossesR.avg) 240 | info += 'L1R: {:.5f}({:.5f}), '.format(lossesR_.val, lossesR_.avg) 241 | info += 'F: {:.5f}({:.5f}), '.format(lossesF.val, lossesF.avg) 242 | info += 'Adv: {:.3f}, {:.3f}, {:.4f}, '.format(lossD_real.data, lossD_fake.data, lossG_fake.data) 243 | print(info, flush=True) 244 | 245 | opt.cur_iter += 1 246 | 247 | #----------------Train by epochs-------------------------- 248 | model.train() 249 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): 250 | if opt.cur_iter < opt.num_mse: 251 | train_mse(epoch) 252 | else: 253 | train_gan(epoch) 254 | 255 | if epoch % opt.save_iter == 0: 256 | save_epoch = (epoch//opt.save_iter)*opt.save_iter 257 | save_checkpoint(model, save_epoch, opt.cur_iter, '', opt.manualSeed) 258 | 259 | def load_model(model, pretrained, strict=False): 260 | state = torch.load(pretrained) 261 | model.G.load_state_dict(state['G'], strict=strict) 262 | return state['epoch'], state['iter'] 263 | 264 | def save_checkpoint(model, epoch, iteration, prefix="", manualSeed=0): 265 | model_out_path = "model/" + prefix +"model_epoch_{}_iter_{}_seed_{}.pth".format(epoch, iteration, manualSeed) 266 | state = {"epoch": epoch, "iter": iteration, "G": model.G.state_dict(), "D": model.D.state_dict()} 267 | if not os.path.exists("model/"): 268 | os.makedirs("model/") 269 | 270 | torch.save(state, model_out_path) 271 | 272 | print("Checkpoint saved to {}".format(model_out_path)) 273 | 274 | 275 | if __name__ == "__main__": 276 | main() 277 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | from torchvision import datasets, transforms 7 | import torch.nn.functional as F 8 | import torchvision.utils as vutils 9 | import torchvision.models as models 10 | import math 11 | import time 12 | import torch.multiprocessing as multiprocessing 13 | from torch.nn.parallel import data_parallel 14 | import torch.nn.utils.spectral_norm as spectral_norm 15 | import scipy.io as sio 16 | import copy 17 | from functools import wraps 18 | 19 | 20 | class WaveletTransform(nn.Module): 21 | def __init__(self, scale=1, dec=True, params_path='model/wavelet.mat', transpose=True, cdim=3): 22 | super(WaveletTransform, self).__init__() 23 | 24 | self.scale = scale 25 | self.dec = dec 26 | self.transpose = transpose 27 | if scale == 0: 28 | return 29 | 30 | self.cdim = cdim 31 | 32 | ks = int(2**self.scale) 33 | nc = cdim * ks * ks 34 | 35 | if dec: 36 | self.conv = nn.Conv2d(in_channels=cdim, out_channels=nc, kernel_size=ks, stride=ks, padding=0, groups=cdim, bias=False) 37 | else: 38 | self.conv = nn.ConvTranspose2d(in_channels=nc, out_channels=cdim, kernel_size=ks, stride=ks, padding=0, groups=cdim, bias=False) 39 | 40 | # init 41 | dct = sio.loadmat(params_path) 42 | self.conv.weight.data = torch.from_numpy(dct['rec%d' % ks])[:ks*ks].repeat(cdim, 1, 1, 1) 43 | self.conv.weight.requires_grad = False 44 | 45 | def forward(self, x): 46 | if self.scale == 0: 47 | return x 48 | 49 | if self.dec: 50 | output = self.conv(x) 51 | if self.transpose: 52 | osz = output.size() 53 | output = output.view(osz[0], self.cdim, -1, osz[2], osz[3]).transpose(1,2).contiguous().view(osz) 54 | else: 55 | if self.transpose: 56 | xsz = x.size() 57 | x = x.view(xsz[0], -1, self.cdim, xsz[2], xsz[3]).transpose(1,2).contiguous().view(xsz) 58 | output = self.conv(x) 59 | return output 60 | 61 | class _BN_Residual_Block(nn.Module): 62 | def __init__(self, in_channels, out_channels, groups=1, act=nn.ReLU(True), bn=False, wider_channel=True, last=True): 63 | super().__init__() 64 | 65 | self.last = last 66 | mid_channels = out_channels if out_channels > in_channels else in_channels 67 | 68 | if in_channels is not out_channels: 69 | self.expand_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=groups, bias=False) 70 | self.expand_bn = nn.BatchNorm2d(out_channels) if bn else None 71 | else: 72 | self.expand_conv, self.expand_bn = None, None 73 | 74 | self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, groups=groups, bias=False) 75 | self.bn1 = nn.BatchNorm2d(mid_channels) if bn else None 76 | self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, groups=groups, bias=False) 77 | self.bn2 = nn.BatchNorm2d(out_channels) if bn else None 78 | if not last: 79 | self.act = act 80 | 81 | def forward(self, x): 82 | if self.expand_conv is not None: 83 | id = self.expand_conv(x) 84 | if (not self.last) and (self.expand_bn is not None): 85 | id = self.expand_bn(id) 86 | else: 87 | id = x 88 | x = self.conv1(x) 89 | if self.bn1 is not None: 90 | x = self.bn1(x) 91 | x = F.relu(x) 92 | x = self.conv2(x) 93 | if (not self.last) and (self.bn2 is not None): 94 | x = self.bn2(x) 95 | x = x + id 96 | if not self.last: 97 | x = self.act(x) 98 | return x 99 | 100 | class WaveletPooling(nn.Module): 101 | def __init__(self, in_channels, out_channels, downsample=True, scale=1, reduction=16, hdim=32): 102 | super().__init__() 103 | 104 | num_wavelets = 4**scale 105 | self.wavelet_dec = WaveletTransform(scale, dec=True, cdim=in_channels, transpose=True) 106 | 107 | self.process = _BN_Residual_Block(in_channels*num_wavelets, out_channels*num_wavelets, groups=num_wavelets) 108 | 109 | self.att0 = nn.Sequential( 110 | nn.Conv2d(in_channels, hdim, 3, 2, 1, bias=False), 111 | nn.ReLU(True), 112 | ) 113 | for i in range(scale - 1): 114 | self.att0.add_module('conv_{}'.format(i), nn.Conv2d(out_channels, out_channels, 3, 2, 1, bias=False)) 115 | self.att0.add_module('relu_{}'.format(i), nn.ReLU(True)) 116 | 117 | self.att1 = nn.Sequential( 118 | nn.Conv2d(out_channels*num_wavelets+hdim, hdim, 1, 1, 0, bias=False), 119 | nn.ReLU(True), 120 | nn.Conv2d(hdim, 4, 1, 1, 0, bias=False), 121 | nn.Sigmoid(), 122 | ) 123 | 124 | 125 | if downsample: 126 | self.pool = nn.Conv2d(out_channels*num_wavelets, out_channels, 1, 1, 0, bias=False) 127 | else: 128 | assert in_channels == out_channels 129 | self.pool = WaveletTransform(scale, dec=False, cdim=in_channels, transpose=True) 130 | 131 | def reparameterize(self, mu, logvar): 132 | std = logvar.mul(0.5).exp_() 133 | 134 | eps = torch.cuda.FloatTensor(std.size()).normal_() 135 | eps = Variable(eps) 136 | 137 | return eps.mul(std).add_(mu) 138 | 139 | def forward(self, x, prior_att=None): 140 | att = self.att0(x) 141 | x = self.wavelet_dec(x) 142 | x = self.process(x) 143 | if prior_att is None: 144 | att = self.att1(torch.cat((att, x), dim=1)) 145 | 146 | else: 147 | att = prior_att 148 | 149 | out_att = att 150 | att = att.unsqueeze(2).repeat(1, 1, x.shape[1]//att.shape[1], 1, 1).flatten(1,2) 151 | x_att = x * att 152 | x_res = x - x_att 153 | x_forward = self.pool(x_att) 154 | return x_forward, x_res, out_att 155 | 156 | class WaveletUpsample(nn.Module): 157 | def __init__(self, in_channels, out_channels, upsample=True, scale=1, last=False): 158 | super().__init__() 159 | 160 | num_wavelets = 4**scale 161 | wavelet_channels = out_channels * num_wavelets 162 | if upsample: 163 | self.up = nn.Conv2d(in_channels, in_channels*num_wavelets, 1, 1, 0, bias=False) 164 | else: 165 | assert in_channels == out_channels 166 | self.up = WaveletTransform(scale, dec=True, cdim=in_channels) 167 | self.process = _BN_Residual_Block(in_channels*num_wavelets, out_channels*num_wavelets, groups=num_wavelets, last=last) 168 | self.wavelet_rec = WaveletTransform(scale, dec=False, cdim=out_channels) 169 | 170 | if last: 171 | self.uncertainty = nn.ConvTranspose2d(in_channels*num_wavelets, 1, 4, 2, 0) 172 | else: 173 | self.uncertainty = None 174 | 175 | 176 | def forward(self, x, res): 177 | x = self.up(x) 178 | x = x + res 179 | x = self.process(x) 180 | x = self.wavelet_rec(x) 181 | return x 182 | 183 | class ResPool(nn.Module): 184 | def __init__(self, in_channels, out_channels, downsample, num_layers=1, scale=1, reduction=8): 185 | super().__init__() 186 | 187 | self.pool = WaveletPooling(in_channels, out_channels, downsample, reduction=reduction, scale=scale) 188 | self.res = nn.ModuleList([_BN_Residual_Block(out_channels, out_channels) for i in range(num_layers)]) 189 | 190 | def forward(self, x, prior_att=None): 191 | x, res, att = self.pool(x, prior_att) 192 | for block in self.res: 193 | x = block(x) 194 | 195 | return x, res, att 196 | 197 | class ResUpsample(nn.Module): 198 | def __init__(self, in_channels, out_channels, upsample, num_layers=1, scale=1, last=True): 199 | super().__init__() 200 | 201 | self.res = nn.ModuleList([_BN_Residual_Block(in_channels, in_channels) for i in range(num_layers)]) 202 | self.up = WaveletUpsample(in_channels, out_channels, upsample, last=last, scale=scale) 203 | 204 | def forward(self, x, res): 205 | for block in self.res: 206 | x = block(x) 207 | x = self.up(x, res) 208 | return x 209 | 210 | class Generator(nn.Module): 211 | def __init__(self, ngf=64): 212 | super().__init__() 213 | 214 | self.down, self.up = [], [] 215 | self.down += [ResPool(3, ngf, True, 16, reduction=2)] # 128 216 | self.up += [ResUpsample(ngf, 3, True, 16, last=True)] 217 | 218 | self.down += [ResPool(ngf, ngf*2, True, 16, reduction=4)] # 64 219 | self.up += [ResUpsample(ngf*2, ngf, True, 16)] 220 | 221 | self.down += [ResPool(ngf*2, ngf*4, True, 16, reduction=8)] # 32 222 | self.up += [ResUpsample(ngf*4, ngf*2, True, 16)] 223 | 224 | self.down += [ResPool(ngf*4, ngf*4, True, 16, reduction=8)] # 16 225 | self.up += [ResUpsample(ngf*4, ngf*4, True, 16)] 226 | 227 | self.down += [ResPool(ngf*4, ngf*4, True, 16, reduction=8)] # 8 228 | self.up += [ResUpsample(ngf*4, ngf*4, True, 16)] 229 | 230 | self.down, self.up = nn.ModuleList(self.down), nn.ModuleList(self.up[::-1]) 231 | 232 | def forward(self, x): 233 | x = x * 2 - 1 234 | res = [] 235 | att = [] 236 | 237 | for down in self.down: 238 | x, r, a = down(x) 239 | res.append(r) 240 | att.append(a.detach()) 241 | 242 | feats = [] 243 | for up, r in zip(self.up, res[::-1]): 244 | feats.append(x) 245 | x = up(x, r) 246 | 247 | x = torch.tanh(x) * 0.5 + 0.5 248 | 249 | return x, att, feats 250 | 251 | class Discriminator(nn.Module): 252 | def __init__(self, ngf=64): 253 | super().__init__() 254 | 255 | self.down, self.fc = [], [] 256 | self.down += [ResPool(6, ngf, True, 1, reduction=2)] # 128 257 | self.fc += [nn.Conv2d(ngf*4, 1, 5, 2, 0)] 258 | 259 | self.down += [ResPool(ngf, ngf, True, 1, reduction=2)] # 64 260 | self.fc += [nn.Conv2d(ngf*4, 1, 5, 2, 0)] 261 | 262 | self.down += [ResPool(ngf, ngf, True, 1, reduction=2)] # 32 263 | self.fc += [nn.Conv2d(ngf*4, 1, 5, 2, 0)] 264 | 265 | self.down += [ResPool(ngf, ngf, True, 1, reduction=2)] # 16 266 | self.fc += [nn.Conv2d(ngf*4, 1, 5, 2, 0)] 267 | 268 | self.down += [ResPool(ngf, ngf, True, 1, reduction=2)] # 8 269 | self.fc += [nn.Conv2d(ngf*4, 1, 5, 2, 0)] 270 | 271 | self.down = nn.ModuleList(self.down) 272 | self.fc = nn.ModuleList(self.fc) 273 | 274 | def forward(self, x): 275 | x = x * 2 - 1 276 | 277 | res = [] 278 | out = [] 279 | for down, fc in zip(self.down, self.fc): 280 | x, r, _ = down(x) 281 | o = fc(r) 282 | res.append(r.flatten(1)) 283 | out.append(o.flatten(1)) 284 | 285 | res = torch.cat(res, dim=1) 286 | out = torch.cat(out, dim=1) 287 | out = torch.sigmoid(out) 288 | return out 289 | 290 | class WaveletModel(nn.Module): 291 | def __init__(self, ngf=64, ndf=64): 292 | super().__init__() 293 | 294 | self.G = Generator(ngf) 295 | self.D = Discriminator(ndf) 296 | 297 | self.L1 = nn.L1Loss() 298 | 299 | def forward(self, x): 300 | pass 301 | 302 | def generate(self, x): 303 | return data_parallel(self.G, x) 304 | 305 | def discriminate(self, x): 306 | return data_parallel(self.D, x).mean() 307 | 308 | def get_l1_loss(self, x, y): 309 | return data_parallel(self.L1, (x, y)).mean() 310 | 311 | def get_ls_loss(self, x, y): 312 | return ((x - y)**2).mean() 313 | 314 | 315 | --------------------------------------------------------------------------------