├── .DS_Store
├── FFHQ_mu_cov.pickle
├── LICENSE
├── README.md
├── StyleGAN2.py
├── assets
├── .DS_Store
├── sample_2.gif
├── style_mixing.png
├── teaser.png
├── truncation_trick.png
└── uncurated.png
├── cuda
├── custom_ops.py
├── fused_bias_act.cu
├── fused_bias_act.py
├── upfirdn_2d.cu
└── upfirdn_2d.py
├── generate_video.py
├── layers.py
├── main.py
├── networks.py
├── ops.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/.DS_Store
--------------------------------------------------------------------------------
/FFHQ_mu_cov.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/FFHQ_mu_cov.pickle
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Spatial unbiased GANs — Simple TensorFlow Implementation [[Paper]](https://arxiv.org/abs/2108.01285)
2 | ### : Toward Spatially Unbiased Generative Models (ICCV 2021)
3 |
4 |
5 |

6 |
7 |
8 | > **Abstract** *Recent image generation models show remarkable generation performance. However, they mirror strong location preference in datasets, which we call **spatial bias**. Therefore, generators render poor samples at unseen locations and scales. We argue that the generators rely on their implicit positional encoding to render spatial content. From our observations, the generator’s implicit positional encoding is translation-variant, making the generator spatially biased. To address this issue, we propose injecting explicit positional encoding at each scale of the generator. By learning the spatially unbiased generator, we facilitate the robust use of generators in multiple tasks, such as GAN inversion, multi-scale generation, generation of arbitrary sizes and aspect ratios. Furthermore, we show that our method can also be applied to denoising diffusion probabilistic models.*
9 |
10 | ## Requirements
11 | * `Tensorflow >= 2.x`
12 |
13 | ## Usage
14 | ```
15 | ├── dataset
16 | └── YOUR_DATASET_NAME
17 | ├── 000001.jpg
18 | ├── 000002.png
19 | └── ...
20 | ```
21 |
22 | ### Train
23 | ```
24 | > python main.py --dataset FFHQ --phase train --img_size 256 --batch_size 4 --n_total_image 6400
25 | ```
26 |
27 | ### Generate Video
28 | ```
29 | > python generate_video.py
30 | ```
31 |
32 | ## Results
33 | * **FID: 3.81 (6.4M images(200k iterations), 8GPU, each 4 batch size)**
34 | * FID reported in the paper: **6.75**
35 | ### Video
36 |
37 |

38 |
39 |
40 | ### Uncuratd
41 |
42 |

43 |
44 |
45 | ### Style mixing
46 | * It's worse than stylegan2.
47 |
48 |

49 |
50 |
51 | ### Truncation trick
52 |
53 |

54 |
55 |
56 | ## Reference
57 | * [Official Pytorch](https://github.com/jychoi118/toward_spatial_unbiased)
58 | * [StyleGAN2-Tensorflow](https://github.com/moono/stylegan2-tf-2.x)
59 |
60 | ## Author
61 | [Junho Kim](http://bit.ly/jhkim_resume)
62 |
--------------------------------------------------------------------------------
/StyleGAN2.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import time
3 | from tensorflow.python.data.experimental import AUTOTUNE
4 | from networks import *
5 | import PIL.Image
6 | import scipy
7 | import pickle
8 | automatic_gpu_usage()
9 |
10 | class Inception_V3(tf.keras.Model):
11 | def __init__(self, name='Inception_V3'):
12 | super(Inception_V3, self).__init__(name=name)
13 |
14 | # tf.keras.backend.image_data_format = 'channels_first'
15 | self.inception_v3_preprocess = tf.keras.applications.inception_v3.preprocess_input
16 | self.inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights='imagenet', include_top=False, pooling='avg')
17 | self.inception_v3.trainable = False
18 |
19 | def torch_normalization(self, x):
20 | x /= 255.
21 |
22 | r, g, b = tf.split(axis=-1, num_or_size_splits=3, value=x)
23 |
24 | mean = [0.485, 0.456, 0.406]
25 | std = [0.229, 0.224, 0.225]
26 |
27 | x = tf.concat(axis=-1, values=[
28 | (r - mean[0]) / std[0],
29 | (g - mean[1]) / std[1],
30 | (b - mean[2]) / std[2]
31 | ])
32 |
33 | return x
34 |
35 | # @tf.function
36 | def call(self, x, training=False, mask=None):
37 | # x = self.torch_normalization(x)
38 | x = self.inception_v3(x, training=training)
39 |
40 | return x
41 |
42 | def inference_feat(self, x, training=False):
43 | inception_real_img = adjust_dynamic_range(x, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.float32)
44 | inception_real_img = tf.image.resize(inception_real_img, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
45 | inception_real_img = self.torch_normalization(inception_real_img)
46 |
47 | inception_feat = self.inception_v3(inception_real_img, training=training)
48 |
49 | return inception_feat
50 |
51 | class StyleGAN2():
52 | def __init__(self, t_params, strategy):
53 | super(StyleGAN2, self).__init__()
54 |
55 | self.model_name = 'StyleGAN2'
56 | self.phase = t_params['phase']
57 | self.checkpoint_dir = t_params['checkpoint_dir']
58 | self.result_dir = t_params['result_dir']
59 | self.log_dir = t_params['log_dir']
60 | self.sample_dir = t_params['sample_dir']
61 | self.dataset_name = t_params['dataset']
62 | self.config = t_params['config']
63 |
64 | self.n_total_image = t_params['n_total_image'] * 1000
65 |
66 | self.strategy = strategy
67 | self.batch_size = t_params['batch_size']
68 | self.each_batch_size = t_params['batch_size'] // t_params['NUM_GPUS']
69 |
70 | self.NUM_GPUS = t_params['NUM_GPUS']
71 | self.iteration = self.n_total_image // self.batch_size
72 |
73 | self.n_samples = min(t_params['batch_size'], t_params['n_samples'])
74 | self.n_test = t_params['n_test']
75 | self.img_size = t_params['img_size']
76 |
77 | self.log_template = 'step [{}/{}]: elapsed: {:.2f}s, d_loss: {:.3f}, g_loss: {:.3f}, r1_reg: {:.3f}, pl_reg: {:.3f}, fid: {:.2f}, best_fid: {:.2f}, best_fid_iter: {}'
78 |
79 | self.lazy_regularization = t_params['lazy_regularization']
80 | self.print_freq = t_params['print_freq']
81 | self.save_freq = t_params['save_freq']
82 |
83 | self.r1_gamma = 10.0
84 |
85 | # setup optimizer params
86 | self.g_params = t_params['g_params']
87 |
88 | self.d_params = t_params['d_params']
89 | self.g_opt = t_params['g_opt']
90 | self.d_opt = t_params['d_opt']
91 | self.g_opt = self.set_optimizer_params(self.g_opt)
92 | self.d_opt = self.set_optimizer_params(self.d_opt)
93 |
94 | self.pl_minibatch_shrink = 2
95 | self.pl_decay = 0.01
96 | self.pl_weight = float(self.pl_minibatch_shrink)
97 | self.pl_denorm = tf.math.rsqrt(float(self.img_size) * float(self.img_size))
98 | self.pl_mean = tf.Variable(initial_value=0.0, name='pl_mean', trainable=False,
99 | synchronization=tf.VariableSynchronization.ON_READ,
100 | aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
101 | )
102 |
103 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
104 | check_folder(self.sample_dir)
105 |
106 | self.checkpoint_dir = os.path.join(self.checkpoint_dir, self.model_dir)
107 | check_folder(self.checkpoint_dir)
108 |
109 | self.log_dir = os.path.join(self.log_dir, self.model_dir)
110 | check_folder(self.log_dir)
111 |
112 |
113 |
114 | dataset_path = './dataset'
115 | self.dataset_path = os.path.join(dataset_path, self.dataset_name)
116 |
117 | print(self.dataset_path)
118 |
119 | if os.path.exists('{}_mu_cov.pickle'.format(self.dataset_name)):
120 | with open('{}_mu_cov.pickle'.format(self.dataset_name), 'rb') as f:
121 | self.real_mu, self.real_cov = pickle.load(f)
122 | self.real_cache = True
123 | print("Pickle load success !!!")
124 | else:
125 | print("Pickle load fail !!!")
126 | self.real_cache = False
127 |
128 | self.fid_samples_num = 10000
129 | print()
130 |
131 | physical_gpus = tf.config.experimental.list_physical_devices('GPU')
132 | logical_gpus = tf.config.experimental.list_logical_devices('GPU')
133 | print(len(physical_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
134 | print("Each batch size : ", self.each_batch_size)
135 | print("Global batch size : ", self.batch_size)
136 | print("Target image size : ", self.img_size)
137 | print("Print frequency : ", self.print_freq)
138 | print("Save frequency : ", self.save_freq)
139 |
140 | print("TF Version :", tf.__version__)
141 |
142 | def set_optimizer_params(self, params):
143 | if self.lazy_regularization:
144 | mb_ratio = params['reg_interval'] / (params['reg_interval'] + 1)
145 | params['learning_rate'] = params['learning_rate'] * mb_ratio
146 | params['beta1'] = params['beta1'] ** mb_ratio
147 | params['beta2'] = params['beta2'] ** mb_ratio
148 | return params
149 |
150 | ##################################################################################
151 | # Model
152 | ##################################################################################
153 | def build_model(self):
154 | if self.phase == 'train':
155 | """ Input Image"""
156 | img_class = Image_data(self.img_size, self.g_params['z_dim'], self.g_params['labels_dim'], self.dataset_path)
157 | img_class.preprocess()
158 |
159 | dataset_num = len(img_class.train_images)
160 | if dataset_num > 10000:
161 | self.fid_samples_num = 50000
162 | print("Dataset number : ", dataset_num)
163 | print()
164 |
165 | dataset_slice = tf.data.Dataset.from_tensor_slices(img_class.train_images)
166 |
167 | gpu_device = '/gpu:0'
168 |
169 | dataset_iter = dataset_slice.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True).repeat()
170 | dataset_iter = dataset_iter.map(map_func=img_class.image_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size, drop_remainder=True)
171 | dataset_iter = dataset_iter.prefetch(buffer_size=AUTOTUNE)
172 | dataset_iter = self.strategy.experimental_distribute_dataset(dataset_iter)
173 |
174 | img_slice = dataset_slice.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True, seed=777)
175 | img_slice = img_slice.map(map_func=inception_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size, drop_remainder=False)
176 | img_slice = img_slice.prefetch(buffer_size=AUTOTUNE)
177 | self.fid_img_slice = self.strategy.experimental_distribute_dataset(img_slice)
178 |
179 | self.dataset_iter = iter(dataset_iter)
180 |
181 |
182 | """ Network """
183 | self.generator = Generator(self.g_params, name='Generator')
184 | self.discriminator = Discriminator(self.d_params, name='Discriminator')
185 | self.g_clone = Generator(self.g_params, name='Generator')
186 | self.inception_model = Inception_V3()
187 |
188 | """ Finalize model (build) """
189 | test_latent = np.ones((1, self.g_params['z_dim']), dtype=np.float32)
190 | test_labels = np.ones((1, self.g_params['labels_dim']), dtype=np.float32)
191 | test_images = np.ones((1, 3, self.img_size, self.img_size), dtype=np.float32)
192 | test_images_inception = np.ones((1, 299, 299, 3), dtype=np.float32)
193 |
194 | _, __ = self.generator([test_latent, test_labels], training=False)
195 | _ = self.discriminator([test_images, test_labels], training=False)
196 | _, __ = self.g_clone([test_latent, test_labels], training=False)
197 | _ = self.inception_model(test_images_inception)
198 |
199 | # Copying g_clone
200 | self.g_clone.set_weights(self.generator.get_weights())
201 |
202 | """ Optimizer """
203 | self.d_optimizer = tf.keras.optimizers.Adam(self.d_opt['learning_rate'],
204 | beta_1=self.d_opt['beta1'],
205 | beta_2=self.d_opt['beta2'],
206 | epsilon=self.d_opt['epsilon'])
207 | self.g_optimizer = tf.keras.optimizers.Adam(self.g_opt['learning_rate'],
208 | beta_1=self.g_opt['beta1'],
209 | beta_2=self.g_opt['beta2'],
210 | epsilon=self.g_opt['epsilon'])
211 |
212 | """ Checkpoint """
213 | self.ckpt = tf.train.Checkpoint(generator=self.generator,
214 | g_clone=self.g_clone,
215 | discriminator=self.discriminator,
216 | g_optimizer=self.g_optimizer,
217 | d_optimizer=self.d_optimizer)
218 | self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2)
219 | self.start_iteration = 0
220 |
221 | if self.manager.latest_checkpoint:
222 | self.ckpt.restore(self.manager.latest_checkpoint).expect_partial()
223 | self.start_iteration = int(self.manager.latest_checkpoint.split('-')[-1])
224 | print('Latest checkpoint restored!!')
225 | print('start iteration : ', self.start_iteration)
226 | else:
227 | print('Not restoring from saved checkpoint')
228 |
229 | else:
230 | """ Test """
231 | """ Network """
232 | self.g_clone = Generator(self.g_params, name='Generator')
233 | self.discriminator = Discriminator(self.d_params, name='Discriminator')
234 |
235 | """ Finalize model (build) """
236 | test_latent = np.ones((1, self.g_params['z_dim']), dtype=np.float32)
237 | test_labels = np.ones((1, self.g_params['labels_dim']), dtype=np.float32)
238 | test_images = np.ones((1, 3, self.img_size, self.img_size), dtype=np.float32)
239 | _ = self.discriminator([test_images, test_labels], training=False)
240 | _, __ = self.g_clone([test_latent, test_labels], training=False)
241 |
242 | """ Checkpoint """
243 | self.ckpt = tf.train.Checkpoint(g_clone=self.g_clone, discriminator=self.discriminator)
244 | self.manager = tf.train.CheckpointManager(self.ckpt, self.checkpoint_dir, max_to_keep=2)
245 |
246 | if self.manager.latest_checkpoint:
247 | self.ckpt.restore(self.manager.latest_checkpoint).expect_partial()
248 | print('Latest checkpoint restored!!')
249 | else:
250 | print('Not restoring from saved checkpoint')
251 |
252 | def d_train_step(self, z, real_images, labels):
253 | with tf.GradientTape() as d_tape:
254 | # forward pass
255 | fake_images, _ = self.generator([z, labels], training=True)
256 | real_scores = self.discriminator([real_images, labels], training=True)
257 | fake_scores = self.discriminator([fake_images, labels], training=True)
258 |
259 | # gan loss
260 | d_adv_loss = tf.math.softplus(fake_scores)
261 | d_adv_loss += tf.math.softplus(-real_scores)
262 | d_adv_loss = multi_gpu_loss(d_adv_loss, global_batch_size=self.batch_size)
263 |
264 | d_loss = d_adv_loss
265 |
266 | d_gradients = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
267 | self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables))
268 |
269 | return d_loss, d_adv_loss
270 |
271 | def d_reg_train_step(self, z, real_images, labels):
272 | with tf.GradientTape() as d_tape:
273 | # forward pass
274 | fake_images, _ = self.generator([z, labels], training=True)
275 | real_scores = self.discriminator([real_images, labels], training=True)
276 | fake_scores = self.discriminator([fake_images, labels], training=True)
277 |
278 | # gan loss
279 | d_adv_loss = tf.math.softplus(fake_scores)
280 | d_adv_loss += tf.math.softplus(-real_scores)
281 |
282 | # simple GP
283 | with tf.GradientTape() as p_tape:
284 | p_tape.watch([real_images, labels])
285 | real_loss = tf.reduce_sum(self.discriminator([real_images, labels], training=True))
286 |
287 | real_grads = p_tape.gradient(real_loss, real_images)
288 | r1_penalty = tf.reduce_sum(tf.math.square(real_grads), axis=[1, 2, 3])
289 | r1_penalty = tf.expand_dims(r1_penalty, axis=1)
290 | r1_penalty = r1_penalty * self.d_opt['reg_interval']
291 |
292 | # combine
293 | d_adv_loss += r1_penalty * (0.5 * self.r1_gamma)
294 | d_adv_loss = multi_gpu_loss(d_adv_loss, global_batch_size=self.batch_size)
295 |
296 | d_loss = d_adv_loss
297 |
298 | d_gradients = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
299 | self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables))
300 |
301 | r1_penalty = multi_gpu_loss(r1_penalty, global_batch_size=self.batch_size)
302 |
303 | return d_loss, d_adv_loss, r1_penalty
304 |
305 | def g_train_step(self, z, labels):
306 | with tf.GradientTape() as g_tape:
307 | # forward pass
308 | fake_images, _ = self.generator([z, labels], training=True)
309 | fake_scores = self.discriminator([fake_images, labels], training=True)
310 |
311 | # gan loss
312 | g_adv_loss = tf.math.softplus(-fake_scores)
313 | g_adv_loss = multi_gpu_loss(g_adv_loss, global_batch_size=self.batch_size)
314 |
315 | g_loss = g_adv_loss
316 |
317 | g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
318 | self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))
319 |
320 | return g_loss, g_adv_loss
321 |
322 | def g_reg_train_step(self, z, labels):
323 | with tf.GradientTape() as g_tape:
324 | # forward pass
325 | fake_images, w_broadcasted = self.generator([z, labels], training=True)
326 | fake_scores = self.discriminator([fake_images, labels], training=True)
327 |
328 | # gan loss
329 | g_adv_loss = tf.math.softplus(-fake_scores)
330 |
331 | # path length regularization
332 | pl_reg = self.path_regularization(pl_minibatch_shrink=self.pl_minibatch_shrink)
333 |
334 | # combine
335 | g_adv_loss += pl_reg
336 | g_adv_loss = multi_gpu_loss(g_adv_loss, global_batch_size=self.batch_size)
337 |
338 | g_loss = g_adv_loss
339 |
340 | g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
341 | self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))
342 |
343 | pl_reg = multi_gpu_loss(pl_reg, global_batch_size=self.batch_size)
344 |
345 | return g_loss, g_adv_loss, pl_reg
346 |
347 | def path_regularization(self, pl_minibatch_shrink=2):
348 | # path length regularization
349 | # Compute |J*y|.
350 |
351 | pl_minibatch = tf.maximum(1, tf.math.floordiv(self.each_batch_size, pl_minibatch_shrink))
352 | pl_z = tf.random.normal(shape=[pl_minibatch, self.g_params['z_dim']], dtype=tf.float32)
353 | pl_labels = tf.random.normal(shape=[pl_minibatch, self.g_params['labels_dim']], dtype=tf.float32)
354 |
355 | with tf.GradientTape() as pl_tape:
356 | pl_tape.watch([pl_z, pl_labels])
357 | pl_fake_images, pl_w_broadcasted = self.generator([pl_z, pl_labels], training=True)
358 |
359 | pl_noise = tf.random.normal(tf.shape(pl_fake_images), mean=0.0, stddev=1.0, dtype=tf.float32) * self.pl_denorm
360 | pl_noise_applied = tf.reduce_sum(pl_fake_images * pl_noise)
361 |
362 | pl_grads = pl_tape.gradient(pl_noise_applied, pl_w_broadcasted)
363 | pl_lengths = tf.math.sqrt(tf.reduce_mean(tf.reduce_sum(tf.math.square(pl_grads), axis=2), axis=1))
364 |
365 | # Track exponential moving average of |J*y|.
366 | pl_mean_val = self.pl_mean + self.pl_decay * (tf.reduce_mean(pl_lengths) - self.pl_mean)
367 | self.pl_mean.assign(pl_mean_val)
368 |
369 | # Calculate (|J*y|-a)^2.
370 | pl_penalty = tf.square(pl_lengths - self.pl_mean) * self.g_opt['reg_interval']
371 |
372 | # compute
373 | pl_reg = pl_penalty * self.pl_weight
374 |
375 | return pl_reg
376 |
377 | """ Distribute Train """
378 | @tf.function
379 | def distribute_d_train_step(self, z, real_images, labels):
380 | d_loss, d_adv_loss = self.strategy.run(self.d_train_step, args=(z, real_images, labels))
381 |
382 | d_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_loss, axis=None)
383 | d_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_adv_loss, axis=None)
384 |
385 | return d_loss, d_adv_loss
386 |
387 | @tf.function
388 | def distribute_d_reg_train_step(self, z, real_images, labels):
389 | d_loss, d_adv_loss, r1_penalty = self.strategy.run(self.d_reg_train_step, args=(z, real_images, labels))
390 |
391 | d_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_loss, axis=None)
392 | d_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, d_adv_loss, axis=None)
393 | r1_penalty = self.strategy.reduce(tf.distribute.ReduceOp.SUM, r1_penalty, axis=None)
394 |
395 | return d_loss, d_adv_loss, r1_penalty
396 |
397 | @tf.function
398 | def distribute_g_train_step(self, z, labels):
399 | g_loss, g_adv_loss = self.strategy.run(self.g_train_step, args=(z, labels))
400 |
401 | g_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_loss, axis=None)
402 | g_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_adv_loss, axis=None)
403 |
404 | return g_loss, g_adv_loss
405 |
406 | @tf.function
407 | def distribute_g_reg_train_step(self, z, labels):
408 | g_loss, g_adv_loss, pl_reg = self.strategy.run(self.g_reg_train_step, args=(z, labels))
409 |
410 | g_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_loss, axis=None)
411 | g_adv_loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM, g_adv_loss, axis=None)
412 | pl_reg = self.strategy.reduce(tf.distribute.ReduceOp.SUM, pl_reg, axis=None)
413 |
414 | return g_loss, g_adv_loss, pl_reg
415 |
416 | def train(self):
417 |
418 | start_time = time.time()
419 |
420 | # setup tensorboards
421 | train_summary_writer = tf.summary.create_file_writer(self.log_dir)
422 |
423 | # start training
424 | print('max_steps: {}'.format(self.iteration))
425 | losses = {'g/loss': 0.0, 'd/loss': 0.0, 'r1_reg': 0.0, 'pl_reg': 0.0,
426 | 'g/adv_loss': 0.0,
427 | 'd/adv_loss': 0.0,
428 | 'fid': 0.0, 'best_fid': 0.0, 'best_fid_iter': 0}
429 | fid = 0
430 | best_fid = 1000
431 | best_fid_iter = 0
432 | for idx in range(self.start_iteration, self.iteration):
433 | iter_start_time = time.time()
434 |
435 | x_real, z, labels = next(self.dataset_iter)
436 |
437 | if idx == 0:
438 | g_params = self.generator.count_params()
439 | d_params = self.discriminator.count_params()
440 | print("G network parameters : ", format(g_params, ','))
441 | print("D network parameters : ", format(d_params, ','))
442 | print("Total network parameters : ", format(g_params + d_params, ','))
443 |
444 | # update discriminator
445 | # At first time, each function takes 1~2 min to make the graph.
446 | if (idx + 1) % self.d_opt['reg_interval'] == 0:
447 | d_loss, d_adv_loss, r1_reg = self.distribute_d_reg_train_step(z, x_real, labels)
448 | losses['r1_reg'] = np.float64(r1_reg)
449 | else:
450 | d_loss, d_adv_loss = self.distribute_d_train_step(z, x_real, labels)
451 |
452 | losses['d/loss'] = np.float64(d_loss)
453 | losses['d/adv_loss'] = np.float64(d_adv_loss)
454 |
455 | # update generator
456 | # At first time, each function takes 1~2 min to make the graph.
457 | if (idx + 1) % self.g_opt['reg_interval'] == 0:
458 | g_loss, g_adv_loss, pl_reg = self.distribute_g_reg_train_step(z, labels)
459 | losses['pl_reg'] = np.float64(pl_reg)
460 | else:
461 | g_loss, g_adv_loss = self.distribute_g_train_step(z, labels)
462 |
463 | losses['g/loss'] = np.float64(g_loss)
464 | losses['g/adv_loss'] = np.float64(g_adv_loss)
465 |
466 |
467 | # update g_clone
468 | self.g_clone.set_as_moving_average_of(self.generator)
469 |
470 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1 :
471 | fid = self.calculate_FID()
472 | if fid < best_fid:
473 | print("BEST FID UPDATED")
474 | best_fid = fid
475 | best_fid_iter = idx
476 | self.manager.save(checkpoint_number=idx)
477 | losses['fid'] = np.float64(fid)
478 |
479 |
480 | # save to tensorboard
481 |
482 | with train_summary_writer.as_default():
483 | tf.summary.scalar('g_loss', losses['g/loss'], step=idx)
484 | tf.summary.scalar('g_adv_loss', losses['g/adv_loss'], step=idx)
485 |
486 | tf.summary.scalar('d_loss', losses['d/loss'], step=idx)
487 | tf.summary.scalar('d_adv_loss', losses['d/adv_loss'], step=idx)
488 |
489 | tf.summary.scalar('r1_reg', losses['r1_reg'], step=idx)
490 | tf.summary.scalar('pl_reg', losses['pl_reg'], step=idx)
491 | # tf.summary.histogram('w_avg', self.generator.w_avg, step=idx)
492 |
493 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1:
494 | tf.summary.scalar('fid', losses['fid'], step=idx)
495 |
496 | # save every self.save_freq
497 | # if np.mod(idx + 1, self.save_freq) == 0:
498 | # self.manager.save(checkpoint_number=idx + 1)
499 |
500 | # save every self.print_freq
501 | if np.mod(idx + 1, self.print_freq) == 0:
502 | total_num_samples = min(self.n_samples, self.batch_size)
503 | partial_size = int(np.floor(np.sqrt(total_num_samples)))
504 |
505 | # prepare inputs
506 | latents = tf.random.normal(shape=(self.n_samples, self.g_params['z_dim']), dtype=tf.dtypes.float32)
507 | dummy_labels = tf.random.normal((self.n_samples, self.g_params['labels_dim']), dtype=tf.dtypes.float32)
508 |
509 | # run networks
510 | fake_img, _ = self.g_clone([latents, dummy_labels], truncation_psi=1.0, training=False)
511 |
512 | save_images(images=fake_img[:partial_size * partial_size, :, :, :],
513 | size=[partial_size, partial_size],
514 | image_path='./{}/fake_{:06d}.png'.format(self.sample_dir, idx + 1))
515 |
516 | x_real_concat = tf.concat(self.strategy.experimental_local_results(x_real), axis=0)
517 | self.truncation_psi_canvas(x_real_concat, path='./{}/fake_psi_{:06d}.png'.format(self.sample_dir, idx + 1))
518 |
519 | elapsed = time.time() - iter_start_time
520 | print(self.log_template.format(idx, self.iteration, elapsed,
521 | losses['d/loss'], losses['g/loss'], losses['r1_reg'], losses['pl_reg'], fid, best_fid, best_fid_iter))
522 | # save model for final step
523 | self.manager.save(checkpoint_number=self.iteration)
524 |
525 | print("LAST FID: ", fid)
526 | print("BEST FID: {}, {}".format(best_fid, best_fid_iter))
527 | print("Total train time: %4.4f" % (time.time() - start_time))
528 |
529 | @property
530 | def model_dir(self):
531 | return "{}_{}_{}_{}".format(self.model_name, self.dataset_name, self.img_size, self.config)
532 |
533 |
534 | def calculate_FID(self):
535 | @tf.function
536 | def gen_samples_feats(test_z, test_labels, g_clone, inception_model):
537 | # run networks
538 | fake_img, _ = g_clone([test_z, test_labels], training=False)
539 | fake_img = adjust_dynamic_range(fake_img, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.float32)
540 | fake_img = tf.transpose(fake_img, [0, 2, 3, 1])
541 | fake_img = tf.image.resize(fake_img, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
542 |
543 | fake_img = torch_normalization(fake_img)
544 |
545 | feats = inception_model(fake_img)
546 |
547 | return feats
548 |
549 | @tf.function
550 | def get_inception_features(img, inception_model):
551 | feats = inception_model(img)
552 | return feats
553 |
554 | @tf.function
555 | def get_real_features(img, inception_model):
556 | feats = self.strategy.run(get_inception_features, args=(img, inception_model))
557 | feats = tf.concat(self.strategy.experimental_local_results(feats), axis=0)
558 |
559 | return feats
560 |
561 | @tf.function
562 | def get_fake_features(z, dummy_labels, g_clone, inception_model):
563 |
564 | feats = self.strategy.run(gen_samples_feats, args=(z, dummy_labels, g_clone, inception_model))
565 | feats = tf.concat(self.strategy.experimental_local_results(feats), axis=0)
566 |
567 | return feats
568 |
569 | @tf.function
570 | def convert_per_replica_image(nchw_per_replica_images, strategy):
571 | as_tensor = tf.concat(strategy.experimental_local_results(nchw_per_replica_images), axis=0)
572 | as_tensor = tf.transpose(as_tensor, perm=[0, 2, 3, 1])
573 | as_tensor = (tf.clip_by_value(as_tensor, -1.0, 1.0) + 1.0) * 127.5
574 | as_tensor = tf.cast(as_tensor, tf.uint8)
575 | as_tensor = tf.image.resize(as_tensor, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
576 |
577 | return as_tensor
578 |
579 | if not self.real_cache:
580 | real_feats = tf.zeros([0, 2048])
581 | """ Input Image"""
582 | # img_class = Image_data(self.img_size, self.g_params['z_dim'], self.g_params['labels_dim'],
583 | # self.dataset_path)
584 | # img_class.preprocess()
585 | # dataset_num = len(img_class.train_images)
586 | # img_slice = tf.data.Dataset.from_tensor_slices(img_class.train_images)
587 | #
588 | # img_slice = img_slice.shuffle(buffer_size=dataset_num, reshuffle_each_iteration=True, seed=777)
589 | # img_slice = img_slice.map(map_func=inception_processing, num_parallel_calls=AUTOTUNE).batch(self.batch_size,
590 | # drop_remainder=False)
591 | # img_slice = img_slice.prefetch(buffer_size=AUTOTUNE)
592 | # img_slice = self.strategy.experimental_distribute_dataset(img_slice)
593 |
594 | for img in self.fid_img_slice:
595 | feats = get_real_features(img, self.inception_model)
596 | real_feats = tf.concat([real_feats, feats], axis=0)
597 | print('real feats:', np.shape(real_feats)[0])
598 |
599 | self.real_mu = np.mean(real_feats, axis=0)
600 | self.real_cov = np.cov(real_feats, rowvar=False)
601 |
602 | with open('{}_mu_cov.pickle'.format(self.dataset_name), 'wb') as f:
603 | pickle.dump((self.real_mu, self.real_cov), f, protocol=pickle.HIGHEST_PROTOCOL)
604 |
605 | print('{} real pickle save !!!'.format(self.dataset_name))
606 |
607 | self.real_cache = True
608 | del real_feats
609 |
610 | fake_feats = tf.zeros([0, 2048])
611 | from tqdm import tqdm
612 | for begin in tqdm(range(0, self.fid_samples_num, self.batch_size)):
613 | z = tf.random.normal(shape=[self.each_batch_size, self.g_params['z_dim']])
614 | dummy_labels = tf.random.normal((self.each_batch_size, self.g_params['labels_dim']), dtype=tf.float32)
615 |
616 | feats = get_fake_features(z, dummy_labels, self.g_clone, self.inception_model)
617 |
618 | fake_feats = tf.concat([fake_feats, feats], axis=0)
619 | # print('fake feats:', np.shape(fake_feats)[0])
620 |
621 | fake_mu = np.mean(fake_feats, axis=0)
622 | fake_cov = np.cov(fake_feats, rowvar=False)
623 | del fake_feats
624 |
625 | # Calculate FID.
626 | m = np.square(fake_mu - self.real_mu).sum()
627 | s, _ = scipy.linalg.sqrtm(np.dot(fake_cov, self.real_cov), disp=False) # pylint: disable=no-member
628 | dist = m + np.trace(fake_cov + self.real_cov - 2 * s)
629 |
630 | return dist
631 |
632 |
633 | def truncation_psi_canvas(self, real_images, path):
634 | # prepare inputs
635 | reals = real_images[:self.n_samples, :, :, :]
636 | latents = tf.random.normal(shape=(self.n_samples, self.g_params['z_dim']), dtype=tf.dtypes.float32)
637 | dummy_labels = tf.random.normal((self.n_samples, self.g_params['labels_dim']), dtype=tf.dtypes.float32)
638 |
639 | # run networks
640 | fake_images_00, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.0, training=False)
641 | fake_images_05, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.5, training=False)
642 | fake_images_07, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.7, training=False)
643 | fake_images_10, _ = self.g_clone([latents, dummy_labels], truncation_psi=1.0, training=False)
644 |
645 | # merge on batch dimension: [4 * n_samples, 3, img_size, img_size]
646 | out = tf.concat([fake_images_00, fake_images_05, fake_images_07, fake_images_10], axis=0)
647 |
648 | # prepare for image saving: [4 * n_samples, img_size, img_size, 3]
649 | out = postprocess_images(out)
650 |
651 | # resize to save disk spaces: [4 * n_samples, size, size, 3]
652 | size = min(self.img_size, 256)
653 | out = tf.image.resize(out, size=[size, size], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
654 |
655 | # make single image and add batch dimension for tensorboard: [1, 4 * size, n_samples * size, 3]
656 | out = merge_batch_images(out, size, rows=4, cols=self.n_samples)
657 |
658 | images = cv2.cvtColor(out.astype('uint8'), cv2.COLOR_RGB2BGR)
659 |
660 | return cv2.imwrite(path, images)
661 |
662 |
663 | def test(self):
664 | result_dir = os.path.join(self.result_dir, self.model_dir)
665 | check_folder(result_dir)
666 |
667 | total_num_samples = min(self.n_samples, self.batch_size)
668 | partial_size = int(np.floor(np.sqrt(total_num_samples)))
669 |
670 | from tqdm import tqdm
671 | for i in tqdm(range(self.n_test)):
672 | z = tf.random.normal(shape=[self.batch_size, self.g_params['z_dim']])
673 | dummy_labels = tf.random.normal((self.batch_size, self.g_params['labels_dim']), dtype=tf.float32)
674 | fake_img, _ = self.g_clone([z, dummy_labels], training=False)
675 |
676 | save_images(images=fake_img[:partial_size * partial_size, :, :, :],
677 | size=[partial_size, partial_size],
678 | image_path='./{}/fake_{:01d}.png'.format(result_dir, i))
679 |
680 | def test_70000(self):
681 | result_dir = os.path.join(self.result_dir, self.model_dir)
682 | check_folder(result_dir)
683 |
684 | total_num_samples = 1
685 | partial_size = int(np.floor(np.sqrt(total_num_samples)))
686 |
687 | from tqdm import tqdm
688 | for i in tqdm(range(70000)):
689 | z = tf.random.normal(shape=[1, self.g_params['z_dim']])
690 | dummy_labels = tf.random.normal((1, self.g_params['labels_dim']), dtype=tf.float32)
691 | fake_img, _ = self.g_clone([z, dummy_labels], training=False)
692 |
693 | save_images(images=fake_img[:partial_size * partial_size, :, :, :],
694 | size=[partial_size, partial_size],
695 | image_path='./{}/fake_{:01d}.png'.format(result_dir, i))
696 |
697 | def draw_uncurated_result_figure(self):
698 |
699 | result_dir = os.path.join(self.result_dir, self.model_dir, 'paper_figure')
700 | check_folder(result_dir)
701 |
702 | seed_flag = True
703 | lods = [0, 1, 2, 2, 3, 3]
704 | seed = 3291
705 | rows = 3
706 | cx = 0
707 | cy = 0
708 |
709 | if seed_flag:
710 | latents = tf.cast(
711 | np.random.RandomState(seed).normal(size=[sum(rows * 2 ** lod for lod in lods), self.g_params['z_dim']]), tf.float32)
712 | else:
713 | latents = tf.cast(np.random.normal(size=[sum(rows * 2 ** lod for lod in lods), self.g_params['z_dim']]), tf.float32)
714 |
715 | dummy_labels = tf.random.normal((sum(rows * 2 ** lod for lod in lods), self.g_params['labels_dim']), dtype=tf.float32)
716 |
717 | images, _ = self.g_clone([latents, dummy_labels], training=False)
718 | images = postprocess_images(images)
719 |
720 | canvas = PIL.Image.new('RGB', (sum(self.img_size // 2 ** lod for lod in lods), self.img_size * rows), 'white')
721 | image_iter = iter(list(images))
722 |
723 | for col, lod in enumerate(lods):
724 | for row in range(rows * 2 ** lod):
725 | image = PIL.Image.fromarray(np.uint8(next(image_iter)), 'RGB')
726 |
727 | image = image.crop((cx, cy, cx + self.img_size, cy + self.img_size))
728 | image = image.resize((self.img_size // 2 ** lod, self.img_size // 2 ** lod), PIL.Image.ANTIALIAS)
729 | canvas.paste(image,
730 | (sum(self.img_size // 2 ** lod for lod in lods[:col]), row * self.img_size // 2 ** lod))
731 |
732 | canvas.save('{}/figure02-uncurated.png'.format(result_dir))
733 |
734 | def draw_style_mixing_figure(self):
735 | result_dir = os.path.join(self.result_dir, self.model_dir, 'paper_figure')
736 | check_folder(result_dir)
737 |
738 | seed_flag = True
739 | src_seeds = [604, 8440, 7613, 6978, 3004]
740 | dst_seeds = [1336, 6968, 607, 728, 7036, 9010]
741 |
742 | truncation_psi = 0.7 # Style strength multiplier for the truncation trick
743 | truncation_cutoff = 8 # Number of layers for which to apply the truncation trick
744 |
745 | resolutions = self.g_params['resolutions']
746 | n_broadcast = len(resolutions) * 2
747 |
748 | style_ranges = [range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, n_broadcast)]
749 |
750 | if seed_flag:
751 | src_latents = tf.cast(
752 | np.concatenate(list(np.random.RandomState(seed).normal(size=[1, self.g_params['z_dim']]) for seed in src_seeds), axis=0), tf.float32)
753 | dst_latents = tf.cast(
754 | np.concatenate(list(np.random.RandomState(seed).normal(size=[1, self.g_params['z_dim']]) for seed in dst_seeds), axis=0), tf.float32)
755 |
756 | else:
757 | src_latents = tf.cast(np.random.normal(size=[len(src_seeds), self.g_params['z_dim']]), tf.float32)
758 | dst_latents = tf.cast(np.random.normal(size=[len(dst_seeds), self.g_params['z_dim']]), tf.float32)
759 |
760 | dummy_labels = tf.random.normal((len(src_seeds), self.g_params['labels_dim']), dtype=tf.float32)
761 |
762 | src_images, src_dlatents = self.g_clone([src_latents, dummy_labels], truncation_cutoff=truncation_cutoff, truncation_psi=truncation_psi, training=False)
763 | dst_images, dst_dlatents = self.g_clone([dst_latents, dummy_labels], truncation_cutoff=truncation_cutoff, truncation_psi=truncation_psi, training=False)
764 |
765 | src_images = postprocess_images(src_images)
766 | dst_images = postprocess_images(dst_images)
767 |
768 | img_out_size = min(self.img_size, 256)
769 |
770 | src_images = tf.image.resize(src_images, size=[img_out_size, img_out_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
771 | dst_images = tf.image.resize(dst_images, size=[img_out_size, img_out_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
772 |
773 | canvas = PIL.Image.new('RGB', (img_out_size * (len(src_seeds) + 1), img_out_size * (len(dst_seeds) + 1)), 'white')
774 |
775 | for col, src_image in enumerate(list(src_images)):
776 | canvas.paste(PIL.Image.fromarray(np.uint8(src_image), 'RGB'), ((col + 1) * img_out_size, 0))
777 |
778 | for row, dst_image in enumerate(list(dst_images)):
779 | canvas.paste(PIL.Image.fromarray(np.uint8(dst_image), 'RGB'), (0, (row + 1) * img_out_size))
780 |
781 | row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
782 | src_dlatents = np.asarray(src_dlatents, dtype=np.float32)
783 | row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
784 |
785 | row_images, _ = self.g_clone([row_dlatents, dummy_labels], mapping=False, training=False)
786 | row_images = postprocess_images(row_images)
787 |
788 |
789 | for col, image in enumerate(list(row_images)):
790 | canvas.paste(PIL.Image.fromarray(np.uint8(image), 'RGB'), ((col + 1) * img_out_size, (row + 1) * img_out_size))
791 |
792 | canvas.save('{}/figure03-style-mixing.png'.format(result_dir))
793 |
794 | def draw_truncation_trick_figure(self):
795 |
796 | result_dir = os.path.join(self.result_dir, self.model_dir, 'paper_figure')
797 | check_folder(result_dir)
798 |
799 | seed_flag = True
800 | seeds = [1653, 4010]
801 | psis = [-1, -0.7, -0.5, 0, 0.5, 0.7, 1]
802 |
803 | if seed_flag:
804 | latents = tf.cast(
805 | np.concatenate(list(np.random.RandomState(seed).normal(size=[1, self.g_params['z_dim']]) for seed in seeds), axis=0), tf.float32)
806 | else:
807 | latents = tf.cast(np.random.normal(size=[len(seeds), self.g_params['z_dim']]), tf.float32)
808 |
809 | dummy_labels = tf.random.normal((len(seeds), self.g_params['labels_dim']), dtype=tf.float32)
810 |
811 | fake_images_10_, _ = self.g_clone([latents, dummy_labels], truncation_psi=-1.0, training=False)
812 | fake_images_05_, _ = self.g_clone([latents, dummy_labels], truncation_psi=-0.5, training=False)
813 | fake_images_07_, _ = self.g_clone([latents, dummy_labels], truncation_psi=-0.7, training=False)
814 | fake_images_00, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.0, training=False)
815 | fake_images_05, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.5, training=False)
816 | fake_images_07, _ = self.g_clone([latents, dummy_labels], truncation_psi=0.7, training=False)
817 | fake_images_10, _ = self.g_clone([latents, dummy_labels], truncation_psi=1.0, training=False)
818 |
819 | # merge on batch dimension: [7, 3, img_size, img_size]
820 | col_images = list([fake_images_10_, fake_images_05_, fake_images_07_, fake_images_00, fake_images_05, fake_images_07, fake_images_10])
821 |
822 | img_out_size = min(self.img_size, 256)
823 |
824 | for i in range(len(col_images)):
825 | col_images[i] = postprocess_images(col_images[i])
826 | col_images[i] = tf.image.resize(col_images[i], size=[img_out_size, img_out_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
827 |
828 | canvas = PIL.Image.new('RGB', (img_out_size * len(psis), img_out_size * len(seeds)), 'white')
829 |
830 | for col, col_img in enumerate(col_images):
831 | for row, image in enumerate(col_img):
832 | canvas.paste(PIL.Image.fromarray(np.uint8(image), 'RGB'),
833 | (col * img_out_size, row * img_out_size))
834 |
835 | canvas.save('{}/figure08-truncation-trick.png'.format(result_dir))
836 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/sample_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/sample_2.gif
--------------------------------------------------------------------------------
/assets/style_mixing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/style_mixing.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/teaser.png
--------------------------------------------------------------------------------
/assets/truncation_trick.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/truncation_trick.png
--------------------------------------------------------------------------------
/assets/uncurated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/Toward_spatial_unbiased-Tensorflow/d37fd26061f28ed064cb86176df86f86d947625f/assets/uncurated.png
--------------------------------------------------------------------------------
/cuda/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """TensorFlow custom ops builder.
8 | """
9 |
10 | import os
11 | import re
12 | import uuid
13 | import hashlib
14 | import tempfile
15 | import shutil
16 | import tensorflow as tf
17 | from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
18 |
19 | #----------------------------------------------------------------------------
20 | # Global options.
21 |
22 | cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
23 | cuda_cache_version_tag = 'v1'
24 | do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
25 | verbose = True # Print status messages to stdout.
26 |
27 | compiler_bindir_search_path = [
28 | 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
29 | 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
30 | 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
31 | ]
32 |
33 | #----------------------------------------------------------------------------
34 | # Internal helper funcs.
35 |
36 | def _find_compiler_bindir():
37 | for compiler_path in compiler_bindir_search_path:
38 | if os.path.isdir(compiler_path):
39 | return compiler_path
40 | return None
41 |
42 | def _get_compute_cap(device):
43 | caps_str = device.physical_device_desc
44 | m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
45 | major = m.group(1)
46 | minor = m.group(2)
47 | return (major, minor)
48 |
49 | def _get_cuda_gpu_arch_string():
50 | gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
51 | if len(gpus) == 0:
52 | raise RuntimeError('No GPU devices found')
53 | (major, minor) = _get_compute_cap(gpus[0])
54 | return 'sm_%s%s' % (major, minor)
55 |
56 | def _run_cmd(cmd):
57 | with os.popen(cmd) as pipe:
58 | output = pipe.read()
59 | status = pipe.close()
60 | if status is not None:
61 | raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
62 |
63 | def _prepare_nvcc_cli(opts):
64 | cmd = 'nvcc --std=c++11 -DNDEBUG ' + opts.strip()
65 | cmd += ' --disable-warnings'
66 | cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
67 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
68 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
69 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
70 |
71 | compiler_bindir = _find_compiler_bindir()
72 | if compiler_bindir is None:
73 | # Require that _find_compiler_bindir succeeds on Windows. Allow
74 | # nvcc to use whatever is the default on Linux.
75 | if os.name == 'nt':
76 | raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
77 | else:
78 | cmd += ' --compiler-bindir "%s"' % compiler_bindir
79 | cmd += ' 2>&1'
80 | return cmd
81 |
82 | #----------------------------------------------------------------------------
83 | # Main entry point.
84 |
85 | _plugin_cache = dict()
86 |
87 | def get_plugin(cuda_file):
88 | cuda_file_base = os.path.basename(cuda_file)
89 | cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
90 |
91 | # Already in cache?
92 | if cuda_file in _plugin_cache:
93 | return _plugin_cache[cuda_file]
94 |
95 | # Setup plugin.
96 | if verbose:
97 | print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
98 | try:
99 | # Hash CUDA source.
100 | md5 = hashlib.md5()
101 | with open(cuda_file, 'rb') as f:
102 | md5.update(f.read())
103 | md5.update(b'\n')
104 |
105 | # Hash headers included by the CUDA code by running it through the preprocessor.
106 | if not do_not_hash_included_headers:
107 | if verbose:
108 | print('Preprocessing... ', end='', flush=True)
109 | with tempfile.TemporaryDirectory() as tmp_dir:
110 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
111 | _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
112 | with open(tmp_file, 'rb') as f:
113 | bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
114 | good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
115 | for ln in f:
116 | if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
117 | ln = ln.replace(bad_file_str, good_file_str)
118 | md5.update(ln)
119 | md5.update(b'\n')
120 |
121 | # Select compiler options.
122 | compile_opts = ''
123 | if os.name == 'nt':
124 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
125 | elif os.name == 'posix':
126 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
127 | compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
128 | else:
129 | assert False # not Windows or Linux, w00t?
130 | compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
131 | compile_opts += ' --use_fast_math'
132 | nvcc_cmd = _prepare_nvcc_cli(compile_opts)
133 |
134 | # Hash build configuration.
135 | md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
136 | md5.update(('tf.VERSION: ' + tf.version.VERSION).encode('utf-8') + b'\n')
137 | md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
138 |
139 | # Compile if not already compiled.
140 | bin_file_ext = '.dll' if os.name == 'nt' else '.so'
141 | bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
142 | if not os.path.isfile(bin_file):
143 | if verbose:
144 | print('Compiling... ', end='', flush=True)
145 | with tempfile.TemporaryDirectory() as tmp_dir:
146 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
147 | _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
148 | os.makedirs(cuda_cache_path, exist_ok=True)
149 | intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
150 | shutil.copyfile(tmp_file, intermediate_file)
151 | os.rename(intermediate_file, bin_file) # atomic
152 |
153 | # Load.
154 | if verbose:
155 | print('Loading... ', end='', flush=True)
156 | plugin = tf.load_op_library(bin_file)
157 |
158 | # Add to cache.
159 | _plugin_cache[cuda_file] = plugin
160 | if verbose:
161 | print('Done.', flush=True)
162 | return plugin
163 |
164 | except:
165 | if verbose:
166 | print('Failed!', flush=True)
167 | raise
168 |
169 | #----------------------------------------------------------------------------
170 |
--------------------------------------------------------------------------------
/cuda/fused_bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #define EIGEN_USE_GPU
8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
9 | #include "tensorflow/core/framework/op.h"
10 | #include "tensorflow/core/framework/op_kernel.h"
11 | #include "tensorflow/core/framework/shape_inference.h"
12 | #include
13 |
14 | using namespace tensorflow;
15 | using namespace tensorflow::shape_inference;
16 |
17 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
18 |
19 | //------------------------------------------------------------------------
20 | // CUDA kernel.
21 |
22 | template
23 | struct FusedBiasActKernelParams
24 | {
25 | const T* x; // [sizeX]
26 | const T* b; // [sizeB] or NULL
27 | const T* ref; // [sizeX] or NULL
28 | T* y; // [sizeX]
29 |
30 | int grad;
31 | int axis;
32 | int act;
33 | float alpha;
34 | float gain;
35 |
36 | int sizeX;
37 | int sizeB;
38 | int stepB;
39 | int loopX;
40 | };
41 |
42 | template
43 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p)
44 | {
45 | const float expRange = 80.0f;
46 | const float halfExpRange = 40.0f;
47 | const float seluScale = 1.0507009873554804934193349852946f;
48 | const float seluAlpha = 1.6732632423543772848170429916717f;
49 |
50 | // Loop over elements.
51 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
52 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
53 | {
54 | // Load and apply bias.
55 | float x = (float)p.x[xi];
56 | if (p.b)
57 | x += (float)p.b[(xi / p.stepB) % p.sizeB];
58 | float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
59 | if (p.gain != 0.0f & p.act != 9)
60 | ref /= p.gain;
61 |
62 | // Evaluate activation func.
63 | float y;
64 | switch (p.act * 10 + p.grad)
65 | {
66 | // linear
67 | default:
68 | case 10: y = x; break;
69 | case 11: y = x; break;
70 | case 12: y = 0.0f; break;
71 |
72 | // relu
73 | case 20: y = (x > 0.0f) ? x : 0.0f; break;
74 | case 21: y = (ref > 0.0f) ? x : 0.0f; break;
75 | case 22: y = 0.0f; break;
76 |
77 | // lrelu
78 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
79 | case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
80 | case 32: y = 0.0f; break;
81 |
82 | // tanh
83 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
84 | case 41: y = x * (1.0f - ref * ref); break;
85 | case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
86 |
87 | // sigmoid
88 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
89 | case 51: y = x * ref * (1.0f - ref); break;
90 | case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
91 |
92 | // elu
93 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
94 | case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
95 | case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
96 |
97 | // selu
98 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
99 | case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
100 | case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
101 |
102 | // softplus
103 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
104 | case 81: y = x * (1.0f - expf(-ref)); break;
105 | case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
106 |
107 | // swish
108 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
109 | case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
110 | case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
111 | }
112 |
113 | // Apply gain and store.
114 | p.y[xi] = (T)(y * p.gain);
115 | }
116 | }
117 |
118 | //------------------------------------------------------------------------
119 | // TensorFlow op.
120 |
121 | template
122 | struct FusedBiasActOp : public OpKernel
123 | {
124 | FusedBiasActKernelParams m_attribs;
125 |
126 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
127 | {
128 | memset(&m_attribs, 0, sizeof(m_attribs));
129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
131 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
132 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
133 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
134 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
135 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
136 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
137 | }
138 |
139 | void Compute(OpKernelContext* ctx)
140 | {
141 | FusedBiasActKernelParams p = m_attribs;
142 | cudaStream_t stream = ctx->eigen_device().stream();
143 |
144 | const Tensor& x = ctx->input(0); // [...]
145 | const Tensor& b = ctx->input(1); // [sizeB] or [0]
146 | const Tensor& ref = ctx->input(2); // x.shape or [0]
147 | p.x = x.flat().data();
148 | p.b = (b.NumElements()) ? b.flat().data() : NULL;
149 | p.ref = (ref.NumElements()) ? ref.flat().data() : NULL;
150 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
151 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
152 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
153 | OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
154 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
155 |
156 | p.sizeX = (int)x.NumElements();
157 | p.sizeB = (int)b.NumElements();
158 | p.stepB = 1;
159 | for (int i = m_attribs.axis + 1; i < x.dims(); i++)
160 | p.stepB *= (int)x.dim_size(i);
161 |
162 | Tensor* y = NULL; // x.shape
163 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
164 | p.y = y->flat().data();
165 |
166 | p.loopX = 4;
167 | int blockSize = 4 * 32;
168 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
169 | void* args[] = {&p};
170 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream));
171 | }
172 | };
173 |
174 | REGISTER_OP("FusedBiasAct")
175 | .Input ("x: T")
176 | .Input ("b: T")
177 | .Input ("ref: T")
178 | .Output ("y: T")
179 | .Attr ("T: {float, half}")
180 | .Attr ("grad: int = 0")
181 | .Attr ("axis: int = 1")
182 | .Attr ("act: int = 0")
183 | .Attr ("alpha: float = 0.0")
184 | .Attr ("gain: float = 1.0");
185 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
186 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
187 |
188 | //------------------------------------------------------------------------
189 |
--------------------------------------------------------------------------------
/cuda/fused_bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Custom TensorFlow ops for efficient bias and activation."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 | from cuda import custom_ops
13 | from utils import EasyDict
14 |
15 | def _get_plugin():
16 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | activation_funcs = {
21 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
22 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
23 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
24 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
25 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
26 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
27 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
28 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
29 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
30 | }
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
35 | r"""Fused bias and activation function.
36 |
37 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
38 | and scales the result by `gain`. Each of the steps is optional. In most cases,
39 | the fused op is considerably more efficient than performing the same calculation
40 | using standard TensorFlow ops. It supports first and second order gradients,
41 | but not third order gradients.
42 |
43 | Args:
44 | x: Input activation tensor. Can have any shape, but if `b` is defined, the
45 | dimension corresponding to `axis`, as well as the rank, must be known.
46 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
47 | as `x`. The shape must be known, and it must match the dimension of `x`
48 | corresponding to `axis`.
49 | axis: The dimension in `x` corresponding to the elements of `b`.
50 | The value of `axis` is ignored if `b` is not specified.
51 | act: Name of the activation function to evaluate, or `"linear"` to disable.
52 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
53 | See `activation_funcs` for a full list. `None` is not allowed.
54 | alpha: Shape parameter for the activation function, or `None` to use the default.
55 | gain: Scaling factor for the output tensor, or `None` to use default.
56 | See `activation_funcs` for the default scaling of each activation function.
57 | If unsure, consider specifying `1.0`.
58 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
59 |
60 | Returns:
61 | Tensor of the same shape and datatype as `x`.
62 | """
63 |
64 | impl_dict = {
65 | 'ref': _fused_bias_act_ref,
66 | 'cuda': _fused_bias_act_cuda,
67 | }
68 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
73 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
74 |
75 | # Validate arguments.
76 | x = tf.convert_to_tensor(x)
77 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
78 | act_spec = activation_funcs[act]
79 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
80 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
81 | if alpha is None:
82 | alpha = act_spec.def_alpha
83 | if gain is None:
84 | gain = act_spec.def_gain
85 |
86 | # Add bias.
87 | if b.shape[0] != 0:
88 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
89 |
90 | # Evaluate activation function.
91 | x = act_spec.func(x, alpha=alpha)
92 |
93 | # Scale by gain.
94 | if gain != 1:
95 | x *= gain
96 | return x
97 |
98 | #----------------------------------------------------------------------------
99 |
100 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
101 | """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
102 |
103 | # Validate arguments.
104 | x = tf.convert_to_tensor(x)
105 | empty_tensor = tf.constant([], dtype=x.dtype)
106 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor
107 | act_spec = activation_funcs[act]
108 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
109 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
110 | if alpha is None:
111 | alpha = act_spec.def_alpha
112 | if gain is None:
113 | gain = act_spec.def_gain
114 |
115 | # Special cases.
116 | if act == 'linear' and b is None and gain == 1.0:
117 | return x
118 | if act_spec.cuda_idx is None:
119 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
120 |
121 | # CUDA kernel.
122 | cuda_kernel = _get_plugin().fused_bias_act
123 | cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
124 |
125 | # Forward pass: y = func(x, b).
126 | def func_y(x, b):
127 | y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
128 | y.set_shape(x.shape)
129 | return y
130 |
131 | # Backward pass: dx, db = grad(dy, x, y)
132 | def grad_dx(dy, x, y):
133 | ref = {'x': x, 'y': y}[act_spec.ref]
134 | dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
135 | dx.set_shape(x.shape)
136 | return dx
137 | def grad_db(dx):
138 | if b.shape[0] == 0:
139 | return empty_tensor
140 | db = dx
141 | if axis < x.shape.rank - 1:
142 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
143 | if axis > 0:
144 | db = tf.reduce_sum(db, list(range(axis)))
145 | db.set_shape(b.shape)
146 | return db
147 |
148 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
149 | def grad2_d_dy(d_dx, d_db, x, y):
150 | ref = {'x': x, 'y': y}[act_spec.ref]
151 | d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
152 | d_dy.set_shape(x.shape)
153 | return d_dy
154 | def grad2_d_x(d_dx, d_db, x, y):
155 | ref = {'x': x, 'y': y}[act_spec.ref]
156 | d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
157 | d_x.set_shape(x.shape)
158 | return d_x
159 |
160 | # Fast version for piecewise-linear activation funcs.
161 | @tf.custom_gradient
162 | def func_zero_2nd_grad(x, b):
163 | y = func_y(x, b)
164 | @tf.custom_gradient
165 | def grad(dy):
166 | dx = grad_dx(dy, x, y)
167 | db = grad_db(dx)
168 | def grad2(d_dx, d_db):
169 | d_dy = grad2_d_dy(d_dx, d_db, x, y)
170 | return d_dy
171 | return (dx, db), grad2
172 | return y, grad
173 |
174 | # Slow version for general activation funcs.
175 | @tf.custom_gradient
176 | def func_nonzero_2nd_grad(x, b):
177 | y = func_y(x, b)
178 | def grad_wrap(dy):
179 | @tf.custom_gradient
180 | def grad_impl(dy, x):
181 | dx = grad_dx(dy, x, y)
182 | db = grad_db(dx)
183 | def grad2(d_dx, d_db):
184 | d_dy = grad2_d_dy(d_dx, d_db, x, y)
185 | d_x = grad2_d_x(d_dx, d_db, x, y)
186 | return d_dy, d_x
187 | return (dx, db), grad2
188 | return grad_impl(dy, x)
189 | return y, grad_wrap
190 |
191 | # Which version to use?
192 | if act_spec.zero_2nd_grad:
193 | return func_zero_2nd_grad(x, b)
194 | return func_nonzero_2nd_grad(x, b)
195 |
196 | #----------------------------------------------------------------------------
197 |
--------------------------------------------------------------------------------
/cuda/upfirdn_2d.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #define EIGEN_USE_GPU
8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
9 | #include "tensorflow/core/framework/op.h"
10 | #include "tensorflow/core/framework/op_kernel.h"
11 | #include "tensorflow/core/framework/shape_inference.h"
12 | #include
13 |
14 | using namespace tensorflow;
15 | using namespace tensorflow::shape_inference;
16 |
17 | //------------------------------------------------------------------------
18 | // Helpers.
19 |
20 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
21 |
22 | static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
23 | {
24 | int c = a / b;
25 | if (c * b > a)
26 | c--;
27 | return c;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 | // CUDA kernel params.
32 |
33 | template
34 | struct UpFirDn2DKernelParams
35 | {
36 | const T* x; // [majorDim, inH, inW, minorDim]
37 | const T* k; // [kernelH, kernelW]
38 | T* y; // [majorDim, outH, outW, minorDim]
39 |
40 | int upx;
41 | int upy;
42 | int downx;
43 | int downy;
44 | int padx0;
45 | int padx1;
46 | int pady0;
47 | int pady1;
48 |
49 | int majorDim;
50 | int inH;
51 | int inW;
52 | int minorDim;
53 | int kernelH;
54 | int kernelW;
55 | int outH;
56 | int outW;
57 | int loopMajor;
58 | int loopX;
59 | };
60 |
61 | //------------------------------------------------------------------------
62 | // General CUDA implementation for large filter kernels.
63 |
64 | template
65 | static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p)
66 | {
67 | // Calculate thread index.
68 | int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
69 | int outY = minorIdx / p.minorDim;
70 | minorIdx -= outY * p.minorDim;
71 | int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
72 | int majorIdxBase = blockIdx.z * p.loopMajor;
73 | if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
74 | return;
75 |
76 | // Setup Y receptive field.
77 | int midY = outY * p.downy + p.upy - 1 - p.pady0;
78 | int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
79 | int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
80 | int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
81 |
82 | // Loop over majorDim and outX.
83 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
84 | for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
85 | {
86 | // Setup X receptive field.
87 | int midX = outX * p.downx + p.upx - 1 - p.padx0;
88 | int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
89 | int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
90 | int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
91 |
92 | // Initialize pointers.
93 | const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
94 | const T* kp = &p.k[kernelY * p.kernelW + kernelX];
95 | int xpx = p.minorDim;
96 | int kpx = -p.upx;
97 | int xpy = p.inW * p.minorDim;
98 | int kpy = -p.upy * p.kernelW;
99 |
100 | // Inner loop.
101 | float v = 0.0f;
102 | for (int y = 0; y < h; y++)
103 | {
104 | for (int x = 0; x < w; x++)
105 | {
106 | v += (float)(*xp) * (float)(*kp);
107 | xp += xpx;
108 | kp += kpx;
109 | }
110 | xp += xpy - w * xpx;
111 | kp += kpy - w * kpx;
112 | }
113 |
114 | // Store result.
115 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
116 | }
117 | }
118 |
119 | //------------------------------------------------------------------------
120 | // Specialized CUDA implementation for small filter kernels.
121 |
122 | template
123 | static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p)
124 | {
125 | //assert(kernelW % upx == 0);
126 | //assert(kernelH % upy == 0);
127 | const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
128 | const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
129 | __shared__ volatile float sk[kernelH][kernelW];
130 | __shared__ volatile float sx[tileInH][tileInW];
131 |
132 | // Calculate tile index.
133 | int minorIdx = blockIdx.x;
134 | int tileOutY = minorIdx / p.minorDim;
135 | minorIdx -= tileOutY * p.minorDim;
136 | tileOutY *= tileOutH;
137 | int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
138 | int majorIdxBase = blockIdx.z * p.loopMajor;
139 | if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
140 | return;
141 |
142 | // Load filter kernel (flipped).
143 | for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
144 | {
145 | int ky = tapIdx / kernelW;
146 | int kx = tapIdx - ky * kernelW;
147 | float v = 0.0f;
148 | if (kx < p.kernelW & ky < p.kernelH)
149 | v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
150 | sk[ky][kx] = v;
151 | }
152 |
153 | // Loop over majorDim and outX.
154 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
155 | for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
156 | {
157 | // Load input pixels.
158 | int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
159 | int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
160 | int tileInX = floorDiv(tileMidX, upx);
161 | int tileInY = floorDiv(tileMidY, upy);
162 | __syncthreads();
163 | for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
164 | {
165 | int relInY = inIdx / tileInW;
166 | int relInX = inIdx - relInY * tileInW;
167 | int inX = relInX + tileInX;
168 | int inY = relInY + tileInY;
169 | float v = 0.0f;
170 | if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
171 | v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
172 | sx[relInY][relInX] = v;
173 | }
174 |
175 | // Loop over output pixels.
176 | __syncthreads();
177 | for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
178 | {
179 | int relOutY = outIdx / tileOutW;
180 | int relOutX = outIdx - relOutY * tileOutW;
181 | int outX = relOutX + tileOutX;
182 | int outY = relOutY + tileOutY;
183 |
184 | // Setup receptive field.
185 | int midX = tileMidX + relOutX * downx;
186 | int midY = tileMidY + relOutY * downy;
187 | int inX = floorDiv(midX, upx);
188 | int inY = floorDiv(midY, upy);
189 | int relInX = inX - tileInX;
190 | int relInY = inY - tileInY;
191 | int kernelX = (inX + 1) * upx - midX - 1; // flipped
192 | int kernelY = (inY + 1) * upy - midY - 1; // flipped
193 |
194 | // Inner loop.
195 | float v = 0.0f;
196 | #pragma unroll
197 | for (int y = 0; y < kernelH / upy; y++)
198 | #pragma unroll
199 | for (int x = 0; x < kernelW / upx; x++)
200 | v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
201 |
202 | // Store result.
203 | if (outX < p.outW & outY < p.outH)
204 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
205 | }
206 | }
207 | }
208 |
209 | //------------------------------------------------------------------------
210 | // TensorFlow op.
211 |
212 | template
213 | struct UpFirDn2DOp : public OpKernel
214 | {
215 | UpFirDn2DKernelParams m_attribs;
216 |
217 | UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
218 | {
219 | memset(&m_attribs, 0, sizeof(m_attribs));
220 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
221 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
222 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
223 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
224 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
225 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
226 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
227 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
228 | OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
229 | OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
230 | }
231 |
232 | void Compute(OpKernelContext* ctx)
233 | {
234 | UpFirDn2DKernelParams p = m_attribs;
235 | cudaStream_t stream = ctx->eigen_device().stream();
236 |
237 | const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
238 | const Tensor& k = ctx->input(1); // [kernelH, kernelW]
239 | p.x = x.flat().data();
240 | p.k = k.flat().data();
241 | OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
242 | OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
243 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
244 | OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
245 |
246 | p.majorDim = (int)x.dim_size(0);
247 | p.inH = (int)x.dim_size(1);
248 | p.inW = (int)x.dim_size(2);
249 | p.minorDim = (int)x.dim_size(3);
250 | p.kernelH = (int)k.dim_size(0);
251 | p.kernelW = (int)k.dim_size(1);
252 | OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
253 |
254 | p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
255 | p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
256 | OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
257 |
258 | Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
259 | TensorShape ys;
260 | ys.AddDim(p.majorDim);
261 | ys.AddDim(p.outH);
262 | ys.AddDim(p.outW);
263 | ys.AddDim(p.minorDim);
264 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
265 | p.y = y->flat().data();
266 | OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
267 |
268 | // Choose CUDA kernel to use.
269 | void* cudaKernel = (void*)UpFirDn2DKernel_large;
270 | int tileOutW = -1;
271 | int tileOutH = -1;
272 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
273 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
274 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
275 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
276 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
277 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
278 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
279 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
280 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
281 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
282 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
283 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
284 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
285 |
286 | // Choose launch params.
287 | dim3 blockSize;
288 | dim3 gridSize;
289 | if (tileOutW > 0 && tileOutH > 0) // small
290 | {
291 | p.loopMajor = (p.majorDim - 1) / 16384 + 1;
292 | p.loopX = 1;
293 | blockSize = dim3(32 * 8, 1, 1);
294 | gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
295 | }
296 | else // large
297 | {
298 | p.loopMajor = (p.majorDim - 1) / 16384 + 1;
299 | p.loopX = 4;
300 | blockSize = dim3(4, 32, 1);
301 | gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
302 | }
303 |
304 | // Launch CUDA kernel.
305 | void* args[] = {&p};
306 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
307 | }
308 | };
309 |
310 | REGISTER_OP("UpFirDn2D")
311 | .Input ("x: T")
312 | .Input ("k: T")
313 | .Output ("y: T")
314 | .Attr ("T: {float, half}")
315 | .Attr ("upx: int = 1")
316 | .Attr ("upy: int = 1")
317 | .Attr ("downx: int = 1")
318 | .Attr ("downy: int = 1")
319 | .Attr ("padx0: int = 0")
320 | .Attr ("padx1: int = 0")
321 | .Attr ("pady0: int = 0")
322 | .Attr ("pady1: int = 0");
323 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
324 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
325 |
326 | //------------------------------------------------------------------------
327 |
--------------------------------------------------------------------------------
/cuda/upfirdn_2d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | from cuda import custom_ops
5 |
6 |
7 | def _get_plugin():
8 | loc = os.path.dirname(os.path.abspath(__file__))
9 | cu_fn = 'upfirdn_2d.cu'
10 | return custom_ops.get_plugin(os.path.join(loc, cu_fn))
11 |
12 |
13 | def _setup_kernel(k):
14 | k = np.asarray(k, dtype=np.float32)
15 | if k.ndim == 1:
16 | k = np.outer(k, k)
17 | k /= np.sum(k)
18 | assert k.ndim == 2
19 | assert k.shape[0] == k.shape[1]
20 | return k
21 |
22 |
23 | def compute_paddings(resample_kernel, convW, up, down, is_conv, factor=2, gain=1):
24 | assert not (up and down)
25 |
26 | k = [1] * factor if resample_kernel is None else resample_kernel
27 | if up:
28 | k = _setup_kernel(k) * (gain * (factor ** 2))
29 | if is_conv:
30 | p = (k.shape[0] - factor) - (convW - 1)
31 | pad0 = (p + 1) // 2 + factor - 1
32 | pad1 = p // 2 + 1
33 | else:
34 | p = k.shape[0] - factor
35 | pad0 = (p + 1) // 2 + factor - 1
36 | pad1 = p // 2
37 | elif down:
38 | k = _setup_kernel(k) * gain
39 | if is_conv:
40 | p = (k.shape[0] - factor) + (convW - 1)
41 | pad0 = (p + 1) // 2
42 | pad1 = p // 2
43 | else:
44 | p = k.shape[0] - factor
45 | pad0 = (p + 1) // 2
46 | pad1 = p // 2
47 | else:
48 | k = resample_kernel
49 | pad0, pad1 = 0, 0
50 | return k, pad0, pad1
51 |
52 |
53 | def upsample_2d(x, pad0, pad1, k, factor=2):
54 | assert isinstance(factor, int) and factor >= 1
55 | x_res = x.shape[2]
56 | return _simple_upfirdn_2d(x, x_res, k, up=factor, pad0=pad0, pad1=pad1)
57 |
58 |
59 | def downsample_2d(x, pad0, pad1, k, factor=2):
60 | assert isinstance(factor, int) and factor >= 1
61 | x_res = x.shape[2]
62 | return _simple_upfirdn_2d(x, x_res, k, down=factor, pad0=pad0, pad1=pad1)
63 |
64 |
65 | def upsample_conv_2d(x, w, convH, convW, pad0, pad1, k, factor=2):
66 | assert isinstance(factor, int) and factor >= 1
67 |
68 | x_res = x.shape[2]
69 | # Check weight shape.
70 | w = tf.convert_to_tensor(w)
71 | assert w.shape.rank == 4
72 | # convH = w.shape[0]
73 | # convW = w.shape[1]
74 | inC = tf.shape(w)[2]
75 | outC = tf.shape(w)[3]
76 | assert convW == convH
77 |
78 | # Determine data dimensions.
79 | stride = [1, 1, factor, factor]
80 | output_shape = [tf.shape(x)[0], outC, (x_res - 1) * factor + convH, (x_res - 1) * factor + convW]
81 | num_groups = tf.shape(x)[1] // inC
82 |
83 | # Transpose weights.
84 | w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
85 | w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
86 | w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
87 |
88 | # Execute.
89 | x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format='NCHW')
90 | new_x_res = output_shape[2]
91 | return _simple_upfirdn_2d(x, new_x_res, k, pad0=pad0, pad1=pad1)
92 |
93 |
94 | def conv_downsample_2d(x, w, convH, convW, pad0, pad1, k, factor=2):
95 | assert isinstance(factor, int) and factor >= 1
96 | x_res = x.shape[2]
97 | w = tf.convert_to_tensor(w)
98 | # convH, convW, _inC, _outC = w.shape.as_list()
99 | assert convW == convH
100 |
101 | s = [1, 1, factor, factor]
102 | x = _simple_upfirdn_2d(x, x_res, k, pad0=pad0, pad1=pad1)
103 | return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format='NCHW')
104 |
105 |
106 | def _simple_upfirdn_2d(x, x_res, k, up=1, down=1, pad0=0, pad1=0):
107 | assert x.shape.rank == 4
108 | y = x
109 | y = tf.reshape(y, [-1, x_res, x_res, 1])
110 | y = upfirdn_2d_cuda(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1)
111 | y = tf.reshape(y, [-1, tf.shape(x)[1], tf.shape(y)[1], tf.shape(y)[2]])
112 | return y
113 |
114 |
115 | def upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
116 | """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
117 |
118 | x = tf.convert_to_tensor(x)
119 | k = np.asarray(k, dtype=np.float32)
120 | majorDim, inH, inW, minorDim = x.shape.as_list()
121 | kernelH, kernelW = k.shape
122 | assert inW >= 1 and inH >= 1
123 | assert kernelW >= 1 and kernelH >= 1
124 | assert isinstance(upx, int) and isinstance(upy, int)
125 | assert isinstance(downx, int) and isinstance(downy, int)
126 | assert isinstance(padx0, int) and isinstance(padx1, int)
127 | assert isinstance(pady0, int) and isinstance(pady1, int)
128 |
129 | outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
130 | outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
131 | assert outW >= 1 and outH >= 1
132 |
133 | kc = tf.constant(k, dtype=x.dtype)
134 | gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
135 | gpadx0 = kernelW - padx0 - 1
136 | gpady0 = kernelH - pady0 - 1
137 | gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
138 | gpady1 = inH * upy - outH * downy + pady0 - upy + 1
139 |
140 | @tf.custom_gradient
141 | def func(x):
142 | y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
143 | y.set_shape([majorDim, outH, outW, minorDim])
144 | @tf.custom_gradient
145 | def grad(dy):
146 | dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
147 | dx.set_shape([majorDim, inH, inW, minorDim])
148 | return dx, func
149 | return y, grad
150 | return func(x)
151 |
--------------------------------------------------------------------------------
/generate_video.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from PIL import Image
3 | import numpy as np
4 |
5 | import math
6 | from tqdm import tqdm
7 |
8 | import torch
9 | from torchvision import utils
10 | import cv2
11 |
12 | from moviepy.editor import *
13 |
14 | def make_grid(images, res, rows, cols):
15 | images = (tf.clip_by_value(images, -1.0, 1.0) + 1.0) * 127.5
16 | images = tf.transpose(images, perm=[0, 2, 3, 1])
17 | images = tf.cast(images, tf.uint8)
18 | images = images.numpy()
19 |
20 | batch_size = images.shape[0]
21 | assert rows * cols == batch_size
22 | canvas = np.zeros(shape=[res * rows, res * cols, 3], dtype=np.uint8)
23 | for row in range(rows):
24 | y_start = row * res
25 | for col in range(cols):
26 | x_start = col * res
27 | index = col + row * cols
28 | canvas[y_start:y_start + res, x_start:x_start + res, :] = images[index, :, :, :]
29 |
30 | return canvas
31 |
32 | def load_generator(g_params=None, is_g_clone=True, ckpt_dir='checkpoint'):
33 |
34 | from networks import Generator
35 |
36 | if g_params is None:
37 | g_params = {
38 | 'z_dim': 512,
39 | 'w_dim': 512,
40 | 'labels_dim': 0,
41 | 'n_mapping': 8,
42 | 'resolutions': [4, 8, 16, 32, 64, 128, 256],
43 | 'featuremaps': [512, 512, 512, 512, 512, 256, 128],
44 | 'w_ema_decay': 0.995,
45 | 'style_mixing_prob': 0.9,
46 | }
47 |
48 | test_latent = tf.ones((1, g_params['z_dim']), dtype=tf.float32)
49 | test_labels = tf.ones((1, g_params['labels_dim']), dtype=tf.float32)
50 |
51 | # build generator model
52 | generator = Generator(g_params)
53 | _, _ = generator([test_latent, test_labels])
54 |
55 | if ckpt_dir is not None:
56 | if is_g_clone:
57 | ckpt = tf.train.Checkpoint(g_clone=generator)
58 | else:
59 | ckpt = tf.train.Checkpoint(generator=generator)
60 | manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=1)
61 | ckpt.restore(manager.latest_checkpoint).expect_partial()
62 | if manager.latest_checkpoint:
63 | print(f'Generator restored from {manager.latest_checkpoint}')
64 |
65 | return generator
66 |
67 | def generate():
68 |
69 | generator = load_generator(is_g_clone=True)
70 | radius = 30 # 32
71 | pics = 120
72 | truncation_psi = 0.5 # 1.0
73 |
74 | sample_n = 16 # 4
75 | n_row = 4
76 | n_col = 4
77 | res = 256
78 | sample_z = tf.random.normal(shape=[sample_n, 512])
79 | images = []
80 | for i in tqdm(range(pics)):
81 | dh = math.sin(2 * math.pi * (i / pics)) * radius
82 | dw = math.cos(2 * math.pi * (i / pics)) * radius
83 |
84 | sample_tf, _ = generator([sample_z,
85 | tf.random.normal(shape=[sample_n, 0])],
86 | shift_h=dh, shift_w=dw,
87 | training=False, truncation_psi=truncation_psi)
88 | # Pytorch
89 |
90 | sample = sample_tf
91 | sample = sample.numpy()
92 | sample = torch.Tensor(sample)
93 | grid = utils.make_grid(
94 | sample.cpu(), normalize=True, nrow=n_row, value_range=(-1, 1)
95 | )
96 | grid = grid.mul(255).permute(1, 2, 0).numpy().astype(np.uint8)
97 | images.append(
98 | grid
99 | )
100 |
101 |
102 | # Tensorflow
103 | # grid_tf = make_grid(sample_tf, res=res, rows=n_row, cols=n_col)
104 | # images.append(grid_tf)
105 |
106 |
107 | # Image save
108 | """
109 | for j in tqdm(range(sample_n)):
110 | f_name = 'images/{}_{}.png'.format(j, i)
111 | utils.save_image(
112 | sample[j].unsqueeze(0),
113 | f_name,
114 | nrow=1,
115 | normalize=True,
116 | range=(-1, 1),
117 | )
118 | """
119 |
120 | # To video
121 | videodims = (images[0].shape[1], images[0].shape[0])
122 | fourcc = cv2.VideoWriter_fourcc(*"VP90")
123 | video = cv2.VideoWriter("sample.webm", fourcc, 24, videodims)
124 |
125 | for i in tqdm(images):
126 | video.write(cv2.cvtColor(i, cv2.COLOR_RGB2BGR))
127 |
128 | video.release()
129 |
130 | # Video to GIF
131 | clip = VideoFileClip("sample.webm")
132 | clip.write_gif("sample.gif")
133 |
134 |
135 | generate()
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 |
3 | ##################################################################################
4 | # Synthesis Layers
5 | ##################################################################################
6 | class Synthesis(tf.keras.layers.Layer):
7 | def __init__(self, resolutions, featuremaps, name, **kwargs):
8 | super(Synthesis, self).__init__(name=name, **kwargs)
9 | self.resolutions = resolutions
10 | self.featuremaps = featuremaps
11 |
12 | self.k, self.pad0, self.pad1 = compute_paddings([1, 3, 3, 1], None, up=True, down=False, is_conv=False)
13 |
14 | # initial layer
15 | res, n_f = resolutions[0], featuremaps[0]
16 | self.img_size = resolutions[-1]
17 | self.log_size = int(np.log2(self.img_size))
18 |
19 | self.shift_h_dict = {4: 0}
20 | self.shift_w_dict = {4: 0}
21 | for i in range(3, self.log_size + 1):
22 | self.shift_h_dict[2 ** i] = 0
23 | self.shift_w_dict[2 ** i] = 0
24 |
25 | self.initial_block = SynthesisConstBlock(fmaps=n_f, name='{:d}x{:d}/const'.format(res, res))
26 | self.initial_torgb = ToRGB(in_ch=n_f, name='{:d}x{:d}/ToRGB'.format(res, res))
27 |
28 | # stack generator block with lerp block
29 | prev_n_f = n_f
30 | self.blocks = []
31 | self.torgbs = []
32 |
33 | for res, n_f in zip(self.resolutions[1:], self.featuremaps[1:]):
34 | self.blocks.append(SynthesisBlock(in_ch=prev_n_f, fmaps=n_f, res=res,
35 | name='{:d}x{:d}/block'.format(res, res)))
36 | self.torgbs.append(ToRGB(in_ch=n_f, name='{:d}x{:d}/ToRGB'.format(res, res)))
37 | prev_n_f = n_f
38 |
39 | def call(self, inputs, shift_h=0, shift_w=0, training=None, mask=None):
40 | ##### positional encoding #####
41 | # continuous roll
42 | if shift_h:
43 | for i in range(2, self.log_size + 1):
44 | self.shift_h_dict[2 ** i] = shift_h / (self.img_size // (2 ** i))
45 | if shift_w:
46 | for i in range(2, self.log_size + 1):
47 | self.shift_w_dict[2 ** i] = shift_w / (self.img_size // (2 ** i))
48 |
49 | w_broadcasted = inputs
50 |
51 | # initial layer
52 | w0, w1 = w_broadcasted[:, 0], w_broadcasted[:, 1]
53 |
54 | x = self.initial_block([w_broadcasted, w0], shift_h_dict=self.shift_h_dict, shift_w_dict=self.shift_w_dict)
55 | y = self.initial_torgb([x, w1])
56 |
57 | layer_index = 1
58 | for block, torgb in zip(self.blocks, self.torgbs):
59 | w0 = w_broadcasted[:, layer_index]
60 | w1 = w_broadcasted[:, layer_index + 1]
61 | w2 = w_broadcasted[:, layer_index + 2]
62 |
63 | x = block([x, w0, w1], shift_h_dict=self.shift_h_dict, shift_w_dict=self.shift_w_dict)
64 | y = upsample_2d(y, self.pad0, self.pad1, self.k)
65 | y = y + torgb([x, w2])
66 |
67 | layer_index += 2
68 |
69 | images_out = y
70 |
71 | return images_out
72 |
73 | # def get_config(self):
74 | # config = super(Synthesis, self).get_config()
75 | # config.update({
76 | # 'resolutions': self.resolutions,
77 | # 'featuremaps': self.featuremaps,
78 | # 'k': self.k,
79 | # 'pad0': self.pad0,
80 | # 'pad1': self.pad1,
81 | # })
82 | # return config
83 |
84 |
85 | class SynthesisConstBlock(tf.keras.layers.Layer):
86 | def __init__(self, fmaps, **kwargs):
87 | super(SynthesisConstBlock, self).__init__(**kwargs)
88 | self.res = 4
89 | self.fmaps = fmaps
90 | self.gain = 1.0
91 | self.lrmul = 1.0
92 |
93 | # conv block
94 | self.conv = ModulatedConv2D(fmaps=self.fmaps, style_fmaps=self.fmaps, kernel=3, up=False, down=False,
95 | demodulate=True, resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul,
96 | fused_modconv=True, name='conv')
97 | self.apply_noise = Noise(name='noise')
98 | self.apply_bias_act = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias')
99 |
100 | self.pes_start = PE2dStart(512, 4, 4, scale=1.0)
101 |
102 | # def build(self, input_shape):
103 | # # starting const variable
104 | # # tf 1.15 mean(0.0), std(1.0) default value of tf.initializers.random_normal()
105 | # const_init = tf.random.normal(shape=(1, self.fmaps, self.res, self.res), mean=0.0, stddev=1.0)
106 | # self.const = tf.Variable(const_init, name='const', trainable=True)
107 |
108 | def call(self, inputs, shift_h_dict=None, shift_w_dict=None, training=None, mask=None):
109 | w_broadcasted, w0 = inputs
110 | batch_size = tf.shape(w0)[0]
111 |
112 | # const block
113 | # x = tf.tile(self.const, [batch_size, 1, 1, 1])
114 | x = self.pes_start(w_broadcasted, shift_h_dict[4], shift_w_dict[4])
115 |
116 | # conv block
117 | x = self.conv([x, w0])
118 | x = self.apply_noise(x)
119 | x = self.apply_bias_act(x)
120 | return x
121 |
122 |
123 | class SynthesisBlock(tf.keras.layers.Layer):
124 | def __init__(self, in_ch, fmaps, res, **kwargs):
125 | super(SynthesisBlock, self).__init__(**kwargs)
126 | self.in_ch = in_ch
127 | self.fmaps = fmaps
128 | self.gain = 1.0
129 | self.lrmul = 1.0
130 | self.res = res
131 |
132 | # conv0 up
133 | self.conv_0 = ModulatedConv2D(fmaps=self.fmaps, style_fmaps=self.in_ch, kernel=3, up=True, down=False,
134 | demodulate=True, resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul,
135 | fused_modconv=True, name='conv_0')
136 | self.apply_noise_0 = Noise(name='noise_0')
137 | self.apply_bias_act_0 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_0')
138 |
139 | self.pes = PE2d(channel=fmaps, height=res, width=res, scale=1.0)
140 |
141 | # conv block
142 | self.conv_1 = ModulatedConv2D(fmaps=self.fmaps, style_fmaps=self.fmaps, kernel=3, up=False, down=False,
143 | demodulate=True, resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul,
144 | fused_modconv=True, name='conv_1')
145 | self.apply_noise_1 = Noise(name='noise_1')
146 | self.apply_bias_act_1 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_1')
147 |
148 | def call(self, inputs, shift_h_dict=None, shift_w_dict=None, training=None, mask=None):
149 | x, w0, w1 = inputs
150 |
151 | # conv0 up
152 | x = self.conv_0([x, w0])
153 | x = self.apply_noise_0(x)
154 | x = self.apply_bias_act_0(x)
155 |
156 | # pse
157 | x = self.pes(x, shift_h=shift_h_dict[self.res], shift_w=shift_w_dict[self.res])
158 |
159 | # conv block
160 | x = self.conv_1([x, w1])
161 | x = self.apply_noise_1(x)
162 | x = self.apply_bias_act_1(x)
163 |
164 | return x
165 |
166 | # def get_config(self):
167 | # config = super(SynthesisBlock, self).get_config()
168 | # config.update({
169 | # 'in_ch': self.in_ch,
170 | # 'res': self.res,
171 | # 'fmaps': self.fmaps,
172 | # 'gain': self.gain,
173 | # 'lrmul': self.lrmul,
174 | # })
175 | # return config
176 |
177 | ##################################################################################
178 | # Discriminator Layers
179 | ##################################################################################
180 | class DiscriminatorBlock(tf.keras.layers.Layer):
181 | def __init__(self, n_f0, n_f1, **kwargs):
182 | super(DiscriminatorBlock, self).__init__(**kwargs)
183 | self.gain = 1.0
184 | self.lrmul = 1.0
185 | self.n_f0 = n_f0
186 | self.n_f1 = n_f1
187 | self.resnet_scale = 1. / tf.sqrt(2.)
188 |
189 | # conv_0
190 | self.conv_0 = Conv2D(fmaps=self.n_f0, kernel=3, up=False, down=False,
191 | resample_kernel=None, gain=self.gain, lrmul=self.lrmul, name='conv_0')
192 | self.apply_bias_act_0 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_0')
193 |
194 | # conv_1 down
195 | self.conv_1 = Conv2D(fmaps=self.n_f1, kernel=3, up=False, down=True,
196 | resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, name='conv_1')
197 | self.apply_bias_act_1 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_1')
198 |
199 | # resnet skip
200 | self.conv_skip = Conv2D(fmaps=self.n_f1, kernel=1, up=False, down=True,
201 | resample_kernel=[1, 3, 3, 1], gain=self.gain, lrmul=self.lrmul, name='skip')
202 |
203 | def call(self, inputs, training=None, mask=None):
204 | x = inputs
205 | residual = x
206 |
207 | # conv0
208 | x = self.conv_0(x)
209 | x = self.apply_bias_act_0(x)
210 |
211 | # conv1 down
212 | x = self.conv_1(x)
213 | x = self.apply_bias_act_1(x)
214 |
215 | # resnet skip
216 | residual = self.conv_skip(residual)
217 | x = (x + residual) * self.resnet_scale
218 | return x
219 |
220 | # def get_config(self):
221 | # config = super(DiscriminatorBlock, self).get_config()
222 | # config.update({
223 | # 'n_f0': self.n_f0,
224 | # 'n_f1': self.n_f1,
225 | # 'gain': self.gain,
226 | # 'lrmul': self.lrmul,
227 | # 'res': self.res,
228 | # 'resnet_scale': self.resnet_scale,
229 | # })
230 | # return config
231 |
232 |
233 | class DiscriminatorLastBlock(tf.keras.layers.Layer):
234 | def __init__(self, n_f0, n_f1, **kwargs):
235 | super(DiscriminatorLastBlock, self).__init__(**kwargs)
236 | self.gain = 1.0
237 | self.lrmul = 1.0
238 | self.n_f0 = n_f0
239 | self.n_f1 = n_f1
240 |
241 | self.minibatch_std = MinibatchStd(group_size=4, num_new_features=1, name='minibatchstd')
242 |
243 | # conv_0
244 | self.conv_0 = Conv2D(fmaps=self.n_f0, kernel=3, up=False, down=False,
245 | resample_kernel=None, gain=self.gain, lrmul=self.lrmul, name='conv_0')
246 | self.apply_bias_act_0 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_0')
247 |
248 | # dense_1
249 | self.dense_1 = Dense(self.n_f1, gain=self.gain, lrmul=self.lrmul, name='dense_1')
250 | self.apply_bias_act_1 = BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_1')
251 |
252 | def call(self, x, training=None, mask=None):
253 | x = self.minibatch_std(x)
254 |
255 | # conv_0
256 | x = self.conv_0(x)
257 | x = self.apply_bias_act_0(x)
258 |
259 | # dense_1
260 | x = self.dense_1(x)
261 | x = self.apply_bias_act_1(x)
262 | return x
263 |
264 | # def get_config(self):
265 | # config = super(DiscriminatorLastBlock, self).get_config()
266 | # config.update({
267 | # 'n_f0': self.n_f0,
268 | # 'n_f1': self.n_f1,
269 | # 'gain': self.gain,
270 | # 'lrmul': self.lrmul,
271 | # 'res': self.res,
272 | # })
273 | # return config
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from StyleGAN2 import StyleGAN2
2 | import argparse
3 | from utils import *
4 |
5 | def parse_args():
6 | desc = "Tensorflow implementation of StyleGAN2"
7 | parser = argparse.ArgumentParser(description=desc)
8 | parser.add_argument('--phase', type=str, default='train', help='[train, test, draw]')
9 | parser.add_argument('--draw', type=str, default='all', help='[uncurated, style_mix, truncation_trick, all]')
10 |
11 | parser.add_argument('--dataset', type=str, default='FFHQ', help='dataset_name')
12 |
13 | parser.add_argument('--batch_size', type=int, default=4, help='The size of batch size')
14 | parser.add_argument('--print_freq', type=int, default=2000, help='The number of image_print_freq')
15 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq')
16 |
17 | parser.add_argument('--n_total_image', type=int, default=6400, help='The total iterations')
18 | parser.add_argument('--config', type=str, default='config-f', help='config-e or config-f')
19 | parser.add_argument('--lazy_regularization', type=str2bool, default=True, help='lazy_regularization')
20 |
21 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
22 |
23 | parser.add_argument('--n_test', type=int, default=10, help='The number of images generated by the test phase')
24 |
25 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
26 | help='Directory name to save the checkpoints')
27 | parser.add_argument('--result_dir', type=str, default='results',
28 | help='Directory name to save the generated images')
29 | parser.add_argument('--log_dir', type=str, default='logs',
30 | help='Directory name to save training logs')
31 | parser.add_argument('--sample_dir', type=str, default='samples',
32 | help='Directory name to save the samples on training')
33 |
34 | return check_args(parser.parse_args())
35 |
36 |
37 | """checking arguments"""
38 | def check_args(args):
39 | # --checkpoint_dir
40 | check_folder(args.checkpoint_dir)
41 |
42 | # --result_dir
43 | check_folder(args.result_dir)
44 |
45 | # --result_dir
46 | check_folder(args.log_dir)
47 |
48 | # --sample_dir
49 | check_folder(args.sample_dir)
50 |
51 | # --batch_size
52 | try:
53 | assert args.batch_size >= 1
54 | except:
55 | print('batch size must be larger than or equal to one')
56 |
57 | return args
58 |
59 | """main"""
60 | def main():
61 |
62 | args = vars(parse_args())
63 |
64 | # network params
65 | img_size = args['img_size']
66 | resolutions = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
67 | if args['config'] == 'config-f':
68 | featuremaps = [512, 512, 512, 512, 512, 256, 128, 64, 32] # config-f
69 | else:
70 | featuremaps = [512, 512, 512, 512, 256, 128, 64, 32, 16] # config-e
71 | train_resolutions, train_featuremaps = filter_resolutions_featuremaps(resolutions, featuremaps, img_size)
72 | g_params = {
73 | 'z_dim': 512,
74 | 'w_dim': 512,
75 | 'labels_dim': 0,
76 | 'n_mapping': 8,
77 | 'resolutions': train_resolutions,
78 | 'featuremaps': train_featuremaps,
79 | 'w_ema_decay': 0.995,
80 | 'style_mixing_prob': 0.9,
81 | }
82 | d_params = {
83 | 'labels_dim': 0,
84 | 'resolutions': train_resolutions,
85 | 'featuremaps': train_featuremaps,
86 | }
87 |
88 | strategy = tf.distribute.MirroredStrategy()
89 | NUM_GPUS = strategy.num_replicas_in_sync
90 | batch_size = args['batch_size'] * NUM_GPUS # global batch size
91 |
92 | # training parameters
93 | training_parameters = {
94 | # global params
95 | **args,
96 |
97 | # network params
98 | 'g_params': g_params,
99 | 'd_params': d_params,
100 |
101 | # training params
102 | 'g_opt': {'learning_rate': 0.002, 'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08, 'reg_interval': 4},
103 | 'd_opt': {'learning_rate': 0.002, 'beta1': 0.0, 'beta2': 0.99, 'epsilon': 1e-08, 'reg_interval': 16},
104 | 'batch_size': batch_size,
105 | 'NUM_GPUS' : NUM_GPUS,
106 | 'n_samples': 4,
107 | }
108 |
109 | # automatic_gpu_usage()
110 | with strategy.scope():
111 | gan = StyleGAN2(training_parameters, strategy)
112 |
113 | # build graph
114 | gan.build_model()
115 |
116 |
117 | if args['phase'] == 'train' :
118 | gan.train()
119 | # gan.test_70000() # for FID evaluation ...
120 | print(" [*] Training finished!")
121 |
122 | if args['phase'] == 'test':
123 | gan.test()
124 | print(" [*] Test finished!")
125 |
126 | if args['phase'] == 'draw':
127 |
128 | if args['draw'] == 'style_mix':
129 |
130 | gan.draw_style_mixing_figure()
131 |
132 | print(" [*] Style mix finished!")
133 |
134 |
135 | elif args['draw'] == 'truncation_trick':
136 |
137 | gan.draw_truncation_trick_figure()
138 |
139 | print(" [*] Truncation_trick finished!")
140 |
141 |
142 | elif args['draw'] == 'uncurated':
143 | gan.draw_uncurated_result_figure()
144 |
145 | print(" [*] Un-curated finished!")
146 |
147 | else:
148 | gan.draw_uncurated_result_figure()
149 | print(" [*] Un-curated finished!")
150 | gan.draw_style_mixing_figure()
151 | print(" [*] Style mix finished!")
152 | gan.draw_truncation_trick_figure()
153 | print(" [*] Truncation_trick finished!")
154 |
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | from layers import *
2 | ##################################################################################
3 | # Generator Networks
4 | ##################################################################################
5 | class Generator(tf.keras.Model):
6 | def __init__(self, g_params, **kwargs):
7 | super(Generator, self).__init__(**kwargs)
8 |
9 | self.z_dim = g_params['z_dim']
10 | self.w_dim = g_params['w_dim']
11 | self.labels_dim = g_params['labels_dim']
12 | self.n_mapping = g_params['n_mapping']
13 | self.resolutions = g_params['resolutions']
14 | self.featuremaps = g_params['featuremaps']
15 | self.w_ema_decay = g_params['w_ema_decay']
16 | self.style_mixing_prob = g_params['style_mixing_prob']
17 |
18 | self.n_broadcast = len(self.resolutions) * 2
19 | self.mixing_layer_indices = np.arange(self.n_broadcast)[np.newaxis, :, np.newaxis]
20 |
21 | self.g_mapping = Mapping(self.w_dim, self.labels_dim, self.n_mapping, name='g_mapping')
22 | self.broadcast = tf.keras.layers.Lambda(lambda x: tf.tile(x[:, np.newaxis], [1, self.n_broadcast, 1]))
23 | self.synthesis = Synthesis(self.resolutions, self.featuremaps, name='g_synthesis')
24 |
25 |
26 |
27 | def build(self, input_shape):
28 | # w_avg
29 | self.w_avg = tf.Variable(tf.zeros(shape=[self.w_dim], dtype=tf.float32), name='w_avg', trainable=False,
30 | synchronization=tf.VariableSynchronization.ON_READ,
31 | aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
32 |
33 | @tf.function
34 | def set_as_moving_average_of(self, src_net, beta=0.99, beta_nontrainable=0.0):
35 | for cw, sw in zip(self.weights, src_net.weights):
36 | assert sw.shape == cw.shape
37 |
38 | if 'w_avg' in cw.name:
39 | cw.assign(lerp(sw, cw, beta_nontrainable))
40 | else:
41 | cw.assign(lerp(sw, cw, beta))
42 | return
43 |
44 | def update_moving_average_of_w(self, w_broadcasted):
45 | # compute average of current w
46 | batch_avg = tf.reduce_mean(w_broadcasted[:, 0], axis=0)
47 |
48 | # compute moving average of w and update(assign) w_avg
49 | update_w_avg = lerp(batch_avg, self.w_avg, self.w_ema_decay)
50 |
51 | return update_w_avg
52 |
53 | def style_mixing_regularization(self, latents1, labels, w_broadcasted1):
54 | # get another w and broadcast it
55 | latents2 = tf.random.normal(shape=tf.shape(latents1), dtype=tf.float32)
56 | dlatents2 = self.g_mapping([latents2, labels])
57 | w_broadcasted2 = self.broadcast(dlatents2)
58 |
59 |
60 | # find mixing limit index
61 | if tf.random.uniform([], 0.0, 1.0) < self.style_mixing_prob:
62 | mixing_cutoff_index = tf.random.uniform([], 1, self.n_broadcast, dtype=tf.int32)
63 | else:
64 | mixing_cutoff_index = tf.constant(self.n_broadcast, dtype=tf.int32)
65 |
66 | # mix it
67 | mixed_w_broadcasted = tf.where(
68 | condition=tf.broadcast_to(self.mixing_layer_indices < mixing_cutoff_index, tf.shape(w_broadcasted1)),
69 | x=w_broadcasted1,
70 | y=w_broadcasted2)
71 |
72 | return mixed_w_broadcasted
73 |
74 | def truncation_trick(self, w_broadcasted, truncation_cutoff, truncation_psi):
75 | ones = tf.ones_like(self.mixing_layer_indices, dtype=tf.float32)
76 | tpsi = ones * truncation_psi
77 |
78 | if truncation_cutoff is None:
79 | truncation_coefs = tpsi
80 | else:
81 | # indices = tf.range(self.n_broadcast)
82 | indices = self.mixing_layer_indices
83 | truncation_coefs = tf.where(condition=tf.less(indices, truncation_cutoff), x=tpsi, y=ones)
84 |
85 | truncated_w_broadcasted = lerp(self.w_avg, w_broadcasted, truncation_coefs)
86 |
87 | return truncated_w_broadcasted
88 |
89 | def call(self, inputs, truncation_cutoff=None, truncation_psi=1.0, shift_h=0, shift_w=0, training=None, mapping=True, mask=None):
90 | latents, labels = inputs
91 |
92 | if mapping:
93 | dlatents = self.g_mapping([latents, labels])
94 | w_broadcasted = self.broadcast(dlatents)
95 |
96 | if training:
97 | self.w_avg.assign(self.update_moving_average_of_w(w_broadcasted))
98 | w_broadcasted = self.style_mixing_regularization(latents, labels, w_broadcasted)
99 |
100 | if not training:
101 | w_broadcasted = self.truncation_trick(w_broadcasted, truncation_cutoff, truncation_psi)
102 |
103 | else:
104 | w_broadcasted = latents
105 |
106 | image_out = self.synthesis(w_broadcasted, shift_h=shift_h, shift_w=shift_w)
107 |
108 | return image_out, w_broadcasted
109 |
110 | def compute_output_shape(self, input_shape):
111 | assert isinstance(input_shape, list)
112 |
113 | # shape_latents, shape_labels = input_shape
114 | return input_shape[0][0], 3, self.resolutions[-1], self.resolutions[-1]
115 |
116 |
117 | ##################################################################################
118 | # Discriminator Networks
119 | ##################################################################################
120 | class Discriminator(tf.keras.Model):
121 | def __init__(self, d_params, **kwargs):
122 | super(Discriminator, self).__init__(**kwargs)
123 | # discriminator's (resolutions and featuremaps) are reversed against generator's
124 | self.labels_dim = d_params['labels_dim']
125 | self.r_resolutions = d_params['resolutions'][::-1]
126 | self.r_featuremaps = d_params['featuremaps'][::-1]
127 |
128 | # stack discriminator blocks
129 | res0, n_f0 = self.r_resolutions[0], self.r_featuremaps[0]
130 | self.initial_fromrgb = FromRGB(fmaps=n_f0, name='{:d}x{:d}/FromRGB'.format(res0, res0))
131 | self.blocks = []
132 |
133 | for index, (res0, n_f0) in enumerate(zip(self.r_resolutions[:-1], self.r_featuremaps[:-1])):
134 | n_f1 = self.r_featuremaps[index + 1]
135 | self.blocks.append(DiscriminatorBlock(n_f0=n_f0, n_f1=n_f1, name='{:d}x{:d}'.format(res0, res0)))
136 |
137 | # set last discriminator block
138 | res = self.r_resolutions[-1]
139 | n_f0, n_f1 = self.r_featuremaps[-2], self.r_featuremaps[-1]
140 | self.last_block = DiscriminatorLastBlock(n_f0, n_f1, name='{:d}x{:d}'.format(res, res))
141 |
142 | # set last dense layer
143 | self.last_dense = Dense(max(self.labels_dim, 1), gain=1.0, lrmul=1.0, name='last_dense')
144 | self.last_bias = BiasAct(lrmul=1.0, act='linear', name='last_bias')
145 |
146 |
147 |
148 | # @ tf.function
149 | def call(self, inputs, training=None, mask=None):
150 | images, labels = inputs
151 |
152 | x = self.initial_fromrgb(images)
153 | for block in self.blocks:
154 | x = block(x)
155 |
156 | x = self.last_block(x)
157 |
158 | logit = self.last_dense(x)
159 | logit = self.last_bias(logit)
160 |
161 | if self.labels_dim > 0:
162 | logit = tf.reduce_sum(logit * labels, axis=1, keepdims=True)
163 |
164 | scores_out = logit
165 |
166 | return scores_out
167 |
168 | def compute_output_shape(self, input_shape):
169 | return input_shape[0][0], max(self.labels_dim, 1)
170 |
171 | ##################################################################################
172 | # Mapping Networks
173 | ##################################################################################
174 | class Mapping(tf.keras.layers.Layer):
175 | def __init__(self, w_dim, labels_dim, n_mapping, **kwargs):
176 | super(Mapping, self).__init__(**kwargs)
177 | self.w_dim = w_dim
178 | self.labels_dim = labels_dim
179 | self.n_mapping = n_mapping
180 | self.gain = 1.0
181 | self.lrmul = 0.01
182 |
183 | if self.labels_dim > 0:
184 | self.labels_embedding = LabelEmbedding(embed_dim=self.w_dim, name='labels_embedding')
185 |
186 | self.normalize = tf.keras.layers.Lambda(lambda x: x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + 1e-8))
187 |
188 | self.dense_layers = []
189 | self.bias_act_layers = []
190 |
191 | for ii in range(self.n_mapping):
192 | self.dense_layers.append(Dense(w_dim, gain=self.gain, lrmul=self.lrmul, name='dense_{:d}'.format(ii)))
193 | self.bias_act_layers.append(BiasAct(lrmul=self.lrmul, act='lrelu', name='bias_{:d}'.format(ii)))
194 |
195 | def call(self, inputs, training=None, mask=None):
196 | latents, labels = inputs
197 | x = latents
198 |
199 | # embed label if any
200 | if self.labels_dim > 0:
201 | y = self.labels_embedding(labels)
202 | x = tf.concat([x, y], axis=1)
203 |
204 | # normalize inputs
205 | x = self.normalize(x)
206 |
207 | # apply mapping blocks
208 | for dense, apply_bias_act in zip(self.dense_layers, self.bias_act_layers):
209 | x = dense(x)
210 | x = apply_bias_act(x)
211 |
212 | return x
213 |
214 | # def get_config(self):
215 | # config = super(Mapping, self).get_config()
216 | # config.update({
217 | # 'w_dim': self.w_dim,
218 | # 'labels_dim': self.labels_dim,
219 | # 'n_mapping': self.n_mapping,
220 | # 'gain': self.gain,
221 | # 'lrmul': self.lrmul,
222 | # })
223 | # return config
224 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from cuda.upfirdn_2d import *
4 | from cuda.fused_bias_act import fused_bias_act
5 |
6 | def compute_runtime_coef(weight_shape, gain, lrmul):
7 | fan_in = tf.reduce_prod(weight_shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
8 | fan_in = tf.cast(fan_in, dtype=tf.float32)
9 | he_std = gain / tf.sqrt(fan_in)
10 | init_std = 1.0 / lrmul
11 | runtime_coef = he_std * lrmul
12 | return init_std, runtime_coef
13 |
14 | def lerp(a, b, t):
15 | out = a + (b - a) * t
16 | return out
17 |
18 | def lerp_clip(a, b, t):
19 | out = a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
20 | return out
21 |
22 | ##################################################################################
23 | # Layers
24 | ##################################################################################
25 |
26 | class Conv2D(tf.keras.layers.Layer):
27 | def __init__(self, fmaps, kernel, up, down, resample_kernel, gain, lrmul, **kwargs):
28 | super(Conv2D, self).__init__(**kwargs)
29 | self.fmaps = fmaps
30 | self.kernel = kernel
31 | self.gain = gain
32 | self.lrmul = lrmul
33 | self.up = up
34 | self.down = down
35 |
36 | self.k, self.pad0, self.pad1 = compute_paddings(resample_kernel, self.kernel, up, down, is_conv=True)
37 |
38 | def build(self, input_shape):
39 | weight_shape = [self.kernel, self.kernel, input_shape[1], self.fmaps]
40 | init_std, self.runtime_coef = compute_runtime_coef(weight_shape, self.gain, self.lrmul)
41 |
42 | # [kernel, kernel, fmaps_in, fmaps_out]
43 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=init_std)
44 | self.w = tf.Variable(w_init, name='w', trainable=True)
45 |
46 | def call(self, inputs, training=None, mask=None):
47 | x = inputs
48 | w = self.runtime_coef * self.w
49 |
50 | # actual conv
51 | if self.up:
52 | x = upsample_conv_2d(x, w, self.kernel, self.kernel, self.pad0, self.pad1, self.k)
53 | elif self.down:
54 | x = conv_downsample_2d(x, w, self.kernel, self.kernel, self.pad0, self.pad1, self.k)
55 | else:
56 | x = tf.nn.conv2d(x, w, data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME')
57 | return x
58 |
59 | # def get_config(self):
60 | # config = super(Conv2D, self).get_config()
61 | # config.update({
62 | # 'in_res': self.in_res,
63 | # 'in_fmaps': self.in_fmaps,
64 | # 'fmaps': self.fmaps,
65 | # 'kernel': self.kernel,
66 | # 'gain': self.gain,
67 | # 'lrmul': self.lrmul,
68 | # 'up': self.up,
69 | # 'down': self.down,
70 | # 'k': self.k,
71 | # 'pad0': self.pad0,
72 | # 'pad1': self.pad1,
73 | # 'runtime_coef': self.runtime_coef,
74 | # })
75 | # return config
76 |
77 | class ModulatedConv2D(tf.keras.layers.Layer):
78 | def __init__(self, fmaps, style_fmaps, kernel, up, down, demodulate, resample_kernel, gain, lrmul, fused_modconv, **kwargs):
79 | super(ModulatedConv2D, self).__init__(**kwargs)
80 | assert not (up and down)
81 |
82 | self.fmaps = fmaps
83 | self.style_fmaps = style_fmaps
84 | self.kernel = kernel
85 | self.demodulate = demodulate
86 | self.up = up
87 | self.down = down
88 | self.fused_modconv = fused_modconv
89 | self.gain = gain
90 | self.lrmul = lrmul
91 |
92 | self.k, self.pad0, self.pad1 = compute_paddings(resample_kernel, self.kernel, up, down, is_conv=True)
93 |
94 | # self.factor = 2
95 | self.mod_dense = Dense(self.style_fmaps, gain=1.0, lrmul=1.0, name='mod_dense')
96 | self.mod_bias = BiasAct(lrmul=1.0, act='linear', name='mod_bias')
97 |
98 | def build(self, input_shape):
99 | x_shape, w_shape = input_shape[0], input_shape[1]
100 | in_fmaps = x_shape[1]
101 | weight_shape = [self.kernel, self.kernel, in_fmaps, self.fmaps]
102 | init_std, self.runtime_coef = compute_runtime_coef(weight_shape, self.gain, self.lrmul)
103 |
104 | # [kkIO]
105 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=init_std)
106 | self.w = tf.Variable(w_init, name='w', trainable=True)
107 |
108 | def scale_conv_weights(self, w):
109 | # convolution kernel weights for fused conv
110 | weight = self.runtime_coef * self.w # [kkIO]
111 | weight = weight[np.newaxis] # [BkkIO]
112 |
113 | # modulation
114 | style = self.mod_dense(w) # [BI]
115 | style = self.mod_bias(style) + 1.0 # [BI]
116 | weight *= style[:, np.newaxis, np.newaxis, :, np.newaxis] # [BkkIO]
117 |
118 | # demodulation
119 | d = None
120 | if self.demodulate:
121 | d = tf.math.rsqrt(tf.reduce_sum(tf.square(weight), axis=[1, 2, 3]) + 1e-8) # [BO]
122 | weight *= d[:, np.newaxis, np.newaxis, np.newaxis, :] # [BkkIO]
123 |
124 | return weight, style, d
125 |
126 | def call(self, inputs, training=None, mask=None):
127 | x, y = inputs
128 | # height, width = tf.shape(x)[2], tf.shape(x)[3]
129 |
130 | # prepare weights: [BkkIO] Introduce minibatch dimension
131 | # prepare convoultuon kernel weights
132 | weight, style, d = self.scale_conv_weights(y)
133 |
134 | if self.fused_modconv:
135 | # Fused => reshape minibatch to convolution groups
136 | x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]])
137 |
138 | # weight: reshape, prepare for fused operation
139 | new_weight_shape = [tf.shape(weight)[1], tf.shape(weight)[2], tf.shape(weight)[3], -1] # [kkI(BO)]
140 | weight = tf.transpose(weight, [1, 2, 3, 0, 4]) # [kkIBO]
141 | weight = tf.reshape(weight, shape=new_weight_shape) # [kkI(BO)]
142 | else:
143 | # [BIhw] Not fused => scale input activations
144 | x *= style[:, :, tf.newaxis, tf.newaxis]
145 |
146 | # Convolution with optional up/downsampling.
147 | if self.up:
148 | x = upsample_conv_2d(x, weight, self.kernel, self.kernel, self.pad0, self.pad1, self.k)
149 | elif self.down:
150 | x = conv_downsample_2d(x, weight, self.kernel, self.kernel, self.pad0, self.pad1, self.k)
151 | else:
152 | x = tf.nn.conv2d(x, weight, data_format='NCHW', strides=[1, 1, 1, 1], padding='SAME')
153 |
154 | # Reshape/scale output
155 | if self.fused_modconv:
156 | # Fused => reshape convolution groups back to minibatch
157 | x_shape = tf.shape(x)
158 | x = tf.reshape(x, [-1, self.fmaps, x_shape[2], x_shape[3]])
159 | elif self.demodulate:
160 | # [BOhw] Not fused => scale output activations
161 | x *= d[:, :, tf.newaxis, tf.newaxis]
162 |
163 | return x
164 |
165 | # def get_config(self):
166 | # config = super(ModulatedConv2D, self).get_config()
167 | # config.update({
168 | # 'in_res': self.in_res,
169 | # 'in_fmaps': self.in_fmaps,
170 | # 'fmaps': self.fmaps,
171 | # 'kernel': self.kernel,
172 | # 'demodulate': self.demodulate,
173 | # 'fused_modconv': self.fused_modconv,
174 | # 'gain': self.gain,
175 | # 'lrmul': self.lrmul,
176 | # 'up': self.up,
177 | # 'down': self.down,
178 | # 'k': self.k,
179 | # 'pad0': self.pad0,
180 | # 'pad1': self.pad1,
181 | # 'runtime_coef': self.runtime_coef,
182 | # })
183 | # return config
184 |
185 | class Dense(tf.keras.layers.Layer):
186 | def __init__(self, fmaps, gain, lrmul, **kwargs):
187 | super(Dense, self).__init__(**kwargs)
188 | self.fmaps = fmaps
189 | self.gain = gain
190 | self.lrmul = lrmul
191 |
192 | def build(self, input_shape):
193 | fan_in = tf.reduce_prod(input_shape[1:])
194 | weight_shape = [fan_in, self.fmaps]
195 | init_std, self.runtime_coef = compute_runtime_coef(weight_shape, self.gain, self.lrmul)
196 |
197 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=init_std)
198 | self.w = tf.Variable(w_init, name='w', trainable=True)
199 |
200 | def call(self, inputs, training=None, mask=None):
201 | weight = self.runtime_coef * self.w
202 |
203 | c = tf.reduce_prod(tf.shape(inputs)[1:])
204 | x = tf.reshape(inputs, shape=[-1, c])
205 | x = tf.matmul(x, weight)
206 | return x
207 |
208 | # def get_config(self):
209 | # config = super(Dense, self).get_config()
210 | # config.update({
211 | # 'fmaps': self.fmaps,
212 | # 'gain': self.gain,
213 | # 'lrmul': self.lrmul,
214 | # 'runtime_coef': self.runtime_coef,
215 | # })
216 | # return config
217 |
218 | class LabelEmbedding(tf.keras.layers.Layer):
219 | def __init__(self, embed_dim, **kwargs):
220 | super(LabelEmbedding, self).__init__(**kwargs)
221 | self.embed_dim = embed_dim
222 |
223 | def build(self, input_shape):
224 | weight_shape = [input_shape[1], self.embed_dim]
225 | # tf 1.15 mean(0.0), std(1.0) default value of tf.initializers.random_normal()
226 | w_init = tf.random.normal(shape=weight_shape, mean=0.0, stddev=1.0)
227 | self.w = tf.Variable(w_init, name='w', trainable=True)
228 |
229 | def call(self, inputs, training=None, mask=None):
230 | x = tf.matmul(inputs, self.w)
231 | return x
232 |
233 | # def get_config(self):
234 | # config = super(LabelEmbedding, self).get_config()
235 | # config.update({
236 | # 'embed_dim': self.embed_dim,
237 | # })
238 | # return config
239 |
240 | ##################################################################################
241 | # Blocks
242 | ##################################################################################
243 | class FromRGB(tf.keras.layers.Layer):
244 | def __init__(self, fmaps, **kwargs):
245 | super(FromRGB, self).__init__(**kwargs)
246 | self.fmaps = fmaps
247 |
248 | self.conv = Conv2D(fmaps=self.fmaps, kernel=1, up=False, down=False,
249 | resample_kernel=None, gain=1.0, lrmul=1.0, name='conv')
250 | self.apply_bias_act = BiasAct(lrmul=1.0, act='lrelu', name='bias')
251 |
252 | def call(self, inputs, training=None, mask=None):
253 | y = self.conv(inputs)
254 | y = self.apply_bias_act(y)
255 | return y
256 |
257 | # def get_config(self):
258 | # config = super(FromRGB, self).get_config()
259 | # config.update({
260 | # 'fmaps': self.fmaps,
261 | # 'res': self.res,
262 | # })
263 | # return config
264 |
265 | class ToRGB(tf.keras.layers.Layer):
266 | def __init__(self, in_ch, **kwargs):
267 | super(ToRGB, self).__init__(**kwargs)
268 | self.in_ch = in_ch
269 |
270 | self.conv = ModulatedConv2D(fmaps=3, style_fmaps=in_ch, kernel=1, up=False, down=False, demodulate=False,
271 | resample_kernel=None, gain=1.0, lrmul=1.0, fused_modconv=True, name='conv')
272 | self.apply_bias = BiasAct(lrmul=1.0, act='linear', name='bias')
273 |
274 | def call(self, inputs, training=None, mask=None):
275 | x, w = inputs
276 |
277 | x = self.conv([x, w])
278 | x = self.apply_bias(x)
279 | return x
280 |
281 | # def get_config(self):
282 | # config = super(ToRGB, self).get_config()
283 | # config.update({
284 | # 'in_ch': self.in_ch,
285 | # 'res': self.res,
286 | # })
287 | # return config
288 |
289 | class PE2d(tf.keras.layers.Layer):
290 | def __init__(self, channel, height, width, scale=1.0):
291 | super(PE2d, self).__init__()
292 | if channel % 4 != 0:
293 | raise ValueError("Cannot use sin/cos positional encoding with "
294 | "odd dimension (got dim={:d})".format(channel))
295 |
296 | height = int(height * scale)
297 | width = int(width * scale)
298 | self.pe = np.zeros(shape=[channel, height, width], dtype=np.float32)
299 |
300 | # Each dimension use half of d_model
301 | self.d_model = int(channel / 2)
302 | self.div_term = np.exp(np.arange(0., self.d_model, 2.) * -(np.log(10000.) / self.d_model)) / scale
303 | self.pos_h = np.expand_dims(np.arange(0., height), axis=-1) # [4, 1]
304 | self.pos_w = np.expand_dims(np.arange(0., width), axis=-1)
305 |
306 |
307 | self.gamma = tf.Variable(initial_value=tf.ones(shape=[1], dtype=tf.float32), trainable=True)
308 |
309 | def call(self, inputs, shift_h=0, shift_w=0, training=None, mask=None):
310 | pos_h = np.roll(self.pos_h, round(shift_h), 0) + (round(shift_h) - shift_h)
311 | pos_w = np.roll(self.pos_w, round(shift_w), 0) + (round(shift_w) - shift_w)
312 |
313 | self.pe[0:self.d_model:2, :, :] = np.tile(
314 | np.expand_dims(
315 | np.transpose(
316 | np.sin(pos_w * self.div_term),
317 | axes=[1, 0]),
318 | axis=1),
319 | reps=[1, pos_h.shape[0], 1])
320 |
321 | self.pe[1:self.d_model:2, :, :] = np.tile(
322 | np.expand_dims(
323 | np.transpose(
324 | np.cos(pos_w * self.div_term),
325 | axes=[1, 0]),
326 | axis=1),
327 | reps=[1, pos_h.shape[0], 1])
328 |
329 | self.pe[self.d_model::2, :, :] = np.tile(
330 | np.expand_dims(
331 | np.transpose(
332 | np.sin(pos_h * self.div_term),
333 | axes=[1, 0]),
334 | axis=2),
335 | reps=[1, 1, pos_w.shape[0]])
336 |
337 | self.pe[self.d_model + 1::2, :, :] = np.tile(
338 | np.expand_dims(
339 | np.transpose(
340 | np.cos(pos_h * self.div_term),
341 | axes=[1, 0]),
342 | axis=2),
343 | reps=[1, 1, pos_w.shape[0]])
344 |
345 | x = tf.cast(inputs, dtype=tf.float32) + self.gamma * np.expand_dims(self.pe, axis=0)
346 |
347 | return x
348 |
349 | class PE2dStart(tf.keras.layers.Layer):
350 | def __init__(self, channel, height, width, scale=1.0):
351 | super(PE2dStart, self).__init__()
352 | if channel % 4 != 0:
353 | raise ValueError("Cannot use sin/cos positional encoding with "
354 | "odd dimension (got dim={:d})".format(channel))
355 |
356 | height = int(height * scale)
357 | width = int(width * scale)
358 | self.pe = np.zeros(shape=[channel, height, width])
359 |
360 | # Each dimension use half of d_model
361 | self.d_model = int(channel / 2)
362 | self.div_term = np.exp(np.arange(0., self.d_model, 2.) * -(np.log(10000.) / self.d_model)) / scale
363 | self.pos_h = np.expand_dims(np.arange(0., height), axis=-1) # [4, 1]
364 | self.pos_w = np.expand_dims(np.arange(0., width), axis=-1)
365 |
366 | def call(self, inputs, shift_h=0, shift_w=0, training=None, mask=None):
367 | pos_h = np.roll(self.pos_h, round(shift_h), 0) + (round(shift_h) - shift_h)
368 | pos_w = np.roll(self.pos_w, round(shift_w), 0) + (round(shift_w) - shift_w)
369 |
370 | self.pe[0:self.d_model:2, :, :] = np.tile(
371 | np.expand_dims(
372 | np.transpose(
373 | np.sin(pos_w * self.div_term),
374 | axes=[1, 0]),
375 | axis=1),
376 | reps=[1, pos_h.shape[0], 1])
377 |
378 | self.pe[1:self.d_model:2, :, :] = np.tile(
379 | np.expand_dims(
380 | np.transpose(
381 | np.cos(pos_w * self.div_term),
382 | axes=[1, 0]),
383 | axis=1),
384 | reps=[1, pos_h.shape[0], 1])
385 |
386 | self.pe[self.d_model::2, :, :] = np.tile(
387 | np.expand_dims(
388 | np.transpose(
389 | np.sin(pos_h * self.div_term),
390 | axes=[1, 0]),
391 | axis=2),
392 | reps=[1, 1, pos_w.shape[0]])
393 |
394 | self.pe[self.d_model + 1::2, :, :] = np.tile(
395 | np.expand_dims(
396 | np.transpose(
397 | np.cos(pos_h * self.div_term),
398 | axes=[1, 0]),
399 | axis=2),
400 | reps=[1, 1, pos_w.shape[0]])
401 |
402 | x = np.tile(np.expand_dims(self.pe, axis=0), reps=[inputs.shape[0], 1, 1, 1])
403 |
404 | return x
405 |
406 | class ConstantInput(tf.keras.layers.Layer):
407 | def __init__(self, channel, size=4):
408 | super(ConstantInput, self).__init__()
409 |
410 | const_init = tf.random.normal(shape=(1, channel, size, size), mean=0.0, stddev=1.0)
411 | self.const = tf.Variable(const_init, name='const', trainable=True)
412 |
413 | def call(self, inputs, training=None, mask=None):
414 | batch = inputs.shape[0]
415 | x = tf.tile(self.const, multiples=[batch, 1, 1, 1])
416 |
417 | return x
418 |
419 | ##################################################################################
420 | # etc
421 | ##################################################################################
422 | class BiasAct(tf.keras.layers.Layer):
423 | def __init__(self, lrmul, act, **kwargs):
424 | super(BiasAct, self).__init__(**kwargs)
425 | self.lrmul = lrmul
426 | self.act = act
427 |
428 | def build(self, input_shape):
429 | b_init = tf.zeros(shape=(input_shape[1],), dtype=tf.float32)
430 | self.b = tf.Variable(b_init, name='b', trainable=True)
431 |
432 | def call(self, inputs, training=None, mask=None):
433 | b = self.lrmul * self.b
434 | x = fused_bias_act(inputs, b=b, act=self.act, alpha=None, gain=None)
435 | return x
436 |
437 | # def get_config(self):
438 | # config = super(BiasAct, self).get_config()
439 | # config.update({
440 | # 'lrmul': self.lrmul,
441 | # 'act': self.act,
442 | # })
443 | # return config
444 |
445 | class Noise(tf.keras.layers.Layer):
446 | def __init__(self, **kwargs):
447 | super(Noise, self).__init__(**kwargs)
448 |
449 | def build(self, input_shape):
450 | self.noise_strength = tf.Variable(initial_value=0.0, dtype=tf.float32, trainable=True, name='w')
451 |
452 |
453 | def call(self, inputs, noise=None, training=None, mask=None):
454 | x_shape = tf.shape(inputs)
455 |
456 | # noise: [1, 1, x_shape[2], x_shape[3]] or None
457 | if noise is None:
458 | noise = tf.random.normal(shape=(x_shape[0], 1, x_shape[2], x_shape[3]), dtype=tf.float32)
459 |
460 | x = inputs + noise * self.noise_strength
461 | return x
462 |
463 | def get_config(self):
464 | config = super(Noise, self).get_config()
465 | config.update({})
466 | return config
467 |
468 | class MinibatchStd(tf.keras.layers.Layer):
469 | def __init__(self, group_size, num_new_features, **kwargs):
470 | super(MinibatchStd, self).__init__(**kwargs)
471 | self.group_size = group_size
472 | self.num_new_features = num_new_features
473 |
474 | def call(self, inputs, training=None, mask=None):
475 | s = tf.shape(inputs)
476 | group_size = tf.minimum(self.group_size, s[0])
477 |
478 | y = tf.reshape(inputs, [group_size, -1, self.num_new_features, s[1] // self.num_new_features, s[2], s[3]])
479 | y = tf.cast(y, tf.float32)
480 | y -= tf.reduce_mean(y, axis=0, keepdims=True)
481 | y = tf.reduce_mean(tf.square(y), axis=0)
482 | y = tf.sqrt(y + 1e-8)
483 | y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True)
484 | y = tf.reduce_mean(y, axis=[2])
485 | y = tf.cast(y, inputs.dtype)
486 | y = tf.tile(y, [group_size, 1, s[2], s[3]])
487 |
488 | x = tf.concat([inputs, y], axis=1)
489 | return x
490 |
491 | def get_config(self):
492 | config = super(MinibatchStd, self).get_config()
493 | config.update({
494 | 'group_size': self.group_size,
495 | 'num_new_features': self.num_new_features,
496 | })
497 | return config
498 |
499 | def torch_normalization(x):
500 | x /= 255.
501 |
502 | r, g, b = tf.split(axis=-1, num_or_size_splits=3, value=x)
503 |
504 | mean = [0.485, 0.456, 0.406]
505 | std = [0.229, 0.224, 0.225]
506 |
507 | x = tf.concat(axis=-1, values=[
508 | (r - mean[0]) / std[0],
509 | (g - mean[1]) / std[1],
510 | (b - mean[2]) / std[2]
511 | ])
512 |
513 | return x
514 |
515 |
516 | def inception_processing(filename):
517 | x = tf.io.read_file(filename)
518 | img = tf.image.decode_jpeg(x, channels=3, dct_method='INTEGER_ACCURATE')
519 | img = tf.image.resize(img, [256, 256], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
520 | img = tf.image.resize(img, [299, 299], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
521 |
522 | img = torch_normalization(img)
523 | # img = tf.transpose(img, [2, 0, 1])
524 | return img
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 |
5 | import tensorflow as tf
6 | from glob import glob
7 |
8 | class Image_data:
9 |
10 | def __init__(self, img_size, z_dim, labels_dim, dataset_path):
11 | self.img_size = img_size
12 | self.z_dim = z_dim
13 | self.labels_dim = labels_dim
14 | self.dataset_path = dataset_path
15 |
16 |
17 | def image_processing(self, filename):
18 |
19 | x = tf.io.read_file(filename)
20 | x_decode = tf.image.decode_jpeg(x, channels=3, dct_method='INTEGER_ACCURATE')
21 | img = tf.image.resize(x_decode, [self.img_size, self.img_size], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
22 | img = preprocess_fit_train_image(img)
23 |
24 | latent = tf.random.normal(shape=(self.z_dim,), dtype=tf.float32)
25 | labels = tf.random.normal((self.labels_dim,), dtype=tf.float32)
26 |
27 | return img, latent, labels
28 |
29 | def preprocess(self):
30 |
31 | self.train_images = glob(os.path.join(self.dataset_path, '*.png')) + glob(os.path.join(self.dataset_path, '*.jpg'))
32 |
33 | def adjust_dynamic_range(images, range_in, range_out, out_dtype):
34 | scale = (range_out[1] - range_out[0]) / (range_in[1] - range_in[0])
35 | bias = range_out[0] - range_in[0] * scale
36 | images = images * scale + bias
37 | images = tf.clip_by_value(images, range_out[0], range_out[1])
38 | images = tf.cast(images, dtype=out_dtype)
39 | return images
40 |
41 | def random_flip_left_right(images):
42 | s = tf.shape(images)
43 | mask = tf.random.uniform([1, 1, 1], 0.0, 1.0)
44 | mask = tf.tile(mask, [s[0], s[1], s[2]]) # [h, w, c]
45 | images = tf.where(mask < 0.5, images, tf.reverse(images, axis=[1]))
46 | return images
47 |
48 | def preprocess_fit_train_image(images):
49 | images = adjust_dynamic_range(images, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32)
50 | images = random_flip_left_right(images)
51 | images = tf.transpose(images, [2, 0, 1])
52 |
53 | return images
54 |
55 | def preprocess_image(images):
56 | images = adjust_dynamic_range(images, range_in=(0.0, 255.0), range_out=(-1.0, 1.0), out_dtype=tf.dtypes.float32)
57 | images = tf.transpose(images, [2, 0, 1])
58 |
59 | return images
60 |
61 | def postprocess_images(images):
62 | images = adjust_dynamic_range(images, range_in=(-1.0, 1.0), range_out=(0.0, 255.0), out_dtype=tf.dtypes.float32)
63 | images = tf.transpose(images, [0, 2, 3, 1])
64 | images = tf.cast(images, dtype=tf.dtypes.uint8)
65 | return images
66 |
67 | def merge_batch_images(images, res, rows, cols):
68 | batch_size = images.shape[0]
69 | assert rows * cols == batch_size
70 | canvas = np.zeros(shape=[res * rows, res * cols, 3], dtype=np.uint8)
71 | for row in range(rows):
72 | y_start = row * res
73 | for col in range(cols):
74 | x_start = col * res
75 | index = col + row * cols
76 | canvas[y_start:y_start + res, x_start:x_start + res, :] = images[index, :, :, :]
77 | return canvas
78 |
79 | def load_images(image_path, img_width, img_height, img_channel):
80 |
81 | # from PIL import Image
82 | if img_channel == 1 :
83 | img = cv2.imread(image_path, flags=cv2.IMREAD_GRAYSCALE)
84 | else :
85 | img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
86 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
87 |
88 | # img = cv2.resize(img, dsize=(img_width, img_height))
89 | img = tf.image.resize(img, [img_height, img_width], antialias=True, method=tf.image.ResizeMethod.BICUBIC)
90 | img = preprocess_image(img)
91 |
92 | if img_channel == 1 :
93 | img = np.expand_dims(img, axis=0)
94 | img = np.expand_dims(img, axis=-1)
95 | else :
96 | img = np.expand_dims(img, axis=0)
97 |
98 | return img
99 |
100 | def save_images(images, size, image_path):
101 | # size = [height, width]
102 | return imsave(postprocess_images(images), size, image_path)
103 |
104 | def imsave(images, size, path):
105 | images = merge(images, size)
106 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
107 |
108 | return cv2.imwrite(path, images)
109 |
110 | def merge(images, size):
111 | h, w = images.shape[1], images.shape[2]
112 | img = np.zeros((h * size[0], w * size[1], 3))
113 | for idx, image in enumerate(images):
114 | i = idx % size[1]
115 | j = idx // size[1]
116 | img[h*j:h*(j+1), w*i:w*(i+1), :] = image
117 |
118 | return img
119 |
120 | def str2bool(x):
121 | return x.lower() in ('true')
122 |
123 | def check_folder(log_dir):
124 | if not os.path.exists(log_dir):
125 | os.makedirs(log_dir)
126 | return log_dir
127 |
128 | def filter_resolutions_featuremaps(resolutions, featuremaps, res):
129 | index = resolutions.index(res)
130 | filtered_resolutions = resolutions[:index + 1]
131 | filtered_featuremaps = featuremaps[:index + 1]
132 | return filtered_resolutions, filtered_featuremaps
133 |
134 | def pytorch_xavier_weight_factor(gain=0.02) :
135 |
136 | factor = gain * gain
137 | mode = 'fan_avg'
138 |
139 | return factor, mode
140 |
141 | def pytorch_kaiming_weight_factor(a=0.0, activation_function='relu') :
142 |
143 | if activation_function == 'relu' :
144 | gain = np.sqrt(2.0)
145 | elif activation_function == 'leaky_relu' :
146 | gain = np.sqrt(2.0 / (1 + a ** 2))
147 | elif activation_function =='tanh' :
148 | gain = 5.0 / 3
149 | else :
150 | gain = 1.0
151 |
152 | factor = gain * gain
153 | mode = 'fan_in'
154 |
155 | return factor, mode
156 |
157 | def automatic_gpu_usage() :
158 | gpus = tf.config.experimental.list_physical_devices('GPU')
159 | if gpus:
160 | try:
161 | # Currently, memory growth needs to be the same across GPUs
162 | for gpu in gpus:
163 | tf.config.experimental.set_memory_growth(gpu, True)
164 | logical_gpus = tf.config.experimental.list_logical_devices('GPU')
165 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
166 | except RuntimeError as e:
167 | # Memory growth must be set before GPUs have been initialized
168 | print(e)
169 |
170 | def multiple_gpu_usage():
171 | gpus = tf.config.experimental.list_physical_devices('GPU')
172 | if gpus:
173 | # Create 2 virtual GPUs with 1GB memory each
174 | try:
175 | tf.config.experimental.set_virtual_device_configuration(
176 | gpus[0],
177 | [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096),
178 | tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
179 | logical_gpus = tf.config.experimental.list_logical_devices('GPU')
180 | print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs")
181 | except RuntimeError as e:
182 | # Virtual devices must be set before GPUs have been initialized
183 | print(e)
184 |
185 | def get_batch_sizes(gpu_num) :
186 | from collections import OrderedDict
187 | # batch size for each gpu
188 |
189 | if gpu_num == 1:
190 | x = OrderedDict([(4, 256), (8, 256), (16, 128), (32, 64), (64, 32), (128, 16), (256, 8), (512, 4), (1024, 4)])
191 |
192 | elif gpu_num == 2 or gpu_num == 3:
193 | x = OrderedDict([(4, 128), (8, 128), (16, 64), (32, 32), (64, 16), (128, 8), (256, 4), (512, 4), (1024, 4)])
194 |
195 | elif gpu_num == 4 or gpu_num == 5 or gpu_num == 6:
196 | x = OrderedDict([(4, 64), (8, 64), (16, 32), (32, 16), (64, 8), (128, 4), (256, 4), (512, 4), (1024, 4)])
197 |
198 | elif gpu_num == 7 or gpu_num == 8 or gpu_num == 9:
199 | x = OrderedDict([(4, 32), (8, 32), (16, 16), (32, 8), (64, 4), (128, 4), (256, 4), (512, 4), (1024, 4)])
200 |
201 | else: # >= 10
202 | x = OrderedDict([(4, 16), (8, 16), (16, 8), (32, 4), (64, 2), (128, 2), (256, 2), (512, 2), (1024, 2)])
203 |
204 | return x
205 |
206 | def multi_gpu_loss(x, global_batch_size):
207 | ndim = len(x.shape)
208 | no_batch_axis = list(range(1, ndim))
209 | x = tf.reduce_mean(x, axis=no_batch_axis)
210 | x = tf.reduce_sum(x) / global_batch_size
211 |
212 | return x
213 |
214 | class EasyDict(dict):
215 | from typing import Any
216 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
217 |
218 | def __getattr__(self, name: str) -> Any:
219 | try:
220 | return self[name]
221 | except KeyError:
222 | raise AttributeError(name)
223 |
224 | def __setattr__(self, name: str, value: Any) -> None:
225 | self[name] = value
226 |
227 | def __delattr__(self, name: str) -> None:
228 | del self[name]
--------------------------------------------------------------------------------