├── README.md ├── datasets_torch.py ├── examples ├── CUB_a_150000.jpg ├── CUB_v_150000.jpg ├── FLO_a_150000.jpg ├── FLO_v_150000.jpg ├── aa └── framework.jpg ├── main.py ├── model.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # ZstGAN-PyTorch 2 | PyTorch Implementation of "ZstGAN: An Adversarial Approach for Unsupervised Zero-Shot Image-to-Image Translation" 3 | 4 | # Dependency: 5 | Python 3.6 6 | 7 | PyTorch 0.4.0 8 | 9 | # Usage: 10 | ### Unsupervised Zero-Shot Image-to-Image Transaltion 11 | 1. Downloading CUB and FLO training and testing dataset following [CUB and FLO](https://pan.baidu.com/s/1m4a4PFpjFNMNLIdE8TlYAQ) with password `n6qd`. Or you can follow the [StackGAN](https://github.com/hanzhanggit/StackGAN) to prepare these two datasets. 12 | 13 | 2. Unzip the Data.zip and organize the CUB and FLO training and testing sets as: 14 | 15 | Data 16 | ├── flowers 17 | | ├── train 18 | | ├── test 19 | | └── ... 20 | ├── birds 21 | ├── train 22 | ├── test 23 | └── ... 24 | 25 | 3. Train ZstGAN on seen domains of FLO: 26 | 27 | `$ python main.py --mode train --model_dir flower --datadir Data/flowers/ --c_dim 102 --batch_size 8 --nz_num 312 --ft_num 2048 --lambda_mut 200` 28 | 4. Train ZstGAN on seen domains of CUB: 29 | 30 | `$ python main.py --mode train --model_dir bird --datadir Data/birds/ --c_dim 200 --batch_size 8 --nz_num 312 --ft_num 2048 --lambda_mut 50` 31 | 5. Test ZstGAN on unseen domains of FLO at iteration 200000: 32 | 33 | `$ python main.py --mode test --model_dir flower --datadir Data/flowers/ --c_dim 102 --test_iters 200000` 34 | 6. Test ZstGAN on unseen domains of CUB at iteration 200000: 35 | 36 | `$ python main.py --mode test --model_dir bird --datadir Data/birds/ --c_dim 200 --test_iters 200000` 37 | # Results: 38 | ### 1. Image translation on unseen domains of FLO at iterations 150000: 39 | 40 | **# Results of V-ZstGAN**: 41 | 42 | 43 | 44 | **# Results of A-ZstGAN**: 45 | 46 | 47 | 48 | ### 2. Image translation on unseen domains of CUB at iterations 150000: 49 | 50 | **# Results of V-ZstGAN**: 51 | 52 | 53 | 54 | **# Results of A-ZstGAN**: 55 | 56 | 57 | -------------------------------------------------------------------------------- /datasets_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | 5 | import numpy as np 6 | import pickle 7 | import random 8 | import sys 9 | 10 | import torch 11 | class Dataset(object): 12 | def __init__(self, images, imsize, embeddings=None, 13 | filenames=None, workdir=None, 14 | labels=None, aug_flag=True, 15 | class_id=None, class_range=None): 16 | self._images = images 17 | self._embeddings = embeddings 18 | self._filenames = filenames 19 | self.workdir = workdir 20 | self._labels = labels 21 | self._epochs_completed = -1 22 | self._num_examples = len(images) 23 | self._saveIDs = self.saveIDs() 24 | 25 | # shuffle on first run 26 | self._index_in_epoch = self._num_examples 27 | self._aug_flag = aug_flag 28 | self._class_id = np.array(class_id) 29 | self._class_range = class_range 30 | self._imsize = imsize 31 | #self._perm = None 32 | self._perm = np.arange(self._num_examples) 33 | np.random.shuffle(self._perm) 34 | def reinitialize_index(self): 35 | self._index_in_epoch = 0 36 | return None 37 | @property 38 | def images(self): 39 | return self._images 40 | 41 | @property 42 | def embeddings(self): 43 | return self._embeddings 44 | 45 | @property 46 | def filenames(self): 47 | return self._filenames 48 | 49 | @property 50 | def num_examples(self): 51 | return self._num_examples 52 | 53 | @property 54 | def epochs_completed(self): 55 | return self._epochs_completed 56 | 57 | def saveIDs(self): 58 | self._saveIDs = np.arange(self._num_examples) 59 | np.random.shuffle(self._saveIDs) 60 | return self._saveIDs 61 | 62 | def readCaptions(self, filenames, class_id): 63 | name = filenames 64 | if name.find('jpg/') != -1: # flowers dataset 65 | class_name = 'class_%05d/' % class_id 66 | name = name.replace('jpg/', class_name) 67 | cap_path = '%s/text_c10/%s.txt' %\ 68 | (self.workdir, name) 69 | with open(cap_path, "r") as f: 70 | captions = f.read().split('\n') 71 | captions = [cap for cap in captions if len(cap) > 0] 72 | return captions 73 | 74 | def transform(self, images): 75 | if self._aug_flag: 76 | transformed_images =\ 77 | np.zeros([images.shape[0], self._imsize, self._imsize, 3]) 78 | for i in range(images.shape[0]): 79 | if random.random() > 0.5: 80 | transformed_images[i] = np.fliplr(images[i]) 81 | else: 82 | transformed_images[i] = images[i] 83 | return transformed_images 84 | else: 85 | return images 86 | 87 | def sample_embeddings(self, embeddings, filenames, class_id, sample_num): 88 | if len(embeddings.shape) == 2 or embeddings.shape[1] == 1: 89 | return np.squeeze(embeddings) 90 | else: 91 | batch_size, embedding_num, _ = embeddings.shape 92 | # Take every sample_num captions to compute the mean vector 93 | sampled_embeddings = [] 94 | sampled_captions = [] 95 | for i in range(batch_size): 96 | randix = np.random.choice(embedding_num, 97 | sample_num, replace=False) 98 | if sample_num == 1: 99 | randix = int(randix) 100 | captions = self.readCaptions(filenames[i], 101 | class_id[i]) 102 | #sampled_captions.append(captions[randix]) 103 | sampled_embeddings.append(embeddings[i, randix, :]) 104 | else: 105 | e_sample = embeddings[i, randix, :] 106 | e_mean = np.mean(e_sample, axis=0) 107 | sampled_embeddings.append(e_mean) 108 | sampled_embeddings_array = np.array(sampled_embeddings) 109 | return np.squeeze(sampled_embeddings_array), sampled_captions 110 | 111 | def next_batch(self, batch_size, window): 112 | """Return the next `batch_size` examples from this data set.""" 113 | start = self._index_in_epoch 114 | self._index_in_epoch += batch_size 115 | 116 | if self._index_in_epoch > self._num_examples: 117 | # Finished epoch 118 | self._epochs_completed += 1 119 | # Shuffle the data 120 | self._perm = np.arange(self._num_examples) 121 | np.random.shuffle(self._perm) 122 | 123 | # Start next epoch 124 | start = 0 125 | self._index_in_epoch = batch_size 126 | assert batch_size <= self._num_examples 127 | end = self._index_in_epoch 128 | 129 | current_ids = self._perm[start:end] 130 | fake_ids = np.random.randint(self._num_examples, size=batch_size) 131 | collision_flag =\ 132 | (self._class_id[current_ids] == self._class_id[fake_ids]) 133 | fake_ids[collision_flag] =\ 134 | (fake_ids[collision_flag] + 135 | np.random.randint(100, 200)) % self._num_examples 136 | 137 | sampled_images = self._images[current_ids] 138 | sampled_wrong_images = self._images[fake_ids, :, :, :] 139 | sampled_images = sampled_images.astype(np.float32) 140 | sampled_wrong_images = sampled_wrong_images.astype(np.float32) 141 | sampled_images = sampled_images * (2. / 255) - 1. 142 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1. 143 | 144 | sampled_images = self.transform(sampled_images) 145 | sampled_wrong_images = self.transform(sampled_wrong_images) 146 | ret_list = [torch.FloatTensor(sampled_images.transpose((0,3,1,2))), torch.FloatTensor(sampled_wrong_images.transpose((0,3,1,2)))] 147 | 148 | if self._embeddings is not None: 149 | filenames = [self._filenames[i] for i in current_ids] 150 | class_id = [self._class_id[i] for i in current_ids] 151 | sampled_embeddings, sampled_captions = \ 152 | self.sample_embeddings(self._embeddings[current_ids], 153 | filenames, class_id, window) 154 | ret_list.append(torch.FloatTensor(sampled_embeddings)) 155 | ret_list.append(torch.FloatTensor(sampled_captions)) 156 | else: 157 | ret_list.append(None) 158 | ret_list.append(None) 159 | 160 | if self._labels is not None: 161 | ret_list.append(torch.LongTensor(np.array(self._labels)[current_ids]-1)) 162 | else: 163 | ret_list.append(None) 164 | return ret_list 165 | def next_batch_test(self, batch_size, window): 166 | """Return the next `batch_size` examples from this data set.""" 167 | start = self._index_in_epoch 168 | self._index_in_epoch += batch_size 169 | 170 | if self._index_in_epoch > self._num_examples: 171 | ret_list = [] 172 | return ret_list 173 | end = self._index_in_epoch 174 | 175 | current_ids = self._perm[start:end] 176 | fake_ids = np.random.randint(self._num_examples, size=batch_size) 177 | collision_flag =\ 178 | (self._class_id[current_ids] == self._class_id[fake_ids]) 179 | fake_ids[collision_flag] =\ 180 | (fake_ids[collision_flag] + 181 | np.random.randint(100, 200)) % self._num_examples 182 | 183 | sampled_images = self._images[current_ids] 184 | sampled_wrong_images = self._images[fake_ids, :, :, :] 185 | sampled_images = sampled_images.astype(np.float32) 186 | sampled_wrong_images = sampled_wrong_images.astype(np.float32) 187 | sampled_images = sampled_images * (2. / 255) - 1. 188 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1. 189 | 190 | sampled_images = self.transform(sampled_images) 191 | sampled_wrong_images = self.transform(sampled_wrong_images) 192 | ret_list = [torch.FloatTensor(sampled_images.transpose((0,3,1,2))), torch.FloatTensor(sampled_wrong_images.transpose((0,3,1,2)))] 193 | 194 | if self._embeddings is not None: 195 | filenames = [self._filenames[i] for i in current_ids] 196 | class_id = [self._class_id[i] for i in current_ids] 197 | sampled_embeddings, sampled_captions = \ 198 | self.sample_embeddings(self._embeddings[current_ids], 199 | filenames, class_id, window) 200 | ret_list.append(torch.FloatTensor(sampled_embeddings)) 201 | ret_list.append(torch.FloatTensor(sampled_captions)) 202 | else: 203 | ret_list.append(None) 204 | ret_list.append(None) 205 | 206 | if self._labels is not None: 207 | ret_list.append(torch.LongTensor(np.array(self._labels)[current_ids]-1)) 208 | else: 209 | ret_list.append(None) 210 | return ret_list 211 | def next_batch_val(self, batch_size, window): 212 | """Return the next `batch_size` examples from this data set.""" 213 | start = self._index_in_epoch 214 | self._index_in_epoch += batch_size 215 | 216 | if self._index_in_epoch > self._num_examples: 217 | # Finished epoch 218 | sys.exit() 219 | end = self._index_in_epoch 220 | 221 | current_ids = self._perm[start:end] 222 | fake_ids = np.random.randint(self._num_examples, size=batch_size) 223 | collision_flag =\ 224 | (self._class_id[current_ids] == self._class_id[fake_ids]) 225 | fake_ids[collision_flag] =\ 226 | (fake_ids[collision_flag] + 227 | np.random.randint(100, 200)) % self._num_examples 228 | 229 | sampled_images = self._images[current_ids] 230 | sampled_wrong_images = self._images[fake_ids, :, :, :] 231 | sampled_images = sampled_images.astype(np.float32) 232 | sampled_wrong_images = sampled_wrong_images.astype(np.float32) 233 | sampled_images = sampled_images * (2. / 255) - 1. 234 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1. 235 | 236 | sampled_images = self.transform(sampled_images) 237 | sampled_wrong_images = self.transform(sampled_wrong_images) 238 | ret_list = [torch.FloatTensor(sampled_images.transpose((0,3,1,2))), torch.FloatTensor(sampled_wrong_images.transpose((0,3,1,2)))] 239 | 240 | if self._embeddings is not None: 241 | filenames = [self._filenames[i] for i in current_ids] 242 | class_id = [self._class_id[i] for i in current_ids] 243 | sampled_embeddings, sampled_captions = \ 244 | self.sample_embeddings(self._embeddings[current_ids], 245 | filenames, class_id, window) 246 | ret_list.append(torch.FloatTensor(sampled_embeddings)) 247 | ret_list.append(torch.FloatTensor(sampled_captions)) 248 | else: 249 | ret_list.append(None) 250 | ret_list.append(None) 251 | 252 | if self._labels is not None: 253 | ret_list.append(torch.LongTensor(np.array(self._labels)[current_ids]-1)) 254 | else: 255 | ret_list.append(None) 256 | return ret_list 257 | 258 | 259 | class TextDataset(object): 260 | def __init__(self, workdir, embedding_type, image_size): 261 | self.image_filename = '/128images.pickle' 262 | 263 | 264 | self.image_shape = [image_size, 265 | image_size, 3] 266 | self.image_dim = self.image_shape[0] * self.image_shape[1] * 3 267 | self.embedding_shape = None 268 | self.train = None 269 | self.test = None 270 | self.workdir = workdir 271 | if embedding_type == 'cnn-rnn': 272 | self.embedding_filename = '/char-CNN-RNN-embeddings.pickle' 273 | elif embedding_type == 'skip-thought': 274 | self.embedding_filename = '/skip-thought-embeddings.pickle' 275 | 276 | def get_data(self, pickle_path, aug_flag=True): 277 | with open(pickle_path + self.image_filename, 'rb') as f: 278 | images = pickle.load(f, encoding='latin1') 279 | images = np.array(images) 280 | print('images: ', images.shape) 281 | 282 | with open(pickle_path + self.embedding_filename, 'rb') as f: 283 | embeddings = pickle.load(f, encoding='latin1') 284 | embeddings = np.array(embeddings) 285 | self.embedding_shape = [embeddings.shape[-1]] 286 | print('embeddings: ', embeddings.shape) 287 | with open(pickle_path + '/filenames.pickle', 'rb') as f: 288 | list_filenames = pickle.load(f, encoding='latin1') 289 | print('list_filenames: ', len(list_filenames), list_filenames[0]) 290 | with open(pickle_path + '/class_info.pickle', 'rb') as f: 291 | class_id = pickle.load(f, encoding='latin1') 292 | 293 | return Dataset(images, self.image_shape[0], embeddings, 294 | list_filenames, self.workdir, class_id, 295 | aug_flag, class_id) 296 | -------------------------------------------------------------------------------- /examples/CUB_a_150000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/CUB_a_150000.jpg -------------------------------------------------------------------------------- /examples/CUB_v_150000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/CUB_v_150000.jpg -------------------------------------------------------------------------------- /examples/FLO_a_150000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/FLO_a_150000.jpg -------------------------------------------------------------------------------- /examples/FLO_v_150000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/FLO_v_150000.jpg -------------------------------------------------------------------------------- /examples/aa: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/framework.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from trainer import Solver 4 | from torch.backends import cudnn 5 | from torchvision import transforms, datasets 6 | import torch.utils.data as data 7 | import torch 8 | from torchvision.utils import save_image 9 | from datasets_torch import TextDataset 10 | 11 | 12 | def main(config): 13 | cudnn.benchmark = True 14 | torch.manual_seed(7) # cpu 15 | torch.cuda.manual_seed_all(999) #gpu 16 | 17 | # Create directories if not exist. 18 | config.log_dir = os.path.join(config.model_dir, 'logs') 19 | config.model_save_dir = os.path.join(config.model_dir, 'models') 20 | config.sample_dir = os.path.join(config.model_dir, 'samples') 21 | config.result_dir = os.path.join(config.model_dir, 'results') 22 | 23 | if not os.path.exists(config.log_dir): 24 | os.makedirs(config.log_dir) 25 | if not os.path.exists(config.model_save_dir): 26 | os.makedirs(config.model_save_dir) 27 | if not os.path.exists(config.sample_dir): 28 | os.makedirs(config.sample_dir) 29 | if not os.path.exists(config.result_dir): 30 | os.makedirs(config.result_dir) 31 | 32 | # dataloader 33 | dataset = TextDataset(config.datadir, 'cnn-rnn', config.image_size) 34 | filename_test = '%s/test' % (config.datadir) 35 | dataset.test = dataset.get_data(filename_test) 36 | filename_train = '%s/train' % (config.datadir) 37 | dataset.train = dataset.get_data(filename_train) 38 | 39 | # Solver for training and testing ZstGAN. 40 | solver = Solver(dataset, config) 41 | 42 | if config.mode == 'train': 43 | solver.train() # train mode for ZstGAN 44 | elif config.mode == 'test': 45 | solver.test() # test mode for ZstGAN 46 | 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | 52 | # Model configuration. 53 | parser.add_argument('--c_dim', type=int, default=200, help='dimension of domain labels (1st dataset)') 54 | parser.add_argument('--image_size', type=int, default=128, help='image resolution') 55 | parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G') 56 | parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') 57 | parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G') 58 | parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') 59 | parser.add_argument('--n_blocks', type=int, default=0, help='number of res conv layers in C') 60 | parser.add_argument('--lambda_mut', type=float, default=10, help='weight for multual information loss') 61 | parser.add_argument('--lambda_rec', type=float, default=1, help='weight for reconstruction loss') 62 | parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') 63 | parser.add_argument('--ft_num', type=int, default=2048, help='number of ds feature') 64 | parser.add_argument('--nz_num', type=int, default=312, help='number of noise feature') 65 | parser.add_argument('--att_num', type=int, default=1024, help='number of attribute feature') 66 | 67 | # Training configuration. 68 | parser.add_argument('--batch_size', type=int, default=8, help='mini-batch size') 69 | parser.add_argument('--num_iters', type=int, default=300000, help='number of total iterations for training D') 70 | parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') 71 | parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') 72 | parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') 73 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') 74 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 75 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 76 | parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') 77 | parser.add_argument('--ev_ea_c_iters', type=int, default=80000, help='number of iterations for training encoder_a and encoder_v') 78 | parser.add_argument('--c_pre_iters', type=int, default=20000, help='number of iterations for pre-training C') 79 | 80 | # Test configuration. 81 | parser.add_argument('--test_iters', type=int, default=300000, help='test model from this step') 82 | 83 | # Miscellaneous. 84 | parser.add_argument('--num_workers', type=int, default=1) 85 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 86 | 87 | # Directories. 88 | parser.add_argument('--datadir', type=str, default='Data/birds') 89 | parser.add_argument('--model_dir', type=str, default='zstgan') 90 | 91 | # Step size. 92 | parser.add_argument('--log_step', type=int, default=100) 93 | parser.add_argument('--sample_step', type=int, default=2000) 94 | parser.add_argument('--model_save_step', type=int, default=20000) 95 | parser.add_argument('--lr_update_step', type=int, default=1000) 96 | 97 | config = parser.parse_args() 98 | print(config) 99 | main(config) 100 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision.models as models 6 | try: 7 | from itertools import izip as zip 8 | except ImportError: # will be 3.x series 9 | pass 10 | 11 | class AdaINEnc(nn.Module): 12 | # AdaIN encoder architecture 13 | def __init__(self, input_dim, ft_num): 14 | super(AdaINEnc, self).__init__() 15 | 16 | dim = 64 17 | style_dim = ft_num 18 | n_downsample = 2 19 | n_res = 16 20 | activ = 'relu' 21 | pad_type = 'reflect' 22 | mlp_dim = 256 23 | 24 | # encoder 25 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) 26 | 27 | def forward(self, images): 28 | # reconstruct an image 29 | content = self.encode(images) 30 | return content 31 | 32 | def encode(self, images): 33 | # encode an image to its content and style codes 34 | content = self.enc_content(images) 35 | return content 36 | 37 | 38 | 39 | class AdaINDec(nn.Module): 40 | # AdaIN decoder architecture 41 | def __init__(self, input_dim, ft_num): 42 | super(AdaINDec, self).__init__() 43 | 44 | dim = 64 45 | style_dim = ft_num 46 | n_downsample = 2 47 | n_res = 16 48 | activ = 'relu' 49 | pad_type = 'reflect' 50 | mlp_dim = 256 51 | 52 | 53 | 54 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) 55 | # decoder 56 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type) 57 | 58 | # MLP to generate AdaIN parameters 59 | self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ) 60 | 61 | def forward(self, content, style): 62 | # decode content and style codes to an image 63 | adain_params = self.mlp(style) 64 | self.assign_adain_params(adain_params, self.dec) 65 | images = self.dec(content) 66 | return images 67 | 68 | def assign_adain_params(self, adain_params, model): 69 | # assign the adain_params to the AdaIN layers in model 70 | for m in model.modules(): 71 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 72 | mean = adain_params[:, :m.num_features] 73 | std = adain_params[:, m.num_features:2*m.num_features] 74 | m.bias = mean.contiguous().view(-1) 75 | m.weight = std.contiguous().view(-1) 76 | if adain_params.size(1) > 2*m.num_features: 77 | adain_params = adain_params[:, 2*m.num_features:] 78 | 79 | def get_num_adain_params(self, model): 80 | # return the number of AdaIN parameters needed by the model 81 | num_adain_params = 0 82 | for m in model.modules(): 83 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 84 | num_adain_params += 2*m.num_features 85 | return num_adain_params 86 | 87 | 88 | 89 | class ContentEncoder(nn.Module): 90 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): 91 | super(ContentEncoder, self).__init__() 92 | self.model = [] 93 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 94 | # downsampling blocks 95 | for i in range(n_downsample): 96 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 97 | dim *= 2 98 | # residual blocks 99 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 100 | self.model = nn.Sequential(*self.model) 101 | self.output_dim = dim 102 | 103 | def forward(self, x): 104 | return self.model(x) 105 | 106 | class Decoder(nn.Module): 107 | def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): 108 | super(Decoder, self).__init__() 109 | 110 | self.model = [] 111 | # AdaIN residual blocks 112 | self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 113 | # upsampling blocks 114 | for i in range(n_upsample): 115 | self.model += [nn.Upsample(scale_factor=2), 116 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 117 | dim //= 2 118 | # use reflection padding in the last conv layer 119 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 120 | self.model = nn.Sequential(*self.model) 121 | 122 | def forward(self, x): 123 | return self.model(x) 124 | class ResBlocks(nn.Module): 125 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): 126 | super(ResBlocks, self).__init__() 127 | self.model = [] 128 | for i in range(num_blocks): 129 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 130 | self.model = nn.Sequential(*self.model) 131 | 132 | def forward(self, x): 133 | return self.model(x) 134 | 135 | class MLP(nn.Module): 136 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): 137 | 138 | super(MLP, self).__init__() 139 | self.model = [] 140 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] 141 | for i in range(n_blk - 2): 142 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] 143 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations 144 | self.model = nn.Sequential(*self.model) 145 | 146 | def forward(self, x): 147 | return self.model(x.view(x.size(0), -1)) 148 | class ResBlock(nn.Module): 149 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 150 | super(ResBlock, self).__init__() 151 | 152 | model = [] 153 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 154 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 155 | self.model = nn.Sequential(*model) 156 | 157 | def forward(self, x): 158 | residual = x 159 | out = self.model(x) 160 | out += residual 161 | return out 162 | 163 | class Conv2dBlock(nn.Module): 164 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 165 | padding=0, norm='none', activation='relu', pad_type='zero'): 166 | super(Conv2dBlock, self).__init__() 167 | self.use_bias = True 168 | # initialize padding 169 | if pad_type == 'reflect': 170 | self.pad = nn.ReflectionPad2d(padding) 171 | elif pad_type == 'replicate': 172 | self.pad = nn.ReplicationPad2d(padding) 173 | elif pad_type == 'zero': 174 | self.pad = nn.ZeroPad2d(padding) 175 | else: 176 | assert 0, "Unsupported padding type: {}".format(pad_type) 177 | 178 | # initialize normalization 179 | norm_dim = output_dim 180 | if norm == 'bn': 181 | self.norm = nn.BatchNorm2d(norm_dim) 182 | elif norm == 'in': 183 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 184 | self.norm = nn.InstanceNorm2d(norm_dim) 185 | elif norm == 'ln': 186 | self.norm = LayerNorm(norm_dim) 187 | elif norm == 'adain': 188 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 189 | elif norm == 'none' or norm == 'sn': 190 | self.norm = None 191 | else: 192 | assert 0, "Unsupported normalization: {}".format(norm) 193 | 194 | # initialize activation 195 | if activation == 'relu': 196 | self.activation = nn.ReLU(inplace=True) 197 | elif activation == 'lrelu': 198 | self.activation = nn.LeakyReLU(0.2, inplace=True) 199 | elif activation == 'prelu': 200 | self.activation = nn.PReLU() 201 | elif activation == 'selu': 202 | self.activation = nn.SELU(inplace=True) 203 | elif activation == 'tanh': 204 | self.activation = nn.Tanh() 205 | elif activation == 'none': 206 | self.activation = None 207 | else: 208 | assert 0, "Unsupported activation: {}".format(activation) 209 | 210 | # initialize convolution 211 | if norm == 'sn': 212 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 213 | else: 214 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 215 | 216 | def forward(self, x): 217 | x = self.conv(self.pad(x)) 218 | if self.norm: 219 | x = self.norm(x) 220 | if self.activation: 221 | x = self.activation(x) 222 | return x 223 | 224 | class LinearBlock(nn.Module): 225 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 226 | super(LinearBlock, self).__init__() 227 | use_bias = True 228 | # initialize fully connected layer 229 | if norm == 'sn': 230 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) 231 | else: 232 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 233 | 234 | # initialize normalization 235 | norm_dim = output_dim 236 | if norm == 'bn': 237 | self.norm = nn.BatchNorm1d(norm_dim) 238 | elif norm == 'in': 239 | self.norm = nn.InstanceNorm1d(norm_dim) 240 | elif norm == 'ln': 241 | self.norm = LayerNorm(norm_dim) 242 | elif norm == 'none' or norm == 'sn': 243 | self.norm = None 244 | else: 245 | assert 0, "Unsupported normalization: {}".format(norm) 246 | 247 | # initialize activation 248 | if activation == 'relu': 249 | self.activation = nn.ReLU(inplace=True) 250 | elif activation == 'lrelu': 251 | self.activation = nn.LeakyReLU(0.2, inplace=True) 252 | elif activation == 'prelu': 253 | self.activation = nn.PReLU() 254 | elif activation == 'selu': 255 | self.activation = nn.SELU(inplace=True) 256 | elif activation == 'tanh': 257 | self.activation = nn.Tanh() 258 | elif activation == 'none': 259 | self.activation = None 260 | else: 261 | assert 0, "Unsupported activation: {}".format(activation) 262 | 263 | def forward(self, x): 264 | out = self.fc(x) 265 | if self.norm: 266 | out = self.norm(out) 267 | if self.activation: 268 | out = self.activation(out) 269 | return out 270 | 271 | 272 | class AdaptiveInstanceNorm2d(nn.Module): 273 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 274 | super(AdaptiveInstanceNorm2d, self).__init__() 275 | self.num_features = num_features 276 | self.eps = eps 277 | self.momentum = momentum 278 | # weight and bias are dynamically assigned 279 | self.weight = None 280 | self.bias = None 281 | # just dummy buffers, not used 282 | self.register_buffer('running_mean', torch.zeros(num_features)) 283 | self.register_buffer('running_var', torch.ones(num_features)) 284 | 285 | def forward(self, x): 286 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 287 | b, c = x.size(0), x.size(1) 288 | running_mean = self.running_mean.repeat(b) 289 | running_var = self.running_var.repeat(b) 290 | 291 | # Apply instance norm 292 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 293 | 294 | out = F.batch_norm( 295 | x_reshaped, running_mean, running_var, self.weight, self.bias, 296 | True, self.momentum, self.eps) 297 | 298 | return out.view(b, c, *x.size()[2:]) 299 | 300 | def __repr__(self): 301 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 302 | 303 | 304 | class LayerNorm(nn.Module): 305 | def __init__(self, num_features, eps=1e-5, affine=True): 306 | super(LayerNorm, self).__init__() 307 | self.num_features = num_features 308 | self.affine = affine 309 | self.eps = eps 310 | 311 | if self.affine: 312 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 313 | self.beta = nn.Parameter(torch.zeros(num_features)) 314 | 315 | def forward(self, x): 316 | shape = [-1] + [1] * (x.dim() - 1) 317 | # print(x.size()) 318 | if x.size(0) == 1: 319 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 320 | mean = x.view(-1).mean().view(*shape) 321 | std = x.view(-1).std().view(*shape) 322 | else: 323 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 324 | std = x.view(x.size(0), -1).std(1).view(*shape) 325 | 326 | x = (x - mean) / (std + self.eps) 327 | 328 | if self.affine: 329 | shape = [1, -1] + [1] * (x.dim() - 2) 330 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 331 | return x 332 | 333 | def l2normalize(v, eps=1e-12): 334 | return v / (v.norm() + eps) 335 | 336 | 337 | class SpectralNorm(nn.Module): 338 | """ 339 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 340 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 341 | """ 342 | def __init__(self, module, name='weight', power_iterations=1): 343 | super(SpectralNorm, self).__init__() 344 | self.module = module 345 | self.name = name 346 | self.power_iterations = power_iterations 347 | if not self._made_params(): 348 | self._make_params() 349 | 350 | def _update_u_v(self): 351 | u = getattr(self.module, self.name + "_u") 352 | v = getattr(self.module, self.name + "_v") 353 | w = getattr(self.module, self.name + "_bar") 354 | 355 | height = w.data.shape[0] 356 | for _ in range(self.power_iterations): 357 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 358 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 359 | 360 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 361 | sigma = u.dot(w.view(height, -1).mv(v)) 362 | setattr(self.module, self.name, w / sigma.expand_as(w)) 363 | 364 | def _made_params(self): 365 | try: 366 | u = getattr(self.module, self.name + "_u") 367 | v = getattr(self.module, self.name + "_v") 368 | w = getattr(self.module, self.name + "_bar") 369 | return True 370 | except AttributeError: 371 | return False 372 | 373 | 374 | def _make_params(self): 375 | w = getattr(self.module, self.name) 376 | 377 | height = w.data.shape[0] 378 | width = w.view(height, -1).data.shape[1] 379 | 380 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 381 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 382 | u.data = l2normalize(u.data) 383 | v.data = l2normalize(v.data) 384 | w_bar = nn.Parameter(w.data) 385 | 386 | del self.module._parameters[self.name] 387 | 388 | self.module.register_parameter(self.name + "_u", u) 389 | self.module.register_parameter(self.name + "_v", v) 390 | self.module.register_parameter(self.name + "_bar", w_bar) 391 | 392 | 393 | def forward(self, *args): 394 | self._update_u_v() 395 | return self.module.forward(*args) 396 | 397 | 398 | 399 | 400 | 401 | class Resnet_Feature(nn.Module): 402 | name = 'Resnet_Feature' 403 | def __init__(self): 404 | super(Resnet_Feature, self).__init__() 405 | res50_model = models.resnet50(pretrained=True) 406 | 407 | self.res50_conv = nn.Sequential(*list(res50_model.children())[:-2]) 408 | 409 | self.avp = nn.AvgPool2d(kernel_size=4, stride=1, padding=0) 410 | 411 | def forward(self, x): 412 | output = self.avp(self.res50_conv(x)) 413 | return output.view(output.size(0), output.size(1)) 414 | 415 | 416 | 417 | class MLP_Encoder(nn.Module): 418 | def __init__(self, in_dim = 1024, nz_num = 312, out_dim = 2048): 419 | super(MLP_Encoder, self).__init__() 420 | self.fc1 = nn.Linear(in_dim + nz_num, 4096) 421 | self.fc2 = nn.Linear(4096, out_dim) 422 | self.lrelu = nn.LeakyReLU(0.2, True) 423 | #self.prelu = nn.PReLU() 424 | self.relu = nn.ReLU(True) 425 | 426 | 427 | def forward(self, att, noise): 428 | h = torch.cat((noise, att), 1) 429 | h = self.lrelu(self.fc1(h)) 430 | h = self.relu(self.fc2(h)) 431 | return h 432 | 433 | 434 | 435 | class Linear_Classifier(nn.Module): 436 | def __init__(self, in_dim= 2048, c_dim = 200): 437 | super(Linear_Classifier, self).__init__() 438 | self.fc = nn.Linear(in_dim, c_dim) 439 | #self.logic = nn.LogSoftmax(dim=1) 440 | def forward(self, x): 441 | o = self.fc(x) 442 | return o 443 | 444 | 445 | 446 | class Eb_Discriminator(nn.Module): 447 | def __init__(self, ft_num = 2048, att_num = 1024): 448 | super(Eb_Discriminator, self).__init__() 449 | self.fc1 = nn.Sequential( nn.Linear(ft_num + att_num, 4096), 450 | nn.LeakyReLU(0.2, True)) 451 | #self.fc2 = nn.Linear(opt.ndh, opt.ndh) 452 | #self.fc2 = nn.Sequential(nn.Linear(4096, 1), 453 | # nn.Sigmoid()) 454 | self.fc2 = nn.Linear(4096, 1) 455 | 456 | 457 | def forward(self, x, att): 458 | h = torch.cat((x, att), 1) 459 | 460 | h = self.fc1(h) 461 | h = self.fc2(h) 462 | return h 463 | 464 | class Discriminator(nn.Module): 465 | """Discriminator network with PatchGAN.""" 466 | def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, ft_num = 16): 467 | super(Discriminator, self).__init__() 468 | layers = [] 469 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 470 | layers.append(nn.LeakyReLU(0.01)) 471 | 472 | curr_dim = conv_dim 473 | for i in range(1, repeat_num): 474 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) 475 | layers.append(nn.LeakyReLU(0.01)) 476 | curr_dim = curr_dim * 2 477 | 478 | kernel_size = int(image_size / np.power(2, repeat_num)) 479 | self.main = nn.Sequential(*layers) 480 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 481 | 482 | self.conv2 = nn.Conv2d(curr_dim, ft_num, kernel_size=kernel_size, bias=False)#nn.Sequential(*[nn.Conv2d(curr_dim, ft_num, kernel_size= kernel_size), nn.LeakyReLU(0.01)])# 483 | 484 | def forward(self, x): 485 | h = self.main(x) 486 | out_src = self.conv1(h) 487 | out_cls = self.conv2(h) 488 | return out_src, out_cls.view(out_cls.size(0), out_cls.size(1)) 489 | 490 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from torch.autograd import Variable 3 | from torchvision.utils import save_image 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import os 8 | import time 9 | import datetime 10 | import itertools 11 | 12 | 13 | 14 | def accuracy(output, target, topk=(1,)): 15 | """Computes the precision@k for the specified values of k""" 16 | if len(output[0]) < topk[1]: 17 | topk = (1, len(output[0])) 18 | maxk = max(topk) 19 | batch_size = target.size(0) 20 | 21 | _, pred = output.topk(maxk, 1, True, True) 22 | pred = pred.t() 23 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 24 | 25 | res = [] 26 | for k in topk: 27 | correct_k = correct[:k].view(-1).float().sum(0) 28 | res.append(correct_k.mul_(100.0 / batch_size)) 29 | return res 30 | 31 | class Solver(object): 32 | """Solver for training and testing zstgan.""" 33 | 34 | def __init__(self, data_loader, config): 35 | """Initialize configurations.""" 36 | 37 | # Data loader. 38 | self.data_loader = data_loader 39 | 40 | # Model configurations. 41 | self.ft_num = config.ft_num 42 | self.nz_num = config.nz_num 43 | self.c_dim = config.c_dim 44 | self.image_size = config.image_size 45 | self.g_conv_dim = config.g_conv_dim 46 | self.d_conv_dim = config.d_conv_dim 47 | self.g_repeat_num = config.g_repeat_num 48 | self.d_repeat_num = config.d_repeat_num 49 | self.n_blocks = config.n_blocks 50 | self.lambda_mut = config.lambda_mut 51 | self.lambda_rec = config.lambda_rec 52 | self.lambda_gp = config.lambda_gp 53 | self.att_num = config.att_num 54 | 55 | # Training configurations. 56 | self.batch_size = config.batch_size 57 | self.num_iters = config.num_iters 58 | self.num_iters_decay = config.num_iters_decay 59 | self.g_lr = config.g_lr 60 | self.d_lr = config.d_lr 61 | self.n_critic = config.n_critic 62 | self.beta1 = config.beta1 63 | self.beta2 = config.beta2 64 | self.resume_iters = config.resume_iters 65 | self.ev_ea_c_iters = config.ev_ea_c_iters 66 | self.c_pre_iters = config.c_pre_iters 67 | 68 | # Test configurations. 69 | self.test_iters = config.test_iters 70 | 71 | # Miscellaneous. 72 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 73 | 74 | # Directories. 75 | self.log_dir = config.log_dir 76 | self.sample_dir = config.sample_dir 77 | self.model_save_dir = config.model_save_dir 78 | self.result_dir = config.result_dir 79 | 80 | # Step size. 81 | self.log_step = config.log_step 82 | self.sample_step = config.sample_step 83 | self.model_save_step = config.model_save_step 84 | self.lr_update_step = config.lr_update_step 85 | 86 | # Build the model 87 | self.build_model() 88 | 89 | def build_model(self): 90 | """Create networks.""" 91 | 92 | self.encoder = AdaINEnc(input_dim = 3, ft_num = self.ft_num) 93 | self.decoder = AdaINDec(input_dim = 3, ft_num = self.ft_num) 94 | self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, ft_num = self.ft_num) 95 | self.encoder_v = Resnet_Feature() 96 | self.encoder_a = MLP_Encoder(in_dim = self.att_num, nz_num = self.nz_num, out_dim= self.ft_num) 97 | self.D_s = Eb_Discriminator(ft_num = self.ft_num, att_num = self.att_num) 98 | self.C = Linear_Classifier(in_dim= self.ft_num, c_dim = self.c_dim) 99 | 100 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), self.g_lr, [self.beta1, self.beta2]) 101 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) 102 | self.ev_optimizer = torch.optim.Adam(itertools.chain(self.encoder_v.parameters(), self.C.parameters()), self.d_lr, [self.beta1, self.beta2]) # use the same optimizer to update encoder_v and C 103 | self.ea_optimizer = torch.optim.Adam(self.encoder_a.parameters(), self.d_lr, [self.beta1, self.beta2]) 104 | self.ds_optimizer = torch.optim.Adam(self.D_s.parameters(), self.d_lr, [self.beta1, self.beta2]) 105 | self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.d_lr, [self.beta1, self.beta2]) 106 | 107 | self.encoder.to(self.device) 108 | self.decoder.to(self.device) 109 | self.D.to(self.device) 110 | self.encoder_v.to(self.device) 111 | self.encoder_a.to(self.device) 112 | self.D_s.to(self.device) 113 | self.C.to(self.device) 114 | 115 | 116 | 117 | 118 | def print_network(self, model, name): 119 | """Print out the network information.""" 120 | num_params = 0 121 | for p in model.parameters(): 122 | num_params += p.numel() 123 | print(model) 124 | print(name) 125 | print("The number of parameters: {}".format(num_params)) 126 | 127 | def restore_model(self, resume_iters): 128 | """Restore the trained networks.""" 129 | 130 | print('Loading the trained models from step {}...'.format(resume_iters)) 131 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(resume_iters)) 132 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(resume_iters)) 133 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) 134 | self.encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage)) 135 | self.decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage)) 136 | self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) 137 | 138 | def update_lr(self, g_lr, d_lr): 139 | """Decay learning rates of the generator and discriminator.""" 140 | for param_group in self.g_optimizer.param_groups: 141 | param_group['lr'] = g_lr 142 | for param_group in self.d_optimizer.param_groups: 143 | param_group['lr'] = d_lr 144 | 145 | def reset_grad(self): 146 | """Reset the gradient buffers.""" 147 | self.g_optimizer.zero_grad() 148 | self.d_optimizer.zero_grad() 149 | 150 | 151 | def denorm(self, x): 152 | """Convert the range from [-1, 1] to [0, 1].""" 153 | out = (x + 1) / 2 154 | return out.clamp_(0, 1) 155 | 156 | def gradient_penalty(self, y, x): 157 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" 158 | weight = torch.ones(y.size()).to(self.device) 159 | dydx = torch.autograd.grad(outputs=y, 160 | inputs=x, 161 | grad_outputs=weight, 162 | retain_graph=True, 163 | create_graph=True, 164 | only_inputs=True)[0] 165 | 166 | dydx = dydx.view(dydx.size(0), -1) 167 | dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) 168 | return torch.mean((dydx_l2norm-1)**2) 169 | 170 | def label2onehot(self, labels, dim): 171 | """Convert label indices to one-hot vectors.""" 172 | batch_size = labels.size(0) 173 | out = torch.zeros(batch_size, dim) 174 | out[np.arange(batch_size), labels.long()] = 1 175 | return out 176 | 177 | 178 | def classification_loss(self, logit, target): 179 | """Compute softmax cross entropy loss.""" 180 | return F.cross_entropy(logit, target) 181 | 182 | 183 | def train_ev_ea(self): 184 | """Train encoder_a and encoder_v with C and D_s.""" 185 | # Set data loader. 186 | data_loader = self.data_loader 187 | 188 | noise = torch.FloatTensor(self.batch_size, self.nz_num) 189 | noise = noise.to(self.device) # noise vector z 190 | 191 | start_iters = 0 192 | 193 | # Start training. 194 | print('Start encoder_a and encoder_v training...') 195 | start_time = time.time() 196 | 197 | ev_ea_c_iters = self.ev_ea_c_iters 198 | c_pre_iters = self.c_pre_iters 199 | 200 | C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(ev_ea_c_iters)) 201 | 202 | encoder_a_path = os.path.join(self.model_save_dir, '{}-encoder_a.ckpt'.format(ev_ea_c_iters)) 203 | 204 | encoder_v_path = os.path.join(self.model_save_dir, '{}-encoder_v.ckpt'.format(ev_ea_c_iters)) 205 | 206 | 207 | if os.path.exists(C_path): 208 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage)) 209 | print('Load model checkpoints from {}'.format(C_path)) 210 | 211 | self.encoder_a.load_state_dict(torch.load(encoder_a_path, map_location=lambda storage, loc: storage)) 212 | print('Load model checkpoints from {}'.format(encoder_a_path)) 213 | 214 | self.encoder_v.load_state_dict(torch.load(encoder_v_path, map_location=lambda storage, loc: storage)) 215 | print('Load model checkpoints from {}'.format(encoder_v_path)) 216 | else: 217 | C_pre_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(c_pre_iters)) 218 | if os.path.exists(C_pre_path): 219 | self.C.load_state_dict(torch.load(C_pre_path, map_location=lambda storage, loc: storage)) 220 | print('Load model pretrained checkpoints from {}'.format(C_pre_path)) 221 | else: 222 | for i in range(0, c_pre_iters): 223 | # Fetch real images, attributes and labels. 224 | x_real, wrong_images, attributes, _, label_org = data_loader.train.next_batch(self.batch_size,10) 225 | 226 | 227 | x_real = x_real.to(self.device) # Input images. 228 | attributes = attributes.to(self.device) # Input attributes 229 | label_org = label_org.to(self.device) # Labels for computing classification loss. 230 | 231 | ev_x = self.encoder_v(x_real) 232 | cls_x = self.C(ev_x.detach()) 233 | # Classification loss from only images for C training 234 | c_loss_cls = self.classification_loss(cls_x, label_org) 235 | # Backward and optimize. 236 | self.c_optimizer.zero_grad() 237 | c_loss_cls.backward() 238 | self.c_optimizer.step() 239 | 240 | if (i+1) % self.log_step == 0: 241 | loss = {} 242 | loss['c_loss_cls'] = c_loss_cls.item() 243 | prec1, prec5 = accuracy(cls_x.data, label_org.data, topk=(1, 5)) 244 | loss['prec1'] = prec1 245 | loss['prec5'] = prec5 246 | log = "C pretraining iteration [{}/{}]".format(i+1, c_pre_iters) 247 | for tag, value in loss.items(): 248 | log += ", {}: {:.4f}".format(tag, value) 249 | print(log) 250 | torch.save(self.C.state_dict(), C_pre_path) 251 | print('Saved model pretrained checkpoints into {}...'.format(C_pre_path)) 252 | 253 | for i in range(c_pre_iters, ev_ea_c_iters): 254 | # Fetch real images, attributes and labels. 255 | x_real, wrong_images, attributes, _, label_org = data_loader.train.next_batch(self.batch_size,10) 256 | 257 | 258 | x_real = x_real.to(self.device) # Input images. 259 | attributes = attributes.to(self.device) # Input attributes 260 | label_org = label_org.to(self.device) # Labels for computing classification loss. 261 | 262 | 263 | # =================================================================================== # 264 | # Train the domain-specific features discriminator 265 | # =================================================================================== # 266 | 267 | noise.normal_(0, 1) 268 | # Compute embedding of both images and attributes 269 | ea_a = self.encoder_a(attributes, noise) 270 | ev_x = self.encoder_v(x_real) 271 | 272 | 273 | ev_x_real = self.D_s(ev_x, attributes) 274 | ds_loss_real = -torch.mean(ev_x_real) 275 | 276 | 277 | ea_a_fake = self.D_s(ea_a, attributes) 278 | ds_loss_fake = torch.mean(ea_a_fake) 279 | 280 | # Compute loss for gradient penalty. 281 | alpha = torch.rand(ev_x.size(0), 1).to(self.device) 282 | ebd_hat = (alpha * ev_x.data + (1 - alpha) * ea_a.data).requires_grad_(True) 283 | 284 | ebd_inter = self.D_s(ebd_hat, attributes) 285 | ds_loss_gp = self.gradient_penalty(ebd_inter, ebd_hat) 286 | 287 | ds_loss = ds_loss_real + ds_loss_fake + self.lambda_gp * ds_loss_gp #+ ds_loss_realw 288 | #self.reset_grad_eb() 289 | self.ea_optimizer.zero_grad() 290 | self.ds_optimizer.zero_grad() 291 | self.ev_optimizer.zero_grad() 292 | 293 | ds_loss.backward() 294 | self.ds_optimizer.step() 295 | if (i+1) % self.n_critic == 0: 296 | # =================================================================================== # 297 | # Train the encoder_a and C 298 | # =================================================================================== # 299 | ev_x = self.encoder_v(x_real) 300 | ev_x_real = self.D_s(ev_x, attributes) 301 | ev_loss_real = torch.mean(ev_x_real) 302 | 303 | cls_x = self.C(ev_x) 304 | c_loss_cls = self.classification_loss(cls_x, label_org) 305 | 306 | # Backward and optimize. 307 | ev_c_loss = ev_loss_real + c_loss_cls 308 | self.ea_optimizer.zero_grad() 309 | self.ds_optimizer.zero_grad() 310 | self.ev_optimizer.zero_grad() 311 | ev_c_loss.backward() 312 | self.ev_optimizer.step() 313 | 314 | # =================================================================================== # 315 | # Train the encoder_v # 316 | # =================================================================================== # 317 | noise.normal_(0, 1) 318 | ea_a = self.encoder_a(attributes,noise) 319 | ea_a_fake = self.D_s(ea_a, attributes) 320 | ea_loss_fake = -torch.mean(ea_a_fake) 321 | 322 | cls_a = self.C(ea_a) 323 | ebn_loss_cls = self.classification_loss(cls_a, label_org) 324 | 325 | 326 | # Backward and optimize. 327 | ea_loss = ea_loss_fake + ebn_loss_cls 328 | self.ea_optimizer.zero_grad() 329 | self.ds_optimizer.zero_grad() 330 | self.ev_optimizer.zero_grad() 331 | ea_loss.backward() 332 | self.ea_optimizer.step() 333 | 334 | # Logging. 335 | loss = {} 336 | 337 | loss['ds/ds_loss_real'] = ds_loss_real.item() 338 | loss['ds/ds_loss_fake'] = ds_loss_fake.item() 339 | loss['ds/ds_loss_gp'] = ds_loss_gp.item() 340 | 341 | # Print out training information. 342 | if (i+1) % self.log_step == 0: 343 | et = time.time() - start_time 344 | et = str(datetime.timedelta(seconds=et))[:-7] 345 | prec1, prec5 = accuracy(cls_x.data, label_org.data, topk=(1, 5)) 346 | loss['prec1'] = prec1 347 | loss['prec5'] = prec5 348 | prec1e, prec5e = accuracy(cls_a.data, label_org.data, topk=(1, 5)) 349 | loss['prec1e'] = prec1e 350 | loss['prec5e'] = prec5e 351 | log = "Encoder_a and Encoder_v Training Elapsed [{}], Iteration [{}/{}]".format(et, i+1, ev_ea_c_iters) 352 | for tag, value in loss.items(): 353 | log += ", {}: {:.4f}".format(tag, value) 354 | print(log) 355 | 356 | 357 | # Save model checkpoints. 358 | if (i+1) % self.model_save_step == 0: 359 | C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i+1)) 360 | torch.save(self.C.state_dict(), C_path) 361 | print('Saved model checkpoints into {}...'.format(C_path)) 362 | 363 | encoder_a_path = os.path.join(self.model_save_dir, '{}-encoder_a.ckpt'.format(i+1)) 364 | torch.save(self.encoder_a.state_dict(), encoder_a_path) 365 | print('Saved model checkpoints into {}...'.format(encoder_a_path)) 366 | 367 | encoder_v_path = os.path.join(self.model_save_dir, '{}-encoder_v.ckpt'.format(i+1)) 368 | torch.save(self.encoder_v.state_dict(), encoder_v_path) 369 | print('Saved model checkpoints into {}...'.format(encoder_v_path)) 370 | 371 | def train(self): 372 | """Train zstgan""" 373 | # train encoder_a and encoder_v first 374 | self.train_ev_ea() 375 | self.encoder_v.eval() 376 | 377 | # Set data loader. 378 | data_loader = self.data_loader 379 | 380 | # Learning rate cache for decaying. 381 | g_lr = self.g_lr 382 | d_lr = self.d_lr 383 | 384 | # noise vector z 385 | noise = torch.FloatTensor(self.batch_size, self.nz_num) 386 | noise = noise.to(self.device) 387 | 388 | # Start training from scratch or resume training. 389 | start_iters = 0 390 | if self.resume_iters: 391 | start_iters = self.resume_iters 392 | self.restore_model(self.resume_iters) 393 | 394 | # Start training. 395 | print('Start training...') 396 | start_time = time.time() 397 | empty = torch.FloatTensor(1, 3,self.image_size,self.image_size).to(self.device) 398 | empty.fill_(1) 399 | for i in range(start_iters, self.num_iters): 400 | # Fetch real images and labels. 401 | x_real, wrong_images, attributes, _, label_org = data_loader.train.next_batch(self.batch_size,10) 402 | label_org = label_org.to(self.device) 403 | attributes = attributes.to(self.device) 404 | x_real = x_real.to(self.device) 405 | # Generate target domains 406 | ev_x = self.encoder_v(x_real) 407 | 408 | rand_idx = torch.randperm(label_org.size(0)) 409 | 410 | trg_ev_x_1 = ev_x[rand_idx] 411 | trg_ev_x = trg_ev_x_1.clone() 412 | label_trg_1 = label_org[rand_idx] 413 | label_trg = label_trg_1.clone() 414 | 415 | # =================================================================================== # 416 | # Train the discriminator 417 | # =================================================================================== # 418 | 419 | # Compute loss with real images. 420 | out_src, out_cls = self.D(x_real) 421 | d_loss_real = - torch.mean(out_src) 422 | d_loss_mut = torch.mean(torch.abs(ev_x.detach() - out_cls)) 423 | 424 | # Compute loss with fake images. 425 | x_fake = self.decoder(self.encoder(x_real), trg_ev_x) 426 | out_src, out_cls = self.D(x_fake.detach()) 427 | d_loss_fake = torch.mean(out_src) 428 | 429 | # Compute loss for gradient penalty. 430 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) 431 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) 432 | out_src, _ = self.D(x_hat) 433 | d_loss_gp = self.gradient_penalty(out_src, x_hat) 434 | 435 | # Backward and optimize. 436 | d_loss = d_loss_real + d_loss_fake + self.lambda_mut * d_loss_mut + self.lambda_gp * d_loss_gp 437 | self.reset_grad() 438 | d_loss.backward() 439 | self.d_optimizer.step() 440 | 441 | # Logging. 442 | loss = {} 443 | loss['D/loss_real'] = d_loss_real.item() 444 | loss['D/loss_fake'] = d_loss_fake.item() 445 | loss['D/loss_mut'] = d_loss_mut.item() 446 | loss['D/loss_gp'] = d_loss_gp.item() 447 | 448 | # =================================================================================== # 449 | # Train the encoder and decoder 450 | # =================================================================================== # 451 | 452 | if (i+1) % self.n_critic == 0: 453 | # Original-to-target domain. 454 | x_di = self.encoder(x_real) 455 | 456 | x_fake = self.decoder(x_di, trg_ev_x) 457 | x_reconst1 = self.decoder(x_di, ev_x) 458 | out_src, out_cls = self.D(x_fake) 459 | g_loss_fake = - torch.mean(out_src) 460 | g_loss_mut = torch.mean(torch.abs(trg_ev_x.detach() - out_cls)) 461 | 462 | # Target-to-original domain. 463 | x_fake_di = self.encoder(x_fake) 464 | 465 | x_reconst2 = self.decoder(x_fake_di, ev_x) 466 | 467 | g_loss_rec1 = torch.mean(torch.abs(x_real - x_reconst1)) 468 | 469 | g_loss_rec12 = torch.mean(torch.abs(x_real - x_reconst2)) 470 | 471 | # Backward and optimize. 472 | g_loss = g_loss_fake + self.lambda_rec * (g_loss_rec1 + g_loss_rec12) + self.lambda_mut * g_loss_mut 473 | self.reset_grad() 474 | g_loss.backward() 475 | self.g_optimizer.step() 476 | 477 | # Logging. 478 | loss['G/loss_fake'] = g_loss_fake.item() 479 | loss['G/loss_rec1'] = g_loss_rec1.item() 480 | loss['G/loss_rec2'] = g_loss_rec12.item() 481 | loss['G/loss_mut'] = g_loss_mut.item() 482 | 483 | # Print out training information. 484 | if (i+1) % self.log_step == 0: 485 | et = time.time() - start_time 486 | et = str(datetime.timedelta(seconds=et))[:-7] 487 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) 488 | for tag, value in loss.items(): 489 | log += ", {}: {:.4f}".format(tag, value) 490 | print(log) 491 | 492 | # Translate fixed images for debugging. 493 | if (i+1) % self.sample_step == 0: 494 | with torch.no_grad(): 495 | out_A2B_results = [empty] 496 | 497 | for idx1 in range(label_org.size(0)): 498 | out_A2B_results.append(x_real[idx1:idx1+1]) 499 | 500 | for idx2 in range(label_org.size(0)): 501 | out_A2B_results.append(x_real[idx2:idx2+1]) 502 | 503 | for idx1 in range(label_org.size(0)): 504 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2+1]), ev_x[idx1:idx1+1]) 505 | out_A2B_results.append(x_fake) 506 | results_concat = torch.cat(out_A2B_results) 507 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results.jpg'.format(i+1)) 508 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0) 509 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 510 | # save vision-driven and attribute-driven results on unseen domains 511 | x_real, wrong_images, attributes, _, label_org = data_loader.test.next_batch(self.batch_size,10) 512 | label_org = label_org.to(self.device) 513 | x_real = x_real.to(self.device) 514 | attributes = attributes.to(self.device) 515 | ev_x = self.encoder_v(x_real) 516 | noise.normal_(0, 1) 517 | ea_a = self.encoder_a(attributes, noise) 518 | 519 | out_A2B_results = [empty] 520 | out_A2B_results_a = [empty] 521 | 522 | for idx1 in range(label_org.size(0)): 523 | out_A2B_results.append(x_real[idx1:idx1+1]) 524 | out_A2B_results_a.append(x_real[idx1:idx1+1]) 525 | 526 | for idx2 in range(label_org.size(0)): 527 | out_A2B_results.append(x_real[idx2:idx2+1]) 528 | out_A2B_results_a.append(x_real[idx2:idx2+1]) 529 | 530 | for idx1 in range(label_org.size(0)): 531 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2+1]), ev_x[idx1:idx1+1]) 532 | out_A2B_results.append(x_fake) 533 | 534 | x_fake_a = self.decoder(self.encoder(x_real[idx2:idx2+1]), ea_a[idx1:idx1+1]) 535 | out_A2B_results_a.append(x_fake_a) 536 | results_concat = torch.cat(out_A2B_results) 537 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results_test_v.jpg'.format(i+1)) 538 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0) 539 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 540 | 541 | results_concat = torch.cat(out_A2B_results_a) 542 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results_test_a.jpg'.format(i+1)) 543 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0) 544 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 545 | 546 | 547 | 548 | 549 | # Save model checkpoints. 550 | if (i+1) % self.model_save_step == 0: 551 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(i+1)) 552 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(i+1)) 553 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) 554 | torch.save(self.encoder.state_dict(), encoder_path) 555 | torch.save(self.decoder.state_dict(), decoder_path) 556 | torch.save(self.D.state_dict(), D_path) 557 | print('Saved model checkpoints into {}...'.format(self.model_save_dir)) 558 | 559 | 560 | 561 | # Decay learning rates. 562 | if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): 563 | g_lr -= (self.g_lr / float(self.num_iters_decay)) 564 | d_lr -= (self.d_lr / float(self.num_iters_decay)) 565 | self.update_lr(g_lr, d_lr) 566 | print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) 567 | 568 | 569 | 570 | def test(self): 571 | """Translate images using zstgan on unseen test set.""" 572 | # Load the trained models. 573 | self.train_ev_ea() 574 | self.restore_model(self.test_iters) 575 | self.encoder_v.eval() 576 | # Set data loader. 577 | data_loader = self.data_loader 578 | empty = torch.FloatTensor(1, 3,self.image_size,self.image_size).to(self.device) 579 | empty.fill_(1) 580 | noise = torch.FloatTensor(self.batch_size, self.nz_num) 581 | noise = noise.to(self.device) 582 | step = 0 583 | data_loader.test.reinitialize_index() 584 | with torch.no_grad(): 585 | while True: 586 | try: 587 | x_real, wrong_images, attributes, _, label_org = data_loader.test.next_batch_test(self.batch_size,10) 588 | except: 589 | break 590 | x_real = x_real.to(self.device) 591 | label_org = label_org.to(self.device) 592 | attributes = attributes.to(self.device) 593 | 594 | 595 | ev_x = self.encoder_v(x_real) 596 | noise.normal_(0, 1) 597 | ea_a = self.encoder_a(attributes, noise) 598 | 599 | out_A2B_results = [empty] 600 | out_A2B_results_a = [empty] 601 | 602 | for idx1 in range(label_org.size(0)): 603 | out_A2B_results.append(x_real[idx1:idx1+1]) 604 | out_A2B_results_a.append(x_real[idx1:idx1+1]) 605 | 606 | for idx2 in range(label_org.size(0)): 607 | out_A2B_results.append(x_real[idx2:idx2+1]) 608 | out_A2B_results_a.append(x_real[idx2:idx2+1]) 609 | 610 | for idx1 in range(label_org.size(0)): 611 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2+1]), ev_x[idx1:idx1+1]) 612 | out_A2B_results.append(x_fake) 613 | 614 | x_fake_a = self.decoder(self.encoder(x_real[idx2:idx2+1]), ea_a[idx1:idx1+1]) 615 | out_A2B_results_a.append(x_fake_a) 616 | results_concat = torch.cat(out_A2B_results) 617 | x_AB_results_path = os.path.join(self.result_dir, '{}_x_AB_results_test_v.jpg'.format(step+1)) 618 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0) 619 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 620 | 621 | results_concat = torch.cat(out_A2B_results_a) 622 | x_AB_results_path = os.path.join(self.result_dir, '{}_x_AB_results_test_a.jpg'.format(step+1)) 623 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0) 624 | print('Saved real and fake images into {}...'.format(x_AB_results_path)) 625 | 626 | step += 1 627 | 628 | 629 | 630 | 631 | 632 | --------------------------------------------------------------------------------