├── .gitignore
├── CONTRIBUTING
├── LICENSE
├── OWNERS
├── README.md
├── __init__.py
├── autoaugment
├── __init__.py
├── autoaugment.py
├── autoaugment_test.py
└── policies.py
├── figures
├── no_sam.png
├── sam_wide.png
└── summary_plot.png
├── requirements.txt
└── sam_jax
├── __init__.py
├── datasets
├── __init__.py
├── augmentation.py
├── dataset_source.py
├── dataset_source_imagenet.py
├── dataset_source_imagenet_test.py
└── dataset_source_test.py
├── efficientnet
├── __init__.py
├── efficientnet.py
├── efficientnet_test.py
├── optim.py
└── optim_test.py
├── imagenet_models
├── __init__.py
├── load_model.py
├── load_model_test.py
├── resnet.py
└── resnet_test.py
├── models
├── __init__.py
├── load_model.py
├── load_model_test.py
├── pyramidnet.py
├── utils.py
├── wide_resnet.py
└── wide_resnet_shakeshake.py
├── train.py
└── training_utils
├── __init__.py
├── flax_training.py
└── flax_training_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/BUILD
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/OWNERS:
--------------------------------------------------------------------------------
1 | pierreforet
2 | akleiner
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SAM: Sharpness-Aware Minimization for Efficiently Improving Generalization
2 |
3 | by Pierre Foret, Ariel Kleiner, Hossein Mobahi and Behnam Neyshabur.
4 |
5 |
6 | ## SAM in a few words
7 |
8 | **Abstract**: In today's heavily overparameterized models, the value of the training loss provides few guarantees on model generalization ability. Indeed, optimizing only the training loss value, as is commonly done, can easily lead to suboptimal model quality. Motivated by the connection between geometry of the loss landscape and generalization---including a generalization bound that we prove here---we introduce a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, our procedure, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss; this formulation results in a min-max optimization problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model generalization across a variety of benchmark datasets (e.g., CIFAR-{10, 100}, ImageNet, finetuning tasks) and models, yielding novel state-of-the-art performance for several. Additionally, we find that SAM natively provides robustness to label noise on par with that provided by state-of-the-art procedures that specifically target learning with noisy labels.
9 |
10 |
11 | |  |  |  |
12 | |:--------------:|:----------:|:----------------------:|
13 | | Error rate reduction obtained by switching to SAM. Each point is a different dataset / model / data augmentation | A sharp minimum to which a ResNet trained with SGD converged | A wide minimum to which the same ResNet trained with SAM converged. |
14 |
15 |
16 |
17 | ## About this repo
18 |
19 | This code allows the user to replicate most of the experiments of the paper, including:
20 |
21 | * Training from scratch Wideresnets and Pyramidnets (with shake shake / shake drop) on CIFAR10/CIFAR100/SVHN/Fashion MNIST, with or without SAM, with or without cutout and AutoAugment.
22 | * Training Resnets and Efficientnet on Imagenet, with or without SAM or RandAugment.
23 | * Finetuning Efficientnet from checkpoints trained on Imagenet/JFT on imagenet.
24 |
25 |
26 | ## How to train from scratch
27 |
28 | Once the repo is cloned, experiments can be launched using sam.sam_jax.train.py:
29 |
30 | ```
31 | python3 -m sam.sam_jax.train --dataset cifar10 --model_name WideResnet28x10 \
32 | --output_dir /tmp/my_experiment --image_level_augmentations autoaugment \
33 | --num_epochs 1800 --sam_rho 0.05
34 | ```
35 |
36 | Note that our code uses all available GPUs/TPUs for training.
37 |
38 | To see a detailed list of all available flags, run python3 -m sam.sam_jax.train --help.
39 |
40 | #### Output
41 |
42 | Training curves can be loaded using TensorBoard. TensorBoard events will be
43 | saved in the output_dir, and their path will contain the learning_rate,
44 | the weight_decay, rho and the random_seed.
45 |
46 | ## Finetuning EfficientNet:
47 |
48 | We provide a FLAX checkpoint compatible with our implementation for all checkpoints
49 | available on [official tensorflow efficientnet implementation](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet):
50 |
51 | | | B0 | B1 | B2 | B3 | B4 | B5 | B6 | B7 | B8 | L2 |
52 | |---------- |-------- | ------| ------|------ |------ |------ | --- | --- | --- | --- |
53 | | Baseline preprocessing | 76.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b0/checkpoint.tar.gz)) | 78.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b1/checkpoint.tar.gz)) | 79.8% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b2/checkpoint.tar.gz)) | 81.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b3/checkpoint.tar.gz)) | 82.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b4/checkpoint.tar.gz)) | 83.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckpts/efficientnet-b5/checkpoint.tar.gz)) | | || | |
54 | | AutoAugment (AA) | 77.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b0/checkpoint.tar.gz)) | 79.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b1/checkpoint.tar.gz)) | 80.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b2/checkpoint.tar.gz)) | 81.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b3/checkpoint.tar.gz)) | 82.9% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b4/checkpoint.tar.gz)) | 83.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b5/checkpoint.tar.gz)) | 84.0% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b6/checkpoint.tar.gz)) | 84.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/ckptsaug/efficientnet-b7/checkpoint.tar.gz)) || |
55 | | RandAugment (RA) | | | | | | 83.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/randaug/efficientnet-b5-randaug/checkpoint.tar.gz)) | | 84.7% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/randaug/efficientnet-b7-randaug/checkpoint.tar.gz)) | | |
56 | | AdvProp + AA | 77.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b0/checkpoint.tar.gz)) | 79.6% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b1/checkpoint.tar.gz)) | 80.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b2/checkpoint.tar.gz)) | 81.9% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b3/checkpoint.tar.gz)) | 83.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b4/checkpoint.tar.gz)) | 84.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b5/checkpoint.tar.gz)) | 84.8% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b6/checkpoint.tar.gz)) | 85.2% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b7/checkpoint.tar.gz)) | 85.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/advprop/efficientnet-b8/checkpoint.tar.gz))|| |
57 | | NoisyStudent + RA | 78.8% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b0/checkpoint.tar.gz)) | 81.5% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b1/checkpoint.tar.gz)) | 82.4% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b2/checkpoint.tar.gz)) | 84.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b3/checkpoint.tar.gz)) | 85.3% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b4/checkpoint.tar.gz)) | 86.1% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b5/checkpoint.tar.gz)) | 86.4% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b6/checkpoint.tar.gz)) | 86.9% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-b7/checkpoint.tar.gz)) | - |88.4% ([ckpt](https://storage.googleapis.com/gresearch/sam/efficientnet_checkpoints/noisystudent/efficientnet-l2/checkpoint.tar.gz)) |
58 |
59 | * We report in this table the scores as found [here](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). If you use these checkpoints with another codebase, you might obtain some slightly different scores based on the type of accelerator you use and your data processing pipeline (resizing algorithm in particular).
60 |
61 | * Advprop requires some slight modification of the input pipeline. See [here](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) for more details.
62 |
63 | * Please refer to the README of the [official tensorflow efficientnet implementation](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) to see which paper should be cited for which checkpoint.
64 |
65 | Once the checkpoint is downloaded and uncompressed, it can be finetuned on any of the datasets:
66 |
67 | ```
68 | python3 -m sam.sam_jax.train --output_dir /tmp/my_finetuning_experiment \
69 | --dataset imagenet --from_pretrained_checkpoint true \
70 | --efficientnet_checkpoint_path /tmp/path_to_efficientnet_checkpoint_to_finetune \
71 | --learning_rate 0.1 --model_name efficientnet-l2-475 --batch_size 512 \
72 | --num_epochs 10 --gradient_clipping 1.0 --label_smoothing 0.1 --sam_rho 0.1
73 | ```
74 |
75 |
76 | ## Bibtex
77 |
78 | ```
79 | @ARTICLE{2020arXiv201001412F,
80 | author = {{Foret}, Pierre and {Kleiner}, Ariel and {Mobahi}, Hossein and {Neyshabur}, Behnam},
81 | title = "{Sharpness-Aware Minimization for Efficiently Improving Generalization}",
82 | year = 2020,
83 | eid = {arXiv:2010.01412},
84 | eprint = {2010.01412},
85 | }
86 | ```
87 |
88 | **This is not an official Google product.**
89 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/autoaugment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/autoaugment/autoaugment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """AutoAugment and RandAugment policies for enhanced image preprocessing.
16 |
17 | AutoAugment Reference: https://arxiv.org/abs/1805.09501
18 | RandAugment Reference: https://arxiv.org/abs/1909.13719
19 |
20 | Forked from
21 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
22 | """
23 |
24 | import inspect
25 | import math
26 |
27 | import dataclasses
28 | from sam.autoaugment import policies as optimal_policies
29 | import tensorflow.compat.v1 as tf
30 | from tensorflow_addons import image as contrib_image
31 |
32 |
33 | # This signifies the max integer that the controller RNN could predict for the
34 | # augmentation scheme.
35 | _MAX_LEVEL = 10.
36 |
37 |
38 | @dataclasses.dataclass
39 | class HParams:
40 | """Parameters for AutoAugment and RandAugment."""
41 | cutout_const: int
42 | translate_const: int
43 |
44 |
45 | def blend(image1, image2, factor):
46 | """Blend image1 and image2 using 'factor'.
47 |
48 | Factor can be above 0.0. A value of 0.0 means only image1 is used.
49 | A value of 1.0 means only image2 is used. A value between 0.0 and
50 | 1.0 means we linearly interpolate the pixel values between the two
51 | images. A value greater than 1.0 "extrapolates" the difference
52 | between the two pixel values, and we clip the results to values
53 | between 0 and 255.
54 |
55 | Args:
56 | image1: An image Tensor of type uint8.
57 | image2: An image Tensor of type uint8.
58 | factor: A floating point value above 0.0.
59 |
60 | Returns:
61 | A blended image Tensor of type uint8.
62 | """
63 | if factor == 0.0:
64 | return tf.convert_to_tensor(image1)
65 | if factor == 1.0:
66 | return tf.convert_to_tensor(image2)
67 |
68 | image1 = tf.to_float(image1)
69 | image2 = tf.to_float(image2)
70 |
71 | difference = image2 - image1
72 | scaled = factor * difference
73 |
74 | # Do addition in float.
75 | temp = tf.to_float(image1) + scaled
76 |
77 | # Interpolate
78 | if factor > 0.0 and factor < 1.0:
79 | # Interpolation means we always stay within 0 and 255.
80 | return tf.cast(temp, tf.uint8)
81 |
82 | # Extrapolate:
83 | #
84 | # We need to clip and then cast.
85 | return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
86 |
87 |
88 | def cutout(image, pad_size, replace=0):
89 | """Apply cutout (https://arxiv.org/abs/1708.04552) to image.
90 |
91 | This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
92 | a random location within `img`. The pixel values filled in will be of the
93 | value `replace`. The located where the mask will be applied is randomly
94 | chosen uniformly over the whole image.
95 |
96 | Args:
97 | image: An image Tensor of type uint8.
98 | pad_size: Specifies how big the zero mask that will be generated is that
99 | is applied to the image. The mask will be of size
100 | (2*pad_size x 2*pad_size).
101 | replace: What pixel value to fill in the image in the area that has
102 | the cutout mask applied to it.
103 |
104 | Returns:
105 | An image Tensor that is of type uint8.
106 | """
107 | image_height = tf.shape(image)[0]
108 | image_width = tf.shape(image)[1]
109 |
110 | # Sample the center location in the image where the zero mask will be applied.
111 | cutout_center_height = tf.random_uniform(
112 | shape=[], minval=0, maxval=image_height,
113 | dtype=tf.int32)
114 |
115 | cutout_center_width = tf.random_uniform(
116 | shape=[], minval=0, maxval=image_width,
117 | dtype=tf.int32)
118 |
119 | lower_pad = tf.maximum(0, cutout_center_height - pad_size)
120 | upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
121 | left_pad = tf.maximum(0, cutout_center_width - pad_size)
122 | right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
123 |
124 | cutout_shape = [image_height - (lower_pad + upper_pad),
125 | image_width - (left_pad + right_pad)]
126 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
127 | mask = tf.pad(
128 | tf.zeros(cutout_shape, dtype=image.dtype),
129 | padding_dims, constant_values=1)
130 | mask = tf.expand_dims(mask, -1)
131 | mask = tf.tile(mask, [1, 1, 3])
132 | image = tf.where(
133 | tf.equal(mask, 0),
134 | tf.ones_like(image, dtype=image.dtype) * replace,
135 | image)
136 | return image
137 |
138 |
139 | def solarize(image, threshold=128):
140 | # For each pixel in the image, select the pixel
141 | # if the value is less than the threshold.
142 | # Otherwise, subtract 255 from the pixel.
143 | return tf.where(image < threshold, image, 255 - image)
144 |
145 |
146 | def solarize_add(image, addition=0, threshold=128):
147 | # For each pixel in the image less than threshold
148 | # we add 'addition' amount to it and then clip the
149 | # pixel value to be between 0 and 255. The value
150 | # of 'addition' is between -128 and 128.
151 | added_image = tf.cast(image, tf.int64) + addition
152 | added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
153 | return tf.where(image < threshold, added_image, image)
154 |
155 |
156 | def color(image, factor):
157 | """Equivalent of PIL Color."""
158 | degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
159 | return blend(degenerate, image, factor)
160 |
161 |
162 | def contrast(image, factor):
163 | """Equivalent of PIL Contrast."""
164 | degenerate = tf.image.rgb_to_grayscale(image)
165 | # Cast before calling tf.histogram.
166 | degenerate = tf.cast(degenerate, tf.int32)
167 |
168 | # Compute the grayscale histogram, then compute the mean pixel value,
169 | # and create a constant image size of that value. Use that as the
170 | # blending degenerate target of the original image.
171 | hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
172 | mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
173 | degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
174 | degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
175 | degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
176 | return blend(degenerate, image, factor)
177 |
178 |
179 | def brightness(image, factor):
180 | """Equivalent of PIL Brightness."""
181 | degenerate = tf.zeros_like(image)
182 | return blend(degenerate, image, factor)
183 |
184 |
185 | def posterize(image, bits):
186 | """Equivalent of PIL Posterize."""
187 | shift = 8 - bits
188 | return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
189 |
190 |
191 | def rotate(image, degrees, replace):
192 | """Rotates the image by degrees either clockwise or counterclockwise.
193 |
194 | Args:
195 | image: An image Tensor of type uint8.
196 | degrees: Float, a scalar angle in degrees to rotate all images by. If
197 | degrees is positive the image will be rotated clockwise otherwise it will
198 | be rotated counterclockwise.
199 | replace: A one or three value 1D tensor to fill empty pixels caused by
200 | the rotate operation.
201 |
202 | Returns:
203 | The rotated version of image.
204 | """
205 | # Convert from degrees to radians.
206 | degrees_to_radians = math.pi / 180.0
207 | radians = degrees * degrees_to_radians
208 |
209 | # In practice, we should randomize the rotation degrees by flipping
210 | # it negatively half the time, but that's done on 'degrees' outside
211 | # of the function.
212 | image = contrib_image.rotate(wrap(image), radians)
213 | return unwrap(image, replace)
214 |
215 |
216 | def translate_x(image, pixels, replace):
217 | """Equivalent of PIL Translate in X dimension."""
218 | image = contrib_image.translate(wrap(image), [-pixels, 0])
219 | return unwrap(image, replace)
220 |
221 |
222 | def translate_y(image, pixels, replace):
223 | """Equivalent of PIL Translate in Y dimension."""
224 | image = contrib_image.translate(wrap(image), [0, -pixels])
225 | return unwrap(image, replace)
226 |
227 |
228 | def shear_x(image, level, replace):
229 | """Equivalent of PIL Shearing in X dimension."""
230 | # Shear parallel to x axis is a projective transform
231 | # with a matrix form of:
232 | # [1 level
233 | # 0 1].
234 | image = contrib_image.transform(
235 | wrap(image), [1., level, 0., 0., 1., 0., 0., 0.])
236 | return unwrap(image, replace)
237 |
238 |
239 | def shear_y(image, level, replace):
240 | """Equivalent of PIL Shearing in Y dimension."""
241 | # Shear parallel to y axis is a projective transform
242 | # with a matrix form of:
243 | # [1 0
244 | # level 1].
245 | image = contrib_image.transform(
246 | wrap(image), [1., 0., 0., level, 1., 0., 0., 0.])
247 | return unwrap(image, replace)
248 |
249 |
250 | def autocontrast(image):
251 | """Implements Autocontrast function from PIL using TF ops.
252 |
253 | Args:
254 | image: A 3D uint8 tensor.
255 |
256 | Returns:
257 | The image after it has had autocontrast applied to it and will be of type
258 | uint8.
259 | """
260 |
261 | def scale_channel(image):
262 | """Scale the 2D image using the autocontrast rule."""
263 | # A possibly cheaper version can be done using cumsum/unique_with_counts
264 | # over the histogram values, rather than iterating over the entire image.
265 | # to compute mins and maxes.
266 | lo = tf.to_float(tf.reduce_min(image))
267 | hi = tf.to_float(tf.reduce_max(image))
268 |
269 | # Scale the image, making the lowest value 0 and the highest value 255.
270 | def scale_values(im):
271 | scale = 255.0 / (hi - lo)
272 | offset = -lo * scale
273 | im = tf.to_float(im) * scale + offset
274 | im = tf.clip_by_value(im, 0.0, 255.0)
275 | return tf.cast(im, tf.uint8)
276 |
277 | result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
278 | return result
279 |
280 | # Assumes RGB for now. Scales each channel independently
281 | # and then stacks the result.
282 | s1 = scale_channel(image[:, :, 0])
283 | s2 = scale_channel(image[:, :, 1])
284 | s3 = scale_channel(image[:, :, 2])
285 | image = tf.stack([s1, s2, s3], 2)
286 | return image
287 |
288 |
289 | def sharpness(image, factor):
290 | """Implements Sharpness function from PIL using TF ops."""
291 | orig_image = image
292 | image = tf.cast(image, tf.float32)
293 | # Make image 4D for conv operation.
294 | image = tf.expand_dims(image, 0)
295 | # SMOOTH PIL Kernel.
296 | kernel = tf.constant(
297 | [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
298 | shape=[3, 3, 1, 1]) / 13.
299 | # Tile across channel dimension.
300 | kernel = tf.tile(kernel, [1, 1, 3, 1])
301 | strides = [1, 1, 1, 1]
302 | degenerate = tf.nn.depthwise_conv2d(
303 | image, kernel, strides, padding='VALID', rate=[1, 1])
304 | degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
305 | degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
306 |
307 | # For the borders of the resulting image, fill in the values of the
308 | # original image.
309 | mask = tf.ones_like(degenerate)
310 | padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
311 | padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
312 | result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
313 |
314 | # Blend the final result.
315 | return blend(result, orig_image, factor)
316 |
317 |
318 | def equalize(image):
319 | """Implements Equalize function from PIL using TF ops."""
320 | def scale_channel(im, c):
321 | """Scale the data in the channel to implement equalize."""
322 | im = tf.cast(im[:, :, c], tf.int32)
323 | # Compute the histogram of the image channel.
324 | histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
325 |
326 | # For the purposes of computing the step, filter out the nonzeros.
327 | nonzero = tf.where(tf.not_equal(histo, 0))
328 | nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
329 | step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
330 |
331 | def build_lut(histo, step):
332 | # Compute the cumulative sum, shifting by step // 2
333 | # and then normalization by step.
334 | lut = (tf.cumsum(histo) + (step // 2)) // step
335 | # Shift lut, prepending with 0.
336 | lut = tf.concat([[0], lut[:-1]], 0)
337 | # Clip the counts to be in range. This is done
338 | # in the C code for image.point.
339 | return tf.clip_by_value(lut, 0, 255)
340 |
341 | # If step is zero, return the original image. Otherwise, build
342 | # lut from the full histogram and step and then index from it.
343 | result = tf.cond(tf.equal(step, 0),
344 | lambda: im,
345 | lambda: tf.gather(build_lut(histo, step), im))
346 |
347 | return tf.cast(result, tf.uint8)
348 |
349 | # Assumes RGB for now. Scales each channel independently
350 | # and then stacks the result.
351 | s1 = scale_channel(image, 0)
352 | s2 = scale_channel(image, 1)
353 | s3 = scale_channel(image, 2)
354 | image = tf.stack([s1, s2, s3], 2)
355 | return image
356 |
357 |
358 | def invert(image):
359 | """Inverts the image pixels."""
360 | image = tf.convert_to_tensor(image)
361 | return 255 - image
362 |
363 |
364 | def wrap(image):
365 | """Returns 'image' with an extra channel set to all 1s."""
366 | shape = tf.shape(image)
367 | extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
368 | extended = tf.concat([image, extended_channel], 2)
369 | return extended
370 |
371 |
372 | def unwrap(image, replace):
373 | """Unwraps an image produced by wrap.
374 |
375 | Where there is a 0 in the last channel for every spatial position,
376 | the rest of the three channels in that spatial dimension are grayed
377 | (set to 128). Operations like translate and shear on a wrapped
378 | Tensor will leave 0s in empty locations. Some transformations look
379 | at the intensity of values to do preprocessing, and we want these
380 | empty pixels to assume the 'average' value, rather than pure black.
381 |
382 |
383 | Args:
384 | image: A 3D Image Tensor with 4 channels.
385 | replace: A one or three value 1D tensor to fill empty pixels.
386 |
387 | Returns:
388 | image: A 3D image Tensor with 3 channels.
389 | """
390 | image_shape = tf.shape(image)
391 | # Flatten the spatial dimensions.
392 | flattened_image = tf.reshape(image, [-1, image_shape[2]])
393 |
394 | # Find all pixels where the last channel is zero.
395 | alpha_channel = flattened_image[:, 3]
396 |
397 | replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
398 |
399 | # Where they are zero, fill them in with 'replace'.
400 | flattened_image = tf.where(
401 | tf.equal(alpha_channel, 0),
402 | tf.ones_like(flattened_image, dtype=image.dtype) * replace,
403 | flattened_image)
404 |
405 | image = tf.reshape(flattened_image, image_shape)
406 | image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
407 | return image
408 |
409 |
410 | NAME_TO_FUNC = {
411 | 'AutoContrast': autocontrast,
412 | 'Equalize': equalize,
413 | 'Invert': invert,
414 | 'Rotate': rotate,
415 | 'Posterize': posterize,
416 | 'Solarize': solarize,
417 | 'SolarizeAdd': solarize_add,
418 | 'Color': color,
419 | 'Contrast': contrast,
420 | 'Brightness': brightness,
421 | 'Sharpness': sharpness,
422 | 'ShearX': shear_x,
423 | 'ShearY': shear_y,
424 | 'TranslateX': translate_x,
425 | 'TranslateY': translate_y,
426 | 'Cutout': cutout,
427 | }
428 |
429 |
430 | def _randomly_negate_tensor(tensor):
431 | """With 50% prob turn the tensor negative."""
432 | should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool)
433 | final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
434 | return final_tensor
435 |
436 |
437 | def _rotate_level_to_arg(level):
438 | level = (level/_MAX_LEVEL) * 30.
439 | level = _randomly_negate_tensor(level)
440 | return (level,)
441 |
442 |
443 | def _shrink_level_to_arg(level):
444 | """Converts level to ratio by which we shrink the image content."""
445 | if level == 0:
446 | return (1.0,) # if level is zero, do not shrink the image
447 | # Maximum shrinking ratio is 2.9.
448 | level = 2. / (_MAX_LEVEL / level) + 0.9
449 | return (level,)
450 |
451 |
452 | def _enhance_level_to_arg(level):
453 | return ((level/_MAX_LEVEL) * 1.8 + 0.1,)
454 |
455 |
456 | def _shear_level_to_arg(level):
457 | level = (level/_MAX_LEVEL) * 0.3
458 | # Flip level to negative with 50% chance.
459 | level = _randomly_negate_tensor(level)
460 | return (level,)
461 |
462 |
463 | def _translate_level_to_arg(level, translate_const):
464 | level = (level/_MAX_LEVEL) * float(translate_const)
465 | # Flip level to negative with 50% chance.
466 | level = _randomly_negate_tensor(level)
467 | return (level,)
468 |
469 |
470 | def level_to_arg(hparams):
471 | return {
472 | 'AutoContrast': lambda level: (),
473 | 'Equalize': lambda level: (),
474 | 'Invert': lambda level: (),
475 | 'Rotate': _rotate_level_to_arg,
476 | 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),),
477 | 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),),
478 | 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),),
479 | 'Color': _enhance_level_to_arg,
480 | 'Contrast': _enhance_level_to_arg,
481 | 'Brightness': _enhance_level_to_arg,
482 | 'Sharpness': _enhance_level_to_arg,
483 | 'ShearX': _shear_level_to_arg,
484 | 'ShearY': _shear_level_to_arg,
485 | 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),),
486 | # pylint:disable=g-long-lambda
487 | 'TranslateX': lambda level: _translate_level_to_arg(
488 | level, hparams.translate_const),
489 | 'TranslateY': lambda level: _translate_level_to_arg(
490 | level, hparams.translate_const),
491 | # pylint:enable=g-long-lambda
492 | }
493 |
494 |
495 | def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
496 | """Return the function that corresponds to `name` and update `level` param."""
497 | func = NAME_TO_FUNC[name]
498 | args = level_to_arg(augmentation_hparams)[name](level)
499 |
500 | # Check to see if prob is passed into function. This is used for operations
501 | # where we alter bboxes independently.
502 | # pytype:disable=wrong-arg-types
503 | if 'prob' in inspect.getfullargspec(func)[0]:
504 | args = tuple([prob] + list(args))
505 | # pytype:enable=wrong-arg-types
506 |
507 | # Add in replace arg if it is required for the function that is being called.
508 | # pytype:disable=wrong-arg-types
509 | if 'replace' in inspect.getfullargspec(func)[0]:
510 | # Make sure replace is the final argument
511 | assert 'replace' == inspect.getfullargspec(func)[0][-1]
512 | args = tuple(list(args) + [replace_value])
513 | # pytype:enable=wrong-arg-types
514 |
515 | return (func, prob, args)
516 |
517 |
518 | def _apply_func_with_prob(func, image, args, prob):
519 | """Apply `func` to image w/ `args` as input with probability `prob`."""
520 | assert isinstance(args, tuple)
521 |
522 | # If prob is a function argument, then this randomness is being handled
523 | # inside the function, so make sure it is always called.
524 | # pytype:disable=wrong-arg-types
525 | if 'prob' in inspect.getfullargspec(func)[0]:
526 | prob = 1.0
527 | # pytype:enable=wrong-arg-types
528 |
529 | # Apply the function with probability `prob`.
530 | should_apply_op = tf.cast(
531 | tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool)
532 | augmented_image = tf.cond(
533 | should_apply_op,
534 | lambda: func(image, *args),
535 | lambda: image)
536 | return augmented_image
537 |
538 |
539 | def select_and_apply_random_policy(policies, image):
540 | """Select a random policy from `policies` and apply it to `image`."""
541 | policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32)
542 | # Note that using tf.case instead of tf.conds would result in significantly
543 | # larger graphs and would even break export for some larger policies.
544 | for (i, policy) in enumerate(policies):
545 | image = tf.cond(
546 | tf.equal(i, policy_to_select),
547 | lambda selected_policy=policy: selected_policy(image),
548 | lambda: image)
549 | return image
550 |
551 |
552 | def build_and_apply_nas_policy(policies, image,
553 | augmentation_hparams):
554 | """Build a policy from the given policies passed in and apply to image.
555 |
556 | Args:
557 | policies: list of lists of tuples in the form `(func, prob, level)`, `func`
558 | is a string name of the augmentation function, `prob` is the probability
559 | of applying the `func` operation, `level` is the input argument for
560 | `func`.
561 | image: tf.Tensor that the resulting policy will be applied to.
562 | augmentation_hparams: Hparams associated with the NAS learned policy.
563 |
564 | Returns:
565 | A version of image that now has data augmentation applied to it based on
566 | the `policies` pass into the function.
567 | """
568 | replace_value = [128, 128, 128]
569 |
570 | # func is the string name of the augmentation function, prob is the
571 | # probability of applying the operation and level is the parameter associated
572 | # with the tf op.
573 |
574 | # tf_policies are functions that take in an image and return an augmented
575 | # image.
576 | tf_policies = []
577 | for policy in policies:
578 | tf_policy = []
579 | # Link string name to the correct python function and make sure the correct
580 | # argument is passed into that function.
581 | for policy_info in policy:
582 | policy_info = list(policy_info) + [replace_value, augmentation_hparams]
583 |
584 | tf_policy.append(_parse_policy_info(*policy_info))
585 | # Now build the tf policy that will apply the augmentation procedue
586 | # on image.
587 | def make_final_policy(tf_policy_):
588 | def final_policy(image_):
589 | for func, prob, args in tf_policy_:
590 | image_ = _apply_func_with_prob(
591 | func, image_, args, prob)
592 | return image_
593 | return final_policy
594 | tf_policies.append(make_final_policy(tf_policy))
595 |
596 | augmented_image = select_and_apply_random_policy(
597 | tf_policies, image)
598 | return augmented_image
599 |
600 |
601 | def distort_image_with_autoaugment(image, augmentation_name):
602 | """Applies the AutoAugment policy to `image`.
603 |
604 | AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
605 |
606 | Args:
607 | image: `Tensor` of shape [height, width, 3] representing an image.
608 | augmentation_name: The name of the AutoAugment policy to use.
609 |
610 | Returns:
611 | A tuple containing the augmented versions of `image`.
612 | """
613 | available_policies = {'imagenet': optimal_policies.policy_imagenet,
614 | 'cifar': optimal_policies.policy_cifar,
615 | 'svhn': optimal_policies.policy_svhn}
616 | if augmentation_name not in available_policies:
617 | raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name))
618 |
619 | policy = available_policies[augmentation_name]()
620 | # Hparams that will be used for AutoAugment.
621 | augmentation_hparams = HParams(cutout_const=100, translate_const=250)
622 |
623 | return build_and_apply_nas_policy(policy, image, augmentation_hparams)
624 |
625 |
626 | def distort_image_with_randaugment(image, num_layers, magnitude):
627 | """Applies the RandAugment policy to `image`.
628 |
629 | RandAugment is from the paper https://arxiv.org/abs/1909.13719,
630 |
631 | Args:
632 | image: `Tensor` of shape [height, width, 3] representing an image.
633 | num_layers: Integer, the number of augmentation transformations to apply
634 | sequentially to an image. Represented as (N) in the paper. Usually best
635 | values will be in the range [1, 3].
636 | magnitude: Integer, shared magnitude across all augmentation operations.
637 | Represented as (M) in the paper. Usually best values are in the range
638 | [5, 30].
639 |
640 | Returns:
641 | The augmented version of `image`.
642 | """
643 | replace_value = [128] * 3
644 | tf.logging.info('Using RandAug.')
645 | augmentation_hparams = HParams(cutout_const=40, translate_const=100)
646 | available_ops = [
647 | 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize',
648 | 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness',
649 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd']
650 |
651 | for layer_num in range(num_layers):
652 | op_to_select = tf.random_uniform(
653 | [], maxval=len(available_ops), dtype=tf.int32)
654 | random_magnitude = float(magnitude)
655 | with tf.name_scope('randaug_layer_{}'.format(layer_num)):
656 | for (i, op_name) in enumerate(available_ops):
657 | prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32)
658 | func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
659 | replace_value, augmentation_hparams)
660 | image = tf.cond(
661 | tf.equal(i, op_to_select),
662 | # pylint:disable=g-long-lambda
663 | lambda selected_func=func, selected_args=args: selected_func(
664 | image, *selected_args),
665 | # pylint:enable=g-long-lambda
666 | lambda: image)
667 | return image
668 |
--------------------------------------------------------------------------------
/autoaugment/autoaugment_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.autoaugment.autoaugment."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from sam.autoaugment import autoaugment
20 | import tensorflow as tf
21 |
22 |
23 | class AutoAugmentTest(parameterized.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | ('ShearX', 'ShearX'),
27 | ('ShearY', 'ShearY'),
28 | ('Cutout', 'Cutout'),
29 | ('TranslateX', 'TranslateX'),
30 | ('TranslateY', 'TranslateY'),
31 | ('Rotate', 'Rotate'),
32 | ('AutoContrast', 'AutoContrast'),
33 | ('Invert', 'Invert'),
34 | ('Equalize', 'Equalize'),
35 | ('Solarize', 'Solarize'),
36 | ('Posterize', 'Posterize'),
37 | ('Contrast', 'Contrast'),
38 | ('Color', 'Color'),
39 | ('Brightness', 'Brightness'),
40 | ('Sharpness', 'Sharpness'))
41 | def test_image_processing_function(self, name: str):
42 | hparams = autoaugment.HParams(cutout_const=10, translate_const=25)
43 | replace_value = [128, 128, 128]
44 | function, _, args = autoaugment._parse_policy_info(
45 | name, 1.0, 10, replace_value, hparams)
46 | cifar_image_shape = [32, 32, 3]
47 | image = tf.zeros(cifar_image_shape, tf.uint8)
48 | augmented_image = function(image, *args)
49 | self.assertEqual(augmented_image.shape, cifar_image_shape)
50 | self.assertEqual(augmented_image.dtype, tf.uint8)
51 |
52 | @parameterized.named_parameters(('cifar', 'cifar'), ('svhn', 'svhn'),
53 | ('imagenet', 'imagenet'))
54 | def test_autoaugment_function(self, dataset_name):
55 | autoaugment_fn = lambda image: autoaugment.distort_image_with_autoaugment( # pylint:disable=g-long-lambda
56 | image, dataset_name)
57 | image_shape = [224, 224, 3] if dataset_name == 'imagenet' else [32, 32, 3]
58 | image = tf.zeros(image_shape, tf.uint8)
59 | augmented_image = autoaugment_fn(image)
60 | self.assertEqual(augmented_image.shape, image_shape)
61 | self.assertEqual(augmented_image.dtype, tf.uint8)
62 |
63 | if __name__ == '__main__':
64 | absltest.main()
65 |
--------------------------------------------------------------------------------
/autoaugment/policies.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """AutoAugment policies for Imagenet, Cifar and SVHN."""
16 |
17 | from typing import List, Tuple
18 |
19 |
20 | def policy_imagenet() -> List[List[Tuple[str, float, int]]]:
21 | """Returns the autoaugment policy that was used in AutoAugment Paper.
22 |
23 | A policy is composed of two augmentations applied sequentially to the image.
24 | Each augmentation is described as a tuple where the first element is the
25 | type of transformation to apply, the second is the probability with which the
26 | augmentation should be applied, and the third element is the strength of the
27 | transformation.
28 | """
29 | policy = [
30 | [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
31 | [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
32 | [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
33 | [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
34 | [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
35 | [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
36 | [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
37 | [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
38 | [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
39 | [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
40 | [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
41 | [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
42 | [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
43 | [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
44 | [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
45 | [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
46 | [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
47 | [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
48 | [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
49 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
50 | [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
51 | [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
52 | [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
53 | [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
54 | [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
55 | ]
56 | return policy
57 |
58 |
59 | def policy_cifar() -> List[List[Tuple[str, float, int]]]:
60 | """Returns the AutoAugment policies found on Cifar.
61 |
62 | A policy is composed of two augmentations applied sequentially to the image.
63 | Each augmentation is described as a tuple where the first element is the
64 | type of transformation to apply, the second is the probability with which the
65 | augmentation should be applied, and the third element is the strength of the
66 | transformation.
67 | """
68 | exp0_0 = [
69 | [('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
70 | [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
71 | [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
72 | [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
73 | [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
74 | exp0_1 = [
75 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
76 | [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
77 | [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
78 | [('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
79 | [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
80 | exp0_2 = [
81 | [('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
82 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
83 | [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
84 | [('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
85 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
86 | exp0_3 = [
87 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
88 | [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
89 | [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
90 | [('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
91 | [('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
92 | exp1_0 = [
93 | [('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
94 | [('Color', 0.4, 3), ('Brightness', 0.6, 7)],
95 | [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
96 | [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
97 | [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
98 | exp1_1 = [
99 | [('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
100 | [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
101 | [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
102 | [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
103 | [('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
104 | exp1_2 = [
105 | [('Solarize', 0.2, 6), ('Color', 0.8, 6)],
106 | [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
107 | [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
108 | [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
109 | [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
110 | exp1_3 = [
111 | [('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
112 | [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
113 | [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
114 | [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
115 | [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
116 | exp1_4 = [
117 | [('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
118 | [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
119 | [('Equalize', 0.6, 8), ('Color', 0.6, 2)],
120 | [('Color', 0.3, 7), ('Color', 0.2, 4)],
121 | [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
122 | exp1_5 = [
123 | [('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
124 | [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
125 | [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
126 | [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
127 | [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
128 | exp1_6 = [
129 | [('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
130 | [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
131 | [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
132 | [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
133 | [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
134 | exp2_0 = [
135 | [('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
136 | [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
137 | [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
138 | [('Brightness', 0.9, 6), ('Color', 0.2, 8)],
139 | [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
140 | exp2_1 = [
141 | [('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
142 | [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
143 | [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
144 | [('Color', 0.1, 8), ('ShearY', 0.2, 3)],
145 | [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
146 | exp2_2 = [
147 | [('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
148 | [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
149 | [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
150 | [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
151 | [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
152 | exp2_3 = [
153 | [('Equalize', 0.9, 5), ('Color', 0.7, 0)],
154 | [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
155 | [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
156 | [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
157 | [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
158 | exp2_4 = [
159 | [('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
160 | [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
161 | [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
162 | [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
163 | [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
164 | exp2_5 = [
165 | [('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
166 | [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
167 | [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
168 | [('Solarize', 0.4, 3), ('Color', 0.2, 4)],
169 | [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
170 | exp2_6 = [
171 | [('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
172 | [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
173 | [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
174 | [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
175 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
176 | exp2_7 = [
177 | [('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
178 | [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
179 | [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
180 | [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
181 | [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
182 | exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
183 | exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
184 | exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
185 | return exp0s + exp1s + exp2s
186 |
187 |
188 | def policy_svhn() -> List[List[Tuple[str, float, int]]]:
189 | """Returns the AutoAugment policies found on SVHN.
190 |
191 | A policy is composed of two augmentations applied sequentially to the image.
192 | Each augmentation is described as a tuple where the first element is the
193 | type of transformation to apply, the second is the probability with which the
194 | augmentation should be applied, and the third element is the strength of the
195 | transformation.
196 | """
197 | return [[('ShearX', 0.9, 4), ('Invert', 0.2, 3)],
198 | [('ShearY', 0.9, 8), ('Invert', 0.7, 5)],
199 | [('Equalize', 0.6, 5), ('Solarize', 0.6, 6)],
200 | [('Invert', 0.9, 3), ('Equalize', 0.6, 3)],
201 | [('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
202 | [('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)],
203 | [('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
204 | [('ShearY', 0.9, 5), ('Solarize', 0.2, 6)],
205 | [('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)],
206 | [('Equalize', 0.6, 3), ('Rotate', 0.9, 3)],
207 | [('ShearX', 0.9, 4), ('Solarize', 0.3, 3)],
208 | [('ShearY', 0.8, 8), ('Invert', 0.7, 4)],
209 | [('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)],
210 | [('Invert', 0.9, 4), ('Equalize', 0.6, 7)],
211 | [('Contrast', 0.3, 3), ('Rotate', 0.8, 4)],
212 | [('Invert', 0.8, 5), ('TranslateY', 0.0, 2)],
213 | [('ShearY', 0.7, 6), ('Solarize', 0.4, 8)],
214 | [('Invert', 0.6, 4), ('Rotate', 0.8, 4)],
215 | [('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)],
216 | [('ShearX', 0.1, 6), ('Invert', 0.6, 5)],
217 | [('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)],
218 | [('ShearY', 0.8, 4), ('Invert', 0.8, 8)],
219 | [('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)],
220 | [('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)],
221 | [('ShearX', 0.7, 2), ('Invert', 0.1, 5)]]
222 |
--------------------------------------------------------------------------------
/figures/no_sam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/sam/dae9904c4cf3a57a304f7b04cecffe371679c702/figures/no_sam.png
--------------------------------------------------------------------------------
/figures/sam_wide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/sam/dae9904c4cf3a57a304f7b04cecffe371679c702/figures/sam_wide.png
--------------------------------------------------------------------------------
/figures/summary_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/sam/dae9904c4cf3a57a304f7b04cecffe371679c702/figures/summary_plot.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.11.0
2 | astunparse==1.6.3
3 | attrs==20.3.0
4 | cachetools==4.1.1
5 | certifi==2020.11.8
6 | chardet==3.0.4
7 | cloudpickle==1.6.0
8 | cycler==0.10.0
9 | decorator==4.4.2
10 | dill==0.3.3
11 | dm-tree==0.1.5
12 | flatbuffers==1.12
13 | flax==0.2.2
14 | future==0.18.2
15 | gast==0.3.3
16 | google-auth==1.23.0
17 | google-auth-oauthlib==0.4.2
18 | google-pasta==0.2.0
19 | googleapis-common-protos==1.52.0
20 | grpcio==1.34.0
21 | h5py==2.10.0
22 | idna==2.10
23 | importlib-resources==3.3.0
24 | jax==0.2.6
25 | jaxlib==0.1.57
26 | Keras-Preprocessing==1.1.2
27 | kiwisolver==1.3.1
28 | Markdown==3.3.3
29 | matplotlib==3.3.3
30 | msgpack==1.0.0
31 | numpy==1.18.5
32 | oauthlib==3.1.0
33 | opt-einsum==3.3.0
34 | pandas==1.1.4
35 | Pillow==8.0.1
36 | promise==2.3
37 | protobuf==3.14.0
38 | pyasn1==0.4.8
39 | pyasn1-modules==0.2.8
40 | pyparsing==2.4.7
41 | python-dateutil==2.8.1
42 | pytz==2020.4
43 | requests==2.25.0
44 | requests-oauthlib==1.3.0
45 | rsa==4.6
46 | scipy==1.5.4
47 | six==1.15.0
48 | tensorboard==2.4.0
49 | tensorboard-plugin-wit==1.7.0
50 | tensorflow==2.3.1
51 | tensorflow-addons==0.11.2
52 | tensorflow-datasets==4.1.0
53 | tensorflow-estimator==2.3.0
54 | tensorflow-metadata==0.25.0
55 | tensorflow-probability==0.11.1
56 | termcolor==1.1.0
57 | tqdm==4.54.0
58 | typeguard==2.10.0
59 | typing==3.7.4.3
60 | urllib3==1.26.2
61 | Werkzeug==1.0.1
62 | wrapt==1.12.1
63 |
--------------------------------------------------------------------------------
/sam_jax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/sam_jax/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/sam_jax/datasets/augmentation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implements data augmentations for cifar10/cifar100."""
16 |
17 | from typing import Dict
18 |
19 | from absl import flags
20 | from sam.autoaugment import autoaugment
21 | import tensorflow as tf
22 | import tensorflow_probability as tfp
23 |
24 |
25 | FLAGS = flags.FLAGS
26 |
27 |
28 | flags.DEFINE_integer('cutout_length', 16,
29 | 'Length (in pixels) of the cutout patch. Default value of '
30 | '16 is used to get SOTA on cifar10/cifar100')
31 |
32 |
33 | def weak_image_augmentation(example: Dict[str, tf.Tensor],
34 | random_crop_pad: int = 4) -> Dict[str, tf.Tensor]:
35 | """Applies random crops and horizontal flips.
36 |
37 | Simple data augmentations that are (almost) always used with cifar. Pad the
38 | image with `random_crop_pad` before randomly cropping it to its original
39 | size. Also randomly apply horizontal flip.
40 |
41 | Args:
42 | example: An example dict containing an image and a label.
43 | random_crop_pad: By how many pixels should the image be padded on each side
44 | before cropping.
45 |
46 | Returns:
47 | An example with the same label and an augmented version of the image.
48 | """
49 | image, label = example['image'], example['label']
50 | image = tf.image.random_flip_left_right(image)
51 | image_shape = tf.shape(image)
52 | image = tf.pad(
53 | image, [[random_crop_pad, random_crop_pad],
54 | [random_crop_pad, random_crop_pad], [0, 0]],
55 | mode='REFLECT')
56 | image = tf.image.random_crop(image, image_shape)
57 | return {'image': image, 'label': label}
58 |
59 |
60 | def auto_augmentation(example: Dict[str, tf.Tensor],
61 | dataset_name: str) -> Dict[str, tf.Tensor]:
62 | """Applies the AutoAugment policy found for the dataset.
63 |
64 | AutoAugment: Learning Augmentation Policies from Data
65 | https://arxiv.org/abs/1805.09501
66 |
67 | Args:
68 | example: An example dict containing an image and a label.
69 | dataset_name: Name of the dataset for which we should return the optimal
70 | policy. Should be 'cifar[10|100]', 'svhn' or 'imagenet'.
71 |
72 | Returns:
73 | An example with the same label and an augmented version of the image.
74 | """
75 | if dataset_name in ('cifar10', 'cifar100'): dataset_name = 'cifar'
76 | image, label = example['image'], example['label']
77 | image = autoaugment.distort_image_with_autoaugment(image, dataset_name)
78 | return {'image': image, 'label': label}
79 |
80 |
81 | def cutout(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
82 | """Applies cutout to a batch of images.
83 |
84 | The cut out patch will be replaced by zeros (thus the batch should be
85 | normalized before cutout is applied).
86 |
87 | Reference:
88 | Improved Regularization of Convolutional Neural Networks with Cutout
89 | https://arxiv.org/abs/1708.04552
90 |
91 | Implementation inspired by:
92 | third_party/cloud_tpu/models/efficientnet/autoaugment.py
93 |
94 | Args:
95 | batch: A batch of images and labels.
96 |
97 | Returns:
98 | The same batch where cutout has been applied to the images.
99 | """
100 | length, replace = FLAGS.cutout_length, 0.0
101 | images, labels = batch['image'], batch['label']
102 | num_channels = tf.shape(images)[3]
103 | image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
104 |
105 | cutout_center_height = tf.random.uniform(
106 | shape=[], minval=0, maxval=image_height,
107 | dtype=tf.int32)
108 | cutout_center_width = tf.random.uniform(
109 | shape=[], minval=0, maxval=image_width,
110 | dtype=tf.int32)
111 |
112 | lower_pad = tf.maximum(0, cutout_center_height - length // 2)
113 | upper_pad = tf.maximum(0, image_height - cutout_center_height - length // 2)
114 | left_pad = tf.maximum(0, cutout_center_width - length // 2)
115 | right_pad = tf.maximum(0, image_width - cutout_center_width - length // 2)
116 |
117 | cutout_shape = [image_height - (lower_pad + upper_pad),
118 | image_width - (left_pad + right_pad)]
119 |
120 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
121 |
122 | mask = tf.pad(
123 | tf.zeros(cutout_shape, dtype=images.dtype),
124 | padding_dims, constant_values=1)
125 |
126 | patch = tf.ones_like(images, dtype=images.dtype) * replace,
127 |
128 | mask = tf.expand_dims(mask, -1)
129 | mask = tf.tile(mask, [1, 1, num_channels])
130 |
131 | images = tf.where(
132 | tf.equal(mask, 0),
133 | patch,
134 | images)
135 |
136 | images = tf.squeeze(images, axis=0)
137 |
138 | return {'image': images, 'label': labels}
139 |
140 |
141 | def mixup(batch: Dict[str, tf.Tensor],
142 | alpha: float = 1.0) -> Dict[str, tf.Tensor]:
143 | """Generates augmented images using Mixup.
144 |
145 | Arguments:
146 | batch: Feature dict containing the images and the labels.
147 | alpha: Float that controls the strength of Mixup regularization.
148 |
149 | Returns:
150 | A feature dict containing the images and labels augmented with mixup.
151 | """
152 | images, labels = batch['image'], batch['label']
153 | batch_size = 1 # Unique mixing parameter for all samples
154 | mix_weight = tfp.distributions.Beta(alpha, alpha).sample([batch_size, 1])
155 | mix_weight = tf.maximum(mix_weight, 1. - mix_weight)
156 | images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1])
157 | images_mix = (
158 | images * images_mix_weight + images[::-1] * (1. - images_mix_weight))
159 | labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight)
160 | return {'image': images_mix, 'label': labels_mix}
161 |
--------------------------------------------------------------------------------
/sam_jax/datasets/dataset_source.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility class to load datasets and apply data augmentation."""
16 |
17 | import abc
18 | from typing import Callable, Dict, Optional
19 |
20 | from absl import flags
21 | from absl import logging
22 | import jax
23 | from sam.sam_jax.datasets import augmentation
24 | import tensorflow as tf
25 | import tensorflow_datasets as tfds
26 |
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | flags.DEFINE_bool('use_test_set', True,
32 | 'Whether to use the test set or not. If not, then 10% '
33 | 'observations will be set aside from the training set and '
34 | 'used as a validation set instead.')
35 |
36 |
37 | class DatasetSource(abc.ABC):
38 | """Parent for classes that load, preprocess and serve datasets.
39 |
40 | Child class constructor should set a `num_training_obs` and a `batch_size`
41 | attribute.
42 | """
43 | batch_size = ... # type: int
44 | num_training_obs = ... # type: int
45 |
46 | @abc.abstractmethod
47 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset:
48 | """Returns the training set.
49 |
50 | The training set will be batched, and the remainder of the batch will be
51 | dropped (except if use_augmentation is False, in which case we don't drop
52 | the remainder as we are most likely computing the accuracy on the train set.
53 |
54 | Args:
55 | use_augmentations: Whether we should apply data augmentation (and possibly
56 | cutout) or not.
57 | """
58 |
59 | @abc.abstractmethod
60 | def get_test(self) -> tf.data.Dataset:
61 | """Returns test set."""
62 |
63 |
64 | def _resize(image: tf.Tensor, image_size: int, method: Optional[str] = None):
65 | if method is not None:
66 | return tf.image.resize(image, [image_size, image_size], method)
67 | return tf.compat.v1.image.resize_bicubic(image, [image_size, image_size])
68 |
69 |
70 | class TFDSDatasetSource(DatasetSource):
71 | """Parent for classes that load, preprocess and serve TensorFlow datasets.
72 |
73 | Small datasets like CIFAR, SVHN and Fashion MNIST subclass TFDSDatasetSource.
74 | """
75 | batch_size = ... # type: int
76 | num_training_obs = ... # type: int
77 | _train_ds = ... # type: tf.data.Dataset
78 | _test_ds = ... # type: tf.data.Dataset
79 | _augmentation = ... # type: str
80 | _num_classes = ... # type: int
81 | _image_mean = ... # type: tf.tensor
82 | _image_std = ... # type: tf.tensor
83 | _dataset_name = ... # type: str
84 | _batch_level_augmentations = ... # type: Callable
85 | _image_size = ... # type: Optional[int]
86 |
87 | def _apply_image_augmentations(
88 | self, example: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
89 | if self._augmentation in ['autoaugment', 'aa-only']:
90 | example = augmentation.auto_augmentation(example, self._dataset_name)
91 | if self._augmentation in ['basic', 'autoaugment']:
92 | example = augmentation.weak_image_augmentation(example)
93 | return example
94 |
95 | def _preprocess_batch(self,
96 | examples: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
97 | image, label = examples['image'], examples['label']
98 | image = tf.cast(image, tf.float32) / 255.0
99 | image = (image - self._image_mean) / self._image_std
100 | label = tf.one_hot(
101 | label, depth=self._num_classes, on_value=1.0, off_value=0.0)
102 | return {'image': image, 'label': label}
103 |
104 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset:
105 | """Returns the training set.
106 |
107 | The training set will be batched, and the remainder of the batch will be
108 | dropped (except if use_augmentations is False, in which case we don't drop
109 | the remainder as we are most likely computing the accuracy on the train
110 | set).
111 |
112 | Args:
113 | use_augmentations: Whether we should apply data augmentation (and possibly
114 | cutout) or not.
115 | """
116 | ds = self._train_ds.shuffle(50000)
117 | if use_augmentations:
118 | ds = ds.map(self._apply_image_augmentations,
119 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
120 | # Don't drop remainder if we don't use augmentation, as we are evaluating.
121 | ds = ds.batch(self.batch_size, drop_remainder=use_augmentations)
122 | ds = ds.map(self._preprocess_batch,
123 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
124 | if self._batch_level_augmentations and use_augmentations:
125 | ds = ds.map(self._batch_level_augmentations,
126 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
127 | if self._image_size:
128 | def resize(batch):
129 | image = _resize(batch['image'], self._image_size)
130 | return {'image': image, 'label': batch['label']}
131 | ds = ds.map(resize)
132 | return ds
133 |
134 | def get_test(self) -> tf.data.Dataset:
135 | """Returns the batched test set."""
136 | eval_batch_size = min(32, self.batch_size)
137 | ds = self._test_ds.batch(eval_batch_size).map(
138 | self._preprocess_batch,
139 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
140 | if self._image_size:
141 | def resize(batch):
142 | image = _resize(batch['image'], self._image_size)
143 | return {'image': image, 'label': batch['label']}
144 | ds = ds.map(resize)
145 | return ds
146 |
147 |
148 | class CifarDatasetSource(TFDSDatasetSource):
149 | """Parent class for DatasetSource created from cifar10/cifar100 datasets.
150 |
151 | The child class constructor must set _num_classes (integer, number of classes
152 | in the dataset).
153 | """
154 |
155 | def __init__(self, batch_size: int, name: str, image_level_augmentations: str,
156 | batch_level_augmentations: str,
157 | image_size: Optional[int] = None):
158 | """Instantiates the DatasetSource.
159 |
160 | Args:
161 | batch_size: Batch size to use for training and evaluation.
162 | name: Name of the Tensorflow Dataset to use. Should be cifar10 or
163 | cifar100.
164 | image_level_augmentations: Augmentations to apply to the images. Should be
165 | one of:
166 | * none: No augmentations are applied.
167 | * basic: Applies random crops and horizontal translations.
168 | * autoaugment: Applies the best found policy for Cifar from the
169 | AutoAugment paper.
170 | batch_level_augmentations: Augmentations to apply at the batch level. Only
171 | cutout is needed to get SOTA results. The following are implemented:
172 | * none: No augmentations are applied.
173 | * cutout: Applies cutout (https://arxiv.org/abs/1708.04552).
174 | * mixup: Applies mixup (https://arxiv.org/pdf/1710.09412.pdf).
175 | * mixcut: Applies mixup and cutout.
176 | image_size: Size to which the image should be rescaled. If None, the
177 | standard size is used (32x32).
178 | """
179 | assert name in ['cifar10', 'cifar100']
180 | assert image_level_augmentations in ['none', 'basic', 'autoaugment']
181 | assert batch_level_augmentations in ['none', 'cutout']
182 | self._image_size = image_size
183 | self.batch_size = batch_size
184 | if FLAGS.use_test_set:
185 | self.num_training_obs = 50000
186 | train_split_size = self.num_training_obs // jax.host_count()
187 | start = jax.host_id() * train_split_size
188 | train_split = 'train[{}:{}]'.format(start, start + train_split_size)
189 | self._train_ds = tfds.load(name, split=train_split).cache()
190 | self._test_ds = tfds.load(name, split='test').cache()
191 | logging.info('Used test set instead of validation set.')
192 | else:
193 | # Validation split not implemented for multi-host training.
194 | assert jax.host_count() == 1
195 | self._train_ds = tfds.load(name, split='train[:45000]').cache()
196 | self._test_ds = tfds.load(name, split='train[45000:]').cache()
197 | self.num_training_obs = 45000
198 | logging.info('Used validation set instead of test set.')
199 | self._augmentation = image_level_augmentations
200 | if batch_level_augmentations == 'cutout':
201 | self._batch_level_augmentations = augmentation.cutout
202 | elif batch_level_augmentations == 'mixup':
203 | self._batch_level_augmentations = augmentation.mixup
204 | elif batch_level_augmentations == 'mixcut':
205 | self._batch_level_augmentations = (
206 | lambda x: augmentation.cutout(augmentation.mixup(x)))
207 | else:
208 | self._batch_level_augmentations = None
209 | if name == 'cifar10':
210 | self._image_mean = tf.constant([[[0.49139968, 0.48215841, 0.44653091]]])
211 | self._image_std = tf.constant([[[0.24703223, 0.24348513, 0.26158784]]])
212 | else:
213 | self._image_mean = tf.constant([[[0.50707516, 0.48654887, 0.44091784]]])
214 | self._image_std = tf.constant([[[0.26733429, 0.25643846, 0.27615047]]])
215 | self._num_classes = None # To define in child classes
216 |
217 |
218 | class Cifar10(CifarDatasetSource):
219 | """Cifar10 DatasetSource."""
220 |
221 | def __init__(self, batch_size: int, image_level_augmentations: str,
222 | batch_level_augmentations: str, image_size: int = None):
223 | """See parent class for more information."""
224 | super().__init__(batch_size, 'cifar10', image_level_augmentations,
225 | batch_level_augmentations, image_size)
226 | self._num_classes = 10
227 | self._dataset_name = 'cifar10'
228 |
229 |
230 | class Cifar100(CifarDatasetSource):
231 | """Cifar100 DatasetSource."""
232 |
233 | def __init__(self, batch_size: int, image_level_augmentations: str,
234 | batch_level_augmentations: str, image_size: int = None):
235 | """See parent class for more information."""
236 | super().__init__(batch_size, 'cifar100', image_level_augmentations,
237 | batch_level_augmentations, image_size)
238 | self._num_classes = 100
239 | self._dataset_name = 'cifar100'
240 |
241 |
242 | class FashionMnist(TFDSDatasetSource):
243 | """Fashion Mnist dataset."""
244 |
245 | def __init__(self, batch_size: int, image_level_augmentations: str,
246 | batch_level_augmentations: str):
247 | """Instantiates the DatasetSource.
248 |
249 | Args:
250 | batch_size: Batch size to use for training and evaluation.
251 | image_level_augmentations: Augmentations to apply to the images. Should be
252 | one of:
253 | * none: No augmentations are applied.
254 | * basic: Applies random crops and horizontal translations.
255 | batch_level_augmentations: Augmentations to apply at the batch level.
256 | * none: No augmentations are applied.
257 | * cutout: Applies cutout (https://arxiv.org/abs/1708.04552).
258 | """
259 | assert image_level_augmentations in ['none', 'basic']
260 | assert batch_level_augmentations in ['none', 'cutout']
261 | self.batch_size = batch_size
262 | self._image_size = None
263 | if FLAGS.use_test_set:
264 | self._train_ds = tfds.load('fashion_mnist', split='train').cache()
265 | self._test_ds = tfds.load('fashion_mnist', split='test').cache()
266 | logging.info('Used test set instead of validation set.')
267 | self.num_training_obs = 60000
268 | else:
269 | self._train_ds = tfds.load('fashion_mnist', split='train[:54000]').cache()
270 | self._test_ds = tfds.load('fashion_mnist', split='train[54000:]').cache()
271 | self.num_training_obs = 54000
272 | logging.info('Used validation set instead of test set.')
273 | self._augmentation = image_level_augmentations
274 | if batch_level_augmentations == 'cutout':
275 | self._batch_level_augmentations = augmentation.cutout
276 | else:
277 | self._batch_level_augmentations = None
278 | self._image_mean = tf.constant([[[0.1307]]])
279 | self._image_std = tf.constant([[[0.3081]]])
280 | self._num_classes = 10
281 | self._dataset_name = 'fashion_mnist'
282 |
283 |
284 | class SVHN(TFDSDatasetSource):
285 | """SVHN dataset."""
286 |
287 | def __init__(self, batch_size: int, image_level_augmentations: str,
288 | batch_level_augmentations: str):
289 | """Instantiates the DatasetSource.
290 |
291 | Args:
292 | batch_size: Batch size to use for training and evaluation.
293 | image_level_augmentations: Augmentations to apply to the images. Should be
294 | one of:
295 | * none: No augmentations are applied.
296 | * basic: Applies random crops and horizontal translations.
297 | * autoaugment: Applies the best found policy for SVHN from the
298 | AutoAugment paper. Also applies the basic augmentations on top of it.
299 | * aa-only: Same as autoaugment but doesn't apply the basic
300 | augmentations. Should be preferred for SVHN.
301 | batch_level_augmentations: Augmentations to apply at the batch level.
302 | * none: No augmentations are applied.
303 | * cutout: Applies cutout (https://arxiv.org/abs/1708.04552).
304 | """
305 | assert image_level_augmentations in [
306 | 'none', 'basic', 'autoaugment', 'aa-only']
307 | assert batch_level_augmentations in ['none', 'cutout']
308 | self.batch_size = batch_size
309 | self._image_size = None
310 | if FLAGS.use_test_set:
311 | ds_base = tfds.load('svhn_cropped', split='train')
312 | ds_extra = tfds.load('svhn_cropped', split='extra')
313 | self._train_ds = ds_base.concatenate(ds_extra).cache()
314 | self._test_ds = tfds.load('svhn_cropped', split='test').cache()
315 | logging.info('Used test set instead of validation set.')
316 | self.num_training_obs = 73257+531131
317 | else:
318 | ds_base = tfds.load('svhn_cropped', split='train[:65929]')
319 | ds_extra = tfds.load('svhn_cropped', split='extra')
320 | self._train_ds = ds_base.concatenate(ds_extra).cache()
321 | self._test_ds = tfds.load('svhn_cropped', split='train[65929:]').cache()
322 | self.num_training_obs = 65929+531131
323 | logging.info('Used validation set instead of test set.')
324 | self._augmentation = image_level_augmentations
325 | if batch_level_augmentations == 'cutout':
326 | self._batch_level_augmentations = augmentation.cutout
327 | else:
328 | self._batch_level_augmentations = None
329 | self._image_mean = tf.constant([[[0.43090966, 0.4302428, 0.44634357]]])
330 | self._image_std = tf.constant([[[0.19759192, 0.20029082, 0.19811132]]])
331 | self._num_classes = 10
332 | self._dataset_name = 'svhn'
333 |
--------------------------------------------------------------------------------
/sam_jax/datasets/dataset_source_imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Imagenet DatasetSource.
16 |
17 | Initially forked from:
18 | https://github.com/google/flax/blob/master/examples/imagenet/input_pipeline.py
19 | """
20 |
21 | from typing import Dict, Tuple
22 |
23 | from absl import flags
24 | from absl import logging
25 | import jax
26 | from sam.autoaugment import autoaugment
27 | from sam.sam_jax.datasets import dataset_source
28 | import tensorflow as tf
29 | import tensorflow_datasets as tfds
30 | import tensorflow_probability as tfp
31 |
32 |
33 | FLAGS = flags.FLAGS
34 |
35 |
36 | flags.DEFINE_integer('randaug_num_layers', 2,
37 | 'Number of augmentations applied to each images by '
38 | 'RandAugment. Typical value is 2 and is generally not '
39 | 'changed.')
40 | flags.DEFINE_integer('randaug_magnitude', 9,
41 | 'Magnitude of augmentations applied by RandAugment.')
42 | flags.DEFINE_float('imagenet_mixup_alpha', 0.0, 'If > 0, use mixup.')
43 |
44 |
45 | TRAIN_IMAGES = 1281167
46 | EVAL_IMAGES = 50000
47 |
48 | IMAGE_SIZE = 224
49 | CROP_PADDING = 32
50 | MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
51 | STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
52 |
53 |
54 | def _distorted_bounding_box_crop(image_bytes: tf.Tensor,
55 | bbox: tf.Tensor,
56 | min_object_covered: float = 0.1,
57 | aspect_ratio_range: Tuple[float,
58 | float] = (0.75,
59 | 1.33),
60 | area_range: Tuple[float, float] = (0.05, 1.0),
61 | max_attempts: int = 100) -> tf.Tensor:
62 | """Generates cropped_image using one of the bboxes randomly distorted.
63 |
64 | See `tf.image.sample_distorted_bounding_box` for more documentation.
65 |
66 | Args:
67 | image_bytes: `Tensor` of binary image data.
68 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
69 | where each coordinate is [0, 1) and the coordinates are arranged
70 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
71 | image.
72 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
73 | area of the image must contain at least this fraction of any bounding
74 | box supplied.
75 | aspect_ratio_range: An optional list of `float`s. The cropped area of the
76 | image must have an aspect ratio = width / height within this range.
77 | area_range: An optional list of `float`s. The cropped area of the image
78 | must contain a fraction of the supplied image within in this range.
79 | max_attempts: An optional `int`. Number of attempts at generating a cropped
80 | region of the image of the specified constraints. After `max_attempts`
81 | failures, return the entire image.
82 | Returns:
83 | cropped image `Tensor`
84 | """
85 | shape = tf.image.extract_jpeg_shape(image_bytes)
86 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
87 | shape,
88 | bounding_boxes=bbox,
89 | min_object_covered=min_object_covered,
90 | aspect_ratio_range=aspect_ratio_range,
91 | area_range=area_range,
92 | max_attempts=max_attempts,
93 | use_image_if_no_bounding_boxes=True)
94 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
95 |
96 | # Crop the image to the specified bounding box.
97 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
98 | target_height, target_width, _ = tf.unstack(bbox_size)
99 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
100 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
101 |
102 | return image
103 |
104 |
105 | def _resize(image: tf.Tensor, image_size: int) -> tf.Tensor:
106 | """Returns the resized image."""
107 | return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0]
108 |
109 |
110 | def _at_least_x_are_equal(a: tf.Tensor, b: tf.Tensor,
111 | x: int) -> tf.Tensor:
112 | """At least `x` of `a` and `b` `Tensors` are equal."""
113 | match = tf.equal(a, b)
114 | match = tf.cast(match, tf.int32)
115 | return tf.greater_equal(tf.reduce_sum(match), x)
116 |
117 |
118 | def _decode_and_random_crop(image_bytes: tf.Tensor,
119 | image_size: int) -> tf.Tensor:
120 | """Make a random crop of image_size."""
121 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
122 | image = _distorted_bounding_box_crop(
123 | image_bytes,
124 | bbox,
125 | min_object_covered=0.1,
126 | aspect_ratio_range=(3. / 4, 4. / 3.),
127 | area_range=(0.08, 1.0),
128 | max_attempts=10)
129 | original_shape = tf.image.extract_jpeg_shape(image_bytes)
130 | bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
131 |
132 | image = tf.cond(bad, lambda: _decode_and_center_crop(image_bytes, image_size),
133 | lambda: _resize(image, image_size))
134 |
135 | return image
136 |
137 |
138 | def _decode_and_center_crop(image_bytes: tf.Tensor,
139 | image_size: int) -> tf.Tensor:
140 | """Crops to center of image with padding then scales image_size."""
141 | shape = tf.image.extract_jpeg_shape(image_bytes)
142 | image_height = shape[0]
143 | image_width = shape[1]
144 |
145 | padded_center_crop_size = tf.cast(
146 | ((image_size / (image_size + CROP_PADDING)) *
147 | tf.cast(tf.minimum(image_height, image_width), tf.float32)),
148 | tf.int32)
149 |
150 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2
151 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2
152 | crop_window = tf.stack([offset_height, offset_width,
153 | padded_center_crop_size, padded_center_crop_size])
154 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
155 | image = _resize(image, image_size)
156 |
157 | return image
158 |
159 |
160 | def normalize_image(image: tf.Tensor) -> tf.Tensor:
161 | """Returns the normalized image.
162 |
163 | Image is normalized so that the mean and variance of each channel over the
164 | dataset is 0 and 1.
165 |
166 | Args:
167 | image: An image from the Imagenet dataset to normalize.
168 | """
169 | image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
170 | image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
171 | return image
172 |
173 |
174 | def preprocess_for_train(image_bytes: tf.Tensor,
175 | dtype: tf.DType = tf.float32,
176 | image_size: int = IMAGE_SIZE,
177 | use_autoaugment: bool = False) -> tf.Tensor:
178 | """Preprocesses the given image for training.
179 |
180 | Args:
181 | image_bytes: `Tensor` representing an image binary of arbitrary size.
182 | dtype: Data type of the returned image.
183 | image_size: Size of the returned image.
184 | use_autoaugment: If True, will apply autoaugment to the inputs.
185 |
186 | Returns:
187 | A preprocessed image `Tensor`.
188 | """
189 | image = _decode_and_random_crop(image_bytes, image_size)
190 | image = tf.reshape(image, [image_size, image_size, 3])
191 | if use_autoaugment:
192 | logging.info('Using autoaugment.')
193 | image = tf.cast(image, tf.uint8)
194 | image = autoaugment.distort_image_with_randaugment(image,
195 | FLAGS.randaug_num_layers,
196 | FLAGS.randaug_magnitude)
197 | image = tf.cast(image, tf.float32)
198 | image = tf.image.random_flip_left_right(image)
199 | image = normalize_image(image)
200 | image = tf.image.convert_image_dtype(image, dtype=dtype)
201 | return image
202 |
203 |
204 | def preprocess_for_eval(image_bytes: tf.Tensor,
205 | dtype: tf.DType = tf.float32,
206 | image_size: int = IMAGE_SIZE) -> tf.Tensor:
207 | """Preprocesses the given image for evaluation.
208 |
209 | Args:
210 | image_bytes: `Tensor` representing an image binary of arbitrary size.
211 | dtype: Data type of the returned image.
212 | image_size: Size of the returned image.
213 |
214 | Returns:
215 | A preprocessed image `Tensor`.
216 | """
217 | image = _decode_and_center_crop(image_bytes, image_size)
218 | image = tf.reshape(image, [image_size, image_size, 3])
219 | image = normalize_image(image)
220 | image = tf.image.convert_image_dtype(image, dtype=dtype)
221 | return image
222 |
223 |
224 | def load_split(train: bool,
225 | cache: bool) -> tf.data.Dataset:
226 | """Creates a split from the ImageNet dataset using TensorFlow Datasets.
227 |
228 | Args:
229 | train: Whether to load the train or evaluation split.
230 | cache: Whether to cache the dataset.
231 | Returns:
232 | A `tf.data.Dataset`.
233 | """
234 | if train:
235 | split_size = TRAIN_IMAGES // jax.host_count()
236 | start = jax.host_id() * split_size
237 | split = 'train[{}:{}]'.format(start, start + split_size)
238 | else:
239 | # For validation, we load up the dataset on each host. This will have the
240 | # effect of evaluating on the whole dataset num_host times, but will
241 | # prevent size issues. This makes the performance slightly worse when
242 | # evaluating often, but spares us the need to pad the datasets and mask the
243 | # loss accordingly.
244 | split = 'validation'
245 |
246 | ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={
247 | 'image': tfds.decode.SkipDecoding(),
248 | })
249 | ds.options().experimental_threading.private_threadpool_size = 48
250 | ds.options().experimental_threading.max_intra_op_parallelism = 1
251 |
252 | if cache:
253 | ds = ds.cache()
254 |
255 | return ds
256 |
257 |
258 | def mixup(batch: Dict[str, tf.Tensor], alpha: float) -> Dict[str, tf.Tensor]:
259 | """Generates augmented images using Mixup.
260 |
261 | Arguments:
262 | batch: Feature dict containing the images and the labels.
263 | alpha: Float that controls the strength of Mixup regularization.
264 |
265 | Returns:
266 | A feature dict containing the mix-uped images.
267 | """
268 | images, labels = batch['image'], batch['label']
269 | batch_size = 1 # Unique mixing parameter for all samples
270 | mix_weight = tfp.distributions.Beta(alpha, alpha).sample([batch_size, 1])
271 | mix_weight = tf.maximum(mix_weight, 1. - mix_weight)
272 | images_mix_weight = tf.reshape(mix_weight, [batch_size, 1, 1, 1])
273 | images_mix = (
274 | images * images_mix_weight + images[::-1] * (1. - images_mix_weight))
275 | labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight)
276 | return {'image': images_mix, 'label': labels_mix}
277 |
278 |
279 | class Imagenet(dataset_source.DatasetSource):
280 | """Class that loads, preprocess and serves the Imagenet dataset."""
281 |
282 | def __init__(self, batch_size: int, image_size: int,
283 | image_level_augmentations: str = 'none'):
284 | """Instantiates the Imagenet dataset source.
285 |
286 | Args:
287 | batch_size: Global batch size used to train the model.
288 | image_size: Size to which the images should be resized (in number of
289 | pixels).
290 | image_level_augmentations: If set to 'autoaugment', will apply
291 | RandAugment to the training set.
292 | """
293 | self.batch_size = batch_size
294 | self.image_size = image_size
295 | self.num_training_obs = TRAIN_IMAGES
296 | self._train_ds = load_split(train=True, cache=True)
297 | self._test_ds = load_split(train=False, cache=True)
298 | self._num_classes = 1000
299 | self._image_level_augmentations = image_level_augmentations
300 |
301 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset:
302 | """Returns the training set.
303 |
304 | The training set will be batched, and the remainder of the batch will be
305 | dropped (except if use_augmentation is False, in which case we don't drop
306 | the remainder as we are most likely computing the accuracy on the train
307 | set).
308 |
309 | Args:
310 | use_augmentations: Whether we should apply data augmentation (and possibly
311 | cutout) or not.
312 | """
313 | ds = self._train_ds.shuffle(16 * self.batch_size)
314 | ds = ds.map(lambda d: self.decode_example( # pylint:disable=g-long-lambda
315 | d, use_augmentations=use_augmentations),
316 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
317 |
318 | batched = ds.batch(self.batch_size, drop_remainder=use_augmentations)
319 | if use_augmentations and FLAGS.imagenet_mixup_alpha > 0.0:
320 | batched = batched.map(lambda b: mixup(b, FLAGS.imagenet_mixup_alpha),
321 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
322 | return batched
323 |
324 | def get_test(self) -> tf.data.Dataset:
325 | """Returns test set."""
326 | ds = self._test_ds.map(
327 | lambda d: self.decode_example( # pylint:disable=g-long-lambda
328 | d, use_augmentations=False),
329 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
330 | return ds.batch(self.batch_size, drop_remainder=False)
331 |
332 | def decode_example(self, example: Dict[str, tf.Tensor],
333 | use_augmentations: bool) -> Dict[str, tf.Tensor]:
334 | """Decodes the raw examples from the imagenet tensorflow dataset.
335 |
336 | Args:
337 | example: A feature dict as returned by the tensorflow imagenet dataset.
338 | use_augmentations: Whether to use train time data augmentation or not.
339 |
340 | Returns:
341 | A dictionnary with an 'image' tensor and a one hot encoded 'label' tensor.
342 | """
343 | if use_augmentations:
344 | image = preprocess_for_train(
345 | example['image'],
346 | image_size=self.image_size,
347 | use_autoaugment=self._image_level_augmentations == 'autoaugment')
348 | else:
349 | image = preprocess_for_eval(example['image'], image_size=self.image_size)
350 | label = tf.one_hot(
351 | example['label'], depth=self._num_classes, on_value=1.0, off_value=0.0)
352 | return {'image': image, 'label': label}
353 |
--------------------------------------------------------------------------------
/sam_jax/datasets/dataset_source_imagenet_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.datasets.dataset_source_imagenet."""
16 |
17 | from absl.testing import absltest
18 | from sam.sam_jax.datasets import dataset_source_imagenet
19 |
20 |
21 | class DatasetSourceImagenetTest(absltest.TestCase):
22 |
23 | def test_LoadImagenet(self):
24 | dataset = dataset_source_imagenet.Imagenet(
25 | batch_size=16, image_size=127, image_level_augmentations='autoaugment')
26 | batch = next(iter(dataset.get_train(use_augmentations=True)))
27 | self.assertEqual(batch['image'].shape, [16, 127, 127, 3])
28 | self.assertEqual(batch['label'].shape, [16, 1000])
29 |
30 |
31 | if __name__ == '__main__':
32 | absltest.main()
33 |
--------------------------------------------------------------------------------
/sam_jax/datasets/dataset_source_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.datasets.dataset_source."""
16 |
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from sam.sam_jax.datasets import dataset_source
21 |
22 |
23 | class DatasetSourceTest(parameterized.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | ('none', 'none'),
27 | ('cutout', 'cutout'))
28 | def test_LoadCifar10(self, batch_level_augmentation: str):
29 | cifar_10_source = dataset_source.Cifar10(
30 | 2,
31 | image_level_augmentations='autoaugment',
32 | batch_level_augmentations=batch_level_augmentation)
33 | for batch in cifar_10_source.get_train(use_augmentations=True):
34 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3])
35 | self.assertEqual(batch['label'].shape, [2, 10])
36 | break
37 | for batch in cifar_10_source.get_test():
38 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3])
39 | self.assertEqual(batch['label'].shape, [2, 10])
40 | break
41 |
42 | @parameterized.named_parameters(
43 | ('none', 'none'),
44 | ('cutout', 'cutout'))
45 | def test_LoadCifar100(self, batch_level_augmentation: str):
46 | cifar_100_source = dataset_source.Cifar100(
47 | 2,
48 | image_level_augmentations='autoaugment',
49 | batch_level_augmentations=batch_level_augmentation)
50 | for batch in cifar_100_source.get_train(use_augmentations=True):
51 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3])
52 | self.assertEqual(batch['label'].shape, [2, 100])
53 | break
54 | for batch in cifar_100_source.get_test():
55 | self.assertEqual(batch['image'].shape, [2, 32, 32, 3])
56 | self.assertEqual(batch['label'].shape, [2, 100])
57 | break
58 |
59 | @parameterized.named_parameters(
60 | ('none', 'none'),
61 | ('cutout', 'cutout'))
62 | def test_LoadFashionMnist(self, batch_level_augmentation: str):
63 | fashion_mnist_source = dataset_source.FashionMnist(
64 | 2,
65 | image_level_augmentations='basic',
66 | batch_level_augmentations=batch_level_augmentation)
67 | for batch in fashion_mnist_source.get_train(use_augmentations=True):
68 | self.assertEqual(batch['image'].shape, [2, 28, 28, 1])
69 | self.assertEqual(batch['label'].shape, [2, 10])
70 | break
71 | for batch in fashion_mnist_source.get_test():
72 | self.assertEqual(batch['image'].shape, [2, 28, 28, 1])
73 | self.assertEqual(batch['label'].shape, [2, 10])
74 | break
75 |
76 |
77 | if __name__ == '__main__':
78 | absltest.main()
79 |
--------------------------------------------------------------------------------
/sam_jax/efficientnet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/sam_jax/efficientnet/efficientnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Defines Efficientnet model."""
16 |
17 | import copy
18 | import math
19 | from typing import Any, Optional, Tuple, Union
20 |
21 | from absl import flags
22 | from absl import logging
23 | import flax
24 | from flax import nn
25 | import jax
26 | from jax import numpy as jnp
27 | import tensorflow as tf
28 |
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def name_to_image_size(name: str) -> int:
34 | """Returns the expected image size for a given model.
35 |
36 | If the model is not a recognized efficientnet model, will default to the
37 | standard resolution of 224 (for Resnet, etc...).
38 |
39 | Args:
40 | name: Name of the efficientnet model (ex: efficientnet-b0).
41 | """
42 | image_sizes = {
43 | 'efficientnet-b0': 224,
44 | 'efficientnet-b1': 240,
45 | 'efficientnet-b2': 260,
46 | 'efficientnet-b3': 300,
47 | 'efficientnet-b4': 380,
48 | 'efficientnet-b5': 456,
49 | 'efficientnet-b6': 528,
50 | 'efficientnet-b7': 600,
51 | 'efficientnet-b8': 672,
52 | 'efficientnet-l2': 800,
53 | 'efficientnet-l2-475': 475,
54 | }
55 | return image_sizes.get(name, 224)
56 |
57 |
58 | # Relevant initializers. The original implementation uses fan_out Kaiming init.
59 |
60 | conv_kernel_init_fn = jax.nn.initializers.variance_scaling(
61 | 2.0, 'fan_out', 'truncated_normal')
62 |
63 | dense_kernel_init_fn = jax.nn.initializers.variance_scaling(
64 | 1 / 3.0, 'fan_out', 'uniform')
65 |
66 |
67 | class DepthwiseConv(flax.nn.Module):
68 | """Depthwise convolution that matches tensorflow's conventions.
69 |
70 | In Tensorflow, the shapes of depthwise kernels don't match the shapes of a
71 | regular convolutional kernel of appropriate feature_group_count.
72 | It is safer to use this class instead of the regular Conv (easier port of
73 | tensorflow checkpoints, fan_out initialization of the previous layer will
74 | match the tensorflow behavior, etc...).
75 | """
76 |
77 | def apply(self,
78 | inputs: jnp.ndarray,
79 | features: int,
80 | kernel_size: Tuple[int, int],
81 | strides: bool = None,
82 | padding: str = 'SAME',
83 | input_dilation: int = None,
84 | kernel_dilation: int = None,
85 | bias: bool = True,
86 | dtype: jnp.dtype = jnp.float32,
87 | precision=None,
88 | kernel_init=flax.nn.initializers.lecun_normal(),
89 | bias_init=flax.nn.initializers.zeros) -> jnp.ndarray:
90 | """Applies a convolution to the inputs.
91 |
92 | Args:
93 | inputs: Input data with dimensions (batch, spatial_dims..., features).
94 | features: Number of convolution filters.
95 | kernel_size: Shape of the convolutional kernel.
96 | strides: A sequence of `n` integers, representing the inter-window
97 | strides.
98 | padding: Either the string `'SAME'`, the string `'VALID'`, or a sequence
99 | of `n` `(low, high)` integer pairs that give the padding to apply before
100 | and after each spatial dimension.
101 | input_dilation: `None`, or a sequence of `n` integers, giving the
102 | dilation factor to apply in each spatial dimension of `inputs`.
103 | Convolution with input dilation `d` is equivalent to transposed
104 | convolution with stride `d`.
105 | kernel_dilation: `None`, or a sequence of `n` integers, giving the
106 | dilation factor to apply in each spatial dimension of the convolution
107 | kernel. Convolution with kernel dilation is also known as 'atrous
108 | convolution'.
109 | bias: Whether to add a bias to the output (default: True).
110 | dtype: The dtype of the computation (default: float32).
111 | precision: Numerical precision of the computation see `jax.lax.Precision`
112 | for details.
113 | kernel_init: Initializer for the convolutional kernel.
114 | bias_init: Initializer for the bias.
115 |
116 | Returns:
117 | The convolved data.
118 | """
119 |
120 | inputs = jnp.asarray(inputs, dtype)
121 | in_features = inputs.shape[-1]
122 |
123 | if strides is None:
124 | strides = (1,) * (inputs.ndim - 2)
125 |
126 | kernel_shape = kernel_size + (features, 1)
127 | # Naming convention follows tensorflow.
128 | kernel = self.param('depthwise_kernel', kernel_shape, kernel_init)
129 | kernel = jnp.asarray(kernel, dtype)
130 |
131 | # Need to transpose to convert tensorflow-shaped kernel to lax-shaped kernel
132 | kernel = jnp.transpose(kernel, [0, 1, 3, 2])
133 |
134 | dimension_numbers = flax.nn.linear._conv_dimension_numbers(inputs.shape) # pylint:disable=protected-access
135 |
136 | y = jax.lax.conv_general_dilated(
137 | inputs,
138 | kernel,
139 | strides,
140 | padding,
141 | lhs_dilation=input_dilation,
142 | rhs_dilation=kernel_dilation,
143 | dimension_numbers=dimension_numbers,
144 | feature_group_count=in_features,
145 | precision=precision)
146 |
147 | if bias:
148 | bias = self.param('bias', (features,), bias_init)
149 | bias = jnp.asarray(bias, dtype)
150 | y = y + bias
151 | return y
152 |
153 |
154 | # pytype: disable=attribute-error
155 | # pylint:disable=unused-argument
156 | class BlockConfig(object):
157 | """Class that contains configuration parameters for a single block."""
158 |
159 | def __init__(self,
160 | input_filters: int = 0,
161 | output_filters: int = 0,
162 | kernel_size: int = 3,
163 | num_repeat: int = 1,
164 | expand_ratio: int = 1,
165 | strides: Tuple[int, int] = (1, 1),
166 | se_ratio: Optional[float] = None,
167 | id_skip: bool = True,
168 | fused_conv: bool = False,
169 | conv_type: str = 'depthwise'):
170 | for arg in locals().items():
171 | setattr(self, *arg)
172 |
173 |
174 | class ModelConfig(object):
175 | """Class that contains configuration parameters for the model."""
176 |
177 | def __init__(
178 | self,
179 | width_coefficient: float = 1.0,
180 | depth_coefficient: float = 1.0,
181 | resolution: int = 224,
182 | dropout_rate: float = 0.2,
183 | blocks: Tuple[BlockConfig, ...] = (
184 | # (input_filters, output_filters, kernel_size, num_repeat,
185 | # expand_ratio, strides, se_ratio)
186 | # pylint: disable=bad-whitespace
187 | BlockConfig(32, 16, 3, 1, 1, (1, 1), 0.25),
188 | BlockConfig(16, 24, 3, 2, 6, (2, 2), 0.25),
189 | BlockConfig(24, 40, 5, 2, 6, (2, 2), 0.25),
190 | BlockConfig(40, 80, 3, 3, 6, (2, 2), 0.25),
191 | BlockConfig(80, 112, 5, 3, 6, (1, 1), 0.25),
192 | BlockConfig(112, 192, 5, 4, 6, (2, 2), 0.25),
193 | BlockConfig(192, 320, 3, 1, 6, (1, 1), 0.25),
194 | # pylint: enable=bad-whitespace
195 | ),
196 | stem_base_filters: int = 32,
197 | top_base_filters: int = 1280,
198 | activation: str = 'swish',
199 | batch_norm: str = 'default',
200 | bn_momentum: float = 0.99,
201 | bn_epsilon: float = 1e-3,
202 | # While the original implementation used a weight decay of 1e-5,
203 | # tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras
204 | weight_decay: float = 5e-6,
205 | drop_connect_rate: float = 0.2,
206 | depth_divisor: int = 8,
207 | min_depth: Optional[int] = None,
208 | use_se: bool = True,
209 | input_channels: int = 3,
210 | model_name: str = 'efficientnet',
211 | rescale_input: bool = True,
212 | data_format: str = 'channels_last',
213 | dtype: str = 'float32'):
214 | """Default Config for Efficientnet-B0."""
215 | for arg in locals().items():
216 | setattr(self, *arg)
217 | # pylint:enable=unused-argument
218 |
219 |
220 | MODEL_CONFIGS = {
221 | # (width, depth, resolution, dropout)
222 | 'efficientnet-b0': ModelConfig(1.0, 1.0, 224, 0.2),
223 | 'efficientnet-b1': ModelConfig(1.0, 1.1, 240, 0.2),
224 | 'efficientnet-b2': ModelConfig(1.1, 1.2, 260, 0.3),
225 | 'efficientnet-b3': ModelConfig(1.2, 1.4, 300, 0.3),
226 | 'efficientnet-b4': ModelConfig(1.4, 1.8, 380, 0.4),
227 | 'efficientnet-b5': ModelConfig(1.6, 2.2, 456, 0.4),
228 | 'efficientnet-b6': ModelConfig(1.8, 2.6, 528, 0.5),
229 | 'efficientnet-b7': ModelConfig(2.0, 3.1, 600, 0.5),
230 | 'efficientnet-b8': ModelConfig(2.2, 3.6, 672, 0.5),
231 | 'efficientnet-l2': ModelConfig(4.3, 5.3, 800, 0.5),
232 | 'efficientnet-l2-475': ModelConfig(4.3, 5.3, 475, 0.5),
233 | }
234 |
235 |
236 | def round_filters(filters: int,
237 | config: ModelConfig) -> int:
238 | """Returns rounded number of filters based on width coefficient."""
239 | width_coefficient = config.width_coefficient
240 | min_depth = config.min_depth
241 | divisor = config.depth_divisor
242 | orig_filters = filters
243 |
244 | if not width_coefficient:
245 | return filters
246 |
247 | filters *= width_coefficient
248 | min_depth = min_depth or divisor
249 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
250 | # Make sure that round down does not go down by more than 10%.
251 | if new_filters < 0.9 * filters:
252 | new_filters += divisor
253 | logging.info('round_filter input=%s output=%s', orig_filters, new_filters)
254 | return int(new_filters)
255 |
256 |
257 | def round_repeats(repeats: int, depth_coefficient: float) -> int:
258 | """Returns rounded number of repeats based on depth coefficient."""
259 | return int(math.ceil(depth_coefficient * repeats))
260 |
261 |
262 | def conv2d(inputs: tf.Tensor,
263 | conv_filters: Optional[int],
264 | config: ModelConfig,
265 | kernel_size: Union[int, Tuple[int, int]] = (1, 1),
266 | strides: Tuple[int, int] = (1, 1),
267 | use_batch_norm: bool = True,
268 | use_bias: bool = False,
269 | activation: Any = None,
270 | depthwise: bool = False,
271 | train: bool = True,
272 | conv_name: str = None,
273 | bn_name: str = None) -> jnp.ndarray:
274 | """Convolutional layer with possibly batch norm and activation.
275 |
276 | Args:
277 | inputs: Input data with dimensions (batch, spatial_dims..., features).
278 | conv_filters: Number of convolution filters.
279 | config: Configuration for the model.
280 | kernel_size: Size of the kernel, as a tuple of int.
281 | strides: Strides for the convolution, as a tuple of int.
282 | use_batch_norm: Whether batch norm should be applied to the output.
283 | use_bias: Whether we should add bias to the output of the first convolution.
284 | activation: Name of the activation function to use.
285 | depthwise: If true, will use depthwise convolutions.
286 | train: Whether the model should behave in training or inference mode.
287 | conv_name: Name to give to the convolution layer.
288 | bn_name: Name to give to the batch norm layer.
289 |
290 | Returns:
291 | The output of the convolutional layer.
292 | """
293 | conv_fn = DepthwiseConv if depthwise else flax.nn.Conv
294 | kernel_size = ((kernel_size, kernel_size)
295 | if isinstance(kernel_size, int) else tuple(kernel_size))
296 | conv_name = conv_name if conv_name else 'conv2d'
297 | bn_name = bn_name if bn_name else 'batch_normalization'
298 |
299 | x = conv_fn(
300 | inputs,
301 | conv_filters,
302 | kernel_size,
303 | tuple(strides),
304 | padding='SAME',
305 | bias=use_bias,
306 | kernel_init=conv_kernel_init_fn,
307 | name=conv_name)
308 |
309 | if use_batch_norm:
310 | x = nn.BatchNorm(
311 | x,
312 | use_running_average=not train or FLAGS.from_pretrained_checkpoint,
313 | momentum=config.bn_momentum,
314 | epsilon=config.bn_epsilon,
315 | name=bn_name,
316 | axis_name='batch')
317 |
318 | if activation is not None:
319 | x = getattr(flax.nn.activation, activation.lower())(x)
320 | return x
321 |
322 |
323 | def stochastic_depth(inputs: jnp.ndarray,
324 | survival_probability: float,
325 | deterministic: bool = False,
326 | rng: Optional[jnp.ndarray] = None) -> jnp.ndarray:
327 | """Applies stochastic depth.
328 |
329 | Args:
330 | inputs: The inputs that should be randomly masked.
331 | survival_probability: 1 - the probablity of masking out a value.
332 | deterministic: If false the inputs are scaled by `1 / (1 - rate)` and
333 | masked, whereas if true, no mask is applied and the inputs are returned as
334 | is.
335 | rng: An optional `jax.random.PRNGKey`. By default `nn.make_rng()` will
336 | be used.
337 |
338 | Returns:
339 | The masked inputs.
340 | """
341 | if survival_probability == 1.0 or deterministic:
342 | return inputs
343 |
344 | if rng is None:
345 | rng = flax.nn.make_rng()
346 | mask_shape = [inputs.shape[0]]+ [1 for _ in inputs.shape[1:]]
347 | mask = jax.random.bernoulli(rng, p=survival_probability, shape=mask_shape)
348 | mask = jnp.tile(mask, [1] + list(inputs.shape[1:]))
349 | return jax.lax.select(mask, inputs / survival_probability,
350 | jnp.zeros_like(inputs))
351 |
352 |
353 | class SqueezeExcite(flax.nn.Module):
354 | """SqueezeExite block (see paper for more details.)"""
355 |
356 | def apply(self,
357 | x: jnp.ndarray,
358 | filters: int,
359 | block: BlockConfig,
360 | config: ModelConfig,
361 | train: bool) -> jnp.ndarray:
362 | """Applies a convolution to the inputs.
363 |
364 | Args:
365 | x: Input data with dimensions (batch, spatial_dims..., features).
366 | filters: Number of convolution filters.
367 | block: Configuration for this block.
368 | config: Configuration for the model.
369 | train: Whether the model is in training or inference mode.
370 |
371 | Returns:
372 | The output of the squeeze excite block.
373 | """
374 | conv_index = 0
375 | num_reduced_filters = max(1, int(block.input_filters * block.se_ratio))
376 |
377 | se = flax.nn.avg_pool(x, x.shape[1:3])
378 | se = conv2d(
379 | se,
380 | num_reduced_filters,
381 | config,
382 | use_bias=True,
383 | use_batch_norm=False,
384 | activation=config.activation,
385 | conv_name='reduce_conv2d_' + str(conv_index),
386 | train=train)
387 | conv_index += 1
388 |
389 | se = conv2d(
390 | se,
391 | filters,
392 | config,
393 | use_bias=True,
394 | use_batch_norm=False,
395 | activation='sigmoid',
396 | conv_name='expand_conv2d_' + str(conv_index),
397 | train=train)
398 | conv_index += 1
399 | x = x * se
400 | return x
401 |
402 |
403 | class MBConvBlock(flax.nn.Module):
404 | """Main building component of Efficientnet."""
405 |
406 | def apply(self,
407 | inputs: jnp.ndarray,
408 | block: BlockConfig,
409 | config: ModelConfig,
410 | train: bool = False) -> jnp.ndarray:
411 | """Mobile Inverted Residual Bottleneck.
412 |
413 | Args:
414 | inputs: Input to the block.
415 | block: BlockConfig, arguments to create a Block.
416 | config: ModelConfig, a set of model parameters.
417 | train: Whether we are training or predicting.
418 |
419 | Returns:
420 | The output of the block.
421 | """
422 | use_se = config.use_se
423 | activation = config.activation
424 | drop_connect_rate = config.drop_connect_rate
425 | use_depthwise = block.conv_type != 'no_depthwise'
426 |
427 | filters = block.input_filters * block.expand_ratio
428 |
429 | x = inputs
430 | bn_index = 0
431 | conv_index = 0
432 |
433 | if block.fused_conv:
434 | # If we use fused mbconv, skip expansion and use regular conv.
435 | x = conv2d(
436 | x,
437 | filters,
438 | config,
439 | kernel_size=block.kernel_size,
440 | strides=block.strides,
441 | activation=activation,
442 | conv_name='fused_conv2d_' + str(conv_index),
443 | bn_name='batch_normalization_' + str(bn_index),
444 | train=train)
445 | bn_index += 1
446 | conv_index += 1
447 | else:
448 | if block.expand_ratio != 1:
449 | # Expansion phase
450 | kernel_size = (1, 1) if use_depthwise else (3, 3)
451 | x = conv2d(
452 | x,
453 | filters,
454 | config,
455 | kernel_size=kernel_size,
456 | activation=activation,
457 | conv_name='expand_conv2d_' + str(conv_index),
458 | bn_name='batch_normalization_' + str(bn_index),
459 | train=train)
460 | bn_index += 1
461 | conv_index += 1
462 | # Depthwise Convolution
463 | if use_depthwise:
464 | x = conv2d(x,
465 | conv_filters=x.shape[-1], # Depthwise conv
466 | config=config,
467 | kernel_size=block.kernel_size,
468 | strides=block.strides,
469 | activation=activation,
470 | depthwise=True,
471 | conv_name='depthwise_conv2d',
472 | bn_name='batch_normalization_' + str(bn_index),
473 | train=train)
474 | bn_index += 1
475 |
476 | # Squeeze and Excitation phase
477 | if use_se:
478 | assert block.se_ratio is not None
479 | assert 0 < block.se_ratio <= 1
480 | x = SqueezeExcite(x, filters, block, config, train=train)
481 |
482 | # Output phase
483 | x = conv2d(
484 | x,
485 | block.output_filters,
486 | config,
487 | activation=None,
488 | conv_name='project_conv2d_' + str(conv_index),
489 | bn_name='batch_normalization_' + str(bn_index),
490 | train=train)
491 | conv_index += 1
492 |
493 | if (block.id_skip and all(s == 1 for s in block.strides) and
494 | block.input_filters == block.output_filters):
495 | if drop_connect_rate and drop_connect_rate > 0:
496 | survival_probability = 1 - drop_connect_rate
497 | x = stochastic_depth(x, survival_probability, deterministic=not train)
498 | x = x + inputs
499 |
500 | return x
501 |
502 |
503 | class Stem(flax.nn.Module):
504 | """Initial block of Efficientnet."""
505 |
506 | def apply(self,
507 | x: jnp.ndarray,
508 | config: ModelConfig,
509 | train: bool = True) -> jnp.ndarray:
510 | """Returns the output of the stem block.
511 |
512 | Args:
513 | x: The input to the block.
514 | config: ModelConfig, a set of model parameters.
515 | train: Whether we are training or predicting.
516 | """
517 | resolution = config.resolution
518 | if x.shape[1:3] != (resolution, resolution):
519 | raise ValueError('Wrong input size. Model was expecting ' +
520 | 'resolution {} '.format((resolution, resolution)) +
521 | 'but got input of resolution {}'.format(x.shape[1:3]))
522 |
523 | # Build stem
524 | x = conv2d(
525 | x,
526 | round_filters(config.stem_base_filters, config),
527 | config,
528 | kernel_size=(3, 3),
529 | strides=(2, 2),
530 | activation=config.activation,
531 | train=train)
532 | return x
533 |
534 |
535 | class Head(flax.nn.Module):
536 | """Final block of Efficientnet."""
537 |
538 | def apply(self,
539 | x: jnp.ndarray,
540 | config: ModelConfig,
541 | num_classes: int,
542 | train: bool = True) -> jnp.ndarray:
543 | """Returns the output of the head block.
544 |
545 | Args:
546 | x: The input to the block.
547 | config: A set of model parameters.
548 | num_classes: Dimension of the output of the model.
549 | train: Whether we are training or predicting.
550 | """
551 | # Build top
552 | x = conv2d(
553 | x,
554 | round_filters(config.top_base_filters, config),
555 | config,
556 | activation=config.activation,
557 | train=train)
558 |
559 | # Build classifier
560 | x = flax.nn.avg_pool(x, x.shape[1:3])
561 | if config.dropout_rate and config.dropout_rate > 0:
562 | x = flax.nn.dropout(x, config.dropout_rate, deterministic=not train)
563 | x = flax.nn.Dense(
564 | x, num_classes, kernel_init=dense_kernel_init_fn, name='dense')
565 | x = x.reshape([x.shape[0], -1])
566 | return x
567 |
568 |
569 | class EfficientNet(flax.nn.Module):
570 | """Implements EfficientNet model."""
571 |
572 | def apply(self,
573 | x: jnp.ndarray,
574 | config: ModelConfig,
575 | num_classes: int = 1000,
576 | train: bool = True) -> jnp.ndarray:
577 | """Returns the output of the EfficientNet model.
578 |
579 | Args:
580 | x: The input batch of images.
581 | config: The model config.
582 | num_classes: Dimension of the output layer.
583 | train: Whether we are in training or inference.
584 |
585 | Returns:
586 | The output of efficientnet
587 | """
588 | config = copy.deepcopy(config)
589 | depth_coefficient = config.depth_coefficient
590 | blocks = config.blocks
591 | drop_connect_rate = config.drop_connect_rate
592 |
593 | resolution = config.resolution
594 | if x.shape[1:3] != (resolution, resolution):
595 | raise ValueError('Wrong input size. Model was expecting ' +
596 | 'resolution {} '.format((resolution, resolution)) +
597 | 'but got input of resolution {}'.format(x.shape[1:3]))
598 |
599 | # Build stem
600 | x = Stem(x, config, train=train)
601 |
602 | # Build blocks
603 | num_blocks_total = sum(
604 | round_repeats(block.num_repeat, depth_coefficient) for block in blocks)
605 | block_num = 0
606 |
607 | for block in blocks:
608 | assert block.num_repeat > 0
609 | # Update block input and output filters based on depth multiplier
610 | block.input_filters = round_filters(block.input_filters, config)
611 | block.output_filters = round_filters(block.output_filters, config)
612 | block.num_repeat = round_repeats(block.num_repeat, depth_coefficient)
613 |
614 | # The first block needs to take care of stride and filter size increase
615 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
616 | config.drop_connect_rate = drop_rate
617 | x = MBConvBlock(x, block, config, train=train)
618 | block_num += 1
619 | if block.num_repeat > 1:
620 | block.input_filters = block.output_filters
621 | block.strides = [1, 1]
622 |
623 | for _ in range(block.num_repeat - 1):
624 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
625 | config.drop_connect_rate = drop_rate
626 | x = MBConvBlock(x, block, config, train=train)
627 | block_num += 1
628 |
629 | # Build top
630 | x = Head(x, config, num_classes, train=train)
631 |
632 | return x
633 | # pytype: enable=attribute-error
634 |
635 |
636 | def get_efficientnet_module(model_name: str,
637 | num_classes: int = 1000) -> EfficientNet:
638 | """Returns an EfficientNet module for a given architecture.
639 |
640 | Args:
641 | model_name: Name of the Efficientnet architecture to use (example:
642 | efficientnet-b0).
643 | num_classes: Dimension of the output layer.
644 | """
645 | return EfficientNet.partial(config=MODEL_CONFIGS[model_name],
646 | num_classes=num_classes)
647 |
--------------------------------------------------------------------------------
/sam_jax/efficientnet/efficientnet_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.efficientnet.efficientnet."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 | import numpy as np
20 | from sam.sam_jax.efficientnet import efficientnet
21 | from sam.sam_jax.imagenet_models import load_model
22 |
23 |
24 | class EfficientnetTest(absltest.TestCase):
25 |
26 | def _test_model_params(
27 | self, model_name: str, image_size: int, expected_params: int):
28 | module = efficientnet.get_efficientnet_module(model_name)
29 | model, _ = load_model.create_image_model(jax.random.PRNGKey(0),
30 | batch_size=1,
31 | image_size=image_size,
32 | module=module)
33 | num_params = sum(np.prod(e.shape) for e in jax.tree_leaves(model))
34 | self.assertEqual(num_params, expected_params)
35 |
36 | def test_efficientnet_b0(self):
37 | self._test_model_params('efficientnet-b0', 224, expected_params=5288548)
38 |
39 | def test_efficientnet_b1(self):
40 | self._test_model_params('efficientnet-b1', 240, expected_params=7794184)
41 |
42 | def test_efficientnet_b2(self):
43 | self._test_model_params('efficientnet-b2', 260, expected_params=9109994)
44 |
45 | def test_efficientnet_b3(self):
46 | self._test_model_params('efficientnet-b3', 300, expected_params=12233232)
47 |
48 | def test_efficientnet_b4(self):
49 | self._test_model_params('efficientnet-b4', 380, expected_params=19341616)
50 |
51 | def test_efficientnet_b5(self):
52 | self._test_model_params('efficientnet-b5', 456, expected_params=30389784)
53 |
54 | def test_efficientnet_b6(self):
55 | self._test_model_params('efficientnet-b6', 528, expected_params=43040704)
56 |
57 | def test_efficientnet_b7(self):
58 | self._test_model_params('efficientnet-b7', 600, expected_params=66347960)
59 |
60 |
61 | if __name__ == '__main__':
62 | absltest.main()
63 |
--------------------------------------------------------------------------------
/sam_jax/efficientnet/optim.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """EfficientNet is trained with TF1 RMSProp, which is reproduced here.
16 |
17 | This version has two difference with the FLAX RMSProp version:
18 | * It supports momentum, which the FLAX version does not.
19 | * The moving average of the squared gradient is initialized to 1 as in TF1,
20 | instead of 0 as in everywhere else.
21 | """
22 |
23 | from typing import Any, Optional, Tuple
24 |
25 | from flax import struct
26 | from flax.optim.base import OptimizerDef
27 | import jax
28 | import jax.numpy as jnp
29 | import numpy as onp
30 |
31 |
32 | # pytype:disable=wrong-arg-count
33 | @struct.dataclass
34 | class _RMSPropHyperParams:
35 | """RMSProp hyper parameters."""
36 |
37 | learning_rate: float
38 | beta: float
39 | beta2: float
40 | eps: float
41 |
42 |
43 | @struct.dataclass
44 | class _RMSPropParamState:
45 | """RMSProp parameter state."""
46 |
47 | v: onp.ndarray
48 | momentum: onp.ndarray
49 |
50 |
51 | class RMSProp(OptimizerDef):
52 | """RMSProp optimizer."""
53 |
54 | def __init__(self, learning_rate: Optional[float] = None, beta: float = 0.0,
55 | beta2: float = 0.9, eps: float = 1e-8):
56 | """Instantiates the RMSProp optimizer.
57 |
58 | Args:
59 | learning_rate: The step size used to update the parameters.
60 | beta: Momentum.
61 | beta2: The coefficient used for the moving average of the
62 | gradient magnitude (default: 0.9).
63 | eps: The term added to the gradient magnitude estimate for
64 | numerical stability.
65 | """
66 | hyper_params = _RMSPropHyperParams(learning_rate, beta, beta2, eps)
67 | super().__init__(hyper_params)
68 |
69 | def init_param_state(self, param: jnp.ndarray) -> _RMSPropParamState:
70 | """Initializes parameter state. See base class."""
71 | return _RMSPropParamState(
72 | jnp.ones_like(param), jnp.zeros_like(param))
73 |
74 | def apply_param_gradient(
75 | self, step: jnp.ndarray,
76 | hyper_params: _RMSPropHyperParams,
77 | param: jnp.ndarray,
78 | state: _RMSPropParamState,
79 | grad: jnp.ndarray) -> Tuple[jnp.ndarray, _RMSPropParamState]:
80 | """Applies per-parameter gradients. See base class."""
81 | assert hyper_params.learning_rate is not None, 'no learning rate provided.'
82 | new_v = hyper_params.beta2 * state.v + (
83 | 1.0 - hyper_params.beta2) * jnp.square(grad)
84 | grad = grad / jnp.sqrt(new_v + hyper_params.eps)
85 | new_momentum = hyper_params.beta * state.momentum + grad
86 | new_param = param - hyper_params.learning_rate * new_momentum
87 | new_state = _RMSPropParamState(new_v, new_momentum)
88 |
89 | return new_param, new_state
90 |
91 |
92 | # pytype:disable=attribute-error
93 | @struct.dataclass
94 | class ExponentialMovingAverage:
95 | """Exponential Moving Average as implemented in Tensorflow."""
96 |
97 | # Moving average of the parameters.
98 | param_ema: Any
99 | # Decay to use for the update (typical values are 0.999, 0.9999, etc...).
100 | decay: float
101 | # For how many steps we should just keep the new parameters instead of an
102 | # average (useful if we don't want the initial weights to be included in the
103 | # average).
104 | warmup_steps: int
105 |
106 | def update_moving_average(self, new_target: Any,
107 | step: jnp.ndarray) -> Any:
108 | """Updates the moving average of the target.
109 |
110 | Args:
111 | new_target: New values of the target (example: weights of a network
112 | after gradient step).
113 | step: Current step (used only for warmup).
114 |
115 | Returns:
116 | The updated ExponentialMovingAverage.
117 | """
118 | factor = jnp.float32(step >= self.warmup_steps)
119 | delta = step - self.warmup_steps
120 | decay = jnp.minimum(self.decay, (1. + delta) / (10. + delta))
121 | decay *= factor
122 | weight_ema = jax.tree_multimap(
123 | lambda a, b: (1 - decay) * a + decay * b, new_target, self.param_ema)
124 | return self.replace(param_ema=weight_ema)
125 | # pytype:enable=attribute-error
126 |
--------------------------------------------------------------------------------
/sam_jax/efficientnet/optim_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for google3.learning.neurosurgeon.research.image_classification.efficientnet.optim."""
16 |
17 | from absl.testing import absltest
18 | from jax.config import config
19 | import numpy as onp
20 | from sam.sam_jax.efficientnet import optim
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | # Use double precision for better comparison with Tensorflow version.
25 | config.update("jax_enable_x64", True)
26 |
27 |
28 | class OptimTest(tf.test.TestCase):
29 |
30 | def test_RMSProp(self):
31 | """Updates should match Tensorflow1 behavior."""
32 | lr, mom, rho = 0.5, 0.7, 0.8
33 | onp.random.seed(0)
34 | w0 = onp.random.normal(size=[17, 13, 1])
35 | num_steps = 10
36 | # First compute weights updates for TF1 version.
37 | tf1_updated_weights = []
38 | with tf.Session() as sess:
39 | var0 = tf.Variable(w0, trainable=True)
40 | opt = tf.train.RMSPropOptimizer(
41 | learning_rate=lr, decay=rho, momentum=mom, epsilon=0.001)
42 | loss = lambda: (var0**2) / 2.0
43 | step = opt.minimize(loss, var_list=[var0])
44 | sess.run(tf.global_variables_initializer())
45 | for _ in range(num_steps):
46 | sess.run(step)
47 | tf1_updated_weights.append(sess.run(var0))
48 | # Now compute the updates for FLAX version.
49 | flax_updated_weights = []
50 | optimizer_def = optim.RMSProp(
51 | learning_rate=lr, beta=mom, beta2=rho, eps=0.001)
52 | ref_opt = optimizer_def.create(w0)
53 | for _ in range(num_steps):
54 | gradient = ref_opt.target
55 | ref_opt = ref_opt.apply_gradient(gradient)
56 | flax_updated_weights.append(ref_opt.target)
57 | for a, b in zip(tf1_updated_weights, flax_updated_weights):
58 | self.assertAllClose(a, b)
59 |
60 | def test_RMSPropWithEMA(self):
61 | """Updates should match Tensorflow1 behavior."""
62 | lr, mom, rho, ema_decay = 0.05, 0.4, 0.8, 1.0
63 | onp.random.seed(0)
64 | w0 = onp.array([1.0])
65 | num_steps = 10
66 | # First compute weights updates for TF1 version.
67 | tf1_updated_weights = []
68 | with tf.Session() as sess:
69 | global_step = tf.train.get_or_create_global_step()
70 | ema = tf.train.ExponentialMovingAverage(
71 | decay=ema_decay, num_updates=global_step)
72 | var0 = tf.Variable(w0, trainable=True)
73 | opt = tf.train.RMSPropOptimizer(
74 | learning_rate=lr, decay=rho, momentum=mom, epsilon=0.000)
75 | loss = lambda: (var0**2) / 2.0
76 | step = opt.minimize(loss, var_list=[var0], global_step=global_step)
77 | with tf.control_dependencies([step]):
78 | step = ema.apply([var0])
79 | sess.run(tf.global_variables_initializer())
80 | for _ in range(num_steps):
81 | sess.run(step)
82 | tf1_updated_weights.append(sess.run(ema.average(var0)))
83 | # Now computes the updates for FLAX version.
84 | flax_updated_weights = []
85 | optimizer_def = optim.RMSProp(
86 | learning_rate=lr, beta=mom, beta2=rho, eps=0.000)
87 | ref_opt = optimizer_def.create(w0)
88 | ema = optim.ExponentialMovingAverage(w0, ema_decay, 0)
89 | for _ in range(num_steps):
90 | gradient = ref_opt.target
91 | ref_opt = ref_opt.apply_gradient(gradient)
92 | ema = ema.update_moving_average(ref_opt.target, ref_opt.state.step)
93 | flax_updated_weights.append(ema.param_ema)
94 | for a, b in zip(tf1_updated_weights, flax_updated_weights):
95 | self.assertAllClose(a, b)
96 |
97 |
98 | if __name__ == "__main__":
99 | absltest.main()
100 |
--------------------------------------------------------------------------------
/sam_jax/imagenet_models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/sam_jax/imagenet_models/load_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Build FLAX models for image classification."""
16 |
17 | from typing import Optional, Tuple
18 |
19 | from absl import flags
20 | import flax
21 | from flax.training import checkpoints
22 | import jax
23 | from jax import numpy as jnp
24 | from jax import random
25 |
26 | from sam.sam_jax.efficientnet import efficientnet
27 | from sam.sam_jax.imagenet_models import resnet
28 |
29 |
30 | FLAGS = flags.FLAGS
31 |
32 | flags.DEFINE_bool('from_pretrained_checkpoint', False,
33 | 'If True, the model will be restarted from an pretrained '
34 | 'checkpoint')
35 | flags.DEFINE_string('efficientnet_checkpoint_path', None,
36 | 'If finetuning, path to the efficientnet checkpoint.')
37 |
38 |
39 | _AVAILABLE_MODEL_NAMES = [
40 | 'Resnet'
41 | ] + list(efficientnet.MODEL_CONFIGS.keys())
42 |
43 |
44 | def create_image_model(
45 | prng_key: jnp.ndarray, batch_size: int, image_size: int,
46 | module: flax.nn.Module) -> Tuple[flax.nn.Model, flax.nn.Collection]:
47 | """Instantiates a FLAX model and its state.
48 |
49 | Args:
50 | prng_key: PRNG key to use to sample the initial weights.
51 | batch_size: Batch size that the model should expect.
52 | image_size: Dimension of the image (assumed to be squared).
53 | module: FLAX module describing the model to instantiates.
54 |
55 | Returns:
56 | A FLAX model and its state.
57 | """
58 | input_shape = (batch_size, image_size, image_size, 3)
59 | with flax.nn.stateful() as init_state:
60 | with flax.nn.stochastic(jax.random.PRNGKey(0)):
61 | _, initial_params = module.init_by_shape(
62 | prng_key, [(input_shape, jnp.float32)])
63 | model = flax.nn.Model(module, initial_params)
64 | return model, init_state
65 |
66 |
67 | class ModelNameError(Exception):
68 | """Exception to raise when the model name is not recognized."""
69 | pass
70 |
71 |
72 | def _replace_dense_layer(model: flax.nn.Model, head: flax.nn.Model):
73 | """Replaces the last layer (head) of a model with the head of another one.
74 |
75 | Args:
76 | model: Model for which we should keep all layers except the head.
77 | head: Model from which we should copy the head.
78 |
79 | Returns:
80 | A model composed from the last layer of `head` and all the other layers of
81 | `model`.
82 | """
83 | new_params = {}
84 | for (ak, av), (bk, bv) in zip(
85 | flax.traverse_util.flatten_dict(model.params).items(),
86 | flax.traverse_util.flatten_dict(head.params).items()):
87 | if ak[1] == 'dense':
88 | new_params[bk] = bv
89 | else:
90 | new_params[ak] = av
91 | return head.replace(params=flax.traverse_util.unflatten_dict(new_params))
92 |
93 |
94 | def get_model(
95 | model_name: str,
96 | batch_size: int,
97 | image_size: int,
98 | num_classes: int = 1000,
99 | prng_key: Optional[jnp.ndarray] = None
100 | ) -> Tuple[flax.nn.Model, flax.nn.Collection]:
101 | """Returns an initialized model of the chosen architecture.
102 |
103 | Args:
104 | model_name: Name of the architecture to use. See image_classification.train
105 | flags for a list of available models.
106 | batch_size: The batch size that the model should expect.
107 | image_size: Dimension of the image (assumed to be squared).
108 | num_classes: Dimension of the output layer. Should be 1000, but is left as
109 | an argument for consistency with other load_model functions. An error will
110 | be raised if num_classes is not 1000.
111 | prng_key: PRNG key to use to sample the weights.
112 |
113 | Returns:
114 | The initialized model and its state.
115 |
116 | Raises:
117 | ModelNameError: If the name of the architecture is not recognized.
118 | """
119 | if model_name == 'Resnet50':
120 | module = resnet.ResNet50.partial(num_classes=num_classes)
121 | elif model_name == 'Resnet101':
122 | module = resnet.ResNet101.partial(num_classes=num_classes)
123 | elif model_name == 'Resnet152':
124 | module = resnet.ResNet152.partial(num_classes=num_classes)
125 | elif model_name in efficientnet.MODEL_CONFIGS:
126 | module = efficientnet.get_efficientnet_module(
127 | model_name, num_classes=num_classes)
128 | else:
129 | raise ModelNameError('Unrecognized model name.')
130 | if not prng_key:
131 | prng_key = random.PRNGKey(0)
132 |
133 | model, init_state = create_image_model(prng_key, batch_size, image_size,
134 | module)
135 |
136 | if FLAGS.from_pretrained_checkpoint:
137 | if FLAGS.efficientnet_checkpoint_path is None:
138 | raise ValueError(
139 | 'For finetuning, must set `efficientnet_checkpoint_path` to a '
140 | 'valid efficientnet checkpoint.')
141 | # If the number of class is 1000, just load the imagenet/JFT checkpoint.
142 | if num_classes == 1000:
143 | model, init_state = checkpoints.restore_checkpoint(
144 | FLAGS.efficientnet_checkpoint_path,
145 | (model, init_state))
146 | # Else we need to change the size of the last layer (head):
147 | else:
148 | # Pretrained model on JFT/Imagenet.
149 | imagenet_module = efficientnet.get_efficientnet_module(
150 | model_name, num_classes=1000)
151 | imagenet_model, imagenet_state = create_image_model(
152 | prng_key, batch_size, image_size, imagenet_module)
153 | imagenet_model, imagenet_state = checkpoints.restore_checkpoint(
154 | FLAGS.efficientnet_checkpoint_path,
155 | (imagenet_model, imagenet_state))
156 | # Replace all the layers of the initialized model with the weights
157 | # extracted from the pretrained model.
158 | model = _replace_dense_layer(imagenet_model, model)
159 | init_state = imagenet_state
160 |
161 | return model, init_state
162 |
--------------------------------------------------------------------------------
/sam_jax/imagenet_models/load_model_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.imagenet_models."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import flax
20 | import numpy as np
21 | from sam.sam_jax.imagenet_models import load_model
22 |
23 |
24 | class LoadModelTest(parameterized.TestCase):
25 |
26 | def test_CreateEfficientnetModel(self):
27 | model, state = load_model.get_model('efficientnet-b0', 1, 224, 1000)
28 | self.assertIsInstance(model, flax.nn.Model)
29 | self.assertIsInstance(state, flax.nn.Collection)
30 | fake_input = np.zeros([1, 224, 224, 3])
31 | with flax.nn.stateful(state, mutable=False):
32 | logits = model(fake_input, train=False)
33 | self.assertEqual(logits.shape, (1, 1000))
34 |
35 | def test_CreateResnetModel(self):
36 | model, state = load_model.get_model('Resnet50', 1, 224, 1000)
37 | self.assertIsInstance(model, flax.nn.Model)
38 | self.assertIsInstance(state, flax.nn.Collection)
39 | fake_input = np.zeros([1, 224, 224, 3])
40 | with flax.nn.stateful(state, mutable=False):
41 | logits = model(fake_input, train=False)
42 | self.assertEqual(logits.shape, (1, 1000))
43 |
44 |
45 | if __name__ == '__main__':
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/sam_jax/imagenet_models/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Flax implementation of ResNet V1.
16 |
17 | Forked from
18 | https://github.com/google/flax/blob/master/examples/imagenet/resnet_v1.py
19 | """
20 |
21 |
22 | from flax import nn
23 |
24 | import jax.numpy as jnp
25 |
26 |
27 | class ResNetBlock(nn.Module):
28 | """ResNet block."""
29 |
30 | def apply(self, x, filters, *,
31 | conv, norm, act,
32 | strides=(1, 1)):
33 | residual = x
34 | y = conv(x, filters, (3, 3), strides)
35 | y = norm(y)
36 | y = act(y)
37 | y = conv(y, filters, (3, 3))
38 | y = norm(y, scale_init=nn.initializers.zeros)
39 |
40 | if residual.shape != y.shape:
41 | residual = conv(residual, filters, (1, 1), strides, name='conv_proj')
42 | residual = norm(residual, name='norm_proj')
43 |
44 | return act(residual + y)
45 |
46 |
47 | class BottleneckResNetBlock(nn.Module):
48 | """Bottleneck ResNet block."""
49 |
50 | def apply(self, x, filters, *,
51 | conv, norm, act,
52 | strides=(1, 1)):
53 | residual = x
54 | y = conv(x, filters, (1, 1))
55 | y = norm(y)
56 | y = act(y)
57 | y = conv(y, filters, (3, 3), strides)
58 | y = norm(y)
59 | y = act(y)
60 | y = conv(y, filters * 4, (1, 1))
61 | y = norm(y, scale_init=nn.initializers.zeros)
62 |
63 | if residual.shape != y.shape:
64 | residual = conv(residual, filters * 4, (1, 1), strides, name='conv_proj')
65 | residual = norm(residual, name='norm_proj')
66 |
67 | return act(residual + y)
68 |
69 |
70 | class ResNet(nn.Module):
71 | """ResNetV1."""
72 |
73 | def apply(self, x, num_classes, *,
74 | stage_sizes,
75 | block_cls,
76 | num_filters=64,
77 | dtype=jnp.float32,
78 | act=nn.relu,
79 | train=True):
80 | conv = nn.Conv.partial(bias=False, dtype=dtype)
81 | norm = nn.BatchNorm.partial(
82 | use_running_average=not train,
83 | momentum=0.9, epsilon=1e-5,
84 | dtype=dtype)
85 |
86 | x = conv(x, num_filters, (7, 7), (2, 2),
87 | padding=[(3, 3), (3, 3)],
88 | name='conv_init')
89 | x = norm(x, name='bn_init')
90 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
91 | for i, block_size in enumerate(stage_sizes):
92 | for j in range(block_size):
93 | strides = (2, 2) if i > 0 and j == 0 else (1, 1)
94 | x = block_cls(x, num_filters * 2 ** i,
95 | strides=strides,
96 | conv=conv,
97 | norm=norm,
98 | act=act)
99 | x = jnp.mean(x, axis=(1, 2))
100 | x = nn.Dense(x, num_classes, dtype=dtype)
101 | return x
102 |
103 |
104 | ResNet18 = ResNet.partial(stage_sizes=[2, 2, 2, 2],
105 | block_cls=ResNetBlock)
106 | ResNet34 = ResNet.partial(stage_sizes=[3, 4, 6, 3],
107 | block_cls=ResNetBlock)
108 | ResNet50 = ResNet.partial(stage_sizes=[3, 4, 6, 3],
109 | block_cls=BottleneckResNetBlock)
110 | ResNet101 = ResNet.partial(stage_sizes=[3, 4, 23, 3],
111 | block_cls=BottleneckResNetBlock)
112 | ResNet152 = ResNet.partial(stage_sizes=[3, 8, 36, 3],
113 | block_cls=BottleneckResNetBlock)
114 | ResNet200 = ResNet.partial(stage_sizes=[3, 24, 36, 3],
115 | block_cls=BottleneckResNetBlock)
116 |
--------------------------------------------------------------------------------
/sam_jax/imagenet_models/resnet_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.imagenet_models.resnet.
16 |
17 | Forked from
18 | https://github.com/google/flax/blob/master/examples/imagenet/resnet_v1_test.py
19 | """
20 |
21 | from absl.testing import absltest
22 | import jax
23 | import jax.numpy as jnp
24 | from sam.sam_jax.imagenet_models import resnet
25 |
26 |
27 | class ResNetTest(absltest.TestCase):
28 | """Test cases for ResNet V1 model."""
29 |
30 | def test_resnet_v1_module(self):
31 | """Tests ResNet V1 model definition."""
32 | rng = jax.random.PRNGKey(0)
33 | model_def = resnet.ResNet50.partial(num_classes=10, dtype=jnp.float32)
34 | output, init_params = model_def.init_by_shape(
35 | rng, [((8, 224, 224, 3), jnp.float32)])
36 |
37 | self.assertEqual((8, 10), output.shape)
38 | self.assertLen(init_params, 19)
39 |
40 |
41 | if __name__ == '__main__':
42 | absltest.main()
43 |
--------------------------------------------------------------------------------
/sam_jax/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/sam_jax/models/load_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Build FLAX models for image classification."""
16 |
17 | from typing import Optional, Tuple
18 | import flax
19 | import jax
20 | from jax import numpy as jnp
21 | from jax import random
22 |
23 | from sam.sam_jax.models import pyramidnet
24 | from sam.sam_jax.models import wide_resnet
25 | from sam.sam_jax.models import wide_resnet_shakeshake
26 |
27 |
28 | _AVAILABLE_MODEL_NAMES = [
29 | 'WideResnet28x10',
30 | 'WideResnet28x6_ShakeShake',
31 | 'Pyramid_ShakeDrop',
32 | 'WideResnet_mini', # For testing/debugging purposes.
33 | 'WideResnet_ShakeShake_mini', # For testing/debugging purposes.
34 | 'Pyramid_ShakeDrop_mini', # For testing/debugging purposes.
35 | ]
36 |
37 |
38 | def create_image_model(
39 | prng_key: jnp.ndarray, batch_size: int, image_size: int,
40 | module: flax.nn.Module,
41 | num_channels: int = 3) -> Tuple[flax.nn.Model, flax.nn.Collection]:
42 | """Instantiates a FLAX model and its state.
43 |
44 | Args:
45 | prng_key: PRNG key to use to sample the initial weights.
46 | batch_size: Batch size that the model should expect.
47 | image_size: Dimension of the image (assumed to be squared).
48 | module: FLAX module describing the model to instantiates.
49 | num_channels: Number of channels for the images.
50 |
51 | Returns:
52 | A FLAX model and its state.
53 | """
54 | input_shape = (batch_size, image_size, image_size, num_channels)
55 | with flax.nn.stateful() as init_state:
56 | with flax.nn.stochastic(jax.random.PRNGKey(0)):
57 | _, initial_params = module.init_by_shape(
58 | prng_key, [(input_shape, jnp.float32)])
59 | model = flax.nn.Model(module, initial_params)
60 | return model, init_state
61 |
62 |
63 | def get_model(
64 | model_name: str,
65 | batch_size: int,
66 | image_size: int,
67 | num_classes: int,
68 | num_channels: int = 3,
69 | prng_key: Optional[jnp.ndarray] = None,
70 | ) -> Tuple[flax.nn.Model, flax.nn.Collection]:
71 | """Returns an initialized model of the chosen architecture.
72 |
73 | Args:
74 | model_name: Name of the architecture to use. Should be one of
75 | _AVAILABLE_MODEL_NAMES.
76 | batch_size: The batch size that the model should expect.
77 | image_size: Dimension of the image (assumed to be squared).
78 | num_classes: Dimension of the output layer.
79 | num_channels: Number of channels for the images.
80 | prng_key: PRNG key to use to sample the weights.
81 |
82 | Returns:
83 | The initialized model and its state.
84 |
85 | Raises:
86 | ValueError if the name of the architecture is not recognized.
87 | """
88 | if model_name == 'WideResnet28x10':
89 | module = wide_resnet.WideResnet.partial(
90 | blocks_per_group=4,
91 | channel_multiplier=10,
92 | num_outputs=num_classes)
93 | elif model_name == 'WideResnet28x6_ShakeShake':
94 | module = wide_resnet_shakeshake.WideResnetShakeShake.partial(
95 | blocks_per_group=4,
96 | channel_multiplier=6,
97 | num_outputs=num_classes)
98 | elif model_name == 'Pyramid_ShakeDrop':
99 | module = pyramidnet.PyramidNetShakeDrop.partial(num_outputs=num_classes)
100 | elif model_name == 'WideResnet_mini': # For testing.
101 | module = wide_resnet.WideResnet.partial(
102 | blocks_per_group=2,
103 | channel_multiplier=1,
104 | num_outputs=num_classes)
105 | elif model_name == 'WideResnet_ShakeShake_mini': # For testing.
106 | module = wide_resnet_shakeshake.WideResnetShakeShake.partial(
107 | blocks_per_group=2,
108 | channel_multiplier=1,
109 | num_outputs=num_classes)
110 | elif model_name == 'Pyramid_ShakeDrop_mini':
111 | module = pyramidnet.PyramidNetShakeDrop.partial(num_outputs=num_classes,
112 | pyramid_depth=11)
113 | else:
114 | raise ValueError('Unrecognized model name.')
115 | if not prng_key:
116 | prng_key = random.PRNGKey(0)
117 |
118 | model, init_state = create_image_model(prng_key, batch_size, image_size,
119 | module, num_channels)
120 | return model, init_state
121 |
--------------------------------------------------------------------------------
/sam_jax/models/load_model_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.models.load_model."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import flax
20 | import jax
21 | import numpy as np
22 | from sam.sam_jax.models import load_model
23 |
24 |
25 | class LoadModelTest(parameterized.TestCase):
26 |
27 | # Parametrized because other models will be added in following CLs.
28 | @parameterized.named_parameters(
29 | ('WideResnet_mini', 'WideResnet_mini'),
30 | ('WideResnet_ShakeShake_mini', 'WideResnet_ShakeShake_mini'),
31 | ('Pyramid_ShakeDrop_mini', 'Pyramid_ShakeDrop_mini'))
32 | def test_CreateModel(self, model_name: str):
33 | model, state = load_model.get_model(model_name, 1, 32, 10)
34 | self.assertIsInstance(model, flax.nn.Model)
35 | self.assertIsInstance(state, flax.nn.Collection)
36 | fake_input = np.zeros([1, 32, 32, 3])
37 | with flax.nn.stateful(state, mutable=False):
38 | logits = model(fake_input, train=False)
39 | self.assertEqual(logits.shape, (1, 10))
40 |
41 | @parameterized.named_parameters(
42 | ('WideResnet28x10', 'WideResnet28x10'),
43 | ('WideResnet28x6_ShakeShake', 'WideResnet28x6_ShakeShake'),
44 | ('Pyramid_ShakeDrop', 'Pyramid_ShakeDrop'))
45 | def test_ParameterCount(self, model_name: str):
46 | # Parameter count from the autoaugment paper models, 100 classes:
47 | reference_parameter_count = {
48 | 'WideResnet28x10': 36278324,
49 | 'WideResnet28x6_ShakeShake': 26227572,
50 | 'Pyramid_ShakeDrop': 26288692,
51 | }
52 | model, _ = load_model.get_model(model_name, 1, 32, 100)
53 | parameter_count = sum(np.prod(e.shape) for e in jax.tree_leaves(model))
54 | self.assertEqual(parameter_count, reference_parameter_count[model_name])
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/sam_jax/models/pyramidnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """PyramidNet model with ShakeDrop regularization.
16 |
17 | Reference:
18 |
19 | ShakeDrop Regularization for Deep Residual Learning
20 | Yoshihiro Yamada, Masakazu Iwamura, Takuya Akiba, Koichi Kise
21 | https://arxiv.org/abs/1802.02375
22 |
23 | Initially forked from
24 | github.com/google/flax/blob/master/examples/cifar10/models/pyramidnet.py
25 |
26 | This implementation mimics the one from
27 | https://github.com/tensorflow/models/blob/master/research/autoaugment/shake_drop.py
28 | that is widely used as a benchmark.
29 |
30 | We use kaiming normal initialization for convolutional kernels (mode = fan_out,
31 | gain = 2.0). The final dense layer use a uniform distribution U[-scale, scale]
32 | where scale = 1 / sqrt(num_classes) as per the autoaugment implementation.
33 |
34 | It is worth noting that this model is slighly different that the one presented
35 | in the Deep Pyramidal Residual Networks paper
36 | (https://arxiv.org/pdf/1610.02915.pdf), as we round instead of trucating when
37 | computing the number of channels in each block. This results in a model with
38 | roughtly 0.2M additional parameters. Rounding is however the method that was
39 | used in follow up work (https://arxiv.org/abs/1905.00397,
40 | https://arxiv.org/abs/2002.12047) so we keep it for consistency.
41 | """
42 |
43 | from typing import Tuple
44 |
45 | from flax import nn
46 | import jax.numpy as jnp
47 |
48 | from sam.sam_jax.models import utils
49 |
50 |
51 | def _shortcut(x: jnp.ndarray, chn_out: int, strides: Tuple[int, int]
52 | ) -> jnp.ndarray:
53 | """Pyramid Net shortcut.
54 |
55 | Use Average pooling to downsample.
56 | Use zero-padding to increase channels.
57 |
58 | Args:
59 | x: Input. Should have shape [batch_size, dim, dim, features]
60 | where dim is the resolution (width and height if the input is an image).
61 | chn_out: Expected output channels.
62 | strides: Output stride.
63 |
64 | Returns:
65 | Shortcut value for Pyramid Net. Shape will be
66 | [batch_size, dim, dim, chn_out] if strides = (1, 1) (no downsampling) or
67 | [batch_size, dim/2, dim/2, chn_out] if strides = (2, 2) (downsampling).
68 | """
69 | chn_in = x.shape[3]
70 | if strides != (1, 1):
71 | x = nn.avg_pool(x, strides, strides)
72 | if chn_out != chn_in:
73 | diff = chn_out - chn_in
74 | x = jnp.pad(x, [[0, 0], [0, 0], [0, 0], [0, diff]])
75 | return x
76 |
77 |
78 | class BottleneckShakeDrop(nn.Module):
79 | """PyramidNet with Shake-Drop Bottleneck."""
80 |
81 | def apply(self,
82 | x: jnp.ndarray,
83 | channels: int,
84 | strides: Tuple[int, int],
85 | prob: float,
86 | alpha_min: float,
87 | alpha_max: float,
88 | beta_min: float,
89 | beta_max: float,
90 | train: bool = True,
91 | true_gradient: bool = False) -> jnp.ndarray:
92 | """Implements the forward pass in the module.
93 |
94 | Args:
95 | x: Input to the module. Should have shape [batch_size, dim, dim, features]
96 | where dim is the resolution (width and height if the input is an image).
97 | channels: How many channels to use in the convolutional layers.
98 | strides: Strides for the pooling.
99 | prob: Probability of dropping the block (see paper for details).
100 | alpha_min: See paper.
101 | alpha_max: See paper.
102 | beta_min: See paper.
103 | beta_max: See paper.
104 | train: If False, will use the moving average for batch norm statistics.
105 | Else, will use statistics computed on the batch.
106 | true_gradient: If true, the same mixing parameter will be used for the
107 | forward and backward pass (see paper for more details).
108 |
109 | Returns:
110 | The output of the bottleneck block.
111 | """
112 | y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre')
113 | y = nn.Conv(
114 | y,
115 | channels, (1, 1),
116 | padding='SAME',
117 | bias=False,
118 | kernel_init=utils.conv_kernel_init_fn,
119 | name='1x1_conv_contract')
120 | y = utils.activation(y, train=train, name='bn_1_post')
121 | y = nn.Conv(
122 | y,
123 | channels, (3, 3),
124 | strides,
125 | padding='SAME',
126 | bias=False,
127 | kernel_init=utils.conv_kernel_init_fn,
128 | name='3x3')
129 | y = utils.activation(y, train=train, name='bn_2')
130 | y = nn.Conv(
131 | y,
132 | channels * 4, (1, 1),
133 | padding='SAME',
134 | bias=False,
135 | kernel_init=utils.conv_kernel_init_fn,
136 | name='1x1_conv_expand')
137 | y = utils.activation(y, apply_relu=False, train=train, name='bn_3')
138 |
139 | if train and not self.is_initializing():
140 | y = utils.shake_drop_train(y, prob, alpha_min, alpha_max,
141 | beta_min, beta_max,
142 | true_gradient=true_gradient)
143 | else:
144 | y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max)
145 |
146 | x = _shortcut(x, channels * 4, strides)
147 | return x + y
148 |
149 |
150 | def _calc_shakedrop_mask_prob(curr_layer: int,
151 | total_layers: int,
152 | mask_prob: float) -> float:
153 | """Calculates drop prob depending on the current layer."""
154 | return 1 - (float(curr_layer) / total_layers) * mask_prob
155 |
156 |
157 | class PyramidNetShakeDrop(nn.Module):
158 | """PyramidNet with Shake-Drop."""
159 |
160 | def apply(self,
161 | x: jnp.ndarray,
162 | num_outputs: int,
163 | pyramid_alpha: int = 200,
164 | pyramid_depth: int = 272,
165 | train: bool = True,
166 | true_gradient: bool = False) -> jnp.ndarray:
167 | """Implements the forward pass in the module.
168 |
169 | Args:
170 | x: Input to the module. Should have shape [batch_size, dim, dim, 3]
171 | where dim is the resolution of the image.
172 | num_outputs: Dimension of the output of the model (ie number of classes
173 | for a classification problem).
174 | pyramid_alpha: See paper.
175 | pyramid_depth: See paper.
176 | train: If False, will use the moving average for batch norm statistics.
177 | Else, will use statistics computed on the batch.
178 | true_gradient: If true, the same mixing parameter will be used for the
179 | forward and backward pass (see paper for more details).
180 |
181 | Returns:
182 | The output of the PyramidNet model, a tensor of shape
183 | [batch_size, num_classes].
184 | """
185 | assert (pyramid_depth - 2) % 9 == 0
186 |
187 | # Shake-drop hyper-params
188 | mask_prob = 0.5
189 | alpha_min, alpha_max = (-1.0, 1.0)
190 | beta_min, beta_max = (0.0, 1.0)
191 |
192 | # Bottleneck network size
193 | blocks_per_group = (pyramid_depth - 2) // 9
194 | # See Eqn 2 in https://arxiv.org/abs/1610.02915
195 | num_channels = 16
196 | # N in https://arxiv.org/abs/1610.02915
197 | total_blocks = blocks_per_group * 3
198 | delta_channels = pyramid_alpha / total_blocks
199 |
200 | x = nn.Conv(
201 | x,
202 | 16, (3, 3),
203 | padding='SAME',
204 | name='init_conv',
205 | bias=False,
206 | kernel_init=utils.conv_kernel_init_fn)
207 | x = utils.activation(x, apply_relu=False, train=train, name='init_bn')
208 |
209 | layer_num = 1
210 |
211 | for block_i in range(blocks_per_group):
212 | num_channels += delta_channels
213 | layer_mask_prob = _calc_shakedrop_mask_prob(layer_num, total_blocks,
214 | mask_prob)
215 | x = BottleneckShakeDrop(
216 | x,
217 | int(round(num_channels)), (1, 1),
218 | layer_mask_prob,
219 | alpha_min,
220 | alpha_max,
221 | beta_min,
222 | beta_max,
223 | train=train,
224 | true_gradient=true_gradient)
225 | layer_num += 1
226 |
227 | for block_i in range(blocks_per_group):
228 | num_channels += delta_channels
229 | layer_mask_prob = _calc_shakedrop_mask_prob(
230 | layer_num, total_blocks, mask_prob)
231 | x = BottleneckShakeDrop(x, int(round(num_channels)),
232 | ((2, 2) if block_i == 0 else (1, 1)),
233 | layer_mask_prob,
234 | alpha_min, alpha_max, beta_min, beta_max,
235 | train=train,
236 | true_gradient=true_gradient)
237 | layer_num += 1
238 |
239 | for block_i in range(blocks_per_group):
240 | num_channels += delta_channels
241 | layer_mask_prob = _calc_shakedrop_mask_prob(
242 | layer_num, total_blocks, mask_prob)
243 | x = BottleneckShakeDrop(x, int(round(num_channels)),
244 | ((2, 2) if block_i == 0 else (1, 1)),
245 | layer_mask_prob,
246 | alpha_min, alpha_max, beta_min, beta_max,
247 | train=train,
248 | true_gradient=true_gradient)
249 | layer_num += 1
250 |
251 | assert layer_num - 1 == total_blocks
252 | x = utils.activation(x, train=train, name='final_bn')
253 | x = nn.avg_pool(x, (8, 8))
254 | x = x.reshape((x.shape[0], -1))
255 | x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
256 | return x
257 |
--------------------------------------------------------------------------------
/sam_jax/models/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Shake-shake and shake-drop functions.
16 |
17 | Forked from:
18 | https://github.com/google-research/google-research/blob/master/flax_models/cifar/models/utils.py
19 | """
20 |
21 | from typing import Optional, Tuple
22 | import flax
23 | from flax import nn
24 | import jax
25 | import jax.numpy as jnp
26 |
27 |
28 | _BATCHNORM_MOMENTUM = 0.9
29 | _BATCHNORM_EPSILON = 1e-5
30 |
31 |
32 | def activation(x: jnp.ndarray,
33 | train: bool,
34 | apply_relu: bool = True,
35 | name: str = '') -> jnp.ndarray:
36 | """Applies BatchNorm and then (optionally) ReLU.
37 |
38 | Args:
39 | x: Tensor on which the activation should be applied.
40 | train: If False, will use the moving average for batch norm statistics.
41 | Else, will use statistics computed on the batch.
42 | apply_relu: Whether or not ReLU should be applied after batch normalization.
43 | name: How to name the BatchNorm layer.
44 |
45 | Returns:
46 | The input tensor where BatchNorm and (optionally) ReLU where applied.
47 | """
48 | batch_norm = nn.BatchNorm.partial(
49 | use_running_average=not train,
50 | momentum=_BATCHNORM_MOMENTUM,
51 | epsilon=_BATCHNORM_EPSILON)
52 | x = batch_norm(x, name=name)
53 | if apply_relu:
54 | x = jax.nn.relu(x)
55 | return x
56 |
57 |
58 | # Kaiming initialization with fan out mode. Should be used to initialize
59 | # convolutional kernels.
60 | conv_kernel_init_fn = jax.nn.initializers.variance_scaling(
61 | 2.0, 'fan_out', 'normal')
62 |
63 |
64 | def dense_layer_init_fn(key: jnp.ndarray,
65 | shape: Tuple[int, int],
66 | dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
67 | """Initializer for the final dense layer.
68 |
69 | Args:
70 | key: PRNG key to use to sample the weights.
71 | shape: Shape of the tensor to initialize.
72 | dtype: Data type of the tensor to initialize.
73 |
74 | Returns:
75 | The initialized tensor.
76 | """
77 | num_units_out = shape[1]
78 | unif_init_range = 1.0 / (num_units_out)**(0.5)
79 | return jax.random.uniform(key, shape, dtype, -1) * unif_init_range
80 |
81 |
82 | def shake_shake_train(xa: jnp.ndarray,
83 | xb: jnp.ndarray,
84 | rng: Optional[jnp.ndarray] = None,
85 | true_gradient: bool = False) -> jnp.ndarray:
86 | """Shake-shake regularization in training mode.
87 |
88 | Shake-shake regularization interpolates between inputs A and B
89 | with *different* random uniform (per-sample) interpolation factors
90 | for the forward and backward/gradient passes.
91 |
92 | Args:
93 | xa: Input, branch A.
94 | xb: Input, branch B.
95 | rng: PRNG key.
96 | true_gradient: If true, the same mixing parameter will be used for the
97 | forward and backward pass (see paper for more details).
98 |
99 | Returns:
100 | Mix of input branches.
101 | """
102 | if rng is None:
103 | rng = flax.nn.make_rng()
104 | gate_forward_key, gate_backward_key = jax.random.split(rng, num=2)
105 | gate_shape = (len(xa), 1, 1, 1)
106 |
107 | # Draw different interpolation factors (gate) for forward and backward pass.
108 | gate_forward = jax.random.uniform(
109 | gate_forward_key, gate_shape, dtype=jnp.float32, minval=0.0, maxval=1.0)
110 | x_forward = xa * gate_forward + xb * (1.0 - gate_forward)
111 | if true_gradient:
112 | return x_forward
113 | gate_backward = jax.random.uniform(
114 | gate_backward_key, gate_shape, dtype=jnp.float32, minval=0.0, maxval=1.0)
115 | # Compute interpolated x for forward and backward.
116 | x_backward = xa * gate_backward + xb * (1.0 - gate_backward)
117 | # Combine using stop_gradient.
118 | return x_backward + jax.lax.stop_gradient(x_forward - x_backward)
119 |
120 |
121 | def shake_shake_eval(xa: jnp.ndarray, xb: jnp.ndarray) -> jnp.ndarray:
122 | """Shake-shake regularization in testing mode.
123 |
124 | Args:
125 | xa: Input, branch A.
126 | xb: Input, branch B.
127 |
128 | Returns:
129 | Mix of input branches.
130 | """
131 | # Blend between inputs A and B 50%-50%.
132 | return (xa + xb) * 0.5
133 |
134 |
135 | def shake_drop_train(x: jnp.ndarray,
136 | mask_prob: float,
137 | alpha_min: float,
138 | alpha_max: float,
139 | beta_min: float,
140 | beta_max: float,
141 | rng: Optional[jnp.ndarray] = None,
142 | true_gradient: bool = False) -> jnp.ndarray:
143 | """ShakeDrop training pass.
144 |
145 | See https://arxiv.org/abs/1802.02375
146 |
147 | Args:
148 | x: Input to apply ShakeDrop to.
149 | mask_prob: Mask probability.
150 | alpha_min: Alpha range lower.
151 | alpha_max: Alpha range upper.
152 | beta_min: Beta range lower.
153 | beta_max: Beta range upper.
154 | rng: PRNG key (if `None`, uses `flax.nn.make_rng`).
155 | true_gradient: If true, the same mixing parameter will be used for the
156 | forward and backward pass (see paper for more details).
157 |
158 | Returns:
159 | The regularized tensor.
160 | """
161 | if rng is None:
162 | rng = flax.nn.make_rng()
163 | bern_key, alpha_key, beta_key = jax.random.split(rng, num=3)
164 | rnd_shape = (len(x), 1, 1, 1)
165 | # Bernoulli variable b_l in Eqn 6, https://arxiv.org/abs/1802.02375.
166 | mask = jax.random.bernoulli(bern_key, mask_prob, rnd_shape)
167 | mask = mask.astype(jnp.float32)
168 |
169 | alpha_values = jax.random.uniform(
170 | alpha_key,
171 | rnd_shape,
172 | dtype=jnp.float32,
173 | minval=alpha_min,
174 | maxval=alpha_max)
175 | beta_values = jax.random.uniform(
176 | beta_key, rnd_shape, dtype=jnp.float32, minval=beta_min, maxval=beta_max)
177 | # See Eqn 6 in https://arxiv.org/abs/1802.02375.
178 | rand_forward = mask + alpha_values - mask * alpha_values
179 | if true_gradient:
180 | return x * rand_forward
181 | rand_backward = mask + beta_values - mask * beta_values
182 | return x * rand_backward + jax.lax.stop_gradient(
183 | x * rand_forward - x * rand_backward)
184 |
185 |
186 | def shake_drop_eval(x: jnp.ndarray,
187 | mask_prob: float,
188 | alpha_min: float,
189 | alpha_max: float) -> jnp.ndarray:
190 | """ShakeDrop eval pass.
191 |
192 | See https://arxiv.org/abs/1802.02375
193 |
194 | Args:
195 | x: Input to apply ShakeDrop to.
196 | mask_prob: Mask probability.
197 | alpha_min: Alpha range lower.
198 | alpha_max: Alpha range upper.
199 |
200 | Returns:
201 | The regularized tensor.
202 | """
203 | expected_alpha = (alpha_max + alpha_min) / 2
204 | # See Eqn 6 in https://arxiv.org/abs/1802.02375.
205 | return (mask_prob + expected_alpha - mask_prob * expected_alpha) * x
206 |
--------------------------------------------------------------------------------
/sam_jax/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Wide Resnet Model.
16 |
17 | Reference:
18 |
19 | Wide Residual Networks, Sergey Zagoruyko, Nikos Komodakis
20 | https://arxiv.org/abs/1605.07146
21 |
22 | Forked from
23 | https://github.com/google-research/google-research/blob/master/flax_models/cifar/models/wide_resnet.py
24 |
25 | This implementation mimics the one from
26 | github.com/tensorflow/models/blob/master/research/autoaugment/wrn.py
27 | that is widely used as a benchmark.
28 |
29 | It uses idendity + zero padding skip connections, with kaiming normal
30 | initialization for convolutional kernels (mode = fan_out, gain=2.0).
31 | The final dense layer use a uniform distribution U[-scale, scale] where
32 | scale = 1 / sqrt(num_classes) as per the autoaugment implementation.
33 |
34 | Using the default initialization instead gives error rates approximately 0.5%
35 | greater on cifar100, most likely between the parameters used in the literature
36 | where finetuned for this particular initialization.
37 |
38 | Finally, the autoaugment implementation adds more residual connections between
39 | the groups (instead of just between the blocks as per the original paper and
40 | most implementations). It is possible to safely remove those connections without
41 | degrading the performances, which we do by default to match the original
42 | wideresnet paper. Setting `use_additional_skip_connections` to True will add
43 | them back and then reproduces exactly the model used in autoaugment.
44 | """
45 |
46 | from typing import Tuple
47 |
48 | from absl import flags
49 | from flax import nn
50 | from jax import numpy as jnp
51 |
52 | from sam.sam_jax.models import utils
53 |
54 |
55 | FLAGS = flags.FLAGS
56 |
57 |
58 | flags.DEFINE_bool('use_additional_skip_connections', False,
59 | 'Set to True to use additional skip connections between the '
60 | 'resnet groups. This reproduces the autoaugment '
61 | 'implementation, but these connections are not present in '
62 | 'most implementations. Removing them does not impact '
63 | 'the performance of the model.')
64 |
65 |
66 | def _output_add(block_x: jnp.ndarray, orig_x: jnp.ndarray) -> jnp.ndarray:
67 | """Add two tensors, padding them with zeros or pooling them if necessary.
68 |
69 | Args:
70 | block_x: Output of a resnet block.
71 | orig_x: Residual branch to add to the output of the resnet block.
72 |
73 | Returns:
74 | The sum of blocks_x and orig_x. If necessary, orig_x will be average pooled
75 | or zero padded so that its shape matches orig_x.
76 | """
77 | stride = orig_x.shape[-2] // block_x.shape[-2]
78 | strides = (stride, stride)
79 | if block_x.shape[-1] != orig_x.shape[-1]:
80 | orig_x = nn.avg_pool(orig_x, strides, strides)
81 | channels_to_add = block_x.shape[-1] - orig_x.shape[-1]
82 | orig_x = jnp.pad(orig_x, [(0, 0), (0, 0), (0, 0), (0, channels_to_add)])
83 | return block_x + orig_x
84 |
85 |
86 | class WideResnetBlock(nn.Module):
87 | """Defines a single WideResnetBlock."""
88 |
89 | def apply(self,
90 | x: jnp.ndarray,
91 | channels: int,
92 | strides: Tuple[int, int] = (1, 1),
93 | activate_before_residual: bool = False,
94 | train: bool = True) -> jnp.ndarray:
95 | """Implements the forward pass in the module.
96 |
97 | Args:
98 | x: Input to the module. Should have shape [batch_size, dim, dim, features]
99 | where dim is the resolution (width and height if the input is an image).
100 | channels: How many channels to use in the convolutional layers.
101 | strides: Strides for the pooling.
102 | activate_before_residual: True if the batch norm and relu should be
103 | applied before the residual branches out (should be True only for the
104 | first block of the model).
105 | train: If False, will use the moving average for batch norm statistics.
106 | Else, will use statistics computed on the batch.
107 |
108 | Returns:
109 | The output of the resnet block.
110 | """
111 | if activate_before_residual:
112 | x = utils.activation(x, train, name='init_bn')
113 | orig_x = x
114 | else:
115 | orig_x = x
116 |
117 | block_x = x
118 | if not activate_before_residual:
119 | block_x = utils.activation(block_x, train, name='init_bn')
120 |
121 | block_x = nn.Conv(
122 | block_x,
123 | channels, (3, 3),
124 | strides,
125 | padding='SAME',
126 | bias=False,
127 | kernel_init=utils.conv_kernel_init_fn,
128 | name='conv1')
129 | block_x = utils.activation(block_x, train=train, name='bn_2')
130 | block_x = nn.Conv(
131 | block_x,
132 | channels, (3, 3),
133 | padding='SAME',
134 | bias=False,
135 | kernel_init=utils.conv_kernel_init_fn,
136 | name='conv2')
137 |
138 | return _output_add(block_x, orig_x)
139 |
140 |
141 | class WideResnetGroup(nn.Module):
142 | """Defines a WideResnetGroup."""
143 |
144 | def apply(self,
145 | x: jnp.ndarray,
146 | blocks_per_group: int,
147 | channels: int,
148 | strides: Tuple[int, int] = (1, 1),
149 | activate_before_residual: bool = False,
150 | train: bool = True) -> jnp.ndarray:
151 | """Implements the forward pass in the module.
152 |
153 | Args:
154 | x: Input to the module. Should have shape [batch_size, dim, dim, features]
155 | where dim is the resolution (width and height if the input is an image).
156 | blocks_per_group: How many resnet blocks to add to each group (should be
157 | 4 blocks for a WRN28, and 6 for a WRN40).
158 | channels: How many channels to use in the convolutional layers.
159 | strides: Strides for the pooling.
160 | activate_before_residual: True if the batch norm and relu should be
161 | applied before the residual branches out (should be True only for the
162 | first group of the model).
163 | train: If False, will use the moving average for batch norm statistics.
164 | Else, will use statistics computed on the batch.
165 |
166 | Returns:
167 | The output of the resnet block.
168 | """
169 | orig_x = x
170 | for i in range(blocks_per_group):
171 | x = WideResnetBlock(
172 | x,
173 | channels,
174 | strides if i == 0 else (1, 1),
175 | activate_before_residual=activate_before_residual and not i,
176 | train=train)
177 | if FLAGS.use_additional_skip_connections:
178 | x = _output_add(x, orig_x)
179 | return x
180 |
181 |
182 | class WideResnet(nn.Module):
183 | """Defines the WideResnet Model."""
184 |
185 | def apply(self,
186 | x: jnp.ndarray,
187 | blocks_per_group: int,
188 | channel_multiplier: int,
189 | num_outputs: int,
190 | train: bool = True) -> jnp.ndarray:
191 | """Implements a WideResnet module.
192 |
193 | Args:
194 | x: Input to the module. Should have shape [batch_size, dim, dim, 3]
195 | where dim is the resolution of the image.
196 | blocks_per_group: How many resnet blocks to add to each group (should be
197 | 4 blocks for a WRN28, and 6 for a WRN40).
198 | channel_multiplier: The multiplier to apply to the number of filters in
199 | the model (1 is classical resnet, 10 for WRN28-10, etc...).
200 | num_outputs: Dimension of the output of the model (ie number of classes
201 | for a classification problem).
202 | train: If False, will use the moving average for batch norm statistics.
203 |
204 | Returns:
205 | The output of the WideResnet, a tensor of shape [batch_size, num_classes].
206 | """
207 | first_x = x
208 | x = nn.Conv(
209 | x,
210 | 16, (3, 3),
211 | padding='SAME',
212 | name='init_conv',
213 | kernel_init=utils.conv_kernel_init_fn,
214 | bias=False)
215 | x = WideResnetGroup(
216 | x,
217 | blocks_per_group,
218 | 16 * channel_multiplier,
219 | activate_before_residual=True,
220 | train=train)
221 | x = WideResnetGroup(
222 | x,
223 | blocks_per_group,
224 | 32 * channel_multiplier, (2, 2),
225 | train=train)
226 | x = WideResnetGroup(
227 | x,
228 | blocks_per_group,
229 | 64 * channel_multiplier, (2, 2),
230 | train=train)
231 | if FLAGS.use_additional_skip_connections:
232 | x = _output_add(x, first_x)
233 | x = utils.activation(x, train=train, name='pre-pool-bn')
234 | x = nn.avg_pool(x, x.shape[1:3])
235 | x = x.reshape((x.shape[0], -1))
236 | x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
237 | return x
238 |
--------------------------------------------------------------------------------
/sam_jax/models/wide_resnet_shakeshake.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Wide Resnet Model with ShakeShake regularization.
16 |
17 | Reference:
18 |
19 | Shake-Shake regularization, Xavier Gastaldi
20 | https://arxiv.org/abs/1705.07485
21 |
22 | Initially forked from
23 | https://github.com/google-research/google-research/blob/master/flax_models/cifar/models/wide_resnet_shakeshake.py
24 |
25 | This implementation mimics the one from
26 | github.com/tensorflow/models/blob/master/research/autoaugment/shake_shake.py
27 | that is widely used as a benchmark.
28 |
29 | It uses kaiming normal initialization for convolutional kernels (mode = fan_out
30 | , gain=2.0). The final dense layer use a uniform distribution U[-scale, scale]
31 | where scale = 1 / sqrt(num_classes) as per the autoaugment implementation.
32 |
33 | The residual connections follows the implementation of X. Gastaldi (1x1 pooling
34 | with a one pixel offset for half the channels, see 2.1.1 Implementation details
35 | section in the reference above for more details).
36 | """
37 |
38 | from typing import Tuple
39 |
40 | from flax import nn
41 | import jax
42 | from jax import numpy as jnp
43 |
44 | from sam.sam_jax.models import utils
45 |
46 |
47 | class Shortcut(nn.Module):
48 | """Shortcut for residual connections."""
49 |
50 | def apply(self,
51 | x: jnp.ndarray,
52 | channels: int,
53 | strides: Tuple[int, int] = (1, 1),
54 | train: bool = True) -> jnp.ndarray:
55 | """Implements the forward pass in the module.
56 |
57 | Args:
58 | x: Input to the module. Should have shape [batch_size, dim, dim, features]
59 | where dim is the resolution (width and height if the input is an image).
60 | channels: How many channels to use in the convolutional layers.
61 | strides: Strides for the pooling.
62 | train: If False, will use the moving average for batch norm statistics.
63 |
64 | Returns:
65 | The output of the resnet block. Will have shape
66 | [batch_size, dim, dim, channels] if strides = (1, 1) or
67 | [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
68 | """
69 |
70 | if x.shape[-1] == channels:
71 | return x
72 |
73 | # Skip path 1
74 | h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID')
75 | h1 = nn.Conv(
76 | h1,
77 | channels // 2, (1, 1),
78 | strides=(1, 1),
79 | padding='SAME',
80 | bias=False,
81 | kernel_init=utils.conv_kernel_init_fn,
82 | name='conv_h1')
83 |
84 | # Skip path 2
85 | # The next two lines offset the "image" by one pixel on the right and one
86 | # down (see Shake-Shake regularization, Xavier Gastaldi for details)
87 | pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
88 | h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :]
89 | h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID')
90 | h2 = nn.Conv(
91 | h2,
92 | channels // 2, (1, 1),
93 | strides=(1, 1),
94 | padding='SAME',
95 | bias=False,
96 | kernel_init=utils.conv_kernel_init_fn,
97 | name='conv_h2')
98 | merged_branches = jnp.concatenate([h1, h2], axis=3)
99 | return utils.activation(
100 | merged_branches, apply_relu=False, train=train, name='bn_residual')
101 |
102 |
103 | class ShakeShakeBlock(nn.Module):
104 | """Wide ResNet block with shake-shake regularization."""
105 |
106 | def apply(self,
107 | x: jnp.ndarray,
108 | channels: int,
109 | strides: Tuple[int, int] = (1, 1),
110 | train: bool = True,
111 | true_gradient: bool = False) -> jnp.ndarray:
112 | """Implements the forward pass in the module.
113 |
114 | Args:
115 | x: Input to the module. Should have shape [batch_size, dim, dim, features]
116 | where dim is the resolution (width and height if the input is an image).
117 | channels: How many channels to use in the convolutional layers.
118 | strides: Strides for the pooling.
119 | train: If False, will use the moving average for batch norm statistics.
120 | Else, will use statistics computed on the batch.
121 | true_gradient: If true, the same mixing parameter will be used for the
122 | forward and backward pass (see paper for more details).
123 |
124 | Returns:
125 | The output of the resnet block. Will have shape
126 | [batch_size, dim, dim, channels] if strides = (1, 1) or
127 | [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
128 | """
129 | a = b = residual = x
130 |
131 | a = jax.nn.relu(a)
132 | a = nn.Conv(
133 | a,
134 | channels, (3, 3),
135 | strides,
136 | padding='SAME',
137 | bias=False,
138 | kernel_init=utils.conv_kernel_init_fn,
139 | name='conv_a_1')
140 | a = utils.activation(a, train=train, name='bn_a_1')
141 | a = nn.Conv(
142 | a,
143 | channels, (3, 3),
144 | padding='SAME',
145 | bias=False,
146 | kernel_init=utils.conv_kernel_init_fn,
147 | name='conv_a_2')
148 | a = utils.activation(a, apply_relu=False, train=train, name='bn_a_2')
149 |
150 | b = jax.nn.relu(b)
151 | b = nn.Conv(
152 | b,
153 | channels, (3, 3),
154 | strides,
155 | padding='SAME',
156 | bias=False,
157 | kernel_init=utils.conv_kernel_init_fn,
158 | name='conv_b_1')
159 | b = utils.activation(b, train=train, name='bn_b_1')
160 | b = nn.Conv(
161 | b,
162 | channels, (3, 3),
163 | padding='SAME',
164 | bias=False,
165 | kernel_init=utils.conv_kernel_init_fn,
166 | name='conv_b_2')
167 | b = utils.activation(b, apply_relu=False, train=train, name='bn_b_2')
168 |
169 | if train and not self.is_initializing():
170 | ab = utils.shake_shake_train(a, b, true_gradient=true_gradient)
171 | else:
172 | ab = utils.shake_shake_eval(a, b)
173 |
174 | # Apply an up projection in case of channel mismatch.
175 | residual = Shortcut(residual, channels, strides, train)
176 |
177 | return residual + ab
178 |
179 |
180 | class WideResnetShakeShakeGroup(nn.Module):
181 | """Defines a WideResnetGroup."""
182 |
183 | def apply(self,
184 | x: jnp.ndarray,
185 | blocks_per_group: int,
186 | channels: int,
187 | strides: Tuple[int, int] = (1, 1),
188 | train: bool = True,
189 | true_gradient: bool = False) -> jnp.ndarray:
190 | """Implements the forward pass in the module.
191 |
192 | Args:
193 | x: Input to the module. Should have shape [batch_size, dim, dim, features]
194 | where dim is the resolution (width and height if the input is an image).
195 | blocks_per_group: How many resnet blocks to add to each group (should be
196 | 4 blocks for a WRN28, and 6 for a WRN40).
197 | channels: How many channels to use in the convolutional layers.
198 | strides: Strides for the pooling.
199 | train: If False, will use the moving average for batch norm statistics.
200 | Else, will use statistics computed on the batch.
201 | true_gradient: If true, the same mixing parameter will be used for the
202 | forward and backward pass (see paper for more details).
203 |
204 | Returns:
205 | The output of the resnet block. Will have shape
206 | [batch_size, dim, dim, channels] if strides = (1, 1) or
207 | [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
208 | """
209 | for i in range(blocks_per_group):
210 | x = ShakeShakeBlock(
211 | x,
212 | channels,
213 | strides if i == 0 else (1, 1),
214 | train=train,
215 | true_gradient=true_gradient)
216 | return x
217 |
218 |
219 | class WideResnetShakeShake(nn.Module):
220 | """Defines the WideResnet Model."""
221 |
222 | def apply(self,
223 | x: jnp.ndarray,
224 | blocks_per_group: int,
225 | channel_multiplier: int,
226 | num_outputs: int,
227 | train: bool = True,
228 | true_gradient: bool = False) -> jnp.ndarray:
229 | """Implements a WideResnet with ShakeShake regularization module.
230 |
231 | Args:
232 | x: Input to the module. Should have shape [batch_size, dim, dim, 3]
233 | where dim is the resolution of the image.
234 | blocks_per_group: How many resnet blocks to add to each group (should be
235 | 4 blocks for a WRN26 as per standard shake shake implementation).
236 | channel_multiplier: The multiplier to apply to the number of filters in
237 | the model (1 is classical resnet, 6 for WRN26-2x6, etc...).
238 | num_outputs: Dimension of the output of the model (ie number of classes
239 | for a classification problem).
240 | train: If False, will use the moving average for batch norm statistics.
241 | Else, will use statistics computed on the batch.
242 | true_gradient: If true, the same mixing parameter will be used for the
243 | forward and backward pass (see paper for more details).
244 |
245 | Returns:
246 | The output of the WideResnet with ShakeShake regularization, a tensor of
247 | shape [batch_size, num_classes].
248 | """
249 | x = nn.Conv(
250 | x,
251 | 16, (3, 3),
252 | padding='SAME',
253 | kernel_init=utils.conv_kernel_init_fn,
254 | bias=False,
255 | name='init_conv')
256 | x = utils.activation(x, apply_relu=False, train=train, name='init_bn')
257 | x = WideResnetShakeShakeGroup(
258 | x,
259 | blocks_per_group,
260 | 16 * channel_multiplier,
261 | train=train,
262 | true_gradient=true_gradient)
263 | x = WideResnetShakeShakeGroup(
264 | x,
265 | blocks_per_group,
266 | 32 * channel_multiplier, (2, 2),
267 | train=train,
268 | true_gradient=true_gradient)
269 | x = WideResnetShakeShakeGroup(
270 | x,
271 | blocks_per_group,
272 | 64 * channel_multiplier, (2, 2),
273 | train=train,
274 | true_gradient=true_gradient)
275 | x = jax.nn.relu(x)
276 | x = nn.avg_pool(x, x.shape[1:3])
277 | x = x.reshape((x.shape[0], -1))
278 | return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
279 |
--------------------------------------------------------------------------------
/sam_jax/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Trains a model on cifar10, cifar100, SVHN, F-MNIST or imagenet."""
16 |
17 | import os
18 |
19 | from absl import app
20 | from absl import flags
21 | from absl import logging
22 | import jax
23 | from sam.sam_jax.datasets import dataset_source as dataset_source_lib
24 | from sam.sam_jax.datasets import dataset_source_imagenet
25 | from sam.sam_jax.efficientnet import efficientnet
26 | from sam.sam_jax.imagenet_models import load_model as load_imagenet_model
27 | from sam.sam_jax.models import load_model
28 | from sam.sam_jax.training_utils import flax_training
29 | import tensorflow.compat.v2 as tf
30 | from tensorflow.io import gfile
31 |
32 |
33 | FLAGS = flags.FLAGS
34 |
35 | flags.DEFINE_enum('dataset', 'cifar10', [
36 | 'cifar10', 'cifar100', 'fashion_mnist', 'svhn', 'imagenet', 'Birdsnap',
37 | 'cifar100_brain', 'Stanford_Cars', 'Flowers', 'FGVC_Aircraft',
38 | 'Oxford_IIIT_Pets', 'Food_101'
39 | ], 'Name of the dataset.')
40 | flags.DEFINE_enum('model_name', 'WideResnet28x10', [
41 | 'WideResnet28x10', 'WideResnet28x6_ShakeShake', 'Pyramid_ShakeDrop',
42 | 'Resnet50', 'Resnet101', 'Resnet152'
43 | ] + list(efficientnet.MODEL_CONFIGS.keys()), 'Name of the model to train.')
44 | flags.DEFINE_integer('num_epochs', 200,
45 | 'How many epochs the model should be trained for.')
46 | flags.DEFINE_integer(
47 | 'batch_size', 128, 'Global batch size. If multiple '
48 | 'replicas are used, each replica will receive '
49 | 'batch_size / num_replicas examples. Batch size should be divisible by '
50 | 'the number of available devices.')
51 | flags.DEFINE_string(
52 | 'output_dir', '', 'Directory where the checkpoints and the tensorboard '
53 | 'records should be saved.')
54 | flags.DEFINE_enum(
55 | 'image_level_augmentations', 'basic', ['none', 'basic', 'autoaugment',
56 | 'aa-only'],
57 | 'Augmentations applied to the images. Should be `none` for '
58 | 'no augmentations, `basic` for the standard horizontal '
59 | 'flips and random crops, and `autoaugment` for the best '
60 | 'AutoAugment policy for cifar10. For SVHN, aa-only should be use for '
61 | 'autoaugment without random crops or flips.'
62 | 'For Imagenet, setting to autoaugment will use RandAugment. For '
63 | 'FromBrainDatasetSource datasets, this flag is ignored.')
64 | flags.DEFINE_enum(
65 | 'batch_level_augmentations', 'none', ['none', 'cutout', 'mixup', 'mixcut'],
66 | 'Augmentations that are applied at the batch level. '
67 | 'Not used by Imagenet and FromBrainDatasetSource datasets.')
68 |
69 |
70 | def main(_):
71 |
72 | tf.enable_v2_behavior()
73 | # make sure tf does not allocate gpu memory
74 | tf.config.experimental.set_visible_devices([], 'GPU')
75 |
76 | # Performance gains on TPU by switching to hardware bernoulli.
77 | def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None):
78 | lax_key = jax.lax.tie_in(rng_key, 0.0)
79 | return jax.lax.rng_uniform(lax_key, 1.0, shape) < p
80 |
81 | def set_hardware_bernoulli():
82 | jax.random.bernoulli = hardware_bernoulli
83 |
84 | set_hardware_bernoulli()
85 |
86 | # As we gridsearch the weight decay and the learning rate, we add them to the
87 | # output directory path so that each model has its own directory to save the
88 | # results in. We also add the `run_seed` which is "gridsearched" on to
89 | # replicate an experiment several times.
90 | output_dir_suffix = os.path.join(
91 | 'lr_' + str(FLAGS.learning_rate),
92 | 'wd_' + str(FLAGS.weight_decay),
93 | 'rho_' + str(FLAGS.sam_rho),
94 | 'seed_' + str(FLAGS.run_seed))
95 |
96 | output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix)
97 |
98 | if not gfile.exists(output_dir):
99 | gfile.makedirs(output_dir)
100 |
101 | num_devices = jax.local_device_count() * jax.host_count()
102 | assert FLAGS.batch_size % num_devices == 0
103 | local_batch_size = FLAGS.batch_size // num_devices
104 | info = 'Total batch size: {} ({} x {} replicas)'.format(
105 | FLAGS.batch_size, local_batch_size, num_devices)
106 | logging.info(info)
107 |
108 | if FLAGS.dataset == 'cifar10':
109 | if FLAGS.from_pretrained_checkpoint:
110 | image_size = efficientnet.name_to_image_size(FLAGS.model_name)
111 | else:
112 | image_size = None
113 | dataset_source = dataset_source_lib.Cifar10(
114 | FLAGS.batch_size // jax.host_count(),
115 | FLAGS.image_level_augmentations,
116 | FLAGS.batch_level_augmentations,
117 | image_size=image_size)
118 | elif FLAGS.dataset == 'cifar100':
119 | if FLAGS.from_pretrained_checkpoint:
120 | image_size = efficientnet.name_to_image_size(FLAGS.model_name)
121 | else:
122 | image_size = None
123 | dataset_source = dataset_source_lib.Cifar100(
124 | FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations,
125 | FLAGS.batch_level_augmentations, image_size=image_size)
126 |
127 | elif FLAGS.dataset == 'fashion_mnist':
128 | dataset_source = dataset_source_lib.FashionMnist(
129 | FLAGS.batch_size, FLAGS.image_level_augmentations,
130 | FLAGS.batch_level_augmentations)
131 | elif FLAGS.dataset == 'svhn':
132 | dataset_source = dataset_source_lib.SVHN(
133 | FLAGS.batch_size, FLAGS.image_level_augmentations,
134 | FLAGS.batch_level_augmentations)
135 | elif FLAGS.dataset == 'imagenet':
136 | imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name)
137 | dataset_source = dataset_source_imagenet.Imagenet(
138 | FLAGS.batch_size // jax.host_count(), imagenet_image_size,
139 | FLAGS.image_level_augmentations)
140 | else:
141 | raise ValueError('Dataset not recognized.')
142 |
143 | if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset:
144 | if image_size is None or 'svhn' in FLAGS.dataset:
145 | image_size = 32
146 | num_channels = 3
147 | num_classes = 100 if FLAGS.dataset == 'cifar100' else 10
148 | elif FLAGS.dataset == 'fashion_mnist':
149 | image_size = 28 # For Fashion Mnist
150 | num_channels = 1
151 | num_classes = 10
152 | elif FLAGS.dataset == 'imagenet':
153 | image_size = imagenet_image_size
154 | num_channels = 3
155 | num_classes = 1000
156 | else:
157 | raise ValueError('Dataset not recognized.')
158 |
159 | try:
160 | model, state = load_imagenet_model.get_model(FLAGS.model_name,
161 | local_batch_size, image_size,
162 | num_classes)
163 | except load_imagenet_model.ModelNameError:
164 | model, state = load_model.get_model(FLAGS.model_name,
165 | local_batch_size, image_size,
166 | num_classes, num_channels)
167 |
168 | # Learning rate will be overwritten by the lr schedule, we set it to zero.
169 | optimizer = flax_training.create_optimizer(model, 0.0)
170 |
171 | flax_training.train(optimizer, state, dataset_source, output_dir,
172 | FLAGS.num_epochs)
173 |
174 |
175 | if __name__ == '__main__':
176 | tf.enable_v2_behavior()
177 | app.run(main)
178 |
--------------------------------------------------------------------------------
/sam_jax/training_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/sam_jax/training_utils/flax_training_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sam.sam_jax.training_utils.flax_training."""
16 |
17 | import os
18 | from typing import Tuple
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | from absl.testing import flagsaver
23 | import flax
24 | import jax
25 | from jax.lib import xla_bridge
26 | import jax.numpy as jnp
27 | import pandas as pd
28 | from sam.sam_jax.datasets import dataset_source
29 | from sam.sam_jax.training_utils import flax_training
30 | import tensorflow as tf
31 | from tensorflow.io import gfile
32 |
33 |
34 | FLAGS = flags.FLAGS
35 |
36 |
37 | class MockDatasetSource(dataset_source.DatasetSource):
38 | """Simple linearly separable dataset for testing.
39 |
40 | See base class for more details.
41 | """
42 |
43 | def __init__(self, mask=False):
44 | positive_input = tf.constant([[1.0 + i/20] for i in range(20)])
45 | negative_input = tf.constant([[-1.0 - i/20] for i in range(20)])
46 | positive_labels = tf.constant([[1, 0] for _ in range(20)])
47 | negative_labels = tf.constant([[0, 1] for _ in range(20)])
48 | inputs = tf.concat((positive_input, negative_input), 0)
49 | labels = tf.concat((positive_labels, negative_labels), 0)
50 | self.inputs, self.labels = inputs.numpy(), labels.numpy()
51 | if not mask:
52 | self._ds = tf.data.Dataset.from_tensor_slices({
53 | 'image': inputs,
54 | 'label': labels
55 | })
56 | else:
57 | self._ds = tf.data.Dataset.from_tensor_slices({
58 | 'image': inputs,
59 | 'label': labels,
60 | 'mask': tf.constant([1.0 for _ in range(40)])
61 | })
62 | self.num_training_obs = 40
63 | self.batch_size = 16
64 |
65 | def get_train(self, use_augmentations: bool) -> tf.data.Dataset:
66 | """Returns the training set.
67 |
68 | Args:
69 | use_augmentations: Ignored (see base class for more details).
70 | """
71 | del use_augmentations
72 | return self._ds.batch(self.batch_size)
73 |
74 | def get_test(self) -> tf.data.Dataset:
75 | """Returns the test set."""
76 | return self._ds.batch(self.batch_size)
77 |
78 |
79 | def _get_linear_model() -> Tuple[flax.nn.Model, flax.nn.Collection]:
80 | """Returns a linear model and its state."""
81 |
82 | class LinearModel(flax.nn.Module):
83 | """Defines the linear model."""
84 |
85 | def apply(self,
86 | x: jnp.ndarray,
87 | num_outputs: int,
88 | train: bool = False) -> jnp.ndarray:
89 | """Forward pass with a linear model.
90 |
91 | Args:
92 | x: Input of shape [batch_size, num_features].
93 | num_outputs: Number of classes.
94 | train: Has no effect.
95 |
96 | Returns:
97 | A tensor of shape [batch_size, num_outputs].
98 | """
99 | del train
100 | x = flax.nn.Dense(x, num_outputs)
101 | return x
102 |
103 | input_shape, num_outputs = [1], 2
104 | module = LinearModel.partial(num_outputs=num_outputs)
105 | with flax.nn.stateful() as init_state:
106 | with flax.nn.stochastic(jax.random.PRNGKey(0)):
107 | _, initial_params = module.init_by_shape(
108 | jax.random.PRNGKey(0), [(input_shape, jnp.float32)])
109 | model = flax.nn.Model(module, initial_params)
110 | return model, init_state
111 |
112 |
113 | def tensorboard_event_to_dataframe(path: str) -> pd.DataFrame:
114 | """Helper to get events written by tests.
115 |
116 | Args:
117 | path: Path where the tensorboard records were saved.
118 |
119 | Returns:
120 | The metric saved by tensorboard, as a dataframe.
121 | """
122 | records = []
123 | all_tb_path = gfile.glob(os.path.join(path, 'events.*.v2'))
124 | for tb_event_path in all_tb_path:
125 | for e in tf.compat.v1.train.summary_iterator(tb_event_path):
126 | if e.step:
127 | for v in e.summary.value:
128 | records.append(dict(
129 | step=e.step, metric=v.tag,
130 | value=float(tf.make_ndarray(v.tensor))))
131 | df = pd.DataFrame.from_records(records)
132 | return df
133 |
134 |
135 | prev_xla_flags = None
136 |
137 |
138 | class FlaxTrainingTest(absltest.TestCase):
139 |
140 | # Run all tests with 8 CPU devices.
141 | # As in third_party/py/jax/tests/pmap_test.py
142 | def setUp(self):
143 | super(FlaxTrainingTest, self).setUp()
144 | global prev_xla_flags
145 | prev_xla_flags = os.getenv('XLA_FLAGS')
146 | flags_str = prev_xla_flags or ''
147 | # Don't override user-specified device count, or other XLA flags.
148 | if 'xla_force_host_platform_device_count' not in flags_str:
149 | os.environ['XLA_FLAGS'] = (
150 | flags_str + ' --xla_force_host_platform_device_count=8')
151 | # Clear any cached backends so new CPU backend will pick up the env var.
152 | xla_bridge.get_backend.cache_clear()
153 |
154 | # Reset to previous configuration in case other test modules will be run.
155 | def tearDown(self):
156 | super(FlaxTrainingTest, self).tearDown()
157 | if prev_xla_flags is None:
158 | del os.environ['XLA_FLAGS']
159 | else:
160 | os.environ['XLA_FLAGS'] = prev_xla_flags
161 | xla_bridge.get_backend.cache_clear()
162 |
163 | @flagsaver.flagsaver
164 | def test_TrainSimpleModel(self):
165 | """Model should reach 100% accuracy easily."""
166 | model, state = _get_linear_model()
167 | dataset = MockDatasetSource()
168 | num_epochs = 10
169 | optimizer = flax_training.create_optimizer(model, 0.0)
170 | training_dir = self.create_tempdir().full_path
171 | FLAGS.learning_rate = 0.01
172 | flax_training.train(
173 | optimizer, state, dataset, training_dir, num_epochs)
174 | records = tensorboard_event_to_dataframe(training_dir)
175 | # Train error rate at the last step should be 0.
176 | records = records[records.metric == 'train_error_rate']
177 | records = records.sort_values('step')
178 | self.assertEqual(records.value.values[-1], 0.0)
179 |
180 | @flagsaver.flagsaver
181 | def _test_ResumeTrainingAfterInterruption(self, use_ema: bool): # pylint:disable=invalid-name
182 | """Resuming training should match a run without interruption."""
183 | if use_ema:
184 | FLAGS.ema_decay = 0.9
185 | model, state = _get_linear_model()
186 | dataset = MockDatasetSource()
187 | optimizer = flax_training.create_optimizer(model, 0.0)
188 | training_dir = self.create_tempdir().full_path
189 | FLAGS.learning_rate = 0.01
190 | FLAGS.use_learning_rate_schedule = False
191 | # First we train for 10 epochs and get the logs.
192 | num_epochs = 10
193 | reference_run_dir = os.path.join(training_dir, 'reference')
194 | flax_training.train(
195 | optimizer, state, dataset, reference_run_dir, num_epochs)
196 | records = tensorboard_event_to_dataframe(reference_run_dir)
197 | # In another directory (new experiment), we run the model for 4 epochs and
198 | # then for 10 epochs, to simulate an interruption.
199 | interrupted_run_dir = os.path.join(training_dir, 'interrupted')
200 | flax_training.train(
201 | optimizer, state, dataset, interrupted_run_dir, 4)
202 | flax_training.train(
203 | optimizer, state, dataset, interrupted_run_dir, 10)
204 | records_interrupted = tensorboard_event_to_dataframe(interrupted_run_dir)
205 |
206 | # Logs should match (order doesn't matter as it is a dataframe in tidy
207 | # format).
208 | def _make_hashable(row):
209 | return str([e if not isinstance(e, float) else round(e, 5) for e in row])
210 |
211 | self.assertEqual(
212 | set([_make_hashable(e) for e in records_interrupted.values]),
213 | set([_make_hashable(e) for e in records.values]))
214 |
215 | def test_ResumeTrainingAfterInterruptionEMA(self):
216 | self._test_ResumeTrainingAfterInterruption(use_ema=True)
217 |
218 | def test_ResumeTrainingAfterInterruption(self):
219 | self._test_ResumeTrainingAfterInterruption(use_ema=False)
220 |
221 | def _RecomputeTestLoss(self, masked: int = False): # pylint:disable=invalid-name
222 | """Recomputes the loss of the final model to check the value logged."""
223 | FLAGS.compute_top_5_error_rate = True
224 | model, state = _get_linear_model()
225 | dataset = MockDatasetSource(mask=masked)
226 | num_epochs = 2
227 | optimizer = flax_training.create_optimizer(model, 0.0)
228 | training_dir = self.create_tempdir().full_path
229 | flax_training.train(
230 | optimizer, state, dataset, training_dir, num_epochs)
231 | records = tensorboard_event_to_dataframe(training_dir)
232 | records = records[records.metric == 'test_loss']
233 | final_test_loss = records.sort_values('step').value.values[-1]
234 | # Loads final model and state.
235 | optimizer, state, _ = flax_training.restore_checkpoint(
236 | optimizer, state, os.path.join(training_dir, 'checkpoints'))
237 | logits = optimizer.target(dataset.inputs)
238 | loss = flax_training.cross_entropy_loss(logits, dataset.labels)
239 | self.assertLess(abs(final_test_loss -loss), 1e-7)
240 |
241 | def test_RecomputeTestLoss(self):
242 | self._RecomputeTestLoss()
243 |
244 | def test_RecomputeTestLossMasked(self):
245 | self._RecomputeTestLoss(masked=True)
246 |
247 | def test_metrics(self):
248 | logits = jnp.array([[2.0, 1.0], [3.0, 1.0], [0.1, 1.6], [2.0, 5.0]])
249 | truth = jnp.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
250 | mask = jnp.array([1.0, 1.0, 1.0, 0.0])
251 | self.assertEqual(flax_training.error_rate_metric(logits, truth), 0.25)
252 | self.assertEqual(flax_training.error_rate_metric(logits, truth, mask), 1/3)
253 |
254 |
255 | if __name__ == '__main__':
256 | absltest.main()
257 |
--------------------------------------------------------------------------------