├── .gitignore ├── CONTRIBUTING ├── LICENSE ├── OWNERS ├── README.md ├── __init__.py ├── autoaugment ├── __init__.py ├── autoaugment.py ├── autoaugment_test.py └── policies.py ├── figures ├── no_sam.png ├── sam_wide.png └── summary_plot.png ├── requirements.txt └── sam_jax ├── __init__.py ├── datasets ├── __init__.py ├── augmentation.py ├── dataset_source.py ├── dataset_source_imagenet.py ├── dataset_source_imagenet_test.py └── dataset_source_test.py ├── efficientnet ├── __init__.py ├── efficientnet.py ├── efficientnet_test.py ├── optim.py └── optim_test.py ├── imagenet_models ├── __init__.py ├── load_model.py ├── load_model_test.py ├── resnet.py └── resnet_test.py ├── models ├── __init__.py ├── load_model.py ├── load_model_test.py ├── pyramidnet.py ├── utils.py ├── wide_resnet.py └── wide_resnet_shakeshake.py ├── train.py └── training_utils ├── __init__.py ├── flax_training.py └── flax_training_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/BUILD 2 | -------------------------------------------------------------------------------- /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /OWNERS: -------------------------------------------------------------------------------- 1 | pierreforet 2 | akleiner 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM: Sharpness-Aware Minimization for Efficiently Improving Generalization 2 | 3 | by Pierre Foret, Ariel Kleiner, Hossein Mobahi and Behnam Neyshabur. 4 | 5 | 6 | ## SAM in a few words 7 | 8 | **Abstract**: In today's heavily overparameterized models, the value of the training loss provides few guarantees on model generalization ability. Indeed, optimizing only the training loss value, as is commonly done, can easily lead to suboptimal model quality. Motivated by the connection between geometry of the loss landscape and generalization---including a generalization bound that we prove here---we introduce a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, our procedure, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss; this formulation results in a min-max optimization problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model generalization across a variety of benchmark datasets (e.g., CIFAR-{10, 100}, ImageNet, finetuning tasks) and models, yielding novel state-of-the-art performance for several. Additionally, we find that SAM natively provides robustness to label noise on par with that provided by state-of-the-art procedures that specifically target learning with noisy labels. 9 | 10 | 11 | | ![fig](figures/summary_plot.png) | ![fig](figures/no_sam.png) | ![fig](figures/sam_wide.png) | 12 | |:--------------:|:----------:|:----------------------:| 13 | | Error rate reduction obtained by switching to SAM. Each point is a different dataset / model / data augmentation | A sharp minimum to which a ResNet trained with SGD converged | A wide minimum to which the same ResNet trained with SAM converged. | 14 | 15 | 16 | 17 | ## About this repo 18 | 19 | This code allows the user to replicate most of the experiments of the paper, including: 20 | 21 | * Training from scratch Wideresnets and Pyramidnets (with shake shake / shake drop) on CIFAR10/CIFAR100/SVHN/Fashion MNIST, with or without SAM, with or without cutout and AutoAugment. 22 | * Training Resnets and Efficientnet on Imagenet, with or without SAM or RandAugment. 23 | * Finetuning Efficientnet from checkpoints trained on Imagenet/JFT on imagenet. 24 | 25 | 26 | ## How to train from scratch 27 | 28 | Once the repo is cloned, experiments can be launched using sam.sam_jax.train.py: 29 | 30 | ``` 31 | python3 -m sam.sam_jax.train --dataset cifar10 --model_name WideResnet28x10 \ 32 | --output_dir /tmp/my_experiment --image_level_augmentations autoaugment \ 33 | --num_epochs 1800 --sam_rho 0.05 34 | ``` 35 | 36 | Note that our code uses all available GPUs/TPUs for training. 37 | 38 | To see a detailed list of all available flags, run python3 -m sam.sam_jax.train --help. 39 | 40 | #### Output 41 | 42 | Training curves can be loaded using TensorBoard. TensorBoard events will be 43 | saved in the output_dir, and their path will contain the learning_rate, 44 | the weight_decay, rho and the random_seed. 45 | 46 | ## Finetuning EfficientNet: 47 | 48 | We provide a FLAX checkpoint compatible with our implementation for all checkpoints 49 | available on [official tensorflow efficientnet implementation](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet): 50 | 51 | | | B0 | B1 | B2 | B3 | B4 | B5 | B6 | B7 | B8 | L2 | 52 | |---------- |-------- | ------| ------|------ |------ |------ | --- | --- | --- | --- | 53 | | Baseline preprocessing | 76.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b0/checkpoint.tar.gz)) | 78.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b1/checkpoint.tar.gz)) | 79.8% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b2/checkpoint.tar.gz)) | 81.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b3/checkpoint.tar.gz)) | 82.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b4/checkpoint.tar.gz)) | 83.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b5/checkpoint.tar.gz)) | | || | | 54 | | AutoAugment (AA) | 77.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b0/checkpoint.tar.gz)) | 79.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b1/checkpoint.tar.gz)) | 80.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b2/checkpoint.tar.gz)) | 81.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b3/checkpoint.tar.gz)) | 82.9% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b4/checkpoint.tar.gz)) | 83.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b5/checkpoint.tar.gz)) | 84.0% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b6/checkpoint.tar.gz)) | 84.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b7/checkpoint.tar.gz)) || | 55 | | RandAugment (RA) | | | | | | 83.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/randaug/efficientnet-b5-randaug/checkpoint.tar.gz)) | | 84.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/randaug/efficientnet-b7-randaug/checkpoint.tar.gz)) | | | 56 | | AdvProp + AA | 77.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b0/checkpoint.tar.gz)) | 79.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b1/checkpoint.tar.gz)) | 80.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b2/checkpoint.tar.gz)) | 81.9% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b3/checkpoint.tar.gz)) | 83.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b4/checkpoint.tar.gz)) | 84.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b5/checkpoint.tar.gz)) | 84.8% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b6/checkpoint.tar.gz)) | 85.2% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b7/checkpoint.tar.gz)) | 85.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b8/checkpoint.tar.gz))|| | 57 | | NoisyStudent + RA | 78.8% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b0/checkpoint.tar.gz)) | 81.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b1/checkpoint.tar.gz)) | 82.4% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b2/checkpoint.tar.gz)) | 84.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b3/checkpoint.tar.gz)) | 85.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b4/checkpoint.tar.gz)) | 86.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b5/checkpoint.tar.gz)) | 86.4% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b6/checkpoint.tar.gz)) | 86.9% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b7/checkpoint.tar.gz)) | - |88.4% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-l2/checkpoint.tar.gz)) | 58 | 59 | * We report in this table the scores as found [here](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). If you use these checkpoints with another codebase, you might obtain some slightly different scores based on the type of accelerator you use and your data processing pipeline (resizing algorithm in particular). 60 | 61 | * Advprop requires some slight modification of the input pipeline. See [here](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) for more details. 62 | 63 | * Please refer to the README of the [official tensorflow efficientnet implementation](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) to see which paper should be cited for which checkpoint. 64 | 65 | Once the checkpoint is downloaded and uncompressed, it can be finetuned on any of the datasets: 66 | 67 | ``` 68 | python3 -m sam.sam_jax.train --output_dir /tmp/my_finetuning_experiment \ 69 | --dataset imagenet --from_pretrained_checkpoint true \ 70 | --efficientnet_checkpoint_path /tmp/path_to_efficientnet_checkpoint_to_finetune \ 71 | --learning_rate 0.1 --model_name efficientnet-l2-475 --batch_size 512 \ 72 | --num_epochs 10 --gradient_clipping 1.0 --label_smoothing 0.1 --sam_rho 0.1 73 | ``` 74 | 75 | 76 | ## Bibtex 77 | 78 | ``` 79 | @ARTICLE{2020arXiv201001412F, 80 | author = {{Foret}, Pierre and {Kleiner}, Ariel and {Mobahi}, Hossein and {Neyshabur}, Behnam}, 81 | title = "{Sharpness-Aware Minimization for Efficiently Improving Generalization}", 82 | year = 2020, 83 | eid = {arXiv:2010.01412}, 84 | eprint = {2010.01412}, 85 | } 86 | ``` 87 | 88 | **This is not an official Google product.** 89 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /autoaugment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /autoaugment/autoaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """AutoAugment and RandAugment policies for enhanced image preprocessing. 16 | 17 | AutoAugment Reference: https://arxiv.org/abs/1805.09501 18 | RandAugment Reference: https://arxiv.org/abs/1909.13719 19 | 20 | Forked from 21 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 22 | """ 23 | 24 | import inspect 25 | import math 26 | 27 | import dataclasses 28 | from sam.autoaugment import policies as optimal_policies 29 | import tensorflow.compat.v1 as tf 30 | from tensorflow_addons import image as contrib_image 31 | 32 | 33 | # This signifies the max integer that the controller RNN could predict for the 34 | # augmentation scheme. 35 | _MAX_LEVEL = 10. 36 | 37 | 38 | @dataclasses.dataclass 39 | class HParams: 40 | """Parameters for AutoAugment and RandAugment.""" 41 | cutout_const: int 42 | translate_const: int 43 | 44 | 45 | def blend(image1, image2, factor): 46 | """Blend image1 and image2 using 'factor'. 47 | 48 | Factor can be above 0.0. A value of 0.0 means only image1 is used. 49 | A value of 1.0 means only image2 is used. A value between 0.0 and 50 | 1.0 means we linearly interpolate the pixel values between the two 51 | images. A value greater than 1.0 "extrapolates" the difference 52 | between the two pixel values, and we clip the results to values 53 | between 0 and 255. 54 | 55 | Args: 56 | image1: An image Tensor of type uint8. 57 | image2: An image Tensor of type uint8. 58 | factor: A floating point value above 0.0. 59 | 60 | Returns: 61 | A blended image Tensor of type uint8. 62 | """ 63 | if factor == 0.0: 64 | return tf.convert_to_tensor(image1) 65 | if factor == 1.0: 66 | return tf.convert_to_tensor(image2) 67 | 68 | image1 = tf.to_float(image1) 69 | image2 = tf.to_float(image2) 70 | 71 | difference = image2 - image1 72 | scaled = factor * difference 73 | 74 | # Do addition in float. 75 | temp = tf.to_float(image1) + scaled 76 | 77 | # Interpolate 78 | if factor > 0.0 and factor < 1.0: 79 | # Interpolation means we always stay within 0 and 255. 80 | return tf.cast(temp, tf.uint8) 81 | 82 | # Extrapolate: 83 | # 84 | # We need to clip and then cast. 85 | return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8) 86 | 87 | 88 | def cutout(image, pad_size, replace=0): 89 | """Apply cutout (https://arxiv.org/abs/1708.04552) to image. 90 | 91 | This operation applies a (2*pad_size x 2*pad_size) mask of zeros to 92 | a random location within `img`. The pixel values filled in will be of the 93 | value `replace`. The located where the mask will be applied is randomly 94 | chosen uniformly over the whole image. 95 | 96 | Args: 97 | image: An image Tensor of type uint8. 98 | pad_size: Specifies how big the zero mask that will be generated is that 99 | is applied to the image. The mask will be of size 100 | (2*pad_size x 2*pad_size). 101 | replace: What pixel value to fill in the image in the area that has 102 | the cutout mask applied to it. 103 | 104 | Returns: 105 | An image Tensor that is of type uint8. 106 | """ 107 | image_height = tf.shape(image)[0] 108 | image_width = tf.shape(image)[1] 109 | 110 | # Sample the center location in the image where the zero mask will be applied. 111 | cutout_center_height = tf.random_uniform( 112 | shape=[], minval=0, maxval=image_height, 113 | dtype=tf.int32) 114 | 115 | cutout_center_width = tf.random_uniform( 116 | shape=[], minval=0, maxval=image_width, 117 | dtype=tf.int32) 118 | 119 | lower_pad = tf.maximum(0, cutout_center_height - pad_size) 120 | upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) 121 | left_pad = tf.maximum(0, cutout_center_width - pad_size) 122 | right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) 123 | 124 | cutout_shape = [image_height - (lower_pad + upper_pad), 125 | image_width - (left_pad + right_pad)] 126 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 127 | mask = tf.pad( 128 | tf.zeros(cutout_shape, dtype=image.dtype), 129 | padding_dims, constant_values=1) 130 | mask = tf.expand_dims(mask, -1) 131 | mask = tf.tile(mask, [1, 1, 3]) 132 | image = tf.where( 133 | tf.equal(mask, 0), 134 | tf.ones_like(image, dtype=image.dtype) * replace, 135 | image) 136 | return image 137 | 138 | 139 | def solarize(image, threshold=128): 140 | # For each pixel in the image, select the pixel 141 | # if the value is less than the threshold. 142 | # Otherwise, subtract 255 from the pixel. 143 | return tf.where(image < threshold, image, 255 - image) 144 | 145 | 146 | def solarize_add(image, addition=0, threshold=128): 147 | # For each pixel in the image less than threshold 148 | # we add 'addition' amount to it and then clip the 149 | # pixel value to be between 0 and 255. The value 150 | # of 'addition' is between -128 and 128. 151 | added_image = tf.cast(image, tf.int64) + addition 152 | added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) 153 | return tf.where(image < threshold, added_image, image) 154 | 155 | 156 | def color(image, factor): 157 | """Equivalent of PIL Color.""" 158 | degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) 159 | return blend(degenerate, image, factor) 160 | 161 | 162 | def contrast(image, factor): 163 | """Equivalent of PIL Contrast.""" 164 | degenerate = tf.image.rgb_to_grayscale(image) 165 | # Cast before calling tf.histogram. 166 | degenerate = tf.cast(degenerate, tf.int32) 167 | 168 | # Compute the grayscale histogram, then compute the mean pixel value, 169 | # and create a constant image size of that value. Use that as the 170 | # blending degenerate target of the original image. 171 | hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) 172 | mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 173 | degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean 174 | degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) 175 | degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) 176 | return blend(degenerate, image, factor) 177 | 178 | 179 | def brightness(image, factor): 180 | """Equivalent of PIL Brightness.""" 181 | degenerate = tf.zeros_like(image) 182 | return blend(degenerate, image, factor) 183 | 184 | 185 | def posterize(image, bits): 186 | """Equivalent of PIL Posterize.""" 187 | shift = 8 - bits 188 | return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) 189 | 190 | 191 | def rotate(image, degrees, replace): 192 | """Rotates the image by degrees either clockwise or counterclockwise. 193 | 194 | Args: 195 | image: An image Tensor of type uint8. 196 | degrees: Float, a scalar angle in degrees to rotate all images by. If 197 | degrees is positive the image will be rotated clockwise otherwise it will 198 | be rotated counterclockwise. 199 | replace: A one or three value 1D tensor to fill empty pixels caused by 200 | the rotate operation. 201 | 202 | Returns: 203 | The rotated version of image. 204 | """ 205 | # Convert from degrees to radians. 206 | degrees_to_radians = math.pi / 180.0 207 | radians = degrees * degrees_to_radians 208 | 209 | # In practice, we should randomize the rotation degrees by flipping 210 | # it negatively half the time, but that's done on 'degrees' outside 211 | # of the function. 212 | image = contrib_image.rotate(wrap(image), radians) 213 | return unwrap(image, replace) 214 | 215 | 216 | def translate_x(image, pixels, replace): 217 | """Equivalent of PIL Translate in X dimension.""" 218 | image = contrib_image.translate(wrap(image), [-pixels, 0]) 219 | return unwrap(image, replace) 220 | 221 | 222 | def translate_y(image, pixels, replace): 223 | """Equivalent of PIL Translate in Y dimension.""" 224 | image = contrib_image.translate(wrap(image), [0, -pixels]) 225 | return unwrap(image, replace) 226 | 227 | 228 | def shear_x(image, level, replace): 229 | """Equivalent of PIL Shearing in X dimension.""" 230 | # Shear parallel to x axis is a projective transform 231 | # with a matrix form of: 232 | # [1 level 233 | # 0 1]. 234 | image = contrib_image.transform( 235 | wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) 236 | return unwrap(image, replace) 237 | 238 | 239 | def shear_y(image, level, replace): 240 | """Equivalent of PIL Shearing in Y dimension.""" 241 | # Shear parallel to y axis is a projective transform 242 | # with a matrix form of: 243 | # [1 0 244 | # level 1]. 245 | image = contrib_image.transform( 246 | wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) 247 | return unwrap(image, replace) 248 | 249 | 250 | def autocontrast(image): 251 | """Implements Autocontrast function from PIL using TF ops. 252 | 253 | Args: 254 | image: A 3D uint8 tensor. 255 | 256 | Returns: 257 | The image after it has had autocontrast applied to it and will be of type 258 | uint8. 259 | """ 260 | 261 | def scale_channel(image): 262 | """Scale the 2D image using the autocontrast rule.""" 263 | # A possibly cheaper version can be done using cumsum/unique_with_counts 264 | # over the histogram values, rather than iterating over the entire image. 265 | # to compute mins and maxes. 266 | lo = tf.to_float(tf.reduce_min(image)) 267 | hi = tf.to_float(tf.reduce_max(image)) 268 | 269 | # Scale the image, making the lowest value 0 and the highest value 255. 270 | def scale_values(im): 271 | scale = 255.0 / (hi - lo) 272 | offset = -lo * scale 273 | im = tf.to_float(im) * scale + offset 274 | im = tf.clip_by_value(im, 0.0, 255.0) 275 | return tf.cast(im, tf.uint8) 276 | 277 | result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) 278 | return result 279 | 280 | # Assumes RGB for now. Scales each channel independently 281 | # and then stacks the result. 282 | s1 = scale_channel(image[:, :, 0]) 283 | s2 = scale_channel(image[:, :, 1]) 284 | s3 = scale_channel(image[:, :, 2]) 285 | image = tf.stack([s1, s2, s3], 2) 286 | return image 287 | 288 | 289 | def sharpness(image, factor): 290 | """Implements Sharpness function from PIL using TF ops.""" 291 | orig_image = image 292 | image = tf.cast(image, tf.float32) 293 | # Make image 4D for conv operation. 294 | image = tf.expand_dims(image, 0) 295 | # SMOOTH PIL Kernel. 296 | kernel = tf.constant( 297 | [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, 298 | shape=[3, 3, 1, 1]) / 13. 299 | # Tile across channel dimension. 300 | kernel = tf.tile(kernel, [1, 1, 3, 1]) 301 | strides = [1, 1, 1, 1] 302 | degenerate = tf.nn.depthwise_conv2d( 303 | image, kernel, strides, padding='VALID', rate=[1, 1]) 304 | degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) 305 | degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) 306 | 307 | # For the borders of the resulting image, fill in the values of the 308 | # original image. 309 | mask = tf.ones_like(degenerate) 310 | padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) 311 | padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) 312 | result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) 313 | 314 | # Blend the final result. 315 | return blend(result, orig_image, factor) 316 | 317 | 318 | def equalize(image): 319 | """Implements Equalize function from PIL using TF ops.""" 320 | def scale_channel(im, c): 321 | """Scale the data in the channel to implement equalize.""" 322 | im = tf.cast(im[:, :, c], tf.int32) 323 | # Compute the histogram of the image channel. 324 | histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) 325 | 326 | # For the purposes of computing the step, filter out the nonzeros. 327 | nonzero = tf.where(tf.not_equal(histo, 0)) 328 | nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) 329 | step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 330 | 331 | def build_lut(histo, step): 332 | # Compute the cumulative sum, shifting by step // 2 333 | # and then normalization by step. 334 | lut = (tf.cumsum(histo) + (step // 2)) // step 335 | # Shift lut, prepending with 0. 336 | lut = tf.concat([[0], lut[:-1]], 0) 337 | # Clip the counts to be in range. This is done 338 | # in the C code for image.point. 339 | return tf.clip_by_value(lut, 0, 255) 340 | 341 | # If step is zero, return the original image. Otherwise, build 342 | # lut from the full histogram and step and then index from it. 343 | result = tf.cond(tf.equal(step, 0), 344 | lambda: im, 345 | lambda: tf.gather(build_lut(histo, step), im)) 346 | 347 | return tf.cast(result, tf.uint8) 348 | 349 | # Assumes RGB for now. Scales each channel independently 350 | # and then stacks the result. 351 | s1 = scale_channel(image, 0) 352 | s2 = scale_channel(image, 1) 353 | s3 = scale_channel(image, 2) 354 | image = tf.stack([s1, s2, s3], 2) 355 | return image 356 | 357 | 358 | def invert(image): 359 | """Inverts the image pixels.""" 360 | image = tf.convert_to_tensor(image) 361 | return 255 - image 362 | 363 | 364 | def wrap(image): 365 | """Returns 'image' with an extra channel set to all 1s.""" 366 | shape = tf.shape(image) 367 | extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype) 368 | extended = tf.concat([image, extended_channel], 2) 369 | return extended 370 | 371 | 372 | def unwrap(image, replace): 373 | """Unwraps an image produced by wrap. 374 | 375 | Where there is a 0 in the last channel for every spatial position, 376 | the rest of the three channels in that spatial dimension are grayed 377 | (set to 128). Operations like translate and shear on a wrapped 378 | Tensor will leave 0s in empty locations. Some transformations look 379 | at the intensity of values to do preprocessing, and we want these 380 | empty pixels to assume the 'average' value, rather than pure black. 381 | 382 | 383 | Args: 384 | image: A 3D Image Tensor with 4 channels. 385 | replace: A one or three value 1D tensor to fill empty pixels. 386 | 387 | Returns: 388 | image: A 3D image Tensor with 3 channels. 389 | """ 390 | image_shape = tf.shape(image) 391 | # Flatten the spatial dimensions. 392 | flattened_image = tf.reshape(image, [-1, image_shape[2]]) 393 | 394 | # Find all pixels where the last channel is zero. 395 | alpha_channel = flattened_image[:, 3] 396 | 397 | replace = tf.concat([replace, tf.ones([1], image.dtype)], 0) 398 | 399 | # Where they are zero, fill them in with 'replace'. 400 | flattened_image = tf.where( 401 | tf.equal(alpha_channel, 0), 402 | tf.ones_like(flattened_image, dtype=image.dtype) * replace, 403 | flattened_image) 404 | 405 | image = tf.reshape(flattened_image, image_shape) 406 | image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) 407 | return image 408 | 409 | 410 | NAME_TO_FUNC = { 411 | 'AutoContrast': autocontrast, 412 | 'Equalize': equalize, 413 | 'Invert': invert, 414 | 'Rotate': rotate, 415 | 'Posterize': posterize, 416 | 'Solarize': solarize, 417 | 'SolarizeAdd': solarize_add, 418 | 'Color': color, 419 | 'Contrast': contrast, 420 | 'Brightness': brightness, 421 | 'Sharpness': sharpness, 422 | 'ShearX': shear_x, 423 | 'ShearY': shear_y, 424 | 'TranslateX': translate_x, 425 | 'TranslateY': translate_y, 426 | 'Cutout': cutout, 427 | } 428 | 429 | 430 | def _randomly_negate_tensor(tensor): 431 | """With 50% prob turn the tensor negative.""" 432 | should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool) 433 | final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) 434 | return final_tensor 435 | 436 | 437 | def _rotate_level_to_arg(level): 438 | level = (level/_MAX_LEVEL) * 30. 439 | level = _randomly_negate_tensor(level) 440 | return (level,) 441 | 442 | 443 | def _shrink_level_to_arg(level): 444 | """Converts level to ratio by which we shrink the image content.""" 445 | if level == 0: 446 | return (1.0,) # if level is zero, do not shrink the image 447 | # Maximum shrinking ratio is 2.9. 448 | level = 2. / (_MAX_LEVEL / level) + 0.9 449 | return (level,) 450 | 451 | 452 | def _enhance_level_to_arg(level): 453 | return ((level/_MAX_LEVEL) * 1.8 + 0.1,) 454 | 455 | 456 | def _shear_level_to_arg(level): 457 | level = (level/_MAX_LEVEL) * 0.3 458 | # Flip level to negative with 50% chance. 459 | level = _randomly_negate_tensor(level) 460 | return (level,) 461 | 462 | 463 | def _translate_level_to_arg(level, translate_const): 464 | level = (level/_MAX_LEVEL) * float(translate_const) 465 | # Flip level to negative with 50% chance. 466 | level = _randomly_negate_tensor(level) 467 | return (level,) 468 | 469 | 470 | def level_to_arg(hparams): 471 | return { 472 | 'AutoContrast': lambda level: (), 473 | 'Equalize': lambda level: (), 474 | 'Invert': lambda level: (), 475 | 'Rotate': _rotate_level_to_arg, 476 | 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),), 477 | 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),), 478 | 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),), 479 | 'Color': _enhance_level_to_arg, 480 | 'Contrast': _enhance_level_to_arg, 481 | 'Brightness': _enhance_level_to_arg, 482 | 'Sharpness': _enhance_level_to_arg, 483 | 'ShearX': _shear_level_to_arg, 484 | 'ShearY': _shear_level_to_arg, 485 | 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),), 486 | # pylint:disable=g-long-lambda 487 | 'TranslateX': lambda level: _translate_level_to_arg( 488 | level, hparams.translate_const), 489 | 'TranslateY': lambda level: _translate_level_to_arg( 490 | level, hparams.translate_const), 491 | # pylint:enable=g-long-lambda 492 | } 493 | 494 | 495 | def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams): 496 | """Return the function that corresponds to `name` and update `level` param.""" 497 | func = NAME_TO_FUNC[name] 498 | args = level_to_arg(augmentation_hparams)[name](level) 499 | 500 | # Check to see if prob is passed into function. This is used for operations 501 | # where we alter bboxes independently. 502 | # pytype:disable=wrong-arg-types 503 | if 'prob' in inspect.getfullargspec(func)[0]: 504 | args = tuple([prob] + list(args)) 505 | # pytype:enable=wrong-arg-types 506 | 507 | # Add in replace arg if it is required for the function that is being called. 508 | # pytype:disable=wrong-arg-types 509 | if 'replace' in inspect.getfullargspec(func)[0]: 510 | # Make sure replace is the final argument 511 | assert 'replace' == inspect.getfullargspec(func)[0][-1] 512 | args = tuple(list(args) + [replace_value]) 513 | # pytype:enable=wrong-arg-types 514 | 515 | return (func, prob, args) 516 | 517 | 518 | def _apply_func_with_prob(func, image, args, prob): 519 | """Apply `func` to image w/ `args` as input with probability `prob`.""" 520 | assert isinstance(args, tuple) 521 | 522 | # If prob is a function argument, then this randomness is being handled 523 | # inside the function, so make sure it is always called. 524 | # pytype:disable=wrong-arg-types 525 | if 'prob' in inspect.getfullargspec(func)[0]: 526 | prob = 1.0 527 | # pytype:enable=wrong-arg-types 528 | 529 | # Apply the function with probability `prob`. 530 | should_apply_op = tf.cast( 531 | tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) 532 | augmented_image = tf.cond( 533 | should_apply_op, 534 | lambda: func(image, *args), 535 | lambda: image) 536 | return augmented_image 537 | 538 | 539 | def select_and_apply_random_policy(policies, image): 540 | """Select a random policy from `policies` and apply it to `image`.""" 541 | policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) 542 | # Note that using tf.case instead of tf.conds would result in significantly 543 | # larger graphs and would even break export for some larger policies. 544 | for (i, policy) in enumerate(policies): 545 | image = tf.cond( 546 | tf.equal(i, policy_to_select), 547 | lambda selected_policy=policy: selected_policy(image), 548 | lambda: image) 549 | return image 550 | 551 | 552 | def build_and_apply_nas_policy(policies, image, 553 | augmentation_hparams): 554 | """Build a policy from the given policies passed in and apply to image. 555 | 556 | Args: 557 | policies: list of lists of tuples in the form `(func, prob, level)`, `func` 558 | is a string name of the augmentation function, `prob` is the probability 559 | of applying the `func` operation, `level` is the input argument for 560 | `func`. 561 | image: tf.Tensor that the resulting policy will be applied to. 562 | augmentation_hparams: Hparams associated with the NAS learned policy. 563 | 564 | Returns: 565 | A version of image that now has data augmentation applied to it based on 566 | the `policies` pass into the function. 567 | """ 568 | replace_value = [128, 128, 128] 569 | 570 | # func is the string name of the augmentation function, prob is the 571 | # probability of applying the operation and level is the parameter associated 572 | # with the tf op. 573 | 574 | # tf_policies are functions that take in an image and return an augmented 575 | # image. 576 | tf_policies = [] 577 | for policy in policies: 578 | tf_policy = [] 579 | # Link string name to the correct python function and make sure the correct 580 | # argument is passed into that function. 581 | for policy_info in policy: 582 | policy_info = list(policy_info) + [replace_value, augmentation_hparams] 583 | 584 | tf_policy.append(_parse_policy_info(*policy_info)) 585 | # Now build the tf policy that will apply the augmentation procedue 586 | # on image. 587 | def make_final_policy(tf_policy_): 588 | def final_policy(image_): 589 | for func, prob, args in tf_policy_: 590 | image_ = _apply_func_with_prob( 591 | func, image_, args, prob) 592 | return image_ 593 | return final_policy 594 | tf_policies.append(make_final_policy(tf_policy)) 595 | 596 | augmented_image = select_and_apply_random_policy( 597 | tf_policies, image) 598 | return augmented_image 599 | 600 | 601 | def distort_image_with_autoaugment(image, augmentation_name): 602 | """Applies the AutoAugment policy to `image`. 603 | 604 | AutoAugment is from the paper: https://arxiv.org/abs/1805.09501. 605 | 606 | Args: 607 | image: `Tensor` of shape [height, width, 3] representing an image. 608 | augmentation_name: The name of the AutoAugment policy to use. 609 | 610 | Returns: 611 | A tuple containing the augmented versions of `image`. 612 | """ 613 | available_policies = {'imagenet': optimal_policies.policy_imagenet, 614 | 'cifar': optimal_policies.policy_cifar, 615 | 'svhn': optimal_policies.policy_svhn} 616 | if augmentation_name not in available_policies: 617 | raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) 618 | 619 | policy = available_policies[augmentation_name]() 620 | # Hparams that will be used for AutoAugment. 621 | augmentation_hparams = HParams(cutout_const=100, translate_const=250) 622 | 623 | return build_and_apply_nas_policy(policy, image, augmentation_hparams) 624 | 625 | 626 | def distort_image_with_randaugment(image, num_layers, magnitude): 627 | """Applies the RandAugment policy to `image`. 628 | 629 | RandAugment is from the paper https://arxiv.org/abs/1909.13719, 630 | 631 | Args: 632 | image: `Tensor` of shape [height, width, 3] representing an image. 633 | num_layers: Integer, the number of augmentation transformations to apply 634 | sequentially to an image. Represented as (N) in the paper. Usually best 635 | values will be in the range [1, 3]. 636 | magnitude: Integer, shared magnitude across all augmentation operations. 637 | Represented as (M) in the paper. Usually best values are in the range 638 | [5, 30]. 639 | 640 | Returns: 641 | The augmented version of `image`. 642 | """ 643 | replace_value = [128] * 3 644 | tf.logging.info('Using RandAug.') 645 | augmentation_hparams = HParams(cutout_const=40, translate_const=100) 646 | available_ops = [ 647 | 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 648 | 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', 649 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] 650 | 651 | for layer_num in range(num_layers): 652 | op_to_select = tf.random_uniform( 653 | [], maxval=len(available_ops), dtype=tf.int32) 654 | random_magnitude = float(magnitude) 655 | with tf.name_scope('randaug_layer_{}'.format(layer_num)): 656 | for (i, op_name) in enumerate(available_ops): 657 | prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) 658 | func, _, args = _parse_policy_info(op_name, prob, random_magnitude, 659 | replace_value, augmentation_hparams) 660 | image = tf.cond( 661 | tf.equal(i, op_to_select), 662 | # pylint:disable=g-long-lambda 663 | lambda selected_func=func, selected_args=args: selected_func( 664 | image, *selected_args), 665 | # pylint:enable=g-long-lambda 666 | lambda: image) 667 | return image 668 | -------------------------------------------------------------------------------- /autoaugment/autoaugment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.autoaugment.autoaugment.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from sam.autoaugment import autoaugment 20 | import tensorflow as tf 21 | 22 | 23 | class AutoAugmentTest(parameterized.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ('ShearX', 'ShearX'), 27 | ('ShearY', 'ShearY'), 28 | ('Cutout', 'Cutout'), 29 | ('TranslateX', 'TranslateX'), 30 | ('TranslateY', 'TranslateY'), 31 | ('Rotate', 'Rotate'), 32 | ('AutoContrast', 'AutoContrast'), 33 | ('Invert', 'Invert'), 34 | ('Equalize', 'Equalize'), 35 | ('Solarize', 'Solarize'), 36 | ('Posterize', 'Posterize'), 37 | ('Contrast', 'Contrast'), 38 | ('Color', 'Color'), 39 | ('Brightness', 'Brightness'), 40 | ('Sharpness', 'Sharpness')) 41 | def test_image_processing_function(self, name: str): 42 | hparams = autoaugment.HParams(cutout_const=10, translate_const=25) 43 | replace_value = [128, 128, 128] 44 | function, _, args = autoaugment._parse_policy_info( 45 | name, 1.0, 10, replace_value, hparams) 46 | cifar_image_shape = [32, 32, 3] 47 | image = tf.zeros(cifar_image_shape, tf.uint8) 48 | augmented_image = function(image, *args) 49 | self.assertEqual(augmented_image.shape, cifar_image_shape) 50 | self.assertEqual(augmented_image.dtype, tf.uint8) 51 | 52 | @parameterized.named_parameters(('cifar', 'cifar'), ('svhn', 'svhn'), 53 | ('imagenet', 'imagenet')) 54 | def test_autoaugment_function(self, dataset_name): 55 | autoaugment_fn = lambda image: autoaugment.distort_image_with_autoaugment( # pylint:disable=g-long-lambda 56 | image, dataset_name) 57 | image_shape = [224, 224, 3] if dataset_name == 'imagenet' else [32, 32, 3] 58 | image = tf.zeros(image_shape, tf.uint8) 59 | augmented_image = autoaugment_fn(image) 60 | self.assertEqual(augmented_image.shape, image_shape) 61 | self.assertEqual(augmented_image.dtype, tf.uint8) 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /autoaugment/policies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """AutoAugment policies for Imagenet, Cifar and SVHN.""" 16 | 17 | from typing import List, Tuple 18 | 19 | 20 | def policy_imagenet() -> List[List[Tuple[str, float, int]]]: 21 | """Returns the autoaugment policy that was used in AutoAugment Paper. 22 | 23 | A policy is composed of two augmentations applied sequentially to the image. 24 | Each augmentation is described as a tuple where the first element is the 25 | type of transformation to apply, the second is the probability with which the 26 | augmentation should be applied, and the third element is the strength of the 27 | transformation. 28 | """ 29 | policy = [ 30 | [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], 31 | [('Color', 0.4, 9), ('Equalize', 0.6, 3)], 32 | [('Color', 0.4, 1), ('Rotate', 0.6, 8)], 33 | [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], 34 | [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], 35 | [('Color', 0.2, 0), ('Equalize', 0.8, 8)], 36 | [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], 37 | [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], 38 | [('Color', 0.6, 1), ('Equalize', 1.0, 2)], 39 | [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], 40 | [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], 41 | [('Color', 0.4, 7), ('Equalize', 0.6, 0)], 42 | [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], 43 | [('Solarize', 0.6, 8), ('Color', 0.6, 9)], 44 | [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], 45 | [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)], 46 | [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], 47 | [('ShearY', 0.8, 0), ('Color', 0.6, 4)], 48 | [('Color', 1.0, 0), ('Rotate', 0.6, 2)], 49 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], 50 | [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], 51 | [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], 52 | [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], 53 | [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], 54 | [('Color', 0.8, 6), ('Rotate', 0.4, 5)], 55 | ] 56 | return policy 57 | 58 | 59 | def policy_cifar() -> List[List[Tuple[str, float, int]]]: 60 | """Returns the AutoAugment policies found on Cifar. 61 | 62 | A policy is composed of two augmentations applied sequentially to the image. 63 | Each augmentation is described as a tuple where the first element is the 64 | type of transformation to apply, the second is the probability with which the 65 | augmentation should be applied, and the third element is the strength of the 66 | transformation. 67 | """ 68 | exp0_0 = [ 69 | [('Invert', 0.1, 7), ('Contrast', 0.2, 6)], 70 | [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)], 71 | [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)], 72 | [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)], 73 | [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]] 74 | exp0_1 = [ 75 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)], 76 | [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)], 77 | [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)], 78 | [('Equalize', 0.8, 8), ('Invert', 0.1, 3)], 79 | [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]] 80 | exp0_2 = [ 81 | [('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)], 82 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)], 83 | [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)], 84 | [('Equalize', 0.7, 5), ('Invert', 0.1, 3)], 85 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]] 86 | exp0_3 = [ 87 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)], 88 | [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)], 89 | [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)], 90 | [('TranslateY', 0.2, 7), ('Color', 0.9, 6)], 91 | [('Equalize', 0.7, 6), ('Color', 0.4, 9)]] 92 | exp1_0 = [ 93 | [('ShearY', 0.2, 7), ('Posterize', 0.3, 7)], 94 | [('Color', 0.4, 3), ('Brightness', 0.6, 7)], 95 | [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)], 96 | [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)], 97 | [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]] 98 | exp1_1 = [ 99 | [('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)], 100 | [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)], 101 | [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)], 102 | [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)], 103 | [('Brightness', 0.0, 8), ('Color', 0.8, 8)]] 104 | exp1_2 = [ 105 | [('Solarize', 0.2, 6), ('Color', 0.8, 6)], 106 | [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)], 107 | [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)], 108 | [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)], 109 | [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]] 110 | exp1_3 = [ 111 | [('Contrast', 0.7, 5), ('Brightness', 0.0, 2)], 112 | [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)], 113 | [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)], 114 | [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)], 115 | [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]] 116 | exp1_4 = [ 117 | [('Brightness', 0.0, 7), ('Equalize', 0.4, 7)], 118 | [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)], 119 | [('Equalize', 0.6, 8), ('Color', 0.6, 2)], 120 | [('Color', 0.3, 7), ('Color', 0.2, 4)], 121 | [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]] 122 | exp1_5 = [ 123 | [('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)], 124 | [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)], 125 | [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)], 126 | [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)], 127 | [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]] 128 | exp1_6 = [ 129 | [('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)], 130 | [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)], 131 | [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)], 132 | [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)], 133 | [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]] 134 | exp2_0 = [ 135 | [('Color', 0.7, 7), ('TranslateX', 0.5, 8)], 136 | [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)], 137 | [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)], 138 | [('Brightness', 0.9, 6), ('Color', 0.2, 8)], 139 | [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]] 140 | exp2_1 = [ 141 | [('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)], 142 | [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)], 143 | [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)], 144 | [('Color', 0.1, 8), ('ShearY', 0.2, 3)], 145 | [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]] 146 | exp2_2 = [ 147 | [('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)], 148 | [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)], 149 | [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)], 150 | [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)], 151 | [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]] 152 | exp2_3 = [ 153 | [('Equalize', 0.9, 5), ('Color', 0.7, 0)], 154 | [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)], 155 | [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)], 156 | [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)], 157 | [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]] 158 | exp2_4 = [ 159 | [('Solarize', 0.2, 3), ('ShearX', 0.0, 0)], 160 | [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)], 161 | [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)], 162 | [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)], 163 | [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]] 164 | exp2_5 = [ 165 | [('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)], 166 | [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)], 167 | [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)], 168 | [('Solarize', 0.4, 3), ('Color', 0.2, 4)], 169 | [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]] 170 | exp2_6 = [ 171 | [('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)], 172 | [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)], 173 | [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)], 174 | [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)], 175 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]] 176 | exp2_7 = [ 177 | [('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)], 178 | [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)], 179 | [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)], 180 | [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)], 181 | [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]] 182 | exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3 183 | exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6 184 | exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7 185 | return exp0s + exp1s + exp2s 186 | 187 | 188 | def policy_svhn() -> List[List[Tuple[str, float, int]]]: 189 | """Returns the AutoAugment policies found on SVHN. 190 | 191 | A policy is composed of two augmentations applied sequentially to the image. 192 | Each augmentation is described as a tuple where the first element is the 193 | type of transformation to apply, the second is the probability with which the 194 | augmentation should be applied, and the third element is the strength of the 195 | transformation. 196 | """ 197 | return [[('ShearX', 0.9, 4), ('Invert', 0.2, 3)], 198 | [('ShearY', 0.9, 8), ('Invert', 0.7, 5)], 199 | [('Equalize', 0.6, 5), ('Solarize', 0.6, 6)], 200 | [('Invert', 0.9, 3), ('Equalize', 0.6, 3)], 201 | [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)], 202 | [('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)], 203 | [('ShearY', 0.9, 8), ('Invert', 0.4, 5)], 204 | [('ShearY', 0.9, 5), ('Solarize', 0.2, 6)], 205 | [('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)], 206 | [('Equalize', 0.6, 3), ('Rotate', 0.9, 3)], 207 | [('ShearX', 0.9, 4), ('Solarize', 0.3, 3)], 208 | [('ShearY', 0.8, 8), ('Invert', 0.7, 4)], 209 | [('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)], 210 | [('Invert', 0.9, 4), ('Equalize', 0.6, 7)], 211 | [('Contrast', 0.3, 3), ('Rotate', 0.8, 4)], 212 | [('Invert', 0.8, 5), ('TranslateY', 0.0, 2)], 213 | [('ShearY', 0.7, 6), ('Solarize', 0.4, 8)], 214 | [('Invert', 0.6, 4), ('Rotate', 0.8, 4)], 215 | [('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)], 216 | [('ShearX', 0.1, 6), ('Invert', 0.6, 5)], 217 | [('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)], 218 | [('ShearY', 0.8, 4), ('Invert', 0.8, 8)], 219 | [('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)], 220 | [('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)], 221 | [('ShearX', 0.7, 2), ('Invert', 0.1, 5)]] 222 | -------------------------------------------------------------------------------- /figures/no_sam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/sam/dae9904c4cf3a57a304f7b04cecffe371679c702/figures/no_sam.png -------------------------------------------------------------------------------- /figures/sam_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/sam/dae9904c4cf3a57a304f7b04cecffe371679c702/figures/sam_wide.png -------------------------------------------------------------------------------- /figures/summary_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/sam/dae9904c4cf3a57a304f7b04cecffe371679c702/figures/summary_plot.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | astunparse==1.6.3 3 | attrs==20.3.0 4 | cachetools==4.1.1 5 | certifi==2020.11.8 6 | chardet==3.0.4 7 | cloudpickle==1.6.0 8 | cycler==0.10.0 9 | decorator==4.4.2 10 | dill==0.3.3 11 | dm-tree==0.1.5 12 | flatbuffers==1.12 13 | flax==0.2.2 14 | future==0.18.2 15 | gast==0.3.3 16 | google-auth==1.23.0 17 | google-auth-oauthlib==0.4.2 18 | google-pasta==0.2.0 19 | googleapis-common-protos==1.52.0 20 | grpcio==1.34.0 21 | h5py==2.10.0 22 | idna==2.10 23 | importlib-resources==3.3.0 24 | jax==0.2.6 25 | jaxlib==0.1.57 26 | Keras-Preprocessing==1.1.2 27 | kiwisolver==1.3.1 28 | Markdown==3.3.3 29 | matplotlib==3.3.3 30 | msgpack==1.0.0 31 | numpy==1.18.5 32 | oauthlib==3.1.0 33 | opt-einsum==3.3.0 34 | pandas==1.1.4 35 | Pillow==8.0.1 36 | promise==2.3 37 | protobuf==3.14.0 38 | pyasn1==0.4.8 39 | pyasn1-modules==0.2.8 40 | pyparsing==2.4.7 41 | python-dateutil==2.8.1 42 | pytz==2020.4 43 | requests==2.25.0 44 | requests-oauthlib==1.3.0 45 | rsa==4.6 46 | scipy==1.5.4 47 | six==1.15.0 48 | tensorboard==2.4.0 49 | tensorboard-plugin-wit==1.7.0 50 | tensorflow==2.3.1 51 | tensorflow-addons==0.11.2 52 | tensorflow-datasets==4.1.0 53 | tensorflow-estimator==2.3.0 54 | tensorflow-metadata==0.25.0 55 | tensorflow-probability==0.11.1 56 | termcolor==1.1.0 57 | tqdm==4.54.0 58 | typeguard==2.10.0 59 | typing==3.7.4.3 60 | urllib3==1.26.2 61 | Werkzeug==1.0.1 62 | wrapt==1.12.1 63 | -------------------------------------------------------------------------------- /sam_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /sam_jax/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /sam_jax/datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements data augmentations for cifar10/cifar100.""" 16 | 17 | from typing import Dict 18 | 19 | from absl import flags 20 | from sam.autoaugment import autoaugment 21 | import tensorflow as tf 22 | import tensorflow_probability as tfp 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | flags.DEFINE_integer('cutout_length', 16, 29 | 'Length (in pixels) of the cutout patch. Default value of ' 30 | '16 is used to get SOTA on cifar10/cifar100') 31 | 32 | 33 | def weak_image_augmentation(example: Dict[str, tf.Tensor], 34 | random_crop_pad: int = 4) -> Dict[str, tf.Tensor]: 35 | """Applies random crops and horizontal flips. 36 | 37 | Simple data augmentations that are (almost) always used with cifar. Pad the 38 | image with `random_crop_pad` before randomly cropping it to its original 39 | size. Also randomly apply horizontal flip. 40 | 41 | Args: 42 | example: An example dict containing an image and a label. 43 | random_crop_pad: By how many pixels should the image be padded on each side 44 | before cropping. 45 | 46 | Returns: 47 | An example with the same label and an augmented version of the image. 48 | """ 49 | image, label = example['image'], example['label'] 50 | image = tf.image.random_flip_left_right(image) 51 | image_shape = tf.shape(image) 52 | image = tf.pad( 53 | image, [[random_crop_pad, random_crop_pad], 54 | [random_crop_pad, random_crop_pad], [0, 0]], 55 | mode='REFLECT') 56 | image = tf.image.random_crop(image, image_shape) 57 | return {'image': image, 'label': label} 58 | 59 | 60 | def auto_augmentation(example: Dict[str, tf.Tensor], 61 | dataset_name: str) -> Dict[str, tf.Tensor]: 62 | """Applies the AutoAugment policy found for the dataset. 63 | 64 | AutoAugment: Learning Augmentation Policies from Data 65 | https://arxiv.org/abs/1805.09501 66 | 67 | Args: 68 | example: An example dict containing an image and a label. 69 | dataset_name: Name of the dataset for which we should return the optimal 70 | policy. Should be 'cifar[10|100]', 'svhn' or 'imagenet'. 71 | 72 | Returns: 73 | An example with the same label and an augmented version of the image. 74 | """ 75 | if dataset_name in ('cifar10', 'cifar100'): dataset_name = 'cifar' 76 | image, label = example['image'], example['label'] 77 | image = autoaugment.distort_image_with_autoaugment(image, dataset_name) 78 | return {'image': image, 'label': label} 79 | 80 | 81 | def cutout(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: 82 | """Applies cutout to a batch of images. 83 | 84 | The cut out patch will be replaced by zeros (thus the batch should be 85 | normalized before cutout is applied). 86 | 87 | Reference: 88 | Improved Regularization of Convolutional Neural Networks with Cutout 89 | https://arxiv.org/abs/1708.04552 90 | 91 | Implementation inspired by: 92 | third_party/cloud_tpu/models/efficientnet/autoaugment.py 93 | 94 | Args: 95 | batch: A batch of images and labels. 96 | 97 | Returns: 98 | The same batch where cutout has been applied to the images. 99 | """ 100 | length, replace = FLAGS.cutout_length, 0.0 101 | images, labels = batch['image'], batch['label'] 102 | num_channels = tf.shape(images)[3] 103 | image_height, image_width = tf.shape(images)[1], tf.shape(images)[2] 104 | 105 | cutout_center_height = tf.random.uniform( 106 | shape=[], minval=0, maxval=image_height, 107 | dtype=tf.int32) 108 | cutout_center_width = tf.random.uniform( 109 | shape=[], minval=0, maxval=image_width, 110 | dtype=tf.int32) 111 | 112 | lower_pad = tf.maximum(0, cutout_center_height - length // 2) 113 | upper_pad = tf.maximum(0, image_height - cutout_center_height - length // 2) 114 | left_pad = tf.maximum(0, cutout_center_width - length // 2) 115 | right_pad = tf.maximum(0, image_width - cutout_center_width - length // 2) 116 | 117 | cutout_shape = [image_height - (lower_pad + upper_pad), 118 | image_width - (left_pad + right_pad)] 119 | 120 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 121 | 122 | mask = tf.pad( 123 | tf.zeros(cutout_shape, dtype=images.dtype), 124 | padding_dims, constant_values=1) 125 | 126 | patch = tf.ones_like(images, dtype=images.dtype) * replace, 127 | 128 | mask = tf.expand_dims(mask, -1) 129 | mask = tf.tile(mask, [1, 1, num_channels]) 130 | 131 | images = tf.where( 132 | tf.equal(mask, 0), 133 | patch, 134 | images) 135 | 136 | images = tf.squeeze(images, axis=0) 137 | 138 | return {'image': images, 'label': labels} 139 | 140 | 141 | def mixup(batch: Dict[str, tf.Tensor], 142 | alpha: float = 1.0) -> Dict[str, tf.Tensor]: 143 | """Generates augmented images using Mixup. 144 | 145 | Arguments: 146 | batch: Feature dict containing the images and the labels. 147 | alpha: Float that controls the strength of Mixup regularization. 148 | 149 | Returns: 150 | A feature dict containing the images and labels augmented with mixup. 151 | """ 152 | images, labels = batch['image'], batch['label'] 153 | batch_size = 1 # Unique mixing parameter for all samples 154 | mix_weight = tfp.distributions.Beta(alpha, alpha).sample([batch_size, 1]) 155 | mix_weight = tf.maximum(mix_weight, 1. - mix_weight) 156 | images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1]) 157 | images_mix = ( 158 | images * images_mix_weight + images[::-1] * (1. - images_mix_weight)) 159 | labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight) 160 | return {'image': images_mix, 'label': labels_mix} 161 | -------------------------------------------------------------------------------- /sam_jax/datasets/dataset_source.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility class to load datasets and apply data augmentation.""" 16 | 17 | import abc 18 | from typing import Callable, Dict, Optional 19 | 20 | from absl import flags 21 | from absl import logging 22 | import jax 23 | from sam.sam_jax.datasets import augmentation 24 | import tensorflow as tf 25 | import tensorflow_datasets as tfds 26 | 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | flags.DEFINE_bool('use_test_set', True, 32 | 'Whether to use the test set or not. If not, then 10% ' 33 | 'observations will be set aside from the training set and ' 34 | 'used as a validation set instead.') 35 | 36 | 37 | class DatasetSource(abc.ABC): 38 | """Parent for classes that load, preprocess and serve datasets. 39 | 40 | Child class constructor should set a `num_training_obs` and a `batch_size` 41 | attribute. 42 | """ 43 | batch_size = ... # type: int 44 | num_training_obs = ... # type: int 45 | 46 | @abc.abstractmethod 47 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset: 48 | """Returns the training set. 49 | 50 | The training set will be batched, and the remainder of the batch will be 51 | dropped (except if use_augmentation is False, in which case we don't drop 52 | the remainder as we are most likely computing the accuracy on the train set. 53 | 54 | Args: 55 | use_augmentations: Whether we should apply data augmentation (and possibly 56 | cutout) or not. 57 | """ 58 | 59 | @abc.abstractmethod 60 | def get_test(self) -> tf.data.Dataset: 61 | """Returns test set.""" 62 | 63 | 64 | def _resize(image: tf.Tensor, image_size: int, method: Optional[str] = None): 65 | if method is not None: 66 | return tf.image.resize(image, [image_size, image_size], method) 67 | return tf.compat.v1.image.resize_bicubic(image, [image_size, image_size]) 68 | 69 | 70 | class TFDSDatasetSource(DatasetSource): 71 | """Parent for classes that load, preprocess and serve TensorFlow datasets. 72 | 73 | Small datasets like CIFAR, SVHN and Fashion MNIST subclass TFDSDatasetSource. 74 | """ 75 | batch_size = ... # type: int 76 | num_training_obs = ... # type: int 77 | _train_ds = ... # type: tf.data.Dataset 78 | _test_ds = ... # type: tf.data.Dataset 79 | _augmentation = ... # type: str 80 | _num_classes = ... # type: int 81 | _image_mean = ... # type: tf.tensor 82 | _image_std = ... # type: tf.tensor 83 | _dataset_name = ... # type: str 84 | _batch_level_augmentations = ... # type: Callable 85 | _image_size = ... # type: Optional[int] 86 | 87 | def _apply_image_augmentations( 88 | self, example: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: 89 | if self._augmentation in ['autoaugment', 'aa-only']: 90 | example = augmentation.auto_augmentation(example, self._dataset_name) 91 | if self._augmentation in ['basic', 'autoaugment']: 92 | example = augmentation.weak_image_augmentation(example) 93 | return example 94 | 95 | def _preprocess_batch(self, 96 | examples: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: 97 | image, label = examples['image'], examples['label'] 98 | image = tf.cast(image, tf.float32) / 255.0 99 | image = (image - self._image_mean) / self._image_std 100 | label = tf.one_hot( 101 | label, depth=self._num_classes, on_value=1.0, off_value=0.0) 102 | return {'image': image, 'label': label} 103 | 104 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset: 105 | """Returns the training set. 106 | 107 | The training set will be batched, and the remainder of the batch will be 108 | dropped (except if use_augmentations is False, in which case we don't drop 109 | the remainder as we are most likely computing the accuracy on the train 110 | set). 111 | 112 | Args: 113 | use_augmentations: Whether we should apply data augmentation (and possibly 114 | cutout) or not. 115 | """ 116 | ds = self._train_ds.shuffle(50000) 117 | if use_augmentations: 118 | ds = ds.map(self._apply_image_augmentations, 119 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 120 | # Don't drop remainder if we don't use augmentation, as we are evaluating. 121 | ds = ds.batch(self.batch_size, drop_remainder=use_augmentations) 122 | ds = ds.map(self._preprocess_batch, 123 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 124 | if self._batch_level_augmentations and use_augmentations: 125 | ds = ds.map(self._batch_level_augmentations, 126 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 127 | if self._image_size: 128 | def resize(batch): 129 | image = _resize(batch['image'], self._image_size) 130 | return {'image': image, 'label': batch['label']} 131 | ds = ds.map(resize) 132 | return ds 133 | 134 | def get_test(self) -> tf.data.Dataset: 135 | """Returns the batched test set.""" 136 | eval_batch_size = min(32, self.batch_size) 137 | ds = self._test_ds.batch(eval_batch_size).map( 138 | self._preprocess_batch, 139 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 140 | if self._image_size: 141 | def resize(batch): 142 | image = _resize(batch['image'], self._image_size) 143 | return {'image': image, 'label': batch['label']} 144 | ds = ds.map(resize) 145 | return ds 146 | 147 | 148 | class CifarDatasetSource(TFDSDatasetSource): 149 | """Parent class for DatasetSource created from cifar10/cifar100 datasets. 150 | 151 | The child class constructor must set _num_classes (integer, number of classes 152 | in the dataset). 153 | """ 154 | 155 | def __init__(self, batch_size: int, name: str, image_level_augmentations: str, 156 | batch_level_augmentations: str, 157 | image_size: Optional[int] = None): 158 | """Instantiates the DatasetSource. 159 | 160 | Args: 161 | batch_size: Batch size to use for training and evaluation. 162 | name: Name of the Tensorflow Dataset to use. Should be cifar10 or 163 | cifar100. 164 | image_level_augmentations: Augmentations to apply to the images. Should be 165 | one of: 166 | * none: No augmentations are applied. 167 | * basic: Applies random crops and horizontal translations. 168 | * autoaugment: Applies the best found policy for Cifar from the 169 | AutoAugment paper. 170 | batch_level_augmentations: Augmentations to apply at the batch level. Only 171 | cutout is needed to get SOTA results. The following are implemented: 172 | * none: No augmentations are applied. 173 | * cutout: Applies cutout (https://arxiv.org/abs/1708.04552). 174 | * mixup: Applies mixup (https://arxiv.org/pdf/1710.09412.pdf). 175 | * mixcut: Applies mixup and cutout. 176 | image_size: Size to which the image should be rescaled. If None, the 177 | standard size is used (32x32). 178 | """ 179 | assert name in ['cifar10', 'cifar100'] 180 | assert image_level_augmentations in ['none', 'basic', 'autoaugment'] 181 | assert batch_level_augmentations in ['none', 'cutout'] 182 | self._image_size = image_size 183 | self.batch_size = batch_size 184 | if FLAGS.use_test_set: 185 | self.num_training_obs = 50000 186 | train_split_size = self.num_training_obs // jax.host_count() 187 | start = jax.host_id() * train_split_size 188 | train_split = 'train[{}:{}]'.format(start, start + train_split_size) 189 | self._train_ds = tfds.load(name, split=train_split).cache() 190 | self._test_ds = tfds.load(name, split='test').cache() 191 | logging.info('Used test set instead of validation set.') 192 | else: 193 | # Validation split not implemented for multi-host training. 194 | assert jax.host_count() == 1 195 | self._train_ds = tfds.load(name, split='train[:45000]').cache() 196 | self._test_ds = tfds.load(name, split='train[45000:]').cache() 197 | self.num_training_obs = 45000 198 | logging.info('Used validation set instead of test set.') 199 | self._augmentation = image_level_augmentations 200 | if batch_level_augmentations == 'cutout': 201 | self._batch_level_augmentations = augmentation.cutout 202 | elif batch_level_augmentations == 'mixup': 203 | self._batch_level_augmentations = augmentation.mixup 204 | elif batch_level_augmentations == 'mixcut': 205 | self._batch_level_augmentations = ( 206 | lambda x: augmentation.cutout(augmentation.mixup(x))) 207 | else: 208 | self._batch_level_augmentations = None 209 | if name == 'cifar10': 210 | self._image_mean = tf.constant([[[0.49139968, 0.48215841, 0.44653091]]]) 211 | self._image_std = tf.constant([[[0.24703223, 0.24348513, 0.26158784]]]) 212 | else: 213 | self._image_mean = tf.constant([[[0.50707516, 0.48654887, 0.44091784]]]) 214 | self._image_std = tf.constant([[[0.26733429, 0.25643846, 0.27615047]]]) 215 | self._num_classes = None # To define in child classes 216 | 217 | 218 | class Cifar10(CifarDatasetSource): 219 | """Cifar10 DatasetSource.""" 220 | 221 | def __init__(self, batch_size: int, image_level_augmentations: str, 222 | batch_level_augmentations: str, image_size: int = None): 223 | """See parent class for more information.""" 224 | super().__init__(batch_size, 'cifar10', image_level_augmentations, 225 | batch_level_augmentations, image_size) 226 | self._num_classes = 10 227 | self._dataset_name = 'cifar10' 228 | 229 | 230 | class Cifar100(CifarDatasetSource): 231 | """Cifar100 DatasetSource.""" 232 | 233 | def __init__(self, batch_size: int, image_level_augmentations: str, 234 | batch_level_augmentations: str, image_size: int = None): 235 | """See parent class for more information.""" 236 | super().__init__(batch_size, 'cifar100', image_level_augmentations, 237 | batch_level_augmentations, image_size) 238 | self._num_classes = 100 239 | self._dataset_name = 'cifar100' 240 | 241 | 242 | class FashionMnist(TFDSDatasetSource): 243 | """Fashion Mnist dataset.""" 244 | 245 | def __init__(self, batch_size: int, image_level_augmentations: str, 246 | batch_level_augmentations: str): 247 | """Instantiates the DatasetSource. 248 | 249 | Args: 250 | batch_size: Batch size to use for training and evaluation. 251 | image_level_augmentations: Augmentations to apply to the images. Should be 252 | one of: 253 | * none: No augmentations are applied. 254 | * basic: Applies random crops and horizontal translations. 255 | batch_level_augmentations: Augmentations to apply at the batch level. 256 | * none: No augmentations are applied. 257 | * cutout: Applies cutout (https://arxiv.org/abs/1708.04552). 258 | """ 259 | assert image_level_augmentations in ['none', 'basic'] 260 | assert batch_level_augmentations in ['none', 'cutout'] 261 | self.batch_size = batch_size 262 | self._image_size = None 263 | if FLAGS.use_test_set: 264 | self._train_ds = tfds.load('fashion_mnist', split='train').cache() 265 | self._test_ds = tfds.load('fashion_mnist', split='test').cache() 266 | logging.info('Used test set instead of validation set.') 267 | self.num_training_obs = 60000 268 | else: 269 | self._train_ds = tfds.load('fashion_mnist', split='train[:54000]').cache() 270 | self._test_ds = tfds.load('fashion_mnist', split='train[54000:]').cache() 271 | self.num_training_obs = 54000 272 | logging.info('Used validation set instead of test set.') 273 | self._augmentation = image_level_augmentations 274 | if batch_level_augmentations == 'cutout': 275 | self._batch_level_augmentations = augmentation.cutout 276 | else: 277 | self._batch_level_augmentations = None 278 | self._image_mean = tf.constant([[[0.1307]]]) 279 | self._image_std = tf.constant([[[0.3081]]]) 280 | self._num_classes = 10 281 | self._dataset_name = 'fashion_mnist' 282 | 283 | 284 | class SVHN(TFDSDatasetSource): 285 | """SVHN dataset.""" 286 | 287 | def __init__(self, batch_size: int, image_level_augmentations: str, 288 | batch_level_augmentations: str): 289 | """Instantiates the DatasetSource. 290 | 291 | Args: 292 | batch_size: Batch size to use for training and evaluation. 293 | image_level_augmentations: Augmentations to apply to the images. Should be 294 | one of: 295 | * none: No augmentations are applied. 296 | * basic: Applies random crops and horizontal translations. 297 | * autoaugment: Applies the best found policy for SVHN from the 298 | AutoAugment paper. Also applies the basic augmentations on top of it. 299 | * aa-only: Same as autoaugment but doesn't apply the basic 300 | augmentations. Should be preferred for SVHN. 301 | batch_level_augmentations: Augmentations to apply at the batch level. 302 | * none: No augmentations are applied. 303 | * cutout: Applies cutout (https://arxiv.org/abs/1708.04552). 304 | """ 305 | assert image_level_augmentations in [ 306 | 'none', 'basic', 'autoaugment', 'aa-only'] 307 | assert batch_level_augmentations in ['none', 'cutout'] 308 | self.batch_size = batch_size 309 | self._image_size = None 310 | if FLAGS.use_test_set: 311 | ds_base = tfds.load('svhn_cropped', split='train') 312 | ds_extra = tfds.load('svhn_cropped', split='extra') 313 | self._train_ds = ds_base.concatenate(ds_extra).cache() 314 | self._test_ds = tfds.load('svhn_cropped', split='test').cache() 315 | logging.info('Used test set instead of validation set.') 316 | self.num_training_obs = 73257+531131 317 | else: 318 | ds_base = tfds.load('svhn_cropped', split='train[:65929]') 319 | ds_extra = tfds.load('svhn_cropped', split='extra') 320 | self._train_ds = ds_base.concatenate(ds_extra).cache() 321 | self._test_ds = tfds.load('svhn_cropped', split='train[65929:]').cache() 322 | self.num_training_obs = 65929+531131 323 | logging.info('Used validation set instead of test set.') 324 | self._augmentation = image_level_augmentations 325 | if batch_level_augmentations == 'cutout': 326 | self._batch_level_augmentations = augmentation.cutout 327 | else: 328 | self._batch_level_augmentations = None 329 | self._image_mean = tf.constant([[[0.43090966, 0.4302428, 0.44634357]]]) 330 | self._image_std = tf.constant([[[0.19759192, 0.20029082, 0.19811132]]]) 331 | self._num_classes = 10 332 | self._dataset_name = 'svhn' 333 | -------------------------------------------------------------------------------- /sam_jax/datasets/dataset_source_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Imagenet DatasetSource. 16 | 17 | Initially forked from: 18 | https://github.com/google/flax/blob/master/examples/imagenet/input_pipeline.py 19 | """ 20 | 21 | from typing import Dict, Tuple 22 | 23 | from absl import flags 24 | from absl import logging 25 | import jax 26 | from sam.autoaugment import autoaugment 27 | from sam.sam_jax.datasets import dataset_source 28 | import tensorflow as tf 29 | import tensorflow_datasets as tfds 30 | import tensorflow_probability as tfp 31 | 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | flags.DEFINE_integer('randaug_num_layers', 2, 37 | 'Number of augmentations applied to each images by ' 38 | 'RandAugment. Typical value is 2 and is generally not ' 39 | 'changed.') 40 | flags.DEFINE_integer('randaug_magnitude', 9, 41 | 'Magnitude of augmentations applied by RandAugment.') 42 | flags.DEFINE_float('imagenet_mixup_alpha', 0.0, 'If > 0, use mixup.') 43 | 44 | 45 | TRAIN_IMAGES = 1281167 46 | EVAL_IMAGES = 50000 47 | 48 | IMAGE_SIZE = 224 49 | CROP_PADDING = 32 50 | MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] 51 | STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] 52 | 53 | 54 | def _distorted_bounding_box_crop(image_bytes: tf.Tensor, 55 | bbox: tf.Tensor, 56 | min_object_covered: float = 0.1, 57 | aspect_ratio_range: Tuple[float, 58 | float] = (0.75, 59 | 1.33), 60 | area_range: Tuple[float, float] = (0.05, 1.0), 61 | max_attempts: int = 100) -> tf.Tensor: 62 | """Generates cropped_image using one of the bboxes randomly distorted. 63 | 64 | See `tf.image.sample_distorted_bounding_box` for more documentation. 65 | 66 | Args: 67 | image_bytes: `Tensor` of binary image data. 68 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 69 | where each coordinate is [0, 1) and the coordinates are arranged 70 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 71 | image. 72 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 73 | area of the image must contain at least this fraction of any bounding 74 | box supplied. 75 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 76 | image must have an aspect ratio = width / height within this range. 77 | area_range: An optional list of `float`s. The cropped area of the image 78 | must contain a fraction of the supplied image within in this range. 79 | max_attempts: An optional `int`. Number of attempts at generating a cropped 80 | region of the image of the specified constraints. After `max_attempts` 81 | failures, return the entire image. 82 | Returns: 83 | cropped image `Tensor` 84 | """ 85 | shape = tf.image.extract_jpeg_shape(image_bytes) 86 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 87 | shape, 88 | bounding_boxes=bbox, 89 | min_object_covered=min_object_covered, 90 | aspect_ratio_range=aspect_ratio_range, 91 | area_range=area_range, 92 | max_attempts=max_attempts, 93 | use_image_if_no_bounding_boxes=True) 94 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 95 | 96 | # Crop the image to the specified bounding box. 97 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 98 | target_height, target_width, _ = tf.unstack(bbox_size) 99 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 100 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 101 | 102 | return image 103 | 104 | 105 | def _resize(image: tf.Tensor, image_size: int) -> tf.Tensor: 106 | """Returns the resized image.""" 107 | return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0] 108 | 109 | 110 | def _at_least_x_are_equal(a: tf.Tensor, b: tf.Tensor, 111 | x: int) -> tf.Tensor: 112 | """At least `x` of `a` and `b` `Tensors` are equal.""" 113 | match = tf.equal(a, b) 114 | match = tf.cast(match, tf.int32) 115 | return tf.greater_equal(tf.reduce_sum(match), x) 116 | 117 | 118 | def _decode_and_random_crop(image_bytes: tf.Tensor, 119 | image_size: int) -> tf.Tensor: 120 | """Make a random crop of image_size.""" 121 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 122 | image = _distorted_bounding_box_crop( 123 | image_bytes, 124 | bbox, 125 | min_object_covered=0.1, 126 | aspect_ratio_range=(3. / 4, 4. / 3.), 127 | area_range=(0.08, 1.0), 128 | max_attempts=10) 129 | original_shape = tf.image.extract_jpeg_shape(image_bytes) 130 | bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) 131 | 132 | image = tf.cond(bad, lambda: _decode_and_center_crop(image_bytes, image_size), 133 | lambda: _resize(image, image_size)) 134 | 135 | return image 136 | 137 | 138 | def _decode_and_center_crop(image_bytes: tf.Tensor, 139 | image_size: int) -> tf.Tensor: 140 | """Crops to center of image with padding then scales image_size.""" 141 | shape = tf.image.extract_jpeg_shape(image_bytes) 142 | image_height = shape[0] 143 | image_width = shape[1] 144 | 145 | padded_center_crop_size = tf.cast( 146 | ((image_size / (image_size + CROP_PADDING)) * 147 | tf.cast(tf.minimum(image_height, image_width), tf.float32)), 148 | tf.int32) 149 | 150 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 151 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 152 | crop_window = tf.stack([offset_height, offset_width, 153 | padded_center_crop_size, padded_center_crop_size]) 154 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 155 | image = _resize(image, image_size) 156 | 157 | return image 158 | 159 | 160 | def normalize_image(image: tf.Tensor) -> tf.Tensor: 161 | """Returns the normalized image. 162 | 163 | Image is normalized so that the mean and variance of each channel over the 164 | dataset is 0 and 1. 165 | 166 | Args: 167 | image: An image from the Imagenet dataset to normalize. 168 | """ 169 | image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) 170 | image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) 171 | return image 172 | 173 | 174 | def preprocess_for_train(image_bytes: tf.Tensor, 175 | dtype: tf.DType = tf.float32, 176 | image_size: int = IMAGE_SIZE, 177 | use_autoaugment: bool = False) -> tf.Tensor: 178 | """Preprocesses the given image for training. 179 | 180 | Args: 181 | image_bytes: `Tensor` representing an image binary of arbitrary size. 182 | dtype: Data type of the returned image. 183 | image_size: Size of the returned image. 184 | use_autoaugment: If True, will apply autoaugment to the inputs. 185 | 186 | Returns: 187 | A preprocessed image `Tensor`. 188 | """ 189 | image = _decode_and_random_crop(image_bytes, image_size) 190 | image = tf.reshape(image, [image_size, image_size, 3]) 191 | if use_autoaugment: 192 | logging.info('Using autoaugment.') 193 | image = tf.cast(image, tf.uint8) 194 | image = autoaugment.distort_image_with_randaugment(image, 195 | FLAGS.randaug_num_layers, 196 | FLAGS.randaug_magnitude) 197 | image = tf.cast(image, tf.float32) 198 | image = tf.image.random_flip_left_right(image) 199 | image = normalize_image(image) 200 | image = tf.image.convert_image_dtype(image, dtype=dtype) 201 | return image 202 | 203 | 204 | def preprocess_for_eval(image_bytes: tf.Tensor, 205 | dtype: tf.DType = tf.float32, 206 | image_size: int = IMAGE_SIZE) -> tf.Tensor: 207 | """Preprocesses the given image for evaluation. 208 | 209 | Args: 210 | image_bytes: `Tensor` representing an image binary of arbitrary size. 211 | dtype: Data type of the returned image. 212 | image_size: Size of the returned image. 213 | 214 | Returns: 215 | A preprocessed image `Tensor`. 216 | """ 217 | image = _decode_and_center_crop(image_bytes, image_size) 218 | image = tf.reshape(image, [image_size, image_size, 3]) 219 | image = normalize_image(image) 220 | image = tf.image.convert_image_dtype(image, dtype=dtype) 221 | return image 222 | 223 | 224 | def load_split(train: bool, 225 | cache: bool) -> tf.data.Dataset: 226 | """Creates a split from the ImageNet dataset using TensorFlow Datasets. 227 | 228 | Args: 229 | train: Whether to load the train or evaluation split. 230 | cache: Whether to cache the dataset. 231 | Returns: 232 | A `tf.data.Dataset`. 233 | """ 234 | if train: 235 | split_size = TRAIN_IMAGES // jax.host_count() 236 | start = jax.host_id() * split_size 237 | split = 'train[{}:{}]'.format(start, start + split_size) 238 | else: 239 | # For validation, we load up the dataset on each host. This will have the 240 | # effect of evaluating on the whole dataset num_host times, but will 241 | # prevent size issues. This makes the performance slightly worse when 242 | # evaluating often, but spares us the need to pad the datasets and mask the 243 | # loss accordingly. 244 | split = 'validation' 245 | 246 | ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={ 247 | 'image': tfds.decode.SkipDecoding(), 248 | }) 249 | ds.options().experimental_threading.private_threadpool_size = 48 250 | ds.options().experimental_threading.max_intra_op_parallelism = 1 251 | 252 | if cache: 253 | ds = ds.cache() 254 | 255 | return ds 256 | 257 | 258 | def mixup(batch: Dict[str, tf.Tensor], alpha: float) -> Dict[str, tf.Tensor]: 259 | """Generates augmented images using Mixup. 260 | 261 | Arguments: 262 | batch: Feature dict containing the images and the labels. 263 | alpha: Float that controls the strength of Mixup regularization. 264 | 265 | Returns: 266 | A feature dict containing the mix-uped images. 267 | """ 268 | images, labels = batch['image'], batch['label'] 269 | batch_size = 1 # Unique mixing parameter for all samples 270 | mix_weight = tfp.distributions.Beta(alpha, alpha).sample([batch_size, 1]) 271 | mix_weight = tf.maximum(mix_weight, 1. - mix_weight) 272 | images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1]) 273 | images_mix = ( 274 | images * images_mix_weight + images[::-1] * (1. - images_mix_weight)) 275 | labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight) 276 | return {'image': images_mix, 'label': labels_mix} 277 | 278 | 279 | class Imagenet(dataset_source.DatasetSource): 280 | """Class that loads, preprocess and serves the Imagenet dataset.""" 281 | 282 | def __init__(self, batch_size: int, image_size: int, 283 | image_level_augmentations: str = 'none'): 284 | """Instantiates the Imagenet dataset source. 285 | 286 | Args: 287 | batch_size: Global batch size used to train the model. 288 | image_size: Size to which the images should be resized (in number of 289 | pixels). 290 | image_level_augmentations: If set to 'autoaugment', will apply 291 | RandAugment to the training set. 292 | """ 293 | self.batch_size = batch_size 294 | self.image_size = image_size 295 | self.num_training_obs = TRAIN_IMAGES 296 | self._train_ds = load_split(train=True, cache=True) 297 | self._test_ds = load_split(train=False, cache=True) 298 | self._num_classes = 1000 299 | self._image_level_augmentations = image_level_augmentations 300 | 301 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset: 302 | """Returns the training set. 303 | 304 | The training set will be batched, and the remainder of the batch will be 305 | dropped (except if use_augmentation is False, in which case we don't drop 306 | the remainder as we are most likely computing the accuracy on the train 307 | set). 308 | 309 | Args: 310 | use_augmentations: Whether we should apply data augmentation (and possibly 311 | cutout) or not. 312 | """ 313 | ds = self._train_ds.shuffle(16 * self.batch_size) 314 | ds = ds.map(lambda d: self.decode_example( # pylint:disable=g-long-lambda 315 | d, use_augmentations=use_augmentations), 316 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 317 | 318 | batched = ds.batch(self.batch_size, drop_remainder=use_augmentations) 319 | if use_augmentations and FLAGS.imagenet_mixup_alpha > 0.0: 320 | batched = batched.map(lambda b: mixup(b, FLAGS.imagenet_mixup_alpha), 321 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 322 | return batched 323 | 324 | def get_test(self) -> tf.data.Dataset: 325 | """Returns test set.""" 326 | ds = self._test_ds.map( 327 | lambda d: self.decode_example( # pylint:disable=g-long-lambda 328 | d, use_augmentations=False), 329 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 330 | return ds.batch(self.batch_size, drop_remainder=False) 331 | 332 | def decode_example(self, example: Dict[str, tf.Tensor], 333 | use_augmentations: bool) -> Dict[str, tf.Tensor]: 334 | """Decodes the raw examples from the imagenet tensorflow dataset. 335 | 336 | Args: 337 | example: A feature dict as returned by the tensorflow imagenet dataset. 338 | use_augmentations: Whether to use train time data augmentation or not. 339 | 340 | Returns: 341 | A dictionnary with an 'image' tensor and a one hot encoded 'label' tensor. 342 | """ 343 | if use_augmentations: 344 | image = preprocess_for_train( 345 | example['image'], 346 | image_size=self.image_size, 347 | use_autoaugment=self._image_level_augmentations == 'autoaugment') 348 | else: 349 | image = preprocess_for_eval(example['image'], image_size=self.image_size) 350 | label = tf.one_hot( 351 | example['label'], depth=self._num_classes, on_value=1.0, off_value=0.0) 352 | return {'image': image, 'label': label} 353 | -------------------------------------------------------------------------------- /sam_jax/datasets/dataset_source_imagenet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.datasets.dataset_source_imagenet.""" 16 | 17 | from absl.testing import absltest 18 | from sam.sam_jax.datasets import dataset_source_imagenet 19 | 20 | 21 | class DatasetSourceImagenetTest(absltest.TestCase): 22 | 23 | def test_LoadImagenet(self): 24 | dataset = dataset_source_imagenet.Imagenet( 25 | batch_size=16, image_size=127, image_level_augmentations='autoaugment') 26 | batch = next(iter(dataset.get_train(use_augmentations=True))) 27 | self.assertEqual(batch['image'].shape, [16, 127, 127, 3]) 28 | self.assertEqual(batch['label'].shape, [16, 1000]) 29 | 30 | 31 | if __name__ == '__main__': 32 | absltest.main() 33 | -------------------------------------------------------------------------------- /sam_jax/datasets/dataset_source_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.datasets.dataset_source.""" 16 | 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from sam.sam_jax.datasets import dataset_source 21 | 22 | 23 | class DatasetSourceTest(parameterized.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ('none', 'none'), 27 | ('cutout', 'cutout')) 28 | def test_LoadCifar10(self, batch_level_augmentation: str): 29 | cifar_10_source = dataset_source.Cifar10( 30 | 2, 31 | image_level_augmentations='autoaugment', 32 | batch_level_augmentations=batch_level_augmentation) 33 | for batch in cifar_10_source.get_train(use_augmentations=True): 34 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3]) 35 | self.assertEqual(batch['label'].shape, [2, 10]) 36 | break 37 | for batch in cifar_10_source.get_test(): 38 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3]) 39 | self.assertEqual(batch['label'].shape, [2, 10]) 40 | break 41 | 42 | @parameterized.named_parameters( 43 | ('none', 'none'), 44 | ('cutout', 'cutout')) 45 | def test_LoadCifar100(self, batch_level_augmentation: str): 46 | cifar_100_source = dataset_source.Cifar100( 47 | 2, 48 | image_level_augmentations='autoaugment', 49 | batch_level_augmentations=batch_level_augmentation) 50 | for batch in cifar_100_source.get_train(use_augmentations=True): 51 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3]) 52 | self.assertEqual(batch['label'].shape, [2, 100]) 53 | break 54 | for batch in cifar_100_source.get_test(): 55 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3]) 56 | self.assertEqual(batch['label'].shape, [2, 100]) 57 | break 58 | 59 | @parameterized.named_parameters( 60 | ('none', 'none'), 61 | ('cutout', 'cutout')) 62 | def test_LoadFashionMnist(self, batch_level_augmentation: str): 63 | fashion_mnist_source = dataset_source.FashionMnist( 64 | 2, 65 | image_level_augmentations='basic', 66 | batch_level_augmentations=batch_level_augmentation) 67 | for batch in fashion_mnist_source.get_train(use_augmentations=True): 68 | self.assertEqual(batch['image'].shape, [2, 28, 28, 1]) 69 | self.assertEqual(batch['label'].shape, [2, 10]) 70 | break 71 | for batch in fashion_mnist_source.get_test(): 72 | self.assertEqual(batch['image'].shape, [2, 28, 28, 1]) 73 | self.assertEqual(batch['label'].shape, [2, 10]) 74 | break 75 | 76 | 77 | if __name__ == '__main__': 78 | absltest.main() 79 | -------------------------------------------------------------------------------- /sam_jax/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /sam_jax/efficientnet/efficientnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines Efficientnet model.""" 16 | 17 | import copy 18 | import math 19 | from typing import Any, Optional, Tuple, Union 20 | 21 | from absl import flags 22 | from absl import logging 23 | import flax 24 | from flax import nn 25 | import jax 26 | from jax import numpy as jnp 27 | import tensorflow as tf 28 | 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def name_to_image_size(name: str) -> int: 34 | """Returns the expected image size for a given model. 35 | 36 | If the model is not a recognized efficientnet model, will default to the 37 | standard resolution of 224 (for Resnet, etc...). 38 | 39 | Args: 40 | name: Name of the efficientnet model (ex: efficientnet-b0). 41 | """ 42 | image_sizes = { 43 | 'efficientnet-b0': 224, 44 | 'efficientnet-b1': 240, 45 | 'efficientnet-b2': 260, 46 | 'efficientnet-b3': 300, 47 | 'efficientnet-b4': 380, 48 | 'efficientnet-b5': 456, 49 | 'efficientnet-b6': 528, 50 | 'efficientnet-b7': 600, 51 | 'efficientnet-b8': 672, 52 | 'efficientnet-l2': 800, 53 | 'efficientnet-l2-475': 475, 54 | } 55 | return image_sizes.get(name, 224) 56 | 57 | 58 | # Relevant initializers. The original implementation uses fan_out Kaiming init. 59 | 60 | conv_kernel_init_fn = jax.nn.initializers.variance_scaling( 61 | 2.0, 'fan_out', 'truncated_normal') 62 | 63 | dense_kernel_init_fn = jax.nn.initializers.variance_scaling( 64 | 1 / 3.0, 'fan_out', 'uniform') 65 | 66 | 67 | class DepthwiseConv(flax.nn.Module): 68 | """Depthwise convolution that matches tensorflow's conventions. 69 | 70 | In Tensorflow, the shapes of depthwise kernels don't match the shapes of a 71 | regular convolutional kernel of appropriate feature_group_count. 72 | It is safer to use this class instead of the regular Conv (easier port of 73 | tensorflow checkpoints, fan_out initialization of the previous layer will 74 | match the tensorflow behavior, etc...). 75 | """ 76 | 77 | def apply(self, 78 | inputs: jnp.ndarray, 79 | features: int, 80 | kernel_size: Tuple[int, int], 81 | strides: bool = None, 82 | padding: str = 'SAME', 83 | input_dilation: int = None, 84 | kernel_dilation: int = None, 85 | bias: bool = True, 86 | dtype: jnp.dtype = jnp.float32, 87 | precision=None, 88 | kernel_init=flax.nn.initializers.lecun_normal(), 89 | bias_init=flax.nn.initializers.zeros) -> jnp.ndarray: 90 | """Applies a convolution to the inputs. 91 | 92 | Args: 93 | inputs: Input data with dimensions (batch, spatial_dims..., features). 94 | features: Number of convolution filters. 95 | kernel_size: Shape of the convolutional kernel. 96 | strides: A sequence of `n` integers, representing the inter-window 97 | strides. 98 | padding: Either the string `'SAME'`, the string `'VALID'`, or a sequence 99 | of `n` `(low, high)` integer pairs that give the padding to apply before 100 | and after each spatial dimension. 101 | input_dilation: `None`, or a sequence of `n` integers, giving the 102 | dilation factor to apply in each spatial dimension of `inputs`. 103 | Convolution with input dilation `d` is equivalent to transposed 104 | convolution with stride `d`. 105 | kernel_dilation: `None`, or a sequence of `n` integers, giving the 106 | dilation factor to apply in each spatial dimension of the convolution 107 | kernel. Convolution with kernel dilation is also known as 'atrous 108 | convolution'. 109 | bias: Whether to add a bias to the output (default: True). 110 | dtype: The dtype of the computation (default: float32). 111 | precision: Numerical precision of the computation see `jax.lax.Precision` 112 | for details. 113 | kernel_init: Initializer for the convolutional kernel. 114 | bias_init: Initializer for the bias. 115 | 116 | Returns: 117 | The convolved data. 118 | """ 119 | 120 | inputs = jnp.asarray(inputs, dtype) 121 | in_features = inputs.shape[-1] 122 | 123 | if strides is None: 124 | strides = (1,) * (inputs.ndim - 2) 125 | 126 | kernel_shape = kernel_size + (features, 1) 127 | # Naming convention follows tensorflow. 128 | kernel = self.param('depthwise_kernel', kernel_shape, kernel_init) 129 | kernel = jnp.asarray(kernel, dtype) 130 | 131 | # Need to transpose to convert tensorflow-shaped kernel to lax-shaped kernel 132 | kernel = jnp.transpose(kernel, [0, 1, 3, 2]) 133 | 134 | dimension_numbers = flax.nn.linear._conv_dimension_numbers(inputs.shape) # pylint:disable=protected-access 135 | 136 | y = jax.lax.conv_general_dilated( 137 | inputs, 138 | kernel, 139 | strides, 140 | padding, 141 | lhs_dilation=input_dilation, 142 | rhs_dilation=kernel_dilation, 143 | dimension_numbers=dimension_numbers, 144 | feature_group_count=in_features, 145 | precision=precision) 146 | 147 | if bias: 148 | bias = self.param('bias', (features,), bias_init) 149 | bias = jnp.asarray(bias, dtype) 150 | y = y + bias 151 | return y 152 | 153 | 154 | # pytype: disable=attribute-error 155 | # pylint:disable=unused-argument 156 | class BlockConfig(object): 157 | """Class that contains configuration parameters for a single block.""" 158 | 159 | def __init__(self, 160 | input_filters: int = 0, 161 | output_filters: int = 0, 162 | kernel_size: int = 3, 163 | num_repeat: int = 1, 164 | expand_ratio: int = 1, 165 | strides: Tuple[int, int] = (1, 1), 166 | se_ratio: Optional[float] = None, 167 | id_skip: bool = True, 168 | fused_conv: bool = False, 169 | conv_type: str = 'depthwise'): 170 | for arg in locals().items(): 171 | setattr(self, *arg) 172 | 173 | 174 | class ModelConfig(object): 175 | """Class that contains configuration parameters for the model.""" 176 | 177 | def __init__( 178 | self, 179 | width_coefficient: float = 1.0, 180 | depth_coefficient: float = 1.0, 181 | resolution: int = 224, 182 | dropout_rate: float = 0.2, 183 | blocks: Tuple[BlockConfig, ...] = ( 184 | # (input_filters, output_filters, kernel_size, num_repeat, 185 | # expand_ratio, strides, se_ratio) 186 | # pylint: disable=bad-whitespace 187 | BlockConfig(32, 16, 3, 1, 1, (1, 1), 0.25), 188 | BlockConfig(16, 24, 3, 2, 6, (2, 2), 0.25), 189 | BlockConfig(24, 40, 5, 2, 6, (2, 2), 0.25), 190 | BlockConfig(40, 80, 3, 3, 6, (2, 2), 0.25), 191 | BlockConfig(80, 112, 5, 3, 6, (1, 1), 0.25), 192 | BlockConfig(112, 192, 5, 4, 6, (2, 2), 0.25), 193 | BlockConfig(192, 320, 3, 1, 6, (1, 1), 0.25), 194 | # pylint: enable=bad-whitespace 195 | ), 196 | stem_base_filters: int = 32, 197 | top_base_filters: int = 1280, 198 | activation: str = 'swish', 199 | batch_norm: str = 'default', 200 | bn_momentum: float = 0.99, 201 | bn_epsilon: float = 1e-3, 202 | # While the original implementation used a weight decay of 1e-5, 203 | # tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras 204 | weight_decay: float = 5e-6, 205 | drop_connect_rate: float = 0.2, 206 | depth_divisor: int = 8, 207 | min_depth: Optional[int] = None, 208 | use_se: bool = True, 209 | input_channels: int = 3, 210 | model_name: str = 'efficientnet', 211 | rescale_input: bool = True, 212 | data_format: str = 'channels_last', 213 | dtype: str = 'float32'): 214 | """Default Config for Efficientnet-B0.""" 215 | for arg in locals().items(): 216 | setattr(self, *arg) 217 | # pylint:enable=unused-argument 218 | 219 | 220 | MODEL_CONFIGS = { 221 | # (width, depth, resolution, dropout) 222 | 'efficientnet-b0': ModelConfig(1.0, 1.0, 224, 0.2), 223 | 'efficientnet-b1': ModelConfig(1.0, 1.1, 240, 0.2), 224 | 'efficientnet-b2': ModelConfig(1.1, 1.2, 260, 0.3), 225 | 'efficientnet-b3': ModelConfig(1.2, 1.4, 300, 0.3), 226 | 'efficientnet-b4': ModelConfig(1.4, 1.8, 380, 0.4), 227 | 'efficientnet-b5': ModelConfig(1.6, 2.2, 456, 0.4), 228 | 'efficientnet-b6': ModelConfig(1.8, 2.6, 528, 0.5), 229 | 'efficientnet-b7': ModelConfig(2.0, 3.1, 600, 0.5), 230 | 'efficientnet-b8': ModelConfig(2.2, 3.6, 672, 0.5), 231 | 'efficientnet-l2': ModelConfig(4.3, 5.3, 800, 0.5), 232 | 'efficientnet-l2-475': ModelConfig(4.3, 5.3, 475, 0.5), 233 | } 234 | 235 | 236 | def round_filters(filters: int, 237 | config: ModelConfig) -> int: 238 | """Returns rounded number of filters based on width coefficient.""" 239 | width_coefficient = config.width_coefficient 240 | min_depth = config.min_depth 241 | divisor = config.depth_divisor 242 | orig_filters = filters 243 | 244 | if not width_coefficient: 245 | return filters 246 | 247 | filters *= width_coefficient 248 | min_depth = min_depth or divisor 249 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 250 | # Make sure that round down does not go down by more than 10%. 251 | if new_filters < 0.9 * filters: 252 | new_filters += divisor 253 | logging.info('round_filter input=%s output=%s', orig_filters, new_filters) 254 | return int(new_filters) 255 | 256 | 257 | def round_repeats(repeats: int, depth_coefficient: float) -> int: 258 | """Returns rounded number of repeats based on depth coefficient.""" 259 | return int(math.ceil(depth_coefficient * repeats)) 260 | 261 | 262 | def conv2d(inputs: tf.Tensor, 263 | conv_filters: Optional[int], 264 | config: ModelConfig, 265 | kernel_size: Union[int, Tuple[int, int]] = (1, 1), 266 | strides: Tuple[int, int] = (1, 1), 267 | use_batch_norm: bool = True, 268 | use_bias: bool = False, 269 | activation: Any = None, 270 | depthwise: bool = False, 271 | train: bool = True, 272 | conv_name: str = None, 273 | bn_name: str = None) -> jnp.ndarray: 274 | """Convolutional layer with possibly batch norm and activation. 275 | 276 | Args: 277 | inputs: Input data with dimensions (batch, spatial_dims..., features). 278 | conv_filters: Number of convolution filters. 279 | config: Configuration for the model. 280 | kernel_size: Size of the kernel, as a tuple of int. 281 | strides: Strides for the convolution, as a tuple of int. 282 | use_batch_norm: Whether batch norm should be applied to the output. 283 | use_bias: Whether we should add bias to the output of the first convolution. 284 | activation: Name of the activation function to use. 285 | depthwise: If true, will use depthwise convolutions. 286 | train: Whether the model should behave in training or inference mode. 287 | conv_name: Name to give to the convolution layer. 288 | bn_name: Name to give to the batch norm layer. 289 | 290 | Returns: 291 | The output of the convolutional layer. 292 | """ 293 | conv_fn = DepthwiseConv if depthwise else flax.nn.Conv 294 | kernel_size = ((kernel_size, kernel_size) 295 | if isinstance(kernel_size, int) else tuple(kernel_size)) 296 | conv_name = conv_name if conv_name else 'conv2d' 297 | bn_name = bn_name if bn_name else 'batch_normalization' 298 | 299 | x = conv_fn( 300 | inputs, 301 | conv_filters, 302 | kernel_size, 303 | tuple(strides), 304 | padding='SAME', 305 | bias=use_bias, 306 | kernel_init=conv_kernel_init_fn, 307 | name=conv_name) 308 | 309 | if use_batch_norm: 310 | x = nn.BatchNorm( 311 | x, 312 | use_running_average=not train or FLAGS.from_pretrained_checkpoint, 313 | momentum=config.bn_momentum, 314 | epsilon=config.bn_epsilon, 315 | name=bn_name, 316 | axis_name='batch') 317 | 318 | if activation is not None: 319 | x = getattr(flax.nn.activation, activation.lower())(x) 320 | return x 321 | 322 | 323 | def stochastic_depth(inputs: jnp.ndarray, 324 | survival_probability: float, 325 | deterministic: bool = False, 326 | rng: Optional[jnp.ndarray] = None) -> jnp.ndarray: 327 | """Applies stochastic depth. 328 | 329 | Args: 330 | inputs: The inputs that should be randomly masked. 331 | survival_probability: 1 - the probablity of masking out a value. 332 | deterministic: If false the inputs are scaled by `1 / (1 - rate)` and 333 | masked, whereas if true, no mask is applied and the inputs are returned as 334 | is. 335 | rng: An optional `jax.random.PRNGKey`. By default `nn.make_rng()` will 336 | be used. 337 | 338 | Returns: 339 | The masked inputs. 340 | """ 341 | if survival_probability == 1.0 or deterministic: 342 | return inputs 343 | 344 | if rng is None: 345 | rng = flax.nn.make_rng() 346 | mask_shape = [inputs.shape[0]]+ [1 for _ in inputs.shape[1:]] 347 | mask = jax.random.bernoulli(rng, p=survival_probability, shape=mask_shape) 348 | mask = jnp.tile(mask, [1] + list(inputs.shape[1:])) 349 | return jax.lax.select(mask, inputs / survival_probability, 350 | jnp.zeros_like(inputs)) 351 | 352 | 353 | class SqueezeExcite(flax.nn.Module): 354 | """SqueezeExite block (see paper for more details.)""" 355 | 356 | def apply(self, 357 | x: jnp.ndarray, 358 | filters: int, 359 | block: BlockConfig, 360 | config: ModelConfig, 361 | train: bool) -> jnp.ndarray: 362 | """Applies a convolution to the inputs. 363 | 364 | Args: 365 | x: Input data with dimensions (batch, spatial_dims..., features). 366 | filters: Number of convolution filters. 367 | block: Configuration for this block. 368 | config: Configuration for the model. 369 | train: Whether the model is in training or inference mode. 370 | 371 | Returns: 372 | The output of the squeeze excite block. 373 | """ 374 | conv_index = 0 375 | num_reduced_filters = max(1, int(block.input_filters * block.se_ratio)) 376 | 377 | se = flax.nn.avg_pool(x, x.shape[1:3]) 378 | se = conv2d( 379 | se, 380 | num_reduced_filters, 381 | config, 382 | use_bias=True, 383 | use_batch_norm=False, 384 | activation=config.activation, 385 | conv_name='reduce_conv2d_' + str(conv_index), 386 | train=train) 387 | conv_index += 1 388 | 389 | se = conv2d( 390 | se, 391 | filters, 392 | config, 393 | use_bias=True, 394 | use_batch_norm=False, 395 | activation='sigmoid', 396 | conv_name='expand_conv2d_' + str(conv_index), 397 | train=train) 398 | conv_index += 1 399 | x = x * se 400 | return x 401 | 402 | 403 | class MBConvBlock(flax.nn.Module): 404 | """Main building component of Efficientnet.""" 405 | 406 | def apply(self, 407 | inputs: jnp.ndarray, 408 | block: BlockConfig, 409 | config: ModelConfig, 410 | train: bool = False) -> jnp.ndarray: 411 | """Mobile Inverted Residual Bottleneck. 412 | 413 | Args: 414 | inputs: Input to the block. 415 | block: BlockConfig, arguments to create a Block. 416 | config: ModelConfig, a set of model parameters. 417 | train: Whether we are training or predicting. 418 | 419 | Returns: 420 | The output of the block. 421 | """ 422 | use_se = config.use_se 423 | activation = config.activation 424 | drop_connect_rate = config.drop_connect_rate 425 | use_depthwise = block.conv_type != 'no_depthwise' 426 | 427 | filters = block.input_filters * block.expand_ratio 428 | 429 | x = inputs 430 | bn_index = 0 431 | conv_index = 0 432 | 433 | if block.fused_conv: 434 | # If we use fused mbconv, skip expansion and use regular conv. 435 | x = conv2d( 436 | x, 437 | filters, 438 | config, 439 | kernel_size=block.kernel_size, 440 | strides=block.strides, 441 | activation=activation, 442 | conv_name='fused_conv2d_' + str(conv_index), 443 | bn_name='batch_normalization_' + str(bn_index), 444 | train=train) 445 | bn_index += 1 446 | conv_index += 1 447 | else: 448 | if block.expand_ratio != 1: 449 | # Expansion phase 450 | kernel_size = (1, 1) if use_depthwise else (3, 3) 451 | x = conv2d( 452 | x, 453 | filters, 454 | config, 455 | kernel_size=kernel_size, 456 | activation=activation, 457 | conv_name='expand_conv2d_' + str(conv_index), 458 | bn_name='batch_normalization_' + str(bn_index), 459 | train=train) 460 | bn_index += 1 461 | conv_index += 1 462 | # Depthwise Convolution 463 | if use_depthwise: 464 | x = conv2d(x, 465 | conv_filters=x.shape[-1], # Depthwise conv 466 | config=config, 467 | kernel_size=block.kernel_size, 468 | strides=block.strides, 469 | activation=activation, 470 | depthwise=True, 471 | conv_name='depthwise_conv2d', 472 | bn_name='batch_normalization_' + str(bn_index), 473 | train=train) 474 | bn_index += 1 475 | 476 | # Squeeze and Excitation phase 477 | if use_se: 478 | assert block.se_ratio is not None 479 | assert 0 < block.se_ratio <= 1 480 | x = SqueezeExcite(x, filters, block, config, train=train) 481 | 482 | # Output phase 483 | x = conv2d( 484 | x, 485 | block.output_filters, 486 | config, 487 | activation=None, 488 | conv_name='project_conv2d_' + str(conv_index), 489 | bn_name='batch_normalization_' + str(bn_index), 490 | train=train) 491 | conv_index += 1 492 | 493 | if (block.id_skip and all(s == 1 for s in block.strides) and 494 | block.input_filters == block.output_filters): 495 | if drop_connect_rate and drop_connect_rate > 0: 496 | survival_probability = 1 - drop_connect_rate 497 | x = stochastic_depth(x, survival_probability, deterministic=not train) 498 | x = x + inputs 499 | 500 | return x 501 | 502 | 503 | class Stem(flax.nn.Module): 504 | """Initial block of Efficientnet.""" 505 | 506 | def apply(self, 507 | x: jnp.ndarray, 508 | config: ModelConfig, 509 | train: bool = True) -> jnp.ndarray: 510 | """Returns the output of the stem block. 511 | 512 | Args: 513 | x: The input to the block. 514 | config: ModelConfig, a set of model parameters. 515 | train: Whether we are training or predicting. 516 | """ 517 | resolution = config.resolution 518 | if x.shape[1:3] != (resolution, resolution): 519 | raise ValueError('Wrong input size. Model was expecting ' + 520 | 'resolution {} '.format((resolution, resolution)) + 521 | 'but got input of resolution {}'.format(x.shape[1:3])) 522 | 523 | # Build stem 524 | x = conv2d( 525 | x, 526 | round_filters(config.stem_base_filters, config), 527 | config, 528 | kernel_size=(3, 3), 529 | strides=(2, 2), 530 | activation=config.activation, 531 | train=train) 532 | return x 533 | 534 | 535 | class Head(flax.nn.Module): 536 | """Final block of Efficientnet.""" 537 | 538 | def apply(self, 539 | x: jnp.ndarray, 540 | config: ModelConfig, 541 | num_classes: int, 542 | train: bool = True) -> jnp.ndarray: 543 | """Returns the output of the head block. 544 | 545 | Args: 546 | x: The input to the block. 547 | config: A set of model parameters. 548 | num_classes: Dimension of the output of the model. 549 | train: Whether we are training or predicting. 550 | """ 551 | # Build top 552 | x = conv2d( 553 | x, 554 | round_filters(config.top_base_filters, config), 555 | config, 556 | activation=config.activation, 557 | train=train) 558 | 559 | # Build classifier 560 | x = flax.nn.avg_pool(x, x.shape[1:3]) 561 | if config.dropout_rate and config.dropout_rate > 0: 562 | x = flax.nn.dropout(x, config.dropout_rate, deterministic=not train) 563 | x = flax.nn.Dense( 564 | x, num_classes, kernel_init=dense_kernel_init_fn, name='dense') 565 | x = x.reshape([x.shape[0], -1]) 566 | return x 567 | 568 | 569 | class EfficientNet(flax.nn.Module): 570 | """Implements EfficientNet model.""" 571 | 572 | def apply(self, 573 | x: jnp.ndarray, 574 | config: ModelConfig, 575 | num_classes: int = 1000, 576 | train: bool = True) -> jnp.ndarray: 577 | """Returns the output of the EfficientNet model. 578 | 579 | Args: 580 | x: The input batch of images. 581 | config: The model config. 582 | num_classes: Dimension of the output layer. 583 | train: Whether we are in training or inference. 584 | 585 | Returns: 586 | The output of efficientnet 587 | """ 588 | config = copy.deepcopy(config) 589 | depth_coefficient = config.depth_coefficient 590 | blocks = config.blocks 591 | drop_connect_rate = config.drop_connect_rate 592 | 593 | resolution = config.resolution 594 | if x.shape[1:3] != (resolution, resolution): 595 | raise ValueError('Wrong input size. Model was expecting ' + 596 | 'resolution {} '.format((resolution, resolution)) + 597 | 'but got input of resolution {}'.format(x.shape[1:3])) 598 | 599 | # Build stem 600 | x = Stem(x, config, train=train) 601 | 602 | # Build blocks 603 | num_blocks_total = sum( 604 | round_repeats(block.num_repeat, depth_coefficient) for block in blocks) 605 | block_num = 0 606 | 607 | for block in blocks: 608 | assert block.num_repeat > 0 609 | # Update block input and output filters based on depth multiplier 610 | block.input_filters = round_filters(block.input_filters, config) 611 | block.output_filters = round_filters(block.output_filters, config) 612 | block.num_repeat = round_repeats(block.num_repeat, depth_coefficient) 613 | 614 | # The first block needs to take care of stride and filter size increase 615 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total 616 | config.drop_connect_rate = drop_rate 617 | x = MBConvBlock(x, block, config, train=train) 618 | block_num += 1 619 | if block.num_repeat > 1: 620 | block.input_filters = block.output_filters 621 | block.strides = [1, 1] 622 | 623 | for _ in range(block.num_repeat - 1): 624 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total 625 | config.drop_connect_rate = drop_rate 626 | x = MBConvBlock(x, block, config, train=train) 627 | block_num += 1 628 | 629 | # Build top 630 | x = Head(x, config, num_classes, train=train) 631 | 632 | return x 633 | # pytype: enable=attribute-error 634 | 635 | 636 | def get_efficientnet_module(model_name: str, 637 | num_classes: int = 1000) -> EfficientNet: 638 | """Returns an EfficientNet module for a given architecture. 639 | 640 | Args: 641 | model_name: Name of the Efficientnet architecture to use (example: 642 | efficientnet-b0). 643 | num_classes: Dimension of the output layer. 644 | """ 645 | return EfficientNet.partial(config=MODEL_CONFIGS[model_name], 646 | num_classes=num_classes) 647 | -------------------------------------------------------------------------------- /sam_jax/efficientnet/efficientnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.efficientnet.efficientnet.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import numpy as np 20 | from sam.sam_jax.efficientnet import efficientnet 21 | from sam.sam_jax.imagenet_models import load_model 22 | 23 | 24 | class EfficientnetTest(absltest.TestCase): 25 | 26 | def _test_model_params( 27 | self, model_name: str, image_size: int, expected_params: int): 28 | module = efficientnet.get_efficientnet_module(model_name) 29 | model, _ = load_model.create_image_model(jax.random.PRNGKey(0), 30 | batch_size=1, 31 | image_size=image_size, 32 | module=module) 33 | num_params = sum(np.prod(e.shape) for e in jax.tree_leaves(model)) 34 | self.assertEqual(num_params, expected_params) 35 | 36 | def test_efficientnet_b0(self): 37 | self._test_model_params('efficientnet-b0', 224, expected_params=5288548) 38 | 39 | def test_efficientnet_b1(self): 40 | self._test_model_params('efficientnet-b1', 240, expected_params=7794184) 41 | 42 | def test_efficientnet_b2(self): 43 | self._test_model_params('efficientnet-b2', 260, expected_params=9109994) 44 | 45 | def test_efficientnet_b3(self): 46 | self._test_model_params('efficientnet-b3', 300, expected_params=12233232) 47 | 48 | def test_efficientnet_b4(self): 49 | self._test_model_params('efficientnet-b4', 380, expected_params=19341616) 50 | 51 | def test_efficientnet_b5(self): 52 | self._test_model_params('efficientnet-b5', 456, expected_params=30389784) 53 | 54 | def test_efficientnet_b6(self): 55 | self._test_model_params('efficientnet-b6', 528, expected_params=43040704) 56 | 57 | def test_efficientnet_b7(self): 58 | self._test_model_params('efficientnet-b7', 600, expected_params=66347960) 59 | 60 | 61 | if __name__ == '__main__': 62 | absltest.main() 63 | -------------------------------------------------------------------------------- /sam_jax/efficientnet/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """EfficientNet is trained with TF1 RMSProp, which is reproduced here. 16 | 17 | This version has two difference with the FLAX RMSProp version: 18 | * It supports momentum, which the FLAX version does not. 19 | * The moving average of the squared gradient is initialized to 1 as in TF1, 20 | instead of 0 as in everywhere else. 21 | """ 22 | 23 | from typing import Any, Optional, Tuple 24 | 25 | from flax import struct 26 | from flax.optim.base import OptimizerDef 27 | import jax 28 | import jax.numpy as jnp 29 | import numpy as onp 30 | 31 | 32 | # pytype:disable=wrong-arg-count 33 | @struct.dataclass 34 | class _RMSPropHyperParams: 35 | """RMSProp hyper parameters.""" 36 | 37 | learning_rate: float 38 | beta: float 39 | beta2: float 40 | eps: float 41 | 42 | 43 | @struct.dataclass 44 | class _RMSPropParamState: 45 | """RMSProp parameter state.""" 46 | 47 | v: onp.ndarray 48 | momentum: onp.ndarray 49 | 50 | 51 | class RMSProp(OptimizerDef): 52 | """RMSProp optimizer.""" 53 | 54 | def __init__(self, learning_rate: Optional[float] = None, beta: float = 0.0, 55 | beta2: float = 0.9, eps: float = 1e-8): 56 | """Instantiates the RMSProp optimizer. 57 | 58 | Args: 59 | learning_rate: The step size used to update the parameters. 60 | beta: Momentum. 61 | beta2: The coefficient used for the moving average of the 62 | gradient magnitude (default: 0.9). 63 | eps: The term added to the gradient magnitude estimate for 64 | numerical stability. 65 | """ 66 | hyper_params = _RMSPropHyperParams(learning_rate, beta, beta2, eps) 67 | super().__init__(hyper_params) 68 | 69 | def init_param_state(self, param: jnp.ndarray) -> _RMSPropParamState: 70 | """Initializes parameter state. See base class.""" 71 | return _RMSPropParamState( 72 | jnp.ones_like(param), jnp.zeros_like(param)) 73 | 74 | def apply_param_gradient( 75 | self, step: jnp.ndarray, 76 | hyper_params: _RMSPropHyperParams, 77 | param: jnp.ndarray, 78 | state: _RMSPropParamState, 79 | grad: jnp.ndarray) -> Tuple[jnp.ndarray, _RMSPropParamState]: 80 | """Applies per-parameter gradients. See base class.""" 81 | assert hyper_params.learning_rate is not None, 'no learning rate provided.' 82 | new_v = hyper_params.beta2 * state.v + ( 83 | 1.0 - hyper_params.beta2) * jnp.square(grad) 84 | grad = grad / jnp.sqrt(new_v + hyper_params.eps) 85 | new_momentum = hyper_params.beta * state.momentum + grad 86 | new_param = param - hyper_params.learning_rate * new_momentum 87 | new_state = _RMSPropParamState(new_v, new_momentum) 88 | 89 | return new_param, new_state 90 | 91 | 92 | # pytype:disable=attribute-error 93 | @struct.dataclass 94 | class ExponentialMovingAverage: 95 | """Exponential Moving Average as implemented in Tensorflow.""" 96 | 97 | # Moving average of the parameters. 98 | param_ema: Any 99 | # Decay to use for the update (typical values are 0.999, 0.9999, etc...). 100 | decay: float 101 | # For how many steps we should just keep the new parameters instead of an 102 | # average (useful if we don't want the initial weights to be included in the 103 | # average). 104 | warmup_steps: int 105 | 106 | def update_moving_average(self, new_target: Any, 107 | step: jnp.ndarray) -> Any: 108 | """Updates the moving average of the target. 109 | 110 | Args: 111 | new_target: New values of the target (example: weights of a network 112 | after gradient step). 113 | step: Current step (used only for warmup). 114 | 115 | Returns: 116 | The updated ExponentialMovingAverage. 117 | """ 118 | factor = jnp.float32(step >= self.warmup_steps) 119 | delta = step - self.warmup_steps 120 | decay = jnp.minimum(self.decay, (1. + delta) / (10. + delta)) 121 | decay *= factor 122 | weight_ema = jax.tree_multimap( 123 | lambda a, b: (1 - decay) * a + decay * b, new_target, self.param_ema) 124 | return self.replace(param_ema=weight_ema) 125 | # pytype:enable=attribute-error 126 | -------------------------------------------------------------------------------- /sam_jax/efficientnet/optim_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for google3.learning.neurosurgeon.research.image_classification.efficientnet.optim.""" 16 | 17 | from absl.testing import absltest 18 | from jax.config import config 19 | import numpy as onp 20 | from sam.sam_jax.efficientnet import optim 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | # Use double precision for better comparison with Tensorflow version. 25 | config.update("jax_enable_x64", True) 26 | 27 | 28 | class OptimTest(tf.test.TestCase): 29 | 30 | def test_RMSProp(self): 31 | """Updates should match Tensorflow1 behavior.""" 32 | lr, mom, rho = 0.5, 0.7, 0.8 33 | onp.random.seed(0) 34 | w0 = onp.random.normal(size=[17, 13, 1]) 35 | num_steps = 10 36 | # First compute weights updates for TF1 version. 37 | tf1_updated_weights = [] 38 | with tf.Session() as sess: 39 | var0 = tf.Variable(w0, trainable=True) 40 | opt = tf.train.RMSPropOptimizer( 41 | learning_rate=lr, decay=rho, momentum=mom, epsilon=0.001) 42 | loss = lambda: (var0**2) / 2.0 43 | step = opt.minimize(loss, var_list=[var0]) 44 | sess.run(tf.global_variables_initializer()) 45 | for _ in range(num_steps): 46 | sess.run(step) 47 | tf1_updated_weights.append(sess.run(var0)) 48 | # Now compute the updates for FLAX version. 49 | flax_updated_weights = [] 50 | optimizer_def = optim.RMSProp( 51 | learning_rate=lr, beta=mom, beta2=rho, eps=0.001) 52 | ref_opt = optimizer_def.create(w0) 53 | for _ in range(num_steps): 54 | gradient = ref_opt.target 55 | ref_opt = ref_opt.apply_gradient(gradient) 56 | flax_updated_weights.append(ref_opt.target) 57 | for a, b in zip(tf1_updated_weights, flax_updated_weights): 58 | self.assertAllClose(a, b) 59 | 60 | def test_RMSPropWithEMA(self): 61 | """Updates should match Tensorflow1 behavior.""" 62 | lr, mom, rho, ema_decay = 0.05, 0.4, 0.8, 1.0 63 | onp.random.seed(0) 64 | w0 = onp.array([1.0]) 65 | num_steps = 10 66 | # First compute weights updates for TF1 version. 67 | tf1_updated_weights = [] 68 | with tf.Session() as sess: 69 | global_step = tf.train.get_or_create_global_step() 70 | ema = tf.train.ExponentialMovingAverage( 71 | decay=ema_decay, num_updates=global_step) 72 | var0 = tf.Variable(w0, trainable=True) 73 | opt = tf.train.RMSPropOptimizer( 74 | learning_rate=lr, decay=rho, momentum=mom, epsilon=0.000) 75 | loss = lambda: (var0**2) / 2.0 76 | step = opt.minimize(loss, var_list=[var0], global_step=global_step) 77 | with tf.control_dependencies([step]): 78 | step = ema.apply([var0]) 79 | sess.run(tf.global_variables_initializer()) 80 | for _ in range(num_steps): 81 | sess.run(step) 82 | tf1_updated_weights.append(sess.run(ema.average(var0))) 83 | # Now computes the updates for FLAX version. 84 | flax_updated_weights = [] 85 | optimizer_def = optim.RMSProp( 86 | learning_rate=lr, beta=mom, beta2=rho, eps=0.000) 87 | ref_opt = optimizer_def.create(w0) 88 | ema = optim.ExponentialMovingAverage(w0, ema_decay, 0) 89 | for _ in range(num_steps): 90 | gradient = ref_opt.target 91 | ref_opt = ref_opt.apply_gradient(gradient) 92 | ema = ema.update_moving_average(ref_opt.target, ref_opt.state.step) 93 | flax_updated_weights.append(ema.param_ema) 94 | for a, b in zip(tf1_updated_weights, flax_updated_weights): 95 | self.assertAllClose(a, b) 96 | 97 | 98 | if __name__ == "__main__": 99 | absltest.main() 100 | -------------------------------------------------------------------------------- /sam_jax/imagenet_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /sam_jax/imagenet_models/load_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Build FLAX models for image classification.""" 16 | 17 | from typing import Optional, Tuple 18 | 19 | from absl import flags 20 | import flax 21 | from flax.training import checkpoints 22 | import jax 23 | from jax import numpy as jnp 24 | from jax import random 25 | 26 | from sam.sam_jax.efficientnet import efficientnet 27 | from sam.sam_jax.imagenet_models import resnet 28 | 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_bool('from_pretrained_checkpoint', False, 33 | 'If True, the model will be restarted from an pretrained ' 34 | 'checkpoint') 35 | flags.DEFINE_string('efficientnet_checkpoint_path', None, 36 | 'If finetuning, path to the efficientnet checkpoint.') 37 | 38 | 39 | _AVAILABLE_MODEL_NAMES = [ 40 | 'Resnet' 41 | ] + list(efficientnet.MODEL_CONFIGS.keys()) 42 | 43 | 44 | def create_image_model( 45 | prng_key: jnp.ndarray, batch_size: int, image_size: int, 46 | module: flax.nn.Module) -> Tuple[flax.nn.Model, flax.nn.Collection]: 47 | """Instantiates a FLAX model and its state. 48 | 49 | Args: 50 | prng_key: PRNG key to use to sample the initial weights. 51 | batch_size: Batch size that the model should expect. 52 | image_size: Dimension of the image (assumed to be squared). 53 | module: FLAX module describing the model to instantiates. 54 | 55 | Returns: 56 | A FLAX model and its state. 57 | """ 58 | input_shape = (batch_size, image_size, image_size, 3) 59 | with flax.nn.stateful() as init_state: 60 | with flax.nn.stochastic(jax.random.PRNGKey(0)): 61 | _, initial_params = module.init_by_shape( 62 | prng_key, [(input_shape, jnp.float32)]) 63 | model = flax.nn.Model(module, initial_params) 64 | return model, init_state 65 | 66 | 67 | class ModelNameError(Exception): 68 | """Exception to raise when the model name is not recognized.""" 69 | pass 70 | 71 | 72 | def _replace_dense_layer(model: flax.nn.Model, head: flax.nn.Model): 73 | """Replaces the last layer (head) of a model with the head of another one. 74 | 75 | Args: 76 | model: Model for which we should keep all layers except the head. 77 | head: Model from which we should copy the head. 78 | 79 | Returns: 80 | A model composed from the last layer of `head` and all the other layers of 81 | `model`. 82 | """ 83 | new_params = {} 84 | for (ak, av), (bk, bv) in zip( 85 | flax.traverse_util.flatten_dict(model.params).items(), 86 | flax.traverse_util.flatten_dict(head.params).items()): 87 | if ak[1] == 'dense': 88 | new_params[bk] = bv 89 | else: 90 | new_params[ak] = av 91 | return head.replace(params=flax.traverse_util.unflatten_dict(new_params)) 92 | 93 | 94 | def get_model( 95 | model_name: str, 96 | batch_size: int, 97 | image_size: int, 98 | num_classes: int = 1000, 99 | prng_key: Optional[jnp.ndarray] = None 100 | ) -> Tuple[flax.nn.Model, flax.nn.Collection]: 101 | """Returns an initialized model of the chosen architecture. 102 | 103 | Args: 104 | model_name: Name of the architecture to use. See image_classification.train 105 | flags for a list of available models. 106 | batch_size: The batch size that the model should expect. 107 | image_size: Dimension of the image (assumed to be squared). 108 | num_classes: Dimension of the output layer. Should be 1000, but is left as 109 | an argument for consistency with other load_model functions. An error will 110 | be raised if num_classes is not 1000. 111 | prng_key: PRNG key to use to sample the weights. 112 | 113 | Returns: 114 | The initialized model and its state. 115 | 116 | Raises: 117 | ModelNameError: If the name of the architecture is not recognized. 118 | """ 119 | if model_name == 'Resnet50': 120 | module = resnet.ResNet50.partial(num_classes=num_classes) 121 | elif model_name == 'Resnet101': 122 | module = resnet.ResNet101.partial(num_classes=num_classes) 123 | elif model_name == 'Resnet152': 124 | module = resnet.ResNet152.partial(num_classes=num_classes) 125 | elif model_name in efficientnet.MODEL_CONFIGS: 126 | module = efficientnet.get_efficientnet_module( 127 | model_name, num_classes=num_classes) 128 | else: 129 | raise ModelNameError('Unrecognized model name.') 130 | if not prng_key: 131 | prng_key = random.PRNGKey(0) 132 | 133 | model, init_state = create_image_model(prng_key, batch_size, image_size, 134 | module) 135 | 136 | if FLAGS.from_pretrained_checkpoint: 137 | if FLAGS.efficientnet_checkpoint_path is None: 138 | raise ValueError( 139 | 'For finetuning, must set `efficientnet_checkpoint_path` to a ' 140 | 'valid efficientnet checkpoint.') 141 | # If the number of class is 1000, just load the imagenet/JFT checkpoint. 142 | if num_classes == 1000: 143 | model, init_state = checkpoints.restore_checkpoint( 144 | FLAGS.efficientnet_checkpoint_path, 145 | (model, init_state)) 146 | # Else we need to change the size of the last layer (head): 147 | else: 148 | # Pretrained model on JFT/Imagenet. 149 | imagenet_module = efficientnet.get_efficientnet_module( 150 | model_name, num_classes=1000) 151 | imagenet_model, imagenet_state = create_image_model( 152 | prng_key, batch_size, image_size, imagenet_module) 153 | imagenet_model, imagenet_state = checkpoints.restore_checkpoint( 154 | FLAGS.efficientnet_checkpoint_path, 155 | (imagenet_model, imagenet_state)) 156 | # Replace all the layers of the initialized model with the weights 157 | # extracted from the pretrained model. 158 | model = _replace_dense_layer(imagenet_model, model) 159 | init_state = imagenet_state 160 | 161 | return model, init_state 162 | -------------------------------------------------------------------------------- /sam_jax/imagenet_models/load_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.imagenet_models.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import flax 20 | import numpy as np 21 | from sam.sam_jax.imagenet_models import load_model 22 | 23 | 24 | class LoadModelTest(parameterized.TestCase): 25 | 26 | def test_CreateEfficientnetModel(self): 27 | model, state = load_model.get_model('efficientnet-b0', 1, 224, 1000) 28 | self.assertIsInstance(model, flax.nn.Model) 29 | self.assertIsInstance(state, flax.nn.Collection) 30 | fake_input = np.zeros([1, 224, 224, 3]) 31 | with flax.nn.stateful(state, mutable=False): 32 | logits = model(fake_input, train=False) 33 | self.assertEqual(logits.shape, (1, 1000)) 34 | 35 | def test_CreateResnetModel(self): 36 | model, state = load_model.get_model('Resnet50', 1, 224, 1000) 37 | self.assertIsInstance(model, flax.nn.Model) 38 | self.assertIsInstance(state, flax.nn.Collection) 39 | fake_input = np.zeros([1, 224, 224, 3]) 40 | with flax.nn.stateful(state, mutable=False): 41 | logits = model(fake_input, train=False) 42 | self.assertEqual(logits.shape, (1, 1000)) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /sam_jax/imagenet_models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax implementation of ResNet V1. 16 | 17 | Forked from 18 | https://github.com/google/flax/blob/master/examples/imagenet/resnet_v1.py 19 | """ 20 | 21 | 22 | from flax import nn 23 | 24 | import jax.numpy as jnp 25 | 26 | 27 | class ResNetBlock(nn.Module): 28 | """ResNet block.""" 29 | 30 | def apply(self, x, filters, *, 31 | conv, norm, act, 32 | strides=(1, 1)): 33 | residual = x 34 | y = conv(x, filters, (3, 3), strides) 35 | y = norm(y) 36 | y = act(y) 37 | y = conv(y, filters, (3, 3)) 38 | y = norm(y, scale_init=nn.initializers.zeros) 39 | 40 | if residual.shape != y.shape: 41 | residual = conv(residual, filters, (1, 1), strides, name='conv_proj') 42 | residual = norm(residual, name='norm_proj') 43 | 44 | return act(residual + y) 45 | 46 | 47 | class BottleneckResNetBlock(nn.Module): 48 | """Bottleneck ResNet block.""" 49 | 50 | def apply(self, x, filters, *, 51 | conv, norm, act, 52 | strides=(1, 1)): 53 | residual = x 54 | y = conv(x, filters, (1, 1)) 55 | y = norm(y) 56 | y = act(y) 57 | y = conv(y, filters, (3, 3), strides) 58 | y = norm(y) 59 | y = act(y) 60 | y = conv(y, filters * 4, (1, 1)) 61 | y = norm(y, scale_init=nn.initializers.zeros) 62 | 63 | if residual.shape != y.shape: 64 | residual = conv(residual, filters * 4, (1, 1), strides, name='conv_proj') 65 | residual = norm(residual, name='norm_proj') 66 | 67 | return act(residual + y) 68 | 69 | 70 | class ResNet(nn.Module): 71 | """ResNetV1.""" 72 | 73 | def apply(self, x, num_classes, *, 74 | stage_sizes, 75 | block_cls, 76 | num_filters=64, 77 | dtype=jnp.float32, 78 | act=nn.relu, 79 | train=True): 80 | conv = nn.Conv.partial(bias=False, dtype=dtype) 81 | norm = nn.BatchNorm.partial( 82 | use_running_average=not train, 83 | momentum=0.9, epsilon=1e-5, 84 | dtype=dtype) 85 | 86 | x = conv(x, num_filters, (7, 7), (2, 2), 87 | padding=[(3, 3), (3, 3)], 88 | name='conv_init') 89 | x = norm(x, name='bn_init') 90 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 91 | for i, block_size in enumerate(stage_sizes): 92 | for j in range(block_size): 93 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 94 | x = block_cls(x, num_filters * 2 ** i, 95 | strides=strides, 96 | conv=conv, 97 | norm=norm, 98 | act=act) 99 | x = jnp.mean(x, axis=(1, 2)) 100 | x = nn.Dense(x, num_classes, dtype=dtype) 101 | return x 102 | 103 | 104 | ResNet18 = ResNet.partial(stage_sizes=[2, 2, 2, 2], 105 | block_cls=ResNetBlock) 106 | ResNet34 = ResNet.partial(stage_sizes=[3, 4, 6, 3], 107 | block_cls=ResNetBlock) 108 | ResNet50 = ResNet.partial(stage_sizes=[3, 4, 6, 3], 109 | block_cls=BottleneckResNetBlock) 110 | ResNet101 = ResNet.partial(stage_sizes=[3, 4, 23, 3], 111 | block_cls=BottleneckResNetBlock) 112 | ResNet152 = ResNet.partial(stage_sizes=[3, 8, 36, 3], 113 | block_cls=BottleneckResNetBlock) 114 | ResNet200 = ResNet.partial(stage_sizes=[3, 24, 36, 3], 115 | block_cls=BottleneckResNetBlock) 116 | -------------------------------------------------------------------------------- /sam_jax/imagenet_models/resnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.imagenet_models.resnet. 16 | 17 | Forked from 18 | https://github.com/google/flax/blob/master/examples/imagenet/resnet_v1_test.py 19 | """ 20 | 21 | from absl.testing import absltest 22 | import jax 23 | import jax.numpy as jnp 24 | from sam.sam_jax.imagenet_models import resnet 25 | 26 | 27 | class ResNetTest(absltest.TestCase): 28 | """Test cases for ResNet V1 model.""" 29 | 30 | def test_resnet_v1_module(self): 31 | """Tests ResNet V1 model definition.""" 32 | rng = jax.random.PRNGKey(0) 33 | model_def = resnet.ResNet50.partial(num_classes=10, dtype=jnp.float32) 34 | output, init_params = model_def.init_by_shape( 35 | rng, [((8, 224, 224, 3), jnp.float32)]) 36 | 37 | self.assertEqual((8, 10), output.shape) 38 | self.assertLen(init_params, 19) 39 | 40 | 41 | if __name__ == '__main__': 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /sam_jax/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /sam_jax/models/load_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Build FLAX models for image classification.""" 16 | 17 | from typing import Optional, Tuple 18 | import flax 19 | import jax 20 | from jax import numpy as jnp 21 | from jax import random 22 | 23 | from sam.sam_jax.models import pyramidnet 24 | from sam.sam_jax.models import wide_resnet 25 | from sam.sam_jax.models import wide_resnet_shakeshake 26 | 27 | 28 | _AVAILABLE_MODEL_NAMES = [ 29 | 'WideResnet28x10', 30 | 'WideResnet28x6_ShakeShake', 31 | 'Pyramid_ShakeDrop', 32 | 'WideResnet_mini', # For testing/debugging purposes. 33 | 'WideResnet_ShakeShake_mini', # For testing/debugging purposes. 34 | 'Pyramid_ShakeDrop_mini', # For testing/debugging purposes. 35 | ] 36 | 37 | 38 | def create_image_model( 39 | prng_key: jnp.ndarray, batch_size: int, image_size: int, 40 | module: flax.nn.Module, 41 | num_channels: int = 3) -> Tuple[flax.nn.Model, flax.nn.Collection]: 42 | """Instantiates a FLAX model and its state. 43 | 44 | Args: 45 | prng_key: PRNG key to use to sample the initial weights. 46 | batch_size: Batch size that the model should expect. 47 | image_size: Dimension of the image (assumed to be squared). 48 | module: FLAX module describing the model to instantiates. 49 | num_channels: Number of channels for the images. 50 | 51 | Returns: 52 | A FLAX model and its state. 53 | """ 54 | input_shape = (batch_size, image_size, image_size, num_channels) 55 | with flax.nn.stateful() as init_state: 56 | with flax.nn.stochastic(jax.random.PRNGKey(0)): 57 | _, initial_params = module.init_by_shape( 58 | prng_key, [(input_shape, jnp.float32)]) 59 | model = flax.nn.Model(module, initial_params) 60 | return model, init_state 61 | 62 | 63 | def get_model( 64 | model_name: str, 65 | batch_size: int, 66 | image_size: int, 67 | num_classes: int, 68 | num_channels: int = 3, 69 | prng_key: Optional[jnp.ndarray] = None, 70 | ) -> Tuple[flax.nn.Model, flax.nn.Collection]: 71 | """Returns an initialized model of the chosen architecture. 72 | 73 | Args: 74 | model_name: Name of the architecture to use. Should be one of 75 | _AVAILABLE_MODEL_NAMES. 76 | batch_size: The batch size that the model should expect. 77 | image_size: Dimension of the image (assumed to be squared). 78 | num_classes: Dimension of the output layer. 79 | num_channels: Number of channels for the images. 80 | prng_key: PRNG key to use to sample the weights. 81 | 82 | Returns: 83 | The initialized model and its state. 84 | 85 | Raises: 86 | ValueError if the name of the architecture is not recognized. 87 | """ 88 | if model_name == 'WideResnet28x10': 89 | module = wide_resnet.WideResnet.partial( 90 | blocks_per_group=4, 91 | channel_multiplier=10, 92 | num_outputs=num_classes) 93 | elif model_name == 'WideResnet28x6_ShakeShake': 94 | module = wide_resnet_shakeshake.WideResnetShakeShake.partial( 95 | blocks_per_group=4, 96 | channel_multiplier=6, 97 | num_outputs=num_classes) 98 | elif model_name == 'Pyramid_ShakeDrop': 99 | module = pyramidnet.PyramidNetShakeDrop.partial(num_outputs=num_classes) 100 | elif model_name == 'WideResnet_mini': # For testing. 101 | module = wide_resnet.WideResnet.partial( 102 | blocks_per_group=2, 103 | channel_multiplier=1, 104 | num_outputs=num_classes) 105 | elif model_name == 'WideResnet_ShakeShake_mini': # For testing. 106 | module = wide_resnet_shakeshake.WideResnetShakeShake.partial( 107 | blocks_per_group=2, 108 | channel_multiplier=1, 109 | num_outputs=num_classes) 110 | elif model_name == 'Pyramid_ShakeDrop_mini': 111 | module = pyramidnet.PyramidNetShakeDrop.partial(num_outputs=num_classes, 112 | pyramid_depth=11) 113 | else: 114 | raise ValueError('Unrecognized model name.') 115 | if not prng_key: 116 | prng_key = random.PRNGKey(0) 117 | 118 | model, init_state = create_image_model(prng_key, batch_size, image_size, 119 | module, num_channels) 120 | return model, init_state 121 | -------------------------------------------------------------------------------- /sam_jax/models/load_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.models.load_model.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import flax 20 | import jax 21 | import numpy as np 22 | from sam.sam_jax.models import load_model 23 | 24 | 25 | class LoadModelTest(parameterized.TestCase): 26 | 27 | # Parametrized because other models will be added in following CLs. 28 | @parameterized.named_parameters( 29 | ('WideResnet_mini', 'WideResnet_mini'), 30 | ('WideResnet_ShakeShake_mini', 'WideResnet_ShakeShake_mini'), 31 | ('Pyramid_ShakeDrop_mini', 'Pyramid_ShakeDrop_mini')) 32 | def test_CreateModel(self, model_name: str): 33 | model, state = load_model.get_model(model_name, 1, 32, 10) 34 | self.assertIsInstance(model, flax.nn.Model) 35 | self.assertIsInstance(state, flax.nn.Collection) 36 | fake_input = np.zeros([1, 32, 32, 3]) 37 | with flax.nn.stateful(state, mutable=False): 38 | logits = model(fake_input, train=False) 39 | self.assertEqual(logits.shape, (1, 10)) 40 | 41 | @parameterized.named_parameters( 42 | ('WideResnet28x10', 'WideResnet28x10'), 43 | ('WideResnet28x6_ShakeShake', 'WideResnet28x6_ShakeShake'), 44 | ('Pyramid_ShakeDrop', 'Pyramid_ShakeDrop')) 45 | def test_ParameterCount(self, model_name: str): 46 | # Parameter count from the autoaugment paper models, 100 classes: 47 | reference_parameter_count = { 48 | 'WideResnet28x10': 36278324, 49 | 'WideResnet28x6_ShakeShake': 26227572, 50 | 'Pyramid_ShakeDrop': 26288692, 51 | } 52 | model, _ = load_model.get_model(model_name, 1, 32, 100) 53 | parameter_count = sum(np.prod(e.shape) for e in jax.tree_leaves(model)) 54 | self.assertEqual(parameter_count, reference_parameter_count[model_name]) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /sam_jax/models/pyramidnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PyramidNet model with ShakeDrop regularization. 16 | 17 | Reference: 18 | 19 | ShakeDrop Regularization for Deep Residual Learning 20 | Yoshihiro Yamada, Masakazu Iwamura, Takuya Akiba, Koichi Kise 21 | https://arxiv.org/abs/1802.02375 22 | 23 | Initially forked from 24 | github.com/google/flax/blob/master/examples/cifar10/models/pyramidnet.py 25 | 26 | This implementation mimics the one from 27 | https://github.com/tensorflow/models/blob/master/research/autoaugment/shake_drop.py 28 | that is widely used as a benchmark. 29 | 30 | We use kaiming normal initialization for convolutional kernels (mode = fan_out, 31 | gain = 2.0). The final dense layer use a uniform distribution U[-scale, scale] 32 | where scale = 1 / sqrt(num_classes) as per the autoaugment implementation. 33 | 34 | It is worth noting that this model is slighly different that the one presented 35 | in the Deep Pyramidal Residual Networks paper 36 | (https://arxiv.org/pdf/1610.02915.pdf), as we round instead of trucating when 37 | computing the number of channels in each block. This results in a model with 38 | roughtly 0.2M additional parameters. Rounding is however the method that was 39 | used in follow up work (https://arxiv.org/abs/1905.00397, 40 | https://arxiv.org/abs/2002.12047) so we keep it for consistency. 41 | """ 42 | 43 | from typing import Tuple 44 | 45 | from flax import nn 46 | import jax.numpy as jnp 47 | 48 | from sam.sam_jax.models import utils 49 | 50 | 51 | def _shortcut(x: jnp.ndarray, chn_out: int, strides: Tuple[int, int] 52 | ) -> jnp.ndarray: 53 | """Pyramid Net shortcut. 54 | 55 | Use Average pooling to downsample. 56 | Use zero-padding to increase channels. 57 | 58 | Args: 59 | x: Input. Should have shape [batch_size, dim, dim, features] 60 | where dim is the resolution (width and height if the input is an image). 61 | chn_out: Expected output channels. 62 | strides: Output stride. 63 | 64 | Returns: 65 | Shortcut value for Pyramid Net. Shape will be 66 | [batch_size, dim, dim, chn_out] if strides = (1, 1) (no downsampling) or 67 | [batch_size, dim/2, dim/2, chn_out] if strides = (2, 2) (downsampling). 68 | """ 69 | chn_in = x.shape[3] 70 | if strides != (1, 1): 71 | x = nn.avg_pool(x, strides, strides) 72 | if chn_out != chn_in: 73 | diff = chn_out - chn_in 74 | x = jnp.pad(x, [[0, 0], [0, 0], [0, 0], [0, diff]]) 75 | return x 76 | 77 | 78 | class BottleneckShakeDrop(nn.Module): 79 | """PyramidNet with Shake-Drop Bottleneck.""" 80 | 81 | def apply(self, 82 | x: jnp.ndarray, 83 | channels: int, 84 | strides: Tuple[int, int], 85 | prob: float, 86 | alpha_min: float, 87 | alpha_max: float, 88 | beta_min: float, 89 | beta_max: float, 90 | train: bool = True, 91 | true_gradient: bool = False) -> jnp.ndarray: 92 | """Implements the forward pass in the module. 93 | 94 | Args: 95 | x: Input to the module. Should have shape [batch_size, dim, dim, features] 96 | where dim is the resolution (width and height if the input is an image). 97 | channels: How many channels to use in the convolutional layers. 98 | strides: Strides for the pooling. 99 | prob: Probability of dropping the block (see paper for details). 100 | alpha_min: See paper. 101 | alpha_max: See paper. 102 | beta_min: See paper. 103 | beta_max: See paper. 104 | train: If False, will use the moving average for batch norm statistics. 105 | Else, will use statistics computed on the batch. 106 | true_gradient: If true, the same mixing parameter will be used for the 107 | forward and backward pass (see paper for more details). 108 | 109 | Returns: 110 | The output of the bottleneck block. 111 | """ 112 | y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre') 113 | y = nn.Conv( 114 | y, 115 | channels, (1, 1), 116 | padding='SAME', 117 | bias=False, 118 | kernel_init=utils.conv_kernel_init_fn, 119 | name='1x1_conv_contract') 120 | y = utils.activation(y, train=train, name='bn_1_post') 121 | y = nn.Conv( 122 | y, 123 | channels, (3, 3), 124 | strides, 125 | padding='SAME', 126 | bias=False, 127 | kernel_init=utils.conv_kernel_init_fn, 128 | name='3x3') 129 | y = utils.activation(y, train=train, name='bn_2') 130 | y = nn.Conv( 131 | y, 132 | channels * 4, (1, 1), 133 | padding='SAME', 134 | bias=False, 135 | kernel_init=utils.conv_kernel_init_fn, 136 | name='1x1_conv_expand') 137 | y = utils.activation(y, apply_relu=False, train=train, name='bn_3') 138 | 139 | if train and not self.is_initializing(): 140 | y = utils.shake_drop_train(y, prob, alpha_min, alpha_max, 141 | beta_min, beta_max, 142 | true_gradient=true_gradient) 143 | else: 144 | y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max) 145 | 146 | x = _shortcut(x, channels * 4, strides) 147 | return x + y 148 | 149 | 150 | def _calc_shakedrop_mask_prob(curr_layer: int, 151 | total_layers: int, 152 | mask_prob: float) -> float: 153 | """Calculates drop prob depending on the current layer.""" 154 | return 1 - (float(curr_layer) / total_layers) * mask_prob 155 | 156 | 157 | class PyramidNetShakeDrop(nn.Module): 158 | """PyramidNet with Shake-Drop.""" 159 | 160 | def apply(self, 161 | x: jnp.ndarray, 162 | num_outputs: int, 163 | pyramid_alpha: int = 200, 164 | pyramid_depth: int = 272, 165 | train: bool = True, 166 | true_gradient: bool = False) -> jnp.ndarray: 167 | """Implements the forward pass in the module. 168 | 169 | Args: 170 | x: Input to the module. Should have shape [batch_size, dim, dim, 3] 171 | where dim is the resolution of the image. 172 | num_outputs: Dimension of the output of the model (ie number of classes 173 | for a classification problem). 174 | pyramid_alpha: See paper. 175 | pyramid_depth: See paper. 176 | train: If False, will use the moving average for batch norm statistics. 177 | Else, will use statistics computed on the batch. 178 | true_gradient: If true, the same mixing parameter will be used for the 179 | forward and backward pass (see paper for more details). 180 | 181 | Returns: 182 | The output of the PyramidNet model, a tensor of shape 183 | [batch_size, num_classes]. 184 | """ 185 | assert (pyramid_depth - 2) % 9 == 0 186 | 187 | # Shake-drop hyper-params 188 | mask_prob = 0.5 189 | alpha_min, alpha_max = (-1.0, 1.0) 190 | beta_min, beta_max = (0.0, 1.0) 191 | 192 | # Bottleneck network size 193 | blocks_per_group = (pyramid_depth - 2) // 9 194 | # See Eqn 2 in https://arxiv.org/abs/1610.02915 195 | num_channels = 16 196 | # N in https://arxiv.org/abs/1610.02915 197 | total_blocks = blocks_per_group * 3 198 | delta_channels = pyramid_alpha / total_blocks 199 | 200 | x = nn.Conv( 201 | x, 202 | 16, (3, 3), 203 | padding='SAME', 204 | name='init_conv', 205 | bias=False, 206 | kernel_init=utils.conv_kernel_init_fn) 207 | x = utils.activation(x, apply_relu=False, train=train, name='init_bn') 208 | 209 | layer_num = 1 210 | 211 | for block_i in range(blocks_per_group): 212 | num_channels += delta_channels 213 | layer_mask_prob = _calc_shakedrop_mask_prob(layer_num, total_blocks, 214 | mask_prob) 215 | x = BottleneckShakeDrop( 216 | x, 217 | int(round(num_channels)), (1, 1), 218 | layer_mask_prob, 219 | alpha_min, 220 | alpha_max, 221 | beta_min, 222 | beta_max, 223 | train=train, 224 | true_gradient=true_gradient) 225 | layer_num += 1 226 | 227 | for block_i in range(blocks_per_group): 228 | num_channels += delta_channels 229 | layer_mask_prob = _calc_shakedrop_mask_prob( 230 | layer_num, total_blocks, mask_prob) 231 | x = BottleneckShakeDrop(x, int(round(num_channels)), 232 | ((2, 2) if block_i == 0 else (1, 1)), 233 | layer_mask_prob, 234 | alpha_min, alpha_max, beta_min, beta_max, 235 | train=train, 236 | true_gradient=true_gradient) 237 | layer_num += 1 238 | 239 | for block_i in range(blocks_per_group): 240 | num_channels += delta_channels 241 | layer_mask_prob = _calc_shakedrop_mask_prob( 242 | layer_num, total_blocks, mask_prob) 243 | x = BottleneckShakeDrop(x, int(round(num_channels)), 244 | ((2, 2) if block_i == 0 else (1, 1)), 245 | layer_mask_prob, 246 | alpha_min, alpha_max, beta_min, beta_max, 247 | train=train, 248 | true_gradient=true_gradient) 249 | layer_num += 1 250 | 251 | assert layer_num - 1 == total_blocks 252 | x = utils.activation(x, train=train, name='final_bn') 253 | x = nn.avg_pool(x, (8, 8)) 254 | x = x.reshape((x.shape[0], -1)) 255 | x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) 256 | return x 257 | -------------------------------------------------------------------------------- /sam_jax/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shake-shake and shake-drop functions. 16 | 17 | Forked from: 18 | https://github.com/google-research/google-research/blob/master/flax_models/cifar/models/utils.py 19 | """ 20 | 21 | from typing import Optional, Tuple 22 | import flax 23 | from flax import nn 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | 28 | _BATCHNORM_MOMENTUM = 0.9 29 | _BATCHNORM_EPSILON = 1e-5 30 | 31 | 32 | def activation(x: jnp.ndarray, 33 | train: bool, 34 | apply_relu: bool = True, 35 | name: str = '') -> jnp.ndarray: 36 | """Applies BatchNorm and then (optionally) ReLU. 37 | 38 | Args: 39 | x: Tensor on which the activation should be applied. 40 | train: If False, will use the moving average for batch norm statistics. 41 | Else, will use statistics computed on the batch. 42 | apply_relu: Whether or not ReLU should be applied after batch normalization. 43 | name: How to name the BatchNorm layer. 44 | 45 | Returns: 46 | The input tensor where BatchNorm and (optionally) ReLU where applied. 47 | """ 48 | batch_norm = nn.BatchNorm.partial( 49 | use_running_average=not train, 50 | momentum=_BATCHNORM_MOMENTUM, 51 | epsilon=_BATCHNORM_EPSILON) 52 | x = batch_norm(x, name=name) 53 | if apply_relu: 54 | x = jax.nn.relu(x) 55 | return x 56 | 57 | 58 | # Kaiming initialization with fan out mode. Should be used to initialize 59 | # convolutional kernels. 60 | conv_kernel_init_fn = jax.nn.initializers.variance_scaling( 61 | 2.0, 'fan_out', 'normal') 62 | 63 | 64 | def dense_layer_init_fn(key: jnp.ndarray, 65 | shape: Tuple[int, int], 66 | dtype: jnp.dtype = jnp.float32) -> jnp.ndarray: 67 | """Initializer for the final dense layer. 68 | 69 | Args: 70 | key: PRNG key to use to sample the weights. 71 | shape: Shape of the tensor to initialize. 72 | dtype: Data type of the tensor to initialize. 73 | 74 | Returns: 75 | The initialized tensor. 76 | """ 77 | num_units_out = shape[1] 78 | unif_init_range = 1.0 / (num_units_out)**(0.5) 79 | return jax.random.uniform(key, shape, dtype, -1) * unif_init_range 80 | 81 | 82 | def shake_shake_train(xa: jnp.ndarray, 83 | xb: jnp.ndarray, 84 | rng: Optional[jnp.ndarray] = None, 85 | true_gradient: bool = False) -> jnp.ndarray: 86 | """Shake-shake regularization in training mode. 87 | 88 | Shake-shake regularization interpolates between inputs A and B 89 | with *different* random uniform (per-sample) interpolation factors 90 | for the forward and backward/gradient passes. 91 | 92 | Args: 93 | xa: Input, branch A. 94 | xb: Input, branch B. 95 | rng: PRNG key. 96 | true_gradient: If true, the same mixing parameter will be used for the 97 | forward and backward pass (see paper for more details). 98 | 99 | Returns: 100 | Mix of input branches. 101 | """ 102 | if rng is None: 103 | rng = flax.nn.make_rng() 104 | gate_forward_key, gate_backward_key = jax.random.split(rng, num=2) 105 | gate_shape = (len(xa), 1, 1, 1) 106 | 107 | # Draw different interpolation factors (gate) for forward and backward pass. 108 | gate_forward = jax.random.uniform( 109 | gate_forward_key, gate_shape, dtype=jnp.float32, minval=0.0, maxval=1.0) 110 | x_forward = xa * gate_forward + xb * (1.0 - gate_forward) 111 | if true_gradient: 112 | return x_forward 113 | gate_backward = jax.random.uniform( 114 | gate_backward_key, gate_shape, dtype=jnp.float32, minval=0.0, maxval=1.0) 115 | # Compute interpolated x for forward and backward. 116 | x_backward = xa * gate_backward + xb * (1.0 - gate_backward) 117 | # Combine using stop_gradient. 118 | return x_backward + jax.lax.stop_gradient(x_forward - x_backward) 119 | 120 | 121 | def shake_shake_eval(xa: jnp.ndarray, xb: jnp.ndarray) -> jnp.ndarray: 122 | """Shake-shake regularization in testing mode. 123 | 124 | Args: 125 | xa: Input, branch A. 126 | xb: Input, branch B. 127 | 128 | Returns: 129 | Mix of input branches. 130 | """ 131 | # Blend between inputs A and B 50%-50%. 132 | return (xa + xb) * 0.5 133 | 134 | 135 | def shake_drop_train(x: jnp.ndarray, 136 | mask_prob: float, 137 | alpha_min: float, 138 | alpha_max: float, 139 | beta_min: float, 140 | beta_max: float, 141 | rng: Optional[jnp.ndarray] = None, 142 | true_gradient: bool = False) -> jnp.ndarray: 143 | """ShakeDrop training pass. 144 | 145 | See https://arxiv.org/abs/1802.02375 146 | 147 | Args: 148 | x: Input to apply ShakeDrop to. 149 | mask_prob: Mask probability. 150 | alpha_min: Alpha range lower. 151 | alpha_max: Alpha range upper. 152 | beta_min: Beta range lower. 153 | beta_max: Beta range upper. 154 | rng: PRNG key (if `None`, uses `flax.nn.make_rng`). 155 | true_gradient: If true, the same mixing parameter will be used for the 156 | forward and backward pass (see paper for more details). 157 | 158 | Returns: 159 | The regularized tensor. 160 | """ 161 | if rng is None: 162 | rng = flax.nn.make_rng() 163 | bern_key, alpha_key, beta_key = jax.random.split(rng, num=3) 164 | rnd_shape = (len(x), 1, 1, 1) 165 | # Bernoulli variable b_l in Eqn 6, https://arxiv.org/abs/1802.02375. 166 | mask = jax.random.bernoulli(bern_key, mask_prob, rnd_shape) 167 | mask = mask.astype(jnp.float32) 168 | 169 | alpha_values = jax.random.uniform( 170 | alpha_key, 171 | rnd_shape, 172 | dtype=jnp.float32, 173 | minval=alpha_min, 174 | maxval=alpha_max) 175 | beta_values = jax.random.uniform( 176 | beta_key, rnd_shape, dtype=jnp.float32, minval=beta_min, maxval=beta_max) 177 | # See Eqn 6 in https://arxiv.org/abs/1802.02375. 178 | rand_forward = mask + alpha_values - mask * alpha_values 179 | if true_gradient: 180 | return x * rand_forward 181 | rand_backward = mask + beta_values - mask * beta_values 182 | return x * rand_backward + jax.lax.stop_gradient( 183 | x * rand_forward - x * rand_backward) 184 | 185 | 186 | def shake_drop_eval(x: jnp.ndarray, 187 | mask_prob: float, 188 | alpha_min: float, 189 | alpha_max: float) -> jnp.ndarray: 190 | """ShakeDrop eval pass. 191 | 192 | See https://arxiv.org/abs/1802.02375 193 | 194 | Args: 195 | x: Input to apply ShakeDrop to. 196 | mask_prob: Mask probability. 197 | alpha_min: Alpha range lower. 198 | alpha_max: Alpha range upper. 199 | 200 | Returns: 201 | The regularized tensor. 202 | """ 203 | expected_alpha = (alpha_max + alpha_min) / 2 204 | # See Eqn 6 in https://arxiv.org/abs/1802.02375. 205 | return (mask_prob + expected_alpha - mask_prob * expected_alpha) * x 206 | -------------------------------------------------------------------------------- /sam_jax/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Wide Resnet Model. 16 | 17 | Reference: 18 | 19 | Wide Residual Networks, Sergey Zagoruyko, Nikos Komodakis 20 | https://arxiv.org/abs/1605.07146 21 | 22 | Forked from 23 | https://github.com/google-research/google-research/blob/master/flax_models/cifar/models/wide_resnet.py 24 | 25 | This implementation mimics the one from 26 | github.com/tensorflow/models/blob/master/research/autoaugment/wrn.py 27 | that is widely used as a benchmark. 28 | 29 | It uses idendity + zero padding skip connections, with kaiming normal 30 | initialization for convolutional kernels (mode = fan_out, gain=2.0). 31 | The final dense layer use a uniform distribution U[-scale, scale] where 32 | scale = 1 / sqrt(num_classes) as per the autoaugment implementation. 33 | 34 | Using the default initialization instead gives error rates approximately 0.5% 35 | greater on cifar100, most likely between the parameters used in the literature 36 | where finetuned for this particular initialization. 37 | 38 | Finally, the autoaugment implementation adds more residual connections between 39 | the groups (instead of just between the blocks as per the original paper and 40 | most implementations). It is possible to safely remove those connections without 41 | degrading the performances, which we do by default to match the original 42 | wideresnet paper. Setting `use_additional_skip_connections` to True will add 43 | them back and then reproduces exactly the model used in autoaugment. 44 | """ 45 | 46 | from typing import Tuple 47 | 48 | from absl import flags 49 | from flax import nn 50 | from jax import numpy as jnp 51 | 52 | from sam.sam_jax.models import utils 53 | 54 | 55 | FLAGS = flags.FLAGS 56 | 57 | 58 | flags.DEFINE_bool('use_additional_skip_connections', False, 59 | 'Set to True to use additional skip connections between the ' 60 | 'resnet groups. This reproduces the autoaugment ' 61 | 'implementation, but these connections are not present in ' 62 | 'most implementations. Removing them does not impact ' 63 | 'the performance of the model.') 64 | 65 | 66 | def _output_add(block_x: jnp.ndarray, orig_x: jnp.ndarray) -> jnp.ndarray: 67 | """Add two tensors, padding them with zeros or pooling them if necessary. 68 | 69 | Args: 70 | block_x: Output of a resnet block. 71 | orig_x: Residual branch to add to the output of the resnet block. 72 | 73 | Returns: 74 | The sum of blocks_x and orig_x. If necessary, orig_x will be average pooled 75 | or zero padded so that its shape matches orig_x. 76 | """ 77 | stride = orig_x.shape[-2] // block_x.shape[-2] 78 | strides = (stride, stride) 79 | if block_x.shape[-1] != orig_x.shape[-1]: 80 | orig_x = nn.avg_pool(orig_x, strides, strides) 81 | channels_to_add = block_x.shape[-1] - orig_x.shape[-1] 82 | orig_x = jnp.pad(orig_x, [(0, 0), (0, 0), (0, 0), (0, channels_to_add)]) 83 | return block_x + orig_x 84 | 85 | 86 | class WideResnetBlock(nn.Module): 87 | """Defines a single WideResnetBlock.""" 88 | 89 | def apply(self, 90 | x: jnp.ndarray, 91 | channels: int, 92 | strides: Tuple[int, int] = (1, 1), 93 | activate_before_residual: bool = False, 94 | train: bool = True) -> jnp.ndarray: 95 | """Implements the forward pass in the module. 96 | 97 | Args: 98 | x: Input to the module. Should have shape [batch_size, dim, dim, features] 99 | where dim is the resolution (width and height if the input is an image). 100 | channels: How many channels to use in the convolutional layers. 101 | strides: Strides for the pooling. 102 | activate_before_residual: True if the batch norm and relu should be 103 | applied before the residual branches out (should be True only for the 104 | first block of the model). 105 | train: If False, will use the moving average for batch norm statistics. 106 | Else, will use statistics computed on the batch. 107 | 108 | Returns: 109 | The output of the resnet block. 110 | """ 111 | if activate_before_residual: 112 | x = utils.activation(x, train, name='init_bn') 113 | orig_x = x 114 | else: 115 | orig_x = x 116 | 117 | block_x = x 118 | if not activate_before_residual: 119 | block_x = utils.activation(block_x, train, name='init_bn') 120 | 121 | block_x = nn.Conv( 122 | block_x, 123 | channels, (3, 3), 124 | strides, 125 | padding='SAME', 126 | bias=False, 127 | kernel_init=utils.conv_kernel_init_fn, 128 | name='conv1') 129 | block_x = utils.activation(block_x, train=train, name='bn_2') 130 | block_x = nn.Conv( 131 | block_x, 132 | channels, (3, 3), 133 | padding='SAME', 134 | bias=False, 135 | kernel_init=utils.conv_kernel_init_fn, 136 | name='conv2') 137 | 138 | return _output_add(block_x, orig_x) 139 | 140 | 141 | class WideResnetGroup(nn.Module): 142 | """Defines a WideResnetGroup.""" 143 | 144 | def apply(self, 145 | x: jnp.ndarray, 146 | blocks_per_group: int, 147 | channels: int, 148 | strides: Tuple[int, int] = (1, 1), 149 | activate_before_residual: bool = False, 150 | train: bool = True) -> jnp.ndarray: 151 | """Implements the forward pass in the module. 152 | 153 | Args: 154 | x: Input to the module. Should have shape [batch_size, dim, dim, features] 155 | where dim is the resolution (width and height if the input is an image). 156 | blocks_per_group: How many resnet blocks to add to each group (should be 157 | 4 blocks for a WRN28, and 6 for a WRN40). 158 | channels: How many channels to use in the convolutional layers. 159 | strides: Strides for the pooling. 160 | activate_before_residual: True if the batch norm and relu should be 161 | applied before the residual branches out (should be True only for the 162 | first group of the model). 163 | train: If False, will use the moving average for batch norm statistics. 164 | Else, will use statistics computed on the batch. 165 | 166 | Returns: 167 | The output of the resnet block. 168 | """ 169 | orig_x = x 170 | for i in range(blocks_per_group): 171 | x = WideResnetBlock( 172 | x, 173 | channels, 174 | strides if i == 0 else (1, 1), 175 | activate_before_residual=activate_before_residual and not i, 176 | train=train) 177 | if FLAGS.use_additional_skip_connections: 178 | x = _output_add(x, orig_x) 179 | return x 180 | 181 | 182 | class WideResnet(nn.Module): 183 | """Defines the WideResnet Model.""" 184 | 185 | def apply(self, 186 | x: jnp.ndarray, 187 | blocks_per_group: int, 188 | channel_multiplier: int, 189 | num_outputs: int, 190 | train: bool = True) -> jnp.ndarray: 191 | """Implements a WideResnet module. 192 | 193 | Args: 194 | x: Input to the module. Should have shape [batch_size, dim, dim, 3] 195 | where dim is the resolution of the image. 196 | blocks_per_group: How many resnet blocks to add to each group (should be 197 | 4 blocks for a WRN28, and 6 for a WRN40). 198 | channel_multiplier: The multiplier to apply to the number of filters in 199 | the model (1 is classical resnet, 10 for WRN28-10, etc...). 200 | num_outputs: Dimension of the output of the model (ie number of classes 201 | for a classification problem). 202 | train: If False, will use the moving average for batch norm statistics. 203 | 204 | Returns: 205 | The output of the WideResnet, a tensor of shape [batch_size, num_classes]. 206 | """ 207 | first_x = x 208 | x = nn.Conv( 209 | x, 210 | 16, (3, 3), 211 | padding='SAME', 212 | name='init_conv', 213 | kernel_init=utils.conv_kernel_init_fn, 214 | bias=False) 215 | x = WideResnetGroup( 216 | x, 217 | blocks_per_group, 218 | 16 * channel_multiplier, 219 | activate_before_residual=True, 220 | train=train) 221 | x = WideResnetGroup( 222 | x, 223 | blocks_per_group, 224 | 32 * channel_multiplier, (2, 2), 225 | train=train) 226 | x = WideResnetGroup( 227 | x, 228 | blocks_per_group, 229 | 64 * channel_multiplier, (2, 2), 230 | train=train) 231 | if FLAGS.use_additional_skip_connections: 232 | x = _output_add(x, first_x) 233 | x = utils.activation(x, train=train, name='pre-pool-bn') 234 | x = nn.avg_pool(x, x.shape[1:3]) 235 | x = x.reshape((x.shape[0], -1)) 236 | x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) 237 | return x 238 | -------------------------------------------------------------------------------- /sam_jax/models/wide_resnet_shakeshake.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Wide Resnet Model with ShakeShake regularization. 16 | 17 | Reference: 18 | 19 | Shake-Shake regularization, Xavier Gastaldi 20 | https://arxiv.org/abs/1705.07485 21 | 22 | Initially forked from 23 | https://github.com/google-research/google-research/blob/master/flax_models/cifar/models/wide_resnet_shakeshake.py 24 | 25 | This implementation mimics the one from 26 | github.com/tensorflow/models/blob/master/research/autoaugment/shake_shake.py 27 | that is widely used as a benchmark. 28 | 29 | It uses kaiming normal initialization for convolutional kernels (mode = fan_out 30 | , gain=2.0). The final dense layer use a uniform distribution U[-scale, scale] 31 | where scale = 1 / sqrt(num_classes) as per the autoaugment implementation. 32 | 33 | The residual connections follows the implementation of X. Gastaldi (1x1 pooling 34 | with a one pixel offset for half the channels, see 2.1.1 Implementation details 35 | section in the reference above for more details). 36 | """ 37 | 38 | from typing import Tuple 39 | 40 | from flax import nn 41 | import jax 42 | from jax import numpy as jnp 43 | 44 | from sam.sam_jax.models import utils 45 | 46 | 47 | class Shortcut(nn.Module): 48 | """Shortcut for residual connections.""" 49 | 50 | def apply(self, 51 | x: jnp.ndarray, 52 | channels: int, 53 | strides: Tuple[int, int] = (1, 1), 54 | train: bool = True) -> jnp.ndarray: 55 | """Implements the forward pass in the module. 56 | 57 | Args: 58 | x: Input to the module. Should have shape [batch_size, dim, dim, features] 59 | where dim is the resolution (width and height if the input is an image). 60 | channels: How many channels to use in the convolutional layers. 61 | strides: Strides for the pooling. 62 | train: If False, will use the moving average for batch norm statistics. 63 | 64 | Returns: 65 | The output of the resnet block. Will have shape 66 | [batch_size, dim, dim, channels] if strides = (1, 1) or 67 | [batch_size, dim/2, dim/2, channels] if strides = (2, 2). 68 | """ 69 | 70 | if x.shape[-1] == channels: 71 | return x 72 | 73 | # Skip path 1 74 | h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID') 75 | h1 = nn.Conv( 76 | h1, 77 | channels // 2, (1, 1), 78 | strides=(1, 1), 79 | padding='SAME', 80 | bias=False, 81 | kernel_init=utils.conv_kernel_init_fn, 82 | name='conv_h1') 83 | 84 | # Skip path 2 85 | # The next two lines offset the "image" by one pixel on the right and one 86 | # down (see Shake-Shake regularization, Xavier Gastaldi for details) 87 | pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] 88 | h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :] 89 | h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID') 90 | h2 = nn.Conv( 91 | h2, 92 | channels // 2, (1, 1), 93 | strides=(1, 1), 94 | padding='SAME', 95 | bias=False, 96 | kernel_init=utils.conv_kernel_init_fn, 97 | name='conv_h2') 98 | merged_branches = jnp.concatenate([h1, h2], axis=3) 99 | return utils.activation( 100 | merged_branches, apply_relu=False, train=train, name='bn_residual') 101 | 102 | 103 | class ShakeShakeBlock(nn.Module): 104 | """Wide ResNet block with shake-shake regularization.""" 105 | 106 | def apply(self, 107 | x: jnp.ndarray, 108 | channels: int, 109 | strides: Tuple[int, int] = (1, 1), 110 | train: bool = True, 111 | true_gradient: bool = False) -> jnp.ndarray: 112 | """Implements the forward pass in the module. 113 | 114 | Args: 115 | x: Input to the module. Should have shape [batch_size, dim, dim, features] 116 | where dim is the resolution (width and height if the input is an image). 117 | channels: How many channels to use in the convolutional layers. 118 | strides: Strides for the pooling. 119 | train: If False, will use the moving average for batch norm statistics. 120 | Else, will use statistics computed on the batch. 121 | true_gradient: If true, the same mixing parameter will be used for the 122 | forward and backward pass (see paper for more details). 123 | 124 | Returns: 125 | The output of the resnet block. Will have shape 126 | [batch_size, dim, dim, channels] if strides = (1, 1) or 127 | [batch_size, dim/2, dim/2, channels] if strides = (2, 2). 128 | """ 129 | a = b = residual = x 130 | 131 | a = jax.nn.relu(a) 132 | a = nn.Conv( 133 | a, 134 | channels, (3, 3), 135 | strides, 136 | padding='SAME', 137 | bias=False, 138 | kernel_init=utils.conv_kernel_init_fn, 139 | name='conv_a_1') 140 | a = utils.activation(a, train=train, name='bn_a_1') 141 | a = nn.Conv( 142 | a, 143 | channels, (3, 3), 144 | padding='SAME', 145 | bias=False, 146 | kernel_init=utils.conv_kernel_init_fn, 147 | name='conv_a_2') 148 | a = utils.activation(a, apply_relu=False, train=train, name='bn_a_2') 149 | 150 | b = jax.nn.relu(b) 151 | b = nn.Conv( 152 | b, 153 | channels, (3, 3), 154 | strides, 155 | padding='SAME', 156 | bias=False, 157 | kernel_init=utils.conv_kernel_init_fn, 158 | name='conv_b_1') 159 | b = utils.activation(b, train=train, name='bn_b_1') 160 | b = nn.Conv( 161 | b, 162 | channels, (3, 3), 163 | padding='SAME', 164 | bias=False, 165 | kernel_init=utils.conv_kernel_init_fn, 166 | name='conv_b_2') 167 | b = utils.activation(b, apply_relu=False, train=train, name='bn_b_2') 168 | 169 | if train and not self.is_initializing(): 170 | ab = utils.shake_shake_train(a, b, true_gradient=true_gradient) 171 | else: 172 | ab = utils.shake_shake_eval(a, b) 173 | 174 | # Apply an up projection in case of channel mismatch. 175 | residual = Shortcut(residual, channels, strides, train) 176 | 177 | return residual + ab 178 | 179 | 180 | class WideResnetShakeShakeGroup(nn.Module): 181 | """Defines a WideResnetGroup.""" 182 | 183 | def apply(self, 184 | x: jnp.ndarray, 185 | blocks_per_group: int, 186 | channels: int, 187 | strides: Tuple[int, int] = (1, 1), 188 | train: bool = True, 189 | true_gradient: bool = False) -> jnp.ndarray: 190 | """Implements the forward pass in the module. 191 | 192 | Args: 193 | x: Input to the module. Should have shape [batch_size, dim, dim, features] 194 | where dim is the resolution (width and height if the input is an image). 195 | blocks_per_group: How many resnet blocks to add to each group (should be 196 | 4 blocks for a WRN28, and 6 for a WRN40). 197 | channels: How many channels to use in the convolutional layers. 198 | strides: Strides for the pooling. 199 | train: If False, will use the moving average for batch norm statistics. 200 | Else, will use statistics computed on the batch. 201 | true_gradient: If true, the same mixing parameter will be used for the 202 | forward and backward pass (see paper for more details). 203 | 204 | Returns: 205 | The output of the resnet block. Will have shape 206 | [batch_size, dim, dim, channels] if strides = (1, 1) or 207 | [batch_size, dim/2, dim/2, channels] if strides = (2, 2). 208 | """ 209 | for i in range(blocks_per_group): 210 | x = ShakeShakeBlock( 211 | x, 212 | channels, 213 | strides if i == 0 else (1, 1), 214 | train=train, 215 | true_gradient=true_gradient) 216 | return x 217 | 218 | 219 | class WideResnetShakeShake(nn.Module): 220 | """Defines the WideResnet Model.""" 221 | 222 | def apply(self, 223 | x: jnp.ndarray, 224 | blocks_per_group: int, 225 | channel_multiplier: int, 226 | num_outputs: int, 227 | train: bool = True, 228 | true_gradient: bool = False) -> jnp.ndarray: 229 | """Implements a WideResnet with ShakeShake regularization module. 230 | 231 | Args: 232 | x: Input to the module. Should have shape [batch_size, dim, dim, 3] 233 | where dim is the resolution of the image. 234 | blocks_per_group: How many resnet blocks to add to each group (should be 235 | 4 blocks for a WRN26 as per standard shake shake implementation). 236 | channel_multiplier: The multiplier to apply to the number of filters in 237 | the model (1 is classical resnet, 6 for WRN26-2x6, etc...). 238 | num_outputs: Dimension of the output of the model (ie number of classes 239 | for a classification problem). 240 | train: If False, will use the moving average for batch norm statistics. 241 | Else, will use statistics computed on the batch. 242 | true_gradient: If true, the same mixing parameter will be used for the 243 | forward and backward pass (see paper for more details). 244 | 245 | Returns: 246 | The output of the WideResnet with ShakeShake regularization, a tensor of 247 | shape [batch_size, num_classes]. 248 | """ 249 | x = nn.Conv( 250 | x, 251 | 16, (3, 3), 252 | padding='SAME', 253 | kernel_init=utils.conv_kernel_init_fn, 254 | bias=False, 255 | name='init_conv') 256 | x = utils.activation(x, apply_relu=False, train=train, name='init_bn') 257 | x = WideResnetShakeShakeGroup( 258 | x, 259 | blocks_per_group, 260 | 16 * channel_multiplier, 261 | train=train, 262 | true_gradient=true_gradient) 263 | x = WideResnetShakeShakeGroup( 264 | x, 265 | blocks_per_group, 266 | 32 * channel_multiplier, (2, 2), 267 | train=train, 268 | true_gradient=true_gradient) 269 | x = WideResnetShakeShakeGroup( 270 | x, 271 | blocks_per_group, 272 | 64 * channel_multiplier, (2, 2), 273 | train=train, 274 | true_gradient=true_gradient) 275 | x = jax.nn.relu(x) 276 | x = nn.avg_pool(x, x.shape[1:3]) 277 | x = x.reshape((x.shape[0], -1)) 278 | return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) 279 | -------------------------------------------------------------------------------- /sam_jax/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Trains a model on cifar10, cifar100, SVHN, F-MNIST or imagenet.""" 16 | 17 | import os 18 | 19 | from absl import app 20 | from absl import flags 21 | from absl import logging 22 | import jax 23 | from sam.sam_jax.datasets import dataset_source as dataset_source_lib 24 | from sam.sam_jax.datasets import dataset_source_imagenet 25 | from sam.sam_jax.efficientnet import efficientnet 26 | from sam.sam_jax.imagenet_models import load_model as load_imagenet_model 27 | from sam.sam_jax.models import load_model 28 | from sam.sam_jax.training_utils import flax_training 29 | import tensorflow.compat.v2 as tf 30 | from tensorflow.io import gfile 31 | 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | flags.DEFINE_enum('dataset', 'cifar10', [ 36 | 'cifar10', 'cifar100', 'fashion_mnist', 'svhn', 'imagenet', 'Birdsnap', 37 | 'cifar100_brain', 'Stanford_Cars', 'Flowers', 'FGVC_Aircraft', 38 | 'Oxford_IIIT_Pets', 'Food_101' 39 | ], 'Name of the dataset.') 40 | flags.DEFINE_enum('model_name', 'WideResnet28x10', [ 41 | 'WideResnet28x10', 'WideResnet28x6_ShakeShake', 'Pyramid_ShakeDrop', 42 | 'Resnet50', 'Resnet101', 'Resnet152' 43 | ] + list(efficientnet.MODEL_CONFIGS.keys()), 'Name of the model to train.') 44 | flags.DEFINE_integer('num_epochs', 200, 45 | 'How many epochs the model should be trained for.') 46 | flags.DEFINE_integer( 47 | 'batch_size', 128, 'Global batch size. If multiple ' 48 | 'replicas are used, each replica will receive ' 49 | 'batch_size / num_replicas examples. Batch size should be divisible by ' 50 | 'the number of available devices.') 51 | flags.DEFINE_string( 52 | 'output_dir', '', 'Directory where the checkpoints and the tensorboard ' 53 | 'records should be saved.') 54 | flags.DEFINE_enum( 55 | 'image_level_augmentations', 'basic', ['none', 'basic', 'autoaugment', 56 | 'aa-only'], 57 | 'Augmentations applied to the images. Should be `none` for ' 58 | 'no augmentations, `basic` for the standard horizontal ' 59 | 'flips and random crops, and `autoaugment` for the best ' 60 | 'AutoAugment policy for cifar10. For SVHN, aa-only should be use for ' 61 | 'autoaugment without random crops or flips.' 62 | 'For Imagenet, setting to autoaugment will use RandAugment. For ' 63 | 'FromBrainDatasetSource datasets, this flag is ignored.') 64 | flags.DEFINE_enum( 65 | 'batch_level_augmentations', 'none', ['none', 'cutout', 'mixup', 'mixcut'], 66 | 'Augmentations that are applied at the batch level. ' 67 | 'Not used by Imagenet and FromBrainDatasetSource datasets.') 68 | 69 | 70 | def main(_): 71 | 72 | tf.enable_v2_behavior() 73 | # make sure tf does not allocate gpu memory 74 | tf.config.experimental.set_visible_devices([], 'GPU') 75 | 76 | # Performance gains on TPU by switching to hardware bernoulli. 77 | def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None): 78 | lax_key = jax.lax.tie_in(rng_key, 0.0) 79 | return jax.lax.rng_uniform(lax_key, 1.0, shape) < p 80 | 81 | def set_hardware_bernoulli(): 82 | jax.random.bernoulli = hardware_bernoulli 83 | 84 | set_hardware_bernoulli() 85 | 86 | # As we gridsearch the weight decay and the learning rate, we add them to the 87 | # output directory path so that each model has its own directory to save the 88 | # results in. We also add the `run_seed` which is "gridsearched" on to 89 | # replicate an experiment several times. 90 | output_dir_suffix = os.path.join( 91 | 'lr_' + str(FLAGS.learning_rate), 92 | 'wd_' + str(FLAGS.weight_decay), 93 | 'rho_' + str(FLAGS.sam_rho), 94 | 'seed_' + str(FLAGS.run_seed)) 95 | 96 | output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix) 97 | 98 | if not gfile.exists(output_dir): 99 | gfile.makedirs(output_dir) 100 | 101 | num_devices = jax.local_device_count() * jax.host_count() 102 | assert FLAGS.batch_size % num_devices == 0 103 | local_batch_size = FLAGS.batch_size // num_devices 104 | info = 'Total batch size: {} ({} x {} replicas)'.format( 105 | FLAGS.batch_size, local_batch_size, num_devices) 106 | logging.info(info) 107 | 108 | if FLAGS.dataset == 'cifar10': 109 | if FLAGS.from_pretrained_checkpoint: 110 | image_size = efficientnet.name_to_image_size(FLAGS.model_name) 111 | else: 112 | image_size = None 113 | dataset_source = dataset_source_lib.Cifar10( 114 | FLAGS.batch_size // jax.host_count(), 115 | FLAGS.image_level_augmentations, 116 | FLAGS.batch_level_augmentations, 117 | image_size=image_size) 118 | elif FLAGS.dataset == 'cifar100': 119 | if FLAGS.from_pretrained_checkpoint: 120 | image_size = efficientnet.name_to_image_size(FLAGS.model_name) 121 | else: 122 | image_size = None 123 | dataset_source = dataset_source_lib.Cifar100( 124 | FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations, 125 | FLAGS.batch_level_augmentations, image_size=image_size) 126 | 127 | elif FLAGS.dataset == 'fashion_mnist': 128 | dataset_source = dataset_source_lib.FashionMnist( 129 | FLAGS.batch_size, FLAGS.image_level_augmentations, 130 | FLAGS.batch_level_augmentations) 131 | elif FLAGS.dataset == 'svhn': 132 | dataset_source = dataset_source_lib.SVHN( 133 | FLAGS.batch_size, FLAGS.image_level_augmentations, 134 | FLAGS.batch_level_augmentations) 135 | elif FLAGS.dataset == 'imagenet': 136 | imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name) 137 | dataset_source = dataset_source_imagenet.Imagenet( 138 | FLAGS.batch_size // jax.host_count(), imagenet_image_size, 139 | FLAGS.image_level_augmentations) 140 | else: 141 | raise ValueError('Dataset not recognized.') 142 | 143 | if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset: 144 | if image_size is None or 'svhn' in FLAGS.dataset: 145 | image_size = 32 146 | num_channels = 3 147 | num_classes = 100 if FLAGS.dataset == 'cifar100' else 10 148 | elif FLAGS.dataset == 'fashion_mnist': 149 | image_size = 28 # For Fashion Mnist 150 | num_channels = 1 151 | num_classes = 10 152 | elif FLAGS.dataset == 'imagenet': 153 | image_size = imagenet_image_size 154 | num_channels = 3 155 | num_classes = 1000 156 | else: 157 | raise ValueError('Dataset not recognized.') 158 | 159 | try: 160 | model, state = load_imagenet_model.get_model(FLAGS.model_name, 161 | local_batch_size, image_size, 162 | num_classes) 163 | except load_imagenet_model.ModelNameError: 164 | model, state = load_model.get_model(FLAGS.model_name, 165 | local_batch_size, image_size, 166 | num_classes, num_channels) 167 | 168 | # Learning rate will be overwritten by the lr schedule, we set it to zero. 169 | optimizer = flax_training.create_optimizer(model, 0.0) 170 | 171 | flax_training.train(optimizer, state, dataset_source, output_dir, 172 | FLAGS.num_epochs) 173 | 174 | 175 | if __name__ == '__main__': 176 | tf.enable_v2_behavior() 177 | app.run(main) 178 | -------------------------------------------------------------------------------- /sam_jax/training_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /sam_jax/training_utils/flax_training_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sam.sam_jax.training_utils.flax_training.""" 16 | 17 | import os 18 | from typing import Tuple 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | from absl.testing import flagsaver 23 | import flax 24 | import jax 25 | from jax.lib import xla_bridge 26 | import jax.numpy as jnp 27 | import pandas as pd 28 | from sam.sam_jax.datasets import dataset_source 29 | from sam.sam_jax.training_utils import flax_training 30 | import tensorflow as tf 31 | from tensorflow.io import gfile 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | 37 | class MockDatasetSource(dataset_source.DatasetSource): 38 | """Simple linearly separable dataset for testing. 39 | 40 | See base class for more details. 41 | """ 42 | 43 | def __init__(self, mask=False): 44 | positive_input = tf.constant([[1.0 + i/20] for i in range(20)]) 45 | negative_input = tf.constant([[-1.0 - i/20] for i in range(20)]) 46 | positive_labels = tf.constant([[1, 0] for _ in range(20)]) 47 | negative_labels = tf.constant([[0, 1] for _ in range(20)]) 48 | inputs = tf.concat((positive_input, negative_input), 0) 49 | labels = tf.concat((positive_labels, negative_labels), 0) 50 | self.inputs, self.labels = inputs.numpy(), labels.numpy() 51 | if not mask: 52 | self._ds = tf.data.Dataset.from_tensor_slices({ 53 | 'image': inputs, 54 | 'label': labels 55 | }) 56 | else: 57 | self._ds = tf.data.Dataset.from_tensor_slices({ 58 | 'image': inputs, 59 | 'label': labels, 60 | 'mask': tf.constant([1.0 for _ in range(40)]) 61 | }) 62 | self.num_training_obs = 40 63 | self.batch_size = 16 64 | 65 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset: 66 | """Returns the training set. 67 | 68 | Args: 69 | use_augmentations: Ignored (see base class for more details). 70 | """ 71 | del use_augmentations 72 | return self._ds.batch(self.batch_size) 73 | 74 | def get_test(self) -> tf.data.Dataset: 75 | """Returns the test set.""" 76 | return self._ds.batch(self.batch_size) 77 | 78 | 79 | def _get_linear_model() -> Tuple[flax.nn.Model, flax.nn.Collection]: 80 | """Returns a linear model and its state.""" 81 | 82 | class LinearModel(flax.nn.Module): 83 | """Defines the linear model.""" 84 | 85 | def apply(self, 86 | x: jnp.ndarray, 87 | num_outputs: int, 88 | train: bool = False) -> jnp.ndarray: 89 | """Forward pass with a linear model. 90 | 91 | Args: 92 | x: Input of shape [batch_size, num_features]. 93 | num_outputs: Number of classes. 94 | train: Has no effect. 95 | 96 | Returns: 97 | A tensor of shape [batch_size, num_outputs]. 98 | """ 99 | del train 100 | x = flax.nn.Dense(x, num_outputs) 101 | return x 102 | 103 | input_shape, num_outputs = [1], 2 104 | module = LinearModel.partial(num_outputs=num_outputs) 105 | with flax.nn.stateful() as init_state: 106 | with flax.nn.stochastic(jax.random.PRNGKey(0)): 107 | _, initial_params = module.init_by_shape( 108 | jax.random.PRNGKey(0), [(input_shape, jnp.float32)]) 109 | model = flax.nn.Model(module, initial_params) 110 | return model, init_state 111 | 112 | 113 | def tensorboard_event_to_dataframe(path: str) -> pd.DataFrame: 114 | """Helper to get events written by tests. 115 | 116 | Args: 117 | path: Path where the tensorboard records were saved. 118 | 119 | Returns: 120 | The metric saved by tensorboard, as a dataframe. 121 | """ 122 | records = [] 123 | all_tb_path = gfile.glob(os.path.join(path, 'events.*.v2')) 124 | for tb_event_path in all_tb_path: 125 | for e in tf.compat.v1.train.summary_iterator(tb_event_path): 126 | if e.step: 127 | for v in e.summary.value: 128 | records.append(dict( 129 | step=e.step, metric=v.tag, 130 | value=float(tf.make_ndarray(v.tensor)))) 131 | df = pd.DataFrame.from_records(records) 132 | return df 133 | 134 | 135 | prev_xla_flags = None 136 | 137 | 138 | class FlaxTrainingTest(absltest.TestCase): 139 | 140 | # Run all tests with 8 CPU devices. 141 | # As in third_party/py/jax/tests/pmap_test.py 142 | def setUp(self): 143 | super(FlaxTrainingTest, self).setUp() 144 | global prev_xla_flags 145 | prev_xla_flags = os.getenv('XLA_FLAGS') 146 | flags_str = prev_xla_flags or '' 147 | # Don't override user-specified device count, or other XLA flags. 148 | if 'xla_force_host_platform_device_count' not in flags_str: 149 | os.environ['XLA_FLAGS'] = ( 150 | flags_str + ' --xla_force_host_platform_device_count=8') 151 | # Clear any cached backends so new CPU backend will pick up the env var. 152 | xla_bridge.get_backend.cache_clear() 153 | 154 | # Reset to previous configuration in case other test modules will be run. 155 | def tearDown(self): 156 | super(FlaxTrainingTest, self).tearDown() 157 | if prev_xla_flags is None: 158 | del os.environ['XLA_FLAGS'] 159 | else: 160 | os.environ['XLA_FLAGS'] = prev_xla_flags 161 | xla_bridge.get_backend.cache_clear() 162 | 163 | @flagsaver.flagsaver 164 | def test_TrainSimpleModel(self): 165 | """Model should reach 100% accuracy easily.""" 166 | model, state = _get_linear_model() 167 | dataset = MockDatasetSource() 168 | num_epochs = 10 169 | optimizer = flax_training.create_optimizer(model, 0.0) 170 | training_dir = self.create_tempdir().full_path 171 | FLAGS.learning_rate = 0.01 172 | flax_training.train( 173 | optimizer, state, dataset, training_dir, num_epochs) 174 | records = tensorboard_event_to_dataframe(training_dir) 175 | # Train error rate at the last step should be 0. 176 | records = records[records.metric == 'train_error_rate'] 177 | records = records.sort_values('step') 178 | self.assertEqual(records.value.values[-1], 0.0) 179 | 180 | @flagsaver.flagsaver 181 | def _test_ResumeTrainingAfterInterruption(self, use_ema: bool): # pylint:disable=invalid-name 182 | """Resuming training should match a run without interruption.""" 183 | if use_ema: 184 | FLAGS.ema_decay = 0.9 185 | model, state = _get_linear_model() 186 | dataset = MockDatasetSource() 187 | optimizer = flax_training.create_optimizer(model, 0.0) 188 | training_dir = self.create_tempdir().full_path 189 | FLAGS.learning_rate = 0.01 190 | FLAGS.use_learning_rate_schedule = False 191 | # First we train for 10 epochs and get the logs. 192 | num_epochs = 10 193 | reference_run_dir = os.path.join(training_dir, 'reference') 194 | flax_training.train( 195 | optimizer, state, dataset, reference_run_dir, num_epochs) 196 | records = tensorboard_event_to_dataframe(reference_run_dir) 197 | # In another directory (new experiment), we run the model for 4 epochs and 198 | # then for 10 epochs, to simulate an interruption. 199 | interrupted_run_dir = os.path.join(training_dir, 'interrupted') 200 | flax_training.train( 201 | optimizer, state, dataset, interrupted_run_dir, 4) 202 | flax_training.train( 203 | optimizer, state, dataset, interrupted_run_dir, 10) 204 | records_interrupted = tensorboard_event_to_dataframe(interrupted_run_dir) 205 | 206 | # Logs should match (order doesn't matter as it is a dataframe in tidy 207 | # format). 208 | def _make_hashable(row): 209 | return str([e if not isinstance(e, float) else round(e, 5) for e in row]) 210 | 211 | self.assertEqual( 212 | set([_make_hashable(e) for e in records_interrupted.values]), 213 | set([_make_hashable(e) for e in records.values])) 214 | 215 | def test_ResumeTrainingAfterInterruptionEMA(self): 216 | self._test_ResumeTrainingAfterInterruption(use_ema=True) 217 | 218 | def test_ResumeTrainingAfterInterruption(self): 219 | self._test_ResumeTrainingAfterInterruption(use_ema=False) 220 | 221 | def _RecomputeTestLoss(self, masked: int = False): # pylint:disable=invalid-name 222 | """Recomputes the loss of the final model to check the value logged.""" 223 | FLAGS.compute_top_5_error_rate = True 224 | model, state = _get_linear_model() 225 | dataset = MockDatasetSource(mask=masked) 226 | num_epochs = 2 227 | optimizer = flax_training.create_optimizer(model, 0.0) 228 | training_dir = self.create_tempdir().full_path 229 | flax_training.train( 230 | optimizer, state, dataset, training_dir, num_epochs) 231 | records = tensorboard_event_to_dataframe(training_dir) 232 | records = records[records.metric == 'test_loss'] 233 | final_test_loss = records.sort_values('step').value.values[-1] 234 | # Loads final model and state. 235 | optimizer, state, _ = flax_training.restore_checkpoint( 236 | optimizer, state, os.path.join(training_dir, 'checkpoints')) 237 | logits = optimizer.target(dataset.inputs) 238 | loss = flax_training.cross_entropy_loss(logits, dataset.labels) 239 | self.assertLess(abs(final_test_loss -loss), 1e-7) 240 | 241 | def test_RecomputeTestLoss(self): 242 | self._RecomputeTestLoss() 243 | 244 | def test_RecomputeTestLossMasked(self): 245 | self._RecomputeTestLoss(masked=True) 246 | 247 | def test_metrics(self): 248 | logits = jnp.array([[2.0, 1.0], [3.0, 1.0], [0.1, 1.6], [2.0, 5.0]]) 249 | truth = jnp.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) 250 | mask = jnp.array([1.0, 1.0, 1.0, 0.0]) 251 | self.assertEqual(flax_training.error_rate_metric(logits, truth), 0.25) 252 | self.assertEqual(flax_training.error_rate_metric(logits, truth, mask), 1/3) 253 | 254 | 255 | if __name__ == '__main__': 256 | absltest.main() 257 | --------------------------------------------------------------------------------