├── .gitignore ├── README.md ├── main.py └── src ├── __init__.py ├── config.py ├── dataset.py ├── estimator.py ├── model ├── __init__.py ├── discriminator.py ├── generator.py ├── layers.py └── losses.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | logs 3 | __pycache__ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StackGAN 2 | 3 | Tensorflow implementation of the StackGAN++ outlined in this paper: [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/pdf/1710.10916.pdf). 4 | 5 | This implementation uses the Estimator API, allowing you to train StackGAN++ models on novel datasets with minimal effort. 6 | 7 | ### Features 8 | 9 | - Easy to use. To retrain this model, simply run `python3 main.py --train --data_dir=/path/to/your/dataset --log_dir=/path/to/logs`. See the full list of [configuration options](https://google.com). 10 | - Fully TPU compatible. To run on gcloud TPUs, use the `--tpu_name` flag. 11 | - Native Tensorboard integration. 12 | 13 | ### Getting Started 14 | 15 | #### Create Dataset 16 | 17 | You must format your dataset as a TFRecord file with the given features: 18 | ``` 19 | image : bytes 20 | height : int64 21 | width : int64 22 | channels: int64 23 | label : int64 24 | ``` 25 | 26 | An example is shown below: 27 | 28 | ```python 29 | import tensorflow as tf 30 | import cv2 31 | 32 | writer = tf.python_io.TFRecordWriter("/path/to/my/dataset.tfrecords") 33 | 34 | for f in os.listdir("./raw_images"): 35 | img = cv2.imread(os.path.join("./raw_images", f)) 36 | height, width, channels = img.shape 37 | class_index = 0 # 0 to (num_classes-1) 38 | 39 | example = tf.train.Example( 40 | features = tf.train.Features( 41 | feature = { 42 | "image_64": _bytes_feature(resize(img, 64).tostring()), 43 | "image_128": _bytes_feature(resize(img, 128).tostring()), 44 | "image_256": _bytes_feature(resize(img, 256).tostring()), 45 | "label": _int64_feature(class_index) 46 | } 47 | ) 48 | ) 49 | 50 | writer.write(example.SerializeToString()) 51 | ``` 52 | 53 | Once this `.tfrecords` file is created, you can immediately use it to train your model from your local machine. Alternatively, you can upload it to a gcloud storage bucket that you own and reference it from there, which is advantageous if you are using AWS or Gcloud VMs and don't want to worry about a time-consuming process of downloading the dataset first. 54 | 55 | *NOTE*: Using a google storage location for your dataset and log files is a requirement when using TPUs. 56 | 57 | #### Run Training 58 | 59 | Training the model on a new dataset is easy. Locally, you can just run: 60 | 61 | ``` 62 | python3 main.py --train --data_dir=/path/to/dataset.tfrecords 63 | ``` 64 | 65 | If your dataset is stored on gcloud storage, you can simply replace with `data_dir` with a fully qualified google storage path: 66 | 67 | ``` 68 | python3 main.py --train --data_dir=gs://${BUCKET}/path/to/dataset.tfrecords 69 | ``` 70 | 71 | #### Monitor 72 | 73 | To monitor training, point a `tensorboard` instance to your log dir: 74 | 75 | ``` 76 | tensorboard --logdir=/path/to/logs 77 | ``` 78 | 79 | ### Using TPUs 80 | 81 | This repos is written to be fully TPU compatible. Assuming you have already provisioned a TPU on gcloud, you can use the `--tpu_name` flag: 82 | 83 | ``` 84 | python3 main.py --train --data_dir=gs://${BUCKET}/path/to/dataset.tfrecords --tpu_name=${TPU_NAME} 85 | ``` 86 | 87 | ### Configuration Options 88 | 89 | ``` 90 | - data_dir: String. The data directory. Must point to a tfrecords file. Can be a google storage path (e.g. gs://my-bucket/my/path/file.tfrecords). 91 | ``` 92 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import src.train 3 | import src.config as config 4 | 5 | def main(_): 6 | src.train.main(_) 7 | 8 | if __name__ == '__main__': 9 | tf.app.run() 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zacharynevin/StackGAN/faabba31b2089b50276488f78d399a052e4752a2/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | flags = tf.app.flags 4 | 5 | flags.DEFINE_bool("train", False, "Run training [False].") 6 | flags.DEFINE_bool("predict", False, "Run prediction [False].") 7 | flags.DEFINE_bool("eval", False, "Run evaluation [False].") 8 | 9 | flags.DEFINE_integer("predict_class", None, "The class to generate. If None, generate images from random classes [None]") 10 | 11 | flags.DEFINE_bool("use_tpu", False, "Set to True to use TPUs [False].") 12 | flags.DEFINE_string("tpu_name", None, "The name of the TPU to use [None].") 13 | flags.DEFINE_integer("tpu_shards", 8, "Number of TPU shards [8].") 14 | flags.DEFINE_integer("tpu_iterations", 50, "Number of iterations per TPU training loop [50].") 15 | 16 | flags.DEFINE_string("data_dir", "./data/dataset.tfrecords", "The data directory. Must point to a tfrecords file. Can be a google storage path (e.g. gs://my-bucket/my/path/file.tfrecords).") 17 | flags.DEFINE_string("log_dir", "./logs", "Directory to store logs. Can be a google storage path (e.g. gs://my-bucket/my/path).") 18 | 19 | flags.DEFINE_integer("buffer_size", 8*1024*1024, "The dataset buffer size. [8388608]") 20 | flags.DEFINE_integer("batch_size", 128, "The batch size. If using TPUs, this is the batch size per shard. [64]") 21 | flags.DEFINE_integer("z_dim", 100, "The z input dimension [100].") 22 | flags.DEFINE_integer("data_shuffle_seed", 12345, "The seed to use when shuffling the database [12345].") 23 | flags.DEFINE_integer("data_map_parallelism", 10, "The number of parallel calls to use in dataset.map [10].") 24 | 25 | flags.DEFINE_float("g_lr", 0.0002, "The generator learning rate [2e-4].") 26 | flags.DEFINE_float("d_lr", 0.0002, "The discriminator learning rate [2e-4].") 27 | 28 | flags.DEFINE_integer("train_steps", 1000, "The number of training steps [1000].") 29 | flags.DEFINE_integer("eval_steps", 1000, "The number of eval steps [1000].") 30 | 31 | config = flags.FLAGS 32 | 33 | if config.train and config.tpu_name: 34 | config.use_tpu = True 35 | config.data_format = 'NCHW' 36 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tensorflow.contrib.slim as slim 4 | 5 | def get_dataset_iterator(data_dir, 6 | batch_size, 7 | data_format, 8 | buffer_size, 9 | shuffle_seed, 10 | num_parallel_calls): 11 | """Construct a TF dataset from a remote source""" 12 | def transform(tfrecord_proto): 13 | return transform_tfrecord(tfrecord_proto, 14 | data_format=data_format) 15 | 16 | tf_dataset = tf.data.TFRecordDataset(data_dir) 17 | tf_dataset = tf_dataset.map(transform, num_parallel_calls=num_parallel_calls) 18 | tf_dataset = tf_dataset.shuffle(seed=shuffle_seed, buffer_size=buffer_size) 19 | tf_dataset = tf_dataset.repeat() 20 | tf_dataset = tf_dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) 21 | tf_iterator = tf_dataset.make_one_shot_iterator() 22 | return tf_iterator 23 | 24 | def decode_image(img, dim, data_format): 25 | """ 26 | Take a raw image byte string and decode to an image 27 | 28 | Params: 29 | img (str): Image byte string 30 | dim (int): The width and height of the image. 31 | data_format (str): The data format for the image 32 | 33 | Return: 34 | Tensor[HCW]: A tensor of shape representing the RGB image. 35 | """ 36 | img = tf.decode_raw(img, out_type=tf.uint8) 37 | img = tf.reshape(img, tf.stack([dim, dim, 3], axis=0)) 38 | img = tf.reverse(img, [-1]) # BGR to RGB 39 | img = transform_image(img, data_format) 40 | return img 41 | 42 | def transform_image(img, data_format): 43 | img = tf.image.convert_image_dtype(img, tf.float32) 44 | if data_format == 'NCHW': 45 | img = tf.transpose(img, [3, 1, 2]) 46 | 47 | return img 48 | 49 | def decode_class(label, num_classes): 50 | return tf.one_hot(label, num_classes, dtype=tf.float32) 51 | 52 | def transform_tfrecord(tf_protobuf, data_format): 53 | """ 54 | Decode the tfrecord protobuf into the image. 55 | 56 | Params: 57 | tf_protobuf (proto): A protobuf representing the data record. 58 | 59 | Returns: 60 | Tensor[64, 64, 3] 61 | Tensor[128, 128, 3] 62 | Tensor[256, 256, 3] 63 | """ 64 | 65 | features = { 66 | "image_64": tf.FixedLenFeature((), tf.string), 67 | "image_128": tf.FixedLenFeature((), tf.string), 68 | "image_256": tf.FixedLenFeature((), tf.string) 69 | } 70 | parsed_features = tf.parse_single_example(tf_protobuf, features) 71 | 72 | image_64 = decode_image(parsed_features["image_64"], 64, data_format=data_format) 73 | image_128 = decode_image(parsed_features["image_128"], 128, data_format=data_format) 74 | image_256 = decode_image(parsed_features["image_256"], 256, data_format=data_format) 75 | 76 | return image_64, image_128, image_256 77 | -------------------------------------------------------------------------------- /src/estimator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import tpu 3 | from tensorflow.contrib.tpu import TPUEstimatorSpec as EstimatorSpec 4 | from tensorflow.contrib import summary 5 | import numpy as np 6 | import os 7 | from src.config import config 8 | from src.dataset import get_dataset_iterator 9 | from src.model.discriminator import StackGANDiscriminator as Discriminator 10 | from src.model.generator import StackGANGenerator as Generator 11 | import src.model.losses as losses 12 | 13 | ModeKeys = tf.estimator.ModeKeys 14 | 15 | def model_fn(features, labels, mode, params): 16 | use_tpu = params['use_tpu'] 17 | D_lr = params["D_lr"] 18 | G_lr = params["G_lr"] 19 | data_format = params["data_format"] 20 | 21 | loss = None 22 | train_op = None 23 | predictions = None 24 | host_call = None 25 | eval_metrics = None 26 | 27 | generator = Generator(data_format) 28 | discriminator = Discriminator(data_format) 29 | 30 | if mode == ModeKeys.PREDICT: 31 | G0, G1, G2 = generator(features) 32 | 33 | predictions = { 34 | 'G0': G0, 35 | 'G1': G1, 36 | 'G2': G2 37 | } 38 | elif mode == ModeKeys.TRAIN or mode == ModeKeys.EVAL: 39 | R0 = features['R0'] 40 | R1 = features['R1'] 41 | R2 = features['R2'] 42 | z = features['z'] 43 | 44 | global_step = tf.train.get_or_create_global_step() 45 | 46 | G_global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='G_{}_global_step'.format(mode)) 47 | D0_global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='D0_{}_global_step'.format(mode)) 48 | D1_global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='D1_{}_global_step'.format(mode)) 49 | D2_global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='D2_{}_global_step'.format(mode)) 50 | 51 | G0, G1, G2, G_scope = generator(z) 52 | 53 | D_R0, D0_scope = discriminator.D0(R0) 54 | D_R1, D1_scope = discriminator.D1(R1) 55 | D_R2, D2_scope = discriminator.D2(R2) 56 | 57 | D_G0, _ = discriminator.D0(G0) 58 | D_G1, _ = discriminator.D1(G1) 59 | D_G2, _ = discriminator.D2(G2) 60 | 61 | op = tf.group( 62 | tf.assign_add(global_step, 1), 63 | tf.assign_add(G_global_step, 1), 64 | tf.assign_add(D0_global_step, 1), 65 | tf.assign_add(D1_global_step, 1), 66 | tf.assign_add(D2_global_step, 1) 67 | ) if mode == ModeKeys.EVAL else tf.no_op() 68 | 69 | with tf.control_dependencies([op]): 70 | 71 | with tf.variable_scope('losses'): 72 | with tf.variable_scope('G0'): 73 | L_G0 = losses.G_loss(D_G0) 74 | 75 | with tf.variable_scope('G1'): 76 | L_G1 = losses.G_loss(D_G1) 77 | L_G1 += losses.colour_consistency_regularization(G1, G0, data_format=data_format) 78 | 79 | with tf.variable_scope('G2'): 80 | L_G2 = losses.G_loss(D_G2) 81 | L_G2 += losses.colour_consistency_regularization(G2, G1, data_format=data_format) 82 | 83 | with tf.variable_scope('D0'): 84 | L_D0 = losses.D_loss(D_R0, D_G0) 85 | L_D0_W = losses.wasserstein_loss(R0, G0, discriminator.D0, D0_scope) 86 | L_D0 += L_D0_W 87 | 88 | with tf.variable_scope('D1'): 89 | L_D1 = losses.D_loss(D_R1, D_G1) 90 | L_D1_W = losses.wasserstein_loss(R1, G1, discriminator.D1, D1_scope) 91 | L_D1 += L_D1_W 92 | 93 | with tf.variable_scope('D2'): 94 | L_D2 = losses.D_loss(D_R2, D_G2) 95 | L_D2_W = losses.wasserstein_loss(R2, G2, discriminator.D2, D2_scope) 96 | L_D2 += L_D2_W 97 | 98 | with tf.variable_scope('G'): 99 | L_G = L_G0 + L_G1 + L_G2 100 | 101 | if mode == ModeKeys.TRAIN: 102 | with tf.variable_scope('optimizers'): 103 | with tf.control_dependencies([tf.assign_add(global_step, 1)]): 104 | trainable_vars = tf.trainable_variables() 105 | G_vars = [var for var in trainable_vars if 'generator' in var.name] 106 | D0_vars = [var for var in trainable_vars if 'discriminator/D0' in var.name] 107 | D1_vars = [var for var in trainable_vars if 'discriminator/D1' in var.name] 108 | D2_vars = [var for var in trainable_vars if 'discriminator/D2' in var.name] 109 | 110 | D0_train = create_train_op(L_D0, 111 | global_step=D0_global_step, 112 | learning_rate=D_lr, 113 | var_list=D0_vars, 114 | use_tpu=use_tpu) 115 | 116 | D1_train = create_train_op(L_D1, 117 | global_step=D1_global_step, 118 | learning_rate=D_lr, 119 | var_list=D1_vars, 120 | use_tpu=use_tpu) 121 | 122 | D2_train = create_train_op(L_D2, 123 | global_step=D2_global_step, 124 | learning_rate=D_lr, 125 | var_list=D2_vars, 126 | use_tpu=use_tpu) 127 | 128 | with tf.control_dependencies([D2_train, D1_train, D0_train]): 129 | G_train = create_train_op(L_G, 130 | global_step=G_global_step, 131 | learning_rate=G_lr, 132 | var_list=G_vars, 133 | use_tpu=use_tpu) 134 | 135 | train_op = tf.group(G_train, D2_train, D1_train, D0_train) 136 | 137 | loss = L_G 138 | 139 | host_call = (host_call_fn(mode), [ 140 | G0, 141 | G1, 142 | G2, 143 | R0, 144 | R1, 145 | R2, 146 | tpu_pad(L_D0), 147 | tpu_pad(L_D1), 148 | tpu_pad(L_D2), 149 | tpu_pad(L_D0_W), 150 | tpu_pad(L_D1_W), 151 | tpu_pad(L_D2_W), 152 | tpu_pad(L_G0), 153 | tpu_pad(L_G1), 154 | tpu_pad(L_G2), 155 | tpu_pad(L_G), 156 | tpu_pad(D0_global_step), 157 | tpu_pad(D1_global_step), 158 | tpu_pad(D2_global_step), 159 | tpu_pad(G_global_step) 160 | ]) 161 | 162 | return EstimatorSpec(mode, 163 | predictions=predictions, 164 | loss=loss, 165 | host_call=host_call, 166 | eval_metrics=eval_metrics, 167 | train_op=train_op) 168 | 169 | def tpu_pad(scalar): 170 | return tf.reshape(scalar, [1]) 171 | 172 | def tpu_depad(tensor, dtype=None): 173 | tensor = tf.reduce_mean(tensor) 174 | if dtype: 175 | tensor = tf.cast(tensor, dtype) 176 | return tensor 177 | 178 | def host_call_fn(mode): 179 | """ 180 | This is a hack for getting multiple losses to appear in Tensorboard. 181 | It also gives us the ability to write summaries when using TPUs, which are normally incompatible with tf.summary. 182 | """ 183 | def summary_fn(G0, G1, G2, R0, R1, R2, L_D0, L_D1, L_D2, L_D0_W, L_D1_W, L_D2_W, L_G0, L_G1, L_G2, L_G, 184 | D0_global_step, D1_global_step, D2_global_step, G_global_step): 185 | with summary.create_file_writer(config.log_dir).as_default(): 186 | with summary.always_record_summaries(): 187 | max_image_outputs = 10 188 | 189 | D0_global_step = tpu_depad(D0_global_step) 190 | D1_global_step = tpu_depad(D1_global_step) 191 | D2_global_step = tpu_depad(D2_global_step) 192 | G_global_step = tpu_depad(G_global_step) 193 | L_D0 = tpu_depad(L_D0) 194 | L_D1 = tpu_depad(L_D1) 195 | L_D2 = tpu_depad(L_D2) 196 | L_D0_W = tpu_depad(L_D0_W) 197 | L_D1_W = tpu_depad(L_D1_W) 198 | L_D2_W = tpu_depad(L_D2_W) 199 | L_G0 = tpu_depad(L_G0) 200 | L_G1 = tpu_depad(L_G1) 201 | L_G2 = tpu_depad(L_G2) 202 | L_G = tpu_depad(L_G) 203 | 204 | summary.image('R0', R0, max_images=max_image_outputs, step=D0_global_step) 205 | summary.image('R1', R1, max_images=max_image_outputs, step=D1_global_step) 206 | summary.image('R2', R2, max_images=max_image_outputs, step=D2_global_step) 207 | summary.image('G0', G0, max_images=max_image_outputs, step=G_global_step) 208 | summary.image('G1', G1, max_images=max_image_outputs, step=G_global_step) 209 | summary.image('G2', G2, max_images=max_image_outputs, step=G_global_step) 210 | 211 | with tf.name_scope('losses'): 212 | summary.scalar('D0', L_D0, step=D0_global_step) 213 | summary.scalar('D1', L_D1, step=D1_global_step) 214 | summary.scalar('D2', L_D2, step=D2_global_step) 215 | summary.scalar('D0_W', L_D0_W, step=D0_global_step) 216 | summary.scalar('D1_W', L_D1_W, step=D1_global_step) 217 | summary.scalar('D2_W', L_D2_W, step=D2_global_step) 218 | 219 | summary.scalar('G0', L_G0, step=G_global_step) 220 | summary.scalar('G1', L_G1, step=G_global_step) 221 | summary.scalar('G2', L_G2, step=G_global_step) 222 | summary.scalar('G', L_G, step=G_global_step) 223 | 224 | return summary.all_summary_ops() 225 | 226 | return summary_fn 227 | 228 | def predict_input_fn(params, class_label=None): 229 | sample_size = params['batch_size'] 230 | z = tf.random_normal([sample_size, z_dim]) 231 | 232 | return z, None 233 | 234 | def eval_input_fn(params): 235 | return get_dataset(params, 'eval') 236 | 237 | def train_input_fn(params): 238 | return get_dataset(params, 'train') 239 | 240 | def get_dataset(params, mode): 241 | batch_size = params['batch_size'] 242 | buffer_size = params['buffer_size'] 243 | data_dir = params['data_dir'] 244 | data_format = params['data_format'] 245 | z_dim = params['z_dim'] 246 | seed = params['data_seed']*2 if mode == 'eval' else params['data_seed'] 247 | num_parallel_calls = params['data_map_parallelism'] 248 | 249 | iterator = get_dataset_iterator(data_dir, 250 | batch_size, 251 | data_format=data_format, 252 | buffer_size=buffer_size, 253 | shuffle_seed=seed, 254 | num_parallel_calls=num_parallel_calls) 255 | 256 | R0, R1, R2 = iterator.get_next() 257 | 258 | z = tf.random_normal([batch_size, z_dim]) 259 | 260 | features = { 261 | 'R0': R0, 262 | 'R1': R1, 263 | 'R2': R2, 264 | 'z': z 265 | } 266 | 267 | return features, None 268 | 269 | def create_train_op(loss, learning_rate, var_list, global_step, use_tpu=False): 270 | exp_learning_rate = tf.train.exponential_decay(learning_rate, 271 | global_step, 272 | decay_steps=10000, 273 | decay_rate=0.96) 274 | 275 | optimizer = tf.train.AdamOptimizer(learning_rate=exp_learning_rate, 276 | beta1=0.5, 277 | beta2=0.999) 278 | 279 | if use_tpu: 280 | optimizer = tpu.CrossShardOptimizer(optimizer) 281 | 282 | return optimizer.minimize(loss, 283 | var_list=var_list, 284 | global_step=global_step, 285 | colocate_gradients_with_ops=True) 286 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zacharynevin/StackGAN/faabba31b2089b50276488f78d399a052e4752a2/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | import numpy as np 4 | import tensorflow.contrib.slim as slim 5 | import src.model.layers as layers 6 | 7 | class StackGANDiscriminator(): 8 | def __init__(self, data_format): 9 | """ 10 | Initialize StackGAN++ Discriminator 11 | 12 | Params: 13 | data_format (str): The data format to use for the image. 14 | """ 15 | self.data_format = data_format 16 | self.Nd = 64 17 | 18 | def D0(self, Im0, scope=None): 19 | with slim.arg_scope([slim.conv2d, slim.batch_norm], data_format=self.data_format): 20 | with tf.variable_scope(scope or 'discriminator/D0', reuse=tf.AUTO_REUSE) as D0_scope: 21 | net = self.add_noise(Im0) 22 | net = self.encode_x16(net) 23 | 24 | logits = self.logits(net) 25 | 26 | return logits, D0_scope 27 | 28 | def D1(self, Im1, scope=None): 29 | with slim.arg_scope([slim.conv2d, slim.batch_norm], data_format=self.data_format): 30 | with tf.variable_scope(scope or 'discriminator/D1', reuse=tf.AUTO_REUSE) as D1_scope: 31 | net = self.add_noise(Im1) 32 | net = self.encode_x16(net) 33 | net = self.downsample(net, 16*self.Nd) 34 | net = self.conv3x3_block(net, 8*self.Nd) 35 | 36 | logits = self.logits(net) 37 | 38 | return logits, D1_scope 39 | 40 | def D2(self, Im2, scope=None): 41 | with slim.arg_scope([slim.conv2d, slim.batch_norm], data_format=self.data_format): 42 | with tf.variable_scope(scope or 'discriminator/D2', reuse=tf.AUTO_REUSE) as D2_scope: 43 | net = self.add_noise(Im2) 44 | net = self.encode_x16(net) 45 | net = self.downsample(net, 16*self.Nd) 46 | net = self.downsample(net, 32*self.Nd) 47 | net = self.conv3x3_block(net, 16*self.Nd) 48 | net = self.conv3x3_block(net, 32*self.Nd) 49 | 50 | logits = self.logits(net) 51 | 52 | return logits, D2_scope 53 | 54 | def conv3x3_block(self, net, filters): 55 | return layers.conv3x3_block(net, filters, self.data_format) 56 | 57 | def downsample(self, net, filters): 58 | with tf.name_scope('downsample'): 59 | net = slim.conv2d(net, filters, kernel_size=4, stride=2, padding='same', biases_initializer=None) 60 | net = slim.batch_norm(net) 61 | net = tf.nn.leaky_relu(net) 62 | return net 63 | 64 | def logits(self, net): 65 | with tf.name_scope('logits'): 66 | net = slim.conv2d(net, 1, kernel_size=4, stride=4, padding='same') 67 | net = tf.nn.sigmoid(net) 68 | return tf.reshape(net, [-1]) 69 | 70 | def add_noise(self, net): 71 | with tf.name_scope('noise'): 72 | noise = tf.random_normal(tf.shape(net), stddev=0.02, dtype=tf.float32) 73 | net = net + noise 74 | return net 75 | 76 | def encode_x16(self, net): 77 | with tf.name_scope('encode_x16'): 78 | with slim.arg_scope([slim.conv2d], kernel_size=4, stride=2, padding='same'): 79 | net = slim.conv2d(net, self.Nd) 80 | net = tf.nn.leaky_relu(net) 81 | 82 | net = slim.conv2d(net, 2*self.Nd) 83 | net = slim.batch_norm(net) 84 | net = tf.nn.leaky_relu(net) 85 | 86 | net = slim.conv2d(net, 4*self.Nd) 87 | net = slim.batch_norm(net) 88 | net = tf.nn.leaky_relu(net) 89 | 90 | net = slim.conv2d(net, 8*self.Nd) 91 | net = slim.batch_norm(net) 92 | net = tf.nn.leaky_relu(net) 93 | 94 | return net 95 | -------------------------------------------------------------------------------- /src/model/generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | import numpy as np 4 | import tensorflow.contrib.slim as slim 5 | import src.model.layers as layers 6 | 7 | class StackGANGenerator(): 8 | def __init__(self, data_format): 9 | """ 10 | Initialize StackGAN++ generator 11 | 12 | Params: 13 | data_format (str): The data format to use for the image. 14 | """ 15 | self.data_format = data_format 16 | self.Ng = 32 17 | 18 | def __call__(self, z): 19 | """ 20 | Build StackGAN++ generator graph 21 | 22 | Params: 23 | z Tensor[None, None]: A 2-D tensor representing the z-input. 24 | 25 | Returns: 26 | G0 Tensor[None, 64, 64, 3] 27 | G1 Tensor[None, 128, 128, 3] 28 | G2 Tensor[None, 256, 256, 3] 29 | """ 30 | with tf.variable_scope('generator') as G_scope: 31 | with slim.arg_scope([slim.conv2d, slim.batch_norm], data_format=self.data_format): 32 | 33 | with tf.variable_scope('FC'): 34 | net = slim.fully_connected(z, 4*4*64*self.Ng*2, biases_initializer=None) 35 | net = slim.batch_norm(net) 36 | net = self.glu(net) 37 | 38 | net_shape = [-1, 4, 4, 64*self.Ng] if self.data_format == 'NHWC' else [-1, 64*self.Ng, 4, 4] 39 | net = tf.reshape(net, net_shape) 40 | 41 | G0, net = self.G0(net) 42 | G1, net = self.G1(net, z) 43 | G2, net = self.G2(net, z) 44 | 45 | return G0, G1, G2, G_scope 46 | 47 | def G0(self, net): 48 | with tf.variable_scope('G0'): 49 | net = self.upsample(net, [-1, 8, 8, 32*self.Ng]) 50 | net = self.upsample(net, [-1, 16, 16, 16*self.Ng]) 51 | net = self.upsample(net, [-1, 32, 32, 8*self.Ng]) 52 | net = self.upsample(net, [-1, 64, 64, 4*self.Ng]) 53 | G0 = self.to_image(net) 54 | return G0, net 55 | 56 | def G1(self, net, z): 57 | with tf.variable_scope('G1'): 58 | net = self.joint_conv(net, z, 64) 59 | net = self.residual_block(net, 64) 60 | net = self.residual_block(net, 64) 61 | net = self.upsample(net, [-1, 128, 128, 2*self.Ng]) 62 | G1 = self.to_image(net) 63 | return G1, net 64 | 65 | def G2(self, net, z): 66 | with tf.variable_scope('G2'): 67 | net = self.joint_conv(net, z, 32) 68 | net = self.residual_block(net, 32) 69 | net = self.residual_block(net, 32) 70 | net = self.upsample(net, [-1, 256, 256, 1*self.Ng]) 71 | G2 = self.to_image(net) 72 | return G2, net 73 | 74 | def to_image(self, 75 | net): 76 | net = slim.conv2d(net, 3, kernel_size=3, stride=1, padding='same') 77 | net = tf.nn.tanh(net) 78 | return net 79 | 80 | def upsample(self, 81 | net, 82 | output_shape): 83 | 84 | filters = output_shape[-1] 85 | height = output_shape[1] 86 | width = output_shape[2] 87 | 88 | with tf.name_scope('upsample_%d_%d_%d' % (height, width, filters)): 89 | net = self.resize(net, [height, width]) 90 | net = self.conv3x3_block(net, filters) 91 | return net 92 | 93 | def joint_conv(self, net, z, filters): 94 | with tf.name_scope('joint_conv'): 95 | net_shape = net.get_shape().as_list() 96 | 97 | print(net, z) 98 | if self.data_format == 'NCHW': 99 | channels = net_shape[1] 100 | height = net_shape[2] 101 | width = net_shape[3] 102 | 103 | z = tf.expand_dims(z, -1) 104 | z = tf.expand_dims(z, -1) 105 | z = tf.tile(z, [1, 1, height, width]) 106 | net = tf.concat([net, z], 1) 107 | else: 108 | height = net_shape[1] 109 | width = net_shape[2] 110 | channels = net_shape[3] 111 | 112 | z = tf.expand_dims(z, 1) 113 | z = tf.expand_dims(z, 1) 114 | z = tf.tile(z, [1, height, width, 1]) 115 | net = tf.concat([net, z], -1) 116 | 117 | print(net) 118 | 119 | net = self.conv3x3_block(net, filters) 120 | 121 | return net 122 | 123 | def residual_block(self, x, filters): 124 | with tf.name_scope('residual_block'): 125 | with slim.arg_scope([slim.conv2d], stride=1, padding='same'): 126 | Fx = slim.conv2d(x, filters*2, kernel_size=3) 127 | Fx = slim.batch_norm(Fx) 128 | Fx = self.glu(Fx) 129 | Fx = slim.conv2d(Fx, filters, kernel_size=3) 130 | Fx = slim.batch_norm(Fx) 131 | 132 | return Fx + x 133 | 134 | def conv3x3_block(self, net, filters): 135 | return layers.conv3x3_block(net, filters, self.data_format) 136 | 137 | def glu(self, net): 138 | return layers.glu(net, self.data_format) 139 | 140 | def resize(self, net, dims): 141 | return layers.resize(net, dims, self.data_format) 142 | -------------------------------------------------------------------------------- /src/model/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | def nhwc_to_nchw(net): 5 | return tf.transpose(net, [0, 3, 1, 2]) 6 | 7 | def nchw_to_nhwc(net): 8 | return tf.transpose(net, [0, 2, 3, 1]) 9 | 10 | def resize(net, dims, data_format): 11 | height = dims[0] 12 | width = dims[1] 13 | 14 | net_shape = net.get_shape().as_list() 15 | 16 | if data_format == 'NCHW': 17 | net_height = net_shape[2] 18 | net_width = net_shape[3] 19 | else: 20 | net_height = net_shape[1] 21 | net_width = net_shape[2] 22 | 23 | is_same_height_width = net_height == height and net_width == width 24 | 25 | if is_same_height_width: 26 | return net 27 | 28 | with tf.name_scope('resize'): 29 | if data_format == 'NCHW': 30 | net = nchw_to_nhwc(net) 31 | 32 | net = tf.image.resize_nearest_neighbor(net, (height, width)) 33 | 34 | if data_format == 'NCHW': 35 | net = nhwc_to_nchw(net) 36 | 37 | return net 38 | 39 | def conv3x3_block(net, filters, data_format): 40 | with tf.name_scope('conv3x3_block'): 41 | net = slim.conv2d(net, filters*2, kernel_size=3, stride=1, padding='same') 42 | net = slim.batch_norm(net) 43 | net = glu(net, data_format) 44 | return net 45 | 46 | def glu(net, data_format): 47 | """ 48 | Gated linear unit 49 | """ 50 | with tf.name_scope('glu'): 51 | if data_format == 'NHWC': 52 | num_channels = net.get_shape().as_list()[-1] 53 | num_channels = int(num_channels/2) 54 | 55 | net = net[...,:num_channels] * tf.nn.sigmoid(net[...,num_channels:]) 56 | else: 57 | num_channels = net.get_shape().as_list()[1] 58 | num_channels = int(num_channels/2) 59 | 60 | net = net[:,:num_channels] * tf.nn.sigmoid(net[:,num_channels:]) 61 | 62 | return net 63 | -------------------------------------------------------------------------------- /src/model/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | import src.model.layers as layers 4 | 5 | def G_loss(G_logits): 6 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=G_logits, labels=tf.ones_like(G_logits))) 7 | 8 | def D_loss(D_logits, G_logits): 9 | output = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits, labels=true_labels(D_logits))) 10 | output += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=G_logits, labels=false_labels(G_logits))) 11 | return output 12 | 13 | def false_labels(labels): 14 | return tf.random_uniform(tf.shape(labels), .0, .3) 15 | 16 | def true_labels(labels): 17 | return tf.random_uniform(tf.shape(labels), .8, 1.2) 18 | 19 | def interpolates(real_batch, fake_batch): 20 | with tf.name_scope('interpolates'): 21 | real_batch = slim.flatten(real_batch) 22 | fake_batch = slim.flatten(fake_batch) 23 | alpha = tf.random_uniform([tf.shape(real_batch)[0], 1], minval=0., maxval=1.) 24 | differences = fake_batch - real_batch 25 | return real_batch + (alpha*differences) 26 | 27 | def lambda_gradient_penalty(logits, diff): 28 | with tf.name_scope('lambda_gradient_penalty'): 29 | gradients = tf.gradients(logits, [diff])[0] 30 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) 31 | gradient_penalty = tf.reduce_mean((slopes-1.)**2) 32 | 33 | return 10*gradient_penalty 34 | 35 | def wasserstein_loss(real_batch, fake_batch, discrim_func, discrim_scope): 36 | return 0 37 | 38 | with tf.name_scope('wasserstein_loss'): 39 | diff = interpolates(real_batch, fake_batch) 40 | diff_reshaped = tf.reshape(diff, tf.shape(real_batch)) 41 | interp_logits, _ = discrim_func(diff_reshaped, discrim_scope) 42 | 43 | return lambda_gradient_penalty(interp_logits, diff) 44 | 45 | def image_mean(img): 46 | with tf.name_scope('image_mean'): 47 | img_shape = img.get_shape().as_list() 48 | channels = img_shape[1] 49 | pixels = img_shape[2] * img_shape[3] 50 | 51 | mu = tf.reduce_mean(img, [2, 3], keepdims=True) 52 | img_mu = tf.reshape(img - mu, [-1, channels, pixels]) 53 | return mu, img_mu, pixels 54 | 55 | def image_covariance(img_mu, pixels): 56 | with tf.name_scope('image_covariance'): 57 | cov_matrix = tf.matmul(img_mu, img_mu, transpose_b=True) 58 | cov_matrix = cov_matrix / pixels 59 | 60 | return cov_matrix 61 | 62 | def colour_consistency_regularization(G1, G0, data_format): 63 | with tf.name_scope('cc_regularization'): 64 | lambda_1 = 1.0 65 | lambda_2 = 5.0 66 | alpha = 50.0 67 | 68 | if data_format == 'NHWC': 69 | G0 = layers.nhwc_to_nchw(G0) 70 | G1 = layers.nhwc_to_nchw(G1) 71 | 72 | mu_si1_j, G0_mu, G0_pixels = image_mean(G0) 73 | mu_si_j, G1_mu, G1_pixels = image_mean(G1) 74 | 75 | cov_si1_j = image_covariance(G0_mu, pixels=G0_pixels) 76 | cov_si_j = image_covariance(G1_mu, pixels=G1_pixels) 77 | 78 | L_ci = lambda_1 * tf.losses.mean_squared_error(mu_si_j, mu_si1_j) 79 | L_ci += lambda_2 * tf.losses.mean_squared_error(cov_si_j, cov_si1_j) 80 | 81 | return alpha * tf.reduce_mean(L_ci) 82 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import tpu 3 | from tensorflow.contrib.tpu import TPUEstimator as Estimator 4 | from tensorflow.contrib.cluster_resolver import TPUClusterResolver 5 | from src.config import config 6 | import src.estimator as estimator 7 | 8 | def main(_): 9 | tpu_grpc_url = None 10 | 11 | if config.use_tpu: 12 | tpu_grpc_url = TPUClusterResolver(tpu=config.tpu_name).get_master() 13 | 14 | run_config = tpu.RunConfig( 15 | master=tpu_grpc_url, 16 | evaluation_master=tpu_grpc_url, 17 | model_dir=config.log_dir, 18 | session_config=tf.ConfigProto(allow_soft_placement=True), 19 | tpu_config=tpu.TPUConfig(config.tpu_iterations, config.tpu_shards) 20 | ) 21 | 22 | batch_size = config.batch_size * config.tpu_shards if config.use_tpu else config.batch_size 23 | est = Estimator( 24 | model_fn=estimator.model_fn, 25 | use_tpu=config.use_tpu, 26 | train_batch_size=batch_size, 27 | eval_batch_size=batch_size, 28 | params={ 29 | "use_tpu": config.use_tpu, 30 | "data_dir": config.data_dir, 31 | "buffer_size": config.buffer_size, 32 | "data_format": "NCHW" if config.use_tpu else "NHWC", 33 | "z_dim": config.z_dim, 34 | "D_lr": config.d_lr, 35 | "G_lr": config.g_lr, 36 | "data_seed": config.data_shuffle_seed, 37 | "data_map_parallelism": config.data_map_parallelism 38 | }, 39 | config=run_config 40 | ) 41 | 42 | if config.train: 43 | est.train( 44 | input_fn=estimator.train_input_fn, 45 | max_steps=config.train_steps 46 | ) 47 | if config.eval: 48 | est.evaluate( 49 | input_fn=estimator.eval_input_fn, 50 | steps=config.eval_steps 51 | ) 52 | elif config.predict: 53 | est.predict( 54 | input_fn=lambda params: estimator.predict_input_fn(params, config.predict_class), 55 | predict_keys=['G2'] 56 | ) 57 | 58 | if __name__ == '__main__': 59 | tf.app.run() 60 | --------------------------------------------------------------------------------