├── README.md
├── datasets_torch.py
├── examples
├── CUB_a_150000.jpg
├── CUB_v_150000.jpg
├── FLO_a_150000.jpg
├── FLO_v_150000.jpg
├── aa
└── framework.jpg
├── main.py
├── model.py
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 | # ZstGAN-PyTorch
2 | PyTorch Implementation of "ZstGAN: An Adversarial Approach for Unsupervised Zero-Shot Image-to-Image Translation"
3 |
4 | # Dependency:
5 | Python 3.6
6 |
7 | PyTorch 0.4.0
8 |
9 | # Usage:
10 | ### Unsupervised Zero-Shot Image-to-Image Transaltion
11 | 1. Downloading CUB and FLO training and testing dataset following [CUB and FLO](https://pan.baidu.com/s/1m4a4PFpjFNMNLIdE8TlYAQ) with password `n6qd`. Or you can follow the [StackGAN](https://github.com/hanzhanggit/StackGAN) to prepare these two datasets.
12 |
13 | 2. Unzip the Data.zip and organize the CUB and FLO training and testing sets as:
14 |
15 | Data
16 | ├── flowers
17 | | ├── train
18 | | ├── test
19 | | └── ...
20 | ├── birds
21 | ├── train
22 | ├── test
23 | └── ...
24 |
25 | 3. Train ZstGAN on seen domains of FLO:
26 |
27 | `$ python main.py --mode train --model_dir flower --datadir Data/flowers/ --c_dim 102 --batch_size 8 --nz_num 312 --ft_num 2048 --lambda_mut 200`
28 | 4. Train ZstGAN on seen domains of CUB:
29 |
30 | `$ python main.py --mode train --model_dir bird --datadir Data/birds/ --c_dim 200 --batch_size 8 --nz_num 312 --ft_num 2048 --lambda_mut 50`
31 | 5. Test ZstGAN on unseen domains of FLO at iteration 200000:
32 |
33 | `$ python main.py --mode test --model_dir flower --datadir Data/flowers/ --c_dim 102 --test_iters 200000`
34 | 6. Test ZstGAN on unseen domains of CUB at iteration 200000:
35 |
36 | `$ python main.py --mode test --model_dir bird --datadir Data/birds/ --c_dim 200 --test_iters 200000`
37 | # Results:
38 | ### 1. Image translation on unseen domains of FLO at iterations 150000:
39 |
40 | **# Results of V-ZstGAN**:
41 |
42 |
43 |
44 | **# Results of A-ZstGAN**:
45 |
46 |
47 |
48 | ### 2. Image translation on unseen domains of CUB at iterations 150000:
49 |
50 | **# Results of V-ZstGAN**:
51 |
52 |
53 |
54 | **# Results of A-ZstGAN**:
55 |
56 |
57 |
--------------------------------------------------------------------------------
/datasets_torch.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 |
5 | import numpy as np
6 | import pickle
7 | import random
8 | import sys
9 |
10 | import torch
11 | class Dataset(object):
12 | def __init__(self, images, imsize, embeddings=None,
13 | filenames=None, workdir=None,
14 | labels=None, aug_flag=True,
15 | class_id=None, class_range=None):
16 | self._images = images
17 | self._embeddings = embeddings
18 | self._filenames = filenames
19 | self.workdir = workdir
20 | self._labels = labels
21 | self._epochs_completed = -1
22 | self._num_examples = len(images)
23 | self._saveIDs = self.saveIDs()
24 |
25 | # shuffle on first run
26 | self._index_in_epoch = self._num_examples
27 | self._aug_flag = aug_flag
28 | self._class_id = np.array(class_id)
29 | self._class_range = class_range
30 | self._imsize = imsize
31 | #self._perm = None
32 | self._perm = np.arange(self._num_examples)
33 | np.random.shuffle(self._perm)
34 | def reinitialize_index(self):
35 | self._index_in_epoch = 0
36 | return None
37 | @property
38 | def images(self):
39 | return self._images
40 |
41 | @property
42 | def embeddings(self):
43 | return self._embeddings
44 |
45 | @property
46 | def filenames(self):
47 | return self._filenames
48 |
49 | @property
50 | def num_examples(self):
51 | return self._num_examples
52 |
53 | @property
54 | def epochs_completed(self):
55 | return self._epochs_completed
56 |
57 | def saveIDs(self):
58 | self._saveIDs = np.arange(self._num_examples)
59 | np.random.shuffle(self._saveIDs)
60 | return self._saveIDs
61 |
62 | def readCaptions(self, filenames, class_id):
63 | name = filenames
64 | if name.find('jpg/') != -1: # flowers dataset
65 | class_name = 'class_%05d/' % class_id
66 | name = name.replace('jpg/', class_name)
67 | cap_path = '%s/text_c10/%s.txt' %\
68 | (self.workdir, name)
69 | with open(cap_path, "r") as f:
70 | captions = f.read().split('\n')
71 | captions = [cap for cap in captions if len(cap) > 0]
72 | return captions
73 |
74 | def transform(self, images):
75 | if self._aug_flag:
76 | transformed_images =\
77 | np.zeros([images.shape[0], self._imsize, self._imsize, 3])
78 | for i in range(images.shape[0]):
79 | if random.random() > 0.5:
80 | transformed_images[i] = np.fliplr(images[i])
81 | else:
82 | transformed_images[i] = images[i]
83 | return transformed_images
84 | else:
85 | return images
86 |
87 | def sample_embeddings(self, embeddings, filenames, class_id, sample_num):
88 | if len(embeddings.shape) == 2 or embeddings.shape[1] == 1:
89 | return np.squeeze(embeddings)
90 | else:
91 | batch_size, embedding_num, _ = embeddings.shape
92 | # Take every sample_num captions to compute the mean vector
93 | sampled_embeddings = []
94 | sampled_captions = []
95 | for i in range(batch_size):
96 | randix = np.random.choice(embedding_num,
97 | sample_num, replace=False)
98 | if sample_num == 1:
99 | randix = int(randix)
100 | captions = self.readCaptions(filenames[i],
101 | class_id[i])
102 | #sampled_captions.append(captions[randix])
103 | sampled_embeddings.append(embeddings[i, randix, :])
104 | else:
105 | e_sample = embeddings[i, randix, :]
106 | e_mean = np.mean(e_sample, axis=0)
107 | sampled_embeddings.append(e_mean)
108 | sampled_embeddings_array = np.array(sampled_embeddings)
109 | return np.squeeze(sampled_embeddings_array), sampled_captions
110 |
111 | def next_batch(self, batch_size, window):
112 | """Return the next `batch_size` examples from this data set."""
113 | start = self._index_in_epoch
114 | self._index_in_epoch += batch_size
115 |
116 | if self._index_in_epoch > self._num_examples:
117 | # Finished epoch
118 | self._epochs_completed += 1
119 | # Shuffle the data
120 | self._perm = np.arange(self._num_examples)
121 | np.random.shuffle(self._perm)
122 |
123 | # Start next epoch
124 | start = 0
125 | self._index_in_epoch = batch_size
126 | assert batch_size <= self._num_examples
127 | end = self._index_in_epoch
128 |
129 | current_ids = self._perm[start:end]
130 | fake_ids = np.random.randint(self._num_examples, size=batch_size)
131 | collision_flag =\
132 | (self._class_id[current_ids] == self._class_id[fake_ids])
133 | fake_ids[collision_flag] =\
134 | (fake_ids[collision_flag] +
135 | np.random.randint(100, 200)) % self._num_examples
136 |
137 | sampled_images = self._images[current_ids]
138 | sampled_wrong_images = self._images[fake_ids, :, :, :]
139 | sampled_images = sampled_images.astype(np.float32)
140 | sampled_wrong_images = sampled_wrong_images.astype(np.float32)
141 | sampled_images = sampled_images * (2. / 255) - 1.
142 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1.
143 |
144 | sampled_images = self.transform(sampled_images)
145 | sampled_wrong_images = self.transform(sampled_wrong_images)
146 | ret_list = [torch.FloatTensor(sampled_images.transpose((0,3,1,2))), torch.FloatTensor(sampled_wrong_images.transpose((0,3,1,2)))]
147 |
148 | if self._embeddings is not None:
149 | filenames = [self._filenames[i] for i in current_ids]
150 | class_id = [self._class_id[i] for i in current_ids]
151 | sampled_embeddings, sampled_captions = \
152 | self.sample_embeddings(self._embeddings[current_ids],
153 | filenames, class_id, window)
154 | ret_list.append(torch.FloatTensor(sampled_embeddings))
155 | ret_list.append(torch.FloatTensor(sampled_captions))
156 | else:
157 | ret_list.append(None)
158 | ret_list.append(None)
159 |
160 | if self._labels is not None:
161 | ret_list.append(torch.LongTensor(np.array(self._labels)[current_ids]-1))
162 | else:
163 | ret_list.append(None)
164 | return ret_list
165 | def next_batch_test(self, batch_size, window):
166 | """Return the next `batch_size` examples from this data set."""
167 | start = self._index_in_epoch
168 | self._index_in_epoch += batch_size
169 |
170 | if self._index_in_epoch > self._num_examples:
171 | ret_list = []
172 | return ret_list
173 | end = self._index_in_epoch
174 |
175 | current_ids = self._perm[start:end]
176 | fake_ids = np.random.randint(self._num_examples, size=batch_size)
177 | collision_flag =\
178 | (self._class_id[current_ids] == self._class_id[fake_ids])
179 | fake_ids[collision_flag] =\
180 | (fake_ids[collision_flag] +
181 | np.random.randint(100, 200)) % self._num_examples
182 |
183 | sampled_images = self._images[current_ids]
184 | sampled_wrong_images = self._images[fake_ids, :, :, :]
185 | sampled_images = sampled_images.astype(np.float32)
186 | sampled_wrong_images = sampled_wrong_images.astype(np.float32)
187 | sampled_images = sampled_images * (2. / 255) - 1.
188 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1.
189 |
190 | sampled_images = self.transform(sampled_images)
191 | sampled_wrong_images = self.transform(sampled_wrong_images)
192 | ret_list = [torch.FloatTensor(sampled_images.transpose((0,3,1,2))), torch.FloatTensor(sampled_wrong_images.transpose((0,3,1,2)))]
193 |
194 | if self._embeddings is not None:
195 | filenames = [self._filenames[i] for i in current_ids]
196 | class_id = [self._class_id[i] for i in current_ids]
197 | sampled_embeddings, sampled_captions = \
198 | self.sample_embeddings(self._embeddings[current_ids],
199 | filenames, class_id, window)
200 | ret_list.append(torch.FloatTensor(sampled_embeddings))
201 | ret_list.append(torch.FloatTensor(sampled_captions))
202 | else:
203 | ret_list.append(None)
204 | ret_list.append(None)
205 |
206 | if self._labels is not None:
207 | ret_list.append(torch.LongTensor(np.array(self._labels)[current_ids]-1))
208 | else:
209 | ret_list.append(None)
210 | return ret_list
211 | def next_batch_val(self, batch_size, window):
212 | """Return the next `batch_size` examples from this data set."""
213 | start = self._index_in_epoch
214 | self._index_in_epoch += batch_size
215 |
216 | if self._index_in_epoch > self._num_examples:
217 | # Finished epoch
218 | sys.exit()
219 | end = self._index_in_epoch
220 |
221 | current_ids = self._perm[start:end]
222 | fake_ids = np.random.randint(self._num_examples, size=batch_size)
223 | collision_flag =\
224 | (self._class_id[current_ids] == self._class_id[fake_ids])
225 | fake_ids[collision_flag] =\
226 | (fake_ids[collision_flag] +
227 | np.random.randint(100, 200)) % self._num_examples
228 |
229 | sampled_images = self._images[current_ids]
230 | sampled_wrong_images = self._images[fake_ids, :, :, :]
231 | sampled_images = sampled_images.astype(np.float32)
232 | sampled_wrong_images = sampled_wrong_images.astype(np.float32)
233 | sampled_images = sampled_images * (2. / 255) - 1.
234 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1.
235 |
236 | sampled_images = self.transform(sampled_images)
237 | sampled_wrong_images = self.transform(sampled_wrong_images)
238 | ret_list = [torch.FloatTensor(sampled_images.transpose((0,3,1,2))), torch.FloatTensor(sampled_wrong_images.transpose((0,3,1,2)))]
239 |
240 | if self._embeddings is not None:
241 | filenames = [self._filenames[i] for i in current_ids]
242 | class_id = [self._class_id[i] for i in current_ids]
243 | sampled_embeddings, sampled_captions = \
244 | self.sample_embeddings(self._embeddings[current_ids],
245 | filenames, class_id, window)
246 | ret_list.append(torch.FloatTensor(sampled_embeddings))
247 | ret_list.append(torch.FloatTensor(sampled_captions))
248 | else:
249 | ret_list.append(None)
250 | ret_list.append(None)
251 |
252 | if self._labels is not None:
253 | ret_list.append(torch.LongTensor(np.array(self._labels)[current_ids]-1))
254 | else:
255 | ret_list.append(None)
256 | return ret_list
257 |
258 |
259 | class TextDataset(object):
260 | def __init__(self, workdir, embedding_type, image_size):
261 | self.image_filename = '/128images.pickle'
262 |
263 |
264 | self.image_shape = [image_size,
265 | image_size, 3]
266 | self.image_dim = self.image_shape[0] * self.image_shape[1] * 3
267 | self.embedding_shape = None
268 | self.train = None
269 | self.test = None
270 | self.workdir = workdir
271 | if embedding_type == 'cnn-rnn':
272 | self.embedding_filename = '/char-CNN-RNN-embeddings.pickle'
273 | elif embedding_type == 'skip-thought':
274 | self.embedding_filename = '/skip-thought-embeddings.pickle'
275 |
276 | def get_data(self, pickle_path, aug_flag=True):
277 | with open(pickle_path + self.image_filename, 'rb') as f:
278 | images = pickle.load(f, encoding='latin1')
279 | images = np.array(images)
280 | print('images: ', images.shape)
281 |
282 | with open(pickle_path + self.embedding_filename, 'rb') as f:
283 | embeddings = pickle.load(f, encoding='latin1')
284 | embeddings = np.array(embeddings)
285 | self.embedding_shape = [embeddings.shape[-1]]
286 | print('embeddings: ', embeddings.shape)
287 | with open(pickle_path + '/filenames.pickle', 'rb') as f:
288 | list_filenames = pickle.load(f, encoding='latin1')
289 | print('list_filenames: ', len(list_filenames), list_filenames[0])
290 | with open(pickle_path + '/class_info.pickle', 'rb') as f:
291 | class_id = pickle.load(f, encoding='latin1')
292 |
293 | return Dataset(images, self.image_shape[0], embeddings,
294 | list_filenames, self.workdir, class_id,
295 | aug_flag, class_id)
296 |
--------------------------------------------------------------------------------
/examples/CUB_a_150000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/CUB_a_150000.jpg
--------------------------------------------------------------------------------
/examples/CUB_v_150000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/CUB_v_150000.jpg
--------------------------------------------------------------------------------
/examples/FLO_a_150000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/FLO_a_150000.jpg
--------------------------------------------------------------------------------
/examples/FLO_v_150000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/FLO_v_150000.jpg
--------------------------------------------------------------------------------
/examples/aa:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/examples/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/linjx-ustc1106/ZstGAN-PyTorch/ef37e81a1a8a4808dbc436803c1e68f5ea1881dd/examples/framework.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from trainer import Solver
4 | from torch.backends import cudnn
5 | from torchvision import transforms, datasets
6 | import torch.utils.data as data
7 | import torch
8 | from torchvision.utils import save_image
9 | from datasets_torch import TextDataset
10 |
11 |
12 | def main(config):
13 | cudnn.benchmark = True
14 | torch.manual_seed(7) # cpu
15 | torch.cuda.manual_seed_all(999) #gpu
16 |
17 | # Create directories if not exist.
18 | config.log_dir = os.path.join(config.model_dir, 'logs')
19 | config.model_save_dir = os.path.join(config.model_dir, 'models')
20 | config.sample_dir = os.path.join(config.model_dir, 'samples')
21 | config.result_dir = os.path.join(config.model_dir, 'results')
22 |
23 | if not os.path.exists(config.log_dir):
24 | os.makedirs(config.log_dir)
25 | if not os.path.exists(config.model_save_dir):
26 | os.makedirs(config.model_save_dir)
27 | if not os.path.exists(config.sample_dir):
28 | os.makedirs(config.sample_dir)
29 | if not os.path.exists(config.result_dir):
30 | os.makedirs(config.result_dir)
31 |
32 | # dataloader
33 | dataset = TextDataset(config.datadir, 'cnn-rnn', config.image_size)
34 | filename_test = '%s/test' % (config.datadir)
35 | dataset.test = dataset.get_data(filename_test)
36 | filename_train = '%s/train' % (config.datadir)
37 | dataset.train = dataset.get_data(filename_train)
38 |
39 | # Solver for training and testing ZstGAN.
40 | solver = Solver(dataset, config)
41 |
42 | if config.mode == 'train':
43 | solver.train() # train mode for ZstGAN
44 | elif config.mode == 'test':
45 | solver.test() # test mode for ZstGAN
46 |
47 |
48 |
49 | if __name__ == '__main__':
50 | parser = argparse.ArgumentParser()
51 |
52 | # Model configuration.
53 | parser.add_argument('--c_dim', type=int, default=200, help='dimension of domain labels (1st dataset)')
54 | parser.add_argument('--image_size', type=int, default=128, help='image resolution')
55 | parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
56 | parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
57 | parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
58 | parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
59 | parser.add_argument('--n_blocks', type=int, default=0, help='number of res conv layers in C')
60 | parser.add_argument('--lambda_mut', type=float, default=10, help='weight for multual information loss')
61 | parser.add_argument('--lambda_rec', type=float, default=1, help='weight for reconstruction loss')
62 | parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
63 | parser.add_argument('--ft_num', type=int, default=2048, help='number of ds feature')
64 | parser.add_argument('--nz_num', type=int, default=312, help='number of noise feature')
65 | parser.add_argument('--att_num', type=int, default=1024, help='number of attribute feature')
66 |
67 | # Training configuration.
68 | parser.add_argument('--batch_size', type=int, default=8, help='mini-batch size')
69 | parser.add_argument('--num_iters', type=int, default=300000, help='number of total iterations for training D')
70 | parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
71 | parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
72 | parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
73 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
74 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
75 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
76 | parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
77 | parser.add_argument('--ev_ea_c_iters', type=int, default=80000, help='number of iterations for training encoder_a and encoder_v')
78 | parser.add_argument('--c_pre_iters', type=int, default=20000, help='number of iterations for pre-training C')
79 |
80 | # Test configuration.
81 | parser.add_argument('--test_iters', type=int, default=300000, help='test model from this step')
82 |
83 | # Miscellaneous.
84 | parser.add_argument('--num_workers', type=int, default=1)
85 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
86 |
87 | # Directories.
88 | parser.add_argument('--datadir', type=str, default='Data/birds')
89 | parser.add_argument('--model_dir', type=str, default='zstgan')
90 |
91 | # Step size.
92 | parser.add_argument('--log_step', type=int, default=100)
93 | parser.add_argument('--sample_step', type=int, default=2000)
94 | parser.add_argument('--model_save_step', type=int, default=20000)
95 | parser.add_argument('--lr_update_step', type=int, default=1000)
96 |
97 | config = parser.parse_args()
98 | print(config)
99 | main(config)
100 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torchvision.models as models
6 | try:
7 | from itertools import izip as zip
8 | except ImportError: # will be 3.x series
9 | pass
10 |
11 | class AdaINEnc(nn.Module):
12 | # AdaIN encoder architecture
13 | def __init__(self, input_dim, ft_num):
14 | super(AdaINEnc, self).__init__()
15 |
16 | dim = 64
17 | style_dim = ft_num
18 | n_downsample = 2
19 | n_res = 16
20 | activ = 'relu'
21 | pad_type = 'reflect'
22 | mlp_dim = 256
23 |
24 | # encoder
25 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
26 |
27 | def forward(self, images):
28 | # reconstruct an image
29 | content = self.encode(images)
30 | return content
31 |
32 | def encode(self, images):
33 | # encode an image to its content and style codes
34 | content = self.enc_content(images)
35 | return content
36 |
37 |
38 |
39 | class AdaINDec(nn.Module):
40 | # AdaIN decoder architecture
41 | def __init__(self, input_dim, ft_num):
42 | super(AdaINDec, self).__init__()
43 |
44 | dim = 64
45 | style_dim = ft_num
46 | n_downsample = 2
47 | n_res = 16
48 | activ = 'relu'
49 | pad_type = 'reflect'
50 | mlp_dim = 256
51 |
52 |
53 |
54 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
55 | # decoder
56 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)
57 |
58 | # MLP to generate AdaIN parameters
59 | self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
60 |
61 | def forward(self, content, style):
62 | # decode content and style codes to an image
63 | adain_params = self.mlp(style)
64 | self.assign_adain_params(adain_params, self.dec)
65 | images = self.dec(content)
66 | return images
67 |
68 | def assign_adain_params(self, adain_params, model):
69 | # assign the adain_params to the AdaIN layers in model
70 | for m in model.modules():
71 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
72 | mean = adain_params[:, :m.num_features]
73 | std = adain_params[:, m.num_features:2*m.num_features]
74 | m.bias = mean.contiguous().view(-1)
75 | m.weight = std.contiguous().view(-1)
76 | if adain_params.size(1) > 2*m.num_features:
77 | adain_params = adain_params[:, 2*m.num_features:]
78 |
79 | def get_num_adain_params(self, model):
80 | # return the number of AdaIN parameters needed by the model
81 | num_adain_params = 0
82 | for m in model.modules():
83 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
84 | num_adain_params += 2*m.num_features
85 | return num_adain_params
86 |
87 |
88 |
89 | class ContentEncoder(nn.Module):
90 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
91 | super(ContentEncoder, self).__init__()
92 | self.model = []
93 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
94 | # downsampling blocks
95 | for i in range(n_downsample):
96 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
97 | dim *= 2
98 | # residual blocks
99 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
100 | self.model = nn.Sequential(*self.model)
101 | self.output_dim = dim
102 |
103 | def forward(self, x):
104 | return self.model(x)
105 |
106 | class Decoder(nn.Module):
107 | def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
108 | super(Decoder, self).__init__()
109 |
110 | self.model = []
111 | # AdaIN residual blocks
112 | self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
113 | # upsampling blocks
114 | for i in range(n_upsample):
115 | self.model += [nn.Upsample(scale_factor=2),
116 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
117 | dim //= 2
118 | # use reflection padding in the last conv layer
119 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
120 | self.model = nn.Sequential(*self.model)
121 |
122 | def forward(self, x):
123 | return self.model(x)
124 | class ResBlocks(nn.Module):
125 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
126 | super(ResBlocks, self).__init__()
127 | self.model = []
128 | for i in range(num_blocks):
129 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
130 | self.model = nn.Sequential(*self.model)
131 |
132 | def forward(self, x):
133 | return self.model(x)
134 |
135 | class MLP(nn.Module):
136 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
137 |
138 | super(MLP, self).__init__()
139 | self.model = []
140 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
141 | for i in range(n_blk - 2):
142 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
143 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
144 | self.model = nn.Sequential(*self.model)
145 |
146 | def forward(self, x):
147 | return self.model(x.view(x.size(0), -1))
148 | class ResBlock(nn.Module):
149 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
150 | super(ResBlock, self).__init__()
151 |
152 | model = []
153 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
154 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
155 | self.model = nn.Sequential(*model)
156 |
157 | def forward(self, x):
158 | residual = x
159 | out = self.model(x)
160 | out += residual
161 | return out
162 |
163 | class Conv2dBlock(nn.Module):
164 | def __init__(self, input_dim ,output_dim, kernel_size, stride,
165 | padding=0, norm='none', activation='relu', pad_type='zero'):
166 | super(Conv2dBlock, self).__init__()
167 | self.use_bias = True
168 | # initialize padding
169 | if pad_type == 'reflect':
170 | self.pad = nn.ReflectionPad2d(padding)
171 | elif pad_type == 'replicate':
172 | self.pad = nn.ReplicationPad2d(padding)
173 | elif pad_type == 'zero':
174 | self.pad = nn.ZeroPad2d(padding)
175 | else:
176 | assert 0, "Unsupported padding type: {}".format(pad_type)
177 |
178 | # initialize normalization
179 | norm_dim = output_dim
180 | if norm == 'bn':
181 | self.norm = nn.BatchNorm2d(norm_dim)
182 | elif norm == 'in':
183 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
184 | self.norm = nn.InstanceNorm2d(norm_dim)
185 | elif norm == 'ln':
186 | self.norm = LayerNorm(norm_dim)
187 | elif norm == 'adain':
188 | self.norm = AdaptiveInstanceNorm2d(norm_dim)
189 | elif norm == 'none' or norm == 'sn':
190 | self.norm = None
191 | else:
192 | assert 0, "Unsupported normalization: {}".format(norm)
193 |
194 | # initialize activation
195 | if activation == 'relu':
196 | self.activation = nn.ReLU(inplace=True)
197 | elif activation == 'lrelu':
198 | self.activation = nn.LeakyReLU(0.2, inplace=True)
199 | elif activation == 'prelu':
200 | self.activation = nn.PReLU()
201 | elif activation == 'selu':
202 | self.activation = nn.SELU(inplace=True)
203 | elif activation == 'tanh':
204 | self.activation = nn.Tanh()
205 | elif activation == 'none':
206 | self.activation = None
207 | else:
208 | assert 0, "Unsupported activation: {}".format(activation)
209 |
210 | # initialize convolution
211 | if norm == 'sn':
212 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
213 | else:
214 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
215 |
216 | def forward(self, x):
217 | x = self.conv(self.pad(x))
218 | if self.norm:
219 | x = self.norm(x)
220 | if self.activation:
221 | x = self.activation(x)
222 | return x
223 |
224 | class LinearBlock(nn.Module):
225 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
226 | super(LinearBlock, self).__init__()
227 | use_bias = True
228 | # initialize fully connected layer
229 | if norm == 'sn':
230 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
231 | else:
232 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
233 |
234 | # initialize normalization
235 | norm_dim = output_dim
236 | if norm == 'bn':
237 | self.norm = nn.BatchNorm1d(norm_dim)
238 | elif norm == 'in':
239 | self.norm = nn.InstanceNorm1d(norm_dim)
240 | elif norm == 'ln':
241 | self.norm = LayerNorm(norm_dim)
242 | elif norm == 'none' or norm == 'sn':
243 | self.norm = None
244 | else:
245 | assert 0, "Unsupported normalization: {}".format(norm)
246 |
247 | # initialize activation
248 | if activation == 'relu':
249 | self.activation = nn.ReLU(inplace=True)
250 | elif activation == 'lrelu':
251 | self.activation = nn.LeakyReLU(0.2, inplace=True)
252 | elif activation == 'prelu':
253 | self.activation = nn.PReLU()
254 | elif activation == 'selu':
255 | self.activation = nn.SELU(inplace=True)
256 | elif activation == 'tanh':
257 | self.activation = nn.Tanh()
258 | elif activation == 'none':
259 | self.activation = None
260 | else:
261 | assert 0, "Unsupported activation: {}".format(activation)
262 |
263 | def forward(self, x):
264 | out = self.fc(x)
265 | if self.norm:
266 | out = self.norm(out)
267 | if self.activation:
268 | out = self.activation(out)
269 | return out
270 |
271 |
272 | class AdaptiveInstanceNorm2d(nn.Module):
273 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
274 | super(AdaptiveInstanceNorm2d, self).__init__()
275 | self.num_features = num_features
276 | self.eps = eps
277 | self.momentum = momentum
278 | # weight and bias are dynamically assigned
279 | self.weight = None
280 | self.bias = None
281 | # just dummy buffers, not used
282 | self.register_buffer('running_mean', torch.zeros(num_features))
283 | self.register_buffer('running_var', torch.ones(num_features))
284 |
285 | def forward(self, x):
286 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
287 | b, c = x.size(0), x.size(1)
288 | running_mean = self.running_mean.repeat(b)
289 | running_var = self.running_var.repeat(b)
290 |
291 | # Apply instance norm
292 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
293 |
294 | out = F.batch_norm(
295 | x_reshaped, running_mean, running_var, self.weight, self.bias,
296 | True, self.momentum, self.eps)
297 |
298 | return out.view(b, c, *x.size()[2:])
299 |
300 | def __repr__(self):
301 | return self.__class__.__name__ + '(' + str(self.num_features) + ')'
302 |
303 |
304 | class LayerNorm(nn.Module):
305 | def __init__(self, num_features, eps=1e-5, affine=True):
306 | super(LayerNorm, self).__init__()
307 | self.num_features = num_features
308 | self.affine = affine
309 | self.eps = eps
310 |
311 | if self.affine:
312 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
313 | self.beta = nn.Parameter(torch.zeros(num_features))
314 |
315 | def forward(self, x):
316 | shape = [-1] + [1] * (x.dim() - 1)
317 | # print(x.size())
318 | if x.size(0) == 1:
319 | # These two lines run much faster in pytorch 0.4 than the two lines listed below.
320 | mean = x.view(-1).mean().view(*shape)
321 | std = x.view(-1).std().view(*shape)
322 | else:
323 | mean = x.view(x.size(0), -1).mean(1).view(*shape)
324 | std = x.view(x.size(0), -1).std(1).view(*shape)
325 |
326 | x = (x - mean) / (std + self.eps)
327 |
328 | if self.affine:
329 | shape = [1, -1] + [1] * (x.dim() - 2)
330 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
331 | return x
332 |
333 | def l2normalize(v, eps=1e-12):
334 | return v / (v.norm() + eps)
335 |
336 |
337 | class SpectralNorm(nn.Module):
338 | """
339 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
340 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
341 | """
342 | def __init__(self, module, name='weight', power_iterations=1):
343 | super(SpectralNorm, self).__init__()
344 | self.module = module
345 | self.name = name
346 | self.power_iterations = power_iterations
347 | if not self._made_params():
348 | self._make_params()
349 |
350 | def _update_u_v(self):
351 | u = getattr(self.module, self.name + "_u")
352 | v = getattr(self.module, self.name + "_v")
353 | w = getattr(self.module, self.name + "_bar")
354 |
355 | height = w.data.shape[0]
356 | for _ in range(self.power_iterations):
357 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
358 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
359 |
360 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
361 | sigma = u.dot(w.view(height, -1).mv(v))
362 | setattr(self.module, self.name, w / sigma.expand_as(w))
363 |
364 | def _made_params(self):
365 | try:
366 | u = getattr(self.module, self.name + "_u")
367 | v = getattr(self.module, self.name + "_v")
368 | w = getattr(self.module, self.name + "_bar")
369 | return True
370 | except AttributeError:
371 | return False
372 |
373 |
374 | def _make_params(self):
375 | w = getattr(self.module, self.name)
376 |
377 | height = w.data.shape[0]
378 | width = w.view(height, -1).data.shape[1]
379 |
380 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
381 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
382 | u.data = l2normalize(u.data)
383 | v.data = l2normalize(v.data)
384 | w_bar = nn.Parameter(w.data)
385 |
386 | del self.module._parameters[self.name]
387 |
388 | self.module.register_parameter(self.name + "_u", u)
389 | self.module.register_parameter(self.name + "_v", v)
390 | self.module.register_parameter(self.name + "_bar", w_bar)
391 |
392 |
393 | def forward(self, *args):
394 | self._update_u_v()
395 | return self.module.forward(*args)
396 |
397 |
398 |
399 |
400 |
401 | class Resnet_Feature(nn.Module):
402 | name = 'Resnet_Feature'
403 | def __init__(self):
404 | super(Resnet_Feature, self).__init__()
405 | res50_model = models.resnet50(pretrained=True)
406 |
407 | self.res50_conv = nn.Sequential(*list(res50_model.children())[:-2])
408 |
409 | self.avp = nn.AvgPool2d(kernel_size=4, stride=1, padding=0)
410 |
411 | def forward(self, x):
412 | output = self.avp(self.res50_conv(x))
413 | return output.view(output.size(0), output.size(1))
414 |
415 |
416 |
417 | class MLP_Encoder(nn.Module):
418 | def __init__(self, in_dim = 1024, nz_num = 312, out_dim = 2048):
419 | super(MLP_Encoder, self).__init__()
420 | self.fc1 = nn.Linear(in_dim + nz_num, 4096)
421 | self.fc2 = nn.Linear(4096, out_dim)
422 | self.lrelu = nn.LeakyReLU(0.2, True)
423 | #self.prelu = nn.PReLU()
424 | self.relu = nn.ReLU(True)
425 |
426 |
427 | def forward(self, att, noise):
428 | h = torch.cat((noise, att), 1)
429 | h = self.lrelu(self.fc1(h))
430 | h = self.relu(self.fc2(h))
431 | return h
432 |
433 |
434 |
435 | class Linear_Classifier(nn.Module):
436 | def __init__(self, in_dim= 2048, c_dim = 200):
437 | super(Linear_Classifier, self).__init__()
438 | self.fc = nn.Linear(in_dim, c_dim)
439 | #self.logic = nn.LogSoftmax(dim=1)
440 | def forward(self, x):
441 | o = self.fc(x)
442 | return o
443 |
444 |
445 |
446 | class Eb_Discriminator(nn.Module):
447 | def __init__(self, ft_num = 2048, att_num = 1024):
448 | super(Eb_Discriminator, self).__init__()
449 | self.fc1 = nn.Sequential( nn.Linear(ft_num + att_num, 4096),
450 | nn.LeakyReLU(0.2, True))
451 | #self.fc2 = nn.Linear(opt.ndh, opt.ndh)
452 | #self.fc2 = nn.Sequential(nn.Linear(4096, 1),
453 | # nn.Sigmoid())
454 | self.fc2 = nn.Linear(4096, 1)
455 |
456 |
457 | def forward(self, x, att):
458 | h = torch.cat((x, att), 1)
459 |
460 | h = self.fc1(h)
461 | h = self.fc2(h)
462 | return h
463 |
464 | class Discriminator(nn.Module):
465 | """Discriminator network with PatchGAN."""
466 | def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6, ft_num = 16):
467 | super(Discriminator, self).__init__()
468 | layers = []
469 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
470 | layers.append(nn.LeakyReLU(0.01))
471 |
472 | curr_dim = conv_dim
473 | for i in range(1, repeat_num):
474 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
475 | layers.append(nn.LeakyReLU(0.01))
476 | curr_dim = curr_dim * 2
477 |
478 | kernel_size = int(image_size / np.power(2, repeat_num))
479 | self.main = nn.Sequential(*layers)
480 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
481 |
482 | self.conv2 = nn.Conv2d(curr_dim, ft_num, kernel_size=kernel_size, bias=False)#nn.Sequential(*[nn.Conv2d(curr_dim, ft_num, kernel_size= kernel_size), nn.LeakyReLU(0.01)])#
483 |
484 | def forward(self, x):
485 | h = self.main(x)
486 | out_src = self.conv1(h)
487 | out_cls = self.conv2(h)
488 | return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
489 |
490 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from model import *
2 | from torch.autograd import Variable
3 | from torchvision.utils import save_image
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import os
8 | import time
9 | import datetime
10 | import itertools
11 |
12 |
13 |
14 | def accuracy(output, target, topk=(1,)):
15 | """Computes the precision@k for the specified values of k"""
16 | if len(output[0]) < topk[1]:
17 | topk = (1, len(output[0]))
18 | maxk = max(topk)
19 | batch_size = target.size(0)
20 |
21 | _, pred = output.topk(maxk, 1, True, True)
22 | pred = pred.t()
23 | correct = pred.eq(target.view(1, -1).expand_as(pred))
24 |
25 | res = []
26 | for k in topk:
27 | correct_k = correct[:k].view(-1).float().sum(0)
28 | res.append(correct_k.mul_(100.0 / batch_size))
29 | return res
30 |
31 | class Solver(object):
32 | """Solver for training and testing zstgan."""
33 |
34 | def __init__(self, data_loader, config):
35 | """Initialize configurations."""
36 |
37 | # Data loader.
38 | self.data_loader = data_loader
39 |
40 | # Model configurations.
41 | self.ft_num = config.ft_num
42 | self.nz_num = config.nz_num
43 | self.c_dim = config.c_dim
44 | self.image_size = config.image_size
45 | self.g_conv_dim = config.g_conv_dim
46 | self.d_conv_dim = config.d_conv_dim
47 | self.g_repeat_num = config.g_repeat_num
48 | self.d_repeat_num = config.d_repeat_num
49 | self.n_blocks = config.n_blocks
50 | self.lambda_mut = config.lambda_mut
51 | self.lambda_rec = config.lambda_rec
52 | self.lambda_gp = config.lambda_gp
53 | self.att_num = config.att_num
54 |
55 | # Training configurations.
56 | self.batch_size = config.batch_size
57 | self.num_iters = config.num_iters
58 | self.num_iters_decay = config.num_iters_decay
59 | self.g_lr = config.g_lr
60 | self.d_lr = config.d_lr
61 | self.n_critic = config.n_critic
62 | self.beta1 = config.beta1
63 | self.beta2 = config.beta2
64 | self.resume_iters = config.resume_iters
65 | self.ev_ea_c_iters = config.ev_ea_c_iters
66 | self.c_pre_iters = config.c_pre_iters
67 |
68 | # Test configurations.
69 | self.test_iters = config.test_iters
70 |
71 | # Miscellaneous.
72 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73 |
74 | # Directories.
75 | self.log_dir = config.log_dir
76 | self.sample_dir = config.sample_dir
77 | self.model_save_dir = config.model_save_dir
78 | self.result_dir = config.result_dir
79 |
80 | # Step size.
81 | self.log_step = config.log_step
82 | self.sample_step = config.sample_step
83 | self.model_save_step = config.model_save_step
84 | self.lr_update_step = config.lr_update_step
85 |
86 | # Build the model
87 | self.build_model()
88 |
89 | def build_model(self):
90 | """Create networks."""
91 |
92 | self.encoder = AdaINEnc(input_dim = 3, ft_num = self.ft_num)
93 | self.decoder = AdaINDec(input_dim = 3, ft_num = self.ft_num)
94 | self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num, ft_num = self.ft_num)
95 | self.encoder_v = Resnet_Feature()
96 | self.encoder_a = MLP_Encoder(in_dim = self.att_num, nz_num = self.nz_num, out_dim= self.ft_num)
97 | self.D_s = Eb_Discriminator(ft_num = self.ft_num, att_num = self.att_num)
98 | self.C = Linear_Classifier(in_dim= self.ft_num, c_dim = self.c_dim)
99 |
100 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), self.g_lr, [self.beta1, self.beta2])
101 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
102 | self.ev_optimizer = torch.optim.Adam(itertools.chain(self.encoder_v.parameters(), self.C.parameters()), self.d_lr, [self.beta1, self.beta2]) # use the same optimizer to update encoder_v and C
103 | self.ea_optimizer = torch.optim.Adam(self.encoder_a.parameters(), self.d_lr, [self.beta1, self.beta2])
104 | self.ds_optimizer = torch.optim.Adam(self.D_s.parameters(), self.d_lr, [self.beta1, self.beta2])
105 | self.c_optimizer = torch.optim.Adam(self.C.parameters(), self.d_lr, [self.beta1, self.beta2])
106 |
107 | self.encoder.to(self.device)
108 | self.decoder.to(self.device)
109 | self.D.to(self.device)
110 | self.encoder_v.to(self.device)
111 | self.encoder_a.to(self.device)
112 | self.D_s.to(self.device)
113 | self.C.to(self.device)
114 |
115 |
116 |
117 |
118 | def print_network(self, model, name):
119 | """Print out the network information."""
120 | num_params = 0
121 | for p in model.parameters():
122 | num_params += p.numel()
123 | print(model)
124 | print(name)
125 | print("The number of parameters: {}".format(num_params))
126 |
127 | def restore_model(self, resume_iters):
128 | """Restore the trained networks."""
129 |
130 | print('Loading the trained models from step {}...'.format(resume_iters))
131 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(resume_iters))
132 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(resume_iters))
133 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
134 | self.encoder.load_state_dict(torch.load(encoder_path, map_location=lambda storage, loc: storage))
135 | self.decoder.load_state_dict(torch.load(decoder_path, map_location=lambda storage, loc: storage))
136 | self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
137 |
138 | def update_lr(self, g_lr, d_lr):
139 | """Decay learning rates of the generator and discriminator."""
140 | for param_group in self.g_optimizer.param_groups:
141 | param_group['lr'] = g_lr
142 | for param_group in self.d_optimizer.param_groups:
143 | param_group['lr'] = d_lr
144 |
145 | def reset_grad(self):
146 | """Reset the gradient buffers."""
147 | self.g_optimizer.zero_grad()
148 | self.d_optimizer.zero_grad()
149 |
150 |
151 | def denorm(self, x):
152 | """Convert the range from [-1, 1] to [0, 1]."""
153 | out = (x + 1) / 2
154 | return out.clamp_(0, 1)
155 |
156 | def gradient_penalty(self, y, x):
157 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
158 | weight = torch.ones(y.size()).to(self.device)
159 | dydx = torch.autograd.grad(outputs=y,
160 | inputs=x,
161 | grad_outputs=weight,
162 | retain_graph=True,
163 | create_graph=True,
164 | only_inputs=True)[0]
165 |
166 | dydx = dydx.view(dydx.size(0), -1)
167 | dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
168 | return torch.mean((dydx_l2norm-1)**2)
169 |
170 | def label2onehot(self, labels, dim):
171 | """Convert label indices to one-hot vectors."""
172 | batch_size = labels.size(0)
173 | out = torch.zeros(batch_size, dim)
174 | out[np.arange(batch_size), labels.long()] = 1
175 | return out
176 |
177 |
178 | def classification_loss(self, logit, target):
179 | """Compute softmax cross entropy loss."""
180 | return F.cross_entropy(logit, target)
181 |
182 |
183 | def train_ev_ea(self):
184 | """Train encoder_a and encoder_v with C and D_s."""
185 | # Set data loader.
186 | data_loader = self.data_loader
187 |
188 | noise = torch.FloatTensor(self.batch_size, self.nz_num)
189 | noise = noise.to(self.device) # noise vector z
190 |
191 | start_iters = 0
192 |
193 | # Start training.
194 | print('Start encoder_a and encoder_v training...')
195 | start_time = time.time()
196 |
197 | ev_ea_c_iters = self.ev_ea_c_iters
198 | c_pre_iters = self.c_pre_iters
199 |
200 | C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(ev_ea_c_iters))
201 |
202 | encoder_a_path = os.path.join(self.model_save_dir, '{}-encoder_a.ckpt'.format(ev_ea_c_iters))
203 |
204 | encoder_v_path = os.path.join(self.model_save_dir, '{}-encoder_v.ckpt'.format(ev_ea_c_iters))
205 |
206 |
207 | if os.path.exists(C_path):
208 | self.C.load_state_dict(torch.load(C_path, map_location=lambda storage, loc: storage))
209 | print('Load model checkpoints from {}'.format(C_path))
210 |
211 | self.encoder_a.load_state_dict(torch.load(encoder_a_path, map_location=lambda storage, loc: storage))
212 | print('Load model checkpoints from {}'.format(encoder_a_path))
213 |
214 | self.encoder_v.load_state_dict(torch.load(encoder_v_path, map_location=lambda storage, loc: storage))
215 | print('Load model checkpoints from {}'.format(encoder_v_path))
216 | else:
217 | C_pre_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(c_pre_iters))
218 | if os.path.exists(C_pre_path):
219 | self.C.load_state_dict(torch.load(C_pre_path, map_location=lambda storage, loc: storage))
220 | print('Load model pretrained checkpoints from {}'.format(C_pre_path))
221 | else:
222 | for i in range(0, c_pre_iters):
223 | # Fetch real images, attributes and labels.
224 | x_real, wrong_images, attributes, _, label_org = data_loader.train.next_batch(self.batch_size,10)
225 |
226 |
227 | x_real = x_real.to(self.device) # Input images.
228 | attributes = attributes.to(self.device) # Input attributes
229 | label_org = label_org.to(self.device) # Labels for computing classification loss.
230 |
231 | ev_x = self.encoder_v(x_real)
232 | cls_x = self.C(ev_x.detach())
233 | # Classification loss from only images for C training
234 | c_loss_cls = self.classification_loss(cls_x, label_org)
235 | # Backward and optimize.
236 | self.c_optimizer.zero_grad()
237 | c_loss_cls.backward()
238 | self.c_optimizer.step()
239 |
240 | if (i+1) % self.log_step == 0:
241 | loss = {}
242 | loss['c_loss_cls'] = c_loss_cls.item()
243 | prec1, prec5 = accuracy(cls_x.data, label_org.data, topk=(1, 5))
244 | loss['prec1'] = prec1
245 | loss['prec5'] = prec5
246 | log = "C pretraining iteration [{}/{}]".format(i+1, c_pre_iters)
247 | for tag, value in loss.items():
248 | log += ", {}: {:.4f}".format(tag, value)
249 | print(log)
250 | torch.save(self.C.state_dict(), C_pre_path)
251 | print('Saved model pretrained checkpoints into {}...'.format(C_pre_path))
252 |
253 | for i in range(c_pre_iters, ev_ea_c_iters):
254 | # Fetch real images, attributes and labels.
255 | x_real, wrong_images, attributes, _, label_org = data_loader.train.next_batch(self.batch_size,10)
256 |
257 |
258 | x_real = x_real.to(self.device) # Input images.
259 | attributes = attributes.to(self.device) # Input attributes
260 | label_org = label_org.to(self.device) # Labels for computing classification loss.
261 |
262 |
263 | # =================================================================================== #
264 | # Train the domain-specific features discriminator
265 | # =================================================================================== #
266 |
267 | noise.normal_(0, 1)
268 | # Compute embedding of both images and attributes
269 | ea_a = self.encoder_a(attributes, noise)
270 | ev_x = self.encoder_v(x_real)
271 |
272 |
273 | ev_x_real = self.D_s(ev_x, attributes)
274 | ds_loss_real = -torch.mean(ev_x_real)
275 |
276 |
277 | ea_a_fake = self.D_s(ea_a, attributes)
278 | ds_loss_fake = torch.mean(ea_a_fake)
279 |
280 | # Compute loss for gradient penalty.
281 | alpha = torch.rand(ev_x.size(0), 1).to(self.device)
282 | ebd_hat = (alpha * ev_x.data + (1 - alpha) * ea_a.data).requires_grad_(True)
283 |
284 | ebd_inter = self.D_s(ebd_hat, attributes)
285 | ds_loss_gp = self.gradient_penalty(ebd_inter, ebd_hat)
286 |
287 | ds_loss = ds_loss_real + ds_loss_fake + self.lambda_gp * ds_loss_gp #+ ds_loss_realw
288 | #self.reset_grad_eb()
289 | self.ea_optimizer.zero_grad()
290 | self.ds_optimizer.zero_grad()
291 | self.ev_optimizer.zero_grad()
292 |
293 | ds_loss.backward()
294 | self.ds_optimizer.step()
295 | if (i+1) % self.n_critic == 0:
296 | # =================================================================================== #
297 | # Train the encoder_a and C
298 | # =================================================================================== #
299 | ev_x = self.encoder_v(x_real)
300 | ev_x_real = self.D_s(ev_x, attributes)
301 | ev_loss_real = torch.mean(ev_x_real)
302 |
303 | cls_x = self.C(ev_x)
304 | c_loss_cls = self.classification_loss(cls_x, label_org)
305 |
306 | # Backward and optimize.
307 | ev_c_loss = ev_loss_real + c_loss_cls
308 | self.ea_optimizer.zero_grad()
309 | self.ds_optimizer.zero_grad()
310 | self.ev_optimizer.zero_grad()
311 | ev_c_loss.backward()
312 | self.ev_optimizer.step()
313 |
314 | # =================================================================================== #
315 | # Train the encoder_v #
316 | # =================================================================================== #
317 | noise.normal_(0, 1)
318 | ea_a = self.encoder_a(attributes,noise)
319 | ea_a_fake = self.D_s(ea_a, attributes)
320 | ea_loss_fake = -torch.mean(ea_a_fake)
321 |
322 | cls_a = self.C(ea_a)
323 | ebn_loss_cls = self.classification_loss(cls_a, label_org)
324 |
325 |
326 | # Backward and optimize.
327 | ea_loss = ea_loss_fake + ebn_loss_cls
328 | self.ea_optimizer.zero_grad()
329 | self.ds_optimizer.zero_grad()
330 | self.ev_optimizer.zero_grad()
331 | ea_loss.backward()
332 | self.ea_optimizer.step()
333 |
334 | # Logging.
335 | loss = {}
336 |
337 | loss['ds/ds_loss_real'] = ds_loss_real.item()
338 | loss['ds/ds_loss_fake'] = ds_loss_fake.item()
339 | loss['ds/ds_loss_gp'] = ds_loss_gp.item()
340 |
341 | # Print out training information.
342 | if (i+1) % self.log_step == 0:
343 | et = time.time() - start_time
344 | et = str(datetime.timedelta(seconds=et))[:-7]
345 | prec1, prec5 = accuracy(cls_x.data, label_org.data, topk=(1, 5))
346 | loss['prec1'] = prec1
347 | loss['prec5'] = prec5
348 | prec1e, prec5e = accuracy(cls_a.data, label_org.data, topk=(1, 5))
349 | loss['prec1e'] = prec1e
350 | loss['prec5e'] = prec5e
351 | log = "Encoder_a and Encoder_v Training Elapsed [{}], Iteration [{}/{}]".format(et, i+1, ev_ea_c_iters)
352 | for tag, value in loss.items():
353 | log += ", {}: {:.4f}".format(tag, value)
354 | print(log)
355 |
356 |
357 | # Save model checkpoints.
358 | if (i+1) % self.model_save_step == 0:
359 | C_path = os.path.join(self.model_save_dir, '{}-C.ckpt'.format(i+1))
360 | torch.save(self.C.state_dict(), C_path)
361 | print('Saved model checkpoints into {}...'.format(C_path))
362 |
363 | encoder_a_path = os.path.join(self.model_save_dir, '{}-encoder_a.ckpt'.format(i+1))
364 | torch.save(self.encoder_a.state_dict(), encoder_a_path)
365 | print('Saved model checkpoints into {}...'.format(encoder_a_path))
366 |
367 | encoder_v_path = os.path.join(self.model_save_dir, '{}-encoder_v.ckpt'.format(i+1))
368 | torch.save(self.encoder_v.state_dict(), encoder_v_path)
369 | print('Saved model checkpoints into {}...'.format(encoder_v_path))
370 |
371 | def train(self):
372 | """Train zstgan"""
373 | # train encoder_a and encoder_v first
374 | self.train_ev_ea()
375 | self.encoder_v.eval()
376 |
377 | # Set data loader.
378 | data_loader = self.data_loader
379 |
380 | # Learning rate cache for decaying.
381 | g_lr = self.g_lr
382 | d_lr = self.d_lr
383 |
384 | # noise vector z
385 | noise = torch.FloatTensor(self.batch_size, self.nz_num)
386 | noise = noise.to(self.device)
387 |
388 | # Start training from scratch or resume training.
389 | start_iters = 0
390 | if self.resume_iters:
391 | start_iters = self.resume_iters
392 | self.restore_model(self.resume_iters)
393 |
394 | # Start training.
395 | print('Start training...')
396 | start_time = time.time()
397 | empty = torch.FloatTensor(1, 3,self.image_size,self.image_size).to(self.device)
398 | empty.fill_(1)
399 | for i in range(start_iters, self.num_iters):
400 | # Fetch real images and labels.
401 | x_real, wrong_images, attributes, _, label_org = data_loader.train.next_batch(self.batch_size,10)
402 | label_org = label_org.to(self.device)
403 | attributes = attributes.to(self.device)
404 | x_real = x_real.to(self.device)
405 | # Generate target domains
406 | ev_x = self.encoder_v(x_real)
407 |
408 | rand_idx = torch.randperm(label_org.size(0))
409 |
410 | trg_ev_x_1 = ev_x[rand_idx]
411 | trg_ev_x = trg_ev_x_1.clone()
412 | label_trg_1 = label_org[rand_idx]
413 | label_trg = label_trg_1.clone()
414 |
415 | # =================================================================================== #
416 | # Train the discriminator
417 | # =================================================================================== #
418 |
419 | # Compute loss with real images.
420 | out_src, out_cls = self.D(x_real)
421 | d_loss_real = - torch.mean(out_src)
422 | d_loss_mut = torch.mean(torch.abs(ev_x.detach() - out_cls))
423 |
424 | # Compute loss with fake images.
425 | x_fake = self.decoder(self.encoder(x_real), trg_ev_x)
426 | out_src, out_cls = self.D(x_fake.detach())
427 | d_loss_fake = torch.mean(out_src)
428 |
429 | # Compute loss for gradient penalty.
430 | alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
431 | x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
432 | out_src, _ = self.D(x_hat)
433 | d_loss_gp = self.gradient_penalty(out_src, x_hat)
434 |
435 | # Backward and optimize.
436 | d_loss = d_loss_real + d_loss_fake + self.lambda_mut * d_loss_mut + self.lambda_gp * d_loss_gp
437 | self.reset_grad()
438 | d_loss.backward()
439 | self.d_optimizer.step()
440 |
441 | # Logging.
442 | loss = {}
443 | loss['D/loss_real'] = d_loss_real.item()
444 | loss['D/loss_fake'] = d_loss_fake.item()
445 | loss['D/loss_mut'] = d_loss_mut.item()
446 | loss['D/loss_gp'] = d_loss_gp.item()
447 |
448 | # =================================================================================== #
449 | # Train the encoder and decoder
450 | # =================================================================================== #
451 |
452 | if (i+1) % self.n_critic == 0:
453 | # Original-to-target domain.
454 | x_di = self.encoder(x_real)
455 |
456 | x_fake = self.decoder(x_di, trg_ev_x)
457 | x_reconst1 = self.decoder(x_di, ev_x)
458 | out_src, out_cls = self.D(x_fake)
459 | g_loss_fake = - torch.mean(out_src)
460 | g_loss_mut = torch.mean(torch.abs(trg_ev_x.detach() - out_cls))
461 |
462 | # Target-to-original domain.
463 | x_fake_di = self.encoder(x_fake)
464 |
465 | x_reconst2 = self.decoder(x_fake_di, ev_x)
466 |
467 | g_loss_rec1 = torch.mean(torch.abs(x_real - x_reconst1))
468 |
469 | g_loss_rec12 = torch.mean(torch.abs(x_real - x_reconst2))
470 |
471 | # Backward and optimize.
472 | g_loss = g_loss_fake + self.lambda_rec * (g_loss_rec1 + g_loss_rec12) + self.lambda_mut * g_loss_mut
473 | self.reset_grad()
474 | g_loss.backward()
475 | self.g_optimizer.step()
476 |
477 | # Logging.
478 | loss['G/loss_fake'] = g_loss_fake.item()
479 | loss['G/loss_rec1'] = g_loss_rec1.item()
480 | loss['G/loss_rec2'] = g_loss_rec12.item()
481 | loss['G/loss_mut'] = g_loss_mut.item()
482 |
483 | # Print out training information.
484 | if (i+1) % self.log_step == 0:
485 | et = time.time() - start_time
486 | et = str(datetime.timedelta(seconds=et))[:-7]
487 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
488 | for tag, value in loss.items():
489 | log += ", {}: {:.4f}".format(tag, value)
490 | print(log)
491 |
492 | # Translate fixed images for debugging.
493 | if (i+1) % self.sample_step == 0:
494 | with torch.no_grad():
495 | out_A2B_results = [empty]
496 |
497 | for idx1 in range(label_org.size(0)):
498 | out_A2B_results.append(x_real[idx1:idx1+1])
499 |
500 | for idx2 in range(label_org.size(0)):
501 | out_A2B_results.append(x_real[idx2:idx2+1])
502 |
503 | for idx1 in range(label_org.size(0)):
504 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2+1]), ev_x[idx1:idx1+1])
505 | out_A2B_results.append(x_fake)
506 | results_concat = torch.cat(out_A2B_results)
507 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results.jpg'.format(i+1))
508 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0)
509 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
510 | # save vision-driven and attribute-driven results on unseen domains
511 | x_real, wrong_images, attributes, _, label_org = data_loader.test.next_batch(self.batch_size,10)
512 | label_org = label_org.to(self.device)
513 | x_real = x_real.to(self.device)
514 | attributes = attributes.to(self.device)
515 | ev_x = self.encoder_v(x_real)
516 | noise.normal_(0, 1)
517 | ea_a = self.encoder_a(attributes, noise)
518 |
519 | out_A2B_results = [empty]
520 | out_A2B_results_a = [empty]
521 |
522 | for idx1 in range(label_org.size(0)):
523 | out_A2B_results.append(x_real[idx1:idx1+1])
524 | out_A2B_results_a.append(x_real[idx1:idx1+1])
525 |
526 | for idx2 in range(label_org.size(0)):
527 | out_A2B_results.append(x_real[idx2:idx2+1])
528 | out_A2B_results_a.append(x_real[idx2:idx2+1])
529 |
530 | for idx1 in range(label_org.size(0)):
531 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2+1]), ev_x[idx1:idx1+1])
532 | out_A2B_results.append(x_fake)
533 |
534 | x_fake_a = self.decoder(self.encoder(x_real[idx2:idx2+1]), ea_a[idx1:idx1+1])
535 | out_A2B_results_a.append(x_fake_a)
536 | results_concat = torch.cat(out_A2B_results)
537 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results_test_v.jpg'.format(i+1))
538 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0)
539 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
540 |
541 | results_concat = torch.cat(out_A2B_results_a)
542 | x_AB_results_path = os.path.join(self.sample_dir, '{}_x_AB_results_test_a.jpg'.format(i+1))
543 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0)
544 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
545 |
546 |
547 |
548 |
549 | # Save model checkpoints.
550 | if (i+1) % self.model_save_step == 0:
551 | encoder_path = os.path.join(self.model_save_dir, '{}-encoder.ckpt'.format(i+1))
552 | decoder_path = os.path.join(self.model_save_dir, '{}-decoder.ckpt'.format(i+1))
553 | D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
554 | torch.save(self.encoder.state_dict(), encoder_path)
555 | torch.save(self.decoder.state_dict(), decoder_path)
556 | torch.save(self.D.state_dict(), D_path)
557 | print('Saved model checkpoints into {}...'.format(self.model_save_dir))
558 |
559 |
560 |
561 | # Decay learning rates.
562 | if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
563 | g_lr -= (self.g_lr / float(self.num_iters_decay))
564 | d_lr -= (self.d_lr / float(self.num_iters_decay))
565 | self.update_lr(g_lr, d_lr)
566 | print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
567 |
568 |
569 |
570 | def test(self):
571 | """Translate images using zstgan on unseen test set."""
572 | # Load the trained models.
573 | self.train_ev_ea()
574 | self.restore_model(self.test_iters)
575 | self.encoder_v.eval()
576 | # Set data loader.
577 | data_loader = self.data_loader
578 | empty = torch.FloatTensor(1, 3,self.image_size,self.image_size).to(self.device)
579 | empty.fill_(1)
580 | noise = torch.FloatTensor(self.batch_size, self.nz_num)
581 | noise = noise.to(self.device)
582 | step = 0
583 | data_loader.test.reinitialize_index()
584 | with torch.no_grad():
585 | while True:
586 | try:
587 | x_real, wrong_images, attributes, _, label_org = data_loader.test.next_batch_test(self.batch_size,10)
588 | except:
589 | break
590 | x_real = x_real.to(self.device)
591 | label_org = label_org.to(self.device)
592 | attributes = attributes.to(self.device)
593 |
594 |
595 | ev_x = self.encoder_v(x_real)
596 | noise.normal_(0, 1)
597 | ea_a = self.encoder_a(attributes, noise)
598 |
599 | out_A2B_results = [empty]
600 | out_A2B_results_a = [empty]
601 |
602 | for idx1 in range(label_org.size(0)):
603 | out_A2B_results.append(x_real[idx1:idx1+1])
604 | out_A2B_results_a.append(x_real[idx1:idx1+1])
605 |
606 | for idx2 in range(label_org.size(0)):
607 | out_A2B_results.append(x_real[idx2:idx2+1])
608 | out_A2B_results_a.append(x_real[idx2:idx2+1])
609 |
610 | for idx1 in range(label_org.size(0)):
611 | x_fake = self.decoder(self.encoder(x_real[idx2:idx2+1]), ev_x[idx1:idx1+1])
612 | out_A2B_results.append(x_fake)
613 |
614 | x_fake_a = self.decoder(self.encoder(x_real[idx2:idx2+1]), ea_a[idx1:idx1+1])
615 | out_A2B_results_a.append(x_fake_a)
616 | results_concat = torch.cat(out_A2B_results)
617 | x_AB_results_path = os.path.join(self.result_dir, '{}_x_AB_results_test_v.jpg'.format(step+1))
618 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0)
619 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
620 |
621 | results_concat = torch.cat(out_A2B_results_a)
622 | x_AB_results_path = os.path.join(self.result_dir, '{}_x_AB_results_test_a.jpg'.format(step+1))
623 | save_image(self.denorm(results_concat.data.cpu()), x_AB_results_path, nrow=label_org.size(0)+1,padding=0)
624 | print('Saved real and fake images into {}...'.format(x_AB_results_path))
625 |
626 | step += 1
627 |
628 |
629 |
630 |
631 |
632 |
--------------------------------------------------------------------------------