├── .DS_Store ├── FFHQ_mu_cov.pickle ├── LICENSE ├── README.md ├── StyleGAN2.py ├── assets ├── .DS_Store ├── sample_2.gif ├── style_mixing.png ├── teaser.png ├── truncation_trick.png └── uncurated.png ├── cuda ├── custom_ops.py ├── fused_bias_act.cu ├── fused_bias_act.py ├── upfirdn_2d.cu └── upfirdn_2d.py ├── generate_video.py ├── layers.py ├── main.py ├── networks.py ├── ops.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/.DS_Store -------------------------------------------------------------------------------- /FFHQ_mu_cov.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/FFHQ_mu_cov.pickle -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Spatial unbiased GANs — Simple TensorFlow Implementation [[Paper]](https://arxiv.org/abs/2108.01285) 2 | ### : Toward Spatially Unbiased Generative Models (ICCV 2021) 3 | 4 |
5 | 6 |
7 | 8 | > **Abstract** *Recent image generation models show remarkable generation performance. However, they mirror strong location preference in datasets, which we call **spatial bias**. Therefore, generators render poor samples at unseen locations and scales. We argue that the generators rely on their implicit positional encoding to render spatial content. From our observations, the generator’s implicit positional encoding is translation-variant, making the generator spatially biased. To address this issue, we propose injecting explicit positional encoding at each scale of the generator. By learning the spatially unbiased generator, we facilitate the robust use of generators in multiple tasks, such as GAN inversion, multi-scale generation, generation of arbitrary sizes and aspect ratios. Furthermore, we show that our method can also be applied to denoising diffusion probabilistic models.* 9 | 10 | ## Requirements 11 | * `Tensorflow >= 2.x` 12 | 13 | ## Usage 14 | ``` 15 | ├── dataset 16 |    └── YOUR_DATASET_NAME 17 | ├── 000001.jpg 18 | ├── 000002.png 19 | └── ... 20 | ``` 21 | 22 | ### Train 23 | ``` 24 | > python main.py --dataset FFHQ --phase train --img_size 256 --batch_size 4 --n_total_image 6400 25 | ``` 26 | 27 | ### Generate Video 28 | ``` 29 | > python generate_video.py 30 | ``` 31 | 32 | ## Results 33 | * **FID: 3.81 (6.4M images(200k iterations), 8GPU, each 4 batch size)** 34 | * FID reported in the paper: **6.75** 35 | ### Video 36 |
37 | 38 |
39 | 40 | ### Uncuratd 41 |
42 | 43 |
44 | 45 | ### Style mixing 46 | * It's worse than stylegan2. 47 |
48 | 49 |
50 | 51 | ### Truncation trick 52 |
53 | 54 |
55 | 56 | ## Reference 57 | * [Official Pytorch](https://github.com/jychoi118/toward_spatial_unbiased) 58 | * [StyleGAN2-Tensorflow](https://github.com/moono/stylegan2-tf-2.x) 59 | 60 | ## Author 61 | [Junho Kim](http://bit.ly/jhkim_resume) 62 | -------------------------------------------------------------------------------- /StyleGAN2.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import time 3 | from tensorflow.python.data.experimental import AUTOTUNE 4 | from networks import * 5 | import PIL.Image 6 | import scipy 7 | import pickle 8 | automatic_gpu_usage() 9 | 10 | class Inception_V3(tf.keras.Model): 11 | def __init__(self, name='Inception_V3'): 12 | super(Inception_V3, self).__init__(name=name) 13 | 14 | # tf.keras.backend.image_data_format = 'channels_first' 15 | self.inception_v3_preprocess = tf.keras.applications.inception_v3.preprocess_input 16 | self.inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights='imagenet', include_top=False, pooling='avg') 17 | self.inception_v3.trainable = False 18 | 19 | def torch_normalization(self, x): 20 | x /= 255. 21 | 22 | r, g, b = tf.split(axis=-1, num_or_size_splits=3, value=x) 23 | 24 | mean = [0.485, 0.456, 0.406] 25 | std = [0.229, 0.224, 0.225] 26 | 27 | x = tf.concat(axis=-1, values=[ 28 | (r - mean[0]) / std[0], 29 | (g - mean[1]) / std[1], 30 | (b - mean[2]) / std[2] 31 | ]) 32 | 33 | return x 34 | 35 | # @tf.function 36 | def call(self, x, training=False, mask=None): 37 | # x = self.torch_normalization(x) 38 | x = self.inception_v3(x, training=training) 39 | 40 | return x 41 | 42 | def inference_feat(self, x, training=False): 43 | inception_real_img = adjust_dynamic_range(x, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.float32) 44 | inception_real_img = tf.image.resize(inception_real_img, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 45 | inception_real_img = self.torch_normalization(inception_real_img) 46 | 47 | inception_feat = self.inception_v3(inception_real_img, training=training) 48 | 49 | return inception_feat 50 | 51 | class StyleGAN2(): 52 | def __init__(self, t_params, strategy): 53 | super(StyleGAN2, self).__init__() 54 | 55 | self.model_name = 'StyleGAN2' 56 | self.phase = t_params['phase'] 57 | self.checkpoint_dir = t_params['checkpoint_dir'] 58 | self.result_dir = t_params['result_dir'] 59 | self.log_dir = t_params['log_dir'] 60 | self.sample_dir = t_params['sample_dir'] 61 | self.dataset_name = t_params['dataset'] 62 | self.config = t_params['config'] 63 | 64 | self.n_total_image = t_params['n_total_image'] * 1000 65 | 66 | self.strategy = strategy 67 | self.batch_size = t_params['batch_size'] 68 | self.each_batch_size = t_params['batch_size'] // t_params['NUM_GPUS'] 69 | 70 | self.NUM_GPUS = t_params['NUM_GPUS'] 71 | self.iteration = self.n_total_image // self.batch_size 72 | 73 | self.n_samples = min(t_params['batch_size'], t_params['n_samples']) 74 | self.n_test = t_params['n_test'] 75 | self.img_size = t_params['img_size'] 76 | 77 | self.log_template = 'step [{}/{}]: elapsed: {:.2f}s, d_loss: {:.3f}, g_loss: {:.3f}, r1_reg: {:.3f}, pl_reg: {:.3f}, fid: {:.2f}, best_fid: {:.2f}, best_fid_iter: {}' 78 | 79 | self.lazy_regularization = t_params['lazy_regularization'] 80 | self.print_freq = t_params['print_freq'] 81 | self.save_freq = t_params['save_freq'] 82 | 83 | self.r1_gamma = 10.0 84 | 85 | # setup optimizer params 86 | self.g_params = t_params['g_params'] 87 | 88 | self.d_params = t_params['d_params'] 89 | self.g_opt = t_params['g_opt'] 90 | self.d_opt = t_params['d_opt'] 91 | self.g_opt = self.set_optimizer_params(self.g_opt) 92 | self.d_opt = self.set_optimizer_params(self.d_opt) 93 | 94 | self.pl_minibatch_shrink = 2 95 | self.pl_decay = 0.01 96 | self.pl_weight = float(self.pl_minibatch_shrink) 97 | self.pl_denorm = tf.math.rsqrt(float(self.img_size) * float(self.img_size)) 98 | self.pl_mean = tf.Variable(initial_value=0.0, name='pl_mean', trainable=False, 99 | synchronization=tf.VariableSynchronization.ON_READ, 100 | aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA 101 | ) 102 | 103 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 104 | check_folder(self.sample_dir) 105 | 106 | self.checkpoint_dir = os.path.join(self.checkpoint_dir, self.model_dir) 107 | check_folder(self.checkpoint_dir) 108 | 109 | self.log_dir = os.path.join(self.log_dir, self.model_dir) 110 | check_folder(self.log_dir) 111 | 112 | 113 | 114 | dataset_path = './dataset' 115 | self.dataset_path = os.path.join(dataset_path, self.dataset_name) 116 | 117 | print(self.dataset_path) 118 | 119 | if os.path.exists('{}_mu_cov.pickle'.format(self.dataset_name)): 120 | with open('{}_mu_cov.pickle'.format(self.dataset_name), 'rb') as f: 121 | self.real_mu, self.real_cov = pickle.load(f) 122 | self.real_cache = True 123 | print("Pickle load success !!!") 124 | else: 125 | print("Pickle load fail !!!") 126 | self.real_cache = False 127 | 128 | self.fid_samples_num = 10000 129 | print() 130 | 131 | physical_gpus = tf.config.experimental.list_physical_devices('GPU') 132 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 133 | print(len(physical_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 134 | print("Each batch size : ", self.each_batch_size) 135 | print("Global batch size : ", self.batch_size) 136 | print("Target image size : ", self.img_size) 137 | print("Print frequency : ", self.print_freq) 138 | print("Save frequency : ", self.save_freq) 139 | 140 | print("TF Version :", tf.__version__) 141 | 142 | def set_optimizer_params(self, params): 143 | if self.lazy_regularization: 144 | mb_ratio = params['reg_interval'] / (params['reg_interval'] + 1) 145 | params['learning_rate'] = params['learning_rate'] * mb_ratio 146 | params['beta1'] = params['beta1'] ** mb_ratio 147 | params['beta2'] = params['beta2'] ** mb_ratio 148 | return params 149 | 150 | ################################################################################## 151 | # Model 152 | ################################################################################## 153 | def build_model(self): 154 | if self.phase == 'train': 155 | """ Input Image""" 156 | img_class = Image_data(self.img_size, self.g_params['z_dim'], self.g_params['labels_dim'], self.dataset_path) 157 | img_class.preprocess() 158 | 159 | dataset_num = len(img_class.train_images) 160 | if dataset_num > 10000: 161 | self.fid_samples_num = 50000 162 | print("Dataset number : ", dataset_num) 163 | print() 164 | 165 | dataset_slice = tf.data.Dataset.from_tensor_slices(img_class.train_images) 166 | 167 | gpu_device = '/gpu:0' 168 | 169 | dataset_iter = dataset_slice.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True).repeat() 170 | dataset_iter = dataset_iter.map(map_func=img_class.image_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size, drop_remainder=True) 171 | dataset_iter = dataset_iter.prefetch(buffer_size=AUTOTUNE) 172 | dataset_iter = self.strategy.experimental_distribute_dataset(dataset_iter) 173 | 174 | img_slice = dataset_slice.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True, seed=777) 175 | img_slice = img_slice.map(map_func=inception_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size, drop_remainder=False) 176 | img_slice = img_slice.prefetch(buffer_size=AUTOTUNE) 177 | self.fid_img_slice = self.strategy.experimental_distribute_dataset(img_slice) 178 | 179 | self.dataset_iter = iter(dataset_iter) 180 | 181 | 182 | """ Network """ 183 | self.generator = Generator(self.g_params, name='Generator') 184 | self.discriminator = Discriminator(self.d_params, name='Discriminator') 185 | self.g_clone = Generator(self.g_params, name='Generator') 186 | self.inception_model = Inception_V3() 187 | 188 | """ Finalize model (build) """ 189 | test_latent = np.ones((1, self.g_params['z_dim']), dtype=np.float32) 190 | test_labels = np.ones((1, self.g_params['labels_dim']), dtype=np.float32) 191 | test_images = np.ones((1, 3, self.img_size, self.img_size), dtype=np.float32) 192 | test_images_inception = np.ones((1, 299, 299, 3), dtype=np.float32) 193 | 194 | _, __ = self.generator([test_latent, test_labels], training=False) 195 | _ = self.discriminator([test_images, test_labels], training=False) 196 | _, __ = self.g_clone([test_latent, test_labels], training=False) 197 | _ = self.inception_model(test_images_inception) 198 | 199 | # Copying g_clone 200 | self.g_clone.set_weights(self.generator.get_weights()) 201 | 202 | """ Optimizer """ 203 | self.d_optimizer = tf.keras.optimizers.Adam(self.d_opt['learning_rate'], 204 | beta_1=self.d_opt['beta1'], 205 | beta_2=self.d_opt['beta2'], 206 | epsilon=self.d_opt['epsilon']) 207 | self.g_optimizer = tf.keras.optimizers.Adam(self.g_opt['learning_rate'], 208 | beta_1=self.g_opt['beta1'], 209 | beta_2=self.g_opt['beta2'], 210 | epsilon=self.g_opt['epsilon']) 211 | 212 | """ Checkpoint """ 213 | self.ckpt = tf.train.Checkpoint(generator=self.generator, 214 | g_clone=self.g_clone, 215 | discriminator=self.discriminator, 216 | g_optimizer=self.g_optimizer, 217 | d_optimizer=self.d_optimizer) 218 | self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) 219 | self.start_iteration = 0 220 | 221 | if self.manager.latest_checkpoint: 222 | self.ckpt.restore(self.manager.latest_checkpoint).expect_partial() 223 | self.start_iteration = int(self.manager.latest_checkpoint.split('-')[-1]) 224 | print('Latest checkpoint restored!!') 225 | print('start iteration : ', self.start_iteration) 226 | else: 227 | print('Not restoring from saved checkpoint') 228 | 229 | else: 230 | """ Test """ 231 | """ Network """ 232 | self.g_clone = Generator(self.g_params, name='Generator') 233 | self.discriminator = Discriminator(self.d_params, name='Discriminator') 234 | 235 | """ Finalize model (build) """ 236 | test_latent = np.ones((1, self.g_params['z_dim']), dtype=np.float32) 237 | test_labels = np.ones((1, self.g_params['labels_dim']), dtype=np.float32) 238 | test_images = np.ones((1, 3, self.img_size, self.img_size), dtype=np.float32) 239 | _ = self.discriminator([test_images, test_labels], training=False) 240 | _, __ = self.g_clone([test_latent, test_labels], training=False) 241 | 242 | """ Checkpoint """ 243 | self.ckpt = tf.train.Checkpoint(g_clone=self.g_clone, discriminator=self.discriminator) 244 | self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2) 245 | 246 | if self.manager.latest_checkpoint: 247 | self.ckpt.restore(self.manager.latest_checkpoint).expect_partial() 248 | print('Latest checkpoint restored!!') 249 | else: 250 | print('Not restoring from saved checkpoint') 251 | 252 | def d_train_step(self, z, real_images, labels): 253 | with tf.GradientTape() as d_tape: 254 | # forward pass 255 | fake_images, _ = self.generator([z, labels], training=True) 256 | real_scores = self.discriminator([real_images, labels], training=True) 257 | fake_scores = self.discriminator([fake_images, labels], training=True) 258 | 259 | # gan loss 260 | d_adv_loss = tf.math.softplus(fake_scores) 261 | d_adv_loss += tf.math.softplus(-real_scores) 262 | d_adv_loss = multi_gpu_loss(d_adv_loss, global_batch_size=self.batch_size) 263 | 264 | d_loss = d_adv_loss 265 | 266 | d_gradients = d_tape.gradient(d_loss, self.discriminator.trainable_variables) 267 | self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables)) 268 | 269 | return d_loss, d_adv_loss 270 | 271 | def d_reg_train_step(self, z, real_images, labels): 272 | with tf.GradientTape() as d_tape: 273 | # forward pass 274 | fake_images, _ = self.generator([z, labels], training=True) 275 | real_scores = self.discriminator([real_images, labels], training=True) 276 | fake_scores = self.discriminator([fake_images, labels], training=True) 277 | 278 | # gan loss 279 | d_adv_loss = tf.math.softplus(fake_scores) 280 | d_adv_loss += tf.math.softplus(-real_scores) 281 | 282 | # simple GP 283 | with tf.GradientTape() as p_tape: 284 | p_tape.watch([real_images, labels]) 285 | real_loss = tf.reduce_sum(self.discriminator([real_images, labels], training=True)) 286 | 287 | real_grads = p_tape.gradient(real_loss, real_images) 288 | r1_penalty = tf.reduce_sum(tf.math.square(real_grads), axis=[1, 2, 3]) 289 | r1_penalty = tf.expand_dims(r1_penalty, axis=1) 290 | r1_penalty = r1_penalty * self.d_opt['reg_interval'] 291 | 292 | # combine 293 | d_adv_loss += r1_penalty * (0.5 * self.r1_gamma) 294 | d_adv_loss = multi_gpu_loss(d_adv_loss, global_batch_size=self.batch_size) 295 | 296 | d_loss = d_adv_loss 297 | 298 | d_gradients = d_tape.gradient(d_loss, self.discriminator.trainable_variables) 299 | self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables)) 300 | 301 | r1_penalty = multi_gpu_loss(r1_penalty, global_batch_size=self.batch_size) 302 | 303 | return d_loss, d_adv_loss, r1_penalty 304 | 305 | def g_train_step(self, z, labels): 306 | with tf.GradientTape() as g_tape: 307 | # forward pass 308 | fake_images, _ = self.generator([z, labels], training=True) 309 | fake_scores = self.discriminator([fake_images, labels], training=True) 310 | 311 | # gan loss 312 | g_adv_loss = tf.math.softplus(-fake_scores) 313 | g_adv_loss = multi_gpu_loss(g_adv_loss, global_batch_size=self.batch_size) 314 | 315 | g_loss = g_adv_loss 316 | 317 | g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables) 318 | self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables)) 319 | 320 | return g_loss, g_adv_loss 321 | 322 | def g_reg_train_step(self, z, labels): 323 | with tf.GradientTape() as g_tape: 324 | # forward pass 325 | fake_images, w_broadcasted = self.generator([z, labels], training=True) 326 | fake_scores = self.discriminator([fake_images, labels], training=True) 327 | 328 | # gan loss 329 | g_adv_loss = tf.math.softplus(-fake_scores) 330 | 331 | # path length regularization 332 | pl_reg = self.path_regularization(pl_minibatch_shrink=self.pl_minibatch_shrink) 333 | 334 | # combine 335 | g_adv_loss += pl_reg 336 | g_adv_loss = multi_gpu_loss(g_adv_loss, global_batch_size=self.batch_size) 337 | 338 | g_loss = g_adv_loss 339 | 340 | g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables) 341 | self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables)) 342 | 343 | pl_reg = multi_gpu_loss(pl_reg, global_batch_size=self.batch_size) 344 | 345 | return g_loss, g_adv_loss, pl_reg 346 | 347 | def path_regularization(self, pl_minibatch_shrink=2): 348 | # path length regularization 349 | # Compute |J*y|. 350 | 351 | pl_minibatch = tf.maximum(1, tf.math.floordiv(self.each_batch_size, pl_minibatch_shrink)) 352 | pl_z = tf.random.normal(shape=[pl_minibatch, self.g_params['z_dim']], dtype=tf.float32) 353 | pl_labels = tf.random.normal(shape=[pl_minibatch, self.g_params['labels_dim']], dtype=tf.float32) 354 | 355 | with tf.GradientTape() as pl_tape: 356 | pl_tape.watch([pl_z, pl_labels]) 357 | pl_fake_images, pl_w_broadcasted = self.generator([pl_z, pl_labels], training=True) 358 | 359 | pl_noise = tf.random.normal(tf.shape(pl_fake_images), mean=0.0, stddev=1.0, dtype=tf.float32) * self.pl_denorm 360 | pl_noise_applied = tf.reduce_sum(pl_fake_images * pl_noise) 361 | 362 | pl_grads = pl_tape.gradient(pl_noise_applied, pl_w_broadcasted) 363 | pl_lengths = tf.math.sqrt(tf.reduce_mean(tf.reduce_sum(tf.math.square(pl_grads), axis=2), axis=1)) 364 | 365 | # Track exponential moving average of |J*y|. 366 | pl_mean_val = self.pl_mean + self.pl_decay * (tf.reduce_mean(pl_lengths) - self.pl_mean) 367 | self.pl_mean.assign(pl_mean_val) 368 | 369 | # Calculate (|J*y|-a)^2. 370 | pl_penalty = tf.square(pl_lengths - self.pl_mean) * self.g_opt['reg_interval'] 371 | 372 | # compute 373 | pl_reg = pl_penalty * self.pl_weight 374 | 375 | return pl_reg 376 | 377 | """ Distribute Train """ 378 | @tf.function 379 | def distribute_d_train_step(self, z, real_images, labels): 380 | d_loss, d_adv_loss = self.strategy.run(self.d_train_step, args=(z, real_images, labels)) 381 | 382 | d_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_loss, axis=None) 383 | d_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_adv_loss, axis=None) 384 | 385 | return d_loss, d_adv_loss 386 | 387 | @tf.function 388 | def distribute_d_reg_train_step(self, z, real_images, labels): 389 | d_loss, d_adv_loss, r1_penalty = self.strategy.run(self.d_reg_train_step, args=(z, real_images, labels)) 390 | 391 | d_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_loss, axis=None) 392 | d_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_adv_loss, axis=None) 393 | r1_penalty = self.strategy.reduce(tf.distribute.ReduceOp.SUM, r1_penalty, axis=None) 394 | 395 | return d_loss, d_adv_loss, r1_penalty 396 | 397 | @tf.function 398 | def distribute_g_train_step(self, z, labels): 399 | g_loss, g_adv_loss = self.strategy.run(self.g_train_step, args=(z, labels)) 400 | 401 | g_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_loss, axis=None) 402 | g_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_adv_loss, axis=None) 403 | 404 | return g_loss, g_adv_loss 405 | 406 | @tf.function 407 | def distribute_g_reg_train_step(self, z, labels): 408 | g_loss, g_adv_loss, pl_reg = self.strategy.run(self.g_reg_train_step, args=(z, labels)) 409 | 410 | g_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_loss, axis=None) 411 | g_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_adv_loss, axis=None) 412 | pl_reg = self.strategy.reduce(tf.distribute.ReduceOp.SUM, pl_reg, axis=None) 413 | 414 | return g_loss, g_adv_loss, pl_reg 415 | 416 | def train(self): 417 | 418 | start_time = time.time() 419 | 420 | # setup tensorboards 421 | train_summary_writer = tf.summary.create_file_writer(self.log_dir) 422 | 423 | # start training 424 | print('max_steps: {}'.format(self.iteration)) 425 | losses = {'g/loss': 0.0, 'd/loss': 0.0, 'r1_reg': 0.0, 'pl_reg': 0.0, 426 | 'g/adv_loss': 0.0, 427 | 'd/adv_loss': 0.0, 428 | 'fid': 0.0, 'best_fid': 0.0, 'best_fid_iter': 0} 429 | fid = 0 430 | best_fid = 1000 431 | best_fid_iter = 0 432 | for idx in range(self.start_iteration, self.iteration): 433 | iter_start_time = time.time() 434 | 435 | x_real, z, labels = next(self.dataset_iter) 436 | 437 | if idx == 0: 438 | g_params = self.generator.count_params() 439 | d_params = self.discriminator.count_params() 440 | print("G network parameters : ", format(g_params, ',')) 441 | print("D network parameters : ", format(d_params, ',')) 442 | print("Total network parameters : ", format(g_params + d_params, ',')) 443 | 444 | # update discriminator 445 | # At first time, each function takes 1~2 min to make the graph. 446 | if (idx + 1) % self.d_opt['reg_interval'] == 0: 447 | d_loss, d_adv_loss, r1_reg = self.distribute_d_reg_train_step(z, x_real, labels) 448 | losses['r1_reg'] = np.float64(r1_reg) 449 | else: 450 | d_loss, d_adv_loss = self.distribute_d_train_step(z, x_real, labels) 451 | 452 | losses['d/loss'] = np.float64(d_loss) 453 | losses['d/adv_loss'] = np.float64(d_adv_loss) 454 | 455 | # update generator 456 | # At first time, each function takes 1~2 min to make the graph. 457 | if (idx + 1) % self.g_opt['reg_interval'] == 0: 458 | g_loss, g_adv_loss, pl_reg = self.distribute_g_reg_train_step(z, labels) 459 | losses['pl_reg'] = np.float64(pl_reg) 460 | else: 461 | g_loss, g_adv_loss = self.distribute_g_train_step(z, labels) 462 | 463 | losses['g/loss'] = np.float64(g_loss) 464 | losses['g/adv_loss'] = np.float64(g_adv_loss) 465 | 466 | 467 | # update g_clone 468 | self.g_clone.set_as_moving_average_of(self.generator) 469 | 470 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1 : 471 | fid = self.calculate_FID() 472 | if fid < best_fid: 473 | print("BEST FID UPDATED") 474 | best_fid = fid 475 | best_fid_iter = idx 476 | self.manager.save(checkpoint_number=idx) 477 | losses['fid'] = np.float64(fid) 478 | 479 | 480 | # save to tensorboard 481 | 482 | with train_summary_writer.as_default(): 483 | tf.summary.scalar('g_loss', losses['g/loss'], step=idx) 484 | tf.summary.scalar('g_adv_loss', losses['g/adv_loss'], step=idx) 485 | 486 | tf.summary.scalar('d_loss', losses['d/loss'], step=idx) 487 | tf.summary.scalar('d_adv_loss', losses['d/adv_loss'], step=idx) 488 | 489 | tf.summary.scalar('r1_reg', losses['r1_reg'], step=idx) 490 | tf.summary.scalar('pl_reg', losses['pl_reg'], step=idx) 491 | # tf.summary.histogram('w_avg', self.generator.w_avg, step=idx) 492 | 493 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1: 494 | tf.summary.scalar('fid', losses['fid'], step=idx) 495 | 496 | # save every self.save_freq 497 | # if np.mod(idx + 1, self.save_freq) == 0: 498 | # self.manager.save(checkpoint_number=idx + 1) 499 | 500 | # save every self.print_freq 501 | if np.mod(idx + 1, self.print_freq) == 0: 502 | total_num_samples = min(self.n_samples, self.batch_size) 503 | partial_size = int(np.floor(np.sqrt(total_num_samples))) 504 | 505 | # prepare inputs 506 | latents = tf.random.normal(shape=(self.n_samples, self.g_params['z_dim']), dtype=tf.dtypes.float32) 507 | dummy_labels = tf.random.normal((self.n_samples, self.g_params['labels_dim']), dtype=tf.dtypes.float32) 508 | 509 | # run networks 510 | fake_img, _ = self.g_clone([latents, dummy_labels], truncation_psi=1.0, training=False) 511 | 512 | save_images(images=fake_img[:partial_size * partial_size, :, :, :], 513 | size=[partial_size, partial_size], 514 | image_path='./{}/fake_{:06d}.png'.format(self.sample_dir, idx + 1)) 515 | 516 | x_real_concat = tf.concat(self.strategy.experimental_local_results(x_real), axis=0) 517 | self.truncation_psi_canvas(x_real_concat, path='./{}/fake_psi_{:06d}.png'.format(self.sample_dir, idx + 1)) 518 | 519 | elapsed = time.time() - iter_start_time 520 | print(self.log_template.format(idx, self.iteration, elapsed, 521 | losses['d/loss'], losses['g/loss'], losses['r1_reg'], losses['pl_reg'], fid, best_fid, best_fid_iter)) 522 | # save model for final step 523 | self.manager.save(checkpoint_number=self.iteration) 524 | 525 | print("LAST FID: ", fid) 526 | print("BEST FID: {}, {}".format(best_fid, best_fid_iter)) 527 | print("Total train time: %4.4f" % (time.time() - start_time)) 528 | 529 | @property 530 | def model_dir(self): 531 | return "{}_{}_{}_{}".format(self.model_name, self.dataset_name, self.img_size, self.config) 532 | 533 | 534 | def calculate_FID(self): 535 | @tf.function 536 | def gen_samples_feats(test_z, test_labels, g_clone, inception_model): 537 | # run networks 538 | fake_img, _ = g_clone([test_z, test_labels], training=False) 539 | fake_img = adjust_dynamic_range(fake_img, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.float32) 540 | fake_img = tf.transpose(fake_img, [0, 2, 3, 1]) 541 | fake_img = tf.image.resize(fake_img, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 542 | 543 | fake_img = torch_normalization(fake_img) 544 | 545 | feats = inception_model(fake_img) 546 | 547 | return feats 548 | 549 | @tf.function 550 | def get_inception_features(img, inception_model): 551 | feats = inception_model(img) 552 | return feats 553 | 554 | @tf.function 555 | def get_real_features(img, inception_model): 556 | feats = self.strategy.run(get_inception_features, args=(img, inception_model)) 557 | feats = tf.concat(self.strategy.experimental_local_results(feats), axis=0) 558 | 559 | return feats 560 | 561 | @tf.function 562 | def get_fake_features(z, dummy_labels, g_clone, inception_model): 563 | 564 | feats = self.strategy.run(gen_samples_feats, args=(z, dummy_labels, g_clone, inception_model)) 565 | feats = tf.concat(self.strategy.experimental_local_results(feats), axis=0) 566 | 567 | return feats 568 | 569 | @tf.function 570 | def convert_per_replica_image(nchw_per_replica_images, strategy): 571 | as_tensor = tf.concat(strategy.experimental_local_results(nchw_per_replica_images), axis=0) 572 | as_tensor = tf.transpose(as_tensor, perm=[0, 2, 3, 1]) 573 | as_tensor = (tf.clip_by_value(as_tensor, -1.0, 1.0) + 1.0) * 127.5 574 | as_tensor = tf.cast(as_tensor, tf.uint8) 575 | as_tensor = tf.image.resize(as_tensor, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 576 | 577 | return as_tensor 578 | 579 | if not self.real_cache: 580 | real_feats = tf.zeros([0, 2048]) 581 | """ Input Image""" 582 | # img_class = Image_data(self.img_size, self.g_params['z_dim'], self.g_params['labels_dim'], 583 | # self.dataset_path) 584 | # img_class.preprocess() 585 | # dataset_num = len(img_class.train_images) 586 | # img_slice = tf.data.Dataset.from_tensor_slices(img_class.train_images) 587 | # 588 | # img_slice = img_slice.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True, seed=777) 589 | # img_slice = img_slice.map(map_func=inception_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size, 590 | # drop_remainder=False) 591 | # img_slice = img_slice.prefetch(buffer_size=AUTOTUNE) 592 | # img_slice = self.strategy.experimental_distribute_dataset(img_slice) 593 | 594 | for img in self.fid_img_slice: 595 | feats = get_real_features(img, self.inception_model) 596 | real_feats = tf.concat([real_feats, feats], axis=0) 597 | print('real feats:', np.shape(real_feats)[0]) 598 | 599 | self.real_mu = np.mean(real_feats, axis=0) 600 | self.real_cov = np.cov(real_feats, rowvar=False) 601 | 602 | with open('{}_mu_cov.pickle'.format(self.dataset_name), 'wb') as f: 603 | pickle.dump((self.real_mu, self.real_cov), f, protocol=pickle.HIGHEST_PROTOCOL) 604 | 605 | print('{} real pickle save !!!'.format(self.dataset_name)) 606 | 607 | self.real_cache = True 608 | del real_feats 609 | 610 | fake_feats = tf.zeros([0, 2048]) 611 | from tqdm import tqdm 612 | for begin in tqdm(range(0, self.fid_samples_num, self.batch_size)): 613 | z = tf.random.normal(shape=[self.each_batch_size, self.g_params['z_dim']]) 614 | dummy_labels = tf.random.normal((self.each_batch_size, self.g_params['labels_dim']), dtype=tf.float32) 615 | 616 | feats = get_fake_features(z, dummy_labels, self.g_clone, self.inception_model) 617 | 618 | fake_feats = tf.concat([fake_feats, feats], axis=0) 619 | # print('fake feats:', np.shape(fake_feats)[0]) 620 | 621 | fake_mu = np.mean(fake_feats, axis=0) 622 | fake_cov = np.cov(fake_feats, rowvar=False) 623 | del fake_feats 624 | 625 | # Calculate FID. 626 | m = np.square(fake_mu - self.real_mu).sum() 627 | s, _ = scipy.linalg.sqrtm(np.dot(fake_cov, self.real_cov), disp=False) # pylint: disable=no-member 628 | dist = m + np.trace(fake_cov + self.real_cov - 2 * s) 629 | 630 | return dist 631 | 632 | 633 | def truncation_psi_canvas(self, real_images, path): 634 | # prepare inputs 635 | reals = real_images[:self.n_samples, :, :, :] 636 | latents = tf.random.normal(shape=(self.n_samples, self.g_params['z_dim']), dtype=tf.dtypes.float32) 637 | dummy_labels = tf.random.normal((self.n_samples, self.g_params['labels_dim']), dtype=tf.dtypes.float32) 638 | 639 | # run networks 640 | fake_images_00, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.0, training=False) 641 | fake_images_05, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.5, training=False) 642 | fake_images_07, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.7, training=False) 643 | fake_images_10, _ = self.g_clone([latents, dummy_labels], truncation_psi=1.0, training=False) 644 | 645 | # merge on batch dimension: [4 * n_samples, 3, img_size, img_size] 646 | out = tf.concat([fake_images_00, fake_images_05, fake_images_07, fake_images_10], axis=0) 647 | 648 | # prepare for image saving: [4 * n_samples, img_size, img_size, 3] 649 | out = postprocess_images(out) 650 | 651 | # resize to save disk spaces: [4 * n_samples, size, size, 3] 652 | size = min(self.img_size, 256) 653 | out = tf.image.resize(out, size=[size, size], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 654 | 655 | # make single image and add batch dimension for tensorboard: [1, 4 * size, n_samples * size, 3] 656 | out = merge_batch_images(out, size, rows=4, cols=self.n_samples) 657 | 658 | images = cv2.cvtColor(out.astype('uint8'), cv2.COLOR_RGB2BGR) 659 | 660 | return cv2.imwrite(path, images) 661 | 662 | 663 | def test(self): 664 | result_dir = os.path.join(self.result_dir, self.model_dir) 665 | check_folder(result_dir) 666 | 667 | total_num_samples = min(self.n_samples, self.batch_size) 668 | partial_size = int(np.floor(np.sqrt(total_num_samples))) 669 | 670 | from tqdm import tqdm 671 | for i in tqdm(range(self.n_test)): 672 | z = tf.random.normal(shape=[self.batch_size, self.g_params['z_dim']]) 673 | dummy_labels = tf.random.normal((self.batch_size, self.g_params['labels_dim']), dtype=tf.float32) 674 | fake_img, _ = self.g_clone([z, dummy_labels], training=False) 675 | 676 | save_images(images=fake_img[:partial_size * partial_size, :, :, :], 677 | size=[partial_size, partial_size], 678 | image_path='./{}/fake_{:01d}.png'.format(result_dir, i)) 679 | 680 | def test_70000(self): 681 | result_dir = os.path.join(self.result_dir, self.model_dir) 682 | check_folder(result_dir) 683 | 684 | total_num_samples = 1 685 | partial_size = int(np.floor(np.sqrt(total_num_samples))) 686 | 687 | from tqdm import tqdm 688 | for i in tqdm(range(70000)): 689 | z = tf.random.normal(shape=[1, self.g_params['z_dim']]) 690 | dummy_labels = tf.random.normal((1, self.g_params['labels_dim']), dtype=tf.float32) 691 | fake_img, _ = self.g_clone([z, dummy_labels], training=False) 692 | 693 | save_images(images=fake_img[:partial_size * partial_size, :, :, :], 694 | size=[partial_size, partial_size], 695 | image_path='./{}/fake_{:01d}.png'.format(result_dir, i)) 696 | 697 | def draw_uncurated_result_figure(self): 698 | 699 | result_dir = os.path.join(self.result_dir, self.model_dir, 'paper_figure') 700 | check_folder(result_dir) 701 | 702 | seed_flag = True 703 | lods = [0, 1, 2, 2, 3, 3] 704 | seed = 3291 705 | rows = 3 706 | cx = 0 707 | cy = 0 708 | 709 | if seed_flag: 710 | latents = tf.cast( 711 | np.random.RandomState(seed).normal(size=[sum(rows * 2 ** lod for lod in lods), self.g_params['z_dim']]), tf.float32) 712 | else: 713 | latents = tf.cast(np.random.normal(size=[sum(rows * 2 ** lod for lod in lods), self.g_params['z_dim']]), tf.float32) 714 | 715 | dummy_labels = tf.random.normal((sum(rows * 2 ** lod for lod in lods), self.g_params['labels_dim']), dtype=tf.float32) 716 | 717 | images, _ = self.g_clone([latents, dummy_labels], training=False) 718 | images = postprocess_images(images) 719 | 720 | canvas = PIL.Image.new('RGB', (sum(self.img_size // 2 ** lod for lod in lods), self.img_size * rows), 'white') 721 | image_iter = iter(list(images)) 722 | 723 | for col, lod in enumerate(lods): 724 | for row in range(rows * 2 ** lod): 725 | image = PIL.Image.fromarray(np.uint8(next(image_iter)), 'RGB') 726 | 727 | image = image.crop((cx, cy, cx + self.img_size, cy + self.img_size)) 728 | image = image.resize((self.img_size // 2 ** lod, self.img_size // 2 ** lod), PIL.Image.ANTIALIAS) 729 | canvas.paste(image, 730 | (sum(self.img_size // 2 ** lod for lod in lods[:col]), row * self.img_size // 2 ** lod)) 731 | 732 | canvas.save('{}/figure02-uncurated.png'.format(result_dir)) 733 | 734 | def draw_style_mixing_figure(self): 735 | result_dir = os.path.join(self.result_dir, self.model_dir, 'paper_figure') 736 | check_folder(result_dir) 737 | 738 | seed_flag = True 739 | src_seeds = [604, 8440, 7613, 6978, 3004] 740 | dst_seeds = [1336, 6968, 607, 728, 7036, 9010] 741 | 742 | truncation_psi = 0.7 # Style strength multiplier for the truncation trick 743 | truncation_cutoff = 8 # Number of layers for which to apply the truncation trick 744 | 745 | resolutions = self.g_params['resolutions'] 746 | n_broadcast = len(resolutions) * 2 747 | 748 | style_ranges = [range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, n_broadcast)] 749 | 750 | if seed_flag: 751 | src_latents = tf.cast( 752 | np.concatenate(list(np.random.RandomState(seed).normal(size=[1, self.g_params['z_dim']]) for seed in src_seeds), axis=0), tf.float32) 753 | dst_latents = tf.cast( 754 | np.concatenate(list(np.random.RandomState(seed).normal(size=[1, self.g_params['z_dim']]) for seed in dst_seeds), axis=0), tf.float32) 755 | 756 | else: 757 | src_latents = tf.cast(np.random.normal(size=[len(src_seeds), self.g_params['z_dim']]), tf.float32) 758 | dst_latents = tf.cast(np.random.normal(size=[len(dst_seeds), self.g_params['z_dim']]), tf.float32) 759 | 760 | dummy_labels = tf.random.normal((len(src_seeds), self.g_params['labels_dim']), dtype=tf.float32) 761 | 762 | src_images, src_dlatents = self.g_clone([src_latents, dummy_labels], truncation_cutoff=truncation_cutoff, truncation_psi=truncation_psi, training=False) 763 | dst_images, dst_dlatents = self.g_clone([dst_latents, dummy_labels], truncation_cutoff=truncation_cutoff, truncation_psi=truncation_psi, training=False) 764 | 765 | src_images = postprocess_images(src_images) 766 | dst_images = postprocess_images(dst_images) 767 | 768 | img_out_size = min(self.img_size, 256) 769 | 770 | src_images = tf.image.resize(src_images, size=[img_out_size, img_out_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 771 | dst_images = tf.image.resize(dst_images, size=[img_out_size, img_out_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 772 | 773 | canvas = PIL.Image.new('RGB', (img_out_size * (len(src_seeds) + 1), img_out_size * (len(dst_seeds) + 1)), 'white') 774 | 775 | for col, src_image in enumerate(list(src_images)): 776 | canvas.paste(PIL.Image.fromarray(np.uint8(src_image), 'RGB'), ((col + 1) * img_out_size, 0)) 777 | 778 | for row, dst_image in enumerate(list(dst_images)): 779 | canvas.paste(PIL.Image.fromarray(np.uint8(dst_image), 'RGB'), (0, (row + 1) * img_out_size)) 780 | 781 | row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds)) 782 | src_dlatents = np.asarray(src_dlatents, dtype=np.float32) 783 | row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]] 784 | 785 | row_images, _ = self.g_clone([row_dlatents, dummy_labels], mapping=False, training=False) 786 | row_images = postprocess_images(row_images) 787 | 788 | 789 | for col, image in enumerate(list(row_images)): 790 | canvas.paste(PIL.Image.fromarray(np.uint8(image), 'RGB'), ((col + 1) * img_out_size, (row + 1) * img_out_size)) 791 | 792 | canvas.save('{}/figure03-style-mixing.png'.format(result_dir)) 793 | 794 | def draw_truncation_trick_figure(self): 795 | 796 | result_dir = os.path.join(self.result_dir, self.model_dir, 'paper_figure') 797 | check_folder(result_dir) 798 | 799 | seed_flag = True 800 | seeds = [1653, 4010] 801 | psis = [-1, -0.7, -0.5, 0, 0.5, 0.7, 1] 802 | 803 | if seed_flag: 804 | latents = tf.cast( 805 | np.concatenate(list(np.random.RandomState(seed).normal(size=[1, self.g_params['z_dim']]) for seed in seeds), axis=0), tf.float32) 806 | else: 807 | latents = tf.cast(np.random.normal(size=[len(seeds), self.g_params['z_dim']]), tf.float32) 808 | 809 | dummy_labels = tf.random.normal((len(seeds), self.g_params['labels_dim']), dtype=tf.float32) 810 | 811 | fake_images_10_, _ = self.g_clone([latents, dummy_labels], truncation_psi=-1.0, training=False) 812 | fake_images_05_, _ = self.g_clone([latents, dummy_labels], truncation_psi=-0.5, training=False) 813 | fake_images_07_, _ = self.g_clone([latents, dummy_labels], truncation_psi=-0.7, training=False) 814 | fake_images_00, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.0, training=False) 815 | fake_images_05, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.5, training=False) 816 | fake_images_07, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.7, training=False) 817 | fake_images_10, _ = self.g_clone([latents, dummy_labels], truncation_psi=1.0, training=False) 818 | 819 | # merge on batch dimension: [7, 3, img_size, img_size] 820 | col_images = list([fake_images_10_, fake_images_05_, fake_images_07_, fake_images_00, fake_images_05, fake_images_07, fake_images_10]) 821 | 822 | img_out_size = min(self.img_size, 256) 823 | 824 | for i in range(len(col_images)): 825 | col_images[i] = postprocess_images(col_images[i]) 826 | col_images[i] = tf.image.resize(col_images[i], size=[img_out_size, img_out_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 827 | 828 | canvas = PIL.Image.new('RGB', (img_out_size * len(psis), img_out_size * len(seeds)), 'white') 829 | 830 | for col, col_img in enumerate(col_images): 831 | for row, image in enumerate(col_img): 832 | canvas.paste(PIL.Image.fromarray(np.uint8(image), 'RGB'), 833 | (col * img_out_size, row * img_out_size)) 834 | 835 | canvas.save('{}/figure08-truncation-trick.png'.format(result_dir)) 836 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/.DS_Store -------------------------------------------------------------------------------- /assets/sample_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/sample_2.gif -------------------------------------------------------------------------------- /assets/style_mixing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/style_mixing.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/teaser.png -------------------------------------------------------------------------------- /assets/truncation_trick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/truncation_trick.png -------------------------------------------------------------------------------- /assets/uncurated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/uncurated.png -------------------------------------------------------------------------------- /cuda/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """TensorFlow custom ops builder. 8 | """ 9 | 10 | import os 11 | import re 12 | import uuid 13 | import hashlib 14 | import tempfile 15 | import shutil 16 | import tensorflow as tf 17 | from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module 18 | 19 | #---------------------------------------------------------------------------- 20 | # Global options. 21 | 22 | cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache') 23 | cuda_cache_version_tag = 'v1' 24 | do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! 25 | verbose = True # Print status messages to stdout. 26 | 27 | compiler_bindir_search_path = [ 28 | 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64', 29 | 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64', 30 | 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin', 31 | ] 32 | 33 | #---------------------------------------------------------------------------- 34 | # Internal helper funcs. 35 | 36 | def _find_compiler_bindir(): 37 | for compiler_path in compiler_bindir_search_path: 38 | if os.path.isdir(compiler_path): 39 | return compiler_path 40 | return None 41 | 42 | def _get_compute_cap(device): 43 | caps_str = device.physical_device_desc 44 | m = re.search('compute capability: (\\d+).(\\d+)', caps_str) 45 | major = m.group(1) 46 | minor = m.group(2) 47 | return (major, minor) 48 | 49 | def _get_cuda_gpu_arch_string(): 50 | gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] 51 | if len(gpus) == 0: 52 | raise RuntimeError('No GPU devices found') 53 | (major, minor) = _get_compute_cap(gpus[0]) 54 | return 'sm_%s%s' % (major, minor) 55 | 56 | def _run_cmd(cmd): 57 | with os.popen(cmd) as pipe: 58 | output = pipe.read() 59 | status = pipe.close() 60 | if status is not None: 61 | raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) 62 | 63 | def _prepare_nvcc_cli(opts): 64 | cmd = 'nvcc --std=c++11 -DNDEBUG ' + opts.strip() 65 | cmd += ' --disable-warnings' 66 | cmd += ' --include-path "%s"' % tf.sysconfig.get_include() 67 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') 68 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') 69 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') 70 | 71 | compiler_bindir = _find_compiler_bindir() 72 | if compiler_bindir is None: 73 | # Require that _find_compiler_bindir succeeds on Windows. Allow 74 | # nvcc to use whatever is the default on Linux. 75 | if os.name == 'nt': 76 | raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) 77 | else: 78 | cmd += ' --compiler-bindir "%s"' % compiler_bindir 79 | cmd += ' 2>&1' 80 | return cmd 81 | 82 | #---------------------------------------------------------------------------- 83 | # Main entry point. 84 | 85 | _plugin_cache = dict() 86 | 87 | def get_plugin(cuda_file): 88 | cuda_file_base = os.path.basename(cuda_file) 89 | cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) 90 | 91 | # Already in cache? 92 | if cuda_file in _plugin_cache: 93 | return _plugin_cache[cuda_file] 94 | 95 | # Setup plugin. 96 | if verbose: 97 | print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) 98 | try: 99 | # Hash CUDA source. 100 | md5 = hashlib.md5() 101 | with open(cuda_file, 'rb') as f: 102 | md5.update(f.read()) 103 | md5.update(b'\n') 104 | 105 | # Hash headers included by the CUDA code by running it through the preprocessor. 106 | if not do_not_hash_included_headers: 107 | if verbose: 108 | print('Preprocessing... ', end='', flush=True) 109 | with tempfile.TemporaryDirectory() as tmp_dir: 110 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) 111 | _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) 112 | with open(tmp_file, 'rb') as f: 113 | bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros 114 | good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') 115 | for ln in f: 116 | if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas 117 | ln = ln.replace(bad_file_str, good_file_str) 118 | md5.update(ln) 119 | md5.update(b'\n') 120 | 121 | # Select compiler options. 122 | compile_opts = '' 123 | if os.name == 'nt': 124 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') 125 | elif os.name == 'posix': 126 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so') 127 | compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' 128 | else: 129 | assert False # not Windows or Linux, w00t? 130 | compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() 131 | compile_opts += ' --use_fast_math' 132 | nvcc_cmd = _prepare_nvcc_cli(compile_opts) 133 | 134 | # Hash build configuration. 135 | md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') 136 | md5.update(('tf.VERSION: ' + tf.version.VERSION).encode('utf-8') + b'\n') 137 | md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') 138 | 139 | # Compile if not already compiled. 140 | bin_file_ext = '.dll' if os.name == 'nt' else '.so' 141 | bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) 142 | if not os.path.isfile(bin_file): 143 | if verbose: 144 | print('Compiling... ', end='', flush=True) 145 | with tempfile.TemporaryDirectory() as tmp_dir: 146 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) 147 | _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) 148 | os.makedirs(cuda_cache_path, exist_ok=True) 149 | intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) 150 | shutil.copyfile(tmp_file, intermediate_file) 151 | os.rename(intermediate_file, bin_file) # atomic 152 | 153 | # Load. 154 | if verbose: 155 | print('Loading... ', end='', flush=True) 156 | plugin = tf.load_op_library(bin_file) 157 | 158 | # Add to cache. 159 | _plugin_cache[cuda_file] = plugin 160 | if verbose: 161 | print('Done.', flush=True) 162 | return plugin 163 | 164 | except: 165 | if verbose: 166 | print('Failed!', flush=True) 167 | raise 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /cuda/fused_bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #define EIGEN_USE_GPU 8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 9 | #include "tensorflow/core/framework/op.h" 10 | #include "tensorflow/core/framework/op_kernel.h" 11 | #include "tensorflow/core/framework/shape_inference.h" 12 | #include 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 18 | 19 | //------------------------------------------------------------------------ 20 | // CUDA kernel. 21 | 22 | template 23 | struct FusedBiasActKernelParams 24 | { 25 | const T* x; // [sizeX] 26 | const T* b; // [sizeB] or NULL 27 | const T* ref; // [sizeX] or NULL 28 | T* y; // [sizeX] 29 | 30 | int grad; 31 | int axis; 32 | int act; 33 | float alpha; 34 | float gain; 35 | 36 | int sizeX; 37 | int sizeB; 38 | int stepB; 39 | int loopX; 40 | }; 41 | 42 | template 43 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) 44 | { 45 | const float expRange = 80.0f; 46 | const float halfExpRange = 40.0f; 47 | const float seluScale = 1.0507009873554804934193349852946f; 48 | const float seluAlpha = 1.6732632423543772848170429916717f; 49 | 50 | // Loop over elements. 51 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 52 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 53 | { 54 | // Load and apply bias. 55 | float x = (float)p.x[xi]; 56 | if (p.b) 57 | x += (float)p.b[(xi / p.stepB) % p.sizeB]; 58 | float ref = (p.ref) ? (float)p.ref[xi] : 0.0f; 59 | if (p.gain != 0.0f & p.act != 9) 60 | ref /= p.gain; 61 | 62 | // Evaluate activation func. 63 | float y; 64 | switch (p.act * 10 + p.grad) 65 | { 66 | // linear 67 | default: 68 | case 10: y = x; break; 69 | case 11: y = x; break; 70 | case 12: y = 0.0f; break; 71 | 72 | // relu 73 | case 20: y = (x > 0.0f) ? x : 0.0f; break; 74 | case 21: y = (ref > 0.0f) ? x : 0.0f; break; 75 | case 22: y = 0.0f; break; 76 | 77 | // lrelu 78 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break; 79 | case 31: y = (ref > 0.0f) ? x : x * p.alpha; break; 80 | case 32: y = 0.0f; break; 81 | 82 | // tanh 83 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; 84 | case 41: y = x * (1.0f - ref * ref); break; 85 | case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break; 86 | 87 | // sigmoid 88 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; 89 | case 51: y = x * ref * (1.0f - ref); break; 90 | case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break; 91 | 92 | // elu 93 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; 94 | case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break; 95 | case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break; 96 | 97 | // selu 98 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; 99 | case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break; 100 | case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break; 101 | 102 | // softplus 103 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; 104 | case 81: y = x * (1.0f - expf(-ref)); break; 105 | case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break; 106 | 107 | // swish 108 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; 109 | case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break; 110 | case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break; 111 | } 112 | 113 | // Apply gain and store. 114 | p.y[xi] = (T)(y * p.gain); 115 | } 116 | } 117 | 118 | //------------------------------------------------------------------------ 119 | // TensorFlow op. 120 | 121 | template 122 | struct FusedBiasActOp : public OpKernel 123 | { 124 | FusedBiasActKernelParams m_attribs; 125 | 126 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) 127 | { 128 | memset(&m_attribs, 0, sizeof(m_attribs)); 129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); 130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); 131 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); 132 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); 133 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); 134 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); 135 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); 136 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); 137 | } 138 | 139 | void Compute(OpKernelContext* ctx) 140 | { 141 | FusedBiasActKernelParams p = m_attribs; 142 | cudaStream_t stream = ctx->eigen_device().stream(); 143 | 144 | const Tensor& x = ctx->input(0); // [...] 145 | const Tensor& b = ctx->input(1); // [sizeB] or [0] 146 | const Tensor& ref = ctx->input(2); // x.shape or [0] 147 | p.x = x.flat().data(); 148 | p.b = (b.NumElements()) ? b.flat().data() : NULL; 149 | p.ref = (ref.NumElements()) ? ref.flat().data() : NULL; 150 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); 151 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); 152 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); 153 | OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements")); 154 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); 155 | 156 | p.sizeX = (int)x.NumElements(); 157 | p.sizeB = (int)b.NumElements(); 158 | p.stepB = 1; 159 | for (int i = m_attribs.axis + 1; i < x.dims(); i++) 160 | p.stepB *= (int)x.dim_size(i); 161 | 162 | Tensor* y = NULL; // x.shape 163 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); 164 | p.y = y->flat().data(); 165 | 166 | p.loopX = 4; 167 | int blockSize = 4 * 32; 168 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 169 | void* args[] = {&p}; 170 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); 171 | } 172 | }; 173 | 174 | REGISTER_OP("FusedBiasAct") 175 | .Input ("x: T") 176 | .Input ("b: T") 177 | .Input ("ref: T") 178 | .Output ("y: T") 179 | .Attr ("T: {float, half}") 180 | .Attr ("grad: int = 0") 181 | .Attr ("axis: int = 1") 182 | .Attr ("act: int = 0") 183 | .Attr ("alpha: float = 0.0") 184 | .Attr ("gain: float = 1.0"); 185 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 186 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 187 | 188 | //------------------------------------------------------------------------ 189 | -------------------------------------------------------------------------------- /cuda/fused_bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Custom TensorFlow ops for efficient bias and activation.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | from cuda import custom_ops 13 | from utils import EasyDict 14 | 15 | def _get_plugin(): 16 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | activation_funcs = { 21 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), 22 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), 23 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), 24 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), 25 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), 26 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), 27 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), 28 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), 29 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), 30 | } 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'): 35 | r"""Fused bias and activation function. 36 | 37 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 38 | and scales the result by `gain`. Each of the steps is optional. In most cases, 39 | the fused op is considerably more efficient than performing the same calculation 40 | using standard TensorFlow ops. It supports first and second order gradients, 41 | but not third order gradients. 42 | 43 | Args: 44 | x: Input activation tensor. Can have any shape, but if `b` is defined, the 45 | dimension corresponding to `axis`, as well as the rank, must be known. 46 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 47 | as `x`. The shape must be known, and it must match the dimension of `x` 48 | corresponding to `axis`. 49 | axis: The dimension in `x` corresponding to the elements of `b`. 50 | The value of `axis` is ignored if `b` is not specified. 51 | act: Name of the activation function to evaluate, or `"linear"` to disable. 52 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 53 | See `activation_funcs` for a full list. `None` is not allowed. 54 | alpha: Shape parameter for the activation function, or `None` to use the default. 55 | gain: Scaling factor for the output tensor, or `None` to use default. 56 | See `activation_funcs` for the default scaling of each activation function. 57 | If unsure, consider specifying `1.0`. 58 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 59 | 60 | Returns: 61 | Tensor of the same shape and datatype as `x`. 62 | """ 63 | 64 | impl_dict = { 65 | 'ref': _fused_bias_act_ref, 66 | 'cuda': _fused_bias_act_cuda, 67 | } 68 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain): 73 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" 74 | 75 | # Validate arguments. 76 | x = tf.convert_to_tensor(x) 77 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) 78 | act_spec = activation_funcs[act] 79 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 80 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 81 | if alpha is None: 82 | alpha = act_spec.def_alpha 83 | if gain is None: 84 | gain = act_spec.def_gain 85 | 86 | # Add bias. 87 | if b.shape[0] != 0: 88 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) 89 | 90 | # Evaluate activation function. 91 | x = act_spec.func(x, alpha=alpha) 92 | 93 | # Scale by gain. 94 | if gain != 1: 95 | x *= gain 96 | return x 97 | 98 | #---------------------------------------------------------------------------- 99 | 100 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain): 101 | """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" 102 | 103 | # Validate arguments. 104 | x = tf.convert_to_tensor(x) 105 | empty_tensor = tf.constant([], dtype=x.dtype) 106 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor 107 | act_spec = activation_funcs[act] 108 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 109 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 110 | if alpha is None: 111 | alpha = act_spec.def_alpha 112 | if gain is None: 113 | gain = act_spec.def_gain 114 | 115 | # Special cases. 116 | if act == 'linear' and b is None and gain == 1.0: 117 | return x 118 | if act_spec.cuda_idx is None: 119 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) 120 | 121 | # CUDA kernel. 122 | cuda_kernel = _get_plugin().fused_bias_act 123 | cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain) 124 | 125 | # Forward pass: y = func(x, b). 126 | def func_y(x, b): 127 | y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs) 128 | y.set_shape(x.shape) 129 | return y 130 | 131 | # Backward pass: dx, db = grad(dy, x, y) 132 | def grad_dx(dy, x, y): 133 | ref = {'x': x, 'y': y}[act_spec.ref] 134 | dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs) 135 | dx.set_shape(x.shape) 136 | return dx 137 | def grad_db(dx): 138 | if b.shape[0] == 0: 139 | return empty_tensor 140 | db = dx 141 | if axis < x.shape.rank - 1: 142 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) 143 | if axis > 0: 144 | db = tf.reduce_sum(db, list(range(axis))) 145 | db.set_shape(b.shape) 146 | return db 147 | 148 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) 149 | def grad2_d_dy(d_dx, d_db, x, y): 150 | ref = {'x': x, 'y': y}[act_spec.ref] 151 | d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs) 152 | d_dy.set_shape(x.shape) 153 | return d_dy 154 | def grad2_d_x(d_dx, d_db, x, y): 155 | ref = {'x': x, 'y': y}[act_spec.ref] 156 | d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs) 157 | d_x.set_shape(x.shape) 158 | return d_x 159 | 160 | # Fast version for piecewise-linear activation funcs. 161 | @tf.custom_gradient 162 | def func_zero_2nd_grad(x, b): 163 | y = func_y(x, b) 164 | @tf.custom_gradient 165 | def grad(dy): 166 | dx = grad_dx(dy, x, y) 167 | db = grad_db(dx) 168 | def grad2(d_dx, d_db): 169 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 170 | return d_dy 171 | return (dx, db), grad2 172 | return y, grad 173 | 174 | # Slow version for general activation funcs. 175 | @tf.custom_gradient 176 | def func_nonzero_2nd_grad(x, b): 177 | y = func_y(x, b) 178 | def grad_wrap(dy): 179 | @tf.custom_gradient 180 | def grad_impl(dy, x): 181 | dx = grad_dx(dy, x, y) 182 | db = grad_db(dx) 183 | def grad2(d_dx, d_db): 184 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 185 | d_x = grad2_d_x(d_dx, d_db, x, y) 186 | return d_dy, d_x 187 | return (dx, db), grad2 188 | return grad_impl(dy, x) 189 | return y, grad_wrap 190 | 191 | # Which version to use? 192 | if act_spec.zero_2nd_grad: 193 | return func_zero_2nd_grad(x, b) 194 | return func_nonzero_2nd_grad(x, b) 195 | 196 | #---------------------------------------------------------------------------- 197 | -------------------------------------------------------------------------------- /cuda/upfirdn_2d.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #define EIGEN_USE_GPU 8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 9 | #include "tensorflow/core/framework/op.h" 10 | #include "tensorflow/core/framework/op_kernel.h" 11 | #include "tensorflow/core/framework/shape_inference.h" 12 | #include 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | //------------------------------------------------------------------------ 18 | // Helpers. 19 | 20 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 21 | 22 | static __host__ __device__ __forceinline__ int floorDiv(int a, int b) 23 | { 24 | int c = a / b; 25 | if (c * b > a) 26 | c--; 27 | return c; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | // CUDA kernel params. 32 | 33 | template 34 | struct UpFirDn2DKernelParams 35 | { 36 | const T* x; // [majorDim, inH, inW, minorDim] 37 | const T* k; // [kernelH, kernelW] 38 | T* y; // [majorDim, outH, outW, minorDim] 39 | 40 | int upx; 41 | int upy; 42 | int downx; 43 | int downy; 44 | int padx0; 45 | int padx1; 46 | int pady0; 47 | int pady1; 48 | 49 | int majorDim; 50 | int inH; 51 | int inW; 52 | int minorDim; 53 | int kernelH; 54 | int kernelW; 55 | int outH; 56 | int outW; 57 | int loopMajor; 58 | int loopX; 59 | }; 60 | 61 | //------------------------------------------------------------------------ 62 | // General CUDA implementation for large filter kernels. 63 | 64 | template 65 | static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) 66 | { 67 | // Calculate thread index. 68 | int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; 69 | int outY = minorIdx / p.minorDim; 70 | minorIdx -= outY * p.minorDim; 71 | int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; 72 | int majorIdxBase = blockIdx.z * p.loopMajor; 73 | if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) 74 | return; 75 | 76 | // Setup Y receptive field. 77 | int midY = outY * p.downy + p.upy - 1 - p.pady0; 78 | int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); 79 | int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; 80 | int kernelY = midY + p.kernelH - (inY + 1) * p.upy; 81 | 82 | // Loop over majorDim and outX. 83 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) 84 | for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) 85 | { 86 | // Setup X receptive field. 87 | int midX = outX * p.downx + p.upx - 1 - p.padx0; 88 | int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); 89 | int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; 90 | int kernelX = midX + p.kernelW - (inX + 1) * p.upx; 91 | 92 | // Initialize pointers. 93 | const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; 94 | const T* kp = &p.k[kernelY * p.kernelW + kernelX]; 95 | int xpx = p.minorDim; 96 | int kpx = -p.upx; 97 | int xpy = p.inW * p.minorDim; 98 | int kpy = -p.upy * p.kernelW; 99 | 100 | // Inner loop. 101 | float v = 0.0f; 102 | for (int y = 0; y < h; y++) 103 | { 104 | for (int x = 0; x < w; x++) 105 | { 106 | v += (float)(*xp) * (float)(*kp); 107 | xp += xpx; 108 | kp += kpx; 109 | } 110 | xp += xpy - w * xpx; 111 | kp += kpy - w * kpx; 112 | } 113 | 114 | // Store result. 115 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; 116 | } 117 | } 118 | 119 | //------------------------------------------------------------------------ 120 | // Specialized CUDA implementation for small filter kernels. 121 | 122 | template 123 | static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) 124 | { 125 | //assert(kernelW % upx == 0); 126 | //assert(kernelH % upy == 0); 127 | const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; 128 | const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; 129 | __shared__ volatile float sk[kernelH][kernelW]; 130 | __shared__ volatile float sx[tileInH][tileInW]; 131 | 132 | // Calculate tile index. 133 | int minorIdx = blockIdx.x; 134 | int tileOutY = minorIdx / p.minorDim; 135 | minorIdx -= tileOutY * p.minorDim; 136 | tileOutY *= tileOutH; 137 | int tileOutXBase = blockIdx.y * p.loopX * tileOutW; 138 | int majorIdxBase = blockIdx.z * p.loopMajor; 139 | if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) 140 | return; 141 | 142 | // Load filter kernel (flipped). 143 | for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) 144 | { 145 | int ky = tapIdx / kernelW; 146 | int kx = tapIdx - ky * kernelW; 147 | float v = 0.0f; 148 | if (kx < p.kernelW & ky < p.kernelH) 149 | v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; 150 | sk[ky][kx] = v; 151 | } 152 | 153 | // Loop over majorDim and outX. 154 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) 155 | for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) 156 | { 157 | // Load input pixels. 158 | int tileMidX = tileOutX * downx + upx - 1 - p.padx0; 159 | int tileMidY = tileOutY * downy + upy - 1 - p.pady0; 160 | int tileInX = floorDiv(tileMidX, upx); 161 | int tileInY = floorDiv(tileMidY, upy); 162 | __syncthreads(); 163 | for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) 164 | { 165 | int relInY = inIdx / tileInW; 166 | int relInX = inIdx - relInY * tileInW; 167 | int inX = relInX + tileInX; 168 | int inY = relInY + tileInY; 169 | float v = 0.0f; 170 | if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) 171 | v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; 172 | sx[relInY][relInX] = v; 173 | } 174 | 175 | // Loop over output pixels. 176 | __syncthreads(); 177 | for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) 178 | { 179 | int relOutY = outIdx / tileOutW; 180 | int relOutX = outIdx - relOutY * tileOutW; 181 | int outX = relOutX + tileOutX; 182 | int outY = relOutY + tileOutY; 183 | 184 | // Setup receptive field. 185 | int midX = tileMidX + relOutX * downx; 186 | int midY = tileMidY + relOutY * downy; 187 | int inX = floorDiv(midX, upx); 188 | int inY = floorDiv(midY, upy); 189 | int relInX = inX - tileInX; 190 | int relInY = inY - tileInY; 191 | int kernelX = (inX + 1) * upx - midX - 1; // flipped 192 | int kernelY = (inY + 1) * upy - midY - 1; // flipped 193 | 194 | // Inner loop. 195 | float v = 0.0f; 196 | #pragma unroll 197 | for (int y = 0; y < kernelH / upy; y++) 198 | #pragma unroll 199 | for (int x = 0; x < kernelW / upx; x++) 200 | v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; 201 | 202 | // Store result. 203 | if (outX < p.outW & outY < p.outH) 204 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; 205 | } 206 | } 207 | } 208 | 209 | //------------------------------------------------------------------------ 210 | // TensorFlow op. 211 | 212 | template 213 | struct UpFirDn2DOp : public OpKernel 214 | { 215 | UpFirDn2DKernelParams m_attribs; 216 | 217 | UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) 218 | { 219 | memset(&m_attribs, 0, sizeof(m_attribs)); 220 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); 221 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); 222 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); 223 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); 224 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); 225 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); 226 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); 227 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); 228 | OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); 229 | OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); 230 | } 231 | 232 | void Compute(OpKernelContext* ctx) 233 | { 234 | UpFirDn2DKernelParams p = m_attribs; 235 | cudaStream_t stream = ctx->eigen_device().stream(); 236 | 237 | const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] 238 | const Tensor& k = ctx->input(1); // [kernelH, kernelW] 239 | p.x = x.flat().data(); 240 | p.k = k.flat().data(); 241 | OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); 242 | OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); 243 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); 244 | OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); 245 | 246 | p.majorDim = (int)x.dim_size(0); 247 | p.inH = (int)x.dim_size(1); 248 | p.inW = (int)x.dim_size(2); 249 | p.minorDim = (int)x.dim_size(3); 250 | p.kernelH = (int)k.dim_size(0); 251 | p.kernelW = (int)k.dim_size(1); 252 | OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); 253 | 254 | p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; 255 | p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; 256 | OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); 257 | 258 | Tensor* y = NULL; // [majorDim, outH, outW, minorDim] 259 | TensorShape ys; 260 | ys.AddDim(p.majorDim); 261 | ys.AddDim(p.outH); 262 | ys.AddDim(p.outW); 263 | ys.AddDim(p.minorDim); 264 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); 265 | p.y = y->flat().data(); 266 | OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); 267 | 268 | // Choose CUDA kernel to use. 269 | void* cudaKernel = (void*)UpFirDn2DKernel_large; 270 | int tileOutW = -1; 271 | int tileOutH = -1; 272 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 273 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 274 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 275 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 276 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 277 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 278 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 279 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 280 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 281 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 282 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 283 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 284 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 285 | 286 | // Choose launch params. 287 | dim3 blockSize; 288 | dim3 gridSize; 289 | if (tileOutW > 0 && tileOutH > 0) // small 290 | { 291 | p.loopMajor = (p.majorDim - 1) / 16384 + 1; 292 | p.loopX = 1; 293 | blockSize = dim3(32 * 8, 1, 1); 294 | gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); 295 | } 296 | else // large 297 | { 298 | p.loopMajor = (p.majorDim - 1) / 16384 + 1; 299 | p.loopX = 4; 300 | blockSize = dim3(4, 32, 1); 301 | gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); 302 | } 303 | 304 | // Launch CUDA kernel. 305 | void* args[] = {&p}; 306 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); 307 | } 308 | }; 309 | 310 | REGISTER_OP("UpFirDn2D") 311 | .Input ("x: T") 312 | .Input ("k: T") 313 | .Output ("y: T") 314 | .Attr ("T: {float, half}") 315 | .Attr ("upx: int = 1") 316 | .Attr ("upy: int = 1") 317 | .Attr ("downx: int = 1") 318 | .Attr ("downy: int = 1") 319 | .Attr ("padx0: int = 0") 320 | .Attr ("padx1: int = 0") 321 | .Attr ("pady0: int = 0") 322 | .Attr ("pady1: int = 0"); 323 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); 324 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); 325 | 326 | //------------------------------------------------------------------------ 327 | -------------------------------------------------------------------------------- /cuda/upfirdn_2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from cuda import custom_ops 5 | 6 | 7 | def _get_plugin(): 8 | loc = os.path.dirname(os.path.abspath(__file__)) 9 | cu_fn = 'upfirdn_2d.cu' 10 | return custom_ops.get_plugin(os.path.join(loc, cu_fn)) 11 | 12 | 13 | def _setup_kernel(k): 14 | k = np.asarray(k, dtype=np.float32) 15 | if k.ndim == 1: 16 | k = np.outer(k, k) 17 | k /= np.sum(k) 18 | assert k.ndim == 2 19 | assert k.shape[0] == k.shape[1] 20 | return k 21 | 22 | 23 | def compute_paddings(resample_kernel, convW, up, down, is_conv, factor=2, gain=1): 24 | assert not (up and down) 25 | 26 | k = [1] * factor if resample_kernel is None else resample_kernel 27 | if up: 28 | k = _setup_kernel(k) * (gain * (factor ** 2)) 29 | if is_conv: 30 | p = (k.shape[0] - factor) - (convW - 1) 31 | pad0 = (p + 1) // 2 + factor - 1 32 | pad1 = p // 2 + 1 33 | else: 34 | p = k.shape[0] - factor 35 | pad0 = (p + 1) // 2 + factor - 1 36 | pad1 = p // 2 37 | elif down: 38 | k = _setup_kernel(k) * gain 39 | if is_conv: 40 | p = (k.shape[0] - factor) + (convW - 1) 41 | pad0 = (p + 1) // 2 42 | pad1 = p // 2 43 | else: 44 | p = k.shape[0] - factor 45 | pad0 = (p + 1) // 2 46 | pad1 = p // 2 47 | else: 48 | k = resample_kernel 49 | pad0, pad1 = 0, 0 50 | return k, pad0, pad1 51 | 52 | 53 | def upsample_2d(x, pad0, pad1, k, factor=2): 54 | assert isinstance(factor, int) and factor >= 1 55 | x_res = x.shape[2] 56 | return _simple_upfirdn_2d(x, x_res, k, up=factor, pad0=pad0, pad1=pad1) 57 | 58 | 59 | def downsample_2d(x, pad0, pad1, k, factor=2): 60 | assert isinstance(factor, int) and factor >= 1 61 | x_res = x.shape[2] 62 | return _simple_upfirdn_2d(x, x_res, k, down=factor, pad0=pad0, pad1=pad1) 63 | 64 | 65 | def upsample_conv_2d(x, w, convH, convW, pad0, pad1, k, factor=2): 66 | assert isinstance(factor, int) and factor >= 1 67 | 68 | x_res = x.shape[2] 69 | # Check weight shape. 70 | w = tf.convert_to_tensor(w) 71 | assert w.shape.rank == 4 72 | # convH = w.shape[0] 73 | # convW = w.shape[1] 74 | inC = tf.shape(w)[2] 75 | outC = tf.shape(w)[3] 76 | assert convW == convH 77 | 78 | # Determine data dimensions. 79 | stride = [1, 1, factor, factor] 80 | output_shape = [tf.shape(x)[0], outC, (x_res - 1) * factor + convH, (x_res - 1) * factor + convW] 81 | num_groups = tf.shape(x)[1] // inC 82 | 83 | # Transpose weights. 84 | w = tf.reshape(w, [convH, convW, inC, num_groups, -1]) 85 | w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2]) 86 | w = tf.reshape(w, [convH, convW, -1, num_groups * inC]) 87 | 88 | # Execute. 89 | x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format='NCHW') 90 | new_x_res = output_shape[2] 91 | return _simple_upfirdn_2d(x, new_x_res, k, pad0=pad0, pad1=pad1) 92 | 93 | 94 | def conv_downsample_2d(x, w, convH, convW, pad0, pad1, k, factor=2): 95 | assert isinstance(factor, int) and factor >= 1 96 | x_res = x.shape[2] 97 | w = tf.convert_to_tensor(w) 98 | # convH, convW, _inC, _outC = w.shape.as_list() 99 | assert convW == convH 100 | 101 | s = [1, 1, factor, factor] 102 | x = _simple_upfirdn_2d(x, x_res, k, pad0=pad0, pad1=pad1) 103 | return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format='NCHW') 104 | 105 | 106 | def _simple_upfirdn_2d(x, x_res, k, up=1, down=1, pad0=0, pad1=0): 107 | assert x.shape.rank == 4 108 | y = x 109 | y = tf.reshape(y, [-1, x_res, x_res, 1]) 110 | y = upfirdn_2d_cuda(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1) 111 | y = tf.reshape(y, [-1, tf.shape(x)[1], tf.shape(y)[1], tf.shape(y)[2]]) 112 | return y 113 | 114 | 115 | def upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): 116 | """Fast CUDA implementation of `upfirdn_2d()` using custom ops.""" 117 | 118 | x = tf.convert_to_tensor(x) 119 | k = np.asarray(k, dtype=np.float32) 120 | majorDim, inH, inW, minorDim = x.shape.as_list() 121 | kernelH, kernelW = k.shape 122 | assert inW >= 1 and inH >= 1 123 | assert kernelW >= 1 and kernelH >= 1 124 | assert isinstance(upx, int) and isinstance(upy, int) 125 | assert isinstance(downx, int) and isinstance(downy, int) 126 | assert isinstance(padx0, int) and isinstance(padx1, int) 127 | assert isinstance(pady0, int) and isinstance(pady1, int) 128 | 129 | outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1 130 | outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1 131 | assert outW >= 1 and outH >= 1 132 | 133 | kc = tf.constant(k, dtype=x.dtype) 134 | gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype) 135 | gpadx0 = kernelW - padx0 - 1 136 | gpady0 = kernelH - pady0 - 1 137 | gpadx1 = inW * upx - outW * downx + padx0 - upx + 1 138 | gpady1 = inH * upy - outH * downy + pady0 - upy + 1 139 | 140 | @tf.custom_gradient 141 | def func(x): 142 | y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) 143 | y.set_shape([majorDim, outH, outW, minorDim]) 144 | @tf.custom_gradient 145 | def grad(dy): 146 | dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1) 147 | dx.set_shape([majorDim, inH, inW, minorDim]) 148 | return dx, func 149 | return y, grad 150 | return func(x) 151 | -------------------------------------------------------------------------------- /generate_video.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from PIL import Image 3 | import numpy as np 4 | 5 | import math 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torchvision import utils 10 | import cv2 11 | 12 | from moviepy.editor import * 13 | 14 | def make_grid(images, res, rows, cols): 15 | images = (tf.clip_by_value(images, -1.0, 1.0) + 1.0) * 127.5 16 | images = tf.transpose(images, perm=[0, 2, 3, 1]) 17 | images = tf.cast(images, tf.uint8) 18 | images = images.numpy() 19 | 20 | batch_size = images.shape[0] 21 | assert rows * cols == batch_size 22 | canvas = np.zeros(shape=[res * rows, res * cols, 3], dtype=np.uint8) 23 | for row in range(rows): 24 | y_start = row * res 25 | for col in range(cols): 26 | x_start = col * res 27 | index = col + row * cols 28 | canvas[y_start:y_start + res, x_start:x_start + res, :] = images[index, :, :, :] 29 | 30 | return canvas 31 | 32 | def load_generator(g_params=None, is_g_clone=True, ckpt_dir='checkpoint'): 33 | 34 | from networks import Generator 35 | 36 | if g_params is None: 37 | g_params = { 38 | 'z_dim': 512, 39 | 'w_dim': 512, 40 | 'labels_dim': 0, 41 | 'n_mapping': 8, 42 | 'resolutions': [4, 8, 16, 32, 64, 128, 256], 43 | 'featuremaps': [512, 512, 512, 512, 512, 256, 128], 44 | 'w_ema_decay': 0.995, 45 | 'style_mixing_prob': 0.9, 46 | } 47 | 48 | test_latent = tf.ones((1, g_params['z_dim']), dtype=tf.float32) 49 | test_labels = tf.ones((1, g_params['labels_dim']), dtype=tf.float32) 50 | 51 | # build generator model 52 | generator = Generator(g_params) 53 | _, _ = generator([test_latent, test_labels]) 54 | 55 | if ckpt_dir is not None: 56 | if is_g_clone: 57 | ckpt = tf.train.Checkpoint(g_clone=generator) 58 | else: 59 | ckpt = tf.train.Checkpoint(generator=generator) 60 | manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=1) 61 | ckpt.restore(manager.latest_checkpoint).expect_partial() 62 | if manager.latest_checkpoint: 63 | print(f'Generator restored from {manager.latest_checkpoint}') 64 | 65 | return generator 66 | 67 | def generate(): 68 | 69 | generator = load_generator(is_g_clone=True) 70 | radius = 30 # 32 71 | pics = 120 72 | truncation_psi = 0.5 # 1.0 73 | 74 | sample_n = 16 # 4 75 | n_row = 4 76 | n_col = 4 77 | res = 256 78 | sample_z = tf.random.normal(shape=[sample_n, 512]) 79 | images = [] 80 | for i in tqdm(range(pics)): 81 | dh = math.sin(2 * math.pi * (i / pics)) * radius 82 | dw = math.cos(2 * math.pi * (i / pics)) * radius 83 | 84 | sample_tf, _ = generator([sample_z, 85 | tf.random.normal(shape=[sample_n, 0])], 86 | shift_h=dh, shift_w=dw, 87 | training=False, truncation_psi=truncation_psi) 88 | # Pytorch 89 | 90 | sample = sample_tf 91 | sample = sample.numpy() 92 | sample = torch.Tensor(sample) 93 | grid = utils.make_grid( 94 | sample.cpu(), normalize=True, nrow=n_row, value_range=(-1, 1) 95 | ) 96 | grid = grid.mul(255).permute(1, 2, 0).numpy().astype(np.uint8) 97 | images.append( 98 | grid 99 | ) 100 | 101 | 102 | # Tensorflow 103 | # grid_tf = make_grid(sample_tf, res=res, rows=n_row, cols=n_col) 104 | # images.append(grid_tf) 105 | 106 | 107 | # Image save 108 | """ 109 | for j in tqdm(range(sample_n)): 110 | f_name = 'images/{}_{}.png'.format(j, i) 111 | utils.save_image( 112 | sample[j].unsqueeze(0), 113 | f_name, 114 | nrow=1, 115 | normalize=True, 116 | range=(-1, 1), 117 | ) 118 | """ 119 | 120 | # To video 121 | videodims = (images[0].shape[1], images[0].shape[0]) 122 | fourcc = cv2.VideoWriter_fourcc(*"VP90") 123 | video = cv2.VideoWriter("sample.webm", fourcc, 24, videodims) 124 | 125 | for i in tqdm(images): 126 | video.write(cv2.cvtColor(i, cv2.COLOR_RGB2BGR)) 127 | 128 | video.release() 129 | 130 | # Video to GIF 131 | clip = VideoFileClip("sample.webm") 132 | clip.write_gif("sample.gif") 133 | 134 | 135 | generate() -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | 3 | ################################################################################## 4 | # Synthesis Layers 5 | ################################################################################## 6 | class Synthesis(tf.keras.layers.Layer): 7 | def __init__(self, resolutions, featuremaps, name, **kwargs): 8 | super(Synthesis, self).__init__(name=name, **kwargs) 9 | self.resolutions = resolutions 10 | self.featuremaps = featuremaps 11 | 12 | self.k, self.pad0, self.pad1 = compute_paddings([1, 3, 3, 1], None, up=True, down=False, is_conv=False) 13 | 14 | # initial layer 15 | res, n_f = resolutions[0], featuremaps[0] 16 | self.img_size = resolutions[-1] 17 | self.log_size = int(np.log2(self.img_size)) 18 | 19 | self.shift_h_dict = {4: 0} 20 | self.shift_w_dict = {4: 0} 21 | for i in range(3, self.log_size + 1): 22 | self.shift_h_dict[2 ** i] = 0 23 | self.shift_w_dict[2 ** i] = 0 24 | 25 | self.initial_block = SynthesisConstBlock(fmaps=n_f, name='{:d}x{:d}/const'.format(res, res)) 26 | self.initial_torgb = ToRGB(in_ch=n_f, name='{:d}x{:d}/ToRGB'.format(res, res)) 27 | 28 | # stack generator block with lerp block 29 | prev_n_f = n_f 30 | self.blocks = [] 31 | self.torgbs = [] 32 | 33 | for res, n_f in zip(self.resolutions[1:], self.featuremaps[1:]): 34 | self.blocks.append(SynthesisBlock(in_ch=prev_n_f, fmaps=n_f, res=res, 35 | name='{:d}x{:d}/block'.format(res, res))) 36 | self.torgbs.append(ToRGB(in_ch=n_f, name='{:d}x{:d}/ToRGB'.format(res, res))) 37 | prev_n_f = n_f 38 | 39 | def call(self, inputs, shift_h=0, shift_w=0, training=None, mask=None): 40 | ##### positional encoding ##### 41 | # continuous roll 42 | if shift_h: 43 | for i in range(2, self.log_size + 1): 44 | self.shift_h_dict[2 ** i] = shift_h / (self.img_size // (2 ** i)) 45 | if shift_w: 46 | for i in range(2, self.log_size + 1): 47 | self.shift_w_dict[2 ** i] = shift_w / (self.img_size // (2 ** i)) 48 | 49 | w_broadcasted = inputs 50 | 51 | # initial layer 52 | w0, w1 = w_broadcasted[:, 0], w_broadcasted[:, 1] 53 | 54 | x = self.initial_block([w_broadcasted, w0], shift_h_dict=self.shift_h_dict, shift_w_dict=self.shift_w_dict) 55 | y = self.initial_torgb([x, w1]) 56 | 57 | layer_index = 1 58 | for block, torgb in zip(self.blocks, self.torgbs): 59 | w0 = w_broadcasted[:, layer_index] 60 | w1 = w_broadcasted[:, layer_index + 1] 61 | w2 = w_broadcasted[:, layer_index + 2] 62 | 63 | x = block([x, w0, w1], shift_h_dict=self.shift_h_dict, shift_w_dict=self.shift_w_dict) 64 | y = upsample_2d(y, self.pad0, self.pad1, self.k) 65 | y = y + torgb([x, w2]) 66 | 67 | layer_index += 2 68 | 69 | images_out = y 70 | 71 | return images_out 72 | 73 | # def get_config(self): 74 | # config = super(Synthesis, self).get_config() 75 | # config.update({ 76 | # 'resolutions': self.resolutions, 77 | # 'featuremaps': self.featuremaps, 78 | # 'k': self.k, 79 | # 'pad0': self.pad0, 80 | # 'pad1': self.pad1, 81 | # }) 82 | # return config 83 | 84 | 85 | class SynthesisConstBlock(tf.keras.layers.Layer): 86 | def __init__(self, fmaps, **kwargs): 87 | super(SynthesisConstBlock, self).__init__(**kwargs) 88 | self.res = 4 89 | self.fmaps = fmaps 90 | self.gain = 1.0 91 | self.lrmul = 1.0 92 | 93 | # conv block 94 | self.conv = ModulatedConv2D(fmaps=self.fmaps, style_fmaps=self.fmaps, kernel=3, up=False, down=False, 95 | demodulate=True, resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, 96 | fused_modconv=True, name='conv') 97 | self.apply_noise = Noise(name='noise') 98 | self.apply_bias_act = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias') 99 | 100 | self.pes_start = PE2dStart(512, 4, 4, scale=1.0) 101 | 102 | # def build(self, input_shape): 103 | # # starting const variable 104 | # # tf 1.15 mean(0.0), std(1.0) default value of tf.initializers.random_normal() 105 | # const_init = tf.random.normal(shape=(1, self.fmaps, self.res, self.res), mean=0.0, stddev=1.0) 106 | # self.const = tf.Variable(const_init, name='const', trainable=True) 107 | 108 | def call(self, inputs, shift_h_dict=None, shift_w_dict=None, training=None, mask=None): 109 | w_broadcasted, w0 = inputs 110 | batch_size = tf.shape(w0)[0] 111 | 112 | # const block 113 | # x = tf.tile(self.const, [batch_size, 1, 1, 1]) 114 | x = self.pes_start(w_broadcasted, shift_h_dict[4], shift_w_dict[4]) 115 | 116 | # conv block 117 | x = self.conv([x, w0]) 118 | x = self.apply_noise(x) 119 | x = self.apply_bias_act(x) 120 | return x 121 | 122 | 123 | class SynthesisBlock(tf.keras.layers.Layer): 124 | def __init__(self, in_ch, fmaps, res, **kwargs): 125 | super(SynthesisBlock, self).__init__(**kwargs) 126 | self.in_ch = in_ch 127 | self.fmaps = fmaps 128 | self.gain = 1.0 129 | self.lrmul = 1.0 130 | self.res = res 131 | 132 | # conv0 up 133 | self.conv_0 = ModulatedConv2D(fmaps=self.fmaps, style_fmaps=self.in_ch, kernel=3, up=True, down=False, 134 | demodulate=True, resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, 135 | fused_modconv=True, name='conv_0') 136 | self.apply_noise_0 = Noise(name='noise_0') 137 | self.apply_bias_act_0 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_0') 138 | 139 | self.pes = PE2d(channel=fmaps, height=res, width=res, scale=1.0) 140 | 141 | # conv block 142 | self.conv_1 = ModulatedConv2D(fmaps=self.fmaps, style_fmaps=self.fmaps, kernel=3, up=False, down=False, 143 | demodulate=True, resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, 144 | fused_modconv=True, name='conv_1') 145 | self.apply_noise_1 = Noise(name='noise_1') 146 | self.apply_bias_act_1 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_1') 147 | 148 | def call(self, inputs, shift_h_dict=None, shift_w_dict=None, training=None, mask=None): 149 | x, w0, w1 = inputs 150 | 151 | # conv0 up 152 | x = self.conv_0([x, w0]) 153 | x = self.apply_noise_0(x) 154 | x = self.apply_bias_act_0(x) 155 | 156 | # pse 157 | x = self.pes(x, shift_h=shift_h_dict[self.res], shift_w=shift_w_dict[self.res]) 158 | 159 | # conv block 160 | x = self.conv_1([x, w1]) 161 | x = self.apply_noise_1(x) 162 | x = self.apply_bias_act_1(x) 163 | 164 | return x 165 | 166 | # def get_config(self): 167 | # config = super(SynthesisBlock, self).get_config() 168 | # config.update({ 169 | # 'in_ch': self.in_ch, 170 | # 'res': self.res, 171 | # 'fmaps': self.fmaps, 172 | # 'gain': self.gain, 173 | # 'lrmul': self.lrmul, 174 | # }) 175 | # return config 176 | 177 | ################################################################################## 178 | # Discriminator Layers 179 | ################################################################################## 180 | class DiscriminatorBlock(tf.keras.layers.Layer): 181 | def __init__(self, n_f0, n_f1, **kwargs): 182 | super(DiscriminatorBlock, self).__init__(**kwargs) 183 | self.gain = 1.0 184 | self.lrmul = 1.0 185 | self.n_f0 = n_f0 186 | self.n_f1 = n_f1 187 | self.resnet_scale = 1. / tf.sqrt(2.) 188 | 189 | # conv_0 190 | self.conv_0 = Conv2D(fmaps=self.n_f0, kernel=3, up=False, down=False, 191 | resample_kernel=None, gain=self.gain, lrmul=self.lrmul, name='conv_0') 192 | self.apply_bias_act_0 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_0') 193 | 194 | # conv_1 down 195 | self.conv_1 = Conv2D(fmaps=self.n_f1, kernel=3, up=False, down=True, 196 | resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, name='conv_1') 197 | self.apply_bias_act_1 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_1') 198 | 199 | # resnet skip 200 | self.conv_skip = Conv2D(fmaps=self.n_f1, kernel=1, up=False, down=True, 201 | resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, name='skip') 202 | 203 | def call(self, inputs, training=None, mask=None): 204 | x = inputs 205 | residual = x 206 | 207 | # conv0 208 | x = self.conv_0(x) 209 | x = self.apply_bias_act_0(x) 210 | 211 | # conv1 down 212 | x = self.conv_1(x) 213 | x = self.apply_bias_act_1(x) 214 | 215 | # resnet skip 216 | residual = self.conv_skip(residual) 217 | x = (x + residual) * self.resnet_scale 218 | return x 219 | 220 | # def get_config(self): 221 | # config = super(DiscriminatorBlock, self).get_config() 222 | # config.update({ 223 | # 'n_f0': self.n_f0, 224 | # 'n_f1': self.n_f1, 225 | # 'gain': self.gain, 226 | # 'lrmul': self.lrmul, 227 | # 'res': self.res, 228 | # 'resnet_scale': self.resnet_scale, 229 | # }) 230 | # return config 231 | 232 | 233 | class DiscriminatorLastBlock(tf.keras.layers.Layer): 234 | def __init__(self, n_f0, n_f1, **kwargs): 235 | super(DiscriminatorLastBlock, self).__init__(**kwargs) 236 | self.gain = 1.0 237 | self.lrmul = 1.0 238 | self.n_f0 = n_f0 239 | self.n_f1 = n_f1 240 | 241 | self.minibatch_std = MinibatchStd(group_size=4, num_new_features=1, name='minibatchstd') 242 | 243 | # conv_0 244 | self.conv_0 = Conv2D(fmaps=self.n_f0, kernel=3, up=False, down=False, 245 | resample_kernel=None, gain=self.gain, lrmul=self.lrmul, name='conv_0') 246 | self.apply_bias_act_0 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_0') 247 | 248 | # dense_1 249 | self.dense_1 = Dense(self.n_f1, gain=self.gain, lrmul=self.lrmul, name='dense_1') 250 | self.apply_bias_act_1 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_1') 251 | 252 | def call(self, x, training=None, mask=None): 253 | x = self.minibatch_std(x) 254 | 255 | # conv_0 256 | x = self.conv_0(x) 257 | x = self.apply_bias_act_0(x) 258 | 259 | # dense_1 260 | x = self.dense_1(x) 261 | x = self.apply_bias_act_1(x) 262 | return x 263 | 264 | # def get_config(self): 265 | # config = super(DiscriminatorLastBlock, self).get_config() 266 | # config.update({ 267 | # 'n_f0': self.n_f0, 268 | # 'n_f1': self.n_f1, 269 | # 'gain': self.gain, 270 | # 'lrmul': self.lrmul, 271 | # 'res': self.res, 272 | # }) 273 | # return config -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from StyleGAN2 import StyleGAN2 2 | import argparse 3 | from utils import * 4 | 5 | def parse_args(): 6 | desc = "Tensorflow implementation of StyleGAN2" 7 | parser = argparse.ArgumentParser(description=desc) 8 | parser.add_argument('--phase', type=str, default='train', help='[train, test, draw]') 9 | parser.add_argument('--draw', type=str, default='all', help='[uncurated, style_mix, truncation_trick, all]') 10 | 11 | parser.add_argument('--dataset', type=str, default='FFHQ', help='dataset_name') 12 | 13 | parser.add_argument('--batch_size', type=int, default=4, help='The size of batch size') 14 | parser.add_argument('--print_freq', type=int, default=2000, help='The number of image_print_freq') 15 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq') 16 | 17 | parser.add_argument('--n_total_image', type=int, default=6400, help='The total iterations') 18 | parser.add_argument('--config', type=str, default='config-f', help='config-e or config-f') 19 | parser.add_argument('--lazy_regularization', type=str2bool, default=True, help='lazy_regularization') 20 | 21 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 22 | 23 | parser.add_argument('--n_test', type=int, default=10, help='The number of images generated by the test phase') 24 | 25 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 26 | help='Directory name to save the checkpoints') 27 | parser.add_argument('--result_dir', type=str, default='results', 28 | help='Directory name to save the generated images') 29 | parser.add_argument('--log_dir', type=str, default='logs', 30 | help='Directory name to save training logs') 31 | parser.add_argument('--sample_dir', type=str, default='samples', 32 | help='Directory name to save the samples on training') 33 | 34 | return check_args(parser.parse_args()) 35 | 36 | 37 | """checking arguments""" 38 | def check_args(args): 39 | # --checkpoint_dir 40 | check_folder(args.checkpoint_dir) 41 | 42 | # --result_dir 43 | check_folder(args.result_dir) 44 | 45 | # --result_dir 46 | check_folder(args.log_dir) 47 | 48 | # --sample_dir 49 | check_folder(args.sample_dir) 50 | 51 | # --batch_size 52 | try: 53 | assert args.batch_size >= 1 54 | except: 55 | print('batch size must be larger than or equal to one') 56 | 57 | return args 58 | 59 | """main""" 60 | def main(): 61 | 62 | args = vars(parse_args()) 63 | 64 | # network params 65 | img_size = args['img_size'] 66 | resolutions = [4, 8, 16, 32, 64, 128, 256, 512, 1024] 67 | if args['config'] == 'config-f': 68 | featuremaps = [512, 512, 512, 512, 512, 256, 128, 64, 32] # config-f 69 | else: 70 | featuremaps = [512, 512, 512, 512, 256, 128, 64, 32, 16] # config-e 71 | train_resolutions, train_featuremaps = filter_resolutions_featuremaps(resolutions, featuremaps, img_size) 72 | g_params = { 73 | 'z_dim': 512, 74 | 'w_dim': 512, 75 | 'labels_dim': 0, 76 | 'n_mapping': 8, 77 | 'resolutions': train_resolutions, 78 | 'featuremaps': train_featuremaps, 79 | 'w_ema_decay': 0.995, 80 | 'style_mixing_prob': 0.9, 81 | } 82 | d_params = { 83 | 'labels_dim': 0, 84 | 'resolutions': train_resolutions, 85 | 'featuremaps': train_featuremaps, 86 | } 87 | 88 | strategy = tf.distribute.MirroredStrategy() 89 | NUM_GPUS = strategy.num_replicas_in_sync 90 | batch_size = args['batch_size'] * NUM_GPUS # global batch size 91 | 92 | # training parameters 93 | training_parameters = { 94 | # global params 95 | **args, 96 | 97 | # network params 98 | 'g_params': g_params, 99 | 'd_params': d_params, 100 | 101 | # training params 102 | 'g_opt': {'learning_rate': 0.002, 'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08, 'reg_interval': 4}, 103 | 'd_opt': {'learning_rate': 0.002, 'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08, 'reg_interval': 16}, 104 | 'batch_size': batch_size, 105 | 'NUM_GPUS' : NUM_GPUS, 106 | 'n_samples': 4, 107 | } 108 | 109 | # automatic_gpu_usage() 110 | with strategy.scope(): 111 | gan = StyleGAN2(training_parameters, strategy) 112 | 113 | # build graph 114 | gan.build_model() 115 | 116 | 117 | if args['phase'] == 'train' : 118 | gan.train() 119 | # gan.test_70000() # for FID evaluation ... 120 | print(" [*] Training finished!") 121 | 122 | if args['phase'] == 'test': 123 | gan.test() 124 | print(" [*] Test finished!") 125 | 126 | if args['phase'] == 'draw': 127 | 128 | if args['draw'] == 'style_mix': 129 | 130 | gan.draw_style_mixing_figure() 131 | 132 | print(" [*] Style mix finished!") 133 | 134 | 135 | elif args['draw'] == 'truncation_trick': 136 | 137 | gan.draw_truncation_trick_figure() 138 | 139 | print(" [*] Truncation_trick finished!") 140 | 141 | 142 | elif args['draw'] == 'uncurated': 143 | gan.draw_uncurated_result_figure() 144 | 145 | print(" [*] Un-curated finished!") 146 | 147 | else: 148 | gan.draw_uncurated_result_figure() 149 | print(" [*] Un-curated finished!") 150 | gan.draw_style_mixing_figure() 151 | print(" [*] Style mix finished!") 152 | gan.draw_truncation_trick_figure() 153 | print(" [*] Truncation_trick finished!") 154 | 155 | 156 | 157 | if __name__ == '__main__': 158 | main() -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | ################################################################################## 3 | # Generator Networks 4 | ################################################################################## 5 | class Generator(tf.keras.Model): 6 | def __init__(self, g_params, **kwargs): 7 | super(Generator, self).__init__(**kwargs) 8 | 9 | self.z_dim = g_params['z_dim'] 10 | self.w_dim = g_params['w_dim'] 11 | self.labels_dim = g_params['labels_dim'] 12 | self.n_mapping = g_params['n_mapping'] 13 | self.resolutions = g_params['resolutions'] 14 | self.featuremaps = g_params['featuremaps'] 15 | self.w_ema_decay = g_params['w_ema_decay'] 16 | self.style_mixing_prob = g_params['style_mixing_prob'] 17 | 18 | self.n_broadcast = len(self.resolutions) * 2 19 | self.mixing_layer_indices = np.arange(self.n_broadcast)[np.newaxis, :, np.newaxis] 20 | 21 | self.g_mapping = Mapping(self.w_dim, self.labels_dim, self.n_mapping, name='g_mapping') 22 | self.broadcast = tf.keras.layers.Lambda(lambda x: tf.tile(x[:, np.newaxis], [1, self.n_broadcast, 1])) 23 | self.synthesis = Synthesis(self.resolutions, self.featuremaps, name='g_synthesis') 24 | 25 | 26 | 27 | def build(self, input_shape): 28 | # w_avg 29 | self.w_avg = tf.Variable(tf.zeros(shape=[self.w_dim], dtype=tf.float32), name='w_avg', trainable=False, 30 | synchronization=tf.VariableSynchronization.ON_READ, 31 | aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) 32 | 33 | @tf.function 34 | def set_as_moving_average_of(self, src_net, beta=0.99, beta_nontrainable=0.0): 35 | for cw, sw in zip(self.weights, src_net.weights): 36 | assert sw.shape == cw.shape 37 | 38 | if 'w_avg' in cw.name: 39 | cw.assign(lerp(sw, cw, beta_nontrainable)) 40 | else: 41 | cw.assign(lerp(sw, cw, beta)) 42 | return 43 | 44 | def update_moving_average_of_w(self, w_broadcasted): 45 | # compute average of current w 46 | batch_avg = tf.reduce_mean(w_broadcasted[:, 0], axis=0) 47 | 48 | # compute moving average of w and update(assign) w_avg 49 | update_w_avg = lerp(batch_avg, self.w_avg, self.w_ema_decay) 50 | 51 | return update_w_avg 52 | 53 | def style_mixing_regularization(self, latents1, labels, w_broadcasted1): 54 | # get another w and broadcast it 55 | latents2 = tf.random.normal(shape=tf.shape(latents1), dtype=tf.float32) 56 | dlatents2 = self.g_mapping([latents2, labels]) 57 | w_broadcasted2 = self.broadcast(dlatents2) 58 | 59 | 60 | # find mixing limit index 61 | if tf.random.uniform([], 0.0, 1.0) < self.style_mixing_prob: 62 | mixing_cutoff_index = tf.random.uniform([], 1, self.n_broadcast, dtype=tf.int32) 63 | else: 64 | mixing_cutoff_index = tf.constant(self.n_broadcast, dtype=tf.int32) 65 | 66 | # mix it 67 | mixed_w_broadcasted = tf.where( 68 | condition=tf.broadcast_to(self.mixing_layer_indices < mixing_cutoff_index, tf.shape(w_broadcasted1)), 69 | x=w_broadcasted1, 70 | y=w_broadcasted2) 71 | 72 | return mixed_w_broadcasted 73 | 74 | def truncation_trick(self, w_broadcasted, truncation_cutoff, truncation_psi): 75 | ones = tf.ones_like(self.mixing_layer_indices, dtype=tf.float32) 76 | tpsi = ones * truncation_psi 77 | 78 | if truncation_cutoff is None: 79 | truncation_coefs = tpsi 80 | else: 81 | # indices = tf.range(self.n_broadcast) 82 | indices = self.mixing_layer_indices 83 | truncation_coefs = tf.where(condition=tf.less(indices, truncation_cutoff), x=tpsi, y=ones) 84 | 85 | truncated_w_broadcasted = lerp(self.w_avg, w_broadcasted, truncation_coefs) 86 | 87 | return truncated_w_broadcasted 88 | 89 | def call(self, inputs, truncation_cutoff=None, truncation_psi=1.0, shift_h=0, shift_w=0, training=None, mapping=True, mask=None): 90 | latents, labels = inputs 91 | 92 | if mapping: 93 | dlatents = self.g_mapping([latents, labels]) 94 | w_broadcasted = self.broadcast(dlatents) 95 | 96 | if training: 97 | self.w_avg.assign(self.update_moving_average_of_w(w_broadcasted)) 98 | w_broadcasted = self.style_mixing_regularization(latents, labels, w_broadcasted) 99 | 100 | if not training: 101 | w_broadcasted = self.truncation_trick(w_broadcasted, truncation_cutoff, truncation_psi) 102 | 103 | else: 104 | w_broadcasted = latents 105 | 106 | image_out = self.synthesis(w_broadcasted, shift_h=shift_h, shift_w=shift_w) 107 | 108 | return image_out, w_broadcasted 109 | 110 | def compute_output_shape(self, input_shape): 111 | assert isinstance(input_shape, list) 112 | 113 | # shape_latents, shape_labels = input_shape 114 | return input_shape[0][0], 3, self.resolutions[-1], self.resolutions[-1] 115 | 116 | 117 | ################################################################################## 118 | # Discriminator Networks 119 | ################################################################################## 120 | class Discriminator(tf.keras.Model): 121 | def __init__(self, d_params, **kwargs): 122 | super(Discriminator, self).__init__(**kwargs) 123 | # discriminator's (resolutions and featuremaps) are reversed against generator's 124 | self.labels_dim = d_params['labels_dim'] 125 | self.r_resolutions = d_params['resolutions'][::-1] 126 | self.r_featuremaps = d_params['featuremaps'][::-1] 127 | 128 | # stack discriminator blocks 129 | res0, n_f0 = self.r_resolutions[0], self.r_featuremaps[0] 130 | self.initial_fromrgb = FromRGB(fmaps=n_f0, name='{:d}x{:d}/FromRGB'.format(res0, res0)) 131 | self.blocks = [] 132 | 133 | for index, (res0, n_f0) in enumerate(zip(self.r_resolutions[:-1], self.r_featuremaps[:-1])): 134 | n_f1 = self.r_featuremaps[index + 1] 135 | self.blocks.append(DiscriminatorBlock(n_f0=n_f0, n_f1=n_f1, name='{:d}x{:d}'.format(res0, res0))) 136 | 137 | # set last discriminator block 138 | res = self.r_resolutions[-1] 139 | n_f0, n_f1 = self.r_featuremaps[-2], self.r_featuremaps[-1] 140 | self.last_block = DiscriminatorLastBlock(n_f0, n_f1, name='{:d}x{:d}'.format(res, res)) 141 | 142 | # set last dense layer 143 | self.last_dense = Dense(max(self.labels_dim, 1), gain=1.0, lrmul=1.0, name='last_dense') 144 | self.last_bias = BiasAct(lrmul=1.0, act='linear', name='last_bias') 145 | 146 | 147 | 148 | # @ tf.function 149 | def call(self, inputs, training=None, mask=None): 150 | images, labels = inputs 151 | 152 | x = self.initial_fromrgb(images) 153 | for block in self.blocks: 154 | x = block(x) 155 | 156 | x = self.last_block(x) 157 | 158 | logit = self.last_dense(x) 159 | logit = self.last_bias(logit) 160 | 161 | if self.labels_dim > 0: 162 | logit = tf.reduce_sum(logit * labels, axis=1, keepdims=True) 163 | 164 | scores_out = logit 165 | 166 | return scores_out 167 | 168 | def compute_output_shape(self, input_shape): 169 | return input_shape[0][0], max(self.labels_dim, 1) 170 | 171 | ################################################################################## 172 | # Mapping Networks 173 | ################################################################################## 174 | class Mapping(tf.keras.layers.Layer): 175 | def __init__(self, w_dim, labels_dim, n_mapping, **kwargs): 176 | super(Mapping, self).__init__(**kwargs) 177 | self.w_dim = w_dim 178 | self.labels_dim = labels_dim 179 | self.n_mapping = n_mapping 180 | self.gain = 1.0 181 | self.lrmul = 0.01 182 | 183 | if self.labels_dim > 0: 184 | self.labels_embedding = LabelEmbedding(embed_dim=self.w_dim, name='labels_embedding') 185 | 186 | self.normalize = tf.keras.layers.Lambda(lambda x: x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + 1e-8)) 187 | 188 | self.dense_layers = [] 189 | self.bias_act_layers = [] 190 | 191 | for ii in range(self.n_mapping): 192 | self.dense_layers.append(Dense(w_dim, gain=self.gain, lrmul=self.lrmul, name='dense_{:d}'.format(ii))) 193 | self.bias_act_layers.append(BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_{:d}'.format(ii))) 194 | 195 | def call(self, inputs, training=None, mask=None): 196 | latents, labels = inputs 197 | x = latents 198 | 199 | # embed label if any 200 | if self.labels_dim > 0: 201 | y = self.labels_embedding(labels) 202 | x = tf.concat([x, y], axis=1) 203 | 204 | # normalize inputs 205 | x = self.normalize(x) 206 | 207 | # apply mapping blocks 208 | for dense, apply_bias_act in zip(self.dense_layers, self.bias_act_layers): 209 | x = dense(x) 210 | x = apply_bias_act(x) 211 | 212 | return x 213 | 214 | # def get_config(self): 215 | # config = super(Mapping, self).get_config() 216 | # config.update({ 217 | # 'w_dim': self.w_dim, 218 | # 'labels_dim': self.labels_dim, 219 | # 'n_mapping': self.n_mapping, 220 | # 'gain': self.gain, 221 | # 'lrmul': self.lrmul, 222 | # }) 223 | # return config 224 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cuda.upfirdn_2d import * 4 | from cuda.fused_bias_act import fused_bias_act 5 | 6 | def compute_runtime_coef(weight_shape, gain, lrmul): 7 | fan_in = tf.reduce_prod(weight_shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] 8 | fan_in = tf.cast(fan_in, dtype=tf.float32) 9 | he_std = gain / tf.sqrt(fan_in) 10 | init_std = 1.0 / lrmul 11 | runtime_coef = he_std * lrmul 12 | return init_std, runtime_coef 13 | 14 | def lerp(a, b, t): 15 | out = a + (b - a) * t 16 | return out 17 | 18 | def lerp_clip(a, b, t): 19 | out = a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 20 | return out 21 | 22 | ################################################################################## 23 | # Layers 24 | ################################################################################## 25 | 26 | class Conv2D(tf.keras.layers.Layer): 27 | def __init__(self, fmaps, kernel, up, down, resample_kernel, gain, lrmul, **kwargs): 28 | super(Conv2D, self).__init__(**kwargs) 29 | self.fmaps = fmaps 30 | self.kernel = kernel 31 | self.gain = gain 32 | self.lrmul = lrmul 33 | self.up = up 34 | self.down = down 35 | 36 | self.k, self.pad0, self.pad1 = compute_paddings(resample_kernel, self.kernel, up, down, is_conv=True) 37 | 38 | def build(self, input_shape): 39 | weight_shape = [self.kernel, self.kernel, input_shape[1], self.fmaps] 40 | init_std, self.runtime_coef = compute_runtime_coef(weight_shape, self.gain, self.lrmul) 41 | 42 | # [kernel, kernel, fmaps_in, fmaps_out] 43 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=init_std) 44 | self.w = tf.Variable(w_init, name='w', trainable=True) 45 | 46 | def call(self, inputs, training=None, mask=None): 47 | x = inputs 48 | w = self.runtime_coef * self.w 49 | 50 | # actual conv 51 | if self.up: 52 | x = upsample_conv_2d(x, w, self.kernel, self.kernel, self.pad0, self.pad1, self.k) 53 | elif self.down: 54 | x = conv_downsample_2d(x, w, self.kernel, self.kernel, self.pad0, self.pad1, self.k) 55 | else: 56 | x = tf.nn.conv2d(x, w, data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME') 57 | return x 58 | 59 | # def get_config(self): 60 | # config = super(Conv2D, self).get_config() 61 | # config.update({ 62 | # 'in_res': self.in_res, 63 | # 'in_fmaps': self.in_fmaps, 64 | # 'fmaps': self.fmaps, 65 | # 'kernel': self.kernel, 66 | # 'gain': self.gain, 67 | # 'lrmul': self.lrmul, 68 | # 'up': self.up, 69 | # 'down': self.down, 70 | # 'k': self.k, 71 | # 'pad0': self.pad0, 72 | # 'pad1': self.pad1, 73 | # 'runtime_coef': self.runtime_coef, 74 | # }) 75 | # return config 76 | 77 | class ModulatedConv2D(tf.keras.layers.Layer): 78 | def __init__(self, fmaps, style_fmaps, kernel, up, down, demodulate, resample_kernel, gain, lrmul, fused_modconv, **kwargs): 79 | super(ModulatedConv2D, self).__init__(**kwargs) 80 | assert not (up and down) 81 | 82 | self.fmaps = fmaps 83 | self.style_fmaps = style_fmaps 84 | self.kernel = kernel 85 | self.demodulate = demodulate 86 | self.up = up 87 | self.down = down 88 | self.fused_modconv = fused_modconv 89 | self.gain = gain 90 | self.lrmul = lrmul 91 | 92 | self.k, self.pad0, self.pad1 = compute_paddings(resample_kernel, self.kernel, up, down, is_conv=True) 93 | 94 | # self.factor = 2 95 | self.mod_dense = Dense(self.style_fmaps, gain=1.0, lrmul=1.0, name='mod_dense') 96 | self.mod_bias = BiasAct(lrmul=1.0, act='linear', name='mod_bias') 97 | 98 | def build(self, input_shape): 99 | x_shape, w_shape = input_shape[0], input_shape[1] 100 | in_fmaps = x_shape[1] 101 | weight_shape = [self.kernel, self.kernel, in_fmaps, self.fmaps] 102 | init_std, self.runtime_coef = compute_runtime_coef(weight_shape, self.gain, self.lrmul) 103 | 104 | # [kkIO] 105 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=init_std) 106 | self.w = tf.Variable(w_init, name='w', trainable=True) 107 | 108 | def scale_conv_weights(self, w): 109 | # convolution kernel weights for fused conv 110 | weight = self.runtime_coef * self.w # [kkIO] 111 | weight = weight[np.newaxis] # [BkkIO] 112 | 113 | # modulation 114 | style = self.mod_dense(w) # [BI] 115 | style = self.mod_bias(style) + 1.0 # [BI] 116 | weight *= style[:, np.newaxis, np.newaxis, :, np.newaxis] # [BkkIO] 117 | 118 | # demodulation 119 | d = None 120 | if self.demodulate: 121 | d = tf.math.rsqrt(tf.reduce_sum(tf.square(weight), axis=[1, 2, 3]) + 1e-8) # [BO] 122 | weight *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO] 123 | 124 | return weight, style, d 125 | 126 | def call(self, inputs, training=None, mask=None): 127 | x, y = inputs 128 | # height, width = tf.shape(x)[2], tf.shape(x)[3] 129 | 130 | # prepare weights: [BkkIO] Introduce minibatch dimension 131 | # prepare convoultuon kernel weights 132 | weight, style, d = self.scale_conv_weights(y) 133 | 134 | if self.fused_modconv: 135 | # Fused => reshape minibatch to convolution groups 136 | x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) 137 | 138 | # weight: reshape, prepare for fused operation 139 | new_weight_shape = [tf.shape(weight)[1], tf.shape(weight)[2], tf.shape(weight)[3], -1] # [kkI(BO)] 140 | weight = tf.transpose(weight, [1, 2, 3, 0, 4]) # [kkIBO] 141 | weight = tf.reshape(weight, shape=new_weight_shape) # [kkI(BO)] 142 | else: 143 | # [BIhw] Not fused => scale input activations 144 | x *= style[:, :, tf.newaxis, tf.newaxis] 145 | 146 | # Convolution with optional up/downsampling. 147 | if self.up: 148 | x = upsample_conv_2d(x, weight, self.kernel, self.kernel, self.pad0, self.pad1, self.k) 149 | elif self.down: 150 | x = conv_downsample_2d(x, weight, self.kernel, self.kernel, self.pad0, self.pad1, self.k) 151 | else: 152 | x = tf.nn.conv2d(x, weight, data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME') 153 | 154 | # Reshape/scale output 155 | if self.fused_modconv: 156 | # Fused => reshape convolution groups back to minibatch 157 | x_shape = tf.shape(x) 158 | x = tf.reshape(x, [-1, self.fmaps, x_shape[2], x_shape[3]]) 159 | elif self.demodulate: 160 | # [BOhw] Not fused => scale output activations 161 | x *= d[:, :, tf.newaxis, tf.newaxis] 162 | 163 | return x 164 | 165 | # def get_config(self): 166 | # config = super(ModulatedConv2D, self).get_config() 167 | # config.update({ 168 | # 'in_res': self.in_res, 169 | # 'in_fmaps': self.in_fmaps, 170 | # 'fmaps': self.fmaps, 171 | # 'kernel': self.kernel, 172 | # 'demodulate': self.demodulate, 173 | # 'fused_modconv': self.fused_modconv, 174 | # 'gain': self.gain, 175 | # 'lrmul': self.lrmul, 176 | # 'up': self.up, 177 | # 'down': self.down, 178 | # 'k': self.k, 179 | # 'pad0': self.pad0, 180 | # 'pad1': self.pad1, 181 | # 'runtime_coef': self.runtime_coef, 182 | # }) 183 | # return config 184 | 185 | class Dense(tf.keras.layers.Layer): 186 | def __init__(self, fmaps, gain, lrmul, **kwargs): 187 | super(Dense, self).__init__(**kwargs) 188 | self.fmaps = fmaps 189 | self.gain = gain 190 | self.lrmul = lrmul 191 | 192 | def build(self, input_shape): 193 | fan_in = tf.reduce_prod(input_shape[1:]) 194 | weight_shape = [fan_in, self.fmaps] 195 | init_std, self.runtime_coef = compute_runtime_coef(weight_shape, self.gain, self.lrmul) 196 | 197 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=init_std) 198 | self.w = tf.Variable(w_init, name='w', trainable=True) 199 | 200 | def call(self, inputs, training=None, mask=None): 201 | weight = self.runtime_coef * self.w 202 | 203 | c = tf.reduce_prod(tf.shape(inputs)[1:]) 204 | x = tf.reshape(inputs, shape=[-1, c]) 205 | x = tf.matmul(x, weight) 206 | return x 207 | 208 | # def get_config(self): 209 | # config = super(Dense, self).get_config() 210 | # config.update({ 211 | # 'fmaps': self.fmaps, 212 | # 'gain': self.gain, 213 | # 'lrmul': self.lrmul, 214 | # 'runtime_coef': self.runtime_coef, 215 | # }) 216 | # return config 217 | 218 | class LabelEmbedding(tf.keras.layers.Layer): 219 | def __init__(self, embed_dim, **kwargs): 220 | super(LabelEmbedding, self).__init__(**kwargs) 221 | self.embed_dim = embed_dim 222 | 223 | def build(self, input_shape): 224 | weight_shape = [input_shape[1], self.embed_dim] 225 | # tf 1.15 mean(0.0), std(1.0) default value of tf.initializers.random_normal() 226 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=1.0) 227 | self.w = tf.Variable(w_init, name='w', trainable=True) 228 | 229 | def call(self, inputs, training=None, mask=None): 230 | x = tf.matmul(inputs, self.w) 231 | return x 232 | 233 | # def get_config(self): 234 | # config = super(LabelEmbedding, self).get_config() 235 | # config.update({ 236 | # 'embed_dim': self.embed_dim, 237 | # }) 238 | # return config 239 | 240 | ################################################################################## 241 | # Blocks 242 | ################################################################################## 243 | class FromRGB(tf.keras.layers.Layer): 244 | def __init__(self, fmaps, **kwargs): 245 | super(FromRGB, self).__init__(**kwargs) 246 | self.fmaps = fmaps 247 | 248 | self.conv = Conv2D(fmaps=self.fmaps, kernel=1, up=False, down=False, 249 | resample_kernel=None, gain=1.0, lrmul=1.0, name='conv') 250 | self.apply_bias_act = BiasAct(lrmul=1.0, act='lrelu', name='bias') 251 | 252 | def call(self, inputs, training=None, mask=None): 253 | y = self.conv(inputs) 254 | y = self.apply_bias_act(y) 255 | return y 256 | 257 | # def get_config(self): 258 | # config = super(FromRGB, self).get_config() 259 | # config.update({ 260 | # 'fmaps': self.fmaps, 261 | # 'res': self.res, 262 | # }) 263 | # return config 264 | 265 | class ToRGB(tf.keras.layers.Layer): 266 | def __init__(self, in_ch, **kwargs): 267 | super(ToRGB, self).__init__(**kwargs) 268 | self.in_ch = in_ch 269 | 270 | self.conv = ModulatedConv2D(fmaps=3, style_fmaps=in_ch, kernel=1, up=False, down=False, demodulate=False, 271 | resample_kernel=None, gain=1.0, lrmul=1.0, fused_modconv=True, name='conv') 272 | self.apply_bias = BiasAct(lrmul=1.0, act='linear', name='bias') 273 | 274 | def call(self, inputs, training=None, mask=None): 275 | x, w = inputs 276 | 277 | x = self.conv([x, w]) 278 | x = self.apply_bias(x) 279 | return x 280 | 281 | # def get_config(self): 282 | # config = super(ToRGB, self).get_config() 283 | # config.update({ 284 | # 'in_ch': self.in_ch, 285 | # 'res': self.res, 286 | # }) 287 | # return config 288 | 289 | class PE2d(tf.keras.layers.Layer): 290 | def __init__(self, channel, height, width, scale=1.0): 291 | super(PE2d, self).__init__() 292 | if channel % 4 != 0: 293 | raise ValueError("Cannot use sin/cos positional encoding with " 294 | "odd dimension (got dim={:d})".format(channel)) 295 | 296 | height = int(height * scale) 297 | width = int(width * scale) 298 | self.pe = np.zeros(shape=[channel, height, width], dtype=np.float32) 299 | 300 | # Each dimension use half of d_model 301 | self.d_model = int(channel / 2) 302 | self.div_term = np.exp(np.arange(0., self.d_model, 2.) * -(np.log(10000.) / self.d_model)) / scale 303 | self.pos_h = np.expand_dims(np.arange(0., height), axis=-1) # [4, 1] 304 | self.pos_w = np.expand_dims(np.arange(0., width), axis=-1) 305 | 306 | 307 | self.gamma = tf.Variable(initial_value=tf.ones(shape=[1], dtype=tf.float32), trainable=True) 308 | 309 | def call(self, inputs, shift_h=0, shift_w=0, training=None, mask=None): 310 | pos_h = np.roll(self.pos_h, round(shift_h), 0) + (round(shift_h) - shift_h) 311 | pos_w = np.roll(self.pos_w, round(shift_w), 0) + (round(shift_w) - shift_w) 312 | 313 | self.pe[0:self.d_model:2, :, :] = np.tile( 314 | np.expand_dims( 315 | np.transpose( 316 | np.sin(pos_w * self.div_term), 317 | axes=[1, 0]), 318 | axis=1), 319 | reps=[1, pos_h.shape[0], 1]) 320 | 321 | self.pe[1:self.d_model:2, :, :] = np.tile( 322 | np.expand_dims( 323 | np.transpose( 324 | np.cos(pos_w * self.div_term), 325 | axes=[1, 0]), 326 | axis=1), 327 | reps=[1, pos_h.shape[0], 1]) 328 | 329 | self.pe[self.d_model::2, :, :] = np.tile( 330 | np.expand_dims( 331 | np.transpose( 332 | np.sin(pos_h * self.div_term), 333 | axes=[1, 0]), 334 | axis=2), 335 | reps=[1, 1, pos_w.shape[0]]) 336 | 337 | self.pe[self.d_model + 1::2, :, :] = np.tile( 338 | np.expand_dims( 339 | np.transpose( 340 | np.cos(pos_h * self.div_term), 341 | axes=[1, 0]), 342 | axis=2), 343 | reps=[1, 1, pos_w.shape[0]]) 344 | 345 | x = tf.cast(inputs, dtype=tf.float32) + self.gamma * np.expand_dims(self.pe, axis=0) 346 | 347 | return x 348 | 349 | class PE2dStart(tf.keras.layers.Layer): 350 | def __init__(self, channel, height, width, scale=1.0): 351 | super(PE2dStart, self).__init__() 352 | if channel % 4 != 0: 353 | raise ValueError("Cannot use sin/cos positional encoding with " 354 | "odd dimension (got dim={:d})".format(channel)) 355 | 356 | height = int(height * scale) 357 | width = int(width * scale) 358 | self.pe = np.zeros(shape=[channel, height, width]) 359 | 360 | # Each dimension use half of d_model 361 | self.d_model = int(channel / 2) 362 | self.div_term = np.exp(np.arange(0., self.d_model, 2.) * -(np.log(10000.) / self.d_model)) / scale 363 | self.pos_h = np.expand_dims(np.arange(0., height), axis=-1) # [4, 1] 364 | self.pos_w = np.expand_dims(np.arange(0., width), axis=-1) 365 | 366 | def call(self, inputs, shift_h=0, shift_w=0, training=None, mask=None): 367 | pos_h = np.roll(self.pos_h, round(shift_h), 0) + (round(shift_h) - shift_h) 368 | pos_w = np.roll(self.pos_w, round(shift_w), 0) + (round(shift_w) - shift_w) 369 | 370 | self.pe[0:self.d_model:2, :, :] = np.tile( 371 | np.expand_dims( 372 | np.transpose( 373 | np.sin(pos_w * self.div_term), 374 | axes=[1, 0]), 375 | axis=1), 376 | reps=[1, pos_h.shape[0], 1]) 377 | 378 | self.pe[1:self.d_model:2, :, :] = np.tile( 379 | np.expand_dims( 380 | np.transpose( 381 | np.cos(pos_w * self.div_term), 382 | axes=[1, 0]), 383 | axis=1), 384 | reps=[1, pos_h.shape[0], 1]) 385 | 386 | self.pe[self.d_model::2, :, :] = np.tile( 387 | np.expand_dims( 388 | np.transpose( 389 | np.sin(pos_h * self.div_term), 390 | axes=[1, 0]), 391 | axis=2), 392 | reps=[1, 1, pos_w.shape[0]]) 393 | 394 | self.pe[self.d_model + 1::2, :, :] = np.tile( 395 | np.expand_dims( 396 | np.transpose( 397 | np.cos(pos_h * self.div_term), 398 | axes=[1, 0]), 399 | axis=2), 400 | reps=[1, 1, pos_w.shape[0]]) 401 | 402 | x = np.tile(np.expand_dims(self.pe, axis=0), reps=[inputs.shape[0], 1, 1, 1]) 403 | 404 | return x 405 | 406 | class ConstantInput(tf.keras.layers.Layer): 407 | def __init__(self, channel, size=4): 408 | super(ConstantInput, self).__init__() 409 | 410 | const_init = tf.random.normal(shape=(1, channel, size, size), mean=0.0, stddev=1.0) 411 | self.const = tf.Variable(const_init, name='const', trainable=True) 412 | 413 | def call(self, inputs, training=None, mask=None): 414 | batch = inputs.shape[0] 415 | x = tf.tile(self.const, multiples=[batch, 1, 1, 1]) 416 | 417 | return x 418 | 419 | ################################################################################## 420 | # etc 421 | ################################################################################## 422 | class BiasAct(tf.keras.layers.Layer): 423 | def __init__(self, lrmul, act, **kwargs): 424 | super(BiasAct, self).__init__(**kwargs) 425 | self.lrmul = lrmul 426 | self.act = act 427 | 428 | def build(self, input_shape): 429 | b_init = tf.zeros(shape=(input_shape[1],), dtype=tf.float32) 430 | self.b = tf.Variable(b_init, name='b', trainable=True) 431 | 432 | def call(self, inputs, training=None, mask=None): 433 | b = self.lrmul * self.b 434 | x = fused_bias_act(inputs, b=b, act=self.act, alpha=None, gain=None) 435 | return x 436 | 437 | # def get_config(self): 438 | # config = super(BiasAct, self).get_config() 439 | # config.update({ 440 | # 'lrmul': self.lrmul, 441 | # 'act': self.act, 442 | # }) 443 | # return config 444 | 445 | class Noise(tf.keras.layers.Layer): 446 | def __init__(self, **kwargs): 447 | super(Noise, self).__init__(**kwargs) 448 | 449 | def build(self, input_shape): 450 | self.noise_strength = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=True, name='w') 451 | 452 | 453 | def call(self, inputs, noise=None, training=None, mask=None): 454 | x_shape = tf.shape(inputs) 455 | 456 | # noise: [1, 1, x_shape[2], x_shape[3]] or None 457 | if noise is None: 458 | noise = tf.random.normal(shape=(x_shape[0], 1, x_shape[2], x_shape[3]), dtype=tf.float32) 459 | 460 | x = inputs + noise * self.noise_strength 461 | return x 462 | 463 | def get_config(self): 464 | config = super(Noise, self).get_config() 465 | config.update({}) 466 | return config 467 | 468 | class MinibatchStd(tf.keras.layers.Layer): 469 | def __init__(self, group_size, num_new_features, **kwargs): 470 | super(MinibatchStd, self).__init__(**kwargs) 471 | self.group_size = group_size 472 | self.num_new_features = num_new_features 473 | 474 | def call(self, inputs, training=None, mask=None): 475 | s = tf.shape(inputs) 476 | group_size = tf.minimum(self.group_size, s[0]) 477 | 478 | y = tf.reshape(inputs, [group_size, -1, self.num_new_features, s[1] // self.num_new_features, s[2], s[3]]) 479 | y = tf.cast(y, tf.float32) 480 | y -= tf.reduce_mean(y, axis=0, keepdims=True) 481 | y = tf.reduce_mean(tf.square(y), axis=0) 482 | y = tf.sqrt(y + 1e-8) 483 | y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True) 484 | y = tf.reduce_mean(y, axis=[2]) 485 | y = tf.cast(y, inputs.dtype) 486 | y = tf.tile(y, [group_size, 1, s[2], s[3]]) 487 | 488 | x = tf.concat([inputs, y], axis=1) 489 | return x 490 | 491 | def get_config(self): 492 | config = super(MinibatchStd, self).get_config() 493 | config.update({ 494 | 'group_size': self.group_size, 495 | 'num_new_features': self.num_new_features, 496 | }) 497 | return config 498 | 499 | def torch_normalization(x): 500 | x /= 255. 501 | 502 | r, g, b = tf.split(axis=-1, num_or_size_splits=3, value=x) 503 | 504 | mean = [0.485, 0.456, 0.406] 505 | std = [0.229, 0.224, 0.225] 506 | 507 | x = tf.concat(axis=-1, values=[ 508 | (r - mean[0]) / std[0], 509 | (g - mean[1]) / std[1], 510 | (b - mean[2]) / std[2] 511 | ]) 512 | 513 | return x 514 | 515 | 516 | def inception_processing(filename): 517 | x = tf.io.read_file(filename) 518 | img = tf.image.decode_jpeg(x, channels=3, dct_method='INTEGER_ACCURATE') 519 | img = tf.image.resize(img, [256, 256], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 520 | img = tf.image.resize(img, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 521 | 522 | img = torch_normalization(img) 523 | # img = tf.transpose(img, [2, 0, 1]) 524 | return img -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | 5 | import tensorflow as tf 6 | from glob import glob 7 | 8 | class Image_data: 9 | 10 | def __init__(self, img_size, z_dim, labels_dim, dataset_path): 11 | self.img_size = img_size 12 | self.z_dim = z_dim 13 | self.labels_dim = labels_dim 14 | self.dataset_path = dataset_path 15 | 16 | 17 | def image_processing(self, filename): 18 | 19 | x = tf.io.read_file(filename) 20 | x_decode = tf.image.decode_jpeg(x, channels=3, dct_method='INTEGER_ACCURATE') 21 | img = tf.image.resize(x_decode, [self.img_size, self.img_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 22 | img = preprocess_fit_train_image(img) 23 | 24 | latent = tf.random.normal(shape=(self.z_dim,), dtype=tf.float32) 25 | labels = tf.random.normal((self.labels_dim,), dtype=tf.float32) 26 | 27 | return img, latent, labels 28 | 29 | def preprocess(self): 30 | 31 | self.train_images = glob(os.path.join(self.dataset_path, '*.png')) + glob(os.path.join(self.dataset_path, '*.jpg')) 32 | 33 | def adjust_dynamic_range(images, range_in, range_out, out_dtype): 34 | scale = (range_out[1] - range_out[0]) / (range_in[1] - range_in[0]) 35 | bias = range_out[0] - range_in[0] * scale 36 | images = images * scale + bias 37 | images = tf.clip_by_value(images, range_out[0], range_out[1]) 38 | images = tf.cast(images, dtype=out_dtype) 39 | return images 40 | 41 | def random_flip_left_right(images): 42 | s = tf.shape(images) 43 | mask = tf.random.uniform([1, 1, 1], 0.0, 1.0) 44 | mask = tf.tile(mask, [s[0], s[1], s[2]]) # [h, w, c] 45 | images = tf.where(mask < 0.5, images, tf.reverse(images, axis=[1])) 46 | return images 47 | 48 | def preprocess_fit_train_image(images): 49 | images = adjust_dynamic_range(images, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32) 50 | images = random_flip_left_right(images) 51 | images = tf.transpose(images, [2, 0, 1]) 52 | 53 | return images 54 | 55 | def preprocess_image(images): 56 | images = adjust_dynamic_range(images, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32) 57 | images = tf.transpose(images, [2, 0, 1]) 58 | 59 | return images 60 | 61 | def postprocess_images(images): 62 | images = adjust_dynamic_range(images, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.dtypes.float32) 63 | images = tf.transpose(images, [0, 2, 3, 1]) 64 | images = tf.cast(images, dtype=tf.dtypes.uint8) 65 | return images 66 | 67 | def merge_batch_images(images, res, rows, cols): 68 | batch_size = images.shape[0] 69 | assert rows * cols == batch_size 70 | canvas = np.zeros(shape=[res * rows, res * cols, 3], dtype=np.uint8) 71 | for row in range(rows): 72 | y_start = row * res 73 | for col in range(cols): 74 | x_start = col * res 75 | index = col + row * cols 76 | canvas[y_start:y_start + res, x_start:x_start + res, :] = images[index, :, :, :] 77 | return canvas 78 | 79 | def load_images(image_path, img_width, img_height, img_channel): 80 | 81 | # from PIL import Image 82 | if img_channel == 1 : 83 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE) 84 | else : 85 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) 86 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 87 | 88 | # img = cv2.resize(img, dsize=(img_width, img_height)) 89 | img = tf.image.resize(img, [img_height, img_width], antialias=True, method=tf.image.ResizeMethod.BICUBIC) 90 | img = preprocess_image(img) 91 | 92 | if img_channel == 1 : 93 | img = np.expand_dims(img, axis=0) 94 | img = np.expand_dims(img, axis=-1) 95 | else : 96 | img = np.expand_dims(img, axis=0) 97 | 98 | return img 99 | 100 | def save_images(images, size, image_path): 101 | # size = [height, width] 102 | return imsave(postprocess_images(images), size, image_path) 103 | 104 | def imsave(images, size, path): 105 | images = merge(images, size) 106 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 107 | 108 | return cv2.imwrite(path, images) 109 | 110 | def merge(images, size): 111 | h, w = images.shape[1], images.shape[2] 112 | img = np.zeros((h * size[0], w * size[1], 3)) 113 | for idx, image in enumerate(images): 114 | i = idx % size[1] 115 | j = idx // size[1] 116 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image 117 | 118 | return img 119 | 120 | def str2bool(x): 121 | return x.lower() in ('true') 122 | 123 | def check_folder(log_dir): 124 | if not os.path.exists(log_dir): 125 | os.makedirs(log_dir) 126 | return log_dir 127 | 128 | def filter_resolutions_featuremaps(resolutions, featuremaps, res): 129 | index = resolutions.index(res) 130 | filtered_resolutions = resolutions[:index + 1] 131 | filtered_featuremaps = featuremaps[:index + 1] 132 | return filtered_resolutions, filtered_featuremaps 133 | 134 | def pytorch_xavier_weight_factor(gain=0.02) : 135 | 136 | factor = gain * gain 137 | mode = 'fan_avg' 138 | 139 | return factor, mode 140 | 141 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='relu') : 142 | 143 | if activation_function == 'relu' : 144 | gain = np.sqrt(2.0) 145 | elif activation_function == 'leaky_relu' : 146 | gain = np.sqrt(2.0 / (1 + a ** 2)) 147 | elif activation_function =='tanh' : 148 | gain = 5.0 / 3 149 | else : 150 | gain = 1.0 151 | 152 | factor = gain * gain 153 | mode = 'fan_in' 154 | 155 | return factor, mode 156 | 157 | def automatic_gpu_usage() : 158 | gpus = tf.config.experimental.list_physical_devices('GPU') 159 | if gpus: 160 | try: 161 | # Currently, memory growth needs to be the same across GPUs 162 | for gpu in gpus: 163 | tf.config.experimental.set_memory_growth(gpu, True) 164 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 165 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 166 | except RuntimeError as e: 167 | # Memory growth must be set before GPUs have been initialized 168 | print(e) 169 | 170 | def multiple_gpu_usage(): 171 | gpus = tf.config.experimental.list_physical_devices('GPU') 172 | if gpus: 173 | # Create 2 virtual GPUs with 1GB memory each 174 | try: 175 | tf.config.experimental.set_virtual_device_configuration( 176 | gpus[0], 177 | [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096), 178 | tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)]) 179 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 180 | print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs") 181 | except RuntimeError as e: 182 | # Virtual devices must be set before GPUs have been initialized 183 | print(e) 184 | 185 | def get_batch_sizes(gpu_num) : 186 | from collections import OrderedDict 187 | # batch size for each gpu 188 | 189 | if gpu_num == 1: 190 | x = OrderedDict([(4, 256), (8, 256), (16, 128), (32, 64), (64, 32), (128, 16), (256, 8), (512, 4), (1024, 4)]) 191 | 192 | elif gpu_num == 2 or gpu_num == 3: 193 | x = OrderedDict([(4, 128), (8, 128), (16, 64), (32, 32), (64, 16), (128, 8), (256, 4), (512, 4), (1024, 4)]) 194 | 195 | elif gpu_num == 4 or gpu_num == 5 or gpu_num == 6: 196 | x = OrderedDict([(4, 64), (8, 64), (16, 32), (32, 16), (64, 8), (128, 4), (256, 4), (512, 4), (1024, 4)]) 197 | 198 | elif gpu_num == 7 or gpu_num == 8 or gpu_num == 9: 199 | x = OrderedDict([(4, 32), (8, 32), (16, 16), (32, 8), (64, 4), (128, 4), (256, 4), (512, 4), (1024, 4)]) 200 | 201 | else: # >= 10 202 | x = OrderedDict([(4, 16), (8, 16), (16, 8), (32, 4), (64, 2), (128, 2), (256, 2), (512, 2), (1024, 2)]) 203 | 204 | return x 205 | 206 | def multi_gpu_loss(x, global_batch_size): 207 | ndim = len(x.shape) 208 | no_batch_axis = list(range(1, ndim)) 209 | x = tf.reduce_mean(x, axis=no_batch_axis) 210 | x = tf.reduce_sum(x) / global_batch_size 211 | 212 | return x 213 | 214 | class EasyDict(dict): 215 | from typing import Any 216 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 217 | 218 | def __getattr__(self, name: str) -> Any: 219 | try: 220 | return self[name] 221 | except KeyError: 222 | raise AttributeError(name) 223 | 224 | def __setattr__(self, name: str, value: Any) -> None: 225 | self[name] = value 226 | 227 | def __delattr__(self, name: str) -> None: 228 | del self[name] --------------------------------------------------------------------------------