├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── dataset.py
├── discriminator.py
├── eval_imagenet.py
├── generator.py
├── imgs
└── img1.png
├── model.py
├── non_local.py
├── ops.py
├── train_imagenet.py
└── utils_ori.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | .DS_Store
4 |
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Attention GAN
2 | Tensorflow implementation for reproducing main results in the paper [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena.
3 |
4 |
5 |
6 |
7 | ### Dependencies
8 | python 3.6
9 |
10 | TensorFlow 1.5
11 |
12 |
13 | **Data**
14 |
15 | Download Imagenet dataset and preprocess the images into tfrecord files as instructed in [improved gan](https://github.com/openai/improved-gan/blob/master/imagenet/convert_imagenet_to_records.py). Put the tfrecord files into ./data
16 |
17 |
18 | **Training**
19 |
20 | The current batch size is 64x4=256. Larger batch size seems to give better performance. But it might need to find new hyperparameters for G&D learning rate. Note: It usually takes several weeks to train one million steps.
21 |
22 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet.py --generator_type test --discriminator_type test --data_dir ./data
23 |
24 | **Evaluation**
25 |
26 | CUDA_VISIBLE_DEVICES=4 python eval_imagenet.py --generator_type test --data_dir ./data
27 |
28 | ### Citing Self-attention GAN
29 | If you find Self-attention GAN is useful in your research, please consider citing:
30 |
31 | ```
32 | @article{Han18,
33 | author = {Han Zhang and
34 | Ian J. Goodfellow and
35 | Dimitris N. Metaxas and
36 | Augustus Odena},
37 | title = {Self-Attention Generative Adversarial Networks},
38 | year = {2018},
39 | journal = {arXiv:1805.08318},
40 | }
41 | ```
42 |
43 | **References**
44 |
45 | - Spectral Normalization for Generative Adversarial Networks [Paper](https://arxiv.org/abs/1802.05957)
46 | - cGANs with Projection Discriminator [Paper](https://arxiv.org/abs/1802.05637)
47 | - Non-local Neural Networks [Paper](https://arxiv.org/abs/1711.07971)
48 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import tensorflow as tf
17 | import os
18 | IMAGE_SIZE=128
19 |
20 | def _extract_image_and_label(record):
21 | """Extracts and preprocesses the image and label from the record."""
22 | features = tf.parse_single_example(
23 | record,
24 | features={
25 | 'image_raw': tf.FixedLenFeature([], tf.string),
26 | 'label': tf.FixedLenFeature([], tf.int64)
27 | })
28 | image_size = IMAGE_SIZE
29 | image = tf.decode_raw(features['image_raw'], tf.uint8)
30 | image.set_shape(image_size * image_size * 3)
31 | image = tf.reshape(image, [image_size, image_size, 3])
32 |
33 | image = tf.cast(image, tf.float32) * (2. / 255) - 1.
34 |
35 | label = tf.cast(features['label'], tf.int32)
36 |
37 | return image, label
38 |
39 | class InputFunction(object):
40 | """Wrapper class that is passed as callable to Estimator."""
41 |
42 | def __init__(self, is_training, noise_dim, dataset_name, num_classes, data_dir="./dataset",
43 | cycle_length=64, shuffle_buffer_size=100000):
44 | self.is_training = is_training
45 | self.noise_dim = noise_dim
46 | split = ('train' if is_training else 'test')
47 | self.data_files = tf.gfile.Glob(os.path.join(data_dir, '*.tfrecords'))
48 | self.parser = _extract_image_and_label
49 | self.num_classes = num_classes
50 | self.cycle_length = cycle_length
51 | self.shuffle_buffer_size = shuffle_buffer_size
52 |
53 | def __call__(self, params):
54 | """Creates a simple Dataset pipeline."""
55 |
56 | batch_size = params['batch_size']
57 | filename_dataset = tf.data.Dataset.from_tensor_slices(self.data_files)
58 | filename_dataset = filename_dataset.shuffle(len(self.data_files))
59 |
60 | def tfrecord_dataset(filename):
61 | buffer_size = 8 * 1024 * 1224
62 | return tf.data.TFRecordDataset(filename, buffer_size=buffer_size)
63 |
64 | dataset = filename_dataset.apply(tf.contrib.data.parallel_interleave(
65 | tfrecord_dataset,
66 | cycle_length=self.cycle_length, sloppy=True))
67 | if self.is_training:
68 | dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(
69 | self.shuffle_buffer_size, -1))
70 | dataset = dataset.map(self.parser, num_parallel_calls=32)
71 | dataset = dataset.apply(
72 | tf.contrib.data.batch_and_drop_remainder(batch_size))
73 |
74 | dataset = dataset.prefetch(4) # Prefetch overlaps in-feed with training
75 | images, labels = dataset.make_one_shot_iterator().get_next()
76 | labels = tf.squeeze(labels)
77 | random_noise = tf.random_normal([batch_size, self.noise_dim])
78 |
79 | gen_class_logits = tf.zeros((batch_size, self.num_classes))
80 | gen_class_ints = tf.multinomial(gen_class_logits, 1)
81 | gen_sparse_class = tf.squeeze(gen_class_ints)
82 |
83 | features = {
84 | 'real_images': images,
85 | 'random_noise': random_noise,
86 | 'fake_labels': gen_sparse_class}
87 |
88 | return features, labels
89 |
90 |
--------------------------------------------------------------------------------
/discriminator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The discriminator of SNGAN."""
16 |
17 | import tensorflow as tf
18 | import ops
19 | import non_local
20 |
21 |
22 | def dsample(x):
23 | """Downsamples the image by a factor of 2."""
24 |
25 | xd = tf.nn.avg_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
26 | return xd
27 |
28 |
29 | def block(x, out_channels, name, update_collection=None,
30 | downsample=True, act=tf.nn.relu):
31 | """Builds the residual blocks used in the discriminator in SNGAN.
32 |
33 | Args:
34 | x: The 4D input vector.
35 | out_channels: Number of features in the output layer.
36 | name: The variable scope name for the block.
37 | update_collection: The update collections used in the
38 | spectral_normed_weight.
39 | downsample: If True, downsample the spatial size the input tensor by
40 | a factor of 4. If False, the spatial size of the input tensor is
41 | unchanged.
42 | act: The activation function used in the block.
43 | Returns:
44 | A `Tensor` representing the output of the operation.
45 | """
46 | with tf.variable_scope(name):
47 | input_channels = x.shape.as_list()[-1]
48 | x_0 = x
49 | x = act(x)
50 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1,
51 | update_collection=update_collection, name='sn_conv1')
52 | x = act(x)
53 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1,
54 | update_collection=update_collection, name='sn_conv2')
55 | if downsample:
56 | x = dsample(x)
57 | if downsample or input_channels != out_channels:
58 | x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1,
59 | update_collection=update_collection, name='sn_conv3')
60 | if downsample:
61 | x_0 = dsample(x_0)
62 | return x_0 + x
63 |
64 |
65 | def optimized_block(x, out_channels, name,
66 | update_collection=None, act=tf.nn.relu):
67 | """Builds the simplified residual blocks for downsampling.
68 |
69 | Compared with block, optimized_block always downsamples the spatial resolution
70 | of the input vector by a factor of 4.
71 |
72 | Args:
73 | x: The 4D input vector.
74 | out_channels: Number of features in the output layer.
75 | name: The variable scope name for the block.
76 | update_collection: The update collections used in the
77 | spectral_normed_weight.
78 | act: The activation function used in the block.
79 | Returns:
80 | A `Tensor` representing the output of the operation.
81 | """
82 | with tf.variable_scope(name):
83 | x_0 = x
84 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1,
85 | update_collection=update_collection, name='sn_conv1')
86 | x = act(x)
87 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1,
88 | update_collection=update_collection, name='sn_conv2')
89 | x = dsample(x)
90 | x_0 = dsample(x_0)
91 | x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1,
92 | update_collection=update_collection, name='sn_conv3')
93 | return x + x_0
94 |
95 |
96 | def discriminator_old(image, labels, df_dim, number_classes, update_collection=None,
97 | act=tf.nn.relu, scope='Discriminator'):
98 | """Builds the discriminator graph.
99 |
100 | Args:
101 | image: The current batch of images to classify as fake or real.
102 | labels: The corresponding labels for the images.
103 | df_dim: The df dimension.
104 | number_classes: The number of classes in the labels.
105 | update_collection: The update collections used in the
106 | spectral_normed_weight.
107 | act: The activation function used in the discriminator.
108 | scope: Optional scope for `variable_op_scope`.
109 | Returns:
110 | A `Tensor` representing the logits of the discriminator.
111 | """
112 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
113 | h0 = optimized_block(image, df_dim, 'd_optimized_block1',
114 | update_collection, act=act) # 64 * 64
115 | h1 = block(h0, df_dim * 2, 'd_block2',
116 | update_collection, act=act) # 32 * 32
117 | h2 = block(h1, df_dim * 4, 'd_block3',
118 | update_collection, act=act) # 16 * 16
119 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
120 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
121 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
122 | h5_act = act(h5)
123 | h6 = tf.reduce_sum(h5_act, [1, 2])
124 | output = ops.snlinear(h6, 1, update_collection=update_collection,
125 | name='d_sn_linear')
126 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
127 | update_collection=update_collection,
128 | name='d_embedding')
129 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
130 | return output
131 |
132 |
133 | def discriminator(image, labels, df_dim, number_classes, update_collection=None,
134 | act=tf.nn.relu):
135 | """Builds the discriminator graph.
136 |
137 | Args:
138 | image: The current batch of images to classify as fake or real.
139 | labels: The corresponding labels for the images.
140 | df_dim: The df dimension.
141 | number_classes: The number of classes in the labels.
142 | update_collection: The update collections used in the
143 | spectral_normed_weight.
144 | act: The activation function used in the discriminator.
145 | scope: Optional scope for `variable_op_scope`.
146 | Returns:
147 | A `Tensor` representing the logits of the discriminator.
148 | """
149 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
150 | h0 = optimized_block(image, df_dim, 'd_optimized_block1',
151 | update_collection, act=act) # 64 * 64
152 | h1 = block(h0, df_dim * 2, 'd_block2',
153 | update_collection, act=act) # 32 * 32
154 | h2 = block(h1, df_dim * 4, 'd_block3',
155 | update_collection, act=act) # 16 * 16
156 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
157 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
158 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
159 | h5_act = act(h5)
160 | h6 = tf.reduce_sum(h5_act, [1, 2])
161 | output = ops.snlinear(h6, 1, update_collection=update_collection,
162 | name='d_sn_linear')
163 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
164 | update_collection=update_collection,
165 | name='d_embedding')
166 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
167 | print('Discriminator Structure')
168 | return output
169 |
170 | def discriminator_test(image, labels, df_dim, number_classes, update_collection=None,
171 | act=tf.nn.relu):
172 | """Builds the discriminator graph.
173 |
174 | Args:
175 | image: The current batch of images to classify as fake or real.
176 | labels: The corresponding labels for the images.
177 | df_dim: The df dimension.
178 | number_classes: The number of classes in the labels.
179 | update_collection: The update collections used in the
180 | spectral_normed_weight.
181 | act: The activation function used in the discriminator.
182 | scope: Optional scope for `variable_op_scope`.
183 | Returns:
184 | A `Tensor` representing the logits of the discriminator.
185 | """
186 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
187 | h0 = optimized_block(image, df_dim, 'd_optimized_block1',
188 | update_collection, act=act) # 64 * 64
189 | h1 = block(h0, df_dim * 2, 'd_block2',
190 | update_collection, act=act) # 32 * 32
191 | h1 = non_local.sn_non_local_block_sim(h1, update_collection, name='d_non_local') # 32 * 32
192 | h2 = block(h1, df_dim * 4, 'd_block3',
193 | update_collection, act=act) # 16 * 16
194 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
195 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
196 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
197 | h5_act = act(h5)
198 | h6 = tf.reduce_sum(h5_act, [1, 2])
199 | output = ops.snlinear(h6, 1, update_collection=update_collection,
200 | name='d_sn_linear')
201 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
202 | update_collection=update_collection,
203 | name='d_embedding')
204 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
205 | print('Discriminator Test Structure')
206 | return output
207 |
208 | def discriminator_test_64(image, labels, df_dim, number_classes, update_collection=None,
209 | act=tf.nn.relu):
210 | """Builds the discriminator graph.
211 |
212 | Args:
213 | image: The current batch of images to classify as fake or real.
214 | labels: The corresponding labels for the images.
215 | df_dim: The df dimension.
216 | number_classes: The number of classes in the labels.
217 | update_collection: The update collections used in the
218 | spectral_normed_weight.
219 | act: The activation function used in the discriminator.
220 | scope: Optional scope for `variable_op_scope`.
221 | Returns:
222 | A `Tensor` representing the logits of the discriminator.
223 | """
224 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
225 | h0 = optimized_block(image, df_dim, 'd_optimized_block1',
226 | update_collection, act=act) # 64 * 64
227 | h0 = non_local.sn_non_local_block_sim(h0, update_collection, name='d_non_local') # 64 * 64
228 | h1 = block(h0, df_dim * 2, 'd_block2',
229 | update_collection, act=act) # 32 * 32
230 | h2 = block(h1, df_dim * 4, 'd_block3',
231 | update_collection, act=act) # 16 * 16
232 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8
233 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4
234 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act)
235 | h5_act = act(h5)
236 | h6 = tf.reduce_sum(h5_act, [1, 2])
237 | output = ops.snlinear(h6, 1, update_collection=update_collection,
238 | name='d_sn_linear')
239 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16,
240 | update_collection=update_collection,
241 | name='d_embedding')
242 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True)
243 | return output
--------------------------------------------------------------------------------
/eval_imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Generic train."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | from absl import flags
23 | import tensorflow as tf
24 |
25 |
26 | import generator as generator_module
27 | import utils_ori as utils
28 |
29 |
30 |
31 |
32 | slim = tf.contrib.slim
33 | tfgan = tf.contrib.gan
34 |
35 |
36 | flags.DEFINE_string(
37 | # 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet',
38 | 'data_dir', '/bigdata1/hz138/Data/imagenet',
39 | 'Directory with Imagenet input data as sharded recordio files of pre-'
40 | 'processed images.')
41 | flags.DEFINE_integer('z_dim', 128, 'The dimension of z')
42 | flags.DEFINE_integer('gf_dim', 64, 'Dimensionality of gf. [64]')
43 |
44 |
45 | flags.DEFINE_string('master', 'local',
46 | 'BNS name of the TensorFlow master to use')
47 | flags.DEFINE_string('checkpoint_dir', 'checkpoint', 'Directory name to load '
48 | 'the checkpoints. [checkpoint]')
49 | flags.DEFINE_string('sample_dir', 'sample', 'Directory name to save the '
50 | 'image samples. [sample]')
51 | flags.DEFINE_string('eval_dir', 'checkpoint/eval', 'Directory name to save the '
52 | 'eval summaries . [eval]')
53 | flags.DEFINE_integer('batch_size', 64, 'Batch size of samples to feed into '
54 | 'Inception models for evaluation. [16]')
55 | flags.DEFINE_integer('shuffle_buffer_size', 5000, 'Number of records to load '
56 | 'before shuffling and yielding for consumption. [5000]')
57 | flags.DEFINE_integer('dcgan_generator_batch_size', 100, 'Size of batch to feed '
58 | 'into generator -- we may stack multiple of these later.')
59 | flags.DEFINE_integer('eval_sample_size', 50000,
60 | 'Number of samples to sample from '
61 | 'generator and real data. [1024]')
62 | flags.DEFINE_boolean('is_train', False, 'Use DCGAN only for evaluation.')
63 |
64 | flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]')
65 | flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]')
66 | flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]')
67 | flags.DEFINE_integer('replicas_to_aggregate', 1, 'The number of replicas '
68 | 'to aggregate for synchronous optimization [1]')
69 |
70 | flags.DEFINE_integer('num_towers', 1, 'The number of GPUs to use per task. [1]')
71 | flags.DEFINE_integer('eval_interval_secs', 300,
72 | 'Frequency of generator evaluation with Inception score '
73 | 'and Frechet Inception Distance. [300]')
74 |
75 | flags.DEFINE_integer('num_classes', 1000, 'The number of classes in the dataset')
76 | flags.DEFINE_string('generator_type', 'test', 'test or baseline')
77 |
78 | FLAGS = flags.FLAGS
79 |
80 |
81 | def main(_):
82 | model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size)
83 | FLAGS.eval_dir = FLAGS.checkpoint_dir + '/eval'
84 | checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, model_dir)
85 | log_dir = os.path.join(FLAGS.eval_dir, model_dir)
86 | print('log_dir', log_dir)
87 | graph_def = None # pylint: disable=protected-access
88 |
89 | # Batch size to feed batches of images through Inception and the generator
90 | # to extract feature vectors to later stack together and compute metrics.
91 | local_batch_size = FLAGS.dcgan_generator_batch_size
92 | if FLAGS.generator_type == 'baseline':
93 | generator_fn = generator_module.generator
94 | elif FLAGS.generator_type == 'test':
95 | generator_fn = generator_module.generator_test
96 | else:
97 | raise NotImplementedError
98 | if FLAGS.num_towers != 1 or FLAGS.num_workers != 1:
99 | raise NotImplementedError(
100 | 'The eval job does not currently support using multiple GPUs')
101 |
102 | # Get activations from real images.
103 | with tf.device('/device:CPU:1'):
104 | real_pools, real_images = utils.get_real_activations(
105 | FLAGS.data_dir,
106 | local_batch_size,
107 | FLAGS.eval_sample_size // local_batch_size,
108 | label_offset=-1,
109 | shuffle_buffer_size=FLAGS.shuffle_buffer_size)
110 |
111 | num_classes = FLAGS.num_classes
112 | gen_class_logits = tf.zeros((local_batch_size, num_classes))
113 | gen_class_ints = tf.multinomial(gen_class_logits, 1)
114 | gen_sparse_class = tf.squeeze(gen_class_ints)
115 |
116 |
117 |
118 | # Generate the first batch of generated images and extract activations;
119 | # this bootstraps the while_loop with a pools and logits tensor.
120 |
121 |
122 | test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
123 | generator = generator_fn(
124 | test_zs[0],
125 | gen_sparse_class,
126 | FLAGS.gf_dim,
127 | FLAGS.num_classes,
128 | is_training=False)
129 |
130 |
131 |
132 | pools, logits = utils.run_custom_inception(
133 | generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def)
134 |
135 | # Set up while_loop to compute activations of generated images from generator.
136 | def while_cond(g_pools, g_logits, i): # pylint: disable=unused-argument
137 | return tf.less(i, FLAGS.eval_sample_size // local_batch_size)
138 |
139 | # We use a while loop because we want to generate a batch of images
140 | # and then feed that batch through Inception to retrieve the activations.
141 | # Otherwise, if we generate all the samples first and then compute all the
142 | # activations, we will run out of memory.
143 | def while_body(g_pools, g_logits, i):
144 | with tf.control_dependencies([g_pools, g_logits]):
145 |
146 | test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim)
147 | # Uniform distribution
148 | gen_class_logits = tf.zeros((local_batch_size, num_classes))
149 | gen_class_ints = tf.multinomial(gen_class_logits, 1)
150 | gen_sparse_class = tf.squeeze(gen_class_ints)
151 |
152 | generator = generator_fn(
153 | test_zs[0],
154 | gen_sparse_class,
155 | FLAGS.gf_dim,
156 | FLAGS.num_classes,
157 | is_training=False)
158 |
159 | pools, logits = utils.run_custom_inception(
160 | generator,
161 | output_tensor=['pool_3:0', 'logits:0'],
162 | graph_def=graph_def)
163 | g_pools = tf.concat([g_pools, pools], 0)
164 | g_logits = tf.concat([g_logits, logits], 0)
165 |
166 | return (g_pools, g_logits, tf.add(i, 1))
167 |
168 | # Get the activations
169 | i = tf.constant(1)
170 | new_generator_pools_list, new_generator_logits_list, _ = tf.while_loop(
171 | while_cond,
172 | while_body, [pools, logits, i],
173 | shape_invariants=[
174 | tf.TensorShape([None, 2048]),
175 | tf.TensorShape([None, 1008]),
176 | i.get_shape()
177 | ],
178 | parallel_iterations=1,
179 | back_prop=False,
180 | swap_memory=True,
181 | name='GeneratedActivations')
182 |
183 | new_generator_pools_list.set_shape([FLAGS.eval_sample_size, 2048])
184 | new_generator_logits_list.set_shape([FLAGS.eval_sample_size, 1008])
185 |
186 | # Get a small batch of samples from generator to dispaly in TensorBoard
187 | vis_batch_size = 16
188 | eval_vis_zs = utils.make_z_normal(
189 | 1, vis_batch_size, FLAGS.z_dim)
190 |
191 | gen_class_logits_vis = tf.zeros((vis_batch_size, num_classes))
192 | gen_class_ints_vis = tf.multinomial(gen_class_logits_vis, 1)
193 | gen_sparse_class_vis = tf.squeeze(gen_class_ints_vis)
194 |
195 | eval_vis_images = generator_fn(
196 | eval_vis_zs[0],
197 | gen_sparse_class_vis,
198 | FLAGS.gf_dim,
199 | FLAGS.num_classes,
200 | is_training=False
201 | )
202 | eval_vis_images = tf.cast((eval_vis_images + 1.) * 127.5, tf.uint8)
203 |
204 | with tf.variable_scope('eval_vis'):
205 | tf.summary.image('generated_images', eval_vis_images)
206 | tf.summary.image('real_images', real_images)
207 | tf.summary.image('real_images_grid',
208 | tfgan.eval.image_grid(
209 | real_images[:16],
210 | grid_shape=utils.squarest_grid_size(16),
211 | image_shape=(128, 128)))
212 | tf.summary.image('generated_images_grid',
213 | tfgan.eval.image_grid(
214 | eval_vis_images[:16],
215 | grid_shape=utils.squarest_grid_size(16),
216 | image_shape=(128, 128)))
217 |
218 | # Use the activations from the real images and generated images to compute
219 | # Inception score and FID.
220 | generated_logits = tf.concat(new_generator_logits_list, 0)
221 | generated_pools = tf.concat(new_generator_pools_list, 0)
222 |
223 | # Compute Frechet Inception Distance and Inception score
224 | incscore = tfgan.eval.classifier_score_from_logits(generated_logits)
225 | fid = tfgan.eval.frechet_classifier_distance_from_activations(
226 | real_pools, generated_pools)
227 |
228 | with tf.variable_scope('eval'):
229 | tf.summary.scalar('fid', fid)
230 | tf.summary.scalar('incscore', incscore)
231 |
232 | session_config = tf.ConfigProto(
233 | allow_soft_placement=True, log_device_placement=False)
234 |
235 | tf.contrib.training.evaluate_repeatedly(
236 | checkpoint_dir=checkpoint_dir,
237 | hooks=[
238 | tf.contrib.training.SummaryAtEndHook(log_dir),
239 | tf.contrib.training.StopAfterNEvalsHook(1)
240 | ],
241 | config=session_config)
242 |
243 |
244 | if __name__ == '__main__':
245 | tf.app.run()
246 |
--------------------------------------------------------------------------------
/generator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """The generator of SNGAN."""
17 |
18 | import tensorflow as tf
19 | import ops
20 | import non_local
21 |
22 |
23 | def upscale(x, n):
24 | """Builds box upscaling (also called nearest neighbors).
25 |
26 | Args:
27 | x: 4D image tensor in B01C format.
28 | n: integer scale (must be a power of 2).
29 |
30 | Returns:
31 | 4D tensor of images up scaled by a factor n.
32 | """
33 | if n == 1:
34 | return x
35 | return tf.batch_to_space(tf.tile(x, [n**2, 1, 1, 1]), [[0, 0], [0, 0]], n)
36 |
37 |
38 | def usample_tpu(x):
39 | """Upscales the width and height of the input vector by a factor of 2."""
40 | x = upscale(x, 2)
41 | return x
42 |
43 | def usample(x):
44 | _, nh, nw, nx = x.get_shape().as_list()
45 | x = tf.image.resize_nearest_neighbor(x, [nh * 2, nw * 2])
46 | return x
47 |
48 | def block_no_sn(x, labels, out_channels, num_classes, is_training, name):
49 | """Builds the residual blocks used in the generator.
50 |
51 | Compared with block, optimized_block always downsamples the spatial resolution
52 | of the input vector by a factor of 4.
53 |
54 | Args:
55 | x: The 4D input vector.
56 | labels: The conditional labels in the generation.
57 | out_channels: Number of features in the output layer.
58 | num_classes: Number of classes in the labels.
59 | name: The variable scope name for the block.
60 | Returns:
61 | A `Tensor` representing the output of the operation.
62 | """
63 | with tf.variable_scope(name):
64 | bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0')
65 | bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1')
66 | x_0 = x
67 | x = tf.nn.relu(bn0(x, labels, is_training))
68 | x = usample(x)
69 | x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1')
70 | x = tf.nn.relu(bn1(x, labels, is_training))
71 | x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2')
72 |
73 | x_0 = usample(x_0)
74 | x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3')
75 |
76 | return x_0 + x
77 |
78 | def block(x, labels, out_channels, num_classes, is_training, name):
79 | with tf.variable_scope(name):
80 | bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0')
81 | bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1')
82 | x_0 = x
83 | x = tf.nn.relu(bn0(x, labels, is_training))
84 | x = usample(x)
85 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv1')
86 | x = tf.nn.relu(bn1(x, labels, is_training))
87 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv2')
88 |
89 | x_0 = usample(x_0)
90 | x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='snconv3')
91 |
92 | return x_0 + x
93 |
94 |
95 | def generator_old(zs,
96 | target_class,
97 | gf_dim,
98 | num_classes,
99 | is_training=True,
100 | scope='Generator'):
101 | """Builds the generator graph propagating from z to x.
102 |
103 | Args:
104 | zs: The list of noise tensors.
105 | target_class: The conditional labels in the generation.
106 | gf_dim: The gf dimension.
107 | num_classes: Number of classes in the labels.
108 | scope: Optional scope for `variable_op_scope`.
109 |
110 | Returns:
111 | outputs: The output layer of the generator.
112 | """
113 |
114 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
115 | # project `z` and reshape
116 | act0 = ops.linear(zs, gf_dim * 16 * 4 * 4, scope='g_h0')
117 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])
118 |
119 | act1 = block_no_sn(act0, target_class, gf_dim * 16,
120 | num_classes, is_training, 'g_block1') # 8 * 8
121 | act2 = block_no_sn(act1, target_class, gf_dim * 8,
122 | num_classes, is_training, 'g_block2') # 16 * 16
123 | act3 = block_no_sn(act2, target_class, gf_dim * 4,
124 | num_classes, is_training, 'g_block3') # 32 * 32
125 | act4 = block_no_sn(act3, target_class, gf_dim * 2,
126 | num_classes, is_training, 'g_block4') # 64 * 64
127 | act5 = block_no_sn(act4, target_class, gf_dim,
128 | num_classes, is_training, 'g_block5') # 128 * 128
129 | bn = ops.batch_norm(name='g_bn')
130 |
131 | act5 = tf.nn.relu(bn(act5, is_training))
132 | act6 = ops.conv2d(act5, 3, 3, 3, 1, 1, name='g_conv_last')
133 | out = tf.nn.tanh(act6)
134 | print('GAN baseline with moving average')
135 | return out
136 |
137 | def generator(zs,
138 | target_class,
139 | gf_dim,
140 | num_classes,
141 | is_training=True):
142 | """Builds the generator graph propagating from z to x.
143 |
144 | Args:
145 | zs: The list of noise tensors.
146 | target_class: The conditional labels in the generation.
147 | gf_dim: The gf dimension.
148 | num_classes: Number of classes in the labels.
149 | scope: Optional scope for `variable_op_scope`.
150 |
151 | Returns:
152 | outputs: The output layer of the generator.
153 | """
154 |
155 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
156 | # project `z` and reshape
157 | act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0')
158 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])
159 |
160 | act1 = block(act0, target_class, gf_dim * 16,
161 | num_classes, is_training, 'g_block1') # 8 * 8
162 | act2 = block(act1, target_class, gf_dim * 8,
163 | num_classes, is_training, 'g_block2') # 16 * 16
164 | act3 = block(act2, target_class, gf_dim * 4,
165 | num_classes, is_training, 'g_block3') # 32 * 32
166 | act4 = block(act3, target_class, gf_dim * 2,
167 | num_classes, is_training, 'g_block4') # 64 * 64
168 | act5 = block(act4, target_class, gf_dim,
169 | num_classes, is_training, 'g_block5') # 128 * 128
170 | bn = ops.batch_norm(name='g_bn')
171 |
172 | act5 = tf.nn.relu(bn(act5, is_training))
173 | act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last')
174 | out = tf.nn.tanh(act6)
175 | print('Generator Structure')
176 | return out
177 |
178 |
179 | def generator_test(zs,
180 | target_class,
181 | gf_dim,
182 | num_classes,
183 | is_training=True):
184 | """Builds the generator graph propagating from z to x.
185 |
186 | Args:
187 | zs: The list of noise tensors.
188 | target_class: The conditional labels in the generation.
189 | gf_dim: The gf dimension.
190 | num_classes: Number of classes in the labels.
191 | scope: Optional scope for `variable_op_scope`.
192 |
193 | Returns:
194 | outputs: The output layer of the generator.
195 | """
196 |
197 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
198 | # project `z` and reshape
199 | act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0')
200 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])
201 |
202 | act1 = block(act0, target_class, gf_dim * 16,
203 | num_classes, is_training, 'g_block1') # 8 * 8
204 | act2 = block(act1, target_class, gf_dim * 8,
205 | num_classes, is_training, 'g_block2') # 16 * 16
206 | act3 = block(act2, target_class, gf_dim * 4,
207 | num_classes, is_training, 'g_block3') # 32 * 32
208 | act3 = non_local.sn_non_local_block_sim(act3, None, name='g_non_local')
209 | act4 = block(act3, target_class, gf_dim * 2,
210 | num_classes, is_training, 'g_block4') # 64 * 64
211 | act5 = block(act4, target_class, gf_dim,
212 | num_classes, is_training, 'g_block5') # 128 * 128
213 | bn = ops.batch_norm(name='g_bn')
214 |
215 | act5 = tf.nn.relu(bn(act5, is_training))
216 | act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last')
217 | out = tf.nn.tanh(act6)
218 | print('Generator TEST structure')
219 | return out
220 |
221 | def generator_test_64(zs,
222 | target_class,
223 | gf_dim,
224 | num_classes,
225 | is_training=True):
226 | """Builds the generator graph propagating from z to x.
227 |
228 | Args:
229 | zs: The list of noise tensors.
230 | target_class: The conditional labels in the generation.
231 | gf_dim: The gf dimension.
232 | num_classes: Number of classes in the labels.
233 | scope: Optional scope for `variable_op_scope`.
234 |
235 | Returns:
236 | outputs: The output layer of the generator.
237 | """
238 |
239 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
240 | # project `z` and reshape
241 | act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0')
242 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16])
243 |
244 | act1 = block(act0, target_class, gf_dim * 16,
245 | num_classes, is_training, 'g_block1') # 8 * 8
246 | act2 = block(act1, target_class, gf_dim * 8,
247 | num_classes, is_training, 'g_block2') # 16 * 16
248 | act3 = block(act2, target_class, gf_dim * 4,
249 | num_classes, is_training, 'g_block3') # 32 * 32
250 |
251 | act4 = block(act3, target_class, gf_dim * 2,
252 | num_classes, is_training, 'g_block4') # 64 * 64
253 | act4 = non_local.sn_non_local_block_sim(act4, None, name='g_non_local')
254 | act5 = block(act4, target_class, gf_dim,
255 | num_classes, is_training, 'g_block5') # 128 * 128
256 | bn = ops.batch_norm(name='g_bn')
257 |
258 | act5 = tf.nn.relu(bn(act5, is_training))
259 | act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last')
260 | out = tf.nn.tanh(act6)
261 | print('GAN test with moving average')
262 | return out
263 |
264 |
--------------------------------------------------------------------------------
/imgs/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brain-research/self-attention-gan/ad9612e60f6ba2b5ad3d3340ebae60f724636d75/imgs/img1.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """The DCGAN Model."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl import flags
22 | import tensorflow as tf
23 |
24 | import discriminator as disc
25 | import generator as generator_module
26 |
27 | import ops
28 | import utils_ori as utils
29 |
30 |
31 | tfgan = tf.contrib.gan
32 |
33 | flags.DEFINE_string(
34 | # 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet',
35 | 'data_dir', '/bigdata1/hz138/Data/imagenet',
36 | 'Directory with Imagenet input data as sharded recordio files of pre-'
37 | 'processed images.')
38 | flags.DEFINE_float('discriminator_learning_rate', 0.0004,
39 | 'Learning rate of for adam. [0.0004]')
40 | flags.DEFINE_float('generator_learning_rate', 0.0001,
41 | 'Learning rate of for adam. [0.0004]')
42 | flags.DEFINE_float('beta1', 0.0, 'Momentum term of adam. [0.5]')
43 | flags.DEFINE_integer('image_size', 128, 'The size of image to use '
44 | '(will be center cropped) [128]')
45 | flags.DEFINE_integer('image_width', 128,
46 | 'The width of the images presented to the model')
47 | flags.DEFINE_integer('data_parallelism', 64, 'The number of objects to read at'
48 | ' one time when loading input data. [64]')
49 | flags.DEFINE_integer('z_dim', 128, 'Dimensionality of latent code z. [8192]')
50 | flags.DEFINE_integer('gf_dim', 64, 'Dimensionality of gf. [64]')
51 | flags.DEFINE_integer('df_dim', 64, 'Dimensionality of df. [64]')
52 | flags.DEFINE_integer('number_classes', 1000, 'The number of classes in the dataset')
53 | flags.DEFINE_string('loss_type', 'hinge_loss', 'the loss type can be'
54 | ' hinge_loss or kl_loss')
55 | flags.DEFINE_string('generator_type', 'test', 'test or baseline')
56 | flags.DEFINE_string('discriminator_type', 'test', 'test or baseline')
57 |
58 |
59 |
60 | FLAGS = flags.FLAGS
61 |
62 |
63 | def _get_d_real_loss(discriminator_on_data_logits):
64 | loss = tf.nn.relu(1.0 - discriminator_on_data_logits)
65 | return tf.reduce_mean(loss)
66 |
67 |
68 | def _get_d_fake_loss(discriminator_on_generator_logits):
69 | return tf.reduce_mean(tf.nn.relu(1 + discriminator_on_generator_logits))
70 |
71 |
72 | def _get_g_loss(discriminator_on_generator_logits):
73 | return -tf.reduce_mean(discriminator_on_generator_logits)
74 |
75 |
76 | def _get_d_real_loss_KL(discriminator_on_data_logits):
77 | loss = tf.nn.softplus(-discriminator_on_data_logits)
78 | return tf.reduce_mean(loss)
79 |
80 |
81 | def _get_d_fake_loss_KL(discriminator_on_generator_logits):
82 | return tf.reduce_mean(tf.nn.softplus(discriminator_on_generator_logits))
83 |
84 |
85 | def _get_g_loss_KL(discriminator_on_generator_logits):
86 | return tf.reduce_mean(-discriminator_on_generator_logits)
87 |
88 |
89 |
90 | class SNGAN(object):
91 | """SNGAN model."""
92 |
93 | def __init__(self, zs, config=None, global_step=None, devices=None):
94 | """Initializes the DCGAN model.
95 |
96 | Args:
97 | zs: input noise tensors for the generator
98 | config: the configuration FLAGS object
99 | global_step: the global training step (maintained by the supervisor)
100 | devices: the list of device names to place ops on (multitower training)
101 | """
102 |
103 | self.config = config
104 | self.image_size = FLAGS.image_size
105 | self.image_shape = [FLAGS.image_size, FLAGS.image_size, 3]
106 | self.z_dim = FLAGS.z_dim
107 | self.gf_dim = FLAGS.gf_dim
108 | self.df_dim = FLAGS.df_dim
109 | self.num_classes = FLAGS.number_classes
110 |
111 | self.data_parallelism = FLAGS.data_parallelism
112 | self.zs = zs
113 |
114 | self.c_dim = 3
115 | self.dataset_name = 'imagenet'
116 | self.devices = devices
117 | self.global_step = global_step
118 |
119 | self.build_model()
120 |
121 | def build_model(self):
122 | """Builds a model."""
123 | config = self.config
124 | # If ps_tasks is zero, the local device is used. When using multiple
125 | # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
126 | # across the different devices.
127 | current_step = tf.cast(self.global_step, tf.float32)
128 | # g_ratio = (1.0 + 2e-5 * tf.maximum((current_step - 100000.0), 0.0))
129 | # g_ratio = tf.minimum(g_ratio, 4.0)
130 | self.d_learning_rate = FLAGS.discriminator_learning_rate
131 | self.g_learning_rate = FLAGS.generator_learning_rate
132 | # self.g_learning_rate = FLAGS.generator_learning_rate / (1.0 + 2e-5 * tf.cast(self.global_step, tf.float32))
133 | # self.g_learning_rate = FLAGS.generator_learning_rate / g_ratio
134 | with tf.device(tf.train.replica_device_setter(config.ps_tasks)):
135 | self.d_opt = tf.train.AdamOptimizer(
136 | self.d_learning_rate, beta1=FLAGS.beta1)
137 | self.g_opt = tf.train.AdamOptimizer(
138 | self.g_learning_rate, beta1=FLAGS.beta1)
139 | if config.sync_replicas and config.num_workers > 1:
140 | self.d_opt = tf.train.SyncReplicasOptimizer(
141 | opt=self.d_opt, replicas_to_aggregate=config.replicas_to_aggregate)
142 | self.g_opt = tf.train.SyncReplicasOptimizer(
143 | opt=self.g_opt, replicas_to_aggregate=config.replicas_to_aggregate)
144 |
145 | if config.num_towers > 1:
146 | all_d_grads = []
147 | all_g_grads = []
148 | for idx, device in enumerate(self.devices):
149 | with tf.device('/%s' % device):
150 | with tf.name_scope('device_%s' % idx):
151 | with ops.variables_on_gpu0():
152 | self.build_model_single_gpu(
153 | gpu_idx=idx,
154 | batch_size=config.batch_size,
155 | num_towers=config.num_towers)
156 | d_grads = self.d_opt.compute_gradients(self.d_losses[-1],
157 | var_list=self.d_vars)
158 | g_grads = self.g_opt.compute_gradients(self.g_losses[-1],
159 | var_list=self.g_vars)
160 | all_d_grads.append(d_grads)
161 | all_g_grads.append(g_grads)
162 | d_grads = ops.avg_grads(all_d_grads)
163 | g_grads = ops.avg_grads(all_g_grads)
164 | else:
165 | with tf.device(tf.train.replica_device_setter(config.ps_tasks)):
166 | # TODO(olganw): reusing virtual batchnorm doesn't work in the multi-
167 | # replica case.
168 | self.build_model_single_gpu(batch_size=config.batch_size,
169 | num_towers=config.num_towers)
170 | d_grads = self.d_opt.compute_gradients(self.d_losses[-1],
171 | var_list=self.d_vars)
172 | g_grads = self.g_opt.compute_gradients(self.g_losses[-1],
173 | var_list=self.g_vars)
174 | with tf.device(tf.train.replica_device_setter(config.ps_tasks)):
175 | update_moving_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
176 | print('update_moving_ops', update_moving_ops)
177 | if config.sync_replicas:
178 | with tf.control_dependencies(update_moving_ops):
179 | d_step = tf.get_variable('d_step', initializer=0, trainable=False)
180 | self.d_optim = self.d_opt.apply_gradients(d_grads, global_step=d_step)
181 | g_step = tf.get_variable('g_step', initializer=0, trainable=False)
182 | self.g_optim = self.g_opt.apply_gradients(g_grads, global_step=g_step)
183 | else:
184 | # Don't create any additional counters, and don't update the global step
185 | with tf.control_dependencies(update_moving_ops):
186 | self.d_optim = self.d_opt.apply_gradients(d_grads)
187 | self.g_optim = self.g_opt.apply_gradients(g_grads)
188 |
189 | def build_model_single_gpu(self, gpu_idx=0, batch_size=1, num_towers=1):
190 | """Builds a model for a single GPU.
191 |
192 | Args:
193 | gpu_idx: The index of the gpu in the tower.
194 | batch_size: The minibatch size. (Default: 1)
195 | num_towers: The total number of towers in this model. (Default: 1)
196 | """
197 | config = self.config
198 | show_num = min(config.batch_size, 64)
199 |
200 | reuse_vars = gpu_idx > 0
201 | if gpu_idx == 0:
202 | self.increment_global_step = self.global_step.assign_add(1)
203 | self.batches = utils.get_imagenet_batches(
204 | FLAGS.data_dir, batch_size, num_towers, label_offset=0,
205 | cycle_length=config.data_parallelism,
206 | shuffle_buffer_size=config.shuffle_buffer_size)
207 | sample_images, _ = self.batches[0]
208 | vis_images = tf.cast((sample_images + 1.) * 127.5, tf.uint8)
209 | tf.summary.image('input_image_grid',
210 | tfgan.eval.image_grid(
211 | vis_images[:show_num],
212 | grid_shape=utils.squarest_grid_size(
213 | show_num),
214 | image_shape=(128, 128)))
215 |
216 | images, sparse_labels = self.batches[gpu_idx]
217 | sparse_labels = tf.squeeze(sparse_labels)
218 | print('han spase_labels.shape', sparse_labels.shape)
219 |
220 | gen_class_logits = tf.zeros((batch_size, self.num_classes))
221 | gen_class_ints = tf.multinomial(gen_class_logits, 1)
222 | # gen_sparse_class = tf.argmax(gen_class_ints, axis=1) BIG BUG!!!!!
223 | gen_sparse_class = tf.squeeze(gen_class_ints)
224 | print('han gen_sparse_class.shape', gen_sparse_class.shape)
225 | assert len(gen_class_ints.get_shape()) == 2
226 | gen_class_ints = tf.squeeze(gen_class_ints)
227 | assert len(gen_class_ints.get_shape()) == 1
228 | gen_class_vector = tf.one_hot(gen_class_ints, self.num_classes)
229 | assert len(gen_class_vector.get_shape()) == 2
230 | assert gen_class_vector.dtype == tf.float32
231 |
232 | if FLAGS.generator_type == 'baseline':
233 | generator_fn = generator_module.generator
234 | elif FLAGS.generator_type == 'test':
235 | generator_fn = generator_module.generator_test
236 |
237 | generator = generator_fn(
238 | self.zs[gpu_idx],
239 | gen_sparse_class,
240 | self.gf_dim,
241 | self.num_classes
242 | )
243 |
244 |
245 | if gpu_idx == 0:
246 | generator_means = tf.reduce_mean(generator, 0, keep_dims=True)
247 | generator_vars = tf.reduce_mean(
248 | tf.squared_difference(generator, generator_means), 0, keep_dims=True)
249 | generator = tf.Print(
250 | generator,
251 | [tf.reduce_mean(generator_means), tf.reduce_mean(generator_vars)],
252 | 'generator mean and average var', first_n=1)
253 | image_means = tf.reduce_mean(images, 0, keep_dims=True)
254 | image_vars = tf.reduce_mean(
255 | tf.squared_difference(images, image_means), 0, keep_dims=True)
256 | images = tf.Print(
257 | images, [tf.reduce_mean(image_means), tf.reduce_mean(image_vars)],
258 | 'image mean and average var', first_n=1)
259 | sparse_labels = tf.Print(sparse_labels, [sparse_labels, sparse_labels.shape], 'sparse_labels', first_n=2)
260 | gen_sparse_class = tf.Print(gen_sparse_class, [gen_sparse_class, gen_sparse_class.shape], 'gen_sparse_labels', first_n=2)
261 |
262 | self.generators = []
263 |
264 | self.generators.append(generator)
265 |
266 | if FLAGS.discriminator_type == 'baseline':
267 | discriminator_fn = disc.discriminator
268 | elif FLAGS.discriminator_type == 'test':
269 | discriminator_fn = disc.discriminator_test
270 | else:
271 | raise NotImplementedError
272 | discriminator_on_data_logits = discriminator_fn(images, sparse_labels, self.df_dim, self.num_classes,
273 | update_collection=None)
274 | discriminator_on_generator_logits = discriminator_fn(generator, gen_sparse_class, self.df_dim, self.num_classes,
275 | update_collection="NO_OPS")
276 |
277 | vis_generator = tf.cast((generator + 1.) * 127.5, tf.uint8)
278 | tf.summary.image('generator', vis_generator)
279 |
280 | tf.summary.image('generator_grid',
281 | tfgan.eval.image_grid(
282 | vis_generator[:show_num],
283 | grid_shape=utils.squarest_grid_size(show_num),
284 | image_shape=(128, 128)))
285 |
286 | if FLAGS.loss_type == 'hinge_loss':
287 | d_loss_real = _get_d_real_loss(
288 | discriminator_on_data_logits)
289 | d_loss_fake = _get_d_fake_loss(discriminator_on_generator_logits)
290 | g_loss_gan = _get_g_loss(discriminator_on_generator_logits)
291 | print('hinge loss is using')
292 | elif FLAGS.loss_type == 'kl_loss':
293 | d_loss_real = _get_d_real_loss_KL(
294 | discriminator_on_data_logits)
295 | d_loss_fake = _get_d_fake_loss_KL(discriminator_on_generator_logits)
296 | g_loss_gan = _get_g_loss_KL(discriminator_on_generator_logits)
297 | print('kl loss is using')
298 | else:
299 | raise NotImplementedError
300 |
301 |
302 | d_loss = d_loss_real + d_loss_fake
303 | g_loss = g_loss_gan
304 |
305 |
306 | # add logit log
307 | logit_discriminator_on_data = tf.reduce_mean(discriminator_on_data_logits)
308 | logit_discriminator_on_generator = tf.reduce_mean(
309 | discriminator_on_generator_logits)
310 |
311 |
312 |
313 | # Add summaries.
314 | tf.summary.scalar('d_loss', d_loss)
315 | tf.summary.scalar('d_loss_real', d_loss_real)
316 | tf.summary.scalar('d_loss_fake', d_loss_fake)
317 | tf.summary.scalar('g_loss', g_loss)
318 | tf.summary.scalar('logit_real', logit_discriminator_on_data)
319 | tf.summary.scalar('logit_fake', logit_discriminator_on_generator)
320 | tf.summary.scalar('d_learning_rate', self.d_learning_rate)
321 | tf.summary.scalar('g_learning_rate', self.g_learning_rate)
322 |
323 |
324 |
325 | if gpu_idx == 0:
326 | self.d_loss_reals = []
327 | self.d_loss_fakes = []
328 | self.d_losses = []
329 | self.g_losses = []
330 | self.d_loss_reals.append(d_loss_real)
331 | self.d_loss_fakes.append(d_loss_fake)
332 | self.d_losses.append(d_loss)
333 | self.g_losses.append(g_loss)
334 |
335 | if gpu_idx == 0:
336 | self.get_vars()
337 | print('gvars', self.g_vars)
338 | print('dvars', self.d_vars)
339 | print('sigma_ratio_vars', self.sigma_ratio_vars)
340 | for var in self.sigma_ratio_vars:
341 | tf.summary.scalar(var.name, var)
342 |
343 | def get_vars(self):
344 | """Get variables."""
345 | t_vars = tf.trainable_variables()
346 | # TODO(olganw): scoping or collections for this instead of name hack
347 | self.d_vars = [var for var in t_vars if var.name.startswith('model/d_')]
348 | self.g_vars = [var for var in t_vars if var.name.startswith('model/g_')]
349 | self.sigma_ratio_vars = [var for var in t_vars if 'sigma_ratio' in var.name]
350 | for x in self.d_vars:
351 | assert x not in self.g_vars
352 | for x in self.g_vars:
353 | assert x not in self.d_vars
354 | for x in t_vars:
355 | assert x in self.g_vars or x in self.d_vars, x.name
356 | self.all_vars = t_vars
357 |
--------------------------------------------------------------------------------
/non_local.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import tensorflow as tf
17 | import numpy as np
18 | import ops
19 |
20 | def conv1x1(input_, output_dim,
21 | init=tf.contrib.layers.xavier_initializer(), name='conv1x1'):
22 | k_h = 1
23 | k_w = 1
24 | d_h = 1
25 | d_w = 1
26 | with tf.variable_scope(name):
27 | w = tf.get_variable(
28 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim],
29 | initializer=init)
30 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
31 | return conv
32 |
33 | def sn_conv1x1(input_, output_dim, update_collection,
34 | init=tf.contrib.layers.xavier_initializer(), name='sn_conv1x1'):
35 | with tf.variable_scope(name):
36 | k_h = 1
37 | k_w = 1
38 | d_h = 1
39 | d_w = 1
40 | w = tf.get_variable(
41 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim],
42 | initializer=init)
43 | w_bar = ops.spectral_normed_weight(w, num_iters=1, update_collection=update_collection)
44 |
45 | conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME')
46 | return conv
47 |
48 | def sn_non_local_block_sim(x, update_collection, name, init=tf.contrib.layers.xavier_initializer()):
49 | with tf.variable_scope(name):
50 | batch_size, h, w, num_channels = x.get_shape().as_list()
51 | location_num = h * w
52 | downsampled_num = location_num // 4
53 |
54 | # theta path
55 | theta = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_theta')
56 | theta = tf.reshape(
57 | theta, [batch_size, location_num, num_channels // 8])
58 |
59 | # phi path
60 | phi = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_phi')
61 | phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[2, 2], strides=2)
62 | phi = tf.reshape(
63 | phi, [batch_size, downsampled_num, num_channels // 8])
64 |
65 |
66 | attn = tf.matmul(theta, phi, transpose_b=True)
67 | attn = tf.nn.softmax(attn)
68 | print(tf.reduce_sum(attn, axis=-1))
69 |
70 | # g path
71 | g = sn_conv1x1(x, num_channels // 2, update_collection, init, 'sn_conv_g')
72 | g = tf.layers.max_pooling2d(inputs=g, pool_size=[2, 2], strides=2)
73 | g = tf.reshape(
74 | g, [batch_size, downsampled_num, num_channels // 2])
75 |
76 | attn_g = tf.matmul(attn, g)
77 | attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2])
78 | sigma = tf.get_variable(
79 | 'sigma_ratio', [], initializer=tf.constant_initializer(0.0))
80 | attn_g = sn_conv1x1(attn_g, num_channels, update_collection, init, 'sn_conv_attn')
81 | return x + sigma * attn_g
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """The building block ops for Spectral Normalization GAN."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 | from contextlib import contextmanager
24 |
25 |
26 | rng = np.random.RandomState([2016, 6, 1])
27 |
28 |
29 | def conv2d(input_, output_dim,
30 | k_h=3, k_w=3, d_h=2, d_w=2, name='conv2d'):
31 | """Creates convolutional layers which use xavier initializer.
32 |
33 | Args:
34 | input_: 4D input tensor (batch size, height, width, channel).
35 | output_dim: Number of features in the output layer.
36 | k_h: The height of the convolutional kernel.
37 | k_w: The width of the convolutional kernel.
38 | d_h: The height stride of the convolutional kernel.
39 | d_w: The width stride of the convolutional kernel.
40 | name: The name of the variable scope.
41 | Returns:
42 | conv: The normalized tensor.
43 | """
44 | with tf.variable_scope(name):
45 | w = tf.get_variable(
46 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim],
47 | initializer=tf.contrib.layers.xavier_initializer())
48 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
49 |
50 | biases = tf.get_variable('biases', [output_dim],
51 | initializer=tf.zeros_initializer())
52 | conv = tf.nn.bias_add(conv, biases)
53 | return conv
54 |
55 |
56 | def deconv2d(input_, output_shape,
57 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
58 | name='deconv2d', init_bias=0.):
59 | """Creates deconvolutional layers.
60 |
61 | Args:
62 | input_: 4D input tensor (batch size, height, width, channel).
63 | output_shape: Number of features in the output layer.
64 | k_h: The height of the convolutional kernel.
65 | k_w: The width of the convolutional kernel.
66 | d_h: The height stride of the convolutional kernel.
67 | d_w: The width stride of the convolutional kernel.
68 | stddev: The standard deviation for weights initializer.
69 | name: The name of the variable scope.
70 | init_bias: The initial bias for the layer.
71 | Returns:
72 | conv: The normalized tensor.
73 | """
74 | with tf.variable_scope(name):
75 | w = tf.get_variable('w',
76 | [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
77 | initializer=tf.random_normal_initializer(stddev=stddev))
78 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
79 | strides=[1, d_h, d_w, 1])
80 | biases = tf.get_variable('biases', [output_shape[-1]],
81 | initializer=tf.constant_initializer(init_bias))
82 | deconv = tf.nn.bias_add(deconv, biases)
83 | deconv.shape.assert_is_compatible_with(output_shape)
84 |
85 | return deconv
86 |
87 |
88 | def linear(x, output_size, scope=None, bias_start=0.0):
89 | """Creates a linear layer.
90 |
91 | Args:
92 | x: 2D input tensor (batch size, features)
93 | output_size: Number of features in the output layer
94 | scope: Optional, variable scope to put the layer's parameters into
95 | bias_start: The bias parameters are initialized to this value
96 |
97 | Returns:
98 | The normalized tensor
99 | """
100 | shape = x.get_shape().as_list()
101 |
102 | with tf.variable_scope(scope or 'Linear'):
103 | matrix = tf.get_variable(
104 | 'Matrix', [shape[1], output_size], tf.float32,
105 | tf.contrib.layers.xavier_initializer())
106 | bias = tf.get_variable(
107 | 'bias', [output_size], initializer=tf.constant_initializer(bias_start))
108 | out = tf.matmul(x, matrix) + bias
109 | return out
110 |
111 |
112 | def lrelu(x, leak=0.2, name='lrelu'):
113 | """The leaky RELU operation."""
114 | with tf.variable_scope(name):
115 | f1 = 0.5 * (1 + leak)
116 | f2 = 0.5 * (1 - leak)
117 | return f1 * x + f2 * abs(x)
118 |
119 |
120 | def _l2normalize(v, eps=1e-12):
121 | """l2 normize the input vector."""
122 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
123 |
124 |
125 | def spectral_normed_weight(weights, num_iters=1, update_collection=None,
126 | with_sigma=False):
127 | """Performs Spectral Normalization on a weight tensor.
128 |
129 | Specifically it divides the weight tensor by its largest singular value. This
130 | is intended to stabilize GAN training, by making the discriminator satisfy a
131 | local 1-Lipschitz constraint.
132 | Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]
133 | [sn-gan] https://openreview.net/pdf?id=B1QRgziT-
134 |
135 | Args:
136 | weights: The weight tensor which requires spectral normalization
137 | num_iters: Number of SN iterations.
138 | update_collection: The update collection for assigning persisted variable u.
139 | If None, the function will update u during the forward
140 | pass. Else if the update_collection equals 'NO_OPS', the
141 | function will not update the u during the forward. This
142 | is useful for the discriminator, since it does not update
143 | u in the second pass.
144 | Else, it will put the assignment in a collection
145 | defined by the user. Then the user need to run the
146 | assignment explicitly.
147 | with_sigma: For debugging purpose. If True, the fuction returns
148 | the estimated singular value for the weight tensor.
149 | Returns:
150 | w_bar: The normalized weight tensor
151 | sigma: The estimated singular value for the weight tensor.
152 | """
153 | w_shape = weights.shape.as_list()
154 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) # [-1, output_channel]
155 | u = tf.get_variable('u', [1, w_shape[-1]],
156 | initializer=tf.truncated_normal_initializer(),
157 | trainable=False)
158 | u_ = u
159 | for _ in range(num_iters):
160 | v_ = _l2normalize(tf.matmul(u_, w_mat, transpose_b=True))
161 | u_ = _l2normalize(tf.matmul(v_, w_mat))
162 |
163 | sigma = tf.squeeze(tf.matmul(tf.matmul(v_, w_mat), u_, transpose_b=True))
164 | w_mat /= sigma
165 | if update_collection is None:
166 | with tf.control_dependencies([u.assign(u_)]):
167 | w_bar = tf.reshape(w_mat, w_shape)
168 | else:
169 | w_bar = tf.reshape(w_mat, w_shape)
170 | if update_collection != 'NO_OPS':
171 | tf.add_to_collection(update_collection, u.assign(u_))
172 | if with_sigma:
173 | return w_bar, sigma
174 | else:
175 | return w_bar
176 |
177 |
178 | def snconv2d(input_, output_dim,
179 | k_h=3, k_w=3, d_h=2, d_w=2,
180 | sn_iters=1, update_collection=None, name='snconv2d'):
181 | """Creates a spectral normalized (SN) convolutional layer.
182 |
183 | Args:
184 | input_: 4D input tensor (batch size, height, width, channel).
185 | output_dim: Number of features in the output layer.
186 | k_h: The height of the convolutional kernel.
187 | k_w: The width of the convolutional kernel.
188 | d_h: The height stride of the convolutional kernel.
189 | d_w: The width stride of the convolutional kernel.
190 | sn_iters: The number of SN iterations.
191 | update_collection: The update collection used in spectral_normed_weight.
192 | name: The name of the variable scope.
193 | Returns:
194 | conv: The normalized tensor.
195 |
196 | """
197 | with tf.variable_scope(name):
198 | w = tf.get_variable(
199 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim],
200 | initializer=tf.contrib.layers.xavier_initializer())
201 | w_bar = spectral_normed_weight(w, num_iters=sn_iters,
202 | update_collection=update_collection)
203 |
204 | conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME')
205 | biases = tf.get_variable('biases', [output_dim],
206 | initializer=tf.zeros_initializer())
207 | conv = tf.nn.bias_add(conv, biases)
208 | return conv
209 |
210 |
211 | def snlinear(x, output_size, bias_start=0.0,
212 | sn_iters=1, update_collection=None, name='snlinear'):
213 | """Creates a spectral normalized linear layer.
214 |
215 | Args:
216 | x: 2D input tensor (batch size, features).
217 | output_size: Number of features in output of layer.
218 | bias_start: The bias parameters are initialized to this value
219 | sn_iters: Number of SN iterations.
220 | update_collection: The update collection used in spectral_normed_weight
221 | name: Optional, variable scope to put the layer's parameters into
222 | Returns:
223 | The normalized tensor
224 | """
225 | shape = x.get_shape().as_list()
226 |
227 | with tf.variable_scope(name):
228 | matrix = tf.get_variable(
229 | 'Matrix', [shape[1], output_size], tf.float32,
230 | tf.contrib.layers.xavier_initializer())
231 | matrix_bar = spectral_normed_weight(matrix, num_iters=sn_iters,
232 | update_collection=update_collection)
233 | bias = tf.get_variable(
234 | 'bias', [output_size], initializer=tf.constant_initializer(bias_start))
235 | out = tf.matmul(x, matrix_bar) + bias
236 | return out
237 |
238 |
239 | def sn_embedding(x, number_classes, embedding_size, sn_iters=1,
240 | update_collection=None, name='snembedding'):
241 | """Creates a spectral normalized embedding lookup layer.
242 |
243 | Args:
244 | x: 1D input tensor (batch size, ).
245 | number_classes: The number of classes.
246 | embedding_size: The length of the embeddding vector for each class.
247 | sn_iters: Number of SN iterations.
248 | update_collection: The update collection used in spectral_normed_weight
249 | name: Optional, variable scope to put the layer's parameters into
250 | Returns:
251 | The output tensor (batch size, embedding_size).
252 | """
253 | with tf.variable_scope(name):
254 | embedding_map = tf.get_variable(
255 | name='embedding_map',
256 | shape=[number_classes, embedding_size],
257 | initializer=tf.contrib.layers.xavier_initializer())
258 | embedding_map_bar_transpose = spectral_normed_weight(
259 | tf.transpose(embedding_map), num_iters=sn_iters,
260 | update_collection=update_collection)
261 | embedding_map_bar = tf.transpose(embedding_map_bar_transpose)
262 | return tf.nn.embedding_lookup(embedding_map_bar, x)
263 |
264 |
265 | class ConditionalBatchNorm_old(object):
266 | """Conditional BatchNorm.
267 |
268 | For each class, it has a specific gamma and beta as normalization variable.
269 | """
270 |
271 | def __init__(self, num_categories, name='conditional_batch_norm', center=True,
272 | scale=True):
273 | with tf.variable_scope(name):
274 | self.name = name
275 | self.num_categories = num_categories
276 | self.center = center
277 | self.scale = scale
278 |
279 | def __call__(self, inputs, labels):
280 | inputs = tf.convert_to_tensor(inputs)
281 | inputs_shape = inputs.get_shape()
282 | params_shape = inputs_shape[-1:]
283 | axis = [0, 1, 2]
284 | shape = tf.TensorShape([self.num_categories]).concatenate(params_shape)
285 |
286 | with tf.variable_scope(self.name):
287 | self.gamma = tf.get_variable(
288 | 'gamma', shape,
289 | initializer=tf.ones_initializer())
290 | self.beta = tf.get_variable(
291 | 'beta', shape,
292 | initializer=tf.zeros_initializer())
293 | beta = tf.gather(self.beta, labels)
294 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1)
295 | gamma = tf.gather(self.gamma, labels)
296 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1)
297 | mean, variance = tf.nn.moments(inputs, axis, keep_dims=True)
298 | variance_epsilon = 1E-5
299 | outputs = tf.nn.batch_normalization(
300 | inputs, mean, variance, beta, gamma, variance_epsilon)
301 | outputs.set_shape(inputs_shape)
302 | return outputs
303 |
304 | class ConditionalBatchNorm(object):
305 | """Conditional BatchNorm.
306 |
307 | For each class, it has a specific gamma and beta as normalization variable.
308 | """
309 |
310 | def __init__(self, num_categories, name='conditional_batch_norm', decay_rate=0.999, center=True,
311 | scale=True):
312 | with tf.variable_scope(name):
313 | self.name = name
314 | self.num_categories = num_categories
315 | self.center = center
316 | self.scale = scale
317 | self.decay_rate = decay_rate
318 |
319 | def __call__(self, inputs, labels, is_training=True):
320 | inputs = tf.convert_to_tensor(inputs)
321 | inputs_shape = inputs.get_shape()
322 | params_shape = inputs_shape[-1:]
323 | axis = [0, 1, 2]
324 | shape = tf.TensorShape([self.num_categories]).concatenate(params_shape)
325 | moving_shape = tf.TensorShape([1, 1, 1]).concatenate(params_shape)
326 |
327 | with tf.variable_scope(self.name):
328 | self.gamma = tf.get_variable(
329 | 'gamma', shape,
330 | initializer=tf.ones_initializer())
331 | self.beta = tf.get_variable(
332 | 'beta', shape,
333 | initializer=tf.zeros_initializer())
334 | self.moving_mean = tf.get_variable('mean', moving_shape,
335 | initializer=tf.zeros_initializer(),
336 | trainable=False)
337 | self.moving_var = tf.get_variable('var', moving_shape,
338 | initializer=tf.ones_initializer(),
339 | trainable=False)
340 |
341 | beta = tf.gather(self.beta, labels)
342 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1)
343 | gamma = tf.gather(self.gamma, labels)
344 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1)
345 | decay = self.decay_rate
346 | variance_epsilon = 1E-5
347 | if is_training:
348 | mean, variance = tf.nn.moments(inputs, axis, keep_dims=True)
349 | update_mean = tf.assign(self.moving_mean, self.moving_mean * decay + mean * (1 - decay))
350 | update_var = tf.assign(self.moving_var, self.moving_var * decay + variance * (1 - decay))
351 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
352 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)
353 | #with tf.control_dependencies([update_mean, update_var]):
354 | outputs = tf.nn.batch_normalization(
355 | inputs, mean, variance, beta, gamma, variance_epsilon)
356 | else:
357 | outputs = tf.nn.batch_normalization(
358 | inputs, self.moving_mean, self.moving_var, beta, gamma, variance_epsilon)
359 | outputs.set_shape(inputs_shape)
360 | return outputs
361 |
362 | class batch_norm(object):
363 | def __init__(self, epsilon=1e-5, momentum = 0.9999, name="batch_norm"):
364 | with tf.variable_scope(name):
365 | self.epsilon = epsilon
366 | self.momentum = momentum
367 | self.name = name
368 |
369 | def __call__(self, x, train=True):
370 | return tf.contrib.layers.batch_norm(x,
371 | decay=self.momentum,
372 | # updates_collections=None,
373 | epsilon=self.epsilon,
374 | scale=True,
375 | is_training=train,
376 | scope=self.name)
377 |
378 | class BatchNorm(object):
379 | """The Batch Normalization layer."""
380 |
381 | def __init__(self, name='batch_norm', center=True,
382 | scale=True):
383 | with tf.variable_scope(name):
384 | self.name = name
385 | self.center = center
386 | self.scale = scale
387 |
388 | def __call__(self, inputs):
389 | inputs = tf.convert_to_tensor(inputs)
390 | inputs_shape = inputs.get_shape().as_list()
391 | params_shape = inputs_shape[-1]
392 | axis = [0, 1, 2]
393 | shape = tf.TensorShape([params_shape])
394 | with tf.variable_scope(self.name):
395 | self.gamma = tf.get_variable(
396 | 'gamma', shape,
397 | initializer=tf.ones_initializer())
398 | self.beta = tf.get_variable(
399 | 'beta', shape,
400 | initializer=tf.zeros_initializer())
401 | beta = self.beta
402 | gamma = self.gamma
403 |
404 | mean, variance = tf.nn.moments(inputs, axis, keep_dims=True)
405 | variance_epsilon = 1E-5
406 | outputs = tf.nn.batch_normalization(
407 | inputs, mean, variance, beta, gamma, variance_epsilon)
408 | outputs.set_shape(inputs_shape)
409 | return outputs
410 |
411 | @contextmanager
412 | def variables_on_gpu0():
413 | old_fn = tf.get_variable
414 | def new_fn(*args, **kwargs):
415 | with tf.device('/gpu:0'):
416 | return old_fn(*args, **kwargs)
417 | tf.get_variable = new_fn
418 | yield
419 | tf.get_variable = old_fn
420 |
421 |
422 | def avg_grads(tower_grads):
423 | """Calculate the average gradient for each shared variable across all towers.
424 |
425 | Note that this function provides a synchronization point across all towers.
426 |
427 | Args:
428 | tower_grads: List of lists of (gradient, variable) tuples. The outer list
429 | is over individual gradients. The inner list is over the gradient
430 | calculation for each tower.
431 | Returns:
432 | List of pairs of (gradient, variable) where the gradient has been averaged
433 | across all towers.
434 | """
435 | average_grads = []
436 | for grad_and_vars in zip(*tower_grads):
437 | # Note that each grad_and_vars looks like the following:
438 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
439 | grads = []
440 | for g, _ in grad_and_vars:
441 | # Add 0 dimension to the gradients to represent the tower.
442 | expanded_g = tf.expand_dims(g, 0)
443 |
444 | # Append on a 'tower' dimension which we will average over below.
445 | grads.append(expanded_g)
446 |
447 | # Average over the 'tower' dimension.
448 | grad = tf.concat(grads, 0)
449 | grad = tf.reduce_mean(grad, 0)
450 |
451 | # Keep in mind that the Variables are redundant because they are shared
452 | # across towers. So .. we will just return the first tower's pointer to
453 | # the Variable.
454 | v = grad_and_vars[0][1]
455 | grad_and_var = (grad, v)
456 | average_grads.append(grad_and_var)
457 | return average_grads
--------------------------------------------------------------------------------
/train_imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Train Imagenet."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 |
23 | from absl import flags
24 | import tensorflow as tf
25 |
26 | import utils_ori as utils
27 | import model
28 |
29 |
30 |
31 | tfgan = tf.contrib.gan
32 | gfile = tf.gfile
33 |
34 |
35 | flags.DEFINE_string('master', 'local',
36 | 'BNS name of the TensorFlow master to use. [local]')
37 |
38 | # flags.DEFINE_string('checkpoint_dir', '/usr/local/google/home/zhanghan/Documents/Research/model_output',
39 | # 'Directory name to save the checkpoints. [checkpoint]')
40 | flags.DEFINE_string('checkpoint_dir', 'checkpoint',
41 | 'Directory name to save the checkpoints. [checkpoint]')
42 | flags.DEFINE_integer('batch_size', 64, 'Number of images in input batch. [64]') # ori 16
43 | flags.DEFINE_integer('shuffle_buffer_size', 100000, 'Number of records to load '
44 | 'before shuffling and yielding for consumption. [100000]')
45 | flags.DEFINE_integer('save_summaries_steps', 200, 'Number of seconds between '
46 | 'saving summary statistics. [1]') # default 300
47 | flags.DEFINE_integer('save_checkpoint_secs', 1200, 'Number of seconds between '
48 | 'saving checkpoints of model. [1200]')
49 | flags.DEFINE_boolean('is_train', True, 'True for training. [default: True]')
50 | flags.DEFINE_boolean('is_gd_equal', True, 'True for 1:1, False for 1:5')
51 |
52 | # TODO(olganw) Find the best way to clean up these flags for eval and train.
53 | flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]')
54 | flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]')
55 | flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]')
56 | flags.DEFINE_integer('replicas_to_aggregate', 1, 'The number of replicas '
57 | 'to aggregate for synchronous optimization [1]')
58 | flags.DEFINE_boolean('sync_replicas', True, 'Whether to sync replicas. [True]')
59 | flags.DEFINE_integer('num_towers', 4, 'The number of GPUs to use per task. [1]')
60 | flags.DEFINE_integer('d_step', 1, 'The number of D_step')
61 | flags.DEFINE_integer('g_step', 1, 'The number of G_step')
62 |
63 | # flags.DEFINE_integer('z_dim', 128, 'The dimension of z')
64 |
65 | FLAGS = flags.FLAGS
66 |
67 |
68 |
69 | def main(_, is_test=False):
70 | print('d_learning_rate', FLAGS.discriminator_learning_rate)
71 | print('g_learning_rate', FLAGS.generator_learning_rate)
72 | print('data_dir', FLAGS.data_dir)
73 | print(FLAGS.loss_type, FLAGS.batch_size, FLAGS.beta1)
74 | print('gf_df_dim', FLAGS.gf_dim, FLAGS.df_dim)
75 | print('Starting the program..')
76 | gfile.MakeDirs(FLAGS.checkpoint_dir)
77 |
78 | model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size)
79 | logdir = os.path.join(FLAGS.checkpoint_dir, model_dir)
80 | gfile.MakeDirs(logdir)
81 |
82 | graph = tf.Graph()
83 | with graph.as_default():
84 |
85 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
86 | # Instantiate global_step.
87 | global_step = tf.train.create_global_step()
88 |
89 | # Create model with FLAGS, global_step, and devices.
90 | devices = ['/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers)]
91 |
92 | # Create noise tensors
93 | zs = utils.make_z_normal(
94 | FLAGS.num_towers, FLAGS.batch_size, FLAGS.z_dim)
95 |
96 | print('save_summaries_steps', FLAGS.save_summaries_steps)
97 |
98 | dcgan = model.SNGAN(
99 | zs=zs,
100 | config=FLAGS,
101 | global_step=global_step,
102 | devices=devices)
103 |
104 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
105 | # Create sync_hooks when needed.
106 | if FLAGS.sync_replicas and FLAGS.num_workers > 1:
107 | print('condition 1')
108 | sync_hooks = [
109 | dcgan.d_opt.make_session_run_hook(FLAGS.task == 0),
110 | dcgan.g_opt.make_session_run_hook(FLAGS.task == 0)
111 | ]
112 | else:
113 | print('condition 2')
114 | sync_hooks = []
115 |
116 | train_ops = tfgan.GANTrainOps(
117 | generator_train_op=dcgan.g_optim,
118 | discriminator_train_op=dcgan.d_optim,
119 | global_step_inc_op=dcgan.increment_global_step)
120 |
121 |
122 | # We set allow_soft_placement to be True because Saver for the DCGAN model
123 | # gets misplaced on the GPU.
124 | session_config = tf.ConfigProto(
125 | allow_soft_placement=True, log_device_placement=False)
126 |
127 | if is_test:
128 | return graph
129 |
130 | print("G step: ", FLAGS.g_step)
131 | print("D_step: ", FLAGS.d_step)
132 | train_steps = tfgan.GANTrainSteps(FLAGS.g_step, FLAGS.d_step)
133 |
134 | tfgan.gan_train(
135 | train_ops,
136 | get_hooks_fn=tfgan.get_sequential_train_hooks(
137 | train_steps=train_steps),
138 | hooks=([tf.train.StopAtStepHook(num_steps=2000000)] + sync_hooks),
139 | logdir=logdir,
140 | # master=FLAGS.master,
141 | # scaffold=scaffold, # load from google checkpoint
142 | is_chief=(FLAGS.task == 0),
143 | save_summaries_steps=FLAGS.save_summaries_steps,
144 | save_checkpoint_secs=FLAGS.save_checkpoint_secs,
145 | config=session_config)
146 |
147 |
148 | if __name__ == '__main__':
149 | tf.app.run()
150 |
--------------------------------------------------------------------------------
/utils_ori.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Some codes from https://github.com/Newmu/dcgan_code."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 | import os
24 | import numpy as np
25 | import scipy.misc
26 | import sympy
27 | import tensorflow as tf
28 |
29 | tfgan = tf.contrib.gan
30 | classifier_metrics = tf.contrib.gan.eval.classifier_metrics
31 | gfile = tf.gfile
32 |
33 |
34 |
35 |
36 | def make_z_normal(num_batches, batch_size, z_dim):
37 | """Make random noises tensors with normal distribution feeding into the generator
38 | Args:
39 | num_batches: copies of batches
40 | batch_size: the batch_size for z
41 | z_dim: The dimension of the z (noise) vector.
42 | Returns:
43 | zs: noise tensors.
44 | """
45 | shape = [num_batches, batch_size, z_dim]
46 | z = tf.random_normal(shape, name='z0', dtype=tf.float32)
47 | return z
48 |
49 |
50 | def run_custom_inception(
51 | images,
52 | output_tensor,
53 | graph_def=None,
54 | # image_size=classifier_metrics.INCEPTION_DEFAULT_IMAGE_SIZE):
55 | image_size=299):
56 | # input_tensor=classifier_metrics.INCEPTION_V1_INPUT):
57 | """Run images through a pretrained Inception classifier.
58 |
59 | This method resizes the images before feeding them through Inception; we do
60 | this to accommodate feeding images through in minibatches without having to
61 | construct any large tensors.
62 |
63 | Args:
64 | images: Input tensors. Must be [batch, height, width, channels]. Input shape
65 | and values must be in [-1, 1], which can be achieved using
66 | `preprocess_image`.
67 | output_tensor: Name of output Tensor. This function will compute activations
68 | at the specified layer. Examples include INCEPTION_V3_OUTPUT and
69 | INCEPTION_V3_FINAL_POOL which would result in this function computing
70 | the final logits or the penultimate pooling layer.
71 | graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
72 | call `default_graph_def_fn` to get GraphDef.
73 | image_size: Required image width and height. See unit tests for the default
74 | values.
75 | input_tensor: Name of input Tensor.
76 |
77 | Returns:
78 | Logits.
79 | """
80 |
81 | images = tf.image.resize_bilinear(images, [image_size, image_size])
82 |
83 | return tfgan.eval.run_inception(
84 | images,
85 | graph_def=graph_def,
86 | image_size=image_size,
87 | # input_tensor=input_tensor,
88 | output_tensor=output_tensor)
89 |
90 |
91 | def get_real_activations(data_dir,
92 | batch_size,
93 | num_batches,
94 | label_offset=0,
95 | cycle_length=1,
96 | shuffle_buffer_size=100000):
97 | """Fetches num_batches batches of size batch_size from the data_dir.
98 |
99 | Args:
100 | data_dir: The directory to read data from. Expected to be a single
101 | TFRecords file.
102 | batch_size: The number of elements in a single minibatch.
103 | num_batches: The number of batches to fetch at a time.
104 | label_offset: The scalar to add to the labels in the dataset. The imagenet
105 | GAN code expects labels in [0, 999], and this scalar can be used to move
106 | other labels into this range. (Default: 0)
107 | cycle_length: The number of input elements to process concurrently in the
108 | Dataset loader. (Default: 1)
109 | shuffle_buffer_size: The number of records to load before shuffling. Larger
110 | means more likely randomization. (Default: 100000)
111 | Returns:
112 | A list of num_batches batches of size batch_size.
113 | """
114 | # filenames = gfile.Glob(os.path.join(data_dir, '*_train_*-*-of-*'))
115 |
116 | filenames = tf.gfile.Glob(os.path.join(data_dir, '*.tfrecords'))
117 | filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
118 | filename_dataset = filename_dataset.shuffle(len(filenames))
119 | prefetch = max(int((batch_size * num_batches) / cycle_length), 1)
120 | dataset = filename_dataset.interleave(
121 | lambda fn: tf.data.TFRecordDataset(fn).prefetch(prefetch),
122 | cycle_length=cycle_length)
123 |
124 | dataset = dataset.shuffle(shuffle_buffer_size)
125 | image_size = 128
126 | # graph_def = classifier_metrics._default_graph_def_fn() # pylint: disable=protected-access
127 |
128 | def _extract_image_and_label(record):
129 | """Extracts and preprocesses the image and label from the record."""
130 | features = tf.parse_single_example(
131 | record,
132 | features={
133 | 'image_raw': tf.FixedLenFeature([], tf.string),
134 | 'label': tf.FixedLenFeature([], tf.int64)
135 | })
136 |
137 | image = tf.decode_raw(features['image_raw'], tf.uint8)
138 | image.set_shape(image_size * image_size * 3)
139 | image = tf.reshape(image, [image_size, image_size, 3])
140 |
141 | image = tf.cast(image, tf.float32) * (2. / 255) - 1.
142 |
143 | label = tf.cast(features['label'], tf.int32)
144 | label += label_offset
145 |
146 | return image, label
147 |
148 | dataset = dataset.map(
149 | _extract_image_and_label,
150 | num_parallel_calls=16).prefetch(batch_size * num_batches)
151 | dataset = dataset.batch(batch_size)
152 | iterator = dataset.make_one_shot_iterator()
153 |
154 | real_images, _ = iterator.get_next()
155 | real_images.set_shape([batch_size, image_size, image_size, 3])
156 |
157 | pools = run_custom_inception(
158 | real_images, graph_def=None, output_tensor=['pool_3:0'])[0]
159 |
160 | def while_cond(_, i):
161 | return tf.less(i, num_batches)
162 |
163 | def while_body(real_pools, i):
164 | with tf.control_dependencies([real_pools]):
165 | imgs, _ = iterator.get_next()
166 | imgs.set_shape([batch_size, image_size, image_size, 3])
167 | pools = run_custom_inception(
168 | imgs, graph_def=None, output_tensor=['pool_3:0'])[0]
169 | real_pools = tf.concat([real_pools, pools], 0)
170 | return (real_pools, tf.add(i, 1))
171 |
172 | # Get activations from real images.
173 | i = tf.constant(1)
174 | real_pools, _ = tf.while_loop(
175 | while_cond,
176 | while_body, [pools, i],
177 | shape_invariants=[tf.TensorShape([None, 2048]),
178 | i.get_shape()],
179 | parallel_iterations=1,
180 | back_prop=False,
181 | swap_memory=True,
182 | name='RealActivations')
183 |
184 | real_pools.set_shape([batch_size * num_batches, 2048])
185 |
186 | return real_pools, real_images
187 |
188 |
189 | def get_imagenet_batches(data_dir,
190 | batch_size,
191 | num_batches,
192 | label_offset=0,
193 | cycle_length=1,
194 | shuffle_buffer_size=100000):
195 | """Fetches num_batches batches of size batch_size from the data_dir.
196 |
197 | Args:
198 | data_dir: The directory to read data from. Expected to be a single
199 | TFRecords file.
200 | batch_size: The number of elements in a single minibatch.
201 | num_batches: The number of batches to fetch at a time.
202 | label_offset: The scalar to add to the labels in the dataset. The imagenet
203 | GAN code expects labels in [0, 999], and this scalar can be used to move
204 | other labels into this range. (Default: 0)
205 | cycle_length: The number of input elements to process concurrently in the
206 | Dataset loader. (Default: 1)
207 | shuffle_buffer_size: The number of records to load before shuffling. Larger
208 | means more likely randomization. (Default: 100000)
209 | Returns:
210 | A list of num_batches batches of size batch_size.
211 | """
212 | # filenames = gfile.Glob(os.path.join(data_dir, '*_train_*-*-of-*'))
213 | filenames = tf.gfile.Glob(os.path.join(data_dir, '*.tfrecords'))
214 | filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
215 | filename_dataset = filename_dataset.shuffle(len(filenames))
216 | prefetch = max(int((batch_size * num_batches) / cycle_length), 1)
217 | dataset = filename_dataset.interleave(
218 | lambda fn: tf.data.TFRecordDataset(fn).prefetch(prefetch),
219 | cycle_length=cycle_length)
220 |
221 | dataset = dataset.shuffle(shuffle_buffer_size)
222 | image_size = 128
223 |
224 | def _extract_image_and_label(record):
225 | """Extracts and preprocesses the image and label from the record."""
226 | features = tf.parse_single_example(
227 | record,
228 | features={
229 | 'image_raw': tf.FixedLenFeature([], tf.string),
230 | 'label': tf.FixedLenFeature([], tf.int64)
231 | })
232 |
233 | image = tf.decode_raw(features['image_raw'], tf.uint8)
234 | image.set_shape(image_size * image_size * 3)
235 | image = tf.reshape(image, [image_size, image_size, 3])
236 |
237 | image = tf.cast(image, tf.float32) * (2. / 255) - 1.
238 |
239 | label = tf.cast(features['label'], tf.int32)
240 | label += label_offset
241 |
242 | return image, label
243 |
244 | dataset = dataset.map(
245 | _extract_image_and_label,
246 | num_parallel_calls=16).prefetch(batch_size * num_batches)
247 | dataset = dataset.repeat() # Repeat for unlimited epochs.
248 | dataset = dataset.batch(batch_size)
249 | dataset = dataset.batch(num_batches)
250 |
251 | iterator = dataset.make_one_shot_iterator()
252 | images, labels = iterator.get_next()
253 |
254 | batches = []
255 | for i in range(num_batches):
256 | # Dataset batches lose shape information. Put it back in.
257 | im = images[i, ...]
258 | im.set_shape([batch_size, image_size, image_size, 3])
259 |
260 | lb = labels[i, ...]
261 | lb.set_shape((batch_size,))
262 |
263 | batches.append((im, tf.expand_dims(lb, 1)))
264 |
265 | return batches
266 |
267 |
268 | def save_images(images, size, image_path):
269 | return imsave(inverse_transform(images), size, image_path)
270 |
271 |
272 | def merge(images, size):
273 | h, w = images.shape[1], images.shape[2]
274 | img = np.zeros((h * size[0], w * size[1], 3))
275 |
276 | idx = 0
277 | for i in range(0, size[0]):
278 | for j in range(0, size[1]):
279 | img[j * h:(j + 1) * h, i * w:(i + 1) * w, :] = images[idx]
280 | idx += 1
281 | return img
282 |
283 |
284 | def imsave(images, size, path):
285 | with gfile.Open(path, mode='w') as f:
286 | saved = scipy.misc.imsave(f, merge(images, size))
287 | return saved
288 |
289 |
290 | def inverse_transform(images):
291 | return (images + 1.) / 2.
292 |
293 |
294 | def visualize(sess, dcgan, config, option):
295 | option = 0
296 | if option == 0:
297 | all_samples = []
298 | for i in range(484):
299 | print(i)
300 | samples = sess.run(dcgan.generator)
301 | all_samples.append(samples)
302 | samples = np.concatenate(all_samples, 0)
303 | n = int(np.sqrt(samples.shape[0]))
304 | m = samples.shape[0] // n
305 | save_images(samples, [m, n], './' + config.sample_dir + '/test.png')
306 | elif option == 1:
307 | counter = 0
308 | coord = tf.train.Coordinator()
309 | tf.train.start_queue_runners(coord=coord)
310 | while counter < 1005:
311 | print(counter)
312 | samples, fake = sess.run([dcgan.generator, dcgan.d_loss_class])
313 | fake = np.argsort(fake)
314 | print(np.sum(samples))
315 | print(fake)
316 | for i in range(samples.shape[0]):
317 | name = '%s%d.png' % (chr(ord('a') + counter % 10), counter)
318 | img = np.expand_dims(samples[fake[i]], 0)
319 | if counter >= 1000:
320 | save_images(img, [1, 1], './{}/turk/fake{}.png'.format(
321 | config.sample_dir, counter - 1000))
322 | else:
323 | save_images(img, [1, 1], './{}/turk/{}'.format(
324 | config.sample_dir, name))
325 | counter += 1
326 | elif option == 2:
327 | values = np.arange(0, 1, 1. / config.batch_size)
328 | for idx in range(100):
329 | print(' [*] %d' % idx)
330 | z_sample = np.zeros([config.batch_size, dcgan.z_dim])
331 | for kdx, z in enumerate(z_sample):
332 | z[idx] = values[kdx]
333 |
334 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
335 | save_images(samples, [8, 8], './{}/test_arange_{}.png'.format(
336 | config.sample_dir, idx))
337 |
338 |
339 | def squarest_grid_size(num_images):
340 | """Calculates the size of the most square grid for num_images.
341 |
342 | Calculates the largest integer divisor of num_images less than or equal to
343 | sqrt(num_images) and returns that as the width. The height is
344 | num_images / width.
345 |
346 | Args:
347 | num_images: The total number of images.
348 |
349 | Returns:
350 | A tuple of (height, width) for the image grid.
351 | """
352 | divisors = sympy.divisors(num_images)
353 | square_root = math.sqrt(num_images)
354 | width = 1
355 | for d in divisors:
356 | if d > square_root:
357 | break
358 | width = d
359 | return (num_images // width, width)
360 |
--------------------------------------------------------------------------------