├── README.md └── src ├── .DS_Store ├── .vscode └── settings.json ├── dataset.py ├── model.py ├── networks.py ├── options.py ├── saver.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Domain-Specific Deblurring via Disentangled Representations 2 | 3 | Pytorch implementation of the paper [Unsupervised Domain-Specific Deblurring via Disentangled Representations](https://arxiv.org/pdf/1903.01594.pdf). The proposed method is unsupervised and takes blurred domain-specific image (faces or text) as an input and procude the corresponding sharp estimate. 4 | 5 | contact: bylu@umiacs.umd.edu 6 | 7 | ## Sample Results 8 | 9 | TO BE ADDED... 10 | 11 | ## Dataset 12 | 13 | To train the model, unpaired sharp and blurred images folders should be named in the following format: `datasets/name/trainA` and `datasets/name/trainB`. Test images can be stored in the same folder and you may choose your own folder name. 14 | 15 | ## Usage 16 | 17 | ### Data Preperation 18 | 19 | In our experiment, face data is from [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) and text data is from [BMVC text dataset](http://www.fit.vutbr.cz/~ihradis/CNN-Deblur/). To manually blur the images, we use the method proposed in [DeblurGAN](https://github.com/KupynOrest/DeblurGAN/tree/master/motion_blur). The dataset are randomly split into three subsets: trainA (sharp), trainB(blur) and test set. 20 | 21 | ### Train 22 | 23 | To train the model, run the following command line in the source code directory. You may set other parameters based on your experiment setting. 24 | 25 | ```bash 26 | python train.py --dataroot ../datasets/DatasetName/ --name job_name --batch_size 2 --lambdaB 0.1 --lr 0.0002 27 | ``` 28 | 29 | ### Test 30 | Our pre-trained model for face and text can be downloaded [here](https://drive.google.com/drive/folders/1P0mP8JjfdV55tDK7a3fIU4yghVmaUJyF?usp=sharing). To test the model, run the following command line in the source code directory. You may set other parameters based on your experiment setting. To choose the perceptual loss type as face, you need to manually set the VGG face model path in `network.py`. VGG_face pretrained model can be found [here](https://drive.google.com/drive/folders/1P0mP8JjfdV55tDK7a3fIU4yghVmaUJyF?usp=sharing). 31 | 32 | 33 | ```bash 34 | python test.py --dataroot ../datasets/dataset_name/test_blur/ --num 1 --resume ../results/model/locations --name job_name --orig_dir ../datasets/dataset_name/test_orig --percep face 35 | ``` 36 | 37 | ## Citation 38 | 39 | If you find the code helpful in your research or work, please kindly cite our paper. 40 | 41 | ``` 42 | @inproceedings{lu2019unsupervised, 43 | title={Unsupervised domain-specific deblurring via disentangled representations}, 44 | author={Lu, Boyu and Chen, Jun-Cheng and Chellappa, Rama}, 45 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 46 | pages={10225--10234}, 47 | year={2019} 48 | } 49 | ``` 50 | ## Acknowledgments 51 | 52 | The code borrows heavily from [DRIT](https://github.com/HsinYingLee/DRIT). We use the image blurring method in [DeblurGAN](https://github.com/KupynOrest/DeblurGAN/tree/master/motion_blur). 53 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustclby/Unsupervised-Domain-Specific-Deblurring/c475b054f92a1b274c7dcb06cd5e622427af2cfc/src/.DS_Store -------------------------------------------------------------------------------- /src/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.linting.pylintEnabled": false, 3 | "python.linting.pydocstyleEnabled": true, 4 | "python.linting.enabled": true 5 | } -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize 5 | import random 6 | 7 | class dataset_single(data.Dataset): 8 | def __init__(self, opts, setname, input_dim): 9 | self.dataroot = opts.dataroot 10 | images = os.listdir(self.dataroot) 11 | self.img = [os.path.join(self.dataroot, x) for x in images] 12 | self.size = len(self.img) 13 | self.input_dim = input_dim 14 | transforms = [ToTensor()] 15 | transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 16 | self.transforms = Compose(transforms) 17 | print('%s: %d images'%(setname, self.size)) 18 | return 19 | 20 | def __getitem__(self, index): 21 | data = self.load_img(self.img[index], self.input_dim) 22 | return data 23 | 24 | def load_img(self, img_name, input_dim): 25 | img = Image.open(img_name).convert('RGB') 26 | img = self.transforms(img) 27 | if input_dim == 1: 28 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 29 | img = img.unsqueeze(0) 30 | return img, img_name 31 | 32 | def __len__(self): 33 | return self.size 34 | 35 | class dataset_unpair(data.Dataset): 36 | def __init__(self, opts): 37 | self.dataroot = opts.dataroot 38 | # A 39 | images_A = sorted(os.listdir(os.path.join(self.dataroot, opts.phase + 'A'))) 40 | self.A = [os.path.join(self.dataroot, opts.phase + 'A', x) for x in images_A] 41 | # B 42 | images_B = sorted(os.listdir(os.path.join(self.dataroot, opts.phase + 'B'))) 43 | self.B = [os.path.join(self.dataroot, opts.phase + 'B', x) for x in images_B] 44 | 45 | self.A_size = len(self.A) 46 | self.B_size = len(self.B) 47 | self.dataset_size = max(self.A_size, self.B_size) 48 | self.input_dim_A = opts.input_dim_a 49 | self.input_dim_B = opts.input_dim_b 50 | self.resize_x = opts.resize_size_x 51 | self.resize_y = opts.resize_size_y 52 | 53 | if opts.phase == 'train': 54 | transforms = [RandomCrop(opts.crop_size)] 55 | else: 56 | transforms = [CenterCrop(opts.crop_size)] 57 | if not opts.no_flip: 58 | transforms.append(RandomHorizontalFlip()) 59 | transforms.append(ToTensor()) 60 | transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 61 | self.transforms = Compose(transforms) 62 | print('A: %d, B: %d images'%(self.A_size, self.B_size)) 63 | return 64 | 65 | def __getitem__(self, index): 66 | if self.dataset_size == self.A_size: 67 | data_A = self.load_img(self.A[index], self.input_dim_A) 68 | temp_b_index = random.randint(0, self.B_size - 1) 69 | data_B = self.load_img(self.B[temp_b_index], self.input_dim_B) 70 | else: 71 | data_A = self.load_img(self.A[random.randint(0, self.A_size - 1)], self.input_dim_A) 72 | data_B = self.load_img(self.B[index], self.input_dim_B) 73 | return data_A, data_B 74 | 75 | def load_img(self, img_name, input_dim): 76 | 77 | img = Image.open(img_name).convert('RGB') 78 | (w,h) = img.size 79 | if w < h: 80 | resize_x = self.resize_x 81 | resize_y = round(self.resize_x * h / w) 82 | else: 83 | resize_y = self.resize_y 84 | resize_x = round(self.resize_y * w / h) 85 | resize_img = Compose([Resize((resize_y, resize_x), Image.BICUBIC)]) 86 | img = resize_img(img) 87 | img = self.transforms(img) 88 | if input_dim == 1: 89 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 90 | img = img.unsqueeze(0) 91 | return img 92 | 93 | def __len__(self): 94 | return self.dataset_size 95 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import networks 2 | import torch 3 | import torch.nn as nn 4 | import time 5 | 6 | 7 | class UID(nn.Module): 8 | def __init__(self, opts): 9 | super(UID, self).__init__() 10 | 11 | # parameters 12 | lr = opts.lr 13 | self.nz = 8 14 | self.concat = opts.concat 15 | self.lambdaB = opts.lambdaB 16 | self.lambdaI = opts.lambdaI 17 | 18 | # discriminators 19 | if opts.dis_scale > 1: 20 | self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 21 | self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 22 | self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 23 | self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 24 | else: 25 | self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 26 | self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 27 | self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 28 | self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 29 | 30 | # encoders 31 | self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b) 32 | if self.concat: 33 | self.enc_a = networks.E_attr_concat(opts.input_dim_b, self.nz, \ 34 | norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu')) 35 | else: 36 | self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz) 37 | 38 | # generator 39 | if self.concat: 40 | self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz) 41 | else: 42 | self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz) 43 | 44 | # optimizers 45 | self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 46 | self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 47 | self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 48 | self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 49 | 50 | self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 51 | self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 52 | self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 53 | 54 | # Setup the loss function for training 55 | self.criterionL1 = torch.nn.L1Loss() 56 | if opts.percep == 'default': 57 | self.perceptualLoss = networks.PerceptualLoss(nn.MSELoss(), opts.gpu, opts.percp_layer) 58 | elif opts.percep == 'face': 59 | self.perceptualLoss = networks.PerceptualLoss16(nn.MSELoss(), opts.gpu, opts.percp_layer) 60 | else: 61 | self.perceptualLoss = networks.MultiPerceptualLoss(nn.MSELoss(), opts.gpu) 62 | 63 | def initialize(self): 64 | self.disA.apply(networks.gaussian_weights_init) 65 | self.disB.apply(networks.gaussian_weights_init) 66 | self.disA2.apply(networks.gaussian_weights_init) 67 | self.disB2.apply(networks.gaussian_weights_init) 68 | self.gen.apply(networks.gaussian_weights_init) 69 | self.enc_c.apply(networks.gaussian_weights_init) 70 | self.enc_a.apply(networks.gaussian_weights_init) 71 | 72 | def set_scheduler(self, opts, last_ep=0): 73 | self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep) 74 | self.disB_sch = networks.get_scheduler(self.disB_opt, opts, last_ep) 75 | self.disA2_sch = networks.get_scheduler(self.disA2_opt, opts, last_ep) 76 | self.disB2_sch = networks.get_scheduler(self.disB2_opt, opts, last_ep) 77 | self.enc_c_sch = networks.get_scheduler(self.enc_c_opt, opts, last_ep) 78 | self.enc_a_sch = networks.get_scheduler(self.enc_a_opt, opts, last_ep) 79 | self.gen_sch = networks.get_scheduler(self.gen_opt, opts, last_ep) 80 | 81 | def setgpu(self, gpu): 82 | self.gpu = gpu 83 | self.disA.cuda(self.gpu) 84 | self.disB.cuda(self.gpu) 85 | self.disA2.cuda(self.gpu) 86 | self.disB2.cuda(self.gpu) 87 | self.enc_c.cuda(self.gpu) 88 | self.enc_a.cuda(self.gpu) 89 | self.gen.cuda(self.gpu) 90 | 91 | def get_z_random(self, batchSize, nz, random_type='gauss'): 92 | z = torch.randn(batchSize, nz).cuda(self.gpu) 93 | return z 94 | 95 | def test_forward(self, image, a2b=True): 96 | if a2b: 97 | self.z_content = self.enc_c.forward_b(image) 98 | self.mu_b, self.logvar_b = self.enc_a.forward(image) 99 | std_b = self.logvar_b.mul(0.5).exp_() 100 | eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss') 101 | self.z_attr_b = eps_b.mul(std_b).add_(self.mu_b) 102 | output = self.gen.forward_D(self.z_content, self.z_attr_b) 103 | 104 | return output 105 | 106 | 107 | def forward(self): 108 | # input images 109 | real_I = self.input_I 110 | real_B = self.input_B 111 | half_size = real_I.size(0) // 2 112 | self.real_I_encoded = real_I[0:half_size] 113 | self.real_I_random = real_I[half_size:] 114 | self.real_B_encoded = real_B[0:half_size] 115 | self.real_B_random = real_B[half_size:] 116 | 117 | # get encoded z_c 118 | self.z_content_i, self.z_content_b = self.enc_c.forward(self.real_I_encoded, self.real_B_encoded) 119 | 120 | # get encoded z_a 121 | if self.concat: 122 | self.mu_b, self.logvar_b = self.enc_a.forward(self.real_B_encoded) 123 | std_b = self.logvar_b.mul(0.5).exp_() 124 | eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss') 125 | self.z_attr_b = eps_b.mul(std_b).add_(self.mu_b) 126 | else: 127 | self.z_attr_i, self.z_attr_b = self.enc_a.forward(self.real_I_encoded, self.real_B_encoded) 128 | 129 | # get random z_a 130 | self.z_random = self.get_z_random(self.real_I_encoded.size(0), self.nz, 'gauss') 131 | 132 | # first cross translation 133 | input_content_forI = torch.cat((self.z_content_b, self.z_content_i, self.z_content_b),0) 134 | input_content_forB = torch.cat((self.z_content_i, self.z_content_b, self.z_content_i),0) 135 | input_attr_forI = torch.cat((self.z_attr_b, self.z_attr_b, self.z_random),0) 136 | input_attr_forB = torch.cat((self.z_attr_b, self.z_attr_b, self.z_random),0) 137 | 138 | 139 | output_fakeI = self.gen.forward_D(input_content_forI, input_attr_forI) 140 | output_fakeB = self.gen.forward_B(input_content_forB, input_attr_forB) 141 | self.fake_I_encoded, self.fake_II_encoded, self.fake_I_random = torch.split(output_fakeI, self.z_content_i.size(0), dim=0) 142 | self.fake_B_encoded, self.fake_BB_encoded, self.fake_B_random = torch.split(output_fakeB, self.z_content_i.size(0), dim=0) 143 | 144 | # get reconstructed encoded z_c 145 | self.z_content_recon_b, self.z_content_recon_i = self.enc_c.forward(self.fake_I_encoded, self.fake_B_encoded) 146 | 147 | # get reconstructed encoded z_a 148 | if self.concat: 149 | self.mu_recon_b, self.logvar_recon_b = self.enc_a.forward(self.fake_B_encoded) 150 | std_b = self.logvar_recon_b.mul(0.5).exp_() 151 | eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss') 152 | self.z_attr_recon_b = eps_b.mul(std_b).add_(self.mu_recon_b) 153 | else: 154 | self.z_attr_recon_i, self.z_attr_recon_b = self.enc_a.forward(self.fake_A_encoded, self.fake_B_encoded) 155 | 156 | # second cross translation 157 | self.fake_I_recon = self.gen.forward_D(self.z_content_recon_i, self.z_attr_recon_b) 158 | self.fake_B_recon = self.gen.forward_B(self.z_content_recon_b, self.z_attr_recon_b) 159 | 160 | 161 | # for latent regression 162 | if self.concat: 163 | self.mu2_b, _ = self.enc_a.forward(self.fake_B_random) 164 | else: 165 | self.z_attr_random_i, self.z_attr_random_b = self.enc_a.forward(self.fake_B_random) 166 | 167 | 168 | def update_D(self, image_a, image_b): 169 | self.input_I = image_a 170 | self.input_B = image_b 171 | self.forward() 172 | 173 | # update disA 174 | self.disA_opt.zero_grad() 175 | loss_D1_A = self.backward_D(self.disA, self.real_I_encoded, self.fake_I_encoded) 176 | self.disA_loss = loss_D1_A.item() 177 | self.disA_opt.step() 178 | 179 | # update disA2 180 | self.disA2_opt.zero_grad() 181 | loss_D2_A = self.backward_D(self.disA2, self.real_I_random, self.fake_I_random) 182 | self.disA2_loss = loss_D2_A.item() 183 | self.disA2_opt.step() 184 | 185 | # update disB 186 | self.disB_opt.zero_grad() 187 | loss_D1_B = self.backward_D(self.disB, self.real_B_encoded, self.fake_B_encoded) 188 | self.disB_loss = loss_D1_B.item() 189 | self.disB_opt.step() 190 | 191 | # update disB2 192 | self.disB2_opt.zero_grad() 193 | loss_D2_B = self.backward_D(self.disB2, self.real_B_random, self.fake_B_random) 194 | self.disB2_loss = loss_D2_B.item() 195 | self.disB2_opt.step() 196 | 197 | def backward_D(self, netD, real, fake): 198 | pred_fake = netD.forward(fake.detach()) 199 | pred_real = netD.forward(real) 200 | loss_D = 0 201 | for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)): 202 | out_fake = nn.functional.sigmoid(out_a) 203 | out_real = nn.functional.sigmoid(out_b) 204 | all0 = torch.zeros_like(out_fake).cuda(self.gpu) 205 | all1 = torch.ones_like(out_real).cuda(self.gpu) 206 | ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0) 207 | ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1) 208 | loss_D += ad_true_loss + ad_fake_loss 209 | loss_D.backward() 210 | return loss_D 211 | 212 | def update_EG(self): 213 | # update G, Ec, Ea 214 | self.enc_c_opt.zero_grad() 215 | self.enc_a_opt.zero_grad() 216 | self.gen_opt.zero_grad() 217 | self.backward_EG() 218 | self.enc_c_opt.step() 219 | self.enc_a_opt.step() 220 | self.gen_opt.step() 221 | 222 | # update G, Ec 223 | self.enc_c_opt.zero_grad() 224 | self.gen_opt.zero_grad() 225 | self.backward_G_alone() 226 | self.enc_c_opt.step() 227 | self.gen_opt.step() 228 | 229 | def backward_EG(self): 230 | 231 | 232 | # Ladv for generator 233 | loss_G_GAN_I = self.backward_G_GAN(self.fake_I_encoded, self.disA) 234 | loss_G_GAN_B = self.backward_G_GAN(self.fake_B_encoded, self.disB) 235 | 236 | # KL loss - z_a 237 | if self.concat: 238 | kl_element_b = self.mu_b.pow(2).add_(self.logvar_b.exp()).mul_(-1).add_(1).add_(self.logvar_b) 239 | loss_kl_za_b = torch.sum(kl_element_b).mul_(-0.5) * 0.01 240 | else: 241 | loss_kl_za_b = self._l2_regularize(self.z_attr_b) * 0.01 242 | 243 | # cross cycle consistency loss 244 | loss_G_L1_I = self.criterionL1(self.fake_I_recon, self.real_I_encoded) * 10 245 | loss_G_L1_B = self.criterionL1(self.fake_B_recon, self.real_B_encoded) * 10 246 | loss_G_L1_II = self.criterionL1(self.fake_II_encoded, self.real_I_encoded) * 10 247 | loss_G_L1_BB = self.criterionL1(self.fake_BB_encoded, self.real_B_encoded) * 10 248 | 249 | # perceptual losses 250 | percp_loss_B = self.perceptualLoss.getloss(self.fake_I_encoded, self.real_B_encoded) * self.lambdaB 251 | percp_loss_I = self.perceptualLoss.getloss(self.fake_B_encoded, self.real_I_encoded) * self.lambdaI 252 | 253 | loss_G = loss_G_GAN_I + loss_G_GAN_B + \ 254 | loss_G_L1_II + loss_G_L1_BB + \ 255 | loss_G_L1_I + loss_G_L1_B + \ 256 | loss_kl_za_b + percp_loss_B + \ 257 | percp_loss_I 258 | 259 | loss_G.backward(retain_graph=True) 260 | 261 | self.gan_loss_i = loss_G_GAN_I.item() 262 | self.gan_loss_b = loss_G_GAN_B.item() 263 | 264 | self.kl_loss_za_b = loss_kl_za_b.item() 265 | self.l1_recon_I_loss = loss_G_L1_I.item() 266 | self.l1_recon_B_loss = loss_G_L1_B.item() 267 | self.l1_recon_II_loss = loss_G_L1_II.item() 268 | self.l1_recon_BB_loss = loss_G_L1_BB.item() 269 | self.B_percp_loss = percp_loss_B.item() 270 | self.G_loss = loss_G.item() 271 | 272 | 273 | def backward_G_GAN(self, fake, netD=None): 274 | outs_fake = netD.forward(fake) 275 | loss_G = 0 276 | for out_a in outs_fake: 277 | outputs_fake = nn.functional.sigmoid(out_a) 278 | all_ones = torch.ones_like(outputs_fake).cuda(self.gpu) 279 | loss_G += nn.functional.binary_cross_entropy(outputs_fake, all_ones) 280 | return loss_G 281 | 282 | def backward_G_alone(self): 283 | # Ladv for generator 284 | loss_G_GAN2_I = self.backward_G_GAN(self.fake_I_random, self.disA2) 285 | loss_G_GAN2_B = self.backward_G_GAN(self.fake_B_random, self.disB2) 286 | 287 | # latent regression loss 288 | if self.concat: 289 | loss_z_L1_b = torch.mean(torch.abs(self.mu2_b - self.z_random)) * 10 290 | else: 291 | loss_z_L1_b = torch.mean(torch.abs(self.z_attr_random_b - self.z_random)) * 10 292 | 293 | # perceptual losses 294 | percp_loss_B2 = self.perceptualLoss.getloss(self.fake_I_random, self.real_B_encoded) * self.lambdaB 295 | percp_loss_I2 = self.perceptualLoss.getloss(self.fake_B_random, self.real_I_encoded) * self.lambdaI 296 | 297 | 298 | loss_G2 = loss_z_L1_b + loss_G_GAN2_I + loss_G_GAN2_B + percp_loss_B2 + percp_loss_I2 299 | loss_G2.backward() 300 | self.gan2_loss_a = loss_G_GAN2_I.item() 301 | self.gan2_loss_b = loss_G_GAN2_B.item() 302 | 303 | def update_lr(self): 304 | self.disA_sch.step() 305 | self.disB_sch.step() 306 | self.disA2_sch.step() 307 | self.disB2_sch.step() 308 | self.enc_c_sch.step() 309 | self.enc_a_sch.step() 310 | self.gen_sch.step() 311 | 312 | def _l2_regularize(self, mu): 313 | mu_2 = torch.pow(mu, 2) 314 | encoding_loss = torch.mean(mu_2) 315 | return encoding_loss 316 | 317 | def resume(self, model_dir, train=True): 318 | checkpoint = torch.load(model_dir, map_location=lambda storage, loc: storage) 319 | 320 | # weight 321 | if train: 322 | self.disA.load_state_dict(checkpoint['disA']) 323 | self.disA2.load_state_dict(checkpoint['disA2']) 324 | self.disB.load_state_dict(checkpoint['disB']) 325 | self.disB2.load_state_dict(checkpoint['disB2']) 326 | self.enc_c.load_state_dict(checkpoint['enc_c']) 327 | self.enc_a.load_state_dict(checkpoint['enc_a']) 328 | self.gen.load_state_dict(checkpoint['gen']) 329 | # optimizer 330 | if train: 331 | self.disA_opt.load_state_dict(checkpoint['disA_opt']) 332 | self.disA2_opt.load_state_dict(checkpoint['disA2_opt']) 333 | self.disB_opt.load_state_dict(checkpoint['disB_opt']) 334 | self.disB2_opt.load_state_dict(checkpoint['disB2_opt']) 335 | self.enc_c_opt.load_state_dict(checkpoint['enc_c_opt']) 336 | self.enc_a_opt.load_state_dict(checkpoint['enc_a_opt']) 337 | self.gen_opt.load_state_dict(checkpoint['gen_opt']) 338 | return checkpoint['ep'], checkpoint['total_it'] 339 | 340 | def save(self, filename, ep, total_it): 341 | state = { 342 | 'disA': self.disA.state_dict(), 343 | 'disA2': self.disA2.state_dict(), 344 | 'disB': self.disB.state_dict(), 345 | 'disB2': self.disB2.state_dict(), 346 | 'enc_c': self.enc_c.state_dict(), 347 | 'enc_a': self.enc_a.state_dict(), 348 | 'gen': self.gen.state_dict(), 349 | 'disA_opt': self.disA_opt.state_dict(), 350 | 'disA2_opt': self.disA2_opt.state_dict(), 351 | 'disB_opt': self.disB_opt.state_dict(), 352 | 'disB2_opt': self.disB2_opt.state_dict(), 353 | 'enc_c_opt': self.enc_c_opt.state_dict(), 354 | 'enc_a_opt': self.enc_a_opt.state_dict(), 355 | 'gen_opt': self.gen_opt.state_dict(), 356 | 'ep': ep, 357 | 'total_it': total_it 358 | } 359 | time.sleep(10) 360 | torch.save(state, filename) 361 | return 362 | 363 | def assemble_outputs(self): 364 | images_a = self.normalize_image(self.real_I_encoded).detach() 365 | images_b = self.normalize_image(self.real_B_encoded).detach() 366 | images_a1 = self.normalize_image(self.fake_I_encoded).detach() 367 | images_a2 = self.normalize_image(self.fake_I_random).detach() 368 | images_a3 = self.normalize_image(self.fake_I_recon).detach() 369 | images_a4 = self.normalize_image(self.fake_II_encoded).detach() 370 | images_b1 = self.normalize_image(self.fake_B_encoded).detach() 371 | images_b2 = self.normalize_image(self.fake_B_random).detach() 372 | images_b3 = self.normalize_image(self.fake_B_recon).detach() 373 | images_b4 = self.normalize_image(self.fake_BB_encoded).detach() 374 | row1 = torch.cat((images_a[0:1, ::], images_b1[0:1, ::], images_b2[0:1, ::], images_a4[0:1, ::], images_a3[0:1, ::]),3) 375 | row2 = torch.cat((images_b[0:1, ::], images_a1[0:1, ::], images_a2[0:1, ::], images_b4[0:1, ::], images_b3[0:1, ::]),3) 376 | return torch.cat((row1,row2),2) 377 | 378 | def normalize_image(self, x): 379 | return x[:,0:3,:,:] 380 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | import torchvision.models as py_models 8 | import numpy 9 | import copy 10 | 11 | 12 | #################################################################### 13 | #------------------------- Discriminators -------------------------- 14 | #################################################################### 15 | class Dis_content(nn.Module): 16 | def __init__(self): 17 | super(Dis_content, self).__init__() 18 | model = [] 19 | model += [LeakyReLUConv2d(256, 256, kernel_size=4, stride=2, padding=1, norm='Instance')] 20 | model += [LeakyReLUConv2d(256, 256, kernel_size=4, stride=2, padding=1, norm='Instance')] 21 | model += [LeakyReLUConv2d(256, 256, kernel_size=4, stride=2, padding=1, norm='Instance')] 22 | model += [LeakyReLUConv2d(256, 256, kernel_size=4, stride=1, padding=0)] 23 | model += [nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0)] 24 | self.model = nn.Sequential(*model) 25 | 26 | def forward(self, x): 27 | out = self.model(x) 28 | out = out.view(-1) 29 | outs = [] 30 | outs.append(out) 31 | return outs 32 | 33 | class MultiScaleDis(nn.Module): 34 | def __init__(self, input_dim, n_scale=3, n_layer=4, norm='None', sn=False): 35 | super(MultiScaleDis, self).__init__() 36 | ch = 64 37 | self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) 38 | self.Diss = nn.ModuleList() 39 | for _ in range(n_scale): 40 | self.Diss.append(self._make_net(ch, input_dim, n_layer, norm, sn)) 41 | 42 | def _make_net(self, ch, input_dim, n_layer, norm, sn): 43 | model = [] 44 | model += [LeakyReLUConv2d(input_dim, ch, 4, 2, 1, norm, sn)] 45 | tch = ch 46 | for _ in range(1, n_layer): 47 | model += [LeakyReLUConv2d(tch, tch * 2, 4, 2, 1, norm, sn)] 48 | tch *= 2 49 | if sn: 50 | model += [spectral_norm(nn.Conv2d(tch, 1, 1, 1, 0))] 51 | else: 52 | model += [nn.Conv2d(tch, 1, 1, 1, 0)] 53 | return nn.Sequential(*model) 54 | 55 | def forward(self, x): 56 | outs = [] 57 | for Dis in self.Diss: 58 | outs.append(Dis(x)) 59 | x = self.downsample(x) 60 | return outs 61 | 62 | class Dis(nn.Module): 63 | def __init__(self, input_dim, norm='None', sn=False): 64 | super(Dis, self).__init__() 65 | ch = 64 66 | n_layer = 6 67 | self.model = self._make_net(ch, input_dim, n_layer, norm, sn) 68 | 69 | def _make_net(self, ch, input_dim, n_layer, norm, sn): 70 | model = [] 71 | model += [LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] #16 72 | tch = ch 73 | for i in range(1, n_layer-1): 74 | model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] # 8 75 | tch *= 2 76 | model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm='None', sn=sn)] # 2 77 | tch *= 2 78 | if sn: 79 | model += [spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))] # 1 80 | else: 81 | model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)] # 1 82 | return nn.Sequential(*model) 83 | 84 | def cuda(self,gpu): 85 | self.model.cuda(gpu) 86 | 87 | def forward(self, x_A): 88 | out_A = self.model(x_A) 89 | out_A = out_A.view(-1) 90 | outs_A = [] 91 | outs_A.append(out_A) 92 | return outs_A 93 | 94 | #################################################################### 95 | #---------------------------- Encoders ----------------------------- 96 | #################################################################### 97 | class E_content(nn.Module): 98 | def __init__(self, input_dim_a, input_dim_b): 99 | super(E_content, self).__init__() 100 | encA_c = [] 101 | tch = 64 102 | encA_c += [LeakyReLUConv2d(input_dim_a, tch, kernel_size=7, stride=1, padding=3)] 103 | for i in range(1, 3): 104 | encA_c += [ReLUINSConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] 105 | tch *= 2 106 | for i in range(0, 3): 107 | encA_c += [INSResBlock(tch, tch)] 108 | 109 | encB_c = [] 110 | tch = 64 111 | encB_c += [LeakyReLUConv2d(input_dim_b, tch, kernel_size=7, stride=1, padding=3)] 112 | for i in range(1, 3): 113 | encB_c += [ReLUINSConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] 114 | tch *= 2 115 | for i in range(0, 3): 116 | encB_c += [INSResBlock(tch, tch)] 117 | 118 | enc_share = [] 119 | for i in range(0, 1): 120 | enc_share += [INSResBlock(tch, tch)] 121 | enc_share += [GaussianNoiseLayer()] 122 | self.conv_share = nn.Sequential(*enc_share) 123 | 124 | self.convA = nn.Sequential(*encA_c) 125 | self.convB = nn.Sequential(*encB_c) 126 | 127 | def forward(self, xa, xb): 128 | outputA = self.convA(xa) 129 | outputB = self.convB(xb) 130 | outputA = self.conv_share(outputA) 131 | outputB = self.conv_share(outputB) 132 | return outputA, outputB 133 | 134 | def forward_a(self, xa): 135 | outputA = self.convA(xa) 136 | outputA = self.conv_share(outputA) 137 | return outputA 138 | 139 | def forward_b(self, xb): 140 | outputB = self.convB(xb) 141 | outputB = self.conv_share(outputB) 142 | return outputB 143 | 144 | class E_attr_concat(nn.Module): 145 | def __init__(self, input_dim_b, output_nc=8, norm_layer=None, nl_layer=None): 146 | super(E_attr_concat, self).__init__() 147 | 148 | ndf = 64 149 | n_blocks=4 150 | max_ndf = 4 151 | 152 | conv_layers_B = [nn.ReflectionPad2d(1)] 153 | conv_layers_B += [nn.Conv2d(input_dim_b, ndf, kernel_size=4, stride=2, padding=0, bias=True)] 154 | for n in range(1, n_blocks): 155 | input_ndf = ndf * min(max_ndf, n) # 2**(n-1) 156 | output_ndf = ndf * min(max_ndf, n+1) # 2**n 157 | conv_layers_B += [BasicBlock(input_ndf, output_ndf, norm_layer, nl_layer)] 158 | conv_layers_B += [nl_layer(), nn.AdaptiveAvgPool2d(1)] # AvgPool2d(13) 159 | self.fc_B = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 160 | self.fcVar_B = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 161 | self.conv_B = nn.Sequential(*conv_layers_B) 162 | 163 | def forward(self, xb): 164 | 165 | x_conv_B = self.conv_B(xb) 166 | conv_flat_B = x_conv_B.view(xb.size(0), -1) 167 | output_B = self.fc_B(conv_flat_B) 168 | outputVar_B = self.fcVar_B(conv_flat_B) 169 | return output_B, outputVar_B 170 | 171 | 172 | def forward_b(self, xb): 173 | x_conv_B = self.conv_B(xb) 174 | conv_flat_B = x_conv_B.view(xb.size(0), -1) 175 | output_B = self.fc_B(conv_flat_B) 176 | outputVar_B = self.fcVar_B(conv_flat_B) 177 | return output_B, outputVar_B 178 | 179 | #################################################################### 180 | #--------------------------- Generators ---------------------------- 181 | #################################################################### 182 | 183 | class G_concat(nn.Module): 184 | def __init__(self, output_dim_a, output_dim_b, nz): 185 | super(G_concat, self).__init__() 186 | self.nz = nz 187 | tch = 256 188 | dec_share = [] 189 | dec_share += [INSResBlock(tch, tch)] 190 | self.dec_share = nn.Sequential(*dec_share) 191 | tch = 256+self.nz 192 | decA1 = [] 193 | for i in range(0, 3): 194 | decA1 += [INSResBlock(tch, tch)] 195 | tch = tch + self.nz 196 | decA2 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 197 | tch = tch//2 198 | tch = tch + self.nz 199 | decA3 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 200 | tch = tch//2 201 | tch = tch + self.nz 202 | decA4 = [nn.ConvTranspose2d(tch, output_dim_a, kernel_size=1, stride=1, padding=0)]+[nn.Tanh()] 203 | self.decA1 = nn.Sequential(*decA1) 204 | self.decA2 = nn.Sequential(*[decA2]) 205 | self.decA3 = nn.Sequential(*[decA3]) 206 | self.decA4 = nn.Sequential(*decA4) 207 | 208 | tch = 256+self.nz 209 | decB1 = [] 210 | for i in range(0, 3): 211 | decB1 += [INSResBlock(tch, tch)] 212 | tch = tch + self.nz 213 | decB2 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 214 | tch = tch//2 215 | tch = tch + self.nz 216 | decB3 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 217 | tch = tch//2 218 | tch = tch + self.nz 219 | decB4 = [nn.ConvTranspose2d(tch, output_dim_b, kernel_size=1, stride=1, padding=0)]+[nn.Tanh()] 220 | self.decB1 = nn.Sequential(*decB1) 221 | self.decB2 = nn.Sequential(*[decB2]) 222 | self.decB3 = nn.Sequential(*[decB3]) 223 | self.decB4 = nn.Sequential(*decB4) 224 | 225 | def forward_D(self, x, z): 226 | out0 = self.dec_share(x) 227 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 228 | x_and_z = torch.cat([out0, z_img], 1) 229 | out1 = self.decA1(x_and_z) 230 | z_img2 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out1.size(2), out1.size(3)) 231 | x_and_z2 = torch.cat([out1, z_img2], 1) 232 | out2 = self.decA2(x_and_z2) 233 | z_img3 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out2.size(2), out2.size(3)) 234 | x_and_z3 = torch.cat([out2, z_img3], 1) 235 | out3 = self.decA3(x_and_z3) 236 | z_img4 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out3.size(2), out3.size(3)) 237 | x_and_z4 = torch.cat([out3, z_img4], 1) 238 | out4 = self.decA4(x_and_z4) 239 | return out4 240 | 241 | def forward_B(self, x, z): 242 | out0 = self.dec_share(x) 243 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 244 | x_and_z = torch.cat([out0, z_img], 1) 245 | out1 = self.decB1(x_and_z) 246 | z_img2 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out1.size(2), out1.size(3)) 247 | x_and_z2 = torch.cat([out1, z_img2], 1) 248 | out2 = self.decB2(x_and_z2) 249 | z_img3 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out2.size(2), out2.size(3)) 250 | x_and_z3 = torch.cat([out2, z_img3], 1) 251 | out3 = self.decB3(x_and_z3) 252 | z_img4 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out3.size(2), out3.size(3)) 253 | x_and_z4 = torch.cat([out3, z_img4], 1) 254 | out4 = self.decB4(x_and_z4) 255 | return out4 256 | 257 | #################################################################### 258 | #--------------------------- losses ---------------------------- 259 | #################################################################### 260 | class PerceptualLoss(): 261 | def __init__(self, loss, gpu=0, p_layer=14): 262 | super(PerceptualLoss, self).__init__() 263 | self.criterion = loss 264 | 265 | cnn = py_models.vgg19(pretrained=True).features 266 | cnn = cnn.cuda() 267 | model = nn.Sequential() 268 | model = model.cuda() 269 | for i,layer in enumerate(list(cnn)): 270 | model.add_module(str(i),layer) 271 | if i == p_layer: 272 | break 273 | self.contentFunc = model 274 | 275 | def getloss(self, fakeIm, realIm): 276 | if isinstance(fakeIm, numpy.ndarray): 277 | fakeIm = torch.from_numpy(fakeIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 278 | realIm = torch.from_numpy(realIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 279 | f_fake = self.contentFunc.forward(fakeIm) 280 | f_real = self.contentFunc.forward(realIm) 281 | f_real_no_grad = f_real.detach() 282 | loss = self.criterion(f_fake, f_real_no_grad) 283 | return loss 284 | class PerceptualLoss16(): 285 | def __init__(self, loss, gpu=0, p_layer=14): 286 | super(PerceptualLoss16, self).__init__() 287 | self.criterion = loss 288 | # conv_3_3_layer = 14 289 | checkpoint = torch.load('/vggface_path/VGGFace16.pth') 290 | vgg16 = py_models.vgg16(num_classes=2622) 291 | vgg16.load_state_dict(checkpoint['state_dict']) 292 | cnn = vgg16.features 293 | cnn = cnn.cuda() 294 | # cnn = cnn.to(gpu) 295 | model = nn.Sequential() 296 | model = model.cuda() 297 | for i,layer in enumerate(list(cnn)): 298 | # print(layer) 299 | model.add_module(str(i),layer) 300 | if i == p_layer: 301 | break 302 | self.contentFunc = model 303 | del vgg16, cnn, checkpoint 304 | 305 | def getloss(self, fakeIm, realIm): 306 | if isinstance(fakeIm, numpy.ndarray): 307 | fakeIm = torch.from_numpy(fakeIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 308 | realIm = torch.from_numpy(realIm).permute(2, 0, 1).unsqueeze(0).float().cuda() 309 | 310 | f_fake = self.contentFunc.forward(fakeIm) 311 | f_real = self.contentFunc.forward(realIm) 312 | f_real_no_grad = f_real.detach() 313 | loss = self.criterion(f_fake, f_real_no_grad) 314 | return loss 315 | 316 | class GradientLoss(): 317 | def __init__(self, loss, n_scale=3): 318 | super(GradientLoss, self).__init__() 319 | self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) 320 | self.criterion = loss 321 | self.n_scale = n_scale 322 | 323 | def grad_xy(self, img): 324 | gradient_x = img[:, :, :, :-1] - img[:, :, :, 1:] 325 | gradient_y = img[:, :, :-1, :] - img[:, :, 1:, :] 326 | return gradient_x, gradient_y 327 | 328 | def getloss(self, fakeIm, realIm): 329 | loss = 0 330 | for i in range(self.n_scale): 331 | fakeIm = self.downsample(fakeIm) 332 | realIm = self.downsample(realIm) 333 | grad_fx, grad_fy = self.grad_xy(fakeIm) 334 | grad_rx, grad_ry = self.grad_xy(realIm) 335 | loss += pow(4,i) * self.criterion(grad_fx, grad_rx) + self.criterion(grad_fy, grad_ry) 336 | return loss 337 | 338 | class l1GradientLoss(): 339 | def __init__(self, loss, n_scale=3): 340 | super(l1GradientLoss, self).__init__() 341 | self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) 342 | self.criterion = loss 343 | self.n_scale = n_scale 344 | 345 | def grad_xy(self, img): 346 | gradient_x = img[:, :, :, :-1] - img[:, :, :, 1:] 347 | gradient_y = img[:, :, :-1, :] - img[:, :, 1:, :] 348 | return gradient_x, gradient_y 349 | 350 | def getloss(self, fakeIm): 351 | loss = 0 352 | for i in range(self.n_scale): 353 | fakeIm = self.downsample(fakeIm) 354 | grad_fx, grad_fy = self.grad_xy(fakeIm) 355 | loss += self.criterion(grad_fx, torch.zeros_like(grad_fx)) + self.criterion(grad_fy, torch.zeros_like(grad_fy)) 356 | 357 | return loss 358 | 359 | #################################################################### 360 | #------------------------- Basic Functions ------------------------- 361 | #################################################################### 362 | def get_scheduler(optimizer, opts, cur_ep=-1): 363 | if opts.lr_policy == 'lambda': 364 | def lambda_rule(ep): 365 | lr_l = 1.0 - max(0, ep - opts.n_ep_decay) / float(opts.n_ep - opts.n_ep_decay + 1) 366 | return lr_l 367 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=cur_ep) 368 | elif opts.lr_policy == 'step': 369 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opts.n_ep_decay, gamma=0.1, last_epoch=cur_ep) 370 | else: 371 | return NotImplementedError('no such learn rate policy') 372 | return scheduler 373 | 374 | def meanpoolConv(inplanes, outplanes): 375 | sequence = [] 376 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 377 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0, bias=True)] 378 | return nn.Sequential(*sequence) 379 | 380 | def convMeanpool(inplanes, outplanes): 381 | sequence = [] 382 | sequence += conv3x3(inplanes, outplanes) 383 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 384 | return nn.Sequential(*sequence) 385 | 386 | def get_norm_layer(layer_type='instance'): 387 | if layer_type == 'batch': 388 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 389 | elif layer_type == 'instance': 390 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 391 | elif layer_type == 'none': 392 | norm_layer = None 393 | else: 394 | raise NotImplementedError('normalization layer [%s] is not found' % layer_type) 395 | return norm_layer 396 | 397 | def get_non_linearity(layer_type='relu'): 398 | if layer_type == 'relu': 399 | nl_layer = functools.partial(nn.ReLU, inplace=True) 400 | elif layer_type == 'lrelu': 401 | nl_layer = functools.partial(nn.LeakyReLU, negative_slope=0.2, inplace=False) 402 | elif layer_type == 'elu': 403 | nl_layer = functools.partial(nn.ELU, inplace=True) 404 | else: 405 | raise NotImplementedError('nonlinearity activitation [%s] is not found' % layer_type) 406 | return nl_layer 407 | def conv3x3(in_planes, out_planes): 408 | return [nn.ReflectionPad2d(1), nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0, bias=True)] 409 | 410 | def gaussian_weights_init(m): 411 | classname = m.__class__.__name__ 412 | if classname.find('Conv') != -1 and classname.find('Conv') == 0: 413 | m.weight.data.normal_(0.0, 0.02) 414 | 415 | #################################################################### 416 | #-------------------------- Basic Blocks -------------------------- 417 | #################################################################### 418 | 419 | ## The code of LayerNorm is modified from MUNIT (https://github.com/NVlabs/MUNIT) 420 | class LayerNorm(nn.Module): 421 | def __init__(self, n_out, eps=1e-5, affine=True): 422 | super(LayerNorm, self).__init__() 423 | self.n_out = n_out 424 | self.affine = affine 425 | if self.affine: 426 | self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) 427 | self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) 428 | return 429 | def forward(self, x): 430 | normalized_shape = x.size()[1:] 431 | if self.affine: 432 | return F.layer_norm(x, normalized_shape, self.weight.expand(normalized_shape), self.bias.expand(normalized_shape)) 433 | else: 434 | return F.layer_norm(x, normalized_shape) 435 | 436 | class BasicBlock(nn.Module): 437 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 438 | super(BasicBlock, self).__init__() 439 | layers = [] 440 | if norm_layer is not None: 441 | layers += [norm_layer(inplanes)] 442 | layers += [nl_layer()] 443 | layers += conv3x3(inplanes, inplanes) 444 | if norm_layer is not None: 445 | layers += [norm_layer(inplanes)] 446 | layers += [nl_layer()] 447 | layers += [convMeanpool(inplanes, outplanes)] 448 | self.conv = nn.Sequential(*layers) 449 | self.shortcut = meanpoolConv(inplanes, outplanes) 450 | def forward(self, x): 451 | out = self.conv(x) + self.shortcut(x) 452 | return out 453 | 454 | class LeakyReLUConv2d(nn.Module): 455 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0, norm='None', sn=False): 456 | super(LeakyReLUConv2d, self).__init__() 457 | model = [] 458 | model += [nn.ReflectionPad2d(padding)] 459 | if sn: 460 | model += [spectral_norm(nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True))] 461 | else: 462 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True)] 463 | if 'norm' == 'Instance': 464 | model += [nn.InstanceNorm2d(n_out, affine=False)] 465 | model += [nn.LeakyReLU(inplace=True)] 466 | self.model = nn.Sequential(*model) 467 | self.model.apply(gaussian_weights_init) 468 | #elif == 'Group' 469 | def forward(self, x): 470 | return self.model(x) 471 | 472 | class ReLUINSConv2d(nn.Module): 473 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0): 474 | super(ReLUINSConv2d, self).__init__() 475 | model = [] 476 | model += [nn.ReflectionPad2d(padding)] 477 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True)] 478 | model += [nn.InstanceNorm2d(n_out, affine=False)] 479 | model += [nn.ReLU(inplace=True)] 480 | self.model = nn.Sequential(*model) 481 | self.model.apply(gaussian_weights_init) 482 | def forward(self, x): 483 | return self.model(x) 484 | 485 | class INSResBlock(nn.Module): 486 | def conv3x3(self, inplanes, out_planes, stride=1): 487 | return [nn.ReflectionPad2d(1), nn.Conv2d(inplanes, out_planes, kernel_size=3, stride=stride)] 488 | def __init__(self, inplanes, planes, stride=1, dropout=0.0): 489 | super(INSResBlock, self).__init__() 490 | model = [] 491 | model += self.conv3x3(inplanes, planes, stride) 492 | model += [nn.InstanceNorm2d(planes)] 493 | model += [nn.ReLU(inplace=True)] 494 | model += self.conv3x3(planes, planes) 495 | model += [nn.InstanceNorm2d(planes)] 496 | if dropout > 0: 497 | model += [nn.Dropout(p=dropout)] 498 | self.model = nn.Sequential(*model) 499 | self.model.apply(gaussian_weights_init) 500 | def forward(self, x): 501 | residual = x 502 | out = self.model(x) 503 | out += residual 504 | return out 505 | 506 | class MisINSResBlock(nn.Module): 507 | def conv3x3(self, dim_in, dim_out, stride=1): 508 | return nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=stride)) 509 | def conv1x1(self, dim_in, dim_out): 510 | return nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, padding=0) 511 | def __init__(self, dim, dim_extra, stride=1, dropout=0.0): 512 | super(MisINSResBlock, self).__init__() 513 | self.conv1 = nn.Sequential( 514 | self.conv3x3(dim, dim, stride), 515 | nn.InstanceNorm2d(dim)) 516 | self.conv2 = nn.Sequential( 517 | self.conv3x3(dim, dim, stride), 518 | nn.InstanceNorm2d(dim)) 519 | self.blk1 = nn.Sequential( 520 | self.conv1x1(dim + dim_extra, dim + dim_extra), 521 | nn.ReLU(inplace=False), 522 | self.conv1x1(dim + dim_extra, dim), 523 | nn.ReLU(inplace=False)) 524 | self.blk2 = nn.Sequential( 525 | self.conv1x1(dim + dim_extra, dim + dim_extra), 526 | nn.ReLU(inplace=False), 527 | self.conv1x1(dim + dim_extra, dim), 528 | nn.ReLU(inplace=False)) 529 | model = [] 530 | if dropout > 0: 531 | model += [nn.Dropout(p=dropout)] 532 | self.model = nn.Sequential(*model) 533 | self.model.apply(gaussian_weights_init) 534 | self.conv1.apply(gaussian_weights_init) 535 | self.conv2.apply(gaussian_weights_init) 536 | self.blk1.apply(gaussian_weights_init) 537 | self.blk2.apply(gaussian_weights_init) 538 | def forward(self, x, z): 539 | residual = x 540 | z_expand = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 541 | o1 = self.conv1(x) 542 | o2 = self.blk1(torch.cat([o1, z_expand], dim=1)) 543 | o3 = self.conv2(o2) 544 | out = self.blk2(torch.cat([o3, z_expand], dim=1)) 545 | out += residual 546 | return out 547 | 548 | class GaussianNoiseLayer(nn.Module): 549 | def __init__(self,): 550 | super(GaussianNoiseLayer, self).__init__() 551 | def forward(self, x): 552 | if self.training == False: 553 | return x 554 | noise = Variable(torch.randn(x.size()).cuda(x.get_device())) 555 | return x + noise 556 | 557 | class ReLUINSConvTranspose2d(nn.Module): 558 | def __init__(self, n_in, n_out, kernel_size, stride, padding, output_padding): 559 | super(ReLUINSConvTranspose2d, self).__init__() 560 | model = [] 561 | model += [nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=True)] 562 | model += [LayerNorm(n_out)] 563 | model += [nn.ReLU(inplace=True)] 564 | self.model = nn.Sequential(*model) 565 | self.model.apply(gaussian_weights_init) 566 | def forward(self, x): 567 | return self.model(x) 568 | 569 | 570 | #################################################################### 571 | #--------------------- Spectral Normalization --------------------- 572 | # This part of code is copied from pytorch master branch (0.5.0) 573 | #################################################################### 574 | class SpectralNorm(object): 575 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 576 | self.name = name 577 | self.dim = dim 578 | if n_power_iterations <= 0: 579 | raise ValueError('Expected n_power_iterations to be positive, but ' 580 | 'got n_power_iterations={}'.format(n_power_iterations)) 581 | self.n_power_iterations = n_power_iterations 582 | self.eps = eps 583 | def compute_weight(self, module): 584 | weight = getattr(module, self.name + '_orig') 585 | u = getattr(module, self.name + '_u') 586 | weight_mat = weight 587 | if self.dim != 0: 588 | # permute dim to front 589 | weight_mat = weight_mat.permute(self.dim, 590 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 591 | height = weight_mat.size(0) 592 | weight_mat = weight_mat.reshape(height, -1) 593 | with torch.no_grad(): 594 | for _ in range(self.n_power_iterations): 595 | v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 596 | u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 597 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 598 | weight = weight / sigma 599 | return weight, u 600 | def remove(self, module): 601 | weight = getattr(module, self.name) 602 | delattr(module, self.name) 603 | delattr(module, self.name + '_u') 604 | delattr(module, self.name + '_orig') 605 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 606 | def __call__(self, module, inputs): 607 | if module.training: 608 | weight, u = self.compute_weight(module) 609 | setattr(module, self.name, weight) 610 | setattr(module, self.name + '_u', u) 611 | else: 612 | r_g = getattr(module, self.name + '_orig').requires_grad 613 | getattr(module, self.name).detach_().requires_grad_(r_g) 614 | 615 | @staticmethod 616 | def apply(module, name, n_power_iterations, dim, eps): 617 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 618 | weight = module._parameters[name] 619 | height = weight.size(dim) 620 | u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 621 | delattr(module, fn.name) 622 | module.register_parameter(fn.name + "_orig", weight) 623 | module.register_buffer(fn.name, weight.data) 624 | module.register_buffer(fn.name + "_u", u) 625 | module.register_forward_pre_hook(fn) 626 | return fn 627 | 628 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 629 | if dim is None: 630 | if isinstance(module, (torch.nn.ConvTranspose1d, 631 | torch.nn.ConvTranspose2d, 632 | torch.nn.ConvTranspose3d)): 633 | dim = 1 634 | else: 635 | dim = 0 636 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 637 | return module 638 | 639 | def remove_spectral_norm(module, name='weight'): 640 | for k, hook in module._forward_pre_hooks.items(): 641 | if isinstance(hook, SpectralNorm) and hook.name == name: 642 | hook.remove(module) 643 | del module._forward_pre_hooks[k] 644 | return module 645 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 646 | 647 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class TrainOptions(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser() 6 | 7 | # data loader related 8 | self.parser.add_argument('--dataroot', type=str, required=True, help='path of data') 9 | self.parser.add_argument('--phase', type=str, default='train', help='phase for dataloading') 10 | self.parser.add_argument('--batch_size', type=int, default=2, help='batch size') 11 | self.parser.add_argument('--resize_size_x', type=int, default=144, help='resized image size for training,144 for face') 12 | self.parser.add_argument('--resize_size_y', type=int, default=144, help='resized image size for training') 13 | self.parser.add_argument('--crop_size', type=int, default=128, help='cropped image size for training, 128 for face') 14 | self.parser.add_argument('--input_dim_a', type=int, default=3, help='# of input channels for domain A') 15 | self.parser.add_argument('--input_dim_b', type=int, default=3, help='# of input channels for domain B') 16 | self.parser.add_argument('--nThreads', type=int, default=8, help='# of threads for data loader') 17 | self.parser.add_argument('--no_flip', action='store_true', help='specified if no flipping') 18 | 19 | # ouptput related 20 | self.parser.add_argument('--name', type=str, default='trial', help='folder name to save outputs') 21 | self.parser.add_argument('--display_dir', type=str, default='../logs', help='path for saving display results') 22 | self.parser.add_argument('--result_dir', type=str, default='../results', help='path for saving result images and models') 23 | self.parser.add_argument('--img_save_freq', type=int, default=1, help='freq (epoch) of saving images') 24 | self.parser.add_argument('--model_save_freq', type=int, default=1, help='freq (epoch) of saving models') 25 | 26 | # training related 27 | self.parser.add_argument('--concat', type=int, default=1, help='concatenate attribute features for translation, set 0 for using feature-wise transform') 28 | self.parser.add_argument('--dis_scale', type=int, default=3, help='scale of discriminator') 29 | self.parser.add_argument('--dis_norm', type=str, default='None', help='normalization layer in discriminator [None, Instance]') 30 | self.parser.add_argument('--dis_spectral_norm', action='store_true', help='use spectral normalization in discriminator') 31 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='type of learn rate decay') 32 | self.parser.add_argument('--n_ep', type=int, default=80, help='number of epochs') # 400 * d_iter 33 | self.parser.add_argument('--n_ep_decay', type=int, default=40, help='epoch start decay learning rate, set -1 if no decay') 34 | self.parser.add_argument('--resume', type=str, default=None, help='specified the dir of saved models for resume the training') 35 | self.parser.add_argument('--d_iter', type=int, default=3, help='# of iterations for updating content discriminator') 36 | self.parser.add_argument('--gpu', type=str, default='cuda', help='gpu ids: e.g. 0 0,1,2, 0,2') 37 | self.parser.add_argument('--lambdaB', type=float, default=0.1, help='perceptual loss weight for B') 38 | self.parser.add_argument('--lambdaI', type=float, default=0, help='perceptual loss weight for I') 39 | self.parser.add_argument('--percp_layer', type=int, default=14, help='the layer of feature for perceptual loss') 40 | self.parser.add_argument('--percep', type=str, default='default', help='type of perceptual loss: default, face, multi') 41 | self.parser.add_argument('--lr', type=float, default=0.0002, help='learning rate') 42 | 43 | def parse(self): 44 | self.opt = self.parser.parse_args() 45 | args = vars(self.opt) 46 | print('\n--- load options ---') 47 | for name, value in sorted(args.items()): 48 | print('%s: %s' % (str(name), str(value))) 49 | return self.opt 50 | 51 | class TestOptions(): 52 | def __init__(self): 53 | self.parser = argparse.ArgumentParser() 54 | 55 | # data loader related 56 | self.parser.add_argument('--dataroot', type=str, required=True, help='path of data') 57 | self.parser.add_argument('--phase', type=str, default='test', help='phase for dataloading') 58 | self.parser.add_argument('--resize_size', type=int, default=256, help='resized image size for training') 59 | self.parser.add_argument('--crop_size', type=int, default=128, help='cropped image size for training') 60 | self.parser.add_argument('--nThreads', type=int, default=4, help='for data loader') 61 | self.parser.add_argument('--input_dim_a', type=int, default=3, help='# of input channels for domain A') 62 | self.parser.add_argument('--input_dim_b', type=int, default=3, help='# of input channels for domain B') 63 | self.parser.add_argument('--a2b', type=int, default=1, help='translation direction, 1 for a2b, 0 for b2a') 64 | self.parser.add_argument('--lr', type=float, default=0.0002, help='learning rate') 65 | 66 | # ouptput related 67 | self.parser.add_argument('--num', type=int, default=5, help='number of outputs per image') 68 | self.parser.add_argument('--name', type=str, default='trial', help='folder name to save outputs') 69 | self.parser.add_argument('--result_dir', type=str, default='../outputs', help='path for saving result images and models') 70 | self.parser.add_argument('--orig_dir', type=str, default='../outputs', help='path for saving result images and models') 71 | 72 | # model related 73 | self.parser.add_argument('--concat', type=int, default=1, help='concatenate attribute features for translation, set 0 for using feature-wise transform') 74 | self.parser.add_argument('--resume', type=str, required=True, help='specified the dir of saved models for resume the training') 75 | self.parser.add_argument('--gpu', type=int, default=0, help='gpu') 76 | self.parser.add_argument('--lambdaB', type=float, default=0.1, help='perceptual loss weight for B') 77 | self.parser.add_argument('--lambdaI', type=float, default=10, help='color loss weight') 78 | self.parser.add_argument('--percep', type=str, default='default', help='type of perceptual loss: default(vgg19), face(vggface), multi') 79 | self.parser.add_argument('--percp_layer', type=int, default=14, help='the layer of feature for perceptual loss') 80 | 81 | 82 | def parse(self): 83 | self.opt = self.parser.parse_args() 84 | args = vars(self.opt) 85 | print('\n--- load options ---') 86 | for name, value in sorted(args.items()): 87 | print('%s: %s' % (str(name), str(value))) 88 | # set irrelevant options 89 | self.opt.dis_scale = 3 90 | self.opt.dis_norm = 'None' 91 | self.opt.dis_spectral_norm = False 92 | return self.opt 93 | -------------------------------------------------------------------------------- /src/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision.transforms import ToPILImage, Compose 6 | # tensor to PIL Image 7 | def tensor2img(img): 8 | img = img[0].cpu().float().numpy() 9 | if img.shape[0] == 1: 10 | img = np.tile(img, (3, 1, 1)) 11 | img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 12 | return img.astype(np.uint8) 13 | 14 | # save a set of images 15 | def save_imgs(imgs, names, path, yuv=False): 16 | if not os.path.exists(path): 17 | os.mkdir(path) 18 | img = tensor2img(imgs) 19 | img = Image.fromarray(img) 20 | img.save(os.path.join(path, names)) 21 | 22 | class Saver(): 23 | def __init__(self, opts): 24 | self.display_dir = os.path.join(opts.display_dir, opts.name) 25 | self.model_dir = os.path.join(opts.result_dir, opts.name) 26 | self.image_dir = os.path.join(self.model_dir, 'images') 27 | self.img_save_freq = opts.img_save_freq 28 | self.model_save_freq = opts.model_save_freq 29 | # make directory 30 | if not os.path.exists(self.display_dir): 31 | os.makedirs(self.display_dir) 32 | if not os.path.exists(self.model_dir): 33 | os.makedirs(self.model_dir) 34 | if not os.path.exists(self.image_dir): 35 | os.makedirs(self.image_dir) 36 | 37 | # save result images 38 | def write_img(self, ep, model): 39 | if (ep + 1) % self.img_save_freq == 0: 40 | assembled_images = model.assemble_outputs() 41 | img_filename = '%s/gen_%05d.jpg' % (self.image_dir, ep) 42 | torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1) 43 | elif ep == -1: 44 | assembled_images = model.assemble_outputs() 45 | img_filename = '%s/gen_last.jpg' % (self.image_dir, ep) 46 | torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1) 47 | 48 | # save model 49 | def write_model(self, ep, total_it, model): 50 | if (ep + 1) % self.model_save_freq == 0: 51 | print('--- save the model @ ep %d ---' % (ep)) 52 | model.save('%s/%05d.pth' % (self.model_dir, ep), ep, total_it) 53 | elif ep == -1: 54 | model.save('%s/last.pth' % self.model_dir, ep, total_it) 55 | 56 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from dataset import dataset_single 7 | from model import UID 8 | from networks import PerceptualLoss16,PerceptualLoss 9 | from options import TestOptions 10 | from saver import save_imgs 11 | from shutil import copyfile 12 | from skimage.measure import compare_psnr as PSNR 13 | from skimage.measure import compare_ssim as SSIM 14 | from skimage.io import imread 15 | from skimage.transform import resize 16 | 17 | def main(): 18 | 19 | # parse options 20 | parser = TestOptions() 21 | opts = parser.parse() 22 | result_dir = os.path.join(opts.result_dir, opts.name) 23 | orig_dir = opts.orig_dir 24 | blur_dir = opts.dataroot 25 | 26 | if not os.path.exists(result_dir): 27 | os.mkdir(result_dir) 28 | 29 | # data loader 30 | print('\n--- load dataset ---') 31 | if opts.a2b: 32 | dataset = dataset_single(opts, 'A', opts.input_dim_a) 33 | else: 34 | dataset = dataset_single(opts, 'B', opts.input_dim_b) 35 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=opts.nThreads) 36 | 37 | # model 38 | print('\n--- load model ---') 39 | model = UID(opts) 40 | model.setgpu(opts.gpu) 41 | model.resume(opts.resume, train=False) 42 | model.eval() 43 | 44 | # test 45 | print('\n--- testing ---') 46 | for idx1, (img1,img_name) in enumerate(loader): 47 | print('{}/{}'.format(idx1, len(loader))) 48 | img1 = img1.cuda(opts.gpu).detach() 49 | with torch.no_grad(): 50 | img = model.test_forward(img1, a2b=opts.a2b) 51 | img_name = img_name[0].split('/') 52 | img_name = img_name[-1] 53 | save_imgs(img, img_name, result_dir) 54 | 55 | # evaluate metrics 56 | if opts.percep == 'default': 57 | pLoss = PerceptualLoss(nn.MSELoss(),p_layer=36) 58 | elif opts.percep == 'face': 59 | self.perceptualLoss = networks.PerceptualLoss16(nn.MSELoss(),p_layer=30) 60 | else: 61 | self.perceptualLoss = networks.MultiPerceptualLoss(nn.MSELoss()) 62 | 63 | orig_list = sorted(os.listdir(orig_dir)) 64 | deblur_list = sorted(os.listdir(result_dir)) 65 | blur_list = sorted(os.listdir(blur_dir)) 66 | 67 | psnr = [] 68 | ssim = [] 69 | percp = [] 70 | blur_psnr = [] 71 | blur_ssim = [] 72 | blur_percp = [] 73 | 74 | for (deblur_img_name, orig_img_name, blur_img_name) in zip(deblur_list, orig_list, blur_list): 75 | deblur_img_name = os.path.join(result_dir,deblur_img_name) 76 | orig_img_name = os.path.join(orig_dir,orig_img_name) 77 | blur_img_name = os.path.join(blur_dir, blur_img_name) 78 | deblur_img = imread(deblur_img_name) 79 | orig_img = imread(orig_img_name) 80 | blur_img = imread(blur_img_name) 81 | try: 82 | psnr.append(PSNR(deblur_img, orig_img)) 83 | ssim.append(SSIM(deblur_img, orig_img, multichannel=True)) 84 | blur_psnr.append(PSNR(blur_img, orig_img)) 85 | blur_ssim.append(SSIM(blur_img, orig_img, multichannel=True)) 86 | except ValueError: 87 | print(orig_img_name) 88 | 89 | with torch.no_grad(): 90 | temp = pLoss.getloss(deblur_img,orig_img) 91 | temp2 = pLoss.getloss(blur_img,orig_img) 92 | percp.append(temp) 93 | blur_percp.append(temp2) 94 | 95 | print(sum(psnr)/len(psnr)) 96 | print(sum(ssim)/len(ssim)) 97 | print(sum(percp)/len(percp)) 98 | 99 | print(sum(blur_psnr)/len(psnr)) 100 | print(sum(blur_ssim)/len(ssim)) 101 | print(sum(blur_percp)/len(percp)) 102 | return 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from options import TrainOptions 3 | from dataset import dataset_unpair 4 | from model import UID 5 | from saver import Saver 6 | 7 | def main(): 8 | # parse options 9 | parser = TrainOptions() 10 | opts = parser.parse() 11 | 12 | # daita loader 13 | print('\n--- load dataset ---') 14 | dataset = dataset_unpair(opts) 15 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.nThreads) 16 | 17 | # model 18 | print('\n--- load model ---') 19 | model = UID(opts) 20 | model.setgpu(opts.gpu) 21 | if opts.resume is None: 22 | model.initialize() 23 | ep0 = -1 24 | total_it = 0 25 | else: 26 | ep0, total_it = model.resume(opts.resume) 27 | model.set_scheduler(opts, last_ep=ep0) 28 | ep0 += 1 29 | print('start the training at epoch %d'%(ep0)) 30 | 31 | # saver for display and output 32 | saver = Saver(opts) 33 | 34 | # train 35 | print('\n--- train ---') 36 | max_it = 500000 37 | for ep in range(ep0, opts.n_ep): 38 | for it, (images_a, images_b) in enumerate(train_loader): 39 | if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size: 40 | continue 41 | images_a = images_a.cuda(opts.gpu).detach() 42 | images_b = images_b.cuda(opts.gpu).detach() 43 | 44 | # update model 45 | model.update_D(images_a, images_b) 46 | if (it + 1) % 2 != 0 and it != len(train_loader)-1: 47 | continue 48 | model.update_EG() 49 | 50 | # save to display file 51 | if (it+1) % 48 == 0: 52 | print('total_it: %d (ep %d, it %d), lr %08f' % (total_it+1, ep, it+1, model.gen_opt.param_groups[0]['lr'])) 53 | print('Dis_I_loss: %04f, Dis_B_loss %04f, GAN_loss_I %04f, GAN_loss_B %04f' % (model.disA_loss, model.disB_loss, model.gan_loss_i,model.gan_loss_b)) 54 | print('B_percp_loss %04f, Recon_II_loss %04f' % (model.B_percp_loss, model.l1_recon_II_loss)) 55 | if (it+1) % 200 == 0: 56 | saver.write_img(ep*len(train_loader) + (it+1), model) 57 | 58 | total_it += 1 59 | if total_it >= max_it: 60 | saver.write_img(-1, model) 61 | saver.write_model(-1, model) 62 | break 63 | 64 | # decay learning rate 65 | if opts.n_ep_decay > -1: 66 | model.update_lr() 67 | 68 | # Save network weights 69 | saver.write_model(ep, total_it+1, model) 70 | 71 | return 72 | 73 | if __name__ == '__main__': 74 | main() 75 | --------------------------------------------------------------------------------