├── save_para └── save_para.txt ├── dataset └── readme.txt ├── IMGS ├── paper.jpg ├── cat2bird.gif ├── cifar10.jpg ├── car2plane.gif ├── imagenet64.jpg ├── ship2horse.gif ├── architecture.jpg └── cifar10-trun-2.jpg ├── LICENSE ├── generate_64.py ├── test_networks_32.py ├── test_ops.py ├── README.md ├── networks_32.py ├── networks_64.py ├── generate_32.py ├── networks_128.py ├── train_64.py ├── train_32.py ├── utils.py └── ops.py /save_para/save_para.txt: -------------------------------------------------------------------------------- 1 | save_para -------------------------------------------------------------------------------- /dataset/readme.txt: -------------------------------------------------------------------------------- 1 | put into dataset -------------------------------------------------------------------------------- /IMGS/paper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/paper.jpg -------------------------------------------------------------------------------- /IMGS/cat2bird.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/cat2bird.gif -------------------------------------------------------------------------------- /IMGS/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/cifar10.jpg -------------------------------------------------------------------------------- /IMGS/car2plane.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/car2plane.gif -------------------------------------------------------------------------------- /IMGS/imagenet64.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/imagenet64.jpg -------------------------------------------------------------------------------- /IMGS/ship2horse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/ship2horse.gif -------------------------------------------------------------------------------- /IMGS/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/architecture.jpg -------------------------------------------------------------------------------- /IMGS/cifar10-trun-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/BigGAN-tensorflow/HEAD/IMGS/cifar10-trun-2.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 MingtaoGuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /generate_64.py: -------------------------------------------------------------------------------- 1 | from networks_64 import Generator 2 | import tensorflow as tf 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | NUMS_GEN = 64 8 | NUMS_CLASS = 40 9 | BATCH_SIZE = 64 10 | Z_DIM = 128 11 | IMG_H = 64 12 | IMG_W = 64 13 | 14 | def generate(): 15 | if not os.path.exists("./generate"): 16 | os.mkdir("./generate") 17 | train_phase = tf.placeholder(tf.bool) 18 | z = tf.random_normal([BATCH_SIZE, Z_DIM]) 19 | y = tf.placeholder(tf.int32, [None]) 20 | G = Generator("generator") 21 | fake_img = G(z, train_phase, y, NUMS_CLASS) 22 | sess = tf.Session() 23 | sess.run(tf.global_variables_initializer()) 24 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "generator")) 25 | saver.restore(sess, "./save_para64/.\\model.ckpt") 26 | 27 | for nums_c in range(NUMS_CLASS): 28 | FAKE_IMG = sess.run(fake_img, feed_dict={train_phase: False, y: nums_c * np.ones([NUMS_GEN])}) 29 | concat_img = np.zeros([8*IMG_H, 8*IMG_W, 3]) 30 | c = 0 31 | for i in range(8): 32 | for j in range(8): 33 | concat_img[i*IMG_H:i*IMG_H+IMG_H, j*IMG_W:j*IMG_W+IMG_W] = FAKE_IMG[c] 34 | c += 1 35 | Image.fromarray(np.uint8((concat_img + 1) * 127.5)).save("./generate/"+str(nums_c)+".jpg") 36 | 37 | if __name__ == "__main__": 38 | generate() 39 | -------------------------------------------------------------------------------- /test_networks_32.py: -------------------------------------------------------------------------------- 1 | from test_ops import * 2 | 3 | 4 | 5 | class test_Generator: 6 | def __init__(self, name): 7 | self.name = name 8 | 9 | def __call__(self, inputs, train_phase, y1, y2, nums_class, alpha): 10 | z_dim = int(inputs.shape[-1]) 11 | nums_layer = 3 12 | remain = z_dim % 3 13 | chunk_size = (z_dim - remain) // nums_layer 14 | z_split = tf.split(inputs, [chunk_size] * (nums_layer - 1) + [chunk_size + remain], axis=1) 15 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 16 | inputs = dense("dense", inputs, 256*4*4) 17 | inputs = tf.reshape(inputs, [-1, 4, 4, 256]) 18 | inputs = test_G_Resblock("ResBlock1", inputs, 256, train_phase, z_split[0], y1, y2, nums_class, alpha) 19 | inputs = test_G_Resblock("ResBlock2", inputs, 256, train_phase, z_split[1], y1, y2, nums_class, alpha) 20 | inputs = non_local("Non-local", inputs, None, True) 21 | inputs = test_G_Resblock("ResBlock3", inputs, 256, train_phase, z_split[2], y1, y2, nums_class, alpha) 22 | inputs = relu(conditional_batchnorm(inputs, train_phase, "BN")) 23 | inputs = conv("conv", inputs, k_size=3, nums_out=3, strides=1, is_sn=True) 24 | return tf.nn.tanh(inputs) 25 | 26 | def var_list(self): 27 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 28 | 29 | 30 | -------------------------------------------------------------------------------- /test_ops.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | 3 | 4 | def test_G_Resblock(name, inputs, nums_out, is_training, splited_z, y1, y2, nums_class, alpha): 5 | with tf.variable_scope(name): 6 | temp = tf.identity(inputs) 7 | inputs = test_conditional_batchnorm(inputs, is_training, "bn1", splited_z, y1, y2, nums_class, alpha) 8 | inputs = relu(inputs) 9 | inputs = upsampling(inputs) 10 | inputs = conv("conv1", inputs, nums_out, 3, 1, is_sn=True) 11 | inputs = test_conditional_batchnorm(inputs, is_training, "bn2", splited_z, y1, y2, nums_class, alpha) 12 | inputs = relu(inputs) 13 | inputs = conv("conv2", inputs, nums_out, 3, 1, is_sn=True) 14 | #Identity mapping 15 | temp = upsampling(temp) 16 | temp = conv("identity", temp, nums_out, 1, 1, is_sn=True) 17 | return inputs + temp 18 | 19 | def test_conditional_batchnorm(x, train_phase, scope_bn, splited_z=None, y1=None, y2=None, nums_class=10, alpha=None): 20 | # Batch Normalization 21 | # Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[J]. 2015:448-456. 22 | with tf.variable_scope(scope_bn): 23 | if y1 == None: 24 | beta = tf.get_variable(name=scope_bn + 'beta', shape=[x.shape[-1]], 25 | initializer=tf.constant_initializer([0.]), trainable=True) # label_nums x C 26 | gamma = tf.get_variable(name=scope_bn + 'gamma', shape=[x.shape[-1]], 27 | initializer=tf.constant_initializer([1.]), trainable=True) # label_nums x C 28 | else: 29 | y1 = tf.one_hot(y1, nums_class) 30 | y2 = tf.one_hot(y2, nums_class) 31 | y = y1 * alpha + y2 * (1 - alpha) 32 | z = tf.concat([splited_z, y], axis=1) 33 | gamma = dense("gamma", z, x.shape[-1], None, True) 34 | beta = dense("beta", z, x.shape[-1], None, True) 35 | gamma = tf.reshape(gamma, [-1, 1, 1, x.shape[-1]]) 36 | beta = tf.reshape(beta, [-1, 1, 1, x.shape[-1]]) 37 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments', keep_dims=True) 38 | ema = tf.train.ExponentialMovingAverage(decay=0.5) 39 | 40 | def mean_var_with_update(): 41 | ema_apply_op = ema.apply([batch_mean, batch_var]) 42 | with tf.control_dependencies([ema_apply_op]): 43 | return tf.identity(batch_mean), tf.identity(batch_var) 44 | 45 | mean, var = tf.cond(train_phase, mean_var_with_update, 46 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 47 | normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3) 48 | return normed -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BigGAN-tensorflow 2 | Reimplementation of the Paper: Large Scale GAN Training for High Fidelity Natural Image Synthesis 3 | 4 | # Introduction 5 | Simply implement the great paper [(BigGAN)Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/pdf/1809.11096), which can generate very realistic images. However, due to my poor device :sob:, I just train the image of size 32x32 of cifar-10 and the image of size 64x64 of Imagenet64. By the way, the training procedure is really slow. 6 | 7 | From the paper: 8 | ![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/paper.jpg) 9 | 10 | # Dataset 11 | 1. Image 32x32: cifar-10: http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz 12 | 2. Image 64x64: ImageNet64: https://drive.google.com/open?id=1uN9O69eeqJEPV797d05ZuUmJ23kGVtfU 13 | 14 | Just download the dataset, and put them into the folder 'dataset' 15 | 16 | # Architecture 17 | ![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/architecture.jpg) 18 | 19 | # Results 20 | 32x32 Cifar-10 21 | -------------- 22 | #### Configuration: 23 | Training iteration: 100,000 24 | Truncation threshold: 1.0 25 | 26 | ||Discriminator|Generator| 27 | |-|-|-| 28 | |Update step|2|1| 29 | |Learning rate|4e-4|1e-4| 30 | |Orthogonal reg|:heavy_check_mark:|:heavy_check_mark:| 31 | |Orthogonal init|:heavy_check_mark:|:heavy_check_mark:| 32 | |Hierarchical latent|:x:|:heavy_check_mark:| 33 | |Projection batchnorm|:heavy_check_mark:|:x:| 34 | |Truncation threshold|:x:|:heavy_check_mark:| 35 | 36 | #### Generation: 37 | Truncation threshold = 1.0, A little mode collapse (truncation threshold is too small). 38 | 39 | ![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/cifar10.jpg) 40 | 41 | Truncation threshold = 2.0. 42 | 43 | ![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/cifar10-trun-2.jpg) 44 | 45 | |car2plane|ship2horse|cat2bird| 46 | |-|-|-| 47 | |![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/car2plane.gif)|![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/ship2horse.gif)|![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/cat2bird.gif)| 48 | 49 | 64x64 ImageNet 50 | -------------- 51 | #### Configuration: 52 | Training iteration: 100,000 53 | 54 | ||Discriminator|Generator| 55 | |-|-|-| 56 | |Update step|2|1| 57 | |Learning rate|4e-4|1e-4| 58 | |Orthogonal reg|:heavy_check_mark:|:heavy_check_mark:| 59 | |Orthogonal init|:heavy_check_mark:|:heavy_check_mark:| 60 | |Hierarchical latent|:x:|:heavy_check_mark:| 61 | |Projection batchnorm|:heavy_check_mark:|:x:| 62 | |Truncation threshold|:x:|:heavy_check_mark:| 63 | 64 | Iteration: 30,000 65 | ![](https://github.com/MingtaoGuo/BigGAN-tensorflow/blob/master/IMGS/imagenet64.jpg) 66 | Iteration: 60,000 67 | Under training .......... 68 | ----------- 69 | To be continue. 70 | ----------------- 71 | -------------------------------------------------------------------------------- /networks_32.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | 3 | 4 | 5 | class Generator: 6 | def __init__(self, name): 7 | self.name = name 8 | 9 | def __call__(self, inputs, train_phase, y, nums_class): 10 | z_dim = int(inputs.shape[-1]) 11 | nums_layer = 3 12 | remain = z_dim % 3 13 | chunk_size = (z_dim - remain) // nums_layer 14 | z_split = tf.split(inputs, [chunk_size] * (nums_layer - 1) + [chunk_size + remain], axis=1) 15 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 16 | inputs = dense("dense", inputs, 256*4*4) 17 | inputs = tf.reshape(inputs, [-1, 4, 4, 256]) 18 | inputs = G_Resblock("ResBlock1", inputs, 256, train_phase, z_split[0], y, nums_class) 19 | inputs = G_Resblock("ResBlock2", inputs, 256, train_phase, z_split[1], y, nums_class) 20 | inputs = non_local("Non-local", inputs, None, True) 21 | inputs = G_Resblock("ResBlock3", inputs, 256, train_phase, z_split[2], y, nums_class) 22 | inputs = relu(conditional_batchnorm(inputs, train_phase, "BN")) 23 | inputs = conv("conv", inputs, k_size=3, nums_out=3, strides=1, is_sn=True) 24 | return tf.nn.tanh(inputs) 25 | 26 | def var_list(self): 27 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 28 | 29 | class Discriminator: 30 | def __init__(self, name): 31 | self.name = name 32 | 33 | def __call__(self, inputs, y, nums_class, update_collection=None): 34 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 35 | inputs = D_FirstResblock("ResBlock1", inputs, 128, update_collection, is_down=True) 36 | inputs = non_local("Non-local", inputs, update_collection, True) 37 | inputs = D_Resblock("ResBlock2", inputs, 128, update_collection, is_down=True) 38 | inputs = D_Resblock("ResBlock3", inputs, 128, update_collection, is_down=False) 39 | inputs = D_Resblock("ResBlock4", inputs, 128, update_collection, is_down=False) 40 | inputs = relu(inputs) 41 | inputs = global_sum_pooling(inputs) 42 | temp = Inner_product(inputs, y, nums_class, update_collection) 43 | inputs = dense("dense", inputs, 1, update_collection, is_sn=True) 44 | inputs= temp + inputs 45 | return inputs 46 | 47 | def var_list(self): 48 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 49 | 50 | if __name__ == "__main__": 51 | x = tf.placeholder(tf.float32, [None, 32, 32, 3]) 52 | z = tf.placeholder(tf.float32, [None, 100]) 53 | y = tf.placeholder(tf.float32, [None, 100]) 54 | train_phase = tf.placeholder(tf.bool) 55 | G = Generator("generator") 56 | D = Discriminator("discriminator") 57 | fake_img = G(z, train_phase) 58 | fake_logit = D(fake_img) 59 | aaa = 0 60 | 61 | -------------------------------------------------------------------------------- /networks_64.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | 3 | class Generator: 4 | def __init__(self, name): 5 | self.name = name 6 | 7 | def __call__(self, inputs, train_phase, y, nums_class): 8 | z_dim = int(inputs.shape[-1]) 9 | nums_layer = 4 10 | remain = z_dim % nums_layer 11 | chunk_size = (z_dim - remain) // nums_layer 12 | z_split = tf.split(inputs, [chunk_size] * (nums_layer - 1) + [chunk_size + remain], axis=1) 13 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 14 | inputs = dense("dense", inputs, 1024*4*4) 15 | inputs = tf.reshape(inputs, [-1, 4, 4, 1024]) 16 | inputs = G_Resblock("ResBlock1", inputs, 1024, train_phase, z_split[0], y, nums_class) 17 | inputs = G_Resblock("ResBlock2", inputs, 512, train_phase, z_split[1], y, nums_class) 18 | inputs = G_Resblock("ResBlock3", inputs, 256, train_phase, z_split[2], y, nums_class) 19 | inputs = non_local("Non-local", inputs, None, True) 20 | inputs = G_Resblock("ResBlock4", inputs, 128, train_phase, z_split[3], y, nums_class) 21 | inputs = relu(conditional_batchnorm(inputs, train_phase, "BN")) 22 | inputs = conv("conv", inputs, k_size=3, nums_out=3, strides=1) 23 | return tf.nn.tanh(inputs) 24 | 25 | def var_list(self): 26 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 27 | 28 | class Discriminator: 29 | def __init__(self, name): 30 | self.name = name 31 | 32 | def __call__(self, inputs, y, nums_class, update_collection=None): 33 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 34 | inputs = D_FirstResblock("ResBlock1", inputs, 64, update_collection, is_down=True) 35 | inputs = non_local("Non-local", inputs, None, True) 36 | inputs = D_Resblock("ResBlock2", inputs, 128, update_collection, is_down=True) 37 | inputs = D_Resblock("ResBlock3", inputs, 256, update_collection, is_down=True) 38 | inputs = D_Resblock("ResBlock4", inputs, 512, update_collection, is_down=True) 39 | inputs = D_Resblock("ResBlock5", inputs, 1024, update_collection, is_down=True) 40 | inputs = relu(inputs) 41 | inputs = global_sum_pooling(inputs) 42 | temp = Inner_product(inputs, y, nums_class, update_collection) 43 | inputs = dense("dense", inputs, 1, update_collection, is_sn=True) 44 | inputs= temp + inputs 45 | return inputs 46 | 47 | def var_list(self): 48 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 49 | 50 | if __name__ == "__main__": 51 | x = tf.placeholder(tf.float32, [None, 32, 32, 3]) 52 | z = tf.placeholder(tf.float32, [None, 100]) 53 | y = tf.placeholder(tf.float32, [None, 100]) 54 | train_phase = tf.placeholder(tf.bool) 55 | G = Generator("generator") 56 | D = Discriminator("discriminator") 57 | fake_img = G(z, train_phase) 58 | fake_logit = D(fake_img) 59 | aaa = 0 60 | 61 | -------------------------------------------------------------------------------- /generate_32.py: -------------------------------------------------------------------------------- 1 | from networks_32 import Generator 2 | from test_networks_32 import test_Generator 3 | import tensorflow as tf 4 | import numpy as np 5 | from PIL import Image 6 | from utils import truncated_noise_sample 7 | import os 8 | 9 | NUMS_GEN = 64 10 | NUMS_CLASS = 10 11 | BATCH_SIZE = 64 12 | Z_DIM = 128 13 | IMG_H = 32 14 | IMG_W = 32 15 | 16 | def Consecutive_category_morphing(): 17 | if not os.path.exists("./generate"): 18 | os.mkdir("./generate") 19 | train_phase = tf.placeholder(tf.bool) 20 | z = tf.placeholder(tf.float32, [BATCH_SIZE, Z_DIM]) 21 | y1 = tf.placeholder(tf.int32, [None]) 22 | y2 = tf.placeholder(tf.int32, [None]) 23 | alpha = tf.placeholder(tf.float32, [None, 1]) 24 | G = test_Generator("generator") 25 | fake_img = G(z, train_phase, y1, y2, NUMS_CLASS, alpha) 26 | sess = tf.Session() 27 | sess.run(tf.global_variables_initializer()) 28 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "generator")) 29 | saver.restore(sess, "./save_para/.\\model.ckpt") 30 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM) 31 | CLASS1 = 2 32 | CLASS2 = 3 33 | count = 0 34 | for ALPHA in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: 35 | ALPHA = np.ones([BATCH_SIZE, 1]) * ALPHA 36 | FAKE_IMG = sess.run(fake_img, feed_dict={z: Z, y1: CLASS1 * np.ones([BATCH_SIZE]), y2: CLASS2 * np.ones([BATCH_SIZE]), train_phase: False, alpha: ALPHA}) 37 | concat_img = np.zeros([8 * IMG_H, 8 * IMG_W, 3]) 38 | c = 0 39 | for i in range(8): 40 | for j in range(8): 41 | concat_img[i * IMG_H:i * IMG_H + IMG_H, j * IMG_W:j * IMG_W + IMG_W] = FAKE_IMG[c] 42 | c += 1 43 | Image.fromarray(np.uint8((concat_img + 1) * 127.5)).save("./generate/" + str(count) + ".jpg") 44 | count += 1 45 | 46 | 47 | def generate(): 48 | if not os.path.exists("./generate"): 49 | os.mkdir("./generate") 50 | train_phase = tf.placeholder(tf.bool) 51 | z = tf.placeholder(tf.float32, [BATCH_SIZE, Z_DIM]) 52 | y = tf.placeholder(tf.int32, [None]) 53 | G = Generator("generator") 54 | fake_img = G(z, train_phase, y, NUMS_CLASS) 55 | sess = tf.Session() 56 | sess.run(tf.global_variables_initializer()) 57 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "generator")) 58 | saver.restore(sess, "./save_para/.\\model.ckpt") 59 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM) 60 | for nums_c in range(NUMS_CLASS): 61 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM) 62 | FAKE_IMG = sess.run(fake_img, feed_dict={z: Z, train_phase: False, y: nums_c * np.ones([NUMS_GEN])}) 63 | concat_img = np.zeros([8*IMG_H, 8*IMG_W, 3]) 64 | c = 0 65 | for i in range(8): 66 | for j in range(8): 67 | concat_img[i*IMG_H:i*IMG_H+IMG_H, j*IMG_W:j*IMG_W+IMG_W] = FAKE_IMG[c] 68 | c += 1 69 | Image.fromarray(np.uint8((concat_img + 1) * 127.5)).save("./generate/"+str(nums_c)+".jpg") 70 | 71 | if __name__ == "__main__": 72 | Consecutive_category_morphing() 73 | -------------------------------------------------------------------------------- /networks_128.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | 3 | class Generator: 4 | def __init__(self, name): 5 | self.name = name 6 | 7 | def __call__(self, inputs, train_phase, y, nums_class): 8 | z_dim = int(inputs.shape[-1]) 9 | nums_layer = 5 10 | remain = z_dim % nums_layer 11 | chunk_size = (z_dim - remain) // nums_layer 12 | z_split = tf.split(inputs, [chunk_size] * (nums_layer - 1) + [chunk_size + remain], axis=1) 13 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 14 | inputs = dense("dense", inputs, 1024*4*4) 15 | inputs = tf.reshape(inputs, [-1, 4, 4, 1024]) 16 | inputs = G_Resblock("ResBlock1", inputs, 1024, train_phase, z_split[0], y, nums_class) 17 | inputs = G_Resblock("ResBlock2", inputs, 512, train_phase, z_split[1], y, nums_class) 18 | inputs = G_Resblock("ResBlock3", inputs, 256, train_phase, z_split[2], y, nums_class) 19 | inputs = G_Resblock("ResBlock4", inputs, 128, train_phase, z_split[3], y, nums_class) 20 | inputs = non_local("Non-local", inputs, None, True) 21 | inputs = G_Resblock("ResBlock5", inputs, 64, train_phase, z_split[4], y, nums_class) 22 | inputs = relu(conditional_batchnorm(inputs, train_phase, "BN")) 23 | inputs = conv("conv", inputs, k_size=3, nums_out=3, strides=1) 24 | return tf.nn.tanh(inputs) 25 | 26 | def var_list(self): 27 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 28 | 29 | class Discriminator: 30 | def __init__(self, name): 31 | self.name = name 32 | 33 | def __call__(self, inputs, y, nums_class, update_collection=None): 34 | with tf.variable_scope(name_or_scope=self.name, reuse=tf.AUTO_REUSE): 35 | inputs = D_FirstResblock("ResBlock1", inputs, 64, update_collection, is_down=True) 36 | inputs = non_local("Non-local", inputs, None, True) 37 | inputs = D_Resblock("ResBlock2", inputs, 128, update_collection, is_down=True) 38 | inputs = D_Resblock("ResBlock3", inputs, 256, update_collection, is_down=True) 39 | inputs = D_Resblock("ResBlock4", inputs, 512, update_collection, is_down=True) 40 | inputs = D_Resblock("ResBlock5", inputs, 1024, update_collection, is_down=True) 41 | inputs = D_Resblock("ResBlock6", inputs, 1024, update_collection, is_down=False) 42 | inputs = relu(inputs) 43 | inputs = global_sum_pooling(inputs) 44 | temp = Inner_product(inputs, y, nums_class, update_collection) 45 | inputs = dense("dense", inputs, 1, update_collection, is_sn=True) 46 | inputs= temp + inputs 47 | return inputs 48 | 49 | def var_list(self): 50 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.name) 51 | 52 | if __name__ == "__main__": 53 | x = tf.placeholder(tf.float32, [None, 32, 32, 3]) 54 | z = tf.placeholder(tf.float32, [None, 100]) 55 | y = tf.placeholder(tf.float32, [None, 100]) 56 | train_phase = tf.placeholder(tf.bool) 57 | G = Generator("generator") 58 | D = Discriminator("discriminator") 59 | fake_img = G(z, train_phase) 60 | fake_logit = D(fake_img) 61 | aaa = 0 62 | 63 | -------------------------------------------------------------------------------- /train_64.py: -------------------------------------------------------------------------------- 1 | from networks_64 import Generator, Discriminator 2 | from ops import Hinge_loss, ortho_reg 3 | import tensorflow as tf 4 | import numpy as np 5 | from utils import read_imagenet, truncated_noise_sample 6 | from PIL import Image 7 | import time 8 | import scipy.io as sio 9 | 10 | NUMS_CLASS = 40 11 | BETA = 1e-4 12 | IMG_H = 64 13 | IMG_W = 64 14 | Z_DIM = 128 15 | BATCH_SIZE = 64 16 | TRAIN_ITR = 100000 17 | TRUNCATION = 2.0 18 | 19 | def Train(): 20 | x = tf.placeholder(tf.float32, [None, IMG_H, IMG_W, 3]) 21 | train_phase = tf.placeholder(tf.bool) 22 | z = tf.placeholder(tf.float32, [None, Z_DIM]) 23 | y = tf.placeholder(tf.int32, [None]) 24 | G = Generator("generator") 25 | D = Discriminator("discriminator") 26 | fake_img = G(z, train_phase, y, NUMS_CLASS) 27 | fake_logits = D(fake_img, y, NUMS_CLASS, None) 28 | real_logits = D(x, y, NUMS_CLASS, "NO_OPS") 29 | D_loss, G_loss = Hinge_loss(real_logits, fake_logits) 30 | D_ortho = BETA * ortho_reg(D.var_list()) 31 | G_ortho = BETA * ortho_reg(G.var_list()) 32 | D_loss += D_ortho 33 | G_loss += G_ortho 34 | D_opt = tf.train.AdamOptimizer(1e-4, beta1=0., beta2=0.9).minimize(D_loss, var_list=D.var_list()) 35 | G_opt = tf.train.AdamOptimizer(4e-4, beta1=0., beta2=0.9).minimize(G_loss, var_list=G.var_list()) 36 | sess = tf.Session() 37 | sess.run(tf.global_variables_initializer()) 38 | saver = tf.train.Saver() 39 | # saver.restore(sess, path_save_para+".\\model.ckpt") 40 | data = sio.loadmat("./dataset/imagenet64.mat") 41 | labels = data["labels"][0, :] 42 | data = data["data"] 43 | for itr in range(TRAIN_ITR): 44 | readtime = 0 45 | updatetime = 0 46 | for d in range(2): 47 | s_read = time.time() 48 | batch, Y = read_imagenet(data, labels, BATCH_SIZE) 49 | e_read = time.time() 50 | readtime += e_read - s_read 51 | batch = batch / 127.5 - 1 52 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM, TRUNCATION) 53 | s_up = time.time() 54 | sess.run(D_opt, feed_dict={z: Z, x: batch, train_phase: True, y: Y}) 55 | e_up = time.time() 56 | updatetime += e_up - s_up 57 | 58 | s = time.time() 59 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM, TRUNCATION) 60 | sess.run(G_opt, feed_dict={z: Z, train_phase: True, y: Y}) 61 | e = time.time() 62 | one_itr_time = e - s + updatetime + readtime 63 | if itr % 100 == 0: 64 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM, TRUNCATION) 65 | Dis_loss = sess.run(D_loss, feed_dict={z: Z, x: batch, train_phase: False, y: Y}) 66 | Gen_loss = sess.run(G_loss, feed_dict={z: Z, train_phase: False, y: Y}) 67 | print("Iteration: %d, D_loss: %f, G_loss: %f, Read_time: %f, Updata_time: %f, One_itr_time: %f" % (itr, Dis_loss, Gen_loss, readtime, updatetime, one_itr_time)) 68 | FAKE_IMG = sess.run(fake_img, feed_dict={z: Z, train_phase: False, y: Y}) 69 | Image.fromarray(np.uint8((FAKE_IMG[0, :, :, :] + 1)*127.5)).save("./save_img/"+str(itr) + "_" + str(Y[0]) + ".jpg") 70 | if itr % 500 == 0: 71 | saver.save(sess, "./save_para/model.ckpt") 72 | 73 | if __name__ == "__main__": 74 | Train() 75 | -------------------------------------------------------------------------------- /train_32.py: -------------------------------------------------------------------------------- 1 | from networks_32 import Generator, Discriminator 2 | from ops import Hinge_loss, ortho_reg 3 | import tensorflow as tf 4 | import numpy as np 5 | from utils import truncated_noise_sample, read_cifar 6 | from PIL import Image 7 | import time 8 | import scipy.io as sio 9 | 10 | NUMS_CLASS = 10 11 | Z_DIM = 128 12 | BETA = 1e-4 13 | BATCH_SIZE = 64 14 | TRAIN_ITR = 100000 15 | IMG_H = 32 16 | IMG_W = 32 17 | TRUNCATION = 2.0 18 | 19 | def Train(): 20 | x = tf.placeholder(tf.float32, [None, IMG_H, IMG_W, 3]) 21 | train_phase = tf.placeholder(tf.bool) 22 | z = tf.placeholder(tf.float32, [None, Z_DIM]) 23 | y = tf.placeholder(tf.int32, [None]) 24 | G = Generator("generator") 25 | D = Discriminator("discriminator") 26 | fake_img = G(z, train_phase, y, NUMS_CLASS) 27 | fake_logits = D(fake_img, y, NUMS_CLASS, None) 28 | real_logits = D(x, y, NUMS_CLASS, "NO_OPS") 29 | D_loss, G_loss = Hinge_loss(real_logits, fake_logits) 30 | D_ortho = BETA * ortho_reg(D.var_list()) 31 | G_ortho = BETA * ortho_reg(G.var_list()) 32 | D_loss += D_ortho 33 | G_loss += G_ortho 34 | D_opt = tf.train.AdamOptimizer(1e-4, beta1=0., beta2=0.9).minimize(D_loss, var_list=D.var_list()) 35 | G_opt = tf.train.AdamOptimizer(4e-4, beta1=0., beta2=0.9).minimize(G_loss, var_list=G.var_list()) 36 | sess = tf.Session() 37 | sess.run(tf.global_variables_initializer()) 38 | saver = tf.train.Saver() 39 | # saver.restore(sess, path_save_para+".\\model.ckpt") 40 | data = np.concatenate((sio.loadmat("./dataset/data_batch_1.mat")["data"], sio.loadmat("./dataset/data_batch_2.mat")["data"], 41 | sio.loadmat("./dataset/data_batch_3.mat")["data"], sio.loadmat("./dataset/data_batch_4.mat")["data"], 42 | sio.loadmat("./dataset/data_batch_5.mat")["data"]), axis=0) 43 | data = np.reshape(data, [50000, 3, 32, 32]) 44 | data = np.transpose(data, axes=[0, 2, 3, 1]) 45 | labels = np.concatenate((sio.loadmat("./dataset/data_batch_1.mat")["labels"], sio.loadmat("./dataset/data_batch_2.mat")["labels"], 46 | sio.loadmat("./dataset/data_batch_3.mat")["labels"], sio.loadmat("./dataset/data_batch_4.mat")["labels"], 47 | sio.loadmat("./dataset/data_batch_5.mat")["labels"]), axis=0)[:, 0] 48 | for itr in range(TRAIN_ITR): 49 | readtime = 0 50 | updatetime = 0 51 | for d in range(2): 52 | s_read = time.time() 53 | batch, Y = read_cifar(data, labels, BATCH_SIZE) 54 | e_read = time.time() 55 | readtime += e_read - s_read 56 | batch = batch / 127.5 - 1 57 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM, TRUNCATION) 58 | s_up = time.time() 59 | sess.run(D_opt, feed_dict={z: Z, x: batch, train_phase: True, y: Y}) 60 | e_up = time.time() 61 | updatetime += e_up - s_up 62 | 63 | Z = truncated_noise_sample(BATCH_SIZE, Z_DIM, TRUNCATION) 64 | s = time.time() 65 | sess.run(G_opt, feed_dict={z: Z, train_phase: True, y: Y}) 66 | e = time.time() 67 | one_itr_time = e - s + updatetime + readtime 68 | if itr % 100 == 0: 69 | Dis_loss = sess.run(D_loss, feed_dict={z: Z, x: batch, train_phase: False, y: Y}) 70 | Gen_loss = sess.run(G_loss, feed_dict={z: Z, train_phase: False, y: Y}) 71 | print("Iteration: %d, D_loss: %f, G_loss: %f, Read_time: %f, Updata_time: %f, One_itr_time: %f" % (itr, Dis_loss, Gen_loss, readtime, updatetime, one_itr_time)) 72 | FAKE_IMG = sess.run(fake_img, feed_dict={z: Z, train_phase: False, y: Y}) 73 | Image.fromarray(np.uint8((FAKE_IMG[0, :, :, :] + 1)*127.5)).save("./save_img/"+str(itr) + "_" + str(Y[0]) + ".jpg") 74 | if itr % 500 == 0: 75 | saver.save(sess, "./save_para/model.ckpt") 76 | 77 | if __name__ == "__main__": 78 | Train() 79 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import scipy.misc as misc 4 | import scipy.io as sio 5 | import os 6 | import pickle 7 | from scipy.stats import truncnorm 8 | 9 | 10 | def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): 11 | state = None if seed is None else np.random.RandomState(seed) 12 | values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) 13 | return truncation * values 14 | 15 | def read_cifar(data, labels, batch_size): 16 | rand_select = np.random.randint(0, 50000, [batch_size]) 17 | batch = data[rand_select] 18 | batch_labels = labels[rand_select] 19 | return batch, batch_labels 20 | 21 | def read_catdog(data, labels, batch_size): 22 | rand_select = np.random.randint(0, 11078, [batch_size]) 23 | batch = data[rand_select] 24 | batch_labels = labels[rand_select] 25 | return batch, batch_labels 26 | 27 | def read_imagenet(data, labels, batch_size): 28 | nums_data = labels.shape[0] 29 | rand_select = np.random.randint(0, nums_data, [batch_size]) 30 | batch = data[rand_select] 31 | batch_labels = labels[rand_select] 32 | return batch, batch_labels 33 | 34 | def read_face(data, batch_size): 35 | rand_select = np.random.randint(0, 13233, [batch_size]) 36 | batch = data[rand_select] 37 | 38 | return batch, 0 39 | 40 | # os.listdir("./dataset") 41 | def random_batch_(path, batch_size, shape, c_nums): 42 | folder_names = os.listdir(path) 43 | rand_select = np.random.randint(0, folder_names.__len__()) 44 | if not c_nums == folder_names.__len__(): 45 | print("Error: c_nums must match the number of the folders") 46 | return 47 | y = np.zeros([1, c_nums]) 48 | y[0, rand_select] = 1 49 | path = path + folder_names[rand_select] + "//" 50 | data = sio.loadmat(path + "dataset.mat")["data"] 51 | rand_select = np.random.randint(0, np.size(data, 0), [batch_size]) 52 | batch = data[rand_select] 53 | return batch, y 54 | 55 | def random_batch(path, batch_size, shape, c_nums): 56 | folder_names = os.listdir(path) 57 | rand_select = np.random.randint(0, folder_names.__len__()) 58 | if not c_nums == folder_names.__len__(): 59 | print("Error: c_nums must match the number of the folders") 60 | return 61 | y = np.zeros([1, 1]) 62 | y[0, 0] = rand_select 63 | path = path + folder_names[rand_select] + "//" 64 | file_names = os.listdir(path) 65 | rand_select = np.random.randint(0, file_names.__len__(), [batch_size]) 66 | batch = np.zeros([batch_size, shape[0], shape[1], shape[2]]) 67 | for i in range(batch_size): 68 | img = np.array(Image.open(path + file_names[rand_select[i]]).resize([shape[0], shape[1]]))[:, :, :3] 69 | if img.shape.__len__() == 2: 70 | img = np.dstack((img, img, img)) 71 | batch[i, :, :, :] = img 72 | return batch, y 73 | 74 | def random_face_batch(path, batch_size): 75 | filenames_young = os.listdir(path+"0//") 76 | filenames_cats = os.listdir(path+"1//") 77 | rand_gender = np.random.randint(0, 2) 78 | batch = np.zeros([batch_size, 64, 64, 3]) 79 | Y = np.zeros([1, 2]) 80 | if rand_gender == 0:#young 81 | rand_samples = np.random.randint(0, filenames_young.__len__(), [batch_size]) 82 | c = 0 83 | for i in rand_samples: 84 | img = np.array(Image.open(path+"0//"+filenames_young[i])) 85 | center_h = img.shape[0] // 2 86 | center_w = img.shape[1] // 2 87 | # batch[c, :, :, :] = misc.imresize(img[center_h - 70:center_h + 70, center_w - 70:center_w + 70, :], [64, 64]) 88 | batch[c, :, :, :] = misc.imresize(img, [64, 64]) 89 | c += 1 90 | Y[0, 0] = 1 91 | else: 92 | rand_samples = np.random.randint(0, filenames_cats.__len__(), [batch_size]) 93 | c = 0 94 | for i in rand_samples: 95 | img = np.array(Image.open(path + "1//" + filenames_cats[i])) 96 | batch[c, :, :, :] = misc.imresize(img, [64, 64]) 97 | c += 1 98 | Y[0, 1] = 1 99 | return batch, Y 100 | 101 | def random_batch_(path, batch_size, shape): 102 | filenames = os.listdir(path) 103 | rand_samples = np.random.randint(0, filenames.__len__(), [batch_size]) 104 | batch = np.zeros([batch_size, shape[0], shape[1], shape[2]]) 105 | c = 0 106 | y = np.zeros([batch_size, 2]) 107 | for idx in rand_samples: 108 | if (filenames[idx])[:3] == "cat": 109 | y[c, 0] = 1 110 | else: 111 | y[c, 1] = 1 112 | try: 113 | batch[c, :, :, :] = misc.imresize(crop(np.array(Image.open(path + filenames[idx]))), [shape[0], shape[1]]) 114 | except: 115 | img = crop(np.array(Image.open(path + filenames[0]))) 116 | batch[c, :, :, :] = misc.imresize(img, [shape[0], shape[1]]) 117 | c += 1 118 | return batch, y 119 | 120 | def crop(img): 121 | h = img.shape[0] 122 | w = img.shape[1] 123 | if h < w: 124 | x = 0 125 | y = np.random.randint(0, w - h + 1) 126 | length = h 127 | elif h > w: 128 | x = np.random.randint(0, h - w + 1) 129 | y = 0 130 | length = w 131 | else: 132 | x = 0 133 | y = 0 134 | length = h 135 | return img[x:x+length, y:y+length, :] 136 | 137 | def unpickle(file): 138 | with open(file, 'rb') as fo: 139 | dict = pickle.load(fo) 140 | return dict 141 | 142 | def to_img(src_path, dst_path): 143 | filenames = os.listdir(src_path) 144 | for filename in filenames: 145 | data = unpickle(src_path + filename) 146 | imgs = data["data"] 147 | labels = data["labels"] 148 | for i in range(np.size(imgs, 0)): 149 | img = np.transpose(np.reshape(imgs[i, :], [3, 64, 64]), [1, 2, 0]) 150 | if not os.path.exists(dst_path+str(labels[i])): 151 | os.mkdir(dst_path+str(labels[i])) 152 | Image.fromarray(np.uint8(img)).save(dst_path + str(labels[i]) + "//" + filename + "_" + str(labels[i]) + "_" + str(i) + ".jpg") 153 | print(filename) 154 | 155 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as contrib 3 | 4 | 5 | def conditional_batchnorm(x, train_phase, scope_bn, splited_z=None, y=None, nums_class=10): 6 | #Batch Normalization 7 | #Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[J]. 2015:448-456. 8 | with tf.variable_scope(scope_bn): 9 | if y == None: 10 | beta = tf.get_variable(name=scope_bn + 'beta', shape=[x.shape[-1]], 11 | initializer=tf.constant_initializer([0.]), trainable=True) # label_nums x C 12 | gamma = tf.get_variable(name=scope_bn + 'gamma', shape=[x.shape[-1]], 13 | initializer=tf.constant_initializer([1.]), trainable=True) # label_nums x C 14 | else: 15 | y = tf.one_hot(y, nums_class) 16 | z = tf.concat([splited_z, y], axis=1) 17 | gamma = dense("gamma", z, x.shape[-1], None, True) 18 | beta = dense("beta", z, x.shape[-1], None, True) 19 | gamma = tf.reshape(gamma, [-1, 1, 1, x.shape[-1]]) 20 | beta = tf.reshape(beta, [-1, 1, 1, x.shape[-1]]) 21 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments', keep_dims=True) 22 | ema = tf.train.ExponentialMovingAverage(decay=0.5) 23 | 24 | def mean_var_with_update(): 25 | ema_apply_op = ema.apply([batch_mean, batch_var]) 26 | with tf.control_dependencies([ema_apply_op]): 27 | return tf.identity(batch_mean), tf.identity(batch_var) 28 | 29 | mean, var = tf.cond(train_phase, mean_var_with_update, 30 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 31 | normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3) 32 | return normed 33 | 34 | # def non_local(name, inputs, update_collection, is_sn): 35 | # H = inputs.shape[1] 36 | # W = inputs.shape[2] 37 | # C = inputs.shape[3] 38 | # C_ = C // 8 39 | # inputs_ = tf.transpose(inputs, perm=[0, 3, 1, 2]) 40 | # inputs_ = tf.reshape(inputs_, [-1, C, H * W]) 41 | # with tf.variable_scope(name): 42 | # f = conv("f", inputs, C_, 1, 1, update_collection, is_sn) 43 | # g = conv("g", inputs, C_, 1, 1, update_collection, is_sn) 44 | # h = conv("h", inputs, C, 1, 1, update_collection, is_sn) 45 | # f = tf.transpose(f, [0, 3, 1, 2]) 46 | # f = tf.reshape(f, [-1, C_, H * W]) 47 | # g = tf.transpose(g, [0, 3, 1, 2]) 48 | # g = tf.reshape(g, [-1, C_, H * W]) 49 | # h = tf.transpose(h, [0, 3, 1, 2]) 50 | # h = tf.reshape(h, [-1, C, H * W]) 51 | # s = tf.matmul(f, g, transpose_a=True) 52 | # beta = tf.nn.softmax(s, dim=0) 53 | # o = tf.matmul(h, beta) 54 | # gamma = tf.get_variable("gamma", [], initializer=tf.constant_initializer(0.)) 55 | # y = gamma * o + inputs_ 56 | # y = tf.reshape(y, [-1, C, H, W]) 57 | # y = tf.transpose(y, perm=[0, 2, 3, 1]) 58 | # return y 59 | 60 | def non_local(name, inputs, update_collection, is_sn): 61 | h, w, num_channels = inputs.shape[1], inputs.shape[2], inputs.shape[3] 62 | location_num = h * w 63 | downsampled_num = location_num // 4 64 | with tf.variable_scope(name): 65 | theta = conv("f", inputs, num_channels // 8, 1, 1, update_collection, is_sn) 66 | theta = tf.reshape(theta, [-1, location_num, num_channels // 8]) 67 | phi = conv("h", inputs, num_channels // 8, 1, 1, update_collection, is_sn) 68 | phi = downsampling(phi) 69 | phi = tf.reshape(phi, [-1, downsampled_num, num_channels // 8]) 70 | attn = tf.matmul(theta, phi, transpose_b=True) 71 | attn = tf.nn.softmax(attn) 72 | g = conv("g", inputs, num_channels // 2, 1, 1, update_collection, is_sn) 73 | g = downsampling(g) 74 | g = tf.reshape(g, [-1, downsampled_num, num_channels // 2]) 75 | attn_g = tf.matmul(attn, g) 76 | attn_g = tf.reshape(attn_g, [-1, h, w, num_channels // 2]) 77 | sigma = tf.get_variable("sigma_ratio", [], initializer=tf.constant_initializer(0.0)) 78 | attn_g = conv("attn", attn_g, num_channels, 1, 1, update_collection, is_sn) 79 | return inputs + sigma * attn_g 80 | 81 | def conv(name, inputs, nums_out, k_size, strides, update_collection=None, is_sn=False): 82 | nums_in = inputs.shape[-1] 83 | with tf.variable_scope(name): 84 | W = tf.get_variable("W", [k_size, k_size, nums_in, nums_out], initializer=tf.orthogonal_initializer()) 85 | b = tf.get_variable("b", [nums_out], initializer=tf.constant_initializer([0.0])) 86 | if is_sn: 87 | W = spectral_normalization("sn", W, update_collection=update_collection) 88 | con = tf.nn.conv2d(inputs, W, [1, strides, strides, 1], "SAME") 89 | return tf.nn.bias_add(con, b) 90 | 91 | def upsampling(inputs): 92 | H = inputs.shape[1] 93 | W = inputs.shape[2] 94 | return tf.image.resize_nearest_neighbor(inputs, [H * 2, W * 2]) 95 | 96 | def downsampling(inputs): 97 | return tf.nn.avg_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], "SAME") 98 | 99 | def relu(inputs): 100 | return tf.nn.relu(inputs) 101 | 102 | def global_sum_pooling(inputs): 103 | inputs = tf.reduce_sum(inputs, [1, 2], keep_dims=False) 104 | return inputs 105 | 106 | def Hinge_loss(real_logits, fake_logits): 107 | D_loss = -tf.reduce_mean(tf.minimum(0., -1.0 + real_logits)) - tf.reduce_mean(tf.minimum(0., -1.0 - fake_logits)) 108 | G_loss = -tf.reduce_mean(fake_logits) 109 | return D_loss, G_loss 110 | 111 | def ortho_reg(vars_list): 112 | s = 0 113 | for var in vars_list: 114 | if "W" in var.name: 115 | if var.shape.__len__() == 4: 116 | nums_kernel = int(var.shape[-1]) 117 | W = tf.transpose(var, perm=[3, 0, 1, 2]) 118 | W = tf.reshape(W, [nums_kernel, -1]) 119 | ones = tf.ones([nums_kernel, nums_kernel]) 120 | eyes = tf.eye(nums_kernel, nums_kernel) 121 | y = tf.matmul(W, W, transpose_b=True) * (ones - eyes) 122 | s += tf.nn.l2_loss(y) 123 | return s 124 | 125 | def dense(name, inputs, nums_out, update_collection=None, is_sn=False): 126 | nums_in = inputs.shape[-1] 127 | with tf.variable_scope(name): 128 | W = tf.get_variable("W", [nums_in, nums_out], initializer=tf.orthogonal_initializer()) 129 | b = tf.get_variable("b", [nums_out], initializer=tf.constant_initializer([0.0])) 130 | if is_sn: 131 | W = spectral_normalization("sn", W, update_collection=update_collection) 132 | return tf.nn.bias_add(tf.matmul(inputs, W), b) 133 | 134 | def Inner_product(global_pooled, y, nums_class, update_collection=None): 135 | W = global_pooled.shape[-1] 136 | V = tf.get_variable("V", [nums_class, W], initializer=tf.orthogonal_initializer()) 137 | V = tf.transpose(V) 138 | V = spectral_normalization("embed", V, update_collection=update_collection) 139 | V = tf.transpose(V) 140 | temp = tf.nn.embedding_lookup(V, y) 141 | temp = tf.reduce_sum(temp * global_pooled, axis=1, keep_dims=True) 142 | return temp 143 | 144 | def G_Resblock(name, inputs, nums_out, is_training, splited_z, y, nums_class): 145 | with tf.variable_scope(name): 146 | temp = tf.identity(inputs) 147 | inputs = conditional_batchnorm(inputs, is_training, "bn1", splited_z, y, nums_class) 148 | inputs = relu(inputs) 149 | inputs = upsampling(inputs) 150 | inputs = conv("conv1", inputs, nums_out, 3, 1, is_sn=True) 151 | inputs = conditional_batchnorm(inputs, is_training, "bn2", splited_z, y, nums_class) 152 | inputs = relu(inputs) 153 | inputs = conv("conv2", inputs, nums_out, 3, 1, is_sn=True) 154 | #Identity mapping 155 | temp = upsampling(temp) 156 | temp = conv("identity", temp, nums_out, 1, 1, is_sn=True) 157 | return inputs + temp 158 | 159 | def D_Resblock(name, inputs, nums_out, update_collection=None, is_down=True): 160 | with tf.variable_scope(name): 161 | temp = tf.identity(inputs) 162 | inputs = relu(inputs) 163 | inputs = conv("conv1", inputs, nums_out, 3, 1, update_collection, is_sn=True) 164 | inputs = relu(inputs) 165 | inputs = conv("conv2", inputs, nums_out, 3, 1, update_collection, is_sn=True) 166 | if is_down: 167 | inputs = downsampling(inputs) 168 | #Identity mapping 169 | temp = conv("identity", temp, nums_out, 1, 1, update_collection, is_sn=True) 170 | temp = downsampling(temp) 171 | # else: 172 | # temp = conv("identity", temp, nums_out, 1, 1, update_collection, is_sn=True) 173 | return inputs + temp 174 | 175 | def D_FirstResblock(name, inputs, nums_out, update_collection, is_down=True): 176 | with tf.variable_scope(name): 177 | temp = tf.identity(inputs) 178 | inputs = conv("conv1", inputs, nums_out, 3, 1, update_collection=update_collection, is_sn=True) 179 | inputs = relu(inputs) 180 | inputs = conv("conv2", inputs, nums_out, 3, 1, update_collection=update_collection, is_sn=True) 181 | if is_down: 182 | inputs = downsampling(inputs) 183 | #Identity mapping 184 | temp = downsampling(temp) 185 | temp = conv("identity", temp, nums_out, 1, 1, update_collection=update_collection, is_sn=True) 186 | return inputs + temp 187 | 188 | 189 | 190 | def _l2normalize(v, eps=1e-12): 191 | """l2 normize the input vector.""" 192 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 193 | 194 | def spectral_normalization(name, weights, num_iters=1, update_collection=None, 195 | with_sigma=False): 196 | w_shape = weights.shape.as_list() 197 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) # [-1, output_channel] 198 | u = tf.get_variable(name + 'u', [1, w_shape[-1]], 199 | initializer=tf.truncated_normal_initializer(), 200 | trainable=False) 201 | u_ = u 202 | for _ in range(num_iters): 203 | v_ = _l2normalize(tf.matmul(u_, w_mat, transpose_b=True)) 204 | u_ = _l2normalize(tf.matmul(v_, w_mat)) 205 | 206 | sigma = tf.squeeze(tf.matmul(tf.matmul(v_, w_mat), u_, transpose_b=True)) 207 | w_mat /= sigma 208 | if update_collection is None: 209 | with tf.control_dependencies([u.assign(u_)]): 210 | w_bar = tf.reshape(w_mat, w_shape) 211 | else: 212 | w_bar = tf.reshape(w_mat, w_shape) 213 | if update_collection != 'NO_OPS': 214 | tf.add_to_collection(update_collection, u.assign(u_)) 215 | if with_sigma: 216 | return w_bar, sigma 217 | else: 218 | return w_bar 219 | 220 | # inputs = tf.placeholder(tf.float32, [None, 32, 32, 128]) 221 | # non_local("non_local", inputs, None, True) 222 | --------------------------------------------------------------------------------