├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dataset.py ├── discriminator.py ├── eval_imagenet.py ├── generator.py ├── imgs └── img1.png ├── model.py ├── non_local.py ├── ops.py ├── train_imagenet.py └── utils_ori.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | .DS_Store 4 | 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Attention GAN 2 | Tensorflow implementation for reproducing main results in the paper [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318) by Han Zhang, Ian Goodfellow, Dimitris Metaxas, Augustus Odena. 3 | 4 | 5 | 6 | 7 | ### Dependencies 8 | python 3.6 9 | 10 | TensorFlow 1.5 11 | 12 | 13 | **Data** 14 | 15 | Download Imagenet dataset and preprocess the images into tfrecord files as instructed in [improved gan](https://github.com/openai/improved-gan/blob/master/imagenet/convert_imagenet_to_records.py). Put the tfrecord files into ./data 16 | 17 | 18 | **Training** 19 | 20 | The current batch size is 64x4=256. Larger batch size seems to give better performance. But it might need to find new hyperparameters for G&D learning rate. Note: It usually takes several weeks to train one million steps. 21 | 22 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_imagenet.py --generator_type test --discriminator_type test --data_dir ./data 23 | 24 | **Evaluation** 25 | 26 | CUDA_VISIBLE_DEVICES=4 python eval_imagenet.py --generator_type test --data_dir ./data 27 | 28 | ### Citing Self-attention GAN 29 | If you find Self-attention GAN is useful in your research, please consider citing: 30 | 31 | ``` 32 | @article{Han18, 33 | author = {Han Zhang and 34 | Ian J. Goodfellow and 35 | Dimitris N. Metaxas and 36 | Augustus Odena}, 37 | title = {Self-Attention Generative Adversarial Networks}, 38 | year = {2018}, 39 | journal = {arXiv:1805.08318}, 40 | } 41 | ``` 42 | 43 | **References** 44 | 45 | - Spectral Normalization for Generative Adversarial Networks [Paper](https://arxiv.org/abs/1802.05957) 46 | - cGANs with Projection Discriminator [Paper](https://arxiv.org/abs/1802.05637) 47 | - Non-local Neural Networks [Paper](https://arxiv.org/abs/1711.07971) 48 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | import os 18 | IMAGE_SIZE=128 19 | 20 | def _extract_image_and_label(record): 21 | """Extracts and preprocesses the image and label from the record.""" 22 | features = tf.parse_single_example( 23 | record, 24 | features={ 25 | 'image_raw': tf.FixedLenFeature([], tf.string), 26 | 'label': tf.FixedLenFeature([], tf.int64) 27 | }) 28 | image_size = IMAGE_SIZE 29 | image = tf.decode_raw(features['image_raw'], tf.uint8) 30 | image.set_shape(image_size * image_size * 3) 31 | image = tf.reshape(image, [image_size, image_size, 3]) 32 | 33 | image = tf.cast(image, tf.float32) * (2. / 255) - 1. 34 | 35 | label = tf.cast(features['label'], tf.int32) 36 | 37 | return image, label 38 | 39 | class InputFunction(object): 40 | """Wrapper class that is passed as callable to Estimator.""" 41 | 42 | def __init__(self, is_training, noise_dim, dataset_name, num_classes, data_dir="./dataset", 43 | cycle_length=64, shuffle_buffer_size=100000): 44 | self.is_training = is_training 45 | self.noise_dim = noise_dim 46 | split = ('train' if is_training else 'test') 47 | self.data_files = tf.gfile.Glob(os.path.join(data_dir, '*.tfrecords')) 48 | self.parser = _extract_image_and_label 49 | self.num_classes = num_classes 50 | self.cycle_length = cycle_length 51 | self.shuffle_buffer_size = shuffle_buffer_size 52 | 53 | def __call__(self, params): 54 | """Creates a simple Dataset pipeline.""" 55 | 56 | batch_size = params['batch_size'] 57 | filename_dataset = tf.data.Dataset.from_tensor_slices(self.data_files) 58 | filename_dataset = filename_dataset.shuffle(len(self.data_files)) 59 | 60 | def tfrecord_dataset(filename): 61 | buffer_size = 8 * 1024 * 1224 62 | return tf.data.TFRecordDataset(filename, buffer_size=buffer_size) 63 | 64 | dataset = filename_dataset.apply(tf.contrib.data.parallel_interleave( 65 | tfrecord_dataset, 66 | cycle_length=self.cycle_length, sloppy=True)) 67 | if self.is_training: 68 | dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat( 69 | self.shuffle_buffer_size, -1)) 70 | dataset = dataset.map(self.parser, num_parallel_calls=32) 71 | dataset = dataset.apply( 72 | tf.contrib.data.batch_and_drop_remainder(batch_size)) 73 | 74 | dataset = dataset.prefetch(4) # Prefetch overlaps in-feed with training 75 | images, labels = dataset.make_one_shot_iterator().get_next() 76 | labels = tf.squeeze(labels) 77 | random_noise = tf.random_normal([batch_size, self.noise_dim]) 78 | 79 | gen_class_logits = tf.zeros((batch_size, self.num_classes)) 80 | gen_class_ints = tf.multinomial(gen_class_logits, 1) 81 | gen_sparse_class = tf.squeeze(gen_class_ints) 82 | 83 | features = { 84 | 'real_images': images, 85 | 'random_noise': random_noise, 86 | 'fake_labels': gen_sparse_class} 87 | 88 | return features, labels 89 | 90 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The discriminator of SNGAN.""" 16 | 17 | import tensorflow as tf 18 | import ops 19 | import non_local 20 | 21 | 22 | def dsample(x): 23 | """Downsamples the image by a factor of 2.""" 24 | 25 | xd = tf.nn.avg_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID') 26 | return xd 27 | 28 | 29 | def block(x, out_channels, name, update_collection=None, 30 | downsample=True, act=tf.nn.relu): 31 | """Builds the residual blocks used in the discriminator in SNGAN. 32 | 33 | Args: 34 | x: The 4D input vector. 35 | out_channels: Number of features in the output layer. 36 | name: The variable scope name for the block. 37 | update_collection: The update collections used in the 38 | spectral_normed_weight. 39 | downsample: If True, downsample the spatial size the input tensor by 40 | a factor of 4. If False, the spatial size of the input tensor is 41 | unchanged. 42 | act: The activation function used in the block. 43 | Returns: 44 | A `Tensor` representing the output of the operation. 45 | """ 46 | with tf.variable_scope(name): 47 | input_channels = x.shape.as_list()[-1] 48 | x_0 = x 49 | x = act(x) 50 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, 51 | update_collection=update_collection, name='sn_conv1') 52 | x = act(x) 53 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, 54 | update_collection=update_collection, name='sn_conv2') 55 | if downsample: 56 | x = dsample(x) 57 | if downsample or input_channels != out_channels: 58 | x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, 59 | update_collection=update_collection, name='sn_conv3') 60 | if downsample: 61 | x_0 = dsample(x_0) 62 | return x_0 + x 63 | 64 | 65 | def optimized_block(x, out_channels, name, 66 | update_collection=None, act=tf.nn.relu): 67 | """Builds the simplified residual blocks for downsampling. 68 | 69 | Compared with block, optimized_block always downsamples the spatial resolution 70 | of the input vector by a factor of 4. 71 | 72 | Args: 73 | x: The 4D input vector. 74 | out_channels: Number of features in the output layer. 75 | name: The variable scope name for the block. 76 | update_collection: The update collections used in the 77 | spectral_normed_weight. 78 | act: The activation function used in the block. 79 | Returns: 80 | A `Tensor` representing the output of the operation. 81 | """ 82 | with tf.variable_scope(name): 83 | x_0 = x 84 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, 85 | update_collection=update_collection, name='sn_conv1') 86 | x = act(x) 87 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, 88 | update_collection=update_collection, name='sn_conv2') 89 | x = dsample(x) 90 | x_0 = dsample(x_0) 91 | x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, 92 | update_collection=update_collection, name='sn_conv3') 93 | return x + x_0 94 | 95 | 96 | def discriminator_old(image, labels, df_dim, number_classes, update_collection=None, 97 | act=tf.nn.relu, scope='Discriminator'): 98 | """Builds the discriminator graph. 99 | 100 | Args: 101 | image: The current batch of images to classify as fake or real. 102 | labels: The corresponding labels for the images. 103 | df_dim: The df dimension. 104 | number_classes: The number of classes in the labels. 105 | update_collection: The update collections used in the 106 | spectral_normed_weight. 107 | act: The activation function used in the discriminator. 108 | scope: Optional scope for `variable_op_scope`. 109 | Returns: 110 | A `Tensor` representing the logits of the discriminator. 111 | """ 112 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 113 | h0 = optimized_block(image, df_dim, 'd_optimized_block1', 114 | update_collection, act=act) # 64 * 64 115 | h1 = block(h0, df_dim * 2, 'd_block2', 116 | update_collection, act=act) # 32 * 32 117 | h2 = block(h1, df_dim * 4, 'd_block3', 118 | update_collection, act=act) # 16 * 16 119 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 120 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 121 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) 122 | h5_act = act(h5) 123 | h6 = tf.reduce_sum(h5_act, [1, 2]) 124 | output = ops.snlinear(h6, 1, update_collection=update_collection, 125 | name='d_sn_linear') 126 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, 127 | update_collection=update_collection, 128 | name='d_embedding') 129 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) 130 | return output 131 | 132 | 133 | def discriminator(image, labels, df_dim, number_classes, update_collection=None, 134 | act=tf.nn.relu): 135 | """Builds the discriminator graph. 136 | 137 | Args: 138 | image: The current batch of images to classify as fake or real. 139 | labels: The corresponding labels for the images. 140 | df_dim: The df dimension. 141 | number_classes: The number of classes in the labels. 142 | update_collection: The update collections used in the 143 | spectral_normed_weight. 144 | act: The activation function used in the discriminator. 145 | scope: Optional scope for `variable_op_scope`. 146 | Returns: 147 | A `Tensor` representing the logits of the discriminator. 148 | """ 149 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 150 | h0 = optimized_block(image, df_dim, 'd_optimized_block1', 151 | update_collection, act=act) # 64 * 64 152 | h1 = block(h0, df_dim * 2, 'd_block2', 153 | update_collection, act=act) # 32 * 32 154 | h2 = block(h1, df_dim * 4, 'd_block3', 155 | update_collection, act=act) # 16 * 16 156 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 157 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 158 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) 159 | h5_act = act(h5) 160 | h6 = tf.reduce_sum(h5_act, [1, 2]) 161 | output = ops.snlinear(h6, 1, update_collection=update_collection, 162 | name='d_sn_linear') 163 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, 164 | update_collection=update_collection, 165 | name='d_embedding') 166 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) 167 | print('Discriminator Structure') 168 | return output 169 | 170 | def discriminator_test(image, labels, df_dim, number_classes, update_collection=None, 171 | act=tf.nn.relu): 172 | """Builds the discriminator graph. 173 | 174 | Args: 175 | image: The current batch of images to classify as fake or real. 176 | labels: The corresponding labels for the images. 177 | df_dim: The df dimension. 178 | number_classes: The number of classes in the labels. 179 | update_collection: The update collections used in the 180 | spectral_normed_weight. 181 | act: The activation function used in the discriminator. 182 | scope: Optional scope for `variable_op_scope`. 183 | Returns: 184 | A `Tensor` representing the logits of the discriminator. 185 | """ 186 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 187 | h0 = optimized_block(image, df_dim, 'd_optimized_block1', 188 | update_collection, act=act) # 64 * 64 189 | h1 = block(h0, df_dim * 2, 'd_block2', 190 | update_collection, act=act) # 32 * 32 191 | h1 = non_local.sn_non_local_block_sim(h1, update_collection, name='d_non_local') # 32 * 32 192 | h2 = block(h1, df_dim * 4, 'd_block3', 193 | update_collection, act=act) # 16 * 16 194 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 195 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 196 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) 197 | h5_act = act(h5) 198 | h6 = tf.reduce_sum(h5_act, [1, 2]) 199 | output = ops.snlinear(h6, 1, update_collection=update_collection, 200 | name='d_sn_linear') 201 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, 202 | update_collection=update_collection, 203 | name='d_embedding') 204 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) 205 | print('Discriminator Test Structure') 206 | return output 207 | 208 | def discriminator_test_64(image, labels, df_dim, number_classes, update_collection=None, 209 | act=tf.nn.relu): 210 | """Builds the discriminator graph. 211 | 212 | Args: 213 | image: The current batch of images to classify as fake or real. 214 | labels: The corresponding labels for the images. 215 | df_dim: The df dimension. 216 | number_classes: The number of classes in the labels. 217 | update_collection: The update collections used in the 218 | spectral_normed_weight. 219 | act: The activation function used in the discriminator. 220 | scope: Optional scope for `variable_op_scope`. 221 | Returns: 222 | A `Tensor` representing the logits of the discriminator. 223 | """ 224 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 225 | h0 = optimized_block(image, df_dim, 'd_optimized_block1', 226 | update_collection, act=act) # 64 * 64 227 | h0 = non_local.sn_non_local_block_sim(h0, update_collection, name='d_non_local') # 64 * 64 228 | h1 = block(h0, df_dim * 2, 'd_block2', 229 | update_collection, act=act) # 32 * 32 230 | h2 = block(h1, df_dim * 4, 'd_block3', 231 | update_collection, act=act) # 16 * 16 232 | h3 = block(h2, df_dim * 8, 'd_block4', update_collection, act=act) # 8 * 8 233 | h4 = block(h3, df_dim * 16, 'd_block5', update_collection, act=act) # 4 * 4 234 | h5 = block(h4, df_dim * 16, 'd_block6', update_collection, False, act=act) 235 | h5_act = act(h5) 236 | h6 = tf.reduce_sum(h5_act, [1, 2]) 237 | output = ops.snlinear(h6, 1, update_collection=update_collection, 238 | name='d_sn_linear') 239 | h_labels = ops.sn_embedding(labels, number_classes, df_dim * 16, 240 | update_collection=update_collection, 241 | name='d_embedding') 242 | output += tf.reduce_sum(h6 * h_labels, axis=1, keepdims=True) 243 | return output -------------------------------------------------------------------------------- /eval_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generic train.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from absl import flags 23 | import tensorflow as tf 24 | 25 | 26 | import generator as generator_module 27 | import utils_ori as utils 28 | 29 | 30 | 31 | 32 | slim = tf.contrib.slim 33 | tfgan = tf.contrib.gan 34 | 35 | 36 | flags.DEFINE_string( 37 | # 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet', 38 | 'data_dir', '/bigdata1/hz138/Data/imagenet', 39 | 'Directory with Imagenet input data as sharded recordio files of pre-' 40 | 'processed images.') 41 | flags.DEFINE_integer('z_dim', 128, 'The dimension of z') 42 | flags.DEFINE_integer('gf_dim', 64, 'Dimensionality of gf. [64]') 43 | 44 | 45 | flags.DEFINE_string('master', 'local', 46 | 'BNS name of the TensorFlow master to use') 47 | flags.DEFINE_string('checkpoint_dir', 'checkpoint', 'Directory name to load ' 48 | 'the checkpoints. [checkpoint]') 49 | flags.DEFINE_string('sample_dir', 'sample', 'Directory name to save the ' 50 | 'image samples. [sample]') 51 | flags.DEFINE_string('eval_dir', 'checkpoint/eval', 'Directory name to save the ' 52 | 'eval summaries . [eval]') 53 | flags.DEFINE_integer('batch_size', 64, 'Batch size of samples to feed into ' 54 | 'Inception models for evaluation. [16]') 55 | flags.DEFINE_integer('shuffle_buffer_size', 5000, 'Number of records to load ' 56 | 'before shuffling and yielding for consumption. [5000]') 57 | flags.DEFINE_integer('dcgan_generator_batch_size', 100, 'Size of batch to feed ' 58 | 'into generator -- we may stack multiple of these later.') 59 | flags.DEFINE_integer('eval_sample_size', 50000, 60 | 'Number of samples to sample from ' 61 | 'generator and real data. [1024]') 62 | flags.DEFINE_boolean('is_train', False, 'Use DCGAN only for evaluation.') 63 | 64 | flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]') 65 | flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]') 66 | flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]') 67 | flags.DEFINE_integer('replicas_to_aggregate', 1, 'The number of replicas ' 68 | 'to aggregate for synchronous optimization [1]') 69 | 70 | flags.DEFINE_integer('num_towers', 1, 'The number of GPUs to use per task. [1]') 71 | flags.DEFINE_integer('eval_interval_secs', 300, 72 | 'Frequency of generator evaluation with Inception score ' 73 | 'and Frechet Inception Distance. [300]') 74 | 75 | flags.DEFINE_integer('num_classes', 1000, 'The number of classes in the dataset') 76 | flags.DEFINE_string('generator_type', 'test', 'test or baseline') 77 | 78 | FLAGS = flags.FLAGS 79 | 80 | 81 | def main(_): 82 | model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size) 83 | FLAGS.eval_dir = FLAGS.checkpoint_dir + '/eval' 84 | checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, model_dir) 85 | log_dir = os.path.join(FLAGS.eval_dir, model_dir) 86 | print('log_dir', log_dir) 87 | graph_def = None # pylint: disable=protected-access 88 | 89 | # Batch size to feed batches of images through Inception and the generator 90 | # to extract feature vectors to later stack together and compute metrics. 91 | local_batch_size = FLAGS.dcgan_generator_batch_size 92 | if FLAGS.generator_type == 'baseline': 93 | generator_fn = generator_module.generator 94 | elif FLAGS.generator_type == 'test': 95 | generator_fn = generator_module.generator_test 96 | else: 97 | raise NotImplementedError 98 | if FLAGS.num_towers != 1 or FLAGS.num_workers != 1: 99 | raise NotImplementedError( 100 | 'The eval job does not currently support using multiple GPUs') 101 | 102 | # Get activations from real images. 103 | with tf.device('/device:CPU:1'): 104 | real_pools, real_images = utils.get_real_activations( 105 | FLAGS.data_dir, 106 | local_batch_size, 107 | FLAGS.eval_sample_size // local_batch_size, 108 | label_offset=-1, 109 | shuffle_buffer_size=FLAGS.shuffle_buffer_size) 110 | 111 | num_classes = FLAGS.num_classes 112 | gen_class_logits = tf.zeros((local_batch_size, num_classes)) 113 | gen_class_ints = tf.multinomial(gen_class_logits, 1) 114 | gen_sparse_class = tf.squeeze(gen_class_ints) 115 | 116 | 117 | 118 | # Generate the first batch of generated images and extract activations; 119 | # this bootstraps the while_loop with a pools and logits tensor. 120 | 121 | 122 | test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim) 123 | generator = generator_fn( 124 | test_zs[0], 125 | gen_sparse_class, 126 | FLAGS.gf_dim, 127 | FLAGS.num_classes, 128 | is_training=False) 129 | 130 | 131 | 132 | pools, logits = utils.run_custom_inception( 133 | generator, output_tensor=['pool_3:0', 'logits:0'], graph_def=graph_def) 134 | 135 | # Set up while_loop to compute activations of generated images from generator. 136 | def while_cond(g_pools, g_logits, i): # pylint: disable=unused-argument 137 | return tf.less(i, FLAGS.eval_sample_size // local_batch_size) 138 | 139 | # We use a while loop because we want to generate a batch of images 140 | # and then feed that batch through Inception to retrieve the activations. 141 | # Otherwise, if we generate all the samples first and then compute all the 142 | # activations, we will run out of memory. 143 | def while_body(g_pools, g_logits, i): 144 | with tf.control_dependencies([g_pools, g_logits]): 145 | 146 | test_zs = utils.make_z_normal(1, local_batch_size, FLAGS.z_dim) 147 | # Uniform distribution 148 | gen_class_logits = tf.zeros((local_batch_size, num_classes)) 149 | gen_class_ints = tf.multinomial(gen_class_logits, 1) 150 | gen_sparse_class = tf.squeeze(gen_class_ints) 151 | 152 | generator = generator_fn( 153 | test_zs[0], 154 | gen_sparse_class, 155 | FLAGS.gf_dim, 156 | FLAGS.num_classes, 157 | is_training=False) 158 | 159 | pools, logits = utils.run_custom_inception( 160 | generator, 161 | output_tensor=['pool_3:0', 'logits:0'], 162 | graph_def=graph_def) 163 | g_pools = tf.concat([g_pools, pools], 0) 164 | g_logits = tf.concat([g_logits, logits], 0) 165 | 166 | return (g_pools, g_logits, tf.add(i, 1)) 167 | 168 | # Get the activations 169 | i = tf.constant(1) 170 | new_generator_pools_list, new_generator_logits_list, _ = tf.while_loop( 171 | while_cond, 172 | while_body, [pools, logits, i], 173 | shape_invariants=[ 174 | tf.TensorShape([None, 2048]), 175 | tf.TensorShape([None, 1008]), 176 | i.get_shape() 177 | ], 178 | parallel_iterations=1, 179 | back_prop=False, 180 | swap_memory=True, 181 | name='GeneratedActivations') 182 | 183 | new_generator_pools_list.set_shape([FLAGS.eval_sample_size, 2048]) 184 | new_generator_logits_list.set_shape([FLAGS.eval_sample_size, 1008]) 185 | 186 | # Get a small batch of samples from generator to dispaly in TensorBoard 187 | vis_batch_size = 16 188 | eval_vis_zs = utils.make_z_normal( 189 | 1, vis_batch_size, FLAGS.z_dim) 190 | 191 | gen_class_logits_vis = tf.zeros((vis_batch_size, num_classes)) 192 | gen_class_ints_vis = tf.multinomial(gen_class_logits_vis, 1) 193 | gen_sparse_class_vis = tf.squeeze(gen_class_ints_vis) 194 | 195 | eval_vis_images = generator_fn( 196 | eval_vis_zs[0], 197 | gen_sparse_class_vis, 198 | FLAGS.gf_dim, 199 | FLAGS.num_classes, 200 | is_training=False 201 | ) 202 | eval_vis_images = tf.cast((eval_vis_images + 1.) * 127.5, tf.uint8) 203 | 204 | with tf.variable_scope('eval_vis'): 205 | tf.summary.image('generated_images', eval_vis_images) 206 | tf.summary.image('real_images', real_images) 207 | tf.summary.image('real_images_grid', 208 | tfgan.eval.image_grid( 209 | real_images[:16], 210 | grid_shape=utils.squarest_grid_size(16), 211 | image_shape=(128, 128))) 212 | tf.summary.image('generated_images_grid', 213 | tfgan.eval.image_grid( 214 | eval_vis_images[:16], 215 | grid_shape=utils.squarest_grid_size(16), 216 | image_shape=(128, 128))) 217 | 218 | # Use the activations from the real images and generated images to compute 219 | # Inception score and FID. 220 | generated_logits = tf.concat(new_generator_logits_list, 0) 221 | generated_pools = tf.concat(new_generator_pools_list, 0) 222 | 223 | # Compute Frechet Inception Distance and Inception score 224 | incscore = tfgan.eval.classifier_score_from_logits(generated_logits) 225 | fid = tfgan.eval.frechet_classifier_distance_from_activations( 226 | real_pools, generated_pools) 227 | 228 | with tf.variable_scope('eval'): 229 | tf.summary.scalar('fid', fid) 230 | tf.summary.scalar('incscore', incscore) 231 | 232 | session_config = tf.ConfigProto( 233 | allow_soft_placement=True, log_device_placement=False) 234 | 235 | tf.contrib.training.evaluate_repeatedly( 236 | checkpoint_dir=checkpoint_dir, 237 | hooks=[ 238 | tf.contrib.training.SummaryAtEndHook(log_dir), 239 | tf.contrib.training.StopAfterNEvalsHook(1) 240 | ], 241 | config=session_config) 242 | 243 | 244 | if __name__ == '__main__': 245 | tf.app.run() 246 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The generator of SNGAN.""" 17 | 18 | import tensorflow as tf 19 | import ops 20 | import non_local 21 | 22 | 23 | def upscale(x, n): 24 | """Builds box upscaling (also called nearest neighbors). 25 | 26 | Args: 27 | x: 4D image tensor in B01C format. 28 | n: integer scale (must be a power of 2). 29 | 30 | Returns: 31 | 4D tensor of images up scaled by a factor n. 32 | """ 33 | if n == 1: 34 | return x 35 | return tf.batch_to_space(tf.tile(x, [n**2, 1, 1, 1]), [[0, 0], [0, 0]], n) 36 | 37 | 38 | def usample_tpu(x): 39 | """Upscales the width and height of the input vector by a factor of 2.""" 40 | x = upscale(x, 2) 41 | return x 42 | 43 | def usample(x): 44 | _, nh, nw, nx = x.get_shape().as_list() 45 | x = tf.image.resize_nearest_neighbor(x, [nh * 2, nw * 2]) 46 | return x 47 | 48 | def block_no_sn(x, labels, out_channels, num_classes, is_training, name): 49 | """Builds the residual blocks used in the generator. 50 | 51 | Compared with block, optimized_block always downsamples the spatial resolution 52 | of the input vector by a factor of 4. 53 | 54 | Args: 55 | x: The 4D input vector. 56 | labels: The conditional labels in the generation. 57 | out_channels: Number of features in the output layer. 58 | num_classes: Number of classes in the labels. 59 | name: The variable scope name for the block. 60 | Returns: 61 | A `Tensor` representing the output of the operation. 62 | """ 63 | with tf.variable_scope(name): 64 | bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0') 65 | bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1') 66 | x_0 = x 67 | x = tf.nn.relu(bn0(x, labels, is_training)) 68 | x = usample(x) 69 | x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv1') 70 | x = tf.nn.relu(bn1(x, labels, is_training)) 71 | x = ops.conv2d(x, out_channels, 3, 3, 1, 1, name='conv2') 72 | 73 | x_0 = usample(x_0) 74 | x_0 = ops.conv2d(x_0, out_channels, 1, 1, 1, 1, name='conv3') 75 | 76 | return x_0 + x 77 | 78 | def block(x, labels, out_channels, num_classes, is_training, name): 79 | with tf.variable_scope(name): 80 | bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0') 81 | bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1') 82 | x_0 = x 83 | x = tf.nn.relu(bn0(x, labels, is_training)) 84 | x = usample(x) 85 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv1') 86 | x = tf.nn.relu(bn1(x, labels, is_training)) 87 | x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, name='snconv2') 88 | 89 | x_0 = usample(x_0) 90 | x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, name='snconv3') 91 | 92 | return x_0 + x 93 | 94 | 95 | def generator_old(zs, 96 | target_class, 97 | gf_dim, 98 | num_classes, 99 | is_training=True, 100 | scope='Generator'): 101 | """Builds the generator graph propagating from z to x. 102 | 103 | Args: 104 | zs: The list of noise tensors. 105 | target_class: The conditional labels in the generation. 106 | gf_dim: The gf dimension. 107 | num_classes: Number of classes in the labels. 108 | scope: Optional scope for `variable_op_scope`. 109 | 110 | Returns: 111 | outputs: The output layer of the generator. 112 | """ 113 | 114 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 115 | # project `z` and reshape 116 | act0 = ops.linear(zs, gf_dim * 16 * 4 * 4, scope='g_h0') 117 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) 118 | 119 | act1 = block_no_sn(act0, target_class, gf_dim * 16, 120 | num_classes, is_training, 'g_block1') # 8 * 8 121 | act2 = block_no_sn(act1, target_class, gf_dim * 8, 122 | num_classes, is_training, 'g_block2') # 16 * 16 123 | act3 = block_no_sn(act2, target_class, gf_dim * 4, 124 | num_classes, is_training, 'g_block3') # 32 * 32 125 | act4 = block_no_sn(act3, target_class, gf_dim * 2, 126 | num_classes, is_training, 'g_block4') # 64 * 64 127 | act5 = block_no_sn(act4, target_class, gf_dim, 128 | num_classes, is_training, 'g_block5') # 128 * 128 129 | bn = ops.batch_norm(name='g_bn') 130 | 131 | act5 = tf.nn.relu(bn(act5, is_training)) 132 | act6 = ops.conv2d(act5, 3, 3, 3, 1, 1, name='g_conv_last') 133 | out = tf.nn.tanh(act6) 134 | print('GAN baseline with moving average') 135 | return out 136 | 137 | def generator(zs, 138 | target_class, 139 | gf_dim, 140 | num_classes, 141 | is_training=True): 142 | """Builds the generator graph propagating from z to x. 143 | 144 | Args: 145 | zs: The list of noise tensors. 146 | target_class: The conditional labels in the generation. 147 | gf_dim: The gf dimension. 148 | num_classes: Number of classes in the labels. 149 | scope: Optional scope for `variable_op_scope`. 150 | 151 | Returns: 152 | outputs: The output layer of the generator. 153 | """ 154 | 155 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 156 | # project `z` and reshape 157 | act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') 158 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) 159 | 160 | act1 = block(act0, target_class, gf_dim * 16, 161 | num_classes, is_training, 'g_block1') # 8 * 8 162 | act2 = block(act1, target_class, gf_dim * 8, 163 | num_classes, is_training, 'g_block2') # 16 * 16 164 | act3 = block(act2, target_class, gf_dim * 4, 165 | num_classes, is_training, 'g_block3') # 32 * 32 166 | act4 = block(act3, target_class, gf_dim * 2, 167 | num_classes, is_training, 'g_block4') # 64 * 64 168 | act5 = block(act4, target_class, gf_dim, 169 | num_classes, is_training, 'g_block5') # 128 * 128 170 | bn = ops.batch_norm(name='g_bn') 171 | 172 | act5 = tf.nn.relu(bn(act5, is_training)) 173 | act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') 174 | out = tf.nn.tanh(act6) 175 | print('Generator Structure') 176 | return out 177 | 178 | 179 | def generator_test(zs, 180 | target_class, 181 | gf_dim, 182 | num_classes, 183 | is_training=True): 184 | """Builds the generator graph propagating from z to x. 185 | 186 | Args: 187 | zs: The list of noise tensors. 188 | target_class: The conditional labels in the generation. 189 | gf_dim: The gf dimension. 190 | num_classes: Number of classes in the labels. 191 | scope: Optional scope for `variable_op_scope`. 192 | 193 | Returns: 194 | outputs: The output layer of the generator. 195 | """ 196 | 197 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 198 | # project `z` and reshape 199 | act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') 200 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) 201 | 202 | act1 = block(act0, target_class, gf_dim * 16, 203 | num_classes, is_training, 'g_block1') # 8 * 8 204 | act2 = block(act1, target_class, gf_dim * 8, 205 | num_classes, is_training, 'g_block2') # 16 * 16 206 | act3 = block(act2, target_class, gf_dim * 4, 207 | num_classes, is_training, 'g_block3') # 32 * 32 208 | act3 = non_local.sn_non_local_block_sim(act3, None, name='g_non_local') 209 | act4 = block(act3, target_class, gf_dim * 2, 210 | num_classes, is_training, 'g_block4') # 64 * 64 211 | act5 = block(act4, target_class, gf_dim, 212 | num_classes, is_training, 'g_block5') # 128 * 128 213 | bn = ops.batch_norm(name='g_bn') 214 | 215 | act5 = tf.nn.relu(bn(act5, is_training)) 216 | act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') 217 | out = tf.nn.tanh(act6) 218 | print('Generator TEST structure') 219 | return out 220 | 221 | def generator_test_64(zs, 222 | target_class, 223 | gf_dim, 224 | num_classes, 225 | is_training=True): 226 | """Builds the generator graph propagating from z to x. 227 | 228 | Args: 229 | zs: The list of noise tensors. 230 | target_class: The conditional labels in the generation. 231 | gf_dim: The gf dimension. 232 | num_classes: Number of classes in the labels. 233 | scope: Optional scope for `variable_op_scope`. 234 | 235 | Returns: 236 | outputs: The output layer of the generator. 237 | """ 238 | 239 | with tf.variable_scope('model', reuse=tf.AUTO_REUSE): 240 | # project `z` and reshape 241 | act0 = ops.snlinear(zs, gf_dim * 16 * 4 * 4, name='g_snh0') 242 | act0 = tf.reshape(act0, [-1, 4, 4, gf_dim * 16]) 243 | 244 | act1 = block(act0, target_class, gf_dim * 16, 245 | num_classes, is_training, 'g_block1') # 8 * 8 246 | act2 = block(act1, target_class, gf_dim * 8, 247 | num_classes, is_training, 'g_block2') # 16 * 16 248 | act3 = block(act2, target_class, gf_dim * 4, 249 | num_classes, is_training, 'g_block3') # 32 * 32 250 | 251 | act4 = block(act3, target_class, gf_dim * 2, 252 | num_classes, is_training, 'g_block4') # 64 * 64 253 | act4 = non_local.sn_non_local_block_sim(act4, None, name='g_non_local') 254 | act5 = block(act4, target_class, gf_dim, 255 | num_classes, is_training, 'g_block5') # 128 * 128 256 | bn = ops.batch_norm(name='g_bn') 257 | 258 | act5 = tf.nn.relu(bn(act5, is_training)) 259 | act6 = ops.snconv2d(act5, 3, 3, 3, 1, 1, name='g_snconv_last') 260 | out = tf.nn.tanh(act6) 261 | print('GAN test with moving average') 262 | return out 263 | 264 | -------------------------------------------------------------------------------- /imgs/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brain-research/self-attention-gan/ad9612e60f6ba2b5ad3d3340ebae60f724636d75/imgs/img1.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The DCGAN Model.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | import tensorflow as tf 23 | 24 | import discriminator as disc 25 | import generator as generator_module 26 | 27 | import ops 28 | import utils_ori as utils 29 | 30 | 31 | tfgan = tf.contrib.gan 32 | 33 | flags.DEFINE_string( 34 | # 'data_dir', '/gpu/hz138/Data/imagenet', #'/home/hz138/Data/imagenet', 35 | 'data_dir', '/bigdata1/hz138/Data/imagenet', 36 | 'Directory with Imagenet input data as sharded recordio files of pre-' 37 | 'processed images.') 38 | flags.DEFINE_float('discriminator_learning_rate', 0.0004, 39 | 'Learning rate of for adam. [0.0004]') 40 | flags.DEFINE_float('generator_learning_rate', 0.0001, 41 | 'Learning rate of for adam. [0.0004]') 42 | flags.DEFINE_float('beta1', 0.0, 'Momentum term of adam. [0.5]') 43 | flags.DEFINE_integer('image_size', 128, 'The size of image to use ' 44 | '(will be center cropped) [128]') 45 | flags.DEFINE_integer('image_width', 128, 46 | 'The width of the images presented to the model') 47 | flags.DEFINE_integer('data_parallelism', 64, 'The number of objects to read at' 48 | ' one time when loading input data. [64]') 49 | flags.DEFINE_integer('z_dim', 128, 'Dimensionality of latent code z. [8192]') 50 | flags.DEFINE_integer('gf_dim', 64, 'Dimensionality of gf. [64]') 51 | flags.DEFINE_integer('df_dim', 64, 'Dimensionality of df. [64]') 52 | flags.DEFINE_integer('number_classes', 1000, 'The number of classes in the dataset') 53 | flags.DEFINE_string('loss_type', 'hinge_loss', 'the loss type can be' 54 | ' hinge_loss or kl_loss') 55 | flags.DEFINE_string('generator_type', 'test', 'test or baseline') 56 | flags.DEFINE_string('discriminator_type', 'test', 'test or baseline') 57 | 58 | 59 | 60 | FLAGS = flags.FLAGS 61 | 62 | 63 | def _get_d_real_loss(discriminator_on_data_logits): 64 | loss = tf.nn.relu(1.0 - discriminator_on_data_logits) 65 | return tf.reduce_mean(loss) 66 | 67 | 68 | def _get_d_fake_loss(discriminator_on_generator_logits): 69 | return tf.reduce_mean(tf.nn.relu(1 + discriminator_on_generator_logits)) 70 | 71 | 72 | def _get_g_loss(discriminator_on_generator_logits): 73 | return -tf.reduce_mean(discriminator_on_generator_logits) 74 | 75 | 76 | def _get_d_real_loss_KL(discriminator_on_data_logits): 77 | loss = tf.nn.softplus(-discriminator_on_data_logits) 78 | return tf.reduce_mean(loss) 79 | 80 | 81 | def _get_d_fake_loss_KL(discriminator_on_generator_logits): 82 | return tf.reduce_mean(tf.nn.softplus(discriminator_on_generator_logits)) 83 | 84 | 85 | def _get_g_loss_KL(discriminator_on_generator_logits): 86 | return tf.reduce_mean(-discriminator_on_generator_logits) 87 | 88 | 89 | 90 | class SNGAN(object): 91 | """SNGAN model.""" 92 | 93 | def __init__(self, zs, config=None, global_step=None, devices=None): 94 | """Initializes the DCGAN model. 95 | 96 | Args: 97 | zs: input noise tensors for the generator 98 | config: the configuration FLAGS object 99 | global_step: the global training step (maintained by the supervisor) 100 | devices: the list of device names to place ops on (multitower training) 101 | """ 102 | 103 | self.config = config 104 | self.image_size = FLAGS.image_size 105 | self.image_shape = [FLAGS.image_size, FLAGS.image_size, 3] 106 | self.z_dim = FLAGS.z_dim 107 | self.gf_dim = FLAGS.gf_dim 108 | self.df_dim = FLAGS.df_dim 109 | self.num_classes = FLAGS.number_classes 110 | 111 | self.data_parallelism = FLAGS.data_parallelism 112 | self.zs = zs 113 | 114 | self.c_dim = 3 115 | self.dataset_name = 'imagenet' 116 | self.devices = devices 117 | self.global_step = global_step 118 | 119 | self.build_model() 120 | 121 | def build_model(self): 122 | """Builds a model.""" 123 | config = self.config 124 | # If ps_tasks is zero, the local device is used. When using multiple 125 | # (non-local) replicas, the ReplicaDeviceSetter distributes the variables 126 | # across the different devices. 127 | current_step = tf.cast(self.global_step, tf.float32) 128 | # g_ratio = (1.0 + 2e-5 * tf.maximum((current_step - 100000.0), 0.0)) 129 | # g_ratio = tf.minimum(g_ratio, 4.0) 130 | self.d_learning_rate = FLAGS.discriminator_learning_rate 131 | self.g_learning_rate = FLAGS.generator_learning_rate 132 | # self.g_learning_rate = FLAGS.generator_learning_rate / (1.0 + 2e-5 * tf.cast(self.global_step, tf.float32)) 133 | # self.g_learning_rate = FLAGS.generator_learning_rate / g_ratio 134 | with tf.device(tf.train.replica_device_setter(config.ps_tasks)): 135 | self.d_opt = tf.train.AdamOptimizer( 136 | self.d_learning_rate, beta1=FLAGS.beta1) 137 | self.g_opt = tf.train.AdamOptimizer( 138 | self.g_learning_rate, beta1=FLAGS.beta1) 139 | if config.sync_replicas and config.num_workers > 1: 140 | self.d_opt = tf.train.SyncReplicasOptimizer( 141 | opt=self.d_opt, replicas_to_aggregate=config.replicas_to_aggregate) 142 | self.g_opt = tf.train.SyncReplicasOptimizer( 143 | opt=self.g_opt, replicas_to_aggregate=config.replicas_to_aggregate) 144 | 145 | if config.num_towers > 1: 146 | all_d_grads = [] 147 | all_g_grads = [] 148 | for idx, device in enumerate(self.devices): 149 | with tf.device('/%s' % device): 150 | with tf.name_scope('device_%s' % idx): 151 | with ops.variables_on_gpu0(): 152 | self.build_model_single_gpu( 153 | gpu_idx=idx, 154 | batch_size=config.batch_size, 155 | num_towers=config.num_towers) 156 | d_grads = self.d_opt.compute_gradients(self.d_losses[-1], 157 | var_list=self.d_vars) 158 | g_grads = self.g_opt.compute_gradients(self.g_losses[-1], 159 | var_list=self.g_vars) 160 | all_d_grads.append(d_grads) 161 | all_g_grads.append(g_grads) 162 | d_grads = ops.avg_grads(all_d_grads) 163 | g_grads = ops.avg_grads(all_g_grads) 164 | else: 165 | with tf.device(tf.train.replica_device_setter(config.ps_tasks)): 166 | # TODO(olganw): reusing virtual batchnorm doesn't work in the multi- 167 | # replica case. 168 | self.build_model_single_gpu(batch_size=config.batch_size, 169 | num_towers=config.num_towers) 170 | d_grads = self.d_opt.compute_gradients(self.d_losses[-1], 171 | var_list=self.d_vars) 172 | g_grads = self.g_opt.compute_gradients(self.g_losses[-1], 173 | var_list=self.g_vars) 174 | with tf.device(tf.train.replica_device_setter(config.ps_tasks)): 175 | update_moving_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 176 | print('update_moving_ops', update_moving_ops) 177 | if config.sync_replicas: 178 | with tf.control_dependencies(update_moving_ops): 179 | d_step = tf.get_variable('d_step', initializer=0, trainable=False) 180 | self.d_optim = self.d_opt.apply_gradients(d_grads, global_step=d_step) 181 | g_step = tf.get_variable('g_step', initializer=0, trainable=False) 182 | self.g_optim = self.g_opt.apply_gradients(g_grads, global_step=g_step) 183 | else: 184 | # Don't create any additional counters, and don't update the global step 185 | with tf.control_dependencies(update_moving_ops): 186 | self.d_optim = self.d_opt.apply_gradients(d_grads) 187 | self.g_optim = self.g_opt.apply_gradients(g_grads) 188 | 189 | def build_model_single_gpu(self, gpu_idx=0, batch_size=1, num_towers=1): 190 | """Builds a model for a single GPU. 191 | 192 | Args: 193 | gpu_idx: The index of the gpu in the tower. 194 | batch_size: The minibatch size. (Default: 1) 195 | num_towers: The total number of towers in this model. (Default: 1) 196 | """ 197 | config = self.config 198 | show_num = min(config.batch_size, 64) 199 | 200 | reuse_vars = gpu_idx > 0 201 | if gpu_idx == 0: 202 | self.increment_global_step = self.global_step.assign_add(1) 203 | self.batches = utils.get_imagenet_batches( 204 | FLAGS.data_dir, batch_size, num_towers, label_offset=0, 205 | cycle_length=config.data_parallelism, 206 | shuffle_buffer_size=config.shuffle_buffer_size) 207 | sample_images, _ = self.batches[0] 208 | vis_images = tf.cast((sample_images + 1.) * 127.5, tf.uint8) 209 | tf.summary.image('input_image_grid', 210 | tfgan.eval.image_grid( 211 | vis_images[:show_num], 212 | grid_shape=utils.squarest_grid_size( 213 | show_num), 214 | image_shape=(128, 128))) 215 | 216 | images, sparse_labels = self.batches[gpu_idx] 217 | sparse_labels = tf.squeeze(sparse_labels) 218 | print('han spase_labels.shape', sparse_labels.shape) 219 | 220 | gen_class_logits = tf.zeros((batch_size, self.num_classes)) 221 | gen_class_ints = tf.multinomial(gen_class_logits, 1) 222 | # gen_sparse_class = tf.argmax(gen_class_ints, axis=1) BIG BUG!!!!! 223 | gen_sparse_class = tf.squeeze(gen_class_ints) 224 | print('han gen_sparse_class.shape', gen_sparse_class.shape) 225 | assert len(gen_class_ints.get_shape()) == 2 226 | gen_class_ints = tf.squeeze(gen_class_ints) 227 | assert len(gen_class_ints.get_shape()) == 1 228 | gen_class_vector = tf.one_hot(gen_class_ints, self.num_classes) 229 | assert len(gen_class_vector.get_shape()) == 2 230 | assert gen_class_vector.dtype == tf.float32 231 | 232 | if FLAGS.generator_type == 'baseline': 233 | generator_fn = generator_module.generator 234 | elif FLAGS.generator_type == 'test': 235 | generator_fn = generator_module.generator_test 236 | 237 | generator = generator_fn( 238 | self.zs[gpu_idx], 239 | gen_sparse_class, 240 | self.gf_dim, 241 | self.num_classes 242 | ) 243 | 244 | 245 | if gpu_idx == 0: 246 | generator_means = tf.reduce_mean(generator, 0, keep_dims=True) 247 | generator_vars = tf.reduce_mean( 248 | tf.squared_difference(generator, generator_means), 0, keep_dims=True) 249 | generator = tf.Print( 250 | generator, 251 | [tf.reduce_mean(generator_means), tf.reduce_mean(generator_vars)], 252 | 'generator mean and average var', first_n=1) 253 | image_means = tf.reduce_mean(images, 0, keep_dims=True) 254 | image_vars = tf.reduce_mean( 255 | tf.squared_difference(images, image_means), 0, keep_dims=True) 256 | images = tf.Print( 257 | images, [tf.reduce_mean(image_means), tf.reduce_mean(image_vars)], 258 | 'image mean and average var', first_n=1) 259 | sparse_labels = tf.Print(sparse_labels, [sparse_labels, sparse_labels.shape], 'sparse_labels', first_n=2) 260 | gen_sparse_class = tf.Print(gen_sparse_class, [gen_sparse_class, gen_sparse_class.shape], 'gen_sparse_labels', first_n=2) 261 | 262 | self.generators = [] 263 | 264 | self.generators.append(generator) 265 | 266 | if FLAGS.discriminator_type == 'baseline': 267 | discriminator_fn = disc.discriminator 268 | elif FLAGS.discriminator_type == 'test': 269 | discriminator_fn = disc.discriminator_test 270 | else: 271 | raise NotImplementedError 272 | discriminator_on_data_logits = discriminator_fn(images, sparse_labels, self.df_dim, self.num_classes, 273 | update_collection=None) 274 | discriminator_on_generator_logits = discriminator_fn(generator, gen_sparse_class, self.df_dim, self.num_classes, 275 | update_collection="NO_OPS") 276 | 277 | vis_generator = tf.cast((generator + 1.) * 127.5, tf.uint8) 278 | tf.summary.image('generator', vis_generator) 279 | 280 | tf.summary.image('generator_grid', 281 | tfgan.eval.image_grid( 282 | vis_generator[:show_num], 283 | grid_shape=utils.squarest_grid_size(show_num), 284 | image_shape=(128, 128))) 285 | 286 | if FLAGS.loss_type == 'hinge_loss': 287 | d_loss_real = _get_d_real_loss( 288 | discriminator_on_data_logits) 289 | d_loss_fake = _get_d_fake_loss(discriminator_on_generator_logits) 290 | g_loss_gan = _get_g_loss(discriminator_on_generator_logits) 291 | print('hinge loss is using') 292 | elif FLAGS.loss_type == 'kl_loss': 293 | d_loss_real = _get_d_real_loss_KL( 294 | discriminator_on_data_logits) 295 | d_loss_fake = _get_d_fake_loss_KL(discriminator_on_generator_logits) 296 | g_loss_gan = _get_g_loss_KL(discriminator_on_generator_logits) 297 | print('kl loss is using') 298 | else: 299 | raise NotImplementedError 300 | 301 | 302 | d_loss = d_loss_real + d_loss_fake 303 | g_loss = g_loss_gan 304 | 305 | 306 | # add logit log 307 | logit_discriminator_on_data = tf.reduce_mean(discriminator_on_data_logits) 308 | logit_discriminator_on_generator = tf.reduce_mean( 309 | discriminator_on_generator_logits) 310 | 311 | 312 | 313 | # Add summaries. 314 | tf.summary.scalar('d_loss', d_loss) 315 | tf.summary.scalar('d_loss_real', d_loss_real) 316 | tf.summary.scalar('d_loss_fake', d_loss_fake) 317 | tf.summary.scalar('g_loss', g_loss) 318 | tf.summary.scalar('logit_real', logit_discriminator_on_data) 319 | tf.summary.scalar('logit_fake', logit_discriminator_on_generator) 320 | tf.summary.scalar('d_learning_rate', self.d_learning_rate) 321 | tf.summary.scalar('g_learning_rate', self.g_learning_rate) 322 | 323 | 324 | 325 | if gpu_idx == 0: 326 | self.d_loss_reals = [] 327 | self.d_loss_fakes = [] 328 | self.d_losses = [] 329 | self.g_losses = [] 330 | self.d_loss_reals.append(d_loss_real) 331 | self.d_loss_fakes.append(d_loss_fake) 332 | self.d_losses.append(d_loss) 333 | self.g_losses.append(g_loss) 334 | 335 | if gpu_idx == 0: 336 | self.get_vars() 337 | print('gvars', self.g_vars) 338 | print('dvars', self.d_vars) 339 | print('sigma_ratio_vars', self.sigma_ratio_vars) 340 | for var in self.sigma_ratio_vars: 341 | tf.summary.scalar(var.name, var) 342 | 343 | def get_vars(self): 344 | """Get variables.""" 345 | t_vars = tf.trainable_variables() 346 | # TODO(olganw): scoping or collections for this instead of name hack 347 | self.d_vars = [var for var in t_vars if var.name.startswith('model/d_')] 348 | self.g_vars = [var for var in t_vars if var.name.startswith('model/g_')] 349 | self.sigma_ratio_vars = [var for var in t_vars if 'sigma_ratio' in var.name] 350 | for x in self.d_vars: 351 | assert x not in self.g_vars 352 | for x in self.g_vars: 353 | assert x not in self.d_vars 354 | for x in t_vars: 355 | assert x in self.g_vars or x in self.d_vars, x.name 356 | self.all_vars = t_vars 357 | -------------------------------------------------------------------------------- /non_local.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow as tf 17 | import numpy as np 18 | import ops 19 | 20 | def conv1x1(input_, output_dim, 21 | init=tf.contrib.layers.xavier_initializer(), name='conv1x1'): 22 | k_h = 1 23 | k_w = 1 24 | d_h = 1 25 | d_w = 1 26 | with tf.variable_scope(name): 27 | w = tf.get_variable( 28 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 29 | initializer=init) 30 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 31 | return conv 32 | 33 | def sn_conv1x1(input_, output_dim, update_collection, 34 | init=tf.contrib.layers.xavier_initializer(), name='sn_conv1x1'): 35 | with tf.variable_scope(name): 36 | k_h = 1 37 | k_w = 1 38 | d_h = 1 39 | d_w = 1 40 | w = tf.get_variable( 41 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 42 | initializer=init) 43 | w_bar = ops.spectral_normed_weight(w, num_iters=1, update_collection=update_collection) 44 | 45 | conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME') 46 | return conv 47 | 48 | def sn_non_local_block_sim(x, update_collection, name, init=tf.contrib.layers.xavier_initializer()): 49 | with tf.variable_scope(name): 50 | batch_size, h, w, num_channels = x.get_shape().as_list() 51 | location_num = h * w 52 | downsampled_num = location_num // 4 53 | 54 | # theta path 55 | theta = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_theta') 56 | theta = tf.reshape( 57 | theta, [batch_size, location_num, num_channels // 8]) 58 | 59 | # phi path 60 | phi = sn_conv1x1(x, num_channels // 8, update_collection, init, 'sn_conv_phi') 61 | phi = tf.layers.max_pooling2d(inputs=phi, pool_size=[2, 2], strides=2) 62 | phi = tf.reshape( 63 | phi, [batch_size, downsampled_num, num_channels // 8]) 64 | 65 | 66 | attn = tf.matmul(theta, phi, transpose_b=True) 67 | attn = tf.nn.softmax(attn) 68 | print(tf.reduce_sum(attn, axis=-1)) 69 | 70 | # g path 71 | g = sn_conv1x1(x, num_channels // 2, update_collection, init, 'sn_conv_g') 72 | g = tf.layers.max_pooling2d(inputs=g, pool_size=[2, 2], strides=2) 73 | g = tf.reshape( 74 | g, [batch_size, downsampled_num, num_channels // 2]) 75 | 76 | attn_g = tf.matmul(attn, g) 77 | attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2]) 78 | sigma = tf.get_variable( 79 | 'sigma_ratio', [], initializer=tf.constant_initializer(0.0)) 80 | attn_g = sn_conv1x1(attn_g, num_channels, update_collection, init, 'sn_conv_attn') 81 | return x + sigma * attn_g -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The building block ops for Spectral Normalization GAN.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | from contextlib import contextmanager 24 | 25 | 26 | rng = np.random.RandomState([2016, 6, 1]) 27 | 28 | 29 | def conv2d(input_, output_dim, 30 | k_h=3, k_w=3, d_h=2, d_w=2, name='conv2d'): 31 | """Creates convolutional layers which use xavier initializer. 32 | 33 | Args: 34 | input_: 4D input tensor (batch size, height, width, channel). 35 | output_dim: Number of features in the output layer. 36 | k_h: The height of the convolutional kernel. 37 | k_w: The width of the convolutional kernel. 38 | d_h: The height stride of the convolutional kernel. 39 | d_w: The width stride of the convolutional kernel. 40 | name: The name of the variable scope. 41 | Returns: 42 | conv: The normalized tensor. 43 | """ 44 | with tf.variable_scope(name): 45 | w = tf.get_variable( 46 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 47 | initializer=tf.contrib.layers.xavier_initializer()) 48 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 49 | 50 | biases = tf.get_variable('biases', [output_dim], 51 | initializer=tf.zeros_initializer()) 52 | conv = tf.nn.bias_add(conv, biases) 53 | return conv 54 | 55 | 56 | def deconv2d(input_, output_shape, 57 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 58 | name='deconv2d', init_bias=0.): 59 | """Creates deconvolutional layers. 60 | 61 | Args: 62 | input_: 4D input tensor (batch size, height, width, channel). 63 | output_shape: Number of features in the output layer. 64 | k_h: The height of the convolutional kernel. 65 | k_w: The width of the convolutional kernel. 66 | d_h: The height stride of the convolutional kernel. 67 | d_w: The width stride of the convolutional kernel. 68 | stddev: The standard deviation for weights initializer. 69 | name: The name of the variable scope. 70 | init_bias: The initial bias for the layer. 71 | Returns: 72 | conv: The normalized tensor. 73 | """ 74 | with tf.variable_scope(name): 75 | w = tf.get_variable('w', 76 | [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 77 | initializer=tf.random_normal_initializer(stddev=stddev)) 78 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 79 | strides=[1, d_h, d_w, 1]) 80 | biases = tf.get_variable('biases', [output_shape[-1]], 81 | initializer=tf.constant_initializer(init_bias)) 82 | deconv = tf.nn.bias_add(deconv, biases) 83 | deconv.shape.assert_is_compatible_with(output_shape) 84 | 85 | return deconv 86 | 87 | 88 | def linear(x, output_size, scope=None, bias_start=0.0): 89 | """Creates a linear layer. 90 | 91 | Args: 92 | x: 2D input tensor (batch size, features) 93 | output_size: Number of features in the output layer 94 | scope: Optional, variable scope to put the layer's parameters into 95 | bias_start: The bias parameters are initialized to this value 96 | 97 | Returns: 98 | The normalized tensor 99 | """ 100 | shape = x.get_shape().as_list() 101 | 102 | with tf.variable_scope(scope or 'Linear'): 103 | matrix = tf.get_variable( 104 | 'Matrix', [shape[1], output_size], tf.float32, 105 | tf.contrib.layers.xavier_initializer()) 106 | bias = tf.get_variable( 107 | 'bias', [output_size], initializer=tf.constant_initializer(bias_start)) 108 | out = tf.matmul(x, matrix) + bias 109 | return out 110 | 111 | 112 | def lrelu(x, leak=0.2, name='lrelu'): 113 | """The leaky RELU operation.""" 114 | with tf.variable_scope(name): 115 | f1 = 0.5 * (1 + leak) 116 | f2 = 0.5 * (1 - leak) 117 | return f1 * x + f2 * abs(x) 118 | 119 | 120 | def _l2normalize(v, eps=1e-12): 121 | """l2 normize the input vector.""" 122 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 123 | 124 | 125 | def spectral_normed_weight(weights, num_iters=1, update_collection=None, 126 | with_sigma=False): 127 | """Performs Spectral Normalization on a weight tensor. 128 | 129 | Specifically it divides the weight tensor by its largest singular value. This 130 | is intended to stabilize GAN training, by making the discriminator satisfy a 131 | local 1-Lipschitz constraint. 132 | Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan] 133 | [sn-gan] https://openreview.net/pdf?id=B1QRgziT- 134 | 135 | Args: 136 | weights: The weight tensor which requires spectral normalization 137 | num_iters: Number of SN iterations. 138 | update_collection: The update collection for assigning persisted variable u. 139 | If None, the function will update u during the forward 140 | pass. Else if the update_collection equals 'NO_OPS', the 141 | function will not update the u during the forward. This 142 | is useful for the discriminator, since it does not update 143 | u in the second pass. 144 | Else, it will put the assignment in a collection 145 | defined by the user. Then the user need to run the 146 | assignment explicitly. 147 | with_sigma: For debugging purpose. If True, the fuction returns 148 | the estimated singular value for the weight tensor. 149 | Returns: 150 | w_bar: The normalized weight tensor 151 | sigma: The estimated singular value for the weight tensor. 152 | """ 153 | w_shape = weights.shape.as_list() 154 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) # [-1, output_channel] 155 | u = tf.get_variable('u', [1, w_shape[-1]], 156 | initializer=tf.truncated_normal_initializer(), 157 | trainable=False) 158 | u_ = u 159 | for _ in range(num_iters): 160 | v_ = _l2normalize(tf.matmul(u_, w_mat, transpose_b=True)) 161 | u_ = _l2normalize(tf.matmul(v_, w_mat)) 162 | 163 | sigma = tf.squeeze(tf.matmul(tf.matmul(v_, w_mat), u_, transpose_b=True)) 164 | w_mat /= sigma 165 | if update_collection is None: 166 | with tf.control_dependencies([u.assign(u_)]): 167 | w_bar = tf.reshape(w_mat, w_shape) 168 | else: 169 | w_bar = tf.reshape(w_mat, w_shape) 170 | if update_collection != 'NO_OPS': 171 | tf.add_to_collection(update_collection, u.assign(u_)) 172 | if with_sigma: 173 | return w_bar, sigma 174 | else: 175 | return w_bar 176 | 177 | 178 | def snconv2d(input_, output_dim, 179 | k_h=3, k_w=3, d_h=2, d_w=2, 180 | sn_iters=1, update_collection=None, name='snconv2d'): 181 | """Creates a spectral normalized (SN) convolutional layer. 182 | 183 | Args: 184 | input_: 4D input tensor (batch size, height, width, channel). 185 | output_dim: Number of features in the output layer. 186 | k_h: The height of the convolutional kernel. 187 | k_w: The width of the convolutional kernel. 188 | d_h: The height stride of the convolutional kernel. 189 | d_w: The width stride of the convolutional kernel. 190 | sn_iters: The number of SN iterations. 191 | update_collection: The update collection used in spectral_normed_weight. 192 | name: The name of the variable scope. 193 | Returns: 194 | conv: The normalized tensor. 195 | 196 | """ 197 | with tf.variable_scope(name): 198 | w = tf.get_variable( 199 | 'w', [k_h, k_w, input_.get_shape()[-1], output_dim], 200 | initializer=tf.contrib.layers.xavier_initializer()) 201 | w_bar = spectral_normed_weight(w, num_iters=sn_iters, 202 | update_collection=update_collection) 203 | 204 | conv = tf.nn.conv2d(input_, w_bar, strides=[1, d_h, d_w, 1], padding='SAME') 205 | biases = tf.get_variable('biases', [output_dim], 206 | initializer=tf.zeros_initializer()) 207 | conv = tf.nn.bias_add(conv, biases) 208 | return conv 209 | 210 | 211 | def snlinear(x, output_size, bias_start=0.0, 212 | sn_iters=1, update_collection=None, name='snlinear'): 213 | """Creates a spectral normalized linear layer. 214 | 215 | Args: 216 | x: 2D input tensor (batch size, features). 217 | output_size: Number of features in output of layer. 218 | bias_start: The bias parameters are initialized to this value 219 | sn_iters: Number of SN iterations. 220 | update_collection: The update collection used in spectral_normed_weight 221 | name: Optional, variable scope to put the layer's parameters into 222 | Returns: 223 | The normalized tensor 224 | """ 225 | shape = x.get_shape().as_list() 226 | 227 | with tf.variable_scope(name): 228 | matrix = tf.get_variable( 229 | 'Matrix', [shape[1], output_size], tf.float32, 230 | tf.contrib.layers.xavier_initializer()) 231 | matrix_bar = spectral_normed_weight(matrix, num_iters=sn_iters, 232 | update_collection=update_collection) 233 | bias = tf.get_variable( 234 | 'bias', [output_size], initializer=tf.constant_initializer(bias_start)) 235 | out = tf.matmul(x, matrix_bar) + bias 236 | return out 237 | 238 | 239 | def sn_embedding(x, number_classes, embedding_size, sn_iters=1, 240 | update_collection=None, name='snembedding'): 241 | """Creates a spectral normalized embedding lookup layer. 242 | 243 | Args: 244 | x: 1D input tensor (batch size, ). 245 | number_classes: The number of classes. 246 | embedding_size: The length of the embeddding vector for each class. 247 | sn_iters: Number of SN iterations. 248 | update_collection: The update collection used in spectral_normed_weight 249 | name: Optional, variable scope to put the layer's parameters into 250 | Returns: 251 | The output tensor (batch size, embedding_size). 252 | """ 253 | with tf.variable_scope(name): 254 | embedding_map = tf.get_variable( 255 | name='embedding_map', 256 | shape=[number_classes, embedding_size], 257 | initializer=tf.contrib.layers.xavier_initializer()) 258 | embedding_map_bar_transpose = spectral_normed_weight( 259 | tf.transpose(embedding_map), num_iters=sn_iters, 260 | update_collection=update_collection) 261 | embedding_map_bar = tf.transpose(embedding_map_bar_transpose) 262 | return tf.nn.embedding_lookup(embedding_map_bar, x) 263 | 264 | 265 | class ConditionalBatchNorm_old(object): 266 | """Conditional BatchNorm. 267 | 268 | For each class, it has a specific gamma and beta as normalization variable. 269 | """ 270 | 271 | def __init__(self, num_categories, name='conditional_batch_norm', center=True, 272 | scale=True): 273 | with tf.variable_scope(name): 274 | self.name = name 275 | self.num_categories = num_categories 276 | self.center = center 277 | self.scale = scale 278 | 279 | def __call__(self, inputs, labels): 280 | inputs = tf.convert_to_tensor(inputs) 281 | inputs_shape = inputs.get_shape() 282 | params_shape = inputs_shape[-1:] 283 | axis = [0, 1, 2] 284 | shape = tf.TensorShape([self.num_categories]).concatenate(params_shape) 285 | 286 | with tf.variable_scope(self.name): 287 | self.gamma = tf.get_variable( 288 | 'gamma', shape, 289 | initializer=tf.ones_initializer()) 290 | self.beta = tf.get_variable( 291 | 'beta', shape, 292 | initializer=tf.zeros_initializer()) 293 | beta = tf.gather(self.beta, labels) 294 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) 295 | gamma = tf.gather(self.gamma, labels) 296 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) 297 | mean, variance = tf.nn.moments(inputs, axis, keep_dims=True) 298 | variance_epsilon = 1E-5 299 | outputs = tf.nn.batch_normalization( 300 | inputs, mean, variance, beta, gamma, variance_epsilon) 301 | outputs.set_shape(inputs_shape) 302 | return outputs 303 | 304 | class ConditionalBatchNorm(object): 305 | """Conditional BatchNorm. 306 | 307 | For each class, it has a specific gamma and beta as normalization variable. 308 | """ 309 | 310 | def __init__(self, num_categories, name='conditional_batch_norm', decay_rate=0.999, center=True, 311 | scale=True): 312 | with tf.variable_scope(name): 313 | self.name = name 314 | self.num_categories = num_categories 315 | self.center = center 316 | self.scale = scale 317 | self.decay_rate = decay_rate 318 | 319 | def __call__(self, inputs, labels, is_training=True): 320 | inputs = tf.convert_to_tensor(inputs) 321 | inputs_shape = inputs.get_shape() 322 | params_shape = inputs_shape[-1:] 323 | axis = [0, 1, 2] 324 | shape = tf.TensorShape([self.num_categories]).concatenate(params_shape) 325 | moving_shape = tf.TensorShape([1, 1, 1]).concatenate(params_shape) 326 | 327 | with tf.variable_scope(self.name): 328 | self.gamma = tf.get_variable( 329 | 'gamma', shape, 330 | initializer=tf.ones_initializer()) 331 | self.beta = tf.get_variable( 332 | 'beta', shape, 333 | initializer=tf.zeros_initializer()) 334 | self.moving_mean = tf.get_variable('mean', moving_shape, 335 | initializer=tf.zeros_initializer(), 336 | trainable=False) 337 | self.moving_var = tf.get_variable('var', moving_shape, 338 | initializer=tf.ones_initializer(), 339 | trainable=False) 340 | 341 | beta = tf.gather(self.beta, labels) 342 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) 343 | gamma = tf.gather(self.gamma, labels) 344 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) 345 | decay = self.decay_rate 346 | variance_epsilon = 1E-5 347 | if is_training: 348 | mean, variance = tf.nn.moments(inputs, axis, keep_dims=True) 349 | update_mean = tf.assign(self.moving_mean, self.moving_mean * decay + mean * (1 - decay)) 350 | update_var = tf.assign(self.moving_var, self.moving_var * decay + variance * (1 - decay)) 351 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean) 352 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var) 353 | #with tf.control_dependencies([update_mean, update_var]): 354 | outputs = tf.nn.batch_normalization( 355 | inputs, mean, variance, beta, gamma, variance_epsilon) 356 | else: 357 | outputs = tf.nn.batch_normalization( 358 | inputs, self.moving_mean, self.moving_var, beta, gamma, variance_epsilon) 359 | outputs.set_shape(inputs_shape) 360 | return outputs 361 | 362 | class batch_norm(object): 363 | def __init__(self, epsilon=1e-5, momentum = 0.9999, name="batch_norm"): 364 | with tf.variable_scope(name): 365 | self.epsilon = epsilon 366 | self.momentum = momentum 367 | self.name = name 368 | 369 | def __call__(self, x, train=True): 370 | return tf.contrib.layers.batch_norm(x, 371 | decay=self.momentum, 372 | # updates_collections=None, 373 | epsilon=self.epsilon, 374 | scale=True, 375 | is_training=train, 376 | scope=self.name) 377 | 378 | class BatchNorm(object): 379 | """The Batch Normalization layer.""" 380 | 381 | def __init__(self, name='batch_norm', center=True, 382 | scale=True): 383 | with tf.variable_scope(name): 384 | self.name = name 385 | self.center = center 386 | self.scale = scale 387 | 388 | def __call__(self, inputs): 389 | inputs = tf.convert_to_tensor(inputs) 390 | inputs_shape = inputs.get_shape().as_list() 391 | params_shape = inputs_shape[-1] 392 | axis = [0, 1, 2] 393 | shape = tf.TensorShape([params_shape]) 394 | with tf.variable_scope(self.name): 395 | self.gamma = tf.get_variable( 396 | 'gamma', shape, 397 | initializer=tf.ones_initializer()) 398 | self.beta = tf.get_variable( 399 | 'beta', shape, 400 | initializer=tf.zeros_initializer()) 401 | beta = self.beta 402 | gamma = self.gamma 403 | 404 | mean, variance = tf.nn.moments(inputs, axis, keep_dims=True) 405 | variance_epsilon = 1E-5 406 | outputs = tf.nn.batch_normalization( 407 | inputs, mean, variance, beta, gamma, variance_epsilon) 408 | outputs.set_shape(inputs_shape) 409 | return outputs 410 | 411 | @contextmanager 412 | def variables_on_gpu0(): 413 | old_fn = tf.get_variable 414 | def new_fn(*args, **kwargs): 415 | with tf.device('/gpu:0'): 416 | return old_fn(*args, **kwargs) 417 | tf.get_variable = new_fn 418 | yield 419 | tf.get_variable = old_fn 420 | 421 | 422 | def avg_grads(tower_grads): 423 | """Calculate the average gradient for each shared variable across all towers. 424 | 425 | Note that this function provides a synchronization point across all towers. 426 | 427 | Args: 428 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 429 | is over individual gradients. The inner list is over the gradient 430 | calculation for each tower. 431 | Returns: 432 | List of pairs of (gradient, variable) where the gradient has been averaged 433 | across all towers. 434 | """ 435 | average_grads = [] 436 | for grad_and_vars in zip(*tower_grads): 437 | # Note that each grad_and_vars looks like the following: 438 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 439 | grads = [] 440 | for g, _ in grad_and_vars: 441 | # Add 0 dimension to the gradients to represent the tower. 442 | expanded_g = tf.expand_dims(g, 0) 443 | 444 | # Append on a 'tower' dimension which we will average over below. 445 | grads.append(expanded_g) 446 | 447 | # Average over the 'tower' dimension. 448 | grad = tf.concat(grads, 0) 449 | grad = tf.reduce_mean(grad, 0) 450 | 451 | # Keep in mind that the Variables are redundant because they are shared 452 | # across towers. So .. we will just return the first tower's pointer to 453 | # the Variable. 454 | v = grad_and_vars[0][1] 455 | grad_and_var = (grad, v) 456 | average_grads.append(grad_and_var) 457 | return average_grads -------------------------------------------------------------------------------- /train_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Train Imagenet.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl import flags 24 | import tensorflow as tf 25 | 26 | import utils_ori as utils 27 | import model 28 | 29 | 30 | 31 | tfgan = tf.contrib.gan 32 | gfile = tf.gfile 33 | 34 | 35 | flags.DEFINE_string('master', 'local', 36 | 'BNS name of the TensorFlow master to use. [local]') 37 | 38 | # flags.DEFINE_string('checkpoint_dir', '/usr/local/google/home/zhanghan/Documents/Research/model_output', 39 | # 'Directory name to save the checkpoints. [checkpoint]') 40 | flags.DEFINE_string('checkpoint_dir', 'checkpoint', 41 | 'Directory name to save the checkpoints. [checkpoint]') 42 | flags.DEFINE_integer('batch_size', 64, 'Number of images in input batch. [64]') # ori 16 43 | flags.DEFINE_integer('shuffle_buffer_size', 100000, 'Number of records to load ' 44 | 'before shuffling and yielding for consumption. [100000]') 45 | flags.DEFINE_integer('save_summaries_steps', 200, 'Number of seconds between ' 46 | 'saving summary statistics. [1]') # default 300 47 | flags.DEFINE_integer('save_checkpoint_secs', 1200, 'Number of seconds between ' 48 | 'saving checkpoints of model. [1200]') 49 | flags.DEFINE_boolean('is_train', True, 'True for training. [default: True]') 50 | flags.DEFINE_boolean('is_gd_equal', True, 'True for 1:1, False for 1:5') 51 | 52 | # TODO(olganw) Find the best way to clean up these flags for eval and train. 53 | flags.DEFINE_integer('task', 0, 'The task id of the current worker. [0]') 54 | flags.DEFINE_integer('ps_tasks', 0, 'The number of ps tasks. [0]') 55 | flags.DEFINE_integer('num_workers', 1, 'The number of worker tasks. [1]') 56 | flags.DEFINE_integer('replicas_to_aggregate', 1, 'The number of replicas ' 57 | 'to aggregate for synchronous optimization [1]') 58 | flags.DEFINE_boolean('sync_replicas', True, 'Whether to sync replicas. [True]') 59 | flags.DEFINE_integer('num_towers', 4, 'The number of GPUs to use per task. [1]') 60 | flags.DEFINE_integer('d_step', 1, 'The number of D_step') 61 | flags.DEFINE_integer('g_step', 1, 'The number of G_step') 62 | 63 | # flags.DEFINE_integer('z_dim', 128, 'The dimension of z') 64 | 65 | FLAGS = flags.FLAGS 66 | 67 | 68 | 69 | def main(_, is_test=False): 70 | print('d_learning_rate', FLAGS.discriminator_learning_rate) 71 | print('g_learning_rate', FLAGS.generator_learning_rate) 72 | print('data_dir', FLAGS.data_dir) 73 | print(FLAGS.loss_type, FLAGS.batch_size, FLAGS.beta1) 74 | print('gf_df_dim', FLAGS.gf_dim, FLAGS.df_dim) 75 | print('Starting the program..') 76 | gfile.MakeDirs(FLAGS.checkpoint_dir) 77 | 78 | model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size) 79 | logdir = os.path.join(FLAGS.checkpoint_dir, model_dir) 80 | gfile.MakeDirs(logdir) 81 | 82 | graph = tf.Graph() 83 | with graph.as_default(): 84 | 85 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): 86 | # Instantiate global_step. 87 | global_step = tf.train.create_global_step() 88 | 89 | # Create model with FLAGS, global_step, and devices. 90 | devices = ['/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers)] 91 | 92 | # Create noise tensors 93 | zs = utils.make_z_normal( 94 | FLAGS.num_towers, FLAGS.batch_size, FLAGS.z_dim) 95 | 96 | print('save_summaries_steps', FLAGS.save_summaries_steps) 97 | 98 | dcgan = model.SNGAN( 99 | zs=zs, 100 | config=FLAGS, 101 | global_step=global_step, 102 | devices=devices) 103 | 104 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): 105 | # Create sync_hooks when needed. 106 | if FLAGS.sync_replicas and FLAGS.num_workers > 1: 107 | print('condition 1') 108 | sync_hooks = [ 109 | dcgan.d_opt.make_session_run_hook(FLAGS.task == 0), 110 | dcgan.g_opt.make_session_run_hook(FLAGS.task == 0) 111 | ] 112 | else: 113 | print('condition 2') 114 | sync_hooks = [] 115 | 116 | train_ops = tfgan.GANTrainOps( 117 | generator_train_op=dcgan.g_optim, 118 | discriminator_train_op=dcgan.d_optim, 119 | global_step_inc_op=dcgan.increment_global_step) 120 | 121 | 122 | # We set allow_soft_placement to be True because Saver for the DCGAN model 123 | # gets misplaced on the GPU. 124 | session_config = tf.ConfigProto( 125 | allow_soft_placement=True, log_device_placement=False) 126 | 127 | if is_test: 128 | return graph 129 | 130 | print("G step: ", FLAGS.g_step) 131 | print("D_step: ", FLAGS.d_step) 132 | train_steps = tfgan.GANTrainSteps(FLAGS.g_step, FLAGS.d_step) 133 | 134 | tfgan.gan_train( 135 | train_ops, 136 | get_hooks_fn=tfgan.get_sequential_train_hooks( 137 | train_steps=train_steps), 138 | hooks=([tf.train.StopAtStepHook(num_steps=2000000)] + sync_hooks), 139 | logdir=logdir, 140 | # master=FLAGS.master, 141 | # scaffold=scaffold, # load from google checkpoint 142 | is_chief=(FLAGS.task == 0), 143 | save_summaries_steps=FLAGS.save_summaries_steps, 144 | save_checkpoint_secs=FLAGS.save_checkpoint_secs, 145 | config=session_config) 146 | 147 | 148 | if __name__ == '__main__': 149 | tf.app.run() 150 | -------------------------------------------------------------------------------- /utils_ori.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Some codes from https://github.com/Newmu/dcgan_code.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | import os 24 | import numpy as np 25 | import scipy.misc 26 | import sympy 27 | import tensorflow as tf 28 | 29 | tfgan = tf.contrib.gan 30 | classifier_metrics = tf.contrib.gan.eval.classifier_metrics 31 | gfile = tf.gfile 32 | 33 | 34 | 35 | 36 | def make_z_normal(num_batches, batch_size, z_dim): 37 | """Make random noises tensors with normal distribution feeding into the generator 38 | Args: 39 | num_batches: copies of batches 40 | batch_size: the batch_size for z 41 | z_dim: The dimension of the z (noise) vector. 42 | Returns: 43 | zs: noise tensors. 44 | """ 45 | shape = [num_batches, batch_size, z_dim] 46 | z = tf.random_normal(shape, name='z0', dtype=tf.float32) 47 | return z 48 | 49 | 50 | def run_custom_inception( 51 | images, 52 | output_tensor, 53 | graph_def=None, 54 | # image_size=classifier_metrics.INCEPTION_DEFAULT_IMAGE_SIZE): 55 | image_size=299): 56 | # input_tensor=classifier_metrics.INCEPTION_V1_INPUT): 57 | """Run images through a pretrained Inception classifier. 58 | 59 | This method resizes the images before feeding them through Inception; we do 60 | this to accommodate feeding images through in minibatches without having to 61 | construct any large tensors. 62 | 63 | Args: 64 | images: Input tensors. Must be [batch, height, width, channels]. Input shape 65 | and values must be in [-1, 1], which can be achieved using 66 | `preprocess_image`. 67 | output_tensor: Name of output Tensor. This function will compute activations 68 | at the specified layer. Examples include INCEPTION_V3_OUTPUT and 69 | INCEPTION_V3_FINAL_POOL which would result in this function computing 70 | the final logits or the penultimate pooling layer. 71 | graph_def: A GraphDef proto of a pretrained Inception graph. If `None`, 72 | call `default_graph_def_fn` to get GraphDef. 73 | image_size: Required image width and height. See unit tests for the default 74 | values. 75 | input_tensor: Name of input Tensor. 76 | 77 | Returns: 78 | Logits. 79 | """ 80 | 81 | images = tf.image.resize_bilinear(images, [image_size, image_size]) 82 | 83 | return tfgan.eval.run_inception( 84 | images, 85 | graph_def=graph_def, 86 | image_size=image_size, 87 | # input_tensor=input_tensor, 88 | output_tensor=output_tensor) 89 | 90 | 91 | def get_real_activations(data_dir, 92 | batch_size, 93 | num_batches, 94 | label_offset=0, 95 | cycle_length=1, 96 | shuffle_buffer_size=100000): 97 | """Fetches num_batches batches of size batch_size from the data_dir. 98 | 99 | Args: 100 | data_dir: The directory to read data from. Expected to be a single 101 | TFRecords file. 102 | batch_size: The number of elements in a single minibatch. 103 | num_batches: The number of batches to fetch at a time. 104 | label_offset: The scalar to add to the labels in the dataset. The imagenet 105 | GAN code expects labels in [0, 999], and this scalar can be used to move 106 | other labels into this range. (Default: 0) 107 | cycle_length: The number of input elements to process concurrently in the 108 | Dataset loader. (Default: 1) 109 | shuffle_buffer_size: The number of records to load before shuffling. Larger 110 | means more likely randomization. (Default: 100000) 111 | Returns: 112 | A list of num_batches batches of size batch_size. 113 | """ 114 | # filenames = gfile.Glob(os.path.join(data_dir, '*_train_*-*-of-*')) 115 | 116 | filenames = tf.gfile.Glob(os.path.join(data_dir, '*.tfrecords')) 117 | filename_dataset = tf.data.Dataset.from_tensor_slices(filenames) 118 | filename_dataset = filename_dataset.shuffle(len(filenames)) 119 | prefetch = max(int((batch_size * num_batches) / cycle_length), 1) 120 | dataset = filename_dataset.interleave( 121 | lambda fn: tf.data.TFRecordDataset(fn).prefetch(prefetch), 122 | cycle_length=cycle_length) 123 | 124 | dataset = dataset.shuffle(shuffle_buffer_size) 125 | image_size = 128 126 | # graph_def = classifier_metrics._default_graph_def_fn() # pylint: disable=protected-access 127 | 128 | def _extract_image_and_label(record): 129 | """Extracts and preprocesses the image and label from the record.""" 130 | features = tf.parse_single_example( 131 | record, 132 | features={ 133 | 'image_raw': tf.FixedLenFeature([], tf.string), 134 | 'label': tf.FixedLenFeature([], tf.int64) 135 | }) 136 | 137 | image = tf.decode_raw(features['image_raw'], tf.uint8) 138 | image.set_shape(image_size * image_size * 3) 139 | image = tf.reshape(image, [image_size, image_size, 3]) 140 | 141 | image = tf.cast(image, tf.float32) * (2. / 255) - 1. 142 | 143 | label = tf.cast(features['label'], tf.int32) 144 | label += label_offset 145 | 146 | return image, label 147 | 148 | dataset = dataset.map( 149 | _extract_image_and_label, 150 | num_parallel_calls=16).prefetch(batch_size * num_batches) 151 | dataset = dataset.batch(batch_size) 152 | iterator = dataset.make_one_shot_iterator() 153 | 154 | real_images, _ = iterator.get_next() 155 | real_images.set_shape([batch_size, image_size, image_size, 3]) 156 | 157 | pools = run_custom_inception( 158 | real_images, graph_def=None, output_tensor=['pool_3:0'])[0] 159 | 160 | def while_cond(_, i): 161 | return tf.less(i, num_batches) 162 | 163 | def while_body(real_pools, i): 164 | with tf.control_dependencies([real_pools]): 165 | imgs, _ = iterator.get_next() 166 | imgs.set_shape([batch_size, image_size, image_size, 3]) 167 | pools = run_custom_inception( 168 | imgs, graph_def=None, output_tensor=['pool_3:0'])[0] 169 | real_pools = tf.concat([real_pools, pools], 0) 170 | return (real_pools, tf.add(i, 1)) 171 | 172 | # Get activations from real images. 173 | i = tf.constant(1) 174 | real_pools, _ = tf.while_loop( 175 | while_cond, 176 | while_body, [pools, i], 177 | shape_invariants=[tf.TensorShape([None, 2048]), 178 | i.get_shape()], 179 | parallel_iterations=1, 180 | back_prop=False, 181 | swap_memory=True, 182 | name='RealActivations') 183 | 184 | real_pools.set_shape([batch_size * num_batches, 2048]) 185 | 186 | return real_pools, real_images 187 | 188 | 189 | def get_imagenet_batches(data_dir, 190 | batch_size, 191 | num_batches, 192 | label_offset=0, 193 | cycle_length=1, 194 | shuffle_buffer_size=100000): 195 | """Fetches num_batches batches of size batch_size from the data_dir. 196 | 197 | Args: 198 | data_dir: The directory to read data from. Expected to be a single 199 | TFRecords file. 200 | batch_size: The number of elements in a single minibatch. 201 | num_batches: The number of batches to fetch at a time. 202 | label_offset: The scalar to add to the labels in the dataset. The imagenet 203 | GAN code expects labels in [0, 999], and this scalar can be used to move 204 | other labels into this range. (Default: 0) 205 | cycle_length: The number of input elements to process concurrently in the 206 | Dataset loader. (Default: 1) 207 | shuffle_buffer_size: The number of records to load before shuffling. Larger 208 | means more likely randomization. (Default: 100000) 209 | Returns: 210 | A list of num_batches batches of size batch_size. 211 | """ 212 | # filenames = gfile.Glob(os.path.join(data_dir, '*_train_*-*-of-*')) 213 | filenames = tf.gfile.Glob(os.path.join(data_dir, '*.tfrecords')) 214 | filename_dataset = tf.data.Dataset.from_tensor_slices(filenames) 215 | filename_dataset = filename_dataset.shuffle(len(filenames)) 216 | prefetch = max(int((batch_size * num_batches) / cycle_length), 1) 217 | dataset = filename_dataset.interleave( 218 | lambda fn: tf.data.TFRecordDataset(fn).prefetch(prefetch), 219 | cycle_length=cycle_length) 220 | 221 | dataset = dataset.shuffle(shuffle_buffer_size) 222 | image_size = 128 223 | 224 | def _extract_image_and_label(record): 225 | """Extracts and preprocesses the image and label from the record.""" 226 | features = tf.parse_single_example( 227 | record, 228 | features={ 229 | 'image_raw': tf.FixedLenFeature([], tf.string), 230 | 'label': tf.FixedLenFeature([], tf.int64) 231 | }) 232 | 233 | image = tf.decode_raw(features['image_raw'], tf.uint8) 234 | image.set_shape(image_size * image_size * 3) 235 | image = tf.reshape(image, [image_size, image_size, 3]) 236 | 237 | image = tf.cast(image, tf.float32) * (2. / 255) - 1. 238 | 239 | label = tf.cast(features['label'], tf.int32) 240 | label += label_offset 241 | 242 | return image, label 243 | 244 | dataset = dataset.map( 245 | _extract_image_and_label, 246 | num_parallel_calls=16).prefetch(batch_size * num_batches) 247 | dataset = dataset.repeat() # Repeat for unlimited epochs. 248 | dataset = dataset.batch(batch_size) 249 | dataset = dataset.batch(num_batches) 250 | 251 | iterator = dataset.make_one_shot_iterator() 252 | images, labels = iterator.get_next() 253 | 254 | batches = [] 255 | for i in range(num_batches): 256 | # Dataset batches lose shape information. Put it back in. 257 | im = images[i, ...] 258 | im.set_shape([batch_size, image_size, image_size, 3]) 259 | 260 | lb = labels[i, ...] 261 | lb.set_shape((batch_size,)) 262 | 263 | batches.append((im, tf.expand_dims(lb, 1))) 264 | 265 | return batches 266 | 267 | 268 | def save_images(images, size, image_path): 269 | return imsave(inverse_transform(images), size, image_path) 270 | 271 | 272 | def merge(images, size): 273 | h, w = images.shape[1], images.shape[2] 274 | img = np.zeros((h * size[0], w * size[1], 3)) 275 | 276 | idx = 0 277 | for i in range(0, size[0]): 278 | for j in range(0, size[1]): 279 | img[j * h:(j + 1) * h, i * w:(i + 1) * w, :] = images[idx] 280 | idx += 1 281 | return img 282 | 283 | 284 | def imsave(images, size, path): 285 | with gfile.Open(path, mode='w') as f: 286 | saved = scipy.misc.imsave(f, merge(images, size)) 287 | return saved 288 | 289 | 290 | def inverse_transform(images): 291 | return (images + 1.) / 2. 292 | 293 | 294 | def visualize(sess, dcgan, config, option): 295 | option = 0 296 | if option == 0: 297 | all_samples = [] 298 | for i in range(484): 299 | print(i) 300 | samples = sess.run(dcgan.generator) 301 | all_samples.append(samples) 302 | samples = np.concatenate(all_samples, 0) 303 | n = int(np.sqrt(samples.shape[0])) 304 | m = samples.shape[0] // n 305 | save_images(samples, [m, n], './' + config.sample_dir + '/test.png') 306 | elif option == 1: 307 | counter = 0 308 | coord = tf.train.Coordinator() 309 | tf.train.start_queue_runners(coord=coord) 310 | while counter < 1005: 311 | print(counter) 312 | samples, fake = sess.run([dcgan.generator, dcgan.d_loss_class]) 313 | fake = np.argsort(fake) 314 | print(np.sum(samples)) 315 | print(fake) 316 | for i in range(samples.shape[0]): 317 | name = '%s%d.png' % (chr(ord('a') + counter % 10), counter) 318 | img = np.expand_dims(samples[fake[i]], 0) 319 | if counter >= 1000: 320 | save_images(img, [1, 1], './{}/turk/fake{}.png'.format( 321 | config.sample_dir, counter - 1000)) 322 | else: 323 | save_images(img, [1, 1], './{}/turk/{}'.format( 324 | config.sample_dir, name)) 325 | counter += 1 326 | elif option == 2: 327 | values = np.arange(0, 1, 1. / config.batch_size) 328 | for idx in range(100): 329 | print(' [*] %d' % idx) 330 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 331 | for kdx, z in enumerate(z_sample): 332 | z[idx] = values[kdx] 333 | 334 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 335 | save_images(samples, [8, 8], './{}/test_arange_{}.png'.format( 336 | config.sample_dir, idx)) 337 | 338 | 339 | def squarest_grid_size(num_images): 340 | """Calculates the size of the most square grid for num_images. 341 | 342 | Calculates the largest integer divisor of num_images less than or equal to 343 | sqrt(num_images) and returns that as the width. The height is 344 | num_images / width. 345 | 346 | Args: 347 | num_images: The total number of images. 348 | 349 | Returns: 350 | A tuple of (height, width) for the image grid. 351 | """ 352 | divisors = sympy.divisors(num_images) 353 | square_root = math.sqrt(num_images) 354 | width = 1 355 | for d in divisors: 356 | if d > square_root: 357 | break 358 | width = d 359 | return (num_images // width, width) 360 | --------------------------------------------------------------------------------