├── .gitignore ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── colabs ├── Convex_Polygons_Dataset.ipynb ├── s3gan_demo.ipynb └── ssgan_demo.ipynb ├── compare_gan ├── __init__.py ├── architectures │ ├── __init__.py │ ├── abstract_arch.py │ ├── arch_ops.py │ ├── arch_ops_test.py │ ├── arch_ops_tpu_test.py │ ├── architectures_test.py │ ├── dcgan.py │ ├── infogan.py │ ├── resnet30.py │ ├── resnet5.py │ ├── resnet_biggan.py │ ├── resnet_biggan_deep.py │ ├── resnet_biggan_deep_test.py │ ├── resnet_biggan_test.py │ ├── resnet_cifar.py │ ├── resnet_init_test.py │ ├── resnet_norm_test.py │ ├── resnet_ops.py │ ├── resnet_stl.py │ └── sndcgan.py ├── datasets.py ├── datasets_test.py ├── eval_gan_lib.py ├── eval_gan_lib_test.py ├── eval_utils.py ├── gans │ ├── __init__.py │ ├── abstract_gan.py │ ├── consts.py │ ├── loss_lib.py │ ├── modular_gan.py │ ├── modular_gan_conditional_test.py │ ├── modular_gan_test.py │ ├── modular_gan_tpu_test.py │ ├── ops.py │ ├── penalty_lib.py │ ├── s3gan.py │ ├── s3gan_test.py │ ├── ssgan.py │ ├── ssgan_test.py │ └── utils.py ├── hooks.py ├── main.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── eval_task.py │ ├── fid_score.py │ ├── fid_score_test.py │ ├── fractal_dimension.py │ ├── fractal_dimension_test.py │ ├── gilbo.py │ ├── image_similarity.py │ ├── inception_score.py │ ├── jacobian_conditioning.py │ ├── jacobian_conditioning_test.py │ ├── kid_score.py │ ├── ms_ssim_score.py │ ├── ms_ssim_score_test.py │ ├── prd_score.py │ └── prd_score_test.py ├── runner_lib.py ├── runner_lib_test.py ├── test_utils.py ├── tpu │ ├── __init__.py │ ├── tpu_ops.py │ ├── tpu_ops_test.py │ ├── tpu_random.py │ ├── tpu_random_test.py │ └── tpu_summaries.py └── utils.py ├── example_configs ├── README.md ├── biggan_imagenet128.gin ├── dcgan_celeba64.gin ├── resnet_cifar10.gin ├── resnet_lsun-bedroom128.gin └── sndcgan_celebahq128.gin └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Python egg metadata, regenerated from source files by setuptools. 9 | /*.egg-info 10 | .eggs/ 11 | 12 | # PyPI distribution artifacts. 13 | build/ 14 | dist/ 15 | 16 | # Sublime project files 17 | *.sublime-project 18 | *.sublime-workspace 19 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of authors for copyright purposes. 2 | # 3 | # This does not necessarily list everyone who has contributed code, since in 4 | # some cases, their employer may be the copyright holder. To see the full list 5 | # of contributors, see the revision history in source control. 6 | 7 | Google LLC 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Issues 4 | 5 | * Please tag your issue with `bug`, `feature request`, or `question` to help us 6 | effectively respond. 7 | * Please include the versions of TensorFlow and Tensor2Tensor you are running 8 | (run `pip list | grep tensor`) 9 | * Please provide the command line you ran as well as the log output. 10 | 11 | # Pull Requests 12 | 13 | We'd love to accept your patches and contributions to this project. There are 14 | just a few small guidelines you need to follow. 15 | 16 | ## Contributor License Agreement 17 | 18 | Contributions to this project must be accompanied by a Contributor License 19 | Agreement. You (or your employer) retain the copyright to your contribution, 20 | this simply gives us permission to use and redistribute your contributions as 21 | part of the project. Head over to to see 22 | your current agreements on file or to sign a new one. 23 | 24 | You generally only need to submit a CLA once, so if you've already submitted one 25 | (even if it was for a different project), you probably don't need to do it 26 | again. 27 | 28 | ## Code reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compare GAN 2 | 3 | This repository offers TensorFlow implementations for many components related to 4 | **Generative Adversarial Networks**: 5 | 6 | * losses (such non-saturating GAN, least-squares GAN, and WGAN), 7 | * penalties (such as the gradient penalty), 8 | * normalization techniques (such as spectral normalization, batch 9 | normalization, and layer normalization), 10 | * neural architectures (BigGAN, ResNet, DCGAN), and 11 | * evaluation metrics (FID score, Inception Score, precision-recall, and KID 12 | score). 13 | 14 | The code is **configurable via [Gin](https://github.com/google/gin-config)** and 15 | runs on **GPU/TPU/CPUs**. Several research papers make use of this repository, 16 | including: 17 | 18 | 1. [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337) 19 | [[Code]](https://github.com/google/compare_gan/tree/v1) 20 | \ 21 | Mario Lucic*, Karol Kurach*, Marcin Michalski, Sylvain Gelly, Olivier 22 | Bousquet **[NeurIPS 2018]** 23 | 24 | 2. [The GAN Landscape: Losses, Architectures, Regularization, and Normalization](https://arxiv.org/abs/1807.04720) 25 | [[Code]](https://github.com/google/compare_gan/tree/v2) 26 | [[Colab]](https://colab.research.google.com/github/google/compare_gan/blob/v2/compare_gan/src/tfhub_models.ipynb) 27 | \ 28 | Karol Kurach*, Mario Lucic*, Xiaohua Zhai, Marcin Michalski, Sylvain Gelly 29 | **[ICML 2019]** 30 | 31 | 3. [Assessing Generative Models via Precision and Recall](https://arxiv.org/abs/1806.00035) 32 | [[Code]](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/prd_score.py) 33 | \ 34 | Mehdi S. M. Sajjadi, Olivier Bachem, Mario Lucic, Olivier Bousquet, Sylvain 35 | Gelly **[NeurIPS 2018]** 36 | 37 | 4. [GILBO: One Metric to Measure Them All](https://arxiv.org/abs/1802.04874) 38 | [[Code]](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/gilbo.py) 39 | \ 40 | Alexander A. Alemi, Ian Fischer **[NeurIPS 2018]** 41 | 42 | 5. [A Case for Object Compositionality in Deep Generative Models of Images](https://arxiv.org/abs/1810.10340) 43 | [[Code]](https://github.com/google/compare_gan/tree/v2_multigan) 44 | \ 45 | Sjoerd van Steenkiste, Karol Kurach, Sylvain Gelly **[2018]** 46 | 47 | 6. [On Self Modulation for Generative Adversarial Networks](https://arxiv.org/abs/1810.01365) 48 | [[Code]](https://github.com/google/compare_gan) \ 49 | Ting Chen, Mario Lucic, Neil Houlsby, Sylvain Gelly **[ICLR 2019]** 50 | 51 | 7. [Self-Supervised GANs via Auxiliary Rotation Loss](https://arxiv.org/abs/1811.11212) 52 | [[Code]](https://github.com/google/compare_gan) 53 | [[Colab]](https://colab.research.google.com/github/google/compare_gan/blob/v3/colabs/ssgan_demo.ipynb) 54 | \ 55 | Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby **[CVPR 56 | 2019]** 57 | 58 | 8. [High-Fidelity Image Generation With Fewer Labels](https://arxiv.org/abs/1903.02271) 59 | [[Code]](https://github.com/google/compare_gan) 60 | [[Blog Post]](https://ai.googleblog.com/2019/03/reducing-need-for-labeled-data-in.html) 61 | [[Colab]](https://colab.research.google.com/github/google/compare_gan/blob/v3/colabs/s3gan_demo.ipynb) 62 | \ 63 | Mario Lucic*, Michael Tschannen*, Marvin Ritter*, Xiaohua Zhai, Olivier 64 | Bachem, Sylvain Gelly **[ICML 2019]** 65 | 66 | ## Installation 67 | 68 | You can easily install the library and all necessary dependencies by running: 69 | `pip install -e .` from the `compare_gan/` folder. 70 | 71 | ## Running experiments 72 | 73 | Simply run the `main.py` passing a `--model_dir` (this is where checkpoints are 74 | stored) and a `--gin_config` (defines which model is trained on which data set 75 | and other training options). We provide several example configurations in the 76 | `example_configs/` folder: 77 | 78 | * **dcgan_celeba64**: DCGAN architecture with non-saturating loss on CelebA 79 | 64x64px 80 | * **resnet_cifar10**: ResNet architecture with non-saturating loss and 81 | spectral normalization on CIFAR-10 82 | * **resnet_lsun-bedroom128**: ResNet architecture with WGAN loss and gradient 83 | penalty on LSUN-bedrooms 128x128px 84 | * **sndcgan_celebahq128**: SN-DCGAN architecture with non-saturating loss and 85 | spectral normalization on CelebA-HQ 128x128px 86 | * **biggan_imagenet128**: BigGAN architecture with hinge loss and spectral 87 | normalization on ImageNet 128x128px 88 | 89 | ### Training and evaluation 90 | 91 | To see all available options please run `python main.py --help`. Main options: 92 | 93 | * To **train** the model use `--schedule=train` (default). Training is resumed 94 | from the last saved checkpoint. 95 | * To **evaluate** all checkpoints use `--schedule=continuous_eval 96 | --eval_every_steps=0`. To evaluate only checkpoints where the step size is 97 | divisible by 5000, use `--schedule=continuous_eval --eval_every_steps=5000`. 98 | By default, 3 averaging runs are used to estimate the Inception Score and 99 | the FID score. Keep in mind that when running locally on a single GPU it may 100 | not be possible to run training and evaluation simultaneously due to memory 101 | constraints. 102 | * To **train and evaluate** the model use `--schedule=eval_after_train 103 | --eval_every_steps=0`. 104 | 105 | ### Training on Cloud TPUs 106 | 107 | We recommend using the 108 | [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu) to create 109 | a Cloud TPU and corresponding Compute Engine VM. We use v3-128 Cloud TPU v3 Pod 110 | for training models on ImageNet in 128x128 resolutions. You can use smaller 111 | slices if you reduce the batch size (`options.batch_size` in the Gin config) or 112 | model parameters. Keep in mind that the model quality might change. Before 113 | training make sure that the environment variable `TPU_NAME` is set. Running 114 | evaluation on TPUs is currently not supported. Use a VM with a single GPU 115 | instead. 116 | 117 | ### Datasets 118 | 119 | Compare GAN uses [TensorFlow Datasets](https://www.tensorflow.org/datasets) and 120 | it will automatically download and prepare the data. For ImageNet you will need 121 | to download the archive yourself. For CelebAHq you need to download and prepare 122 | the images on your own. If you are using TPUs make sure to point the training 123 | script to your Google Storage Bucket (`--tfds_data_dir`). 124 | -------------------------------------------------------------------------------- /compare_gan/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # coding=utf-8 17 | -------------------------------------------------------------------------------- /compare_gan/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # coding=utf-8 17 | -------------------------------------------------------------------------------- /compare_gan/architectures/abstract_arch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines interfaces for generator and discriminator networks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | from compare_gan import utils 24 | import gin 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | @six.add_metaclass(abc.ABCMeta) 30 | class _Module(object): 31 | """Base class for architectures. 32 | 33 | Long term this will be replaced by `tf.Module` in TF 2.0. 34 | """ 35 | 36 | def __init__(self, name): 37 | self._name = name 38 | 39 | @property 40 | def name(self): 41 | return self._name 42 | 43 | @property 44 | def trainable_variables(self): 45 | return [var for var in tf.trainable_variables() if self._name in var.name] 46 | 47 | 48 | @gin.configurable("G", blacklist=["name", "image_shape"]) 49 | class AbstractGenerator(_Module): 50 | """Interface for generator architectures.""" 51 | 52 | def __init__(self, 53 | name="generator", 54 | image_shape=None, 55 | batch_norm_fn=None, 56 | spectral_norm=False): 57 | """Constructor for all generator architectures. 58 | 59 | Args: 60 | name: Scope name of the generator. 61 | image_shape: Image shape to be generated, [height, width, colors]. 62 | batch_norm_fn: Function for batch normalization or None. 63 | spectral_norm: If True use spectral normalization for all weights. 64 | """ 65 | super(AbstractGenerator, self).__init__(name=name) 66 | self._name = name 67 | self._image_shape = image_shape 68 | self._batch_norm_fn = batch_norm_fn 69 | self._spectral_norm = spectral_norm 70 | 71 | def __call__(self, z, y, is_training, reuse=tf.AUTO_REUSE): 72 | with tf.variable_scope(self.name, values=[z, y], reuse=reuse): 73 | outputs = self.apply(z=z, y=y, is_training=is_training) 74 | return outputs 75 | 76 | def batch_norm(self, inputs, **kwargs): 77 | if self._batch_norm_fn is None: 78 | return inputs 79 | args = kwargs.copy() 80 | args["inputs"] = inputs 81 | if "use_sn" not in args: 82 | args["use_sn"] = self._spectral_norm 83 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 84 | 85 | @abc.abstractmethod 86 | def apply(self, z, y, is_training): 87 | """Apply the generator on a input. 88 | 89 | Args: 90 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 91 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 92 | labels. 93 | is_training: Boolean, whether the architecture should be constructed for 94 | training or inference. 95 | 96 | Returns: 97 | Generated images of shape [batch_size] + self.image_shape. 98 | """ 99 | 100 | 101 | @gin.configurable("D", blacklist=["name"]) 102 | class AbstractDiscriminator(_Module): 103 | """Interface for discriminator architectures.""" 104 | 105 | def __init__(self, 106 | name="discriminator", 107 | batch_norm_fn=None, 108 | layer_norm=False, 109 | spectral_norm=False): 110 | super(AbstractDiscriminator, self).__init__(name=name) 111 | self._name = name 112 | self._batch_norm_fn = batch_norm_fn 113 | self._layer_norm = layer_norm 114 | self._spectral_norm = spectral_norm 115 | 116 | def __call__(self, x, y, is_training, reuse=tf.AUTO_REUSE): 117 | with tf.variable_scope(self.name, values=[x, y], reuse=reuse): 118 | outputs = self.apply(x=x, y=y, is_training=is_training) 119 | return outputs 120 | 121 | def batch_norm(self, inputs, **kwargs): 122 | if self._batch_norm_fn is None: 123 | return inputs 124 | args = kwargs.copy() 125 | args["inputs"] = inputs 126 | if "use_sn" not in args: 127 | args["use_sn"] = self._spectral_norm 128 | return utils.call_with_accepted_args(self._batch_norm_fn, **args) 129 | 130 | 131 | @abc.abstractmethod 132 | def apply(self, x, y, is_training): 133 | """Apply the discriminator on a input. 134 | 135 | Args: 136 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 137 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 138 | labels. 139 | is_training: Boolean, whether the architecture should be constructed for 140 | training or inference. 141 | 142 | Returns: 143 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 144 | before the final output activation function and logits form the second 145 | last layer. 146 | """ 147 | -------------------------------------------------------------------------------- /compare_gan/architectures/arch_ops_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for custom architecture operations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.architectures import arch_ops 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | class ArchOpsTest(tf.test.TestCase): 28 | 29 | def testBatchNorm(self): 30 | with tf.Graph().as_default(): 31 | # 4 images with resolution 2x1 and 3 channels. 32 | x1 = tf.constant([[[5, 7, 2]], [[5, 8, 8]]], dtype=tf.float32) 33 | x2 = tf.constant([[[1, 2, 0]], [[4, 0, 4]]], dtype=tf.float32) 34 | x3 = tf.constant([[[6, 2, 6]], [[5, 0, 5]]], dtype=tf.float32) 35 | x4 = tf.constant([[[2, 4, 2]], [[6, 4, 1]]], dtype=tf.float32) 36 | x = tf.stack([x1, x2, x3, x4]) 37 | self.assertAllEqual(x.shape.as_list(), [4, 2, 1, 3]) 38 | 39 | core_bn = tf.layers.batch_normalization(x, training=True) 40 | contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True) 41 | custom_bn = arch_ops.batch_norm(x, is_training=True) 42 | with self.session() as sess: 43 | sess.run(tf.global_variables_initializer()) 44 | core_bn, contrib_bn, custom_bn = sess.run( 45 | [core_bn, contrib_bn, custom_bn]) 46 | tf.logging.info("core_bn: %s", core_bn[0]) 47 | tf.logging.info("contrib_bn: %s", contrib_bn[0]) 48 | tf.logging.info("custom_bn: %s", custom_bn[0]) 49 | self.assertAllClose(core_bn, contrib_bn) 50 | self.assertAllClose(custom_bn, contrib_bn) 51 | expected_values = np.asarray( 52 | [[[[0.4375205, 1.30336881, -0.58830315]], 53 | [[0.4375205, 1.66291881, 1.76490951]]], 54 | [[[-1.89592218, -0.49438119, -1.37270737]], 55 | [[-0.14584017, -1.21348119, 0.19610107]]], 56 | [[[1.02088118, -0.49438119, 0.98050523]], 57 | [[0.4375205, -1.21348119, 0.58830321]]], 58 | [[[-1.31256151, 0.22471881, -0.58830315]], 59 | [[1.02088118, 0.22471881, -0.98050523]]]], 60 | dtype=np.float32) 61 | self.assertAllClose(custom_bn, expected_values) 62 | 63 | def testAccumulatedMomentsDuringTraing(self): 64 | with tf.Graph().as_default(): 65 | mean_in = tf.placeholder(tf.float32, shape=[2]) 66 | variance_in = tf.placeholder(tf.float32, shape=[2]) 67 | mean, variance = arch_ops._accumulated_moments_for_inference( 68 | mean=mean_in, variance=variance_in, is_training=True) 69 | variables_by_name = {v.op.name: v for v in tf.global_variables()} 70 | tf.logging.error(variables_by_name) 71 | accu_mean = variables_by_name["accu/accu_mean"] 72 | accu_variance = variables_by_name["accu/accu_variance"] 73 | accu_counter = variables_by_name["accu/accu_counter"] 74 | with self.session() as sess: 75 | sess.run(tf.global_variables_initializer()) 76 | m1, v1 = sess.run( 77 | [mean, variance], 78 | feed_dict={mean_in: [1.0, 2.0], variance_in: [3.0, 4.0]}) 79 | self.assertAllClose(m1, [1.0, 2.0]) 80 | self.assertAllClose(v1, [3.0, 4.0]) 81 | m2, v2 = sess.run( 82 | [mean, variance], 83 | feed_dict={mean_in: [5.0, 6.0], variance_in: [7.0, 8.0]}) 84 | self.assertAllClose(m2, [5.0, 6.0]) 85 | self.assertAllClose(v2, [7.0, 8.0]) 86 | am, av, ac = sess.run([accu_mean, accu_variance, accu_counter]) 87 | self.assertAllClose(am, [0.0, 0.0]) 88 | self.assertAllClose(av, [0.0, 0.0]) 89 | self.assertAllClose([ac], [0.0]) 90 | 91 | def testAccumulatedMomentsDuringEal(self): 92 | with tf.Graph().as_default(): 93 | mean_in = tf.placeholder(tf.float32, shape=[2]) 94 | variance_in = tf.placeholder(tf.float32, shape=[2]) 95 | mean, variance = arch_ops._accumulated_moments_for_inference( 96 | mean=mean_in, variance=variance_in, is_training=False) 97 | variables_by_name = {v.op.name: v for v in tf.global_variables()} 98 | tf.logging.error(variables_by_name) 99 | accu_mean = variables_by_name["accu/accu_mean"] 100 | accu_variance = variables_by_name["accu/accu_variance"] 101 | accu_counter = variables_by_name["accu/accu_counter"] 102 | update_accus = variables_by_name["accu/update_accus"] 103 | with self.session() as sess: 104 | sess.run(tf.global_variables_initializer()) 105 | # Fill accumulators. 106 | sess.run(tf.assign(update_accus, 1)) 107 | m1, v1 = sess.run( 108 | [mean, variance], 109 | feed_dict={mean_in: [1.0, 2.0], variance_in: [3.0, 4.0]}) 110 | self.assertAllClose(m1, [1.0, 2.0]) 111 | self.assertAllClose(v1, [3.0, 4.0]) 112 | m2, v2 = sess.run( 113 | [mean, variance], 114 | feed_dict={mean_in: [5.0, 6.0], variance_in: [7.0, 8.0]}) 115 | self.assertAllClose(m2, [3.0, 4.0]) 116 | self.assertAllClose(v2, [5.0, 6.0]) 117 | # Check accumulators. 118 | am, av, ac = sess.run([accu_mean, accu_variance, accu_counter]) 119 | self.assertAllClose(am, [6.0, 8.0]) 120 | self.assertAllClose(av, [10.0, 12.0]) 121 | self.assertAllClose([ac], [2.0]) 122 | # Use accumulators. 123 | sess.run(tf.assign(update_accus, 0)) 124 | m3, v3 = sess.run( 125 | [mean, variance], 126 | feed_dict={mean_in: [2.0, 2.0], variance_in: [3.0, 3.0]}) 127 | self.assertAllClose(m3, [3.0, 4.0]) 128 | self.assertAllClose(v3, [5.0, 6.0]) 129 | am, av, ac = sess.run([accu_mean, accu_variance, accu_counter]) 130 | self.assertAllClose(am, [6.0, 8.0]) 131 | self.assertAllClose(av, [10.0, 12.0]) 132 | self.assertAllClose([ac], [2.0]) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /compare_gan/architectures/arch_ops_tpu_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for custom architecture operations on TPUs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import logging 23 | from compare_gan.architectures import arch_ops 24 | import gin 25 | import numpy as np 26 | import tensorflow as tf 27 | 28 | 29 | class ArchOpsTpuTest(tf.test.TestCase): 30 | 31 | def setUp(self): 32 | # Construct input for batch norm tests: 33 | # 4 images with resolution 2x1 and 3 channels. 34 | x1 = np.asarray([[[5, 7, 2]], [[5, 8, 8]]], dtype=np.float32) 35 | x2 = np.asarray([[[1, 2, 0]], [[4, 0, 4]]], dtype=np.float32) 36 | x3 = np.asarray([[[6, 2, 6]], [[5, 0, 5]]], dtype=np.float32) 37 | x4 = np.asarray([[[2, 4, 2]], [[6, 4, 1]]], dtype=np.float32) 38 | self._inputs = np.stack([x1, x2, x3, x4]) 39 | self.assertAllEqual(self._inputs.shape, [4, 2, 1, 3]) 40 | # And the expected output for applying batch norm (without additional 41 | # scaling/shifting). 42 | self._expected_outputs = np.asarray( 43 | [[[[0.4375205, 1.30336881, -0.58830315]], 44 | [[0.4375205, 1.66291881, 1.76490951]]], 45 | [[[-1.89592218, -0.49438119, -1.37270737]], 46 | [[-0.14584017, -1.21348119, 0.19610107]]], 47 | [[[1.02088118, -0.49438119, 0.98050523]], 48 | [[0.4375205, -1.21348119, 0.58830321]]], 49 | [[[-1.31256151, 0.22471881, -0.58830315]], 50 | [[1.02088118, 0.22471881, -0.98050523]]]], 51 | dtype=np.float32) 52 | self.assertAllEqual(self._expected_outputs.shape, [4, 2, 1, 3]) 53 | 54 | def testRunsOnTpu(self): 55 | """Verify that the test cases runs on a TPU chip and has 2 cores.""" 56 | expected_device_names = [ 57 | "/job:localhost/replica:0/task:0/device:CPU:0", 58 | "/job:localhost/replica:0/task:0/device:TPU:0", 59 | "/job:localhost/replica:0/task:0/device:TPU:1", 60 | "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", 61 | ] 62 | with self.session() as sess: 63 | devices = sess.list_devices() 64 | tf.logging.info("devices:\n%s", "\n".join([str(d) for d in devices])) 65 | self.assertAllEqual([d.name for d in devices], expected_device_names) 66 | 67 | def testBatchNormOneCore(self): 68 | def computation(x): 69 | core_bn = tf.layers.batch_normalization(x, training=True) 70 | contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True) 71 | custom_bn = arch_ops.batch_norm(x, is_training=True) 72 | tf.logging.info("custom_bn tensor: %s", custom_bn) 73 | return core_bn, contrib_bn, custom_bn 74 | 75 | with tf.Graph().as_default(): 76 | x = tf.constant(self._inputs) 77 | core_bn, contrib_bn, custom_bn = tf.contrib.tpu.batch_parallel( 78 | computation, [x], num_shards=1) 79 | 80 | with self.session() as sess: 81 | sess.run(tf.contrib.tpu.initialize_system()) 82 | sess.run(tf.global_variables_initializer()) 83 | core_bn, contrib_bn, custom_bn = sess.run( 84 | [core_bn, contrib_bn, custom_bn]) 85 | logging.info("core_bn: %s", core_bn) 86 | logging.info("contrib_bn: %s", contrib_bn) 87 | logging.info("custom_bn: %s", custom_bn) 88 | self.assertAllClose(core_bn, self._expected_outputs) 89 | self.assertAllClose(contrib_bn, self._expected_outputs) 90 | self.assertAllClose(custom_bn, self._expected_outputs) 91 | 92 | def testBatchNormTwoCoresCoreAndContrib(self): 93 | def computation(x): 94 | core_bn = tf.layers.batch_normalization(x, training=True) 95 | contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True) 96 | return core_bn, contrib_bn 97 | 98 | with tf.Graph().as_default(): 99 | x = tf.constant(self._inputs) 100 | core_bn, contrib_bn = tf.contrib.tpu.batch_parallel( 101 | computation, [x], num_shards=2) 102 | 103 | with self.session() as sess: 104 | sess.run(tf.contrib.tpu.initialize_system()) 105 | sess.run(tf.global_variables_initializer()) 106 | core_bn, contrib_bn = sess.run([core_bn, contrib_bn]) 107 | logging.info("core_bn: %s", core_bn) 108 | logging.info("contrib_bn: %s", contrib_bn) 109 | self.assertNotAllClose(core_bn, self._expected_outputs) 110 | self.assertNotAllClose(contrib_bn, self._expected_outputs) 111 | 112 | def testBatchNormTwoCoresCustom(self): 113 | def computation(x): 114 | custom_bn = arch_ops.batch_norm(x, is_training=True, name="custom_bn") 115 | gin.bind_parameter("cross_replica_moments.parallel", False) 116 | custom_bn_seq = arch_ops.batch_norm(x, is_training=True, 117 | name="custom_bn_seq") 118 | return custom_bn, custom_bn_seq 119 | 120 | with tf.Graph().as_default(): 121 | x = tf.constant(self._inputs) 122 | custom_bn, custom_bn_seq = tf.contrib.tpu.batch_parallel( 123 | computation, [x], num_shards=2) 124 | 125 | with self.session() as sess: 126 | sess.run(tf.contrib.tpu.initialize_system()) 127 | sess.run(tf.global_variables_initializer()) 128 | custom_bn, custom_bn_seq = sess.run( 129 | [custom_bn, custom_bn_seq]) 130 | logging.info("custom_bn: %s", custom_bn) 131 | logging.info("custom_bn_seq: %s", custom_bn_seq) 132 | self.assertAllClose(custom_bn, self._expected_outputs) 133 | self.assertAllClose(custom_bn_seq, self._expected_outputs) 134 | 135 | 136 | if __name__ == "__main__": 137 | tf.test.main() 138 | -------------------------------------------------------------------------------- /compare_gan/architectures/architectures_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for neural architectures.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import parameterized 23 | from compare_gan.architectures import dcgan 24 | from compare_gan.architectures import infogan 25 | from compare_gan.architectures import resnet30 26 | from compare_gan.architectures import resnet5 27 | from compare_gan.architectures import resnet_biggan 28 | from compare_gan.architectures import resnet_cifar 29 | from compare_gan.architectures import resnet_stl 30 | from compare_gan.architectures import sndcgan 31 | import tensorflow as tf 32 | 33 | 34 | class ArchitectureTest(parameterized.TestCase, tf.test.TestCase): 35 | 36 | def assertArchitectureBuilds(self, gen, disc, image_shape, z_dim=120): 37 | with tf.Graph().as_default(): 38 | batch_size = 2 39 | num_classes = 10 40 | # Prepare inputs 41 | z = tf.random.normal((batch_size, z_dim), name="z") 42 | y = tf.one_hot(tf.range(batch_size), num_classes) 43 | # Run check output shapes for G and D. 44 | x = gen(z=z, y=y, is_training=True, reuse=False) 45 | self.assertAllEqual(x.shape.as_list()[1:], image_shape) 46 | out, _, _ = disc( 47 | x, y=y, is_training=True, reuse=False) 48 | self.assertAllEqual(out.shape.as_list(), (batch_size, 1)) 49 | # Check that G outputs valid pixel values (we use [0, 1] everywhere) and 50 | # D outputs a probablilty. 51 | with self.session() as sess: 52 | sess.run(tf.global_variables_initializer()) 53 | image, pred = sess.run([x, out]) 54 | self.assertAllGreaterEqual(image, 0) 55 | self.assertAllLessEqual(image, 1) 56 | self.assertAllGreaterEqual(pred, 0) 57 | self.assertAllLessEqual(pred, 1) 58 | 59 | @parameterized.parameters( 60 | {"image_shape": (28, 28, 1)}, 61 | {"image_shape": (32, 32, 1)}, 62 | {"image_shape": (32, 32, 3)}, 63 | {"image_shape": (64, 64, 3)}, 64 | {"image_shape": (128, 128, 3)}, 65 | ) 66 | def testDcGan(self, image_shape): 67 | self.assertArchitectureBuilds( 68 | gen=dcgan.Generator(image_shape=image_shape), 69 | disc=dcgan.Discriminator(), 70 | image_shape=image_shape) 71 | 72 | @parameterized.parameters( 73 | {"image_shape": (28, 28, 1)}, 74 | {"image_shape": (32, 32, 1)}, 75 | {"image_shape": (32, 32, 3)}, 76 | {"image_shape": (64, 64, 3)}, 77 | {"image_shape": (128, 128, 3)}, 78 | ) 79 | def testInfoGan(self, image_shape): 80 | self.assertArchitectureBuilds( 81 | gen=infogan.Generator(image_shape=image_shape), 82 | disc=infogan.Discriminator(), 83 | image_shape=image_shape) 84 | 85 | def testResNet30(self, image_shape=(128, 128, 3)): 86 | self.assertArchitectureBuilds( 87 | gen=resnet30.Generator(image_shape=image_shape), 88 | disc=resnet30.Discriminator(), 89 | image_shape=image_shape) 90 | 91 | @parameterized.parameters( 92 | {"image_shape": (32, 32, 1)}, 93 | {"image_shape": (32, 32, 3)}, 94 | {"image_shape": (64, 64, 3)}, 95 | {"image_shape": (128, 128, 3)}, 96 | ) 97 | def testResNet5(self, image_shape): 98 | self.assertArchitectureBuilds( 99 | gen=resnet5.Generator(image_shape=image_shape), 100 | disc=resnet5.Discriminator(), 101 | image_shape=image_shape) 102 | 103 | @parameterized.parameters( 104 | {"image_shape": (32, 32, 3)}, 105 | {"image_shape": (64, 64, 3)}, 106 | {"image_shape": (128, 128, 3)}, 107 | {"image_shape": (256, 256, 3)}, 108 | {"image_shape": (512, 512, 3)}, 109 | ) 110 | def testResNet5BigGan(self, image_shape): 111 | if image_shape[0] == 512: 112 | z_dim = 160 113 | elif image_shape[0] == 256: 114 | z_dim = 140 115 | else: 116 | z_dim = 120 117 | # Use channel multiplier 4 to avoid OOM errors. 118 | self.assertArchitectureBuilds( 119 | gen=resnet_biggan.Generator(image_shape=image_shape, ch=16), 120 | disc=resnet_biggan.Discriminator(ch=16), 121 | image_shape=image_shape, 122 | z_dim=z_dim) 123 | 124 | @parameterized.parameters( 125 | {"image_shape": (32, 32, 1)}, 126 | {"image_shape": (32, 32, 3)}, 127 | ) 128 | def testResNetCifar(self, image_shape): 129 | self.assertArchitectureBuilds( 130 | gen=resnet_cifar.Generator(image_shape=image_shape), 131 | disc=resnet_cifar.Discriminator(), 132 | image_shape=image_shape) 133 | 134 | @parameterized.parameters( 135 | {"image_shape": (48, 48, 1)}, 136 | {"image_shape": (48, 48, 3)}, 137 | ) 138 | def testResNetStl(self, image_shape): 139 | self.assertArchitectureBuilds( 140 | gen=resnet_stl.Generator(image_shape=image_shape), 141 | disc=resnet_stl.Discriminator(), 142 | image_shape=image_shape) 143 | 144 | @parameterized.parameters( 145 | {"image_shape": (28, 28, 1)}, 146 | {"image_shape": (32, 32, 1)}, 147 | {"image_shape": (32, 32, 3)}, 148 | {"image_shape": (64, 64, 3)}, 149 | {"image_shape": (128, 128, 3)}, 150 | ) 151 | def testSnDcGan(self, image_shape): 152 | self.assertArchitectureBuilds( 153 | gen=sndcgan.Generator(image_shape=image_shape), 154 | disc=sndcgan.Discriminator(), 155 | image_shape=image_shape) 156 | 157 | 158 | if __name__ == "__main__": 159 | tf.test.main() 160 | -------------------------------------------------------------------------------- /compare_gan/architectures/dcgan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of DCGAN generator and discriminator architectures. 17 | 18 | Details are available in https://arxiv.org/abs/1511.06434. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from compare_gan.architectures import abstract_arch 26 | from compare_gan.architectures.arch_ops import conv2d 27 | from compare_gan.architectures.arch_ops import deconv2d 28 | from compare_gan.architectures.arch_ops import linear 29 | from compare_gan.architectures.arch_ops import lrelu 30 | 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | 35 | def conv_out_size_same(size, stride): 36 | return int(np.ceil(float(size) / float(stride))) 37 | 38 | 39 | class Generator(abstract_arch.AbstractGenerator): 40 | """DCGAN generator. 41 | 42 | Details are available at https://arxiv.org/abs/1511.06434. Notable changes 43 | include BatchNorm in the generator, ReLu instead of LeakyReLu and ReLu in the 44 | generator, except for output which uses tanh. 45 | """ 46 | 47 | def apply(self, z, y, is_training): 48 | """Build the generator network for the given inputs. 49 | 50 | Args: 51 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 52 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 53 | labels. 54 | is_training: boolean, are we in train or eval model. 55 | 56 | Returns: 57 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 58 | """ 59 | gf_dim = 64 # Dimension of filters in first convolutional layer. 60 | bs = z.shape[0].value 61 | s_h, s_w, colors = self._image_shape 62 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 63 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 64 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 65 | s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) 66 | 67 | net = linear(z, gf_dim * 8 *s_h16 * s_w16, scope="g_fc1") 68 | net = tf.reshape(net, [-1, s_h16, s_w16, gf_dim * 8]) 69 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn1") 70 | net = tf.nn.relu(net) 71 | net = deconv2d(net, [bs, s_h8, s_w8, gf_dim*4], 5, 5, 2, 2, name="g_dc1") 72 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn2") 73 | net = tf.nn.relu(net) 74 | net = deconv2d(net, [bs, s_h4, s_w4, gf_dim*2], 5, 5, 2, 2, name="g_dc2") 75 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn3") 76 | net = tf.nn.relu(net) 77 | net = deconv2d(net, [bs, s_h2, s_w2, gf_dim*1], 5, 5, 2, 2, name="g_dc3") 78 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn4") 79 | net = tf.nn.relu(net) 80 | net = deconv2d(net, [bs, s_h, s_w, colors], 5, 5, 2, 2, name="g_dc4") 81 | net = 0.5 * tf.nn.tanh(net) + 0.5 82 | return net 83 | 84 | 85 | class Discriminator(abstract_arch.AbstractDiscriminator): 86 | """DCGAN discriminator. 87 | 88 | Details are available at https://arxiv.org/abs/1511.06434. Notable changes 89 | include BatchNorm in the discriminator and LeakyReLU for all layers. 90 | """ 91 | 92 | def apply(self, x, y, is_training): 93 | """Apply the discriminator on a input. 94 | 95 | Args: 96 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 97 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 98 | labels. 99 | is_training: Boolean, whether the architecture should be constructed for 100 | training or inference. 101 | 102 | Returns: 103 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 104 | before the final output activation function and logits form the second 105 | last layer. 106 | """ 107 | bs = x.shape[0].value 108 | df_dim = 64 # Dimension of filters in the first convolutional layer. 109 | net = lrelu(conv2d(x, df_dim, 5, 5, 2, 2, name="d_conv1", 110 | use_sn=self._spectral_norm)) 111 | net = conv2d(net, df_dim * 2, 5, 5, 2, 2, name="d_conv2", 112 | use_sn=self._spectral_norm) 113 | 114 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn1") 115 | net = lrelu(net) 116 | net = conv2d(net, df_dim * 4, 5, 5, 2, 2, name="d_conv3", 117 | use_sn=self._spectral_norm) 118 | 119 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn2") 120 | net = lrelu(net) 121 | net = conv2d(net, df_dim * 8, 5, 5, 2, 2, name="d_conv4", 122 | use_sn=self._spectral_norm) 123 | 124 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn3") 125 | net = lrelu(net) 126 | out_logit = linear( 127 | tf.reshape(net, [bs, -1]), 1, scope="d_fc4", use_sn=self._spectral_norm) 128 | out = tf.nn.sigmoid(out_logit) 129 | return out, out_logit, net 130 | -------------------------------------------------------------------------------- /compare_gan/architectures/infogan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of InfoGAN generator and discriminator architectures. 17 | 18 | Details are available in https://arxiv.org/pdf/1606.03657.pdf. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from compare_gan.architectures import abstract_arch 26 | from compare_gan.architectures.arch_ops import batch_norm 27 | from compare_gan.architectures.arch_ops import conv2d 28 | from compare_gan.architectures.arch_ops import deconv2d 29 | from compare_gan.architectures.arch_ops import linear 30 | from compare_gan.architectures.arch_ops import lrelu 31 | 32 | import tensorflow as tf 33 | 34 | 35 | class Generator(abstract_arch.AbstractGenerator): 36 | """Generator architecture based on InfoGAN.""" 37 | 38 | def apply(self, z, y, is_training): 39 | """Build the generator network for the given inputs. 40 | 41 | Args: 42 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 43 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 44 | labels. 45 | is_training: boolean, are we in train or eval model. 46 | 47 | Returns: 48 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 49 | """ 50 | del y 51 | h, w, c = self._image_shape 52 | bs = z.shape.as_list()[0] 53 | net = linear(z, 1024, scope="g_fc1") 54 | net = lrelu(batch_norm(net, is_training=is_training, name="g_bn1")) 55 | net = linear(net, 128 * (h // 4) * (w // 4), scope="g_fc2") 56 | net = lrelu(batch_norm(net, is_training=is_training, name="g_bn2")) 57 | net = tf.reshape(net, [bs, h // 4, w // 4, 128]) 58 | net = deconv2d(net, [bs, h // 2, w // 2, 64], 4, 4, 2, 2, name="g_dc3") 59 | net = lrelu(batch_norm(net, is_training=is_training, name="g_bn3")) 60 | net = deconv2d(net, [bs, h, w, c], 4, 4, 2, 2, name="g_dc4") 61 | out = tf.nn.sigmoid(net) 62 | return out 63 | 64 | 65 | class Discriminator(abstract_arch.AbstractDiscriminator): 66 | """Discriminator architecture based on InfoGAN.""" 67 | 68 | def apply(self, x, y, is_training): 69 | """Apply the discriminator on a input. 70 | 71 | Args: 72 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 73 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 74 | labels. 75 | is_training: Boolean, whether the architecture should be constructed for 76 | training or inference. 77 | 78 | Returns: 79 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 80 | before the final output activation function and logits form the second 81 | last layer. 82 | """ 83 | use_sn = self._spectral_norm 84 | batch_size = x.shape.as_list()[0] 85 | # Resulting shape: [bs, h/2, w/2, 64]. 86 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name="d_conv1", use_sn=use_sn)) 87 | # Resulting shape: [bs, h/4, w/4, 128]. 88 | net = conv2d(net, 128, 4, 4, 2, 2, name="d_conv2", use_sn=use_sn) 89 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn2") 90 | net = lrelu(net) 91 | # Resulting shape: [bs, h * w * 8]. 92 | net = tf.reshape(net, [batch_size, -1]) 93 | # Resulting shape: [bs, 1024]. 94 | net = linear(net, 1024, scope="d_fc3", use_sn=use_sn) 95 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn3") 96 | net = lrelu(net) 97 | # Resulting shape: [bs, 1]. 98 | out_logit = linear(net, 1, scope="d_fc4", use_sn=use_sn) 99 | out = tf.nn.sigmoid(out_logit) 100 | return out, out_logit, net 101 | -------------------------------------------------------------------------------- /compare_gan/architectures/resnet30.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A 30-block resnet. 17 | 18 | It contains 6 "super-blocks" and each such block contains 5 residual blocks in 19 | both the generator and discriminator. It supports the 128x128 resolution. 20 | Details can be found in "Improved Training of Wasserstein GANs", Gulrajani I. 21 | et al. 2017. The related code is available at 22 | https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | from compare_gan.architectures import arch_ops as ops 30 | from compare_gan.architectures import resnet_ops 31 | 32 | from six.moves import range 33 | import tensorflow as tf 34 | 35 | 36 | class Generator(resnet_ops.ResNetGenerator): 37 | """ResNet30 generator, 30 blocks, generates images of resolution 128x128. 38 | 39 | Trying to match the architecture defined in [1]. Difference is that there 40 | the final resolution is 64x64, while here we have 128x128. 41 | """ 42 | 43 | def apply(self, z, y, is_training): 44 | """Build the generator network for the given inputs. 45 | 46 | Args: 47 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 48 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 49 | labels. 50 | is_training: boolean, are we in train or eval model. 51 | 52 | Returns: 53 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 54 | """ 55 | z_shape = z.get_shape().as_list() 56 | if len(z_shape) != 2: 57 | raise ValueError("Expected shape [batch_size, z_dim], got %s." % z_shape) 58 | ch = 64 59 | colors = self._image_shape[2] 60 | # Map noise to the actual seed. 61 | output = ops.linear(z, 4 * 4 * 8 * ch, scope="fc_noise") 62 | # Reshape the seed to be a rank-4 Tensor. 63 | output = tf.reshape(output, [-1, 4, 4, 8 * ch], name="fc_reshaped") 64 | in_channels = 8 * ch 65 | out_channels = 4 * ch 66 | for superblock in range(6): 67 | for i in range(5): 68 | block = self._resnet_block( 69 | name="B_{}_{}".format(superblock, i), 70 | in_channels=in_channels, 71 | out_channels=in_channels, 72 | scale="none") 73 | output = block(output, z=z, y=y, is_training=is_training) 74 | # We want to upscale 5 times. 75 | if superblock < 5: 76 | block = self._resnet_block( 77 | name="B_{}_up".format(superblock), 78 | in_channels=in_channels, 79 | out_channels=out_channels, 80 | scale="up") 81 | output = block(output, z=z, y=y, is_training=is_training) 82 | in_channels /= 2 83 | out_channels /= 2 84 | 85 | output = ops.conv2d( 86 | output, output_dim=colors, k_h=3, k_w=3, d_h=1, d_w=1, 87 | name="final_conv") 88 | output = tf.nn.sigmoid(output) 89 | return output 90 | 91 | 92 | class Discriminator(resnet_ops.ResNetDiscriminator): 93 | """ResNet discriminator, 30 blocks, 128x128x3 and 128x128x1 resolution.""" 94 | 95 | def apply(self, x, y, is_training): 96 | """Apply the discriminator on a input. 97 | 98 | Args: 99 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 100 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 101 | labels. 102 | is_training: Boolean, whether the architecture should be constructed for 103 | training or inference. 104 | 105 | Returns: 106 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 107 | before the final output activation function and logits form the second 108 | last layer. 109 | """ 110 | resnet_ops.validate_image_inputs(x) 111 | colors = x.get_shape().as_list()[-1] 112 | assert colors in [1, 3] 113 | ch = 64 114 | output = ops.conv2d( 115 | x, output_dim=ch // 4, k_h=3, k_w=3, d_h=1, d_w=1, 116 | name="color_conv") 117 | in_channels = ch // 4 118 | out_channels = ch // 2 119 | for superblock in range(6): 120 | for i in range(5): 121 | block = self._resnet_block( 122 | name="B_{}_{}".format(superblock, i), 123 | in_channels=in_channels, 124 | out_channels=in_channels, 125 | scale="none") 126 | output = block(output, z=None, y=y, is_training=is_training) 127 | # We want to downscale 5 times. 128 | if superblock < 5: 129 | block = self._resnet_block( 130 | name="B_{}_up".format(superblock), 131 | in_channels=in_channels, 132 | out_channels=out_channels, 133 | scale="down") 134 | output = block(output, z=None, y=y, is_training=is_training) 135 | in_channels *= 2 136 | out_channels *= 2 137 | 138 | # Final part 139 | output = tf.reshape(output, [-1, 4 * 4 * 8 * ch]) 140 | out_logit = ops.linear(output, 1, scope="disc_final_fc", 141 | use_sn=self._spectral_norm) 142 | out = tf.nn.sigmoid(out_logit) 143 | return out, out_logit, output 144 | -------------------------------------------------------------------------------- /compare_gan/architectures/resnet5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A deep neural architecture with residual blocks and skip connections. 17 | 18 | It contains 5 residual blocks in both the generator and discriminator and 19 | supports 128x128 resolution. Details can be found in "Improved Training 20 | of Wasserstein GANs", Gulrajani I. et al. 2017. The related code is available at 21 | https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from compare_gan.architectures import arch_ops as ops 29 | from compare_gan.architectures import resnet_ops 30 | 31 | import numpy as np 32 | from six.moves import range 33 | import tensorflow as tf 34 | 35 | 36 | class Generator(resnet_ops.ResNetGenerator): 37 | """ResNet generator consisting of 5 blocks, outputs 128x128x3 resolution.""" 38 | 39 | def __init__(self, ch=64, channels=(8, 8, 4, 4, 2, 1), **kwargs): 40 | super(Generator, self).__init__(**kwargs) 41 | self._ch = ch 42 | self._channels = channels 43 | 44 | def apply(self, z, y, is_training): 45 | """Build the generator network for the given inputs. 46 | 47 | Args: 48 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 49 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 50 | labels. 51 | is_training: boolean, are we in train or eval model. 52 | 53 | Returns: 54 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 55 | """ 56 | # Each block upscales by a factor of 2. 57 | seed_size = 4 58 | image_size = self._image_shape[0] 59 | 60 | # Map noise to the actual seed. 61 | net = ops.linear( 62 | z, 63 | self._ch * self._channels[0] * seed_size * seed_size, 64 | scope="fc_noise") 65 | # Reshape the seed to be a rank-4 Tensor. 66 | net = tf.reshape( 67 | net, 68 | [-1, seed_size, seed_size, self._ch * self._channels[0]], 69 | name="fc_reshaped") 70 | 71 | up_layers = np.log2(float(image_size) / seed_size) 72 | if not up_layers.is_integer(): 73 | raise ValueError("log2({}/{}) must be an integer.".format( 74 | image_size, seed_size)) 75 | if up_layers < 0 or up_layers > 5: 76 | raise ValueError("Invalid image_size {}.".format(image_size)) 77 | up_layers = int(up_layers) 78 | 79 | for block_idx in range(5): 80 | block = self._resnet_block( 81 | name="B{}".format(block_idx + 1), 82 | in_channels=self._ch * self._channels[block_idx], 83 | out_channels=self._ch * self._channels[block_idx + 1], 84 | scale="up" if block_idx < up_layers else "none") 85 | net = block(net, z=z, y=y, is_training=is_training) 86 | 87 | net = self.batch_norm( 88 | net, z=z, y=y, is_training=is_training, name="final_norm") 89 | net = tf.nn.relu(net) 90 | net = ops.conv2d(net, output_dim=self._image_shape[2], 91 | k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv") 92 | net = tf.nn.sigmoid(net) 93 | return net 94 | 95 | 96 | class Discriminator(resnet_ops.ResNetDiscriminator): 97 | """ResNet5 discriminator, 5 blocks, supporting 128x128x3 and 128x128x1.""" 98 | 99 | def __init__(self, ch=64, channels=(1, 2, 4, 4, 8, 8), **kwargs): 100 | super(Discriminator, self).__init__(**kwargs) 101 | self._ch = ch 102 | self._channels = channels 103 | 104 | def apply(self, x, y, is_training): 105 | """Apply the discriminator on a input. 106 | 107 | Args: 108 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 109 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 110 | labels. 111 | is_training: Boolean, whether the architecture should be constructed for 112 | training or inference. 113 | 114 | Returns: 115 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 116 | before the final output activation function and logits form the second 117 | last layer. 118 | """ 119 | resnet_ops.validate_image_inputs(x) 120 | colors = x.shape[3].value 121 | if colors not in [1, 3]: 122 | raise ValueError("Number of color channels not supported: {}".format( 123 | colors)) 124 | 125 | block = self._resnet_block( 126 | name="B0", 127 | in_channels=colors, 128 | out_channels=self._ch, 129 | scale="down") 130 | output = block(x, z=None, y=y, is_training=is_training) 131 | 132 | for block_idx in range(5): 133 | block = self._resnet_block( 134 | name="B{}".format(block_idx + 1), 135 | in_channels=self._ch * self._channels[block_idx], 136 | out_channels=self._ch * self._channels[block_idx + 1], 137 | scale="down") 138 | output = block(output, z=None, y=y, is_training=is_training) 139 | 140 | output = tf.nn.relu(output) 141 | pre_logits = tf.reduce_mean(output, axis=[1, 2]) 142 | out_logit = ops.linear(pre_logits, 1, scope="disc_final_fc", 143 | use_sn=self._spectral_norm) 144 | out = tf.nn.sigmoid(out_logit) 145 | return out, out_logit, pre_logits 146 | -------------------------------------------------------------------------------- /compare_gan/architectures/resnet_biggan_deep_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test number of parameters for the BigGAN-Deep architecture.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import logging 23 | from compare_gan import utils 24 | from compare_gan.architectures import arch_ops 25 | from compare_gan.architectures import resnet_biggan_deep 26 | import tensorflow as tf 27 | 28 | 29 | class ResNet5BigGanDeepTest(tf.test.TestCase): 30 | 31 | def testNumberOfParameters(self): 32 | with tf.Graph().as_default(): 33 | batch_size = 2 34 | z = tf.zeros((batch_size, 128)) 35 | y = tf.one_hot(tf.ones((batch_size,), dtype=tf.int32), 1000) 36 | generator = resnet_biggan_deep.Generator( 37 | image_shape=(128, 128, 3), 38 | batch_norm_fn=arch_ops.conditional_batch_norm) 39 | fake_images = generator(z, y=y, is_training=True, reuse=False) 40 | self.assertEqual(fake_images.shape.as_list(), [batch_size, 128, 128, 3]) 41 | discriminator = resnet_biggan_deep.Discriminator() 42 | predictions = discriminator(fake_images, y, is_training=True) 43 | self.assertLen(predictions, 3) 44 | 45 | t_vars = tf.trainable_variables() 46 | g_vars = [var for var in t_vars if "generator" in var.name] 47 | d_vars = [var for var in t_vars if "discriminator" in var.name] 48 | g_param_overview = utils.get_parameter_overview(g_vars, limit=None) 49 | d_param_overview = utils.get_parameter_overview(d_vars, limit=None) 50 | g_param_overview = g_param_overview.split("\n") 51 | logging.info("Generator variables:") 52 | for i in range(0, len(g_param_overview), 80): 53 | logging.info("\n%s", "\n".join(g_param_overview[i:i + 80])) 54 | logging.info("Discriminator variables:\n%s", d_param_overview) 55 | 56 | g_num_weights = sum([v.get_shape().num_elements() for v in g_vars]) 57 | self.assertEqual(g_num_weights, 50244484) 58 | 59 | d_num_weights = sum([v.get_shape().num_elements() for v in d_vars]) 60 | self.assertEqual(d_num_weights, 34590210) 61 | 62 | 63 | if __name__ == "__main__": 64 | tf.test.main() 65 | -------------------------------------------------------------------------------- /compare_gan/architectures/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Resnet generator and discriminator for CIFAR. 17 | 18 | Based on Table 4 from "Spectral Normalization for Generative Adversarial 19 | Networks", Miyato T. et al., 2018. [https://arxiv.org/pdf/1802.05957.pdf]. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from compare_gan.architectures import arch_ops as ops 27 | from compare_gan.architectures import resnet_ops 28 | 29 | import gin 30 | from six.moves import range 31 | import tensorflow as tf 32 | 33 | 34 | @gin.configurable 35 | class Generator(resnet_ops.ResNetGenerator): 36 | """ResNet generator, 4 blocks, supporting 32x32 resolution.""" 37 | 38 | def __init__(self, 39 | hierarchical_z=False, 40 | embed_z=False, 41 | embed_y=False, 42 | **kwargs): 43 | """Constructor for the ResNet Cifar generator. 44 | 45 | Args: 46 | hierarchical_z: Split z into chunks and only give one chunk to each. 47 | Each chunk will also be concatenated to y, the one hot encoded labels. 48 | embed_z: If True use a learnable embedding of z that is used instead. 49 | The embedding will have the length of z. 50 | embed_y: If True use a learnable embedding of y that is used instead. 51 | The embedding will have the length of z (not y!). 52 | **kwargs: additional arguments past on to ResNetGenerator. 53 | """ 54 | super(Generator, self).__init__(**kwargs) 55 | self._hierarchical_z = hierarchical_z 56 | self._embed_z = embed_z 57 | self._embed_y = embed_y 58 | 59 | def apply(self, z, y, is_training): 60 | """Build the generator network for the given inputs. 61 | 62 | Args: 63 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 64 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 65 | labels. 66 | is_training: boolean, are we in train or eval model. 67 | 68 | Returns: 69 | A tensor of size [batch_size, 32, 32, colors] with values in [0, 1]. 70 | """ 71 | assert self._image_shape[0] == 32 72 | assert self._image_shape[1] == 32 73 | num_blocks = 3 74 | z_dim = z.shape[1].value 75 | 76 | if self._embed_z: 77 | z = ops.linear(z, z_dim, scope="embed_z", use_sn=self._spectral_norm) 78 | if self._embed_y: 79 | y = ops.linear(y, z_dim, scope="embed_y", use_sn=self._spectral_norm) 80 | y_per_block = num_blocks * [y] 81 | if self._hierarchical_z: 82 | z_per_block = tf.split(z, num_blocks + 1, axis=1) 83 | z0, z_per_block = z_per_block[0], z_per_block[1:] 84 | if y is not None: 85 | y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block] 86 | else: 87 | z0 = z 88 | z_per_block = num_blocks * [z] 89 | 90 | output = ops.linear(z0, 4 * 4 * 256, scope="fc_noise", 91 | use_sn=self._spectral_norm) 92 | output = tf.reshape(output, [-1, 4, 4, 256], name="fc_reshaped") 93 | for block_idx in range(3): 94 | block = self._resnet_block( 95 | name="B{}".format(block_idx + 1), 96 | in_channels=256, 97 | out_channels=256, 98 | scale="up") 99 | output = block( 100 | output, 101 | z=z_per_block[block_idx], 102 | y=y_per_block[block_idx], 103 | is_training=is_training) 104 | 105 | # Final processing of the output. 106 | output = self.batch_norm( 107 | output, z=z, y=y, is_training=is_training, name="final_norm") 108 | output = tf.nn.relu(output) 109 | output = ops.conv2d(output, output_dim=self._image_shape[2], k_h=3, k_w=3, 110 | d_h=1, d_w=1, name="final_conv", 111 | use_sn=self._spectral_norm,) 112 | return tf.nn.sigmoid(output) 113 | 114 | 115 | @gin.configurable 116 | class Discriminator(resnet_ops.ResNetDiscriminator): 117 | """ResNet discriminator, 4 blocks, supporting 32x32 with 1 or 3 colors.""" 118 | 119 | def __init__(self, project_y=False, **kwargs): 120 | super(Discriminator, self).__init__(**kwargs) 121 | self._project_y = project_y 122 | 123 | def apply(self, x, y, is_training): 124 | """Apply the discriminator on a input. 125 | 126 | Args: 127 | x: `Tensor` of shape [batch_size, 32, 32, ?] with real or fake images. 128 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 129 | labels. 130 | is_training: Boolean, whether the architecture should be constructed for 131 | training or inference. 132 | 133 | Returns: 134 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 135 | before the final output activation function and logits form the second 136 | last layer. 137 | """ 138 | resnet_ops.validate_image_inputs(x) 139 | colors = x.shape[3].value 140 | if colors not in [1, 3]: 141 | raise ValueError("Number of color channels not supported: {}".format( 142 | colors)) 143 | 144 | output = x 145 | for block_idx in range(4): 146 | block = self._resnet_block( 147 | name="B{}".format(block_idx + 1), 148 | in_channels=colors if block_idx == 0 else 128, 149 | out_channels=128, 150 | scale="down" if block_idx <= 1 else "none") 151 | output = block(output, z=None, y=y, is_training=is_training) 152 | 153 | # Final part - ReLU 154 | output = tf.nn.relu(output) 155 | 156 | h = tf.reduce_mean(output, axis=[1, 2]) 157 | 158 | out_logit = ops.linear(h, 1, scope="disc_final_fc", 159 | use_sn=self._spectral_norm) 160 | if self._project_y: 161 | if y is None: 162 | raise ValueError("You must provide class information y to project.") 163 | embedded_y = ops.linear(y, 128, use_bias=False, 164 | scope="embedding_fc", use_sn=self._spectral_norm) 165 | out_logit += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) 166 | out = tf.nn.sigmoid(out_logit) 167 | return out, out_logit, h 168 | -------------------------------------------------------------------------------- /compare_gan/architectures/resnet_init_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests weight initialization ops using ResNet5 architecture.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.architectures import resnet5 23 | from compare_gan.gans import consts 24 | import gin 25 | import tensorflow as tf 26 | 27 | 28 | class ResNetInitTest(tf.test.TestCase): 29 | 30 | def setUp(self): 31 | super(ResNetInitTest, self).setUp() 32 | gin.clear_config() 33 | 34 | def testInitializersOldDefault(self): 35 | valid_initalizer = [ 36 | "kernel/Initializer/random_normal", 37 | "bias/Initializer/Const", 38 | # truncated_normal is the old default for conv2d. 39 | "kernel/Initializer/truncated_normal", 40 | "bias/Initializer/Const", 41 | "beta/Initializer/zeros", 42 | "gamma/Initializer/ones", 43 | ] 44 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer)) 45 | with tf.Graph().as_default(): 46 | z = tf.zeros((2, 128)) 47 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))( 48 | z, y=None, is_training=True) 49 | resnet5.Discriminator()(fake_image, y=None, is_training=True) 50 | for var in tf.trainable_variables(): 51 | op_name = var.initializer.inputs[1].name 52 | self.assertRegex(op_name, valid_op_names) 53 | 54 | def testInitializersRandomNormal(self): 55 | gin.bind_parameter("weights.initializer", consts.NORMAL_INIT) 56 | valid_initalizer = [ 57 | "kernel/Initializer/random_normal", 58 | "bias/Initializer/Const", 59 | "kernel/Initializer/random_normal", 60 | "bias/Initializer/Const", 61 | "beta/Initializer/zeros", 62 | "gamma/Initializer/ones", 63 | ] 64 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer)) 65 | with tf.Graph().as_default(): 66 | z = tf.zeros((2, 128)) 67 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))( 68 | z, y=None, is_training=True) 69 | resnet5.Discriminator()(fake_image, y=None, is_training=True) 70 | for var in tf.trainable_variables(): 71 | op_name = var.initializer.inputs[1].name 72 | self.assertRegex(op_name, valid_op_names) 73 | 74 | def testInitializersTruncatedNormal(self): 75 | gin.bind_parameter("weights.initializer", consts.TRUNCATED_INIT) 76 | valid_initalizer = [ 77 | "kernel/Initializer/truncated_normal", 78 | "bias/Initializer/Const", 79 | "kernel/Initializer/truncated_normal", 80 | "bias/Initializer/Const", 81 | "beta/Initializer/zeros", 82 | "gamma/Initializer/ones", 83 | ] 84 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer)) 85 | with tf.Graph().as_default(): 86 | z = tf.zeros((2, 128)) 87 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))( 88 | z, y=None, is_training=True) 89 | resnet5.Discriminator()(fake_image, y=None, is_training=True) 90 | for var in tf.trainable_variables(): 91 | op_name = var.initializer.inputs[1].name 92 | self.assertRegex(op_name, valid_op_names) 93 | 94 | def testGeneratorInitializersOrthogonal(self): 95 | gin.bind_parameter("weights.initializer", consts.ORTHOGONAL_INIT) 96 | valid_initalizer = [ 97 | "kernel/Initializer/mul_1", 98 | "bias/Initializer/Const", 99 | "kernel/Initializer/mul_1", 100 | "bias/Initializer/Const", 101 | "beta/Initializer/zeros", 102 | "gamma/Initializer/ones", 103 | ] 104 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer)) 105 | with tf.Graph().as_default(): 106 | z = tf.zeros((2, 128)) 107 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))( 108 | z, y=None, is_training=True) 109 | resnet5.Discriminator()(fake_image, y=None, is_training=True) 110 | for var in tf.trainable_variables(): 111 | op_name = var.initializer.inputs[1].name 112 | self.assertRegex(op_name, valid_op_names) 113 | 114 | 115 | if __name__ == "__main__": 116 | tf.test.main() 117 | -------------------------------------------------------------------------------- /compare_gan/architectures/resnet_stl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A deep neural architecture with residual blocks and skip connections. 17 | 18 | Based on Table 5 from "Spectral Normalization for Generative Adversarial 19 | Networks", Miyato T. et al., 2018. [https://arxiv.org/pdf/1802.05957.pdf]. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from compare_gan.architectures import arch_ops as ops 27 | from compare_gan.architectures import resnet_ops 28 | 29 | from six.moves import range 30 | import tensorflow as tf 31 | 32 | 33 | class Generator(resnet_ops.ResNetGenerator): 34 | """ResNet generator, 3 blocks, supporting 48x48 resolution.""" 35 | 36 | def apply(self, z, y, is_training): 37 | """Build the generator network for the given inputs. 38 | 39 | Args: 40 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 41 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 42 | labels. 43 | is_training: boolean, are we in train or eval model. 44 | 45 | Returns: 46 | A tensor of size [batch_size, 32, 32, colors] with values in [0, 1]. 47 | """ 48 | ch = 64 49 | colors = self._image_shape[2] 50 | batch_size = z.get_shape().as_list()[0] 51 | magic = [(8, 4), (4, 2), (2, 1)] 52 | output = ops.linear(z, 6 * 6 * 512, scope="fc_noise") 53 | output = tf.reshape(output, [batch_size, 6, 6, 512], name="fc_reshaped") 54 | for block_idx in range(3): 55 | block = self._resnet_block( 56 | name="B{}".format(block_idx + 1), 57 | in_channels=ch * magic[block_idx][0], 58 | out_channels=ch * magic[block_idx][1], 59 | scale="up") 60 | output = block(output, z=z, y=y, is_training=is_training) 61 | output = self.batch_norm( 62 | output, z=z, y=y, is_training=is_training, scope="final_norm") 63 | output = tf.nn.relu(output) 64 | output = ops.conv2d(output, output_dim=colors, k_h=3, k_w=3, d_h=1, d_w=1, 65 | name="final_conv") 66 | return tf.nn.sigmoid(output) 67 | 68 | 69 | class Discriminator(resnet_ops.ResNetDiscriminator): 70 | """ResNet discriminator, 4 blocks, suports 48x48 resolution.""" 71 | 72 | def apply(self, x, y, is_training): 73 | """Apply the discriminator on a input. 74 | 75 | Args: 76 | x: `Tensor` of shape [batch_size, 32, 32, ?] with real or fake images. 77 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 78 | labels. 79 | is_training: Boolean, whether the architecture should be constructed for 80 | training or inference. 81 | 82 | Returns: 83 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 84 | before the final output activation function and logits form the second 85 | last layer. 86 | """ 87 | resnet_ops.validate_image_inputs(x, validate_power2=False) 88 | colors = x.shape[-1].value 89 | if colors not in [1, 3]: 90 | raise ValueError("Number of color channels unknown: %s" % colors) 91 | ch = 64 92 | block = self._resnet_block( 93 | name="B0", in_channels=colors, out_channels=ch, scale="down") 94 | output = block(x, z=None, y=y, is_training=is_training) 95 | magic = [(1, 2), (2, 4), (4, 8), (8, 16)] 96 | for block_idx in range(4): 97 | block = self._resnet_block( 98 | name="B{}".format(block_idx + 1), 99 | in_channels=ch * magic[block_idx][0], 100 | out_channels=ch * magic[block_idx][1], 101 | scale="down" if block_idx < 3 else "none") 102 | output = block(output, z=None, y=y, is_training=is_training) 103 | output = tf.nn.relu(output) 104 | pre_logits = tf.reduce_mean(output, axis=[1, 2]) 105 | out_logit = ops.linear(pre_logits, 1, scope="disc_final_fc", 106 | use_sn=self._spectral_norm) 107 | out = tf.nn.sigmoid(out_logit) 108 | return out, out_logit, pre_logits 109 | -------------------------------------------------------------------------------- /compare_gan/architectures/sndcgan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of SNDCGAN generator and discriminator architectures.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.architectures import abstract_arch 23 | from compare_gan.architectures.arch_ops import conv2d 24 | from compare_gan.architectures.arch_ops import deconv2d 25 | from compare_gan.architectures.arch_ops import linear 26 | from compare_gan.architectures.arch_ops import lrelu 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | 31 | 32 | def conv_out_size_same(size, stride): 33 | return int(np.ceil(float(size) / float(stride))) 34 | 35 | 36 | class Generator(abstract_arch.AbstractGenerator): 37 | """SNDCGAN generator. 38 | 39 | Details are available at https://openreview.net/pdf?id=B1QRgziT-. 40 | """ 41 | 42 | def apply(self, z, y, is_training): 43 | """Build the generator network for the given inputs. 44 | 45 | Args: 46 | z: `Tensor` of shape [batch_size, z_dim] with latent code. 47 | y: `Tensor` of shape [batch_size, num_classes] of one hot encoded labels. 48 | is_training: boolean, are we in train or eval model. 49 | 50 | Returns: 51 | A tensor of size [batch_size] + self._image_shape with values in [0, 1]. 52 | """ 53 | batch_size = z.shape[0].value 54 | s_h, s_w, colors = self._image_shape 55 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 56 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 57 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 58 | 59 | net = linear(z, s_h8 * s_w8 * 512, scope="g_fc1") 60 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn1") 61 | net = tf.nn.relu(net) 62 | net = tf.reshape(net, [batch_size, s_h8, s_w8, 512]) 63 | net = deconv2d(net, [batch_size, s_h4, s_w4, 256], 4, 4, 2, 2, name="g_dc2") 64 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn2") 65 | net = tf.nn.relu(net) 66 | net = deconv2d(net, [batch_size, s_h2, s_w2, 128], 4, 4, 2, 2, name="g_dc3") 67 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn3") 68 | net = tf.nn.relu(net) 69 | net = deconv2d(net, [batch_size, s_h, s_w, 64], 4, 4, 2, 2, name="g_dc4") 70 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn4") 71 | net = tf.nn.relu(net) 72 | net = deconv2d( 73 | net, [batch_size, s_h, s_w, colors], 3, 3, 1, 1, name="g_dc5") 74 | out = tf.tanh(net) 75 | 76 | # This normalization from [-1, 1] to [0, 1] is introduced for consistency 77 | # with other models. 78 | out = tf.div(out + 1.0, 2.0) 79 | return out 80 | 81 | 82 | class Discriminator(abstract_arch.AbstractDiscriminator): 83 | """SNDCGAN discriminator. 84 | 85 | Details are available at https://openreview.net/pdf?id=B1QRgziT-. 86 | """ 87 | 88 | def apply(self, x, y, is_training): 89 | """Apply the discriminator on a input. 90 | 91 | Args: 92 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. 93 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded 94 | labels. 95 | is_training: Boolean, whether the architecture should be constructed for 96 | training or inference. 97 | 98 | Returns: 99 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits 100 | before the final output activation function and logits form the second 101 | last layer. 102 | """ 103 | del is_training, y 104 | use_sn = self._spectral_norm 105 | # In compare gan framework, the image preprocess normalize image pixel to 106 | # range [0, 1], while author used [-1, 1]. Apply this trick to input image 107 | # instead of changing our preprocessing function. 108 | x = x * 2.0 - 1.0 109 | net = conv2d(x, 64, 3, 3, 1, 1, name="d_conv1", use_sn=use_sn) 110 | net = lrelu(net, leak=0.1) 111 | net = conv2d(net, 128, 4, 4, 2, 2, name="d_conv2", use_sn=use_sn) 112 | net = lrelu(net, leak=0.1) 113 | net = conv2d(net, 128, 3, 3, 1, 1, name="d_conv3", use_sn=use_sn) 114 | net = lrelu(net, leak=0.1) 115 | net = conv2d(net, 256, 4, 4, 2, 2, name="d_conv4", use_sn=use_sn) 116 | net = lrelu(net, leak=0.1) 117 | net = conv2d(net, 256, 3, 3, 1, 1, name="d_conv5", use_sn=use_sn) 118 | net = lrelu(net, leak=0.1) 119 | net = conv2d(net, 512, 4, 4, 2, 2, name="d_conv6", use_sn=use_sn) 120 | net = lrelu(net, leak=0.1) 121 | net = conv2d(net, 512, 3, 3, 1, 1, name="d_conv7", use_sn=use_sn) 122 | net = lrelu(net, leak=0.1) 123 | batch_size = x.shape.as_list()[0] 124 | net = tf.reshape(net, [batch_size, -1]) 125 | out_logit = linear(net, 1, scope="d_fc1", use_sn=use_sn) 126 | out = tf.nn.sigmoid(out_logit) 127 | return out, out_logit, net 128 | -------------------------------------------------------------------------------- /compare_gan/datasets_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for datasets.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | from absl.testing import flagsaver 24 | from absl.testing import parameterized 25 | from compare_gan import datasets 26 | 27 | import tensorflow as tf 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | _TPU_SUPPORTED_TYPES = { 32 | tf.float32, tf.int32, tf.complex64, tf.int64, tf.bool, tf.bfloat16 33 | } 34 | 35 | 36 | def _preprocess_fn_id(images, labels): 37 | return {"images": images}, labels 38 | 39 | 40 | def _preprocess_fn_add_noise(images, labels, seed=None): 41 | del labels 42 | tf.set_random_seed(seed) 43 | noise = tf.random.uniform([128], maxval=1.0) 44 | return {"images": images}, noise 45 | 46 | 47 | class DatasetsTest(parameterized.TestCase, tf.test.TestCase): 48 | 49 | def setUp(self): 50 | super(DatasetsTest, self).setUp() 51 | FLAGS.data_shuffle_buffer_size = 100 52 | 53 | def get_element_and_verify_shape(self, dataset_name, expected_shape): 54 | dataset = datasets.get_dataset(dataset_name) 55 | dataset = dataset.eval_input_fn() 56 | image, label = dataset.make_one_shot_iterator().get_next() 57 | # Check if shape is known at compile time, required for TPUs. 58 | self.assertAllEqual(image.shape.as_list(), expected_shape) 59 | self.assertEqual(image.dtype, tf.float32) 60 | self.assertIn(label.dtype, _TPU_SUPPORTED_TYPES) 61 | with self.cached_session() as session: 62 | image = session.run(image) 63 | self.assertEqual(image.shape, expected_shape) 64 | self.assertGreaterEqual(image.min(), 0.0) 65 | self.assertLessEqual(image.max(), 1.0) 66 | 67 | def test_mnist(self): 68 | self.get_element_and_verify_shape("mnist", (28, 28, 1)) 69 | 70 | def test_fashion_mnist(self): 71 | self.get_element_and_verify_shape("fashion-mnist", (28, 28, 1)) 72 | 73 | def test_celeba(self): 74 | self.get_element_and_verify_shape("celeb_a", (64, 64, 3)) 75 | 76 | def test_lsun(self): 77 | self.get_element_and_verify_shape("lsun-bedroom", (128, 128, 3)) 78 | 79 | def _run_train_input_fn(self, dataset_name, preprocess_fn): 80 | dataset = datasets.get_dataset(dataset_name) 81 | with tf.Graph().as_default(): 82 | dataset = dataset.input_fn(params={"batch_size": 1}, 83 | preprocess_fn=preprocess_fn) 84 | iterator = dataset.make_initializable_iterator() 85 | with self.session() as sess: 86 | sess.run(iterator.initializer) 87 | next_batch = iterator.get_next() 88 | return [sess.run(next_batch) for _ in range(5)] 89 | 90 | @parameterized.named_parameters( 91 | ("FakeCifar", _preprocess_fn_id), 92 | ("FakeCifarWithRandomNoise", _preprocess_fn_add_noise), 93 | ) 94 | @flagsaver.flagsaver 95 | def test_train_input_fn_is_determinsitic(self, preprocess_fn): 96 | FLAGS.data_fake_dataset = True 97 | batches1 = self._run_train_input_fn("cifar10", preprocess_fn) 98 | batches2 = self._run_train_input_fn("cifar10", preprocess_fn) 99 | for i in range(len(batches1)): 100 | # Check that both runs got the same images/noise 101 | self.assertAllClose(batches1[i][0], batches2[i][0]) 102 | self.assertAllClose(batches1[i][1], batches2[i][1]) 103 | 104 | @flagsaver.flagsaver 105 | def test_train_input_fn_noise_changes(self): 106 | FLAGS.data_fake_dataset = True 107 | batches = self._run_train_input_fn("cifar10", _preprocess_fn_add_noise) 108 | for i in range(1, len(batches)): 109 | self.assertNotAllClose(batches[0][1], batches[i][1]) 110 | self.assertNotAllClose(batches[i - 1][1], batches[i][1]) 111 | 112 | 113 | if __name__ == "__main__": 114 | tf.test.main() 115 | -------------------------------------------------------------------------------- /compare_gan/eval_gan_lib_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for eval_gan_lib.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import os.path 24 | 25 | from absl import flags 26 | from absl.testing import flagsaver 27 | from absl.testing import parameterized 28 | 29 | from compare_gan import datasets 30 | from compare_gan import eval_gan_lib 31 | from compare_gan import eval_utils 32 | from compare_gan.gans import consts as c 33 | from compare_gan.gans.modular_gan import ModularGAN 34 | from compare_gan.metrics import fid_score 35 | from compare_gan.metrics import fractal_dimension 36 | from compare_gan.metrics import inception_score 37 | from compare_gan.metrics import ms_ssim_score 38 | 39 | import gin 40 | import mock 41 | import tensorflow as tf 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def create_fake_inception_graph(): 47 | """Creates a `GraphDef` with that mocks the Inception V1 graph. 48 | 49 | It takes the input, multiplies it through a matrix full of 0.00001 values, 50 | and provides the results in the endpoints 'pool_3' and 'logits'. This 51 | matches the tensor names in the real Inception V1 model. 52 | the real inception model. 53 | 54 | Returns: 55 | `tf.GraphDef` for the mocked Inception V1 graph. 56 | """ 57 | fake_inception = tf.Graph() 58 | with fake_inception.as_default(): 59 | inputs = tf.placeholder( 60 | tf.float32, shape=[None, 299, 299, 3], name="Mul") 61 | w = tf.ones(shape=[299 * 299 * 3, 10]) * 0.00001 62 | outputs = tf.matmul(tf.layers.flatten(inputs), w) 63 | tf.identity(outputs, name="pool_3") 64 | tf.identity(outputs, name="logits") 65 | return fake_inception.as_graph_def() 66 | 67 | 68 | class EvalGanLibTest(parameterized.TestCase, tf.test.TestCase): 69 | 70 | def setUp(self): 71 | super(EvalGanLibTest, self).setUp() 72 | gin.clear_config() 73 | FLAGS.data_fake_dataset = True 74 | self.mock_get_graph = mock.patch.object( 75 | eval_utils, "get_inception_graph_def").start() 76 | self.mock_get_graph.return_value = create_fake_inception_graph() 77 | 78 | @parameterized.parameters(c.ARCHITECTURES) 79 | @flagsaver.flagsaver 80 | def test_end2end_checkpoint(self, architecture): 81 | """Takes real GAN (trained for 1 step) and evaluate it.""" 82 | if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}: 83 | # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape. 84 | return 85 | gin.bind_parameter("dataset.name", "cifar10") 86 | dataset = datasets.get_dataset("cifar10") 87 | options = { 88 | "architecture": architecture, 89 | "z_dim": 120, 90 | "disc_iters": 1, 91 | "lambda": 1, 92 | } 93 | model_dir = os.path.join(tf.test.get_temp_dir(), self.id()) 94 | tf.logging.info("model_dir: %s" % model_dir) 95 | run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir) 96 | gan = ModularGAN(dataset=dataset, 97 | parameters=options, 98 | conditional="biggan" in architecture, 99 | model_dir=model_dir) 100 | estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) 101 | estimator.train(input_fn=gan.input_fn, steps=1) 102 | export_path = os.path.join(model_dir, "tfhub") 103 | checkpoint_path = os.path.join(model_dir, "model.ckpt-1") 104 | module_spec = gan.as_module_spec() 105 | module_spec.export(export_path, checkpoint_path=checkpoint_path) 106 | 107 | eval_tasks = [ 108 | fid_score.FIDScoreTask(), 109 | fractal_dimension.FractalDimensionTask(), 110 | inception_score.InceptionScoreTask(), 111 | ms_ssim_score.MultiscaleSSIMTask() 112 | ] 113 | result_dict = eval_gan_lib.evaluate_tfhub_module( 114 | export_path, eval_tasks, use_tpu=False, num_averaging_runs=1) 115 | tf.logging.info("result_dict: %s", result_dict) 116 | for score in ["fid_score", "fractal_dimension", "inception_score", 117 | "ms_ssim"]: 118 | for stats in ["mean", "std", "list"]: 119 | required_key = "%s_%s" % (score, stats) 120 | self.assertIn(required_key, result_dict, "Missing: %s." % required_key) 121 | 122 | 123 | if __name__ == "__main__": 124 | tf.test.main() 125 | -------------------------------------------------------------------------------- /compare_gan/gans/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # coding=utf-8 17 | -------------------------------------------------------------------------------- /compare_gan/gans/abstract_gan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Interface for GAN models that can be trained using the Estimator API.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | @six.add_metaclass(abc.ABCMeta) 29 | class AbstractGAN(object): 30 | """Interface for GAN models that can be training using the Estimator API.""" 31 | 32 | def __init__(self, 33 | dataset, 34 | parameters, 35 | model_dir): 36 | super(AbstractGAN, self).__init__() 37 | self._dataset = dataset 38 | self._parameters = parameters 39 | self._model_dir = model_dir 40 | 41 | def as_estimator(self, run_config, batch_size, use_tpu): 42 | """Returns a TPUEstimator for this GAN.""" 43 | return tf.contrib.tpu.TPUEstimator( 44 | config=run_config, 45 | use_tpu=use_tpu, 46 | model_fn=self.model_fn, 47 | train_batch_size=batch_size) 48 | 49 | @abc.abstractmethod 50 | def as_module_spec(self, params, mode): 51 | """Returns the generator network as TFHub module spec.""" 52 | 53 | @abc.abstractmethod 54 | def input_fn(self, params, mode): 55 | """Input function that retuns a `tf.data.Dataset` object. 56 | 57 | This function will be called once for each host machine. 58 | 59 | Args: 60 | params: Python dictionary with parameters given to TPUEstimator. 61 | Additional TPUEstimator will set the key `batch_size` with the batch 62 | size for this host machine and `tpu_contextu` with a TPUContext 63 | object. 64 | mode: `tf.estimator.MoedeKeys` value. 65 | 66 | Returns: 67 | A `tf.data.Dataset` object with batched features and labels. 68 | """ 69 | 70 | @abc.abstractmethod 71 | def model_fn(self, features, labels, params, mode): 72 | """Constructs the model for the given features and mode. 73 | 74 | This interface only requires implementing the TRAIN mode. 75 | 76 | On TPUs the model_fn should construct a graph for a single TPU core. 77 | Wrap the optimizer with a `tf.contrib.tpu.CrossShardOptimizer` to do 78 | synchronous training with all TPU cores.c 79 | 80 | Args: 81 | features: A dictionary with the feature tensors. 82 | labels: Tensor will labels. Will be None if mode is PREDICT. 83 | params: Dictionary with hyperparameters passed to TPUEstimator. 84 | Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`, 85 | `tpu_context`. `batch_size` is the batch size for this core. 86 | mode: `tf.estimator.ModeKeys` value (TRAIN, EVAL, PREDICT). The mode 87 | should be passed to the TPUEstimatorSpec and your model should be 88 | build this mode. 89 | 90 | Returns: 91 | A `tf.contrib.tpu.TPUEstimatorSpec`. 92 | """ 93 | -------------------------------------------------------------------------------- /compare_gan/gans/consts.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Defines constants used across the code base.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | NORMAL_INIT = "normal" 24 | TRUNCATED_INIT = "truncated" 25 | ORTHOGONAL_INIT = "orthogonal" 26 | INITIALIZERS = [NORMAL_INIT, TRUNCATED_INIT, ORTHOGONAL_INIT] 27 | 28 | DCGAN_ARCH = "dcgan_arch" 29 | DUMMY_ARCH = "dummy_arch" 30 | INFOGAN_ARCH = "infogan_arch" 31 | RESNET5_ARCH = "resnet5_arch" 32 | RESNET30_ARCH = "resnet30_arch" 33 | RESNET_BIGGAN_ARCH = "resnet_biggan_arch" 34 | RESNET_BIGGAN_DEEP_ARCH = "resnet_biggan_deep_arch" 35 | RESNET_CIFAR_ARCH = "resnet_cifar_arch" 36 | RESNET_STL_ARCH = "resnet_stl_arch" 37 | SNDCGAN_ARCH = "sndcgan_arch" 38 | ARCHITECTURES = [INFOGAN_ARCH, DCGAN_ARCH, RESNET5_ARCH, RESNET30_ARCH, 39 | RESNET_BIGGAN_ARCH, RESNET_BIGGAN_DEEP_ARCH, RESNET_CIFAR_ARCH, 40 | RESNET_STL_ARCH, SNDCGAN_ARCH] 41 | -------------------------------------------------------------------------------- /compare_gan/gans/loss_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of popular GAN losses.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan import utils 23 | import gin 24 | import tensorflow as tf 25 | 26 | 27 | def check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits): 28 | """Checks the shapes and ranks of logits and prediction tensors. 29 | 30 | Args: 31 | d_real: prediction for real points, values in [0, 1], shape [batch_size, 1]. 32 | d_fake: prediction for fake points, values in [0, 1], shape [batch_size, 1]. 33 | d_real_logits: logits for real points, shape [batch_size, 1]. 34 | d_fake_logits: logits for fake points, shape [batch_size, 1]. 35 | 36 | Raises: 37 | ValueError: if the ranks or shapes are mismatched. 38 | """ 39 | def _check_pair(a, b): 40 | if a != b: 41 | raise ValueError("Shape mismatch: %s vs %s." % (a, b)) 42 | if len(a) != 2 or len(b) != 2: 43 | raise ValueError("Rank: expected 2, got %s and %s" % (len(a), len(b))) 44 | 45 | if (d_real is not None) and (d_fake is not None): 46 | _check_pair(d_real.shape.as_list(), d_fake.shape.as_list()) 47 | if (d_real_logits is not None) and (d_fake_logits is not None): 48 | _check_pair(d_real_logits.shape.as_list(), d_fake_logits.shape.as_list()) 49 | if (d_real is not None) and (d_real_logits is not None): 50 | _check_pair(d_real.shape.as_list(), d_real_logits.shape.as_list()) 51 | 52 | 53 | @gin.configurable(whitelist=[]) 54 | def non_saturating(d_real_logits, d_fake_logits, d_real=None, d_fake=None): 55 | """Returns the discriminator and generator loss for Non-saturating loss. 56 | 57 | Args: 58 | d_real_logits: logits for real points, shape [batch_size, 1]. 59 | d_fake_logits: logits for fake points, shape [batch_size, 1]. 60 | d_real: ignored. 61 | d_fake: ignored. 62 | 63 | Returns: 64 | A tuple consisting of the discriminator loss, discriminator's loss on the 65 | real samples and fake samples, and the generator's loss. 66 | """ 67 | with tf.name_scope("non_saturating_loss"): 68 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits) 69 | d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 70 | logits=d_real_logits, labels=tf.ones_like(d_real_logits), 71 | name="cross_entropy_d_real")) 72 | d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 73 | logits=d_fake_logits, labels=tf.zeros_like(d_fake_logits), 74 | name="cross_entropy_d_fake")) 75 | d_loss = d_loss_real + d_loss_fake 76 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 77 | logits=d_fake_logits, labels=tf.ones_like(d_fake_logits), 78 | name="cross_entropy_g")) 79 | return d_loss, d_loss_real, d_loss_fake, g_loss 80 | 81 | 82 | @gin.configurable(whitelist=[]) 83 | def wasserstein(d_real_logits, d_fake_logits, d_real=None, d_fake=None): 84 | """Returns the discriminator and generator loss for Wasserstein loss. 85 | 86 | Args: 87 | d_real_logits: logits for real points, shape [batch_size, 1]. 88 | d_fake_logits: logits for fake points, shape [batch_size, 1]. 89 | d_real: ignored. 90 | d_fake: ignored. 91 | 92 | Returns: 93 | A tuple consisting of the discriminator loss, discriminator's loss on the 94 | real samples and fake samples, and the generator's loss. 95 | """ 96 | with tf.name_scope("wasserstein_loss"): 97 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits) 98 | d_loss_real = -tf.reduce_mean(d_real_logits) 99 | d_loss_fake = tf.reduce_mean(d_fake_logits) 100 | d_loss = d_loss_real + d_loss_fake 101 | g_loss = -d_loss_fake 102 | return d_loss, d_loss_real, d_loss_fake, g_loss 103 | 104 | 105 | @gin.configurable(whitelist=[]) 106 | def least_squares(d_real, d_fake, d_real_logits=None, d_fake_logits=None): 107 | """Returns the discriminator and generator loss for the least-squares loss. 108 | 109 | Args: 110 | d_real: prediction for real points, values in [0, 1], shape [batch_size, 1]. 111 | d_fake: prediction for fake points, values in [0, 1], shape [batch_size, 1]. 112 | d_real_logits: ignored. 113 | d_fake_logits: ignored. 114 | 115 | Returns: 116 | A tuple consisting of the discriminator loss, discriminator's loss on the 117 | real samples and fake samples, and the generator's loss. 118 | """ 119 | with tf.name_scope("least_square_loss"): 120 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits) 121 | d_loss_real = tf.reduce_mean(tf.square(d_real - 1.0)) 122 | d_loss_fake = tf.reduce_mean(tf.square(d_fake)) 123 | d_loss = 0.5 * (d_loss_real + d_loss_fake) 124 | g_loss = 0.5 * tf.reduce_mean(tf.square(d_fake - 1.0)) 125 | return d_loss, d_loss_real, d_loss_fake, g_loss 126 | 127 | 128 | @gin.configurable(whitelist=[]) 129 | def hinge(d_real_logits, d_fake_logits, d_real=None, d_fake=None): 130 | """Returns the discriminator and generator loss for the hinge loss. 131 | 132 | Args: 133 | d_real_logits: logits for real points, shape [batch_size, 1]. 134 | d_fake_logits: logits for fake points, shape [batch_size, 1]. 135 | d_real: ignored. 136 | d_fake: ignored. 137 | 138 | Returns: 139 | A tuple consisting of the discriminator loss, discriminator's loss on the 140 | real samples and fake samples, and the generator's loss. 141 | """ 142 | with tf.name_scope("hinge_loss"): 143 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits) 144 | d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - d_real_logits)) 145 | d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_fake_logits)) 146 | d_loss = d_loss_real + d_loss_fake 147 | g_loss = - tf.reduce_mean(d_fake_logits) 148 | return d_loss, d_loss_real, d_loss_fake, g_loss 149 | 150 | 151 | @gin.configurable("loss", whitelist=["fn"]) 152 | def get_losses(fn=non_saturating, **kwargs): 153 | """Returns the losses for the discriminator and generator.""" 154 | return utils.call_with_accepted_args(fn, **kwargs) 155 | -------------------------------------------------------------------------------- /compare_gan/gans/modular_gan_conditional_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for GANs with different regularizers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | from absl.testing import parameterized 24 | from compare_gan import datasets 25 | from compare_gan import test_utils 26 | from compare_gan.gans import consts as c 27 | from compare_gan.gans import loss_lib 28 | from compare_gan.gans import penalty_lib 29 | from compare_gan.gans.modular_gan import ModularGAN 30 | import gin 31 | import tensorflow as tf 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | TEST_ARCHITECTURES = [c.RESNET5_ARCH, c.RESNET_BIGGAN_ARCH, c.RESNET_CIFAR_ARCH] 36 | TEST_LOSSES = [loss_lib.non_saturating, loss_lib.wasserstein, 37 | loss_lib.least_squares, loss_lib.hinge] 38 | TEST_PENALTIES = [penalty_lib.no_penalty, penalty_lib.dragan_penalty, 39 | penalty_lib.wgangp_penalty, penalty_lib.l2_penalty] 40 | 41 | 42 | class ModularGANConditionalTest(parameterized.TestCase, 43 | test_utils.CompareGanTestCase): 44 | 45 | def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn, 46 | labeled_dataset): 47 | parameters = { 48 | "architecture": architecture, 49 | "lambda": 1, 50 | "z_dim": 120, 51 | } 52 | with gin.unlock_config(): 53 | gin.bind_parameter("penalty.fn", penalty_fn) 54 | gin.bind_parameter("loss.fn", loss_fn) 55 | model_dir = self._get_empty_model_dir() 56 | run_config = tf.contrib.tpu.RunConfig( 57 | model_dir=model_dir, 58 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) 59 | dataset = datasets.get_dataset("cifar10") 60 | gan = ModularGAN( 61 | dataset=dataset, 62 | parameters=parameters, 63 | conditional=True, 64 | model_dir=model_dir) 65 | estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) 66 | estimator.train(gan.input_fn, steps=1) 67 | 68 | @parameterized.parameters(TEST_ARCHITECTURES) 69 | def testSingleTrainingStepArchitectures(self, architecture): 70 | self._runSingleTrainingStep(architecture, loss_lib.hinge, 71 | penalty_lib.no_penalty, True) 72 | 73 | @parameterized.parameters(TEST_LOSSES) 74 | def testSingleTrainingStepLosses(self, loss_fn): 75 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_fn, 76 | penalty_lib.no_penalty, labeled_dataset=True) 77 | 78 | @parameterized.parameters(TEST_PENALTIES) 79 | def testSingleTrainingStepPenalties(self, penalty_fn): 80 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_lib.hinge, penalty_fn, 81 | labeled_dataset=True) 82 | 83 | def testUnlabledDatasetRaisesError(self): 84 | parameters = { 85 | "architecture": c.RESNET_CIFAR_ARCH, 86 | "lambda": 1, 87 | "z_dim": 120, 88 | } 89 | with gin.unlock_config(): 90 | gin.bind_parameter("loss.fn", loss_lib.hinge) 91 | # Use dataset without labels. 92 | dataset = datasets.get_dataset("celeb_a") 93 | model_dir = self._get_empty_model_dir() 94 | with self.assertRaises(ValueError): 95 | gan = ModularGAN( 96 | dataset=dataset, 97 | parameters=parameters, 98 | conditional=True, 99 | model_dir=model_dir) 100 | del gan 101 | 102 | 103 | if __name__ == "__main__": 104 | tf.test.main() 105 | -------------------------------------------------------------------------------- /compare_gan/gans/modular_gan_tpu_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests TPU specfic parts of ModularGAN.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | from absl.testing import parameterized 24 | from compare_gan import datasets 25 | from compare_gan import test_utils 26 | from compare_gan.gans import consts as c 27 | from compare_gan.gans.modular_gan import ModularGAN 28 | import tensorflow as tf 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class ModularGanTpuTest(parameterized.TestCase, test_utils.CompareGanTestCase): 34 | 35 | def setUp(self): 36 | super(ModularGanTpuTest, self).setUp() 37 | self.model_dir = self._get_empty_model_dir() 38 | self.run_config = tf.contrib.tpu.RunConfig( 39 | model_dir=self.model_dir, 40 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) 41 | 42 | @parameterized.parameters([1, 2, 5]) 43 | def testBatchSize(self, disc_iters, use_tpu=True): 44 | parameters = { 45 | "architecture": c.DUMMY_ARCH, 46 | "lambda": 1, 47 | "z_dim": 128, 48 | "disc_iters": disc_iters, 49 | } 50 | batch_size = 16 51 | dataset = datasets.get_dataset("cifar10") 52 | gan = ModularGAN( 53 | dataset=dataset, 54 | parameters=parameters, 55 | model_dir=self.model_dir) 56 | estimator = gan.as_estimator(self.run_config, batch_size=batch_size, 57 | use_tpu=True) 58 | estimator.train(gan.input_fn, steps=1) 59 | 60 | gen_args = gan.generator.call_arg_list 61 | disc_args = gan.discriminator.call_arg_list 62 | self.assertLen(gen_args, disc_iters + 1) # D steps, G step. 63 | self.assertLen(disc_args, disc_iters + 1) # D steps, G step. 64 | 65 | for args in gen_args: 66 | self.assertAllEqual(args["z"].shape.as_list(), [8, 128]) 67 | for args in disc_args: 68 | self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3]) 69 | 70 | @parameterized.parameters([1, 2, 5]) 71 | def testBatchSizeSplitDiscCalls(self, disc_iters): 72 | parameters = { 73 | "architecture": c.DUMMY_ARCH, 74 | "lambda": 1, 75 | "z_dim": 128, 76 | "disc_iters": disc_iters, 77 | } 78 | batch_size = 16 79 | dataset = datasets.get_dataset("cifar10") 80 | gan = ModularGAN( 81 | dataset=dataset, 82 | parameters=parameters, 83 | deprecated_split_disc_calls=True, 84 | model_dir=self.model_dir) 85 | estimator = gan.as_estimator(self.run_config, batch_size=batch_size, 86 | use_tpu=True) 87 | estimator.train(gan.input_fn, steps=1) 88 | 89 | gen_args = gan.generator.call_arg_list 90 | disc_args = gan.discriminator.call_arg_list 91 | self.assertLen(gen_args, disc_iters + 1) # D steps, G step. 92 | # Each D and G step calls discriminator twice: for real and fake images. 93 | self.assertLen(disc_args, 2 * (disc_iters + 1)) 94 | 95 | for args in gen_args: 96 | self.assertAllEqual(args["z"].shape.as_list(), [8, 128]) 97 | for args in disc_args: 98 | self.assertAllEqual(args["x"].shape.as_list(), [8, 32, 32, 3]) 99 | 100 | @parameterized.parameters([1, 2, 5]) 101 | def testBatchSizeExperimentalJointGenForDisc(self, disc_iters): 102 | parameters = { 103 | "architecture": c.DUMMY_ARCH, 104 | "lambda": 1, 105 | "z_dim": 128, 106 | "disc_iters": disc_iters, 107 | } 108 | batch_size = 16 109 | dataset = datasets.get_dataset("cifar10") 110 | gan = ModularGAN( 111 | dataset=dataset, 112 | parameters=parameters, 113 | experimental_joint_gen_for_disc=True, 114 | model_dir=self.model_dir) 115 | estimator = gan.as_estimator(self.run_config, batch_size=batch_size, 116 | use_tpu=True) 117 | estimator.train(gan.input_fn, steps=1) 118 | 119 | gen_args = gan.generator.call_arg_list 120 | disc_args = gan.discriminator.call_arg_list 121 | self.assertLen(gen_args, 2) 122 | self.assertLen(disc_args, disc_iters + 1) 123 | 124 | self.assertAllEqual(gen_args[0]["z"].shape.as_list(), [8 * disc_iters, 128]) 125 | self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128]) 126 | for args in disc_args: 127 | self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3]) 128 | 129 | 130 | if __name__ == "__main__": 131 | tf.test.main() 132 | -------------------------------------------------------------------------------- /compare_gan/gans/ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Customized TensorFlow operations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.tpu import tpu_random 23 | 24 | random_uniform = tpu_random.uniform 25 | random_normal = tpu_random.normal 26 | -------------------------------------------------------------------------------- /compare_gan/gans/penalty_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of popular GAN penalties.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan import utils 23 | from compare_gan.gans import ops 24 | import gin 25 | import tensorflow as tf 26 | 27 | 28 | @gin.configurable 29 | def no_penalty(): 30 | return tf.constant(0.0) 31 | 32 | 33 | @gin.configurable(whitelist=[]) 34 | def dragan_penalty(discriminator, x, y, is_training): 35 | """Returns the DRAGAN gradient penalty. 36 | 37 | Args: 38 | discriminator: Instance of `AbstractDiscriminator`. 39 | x: Samples from the true distribution, shape [bs, h, w, channels]. 40 | y: Encoded class embedding for the samples. None for unsupervised models. 41 | is_training: boolean, are we in train or eval model. 42 | 43 | Returns: 44 | A tensor with the computed penalty. 45 | """ 46 | with tf.name_scope("dragan_penalty"): 47 | _, var = tf.nn.moments(x, axes=list(range(len(x.get_shape())))) 48 | std = tf.sqrt(var) 49 | x_noisy = x + std * (ops.random_uniform(x.shape) - 0.5) 50 | x_noisy = tf.clip_by_value(x_noisy, 0.0, 1.0) 51 | logits = discriminator(x_noisy, y=y, is_training=is_training, reuse=True)[1] 52 | gradients = tf.gradients(logits, [x_noisy])[0] 53 | slopes = tf.sqrt(0.0001 + tf.reduce_sum( 54 | tf.square(gradients), reduction_indices=[1, 2, 3])) 55 | gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0)) 56 | return gradient_penalty 57 | 58 | 59 | @gin.configurable(whitelist=[]) 60 | def wgangp_penalty(discriminator, x, x_fake, y, is_training): 61 | """Returns the WGAN gradient penalty. 62 | 63 | Args: 64 | discriminator: Instance of `AbstractDiscriminator`. 65 | x: samples from the true distribution, shape [bs, h, w, channels]. 66 | x_fake: samples from the fake distribution, shape [bs, h, w, channels]. 67 | y: Encoded class embedding for the samples. None for unsupervised models. 68 | is_training: boolean, are we in train or eval model. 69 | 70 | Returns: 71 | A tensor with the computed penalty. 72 | """ 73 | with tf.name_scope("wgangp_penalty"): 74 | alpha = ops.random_uniform(shape=[x.shape[0].value, 1, 1, 1], name="alpha") 75 | interpolates = x + alpha * (x_fake - x) 76 | logits = discriminator( 77 | interpolates, y=y, is_training=is_training, reuse=True)[1] 78 | gradients = tf.gradients(logits, [interpolates])[0] 79 | slopes = tf.sqrt(0.0001 + tf.reduce_sum( 80 | tf.square(gradients), reduction_indices=[1, 2, 3])) 81 | gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0)) 82 | return gradient_penalty 83 | 84 | 85 | @gin.configurable(whitelist=[]) 86 | def l2_penalty(discriminator): 87 | """Returns the L2 penalty for each matrix/vector excluding biases. 88 | 89 | Assumes a specific tensor naming followed throughout the compare_gan library. 90 | We penalize all fully connected, conv2d, and deconv2d layers. 91 | 92 | Args: 93 | discriminator: Instance of `AbstractDiscriminator`. 94 | 95 | Returns: 96 | A tensor with the computed penalty. 97 | """ 98 | with tf.name_scope("l2_penalty"): 99 | d_weights = [v for v in discriminator.trainable_variables 100 | if v.name.endswith("/kernel:0")] 101 | return tf.reduce_mean( 102 | [tf.nn.l2_loss(i) for i in d_weights], name="l2_penalty") 103 | 104 | 105 | @gin.configurable("penalty", whitelist=["fn"]) 106 | def get_penalty_loss(fn=no_penalty, **kwargs): 107 | """Returns the penalty loss.""" 108 | return utils.call_with_accepted_args(fn, **kwargs) 109 | -------------------------------------------------------------------------------- /compare_gan/gans/s3gan_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for S3GANs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | from absl.testing import parameterized 24 | from compare_gan import datasets 25 | from compare_gan import test_utils 26 | from compare_gan.gans import consts as c 27 | from compare_gan.gans import loss_lib 28 | from compare_gan.gans.s3gan import S3GAN 29 | import gin 30 | import tensorflow as tf 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | class S3GANTest(parameterized.TestCase, test_utils.CompareGanTestCase): 36 | 37 | @parameterized.parameters( 38 | {"use_predictor": False, "project_y": False}, # unsupervised. 39 | {"use_predictor": False}, # fully supervised. 40 | {"use_predictor": True}, # only oracle. 41 | {"use_predictor": True, "self_supervision": "rotation"}, # oracle + SS. 42 | {"use_predictor": False, "self_supervision": "rotation"}, # only SS. 43 | ) 44 | def testSingleTrainingStepArchitectures( 45 | self, use_predictor, project_y=True, self_supervision="none"): 46 | parameters = { 47 | "architecture": c.RESNET_BIGGAN_ARCH, 48 | "lambda": 1, 49 | "z_dim": 120, 50 | } 51 | with gin.unlock_config(): 52 | gin.bind_parameter("ModularGAN.conditional", True) 53 | gin.bind_parameter("loss.fn", loss_lib.hinge) 54 | gin.bind_parameter("S3GAN.use_predictor", use_predictor) 55 | gin.bind_parameter("S3GAN.project_y", project_y) 56 | gin.bind_parameter("S3GAN.self_supervision", self_supervision) 57 | # Fake ImageNet dataset by overriding the properties. 58 | dataset = datasets.get_dataset("imagenet_128") 59 | model_dir = self._get_empty_model_dir() 60 | run_config = tf.contrib.tpu.RunConfig( 61 | model_dir=model_dir, 62 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) 63 | gan = S3GAN( 64 | dataset=dataset, 65 | parameters=parameters, 66 | model_dir=model_dir, 67 | g_optimizer_fn=tf.train.AdamOptimizer, 68 | g_lr=0.0002, 69 | rotated_batch_fraction=2) 70 | estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False) 71 | estimator.train(gan.input_fn, steps=1) 72 | 73 | 74 | if __name__ == "__main__": 75 | tf.test.main() 76 | 77 | -------------------------------------------------------------------------------- /compare_gan/gans/ssgan_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for SSGANs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | from absl.testing import parameterized 24 | from compare_gan import datasets 25 | from compare_gan import test_utils 26 | from compare_gan.gans import consts as c 27 | from compare_gan.gans import loss_lib 28 | from compare_gan.gans import penalty_lib 29 | from compare_gan.gans.ssgan import SSGAN 30 | import gin 31 | import tensorflow as tf 32 | 33 | FLAGS = flags.FLAGS 34 | TEST_ARCHITECTURES = [c.RESNET_CIFAR_ARCH, c.SNDCGAN_ARCH, c.RESNET5_ARCH] 35 | TEST_LOSSES = [loss_lib.non_saturating, loss_lib.hinge] 36 | TEST_PENALTIES = [penalty_lib.no_penalty, penalty_lib.wgangp_penalty] 37 | 38 | 39 | class SSGANTest(parameterized.TestCase, test_utils.CompareGanTestCase): 40 | 41 | def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn): 42 | parameters = { 43 | "architecture": architecture, 44 | "lambda": 1, 45 | "z_dim": 128, 46 | } 47 | with gin.unlock_config(): 48 | gin.bind_parameter("penalty.fn", penalty_fn) 49 | gin.bind_parameter("loss.fn", loss_fn) 50 | model_dir = self._get_empty_model_dir() 51 | run_config = tf.contrib.tpu.RunConfig( 52 | model_dir=model_dir, 53 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1)) 54 | dataset = datasets.get_dataset("cifar10") 55 | gan = SSGAN( 56 | dataset=dataset, 57 | parameters=parameters, 58 | model_dir=model_dir, 59 | g_optimizer_fn=tf.train.AdamOptimizer, 60 | g_lr=0.0002, 61 | rotated_batch_size=4) 62 | estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False) 63 | estimator.train(gan.input_fn, steps=1) 64 | 65 | @parameterized.parameters(TEST_ARCHITECTURES) 66 | def testSingleTrainingStepArchitectures(self, architecture): 67 | self._runSingleTrainingStep(architecture, loss_lib.hinge, 68 | penalty_lib.no_penalty) 69 | 70 | @parameterized.parameters(TEST_LOSSES) 71 | def testSingleTrainingStepLosses(self, loss_fn): 72 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_fn, 73 | penalty_lib.no_penalty) 74 | 75 | @parameterized.parameters(TEST_PENALTIES) 76 | def testSingleTrainingStepPenalties(self, penalty_fn): 77 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_lib.hinge, penalty_fn) 78 | 79 | 80 | if __name__ == "__main__": 81 | tf.test.main() 82 | -------------------------------------------------------------------------------- /compare_gan/gans/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities library.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import scipy.misc 24 | import tensorflow as tf 25 | 26 | 27 | def check_folder(log_dir): 28 | if not tf.gfile.IsDirectory(log_dir): 29 | tf.gfile.MakeDirs(log_dir) 30 | return log_dir 31 | 32 | 33 | def save_images(images, image_path): 34 | with tf.gfile.Open(image_path, "wb") as f: 35 | scipy.misc.imsave(f, images * 255.0) 36 | 37 | 38 | def rotate_images(images, rot90_scalars=(0, 1, 2, 3)): 39 | """Return the input image and its 90, 180, and 270 degree rotations.""" 40 | images_rotated = [ 41 | images, # 0 degree 42 | tf.image.flip_up_down(tf.image.transpose_image(images)), # 90 degrees 43 | tf.image.flip_left_right(tf.image.flip_up_down(images)), # 180 degrees 44 | tf.image.transpose_image(tf.image.flip_up_down(images)) # 270 degrees 45 | ] 46 | 47 | results = tf.stack([images_rotated[i] for i in rot90_scalars]) 48 | results = tf.reshape(results, 49 | [-1] + images.get_shape().as_list()[1:]) 50 | return results 51 | 52 | 53 | def gaussian(batch_size, n_dim, mean=0., var=1.): 54 | return np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32) 55 | -------------------------------------------------------------------------------- /compare_gan/hooks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Contains SessionRunHooks for training.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import time 23 | 24 | from absl import logging 25 | import tensorflow as tf 26 | 27 | 28 | class AsyncCheckpointSaverHook(tf.contrib.tpu.AsyncCheckpointSaverHook): 29 | """Saves checkpoints every N steps in a asynchronous thread. 30 | 31 | This is the same as tf.contrib.tpu.AsyncCheckpointSaverHook but guarantees 32 | that there will be a checkpoint every `save_steps` steps. This helps to have 33 | eval results at fixed step counts, even when training is paused between 34 | regular checkpoint intervals. 35 | """ 36 | 37 | def after_create_session(self, session, coord): 38 | super(AsyncCheckpointSaverHook, self).after_create_session(session, coord) 39 | # Interruptions to the training job can cause non-regular checkpoints 40 | # (between every_steps). Modify last triggered step to point to the last 41 | # regular checkpoint step to make sure we trigger on the next regular 42 | # checkpoint step. 43 | step = session.run(self._global_step_tensor) 44 | every_steps = self._timer._every_steps # pylint: disable=protected-access 45 | last_triggered_step = step - step % every_steps 46 | self._timer.update_last_triggered_step(last_triggered_step) 47 | 48 | 49 | class EveryNSteps(tf.train.SessionRunHook): 50 | """"Base class for hooks that execute callbacks every N steps. 51 | 52 | class MyHook(EveryNSteps): 53 | def __init__(self, every_n_steps): 54 | super(MyHook, self).__init__(every_n_steps) 55 | 56 | def every_n_steps_after_run(self, step, run_context, run_values): 57 | # Your Implementation 58 | 59 | If you do overwrite begin(), end(), before_run() or after_run() make sure to 60 | call super() at the beginning. 61 | """ 62 | 63 | def __init__(self, every_n_steps): 64 | """Initializes an `EveryNSteps` hook. 65 | 66 | Args: 67 | every_n_steps: `int`, the number of steps to allow between callbacks. 68 | """ 69 | self._timer = tf.train.SecondOrStepTimer(every_steps=every_n_steps) 70 | self._global_step_tensor = None 71 | 72 | def begin(self): 73 | self._global_step_tensor = tf.train.get_global_step() 74 | if self._global_step_tensor is None: 75 | raise RuntimeError("Global step must be created to use EveryNSteps.") 76 | 77 | def before_run(self, run_context): # pylint: disable=unused-argument 78 | """Overrides `SessionRunHook.before_run`. 79 | 80 | Args: 81 | run_context: A `SessionRunContext` object. 82 | 83 | Returns: 84 | None or a `SessionRunArgs` object. 85 | """ 86 | return tf.train.SessionRunArgs({"global_step": self._global_step_tensor}) 87 | 88 | def after_run(self, run_context, run_values): 89 | """Overrides `SessionRunHook.after_run`. 90 | 91 | Args: 92 | run_context: A `SessionRunContext` object. 93 | run_values: A SessionRunValues object. 94 | """ 95 | step = run_values.results["global_step"] 96 | if self._timer.should_trigger_for_step(step): 97 | self.every_n_steps_after_run(step, run_context, run_values) 98 | self._timer.update_last_triggered_step(step) 99 | 100 | def end(self, sess): 101 | step = sess.run(self._global_step_tensor) 102 | self.every_n_steps_after_run(step, None, None) 103 | 104 | def every_n_steps_after_run(self, step, run_context, run_values): 105 | """Callback after every n"th call to run(). 106 | 107 | Args: 108 | step: Current global_step value. 109 | run_context: A `SessionRunContext` object. 110 | run_values: A SessionRunValues object. 111 | """ 112 | raise NotImplementedError("Subclasses of EveryNSteps should implement " 113 | "every_n_steps_after_run().") 114 | 115 | 116 | class ReportProgressHook(EveryNSteps): 117 | """SessionRunHook that reports progress to a `TaskManager` instance.""" 118 | 119 | def __init__(self, task_manager, max_steps, every_n_steps=100): 120 | """Create a new instance of ReportProgressHook. 121 | 122 | Args: 123 | task_manager: A `TaskManager` instance that implements report_progress(). 124 | max_steps: Maximum number of training steps. 125 | every_n_steps: How frequently the hook should report progress. 126 | """ 127 | super(ReportProgressHook, self).__init__(every_n_steps=every_n_steps) 128 | logging.info("Creating ReportProgressHook to report progress every %d " 129 | "steps.", every_n_steps) 130 | self.max_steps = max_steps 131 | self.task_manager = task_manager 132 | self.start_time = None 133 | self.start_step = None 134 | 135 | def every_n_steps_after_run(self, step, run_context, run_values): 136 | if self.start_time is None: 137 | # First call. 138 | self.start_time = time.time() 139 | self.start_step = step 140 | return 141 | 142 | time_elapsed = time.time() - self.start_time 143 | steps_per_sec = float(step - self.start_step) / time_elapsed 144 | eta_seconds = (self.max_steps - step) / (steps_per_sec + 0.0000001) 145 | message = "{:.1f}% @{:d}, {:.1f} steps/s, ETA: {:.0f} min".format( 146 | 100 * step / self.max_steps, step, steps_per_sec, eta_seconds / 60) 147 | logging.info("Reporting progress: %s", message) 148 | self.task_manager.report_progress(message) 149 | -------------------------------------------------------------------------------- /compare_gan/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Binary to train and evaluate one GAN configuration.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | # pylint: disable=unused-import 25 | 26 | from absl import app 27 | from absl import flags 28 | from absl import logging 29 | 30 | from compare_gan import datasets 31 | from compare_gan import runner_lib 32 | # Import GAN types so that they can be used in Gin configs without module names. 33 | from compare_gan.gans.modular_gan import ModularGAN 34 | from compare_gan.gans.s3gan import S3GAN 35 | from compare_gan.gans.ssgan import SSGAN 36 | 37 | # Required import to configure core TF classes and functions. 38 | import gin 39 | import gin.tf.external_configurables 40 | import tensorflow as tf 41 | 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | flags.DEFINE_string("model_dir", None, "Where to store files.") 46 | flags.DEFINE_string( 47 | "schedule", "train", 48 | "Schedule to run. Options: train, continuous_eval.") 49 | flags.DEFINE_multi_string( 50 | "gin_config", [], 51 | "List of paths to the config files.") 52 | flags.DEFINE_multi_string( 53 | "gin_bindings", [], 54 | "Newline separated list of Gin parameter bindings.") 55 | flags.DEFINE_string( 56 | "score_filename", "scores.csv", 57 | "Name of the CSV file with evaluation results model_dir.") 58 | 59 | flags.DEFINE_integer( 60 | "num_eval_averaging_runs", 3, 61 | "How many times to average FID and IS") 62 | flags.DEFINE_integer( 63 | "eval_every_steps", 5000, 64 | "Evaluate only checkpoints whose step is divisible by this integer") 65 | 66 | flags.DEFINE_bool("use_tpu", None, "Whether running on TPU or not.") 67 | 68 | 69 | def _get_cluster(): 70 | if not FLAGS.use_tpu: # pylint: disable=unreachable 71 | return None 72 | if "TPU_NAME" not in os.environ: 73 | raise ValueError("Could not find a TPU. Set TPU_NAME.") 74 | return tf.contrib.cluster_resolver.TPUClusterResolver( 75 | tpu=os.environ["TPU_NAME"], 76 | zone=os.environ.get("TPU_ZONE", None)) 77 | 78 | 79 | @gin.configurable("run_config") 80 | def _get_run_config(tf_random_seed=None, 81 | single_core=False, 82 | iterations_per_loop=1000, 83 | save_checkpoints_steps=5000, 84 | keep_checkpoint_max=1000): 85 | """Return `RunConfig` for TPUs.""" 86 | tpu_config = tf.contrib.tpu.TPUConfig( 87 | num_shards=1 if single_core else None, # None = all cores. 88 | iterations_per_loop=iterations_per_loop) 89 | return tf.contrib.tpu.RunConfig( 90 | model_dir=FLAGS.model_dir, 91 | tf_random_seed=tf_random_seed, 92 | save_checkpoints_steps=save_checkpoints_steps, 93 | keep_checkpoint_max=keep_checkpoint_max, 94 | cluster=_get_cluster(), 95 | tpu_config=tpu_config) 96 | 97 | 98 | 99 | 100 | def _get_task_manager(): 101 | """Returns a TaskManager for this experiment.""" 102 | score_file = os.path.join(FLAGS.model_dir, FLAGS.score_filename) 103 | return runner_lib.TaskManagerWithCsvResults( 104 | model_dir=FLAGS.model_dir, score_file=score_file) 105 | 106 | 107 | def main(unused_argv): 108 | logging.info("Gin config: %s\nGin bindings: %s", 109 | FLAGS.gin_config, FLAGS.gin_bindings) 110 | gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) 111 | 112 | 113 | if FLAGS.use_tpu is None: 114 | FLAGS.use_tpu = bool(os.environ.get("TPU_NAME", "")) 115 | if FLAGS.use_tpu: 116 | logging.info("Found TPU %s.", os.environ["TPU_NAME"]) 117 | run_config = _get_run_config() 118 | task_manager = _get_task_manager() 119 | options = runner_lib.get_options_dict() 120 | runner_lib.run_with_schedule( 121 | schedule=FLAGS.schedule, 122 | run_config=run_config, 123 | task_manager=task_manager, 124 | options=options, 125 | use_tpu=FLAGS.use_tpu, 126 | num_eval_averaging_runs=FLAGS.num_eval_averaging_runs, 127 | eval_every_steps=FLAGS.eval_every_steps) 128 | logging.info("I\"m done with my work, ciao!") 129 | 130 | 131 | if __name__ == "__main__": 132 | flags.mark_flag_as_required("model_dir") 133 | app.run(main) 134 | -------------------------------------------------------------------------------- /compare_gan/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # coding=utf-8 17 | -------------------------------------------------------------------------------- /compare_gan/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Discriminator accuracy. 17 | 18 | Computes the discrimionator's accuracy on (a subset) of the training dataset, 19 | test dataset, and a generated data set. The score is averaged over several 20 | multiple generated data sets and subsets of the training data. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from absl import logging 28 | 29 | from compare_gan import datasets 30 | from compare_gan import eval_utils 31 | from compare_gan.metrics import eval_task 32 | 33 | import numpy as np 34 | 35 | 36 | class AccuracyTask(eval_task.EvalTask): 37 | """Evaluation Task for computing and reporting accuracy.""" 38 | 39 | def metric_list(self): 40 | return frozenset([ 41 | "train_accuracy", "test_accuracy", "fake_accuracy", "train_d_loss", 42 | "test_d_loss" 43 | ]) 44 | 45 | def run_in_session(self, options, sess, gan, real_images): 46 | del options 47 | return compute_accuracy_loss(sess, gan, real_images) 48 | 49 | 50 | def compute_accuracy_loss(sess, 51 | gan, 52 | test_images, 53 | max_train_examples=50000, 54 | num_repeat=5): 55 | """Compute discriminator's accuracy and loss on a given dataset. 56 | 57 | Args: 58 | sess: Tf.Session object. 59 | gan: Any AbstractGAN instance. 60 | test_images: numpy array with test images. 61 | max_train_examples: How many "train" examples to get from the dataset. 62 | In each round, some of them will be randomly selected 63 | to evaluate train set accuracy. 64 | num_repeat: How many times to repreat the computation. 65 | The mean of all the results is reported. 66 | Returns: 67 | Dict[Text, float] with all the computed scores. 68 | 69 | Raises: 70 | ValueError: If the number of test_images is greater than the number of 71 | training images returned by the dataset. 72 | """ 73 | logging.info("Evaluating training and test accuracy...") 74 | train_images = eval_utils.get_real_images( 75 | dataset=datasets.get_dataset(), 76 | num_examples=max_train_examples, 77 | split="train", 78 | failure_on_insufficient_examples=False) 79 | if train_images.shape[0] < test_images.shape[0]: 80 | raise ValueError("num_train %d must be larger than num_test %d." % 81 | (train_images.shape[0], test_images.shape[0])) 82 | 83 | num_batches = int(np.floor(test_images.shape[0] / gan.batch_size)) 84 | if num_batches * gan.batch_size < test_images.shape[0]: 85 | logging.error("Ignoring the last batch with %d samples / %d epoch size.", 86 | test_images.shape[0] - num_batches * gan.batch_size, 87 | gan.batch_size) 88 | 89 | ret = { 90 | "train_accuracy": [], 91 | "test_accuracy": [], 92 | "fake_accuracy": [], 93 | "train_d_loss": [], 94 | "test_d_loss": [] 95 | } 96 | 97 | for _ in range(num_repeat): 98 | idx = np.random.choice(train_images.shape[0], test_images.shape[0]) 99 | bs = gan.batch_size 100 | train_subset = [train_images[i] for i in idx] 101 | train_predictions, test_predictions, fake_predictions = [], [], [] 102 | train_d_losses, test_d_losses = [], [] 103 | 104 | for i in range(num_batches): 105 | z_sample = gan.z_generator(gan.batch_size, gan.z_dim) 106 | start_idx = i * bs 107 | end_idx = start_idx + bs 108 | test_batch = test_images[start_idx : end_idx] 109 | train_batch = train_subset[start_idx : end_idx] 110 | 111 | test_prediction, test_d_loss, fake_images = sess.run( 112 | [gan.discriminator_output, gan.d_loss, gan.fake_images], 113 | feed_dict={ 114 | gan.inputs: test_batch, gan.z: z_sample 115 | }) 116 | train_prediction, train_d_loss = sess.run( 117 | [gan.discriminator_output, gan.d_loss], 118 | feed_dict={ 119 | gan.inputs: train_batch, 120 | gan.z: z_sample 121 | }) 122 | fake_prediction = sess.run( 123 | gan.discriminator_output, 124 | feed_dict={gan.inputs: fake_images})[0] 125 | 126 | train_predictions.append(train_prediction[0]) 127 | test_predictions.append(test_prediction[0]) 128 | fake_predictions.append(fake_prediction) 129 | train_d_losses.append(train_d_loss) 130 | test_d_losses.append(test_d_loss) 131 | 132 | train_predictions = [x >= 0.5 for x in train_predictions] 133 | test_predictions = [x >= 0.5 for x in test_predictions] 134 | fake_predictions = [x < 0.5 for x in fake_predictions] 135 | 136 | ret["train_accuracy"].append(np.array(train_predictions).mean()) 137 | ret["test_accuracy"].append(np.array(test_predictions).mean()) 138 | ret["fake_accuracy"].append(np.array(fake_predictions).mean()) 139 | ret["train_d_loss"].append(np.mean(train_d_losses)) 140 | ret["test_d_loss"].append(np.mean(test_d_losses)) 141 | 142 | for key in ret: 143 | ret[key] = np.mean(ret[key]) 144 | 145 | return ret 146 | -------------------------------------------------------------------------------- /compare_gan/metrics/eval_task.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract class that describes a single evaluation task. 17 | 18 | The tasks can be run in or after session. Each task can result 19 | in a set of metrics. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import abc 27 | 28 | from absl import flags 29 | import six 30 | import tensorflow as tf 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | @six.add_metaclass(abc.ABCMeta) 36 | class EvalTask(object): 37 | """Class that describes a single evaluation task. 38 | 39 | For example: compute inception score or compute accuracy. 40 | The classes that inherit from it, should implement the methods below. 41 | """ 42 | 43 | _LABEL = None 44 | 45 | def metric_list(self): 46 | """List of metrics that this class generates. 47 | 48 | These are the only keys that RunXX methods can return in 49 | their output maps. 50 | Returns: 51 | frozenset of strings, which are the names of the metrics that task 52 | computes. 53 | """ 54 | return frozenset(self._LABEL) 55 | 56 | def _create_session(self): 57 | try: 58 | target = FLAGS.master 59 | except AttributeError: 60 | return tf.Session() 61 | return tf.Session(target) 62 | 63 | @abc.abstractmethod 64 | def run_after_session(self, fake_dset, real_dset): 65 | """Runs the task after all the generator calls, after session was closed. 66 | 67 | WARNING: the images here, are in 0..255 range, with 3 color channels. 68 | 69 | Args: 70 | fake_dset: `EvalDataSample` with fake images and inception features. 71 | real_dset: `EvalDataSample` with real images and inception features. 72 | 73 | Returns: 74 | Dict with metric values. The keys must be contained in the set that 75 | "MetricList" method above returns. 76 | """ 77 | -------------------------------------------------------------------------------- /compare_gan/metrics/fid_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the Frechet Inception Distance. 17 | 18 | Implemented as a wrapper around the tf.contrib.gan library. The details can be 19 | found in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash 20 | Equilibrium", Heusel et al. [https://arxiv.org/abs/1706.08500]. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from absl import logging 28 | 29 | from compare_gan.metrics import eval_task 30 | 31 | import tensorflow as tf 32 | import tensorflow_gan as tfgan 33 | 34 | 35 | # Special value returned when FID code returned exception. 36 | FID_CODE_FAILED = 4242.0 37 | 38 | 39 | class FIDScoreTask(eval_task.EvalTask): 40 | """Evaluation task for the FID score.""" 41 | 42 | _LABEL = "fid_score" 43 | 44 | def run_after_session(self, fake_dset, real_dset): 45 | logging.info("Calculating FID.") 46 | with tf.Graph().as_default(): 47 | fake_activations = tf.convert_to_tensor(fake_dset.activations) 48 | real_activations = tf.convert_to_tensor(real_dset.activations) 49 | fid = tfgan.eval.frechet_classifier_distance_from_activations( 50 | real_activations=real_activations, 51 | generated_activations=fake_activations) 52 | with self._create_session() as sess: 53 | fid = sess.run(fid) 54 | logging.info("Frechet Inception Distance: %.3f.", fid) 55 | return {self._LABEL: fid} 56 | 57 | 58 | def compute_fid_from_activations(fake_activations, real_activations): 59 | """Returns the FID based on activations. 60 | 61 | Args: 62 | fake_activations: NumPy array with fake activations. 63 | real_activations: NumPy array with real activations. 64 | Returns: 65 | A float, the Frechet Inception Distance. 66 | """ 67 | logging.info("Computing FID score.") 68 | assert fake_activations.shape == real_activations.shape 69 | with tf.Session(graph=tf.Graph()) as sess: 70 | fake_activations = tf.convert_to_tensor(fake_activations) 71 | real_activations = tf.convert_to_tensor(real_activations) 72 | fid = tfgan.eval.frechet_classifier_distance_from_activations( 73 | real_activations=real_activations, 74 | generated_activations=fake_activations) 75 | return sess.run(fid) 76 | -------------------------------------------------------------------------------- /compare_gan/metrics/fid_score_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for the FID score.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.metrics import fid_score as fid_score_lib 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | 28 | class FIDScoreTest(tf.test.TestCase): 29 | 30 | def test_fid_computation(self): 31 | real_data = np.ones((100, 2)) 32 | real_data[:50, 0] = 2 33 | gen_data = np.ones((100, 2)) * 9 34 | gen_data[50:, 0] = 2 35 | # mean(real_data) = [1.5, 1] 36 | # Cov(real_data) = [[ 0.2525, 0], [0, 0]] 37 | # mean(gen_data) = [5.5, 9] 38 | # Cov(gen_data) = [[12.37, 0], [0, 0]] 39 | result = fid_score_lib.compute_fid_from_activations(real_data, gen_data) 40 | self.assertNear(result, 89.091, 1e-4) 41 | 42 | if __name__ == "__main__": 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /compare_gan/metrics/fractal_dimension.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the fractal dimension metric.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.metrics import eval_task 23 | 24 | import numpy as np 25 | import scipy.spatial 26 | 27 | 28 | class FractalDimensionTask(eval_task.EvalTask): 29 | """Fractal dimension metric.""" 30 | 31 | _LABEL = "fractal_dimension" 32 | 33 | def run_after_session(self, options, eval_data_fake, eval_data_real=None): 34 | print(eval_data_fake) 35 | score = compute_fractal_dimension(eval_data_fake.images) 36 | return {self._LABEL: score} 37 | 38 | 39 | def compute_fractal_dimension(fake_images, 40 | num_fd_seeds=100, 41 | n_bins=1000, 42 | scale=0.1): 43 | """Compute Fractal Dimension of fake_images. 44 | 45 | Args: 46 | fake_images: an np array of datapoints, the dimensionality and scaling of 47 | images can be arbitrary 48 | num_fd_seeds: number of random centers from which fractal dimension 49 | computation is performed 50 | n_bins: number of bins to split the range of distance values into 51 | scale: the scale of the y interval in the log-log plot for which we apply a 52 | linear regression fit 53 | 54 | Returns: 55 | fractal dimension of the dataset. 56 | """ 57 | assert len(fake_images.shape) >= 2 58 | assert fake_images.shape[0] >= num_fd_seeds 59 | 60 | num_images = fake_images.shape[0] 61 | # In order to apply scipy function we need to flatten the number of dimensions 62 | # to 2 63 | fake_images = np.reshape(fake_images, (num_images, -1)) 64 | fake_images_subset = fake_images[np.random.randint( 65 | num_images, size=num_fd_seeds)] 66 | 67 | distances = scipy.spatial.distance.cdist(fake_images, 68 | fake_images_subset).flatten() 69 | min_distance = np.min(distances[np.nonzero(distances)]) 70 | max_distance = np.max(distances) 71 | buckets = min_distance * ( 72 | (max_distance / min_distance)**np.linspace(0, 1, n_bins)) 73 | # Create a table where first column corresponds to distances r 74 | # and second column corresponds to number of points N(r) that lie 75 | # within distance r from the random seeds 76 | fd_result = np.zeros((n_bins - 1, 2)) 77 | fd_result[:, 0] = buckets[1:] 78 | fd_result[:, 1] = np.sum(np.less.outer(distances, buckets[1:]), axis=0) 79 | 80 | # We compute the slope of the log-log plot at the middle y value 81 | # which is stored in y_val; the linear regression fit is computed on 82 | # the part of the plot that corresponds to an interval around y_val 83 | # whose size is 2*scale*(total width of the y axis) 84 | max_y = np.log(num_images * num_fd_seeds) 85 | min_y = np.log(num_fd_seeds) 86 | x = np.log(fd_result[:, 0]) 87 | y = np.log(fd_result[:, 1]) 88 | y_width = max_y - min_y 89 | y_val = min_y + 0.5 * y_width 90 | 91 | start = np.argmax(y > y_val - scale * y_width) 92 | end = np.argmax(y > y_val + scale * y_width) 93 | 94 | slope = np.linalg.lstsq( 95 | a=np.vstack([x[start:end], np.ones(end - start)]).transpose(), 96 | b=y[start:end].reshape(end - start, 1))[0][0][0] 97 | return slope 98 | -------------------------------------------------------------------------------- /compare_gan/metrics/fractal_dimension_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for the fractal dimension metric.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.metrics import fractal_dimension as fractal_dimension_lib 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | 28 | class FractalDimensionTest(tf.test.TestCase): 29 | 30 | def test_straight_line(self): 31 | """The fractal dimension of a 1D line must lie near 1.0.""" 32 | self.assertAllClose( 33 | fractal_dimension_lib.compute_fractal_dimension( 34 | np.random.uniform(size=(10000, 1))), 1.0, atol=0.05) 35 | 36 | def test_square(self): 37 | """The fractal dimension of a 2D square must lie near 2.0.""" 38 | self.assertAllClose( 39 | fractal_dimension_lib.compute_fractal_dimension( 40 | np.random.uniform(size=(10000, 2))), 2.0, atol=0.1) 41 | -------------------------------------------------------------------------------- /compare_gan/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the Inception Score. 17 | 18 | Implemented as a wrapper around the tensorflow_gan library. The details can be 19 | found in "Improved Techniques for Training GANs", Salimans et al. 20 | [https://arxiv.org/abs/1606.03498]. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from absl import logging 28 | 29 | from compare_gan.metrics import eval_task 30 | import tensorflow as tf 31 | import tensorflow_gan as tfgan 32 | 33 | 34 | class InceptionScoreTask(eval_task.EvalTask): 35 | """Task that computes inception score for the generated images.""" 36 | 37 | _LABEL = "inception_score" 38 | 39 | def run_after_session(self, fake_dset, real_dest): 40 | del real_dest 41 | logging.info("Computing inception score.") 42 | with tf.Graph().as_default(): 43 | fake_logits = tf.convert_to_tensor(fake_dset.logits) 44 | inception_score = tfgan.eval.classifier_score_from_logits(fake_logits) 45 | with self._create_session() as sess: 46 | inception_score = sess.run(inception_score) 47 | logging.info("Inception score: %.3f", inception_score) 48 | return {self._LABEL: inception_score} 49 | -------------------------------------------------------------------------------- /compare_gan/metrics/jacobian_conditioning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the Jacobian Conditioning metrics. 17 | 18 | The details can be found in "Is Generator Conditioning Causally Related to 19 | GAN Performance?", Odena et al. [https://arxiv.org/abs/1802.08768]. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from compare_gan.metrics import eval_task 27 | 28 | import numpy as np 29 | import tensorflow as tf 30 | 31 | 32 | class GeneratorConditionNumberTask(eval_task.EvalTask): 33 | """Computes the generator condition number. 34 | 35 | Computes the condition number for metric Tensor of the generator Jacobian. 36 | This condition number is computed locally for each z sample in a minibatch. 37 | Returns the mean log condition number and standard deviation across the 38 | minibatch. 39 | 40 | Follows the methods in https://arxiv.org/abs/1802.08768. 41 | """ 42 | 43 | _CONDITION_NUMBER_COUNT = "log_condition_number_count" 44 | _CONDITION_NUMBER_MEAN = "log_condition_number_mean" 45 | _CONDITION_NUMBER_STD = "log_condition_number_std" 46 | 47 | def metric_list(self): 48 | return frozenset([ 49 | self._CONDITION_NUMBER_COUNT, self._CONDITION_NUMBER_MEAN, 50 | self._CONDITION_NUMBER_STD 51 | ]) 52 | 53 | def run_in_session(self, options, sess, gan, real_images): 54 | del options, real_images 55 | result_dict = {} 56 | result = compute_generator_condition_number(sess, gan) 57 | result_dict[self._CONDITION_NUMBER_COUNT] = len(result) 58 | result_dict[self._CONDITION_NUMBER_MEAN] = np.mean(result) 59 | result_dict[self._CONDITION_NUMBER_STD] = np.std(result) 60 | return result_dict 61 | 62 | 63 | def compute_generator_condition_number(sess, gan): 64 | """Computes the generator condition number. 65 | 66 | Computes the Jacobian of the generator in session, then postprocesses to get 67 | the condition number. 68 | 69 | Args: 70 | sess: tf.Session object. 71 | gan: AbstractGAN object, that is already present in the current tf.Graph. 72 | 73 | Returns: 74 | A list of length gan.batch_size. Each element is the condition number 75 | computed at a single z sample within a minibatch. 76 | """ 77 | shape = gan.fake_images.get_shape().as_list() 78 | flat_generator_output = tf.reshape( 79 | gan.fake_images, [gan.batch_size, np.prod(shape[1:])]) 80 | tf_jacobian = compute_jacobian( 81 | xs=gan.z, fx=flat_generator_output) 82 | z_sample = gan.z_generator(gan.batch_size, gan.z_dim) 83 | np_jacobian = sess.run(tf_jacobian, feed_dict={gan.z: z_sample}) 84 | result_dict = analyze_jacobian(np_jacobian) 85 | return result_dict["metric_tensor"]["log_condition_number"] 86 | 87 | 88 | def compute_jacobian(xs, fx): 89 | """Computes df/dx matrix. 90 | 91 | We assume x and fx are both batched, so the shape of the Jacobian is: 92 | [fx.shape[0]] + fx.shape[1:] + xs.shape[1:] 93 | 94 | This function computes the grads inside a TF loop so that we don't 95 | end up storing many extra copies of the function we are taking the 96 | Jacobian of. 97 | 98 | Args: 99 | xs: input tensor(s) of arbitrary shape. 100 | fx: f(x) tensor of arbitrary shape. 101 | 102 | Returns: 103 | df/dx tensor of shape [fx.shape[0], fx.shape[1], xs.shape[1]]. 104 | """ 105 | # Declares an iterator and tensor array loop variables for the gradients. 106 | n = fx.get_shape().as_list()[1] 107 | loop_vars = [tf.constant(0, tf.int32), tf.TensorArray(xs.dtype, n)] 108 | 109 | def accumulator(j, result): 110 | return (j + 1, result.write(j, tf.gradients(fx[:, j], xs)[0])) 111 | 112 | # Iterates over all elements of the gradient and computes all partial 113 | # derivatives. 114 | _, df_dxs = tf.while_loop(lambda j, _: j < n, accumulator, loop_vars) 115 | 116 | df_dx = df_dxs.stack() 117 | df_dx = tf.transpose(df_dx, perm=[1, 0, 2]) 118 | 119 | return df_dx 120 | 121 | 122 | def _analyze_metric_tensor(metric_tensor): 123 | """Analyzes a metric tensor. 124 | 125 | Args: 126 | metric_tensor: A numpy array of shape [batch, dim, dim] 127 | 128 | Returns: 129 | A dict containing spectral statstics. 130 | """ 131 | # eigenvalues will have shape [batch, dim]. 132 | eigenvalues, _ = np.linalg.eig(metric_tensor) 133 | 134 | # Shape [batch,]. 135 | condition_number = np.linalg.cond(metric_tensor) 136 | log_condition_number = np.log(condition_number) 137 | (_, logdet) = np.linalg.slogdet(metric_tensor) 138 | 139 | return { 140 | "eigenvalues": eigenvalues, 141 | "logdet": logdet, 142 | "log_condition_number": log_condition_number 143 | } 144 | 145 | 146 | def analyze_jacobian(jacobian_array): 147 | """Computes eigenvalue statistics of the Jacobian. 148 | 149 | Computes the eigenvalues and condition number of the metric tensor for the 150 | Jacobian evaluated at each element of the batch and the mean metric tensor 151 | across the batch. 152 | 153 | Args: 154 | jacobian_array: A numpy array holding the Jacobian. 155 | 156 | Returns: 157 | A dict of spectral statistics with two elements, one containing stats 158 | for every metric tensor in the batch, another for the mean metric tensor. 159 | """ 160 | # Shape [batch, x_dim, fx_dim]. 161 | jacobian_transpose = np.transpose(jacobian_array, [0, 2, 1]) 162 | 163 | # Shape [batch, x_dim, x_dim]. 164 | metric_tensor = np.matmul(jacobian_transpose, jacobian_array) 165 | 166 | mean_metric_tensor = np.mean(metric_tensor, 0) 167 | # Reshapes to have a dummy batch dimension. 168 | mean_metric_tensor = np.reshape(mean_metric_tensor, 169 | (1,) + metric_tensor.shape[1:]) 170 | return { 171 | "metric_tensor": _analyze_metric_tensor(metric_tensor), 172 | "mean_metric_tensor": _analyze_metric_tensor(mean_metric_tensor) 173 | } 174 | -------------------------------------------------------------------------------- /compare_gan/metrics/jacobian_conditioning_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for Jacobian Conditioning metrics.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.metrics import jacobian_conditioning 23 | import mock 24 | import numpy as np 25 | from six.moves import range 26 | import tensorflow as tf 27 | 28 | 29 | _BATCH_SIZE = 32 30 | 31 | 32 | def SlowJacobian(xs, fx): 33 | """Computes df/dx matrix. 34 | 35 | As jacobian_conditioning.compute_jacobian, but explicitly loops over 36 | dimensions of f. 37 | 38 | Args: 39 | xs: input tensor(s) of arbitrary shape. 40 | fx: f(x) tensor of arbitrary shape. 41 | 42 | Returns: 43 | df/dx tensor. 44 | """ 45 | fxs = tf.unstack(fx, axis=-1) 46 | grads = [tf.gradients(fx_i, xs) for fx_i in fxs] 47 | grads = [grad[0] for grad in grads] 48 | df_dx = tf.stack(grads, axis=1) 49 | return df_dx 50 | 51 | 52 | class JacobianConditioningTest(tf.test.TestCase): 53 | 54 | def test_jacobian_simple_case(self): 55 | x = tf.random_normal([_BATCH_SIZE, 2]) 56 | W = tf.constant([[2., -1.], [1.5, 1.]]) # pylint: disable=invalid-name 57 | f = tf.matmul(x, W) 58 | j_tensor = jacobian_conditioning.compute_jacobian(xs=x, fx=f) 59 | with tf.Session() as sess: 60 | jacobian = sess.run(j_tensor) 61 | 62 | # Transpose of W in 'expected' is expected because in vector notation 63 | # f = W^T * x. 64 | expected = tf.tile([[[2, 1.5], [-1, 1]]], [_BATCH_SIZE, 1, 1]) 65 | self.assertAllClose(jacobian, expected) 66 | 67 | def test_jacobian_against_slow_version(self): 68 | x = tf.random_normal([_BATCH_SIZE, 2]) 69 | h1 = tf.contrib.layers.fully_connected(x, 20) 70 | h2 = tf.contrib.layers.fully_connected(h1, 20) 71 | f = tf.contrib.layers.fully_connected(h2, 10) 72 | 73 | j_slow_tensor = SlowJacobian(xs=x, fx=f) 74 | j_fast_tensor = jacobian_conditioning.compute_jacobian(xs=x, fx=f) 75 | 76 | with tf.Session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | j_fast, j_slow = sess.run([j_fast_tensor, j_slow_tensor]) 79 | self.assertAllClose(j_fast, j_slow) 80 | 81 | def test_jacobian_numerically(self): 82 | x = tf.random_normal([_BATCH_SIZE, 2]) 83 | h1 = tf.contrib.layers.fully_connected(x, 20) 84 | h2 = tf.contrib.layers.fully_connected(h1, 20) 85 | f = tf.contrib.layers.fully_connected(h2, 10) 86 | j_tensor = jacobian_conditioning.compute_jacobian(xs=x, fx=f) 87 | 88 | with tf.Session() as sess: 89 | sess.run(tf.global_variables_initializer()) 90 | x_np = sess.run(x) 91 | jacobian = sess.run(j_tensor, feed_dict={x: x_np}) 92 | 93 | # Test 10 random elements. 94 | for _ in range(10): 95 | # Pick a random element of Jacobian to test. 96 | batch_idx = np.random.randint(_BATCH_SIZE) 97 | x_idx = np.random.randint(2) 98 | f_idx = np.random.randint(10) 99 | 100 | # Test with finite differences. 101 | epsilon = 1e-4 102 | 103 | x_plus = x_np.copy() 104 | x_plus[batch_idx, x_idx] += epsilon 105 | f_plus = sess.run(f, feed_dict={x: x_plus})[batch_idx, f_idx] 106 | 107 | x_minus = x_np.copy() 108 | x_minus[batch_idx, x_idx] -= epsilon 109 | f_minus = sess.run(f, feed_dict={x: x_minus})[batch_idx, f_idx] 110 | 111 | self.assertAllClose( 112 | jacobian[batch_idx, f_idx, x_idx], 113 | (f_plus - f_minus) / (2. * epsilon), 114 | rtol=1e-3, 115 | atol=1e-3) 116 | 117 | def test_analyze_metric_tensor(self): 118 | # Assumes NumPy works, just tests that output shapes are as expected. 119 | jacobian = np.random.normal(0, 1, (_BATCH_SIZE, 2, 10)) 120 | metric_tensor = np.matmul(np.transpose(jacobian, [0, 2, 1]), jacobian) 121 | result_dict = jacobian_conditioning._analyze_metric_tensor(metric_tensor) 122 | self.assertAllEqual(result_dict['eigenvalues'].shape, [_BATCH_SIZE, 10]) 123 | self.assertAllEqual(result_dict['logdet'].shape, [_BATCH_SIZE]) 124 | self.assertAllEqual(result_dict['log_condition_number'].shape, 125 | [_BATCH_SIZE]) 126 | 127 | def test_analyze_jacobian(self): 128 | m = mock.patch.object( 129 | jacobian_conditioning, '_analyze_metric_tensor', new=lambda x: x) 130 | m.start() 131 | jacobian = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) 132 | result_dict = jacobian_conditioning.analyze_jacobian(jacobian) 133 | self.assertAllEqual(result_dict['metric_tensor'], 134 | [[[10, 14], [14, 20]], [[40, 56], [56, 80]]]) 135 | self.assertAllEqual(result_dict['mean_metric_tensor'], 136 | [[[25, 35], [35, 50]]]) 137 | m.stop() 138 | 139 | 140 | if __name__ == '__main__': 141 | tf.test.main() 142 | -------------------------------------------------------------------------------- /compare_gan/metrics/kid_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the KID score. 17 | 18 | The details can be found in "Demystifying MMD GANs", Binkowski et al. 19 | [https://arxiv.org/abs/1801.01401]. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import math 27 | 28 | from compare_gan.metrics import eval_task 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | 34 | class KIDScoreTask(eval_task.EvalTask): 35 | """Evaluation task for the KID score.""" 36 | 37 | _LABEL = "kid_score" 38 | 39 | def run_after_session(self, fake_dset, real_dset): 40 | score = kid(fake_dset.activations, real_dset.activations) 41 | return {self._LABEL: score} 42 | 43 | 44 | def kid(fake_activations, 45 | real_activations, 46 | max_batch_size=1024, 47 | dtype=None, 48 | return_stderr=False): 49 | """Unbiased estimator of the Kernel Inception Distance. 50 | 51 | As defined by https://arxiv.org/abs/1801.01401. 52 | 53 | If return_stderr, also returns an estimate of the standard error, i.e. the 54 | standard deviation of the KID estimator. Returns nan if the number 55 | of batches is too small (< 5); for more reliable estimates, one could use 56 | the asymptotic variance estimate given in https://arxiv.org/abs/1611.04488. 57 | 58 | Uses a block estimator, as in https://arxiv.org/abs/1307.1954, with blocks 59 | no larger than max_batch_size. This is slightly different than the authors' 60 | provided code, but is also unbiased (and provides more-valid a variance 61 | estimate). 62 | 63 | NOTE: the blocking code assumes that real_activations and 64 | fake_activations are in random order. If real_activations is sorted 65 | in a meaningful order, the estimator will be biased. 66 | 67 | Args: 68 | fake_activations: [batch, num_features] tensor with inception features. 69 | real_activations: [batch, num_features] tensor with inception features. 70 | max_batch_size: Batches to compute the KID. 71 | dtype: Type used by the computations. 72 | return_stderr: If true, also returns the std_error from the KID computation. 73 | 74 | Returns: 75 | KID score (and optionally std error). 76 | """ 77 | real_activations.get_shape().assert_has_rank(2) 78 | fake_activations.get_shape().assert_has_rank(2) 79 | 80 | # need to know dimension for the kernel, and batch size to split things 81 | real_activations.get_shape().assert_is_fully_defined() 82 | fake_activations.get_shape().assert_is_fully_defined() 83 | 84 | n_real, dim = real_activations.get_shape().as_list() 85 | n_gen, dim2 = fake_activations.get_shape().as_list() 86 | assert dim2 == dim 87 | 88 | # tensorflow_gan forces doubles for FID, but I don't think we need that here 89 | if dtype is None: 90 | dtype = real_activations.dtype 91 | assert fake_activations.dtype == dtype 92 | else: 93 | real_activations = tf.cast(real_activations, dtype) 94 | fake_activations = tf.cast(fake_activations, dtype) 95 | 96 | # split into largest approximately-equally-sized blocks 97 | n_bins = int(math.ceil(max(n_real, n_gen) / max_batch_size)) 98 | bins_r = np.full(n_bins, int(math.ceil(n_real / n_bins))) 99 | bins_g = np.full(n_bins, int(math.ceil(n_gen / n_bins))) 100 | bins_r[:(n_bins * bins_r[0]) - n_real] -= 1 101 | bins_g[:(n_bins * bins_r[0]) - n_gen] -= 1 102 | assert bins_r.min() >= 2 103 | assert bins_g.min() >= 2 104 | 105 | inds_r = tf.constant(np.r_[0, np.cumsum(bins_r)]) 106 | inds_g = tf.constant(np.r_[0, np.cumsum(bins_g)]) 107 | 108 | dim_ = tf.cast(dim, dtype) 109 | 110 | def get_kid_batch(i): 111 | """Computes KID on a given batch of features. 112 | 113 | Takes real_activations[ind_r[i] : ind_r[i+1]] and 114 | fake_activations[ind_g[i] : ind_g[i+1]]. 115 | 116 | Args: 117 | i: is the index of the batch. 118 | 119 | Returns: 120 | KID for the given batch. 121 | """ 122 | r_s = inds_r[i] 123 | r_e = inds_r[i + 1] 124 | r = real_activations[r_s:r_e] 125 | m = tf.cast(r_e - r_s, dtype) 126 | 127 | g_s = inds_g[i] 128 | g_e = inds_g[i + 1] 129 | g = fake_activations[g_s:g_e] 130 | n = tf.cast(r_e - r_s, dtype) 131 | 132 | # Could probably do this a bit faster... 133 | k_rr = (tf.matmul(r, r, transpose_b=True) / dim_ + 1)**3 134 | k_rg = (tf.matmul(r, g, transpose_b=True) / dim_ + 1)**3 135 | k_gg = (tf.matmul(g, g, transpose_b=True) / dim_ + 1)**3 136 | return ( 137 | -2 * tf.reduce_mean(k_rg) + (tf.reduce_sum(k_rr) - tf.trace(k_rr)) / 138 | (m * (m - 1)) + (tf.reduce_sum(k_gg) - tf.trace(k_gg)) / (n * (n - 1))) 139 | 140 | ests = tf.map_fn( 141 | get_kid_batch, np.arange(n_bins), dtype=dtype, back_prop=False) 142 | 143 | if return_stderr: 144 | if n_bins < 5: 145 | return tf.reduce_mean(ests), np.nan 146 | mn, var = tf.nn.moments(ests, [0]) 147 | return mn, tf.sqrt(var / n_bins) 148 | else: 149 | return tf.reduce_mean(ests) 150 | -------------------------------------------------------------------------------- /compare_gan/metrics/ms_ssim_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of the MS-SSIM metric. 17 | 18 | The details on the application of this metric to GANs can be found in 19 | Section 5.3 of "Many Paths to Equilibrium: GANs Do Not Need to Decrease a 20 | Divergence At Every Step", Fedus*, Rosca* et al. 21 | [https://arxiv.org/abs/1710.08446]. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from absl import logging 29 | 30 | from compare_gan.metrics import eval_task 31 | from compare_gan.metrics import image_similarity 32 | 33 | import numpy as np 34 | from six.moves import range 35 | import tensorflow as tf 36 | 37 | 38 | class MultiscaleSSIMTask(eval_task.EvalTask): 39 | """Task that computes MSSIMScore for generated images.""" 40 | 41 | _LABEL = "ms_ssim" 42 | 43 | def run_after_session(self, options, eval_data_fake, eval_data_real=None): 44 | del options, eval_data_real 45 | score = _compute_multiscale_ssim_score(eval_data_fake.images) 46 | return {self._LABEL: score} 47 | 48 | 49 | def _compute_multiscale_ssim_score(fake_images): 50 | """Compute ms-ssim score .""" 51 | batch_size = 64 52 | with tf.Graph().as_default(): 53 | fake_images_batch = tf.train.shuffle_batch( 54 | [tf.convert_to_tensor(fake_images, dtype=tf.float32)], 55 | capacity=16*batch_size, 56 | min_after_dequeue=8*batch_size, 57 | num_threads=4, 58 | enqueue_many=True, 59 | batch_size=batch_size) 60 | 61 | # Following section 5.3 of https://arxiv.org/pdf/1710.08446.pdf, we only 62 | # evaluate 5 batches of the generated images. 63 | eval_fn = compute_msssim( 64 | generated_images=fake_images_batch, num_batches=5) 65 | with tf.train.MonitoredTrainingSession() as sess: 66 | score = eval_fn(sess) 67 | return score 68 | 69 | 70 | def compute_msssim(generated_images, num_batches): 71 | """Get a fn returning the ms ssim score for generated images. 72 | 73 | Args: 74 | generated_images: TF Tensor of shape [batch_size, dim, dim, 3] which 75 | evaluates to a batch of generated images. Should be in range [0..255]. 76 | num_batches: Number of batches to consider. 77 | 78 | Returns: 79 | eval_fn: a function which takes a session as an argument and returns the 80 | average ms ssim score among all the possible image pairs from 81 | generated_images. 82 | """ 83 | batch_size = int(generated_images.get_shape()[0]) 84 | assert batch_size > 1 85 | 86 | # Generate all possible image pairs from input set of imgs. 87 | pair1 = tf.tile(generated_images, [batch_size, 1, 1, 1]) 88 | pair2 = tf.reshape( 89 | tf.tile(generated_images, [1, batch_size, 1, 1]), [ 90 | batch_size * batch_size, generated_images.shape[1], 91 | generated_images.shape[2], generated_images.shape[3] 92 | ]) 93 | 94 | # Compute the mean of the scores (but ignore the 'identical' images - which 95 | # should get 1.0 from the MultiscaleSSIM) 96 | score = tf.reduce_sum(image_similarity.multiscale_ssim(pair1, pair2)) 97 | score -= batch_size 98 | score = tf.div(score, batch_size * batch_size - batch_size) 99 | 100 | # Define a function which wraps some session.run calls to generate a large 101 | # number of images and compute multiscale ssim metric on them. 102 | def _eval_fn(session): 103 | """Function which wraps session.run calls to compute given metric.""" 104 | logging.info("Computing MS-SSIM score...") 105 | scores = [] 106 | for _ in range(num_batches): 107 | scores.append(session.run(score)) 108 | 109 | result = np.mean(scores) 110 | return result 111 | return _eval_fn 112 | -------------------------------------------------------------------------------- /compare_gan/metrics/ms_ssim_score_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for the MS-SSIM score.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from compare_gan.metrics import ms_ssim_score 23 | import tensorflow as tf 24 | 25 | 26 | class MsSsimScoreTest(tf.test.TestCase): 27 | 28 | def test_on_one_vs_07_vs_zero_images(self): 29 | """Computes the SSIM value for 3 simple images.""" 30 | with tf.Graph().as_default(): 31 | generated_images = tf.stack([ 32 | tf.ones([64, 64, 3]), 33 | tf.ones([64, 64, 3]) * 0.7, 34 | tf.zeros([64, 64, 3]), 35 | ]) 36 | metric = ms_ssim_score.compute_msssim(generated_images, 1) 37 | with tf.Session() as sess: 38 | sess.run(tf.global_variables_initializer()) 39 | result = metric(sess) 40 | self.assertNear(result, 0.989989, 0.001) 41 | 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /compare_gan/metrics/prd_score_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Testing precision and recall computation on synthetic data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unittest 23 | from compare_gan.metrics import prd_score as prd 24 | import numpy as np 25 | 26 | 27 | class PRDTest(unittest.TestCase): 28 | 29 | def test_compute_prd_no_overlap(self): 30 | eval_dist = [0, 1] 31 | ref_dist = [1, 0] 32 | result = np.ravel(prd.compute_prd(eval_dist, ref_dist)) 33 | np.testing.assert_almost_equal(result, 0) 34 | 35 | def test_compute_prd_perfect_overlap(self): 36 | eval_dist = [1, 0] 37 | ref_dist = [1, 0] 38 | result = prd.compute_prd(eval_dist, ref_dist, num_angles=11) 39 | np.testing.assert_almost_equal([result[0][5], result[1][5]], [1, 1]) 40 | 41 | def test_compute_prd_low_precision_high_recall(self): 42 | eval_dist = [0.5, 0.5] 43 | ref_dist = [1, 0] 44 | result = prd.compute_prd(eval_dist, ref_dist, num_angles=11) 45 | np.testing.assert_almost_equal(result[0][5], 0.5) 46 | np.testing.assert_almost_equal(result[1][5], 0.5) 47 | np.testing.assert_almost_equal(result[0][10], 0.5) 48 | np.testing.assert_almost_equal(result[1][1], 1) 49 | 50 | def test_compute_prd_high_precision_low_recall(self): 51 | eval_dist = [1, 0] 52 | ref_dist = [0.5, 0.5] 53 | result = prd.compute_prd(eval_dist, ref_dist, num_angles=11) 54 | np.testing.assert_almost_equal([result[0][5], result[1][5]], [0.5, 0.5]) 55 | np.testing.assert_almost_equal(result[1][1], 0.5) 56 | np.testing.assert_almost_equal(result[0][10], 1) 57 | 58 | def test_compute_prd_bad_epsilon(self): 59 | with self.assertRaises(ValueError): 60 | prd.compute_prd([1], [1], epsilon=0) 61 | with self.assertRaises(ValueError): 62 | prd.compute_prd([1], [1], epsilon=1) 63 | with self.assertRaises(ValueError): 64 | prd.compute_prd([1], [1], epsilon=-1) 65 | 66 | def test_compute_prd_bad_num_angles(self): 67 | with self.assertRaises(ValueError): 68 | prd.compute_prd([1], [1], num_angles=0) 69 | with self.assertRaises(ValueError): 70 | prd.compute_prd([1], [1], num_angles=1) 71 | with self.assertRaises(ValueError): 72 | prd.compute_prd([1], [1], num_angles=-1) 73 | with self.assertRaises(ValueError): 74 | prd.compute_prd([1], [1], num_angles=1e6+1) 75 | with self.assertRaises(ValueError): 76 | prd.compute_prd([1], [1], num_angles=2.5) 77 | 78 | def test__cluster_into_bins(self): 79 | eval_data = np.zeros([5, 4]) 80 | ref_data = np.ones([5, 4]) 81 | result = prd._cluster_into_bins(eval_data, ref_data, 3) 82 | 83 | self.assertEqual(len(result), 2) 84 | self.assertEqual(len(result[0]), 3) 85 | self.assertEqual(len(result[1]), 3) 86 | np.testing.assert_almost_equal(sum(result[0]), 1) 87 | np.testing.assert_almost_equal(sum(result[1]), 1) 88 | 89 | def test_compute_prd_from_embedding_mismatch_num_samples_should_fail(self): 90 | # Mismatch in number of samples with enforce_balance set to True 91 | with self.assertRaises(ValueError): 92 | prd.compute_prd_from_embedding( 93 | np.array([[0], [0], [1]]), np.array([[0], [1]]), num_clusters=2, 94 | enforce_balance=True) 95 | 96 | def test_compute_prd_from_embedding_mismatch_num_samples_should_work(self): 97 | # Mismatch in number of samples with enforce_balance set to False 98 | try: 99 | prd.compute_prd_from_embedding( 100 | np.array([[0], [0], [1]]), np.array([[0], [1]]), num_clusters=2, 101 | enforce_balance=False) 102 | except ValueError: 103 | self.fail( 104 | 'compute_prd_from_embedding should not raise a ValueError when ' 105 | 'enforce_balance is set to False.') 106 | 107 | def test__prd_to_f_beta_correct_computation(self): 108 | precision = np.array([1, 1, 0, 0, 0.5, 1, 0.5]) 109 | recall = np.array([1, 0, 1, 0, 0.5, 0.5, 1]) 110 | expected = np.array([1, 0, 0, 0, 0.5, 2/3, 2/3]) 111 | with np.errstate(invalid='ignore'): 112 | result = prd._prd_to_f_beta(precision, recall, beta=1) 113 | np.testing.assert_almost_equal(result, expected) 114 | 115 | expected = np.array([1, 0, 0, 0, 0.5, 5/9, 5/6]) 116 | with np.errstate(invalid='ignore'): 117 | result = prd._prd_to_f_beta(precision, recall, beta=2) 118 | np.testing.assert_almost_equal(result, expected) 119 | 120 | expected = np.array([1, 0, 0, 0, 0.5, 5/6, 5/9]) 121 | with np.errstate(invalid='ignore'): 122 | result = prd._prd_to_f_beta(precision, recall, beta=1/2) 123 | np.testing.assert_almost_equal(result, expected) 124 | 125 | result = prd._prd_to_f_beta(np.array([]), np.array([]), beta=1) 126 | expected = np.array([]) 127 | np.testing.assert_almost_equal(result, expected) 128 | 129 | def test__prd_to_f_beta_bad_beta(self): 130 | with self.assertRaises(ValueError): 131 | prd._prd_to_f_beta(np.ones(1), np.ones(1), beta=0) 132 | with self.assertRaises(ValueError): 133 | prd._prd_to_f_beta(np.ones(1), np.ones(1), beta=-3) 134 | 135 | def test__prd_to_f_beta_bad_precision_or_recall(self): 136 | with self.assertRaises(ValueError): 137 | prd._prd_to_f_beta(-np.ones(1), np.ones(1), beta=1) 138 | with self.assertRaises(ValueError): 139 | prd._prd_to_f_beta(np.ones(1), -np.ones(1), beta=1) 140 | 141 | def test_plot_not_enough_labels(self): 142 | with self.assertRaises(ValueError): 143 | prd.plot(np.zeros([3, 2, 5]), labels=['1', '2']) 144 | 145 | def test_plot_too_many_labels(self): 146 | with self.assertRaises(ValueError): 147 | prd.plot(np.zeros([1, 2, 5]), labels=['1', '2', '3']) 148 | 149 | 150 | if __name__ == '__main__': 151 | unittest.main() 152 | -------------------------------------------------------------------------------- /compare_gan/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility classes and methods for testing.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import datetime 23 | import os 24 | 25 | from absl import flags 26 | from compare_gan import eval_utils 27 | from compare_gan.architectures import abstract_arch 28 | from compare_gan.architectures import arch_ops 29 | import gin 30 | import mock 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | 37 | def create_fake_inception_graph(): 38 | """Creates a graph that mocks inception. 39 | 40 | It takes the input, multiplies it through a matrix full of 0.001 values 41 | and returns as logits. It makes sure to match the tensor names of 42 | the real inception model. 43 | 44 | Returns: 45 | tf.Graph object with a simple mock inception inside. 46 | """ 47 | fake_inception = tf.Graph() 48 | with fake_inception.as_default(): 49 | graph_input = tf.placeholder( 50 | tf.float32, shape=[None, 299, 299, 3], name="Mul") 51 | matrix = tf.ones(shape=[299 * 299 * 3, 10]) * 0.00001 52 | output = tf.matmul(tf.layers.flatten(graph_input), matrix) 53 | output = tf.identity(output, name="pool_3") 54 | output = tf.identity(output, name="logits") 55 | return fake_inception.as_graph_def() 56 | 57 | 58 | class Generator(abstract_arch.AbstractGenerator): 59 | """Generator with a single linear layer from z to the output.""" 60 | 61 | def __init__(self, **kwargs): 62 | super(Generator, self).__init__(**kwargs) 63 | self.call_arg_list = [] 64 | 65 | def apply(self, z, y, is_training): 66 | self.call_arg_list.append(dict(z=z, y=y, is_training=is_training)) 67 | batch_size = z.shape[0].value 68 | out = arch_ops.linear(z, np.prod(self._image_shape), scope="fc_noise") 69 | out = tf.nn.sigmoid(out) 70 | return tf.reshape(out, [batch_size] + list(self._image_shape)) 71 | 72 | 73 | class Discriminator(abstract_arch.AbstractDiscriminator): 74 | """Discriminator with a single linear layer.""" 75 | 76 | def __init__(self, **kwargs): 77 | super(Discriminator, self).__init__(**kwargs) 78 | self.call_arg_list = [] 79 | 80 | def apply(self, x, y, is_training): 81 | self.call_arg_list.append(dict(x=x, y=y, is_training=is_training)) 82 | h = tf.reduce_mean(x, axis=[1, 2]) 83 | out = arch_ops.linear(h, 1) 84 | return tf.nn.sigmoid(out), out, h 85 | 86 | 87 | class CompareGanTestCase(tf.test.TestCase): 88 | """Base class for test cases.""" 89 | 90 | def setUp(self): 91 | super(CompareGanTestCase, self).setUp() 92 | # Use fake datasets instead of reading real files. 93 | FLAGS.data_fake_dataset = True 94 | # Clear the gin cofiguration. 95 | gin.clear_config() 96 | # Mock the inception graph. 97 | fake_inception_graph = create_fake_inception_graph() 98 | self.inception_graph_def_mock = mock.patch.object( 99 | eval_utils, 100 | "get_inception_graph_def", 101 | return_value=fake_inception_graph).start() 102 | 103 | def _get_empty_model_dir(self): 104 | unused_sub_dir = str(datetime.datetime.now().microsecond) 105 | model_dir = os.path.join(FLAGS.test_tmpdir, unused_sub_dir) 106 | assert not tf.gfile.Exists(model_dir) 107 | return model_dir 108 | -------------------------------------------------------------------------------- /compare_gan/tpu/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # coding=utf-8 17 | -------------------------------------------------------------------------------- /compare_gan/tpu/tpu_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tensorflow operations specific to TPUs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin 23 | from six.moves import range 24 | import tensorflow as tf 25 | 26 | from tensorflow.contrib.tpu.python.tpu import tpu_function 27 | 28 | 29 | def cross_replica_concat(value, replica_id, num_replicas): 30 | """Reduce a concatenation of the `value` across TPU replicas. 31 | 32 | Args: 33 | value: Tensor to concatenate. 34 | replica_id: Integer tensor that indicates the index of the replica. 35 | num_replicas: Python integer, total number of replicas. 36 | 37 | Returns: 38 | Tensor of the same rank as value with first dimension `num_replicas` 39 | times larger. 40 | 41 | Raises: 42 | ValueError: If `value` is a scalar. 43 | """ 44 | if value.shape.ndims < 1: 45 | raise ValueError("Value must have at least rank 1 but got {}.".format( 46 | value.shape.ndims)) 47 | if num_replicas <= 1: 48 | return value 49 | with tf.name_scope(None, "tpu_cross_replica_concat"): 50 | # Mask is one hot encoded position of the core_index. 51 | mask = tf.to_float(tf.equal(tf.range(num_replicas), replica_id)) 52 | # Expand dims with 1's to match rank of value. 53 | mask = tf.reshape(mask, [num_replicas] + [1] * value.shape.ndims) 54 | if value.dtype in {tf.bfloat16, tf.float32}: 55 | result = mask * value 56 | else: 57 | result = mask * tf.to_float(value) 58 | # Thanks to broadcasting now result is set only in the position pointed by 59 | # replica_id, the rest of the vector is set to 0's. 60 | # All these steps are basically implementing tf.scatter_nd which is missing 61 | # in TPU's backend since it doesn't support sparse operations. 62 | 63 | # Merge first 2 dimensions. 64 | # This is equivalent to (value.shape[0].value * num_replicas). 65 | # Using [-1] trick to support also scalar input. 66 | result = tf.reshape(result, [-1] + result.shape.as_list()[2:]) 67 | # Each core set the "results" in position pointed by replica_id. When we now 68 | # sum across replicas we exchange the information and fill in local 0's with 69 | # values from other cores. 70 | result = tf.contrib.tpu.cross_replica_sum(result) 71 | # Now all the cores see exactly the same data. 72 | return tf.cast(result, dtype=value.dtype) 73 | 74 | 75 | def cross_replica_mean(inputs, group_size=None): 76 | """Calculates the average value of inputs tensor across TPU replicas.""" 77 | num_replicas = tpu_function.get_tpu_context().number_of_shards 78 | if not group_size: 79 | group_size = num_replicas 80 | if group_size == 1: 81 | return inputs 82 | if group_size != num_replicas: 83 | group_assignment = [] 84 | assert num_replicas % group_size == 0 85 | for g in range(num_replicas // group_size): 86 | replica_ids = [g * group_size + i for i in range(group_size)] 87 | group_assignment.append(replica_ids) 88 | else: 89 | group_assignment = None 90 | return tf.contrib.tpu.cross_replica_sum(inputs, group_assignment) / tf.cast( 91 | group_size, inputs.dtype) 92 | 93 | 94 | @gin.configurable(blacklist=["inputs", "axis"]) 95 | def cross_replica_moments(inputs, axis, parallel=True, group_size=None): 96 | """Compute mean and variance of the inputs tensor across TPU replicas. 97 | 98 | Args: 99 | inputs: A tensor with 2 or more dimensions. 100 | axis: Array of ints. Axes along which to compute mean and variance. 101 | parallel: Use E[x^2] - (E[x])^2 to compute variance. Then can be done 102 | in parallel to computing the mean and reducing the communication overhead. 103 | group_size: Integer, the number of replicas to compute moments arcoss. 104 | None or 0 will use all replicas (global). 105 | 106 | Returns: 107 | Two tensors with mean and variance. 108 | """ 109 | # Compute local mean and then average across replicas. 110 | mean = tf.math.reduce_mean(inputs, axis=axis) 111 | mean = cross_replica_mean(mean) 112 | if parallel: 113 | # Compute variance using the E[x^2] - (E[x])^2 formula. This is less 114 | # numerically stable than the E[(x-E[x])^2] formula, but allows the two 115 | # cross-replica sums to be computed in parallel, saving communication 116 | # overhead. 117 | mean_of_squares = tf.reduce_mean(tf.square(inputs), axis=axis) 118 | mean_of_squares = cross_replica_mean(mean_of_squares, group_size=group_size) 119 | mean_squared = tf.square(mean) 120 | variance = mean_of_squares - mean_squared 121 | else: 122 | variance = tf.math.reduce_mean( 123 | tf.math.square(inputs - mean), axis=axis) 124 | variance = cross_replica_mean(variance, group_size=group_size) 125 | return mean, variance 126 | -------------------------------------------------------------------------------- /compare_gan/tpu/tpu_ops_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests custom TensorFlow operations for TPU.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import logging 23 | from absl.testing import parameterized 24 | from compare_gan.tpu import tpu_ops 25 | import numpy as np 26 | import tensorflow as tf 27 | 28 | 29 | class TpuOpsTpuTest(parameterized.TestCase, tf.test.TestCase): 30 | 31 | def testRunsOnTpu(self): 32 | """Verify that the test cases runs on a TPU chip and has 2 cores.""" 33 | expected_device_names = [ 34 | "/job:localhost/replica:0/task:0/device:CPU:0", 35 | "/job:localhost/replica:0/task:0/device:TPU:0", 36 | "/job:localhost/replica:0/task:0/device:TPU:1", 37 | "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", 38 | ] 39 | with self.session() as sess: 40 | devices = sess.list_devices() 41 | logging.info("devices:\n%s", "\n".join([str(d) for d in devices])) 42 | self.assertAllEqual([d.name for d in devices], expected_device_names) 43 | 44 | def testCrossReplicaConcat(self): 45 | def computation(x, replica_id): 46 | logging.info("x: %s\nreplica_id: %s", x, replica_id[0]) 47 | return tpu_ops.cross_replica_concat(x, replica_id[0], num_replicas=2) 48 | 49 | inputs = np.asarray([[3, 4], [1, 5]]) 50 | expected_output = np.asarray([[3, 4], [1, 5], [3, 4], [1, 5]]) 51 | 52 | with tf.Graph().as_default(): 53 | x = tf.constant(inputs) 54 | replica_ids = tf.constant([0, 1], dtype=tf.int32) 55 | x_concat, = tf.contrib.tpu.batch_parallel( 56 | computation, [x, replica_ids], num_shards=2) 57 | self.assertAllEqual(x.shape.as_list(), [2, 2]) 58 | self.assertAllEqual(x_concat.shape.as_list(), [4, 2]) 59 | 60 | with self.session() as sess: 61 | sess.run(tf.contrib.tpu.initialize_system()) 62 | sess.run(tf.global_variables_initializer()) 63 | x_concat = sess.run(x_concat) 64 | logging.info("x_concat: %s", x_concat) 65 | self.assertAllClose(x_concat, expected_output) 66 | 67 | # Test with group size 2 (test case has 2 cores, so this global batch norm). 68 | @parameterized.parameters( 69 | {"group_size": None}, # Defaults to number of TPU cores. 70 | {"group_size": 0}, # Defaults to number of TPU cores. 71 | {"group_size": 2}, 72 | ) 73 | def testCrossReplicaMean(self, group_size): 74 | # Verify that we average across replicas by feeding 2 vectors to the system. 75 | # Each replica should get one vector which is then averaged across 76 | # all replicas and simply returned. 77 | # After that each replica has the same vector and since the outputs gets 78 | # concatenated we see the same vector twice. 79 | inputs = np.asarray( 80 | [[0.55, 0.70, -1.29, 0.502], [0.57, 0.90, 1.290, 0.202]], 81 | dtype=np.float32) 82 | expected_output = np.asarray( 83 | [[0.56, 0.8, 0.0, 0.352], [0.56, 0.8, 0.0, 0.352]], dtype=np.float32) 84 | 85 | def computation(x): 86 | self.assertAllEqual(x.shape.as_list(), [1, 4]) 87 | return tpu_ops.cross_replica_mean(x, group_size=group_size) 88 | 89 | with tf.Graph().as_default(): 90 | # Note: Using placeholders for feeding TPUs is discouraged but fine for 91 | # a simple test case. 92 | x = tf.placeholder(name="x", dtype=tf.float32, shape=inputs.shape) 93 | y = tf.contrib.tpu.batch_parallel(computation, inputs=[x], num_shards=2) 94 | with self.session() as sess: 95 | sess.run(tf.contrib.tpu.initialize_system()) 96 | # y is actually a list with one tensor. computation would be allowed 97 | # to return multiple tensors (and ops). 98 | actual_output = sess.run(y, {x: inputs})[0] 99 | 100 | self.assertAllEqual(actual_output.shape, (2, 4)) 101 | self.assertAllClose(actual_output, expected_output) 102 | 103 | def testCrossReplicaMeanGroupSizeOne(self, group_size=1): 104 | # Since the group size is 1 we only average over 1 replica. 105 | inputs = np.asarray( 106 | [[0.55, 0.70, -1.29, 0.502], [0.57, 0.90, 1.290, 0.202]], 107 | dtype=np.float32) 108 | expected_output = np.asarray( 109 | [[0.55, 0.7, -1.29, 0.502], [0.57, 0.9, 1.290, 0.202]], 110 | dtype=np.float32) 111 | 112 | def computation(x): 113 | self.assertAllEqual(x.shape.as_list(), [1, 4]) 114 | return tpu_ops.cross_replica_mean(x, group_size=group_size) 115 | 116 | with tf.Graph().as_default(): 117 | # Note: Using placeholders for feeding TPUs is discouraged but fine for 118 | # a simple test case. 119 | x = tf.placeholder(name="x", dtype=tf.float32, shape=inputs.shape) 120 | y = tf.contrib.tpu.batch_parallel(computation, inputs=[x], num_shards=2) 121 | with self.session() as sess: 122 | sess.run(tf.contrib.tpu.initialize_system()) 123 | # y is actually a list with one tensor. computation would be allowed 124 | # to return multiple tensors (and ops). 125 | actual_output = sess.run(y, {x: inputs})[0] 126 | 127 | self.assertAllEqual(actual_output.shape, (2, 4)) 128 | self.assertAllClose(actual_output, expected_output) 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.test.main() 133 | -------------------------------------------------------------------------------- /compare_gan/tpu/tpu_summaries.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Provide a helper class for using summaries on TPU via a host call. 17 | 18 | TPUEstimator does not support writing TF summaries out of the box and TPUs can't 19 | perform operations that write files to disk. To monitor tensor values during 20 | training you can copy the tensors back to the CPU of the host machine via 21 | a host call function. This small library provides a convienent API to do this. 22 | 23 | Example: 24 | from compare_gan.tpu import tpu_summaries 25 | def model_fn(features, labels, params, mode): 26 | summary = tpu_summries.TpuSummaries(my_model_dir) 27 | 28 | summary.scalar("my_scalar_summary", tensor1) 29 | summary.scalar("my_counter", tensor2, reduce_fn=tf.math.reduce_sum) 30 | 31 | return TPUEstimatorSpec( 32 | host_call=summary.get_host_call(), 33 | ...) 34 | 35 | Warning: The host call function will run every step. Writing large tensors to 36 | summaries can slow down your training. High ranking outfeed operations in your 37 | XProf profile can be an indication for this. 38 | """ 39 | 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import collections 45 | 46 | from absl import logging 47 | import tensorflow as tf 48 | 49 | 50 | summary = tf.contrib.summary # TensorFlow Summary API v2. 51 | 52 | 53 | TpuSummaryEntry = collections.namedtuple( 54 | "TpuSummaryEntry", "summary_fn name tensor reduce_fn") 55 | 56 | 57 | class TpuSummaries(object): 58 | """Class to simplify TF summaries on TPU. 59 | 60 | An instance of the class provides simple methods for writing summaries in the 61 | similar way to tf.summary. The difference is that each summary entry must 62 | provide a reduction function that is used to reduce the summary values from 63 | all the TPU cores. 64 | """ 65 | 66 | def __init__(self, log_dir, save_summary_steps=250): 67 | self._log_dir = log_dir 68 | self._entries = [] 69 | # While False no summary entries will be added. On TPU we unroll the graph 70 | # and don't want to add multiple summaries per step. 71 | self.record = True 72 | self._save_summary_steps = save_summary_steps 73 | 74 | def image(self, name, tensor, reduce_fn): 75 | """Add a summary for images. Tensor must be of 4-D tensor.""" 76 | if not self.record: 77 | return 78 | self._entries.append( 79 | TpuSummaryEntry(summary.image, name, tensor, reduce_fn)) 80 | 81 | def scalar(self, name, tensor, reduce_fn=tf.math.reduce_mean): 82 | """Add a summary for a scalar tensor.""" 83 | if not self.record: 84 | return 85 | tensor = tf.convert_to_tensor(tensor) 86 | if tensor.shape.ndims == 0: 87 | tensor = tf.expand_dims(tensor, 0) 88 | self._entries.append( 89 | TpuSummaryEntry(summary.scalar, name, tensor, reduce_fn)) 90 | 91 | def get_host_call(self): 92 | """Returns the tuple (host_call_fn, host_call_args) for TPUEstimatorSpec.""" 93 | # All host_call_args must be tensors with batch dimension. 94 | # All tensors are streamed to the host machine (mind the band width). 95 | global_step = tf.train.get_or_create_global_step() 96 | host_call_args = [tf.expand_dims(global_step, 0)] 97 | host_call_args.extend([e.tensor for e in self._entries]) 98 | logging.info("host_call_args: %s", host_call_args) 99 | return (self._host_call_fn, host_call_args) 100 | 101 | def _host_call_fn(self, step, *args): 102 | """Function that will run on the host machine.""" 103 | # Host call receives values from all tensor cores (concatenate on the 104 | # batch dimension). Step is the same for all cores. 105 | step = step[0] 106 | logging.info("host_call_fn: args=%s", args) 107 | with summary.create_file_writer(self._log_dir).as_default(): 108 | with summary.record_summaries_every_n_global_steps( 109 | self._save_summary_steps, step): 110 | for i, e in enumerate(self._entries): 111 | value = e.reduce_fn(args[i]) 112 | e.summary_fn(e.name, value, step=step) 113 | return summary.all_summary_ops() 114 | -------------------------------------------------------------------------------- /compare_gan/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities library.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import functools 24 | import inspect 25 | 26 | from absl import logging 27 | import six 28 | 29 | 30 | # In Python 2 the inspect module does not have FullArgSpec. Define a named tuple 31 | # instead. 32 | if hasattr(inspect, "FullArgSpec"): 33 | _FullArgSpec = inspect.FullArgSpec # pylint: disable=invalid-name 34 | else: 35 | _FullArgSpec = collections.namedtuple("FullArgSpec", [ 36 | "args", "varargs", "varkw", "defaults", "kwonlyargs", "kwonlydefaults", 37 | "annotations" 38 | ]) 39 | 40 | 41 | def _getfullargspec(fn): 42 | """Python 2/3 compatible version of the inspect.getfullargspec method. 43 | 44 | Args: 45 | fn: The function object. 46 | 47 | Returns: 48 | A FullArgSpec. For Python 2 this is emulated by a named tuple. 49 | """ 50 | arg_spec_fn = inspect.getfullargspec if six.PY3 else inspect.getargspec 51 | try: 52 | arg_spec = arg_spec_fn(fn) 53 | except TypeError: 54 | # `fn` might be a callable object. 55 | arg_spec = arg_spec_fn(fn.__call__) 56 | if six.PY3: 57 | assert isinstance(arg_spec, _FullArgSpec) 58 | return arg_spec 59 | return _FullArgSpec( 60 | args=arg_spec.args, 61 | varargs=arg_spec.varargs, 62 | varkw=arg_spec.keywords, 63 | defaults=arg_spec.defaults, 64 | kwonlyargs=[], 65 | kwonlydefaults=None, 66 | annotations={}) 67 | 68 | 69 | def _has_arg(fn, arg_name): 70 | """Returns True if `arg_name` might be a valid parameter for `fn`. 71 | 72 | Specifically, this means that `fn` either has a parameter named 73 | `arg_name`, or has a `**kwargs` parameter. 74 | 75 | Args: 76 | fn: The function to check. 77 | arg_name: The name fo the parameter. 78 | 79 | Returns: 80 | Whether `arg_name` might be a valid argument of `fn`. 81 | """ 82 | while isinstance(fn, functools.partial): 83 | fn = fn.func 84 | while hasattr(fn, "__wrapped__"): 85 | fn = fn.__wrapped__ 86 | arg_spec = _getfullargspec(fn) 87 | if arg_spec.varkw: 88 | return True 89 | return arg_name in arg_spec.args or arg_name in arg_spec.kwonlyargs 90 | 91 | 92 | def call_with_accepted_args(fn, **kwargs): 93 | """Calls `fn` only with the keyword arguments that `fn` accepts.""" 94 | kwargs = {k: v for k, v in six.iteritems(kwargs) if _has_arg(fn, k)} 95 | logging.debug("Calling %s with args %s.", fn, kwargs) 96 | return fn(**kwargs) 97 | 98 | 99 | def get_parameter_overview(variables, limit=40): 100 | """Returns a string with variables names, their shapes, count, and types. 101 | 102 | To get all trainable parameters pass in `tf.trainable_variables()`. 103 | 104 | Args: 105 | variables: List of `tf.Variable`(s). 106 | limit: If not `None`, the maximum number of variables to include. 107 | 108 | Returns: 109 | A string with a table like in the example. 110 | 111 | +----------------+---------------+------------+---------+ 112 | | Name | Shape | Size | Type | 113 | +----------------+---------------+------------+---------+ 114 | | FC_1/weights:0 | (63612, 1024) | 65,138,688 | float32 | 115 | | FC_1/biases:0 | (1024,) | 1,024 | float32 | 116 | | FC_2/weights:0 | (1024, 32) | 32,768 | float32 | 117 | | FC_2/biases:0 | (32,) | 32 | float32 | 118 | +----------------+---------------+------------+---------+ 119 | 120 | Total: 65,172,512 121 | """ 122 | max_name_len = max([len(v.name) for v in variables] + [len("Name")]) 123 | max_shape_len = max([len(str(v.get_shape())) for v in variables] + [len( 124 | "Shape")]) 125 | max_size_len = max([len("{:,}".format(v.get_shape().num_elements())) 126 | for v in variables] + [len("Size")]) 127 | max_type_len = max([len(v.dtype.base_dtype.name) for v in variables] + [len( 128 | "Type")]) 129 | 130 | var_line_format = "| {: <{}s} | {: >{}s} | {: >{}s} | {: <{}s} |" 131 | sep_line_format = var_line_format.replace(" ", "-").replace("|", "+") 132 | 133 | header = var_line_format.replace(">", "<").format("Name", max_name_len, 134 | "Shape", max_shape_len, 135 | "Size", max_size_len, 136 | "Type", max_type_len) 137 | separator = sep_line_format.format("", max_name_len, "", max_shape_len, "", 138 | max_size_len, "", max_type_len) 139 | 140 | lines = [separator, header, separator] 141 | 142 | total_weights = sum(v.get_shape().num_elements() for v in variables) 143 | 144 | # Create lines for up to 80 variables. 145 | for v in variables: 146 | if limit is not None and len(lines) >= limit: 147 | lines.append("[...]") 148 | break 149 | lines.append(var_line_format.format( 150 | v.name, max_name_len, 151 | str(v.get_shape()), max_shape_len, 152 | "{:,}".format(v.get_shape().num_elements()), max_size_len, 153 | v.dtype.base_dtype.name, max_type_len)) 154 | 155 | lines.append(separator) 156 | lines.append("Total: {:,}".format(total_weights)) 157 | 158 | return "\n".join(lines) 159 | 160 | 161 | def log_parameter_overview(variables, msg): 162 | """Writes a table with variables name and shapes to INFO log. 163 | 164 | See get_parameter_overview for details. 165 | 166 | Args: 167 | variables: List of `tf.Variable`(s). 168 | msg: Message to be logged before the table. 169 | """ 170 | table = get_parameter_overview(variables, limit=None) 171 | # The table can to large to fit into one log entry. 172 | lines = [msg] + table.split("\n") 173 | for i in range(0, len(lines), 80): 174 | logging.info("\n%s", "\n".join(lines[i:i + 80])) 175 | -------------------------------------------------------------------------------- /example_configs/README.md: -------------------------------------------------------------------------------- 1 | This folder contains configurations of popular GANs and links to the 2 | corresponding papers. 3 | 4 | **Please note that we provide them as a starting point for the user and that 5 | they represent the best-effort implementation** -- we do not guarantee that all 6 | the implementation details match exactly due to the vast number of design 7 | options. Given the sensitivity of GANs to design choices, hyperparameters and 8 | differences in platforms (GPUs/TPUs), a small differences might have a 9 | significant impact on the final results. 10 | -------------------------------------------------------------------------------- /example_configs/biggan_imagenet128.gin: -------------------------------------------------------------------------------- 1 | # BigGAN architecture and settings on ImageNet 128. 2 | # http://arxiv.org/abs/1809.11096 3 | 4 | # This should be similar to row 7 in Table 1. 5 | # It does not include orthogonal regularization (which would be row 8) and uses 6 | # a different learning rate. 7 | 8 | # Recommended training platform: TPU v3-128. 9 | 10 | dataset.name = "imagenet_128" 11 | options.z_dim = 120 12 | 13 | options.architecture = "resnet_biggan_arch" 14 | ModularGAN.conditional = True 15 | options.batch_size = 2048 16 | options.gan_class = @ModularGAN 17 | options.lamba = 1 18 | options.training_steps = 250000 19 | weights.initializer = "orthogonal" 20 | spectral_norm.singular_value = "auto" 21 | 22 | # Generator 23 | G.batch_norm_fn = @conditional_batch_norm 24 | G.spectral_norm = True 25 | ModularGAN.g_use_ema = True 26 | resnet_biggan.Generator.hierarchical_z = True 27 | resnet_biggan.Generator.embed_y = True 28 | standardize_batch.decay = 0.9 29 | standardize_batch.epsilon = 1e-5 30 | standardize_batch.use_moving_averages = False 31 | 32 | # Discriminator 33 | options.disc_iters = 2 34 | D.spectral_norm = True 35 | resnet_biggan.Discriminator.project_y = True 36 | 37 | # Loss and optimizer 38 | loss.fn = @hinge 39 | penalty.fn = @no_penalty 40 | ModularGAN.g_lr = 0.0001 41 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer 42 | ModularGAN.d_lr = 0.0005 43 | ModularGAN.d_optimizer_fn = @tf.train.AdamOptimizer 44 | tf.train.AdamOptimizer.beta1 = 0.0 45 | tf.train.AdamOptimizer.beta2 = 0.999 46 | 47 | z.distribution_fn = @tf.random.normal 48 | eval_z.distribution_fn = @tf.random.normal 49 | 50 | run_config.iterations_per_loop = 500 51 | run_config.save_checkpoints_steps = 2500 52 | -------------------------------------------------------------------------------- /example_configs/dcgan_celeba64.gin: -------------------------------------------------------------------------------- 1 | 2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8 3 | 4 | dataset.name = "celeb_a" 5 | options.architecture = "dcgan_arch" 6 | options.batch_size = 64 7 | options.gan_class = @ModularGAN 8 | options.lamba = 1 9 | options.training_steps = 100000 10 | options.z_dim = 128 11 | 12 | # Generator 13 | G.batch_norm_fn = @batch_norm 14 | standardize_batch.decay = 0.9 15 | standardize_batch.epsilon = 1e-5 16 | 17 | # Discriminator 18 | options.disc_iters = 1 19 | D.spectral_norm = False 20 | 21 | # Loss and optimizer 22 | loss.fn = @non_saturating 23 | penalty.fn = @no_penalty 24 | ModularGAN.g_lr = 0.0002 25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer 26 | tf.train.AdamOptimizer.beta1 = 0.5 27 | tf.train.AdamOptimizer.beta2 = 0.999 28 | -------------------------------------------------------------------------------- /example_configs/resnet_cifar10.gin: -------------------------------------------------------------------------------- 1 | 2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8 3 | 4 | dataset.name = "cifar10" 5 | options.architecture = "resnet_cifar_arch" 6 | options.batch_size = 64 7 | options.gan_class = @ModularGAN 8 | options.lamba = 1 9 | options.training_steps = 40000 10 | options.z_dim = 128 11 | 12 | # Generator 13 | G.batch_norm_fn = @batch_norm 14 | standardize_batch.decay = 0.9 15 | standardize_batch.epsilon = 1e-5 16 | 17 | # Discriminator 18 | options.disc_iters = 5 19 | D.spectral_norm = True 20 | 21 | # Loss and optimizer 22 | loss.fn = @non_saturating 23 | penalty.fn = @no_penalty 24 | ModularGAN.g_lr = 0.0002 25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer 26 | tf.train.AdamOptimizer.beta1 = 0.5 27 | tf.train.AdamOptimizer.beta2 = 0.999 28 | -------------------------------------------------------------------------------- /example_configs/resnet_lsun-bedroom128.gin: -------------------------------------------------------------------------------- 1 | 2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8 3 | 4 | dataset.name = "lsun-bedroom" 5 | options.architecture = "resnet5_arch" 6 | options.batch_size = 64 7 | options.gan_class = @ModularGAN 8 | options.lamba = 10 9 | options.training_steps = 40000 10 | options.z_dim = 128 11 | 12 | # Generator 13 | G.batch_norm_fn = @batch_norm 14 | standardize_batch.decay = 0.9 15 | standardize_batch.epsilon = 1e-5 16 | 17 | # Discriminator 18 | options.disc_iters = 5 19 | D.spectral_norm = False 20 | 21 | # Loss and optimizer 22 | loss.fn = @wasserstein 23 | penalty.fn = @wgangp_penalty 24 | ModularGAN.g_lr = 0.0001 25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer 26 | tf.train.AdamOptimizer.beta1 = 0.5 27 | tf.train.AdamOptimizer.beta2 = 0.9 28 | -------------------------------------------------------------------------------- /example_configs/sndcgan_celebahq128.gin: -------------------------------------------------------------------------------- 1 | 2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8 3 | 4 | dataset.name = "celeb_a_hq_128" 5 | options.architecture = "sndcgan_arch" 6 | options.batch_size = 64 7 | options.gan_class = @ModularGAN 8 | options.lamba = 1 9 | options.training_steps = 100000 10 | options.z_dim = 128 11 | 12 | # Generator 13 | G.batch_norm_fn = @batch_norm 14 | standardize_batch.decay = 0.9 15 | standardize_batch.epsilon = 1e-5 16 | 17 | # Discriminator 18 | options.disc_iters = 1 19 | D.spectral_norm = True 20 | 21 | # Loss and optimizer 22 | loss.fn = @non_saturating 23 | penalty.fn = @no_penalty 24 | ModularGAN.g_lr = 0.0002 25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer 26 | tf.train.AdamOptimizer.beta1 = 0.5 27 | tf.train.AdamOptimizer.beta2 = 0.999 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google LLC & Hwalsuk Lee. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Install compare_gan.""" 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | setup( 22 | name='compare_gan', 23 | version='3.0', 24 | description=( 25 | 'Compare GAN - A modular library for training and evaluating GANs.'), 26 | author='Google LLC', 27 | author_email='no-reply@google.com', 28 | url='https://github.com/google/compare_gan', 29 | license='Apache 2.0', 30 | packages=find_packages(), 31 | package_data={}, 32 | install_requires=[ 33 | 'future', 34 | 'gin-config==0.1.4', 35 | 'numpy', 36 | 'pandas', 37 | 'six', 38 | 'tensorflow-datasets==1.0.1', 39 | 'tensorflow-hub>=0.2.0', 40 | 'tensorflow-gan==0.0.0.dev0', 41 | 'matplotlib>=1.5.2', 42 | 'pstar>=0.1.6', 43 | 'scipy>=1.0.0', 44 | ], 45 | extras_require={ 46 | 'tf': ['tensorflow>=1.12'], 47 | # Evaluation of Hub modules with EMA variables requires TF > 1.12. 48 | 'tf_gpu': ['tf-nightly-gpu>=1.13.0.dev20190221'], 49 | 'pillow': ['pillow>=5.0.0'], 50 | 'tensorflow-probability': ['tensorflow-probability>=0.5.0'], 51 | }, 52 | classifiers=[ 53 | 'Development Status :: 4 - Beta', 54 | 'Intended Audience :: Developers', 55 | 'Intended Audience :: Science/Research', 56 | 'License :: OSI Approved :: Apache Software License', 57 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 58 | ], 59 | keywords='tensorflow machine learning gan', 60 | ) 61 | --------------------------------------------------------------------------------