├── .gitignore ├── LICENSE ├── README.md ├── TripleGAN.py ├── assests ├── accuracy.png ├── algorithm.JPG ├── classification_result.JPG ├── generated_image │ ├── all_class.png │ ├── class_0.png │ ├── class_1.png │ ├── class_2.png │ ├── class_3.png │ ├── class_4.png │ ├── class_5.png │ ├── class_6.png │ ├── class_7.png │ ├── class_8.png │ ├── class_9.png │ └── style_by_style.png ├── loss.png ├── network.JPG ├── result.JPG └── result2.JPG ├── cifar10.py ├── main.py ├── ops.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TripleGAN-Tensorflow 2 | Simple Tensorflow implementation of [Triple Generative Adversarial Nets](https://arxiv.org/pdf/1703.02291.pdf)(Triple-GAN) 3 | 4 | If you want to see the original author's code, please refer to this [link](https://github.com/zhenxuan00/triple-gan) 5 | 6 | ## Issue 7 | * I am now modifying the ***weight normalization*** (If you know how to implement with tensorflow, let me know) 8 | 9 | ## Usage 10 | ```bash 11 | > python main.py --n 4000 --epoch 1000 --batch_size 20 --unlabel_batch_size 250 --z_dim 100 12 | ``` 13 | * See `main.py` for other arguments. 14 | 15 | ## Idea 16 | ### Network Architecture 17 | ![network](./assests/network.JPG) 18 | 19 | ### Algorithm 20 | ![algorithm](./assests/algorithm.JPG) 21 | 22 | ## Result 23 | ### Classification result 24 | ![c_result](./assests/result.JPG) 25 | 26 | ### Convergence speed on SVHN 27 | ![s_result](./assests/result2.JPG) 28 | 29 | ## My result (Cifar10, 4000 labelled image) 30 | ### Loss 31 | ![loss](./assests/loss.png) 32 | 33 | ### Classification accuracy 34 | ![accuracy](./assests/accuracy.png) 35 | 36 | ### Generated Image (Other images are in assests) 37 | #### Automobile 38 | ![automobile](./assests/generated_image/class_1.png) 39 | 40 | ## Related works 41 | * [CycleGAN](https://github.com/taki0112/CycleGAN-Tensorflow) 42 | * [DiscoGAN](https://github.com/taki0112/DiscoGAN-Tensorflow) 43 | 44 | ## Reference 45 | * [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections) 46 | 47 | ## Author 48 | Junho Kim 49 | -------------------------------------------------------------------------------- /TripleGAN.py: -------------------------------------------------------------------------------- 1 | import cifar10 2 | from ops import * 3 | from utils import * 4 | import time 5 | 6 | class TripleGAN(object) : 7 | def __init__(self, sess, epoch, batch_size, unlabel_batch_size, z_dim, dataset_name, n, gan_lr, cla_lr, checkpoint_dir, result_dir, log_dir): 8 | self.sess = sess 9 | self.dataset_name = dataset_name 10 | self.checkpoint_dir = checkpoint_dir 11 | self.result_dir = result_dir 12 | self.log_dir = log_dir 13 | self.epoch = epoch 14 | self.batch_size = batch_size 15 | self.unlabelled_batch_size = unlabel_batch_size 16 | self.test_batch_size = 1000 17 | self.model_name = "TripleGAN" # name for checkpoint 18 | if self.dataset_name == 'cifar10' : 19 | self.input_height = 32 20 | self.input_width = 32 21 | self.output_height = 32 22 | self.output_width = 32 23 | 24 | self.z_dim = z_dim 25 | self.y_dim = 10 26 | self.c_dim = 3 27 | 28 | self.learning_rate = gan_lr # 3e-4, 1e-3 29 | self.cla_learning_rate = cla_lr # 3e-3, 1e-2 ? 30 | self.GAN_beta1 = 0.5 31 | self.beta1 = 0.9 32 | self.beta2 = 0.999 33 | self.epsilon = 1e-8 34 | self.alpha = 0.5 35 | self.alpha_cla_adv = 0.01 36 | self.init_alpha_p = 0.0 # 0.1, 0.03 37 | self.apply_alpha_p = 0.1 38 | self.apply_epoch = 200 # 200, 300 39 | self.decay_epoch = 50 40 | 41 | self.sample_num = 64 42 | self.visual_num = 100 43 | self.len_discrete_code = 10 44 | 45 | self.data_X, self.data_y, self.unlabelled_X, self.unlabelled_y, self.test_X, self.test_y = cifar10.prepare_data(n) # trainX, trainY, testX, testY 46 | 47 | self.num_batches = len(self.data_X) // self.batch_size 48 | 49 | else : 50 | raise NotImplementedError 51 | 52 | def discriminator(self, x, y_, scope='discriminator', is_training=True, reuse=False): 53 | with tf.variable_scope(scope, reuse=reuse) : 54 | x = dropout(x, rate=0.2, is_training=is_training) 55 | y = tf.reshape(y_, [-1, 1, 1, self.y_dim]) 56 | x = conv_concat(x,y) 57 | 58 | x = lrelu(conv_layer(x, filter_size=32, kernel=[3,3], layer_name=scope+'_conv1')) 59 | x = conv_concat(x,y) 60 | x = lrelu(conv_layer(x, filter_size=32, kernel=[3,3], stride=2, layer_name=scope+'_conv2')) 61 | x = dropout(x, rate=0.2, is_training=is_training) 62 | x = conv_concat(x,y) 63 | 64 | x = lrelu(conv_layer(x, filter_size=64, kernel=[3,3], layer_name=scope+'_conv3')) 65 | x = conv_concat(x,y) 66 | x = lrelu(conv_layer(x, filter_size=64, kernel=[3,3], stride=2, layer_name=scope+'_conv4')) 67 | x = dropout(x, rate=0.2, is_training=is_training) 68 | x = conv_concat(x,y) 69 | 70 | x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv5')) 71 | x = conv_concat(x,y) 72 | x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv6')) 73 | x = conv_concat(x,y) 74 | 75 | x = Global_Average_Pooling(x) 76 | x = flatten(x) 77 | x = concat([x,y_]) # mlp_concat 78 | 79 | x_logit = linear(x, unit=1, layer_name=scope+'_linear1') 80 | out = sigmoid(x_logit) 81 | 82 | 83 | return out, x_logit, x 84 | 85 | def generator(self, z, y, scope='generator', is_training=True, reuse=False): 86 | with tf.variable_scope(scope, reuse=reuse) : 87 | 88 | x = concat([z, y]) # mlp_concat 89 | 90 | x = relu(linear(x, unit=512*4*4, layer_name=scope+'_linear1')) 91 | x = batch_norm(x, is_training=is_training, scope=scope+'_batch1') 92 | 93 | x = tf.reshape(x, shape=[-1, 4, 4, 512]) 94 | y = tf.reshape(y, [-1, 1, 1, self.y_dim]) 95 | x = conv_concat(x,y) 96 | 97 | x = relu(deconv_layer(x, filter_size=256, kernel=[5,5], stride=2, layer_name=scope+'_deconv1')) 98 | x = batch_norm(x, is_training=is_training, scope=scope+'_batch2') 99 | x = conv_concat(x,y) 100 | 101 | x = relu(deconv_layer(x, filter_size=128, kernel=[5,5], stride=2, layer_name=scope+'_deconv2')) 102 | x = batch_norm(x, is_training=is_training, scope=scope+'_batch3') 103 | x = conv_concat(x,y) 104 | 105 | x = tanh(deconv_layer(x, filter_size=3, kernel=[5,5], stride=2, wn=False, layer_name=scope+'deconv3')) 106 | 107 | return x 108 | def classifier(self, x, scope='classifier', is_training=True, reuse=False): 109 | with tf.variable_scope(scope, reuse=reuse) : 110 | x = gaussian_noise_layer(x) # default = 0.15 111 | x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv1')) 112 | x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv2')) 113 | x = lrelu(conv_layer(x, filter_size=128, kernel=[3,3], layer_name=scope+'_conv3')) 114 | 115 | x = max_pooling(x, kernel=[2,2], stride=2) 116 | x = dropout(x, rate=0.5, is_training=is_training) 117 | 118 | x = lrelu(conv_layer(x, filter_size=256, kernel=[3,3], layer_name=scope+'_conv4')) 119 | x = lrelu(conv_layer(x, filter_size=256, kernel=[3,3], layer_name=scope+'_conv5')) 120 | x = lrelu(conv_layer(x, filter_size=256, kernel=[3,3], layer_name=scope+'_conv6')) 121 | 122 | x = max_pooling(x, kernel=[2,2], stride=2) 123 | x = dropout(x, rate=0.5, is_training=is_training) 124 | 125 | x = lrelu(conv_layer(x, filter_size=512, kernel=[3,3], layer_name=scope+'_conv7')) 126 | x = nin(x, unit=256, layer_name=scope+'_nin1') 127 | x = nin(x, unit=128, layer_name=scope+'_nin2') 128 | 129 | x = Global_Average_Pooling(x) 130 | x = flatten(x) 131 | x = linear(x, unit=10, layer_name=scope+'_linear1') 132 | return x 133 | 134 | def build_model(self): 135 | image_dims = [self.input_height, self.input_width, self.c_dim] 136 | bs = self.batch_size 137 | unlabel_bs = self.unlabelled_batch_size 138 | test_bs = self.test_batch_size 139 | alpha = self.alpha 140 | alpha_cla_adv = self.alpha_cla_adv 141 | self.alpha_p = tf.placeholder(tf.float32, name='alpha_p') 142 | self.gan_lr = tf.placeholder(tf.float32, name='gan_lr') 143 | self.cla_lr = tf.placeholder(tf.float32, name='cla_lr') 144 | self.unsup_weight = tf.placeholder(tf.float32, name='unsup_weight') 145 | self.c_beta1 = tf.placeholder(tf.float32, name='c_beta1') 146 | 147 | """ Graph Input """ 148 | # images 149 | self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') 150 | self.unlabelled_inputs = tf.placeholder(tf.float32, [unlabel_bs] + image_dims, name='unlabelled_images') 151 | self.test_inputs = tf.placeholder(tf.float32, [test_bs] + image_dims, name='test_images') 152 | 153 | # labels 154 | self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y') 155 | self.unlabelled_inputs_y = tf.placeholder(tf.float32, [unlabel_bs, self.y_dim]) 156 | self.test_label = tf.placeholder(tf.float32, [test_bs, self.y_dim], name='test_label') 157 | self.visual_y = tf.placeholder(tf.float32, [self.visual_num, self.y_dim], name='visual_y') 158 | 159 | # noises 160 | self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') 161 | self.visual_z = tf.placeholder(tf.float32, [self.visual_num, self.z_dim], name='visual_z') 162 | 163 | """ Loss Function """ 164 | # A Game with Three Players 165 | 166 | # output of D for real images 167 | D_real, D_real_logits, _ = self.discriminator(self.inputs, self.y, is_training=True, reuse=False) 168 | 169 | # output of D for fake images 170 | G = self.generator(self.z, self.y, is_training=True, reuse=False) 171 | D_fake, D_fake_logits, _ = self.discriminator(G, self.y, is_training=True, reuse=True) 172 | 173 | # output of C for real images 174 | C_real_logits = self.classifier(self.inputs, is_training=True, reuse=False) 175 | R_L = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=C_real_logits)) 176 | 177 | # output of D for unlabelled images 178 | Y_c = self.classifier(self.unlabelled_inputs, is_training=True, reuse=True) 179 | D_cla, D_cla_logits, _ = self.discriminator(self.unlabelled_inputs, Y_c, is_training=True, reuse=True) 180 | 181 | # output of C for fake images 182 | C_fake_logits = self.classifier(G, is_training=True, reuse=True) 183 | R_P = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=C_fake_logits)) 184 | 185 | # 186 | 187 | # get loss for discriminator 188 | d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))) 189 | d_loss_fake = (1-alpha)*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake))) 190 | d_loss_cla = alpha*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_cla_logits, labels=tf.zeros_like(D_cla))) 191 | self.d_loss = d_loss_real + d_loss_fake + d_loss_cla 192 | 193 | # get loss for generator 194 | self.g_loss = (1-alpha)*tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) 195 | 196 | # test loss for classify 197 | test_Y = self.classifier(self.test_inputs, is_training=False, reuse=True) 198 | correct_prediction = tf.equal(tf.argmax(test_Y, 1), tf.argmax(self.test_label, 1)) 199 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 200 | 201 | # get loss for classify 202 | max_c = tf.cast(tf.argmax(Y_c, axis=1), tf.float32) 203 | c_loss_dis = tf.reduce_mean(max_c * tf.nn.softmax_cross_entropy_with_logits(logits=D_cla_logits, labels=tf.ones_like(D_cla))) 204 | # self.c_loss = alpha * c_loss_dis + R_L + self.alpha_p*R_P 205 | 206 | # R_UL = self.unsup_weight * tf.reduce_mean(tf.squared_difference(Y_c, self.unlabelled_inputs_y)) 207 | self.c_loss = alpha_cla_adv * alpha * c_loss_dis + R_L + self.alpha_p*R_P 208 | 209 | """ Training """ 210 | 211 | # divide trainable variables into a group for D and a group for G 212 | t_vars = tf.trainable_variables() 213 | d_vars = [var for var in t_vars if 'discriminator' in var.name] 214 | g_vars = [var for var in t_vars if 'generator' in var.name] 215 | c_vars = [var for var in t_vars if 'classifier' in var.name] 216 | 217 | for var in t_vars: print(var.name) 218 | # optimizers 219 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 220 | self.d_optim = tf.train.AdamOptimizer(self.gan_lr, beta1=self.GAN_beta1).minimize(self.d_loss, var_list=d_vars) 221 | self.g_optim = tf.train.AdamOptimizer(self.gan_lr, beta1=self.GAN_beta1).minimize(self.g_loss, var_list=g_vars) 222 | self.c_optim = tf.train.AdamOptimizer(self.cla_lr, beta1=self.beta1, beta2=self.beta2, epsilon=self.epsilon).minimize(self.c_loss, var_list=c_vars) 223 | 224 | """" Testing """ 225 | # for test 226 | self.fake_images = self.generator(self.visual_z, self.visual_y, is_training=False, reuse=True) 227 | 228 | """ Summary """ 229 | d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real) 230 | d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake) 231 | d_loss_cla_sum = tf.summary.scalar("d_loss_cla", d_loss_cla) 232 | 233 | d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 234 | g_loss_sum = tf.summary.scalar("g_loss", self.g_loss) 235 | c_loss_sum = tf.summary.scalar("c_loss", self.c_loss) 236 | 237 | 238 | 239 | # final summary operations 240 | self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum]) 241 | self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum]) 242 | self.c_sum = tf.summary.merge([d_loss_cla_sum, c_loss_sum]) 243 | 244 | 245 | def train(self): 246 | 247 | # initialize all variables 248 | tf.global_variables_initializer().run() 249 | gan_lr = self.learning_rate 250 | cla_lr = self.cla_learning_rate 251 | 252 | # graph inputs for visualize training results 253 | self.sample_z = np.random.uniform(-1, 1, size=(self.visual_num, self.z_dim)) 254 | self.test_codes = self.data_y[0:self.visual_num] 255 | 256 | # saver to save model 257 | self.saver = tf.train.Saver() 258 | 259 | # summary writer 260 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) 261 | 262 | # restore check-point if it exits 263 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 264 | if could_load: 265 | start_epoch = (int)(checkpoint_counter / self.num_batches) 266 | start_batch_id = checkpoint_counter - start_epoch * self.num_batches 267 | counter = checkpoint_counter 268 | with open('lr_logs.txt', 'r') as f : 269 | line = f.readlines() 270 | line = line[-1] 271 | gan_lr = float(line.split()[0]) 272 | cla_lr = float(line.split()[1]) 273 | print("gan_lr : ", gan_lr) 274 | print("cla_lr : ", cla_lr) 275 | print(" [*] Load SUCCESS") 276 | else: 277 | start_epoch = 0 278 | start_batch_id = 0 279 | counter = 1 280 | print(" [!] Load failed...") 281 | 282 | # loop for epoch 283 | start_time = time.time() 284 | 285 | for epoch in range(start_epoch, self.epoch): 286 | 287 | if epoch >= self.decay_epoch : 288 | gan_lr *= 0.995 289 | cla_lr *= 0.99 290 | print("**** learning rate DECAY ****") 291 | print(gan_lr) 292 | print(cla_lr) 293 | 294 | if epoch >= self.apply_epoch : 295 | alpha_p = self.apply_alpha_p 296 | else : 297 | alpha_p = self.init_alpha_p 298 | 299 | rampup_value = rampup(epoch - 1) 300 | unsup_weight = rampup_value * 100.0 if epoch > 1 else 0 301 | 302 | # get batch data 303 | for idx in range(start_batch_id, self.num_batches): 304 | batch_images = self.data_X[idx * self.batch_size : (idx + 1) * self.batch_size] 305 | batch_codes = self.data_y[idx * self.batch_size : (idx + 1) * self.batch_size] 306 | 307 | batch_unlabelled_images = self.unlabelled_X[idx * self.unlabelled_batch_size : (idx + 1) * self.unlabelled_batch_size] 308 | batch_unlabelled_images_y = self.unlabelled_y[idx * self.unlabelled_batch_size : (idx + 1) * self.unlabelled_batch_size] 309 | 310 | batch_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 311 | 312 | feed_dict = { 313 | self.inputs: batch_images, self.y: batch_codes, 314 | self.unlabelled_inputs: batch_unlabelled_images, 315 | self.unlabelled_inputs_y: batch_unlabelled_images_y, 316 | self.z: batch_z, self.alpha_p: alpha_p, 317 | self.gan_lr: gan_lr, self.cla_lr: cla_lr, 318 | self.unsup_weight : unsup_weight 319 | } 320 | # update D network 321 | _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], feed_dict=feed_dict) 322 | self.writer.add_summary(summary_str, counter) 323 | 324 | # update G network 325 | _, summary_str_g, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict=feed_dict) 326 | self.writer.add_summary(summary_str_g, counter) 327 | 328 | # update C network 329 | _, summary_str_c, c_loss = self.sess.run([self.c_optim, self.c_sum, self.c_loss], feed_dict=feed_dict) 330 | self.writer.add_summary(summary_str_c, counter) 331 | 332 | # display training status 333 | counter += 1 334 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, c_loss: %.8f" \ 335 | % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss, c_loss)) 336 | 337 | # save training results for every 100 steps 338 | """ 339 | if np.mod(counter, 100) == 0: 340 | samples = self.sess.run(self.fake_images, 341 | feed_dict={self.z: self.sample_z, self.y: self.test_codes}) 342 | image_frame_dim = int(np.floor(np.sqrt(self.visual_num))) 343 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 344 | './' + check_folder( 345 | self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( 346 | epoch, idx)) 347 | """ 348 | 349 | # classifier test 350 | test_acc = 0.0 351 | 352 | for idx in range(10) : 353 | test_batch_x = self.test_X[idx * self.test_batch_size : (idx+1) * self.test_batch_size] 354 | test_batch_y = self.test_y[idx * self.test_batch_size : (idx+1) * self.test_batch_size] 355 | 356 | acc_ = self.sess.run(self.accuracy, feed_dict={ 357 | self.test_inputs: test_batch_x, 358 | self.test_label: test_batch_y 359 | }) 360 | 361 | test_acc += acc_ 362 | test_acc /= 10 363 | 364 | summary_test = tf.Summary(value=[tf.Summary.Value(tag='test_accuracy', simple_value=test_acc)]) 365 | self.writer.add_summary(summary_test, epoch) 366 | 367 | line = "Epoch: [%2d], test_acc: %.4f\n" % (epoch, test_acc) 368 | print(line) 369 | lr = "{} {}".format(gan_lr, cla_lr) 370 | with open('logs.txt', 'a') as f: 371 | f.write(line) 372 | with open('lr_logs.txt', 'a') as f : 373 | f.write(lr+'\n') 374 | 375 | # After an epoch, start_batch_id is set to zero 376 | # non-zero value is only for the first epoch after loading pre-trained model 377 | start_batch_id = 0 378 | 379 | # save model 380 | self.save(self.checkpoint_dir, counter) 381 | 382 | # show temporal results 383 | self.visualize_results(epoch) 384 | 385 | # save model for final step 386 | self.save(self.checkpoint_dir, counter) 387 | 388 | def visualize_results(self, epoch): 389 | # tot_num_samples = min(self.sample_num, self.batch_size) 390 | image_frame_dim = int(np.floor(np.sqrt(self.visual_num))) 391 | z_sample = np.random.uniform(-1, 1, size=(self.visual_num, self.z_dim)) 392 | 393 | """ random noise, random discrete code, fixed continuous code """ 394 | y = np.random.choice(self.len_discrete_code, self.visual_num) 395 | # Generated 10 labels with batch_size 396 | y_one_hot = np.zeros((self.visual_num, self.y_dim)) 397 | y_one_hot[np.arange(self.visual_num), y] = 1 398 | 399 | samples = self.sess.run(self.fake_images, feed_dict={self.visual_z: z_sample, self.visual_y: y_one_hot}) 400 | 401 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 402 | check_folder( 403 | self.result_dir + '/' + self.model_dir + '/all_classes') + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png') 404 | 405 | """ specified condition, random noise """ 406 | n_styles = 10 # must be less than or equal to self.batch_size 407 | 408 | np.random.seed() 409 | si = np.random.choice(self.visual_num, n_styles) 410 | 411 | for l in range(self.len_discrete_code): 412 | y = np.zeros(self.visual_num, dtype=np.int64) + l 413 | y_one_hot = np.zeros((self.visual_num, self.y_dim)) 414 | y_one_hot[np.arange(self.visual_num), y] = 1 415 | 416 | samples = self.sess.run(self.fake_images, feed_dict={self.visual_z: z_sample, self.visual_y: y_one_hot}) 417 | save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 418 | check_folder( 419 | self.result_dir + '/' + self.model_dir + '/class_%d' % l) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_class_%d.png' % l) 420 | 421 | samples = samples[si, :, :, :] 422 | 423 | if l == 0: 424 | all_samples = samples 425 | else: 426 | all_samples = np.concatenate((all_samples, samples), axis=0) 427 | 428 | """ save merged images to check style-consistency """ 429 | canvas = np.zeros_like(all_samples) 430 | for s in range(n_styles): 431 | for c in range(self.len_discrete_code): 432 | canvas[s * self.len_discrete_code + c, :, :, :] = all_samples[c * n_styles + s, :, :, :] 433 | 434 | save_images(canvas, [n_styles, self.len_discrete_code], 435 | check_folder( 436 | self.result_dir + '/' + self.model_dir + '/all_classes_style_by_style') + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes_style_by_style.png') 437 | 438 | @property 439 | def model_dir(self): 440 | return "{}_{}_{}_{}".format( 441 | self.model_name, self.dataset_name, 442 | self.batch_size, self.z_dim) 443 | 444 | def save(self, checkpoint_dir, step): 445 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 446 | 447 | if not os.path.exists(checkpoint_dir): 448 | os.makedirs(checkpoint_dir) 449 | 450 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 451 | 452 | def load(self, checkpoint_dir): 453 | import re 454 | print(" [*] Reading checkpoints...") 455 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) 456 | 457 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 458 | if ckpt and ckpt.model_checkpoint_path: 459 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 460 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 461 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 462 | print(" [*] Success to read {}".format(ckpt_name)) 463 | return True, counter 464 | else: 465 | print(" [*] Failed to find a checkpoint") 466 | return False, 0 467 | -------------------------------------------------------------------------------- /assests/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/accuracy.png -------------------------------------------------------------------------------- /assests/algorithm.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/algorithm.JPG -------------------------------------------------------------------------------- /assests/classification_result.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/classification_result.JPG -------------------------------------------------------------------------------- /assests/generated_image/all_class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/all_class.png -------------------------------------------------------------------------------- /assests/generated_image/class_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_0.png -------------------------------------------------------------------------------- /assests/generated_image/class_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_1.png -------------------------------------------------------------------------------- /assests/generated_image/class_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_2.png -------------------------------------------------------------------------------- /assests/generated_image/class_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_3.png -------------------------------------------------------------------------------- /assests/generated_image/class_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_4.png -------------------------------------------------------------------------------- /assests/generated_image/class_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_5.png -------------------------------------------------------------------------------- /assests/generated_image/class_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_6.png -------------------------------------------------------------------------------- /assests/generated_image/class_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_7.png -------------------------------------------------------------------------------- /assests/generated_image/class_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_8.png -------------------------------------------------------------------------------- /assests/generated_image/class_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/class_9.png -------------------------------------------------------------------------------- /assests/generated_image/style_by_style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/generated_image/style_by_style.png -------------------------------------------------------------------------------- /assests/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/loss.png -------------------------------------------------------------------------------- /assests/network.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/network.JPG -------------------------------------------------------------------------------- /assests/result.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/result.JPG -------------------------------------------------------------------------------- /assests/result2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/TripleGAN-Tensorflow/f12c9cac37fc3b99965c0146f39fac04e4f53d21/assests/result2.JPG -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from collections import defaultdict 4 | from keras.datasets import cifar10 5 | 6 | class_num = 10 7 | image_size = 32 8 | img_channels = 3 9 | 10 | 11 | def prepare_data(n): 12 | (train_data, train_labels), (test_data, test_labels) = cifar10.load_data() 13 | train_data, test_data = color_preprocessing(train_data, test_data) # pre-processing 14 | 15 | criteria = n//10 16 | input_dict, labelled_x, labelled_y, unlabelled_x, unlabelled_y = defaultdict(int), list(), list(), list(), list() 17 | 18 | for image, label in zip(train_data,train_labels) : 19 | if input_dict[int(label)] != criteria : 20 | input_dict[int(label)] += 1 21 | labelled_x.append(image) 22 | labelled_y.append(label) 23 | 24 | unlabelled_x.append(image) 25 | unlabelled_y.append(label) 26 | 27 | 28 | labelled_x = np.asarray(labelled_x) 29 | labelled_y = np.asarray(labelled_y) 30 | unlabelled_x = np.asarray(unlabelled_x) 31 | unlabelled_y = np.asarray(unlabelled_y) 32 | 33 | print("labelled data:", np.shape(labelled_x), np.shape(labelled_y)) 34 | print("unlabelled data :", np.shape(unlabelled_x), np.shape(unlabelled_y)) 35 | print("Test data :", np.shape(test_data), np.shape(test_labels)) 36 | print("======Load finished======") 37 | 38 | print("======Shuffling data======") 39 | indices = np.random.permutation(len(labelled_x)) 40 | labelled_x = labelled_x[indices] 41 | labelled_y = labelled_y[indices] 42 | 43 | indices = np.random.permutation(len(unlabelled_x)) 44 | unlabelled_x = unlabelled_x[indices] 45 | unlabelled_y = unlabelled_y[indices] 46 | 47 | print("======Prepare Finished======") 48 | 49 | 50 | labelled_y_vec = np.zeros((len(labelled_y), 10), dtype=np.float) 51 | for i, label in enumerate(labelled_y) : 52 | labelled_y_vec[i, labelled_y[i]] = 1.0 53 | 54 | unlabelled_y_vec = np.zeros((len(unlabelled_y), 10), dtype=np.float) 55 | for i, label in enumerate(unlabelled_y) : 56 | unlabelled_y_vec[i, unlabelled_y[i]] = 1.0 57 | 58 | test_labels_vec = np.zeros((len(test_labels), 10), dtype=np.float) 59 | for i, label in enumerate(test_labels) : 60 | test_labels_vec[i, test_labels[i]] = 1.0 61 | 62 | 63 | return labelled_x, labelled_y_vec, unlabelled_x, unlabelled_y_vec, test_data, test_labels_vec 64 | 65 | 66 | # ========================================================== # 67 | # ├─ _random_crop() 68 | # ├─ _random_flip_leftright() 69 | # ├─ data_augmentation() 70 | # └─ color_preprocessing() 71 | # ========================================================== # 72 | 73 | def _random_crop(batch, crop_shape, padding=None): 74 | oshape = np.shape(batch[0]) 75 | 76 | if padding: 77 | oshape = (oshape[0] + 2 * padding, oshape[1] + 2 * padding) 78 | new_batch = [] 79 | npad = ((padding, padding), (padding, padding), (0, 0)) 80 | for i in range(len(batch)): 81 | new_batch.append(batch[i]) 82 | if padding: 83 | new_batch[i] = np.lib.pad(batch[i], pad_width=npad, 84 | mode='constant', constant_values=0) 85 | nh = random.randint(0, oshape[0] - crop_shape[0]) 86 | nw = random.randint(0, oshape[1] - crop_shape[1]) 87 | new_batch[i] = new_batch[i][nh:nh + crop_shape[0], 88 | nw:nw + crop_shape[1]] 89 | return new_batch 90 | 91 | 92 | def _random_flip_leftright(batch): 93 | for i in range(len(batch)): 94 | if bool(random.getrandbits(1)): 95 | batch[i] = np.fliplr(batch[i]) 96 | return batch 97 | 98 | 99 | def color_preprocessing(x_train, x_test): 100 | """ 101 | x_train = x_train.astype('float32') 102 | x_test = x_test.astype('float32') 103 | x_train[:, :, :, 0] = (x_train[:, :, :, 0] - np.mean(x_train[:, :, :, 0])) / np.std(x_train[:, :, :, 0]) 104 | x_train[:, :, :, 1] = (x_train[:, :, :, 1] - np.mean(x_train[:, :, :, 1])) / np.std(x_train[:, :, :, 1]) 105 | x_train[:, :, :, 2] = (x_train[:, :, :, 2] - np.mean(x_train[:, :, :, 2])) / np.std(x_train[:, :, :, 2]) 106 | 107 | x_test[:, :, :, 0] = (x_test[:, :, :, 0] - np.mean(x_test[:, :, :, 0])) / np.std(x_test[:, :, :, 0]) 108 | x_test[:, :, :, 1] = (x_test[:, :, :, 1] - np.mean(x_test[:, :, :, 1])) / np.std(x_test[:, :, :, 1]) 109 | x_test[:, :, :, 2] = (x_test[:, :, :, 2] - np.mean(x_test[:, :, :, 2])) / np.std(x_test[:, :, :, 2]) 110 | """ 111 | x_train = x_train/127.5 - 1 112 | x_test = x_test/127.5 - 1 113 | return x_train, x_test 114 | 115 | 116 | def data_augmentation(batch): 117 | batch = _random_flip_leftright(batch) 118 | batch = _random_crop(batch, [32, 32], 4) 119 | return batch -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from TripleGAN import TripleGAN 2 | 3 | from utils import show_all_variables 4 | from utils import check_folder 5 | 6 | import tensorflow as tf 7 | import argparse 8 | 9 | """parsing and configuration""" 10 | def parse_args(): 11 | desc = "Tensorflow implementation of TripleGAN" 12 | parser = argparse.ArgumentParser(description=desc) 13 | parser.add_argument('--n', type=int, default=4000, help='The number of dataset') 14 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['mnist', 'fashion-mnist', 'celebA', 'cifar10'], 15 | help='The name of dataset') 16 | # In now, only cifar 10... 17 | parser.add_argument('--epoch', type=int, default=1000, help='The number of epochs to run') 18 | parser.add_argument('--batch_size', type=int, default=20, help='The size of batch') 19 | parser.add_argument('--unlabel_batch_size', type=int, default=250, help='The size of unlabel batch') 20 | parser.add_argument('--z_dim', type=int, default=100, help='Dimension of noise vector') 21 | parser.add_argument('--gan_lr', type=float, default=2e-4, help='learning rate of GAN') 22 | parser.add_argument('--cla_lr', type=float, default=2e-3, help='learning rate of Classify') 23 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 24 | help='Directory name to save the checkpoints') 25 | parser.add_argument('--result_dir', type=str, default='results', 26 | help='Directory name to save the generated images') 27 | parser.add_argument('--log_dir', type=str, default='logs', 28 | help='Directory name to save training logs') 29 | 30 | return check_args(parser.parse_args()) 31 | 32 | """checking arguments""" 33 | def check_args(args): 34 | # --checkpoint_dir 35 | check_folder(args.checkpoint_dir) 36 | 37 | # --result_dir 38 | check_folder(args.result_dir) 39 | 40 | # --result_dir 41 | check_folder(args.log_dir) 42 | 43 | # --epoch 44 | try: 45 | assert args.epoch >= 1 46 | except: 47 | print('number of epochs must be larger than or equal to one') 48 | 49 | # --batch_size 50 | try: 51 | assert args.batch_size >= 1 52 | except: 53 | print('batch size must be larger than or equal to one') 54 | 55 | # --z_dim 56 | try: 57 | assert args.z_dim >= 1 58 | except: 59 | print('dimension of noise vector must be larger than or equal to one') 60 | 61 | return args 62 | 63 | """main""" 64 | def main(): 65 | # parse arguments 66 | args = parse_args() 67 | if args is None: 68 | exit() 69 | 70 | # open session 71 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 72 | gan = TripleGAN(sess, epoch=args.epoch, batch_size=args.batch_size, unlabel_batch_size=args.unlabel_batch_size, 73 | z_dim=args.z_dim, dataset_name=args.dataset, n=args.n, gan_lr = args.gan_lr, cla_lr = args.cla_lr, 74 | checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir) 75 | 76 | # build graph 77 | gan.build_model() 78 | 79 | # show network architecture 80 | show_all_variables() 81 | 82 | # launch the graph in a session 83 | gan.train() 84 | print(" [*] Training finished!") 85 | 86 | # visualize learned generator 87 | gan.visualize_results(args.epoch-1) 88 | print(" [*] Testing finished!") 89 | 90 | if __name__ == '__main__': 91 | main() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tflearn import global_avg_pool 3 | from tensorflow.contrib.layers import variance_scaling_initializer 4 | import numpy as np 5 | import math 6 | 7 | he_init = variance_scaling_initializer() 8 | # he_init = tf.truncated_normal_initializer(stddev=0.02) 9 | """ 10 | The weight norm is not implemented at this time. 11 | """ 12 | 13 | def weight_norm(x, output_dim) : 14 | input_dim = int(x.get_shape()[-1]) 15 | g = tf.get_variable('g_scalar', shape=[output_dim], dtype=tf.float32, initializer=tf.ones_initializer()) 16 | w = tf.get_variable('weight', shape=[input_dim, output_dim], dtype=tf.float32, initializer=he_init) 17 | w_init = tf.nn.l2_normalize(w, dim=0) * g # SAME dim=1 18 | 19 | return tf.variables_initializer(w_init) 20 | 21 | def conv_layer(x, filter_size, kernel, stride=1, padding='SAME', wn=False, layer_name="conv"): 22 | with tf.name_scope(layer_name): 23 | if wn: 24 | w_init = weight_norm(x, filter_size) 25 | 26 | x = tf.layers.conv2d(inputs=x, filters=filter_size, kernel_size=kernel, kernel_initializer=w_init, strides=stride, padding=padding) 27 | else : 28 | x = tf.layers.conv2d(inputs=x, filters=filter_size, kernel_size=kernel, kernel_initializer=he_init, strides=stride, padding=padding) 29 | return x 30 | 31 | 32 | def deconv_layer(x, filter_size, kernel, stride=1, padding='SAME', wn=False, layer_name='deconv'): 33 | with tf.name_scope(layer_name): 34 | if wn : 35 | w_init = weight_norm(x, filter_size) 36 | x = tf.layers.conv2d_transpose(inputs=x, filters=filter_size, kernel_size=kernel, kernel_initializer=w_init, strides=stride, padding=padding) 37 | else : 38 | x = tf.layers.conv2d_transpose(inputs=x, filters=filter_size, kernel_size=kernel, kernel_initializer=he_init, strides=stride, padding=padding) 39 | return x 40 | 41 | 42 | def linear(x, unit, wn=False, layer_name='linear'): 43 | with tf.name_scope(layer_name): 44 | if wn : 45 | w_init = weight_norm(x, unit) 46 | x = tf.layers.dense(inputs=x, units=unit, kernel_initializer=w_init) 47 | else : 48 | x = tf.layers.dense(inputs=x, units=unit, kernel_initializer=he_init) 49 | return x 50 | 51 | 52 | def nin(x, unit, wn=False, layer_name='nin'): 53 | # https://github.com/openai/weightnorm/blob/master/tensorflow/nn.py 54 | with tf.name_scope(layer_name): 55 | s = list(map(int, x.get_shape())) 56 | x = tf.reshape(x, [np.prod(s[:-1]), s[-1]]) 57 | x = linear(x, unit, wn, layer_name) 58 | x = tf.reshape(x, s[:-1] + [unit]) 59 | 60 | 61 | return x 62 | 63 | 64 | def gaussian_noise_layer(x, std=0.15): 65 | noise = tf.random_normal(shape=tf.shape(x), mean=0.0, stddev=std, dtype=tf.float32) 66 | return x + noise 67 | 68 | def Global_Average_Pooling(x): 69 | return global_avg_pool(x, name='Global_avg_pooling') 70 | 71 | 72 | def max_pooling(x, kernel, stride): 73 | return tf.layers.max_pooling2d(x, pool_size=kernel, strides=stride, padding='VALID') 74 | 75 | 76 | def flatten(x): 77 | return tf.contrib.layers.flatten(x) 78 | 79 | 80 | def lrelu(x, leak=0.2, name="lrelu"): 81 | return tf.maximum(x, leak * x) 82 | 83 | 84 | def sigmoid(x): 85 | return tf.nn.sigmoid(x) 86 | 87 | 88 | def relu(x): 89 | return tf.nn.relu(x) 90 | 91 | 92 | def tanh(x): 93 | return tf.nn.tanh(x) 94 | 95 | def conv_concat(x, y): 96 | x_shapes = x.get_shape() 97 | y_shapes = y.get_shape() 98 | 99 | return concat([x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], axis=3) 100 | 101 | 102 | def concat(x, axis=1): 103 | return tf.concat(x, axis=axis) 104 | 105 | 106 | def reshape(x, shape): 107 | return tf.reshape(x, shape=shape) 108 | 109 | 110 | def batch_norm(x, is_training, scope): 111 | return tf.contrib.layers.batch_norm(x, 112 | decay=0.9, 113 | updates_collections=None, 114 | epsilon=1e-5, 115 | scale=True, 116 | is_training=is_training, 117 | scope=scope) 118 | 119 | def instance_norm(x, is_training, scope): 120 | with tf.variable_scope(scope): 121 | epsilon = 1e-5 122 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 123 | scale = tf.get_variable('scale', [x.get_shape()[-1]], 124 | initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02)) 125 | offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0)) 126 | out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset 127 | 128 | return out 129 | 130 | def dropout(x, rate, is_training): 131 | return tf.layers.dropout(inputs=x, rate=rate, training=is_training) 132 | 133 | def rampup(epoch): 134 | if epoch < 80: 135 | p = max(0.0, float(epoch)) / float(80) 136 | p = 1.0 - p 137 | return math.exp(-p*p*5.0) 138 | else: 139 | return 1.0 140 | 141 | def rampdown(epoch): 142 | if epoch >= (300 - 50): 143 | ep = (epoch - (300 - 50)) * 0.5 144 | return math.exp(-(ep * ep) / 50) 145 | else: 146 | return 1.0 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc, os 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | 6 | def show_all_variables(): 7 | model_vars = tf.trainable_variables() 8 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 9 | 10 | def check_folder(log_dir): 11 | if not os.path.exists(log_dir): 12 | os.makedirs(log_dir) 13 | return log_dir 14 | 15 | def save_images(images, size, image_path): 16 | return imsave(inverse_transform(images), size, image_path) 17 | 18 | def imsave(images, size, path): 19 | image = np.squeeze(merge(images, size)) 20 | return scipy.misc.imsave(path, image) 21 | 22 | def inverse_transform(images): 23 | return (images+1.)/2. 24 | # return ((images + 1.) * 127.5).astype('uint8') 25 | 26 | def merge(images, size): 27 | h, w = images.shape[1], images.shape[2] 28 | if (images.shape[3] in (3,4)): 29 | c = images.shape[3] 30 | img = np.zeros((h * size[0], w * size[1], c)) 31 | for idx, image in enumerate(images): 32 | i = idx % size[1] 33 | j = idx // size[1] 34 | img[j * h:j * h + h, i * w:i * w + w, :] = image 35 | return img 36 | elif images.shape[3]==1: 37 | img = np.zeros((h * size[0], w * size[1])) 38 | for idx, image in enumerate(images): 39 | i = idx % size[1] 40 | j = idx // size[1] 41 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 42 | return img 43 | else: 44 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 45 | --------------------------------------------------------------------------------