├── .gitignore
├── AUTHORS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── colabs
├── Convex_Polygons_Dataset.ipynb
├── s3gan_demo.ipynb
└── ssgan_demo.ipynb
├── compare_gan
├── __init__.py
├── architectures
│ ├── __init__.py
│ ├── abstract_arch.py
│ ├── arch_ops.py
│ ├── arch_ops_test.py
│ ├── arch_ops_tpu_test.py
│ ├── architectures_test.py
│ ├── dcgan.py
│ ├── infogan.py
│ ├── resnet30.py
│ ├── resnet5.py
│ ├── resnet_biggan.py
│ ├── resnet_biggan_deep.py
│ ├── resnet_biggan_deep_test.py
│ ├── resnet_biggan_test.py
│ ├── resnet_cifar.py
│ ├── resnet_init_test.py
│ ├── resnet_norm_test.py
│ ├── resnet_ops.py
│ ├── resnet_stl.py
│ └── sndcgan.py
├── datasets.py
├── datasets_test.py
├── eval_gan_lib.py
├── eval_gan_lib_test.py
├── eval_utils.py
├── gans
│ ├── __init__.py
│ ├── abstract_gan.py
│ ├── consts.py
│ ├── loss_lib.py
│ ├── modular_gan.py
│ ├── modular_gan_conditional_test.py
│ ├── modular_gan_test.py
│ ├── modular_gan_tpu_test.py
│ ├── ops.py
│ ├── penalty_lib.py
│ ├── s3gan.py
│ ├── s3gan_test.py
│ ├── ssgan.py
│ ├── ssgan_test.py
│ └── utils.py
├── hooks.py
├── main.py
├── metrics
│ ├── __init__.py
│ ├── accuracy.py
│ ├── eval_task.py
│ ├── fid_score.py
│ ├── fid_score_test.py
│ ├── fractal_dimension.py
│ ├── fractal_dimension_test.py
│ ├── gilbo.py
│ ├── image_similarity.py
│ ├── inception_score.py
│ ├── jacobian_conditioning.py
│ ├── jacobian_conditioning_test.py
│ ├── kid_score.py
│ ├── ms_ssim_score.py
│ ├── ms_ssim_score_test.py
│ ├── prd_score.py
│ └── prd_score_test.py
├── runner_lib.py
├── runner_lib_test.py
├── test_utils.py
├── tpu
│ ├── __init__.py
│ ├── tpu_ops.py
│ ├── tpu_ops_test.py
│ ├── tpu_random.py
│ ├── tpu_random_test.py
│ └── tpu_summaries.py
└── utils.py
├── example_configs
├── README.md
├── biggan_imagenet128.gin
├── dcgan_celeba64.gin
├── resnet_cifar10.gin
├── resnet_lsun-bedroom128.gin
└── sndcgan_celebahq128.gin
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled python modules.
2 | *.pyc
3 |
4 | # Byte-compiled
5 | _pycache__/
6 | .cache/
7 |
8 | # Python egg metadata, regenerated from source files by setuptools.
9 | /*.egg-info
10 | .eggs/
11 |
12 | # PyPI distribution artifacts.
13 | build/
14 | dist/
15 |
16 | # Sublime project files
17 | *.sublime-project
18 | *.sublime-workspace
19 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | # This is the list of authors for copyright purposes.
2 | #
3 | # This does not necessarily list everyone who has contributed code, since in
4 | # some cases, their employer may be the copyright holder. To see the full list
5 | # of contributors, see the revision history in source control.
6 |
7 | Google LLC
8 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | # Issues
4 |
5 | * Please tag your issue with `bug`, `feature request`, or `question` to help us
6 | effectively respond.
7 | * Please include the versions of TensorFlow and Tensor2Tensor you are running
8 | (run `pip list | grep tensor`)
9 | * Please provide the command line you ran as well as the log output.
10 |
11 | # Pull Requests
12 |
13 | We'd love to accept your patches and contributions to this project. There are
14 | just a few small guidelines you need to follow.
15 |
16 | ## Contributor License Agreement
17 |
18 | Contributions to this project must be accompanied by a Contributor License
19 | Agreement. You (or your employer) retain the copyright to your contribution,
20 | this simply gives us permission to use and redistribute your contributions as
21 | part of the project. Head over to to see
22 | your current agreements on file or to sign a new one.
23 |
24 | You generally only need to submit a CLA once, so if you've already submitted one
25 | (even if it was for a different project), you probably don't need to do it
26 | again.
27 |
28 | ## Code reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use GitHub pull requests for this purpose. Consult
32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
33 | information on using pull requests.
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Compare GAN
2 |
3 | This repository offers TensorFlow implementations for many components related to
4 | **Generative Adversarial Networks**:
5 |
6 | * losses (such non-saturating GAN, least-squares GAN, and WGAN),
7 | * penalties (such as the gradient penalty),
8 | * normalization techniques (such as spectral normalization, batch
9 | normalization, and layer normalization),
10 | * neural architectures (BigGAN, ResNet, DCGAN), and
11 | * evaluation metrics (FID score, Inception Score, precision-recall, and KID
12 | score).
13 |
14 | The code is **configurable via [Gin](https://github.com/google/gin-config)** and
15 | runs on **GPU/TPU/CPUs**. Several research papers make use of this repository,
16 | including:
17 |
18 | 1. [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337)
19 | [[Code]](https://github.com/google/compare_gan/tree/v1)
20 | \
21 | Mario Lucic*, Karol Kurach*, Marcin Michalski, Sylvain Gelly, Olivier
22 | Bousquet **[NeurIPS 2018]**
23 |
24 | 2. [The GAN Landscape: Losses, Architectures, Regularization, and Normalization](https://arxiv.org/abs/1807.04720)
25 | [[Code]](https://github.com/google/compare_gan/tree/v2)
26 | [[Colab]](https://colab.research.google.com/github/google/compare_gan/blob/v2/compare_gan/src/tfhub_models.ipynb)
27 | \
28 | Karol Kurach*, Mario Lucic*, Xiaohua Zhai, Marcin Michalski, Sylvain Gelly
29 | **[ICML 2019]**
30 |
31 | 3. [Assessing Generative Models via Precision and Recall](https://arxiv.org/abs/1806.00035)
32 | [[Code]](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/prd_score.py)
33 | \
34 | Mehdi S. M. Sajjadi, Olivier Bachem, Mario Lucic, Olivier Bousquet, Sylvain
35 | Gelly **[NeurIPS 2018]**
36 |
37 | 4. [GILBO: One Metric to Measure Them All](https://arxiv.org/abs/1802.04874)
38 | [[Code]](https://github.com/google/compare_gan/blob/560697ee213f91048c6b4231ab79fcdd9bf20381/compare_gan/src/gilbo.py)
39 | \
40 | Alexander A. Alemi, Ian Fischer **[NeurIPS 2018]**
41 |
42 | 5. [A Case for Object Compositionality in Deep Generative Models of Images](https://arxiv.org/abs/1810.10340)
43 | [[Code]](https://github.com/google/compare_gan/tree/v2_multigan)
44 | \
45 | Sjoerd van Steenkiste, Karol Kurach, Sylvain Gelly **[2018]**
46 |
47 | 6. [On Self Modulation for Generative Adversarial Networks](https://arxiv.org/abs/1810.01365)
48 | [[Code]](https://github.com/google/compare_gan) \
49 | Ting Chen, Mario Lucic, Neil Houlsby, Sylvain Gelly **[ICLR 2019]**
50 |
51 | 7. [Self-Supervised GANs via Auxiliary Rotation Loss](https://arxiv.org/abs/1811.11212)
52 | [[Code]](https://github.com/google/compare_gan)
53 | [[Colab]](https://colab.research.google.com/github/google/compare_gan/blob/v3/colabs/ssgan_demo.ipynb)
54 | \
55 | Ting Chen, Xiaohua Zhai, Marvin Ritter, Mario Lucic, Neil Houlsby **[CVPR
56 | 2019]**
57 |
58 | 8. [High-Fidelity Image Generation With Fewer Labels](https://arxiv.org/abs/1903.02271)
59 | [[Code]](https://github.com/google/compare_gan)
60 | [[Blog Post]](https://ai.googleblog.com/2019/03/reducing-need-for-labeled-data-in.html)
61 | [[Colab]](https://colab.research.google.com/github/google/compare_gan/blob/v3/colabs/s3gan_demo.ipynb)
62 | \
63 | Mario Lucic*, Michael Tschannen*, Marvin Ritter*, Xiaohua Zhai, Olivier
64 | Bachem, Sylvain Gelly **[ICML 2019]**
65 |
66 | ## Installation
67 |
68 | You can easily install the library and all necessary dependencies by running:
69 | `pip install -e .` from the `compare_gan/` folder.
70 |
71 | ## Running experiments
72 |
73 | Simply run the `main.py` passing a `--model_dir` (this is where checkpoints are
74 | stored) and a `--gin_config` (defines which model is trained on which data set
75 | and other training options). We provide several example configurations in the
76 | `example_configs/` folder:
77 |
78 | * **dcgan_celeba64**: DCGAN architecture with non-saturating loss on CelebA
79 | 64x64px
80 | * **resnet_cifar10**: ResNet architecture with non-saturating loss and
81 | spectral normalization on CIFAR-10
82 | * **resnet_lsun-bedroom128**: ResNet architecture with WGAN loss and gradient
83 | penalty on LSUN-bedrooms 128x128px
84 | * **sndcgan_celebahq128**: SN-DCGAN architecture with non-saturating loss and
85 | spectral normalization on CelebA-HQ 128x128px
86 | * **biggan_imagenet128**: BigGAN architecture with hinge loss and spectral
87 | normalization on ImageNet 128x128px
88 |
89 | ### Training and evaluation
90 |
91 | To see all available options please run `python main.py --help`. Main options:
92 |
93 | * To **train** the model use `--schedule=train` (default). Training is resumed
94 | from the last saved checkpoint.
95 | * To **evaluate** all checkpoints use `--schedule=continuous_eval
96 | --eval_every_steps=0`. To evaluate only checkpoints where the step size is
97 | divisible by 5000, use `--schedule=continuous_eval --eval_every_steps=5000`.
98 | By default, 3 averaging runs are used to estimate the Inception Score and
99 | the FID score. Keep in mind that when running locally on a single GPU it may
100 | not be possible to run training and evaluation simultaneously due to memory
101 | constraints.
102 | * To **train and evaluate** the model use `--schedule=eval_after_train
103 | --eval_every_steps=0`.
104 |
105 | ### Training on Cloud TPUs
106 |
107 | We recommend using the
108 | [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu) to create
109 | a Cloud TPU and corresponding Compute Engine VM. We use v3-128 Cloud TPU v3 Pod
110 | for training models on ImageNet in 128x128 resolutions. You can use smaller
111 | slices if you reduce the batch size (`options.batch_size` in the Gin config) or
112 | model parameters. Keep in mind that the model quality might change. Before
113 | training make sure that the environment variable `TPU_NAME` is set. Running
114 | evaluation on TPUs is currently not supported. Use a VM with a single GPU
115 | instead.
116 |
117 | ### Datasets
118 |
119 | Compare GAN uses [TensorFlow Datasets](https://www.tensorflow.org/datasets) and
120 | it will automatically download and prepare the data. For ImageNet you will need
121 | to download the archive yourself. For CelebAHq you need to download and prepare
122 | the images on your own. If you are using TPUs make sure to point the training
123 | script to your Google Storage Bucket (`--tfds_data_dir`).
124 |
--------------------------------------------------------------------------------
/compare_gan/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # coding=utf-8
17 |
--------------------------------------------------------------------------------
/compare_gan/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # coding=utf-8
17 |
--------------------------------------------------------------------------------
/compare_gan/architectures/abstract_arch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines interfaces for generator and discriminator networks."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | from compare_gan import utils
24 | import gin
25 | import six
26 | import tensorflow as tf
27 |
28 |
29 | @six.add_metaclass(abc.ABCMeta)
30 | class _Module(object):
31 | """Base class for architectures.
32 |
33 | Long term this will be replaced by `tf.Module` in TF 2.0.
34 | """
35 |
36 | def __init__(self, name):
37 | self._name = name
38 |
39 | @property
40 | def name(self):
41 | return self._name
42 |
43 | @property
44 | def trainable_variables(self):
45 | return [var for var in tf.trainable_variables() if self._name in var.name]
46 |
47 |
48 | @gin.configurable("G", blacklist=["name", "image_shape"])
49 | class AbstractGenerator(_Module):
50 | """Interface for generator architectures."""
51 |
52 | def __init__(self,
53 | name="generator",
54 | image_shape=None,
55 | batch_norm_fn=None,
56 | spectral_norm=False):
57 | """Constructor for all generator architectures.
58 |
59 | Args:
60 | name: Scope name of the generator.
61 | image_shape: Image shape to be generated, [height, width, colors].
62 | batch_norm_fn: Function for batch normalization or None.
63 | spectral_norm: If True use spectral normalization for all weights.
64 | """
65 | super(AbstractGenerator, self).__init__(name=name)
66 | self._name = name
67 | self._image_shape = image_shape
68 | self._batch_norm_fn = batch_norm_fn
69 | self._spectral_norm = spectral_norm
70 |
71 | def __call__(self, z, y, is_training, reuse=tf.AUTO_REUSE):
72 | with tf.variable_scope(self.name, values=[z, y], reuse=reuse):
73 | outputs = self.apply(z=z, y=y, is_training=is_training)
74 | return outputs
75 |
76 | def batch_norm(self, inputs, **kwargs):
77 | if self._batch_norm_fn is None:
78 | return inputs
79 | args = kwargs.copy()
80 | args["inputs"] = inputs
81 | if "use_sn" not in args:
82 | args["use_sn"] = self._spectral_norm
83 | return utils.call_with_accepted_args(self._batch_norm_fn, **args)
84 |
85 | @abc.abstractmethod
86 | def apply(self, z, y, is_training):
87 | """Apply the generator on a input.
88 |
89 | Args:
90 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
91 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
92 | labels.
93 | is_training: Boolean, whether the architecture should be constructed for
94 | training or inference.
95 |
96 | Returns:
97 | Generated images of shape [batch_size] + self.image_shape.
98 | """
99 |
100 |
101 | @gin.configurable("D", blacklist=["name"])
102 | class AbstractDiscriminator(_Module):
103 | """Interface for discriminator architectures."""
104 |
105 | def __init__(self,
106 | name="discriminator",
107 | batch_norm_fn=None,
108 | layer_norm=False,
109 | spectral_norm=False):
110 | super(AbstractDiscriminator, self).__init__(name=name)
111 | self._name = name
112 | self._batch_norm_fn = batch_norm_fn
113 | self._layer_norm = layer_norm
114 | self._spectral_norm = spectral_norm
115 |
116 | def __call__(self, x, y, is_training, reuse=tf.AUTO_REUSE):
117 | with tf.variable_scope(self.name, values=[x, y], reuse=reuse):
118 | outputs = self.apply(x=x, y=y, is_training=is_training)
119 | return outputs
120 |
121 | def batch_norm(self, inputs, **kwargs):
122 | if self._batch_norm_fn is None:
123 | return inputs
124 | args = kwargs.copy()
125 | args["inputs"] = inputs
126 | if "use_sn" not in args:
127 | args["use_sn"] = self._spectral_norm
128 | return utils.call_with_accepted_args(self._batch_norm_fn, **args)
129 |
130 |
131 | @abc.abstractmethod
132 | def apply(self, x, y, is_training):
133 | """Apply the discriminator on a input.
134 |
135 | Args:
136 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
137 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
138 | labels.
139 | is_training: Boolean, whether the architecture should be constructed for
140 | training or inference.
141 |
142 | Returns:
143 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
144 | before the final output activation function and logits form the second
145 | last layer.
146 | """
147 |
--------------------------------------------------------------------------------
/compare_gan/architectures/arch_ops_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for custom architecture operations."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.architectures import arch_ops
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 |
27 | class ArchOpsTest(tf.test.TestCase):
28 |
29 | def testBatchNorm(self):
30 | with tf.Graph().as_default():
31 | # 4 images with resolution 2x1 and 3 channels.
32 | x1 = tf.constant([[[5, 7, 2]], [[5, 8, 8]]], dtype=tf.float32)
33 | x2 = tf.constant([[[1, 2, 0]], [[4, 0, 4]]], dtype=tf.float32)
34 | x3 = tf.constant([[[6, 2, 6]], [[5, 0, 5]]], dtype=tf.float32)
35 | x4 = tf.constant([[[2, 4, 2]], [[6, 4, 1]]], dtype=tf.float32)
36 | x = tf.stack([x1, x2, x3, x4])
37 | self.assertAllEqual(x.shape.as_list(), [4, 2, 1, 3])
38 |
39 | core_bn = tf.layers.batch_normalization(x, training=True)
40 | contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True)
41 | custom_bn = arch_ops.batch_norm(x, is_training=True)
42 | with self.session() as sess:
43 | sess.run(tf.global_variables_initializer())
44 | core_bn, contrib_bn, custom_bn = sess.run(
45 | [core_bn, contrib_bn, custom_bn])
46 | tf.logging.info("core_bn: %s", core_bn[0])
47 | tf.logging.info("contrib_bn: %s", contrib_bn[0])
48 | tf.logging.info("custom_bn: %s", custom_bn[0])
49 | self.assertAllClose(core_bn, contrib_bn)
50 | self.assertAllClose(custom_bn, contrib_bn)
51 | expected_values = np.asarray(
52 | [[[[0.4375205, 1.30336881, -0.58830315]],
53 | [[0.4375205, 1.66291881, 1.76490951]]],
54 | [[[-1.89592218, -0.49438119, -1.37270737]],
55 | [[-0.14584017, -1.21348119, 0.19610107]]],
56 | [[[1.02088118, -0.49438119, 0.98050523]],
57 | [[0.4375205, -1.21348119, 0.58830321]]],
58 | [[[-1.31256151, 0.22471881, -0.58830315]],
59 | [[1.02088118, 0.22471881, -0.98050523]]]],
60 | dtype=np.float32)
61 | self.assertAllClose(custom_bn, expected_values)
62 |
63 | def testAccumulatedMomentsDuringTraing(self):
64 | with tf.Graph().as_default():
65 | mean_in = tf.placeholder(tf.float32, shape=[2])
66 | variance_in = tf.placeholder(tf.float32, shape=[2])
67 | mean, variance = arch_ops._accumulated_moments_for_inference(
68 | mean=mean_in, variance=variance_in, is_training=True)
69 | variables_by_name = {v.op.name: v for v in tf.global_variables()}
70 | tf.logging.error(variables_by_name)
71 | accu_mean = variables_by_name["accu/accu_mean"]
72 | accu_variance = variables_by_name["accu/accu_variance"]
73 | accu_counter = variables_by_name["accu/accu_counter"]
74 | with self.session() as sess:
75 | sess.run(tf.global_variables_initializer())
76 | m1, v1 = sess.run(
77 | [mean, variance],
78 | feed_dict={mean_in: [1.0, 2.0], variance_in: [3.0, 4.0]})
79 | self.assertAllClose(m1, [1.0, 2.0])
80 | self.assertAllClose(v1, [3.0, 4.0])
81 | m2, v2 = sess.run(
82 | [mean, variance],
83 | feed_dict={mean_in: [5.0, 6.0], variance_in: [7.0, 8.0]})
84 | self.assertAllClose(m2, [5.0, 6.0])
85 | self.assertAllClose(v2, [7.0, 8.0])
86 | am, av, ac = sess.run([accu_mean, accu_variance, accu_counter])
87 | self.assertAllClose(am, [0.0, 0.0])
88 | self.assertAllClose(av, [0.0, 0.0])
89 | self.assertAllClose([ac], [0.0])
90 |
91 | def testAccumulatedMomentsDuringEal(self):
92 | with tf.Graph().as_default():
93 | mean_in = tf.placeholder(tf.float32, shape=[2])
94 | variance_in = tf.placeholder(tf.float32, shape=[2])
95 | mean, variance = arch_ops._accumulated_moments_for_inference(
96 | mean=mean_in, variance=variance_in, is_training=False)
97 | variables_by_name = {v.op.name: v for v in tf.global_variables()}
98 | tf.logging.error(variables_by_name)
99 | accu_mean = variables_by_name["accu/accu_mean"]
100 | accu_variance = variables_by_name["accu/accu_variance"]
101 | accu_counter = variables_by_name["accu/accu_counter"]
102 | update_accus = variables_by_name["accu/update_accus"]
103 | with self.session() as sess:
104 | sess.run(tf.global_variables_initializer())
105 | # Fill accumulators.
106 | sess.run(tf.assign(update_accus, 1))
107 | m1, v1 = sess.run(
108 | [mean, variance],
109 | feed_dict={mean_in: [1.0, 2.0], variance_in: [3.0, 4.0]})
110 | self.assertAllClose(m1, [1.0, 2.0])
111 | self.assertAllClose(v1, [3.0, 4.0])
112 | m2, v2 = sess.run(
113 | [mean, variance],
114 | feed_dict={mean_in: [5.0, 6.0], variance_in: [7.0, 8.0]})
115 | self.assertAllClose(m2, [3.0, 4.0])
116 | self.assertAllClose(v2, [5.0, 6.0])
117 | # Check accumulators.
118 | am, av, ac = sess.run([accu_mean, accu_variance, accu_counter])
119 | self.assertAllClose(am, [6.0, 8.0])
120 | self.assertAllClose(av, [10.0, 12.0])
121 | self.assertAllClose([ac], [2.0])
122 | # Use accumulators.
123 | sess.run(tf.assign(update_accus, 0))
124 | m3, v3 = sess.run(
125 | [mean, variance],
126 | feed_dict={mean_in: [2.0, 2.0], variance_in: [3.0, 3.0]})
127 | self.assertAllClose(m3, [3.0, 4.0])
128 | self.assertAllClose(v3, [5.0, 6.0])
129 | am, av, ac = sess.run([accu_mean, accu_variance, accu_counter])
130 | self.assertAllClose(am, [6.0, 8.0])
131 | self.assertAllClose(av, [10.0, 12.0])
132 | self.assertAllClose([ac], [2.0])
133 |
134 |
135 | if __name__ == "__main__":
136 | tf.test.main()
137 |
--------------------------------------------------------------------------------
/compare_gan/architectures/arch_ops_tpu_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for custom architecture operations on TPUs."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import logging
23 | from compare_gan.architectures import arch_ops
24 | import gin
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 |
29 | class ArchOpsTpuTest(tf.test.TestCase):
30 |
31 | def setUp(self):
32 | # Construct input for batch norm tests:
33 | # 4 images with resolution 2x1 and 3 channels.
34 | x1 = np.asarray([[[5, 7, 2]], [[5, 8, 8]]], dtype=np.float32)
35 | x2 = np.asarray([[[1, 2, 0]], [[4, 0, 4]]], dtype=np.float32)
36 | x3 = np.asarray([[[6, 2, 6]], [[5, 0, 5]]], dtype=np.float32)
37 | x4 = np.asarray([[[2, 4, 2]], [[6, 4, 1]]], dtype=np.float32)
38 | self._inputs = np.stack([x1, x2, x3, x4])
39 | self.assertAllEqual(self._inputs.shape, [4, 2, 1, 3])
40 | # And the expected output for applying batch norm (without additional
41 | # scaling/shifting).
42 | self._expected_outputs = np.asarray(
43 | [[[[0.4375205, 1.30336881, -0.58830315]],
44 | [[0.4375205, 1.66291881, 1.76490951]]],
45 | [[[-1.89592218, -0.49438119, -1.37270737]],
46 | [[-0.14584017, -1.21348119, 0.19610107]]],
47 | [[[1.02088118, -0.49438119, 0.98050523]],
48 | [[0.4375205, -1.21348119, 0.58830321]]],
49 | [[[-1.31256151, 0.22471881, -0.58830315]],
50 | [[1.02088118, 0.22471881, -0.98050523]]]],
51 | dtype=np.float32)
52 | self.assertAllEqual(self._expected_outputs.shape, [4, 2, 1, 3])
53 |
54 | def testRunsOnTpu(self):
55 | """Verify that the test cases runs on a TPU chip and has 2 cores."""
56 | expected_device_names = [
57 | "/job:localhost/replica:0/task:0/device:CPU:0",
58 | "/job:localhost/replica:0/task:0/device:TPU:0",
59 | "/job:localhost/replica:0/task:0/device:TPU:1",
60 | "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
61 | ]
62 | with self.session() as sess:
63 | devices = sess.list_devices()
64 | tf.logging.info("devices:\n%s", "\n".join([str(d) for d in devices]))
65 | self.assertAllEqual([d.name for d in devices], expected_device_names)
66 |
67 | def testBatchNormOneCore(self):
68 | def computation(x):
69 | core_bn = tf.layers.batch_normalization(x, training=True)
70 | contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True)
71 | custom_bn = arch_ops.batch_norm(x, is_training=True)
72 | tf.logging.info("custom_bn tensor: %s", custom_bn)
73 | return core_bn, contrib_bn, custom_bn
74 |
75 | with tf.Graph().as_default():
76 | x = tf.constant(self._inputs)
77 | core_bn, contrib_bn, custom_bn = tf.contrib.tpu.batch_parallel(
78 | computation, [x], num_shards=1)
79 |
80 | with self.session() as sess:
81 | sess.run(tf.contrib.tpu.initialize_system())
82 | sess.run(tf.global_variables_initializer())
83 | core_bn, contrib_bn, custom_bn = sess.run(
84 | [core_bn, contrib_bn, custom_bn])
85 | logging.info("core_bn: %s", core_bn)
86 | logging.info("contrib_bn: %s", contrib_bn)
87 | logging.info("custom_bn: %s", custom_bn)
88 | self.assertAllClose(core_bn, self._expected_outputs)
89 | self.assertAllClose(contrib_bn, self._expected_outputs)
90 | self.assertAllClose(custom_bn, self._expected_outputs)
91 |
92 | def testBatchNormTwoCoresCoreAndContrib(self):
93 | def computation(x):
94 | core_bn = tf.layers.batch_normalization(x, training=True)
95 | contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True)
96 | return core_bn, contrib_bn
97 |
98 | with tf.Graph().as_default():
99 | x = tf.constant(self._inputs)
100 | core_bn, contrib_bn = tf.contrib.tpu.batch_parallel(
101 | computation, [x], num_shards=2)
102 |
103 | with self.session() as sess:
104 | sess.run(tf.contrib.tpu.initialize_system())
105 | sess.run(tf.global_variables_initializer())
106 | core_bn, contrib_bn = sess.run([core_bn, contrib_bn])
107 | logging.info("core_bn: %s", core_bn)
108 | logging.info("contrib_bn: %s", contrib_bn)
109 | self.assertNotAllClose(core_bn, self._expected_outputs)
110 | self.assertNotAllClose(contrib_bn, self._expected_outputs)
111 |
112 | def testBatchNormTwoCoresCustom(self):
113 | def computation(x):
114 | custom_bn = arch_ops.batch_norm(x, is_training=True, name="custom_bn")
115 | gin.bind_parameter("cross_replica_moments.parallel", False)
116 | custom_bn_seq = arch_ops.batch_norm(x, is_training=True,
117 | name="custom_bn_seq")
118 | return custom_bn, custom_bn_seq
119 |
120 | with tf.Graph().as_default():
121 | x = tf.constant(self._inputs)
122 | custom_bn, custom_bn_seq = tf.contrib.tpu.batch_parallel(
123 | computation, [x], num_shards=2)
124 |
125 | with self.session() as sess:
126 | sess.run(tf.contrib.tpu.initialize_system())
127 | sess.run(tf.global_variables_initializer())
128 | custom_bn, custom_bn_seq = sess.run(
129 | [custom_bn, custom_bn_seq])
130 | logging.info("custom_bn: %s", custom_bn)
131 | logging.info("custom_bn_seq: %s", custom_bn_seq)
132 | self.assertAllClose(custom_bn, self._expected_outputs)
133 | self.assertAllClose(custom_bn_seq, self._expected_outputs)
134 |
135 |
136 | if __name__ == "__main__":
137 | tf.test.main()
138 |
--------------------------------------------------------------------------------
/compare_gan/architectures/architectures_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for neural architectures."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import parameterized
23 | from compare_gan.architectures import dcgan
24 | from compare_gan.architectures import infogan
25 | from compare_gan.architectures import resnet30
26 | from compare_gan.architectures import resnet5
27 | from compare_gan.architectures import resnet_biggan
28 | from compare_gan.architectures import resnet_cifar
29 | from compare_gan.architectures import resnet_stl
30 | from compare_gan.architectures import sndcgan
31 | import tensorflow as tf
32 |
33 |
34 | class ArchitectureTest(parameterized.TestCase, tf.test.TestCase):
35 |
36 | def assertArchitectureBuilds(self, gen, disc, image_shape, z_dim=120):
37 | with tf.Graph().as_default():
38 | batch_size = 2
39 | num_classes = 10
40 | # Prepare inputs
41 | z = tf.random.normal((batch_size, z_dim), name="z")
42 | y = tf.one_hot(tf.range(batch_size), num_classes)
43 | # Run check output shapes for G and D.
44 | x = gen(z=z, y=y, is_training=True, reuse=False)
45 | self.assertAllEqual(x.shape.as_list()[1:], image_shape)
46 | out, _, _ = disc(
47 | x, y=y, is_training=True, reuse=False)
48 | self.assertAllEqual(out.shape.as_list(), (batch_size, 1))
49 | # Check that G outputs valid pixel values (we use [0, 1] everywhere) and
50 | # D outputs a probablilty.
51 | with self.session() as sess:
52 | sess.run(tf.global_variables_initializer())
53 | image, pred = sess.run([x, out])
54 | self.assertAllGreaterEqual(image, 0)
55 | self.assertAllLessEqual(image, 1)
56 | self.assertAllGreaterEqual(pred, 0)
57 | self.assertAllLessEqual(pred, 1)
58 |
59 | @parameterized.parameters(
60 | {"image_shape": (28, 28, 1)},
61 | {"image_shape": (32, 32, 1)},
62 | {"image_shape": (32, 32, 3)},
63 | {"image_shape": (64, 64, 3)},
64 | {"image_shape": (128, 128, 3)},
65 | )
66 | def testDcGan(self, image_shape):
67 | self.assertArchitectureBuilds(
68 | gen=dcgan.Generator(image_shape=image_shape),
69 | disc=dcgan.Discriminator(),
70 | image_shape=image_shape)
71 |
72 | @parameterized.parameters(
73 | {"image_shape": (28, 28, 1)},
74 | {"image_shape": (32, 32, 1)},
75 | {"image_shape": (32, 32, 3)},
76 | {"image_shape": (64, 64, 3)},
77 | {"image_shape": (128, 128, 3)},
78 | )
79 | def testInfoGan(self, image_shape):
80 | self.assertArchitectureBuilds(
81 | gen=infogan.Generator(image_shape=image_shape),
82 | disc=infogan.Discriminator(),
83 | image_shape=image_shape)
84 |
85 | def testResNet30(self, image_shape=(128, 128, 3)):
86 | self.assertArchitectureBuilds(
87 | gen=resnet30.Generator(image_shape=image_shape),
88 | disc=resnet30.Discriminator(),
89 | image_shape=image_shape)
90 |
91 | @parameterized.parameters(
92 | {"image_shape": (32, 32, 1)},
93 | {"image_shape": (32, 32, 3)},
94 | {"image_shape": (64, 64, 3)},
95 | {"image_shape": (128, 128, 3)},
96 | )
97 | def testResNet5(self, image_shape):
98 | self.assertArchitectureBuilds(
99 | gen=resnet5.Generator(image_shape=image_shape),
100 | disc=resnet5.Discriminator(),
101 | image_shape=image_shape)
102 |
103 | @parameterized.parameters(
104 | {"image_shape": (32, 32, 3)},
105 | {"image_shape": (64, 64, 3)},
106 | {"image_shape": (128, 128, 3)},
107 | {"image_shape": (256, 256, 3)},
108 | {"image_shape": (512, 512, 3)},
109 | )
110 | def testResNet5BigGan(self, image_shape):
111 | if image_shape[0] == 512:
112 | z_dim = 160
113 | elif image_shape[0] == 256:
114 | z_dim = 140
115 | else:
116 | z_dim = 120
117 | # Use channel multiplier 4 to avoid OOM errors.
118 | self.assertArchitectureBuilds(
119 | gen=resnet_biggan.Generator(image_shape=image_shape, ch=16),
120 | disc=resnet_biggan.Discriminator(ch=16),
121 | image_shape=image_shape,
122 | z_dim=z_dim)
123 |
124 | @parameterized.parameters(
125 | {"image_shape": (32, 32, 1)},
126 | {"image_shape": (32, 32, 3)},
127 | )
128 | def testResNetCifar(self, image_shape):
129 | self.assertArchitectureBuilds(
130 | gen=resnet_cifar.Generator(image_shape=image_shape),
131 | disc=resnet_cifar.Discriminator(),
132 | image_shape=image_shape)
133 |
134 | @parameterized.parameters(
135 | {"image_shape": (48, 48, 1)},
136 | {"image_shape": (48, 48, 3)},
137 | )
138 | def testResNetStl(self, image_shape):
139 | self.assertArchitectureBuilds(
140 | gen=resnet_stl.Generator(image_shape=image_shape),
141 | disc=resnet_stl.Discriminator(),
142 | image_shape=image_shape)
143 |
144 | @parameterized.parameters(
145 | {"image_shape": (28, 28, 1)},
146 | {"image_shape": (32, 32, 1)},
147 | {"image_shape": (32, 32, 3)},
148 | {"image_shape": (64, 64, 3)},
149 | {"image_shape": (128, 128, 3)},
150 | )
151 | def testSnDcGan(self, image_shape):
152 | self.assertArchitectureBuilds(
153 | gen=sndcgan.Generator(image_shape=image_shape),
154 | disc=sndcgan.Discriminator(),
155 | image_shape=image_shape)
156 |
157 |
158 | if __name__ == "__main__":
159 | tf.test.main()
160 |
--------------------------------------------------------------------------------
/compare_gan/architectures/dcgan.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of DCGAN generator and discriminator architectures.
17 |
18 | Details are available in https://arxiv.org/abs/1511.06434.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | from compare_gan.architectures import abstract_arch
26 | from compare_gan.architectures.arch_ops import conv2d
27 | from compare_gan.architectures.arch_ops import deconv2d
28 | from compare_gan.architectures.arch_ops import linear
29 | from compare_gan.architectures.arch_ops import lrelu
30 |
31 | import numpy as np
32 | import tensorflow as tf
33 |
34 |
35 | def conv_out_size_same(size, stride):
36 | return int(np.ceil(float(size) / float(stride)))
37 |
38 |
39 | class Generator(abstract_arch.AbstractGenerator):
40 | """DCGAN generator.
41 |
42 | Details are available at https://arxiv.org/abs/1511.06434. Notable changes
43 | include BatchNorm in the generator, ReLu instead of LeakyReLu and ReLu in the
44 | generator, except for output which uses tanh.
45 | """
46 |
47 | def apply(self, z, y, is_training):
48 | """Build the generator network for the given inputs.
49 |
50 | Args:
51 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
52 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
53 | labels.
54 | is_training: boolean, are we in train or eval model.
55 |
56 | Returns:
57 | A tensor of size [batch_size] + self._image_shape with values in [0, 1].
58 | """
59 | gf_dim = 64 # Dimension of filters in first convolutional layer.
60 | bs = z.shape[0].value
61 | s_h, s_w, colors = self._image_shape
62 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
63 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
64 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
65 | s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)
66 |
67 | net = linear(z, gf_dim * 8 *s_h16 * s_w16, scope="g_fc1")
68 | net = tf.reshape(net, [-1, s_h16, s_w16, gf_dim * 8])
69 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn1")
70 | net = tf.nn.relu(net)
71 | net = deconv2d(net, [bs, s_h8, s_w8, gf_dim*4], 5, 5, 2, 2, name="g_dc1")
72 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn2")
73 | net = tf.nn.relu(net)
74 | net = deconv2d(net, [bs, s_h4, s_w4, gf_dim*2], 5, 5, 2, 2, name="g_dc2")
75 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn3")
76 | net = tf.nn.relu(net)
77 | net = deconv2d(net, [bs, s_h2, s_w2, gf_dim*1], 5, 5, 2, 2, name="g_dc3")
78 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn4")
79 | net = tf.nn.relu(net)
80 | net = deconv2d(net, [bs, s_h, s_w, colors], 5, 5, 2, 2, name="g_dc4")
81 | net = 0.5 * tf.nn.tanh(net) + 0.5
82 | return net
83 |
84 |
85 | class Discriminator(abstract_arch.AbstractDiscriminator):
86 | """DCGAN discriminator.
87 |
88 | Details are available at https://arxiv.org/abs/1511.06434. Notable changes
89 | include BatchNorm in the discriminator and LeakyReLU for all layers.
90 | """
91 |
92 | def apply(self, x, y, is_training):
93 | """Apply the discriminator on a input.
94 |
95 | Args:
96 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
97 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
98 | labels.
99 | is_training: Boolean, whether the architecture should be constructed for
100 | training or inference.
101 |
102 | Returns:
103 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
104 | before the final output activation function and logits form the second
105 | last layer.
106 | """
107 | bs = x.shape[0].value
108 | df_dim = 64 # Dimension of filters in the first convolutional layer.
109 | net = lrelu(conv2d(x, df_dim, 5, 5, 2, 2, name="d_conv1",
110 | use_sn=self._spectral_norm))
111 | net = conv2d(net, df_dim * 2, 5, 5, 2, 2, name="d_conv2",
112 | use_sn=self._spectral_norm)
113 |
114 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn1")
115 | net = lrelu(net)
116 | net = conv2d(net, df_dim * 4, 5, 5, 2, 2, name="d_conv3",
117 | use_sn=self._spectral_norm)
118 |
119 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn2")
120 | net = lrelu(net)
121 | net = conv2d(net, df_dim * 8, 5, 5, 2, 2, name="d_conv4",
122 | use_sn=self._spectral_norm)
123 |
124 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn3")
125 | net = lrelu(net)
126 | out_logit = linear(
127 | tf.reshape(net, [bs, -1]), 1, scope="d_fc4", use_sn=self._spectral_norm)
128 | out = tf.nn.sigmoid(out_logit)
129 | return out, out_logit, net
130 |
--------------------------------------------------------------------------------
/compare_gan/architectures/infogan.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of InfoGAN generator and discriminator architectures.
17 |
18 | Details are available in https://arxiv.org/pdf/1606.03657.pdf.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | from compare_gan.architectures import abstract_arch
26 | from compare_gan.architectures.arch_ops import batch_norm
27 | from compare_gan.architectures.arch_ops import conv2d
28 | from compare_gan.architectures.arch_ops import deconv2d
29 | from compare_gan.architectures.arch_ops import linear
30 | from compare_gan.architectures.arch_ops import lrelu
31 |
32 | import tensorflow as tf
33 |
34 |
35 | class Generator(abstract_arch.AbstractGenerator):
36 | """Generator architecture based on InfoGAN."""
37 |
38 | def apply(self, z, y, is_training):
39 | """Build the generator network for the given inputs.
40 |
41 | Args:
42 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
43 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
44 | labels.
45 | is_training: boolean, are we in train or eval model.
46 |
47 | Returns:
48 | A tensor of size [batch_size] + self._image_shape with values in [0, 1].
49 | """
50 | del y
51 | h, w, c = self._image_shape
52 | bs = z.shape.as_list()[0]
53 | net = linear(z, 1024, scope="g_fc1")
54 | net = lrelu(batch_norm(net, is_training=is_training, name="g_bn1"))
55 | net = linear(net, 128 * (h // 4) * (w // 4), scope="g_fc2")
56 | net = lrelu(batch_norm(net, is_training=is_training, name="g_bn2"))
57 | net = tf.reshape(net, [bs, h // 4, w // 4, 128])
58 | net = deconv2d(net, [bs, h // 2, w // 2, 64], 4, 4, 2, 2, name="g_dc3")
59 | net = lrelu(batch_norm(net, is_training=is_training, name="g_bn3"))
60 | net = deconv2d(net, [bs, h, w, c], 4, 4, 2, 2, name="g_dc4")
61 | out = tf.nn.sigmoid(net)
62 | return out
63 |
64 |
65 | class Discriminator(abstract_arch.AbstractDiscriminator):
66 | """Discriminator architecture based on InfoGAN."""
67 |
68 | def apply(self, x, y, is_training):
69 | """Apply the discriminator on a input.
70 |
71 | Args:
72 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
73 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
74 | labels.
75 | is_training: Boolean, whether the architecture should be constructed for
76 | training or inference.
77 |
78 | Returns:
79 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
80 | before the final output activation function and logits form the second
81 | last layer.
82 | """
83 | use_sn = self._spectral_norm
84 | batch_size = x.shape.as_list()[0]
85 | # Resulting shape: [bs, h/2, w/2, 64].
86 | net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name="d_conv1", use_sn=use_sn))
87 | # Resulting shape: [bs, h/4, w/4, 128].
88 | net = conv2d(net, 128, 4, 4, 2, 2, name="d_conv2", use_sn=use_sn)
89 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn2")
90 | net = lrelu(net)
91 | # Resulting shape: [bs, h * w * 8].
92 | net = tf.reshape(net, [batch_size, -1])
93 | # Resulting shape: [bs, 1024].
94 | net = linear(net, 1024, scope="d_fc3", use_sn=use_sn)
95 | net = self.batch_norm(net, y=y, is_training=is_training, name="d_bn3")
96 | net = lrelu(net)
97 | # Resulting shape: [bs, 1].
98 | out_logit = linear(net, 1, scope="d_fc4", use_sn=use_sn)
99 | out = tf.nn.sigmoid(out_logit)
100 | return out, out_logit, net
101 |
--------------------------------------------------------------------------------
/compare_gan/architectures/resnet30.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A 30-block resnet.
17 |
18 | It contains 6 "super-blocks" and each such block contains 5 residual blocks in
19 | both the generator and discriminator. It supports the 128x128 resolution.
20 | Details can be found in "Improved Training of Wasserstein GANs", Gulrajani I.
21 | et al. 2017. The related code is available at
22 | https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py.
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | from compare_gan.architectures import arch_ops as ops
30 | from compare_gan.architectures import resnet_ops
31 |
32 | from six.moves import range
33 | import tensorflow as tf
34 |
35 |
36 | class Generator(resnet_ops.ResNetGenerator):
37 | """ResNet30 generator, 30 blocks, generates images of resolution 128x128.
38 |
39 | Trying to match the architecture defined in [1]. Difference is that there
40 | the final resolution is 64x64, while here we have 128x128.
41 | """
42 |
43 | def apply(self, z, y, is_training):
44 | """Build the generator network for the given inputs.
45 |
46 | Args:
47 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
48 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
49 | labels.
50 | is_training: boolean, are we in train or eval model.
51 |
52 | Returns:
53 | A tensor of size [batch_size] + self._image_shape with values in [0, 1].
54 | """
55 | z_shape = z.get_shape().as_list()
56 | if len(z_shape) != 2:
57 | raise ValueError("Expected shape [batch_size, z_dim], got %s." % z_shape)
58 | ch = 64
59 | colors = self._image_shape[2]
60 | # Map noise to the actual seed.
61 | output = ops.linear(z, 4 * 4 * 8 * ch, scope="fc_noise")
62 | # Reshape the seed to be a rank-4 Tensor.
63 | output = tf.reshape(output, [-1, 4, 4, 8 * ch], name="fc_reshaped")
64 | in_channels = 8 * ch
65 | out_channels = 4 * ch
66 | for superblock in range(6):
67 | for i in range(5):
68 | block = self._resnet_block(
69 | name="B_{}_{}".format(superblock, i),
70 | in_channels=in_channels,
71 | out_channels=in_channels,
72 | scale="none")
73 | output = block(output, z=z, y=y, is_training=is_training)
74 | # We want to upscale 5 times.
75 | if superblock < 5:
76 | block = self._resnet_block(
77 | name="B_{}_up".format(superblock),
78 | in_channels=in_channels,
79 | out_channels=out_channels,
80 | scale="up")
81 | output = block(output, z=z, y=y, is_training=is_training)
82 | in_channels /= 2
83 | out_channels /= 2
84 |
85 | output = ops.conv2d(
86 | output, output_dim=colors, k_h=3, k_w=3, d_h=1, d_w=1,
87 | name="final_conv")
88 | output = tf.nn.sigmoid(output)
89 | return output
90 |
91 |
92 | class Discriminator(resnet_ops.ResNetDiscriminator):
93 | """ResNet discriminator, 30 blocks, 128x128x3 and 128x128x1 resolution."""
94 |
95 | def apply(self, x, y, is_training):
96 | """Apply the discriminator on a input.
97 |
98 | Args:
99 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
100 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
101 | labels.
102 | is_training: Boolean, whether the architecture should be constructed for
103 | training or inference.
104 |
105 | Returns:
106 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
107 | before the final output activation function and logits form the second
108 | last layer.
109 | """
110 | resnet_ops.validate_image_inputs(x)
111 | colors = x.get_shape().as_list()[-1]
112 | assert colors in [1, 3]
113 | ch = 64
114 | output = ops.conv2d(
115 | x, output_dim=ch // 4, k_h=3, k_w=3, d_h=1, d_w=1,
116 | name="color_conv")
117 | in_channels = ch // 4
118 | out_channels = ch // 2
119 | for superblock in range(6):
120 | for i in range(5):
121 | block = self._resnet_block(
122 | name="B_{}_{}".format(superblock, i),
123 | in_channels=in_channels,
124 | out_channels=in_channels,
125 | scale="none")
126 | output = block(output, z=None, y=y, is_training=is_training)
127 | # We want to downscale 5 times.
128 | if superblock < 5:
129 | block = self._resnet_block(
130 | name="B_{}_up".format(superblock),
131 | in_channels=in_channels,
132 | out_channels=out_channels,
133 | scale="down")
134 | output = block(output, z=None, y=y, is_training=is_training)
135 | in_channels *= 2
136 | out_channels *= 2
137 |
138 | # Final part
139 | output = tf.reshape(output, [-1, 4 * 4 * 8 * ch])
140 | out_logit = ops.linear(output, 1, scope="disc_final_fc",
141 | use_sn=self._spectral_norm)
142 | out = tf.nn.sigmoid(out_logit)
143 | return out, out_logit, output
144 |
--------------------------------------------------------------------------------
/compare_gan/architectures/resnet5.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A deep neural architecture with residual blocks and skip connections.
17 |
18 | It contains 5 residual blocks in both the generator and discriminator and
19 | supports 128x128 resolution. Details can be found in "Improved Training
20 | of Wasserstein GANs", Gulrajani I. et al. 2017. The related code is available at
21 | https://github.com/igul222/improved_wgan_training/blob/master/gan_64x64.py.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from compare_gan.architectures import arch_ops as ops
29 | from compare_gan.architectures import resnet_ops
30 |
31 | import numpy as np
32 | from six.moves import range
33 | import tensorflow as tf
34 |
35 |
36 | class Generator(resnet_ops.ResNetGenerator):
37 | """ResNet generator consisting of 5 blocks, outputs 128x128x3 resolution."""
38 |
39 | def __init__(self, ch=64, channels=(8, 8, 4, 4, 2, 1), **kwargs):
40 | super(Generator, self).__init__(**kwargs)
41 | self._ch = ch
42 | self._channels = channels
43 |
44 | def apply(self, z, y, is_training):
45 | """Build the generator network for the given inputs.
46 |
47 | Args:
48 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
49 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
50 | labels.
51 | is_training: boolean, are we in train or eval model.
52 |
53 | Returns:
54 | A tensor of size [batch_size] + self._image_shape with values in [0, 1].
55 | """
56 | # Each block upscales by a factor of 2.
57 | seed_size = 4
58 | image_size = self._image_shape[0]
59 |
60 | # Map noise to the actual seed.
61 | net = ops.linear(
62 | z,
63 | self._ch * self._channels[0] * seed_size * seed_size,
64 | scope="fc_noise")
65 | # Reshape the seed to be a rank-4 Tensor.
66 | net = tf.reshape(
67 | net,
68 | [-1, seed_size, seed_size, self._ch * self._channels[0]],
69 | name="fc_reshaped")
70 |
71 | up_layers = np.log2(float(image_size) / seed_size)
72 | if not up_layers.is_integer():
73 | raise ValueError("log2({}/{}) must be an integer.".format(
74 | image_size, seed_size))
75 | if up_layers < 0 or up_layers > 5:
76 | raise ValueError("Invalid image_size {}.".format(image_size))
77 | up_layers = int(up_layers)
78 |
79 | for block_idx in range(5):
80 | block = self._resnet_block(
81 | name="B{}".format(block_idx + 1),
82 | in_channels=self._ch * self._channels[block_idx],
83 | out_channels=self._ch * self._channels[block_idx + 1],
84 | scale="up" if block_idx < up_layers else "none")
85 | net = block(net, z=z, y=y, is_training=is_training)
86 |
87 | net = self.batch_norm(
88 | net, z=z, y=y, is_training=is_training, name="final_norm")
89 | net = tf.nn.relu(net)
90 | net = ops.conv2d(net, output_dim=self._image_shape[2],
91 | k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv")
92 | net = tf.nn.sigmoid(net)
93 | return net
94 |
95 |
96 | class Discriminator(resnet_ops.ResNetDiscriminator):
97 | """ResNet5 discriminator, 5 blocks, supporting 128x128x3 and 128x128x1."""
98 |
99 | def __init__(self, ch=64, channels=(1, 2, 4, 4, 8, 8), **kwargs):
100 | super(Discriminator, self).__init__(**kwargs)
101 | self._ch = ch
102 | self._channels = channels
103 |
104 | def apply(self, x, y, is_training):
105 | """Apply the discriminator on a input.
106 |
107 | Args:
108 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
109 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
110 | labels.
111 | is_training: Boolean, whether the architecture should be constructed for
112 | training or inference.
113 |
114 | Returns:
115 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
116 | before the final output activation function and logits form the second
117 | last layer.
118 | """
119 | resnet_ops.validate_image_inputs(x)
120 | colors = x.shape[3].value
121 | if colors not in [1, 3]:
122 | raise ValueError("Number of color channels not supported: {}".format(
123 | colors))
124 |
125 | block = self._resnet_block(
126 | name="B0",
127 | in_channels=colors,
128 | out_channels=self._ch,
129 | scale="down")
130 | output = block(x, z=None, y=y, is_training=is_training)
131 |
132 | for block_idx in range(5):
133 | block = self._resnet_block(
134 | name="B{}".format(block_idx + 1),
135 | in_channels=self._ch * self._channels[block_idx],
136 | out_channels=self._ch * self._channels[block_idx + 1],
137 | scale="down")
138 | output = block(output, z=None, y=y, is_training=is_training)
139 |
140 | output = tf.nn.relu(output)
141 | pre_logits = tf.reduce_mean(output, axis=[1, 2])
142 | out_logit = ops.linear(pre_logits, 1, scope="disc_final_fc",
143 | use_sn=self._spectral_norm)
144 | out = tf.nn.sigmoid(out_logit)
145 | return out, out_logit, pre_logits
146 |
--------------------------------------------------------------------------------
/compare_gan/architectures/resnet_biggan_deep_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test number of parameters for the BigGAN-Deep architecture."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import logging
23 | from compare_gan import utils
24 | from compare_gan.architectures import arch_ops
25 | from compare_gan.architectures import resnet_biggan_deep
26 | import tensorflow as tf
27 |
28 |
29 | class ResNet5BigGanDeepTest(tf.test.TestCase):
30 |
31 | def testNumberOfParameters(self):
32 | with tf.Graph().as_default():
33 | batch_size = 2
34 | z = tf.zeros((batch_size, 128))
35 | y = tf.one_hot(tf.ones((batch_size,), dtype=tf.int32), 1000)
36 | generator = resnet_biggan_deep.Generator(
37 | image_shape=(128, 128, 3),
38 | batch_norm_fn=arch_ops.conditional_batch_norm)
39 | fake_images = generator(z, y=y, is_training=True, reuse=False)
40 | self.assertEqual(fake_images.shape.as_list(), [batch_size, 128, 128, 3])
41 | discriminator = resnet_biggan_deep.Discriminator()
42 | predictions = discriminator(fake_images, y, is_training=True)
43 | self.assertLen(predictions, 3)
44 |
45 | t_vars = tf.trainable_variables()
46 | g_vars = [var for var in t_vars if "generator" in var.name]
47 | d_vars = [var for var in t_vars if "discriminator" in var.name]
48 | g_param_overview = utils.get_parameter_overview(g_vars, limit=None)
49 | d_param_overview = utils.get_parameter_overview(d_vars, limit=None)
50 | g_param_overview = g_param_overview.split("\n")
51 | logging.info("Generator variables:")
52 | for i in range(0, len(g_param_overview), 80):
53 | logging.info("\n%s", "\n".join(g_param_overview[i:i + 80]))
54 | logging.info("Discriminator variables:\n%s", d_param_overview)
55 |
56 | g_num_weights = sum([v.get_shape().num_elements() for v in g_vars])
57 | self.assertEqual(g_num_weights, 50244484)
58 |
59 | d_num_weights = sum([v.get_shape().num_elements() for v in d_vars])
60 | self.assertEqual(d_num_weights, 34590210)
61 |
62 |
63 | if __name__ == "__main__":
64 | tf.test.main()
65 |
--------------------------------------------------------------------------------
/compare_gan/architectures/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Resnet generator and discriminator for CIFAR.
17 |
18 | Based on Table 4 from "Spectral Normalization for Generative Adversarial
19 | Networks", Miyato T. et al., 2018. [https://arxiv.org/pdf/1802.05957.pdf].
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | from compare_gan.architectures import arch_ops as ops
27 | from compare_gan.architectures import resnet_ops
28 |
29 | import gin
30 | from six.moves import range
31 | import tensorflow as tf
32 |
33 |
34 | @gin.configurable
35 | class Generator(resnet_ops.ResNetGenerator):
36 | """ResNet generator, 4 blocks, supporting 32x32 resolution."""
37 |
38 | def __init__(self,
39 | hierarchical_z=False,
40 | embed_z=False,
41 | embed_y=False,
42 | **kwargs):
43 | """Constructor for the ResNet Cifar generator.
44 |
45 | Args:
46 | hierarchical_z: Split z into chunks and only give one chunk to each.
47 | Each chunk will also be concatenated to y, the one hot encoded labels.
48 | embed_z: If True use a learnable embedding of z that is used instead.
49 | The embedding will have the length of z.
50 | embed_y: If True use a learnable embedding of y that is used instead.
51 | The embedding will have the length of z (not y!).
52 | **kwargs: additional arguments past on to ResNetGenerator.
53 | """
54 | super(Generator, self).__init__(**kwargs)
55 | self._hierarchical_z = hierarchical_z
56 | self._embed_z = embed_z
57 | self._embed_y = embed_y
58 |
59 | def apply(self, z, y, is_training):
60 | """Build the generator network for the given inputs.
61 |
62 | Args:
63 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
64 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
65 | labels.
66 | is_training: boolean, are we in train or eval model.
67 |
68 | Returns:
69 | A tensor of size [batch_size, 32, 32, colors] with values in [0, 1].
70 | """
71 | assert self._image_shape[0] == 32
72 | assert self._image_shape[1] == 32
73 | num_blocks = 3
74 | z_dim = z.shape[1].value
75 |
76 | if self._embed_z:
77 | z = ops.linear(z, z_dim, scope="embed_z", use_sn=self._spectral_norm)
78 | if self._embed_y:
79 | y = ops.linear(y, z_dim, scope="embed_y", use_sn=self._spectral_norm)
80 | y_per_block = num_blocks * [y]
81 | if self._hierarchical_z:
82 | z_per_block = tf.split(z, num_blocks + 1, axis=1)
83 | z0, z_per_block = z_per_block[0], z_per_block[1:]
84 | if y is not None:
85 | y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block]
86 | else:
87 | z0 = z
88 | z_per_block = num_blocks * [z]
89 |
90 | output = ops.linear(z0, 4 * 4 * 256, scope="fc_noise",
91 | use_sn=self._spectral_norm)
92 | output = tf.reshape(output, [-1, 4, 4, 256], name="fc_reshaped")
93 | for block_idx in range(3):
94 | block = self._resnet_block(
95 | name="B{}".format(block_idx + 1),
96 | in_channels=256,
97 | out_channels=256,
98 | scale="up")
99 | output = block(
100 | output,
101 | z=z_per_block[block_idx],
102 | y=y_per_block[block_idx],
103 | is_training=is_training)
104 |
105 | # Final processing of the output.
106 | output = self.batch_norm(
107 | output, z=z, y=y, is_training=is_training, name="final_norm")
108 | output = tf.nn.relu(output)
109 | output = ops.conv2d(output, output_dim=self._image_shape[2], k_h=3, k_w=3,
110 | d_h=1, d_w=1, name="final_conv",
111 | use_sn=self._spectral_norm,)
112 | return tf.nn.sigmoid(output)
113 |
114 |
115 | @gin.configurable
116 | class Discriminator(resnet_ops.ResNetDiscriminator):
117 | """ResNet discriminator, 4 blocks, supporting 32x32 with 1 or 3 colors."""
118 |
119 | def __init__(self, project_y=False, **kwargs):
120 | super(Discriminator, self).__init__(**kwargs)
121 | self._project_y = project_y
122 |
123 | def apply(self, x, y, is_training):
124 | """Apply the discriminator on a input.
125 |
126 | Args:
127 | x: `Tensor` of shape [batch_size, 32, 32, ?] with real or fake images.
128 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
129 | labels.
130 | is_training: Boolean, whether the architecture should be constructed for
131 | training or inference.
132 |
133 | Returns:
134 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
135 | before the final output activation function and logits form the second
136 | last layer.
137 | """
138 | resnet_ops.validate_image_inputs(x)
139 | colors = x.shape[3].value
140 | if colors not in [1, 3]:
141 | raise ValueError("Number of color channels not supported: {}".format(
142 | colors))
143 |
144 | output = x
145 | for block_idx in range(4):
146 | block = self._resnet_block(
147 | name="B{}".format(block_idx + 1),
148 | in_channels=colors if block_idx == 0 else 128,
149 | out_channels=128,
150 | scale="down" if block_idx <= 1 else "none")
151 | output = block(output, z=None, y=y, is_training=is_training)
152 |
153 | # Final part - ReLU
154 | output = tf.nn.relu(output)
155 |
156 | h = tf.reduce_mean(output, axis=[1, 2])
157 |
158 | out_logit = ops.linear(h, 1, scope="disc_final_fc",
159 | use_sn=self._spectral_norm)
160 | if self._project_y:
161 | if y is None:
162 | raise ValueError("You must provide class information y to project.")
163 | embedded_y = ops.linear(y, 128, use_bias=False,
164 | scope="embedding_fc", use_sn=self._spectral_norm)
165 | out_logit += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True)
166 | out = tf.nn.sigmoid(out_logit)
167 | return out, out_logit, h
168 |
--------------------------------------------------------------------------------
/compare_gan/architectures/resnet_init_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests weight initialization ops using ResNet5 architecture."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.architectures import resnet5
23 | from compare_gan.gans import consts
24 | import gin
25 | import tensorflow as tf
26 |
27 |
28 | class ResNetInitTest(tf.test.TestCase):
29 |
30 | def setUp(self):
31 | super(ResNetInitTest, self).setUp()
32 | gin.clear_config()
33 |
34 | def testInitializersOldDefault(self):
35 | valid_initalizer = [
36 | "kernel/Initializer/random_normal",
37 | "bias/Initializer/Const",
38 | # truncated_normal is the old default for conv2d.
39 | "kernel/Initializer/truncated_normal",
40 | "bias/Initializer/Const",
41 | "beta/Initializer/zeros",
42 | "gamma/Initializer/ones",
43 | ]
44 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
45 | with tf.Graph().as_default():
46 | z = tf.zeros((2, 128))
47 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
48 | z, y=None, is_training=True)
49 | resnet5.Discriminator()(fake_image, y=None, is_training=True)
50 | for var in tf.trainable_variables():
51 | op_name = var.initializer.inputs[1].name
52 | self.assertRegex(op_name, valid_op_names)
53 |
54 | def testInitializersRandomNormal(self):
55 | gin.bind_parameter("weights.initializer", consts.NORMAL_INIT)
56 | valid_initalizer = [
57 | "kernel/Initializer/random_normal",
58 | "bias/Initializer/Const",
59 | "kernel/Initializer/random_normal",
60 | "bias/Initializer/Const",
61 | "beta/Initializer/zeros",
62 | "gamma/Initializer/ones",
63 | ]
64 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
65 | with tf.Graph().as_default():
66 | z = tf.zeros((2, 128))
67 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
68 | z, y=None, is_training=True)
69 | resnet5.Discriminator()(fake_image, y=None, is_training=True)
70 | for var in tf.trainable_variables():
71 | op_name = var.initializer.inputs[1].name
72 | self.assertRegex(op_name, valid_op_names)
73 |
74 | def testInitializersTruncatedNormal(self):
75 | gin.bind_parameter("weights.initializer", consts.TRUNCATED_INIT)
76 | valid_initalizer = [
77 | "kernel/Initializer/truncated_normal",
78 | "bias/Initializer/Const",
79 | "kernel/Initializer/truncated_normal",
80 | "bias/Initializer/Const",
81 | "beta/Initializer/zeros",
82 | "gamma/Initializer/ones",
83 | ]
84 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
85 | with tf.Graph().as_default():
86 | z = tf.zeros((2, 128))
87 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
88 | z, y=None, is_training=True)
89 | resnet5.Discriminator()(fake_image, y=None, is_training=True)
90 | for var in tf.trainable_variables():
91 | op_name = var.initializer.inputs[1].name
92 | self.assertRegex(op_name, valid_op_names)
93 |
94 | def testGeneratorInitializersOrthogonal(self):
95 | gin.bind_parameter("weights.initializer", consts.ORTHOGONAL_INIT)
96 | valid_initalizer = [
97 | "kernel/Initializer/mul_1",
98 | "bias/Initializer/Const",
99 | "kernel/Initializer/mul_1",
100 | "bias/Initializer/Const",
101 | "beta/Initializer/zeros",
102 | "gamma/Initializer/ones",
103 | ]
104 | valid_op_names = "/({}):0$".format("|".join(valid_initalizer))
105 | with tf.Graph().as_default():
106 | z = tf.zeros((2, 128))
107 | fake_image = resnet5.Generator(image_shape=(128, 128, 3))(
108 | z, y=None, is_training=True)
109 | resnet5.Discriminator()(fake_image, y=None, is_training=True)
110 | for var in tf.trainable_variables():
111 | op_name = var.initializer.inputs[1].name
112 | self.assertRegex(op_name, valid_op_names)
113 |
114 |
115 | if __name__ == "__main__":
116 | tf.test.main()
117 |
--------------------------------------------------------------------------------
/compare_gan/architectures/resnet_stl.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A deep neural architecture with residual blocks and skip connections.
17 |
18 | Based on Table 5 from "Spectral Normalization for Generative Adversarial
19 | Networks", Miyato T. et al., 2018. [https://arxiv.org/pdf/1802.05957.pdf].
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | from compare_gan.architectures import arch_ops as ops
27 | from compare_gan.architectures import resnet_ops
28 |
29 | from six.moves import range
30 | import tensorflow as tf
31 |
32 |
33 | class Generator(resnet_ops.ResNetGenerator):
34 | """ResNet generator, 3 blocks, supporting 48x48 resolution."""
35 |
36 | def apply(self, z, y, is_training):
37 | """Build the generator network for the given inputs.
38 |
39 | Args:
40 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
41 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
42 | labels.
43 | is_training: boolean, are we in train or eval model.
44 |
45 | Returns:
46 | A tensor of size [batch_size, 32, 32, colors] with values in [0, 1].
47 | """
48 | ch = 64
49 | colors = self._image_shape[2]
50 | batch_size = z.get_shape().as_list()[0]
51 | magic = [(8, 4), (4, 2), (2, 1)]
52 | output = ops.linear(z, 6 * 6 * 512, scope="fc_noise")
53 | output = tf.reshape(output, [batch_size, 6, 6, 512], name="fc_reshaped")
54 | for block_idx in range(3):
55 | block = self._resnet_block(
56 | name="B{}".format(block_idx + 1),
57 | in_channels=ch * magic[block_idx][0],
58 | out_channels=ch * magic[block_idx][1],
59 | scale="up")
60 | output = block(output, z=z, y=y, is_training=is_training)
61 | output = self.batch_norm(
62 | output, z=z, y=y, is_training=is_training, scope="final_norm")
63 | output = tf.nn.relu(output)
64 | output = ops.conv2d(output, output_dim=colors, k_h=3, k_w=3, d_h=1, d_w=1,
65 | name="final_conv")
66 | return tf.nn.sigmoid(output)
67 |
68 |
69 | class Discriminator(resnet_ops.ResNetDiscriminator):
70 | """ResNet discriminator, 4 blocks, suports 48x48 resolution."""
71 |
72 | def apply(self, x, y, is_training):
73 | """Apply the discriminator on a input.
74 |
75 | Args:
76 | x: `Tensor` of shape [batch_size, 32, 32, ?] with real or fake images.
77 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
78 | labels.
79 | is_training: Boolean, whether the architecture should be constructed for
80 | training or inference.
81 |
82 | Returns:
83 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
84 | before the final output activation function and logits form the second
85 | last layer.
86 | """
87 | resnet_ops.validate_image_inputs(x, validate_power2=False)
88 | colors = x.shape[-1].value
89 | if colors not in [1, 3]:
90 | raise ValueError("Number of color channels unknown: %s" % colors)
91 | ch = 64
92 | block = self._resnet_block(
93 | name="B0", in_channels=colors, out_channels=ch, scale="down")
94 | output = block(x, z=None, y=y, is_training=is_training)
95 | magic = [(1, 2), (2, 4), (4, 8), (8, 16)]
96 | for block_idx in range(4):
97 | block = self._resnet_block(
98 | name="B{}".format(block_idx + 1),
99 | in_channels=ch * magic[block_idx][0],
100 | out_channels=ch * magic[block_idx][1],
101 | scale="down" if block_idx < 3 else "none")
102 | output = block(output, z=None, y=y, is_training=is_training)
103 | output = tf.nn.relu(output)
104 | pre_logits = tf.reduce_mean(output, axis=[1, 2])
105 | out_logit = ops.linear(pre_logits, 1, scope="disc_final_fc",
106 | use_sn=self._spectral_norm)
107 | out = tf.nn.sigmoid(out_logit)
108 | return out, out_logit, pre_logits
109 |
--------------------------------------------------------------------------------
/compare_gan/architectures/sndcgan.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of SNDCGAN generator and discriminator architectures."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.architectures import abstract_arch
23 | from compare_gan.architectures.arch_ops import conv2d
24 | from compare_gan.architectures.arch_ops import deconv2d
25 | from compare_gan.architectures.arch_ops import linear
26 | from compare_gan.architectures.arch_ops import lrelu
27 |
28 | import numpy as np
29 | import tensorflow as tf
30 |
31 |
32 | def conv_out_size_same(size, stride):
33 | return int(np.ceil(float(size) / float(stride)))
34 |
35 |
36 | class Generator(abstract_arch.AbstractGenerator):
37 | """SNDCGAN generator.
38 |
39 | Details are available at https://openreview.net/pdf?id=B1QRgziT-.
40 | """
41 |
42 | def apply(self, z, y, is_training):
43 | """Build the generator network for the given inputs.
44 |
45 | Args:
46 | z: `Tensor` of shape [batch_size, z_dim] with latent code.
47 | y: `Tensor` of shape [batch_size, num_classes] of one hot encoded labels.
48 | is_training: boolean, are we in train or eval model.
49 |
50 | Returns:
51 | A tensor of size [batch_size] + self._image_shape with values in [0, 1].
52 | """
53 | batch_size = z.shape[0].value
54 | s_h, s_w, colors = self._image_shape
55 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
56 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
57 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
58 |
59 | net = linear(z, s_h8 * s_w8 * 512, scope="g_fc1")
60 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn1")
61 | net = tf.nn.relu(net)
62 | net = tf.reshape(net, [batch_size, s_h8, s_w8, 512])
63 | net = deconv2d(net, [batch_size, s_h4, s_w4, 256], 4, 4, 2, 2, name="g_dc2")
64 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn2")
65 | net = tf.nn.relu(net)
66 | net = deconv2d(net, [batch_size, s_h2, s_w2, 128], 4, 4, 2, 2, name="g_dc3")
67 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn3")
68 | net = tf.nn.relu(net)
69 | net = deconv2d(net, [batch_size, s_h, s_w, 64], 4, 4, 2, 2, name="g_dc4")
70 | net = self.batch_norm(net, z=z, y=y, is_training=is_training, name="g_bn4")
71 | net = tf.nn.relu(net)
72 | net = deconv2d(
73 | net, [batch_size, s_h, s_w, colors], 3, 3, 1, 1, name="g_dc5")
74 | out = tf.tanh(net)
75 |
76 | # This normalization from [-1, 1] to [0, 1] is introduced for consistency
77 | # with other models.
78 | out = tf.div(out + 1.0, 2.0)
79 | return out
80 |
81 |
82 | class Discriminator(abstract_arch.AbstractDiscriminator):
83 | """SNDCGAN discriminator.
84 |
85 | Details are available at https://openreview.net/pdf?id=B1QRgziT-.
86 | """
87 |
88 | def apply(self, x, y, is_training):
89 | """Apply the discriminator on a input.
90 |
91 | Args:
92 | x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
93 | y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
94 | labels.
95 | is_training: Boolean, whether the architecture should be constructed for
96 | training or inference.
97 |
98 | Returns:
99 | Tuple of 3 Tensors, the final prediction of the discriminator, the logits
100 | before the final output activation function and logits form the second
101 | last layer.
102 | """
103 | del is_training, y
104 | use_sn = self._spectral_norm
105 | # In compare gan framework, the image preprocess normalize image pixel to
106 | # range [0, 1], while author used [-1, 1]. Apply this trick to input image
107 | # instead of changing our preprocessing function.
108 | x = x * 2.0 - 1.0
109 | net = conv2d(x, 64, 3, 3, 1, 1, name="d_conv1", use_sn=use_sn)
110 | net = lrelu(net, leak=0.1)
111 | net = conv2d(net, 128, 4, 4, 2, 2, name="d_conv2", use_sn=use_sn)
112 | net = lrelu(net, leak=0.1)
113 | net = conv2d(net, 128, 3, 3, 1, 1, name="d_conv3", use_sn=use_sn)
114 | net = lrelu(net, leak=0.1)
115 | net = conv2d(net, 256, 4, 4, 2, 2, name="d_conv4", use_sn=use_sn)
116 | net = lrelu(net, leak=0.1)
117 | net = conv2d(net, 256, 3, 3, 1, 1, name="d_conv5", use_sn=use_sn)
118 | net = lrelu(net, leak=0.1)
119 | net = conv2d(net, 512, 4, 4, 2, 2, name="d_conv6", use_sn=use_sn)
120 | net = lrelu(net, leak=0.1)
121 | net = conv2d(net, 512, 3, 3, 1, 1, name="d_conv7", use_sn=use_sn)
122 | net = lrelu(net, leak=0.1)
123 | batch_size = x.shape.as_list()[0]
124 | net = tf.reshape(net, [batch_size, -1])
125 | out_logit = linear(net, 1, scope="d_fc1", use_sn=use_sn)
126 | out = tf.nn.sigmoid(out_logit)
127 | return out, out_logit, net
128 |
--------------------------------------------------------------------------------
/compare_gan/datasets_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for datasets."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 | from absl.testing import flagsaver
24 | from absl.testing import parameterized
25 | from compare_gan import datasets
26 |
27 | import tensorflow as tf
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | _TPU_SUPPORTED_TYPES = {
32 | tf.float32, tf.int32, tf.complex64, tf.int64, tf.bool, tf.bfloat16
33 | }
34 |
35 |
36 | def _preprocess_fn_id(images, labels):
37 | return {"images": images}, labels
38 |
39 |
40 | def _preprocess_fn_add_noise(images, labels, seed=None):
41 | del labels
42 | tf.set_random_seed(seed)
43 | noise = tf.random.uniform([128], maxval=1.0)
44 | return {"images": images}, noise
45 |
46 |
47 | class DatasetsTest(parameterized.TestCase, tf.test.TestCase):
48 |
49 | def setUp(self):
50 | super(DatasetsTest, self).setUp()
51 | FLAGS.data_shuffle_buffer_size = 100
52 |
53 | def get_element_and_verify_shape(self, dataset_name, expected_shape):
54 | dataset = datasets.get_dataset(dataset_name)
55 | dataset = dataset.eval_input_fn()
56 | image, label = dataset.make_one_shot_iterator().get_next()
57 | # Check if shape is known at compile time, required for TPUs.
58 | self.assertAllEqual(image.shape.as_list(), expected_shape)
59 | self.assertEqual(image.dtype, tf.float32)
60 | self.assertIn(label.dtype, _TPU_SUPPORTED_TYPES)
61 | with self.cached_session() as session:
62 | image = session.run(image)
63 | self.assertEqual(image.shape, expected_shape)
64 | self.assertGreaterEqual(image.min(), 0.0)
65 | self.assertLessEqual(image.max(), 1.0)
66 |
67 | def test_mnist(self):
68 | self.get_element_and_verify_shape("mnist", (28, 28, 1))
69 |
70 | def test_fashion_mnist(self):
71 | self.get_element_and_verify_shape("fashion-mnist", (28, 28, 1))
72 |
73 | def test_celeba(self):
74 | self.get_element_and_verify_shape("celeb_a", (64, 64, 3))
75 |
76 | def test_lsun(self):
77 | self.get_element_and_verify_shape("lsun-bedroom", (128, 128, 3))
78 |
79 | def _run_train_input_fn(self, dataset_name, preprocess_fn):
80 | dataset = datasets.get_dataset(dataset_name)
81 | with tf.Graph().as_default():
82 | dataset = dataset.input_fn(params={"batch_size": 1},
83 | preprocess_fn=preprocess_fn)
84 | iterator = dataset.make_initializable_iterator()
85 | with self.session() as sess:
86 | sess.run(iterator.initializer)
87 | next_batch = iterator.get_next()
88 | return [sess.run(next_batch) for _ in range(5)]
89 |
90 | @parameterized.named_parameters(
91 | ("FakeCifar", _preprocess_fn_id),
92 | ("FakeCifarWithRandomNoise", _preprocess_fn_add_noise),
93 | )
94 | @flagsaver.flagsaver
95 | def test_train_input_fn_is_determinsitic(self, preprocess_fn):
96 | FLAGS.data_fake_dataset = True
97 | batches1 = self._run_train_input_fn("cifar10", preprocess_fn)
98 | batches2 = self._run_train_input_fn("cifar10", preprocess_fn)
99 | for i in range(len(batches1)):
100 | # Check that both runs got the same images/noise
101 | self.assertAllClose(batches1[i][0], batches2[i][0])
102 | self.assertAllClose(batches1[i][1], batches2[i][1])
103 |
104 | @flagsaver.flagsaver
105 | def test_train_input_fn_noise_changes(self):
106 | FLAGS.data_fake_dataset = True
107 | batches = self._run_train_input_fn("cifar10", _preprocess_fn_add_noise)
108 | for i in range(1, len(batches)):
109 | self.assertNotAllClose(batches[0][1], batches[i][1])
110 | self.assertNotAllClose(batches[i - 1][1], batches[i][1])
111 |
112 |
113 | if __name__ == "__main__":
114 | tf.test.main()
115 |
--------------------------------------------------------------------------------
/compare_gan/eval_gan_lib_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for eval_gan_lib."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import os.path
24 |
25 | from absl import flags
26 | from absl.testing import flagsaver
27 | from absl.testing import parameterized
28 |
29 | from compare_gan import datasets
30 | from compare_gan import eval_gan_lib
31 | from compare_gan import eval_utils
32 | from compare_gan.gans import consts as c
33 | from compare_gan.gans.modular_gan import ModularGAN
34 | from compare_gan.metrics import fid_score
35 | from compare_gan.metrics import fractal_dimension
36 | from compare_gan.metrics import inception_score
37 | from compare_gan.metrics import ms_ssim_score
38 |
39 | import gin
40 | import mock
41 | import tensorflow as tf
42 |
43 | FLAGS = flags.FLAGS
44 |
45 |
46 | def create_fake_inception_graph():
47 | """Creates a `GraphDef` with that mocks the Inception V1 graph.
48 |
49 | It takes the input, multiplies it through a matrix full of 0.00001 values,
50 | and provides the results in the endpoints 'pool_3' and 'logits'. This
51 | matches the tensor names in the real Inception V1 model.
52 | the real inception model.
53 |
54 | Returns:
55 | `tf.GraphDef` for the mocked Inception V1 graph.
56 | """
57 | fake_inception = tf.Graph()
58 | with fake_inception.as_default():
59 | inputs = tf.placeholder(
60 | tf.float32, shape=[None, 299, 299, 3], name="Mul")
61 | w = tf.ones(shape=[299 * 299 * 3, 10]) * 0.00001
62 | outputs = tf.matmul(tf.layers.flatten(inputs), w)
63 | tf.identity(outputs, name="pool_3")
64 | tf.identity(outputs, name="logits")
65 | return fake_inception.as_graph_def()
66 |
67 |
68 | class EvalGanLibTest(parameterized.TestCase, tf.test.TestCase):
69 |
70 | def setUp(self):
71 | super(EvalGanLibTest, self).setUp()
72 | gin.clear_config()
73 | FLAGS.data_fake_dataset = True
74 | self.mock_get_graph = mock.patch.object(
75 | eval_utils, "get_inception_graph_def").start()
76 | self.mock_get_graph.return_value = create_fake_inception_graph()
77 |
78 | @parameterized.parameters(c.ARCHITECTURES)
79 | @flagsaver.flagsaver
80 | def test_end2end_checkpoint(self, architecture):
81 | """Takes real GAN (trained for 1 step) and evaluate it."""
82 | if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}:
83 | # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape.
84 | return
85 | gin.bind_parameter("dataset.name", "cifar10")
86 | dataset = datasets.get_dataset("cifar10")
87 | options = {
88 | "architecture": architecture,
89 | "z_dim": 120,
90 | "disc_iters": 1,
91 | "lambda": 1,
92 | }
93 | model_dir = os.path.join(tf.test.get_temp_dir(), self.id())
94 | tf.logging.info("model_dir: %s" % model_dir)
95 | run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir)
96 | gan = ModularGAN(dataset=dataset,
97 | parameters=options,
98 | conditional="biggan" in architecture,
99 | model_dir=model_dir)
100 | estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
101 | estimator.train(input_fn=gan.input_fn, steps=1)
102 | export_path = os.path.join(model_dir, "tfhub")
103 | checkpoint_path = os.path.join(model_dir, "model.ckpt-1")
104 | module_spec = gan.as_module_spec()
105 | module_spec.export(export_path, checkpoint_path=checkpoint_path)
106 |
107 | eval_tasks = [
108 | fid_score.FIDScoreTask(),
109 | fractal_dimension.FractalDimensionTask(),
110 | inception_score.InceptionScoreTask(),
111 | ms_ssim_score.MultiscaleSSIMTask()
112 | ]
113 | result_dict = eval_gan_lib.evaluate_tfhub_module(
114 | export_path, eval_tasks, use_tpu=False, num_averaging_runs=1)
115 | tf.logging.info("result_dict: %s", result_dict)
116 | for score in ["fid_score", "fractal_dimension", "inception_score",
117 | "ms_ssim"]:
118 | for stats in ["mean", "std", "list"]:
119 | required_key = "%s_%s" % (score, stats)
120 | self.assertIn(required_key, result_dict, "Missing: %s." % required_key)
121 |
122 |
123 | if __name__ == "__main__":
124 | tf.test.main()
125 |
--------------------------------------------------------------------------------
/compare_gan/gans/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # coding=utf-8
17 |
--------------------------------------------------------------------------------
/compare_gan/gans/abstract_gan.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Interface for GAN models that can be trained using the Estimator API."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 |
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | @six.add_metaclass(abc.ABCMeta)
29 | class AbstractGAN(object):
30 | """Interface for GAN models that can be training using the Estimator API."""
31 |
32 | def __init__(self,
33 | dataset,
34 | parameters,
35 | model_dir):
36 | super(AbstractGAN, self).__init__()
37 | self._dataset = dataset
38 | self._parameters = parameters
39 | self._model_dir = model_dir
40 |
41 | def as_estimator(self, run_config, batch_size, use_tpu):
42 | """Returns a TPUEstimator for this GAN."""
43 | return tf.contrib.tpu.TPUEstimator(
44 | config=run_config,
45 | use_tpu=use_tpu,
46 | model_fn=self.model_fn,
47 | train_batch_size=batch_size)
48 |
49 | @abc.abstractmethod
50 | def as_module_spec(self, params, mode):
51 | """Returns the generator network as TFHub module spec."""
52 |
53 | @abc.abstractmethod
54 | def input_fn(self, params, mode):
55 | """Input function that retuns a `tf.data.Dataset` object.
56 |
57 | This function will be called once for each host machine.
58 |
59 | Args:
60 | params: Python dictionary with parameters given to TPUEstimator.
61 | Additional TPUEstimator will set the key `batch_size` with the batch
62 | size for this host machine and `tpu_contextu` with a TPUContext
63 | object.
64 | mode: `tf.estimator.MoedeKeys` value.
65 |
66 | Returns:
67 | A `tf.data.Dataset` object with batched features and labels.
68 | """
69 |
70 | @abc.abstractmethod
71 | def model_fn(self, features, labels, params, mode):
72 | """Constructs the model for the given features and mode.
73 |
74 | This interface only requires implementing the TRAIN mode.
75 |
76 | On TPUs the model_fn should construct a graph for a single TPU core.
77 | Wrap the optimizer with a `tf.contrib.tpu.CrossShardOptimizer` to do
78 | synchronous training with all TPU cores.c
79 |
80 | Args:
81 | features: A dictionary with the feature tensors.
82 | labels: Tensor will labels. Will be None if mode is PREDICT.
83 | params: Dictionary with hyperparameters passed to TPUEstimator.
84 | Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
85 | `tpu_context`. `batch_size` is the batch size for this core.
86 | mode: `tf.estimator.ModeKeys` value (TRAIN, EVAL, PREDICT). The mode
87 | should be passed to the TPUEstimatorSpec and your model should be
88 | build this mode.
89 |
90 | Returns:
91 | A `tf.contrib.tpu.TPUEstimatorSpec`.
92 | """
93 |
--------------------------------------------------------------------------------
/compare_gan/gans/consts.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Defines constants used across the code base."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | NORMAL_INIT = "normal"
24 | TRUNCATED_INIT = "truncated"
25 | ORTHOGONAL_INIT = "orthogonal"
26 | INITIALIZERS = [NORMAL_INIT, TRUNCATED_INIT, ORTHOGONAL_INIT]
27 |
28 | DCGAN_ARCH = "dcgan_arch"
29 | DUMMY_ARCH = "dummy_arch"
30 | INFOGAN_ARCH = "infogan_arch"
31 | RESNET5_ARCH = "resnet5_arch"
32 | RESNET30_ARCH = "resnet30_arch"
33 | RESNET_BIGGAN_ARCH = "resnet_biggan_arch"
34 | RESNET_BIGGAN_DEEP_ARCH = "resnet_biggan_deep_arch"
35 | RESNET_CIFAR_ARCH = "resnet_cifar_arch"
36 | RESNET_STL_ARCH = "resnet_stl_arch"
37 | SNDCGAN_ARCH = "sndcgan_arch"
38 | ARCHITECTURES = [INFOGAN_ARCH, DCGAN_ARCH, RESNET5_ARCH, RESNET30_ARCH,
39 | RESNET_BIGGAN_ARCH, RESNET_BIGGAN_DEEP_ARCH, RESNET_CIFAR_ARCH,
40 | RESNET_STL_ARCH, SNDCGAN_ARCH]
41 |
--------------------------------------------------------------------------------
/compare_gan/gans/loss_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of popular GAN losses."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan import utils
23 | import gin
24 | import tensorflow as tf
25 |
26 |
27 | def check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits):
28 | """Checks the shapes and ranks of logits and prediction tensors.
29 |
30 | Args:
31 | d_real: prediction for real points, values in [0, 1], shape [batch_size, 1].
32 | d_fake: prediction for fake points, values in [0, 1], shape [batch_size, 1].
33 | d_real_logits: logits for real points, shape [batch_size, 1].
34 | d_fake_logits: logits for fake points, shape [batch_size, 1].
35 |
36 | Raises:
37 | ValueError: if the ranks or shapes are mismatched.
38 | """
39 | def _check_pair(a, b):
40 | if a != b:
41 | raise ValueError("Shape mismatch: %s vs %s." % (a, b))
42 | if len(a) != 2 or len(b) != 2:
43 | raise ValueError("Rank: expected 2, got %s and %s" % (len(a), len(b)))
44 |
45 | if (d_real is not None) and (d_fake is not None):
46 | _check_pair(d_real.shape.as_list(), d_fake.shape.as_list())
47 | if (d_real_logits is not None) and (d_fake_logits is not None):
48 | _check_pair(d_real_logits.shape.as_list(), d_fake_logits.shape.as_list())
49 | if (d_real is not None) and (d_real_logits is not None):
50 | _check_pair(d_real.shape.as_list(), d_real_logits.shape.as_list())
51 |
52 |
53 | @gin.configurable(whitelist=[])
54 | def non_saturating(d_real_logits, d_fake_logits, d_real=None, d_fake=None):
55 | """Returns the discriminator and generator loss for Non-saturating loss.
56 |
57 | Args:
58 | d_real_logits: logits for real points, shape [batch_size, 1].
59 | d_fake_logits: logits for fake points, shape [batch_size, 1].
60 | d_real: ignored.
61 | d_fake: ignored.
62 |
63 | Returns:
64 | A tuple consisting of the discriminator loss, discriminator's loss on the
65 | real samples and fake samples, and the generator's loss.
66 | """
67 | with tf.name_scope("non_saturating_loss"):
68 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits)
69 | d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
70 | logits=d_real_logits, labels=tf.ones_like(d_real_logits),
71 | name="cross_entropy_d_real"))
72 | d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
73 | logits=d_fake_logits, labels=tf.zeros_like(d_fake_logits),
74 | name="cross_entropy_d_fake"))
75 | d_loss = d_loss_real + d_loss_fake
76 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
77 | logits=d_fake_logits, labels=tf.ones_like(d_fake_logits),
78 | name="cross_entropy_g"))
79 | return d_loss, d_loss_real, d_loss_fake, g_loss
80 |
81 |
82 | @gin.configurable(whitelist=[])
83 | def wasserstein(d_real_logits, d_fake_logits, d_real=None, d_fake=None):
84 | """Returns the discriminator and generator loss for Wasserstein loss.
85 |
86 | Args:
87 | d_real_logits: logits for real points, shape [batch_size, 1].
88 | d_fake_logits: logits for fake points, shape [batch_size, 1].
89 | d_real: ignored.
90 | d_fake: ignored.
91 |
92 | Returns:
93 | A tuple consisting of the discriminator loss, discriminator's loss on the
94 | real samples and fake samples, and the generator's loss.
95 | """
96 | with tf.name_scope("wasserstein_loss"):
97 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits)
98 | d_loss_real = -tf.reduce_mean(d_real_logits)
99 | d_loss_fake = tf.reduce_mean(d_fake_logits)
100 | d_loss = d_loss_real + d_loss_fake
101 | g_loss = -d_loss_fake
102 | return d_loss, d_loss_real, d_loss_fake, g_loss
103 |
104 |
105 | @gin.configurable(whitelist=[])
106 | def least_squares(d_real, d_fake, d_real_logits=None, d_fake_logits=None):
107 | """Returns the discriminator and generator loss for the least-squares loss.
108 |
109 | Args:
110 | d_real: prediction for real points, values in [0, 1], shape [batch_size, 1].
111 | d_fake: prediction for fake points, values in [0, 1], shape [batch_size, 1].
112 | d_real_logits: ignored.
113 | d_fake_logits: ignored.
114 |
115 | Returns:
116 | A tuple consisting of the discriminator loss, discriminator's loss on the
117 | real samples and fake samples, and the generator's loss.
118 | """
119 | with tf.name_scope("least_square_loss"):
120 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits)
121 | d_loss_real = tf.reduce_mean(tf.square(d_real - 1.0))
122 | d_loss_fake = tf.reduce_mean(tf.square(d_fake))
123 | d_loss = 0.5 * (d_loss_real + d_loss_fake)
124 | g_loss = 0.5 * tf.reduce_mean(tf.square(d_fake - 1.0))
125 | return d_loss, d_loss_real, d_loss_fake, g_loss
126 |
127 |
128 | @gin.configurable(whitelist=[])
129 | def hinge(d_real_logits, d_fake_logits, d_real=None, d_fake=None):
130 | """Returns the discriminator and generator loss for the hinge loss.
131 |
132 | Args:
133 | d_real_logits: logits for real points, shape [batch_size, 1].
134 | d_fake_logits: logits for fake points, shape [batch_size, 1].
135 | d_real: ignored.
136 | d_fake: ignored.
137 |
138 | Returns:
139 | A tuple consisting of the discriminator loss, discriminator's loss on the
140 | real samples and fake samples, and the generator's loss.
141 | """
142 | with tf.name_scope("hinge_loss"):
143 | check_dimensions(d_real, d_fake, d_real_logits, d_fake_logits)
144 | d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - d_real_logits))
145 | d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_fake_logits))
146 | d_loss = d_loss_real + d_loss_fake
147 | g_loss = - tf.reduce_mean(d_fake_logits)
148 | return d_loss, d_loss_real, d_loss_fake, g_loss
149 |
150 |
151 | @gin.configurable("loss", whitelist=["fn"])
152 | def get_losses(fn=non_saturating, **kwargs):
153 | """Returns the losses for the discriminator and generator."""
154 | return utils.call_with_accepted_args(fn, **kwargs)
155 |
--------------------------------------------------------------------------------
/compare_gan/gans/modular_gan_conditional_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for GANs with different regularizers."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 | from absl.testing import parameterized
24 | from compare_gan import datasets
25 | from compare_gan import test_utils
26 | from compare_gan.gans import consts as c
27 | from compare_gan.gans import loss_lib
28 | from compare_gan.gans import penalty_lib
29 | from compare_gan.gans.modular_gan import ModularGAN
30 | import gin
31 | import tensorflow as tf
32 |
33 |
34 | FLAGS = flags.FLAGS
35 | TEST_ARCHITECTURES = [c.RESNET5_ARCH, c.RESNET_BIGGAN_ARCH, c.RESNET_CIFAR_ARCH]
36 | TEST_LOSSES = [loss_lib.non_saturating, loss_lib.wasserstein,
37 | loss_lib.least_squares, loss_lib.hinge]
38 | TEST_PENALTIES = [penalty_lib.no_penalty, penalty_lib.dragan_penalty,
39 | penalty_lib.wgangp_penalty, penalty_lib.l2_penalty]
40 |
41 |
42 | class ModularGANConditionalTest(parameterized.TestCase,
43 | test_utils.CompareGanTestCase):
44 |
45 | def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn,
46 | labeled_dataset):
47 | parameters = {
48 | "architecture": architecture,
49 | "lambda": 1,
50 | "z_dim": 120,
51 | }
52 | with gin.unlock_config():
53 | gin.bind_parameter("penalty.fn", penalty_fn)
54 | gin.bind_parameter("loss.fn", loss_fn)
55 | model_dir = self._get_empty_model_dir()
56 | run_config = tf.contrib.tpu.RunConfig(
57 | model_dir=model_dir,
58 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
59 | dataset = datasets.get_dataset("cifar10")
60 | gan = ModularGAN(
61 | dataset=dataset,
62 | parameters=parameters,
63 | conditional=True,
64 | model_dir=model_dir)
65 | estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
66 | estimator.train(gan.input_fn, steps=1)
67 |
68 | @parameterized.parameters(TEST_ARCHITECTURES)
69 | def testSingleTrainingStepArchitectures(self, architecture):
70 | self._runSingleTrainingStep(architecture, loss_lib.hinge,
71 | penalty_lib.no_penalty, True)
72 |
73 | @parameterized.parameters(TEST_LOSSES)
74 | def testSingleTrainingStepLosses(self, loss_fn):
75 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_fn,
76 | penalty_lib.no_penalty, labeled_dataset=True)
77 |
78 | @parameterized.parameters(TEST_PENALTIES)
79 | def testSingleTrainingStepPenalties(self, penalty_fn):
80 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_lib.hinge, penalty_fn,
81 | labeled_dataset=True)
82 |
83 | def testUnlabledDatasetRaisesError(self):
84 | parameters = {
85 | "architecture": c.RESNET_CIFAR_ARCH,
86 | "lambda": 1,
87 | "z_dim": 120,
88 | }
89 | with gin.unlock_config():
90 | gin.bind_parameter("loss.fn", loss_lib.hinge)
91 | # Use dataset without labels.
92 | dataset = datasets.get_dataset("celeb_a")
93 | model_dir = self._get_empty_model_dir()
94 | with self.assertRaises(ValueError):
95 | gan = ModularGAN(
96 | dataset=dataset,
97 | parameters=parameters,
98 | conditional=True,
99 | model_dir=model_dir)
100 | del gan
101 |
102 |
103 | if __name__ == "__main__":
104 | tf.test.main()
105 |
--------------------------------------------------------------------------------
/compare_gan/gans/modular_gan_tpu_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests TPU specfic parts of ModularGAN."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 | from absl.testing import parameterized
24 | from compare_gan import datasets
25 | from compare_gan import test_utils
26 | from compare_gan.gans import consts as c
27 | from compare_gan.gans.modular_gan import ModularGAN
28 | import tensorflow as tf
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | class ModularGanTpuTest(parameterized.TestCase, test_utils.CompareGanTestCase):
34 |
35 | def setUp(self):
36 | super(ModularGanTpuTest, self).setUp()
37 | self.model_dir = self._get_empty_model_dir()
38 | self.run_config = tf.contrib.tpu.RunConfig(
39 | model_dir=self.model_dir,
40 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
41 |
42 | @parameterized.parameters([1, 2, 5])
43 | def testBatchSize(self, disc_iters, use_tpu=True):
44 | parameters = {
45 | "architecture": c.DUMMY_ARCH,
46 | "lambda": 1,
47 | "z_dim": 128,
48 | "disc_iters": disc_iters,
49 | }
50 | batch_size = 16
51 | dataset = datasets.get_dataset("cifar10")
52 | gan = ModularGAN(
53 | dataset=dataset,
54 | parameters=parameters,
55 | model_dir=self.model_dir)
56 | estimator = gan.as_estimator(self.run_config, batch_size=batch_size,
57 | use_tpu=True)
58 | estimator.train(gan.input_fn, steps=1)
59 |
60 | gen_args = gan.generator.call_arg_list
61 | disc_args = gan.discriminator.call_arg_list
62 | self.assertLen(gen_args, disc_iters + 1) # D steps, G step.
63 | self.assertLen(disc_args, disc_iters + 1) # D steps, G step.
64 |
65 | for args in gen_args:
66 | self.assertAllEqual(args["z"].shape.as_list(), [8, 128])
67 | for args in disc_args:
68 | self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3])
69 |
70 | @parameterized.parameters([1, 2, 5])
71 | def testBatchSizeSplitDiscCalls(self, disc_iters):
72 | parameters = {
73 | "architecture": c.DUMMY_ARCH,
74 | "lambda": 1,
75 | "z_dim": 128,
76 | "disc_iters": disc_iters,
77 | }
78 | batch_size = 16
79 | dataset = datasets.get_dataset("cifar10")
80 | gan = ModularGAN(
81 | dataset=dataset,
82 | parameters=parameters,
83 | deprecated_split_disc_calls=True,
84 | model_dir=self.model_dir)
85 | estimator = gan.as_estimator(self.run_config, batch_size=batch_size,
86 | use_tpu=True)
87 | estimator.train(gan.input_fn, steps=1)
88 |
89 | gen_args = gan.generator.call_arg_list
90 | disc_args = gan.discriminator.call_arg_list
91 | self.assertLen(gen_args, disc_iters + 1) # D steps, G step.
92 | # Each D and G step calls discriminator twice: for real and fake images.
93 | self.assertLen(disc_args, 2 * (disc_iters + 1))
94 |
95 | for args in gen_args:
96 | self.assertAllEqual(args["z"].shape.as_list(), [8, 128])
97 | for args in disc_args:
98 | self.assertAllEqual(args["x"].shape.as_list(), [8, 32, 32, 3])
99 |
100 | @parameterized.parameters([1, 2, 5])
101 | def testBatchSizeExperimentalJointGenForDisc(self, disc_iters):
102 | parameters = {
103 | "architecture": c.DUMMY_ARCH,
104 | "lambda": 1,
105 | "z_dim": 128,
106 | "disc_iters": disc_iters,
107 | }
108 | batch_size = 16
109 | dataset = datasets.get_dataset("cifar10")
110 | gan = ModularGAN(
111 | dataset=dataset,
112 | parameters=parameters,
113 | experimental_joint_gen_for_disc=True,
114 | model_dir=self.model_dir)
115 | estimator = gan.as_estimator(self.run_config, batch_size=batch_size,
116 | use_tpu=True)
117 | estimator.train(gan.input_fn, steps=1)
118 |
119 | gen_args = gan.generator.call_arg_list
120 | disc_args = gan.discriminator.call_arg_list
121 | self.assertLen(gen_args, 2)
122 | self.assertLen(disc_args, disc_iters + 1)
123 |
124 | self.assertAllEqual(gen_args[0]["z"].shape.as_list(), [8 * disc_iters, 128])
125 | self.assertAllEqual(gen_args[1]["z"].shape.as_list(), [8, 128])
126 | for args in disc_args:
127 | self.assertAllEqual(args["x"].shape.as_list(), [16, 32, 32, 3])
128 |
129 |
130 | if __name__ == "__main__":
131 | tf.test.main()
132 |
--------------------------------------------------------------------------------
/compare_gan/gans/ops.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Customized TensorFlow operations."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.tpu import tpu_random
23 |
24 | random_uniform = tpu_random.uniform
25 | random_normal = tpu_random.normal
26 |
--------------------------------------------------------------------------------
/compare_gan/gans/penalty_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of popular GAN penalties."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan import utils
23 | from compare_gan.gans import ops
24 | import gin
25 | import tensorflow as tf
26 |
27 |
28 | @gin.configurable
29 | def no_penalty():
30 | return tf.constant(0.0)
31 |
32 |
33 | @gin.configurable(whitelist=[])
34 | def dragan_penalty(discriminator, x, y, is_training):
35 | """Returns the DRAGAN gradient penalty.
36 |
37 | Args:
38 | discriminator: Instance of `AbstractDiscriminator`.
39 | x: Samples from the true distribution, shape [bs, h, w, channels].
40 | y: Encoded class embedding for the samples. None for unsupervised models.
41 | is_training: boolean, are we in train or eval model.
42 |
43 | Returns:
44 | A tensor with the computed penalty.
45 | """
46 | with tf.name_scope("dragan_penalty"):
47 | _, var = tf.nn.moments(x, axes=list(range(len(x.get_shape()))))
48 | std = tf.sqrt(var)
49 | x_noisy = x + std * (ops.random_uniform(x.shape) - 0.5)
50 | x_noisy = tf.clip_by_value(x_noisy, 0.0, 1.0)
51 | logits = discriminator(x_noisy, y=y, is_training=is_training, reuse=True)[1]
52 | gradients = tf.gradients(logits, [x_noisy])[0]
53 | slopes = tf.sqrt(0.0001 + tf.reduce_sum(
54 | tf.square(gradients), reduction_indices=[1, 2, 3]))
55 | gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0))
56 | return gradient_penalty
57 |
58 |
59 | @gin.configurable(whitelist=[])
60 | def wgangp_penalty(discriminator, x, x_fake, y, is_training):
61 | """Returns the WGAN gradient penalty.
62 |
63 | Args:
64 | discriminator: Instance of `AbstractDiscriminator`.
65 | x: samples from the true distribution, shape [bs, h, w, channels].
66 | x_fake: samples from the fake distribution, shape [bs, h, w, channels].
67 | y: Encoded class embedding for the samples. None for unsupervised models.
68 | is_training: boolean, are we in train or eval model.
69 |
70 | Returns:
71 | A tensor with the computed penalty.
72 | """
73 | with tf.name_scope("wgangp_penalty"):
74 | alpha = ops.random_uniform(shape=[x.shape[0].value, 1, 1, 1], name="alpha")
75 | interpolates = x + alpha * (x_fake - x)
76 | logits = discriminator(
77 | interpolates, y=y, is_training=is_training, reuse=True)[1]
78 | gradients = tf.gradients(logits, [interpolates])[0]
79 | slopes = tf.sqrt(0.0001 + tf.reduce_sum(
80 | tf.square(gradients), reduction_indices=[1, 2, 3]))
81 | gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0))
82 | return gradient_penalty
83 |
84 |
85 | @gin.configurable(whitelist=[])
86 | def l2_penalty(discriminator):
87 | """Returns the L2 penalty for each matrix/vector excluding biases.
88 |
89 | Assumes a specific tensor naming followed throughout the compare_gan library.
90 | We penalize all fully connected, conv2d, and deconv2d layers.
91 |
92 | Args:
93 | discriminator: Instance of `AbstractDiscriminator`.
94 |
95 | Returns:
96 | A tensor with the computed penalty.
97 | """
98 | with tf.name_scope("l2_penalty"):
99 | d_weights = [v for v in discriminator.trainable_variables
100 | if v.name.endswith("/kernel:0")]
101 | return tf.reduce_mean(
102 | [tf.nn.l2_loss(i) for i in d_weights], name="l2_penalty")
103 |
104 |
105 | @gin.configurable("penalty", whitelist=["fn"])
106 | def get_penalty_loss(fn=no_penalty, **kwargs):
107 | """Returns the penalty loss."""
108 | return utils.call_with_accepted_args(fn, **kwargs)
109 |
--------------------------------------------------------------------------------
/compare_gan/gans/s3gan_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for S3GANs."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 | from absl.testing import parameterized
24 | from compare_gan import datasets
25 | from compare_gan import test_utils
26 | from compare_gan.gans import consts as c
27 | from compare_gan.gans import loss_lib
28 | from compare_gan.gans.s3gan import S3GAN
29 | import gin
30 | import tensorflow as tf
31 |
32 | FLAGS = flags.FLAGS
33 |
34 |
35 | class S3GANTest(parameterized.TestCase, test_utils.CompareGanTestCase):
36 |
37 | @parameterized.parameters(
38 | {"use_predictor": False, "project_y": False}, # unsupervised.
39 | {"use_predictor": False}, # fully supervised.
40 | {"use_predictor": True}, # only oracle.
41 | {"use_predictor": True, "self_supervision": "rotation"}, # oracle + SS.
42 | {"use_predictor": False, "self_supervision": "rotation"}, # only SS.
43 | )
44 | def testSingleTrainingStepArchitectures(
45 | self, use_predictor, project_y=True, self_supervision="none"):
46 | parameters = {
47 | "architecture": c.RESNET_BIGGAN_ARCH,
48 | "lambda": 1,
49 | "z_dim": 120,
50 | }
51 | with gin.unlock_config():
52 | gin.bind_parameter("ModularGAN.conditional", True)
53 | gin.bind_parameter("loss.fn", loss_lib.hinge)
54 | gin.bind_parameter("S3GAN.use_predictor", use_predictor)
55 | gin.bind_parameter("S3GAN.project_y", project_y)
56 | gin.bind_parameter("S3GAN.self_supervision", self_supervision)
57 | # Fake ImageNet dataset by overriding the properties.
58 | dataset = datasets.get_dataset("imagenet_128")
59 | model_dir = self._get_empty_model_dir()
60 | run_config = tf.contrib.tpu.RunConfig(
61 | model_dir=model_dir,
62 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
63 | gan = S3GAN(
64 | dataset=dataset,
65 | parameters=parameters,
66 | model_dir=model_dir,
67 | g_optimizer_fn=tf.train.AdamOptimizer,
68 | g_lr=0.0002,
69 | rotated_batch_fraction=2)
70 | estimator = gan.as_estimator(run_config, batch_size=8, use_tpu=False)
71 | estimator.train(gan.input_fn, steps=1)
72 |
73 |
74 | if __name__ == "__main__":
75 | tf.test.main()
76 |
77 |
--------------------------------------------------------------------------------
/compare_gan/gans/ssgan_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for SSGANs."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 | from absl.testing import parameterized
24 | from compare_gan import datasets
25 | from compare_gan import test_utils
26 | from compare_gan.gans import consts as c
27 | from compare_gan.gans import loss_lib
28 | from compare_gan.gans import penalty_lib
29 | from compare_gan.gans.ssgan import SSGAN
30 | import gin
31 | import tensorflow as tf
32 |
33 | FLAGS = flags.FLAGS
34 | TEST_ARCHITECTURES = [c.RESNET_CIFAR_ARCH, c.SNDCGAN_ARCH, c.RESNET5_ARCH]
35 | TEST_LOSSES = [loss_lib.non_saturating, loss_lib.hinge]
36 | TEST_PENALTIES = [penalty_lib.no_penalty, penalty_lib.wgangp_penalty]
37 |
38 |
39 | class SSGANTest(parameterized.TestCase, test_utils.CompareGanTestCase):
40 |
41 | def _runSingleTrainingStep(self, architecture, loss_fn, penalty_fn):
42 | parameters = {
43 | "architecture": architecture,
44 | "lambda": 1,
45 | "z_dim": 128,
46 | }
47 | with gin.unlock_config():
48 | gin.bind_parameter("penalty.fn", penalty_fn)
49 | gin.bind_parameter("loss.fn", loss_fn)
50 | model_dir = self._get_empty_model_dir()
51 | run_config = tf.contrib.tpu.RunConfig(
52 | model_dir=model_dir,
53 | tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1))
54 | dataset = datasets.get_dataset("cifar10")
55 | gan = SSGAN(
56 | dataset=dataset,
57 | parameters=parameters,
58 | model_dir=model_dir,
59 | g_optimizer_fn=tf.train.AdamOptimizer,
60 | g_lr=0.0002,
61 | rotated_batch_size=4)
62 | estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
63 | estimator.train(gan.input_fn, steps=1)
64 |
65 | @parameterized.parameters(TEST_ARCHITECTURES)
66 | def testSingleTrainingStepArchitectures(self, architecture):
67 | self._runSingleTrainingStep(architecture, loss_lib.hinge,
68 | penalty_lib.no_penalty)
69 |
70 | @parameterized.parameters(TEST_LOSSES)
71 | def testSingleTrainingStepLosses(self, loss_fn):
72 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_fn,
73 | penalty_lib.no_penalty)
74 |
75 | @parameterized.parameters(TEST_PENALTIES)
76 | def testSingleTrainingStepPenalties(self, penalty_fn):
77 | self._runSingleTrainingStep(c.RESNET_CIFAR_ARCH, loss_lib.hinge, penalty_fn)
78 |
79 |
80 | if __name__ == "__main__":
81 | tf.test.main()
82 |
--------------------------------------------------------------------------------
/compare_gan/gans/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities library."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import numpy as np
23 | import scipy.misc
24 | import tensorflow as tf
25 |
26 |
27 | def check_folder(log_dir):
28 | if not tf.gfile.IsDirectory(log_dir):
29 | tf.gfile.MakeDirs(log_dir)
30 | return log_dir
31 |
32 |
33 | def save_images(images, image_path):
34 | with tf.gfile.Open(image_path, "wb") as f:
35 | scipy.misc.imsave(f, images * 255.0)
36 |
37 |
38 | def rotate_images(images, rot90_scalars=(0, 1, 2, 3)):
39 | """Return the input image and its 90, 180, and 270 degree rotations."""
40 | images_rotated = [
41 | images, # 0 degree
42 | tf.image.flip_up_down(tf.image.transpose_image(images)), # 90 degrees
43 | tf.image.flip_left_right(tf.image.flip_up_down(images)), # 180 degrees
44 | tf.image.transpose_image(tf.image.flip_up_down(images)) # 270 degrees
45 | ]
46 |
47 | results = tf.stack([images_rotated[i] for i in rot90_scalars])
48 | results = tf.reshape(results,
49 | [-1] + images.get_shape().as_list()[1:])
50 | return results
51 |
52 |
53 | def gaussian(batch_size, n_dim, mean=0., var=1.):
54 | return np.random.normal(mean, var, (batch_size, n_dim)).astype(np.float32)
55 |
--------------------------------------------------------------------------------
/compare_gan/hooks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Contains SessionRunHooks for training."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import time
23 |
24 | from absl import logging
25 | import tensorflow as tf
26 |
27 |
28 | class AsyncCheckpointSaverHook(tf.contrib.tpu.AsyncCheckpointSaverHook):
29 | """Saves checkpoints every N steps in a asynchronous thread.
30 |
31 | This is the same as tf.contrib.tpu.AsyncCheckpointSaverHook but guarantees
32 | that there will be a checkpoint every `save_steps` steps. This helps to have
33 | eval results at fixed step counts, even when training is paused between
34 | regular checkpoint intervals.
35 | """
36 |
37 | def after_create_session(self, session, coord):
38 | super(AsyncCheckpointSaverHook, self).after_create_session(session, coord)
39 | # Interruptions to the training job can cause non-regular checkpoints
40 | # (between every_steps). Modify last triggered step to point to the last
41 | # regular checkpoint step to make sure we trigger on the next regular
42 | # checkpoint step.
43 | step = session.run(self._global_step_tensor)
44 | every_steps = self._timer._every_steps # pylint: disable=protected-access
45 | last_triggered_step = step - step % every_steps
46 | self._timer.update_last_triggered_step(last_triggered_step)
47 |
48 |
49 | class EveryNSteps(tf.train.SessionRunHook):
50 | """"Base class for hooks that execute callbacks every N steps.
51 |
52 | class MyHook(EveryNSteps):
53 | def __init__(self, every_n_steps):
54 | super(MyHook, self).__init__(every_n_steps)
55 |
56 | def every_n_steps_after_run(self, step, run_context, run_values):
57 | # Your Implementation
58 |
59 | If you do overwrite begin(), end(), before_run() or after_run() make sure to
60 | call super() at the beginning.
61 | """
62 |
63 | def __init__(self, every_n_steps):
64 | """Initializes an `EveryNSteps` hook.
65 |
66 | Args:
67 | every_n_steps: `int`, the number of steps to allow between callbacks.
68 | """
69 | self._timer = tf.train.SecondOrStepTimer(every_steps=every_n_steps)
70 | self._global_step_tensor = None
71 |
72 | def begin(self):
73 | self._global_step_tensor = tf.train.get_global_step()
74 | if self._global_step_tensor is None:
75 | raise RuntimeError("Global step must be created to use EveryNSteps.")
76 |
77 | def before_run(self, run_context): # pylint: disable=unused-argument
78 | """Overrides `SessionRunHook.before_run`.
79 |
80 | Args:
81 | run_context: A `SessionRunContext` object.
82 |
83 | Returns:
84 | None or a `SessionRunArgs` object.
85 | """
86 | return tf.train.SessionRunArgs({"global_step": self._global_step_tensor})
87 |
88 | def after_run(self, run_context, run_values):
89 | """Overrides `SessionRunHook.after_run`.
90 |
91 | Args:
92 | run_context: A `SessionRunContext` object.
93 | run_values: A SessionRunValues object.
94 | """
95 | step = run_values.results["global_step"]
96 | if self._timer.should_trigger_for_step(step):
97 | self.every_n_steps_after_run(step, run_context, run_values)
98 | self._timer.update_last_triggered_step(step)
99 |
100 | def end(self, sess):
101 | step = sess.run(self._global_step_tensor)
102 | self.every_n_steps_after_run(step, None, None)
103 |
104 | def every_n_steps_after_run(self, step, run_context, run_values):
105 | """Callback after every n"th call to run().
106 |
107 | Args:
108 | step: Current global_step value.
109 | run_context: A `SessionRunContext` object.
110 | run_values: A SessionRunValues object.
111 | """
112 | raise NotImplementedError("Subclasses of EveryNSteps should implement "
113 | "every_n_steps_after_run().")
114 |
115 |
116 | class ReportProgressHook(EveryNSteps):
117 | """SessionRunHook that reports progress to a `TaskManager` instance."""
118 |
119 | def __init__(self, task_manager, max_steps, every_n_steps=100):
120 | """Create a new instance of ReportProgressHook.
121 |
122 | Args:
123 | task_manager: A `TaskManager` instance that implements report_progress().
124 | max_steps: Maximum number of training steps.
125 | every_n_steps: How frequently the hook should report progress.
126 | """
127 | super(ReportProgressHook, self).__init__(every_n_steps=every_n_steps)
128 | logging.info("Creating ReportProgressHook to report progress every %d "
129 | "steps.", every_n_steps)
130 | self.max_steps = max_steps
131 | self.task_manager = task_manager
132 | self.start_time = None
133 | self.start_step = None
134 |
135 | def every_n_steps_after_run(self, step, run_context, run_values):
136 | if self.start_time is None:
137 | # First call.
138 | self.start_time = time.time()
139 | self.start_step = step
140 | return
141 |
142 | time_elapsed = time.time() - self.start_time
143 | steps_per_sec = float(step - self.start_step) / time_elapsed
144 | eta_seconds = (self.max_steps - step) / (steps_per_sec + 0.0000001)
145 | message = "{:.1f}% @{:d}, {:.1f} steps/s, ETA: {:.0f} min".format(
146 | 100 * step / self.max_steps, step, steps_per_sec, eta_seconds / 60)
147 | logging.info("Reporting progress: %s", message)
148 | self.task_manager.report_progress(message)
149 |
--------------------------------------------------------------------------------
/compare_gan/main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Binary to train and evaluate one GAN configuration."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 |
24 | # pylint: disable=unused-import
25 |
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 |
30 | from compare_gan import datasets
31 | from compare_gan import runner_lib
32 | # Import GAN types so that they can be used in Gin configs without module names.
33 | from compare_gan.gans.modular_gan import ModularGAN
34 | from compare_gan.gans.s3gan import S3GAN
35 | from compare_gan.gans.ssgan import SSGAN
36 |
37 | # Required import to configure core TF classes and functions.
38 | import gin
39 | import gin.tf.external_configurables
40 | import tensorflow as tf
41 |
42 |
43 | FLAGS = flags.FLAGS
44 |
45 | flags.DEFINE_string("model_dir", None, "Where to store files.")
46 | flags.DEFINE_string(
47 | "schedule", "train",
48 | "Schedule to run. Options: train, continuous_eval.")
49 | flags.DEFINE_multi_string(
50 | "gin_config", [],
51 | "List of paths to the config files.")
52 | flags.DEFINE_multi_string(
53 | "gin_bindings", [],
54 | "Newline separated list of Gin parameter bindings.")
55 | flags.DEFINE_string(
56 | "score_filename", "scores.csv",
57 | "Name of the CSV file with evaluation results model_dir.")
58 |
59 | flags.DEFINE_integer(
60 | "num_eval_averaging_runs", 3,
61 | "How many times to average FID and IS")
62 | flags.DEFINE_integer(
63 | "eval_every_steps", 5000,
64 | "Evaluate only checkpoints whose step is divisible by this integer")
65 |
66 | flags.DEFINE_bool("use_tpu", None, "Whether running on TPU or not.")
67 |
68 |
69 | def _get_cluster():
70 | if not FLAGS.use_tpu: # pylint: disable=unreachable
71 | return None
72 | if "TPU_NAME" not in os.environ:
73 | raise ValueError("Could not find a TPU. Set TPU_NAME.")
74 | return tf.contrib.cluster_resolver.TPUClusterResolver(
75 | tpu=os.environ["TPU_NAME"],
76 | zone=os.environ.get("TPU_ZONE", None))
77 |
78 |
79 | @gin.configurable("run_config")
80 | def _get_run_config(tf_random_seed=None,
81 | single_core=False,
82 | iterations_per_loop=1000,
83 | save_checkpoints_steps=5000,
84 | keep_checkpoint_max=1000):
85 | """Return `RunConfig` for TPUs."""
86 | tpu_config = tf.contrib.tpu.TPUConfig(
87 | num_shards=1 if single_core else None, # None = all cores.
88 | iterations_per_loop=iterations_per_loop)
89 | return tf.contrib.tpu.RunConfig(
90 | model_dir=FLAGS.model_dir,
91 | tf_random_seed=tf_random_seed,
92 | save_checkpoints_steps=save_checkpoints_steps,
93 | keep_checkpoint_max=keep_checkpoint_max,
94 | cluster=_get_cluster(),
95 | tpu_config=tpu_config)
96 |
97 |
98 |
99 |
100 | def _get_task_manager():
101 | """Returns a TaskManager for this experiment."""
102 | score_file = os.path.join(FLAGS.model_dir, FLAGS.score_filename)
103 | return runner_lib.TaskManagerWithCsvResults(
104 | model_dir=FLAGS.model_dir, score_file=score_file)
105 |
106 |
107 | def main(unused_argv):
108 | logging.info("Gin config: %s\nGin bindings: %s",
109 | FLAGS.gin_config, FLAGS.gin_bindings)
110 | gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
111 |
112 |
113 | if FLAGS.use_tpu is None:
114 | FLAGS.use_tpu = bool(os.environ.get("TPU_NAME", ""))
115 | if FLAGS.use_tpu:
116 | logging.info("Found TPU %s.", os.environ["TPU_NAME"])
117 | run_config = _get_run_config()
118 | task_manager = _get_task_manager()
119 | options = runner_lib.get_options_dict()
120 | runner_lib.run_with_schedule(
121 | schedule=FLAGS.schedule,
122 | run_config=run_config,
123 | task_manager=task_manager,
124 | options=options,
125 | use_tpu=FLAGS.use_tpu,
126 | num_eval_averaging_runs=FLAGS.num_eval_averaging_runs,
127 | eval_every_steps=FLAGS.eval_every_steps)
128 | logging.info("I\"m done with my work, ciao!")
129 |
130 |
131 | if __name__ == "__main__":
132 | flags.mark_flag_as_required("model_dir")
133 | app.run(main)
134 |
--------------------------------------------------------------------------------
/compare_gan/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # coding=utf-8
17 |
--------------------------------------------------------------------------------
/compare_gan/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Discriminator accuracy.
17 |
18 | Computes the discrimionator's accuracy on (a subset) of the training dataset,
19 | test dataset, and a generated data set. The score is averaged over several
20 | multiple generated data sets and subsets of the training data.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from absl import logging
28 |
29 | from compare_gan import datasets
30 | from compare_gan import eval_utils
31 | from compare_gan.metrics import eval_task
32 |
33 | import numpy as np
34 |
35 |
36 | class AccuracyTask(eval_task.EvalTask):
37 | """Evaluation Task for computing and reporting accuracy."""
38 |
39 | def metric_list(self):
40 | return frozenset([
41 | "train_accuracy", "test_accuracy", "fake_accuracy", "train_d_loss",
42 | "test_d_loss"
43 | ])
44 |
45 | def run_in_session(self, options, sess, gan, real_images):
46 | del options
47 | return compute_accuracy_loss(sess, gan, real_images)
48 |
49 |
50 | def compute_accuracy_loss(sess,
51 | gan,
52 | test_images,
53 | max_train_examples=50000,
54 | num_repeat=5):
55 | """Compute discriminator's accuracy and loss on a given dataset.
56 |
57 | Args:
58 | sess: Tf.Session object.
59 | gan: Any AbstractGAN instance.
60 | test_images: numpy array with test images.
61 | max_train_examples: How many "train" examples to get from the dataset.
62 | In each round, some of them will be randomly selected
63 | to evaluate train set accuracy.
64 | num_repeat: How many times to repreat the computation.
65 | The mean of all the results is reported.
66 | Returns:
67 | Dict[Text, float] with all the computed scores.
68 |
69 | Raises:
70 | ValueError: If the number of test_images is greater than the number of
71 | training images returned by the dataset.
72 | """
73 | logging.info("Evaluating training and test accuracy...")
74 | train_images = eval_utils.get_real_images(
75 | dataset=datasets.get_dataset(),
76 | num_examples=max_train_examples,
77 | split="train",
78 | failure_on_insufficient_examples=False)
79 | if train_images.shape[0] < test_images.shape[0]:
80 | raise ValueError("num_train %d must be larger than num_test %d." %
81 | (train_images.shape[0], test_images.shape[0]))
82 |
83 | num_batches = int(np.floor(test_images.shape[0] / gan.batch_size))
84 | if num_batches * gan.batch_size < test_images.shape[0]:
85 | logging.error("Ignoring the last batch with %d samples / %d epoch size.",
86 | test_images.shape[0] - num_batches * gan.batch_size,
87 | gan.batch_size)
88 |
89 | ret = {
90 | "train_accuracy": [],
91 | "test_accuracy": [],
92 | "fake_accuracy": [],
93 | "train_d_loss": [],
94 | "test_d_loss": []
95 | }
96 |
97 | for _ in range(num_repeat):
98 | idx = np.random.choice(train_images.shape[0], test_images.shape[0])
99 | bs = gan.batch_size
100 | train_subset = [train_images[i] for i in idx]
101 | train_predictions, test_predictions, fake_predictions = [], [], []
102 | train_d_losses, test_d_losses = [], []
103 |
104 | for i in range(num_batches):
105 | z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
106 | start_idx = i * bs
107 | end_idx = start_idx + bs
108 | test_batch = test_images[start_idx : end_idx]
109 | train_batch = train_subset[start_idx : end_idx]
110 |
111 | test_prediction, test_d_loss, fake_images = sess.run(
112 | [gan.discriminator_output, gan.d_loss, gan.fake_images],
113 | feed_dict={
114 | gan.inputs: test_batch, gan.z: z_sample
115 | })
116 | train_prediction, train_d_loss = sess.run(
117 | [gan.discriminator_output, gan.d_loss],
118 | feed_dict={
119 | gan.inputs: train_batch,
120 | gan.z: z_sample
121 | })
122 | fake_prediction = sess.run(
123 | gan.discriminator_output,
124 | feed_dict={gan.inputs: fake_images})[0]
125 |
126 | train_predictions.append(train_prediction[0])
127 | test_predictions.append(test_prediction[0])
128 | fake_predictions.append(fake_prediction)
129 | train_d_losses.append(train_d_loss)
130 | test_d_losses.append(test_d_loss)
131 |
132 | train_predictions = [x >= 0.5 for x in train_predictions]
133 | test_predictions = [x >= 0.5 for x in test_predictions]
134 | fake_predictions = [x < 0.5 for x in fake_predictions]
135 |
136 | ret["train_accuracy"].append(np.array(train_predictions).mean())
137 | ret["test_accuracy"].append(np.array(test_predictions).mean())
138 | ret["fake_accuracy"].append(np.array(fake_predictions).mean())
139 | ret["train_d_loss"].append(np.mean(train_d_losses))
140 | ret["test_d_loss"].append(np.mean(test_d_losses))
141 |
142 | for key in ret:
143 | ret[key] = np.mean(ret[key])
144 |
145 | return ret
146 |
--------------------------------------------------------------------------------
/compare_gan/metrics/eval_task.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract class that describes a single evaluation task.
17 |
18 | The tasks can be run in or after session. Each task can result
19 | in a set of metrics.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import abc
27 |
28 | from absl import flags
29 | import six
30 | import tensorflow as tf
31 |
32 | FLAGS = flags.FLAGS
33 |
34 |
35 | @six.add_metaclass(abc.ABCMeta)
36 | class EvalTask(object):
37 | """Class that describes a single evaluation task.
38 |
39 | For example: compute inception score or compute accuracy.
40 | The classes that inherit from it, should implement the methods below.
41 | """
42 |
43 | _LABEL = None
44 |
45 | def metric_list(self):
46 | """List of metrics that this class generates.
47 |
48 | These are the only keys that RunXX methods can return in
49 | their output maps.
50 | Returns:
51 | frozenset of strings, which are the names of the metrics that task
52 | computes.
53 | """
54 | return frozenset(self._LABEL)
55 |
56 | def _create_session(self):
57 | try:
58 | target = FLAGS.master
59 | except AttributeError:
60 | return tf.Session()
61 | return tf.Session(target)
62 |
63 | @abc.abstractmethod
64 | def run_after_session(self, fake_dset, real_dset):
65 | """Runs the task after all the generator calls, after session was closed.
66 |
67 | WARNING: the images here, are in 0..255 range, with 3 color channels.
68 |
69 | Args:
70 | fake_dset: `EvalDataSample` with fake images and inception features.
71 | real_dset: `EvalDataSample` with real images and inception features.
72 |
73 | Returns:
74 | Dict with metric values. The keys must be contained in the set that
75 | "MetricList" method above returns.
76 | """
77 |
--------------------------------------------------------------------------------
/compare_gan/metrics/fid_score.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the Frechet Inception Distance.
17 |
18 | Implemented as a wrapper around the tf.contrib.gan library. The details can be
19 | found in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash
20 | Equilibrium", Heusel et al. [https://arxiv.org/abs/1706.08500].
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from absl import logging
28 |
29 | from compare_gan.metrics import eval_task
30 |
31 | import tensorflow as tf
32 | import tensorflow_gan as tfgan
33 |
34 |
35 | # Special value returned when FID code returned exception.
36 | FID_CODE_FAILED = 4242.0
37 |
38 |
39 | class FIDScoreTask(eval_task.EvalTask):
40 | """Evaluation task for the FID score."""
41 |
42 | _LABEL = "fid_score"
43 |
44 | def run_after_session(self, fake_dset, real_dset):
45 | logging.info("Calculating FID.")
46 | with tf.Graph().as_default():
47 | fake_activations = tf.convert_to_tensor(fake_dset.activations)
48 | real_activations = tf.convert_to_tensor(real_dset.activations)
49 | fid = tfgan.eval.frechet_classifier_distance_from_activations(
50 | real_activations=real_activations,
51 | generated_activations=fake_activations)
52 | with self._create_session() as sess:
53 | fid = sess.run(fid)
54 | logging.info("Frechet Inception Distance: %.3f.", fid)
55 | return {self._LABEL: fid}
56 |
57 |
58 | def compute_fid_from_activations(fake_activations, real_activations):
59 | """Returns the FID based on activations.
60 |
61 | Args:
62 | fake_activations: NumPy array with fake activations.
63 | real_activations: NumPy array with real activations.
64 | Returns:
65 | A float, the Frechet Inception Distance.
66 | """
67 | logging.info("Computing FID score.")
68 | assert fake_activations.shape == real_activations.shape
69 | with tf.Session(graph=tf.Graph()) as sess:
70 | fake_activations = tf.convert_to_tensor(fake_activations)
71 | real_activations = tf.convert_to_tensor(real_activations)
72 | fid = tfgan.eval.frechet_classifier_distance_from_activations(
73 | real_activations=real_activations,
74 | generated_activations=fake_activations)
75 | return sess.run(fid)
76 |
--------------------------------------------------------------------------------
/compare_gan/metrics/fid_score_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for the FID score."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.metrics import fid_score as fid_score_lib
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 |
28 | class FIDScoreTest(tf.test.TestCase):
29 |
30 | def test_fid_computation(self):
31 | real_data = np.ones((100, 2))
32 | real_data[:50, 0] = 2
33 | gen_data = np.ones((100, 2)) * 9
34 | gen_data[50:, 0] = 2
35 | # mean(real_data) = [1.5, 1]
36 | # Cov(real_data) = [[ 0.2525, 0], [0, 0]]
37 | # mean(gen_data) = [5.5, 9]
38 | # Cov(gen_data) = [[12.37, 0], [0, 0]]
39 | result = fid_score_lib.compute_fid_from_activations(real_data, gen_data)
40 | self.assertNear(result, 89.091, 1e-4)
41 |
42 | if __name__ == "__main__":
43 | tf.test.main()
44 |
--------------------------------------------------------------------------------
/compare_gan/metrics/fractal_dimension.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the fractal dimension metric."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.metrics import eval_task
23 |
24 | import numpy as np
25 | import scipy.spatial
26 |
27 |
28 | class FractalDimensionTask(eval_task.EvalTask):
29 | """Fractal dimension metric."""
30 |
31 | _LABEL = "fractal_dimension"
32 |
33 | def run_after_session(self, options, eval_data_fake, eval_data_real=None):
34 | print(eval_data_fake)
35 | score = compute_fractal_dimension(eval_data_fake.images)
36 | return {self._LABEL: score}
37 |
38 |
39 | def compute_fractal_dimension(fake_images,
40 | num_fd_seeds=100,
41 | n_bins=1000,
42 | scale=0.1):
43 | """Compute Fractal Dimension of fake_images.
44 |
45 | Args:
46 | fake_images: an np array of datapoints, the dimensionality and scaling of
47 | images can be arbitrary
48 | num_fd_seeds: number of random centers from which fractal dimension
49 | computation is performed
50 | n_bins: number of bins to split the range of distance values into
51 | scale: the scale of the y interval in the log-log plot for which we apply a
52 | linear regression fit
53 |
54 | Returns:
55 | fractal dimension of the dataset.
56 | """
57 | assert len(fake_images.shape) >= 2
58 | assert fake_images.shape[0] >= num_fd_seeds
59 |
60 | num_images = fake_images.shape[0]
61 | # In order to apply scipy function we need to flatten the number of dimensions
62 | # to 2
63 | fake_images = np.reshape(fake_images, (num_images, -1))
64 | fake_images_subset = fake_images[np.random.randint(
65 | num_images, size=num_fd_seeds)]
66 |
67 | distances = scipy.spatial.distance.cdist(fake_images,
68 | fake_images_subset).flatten()
69 | min_distance = np.min(distances[np.nonzero(distances)])
70 | max_distance = np.max(distances)
71 | buckets = min_distance * (
72 | (max_distance / min_distance)**np.linspace(0, 1, n_bins))
73 | # Create a table where first column corresponds to distances r
74 | # and second column corresponds to number of points N(r) that lie
75 | # within distance r from the random seeds
76 | fd_result = np.zeros((n_bins - 1, 2))
77 | fd_result[:, 0] = buckets[1:]
78 | fd_result[:, 1] = np.sum(np.less.outer(distances, buckets[1:]), axis=0)
79 |
80 | # We compute the slope of the log-log plot at the middle y value
81 | # which is stored in y_val; the linear regression fit is computed on
82 | # the part of the plot that corresponds to an interval around y_val
83 | # whose size is 2*scale*(total width of the y axis)
84 | max_y = np.log(num_images * num_fd_seeds)
85 | min_y = np.log(num_fd_seeds)
86 | x = np.log(fd_result[:, 0])
87 | y = np.log(fd_result[:, 1])
88 | y_width = max_y - min_y
89 | y_val = min_y + 0.5 * y_width
90 |
91 | start = np.argmax(y > y_val - scale * y_width)
92 | end = np.argmax(y > y_val + scale * y_width)
93 |
94 | slope = np.linalg.lstsq(
95 | a=np.vstack([x[start:end], np.ones(end - start)]).transpose(),
96 | b=y[start:end].reshape(end - start, 1))[0][0][0]
97 | return slope
98 |
--------------------------------------------------------------------------------
/compare_gan/metrics/fractal_dimension_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for the fractal dimension metric."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.metrics import fractal_dimension as fractal_dimension_lib
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 |
28 | class FractalDimensionTest(tf.test.TestCase):
29 |
30 | def test_straight_line(self):
31 | """The fractal dimension of a 1D line must lie near 1.0."""
32 | self.assertAllClose(
33 | fractal_dimension_lib.compute_fractal_dimension(
34 | np.random.uniform(size=(10000, 1))), 1.0, atol=0.05)
35 |
36 | def test_square(self):
37 | """The fractal dimension of a 2D square must lie near 2.0."""
38 | self.assertAllClose(
39 | fractal_dimension_lib.compute_fractal_dimension(
40 | np.random.uniform(size=(10000, 2))), 2.0, atol=0.1)
41 |
--------------------------------------------------------------------------------
/compare_gan/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the Inception Score.
17 |
18 | Implemented as a wrapper around the tensorflow_gan library. The details can be
19 | found in "Improved Techniques for Training GANs", Salimans et al.
20 | [https://arxiv.org/abs/1606.03498].
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from absl import logging
28 |
29 | from compare_gan.metrics import eval_task
30 | import tensorflow as tf
31 | import tensorflow_gan as tfgan
32 |
33 |
34 | class InceptionScoreTask(eval_task.EvalTask):
35 | """Task that computes inception score for the generated images."""
36 |
37 | _LABEL = "inception_score"
38 |
39 | def run_after_session(self, fake_dset, real_dest):
40 | del real_dest
41 | logging.info("Computing inception score.")
42 | with tf.Graph().as_default():
43 | fake_logits = tf.convert_to_tensor(fake_dset.logits)
44 | inception_score = tfgan.eval.classifier_score_from_logits(fake_logits)
45 | with self._create_session() as sess:
46 | inception_score = sess.run(inception_score)
47 | logging.info("Inception score: %.3f", inception_score)
48 | return {self._LABEL: inception_score}
49 |
--------------------------------------------------------------------------------
/compare_gan/metrics/jacobian_conditioning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the Jacobian Conditioning metrics.
17 |
18 | The details can be found in "Is Generator Conditioning Causally Related to
19 | GAN Performance?", Odena et al. [https://arxiv.org/abs/1802.08768].
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | from compare_gan.metrics import eval_task
27 |
28 | import numpy as np
29 | import tensorflow as tf
30 |
31 |
32 | class GeneratorConditionNumberTask(eval_task.EvalTask):
33 | """Computes the generator condition number.
34 |
35 | Computes the condition number for metric Tensor of the generator Jacobian.
36 | This condition number is computed locally for each z sample in a minibatch.
37 | Returns the mean log condition number and standard deviation across the
38 | minibatch.
39 |
40 | Follows the methods in https://arxiv.org/abs/1802.08768.
41 | """
42 |
43 | _CONDITION_NUMBER_COUNT = "log_condition_number_count"
44 | _CONDITION_NUMBER_MEAN = "log_condition_number_mean"
45 | _CONDITION_NUMBER_STD = "log_condition_number_std"
46 |
47 | def metric_list(self):
48 | return frozenset([
49 | self._CONDITION_NUMBER_COUNT, self._CONDITION_NUMBER_MEAN,
50 | self._CONDITION_NUMBER_STD
51 | ])
52 |
53 | def run_in_session(self, options, sess, gan, real_images):
54 | del options, real_images
55 | result_dict = {}
56 | result = compute_generator_condition_number(sess, gan)
57 | result_dict[self._CONDITION_NUMBER_COUNT] = len(result)
58 | result_dict[self._CONDITION_NUMBER_MEAN] = np.mean(result)
59 | result_dict[self._CONDITION_NUMBER_STD] = np.std(result)
60 | return result_dict
61 |
62 |
63 | def compute_generator_condition_number(sess, gan):
64 | """Computes the generator condition number.
65 |
66 | Computes the Jacobian of the generator in session, then postprocesses to get
67 | the condition number.
68 |
69 | Args:
70 | sess: tf.Session object.
71 | gan: AbstractGAN object, that is already present in the current tf.Graph.
72 |
73 | Returns:
74 | A list of length gan.batch_size. Each element is the condition number
75 | computed at a single z sample within a minibatch.
76 | """
77 | shape = gan.fake_images.get_shape().as_list()
78 | flat_generator_output = tf.reshape(
79 | gan.fake_images, [gan.batch_size, np.prod(shape[1:])])
80 | tf_jacobian = compute_jacobian(
81 | xs=gan.z, fx=flat_generator_output)
82 | z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
83 | np_jacobian = sess.run(tf_jacobian, feed_dict={gan.z: z_sample})
84 | result_dict = analyze_jacobian(np_jacobian)
85 | return result_dict["metric_tensor"]["log_condition_number"]
86 |
87 |
88 | def compute_jacobian(xs, fx):
89 | """Computes df/dx matrix.
90 |
91 | We assume x and fx are both batched, so the shape of the Jacobian is:
92 | [fx.shape[0]] + fx.shape[1:] + xs.shape[1:]
93 |
94 | This function computes the grads inside a TF loop so that we don't
95 | end up storing many extra copies of the function we are taking the
96 | Jacobian of.
97 |
98 | Args:
99 | xs: input tensor(s) of arbitrary shape.
100 | fx: f(x) tensor of arbitrary shape.
101 |
102 | Returns:
103 | df/dx tensor of shape [fx.shape[0], fx.shape[1], xs.shape[1]].
104 | """
105 | # Declares an iterator and tensor array loop variables for the gradients.
106 | n = fx.get_shape().as_list()[1]
107 | loop_vars = [tf.constant(0, tf.int32), tf.TensorArray(xs.dtype, n)]
108 |
109 | def accumulator(j, result):
110 | return (j + 1, result.write(j, tf.gradients(fx[:, j], xs)[0]))
111 |
112 | # Iterates over all elements of the gradient and computes all partial
113 | # derivatives.
114 | _, df_dxs = tf.while_loop(lambda j, _: j < n, accumulator, loop_vars)
115 |
116 | df_dx = df_dxs.stack()
117 | df_dx = tf.transpose(df_dx, perm=[1, 0, 2])
118 |
119 | return df_dx
120 |
121 |
122 | def _analyze_metric_tensor(metric_tensor):
123 | """Analyzes a metric tensor.
124 |
125 | Args:
126 | metric_tensor: A numpy array of shape [batch, dim, dim]
127 |
128 | Returns:
129 | A dict containing spectral statstics.
130 | """
131 | # eigenvalues will have shape [batch, dim].
132 | eigenvalues, _ = np.linalg.eig(metric_tensor)
133 |
134 | # Shape [batch,].
135 | condition_number = np.linalg.cond(metric_tensor)
136 | log_condition_number = np.log(condition_number)
137 | (_, logdet) = np.linalg.slogdet(metric_tensor)
138 |
139 | return {
140 | "eigenvalues": eigenvalues,
141 | "logdet": logdet,
142 | "log_condition_number": log_condition_number
143 | }
144 |
145 |
146 | def analyze_jacobian(jacobian_array):
147 | """Computes eigenvalue statistics of the Jacobian.
148 |
149 | Computes the eigenvalues and condition number of the metric tensor for the
150 | Jacobian evaluated at each element of the batch and the mean metric tensor
151 | across the batch.
152 |
153 | Args:
154 | jacobian_array: A numpy array holding the Jacobian.
155 |
156 | Returns:
157 | A dict of spectral statistics with two elements, one containing stats
158 | for every metric tensor in the batch, another for the mean metric tensor.
159 | """
160 | # Shape [batch, x_dim, fx_dim].
161 | jacobian_transpose = np.transpose(jacobian_array, [0, 2, 1])
162 |
163 | # Shape [batch, x_dim, x_dim].
164 | metric_tensor = np.matmul(jacobian_transpose, jacobian_array)
165 |
166 | mean_metric_tensor = np.mean(metric_tensor, 0)
167 | # Reshapes to have a dummy batch dimension.
168 | mean_metric_tensor = np.reshape(mean_metric_tensor,
169 | (1,) + metric_tensor.shape[1:])
170 | return {
171 | "metric_tensor": _analyze_metric_tensor(metric_tensor),
172 | "mean_metric_tensor": _analyze_metric_tensor(mean_metric_tensor)
173 | }
174 |
--------------------------------------------------------------------------------
/compare_gan/metrics/jacobian_conditioning_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for Jacobian Conditioning metrics."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.metrics import jacobian_conditioning
23 | import mock
24 | import numpy as np
25 | from six.moves import range
26 | import tensorflow as tf
27 |
28 |
29 | _BATCH_SIZE = 32
30 |
31 |
32 | def SlowJacobian(xs, fx):
33 | """Computes df/dx matrix.
34 |
35 | As jacobian_conditioning.compute_jacobian, but explicitly loops over
36 | dimensions of f.
37 |
38 | Args:
39 | xs: input tensor(s) of arbitrary shape.
40 | fx: f(x) tensor of arbitrary shape.
41 |
42 | Returns:
43 | df/dx tensor.
44 | """
45 | fxs = tf.unstack(fx, axis=-1)
46 | grads = [tf.gradients(fx_i, xs) for fx_i in fxs]
47 | grads = [grad[0] for grad in grads]
48 | df_dx = tf.stack(grads, axis=1)
49 | return df_dx
50 |
51 |
52 | class JacobianConditioningTest(tf.test.TestCase):
53 |
54 | def test_jacobian_simple_case(self):
55 | x = tf.random_normal([_BATCH_SIZE, 2])
56 | W = tf.constant([[2., -1.], [1.5, 1.]]) # pylint: disable=invalid-name
57 | f = tf.matmul(x, W)
58 | j_tensor = jacobian_conditioning.compute_jacobian(xs=x, fx=f)
59 | with tf.Session() as sess:
60 | jacobian = sess.run(j_tensor)
61 |
62 | # Transpose of W in 'expected' is expected because in vector notation
63 | # f = W^T * x.
64 | expected = tf.tile([[[2, 1.5], [-1, 1]]], [_BATCH_SIZE, 1, 1])
65 | self.assertAllClose(jacobian, expected)
66 |
67 | def test_jacobian_against_slow_version(self):
68 | x = tf.random_normal([_BATCH_SIZE, 2])
69 | h1 = tf.contrib.layers.fully_connected(x, 20)
70 | h2 = tf.contrib.layers.fully_connected(h1, 20)
71 | f = tf.contrib.layers.fully_connected(h2, 10)
72 |
73 | j_slow_tensor = SlowJacobian(xs=x, fx=f)
74 | j_fast_tensor = jacobian_conditioning.compute_jacobian(xs=x, fx=f)
75 |
76 | with tf.Session() as sess:
77 | sess.run(tf.global_variables_initializer())
78 | j_fast, j_slow = sess.run([j_fast_tensor, j_slow_tensor])
79 | self.assertAllClose(j_fast, j_slow)
80 |
81 | def test_jacobian_numerically(self):
82 | x = tf.random_normal([_BATCH_SIZE, 2])
83 | h1 = tf.contrib.layers.fully_connected(x, 20)
84 | h2 = tf.contrib.layers.fully_connected(h1, 20)
85 | f = tf.contrib.layers.fully_connected(h2, 10)
86 | j_tensor = jacobian_conditioning.compute_jacobian(xs=x, fx=f)
87 |
88 | with tf.Session() as sess:
89 | sess.run(tf.global_variables_initializer())
90 | x_np = sess.run(x)
91 | jacobian = sess.run(j_tensor, feed_dict={x: x_np})
92 |
93 | # Test 10 random elements.
94 | for _ in range(10):
95 | # Pick a random element of Jacobian to test.
96 | batch_idx = np.random.randint(_BATCH_SIZE)
97 | x_idx = np.random.randint(2)
98 | f_idx = np.random.randint(10)
99 |
100 | # Test with finite differences.
101 | epsilon = 1e-4
102 |
103 | x_plus = x_np.copy()
104 | x_plus[batch_idx, x_idx] += epsilon
105 | f_plus = sess.run(f, feed_dict={x: x_plus})[batch_idx, f_idx]
106 |
107 | x_minus = x_np.copy()
108 | x_minus[batch_idx, x_idx] -= epsilon
109 | f_minus = sess.run(f, feed_dict={x: x_minus})[batch_idx, f_idx]
110 |
111 | self.assertAllClose(
112 | jacobian[batch_idx, f_idx, x_idx],
113 | (f_plus - f_minus) / (2. * epsilon),
114 | rtol=1e-3,
115 | atol=1e-3)
116 |
117 | def test_analyze_metric_tensor(self):
118 | # Assumes NumPy works, just tests that output shapes are as expected.
119 | jacobian = np.random.normal(0, 1, (_BATCH_SIZE, 2, 10))
120 | metric_tensor = np.matmul(np.transpose(jacobian, [0, 2, 1]), jacobian)
121 | result_dict = jacobian_conditioning._analyze_metric_tensor(metric_tensor)
122 | self.assertAllEqual(result_dict['eigenvalues'].shape, [_BATCH_SIZE, 10])
123 | self.assertAllEqual(result_dict['logdet'].shape, [_BATCH_SIZE])
124 | self.assertAllEqual(result_dict['log_condition_number'].shape,
125 | [_BATCH_SIZE])
126 |
127 | def test_analyze_jacobian(self):
128 | m = mock.patch.object(
129 | jacobian_conditioning, '_analyze_metric_tensor', new=lambda x: x)
130 | m.start()
131 | jacobian = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
132 | result_dict = jacobian_conditioning.analyze_jacobian(jacobian)
133 | self.assertAllEqual(result_dict['metric_tensor'],
134 | [[[10, 14], [14, 20]], [[40, 56], [56, 80]]])
135 | self.assertAllEqual(result_dict['mean_metric_tensor'],
136 | [[[25, 35], [35, 50]]])
137 | m.stop()
138 |
139 |
140 | if __name__ == '__main__':
141 | tf.test.main()
142 |
--------------------------------------------------------------------------------
/compare_gan/metrics/kid_score.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the KID score.
17 |
18 | The details can be found in "Demystifying MMD GANs", Binkowski et al.
19 | [https://arxiv.org/abs/1801.01401].
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import math
27 |
28 | from compare_gan.metrics import eval_task
29 |
30 | import numpy as np
31 | import tensorflow as tf
32 |
33 |
34 | class KIDScoreTask(eval_task.EvalTask):
35 | """Evaluation task for the KID score."""
36 |
37 | _LABEL = "kid_score"
38 |
39 | def run_after_session(self, fake_dset, real_dset):
40 | score = kid(fake_dset.activations, real_dset.activations)
41 | return {self._LABEL: score}
42 |
43 |
44 | def kid(fake_activations,
45 | real_activations,
46 | max_batch_size=1024,
47 | dtype=None,
48 | return_stderr=False):
49 | """Unbiased estimator of the Kernel Inception Distance.
50 |
51 | As defined by https://arxiv.org/abs/1801.01401.
52 |
53 | If return_stderr, also returns an estimate of the standard error, i.e. the
54 | standard deviation of the KID estimator. Returns nan if the number
55 | of batches is too small (< 5); for more reliable estimates, one could use
56 | the asymptotic variance estimate given in https://arxiv.org/abs/1611.04488.
57 |
58 | Uses a block estimator, as in https://arxiv.org/abs/1307.1954, with blocks
59 | no larger than max_batch_size. This is slightly different than the authors'
60 | provided code, but is also unbiased (and provides more-valid a variance
61 | estimate).
62 |
63 | NOTE: the blocking code assumes that real_activations and
64 | fake_activations are in random order. If real_activations is sorted
65 | in a meaningful order, the estimator will be biased.
66 |
67 | Args:
68 | fake_activations: [batch, num_features] tensor with inception features.
69 | real_activations: [batch, num_features] tensor with inception features.
70 | max_batch_size: Batches to compute the KID.
71 | dtype: Type used by the computations.
72 | return_stderr: If true, also returns the std_error from the KID computation.
73 |
74 | Returns:
75 | KID score (and optionally std error).
76 | """
77 | real_activations.get_shape().assert_has_rank(2)
78 | fake_activations.get_shape().assert_has_rank(2)
79 |
80 | # need to know dimension for the kernel, and batch size to split things
81 | real_activations.get_shape().assert_is_fully_defined()
82 | fake_activations.get_shape().assert_is_fully_defined()
83 |
84 | n_real, dim = real_activations.get_shape().as_list()
85 | n_gen, dim2 = fake_activations.get_shape().as_list()
86 | assert dim2 == dim
87 |
88 | # tensorflow_gan forces doubles for FID, but I don't think we need that here
89 | if dtype is None:
90 | dtype = real_activations.dtype
91 | assert fake_activations.dtype == dtype
92 | else:
93 | real_activations = tf.cast(real_activations, dtype)
94 | fake_activations = tf.cast(fake_activations, dtype)
95 |
96 | # split into largest approximately-equally-sized blocks
97 | n_bins = int(math.ceil(max(n_real, n_gen) / max_batch_size))
98 | bins_r = np.full(n_bins, int(math.ceil(n_real / n_bins)))
99 | bins_g = np.full(n_bins, int(math.ceil(n_gen / n_bins)))
100 | bins_r[:(n_bins * bins_r[0]) - n_real] -= 1
101 | bins_g[:(n_bins * bins_r[0]) - n_gen] -= 1
102 | assert bins_r.min() >= 2
103 | assert bins_g.min() >= 2
104 |
105 | inds_r = tf.constant(np.r_[0, np.cumsum(bins_r)])
106 | inds_g = tf.constant(np.r_[0, np.cumsum(bins_g)])
107 |
108 | dim_ = tf.cast(dim, dtype)
109 |
110 | def get_kid_batch(i):
111 | """Computes KID on a given batch of features.
112 |
113 | Takes real_activations[ind_r[i] : ind_r[i+1]] and
114 | fake_activations[ind_g[i] : ind_g[i+1]].
115 |
116 | Args:
117 | i: is the index of the batch.
118 |
119 | Returns:
120 | KID for the given batch.
121 | """
122 | r_s = inds_r[i]
123 | r_e = inds_r[i + 1]
124 | r = real_activations[r_s:r_e]
125 | m = tf.cast(r_e - r_s, dtype)
126 |
127 | g_s = inds_g[i]
128 | g_e = inds_g[i + 1]
129 | g = fake_activations[g_s:g_e]
130 | n = tf.cast(r_e - r_s, dtype)
131 |
132 | # Could probably do this a bit faster...
133 | k_rr = (tf.matmul(r, r, transpose_b=True) / dim_ + 1)**3
134 | k_rg = (tf.matmul(r, g, transpose_b=True) / dim_ + 1)**3
135 | k_gg = (tf.matmul(g, g, transpose_b=True) / dim_ + 1)**3
136 | return (
137 | -2 * tf.reduce_mean(k_rg) + (tf.reduce_sum(k_rr) - tf.trace(k_rr)) /
138 | (m * (m - 1)) + (tf.reduce_sum(k_gg) - tf.trace(k_gg)) / (n * (n - 1)))
139 |
140 | ests = tf.map_fn(
141 | get_kid_batch, np.arange(n_bins), dtype=dtype, back_prop=False)
142 |
143 | if return_stderr:
144 | if n_bins < 5:
145 | return tf.reduce_mean(ests), np.nan
146 | mn, var = tf.nn.moments(ests, [0])
147 | return mn, tf.sqrt(var / n_bins)
148 | else:
149 | return tf.reduce_mean(ests)
150 |
--------------------------------------------------------------------------------
/compare_gan/metrics/ms_ssim_score.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Implementation of the MS-SSIM metric.
17 |
18 | The details on the application of this metric to GANs can be found in
19 | Section 5.3 of "Many Paths to Equilibrium: GANs Do Not Need to Decrease a
20 | Divergence At Every Step", Fedus*, Rosca* et al.
21 | [https://arxiv.org/abs/1710.08446].
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from absl import logging
29 |
30 | from compare_gan.metrics import eval_task
31 | from compare_gan.metrics import image_similarity
32 |
33 | import numpy as np
34 | from six.moves import range
35 | import tensorflow as tf
36 |
37 |
38 | class MultiscaleSSIMTask(eval_task.EvalTask):
39 | """Task that computes MSSIMScore for generated images."""
40 |
41 | _LABEL = "ms_ssim"
42 |
43 | def run_after_session(self, options, eval_data_fake, eval_data_real=None):
44 | del options, eval_data_real
45 | score = _compute_multiscale_ssim_score(eval_data_fake.images)
46 | return {self._LABEL: score}
47 |
48 |
49 | def _compute_multiscale_ssim_score(fake_images):
50 | """Compute ms-ssim score ."""
51 | batch_size = 64
52 | with tf.Graph().as_default():
53 | fake_images_batch = tf.train.shuffle_batch(
54 | [tf.convert_to_tensor(fake_images, dtype=tf.float32)],
55 | capacity=16*batch_size,
56 | min_after_dequeue=8*batch_size,
57 | num_threads=4,
58 | enqueue_many=True,
59 | batch_size=batch_size)
60 |
61 | # Following section 5.3 of https://arxiv.org/pdf/1710.08446.pdf, we only
62 | # evaluate 5 batches of the generated images.
63 | eval_fn = compute_msssim(
64 | generated_images=fake_images_batch, num_batches=5)
65 | with tf.train.MonitoredTrainingSession() as sess:
66 | score = eval_fn(sess)
67 | return score
68 |
69 |
70 | def compute_msssim(generated_images, num_batches):
71 | """Get a fn returning the ms ssim score for generated images.
72 |
73 | Args:
74 | generated_images: TF Tensor of shape [batch_size, dim, dim, 3] which
75 | evaluates to a batch of generated images. Should be in range [0..255].
76 | num_batches: Number of batches to consider.
77 |
78 | Returns:
79 | eval_fn: a function which takes a session as an argument and returns the
80 | average ms ssim score among all the possible image pairs from
81 | generated_images.
82 | """
83 | batch_size = int(generated_images.get_shape()[0])
84 | assert batch_size > 1
85 |
86 | # Generate all possible image pairs from input set of imgs.
87 | pair1 = tf.tile(generated_images, [batch_size, 1, 1, 1])
88 | pair2 = tf.reshape(
89 | tf.tile(generated_images, [1, batch_size, 1, 1]), [
90 | batch_size * batch_size, generated_images.shape[1],
91 | generated_images.shape[2], generated_images.shape[3]
92 | ])
93 |
94 | # Compute the mean of the scores (but ignore the 'identical' images - which
95 | # should get 1.0 from the MultiscaleSSIM)
96 | score = tf.reduce_sum(image_similarity.multiscale_ssim(pair1, pair2))
97 | score -= batch_size
98 | score = tf.div(score, batch_size * batch_size - batch_size)
99 |
100 | # Define a function which wraps some session.run calls to generate a large
101 | # number of images and compute multiscale ssim metric on them.
102 | def _eval_fn(session):
103 | """Function which wraps session.run calls to compute given metric."""
104 | logging.info("Computing MS-SSIM score...")
105 | scores = []
106 | for _ in range(num_batches):
107 | scores.append(session.run(score))
108 |
109 | result = np.mean(scores)
110 | return result
111 | return _eval_fn
112 |
--------------------------------------------------------------------------------
/compare_gan/metrics/ms_ssim_score_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for the MS-SSIM score."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from compare_gan.metrics import ms_ssim_score
23 | import tensorflow as tf
24 |
25 |
26 | class MsSsimScoreTest(tf.test.TestCase):
27 |
28 | def test_on_one_vs_07_vs_zero_images(self):
29 | """Computes the SSIM value for 3 simple images."""
30 | with tf.Graph().as_default():
31 | generated_images = tf.stack([
32 | tf.ones([64, 64, 3]),
33 | tf.ones([64, 64, 3]) * 0.7,
34 | tf.zeros([64, 64, 3]),
35 | ])
36 | metric = ms_ssim_score.compute_msssim(generated_images, 1)
37 | with tf.Session() as sess:
38 | sess.run(tf.global_variables_initializer())
39 | result = metric(sess)
40 | self.assertNear(result, 0.989989, 0.001)
41 |
42 |
43 | if __name__ == '__main__':
44 | tf.test.main()
45 |
--------------------------------------------------------------------------------
/compare_gan/metrics/prd_score_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Testing precision and recall computation on synthetic data."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import unittest
23 | from compare_gan.metrics import prd_score as prd
24 | import numpy as np
25 |
26 |
27 | class PRDTest(unittest.TestCase):
28 |
29 | def test_compute_prd_no_overlap(self):
30 | eval_dist = [0, 1]
31 | ref_dist = [1, 0]
32 | result = np.ravel(prd.compute_prd(eval_dist, ref_dist))
33 | np.testing.assert_almost_equal(result, 0)
34 |
35 | def test_compute_prd_perfect_overlap(self):
36 | eval_dist = [1, 0]
37 | ref_dist = [1, 0]
38 | result = prd.compute_prd(eval_dist, ref_dist, num_angles=11)
39 | np.testing.assert_almost_equal([result[0][5], result[1][5]], [1, 1])
40 |
41 | def test_compute_prd_low_precision_high_recall(self):
42 | eval_dist = [0.5, 0.5]
43 | ref_dist = [1, 0]
44 | result = prd.compute_prd(eval_dist, ref_dist, num_angles=11)
45 | np.testing.assert_almost_equal(result[0][5], 0.5)
46 | np.testing.assert_almost_equal(result[1][5], 0.5)
47 | np.testing.assert_almost_equal(result[0][10], 0.5)
48 | np.testing.assert_almost_equal(result[1][1], 1)
49 |
50 | def test_compute_prd_high_precision_low_recall(self):
51 | eval_dist = [1, 0]
52 | ref_dist = [0.5, 0.5]
53 | result = prd.compute_prd(eval_dist, ref_dist, num_angles=11)
54 | np.testing.assert_almost_equal([result[0][5], result[1][5]], [0.5, 0.5])
55 | np.testing.assert_almost_equal(result[1][1], 0.5)
56 | np.testing.assert_almost_equal(result[0][10], 1)
57 |
58 | def test_compute_prd_bad_epsilon(self):
59 | with self.assertRaises(ValueError):
60 | prd.compute_prd([1], [1], epsilon=0)
61 | with self.assertRaises(ValueError):
62 | prd.compute_prd([1], [1], epsilon=1)
63 | with self.assertRaises(ValueError):
64 | prd.compute_prd([1], [1], epsilon=-1)
65 |
66 | def test_compute_prd_bad_num_angles(self):
67 | with self.assertRaises(ValueError):
68 | prd.compute_prd([1], [1], num_angles=0)
69 | with self.assertRaises(ValueError):
70 | prd.compute_prd([1], [1], num_angles=1)
71 | with self.assertRaises(ValueError):
72 | prd.compute_prd([1], [1], num_angles=-1)
73 | with self.assertRaises(ValueError):
74 | prd.compute_prd([1], [1], num_angles=1e6+1)
75 | with self.assertRaises(ValueError):
76 | prd.compute_prd([1], [1], num_angles=2.5)
77 |
78 | def test__cluster_into_bins(self):
79 | eval_data = np.zeros([5, 4])
80 | ref_data = np.ones([5, 4])
81 | result = prd._cluster_into_bins(eval_data, ref_data, 3)
82 |
83 | self.assertEqual(len(result), 2)
84 | self.assertEqual(len(result[0]), 3)
85 | self.assertEqual(len(result[1]), 3)
86 | np.testing.assert_almost_equal(sum(result[0]), 1)
87 | np.testing.assert_almost_equal(sum(result[1]), 1)
88 |
89 | def test_compute_prd_from_embedding_mismatch_num_samples_should_fail(self):
90 | # Mismatch in number of samples with enforce_balance set to True
91 | with self.assertRaises(ValueError):
92 | prd.compute_prd_from_embedding(
93 | np.array([[0], [0], [1]]), np.array([[0], [1]]), num_clusters=2,
94 | enforce_balance=True)
95 |
96 | def test_compute_prd_from_embedding_mismatch_num_samples_should_work(self):
97 | # Mismatch in number of samples with enforce_balance set to False
98 | try:
99 | prd.compute_prd_from_embedding(
100 | np.array([[0], [0], [1]]), np.array([[0], [1]]), num_clusters=2,
101 | enforce_balance=False)
102 | except ValueError:
103 | self.fail(
104 | 'compute_prd_from_embedding should not raise a ValueError when '
105 | 'enforce_balance is set to False.')
106 |
107 | def test__prd_to_f_beta_correct_computation(self):
108 | precision = np.array([1, 1, 0, 0, 0.5, 1, 0.5])
109 | recall = np.array([1, 0, 1, 0, 0.5, 0.5, 1])
110 | expected = np.array([1, 0, 0, 0, 0.5, 2/3, 2/3])
111 | with np.errstate(invalid='ignore'):
112 | result = prd._prd_to_f_beta(precision, recall, beta=1)
113 | np.testing.assert_almost_equal(result, expected)
114 |
115 | expected = np.array([1, 0, 0, 0, 0.5, 5/9, 5/6])
116 | with np.errstate(invalid='ignore'):
117 | result = prd._prd_to_f_beta(precision, recall, beta=2)
118 | np.testing.assert_almost_equal(result, expected)
119 |
120 | expected = np.array([1, 0, 0, 0, 0.5, 5/6, 5/9])
121 | with np.errstate(invalid='ignore'):
122 | result = prd._prd_to_f_beta(precision, recall, beta=1/2)
123 | np.testing.assert_almost_equal(result, expected)
124 |
125 | result = prd._prd_to_f_beta(np.array([]), np.array([]), beta=1)
126 | expected = np.array([])
127 | np.testing.assert_almost_equal(result, expected)
128 |
129 | def test__prd_to_f_beta_bad_beta(self):
130 | with self.assertRaises(ValueError):
131 | prd._prd_to_f_beta(np.ones(1), np.ones(1), beta=0)
132 | with self.assertRaises(ValueError):
133 | prd._prd_to_f_beta(np.ones(1), np.ones(1), beta=-3)
134 |
135 | def test__prd_to_f_beta_bad_precision_or_recall(self):
136 | with self.assertRaises(ValueError):
137 | prd._prd_to_f_beta(-np.ones(1), np.ones(1), beta=1)
138 | with self.assertRaises(ValueError):
139 | prd._prd_to_f_beta(np.ones(1), -np.ones(1), beta=1)
140 |
141 | def test_plot_not_enough_labels(self):
142 | with self.assertRaises(ValueError):
143 | prd.plot(np.zeros([3, 2, 5]), labels=['1', '2'])
144 |
145 | def test_plot_too_many_labels(self):
146 | with self.assertRaises(ValueError):
147 | prd.plot(np.zeros([1, 2, 5]), labels=['1', '2', '3'])
148 |
149 |
150 | if __name__ == '__main__':
151 | unittest.main()
152 |
--------------------------------------------------------------------------------
/compare_gan/test_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility classes and methods for testing."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import datetime
23 | import os
24 |
25 | from absl import flags
26 | from compare_gan import eval_utils
27 | from compare_gan.architectures import abstract_arch
28 | from compare_gan.architectures import arch_ops
29 | import gin
30 | import mock
31 | import numpy as np
32 | import tensorflow as tf
33 |
34 | FLAGS = flags.FLAGS
35 |
36 |
37 | def create_fake_inception_graph():
38 | """Creates a graph that mocks inception.
39 |
40 | It takes the input, multiplies it through a matrix full of 0.001 values
41 | and returns as logits. It makes sure to match the tensor names of
42 | the real inception model.
43 |
44 | Returns:
45 | tf.Graph object with a simple mock inception inside.
46 | """
47 | fake_inception = tf.Graph()
48 | with fake_inception.as_default():
49 | graph_input = tf.placeholder(
50 | tf.float32, shape=[None, 299, 299, 3], name="Mul")
51 | matrix = tf.ones(shape=[299 * 299 * 3, 10]) * 0.00001
52 | output = tf.matmul(tf.layers.flatten(graph_input), matrix)
53 | output = tf.identity(output, name="pool_3")
54 | output = tf.identity(output, name="logits")
55 | return fake_inception.as_graph_def()
56 |
57 |
58 | class Generator(abstract_arch.AbstractGenerator):
59 | """Generator with a single linear layer from z to the output."""
60 |
61 | def __init__(self, **kwargs):
62 | super(Generator, self).__init__(**kwargs)
63 | self.call_arg_list = []
64 |
65 | def apply(self, z, y, is_training):
66 | self.call_arg_list.append(dict(z=z, y=y, is_training=is_training))
67 | batch_size = z.shape[0].value
68 | out = arch_ops.linear(z, np.prod(self._image_shape), scope="fc_noise")
69 | out = tf.nn.sigmoid(out)
70 | return tf.reshape(out, [batch_size] + list(self._image_shape))
71 |
72 |
73 | class Discriminator(abstract_arch.AbstractDiscriminator):
74 | """Discriminator with a single linear layer."""
75 |
76 | def __init__(self, **kwargs):
77 | super(Discriminator, self).__init__(**kwargs)
78 | self.call_arg_list = []
79 |
80 | def apply(self, x, y, is_training):
81 | self.call_arg_list.append(dict(x=x, y=y, is_training=is_training))
82 | h = tf.reduce_mean(x, axis=[1, 2])
83 | out = arch_ops.linear(h, 1)
84 | return tf.nn.sigmoid(out), out, h
85 |
86 |
87 | class CompareGanTestCase(tf.test.TestCase):
88 | """Base class for test cases."""
89 |
90 | def setUp(self):
91 | super(CompareGanTestCase, self).setUp()
92 | # Use fake datasets instead of reading real files.
93 | FLAGS.data_fake_dataset = True
94 | # Clear the gin cofiguration.
95 | gin.clear_config()
96 | # Mock the inception graph.
97 | fake_inception_graph = create_fake_inception_graph()
98 | self.inception_graph_def_mock = mock.patch.object(
99 | eval_utils,
100 | "get_inception_graph_def",
101 | return_value=fake_inception_graph).start()
102 |
103 | def _get_empty_model_dir(self):
104 | unused_sub_dir = str(datetime.datetime.now().microsecond)
105 | model_dir = os.path.join(FLAGS.test_tmpdir, unused_sub_dir)
106 | assert not tf.gfile.Exists(model_dir)
107 | return model_dir
108 |
--------------------------------------------------------------------------------
/compare_gan/tpu/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # coding=utf-8
17 |
--------------------------------------------------------------------------------
/compare_gan/tpu/tpu_ops.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tensorflow operations specific to TPUs."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import gin
23 | from six.moves import range
24 | import tensorflow as tf
25 |
26 | from tensorflow.contrib.tpu.python.tpu import tpu_function
27 |
28 |
29 | def cross_replica_concat(value, replica_id, num_replicas):
30 | """Reduce a concatenation of the `value` across TPU replicas.
31 |
32 | Args:
33 | value: Tensor to concatenate.
34 | replica_id: Integer tensor that indicates the index of the replica.
35 | num_replicas: Python integer, total number of replicas.
36 |
37 | Returns:
38 | Tensor of the same rank as value with first dimension `num_replicas`
39 | times larger.
40 |
41 | Raises:
42 | ValueError: If `value` is a scalar.
43 | """
44 | if value.shape.ndims < 1:
45 | raise ValueError("Value must have at least rank 1 but got {}.".format(
46 | value.shape.ndims))
47 | if num_replicas <= 1:
48 | return value
49 | with tf.name_scope(None, "tpu_cross_replica_concat"):
50 | # Mask is one hot encoded position of the core_index.
51 | mask = tf.to_float(tf.equal(tf.range(num_replicas), replica_id))
52 | # Expand dims with 1's to match rank of value.
53 | mask = tf.reshape(mask, [num_replicas] + [1] * value.shape.ndims)
54 | if value.dtype in {tf.bfloat16, tf.float32}:
55 | result = mask * value
56 | else:
57 | result = mask * tf.to_float(value)
58 | # Thanks to broadcasting now result is set only in the position pointed by
59 | # replica_id, the rest of the vector is set to 0's.
60 | # All these steps are basically implementing tf.scatter_nd which is missing
61 | # in TPU's backend since it doesn't support sparse operations.
62 |
63 | # Merge first 2 dimensions.
64 | # This is equivalent to (value.shape[0].value * num_replicas).
65 | # Using [-1] trick to support also scalar input.
66 | result = tf.reshape(result, [-1] + result.shape.as_list()[2:])
67 | # Each core set the "results" in position pointed by replica_id. When we now
68 | # sum across replicas we exchange the information and fill in local 0's with
69 | # values from other cores.
70 | result = tf.contrib.tpu.cross_replica_sum(result)
71 | # Now all the cores see exactly the same data.
72 | return tf.cast(result, dtype=value.dtype)
73 |
74 |
75 | def cross_replica_mean(inputs, group_size=None):
76 | """Calculates the average value of inputs tensor across TPU replicas."""
77 | num_replicas = tpu_function.get_tpu_context().number_of_shards
78 | if not group_size:
79 | group_size = num_replicas
80 | if group_size == 1:
81 | return inputs
82 | if group_size != num_replicas:
83 | group_assignment = []
84 | assert num_replicas % group_size == 0
85 | for g in range(num_replicas // group_size):
86 | replica_ids = [g * group_size + i for i in range(group_size)]
87 | group_assignment.append(replica_ids)
88 | else:
89 | group_assignment = None
90 | return tf.contrib.tpu.cross_replica_sum(inputs, group_assignment) / tf.cast(
91 | group_size, inputs.dtype)
92 |
93 |
94 | @gin.configurable(blacklist=["inputs", "axis"])
95 | def cross_replica_moments(inputs, axis, parallel=True, group_size=None):
96 | """Compute mean and variance of the inputs tensor across TPU replicas.
97 |
98 | Args:
99 | inputs: A tensor with 2 or more dimensions.
100 | axis: Array of ints. Axes along which to compute mean and variance.
101 | parallel: Use E[x^2] - (E[x])^2 to compute variance. Then can be done
102 | in parallel to computing the mean and reducing the communication overhead.
103 | group_size: Integer, the number of replicas to compute moments arcoss.
104 | None or 0 will use all replicas (global).
105 |
106 | Returns:
107 | Two tensors with mean and variance.
108 | """
109 | # Compute local mean and then average across replicas.
110 | mean = tf.math.reduce_mean(inputs, axis=axis)
111 | mean = cross_replica_mean(mean)
112 | if parallel:
113 | # Compute variance using the E[x^2] - (E[x])^2 formula. This is less
114 | # numerically stable than the E[(x-E[x])^2] formula, but allows the two
115 | # cross-replica sums to be computed in parallel, saving communication
116 | # overhead.
117 | mean_of_squares = tf.reduce_mean(tf.square(inputs), axis=axis)
118 | mean_of_squares = cross_replica_mean(mean_of_squares, group_size=group_size)
119 | mean_squared = tf.square(mean)
120 | variance = mean_of_squares - mean_squared
121 | else:
122 | variance = tf.math.reduce_mean(
123 | tf.math.square(inputs - mean), axis=axis)
124 | variance = cross_replica_mean(variance, group_size=group_size)
125 | return mean, variance
126 |
--------------------------------------------------------------------------------
/compare_gan/tpu/tpu_ops_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests custom TensorFlow operations for TPU."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import logging
23 | from absl.testing import parameterized
24 | from compare_gan.tpu import tpu_ops
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 |
29 | class TpuOpsTpuTest(parameterized.TestCase, tf.test.TestCase):
30 |
31 | def testRunsOnTpu(self):
32 | """Verify that the test cases runs on a TPU chip and has 2 cores."""
33 | expected_device_names = [
34 | "/job:localhost/replica:0/task:0/device:CPU:0",
35 | "/job:localhost/replica:0/task:0/device:TPU:0",
36 | "/job:localhost/replica:0/task:0/device:TPU:1",
37 | "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
38 | ]
39 | with self.session() as sess:
40 | devices = sess.list_devices()
41 | logging.info("devices:\n%s", "\n".join([str(d) for d in devices]))
42 | self.assertAllEqual([d.name for d in devices], expected_device_names)
43 |
44 | def testCrossReplicaConcat(self):
45 | def computation(x, replica_id):
46 | logging.info("x: %s\nreplica_id: %s", x, replica_id[0])
47 | return tpu_ops.cross_replica_concat(x, replica_id[0], num_replicas=2)
48 |
49 | inputs = np.asarray([[3, 4], [1, 5]])
50 | expected_output = np.asarray([[3, 4], [1, 5], [3, 4], [1, 5]])
51 |
52 | with tf.Graph().as_default():
53 | x = tf.constant(inputs)
54 | replica_ids = tf.constant([0, 1], dtype=tf.int32)
55 | x_concat, = tf.contrib.tpu.batch_parallel(
56 | computation, [x, replica_ids], num_shards=2)
57 | self.assertAllEqual(x.shape.as_list(), [2, 2])
58 | self.assertAllEqual(x_concat.shape.as_list(), [4, 2])
59 |
60 | with self.session() as sess:
61 | sess.run(tf.contrib.tpu.initialize_system())
62 | sess.run(tf.global_variables_initializer())
63 | x_concat = sess.run(x_concat)
64 | logging.info("x_concat: %s", x_concat)
65 | self.assertAllClose(x_concat, expected_output)
66 |
67 | # Test with group size 2 (test case has 2 cores, so this global batch norm).
68 | @parameterized.parameters(
69 | {"group_size": None}, # Defaults to number of TPU cores.
70 | {"group_size": 0}, # Defaults to number of TPU cores.
71 | {"group_size": 2},
72 | )
73 | def testCrossReplicaMean(self, group_size):
74 | # Verify that we average across replicas by feeding 2 vectors to the system.
75 | # Each replica should get one vector which is then averaged across
76 | # all replicas and simply returned.
77 | # After that each replica has the same vector and since the outputs gets
78 | # concatenated we see the same vector twice.
79 | inputs = np.asarray(
80 | [[0.55, 0.70, -1.29, 0.502], [0.57, 0.90, 1.290, 0.202]],
81 | dtype=np.float32)
82 | expected_output = np.asarray(
83 | [[0.56, 0.8, 0.0, 0.352], [0.56, 0.8, 0.0, 0.352]], dtype=np.float32)
84 |
85 | def computation(x):
86 | self.assertAllEqual(x.shape.as_list(), [1, 4])
87 | return tpu_ops.cross_replica_mean(x, group_size=group_size)
88 |
89 | with tf.Graph().as_default():
90 | # Note: Using placeholders for feeding TPUs is discouraged but fine for
91 | # a simple test case.
92 | x = tf.placeholder(name="x", dtype=tf.float32, shape=inputs.shape)
93 | y = tf.contrib.tpu.batch_parallel(computation, inputs=[x], num_shards=2)
94 | with self.session() as sess:
95 | sess.run(tf.contrib.tpu.initialize_system())
96 | # y is actually a list with one tensor. computation would be allowed
97 | # to return multiple tensors (and ops).
98 | actual_output = sess.run(y, {x: inputs})[0]
99 |
100 | self.assertAllEqual(actual_output.shape, (2, 4))
101 | self.assertAllClose(actual_output, expected_output)
102 |
103 | def testCrossReplicaMeanGroupSizeOne(self, group_size=1):
104 | # Since the group size is 1 we only average over 1 replica.
105 | inputs = np.asarray(
106 | [[0.55, 0.70, -1.29, 0.502], [0.57, 0.90, 1.290, 0.202]],
107 | dtype=np.float32)
108 | expected_output = np.asarray(
109 | [[0.55, 0.7, -1.29, 0.502], [0.57, 0.9, 1.290, 0.202]],
110 | dtype=np.float32)
111 |
112 | def computation(x):
113 | self.assertAllEqual(x.shape.as_list(), [1, 4])
114 | return tpu_ops.cross_replica_mean(x, group_size=group_size)
115 |
116 | with tf.Graph().as_default():
117 | # Note: Using placeholders for feeding TPUs is discouraged but fine for
118 | # a simple test case.
119 | x = tf.placeholder(name="x", dtype=tf.float32, shape=inputs.shape)
120 | y = tf.contrib.tpu.batch_parallel(computation, inputs=[x], num_shards=2)
121 | with self.session() as sess:
122 | sess.run(tf.contrib.tpu.initialize_system())
123 | # y is actually a list with one tensor. computation would be allowed
124 | # to return multiple tensors (and ops).
125 | actual_output = sess.run(y, {x: inputs})[0]
126 |
127 | self.assertAllEqual(actual_output.shape, (2, 4))
128 | self.assertAllClose(actual_output, expected_output)
129 |
130 |
131 | if __name__ == "__main__":
132 | tf.test.main()
133 |
--------------------------------------------------------------------------------
/compare_gan/tpu/tpu_summaries.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Provide a helper class for using summaries on TPU via a host call.
17 |
18 | TPUEstimator does not support writing TF summaries out of the box and TPUs can't
19 | perform operations that write files to disk. To monitor tensor values during
20 | training you can copy the tensors back to the CPU of the host machine via
21 | a host call function. This small library provides a convienent API to do this.
22 |
23 | Example:
24 | from compare_gan.tpu import tpu_summaries
25 | def model_fn(features, labels, params, mode):
26 | summary = tpu_summries.TpuSummaries(my_model_dir)
27 |
28 | summary.scalar("my_scalar_summary", tensor1)
29 | summary.scalar("my_counter", tensor2, reduce_fn=tf.math.reduce_sum)
30 |
31 | return TPUEstimatorSpec(
32 | host_call=summary.get_host_call(),
33 | ...)
34 |
35 | Warning: The host call function will run every step. Writing large tensors to
36 | summaries can slow down your training. High ranking outfeed operations in your
37 | XProf profile can be an indication for this.
38 | """
39 |
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 |
44 | import collections
45 |
46 | from absl import logging
47 | import tensorflow as tf
48 |
49 |
50 | summary = tf.contrib.summary # TensorFlow Summary API v2.
51 |
52 |
53 | TpuSummaryEntry = collections.namedtuple(
54 | "TpuSummaryEntry", "summary_fn name tensor reduce_fn")
55 |
56 |
57 | class TpuSummaries(object):
58 | """Class to simplify TF summaries on TPU.
59 |
60 | An instance of the class provides simple methods for writing summaries in the
61 | similar way to tf.summary. The difference is that each summary entry must
62 | provide a reduction function that is used to reduce the summary values from
63 | all the TPU cores.
64 | """
65 |
66 | def __init__(self, log_dir, save_summary_steps=250):
67 | self._log_dir = log_dir
68 | self._entries = []
69 | # While False no summary entries will be added. On TPU we unroll the graph
70 | # and don't want to add multiple summaries per step.
71 | self.record = True
72 | self._save_summary_steps = save_summary_steps
73 |
74 | def image(self, name, tensor, reduce_fn):
75 | """Add a summary for images. Tensor must be of 4-D tensor."""
76 | if not self.record:
77 | return
78 | self._entries.append(
79 | TpuSummaryEntry(summary.image, name, tensor, reduce_fn))
80 |
81 | def scalar(self, name, tensor, reduce_fn=tf.math.reduce_mean):
82 | """Add a summary for a scalar tensor."""
83 | if not self.record:
84 | return
85 | tensor = tf.convert_to_tensor(tensor)
86 | if tensor.shape.ndims == 0:
87 | tensor = tf.expand_dims(tensor, 0)
88 | self._entries.append(
89 | TpuSummaryEntry(summary.scalar, name, tensor, reduce_fn))
90 |
91 | def get_host_call(self):
92 | """Returns the tuple (host_call_fn, host_call_args) for TPUEstimatorSpec."""
93 | # All host_call_args must be tensors with batch dimension.
94 | # All tensors are streamed to the host machine (mind the band width).
95 | global_step = tf.train.get_or_create_global_step()
96 | host_call_args = [tf.expand_dims(global_step, 0)]
97 | host_call_args.extend([e.tensor for e in self._entries])
98 | logging.info("host_call_args: %s", host_call_args)
99 | return (self._host_call_fn, host_call_args)
100 |
101 | def _host_call_fn(self, step, *args):
102 | """Function that will run on the host machine."""
103 | # Host call receives values from all tensor cores (concatenate on the
104 | # batch dimension). Step is the same for all cores.
105 | step = step[0]
106 | logging.info("host_call_fn: args=%s", args)
107 | with summary.create_file_writer(self._log_dir).as_default():
108 | with summary.record_summaries_every_n_global_steps(
109 | self._save_summary_steps, step):
110 | for i, e in enumerate(self._entries):
111 | value = e.reduce_fn(args[i])
112 | e.summary_fn(e.name, value, step=step)
113 | return summary.all_summary_ops()
114 |
--------------------------------------------------------------------------------
/compare_gan/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utilities library."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 | import functools
24 | import inspect
25 |
26 | from absl import logging
27 | import six
28 |
29 |
30 | # In Python 2 the inspect module does not have FullArgSpec. Define a named tuple
31 | # instead.
32 | if hasattr(inspect, "FullArgSpec"):
33 | _FullArgSpec = inspect.FullArgSpec # pylint: disable=invalid-name
34 | else:
35 | _FullArgSpec = collections.namedtuple("FullArgSpec", [
36 | "args", "varargs", "varkw", "defaults", "kwonlyargs", "kwonlydefaults",
37 | "annotations"
38 | ])
39 |
40 |
41 | def _getfullargspec(fn):
42 | """Python 2/3 compatible version of the inspect.getfullargspec method.
43 |
44 | Args:
45 | fn: The function object.
46 |
47 | Returns:
48 | A FullArgSpec. For Python 2 this is emulated by a named tuple.
49 | """
50 | arg_spec_fn = inspect.getfullargspec if six.PY3 else inspect.getargspec
51 | try:
52 | arg_spec = arg_spec_fn(fn)
53 | except TypeError:
54 | # `fn` might be a callable object.
55 | arg_spec = arg_spec_fn(fn.__call__)
56 | if six.PY3:
57 | assert isinstance(arg_spec, _FullArgSpec)
58 | return arg_spec
59 | return _FullArgSpec(
60 | args=arg_spec.args,
61 | varargs=arg_spec.varargs,
62 | varkw=arg_spec.keywords,
63 | defaults=arg_spec.defaults,
64 | kwonlyargs=[],
65 | kwonlydefaults=None,
66 | annotations={})
67 |
68 |
69 | def _has_arg(fn, arg_name):
70 | """Returns True if `arg_name` might be a valid parameter for `fn`.
71 |
72 | Specifically, this means that `fn` either has a parameter named
73 | `arg_name`, or has a `**kwargs` parameter.
74 |
75 | Args:
76 | fn: The function to check.
77 | arg_name: The name fo the parameter.
78 |
79 | Returns:
80 | Whether `arg_name` might be a valid argument of `fn`.
81 | """
82 | while isinstance(fn, functools.partial):
83 | fn = fn.func
84 | while hasattr(fn, "__wrapped__"):
85 | fn = fn.__wrapped__
86 | arg_spec = _getfullargspec(fn)
87 | if arg_spec.varkw:
88 | return True
89 | return arg_name in arg_spec.args or arg_name in arg_spec.kwonlyargs
90 |
91 |
92 | def call_with_accepted_args(fn, **kwargs):
93 | """Calls `fn` only with the keyword arguments that `fn` accepts."""
94 | kwargs = {k: v for k, v in six.iteritems(kwargs) if _has_arg(fn, k)}
95 | logging.debug("Calling %s with args %s.", fn, kwargs)
96 | return fn(**kwargs)
97 |
98 |
99 | def get_parameter_overview(variables, limit=40):
100 | """Returns a string with variables names, their shapes, count, and types.
101 |
102 | To get all trainable parameters pass in `tf.trainable_variables()`.
103 |
104 | Args:
105 | variables: List of `tf.Variable`(s).
106 | limit: If not `None`, the maximum number of variables to include.
107 |
108 | Returns:
109 | A string with a table like in the example.
110 |
111 | +----------------+---------------+------------+---------+
112 | | Name | Shape | Size | Type |
113 | +----------------+---------------+------------+---------+
114 | | FC_1/weights:0 | (63612, 1024) | 65,138,688 | float32 |
115 | | FC_1/biases:0 | (1024,) | 1,024 | float32 |
116 | | FC_2/weights:0 | (1024, 32) | 32,768 | float32 |
117 | | FC_2/biases:0 | (32,) | 32 | float32 |
118 | +----------------+---------------+------------+---------+
119 |
120 | Total: 65,172,512
121 | """
122 | max_name_len = max([len(v.name) for v in variables] + [len("Name")])
123 | max_shape_len = max([len(str(v.get_shape())) for v in variables] + [len(
124 | "Shape")])
125 | max_size_len = max([len("{:,}".format(v.get_shape().num_elements()))
126 | for v in variables] + [len("Size")])
127 | max_type_len = max([len(v.dtype.base_dtype.name) for v in variables] + [len(
128 | "Type")])
129 |
130 | var_line_format = "| {: <{}s} | {: >{}s} | {: >{}s} | {: <{}s} |"
131 | sep_line_format = var_line_format.replace(" ", "-").replace("|", "+")
132 |
133 | header = var_line_format.replace(">", "<").format("Name", max_name_len,
134 | "Shape", max_shape_len,
135 | "Size", max_size_len,
136 | "Type", max_type_len)
137 | separator = sep_line_format.format("", max_name_len, "", max_shape_len, "",
138 | max_size_len, "", max_type_len)
139 |
140 | lines = [separator, header, separator]
141 |
142 | total_weights = sum(v.get_shape().num_elements() for v in variables)
143 |
144 | # Create lines for up to 80 variables.
145 | for v in variables:
146 | if limit is not None and len(lines) >= limit:
147 | lines.append("[...]")
148 | break
149 | lines.append(var_line_format.format(
150 | v.name, max_name_len,
151 | str(v.get_shape()), max_shape_len,
152 | "{:,}".format(v.get_shape().num_elements()), max_size_len,
153 | v.dtype.base_dtype.name, max_type_len))
154 |
155 | lines.append(separator)
156 | lines.append("Total: {:,}".format(total_weights))
157 |
158 | return "\n".join(lines)
159 |
160 |
161 | def log_parameter_overview(variables, msg):
162 | """Writes a table with variables name and shapes to INFO log.
163 |
164 | See get_parameter_overview for details.
165 |
166 | Args:
167 | variables: List of `tf.Variable`(s).
168 | msg: Message to be logged before the table.
169 | """
170 | table = get_parameter_overview(variables, limit=None)
171 | # The table can to large to fit into one log entry.
172 | lines = [msg] + table.split("\n")
173 | for i in range(0, len(lines), 80):
174 | logging.info("\n%s", "\n".join(lines[i:i + 80]))
175 |
--------------------------------------------------------------------------------
/example_configs/README.md:
--------------------------------------------------------------------------------
1 | This folder contains configurations of popular GANs and links to the
2 | corresponding papers.
3 |
4 | **Please note that we provide them as a starting point for the user and that
5 | they represent the best-effort implementation** -- we do not guarantee that all
6 | the implementation details match exactly due to the vast number of design
7 | options. Given the sensitivity of GANs to design choices, hyperparameters and
8 | differences in platforms (GPUs/TPUs), a small differences might have a
9 | significant impact on the final results.
10 |
--------------------------------------------------------------------------------
/example_configs/biggan_imagenet128.gin:
--------------------------------------------------------------------------------
1 | # BigGAN architecture and settings on ImageNet 128.
2 | # http://arxiv.org/abs/1809.11096
3 |
4 | # This should be similar to row 7 in Table 1.
5 | # It does not include orthogonal regularization (which would be row 8) and uses
6 | # a different learning rate.
7 |
8 | # Recommended training platform: TPU v3-128.
9 |
10 | dataset.name = "imagenet_128"
11 | options.z_dim = 120
12 |
13 | options.architecture = "resnet_biggan_arch"
14 | ModularGAN.conditional = True
15 | options.batch_size = 2048
16 | options.gan_class = @ModularGAN
17 | options.lamba = 1
18 | options.training_steps = 250000
19 | weights.initializer = "orthogonal"
20 | spectral_norm.singular_value = "auto"
21 |
22 | # Generator
23 | G.batch_norm_fn = @conditional_batch_norm
24 | G.spectral_norm = True
25 | ModularGAN.g_use_ema = True
26 | resnet_biggan.Generator.hierarchical_z = True
27 | resnet_biggan.Generator.embed_y = True
28 | standardize_batch.decay = 0.9
29 | standardize_batch.epsilon = 1e-5
30 | standardize_batch.use_moving_averages = False
31 |
32 | # Discriminator
33 | options.disc_iters = 2
34 | D.spectral_norm = True
35 | resnet_biggan.Discriminator.project_y = True
36 |
37 | # Loss and optimizer
38 | loss.fn = @hinge
39 | penalty.fn = @no_penalty
40 | ModularGAN.g_lr = 0.0001
41 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer
42 | ModularGAN.d_lr = 0.0005
43 | ModularGAN.d_optimizer_fn = @tf.train.AdamOptimizer
44 | tf.train.AdamOptimizer.beta1 = 0.0
45 | tf.train.AdamOptimizer.beta2 = 0.999
46 |
47 | z.distribution_fn = @tf.random.normal
48 | eval_z.distribution_fn = @tf.random.normal
49 |
50 | run_config.iterations_per_loop = 500
51 | run_config.save_checkpoints_steps = 2500
52 |
--------------------------------------------------------------------------------
/example_configs/dcgan_celeba64.gin:
--------------------------------------------------------------------------------
1 |
2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8
3 |
4 | dataset.name = "celeb_a"
5 | options.architecture = "dcgan_arch"
6 | options.batch_size = 64
7 | options.gan_class = @ModularGAN
8 | options.lamba = 1
9 | options.training_steps = 100000
10 | options.z_dim = 128
11 |
12 | # Generator
13 | G.batch_norm_fn = @batch_norm
14 | standardize_batch.decay = 0.9
15 | standardize_batch.epsilon = 1e-5
16 |
17 | # Discriminator
18 | options.disc_iters = 1
19 | D.spectral_norm = False
20 |
21 | # Loss and optimizer
22 | loss.fn = @non_saturating
23 | penalty.fn = @no_penalty
24 | ModularGAN.g_lr = 0.0002
25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer
26 | tf.train.AdamOptimizer.beta1 = 0.5
27 | tf.train.AdamOptimizer.beta2 = 0.999
28 |
--------------------------------------------------------------------------------
/example_configs/resnet_cifar10.gin:
--------------------------------------------------------------------------------
1 |
2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8
3 |
4 | dataset.name = "cifar10"
5 | options.architecture = "resnet_cifar_arch"
6 | options.batch_size = 64
7 | options.gan_class = @ModularGAN
8 | options.lamba = 1
9 | options.training_steps = 40000
10 | options.z_dim = 128
11 |
12 | # Generator
13 | G.batch_norm_fn = @batch_norm
14 | standardize_batch.decay = 0.9
15 | standardize_batch.epsilon = 1e-5
16 |
17 | # Discriminator
18 | options.disc_iters = 5
19 | D.spectral_norm = True
20 |
21 | # Loss and optimizer
22 | loss.fn = @non_saturating
23 | penalty.fn = @no_penalty
24 | ModularGAN.g_lr = 0.0002
25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer
26 | tf.train.AdamOptimizer.beta1 = 0.5
27 | tf.train.AdamOptimizer.beta2 = 0.999
28 |
--------------------------------------------------------------------------------
/example_configs/resnet_lsun-bedroom128.gin:
--------------------------------------------------------------------------------
1 |
2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8
3 |
4 | dataset.name = "lsun-bedroom"
5 | options.architecture = "resnet5_arch"
6 | options.batch_size = 64
7 | options.gan_class = @ModularGAN
8 | options.lamba = 10
9 | options.training_steps = 40000
10 | options.z_dim = 128
11 |
12 | # Generator
13 | G.batch_norm_fn = @batch_norm
14 | standardize_batch.decay = 0.9
15 | standardize_batch.epsilon = 1e-5
16 |
17 | # Discriminator
18 | options.disc_iters = 5
19 | D.spectral_norm = False
20 |
21 | # Loss and optimizer
22 | loss.fn = @wasserstein
23 | penalty.fn = @wgangp_penalty
24 | ModularGAN.g_lr = 0.0001
25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer
26 | tf.train.AdamOptimizer.beta1 = 0.5
27 | tf.train.AdamOptimizer.beta2 = 0.9
28 |
--------------------------------------------------------------------------------
/example_configs/sndcgan_celebahq128.gin:
--------------------------------------------------------------------------------
1 |
2 | # Recommended training platform: P100, V100, TPU v2-8 or TPU v3-8
3 |
4 | dataset.name = "celeb_a_hq_128"
5 | options.architecture = "sndcgan_arch"
6 | options.batch_size = 64
7 | options.gan_class = @ModularGAN
8 | options.lamba = 1
9 | options.training_steps = 100000
10 | options.z_dim = 128
11 |
12 | # Generator
13 | G.batch_norm_fn = @batch_norm
14 | standardize_batch.decay = 0.9
15 | standardize_batch.epsilon = 1e-5
16 |
17 | # Discriminator
18 | options.disc_iters = 1
19 | D.spectral_norm = True
20 |
21 | # Loss and optimizer
22 | loss.fn = @non_saturating
23 | penalty.fn = @no_penalty
24 | ModularGAN.g_lr = 0.0002
25 | ModularGAN.g_optimizer_fn = @tf.train.AdamOptimizer
26 | tf.train.AdamOptimizer.beta1 = 0.5
27 | tf.train.AdamOptimizer.beta2 = 0.999
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google LLC & Hwalsuk Lee.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Install compare_gan."""
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | setup(
22 | name='compare_gan',
23 | version='3.0',
24 | description=(
25 | 'Compare GAN - A modular library for training and evaluating GANs.'),
26 | author='Google LLC',
27 | author_email='no-reply@google.com',
28 | url='https://github.com/google/compare_gan',
29 | license='Apache 2.0',
30 | packages=find_packages(),
31 | package_data={},
32 | install_requires=[
33 | 'future',
34 | 'gin-config==0.1.4',
35 | 'numpy',
36 | 'pandas',
37 | 'six',
38 | 'tensorflow-datasets==1.0.1',
39 | 'tensorflow-hub>=0.2.0',
40 | 'tensorflow-gan==0.0.0.dev0',
41 | 'matplotlib>=1.5.2',
42 | 'pstar>=0.1.6',
43 | 'scipy>=1.0.0',
44 | ],
45 | extras_require={
46 | 'tf': ['tensorflow>=1.12'],
47 | # Evaluation of Hub modules with EMA variables requires TF > 1.12.
48 | 'tf_gpu': ['tf-nightly-gpu>=1.13.0.dev20190221'],
49 | 'pillow': ['pillow>=5.0.0'],
50 | 'tensorflow-probability': ['tensorflow-probability>=0.5.0'],
51 | },
52 | classifiers=[
53 | 'Development Status :: 4 - Beta',
54 | 'Intended Audience :: Developers',
55 | 'Intended Audience :: Science/Research',
56 | 'License :: OSI Approved :: Apache Software License',
57 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
58 | ],
59 | keywords='tensorflow machine learning gan',
60 | )
61 |
--------------------------------------------------------------------------------