├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── colabs
├── distillation_self_training.ipynb
├── finetuning.ipynb
├── intriguing_properties
│ ├── README.md
│ ├── digits_on_tf_flowers.ipynb
│ ├── generalized_contrastive_loss.ipynb
│ └── randbits_mnist.ipynb
└── load_and_inference.ipynb
├── data.py
├── data_util.py
├── imagenet_subsets
├── 10percent.txt
└── 1percent.txt
├── lars_optimizer.py
├── model.py
├── model_util.py
├── objective.py
├── requirements.txt
├── resnet.py
├── run.py
└── tf2
├── README.md
├── colabs
├── distillation_self_training.ipynb
├── finetuning.ipynb
├── imagenet_results.ipynb
└── load_and_inference.ipynb
├── data.py
├── data_util.py
├── lars_optimizer.py
├── metrics.py
├── model.py
├── objective.py
├── requirements.txt
├── resnet.py
└── run.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | SimCLR needs to maintain permanent compatibility with the pre-trained model
4 | files, so we do not plan to make any major changes to this library (other than
5 | what was promised in the README). However, we can accept small patches related
6 | to re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimCLR - A Simple Framework for Contrastive Learning of Visual Representations
2 |
3 | News! We have released a TF2 implementation of SimCLR (along with converted checkpoints in TF2), they are in tf2/ folder.
4 |
5 | News! Colabs for Intriguing Properties of Contrastive Losses are added, see here.
6 |
7 |
8 |

9 |
10 |
13 |
14 | ## Pre-trained models for SimCLRv2
15 |
16 |
17 | We opensourced total 65 pretrained models here, corresponding to those in Table 1 of the SimCLRv2 paper:
18 |
19 | | Depth | Width | SK | Param (M) | F-T (1%) | F-T(10%) | F-T(100%) | Linear eval | Supervised |
20 | |--------:|--------:|------:|--------:|-------------:|--------------:|---------------:|-----------------:|--------------:|
21 | | 50 | 1X | False | 24 | 57.9 | 68.4 | 76.3 | 71.7 | 76.6 |
22 | | 50 | 1X | True | 35 | 64.5 | 72.1 | 78.7 | 74.6 | 78.5 |
23 | | 50 | 2X | False | 94 | 66.3 | 73.9 | 79.1 | 75.6 | 77.8 |
24 | | 50 | 2X | True | 140 | 70.6 | 77.0 | 81.3 | 77.7 | 79.3 |
25 | | 101 | 1X | False | 43 | 62.1 | 71.4 | 78.2 | 73.6 | 78.0 |
26 | | 101 | 1X | True | 65 | 68.3 | 75.1 | 80.6 | 76.3 | 79.6 |
27 | | 101 | 2X | False | 170 | 69.1 | 75.8 | 80.7 | 77.0 | 78.9 |
28 | | 101 | 2X | True | 257 | 73.2 | 78.8 | 82.4 | 79.0 | 80.1 |
29 | | 152 | 1X | False | 58 | 64.0 | 73.0 | 79.3 | 74.5 | 78.3 |
30 | | 152 | 1X | True | 89 | 70.0 | 76.5 | 81.3 | 77.2 | 79.9 |
31 | | 152 | 2X | False | 233 | 70.2 | 76.6 | 81.1 | 77.4 | 79.1 |
32 | | 152 | 2X | True | 354 | 74.2 | 79.4 | 82.9 | 79.4 | 80.4 |
33 | | 152 | 3X | True | 795 | 74.9 | 80.1 | 83.1 | 79.8 | 80.5 |
34 |
35 | These checkpoints are stored in Google Cloud Storage:
36 |
37 | * Pretrained SimCLRv2 models (with linear eval head): [gs://simclr-checkpoints/simclrv2/pretrained](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/pretrained)
38 | * Fine-tuned SimCLRv2 models on 1% of labels: [gs://simclr-checkpoints/simclrv2/finetuned_1pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/finetuned_1pct)
39 | * Fine-tuned SimCLRv2 models on 10% of labels: [gs://simclr-checkpoints/simclrv2/finetuned_10pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/finetuned_10pct)
40 | * Fine-tuned SimCLRv2 models on 100% of labels: [gs://simclr-checkpoints/simclrv2/finetuned_100pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/finetuned_100pct)
41 | * Supervised models with the same architectures: [gs://simclr-checkpoints/simclrv2/supervised](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/supervised)
42 | * The distilled / self-trained models (after fine-tuning) are also provided:
43 | * [gs://simclr-checkpoints/simclrv2/distill_1pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/distill_1pct)
44 | * [gs://simclr-checkpoints/simclrv2/distill_10pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/distill_10pct)
45 |
46 | We also provide examples on how to use the checkpoints in `colabs/` folder.
47 |
48 | ## Pre-trained models for SimCLRv1
49 |
50 | The pre-trained models (base network with linear classifier layer) can be found below. Note that for these SimCLRv1 checkpoints, the projection head is not available.
51 |
52 | | Model checkpoint and hub-module | ImageNet Top-1 |
53 | |-----------------------------------------------------------------------------------------|------------------------|
54 | |[ResNet50 (1x)](https://storage.cloud.google.com/simclr-gcs/checkpoints/ResNet50_1x.zip) | 69.1 |
55 | |[ResNet50 (2x)](https://storage.cloud.google.com/simclr-gcs/checkpoints/ResNet50_2x.zip) | 74.2 |
56 | |[ResNet50 (4x)](https://storage.cloud.google.com/simclr-gcs/checkpoints/ResNet50_4x.zip) | 76.6 |
57 |
58 | Additional SimCLRv1 checkpoints are available: [gs://simclr-checkpoints/simclrv1](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv1).
59 |
60 | A note on the signatures of the TensorFlow Hub module: `default` is the representation output of the base network; `logits_sup` is the supervised classification logits for ImageNet 1000 categories. Others (e.g. `initial_max_pool`, `block_group1`) are middle layers of ResNet; refer to resnet.py for the specifics. See this [tutorial](https://www.tensorflow.org/hub/tf1_hub_module) for additional information regarding use of TensorFlow Hub modules.
61 |
62 | ## Enviroment setup
63 |
64 | Our models are trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining.
65 |
66 | Our code can also run on a *single* GPU. It does not support multi-GPUs, for reasons such as global BatchNorm and contrastive loss across cores.
67 |
68 | The code is compatible with both TensorFlow v1 and v2. See requirements.txt for all prerequisites, and you can also install them using the following command.
69 |
70 | ```
71 | pip install -r requirements.txt
72 | ```
73 |
74 | ## Pretraining
75 |
76 | To pretrain the model on CIFAR-10 with a *single* GPU, try the following command:
77 |
78 | ```
79 | python run.py --train_mode=pretrain \
80 | --train_batch_size=512 --train_epochs=1000 \
81 | --learning_rate=1.0 --weight_decay=1e-4 --temperature=0.5 \
82 | --dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
83 | --use_blur=False --color_jitter_strength=0.5 \
84 | --model_dir=/tmp/simclr_test --use_tpu=False
85 | ```
86 |
87 | To pretrain the model on ImageNet with Cloud TPUs, first check out the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for basic information on how to use Google Cloud TPUs.
88 |
89 | Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for [tensorflow_datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012), please set the following enviroment variables:
90 |
91 | ```
92 | TPU_NAME=
93 | STORAGE_BUCKET=gs://
94 | DATA_DIR=$STORAGE_BUCKET/
95 | MODEL_DIR=$STORAGE_BUCKET/
96 | ```
97 |
98 | The following command can be used to pretrain a ResNet-50 on ImageNet (which reflects the default hyperparameters in our paper):
99 |
100 | ```
101 | python run.py --train_mode=pretrain \
102 | --train_batch_size=4096 --train_epochs=100 --temperature=0.1 \
103 | --learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
104 | --dataset=imagenet2012 --image_size=224 --eval_split=validation \
105 | --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
106 | --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
107 | ```
108 |
109 | A batch size of 4096 requires at least 32 TPUs. 100 epochs takes around 6 hours with 32 TPU v3s. Note that learning rate of 0.3 with `learning_rate_scaling=linear` is equivalent to that of 0.075 with `learning_rate_scaling=sqrt` when the batch size is 4096. However, using sqrt scaling allows it to train better when smaller batch size is used.
110 |
111 | ## Finetuning the linear head (linear eval)
112 |
113 | To fine-tune a linear head (with a single GPU), try the following command:
114 |
115 | ```
116 | python run.py --mode=train_then_eval --train_mode=finetune \
117 | --fine_tune_after_block=4 --zero_init_logits_layer=True \
118 | --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
119 | --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \
120 | --train_epochs=100 --train_batch_size=512 --warmup_epochs=0 \
121 | --dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
122 | --checkpoint=/tmp/simclr_test --model_dir=/tmp/simclr_test_ft --use_tpu=False
123 | ```
124 |
125 | You can check the results using tensorboard, such as
126 |
127 | ```
128 | python -m tensorboard.main --logdir=/tmp/simclr_test
129 | ```
130 |
131 | As a reference, the above runs on CIFAR-10 should give you around 91% accuracy, though it can be further optimized.
132 |
133 | For fine-tuning a linear head on ImageNet using Cloud TPUs, first set the `CHKPT_DIR` to pretrained model dir and set a new `MODEL_DIR`, then use the following command:
134 |
135 | ```
136 | python run.py --mode=train_then_eval --train_mode=finetune \
137 | --fine_tune_after_block=4 --zero_init_logits_layer=True \
138 | --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
139 | --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=1e-6 \
140 | --train_epochs=90 --train_batch_size=4096 --warmup_epochs=0 \
141 | --dataset=imagenet2012 --image_size=224 --eval_split=validation \
142 | --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR \
143 | --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
144 | ```
145 |
146 | As a reference, the above runs on ImageNet should give you around 64.5% accuracy.
147 |
148 | ## Semi-supervised learning and fine-tuning the whole network
149 |
150 | You can access 1% and 10% ImageNet subsets used for semi-supervised learning via [tensorflow datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012_subset): simply set `dataset=imagenet2012_subset/1pct` and `dataset=imagenet2012_subset/10pct` in the command line for fine-tuning on these subsets.
151 |
152 | You can also find image IDs of these subsets in `imagenet_subsets/`.
153 |
154 | To fine-tune the whole network on ImageNet (1% of labels), refer to the following command:
155 |
156 | ```
157 | python run.py --mode=train_then_eval --train_mode=finetune \
158 | --fine_tune_after_block=-1 --zero_init_logits_layer=True \
159 | --variable_schema='(?!global_step|(?:.*/|^)Momentum|head_supervised)' \
160 | --global_bn=True --optimizer=lars --learning_rate=0.005 \
161 | --learning_rate_scaling=sqrt --weight_decay=0 \
162 | --train_epochs=60 --train_batch_size=1024 --warmup_epochs=0 \
163 | --dataset=imagenet2012_subset/1pct --image_size=224 --eval_split=validation \
164 | --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR \
165 | --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 \
166 | --num_proj_layers=3 --ft_proj_selector=1
167 | ```
168 |
169 | Set the `checkpoint` to those that are only pre-trained but not fine-tuned. Given that SimCLRv1 checkpoints do not contain projection head, it is recommended to run with SimCLRv2 checkpoints (you can still run with SimCLRv1 checkpoints, but `variable_schema` needs to exclude `head`). The `num_proj_layers` and `ft_proj_selector` need to be adjusted accordingly following SimCLRv2 paper to obtain best performances.
170 |
171 | ## Other resources
172 |
173 | ### Model conversion to Pytorch format
174 |
175 | This [repo](https://github.com/tonylins/simclr-converter) provides a solution for converting the pretrained SimCLRv1 Tensorflow checkpoints into Pytorch ones.
176 |
177 | This [repo](https://github.com/Separius/SimCLRv2-Pytorch) provides a solution for converting the pretrained SimCLRv2 Tensorflow checkpoints into Pytorch ones.
178 |
179 | ### Other *non-offical* / *unverified* implementations
180 |
181 | (Feel free to share your implementation by creating an issue)
182 |
183 | Implementations in PyTorch:
184 | * [leftthomas](https://github.com/leftthomas/SimCLR)
185 | * [AndrewAtanov](https://github.com/AndrewAtanov/simclr-pytorch)
186 | * [ae-foster](https://github.com/ae-foster/pytorch-simclr)
187 | * [Spijkervet](https://github.com/Spijkervet/SimCLR)
188 | * [williamFalcon](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr)
189 |
190 | Implementations in Tensorflow 2 / Keras (official TF2 implementation was added in tf2/ folder):
191 | * [sayakpaul](https://github.com/sayakpaul/SimCLR-in-TensorFlow-2)
192 | * [mwdhont](https://github.com/mwdhont/SimCLRv1-keras-tensorflow)
193 |
194 | ## Known issues
195 |
196 | * **Batch size**: original results of SimCLR were tuned under a large batch size (i.e. 4096), which leads to suboptimal results when training using a smaller batch size. However, with a good set of hyper-parameters (mainly learning rate, temperature, projection head depth), small batch sizes can yield results that are on par with large batch sizes (e.g., see Table 2 in [this paper](https://arxiv.org/pdf/2011.02803.pdf)).
197 |
198 | * **Pretrained models / Checkpoints**: SimCLRv1 and SimCLRv2 are pretrained with different weight decays, so the pretrained models from the two versions have very different weight norm scales (convolutional weights in SimCLRv1 ResNet-50 are on average 16.8X of that in SimCLRv2). For fine-tuning the pretrained models from both versions, it is fine if you use an LARS optimizer, but it requires very different hyperparameters (e.g. learning rate, weight decay) if you use the momentum optimizer. So for the latter case, you may want to either search for very different hparams according to which version used, or re-scale th weight (i.e. conv `kernel` parameters of `base_model` in the checkpoints) to make sure they're roughly in the same scale.
199 |
200 | ## Cite
201 |
202 | [SimCLR paper](https://arxiv.org/abs/2002.05709):
203 |
204 | ```
205 | @article{chen2020simple,
206 | title={A Simple Framework for Contrastive Learning of Visual Representations},
207 | author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey},
208 | journal={arXiv preprint arXiv:2002.05709},
209 | year={2020}
210 | }
211 | ```
212 |
213 | [SimCLRv2 paper](https://arxiv.org/abs/2006.10029):
214 |
215 | ```
216 | @article{chen2020big,
217 | title={Big Self-Supervised Models are Strong Semi-Supervised Learners},
218 | author={Chen, Ting and Kornblith, Simon and Swersky, Kevin and Norouzi, Mohammad and Hinton, Geoffrey},
219 | journal={arXiv preprint arXiv:2006.10029},
220 | year={2020}
221 | }
222 | ```
223 |
224 | ## Disclaimer
225 | This is not an official Google product.
226 |
--------------------------------------------------------------------------------
/colabs/intriguing_properties/README.md:
--------------------------------------------------------------------------------
1 | ## Intriguing Properties of Contrastive Losses
2 |
3 | This folder contains code for the paper titled [Intriguing Properties of Contrastive Losses](https://arxiv.org/abs/2011.02803). And here is the [link](https://contrastive-learning.github.io/intriguing/) to its website/blog.
4 |
5 | ```
6 | @article{chen2021intriguing,
7 | title={Intriguing Properties of Contrastive Losses},
8 | author={Chen, Ting and Luo, Calvin and Li, Lala},
9 | journal={Advances in Neural Information Processing Systems},
10 | volume={34},
11 | year={2021}
12 | }
13 | ```
14 |
15 |
16 |
--------------------------------------------------------------------------------
/colabs/intriguing_properties/generalized_contrastive_loss.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "generalized_contrastive_loss.ipynb",
7 | "provenance": [
8 | {
9 | "file_id": "1gP6orB_1mRXdMTrKi8xvL06_5f3u9Nw2",
10 | "timestamp": 1604290125178
11 | }
12 | ],
13 | "collapsed_sections": [],
14 | "last_runtime": {
15 | "build_target": "//learning/deepmind/dm_python:dm_notebook3",
16 | "kind": "private"
17 | }
18 | },
19 | "kernelspec": {
20 | "name": "python3",
21 | "display_name": "Python 3"
22 | }
23 | },
24 | "cells": [
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {
28 | "id": "pEAjCLI8QCjU"
29 | },
30 | "source": [
31 | "##### Copyright 2020 Google LLC.\n",
32 | "\n",
33 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
34 | "you may not use this file except in compliance with the License.\n",
35 | "You may obtain a copy of the License at\n",
36 | "\n",
37 | "https://www.apache.org/licenses/LICENSE-2.0\n",
38 | "\n",
39 | "Unless required by applicable law or agreed to in writing, software\n",
40 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
41 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
42 | "See the License for the specific language governing permissions and\n",
43 | "limitations under the License."
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {
49 | "id": "2lCKaZSM2Ac0"
50 | },
51 | "source": [
52 | "## Generalized contrastive loss\n",
53 | "\n",
54 | "This notebook contains implementation of generalized contrastive loss from ***Intriguing Properties of Contrastive Losses***."
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "metadata": {
60 | "id": "0amaY7x4wgGr",
61 | "cellView": "both"
62 | },
63 | "source": [
64 | "def generalized_contrastive_loss(\n",
65 | " hidden1,\n",
66 | " hidden2,\n",
67 | " lambda_weight=1.0,\n",
68 | " temperature=1.0,\n",
69 | " dist='normal',\n",
70 | " hidden_norm=True,\n",
71 | " loss_scaling=1.0):\n",
72 | " \"\"\"Generalized contrastive loss.\n",
73 | "\n",
74 | " Both hidden1 and hidden2 should have shape of (n, d).\n",
75 | "\n",
76 | " Configurations to get following losses:\n",
77 | " * decoupled NT-Xent loss: set dist='logsumexp', hidden_norm=True\n",
78 | " * SWD with normal distribution: set dist='normal', hidden_norm=False\n",
79 | " * SWD with uniform hypersphere: set dist='normal', hidden_norm=True\n",
80 | " * SWD with uniform hypercube: set dist='uniform', hidden_norm=False\n",
81 | " \"\"\"\n",
82 | " hidden_dim = hidden1.shape[-1] # get hidden dimension\n",
83 | " if hidden_norm:\n",
84 | " hidden1 = tf.math.l2_normalize(hidden1, -1)\n",
85 | " hidden2 = tf.math.l2_normalize(hidden2, -1)\n",
86 | " loss_align = tf.reduce_mean((hidden1 - hidden2)**2) / 2.\n",
87 | " hiddens = tf.concat([hidden1, hidden2], 0)\n",
88 | " if dist == 'logsumexp':\n",
89 | " loss_dist_match = get_logsumexp_loss(hiddens, temperature)\n",
90 | " else:\n",
91 | " initializer = tf.keras.initializers.Orthogonal()\n",
92 | " rand_w = initializer([hidden_dim, hidden_dim])\n",
93 | " loss_dist_match = get_swd_loss(hiddens, rand_w,\n",
94 | " prior=dist,\n",
95 | " hidden_norm=hidden_norm)\n",
96 | " return loss_scaling * (loss_align + lambda_weight * loss_dist_match)"
97 | ],
98 | "execution_count": null,
99 | "outputs": []
100 | },
101 | {
102 | "cell_type": "code",
103 | "metadata": {
104 | "id": "gsXaxQOrJxHe"
105 | },
106 | "source": [
107 | "# Utilities for loss implementation.\n",
108 | "\n",
109 | "\n",
110 | "def get_logsumexp_loss(states, temperature):\n",
111 | " scores = tf.matmul(states, states, transpose_b=True) # (bsz, bsz)\n",
112 | " bias = tf.math.log(tf.cast(tf.shape(states)[1], tf.float32)) # a constant\n",
113 | " return tf.reduce_mean(\n",
114 | " tf.math.reduce_logsumexp(scores / temperature, 1) - bias)\n",
115 | "\n",
116 | "\n",
117 | "def sort(x):\n",
118 | " \"\"\"Returns the matrix x where each row is sorted (ascending).\"\"\"\n",
119 | " xshape = tf.shape(x)\n",
120 | " rank = tf.reduce_sum(\n",
121 | " tf.cast(tf.expand_dims(x, 2) > tf.expand_dims(x, 1), tf.int32), axis=2)\n",
122 | " rank_inv = tf.einsum(\n",
123 | " 'dbc,c->db',\n",
124 | " tf.transpose(tf.cast(tf.one_hot(rank, xshape[1]), tf.float32), [0, 2, 1]),\n",
125 | " tf.range(xshape[1], dtype='float32')) # (dim, bsz)\n",
126 | " x = tf.gather(x, tf.cast(rank_inv, tf.int32), axis=-1, batch_dims=-1)\n",
127 | " return x\n",
128 | "\n",
129 | "\n",
130 | "def get_swd_loss(states, rand_w, prior='normal', stddev=1., hidden_norm=True):\n",
131 | " states_shape = tf.shape(states)\n",
132 | " states = tf.matmul(states, rand_w)\n",
133 | " states_t = sort(tf.transpose(states)) # (dim, bsz)\n",
134 | "\n",
135 | " if prior == 'normal':\n",
136 | " states_prior = tf.random.normal(states_shape, mean=0, stddev=stddev)\n",
137 | " elif prior == 'uniform':\n",
138 | " states_prior = tf.random.uniform(states_shape, -stddev, stddev)\n",
139 | " else:\n",
140 | " raise ValueError('Unknown prior {}'.format(prior))\n",
141 | " if hidden_norm:\n",
142 | " states_prior = tf.math.l2_normalize(states_prior, -1)\n",
143 | " states_prior = tf.matmul(states_prior, rand_w)\n",
144 | " states_prior_t = sort(tf.transpose(states_prior)) # (dim, bsz)\n",
145 | "\n",
146 | " return tf.reduce_mean((states_prior_t - states_t)**2)"
147 | ],
148 | "execution_count": null,
149 | "outputs": []
150 | }
151 | ]
152 | }
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data pipeline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from absl import flags
24 |
25 | import data_util as data_util
26 | import tensorflow.compat.v1 as tf
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | def pad_to_batch(dataset, batch_size):
32 | """Pad Tensors to specified batch size.
33 |
34 | Args:
35 | dataset: An instance of tf.data.Dataset.
36 | batch_size: The number of samples per batch of input requested.
37 |
38 | Returns:
39 | An instance of tf.data.Dataset that yields the same Tensors with the same
40 | structure as the original padded to batch_size along the leading
41 | dimension.
42 |
43 | Raises:
44 | ValueError: If the dataset does not comprise any tensors; if a tensor
45 | yielded by the dataset has an unknown number of dimensions or is a
46 | scalar; or if it can be statically determined that tensors comprising
47 | a single dataset element will have different leading dimensions.
48 | """
49 | def _pad_to_batch(*args):
50 | """Given Tensors yielded by a Dataset, pads all to the batch size."""
51 | flat_args = tf.nest.flatten(args)
52 |
53 | for tensor in flat_args:
54 | if tensor.shape.ndims is None:
55 | raise ValueError(
56 | 'Unknown number of dimensions for tensor %s.' % tensor.name)
57 | if tensor.shape.ndims == 0:
58 | raise ValueError('Tensor %s is a scalar.' % tensor.name)
59 |
60 | # This will throw if flat_args is empty. However, as of this writing,
61 | # tf.data.Dataset.map will throw first with an internal error, so we do
62 | # not check this case explicitly.
63 | first_tensor = flat_args[0]
64 | first_tensor_shape = tf.shape(first_tensor)
65 | first_tensor_batch_size = first_tensor_shape[0]
66 | difference = batch_size - first_tensor_batch_size
67 |
68 | for i, tensor in enumerate(flat_args):
69 | control_deps = []
70 | if i != 0:
71 | # Check that leading dimensions of this tensor matches the first,
72 | # either statically or dynamically. (If the first dimensions of both
73 | # tensors are statically known, the we have to check the static
74 | # shapes at graph construction time or else we will never get to the
75 | # dynamic assertion.)
76 | if (first_tensor.shape[:1].is_fully_defined() and
77 | tensor.shape[:1].is_fully_defined()):
78 | if first_tensor.shape[0] != tensor.shape[0]:
79 | raise ValueError(
80 | 'Batch size of dataset tensors does not match. %s '
81 | 'has shape %s, but %s has shape %s' % (
82 | first_tensor.name, first_tensor.shape,
83 | tensor.name, tensor.shape))
84 | else:
85 | curr_shape = tf.shape(tensor)
86 | control_deps = [tf.Assert(
87 | tf.equal(curr_shape[0], first_tensor_batch_size),
88 | ['Batch size of dataset tensors %s and %s do not match. '
89 | 'Shapes are' % (tensor.name, first_tensor.name), curr_shape,
90 | first_tensor_shape])]
91 |
92 | with tf.control_dependencies(control_deps):
93 | # Pad to batch_size along leading dimension.
94 | flat_args[i] = tf.pad(
95 | tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1))
96 | flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:])
97 |
98 | return tf.nest.pack_sequence_as(args, flat_args)
99 |
100 | return dataset.map(_pad_to_batch)
101 |
102 |
103 | def build_input_fn(builder, is_training):
104 | """Build input function.
105 |
106 | Args:
107 | builder: TFDS builder for specified dataset.
108 | is_training: Whether to build in training mode.
109 |
110 | Returns:
111 | A function that accepts a dict of params and returns a tuple of images and
112 | features, to be used as the input_fn in TPUEstimator.
113 | """
114 | def _input_fn(params):
115 | """Inner input function."""
116 | preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True)
117 | preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False)
118 | num_classes = builder.info.features['label'].num_classes
119 |
120 | def map_fn(image, label):
121 | """Produces multiple transformations of the same batch."""
122 | if FLAGS.train_mode == 'pretrain':
123 | xs = []
124 | for _ in range(2): # Two transformations
125 | xs.append(preprocess_fn_pretrain(image))
126 | image = tf.concat(xs, -1)
127 | label = tf.zeros([num_classes])
128 | else:
129 | image = preprocess_fn_finetune(image)
130 | label = tf.one_hot(label, num_classes)
131 | return image, label, 1.0
132 |
133 | dataset = builder.as_dataset(
134 | split=FLAGS.train_split if is_training else FLAGS.eval_split,
135 | shuffle_files=is_training, as_supervised=True)
136 | if FLAGS.cache_dataset:
137 | dataset = dataset.cache()
138 | if is_training:
139 | buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10
140 | dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier)
141 | dataset = dataset.repeat(-1)
142 | dataset = dataset.map(map_fn,
143 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
144 | dataset = dataset.batch(params['batch_size'], drop_remainder=is_training)
145 | dataset = pad_to_batch(dataset, params['batch_size'])
146 | images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next()
147 |
148 | return images, {'labels': labels, 'mask': mask}
149 | return _input_fn
150 |
151 |
152 | def get_preprocess_fn(is_training, is_pretrain):
153 | """Get function that accepts an image and returns a preprocessed image."""
154 | # Disable test cropping for small images (e.g. CIFAR)
155 | if FLAGS.image_size <= 32:
156 | test_crop = False
157 | else:
158 | test_crop = True
159 | return functools.partial(
160 | data_util.preprocess_image,
161 | height=FLAGS.image_size,
162 | width=FLAGS.image_size,
163 | is_training=is_training,
164 | color_distort=is_pretrain,
165 | test_crop=test_crop)
166 |
--------------------------------------------------------------------------------
/data_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data preprocessing and augmentation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from absl import flags
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | FLAGS = flags.FLAGS
28 |
29 | CROP_PROPORTION = 0.875 # Standard for ImageNet.
30 |
31 |
32 | def random_apply(func, p, x):
33 | """Randomly apply function func to x with probability p."""
34 | return tf.cond(
35 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32),
36 | tf.cast(p, tf.float32)),
37 | lambda: func(x),
38 | lambda: x)
39 |
40 |
41 | def random_brightness(image, max_delta, impl='simclrv2'):
42 | """A multiplicative vs additive change of brightness."""
43 | if impl == 'simclrv2':
44 | factor = tf.random_uniform(
45 | [], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
46 | image = image * factor
47 | elif impl == 'simclrv1':
48 | image = tf.image.random_brightness(image, max_delta=max_delta)
49 | else:
50 | raise ValueError('Unknown impl {} for random brightness.'.format(impl))
51 | return image
52 |
53 |
54 | def to_grayscale(image, keep_channels=True):
55 | image = tf.image.rgb_to_grayscale(image)
56 | if keep_channels:
57 | image = tf.tile(image, [1, 1, 3])
58 | return image
59 |
60 |
61 | def color_jitter(image, strength, random_order=True, impl='simclrv2'):
62 | """Distorts the color of the image.
63 |
64 | Args:
65 | image: The input image tensor.
66 | strength: the floating number for the strength of the color augmentation.
67 | random_order: A bool, specifying whether to randomize the jittering order.
68 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
69 | version of random brightness.
70 |
71 | Returns:
72 | The distorted image tensor.
73 | """
74 | brightness = 0.8 * strength
75 | contrast = 0.8 * strength
76 | saturation = 0.8 * strength
77 | hue = 0.2 * strength
78 | if random_order:
79 | return color_jitter_rand(
80 | image, brightness, contrast, saturation, hue, impl=impl)
81 | else:
82 | return color_jitter_nonrand(
83 | image, brightness, contrast, saturation, hue, impl=impl)
84 |
85 |
86 | def color_jitter_nonrand(image,
87 | brightness=0,
88 | contrast=0,
89 | saturation=0,
90 | hue=0,
91 | impl='simclrv2'):
92 | """Distorts the color of the image (jittering order is fixed).
93 |
94 | Args:
95 | image: The input image tensor.
96 | brightness: A float, specifying the brightness for color jitter.
97 | contrast: A float, specifying the contrast for color jitter.
98 | saturation: A float, specifying the saturation for color jitter.
99 | hue: A float, specifying the hue for color jitter.
100 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
101 | version of random brightness.
102 |
103 | Returns:
104 | The distorted image tensor.
105 | """
106 | with tf.name_scope('distort_color'):
107 | def apply_transform(i, x, brightness, contrast, saturation, hue):
108 | """Apply the i-th transformation."""
109 | if brightness != 0 and i == 0:
110 | x = random_brightness(x, max_delta=brightness, impl=impl)
111 | elif contrast != 0 and i == 1:
112 | x = tf.image.random_contrast(
113 | x, lower=1-contrast, upper=1+contrast)
114 | elif saturation != 0 and i == 2:
115 | x = tf.image.random_saturation(
116 | x, lower=1-saturation, upper=1+saturation)
117 | elif hue != 0:
118 | x = tf.image.random_hue(x, max_delta=hue)
119 | return x
120 |
121 | for i in range(4):
122 | image = apply_transform(i, image, brightness, contrast, saturation, hue)
123 | image = tf.clip_by_value(image, 0., 1.)
124 | return image
125 |
126 |
127 | def color_jitter_rand(image,
128 | brightness=0,
129 | contrast=0,
130 | saturation=0,
131 | hue=0,
132 | impl='simclrv2'):
133 | """Distorts the color of the image (jittering order is random).
134 |
135 | Args:
136 | image: The input image tensor.
137 | brightness: A float, specifying the brightness for color jitter.
138 | contrast: A float, specifying the contrast for color jitter.
139 | saturation: A float, specifying the saturation for color jitter.
140 | hue: A float, specifying the hue for color jitter.
141 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
142 | version of random brightness.
143 |
144 | Returns:
145 | The distorted image tensor.
146 | """
147 | with tf.name_scope('distort_color'):
148 | def apply_transform(i, x):
149 | """Apply the i-th transformation."""
150 | def brightness_foo():
151 | if brightness == 0:
152 | return x
153 | else:
154 | return random_brightness(x, max_delta=brightness, impl=impl)
155 |
156 | def contrast_foo():
157 | if contrast == 0:
158 | return x
159 | else:
160 | return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
161 | def saturation_foo():
162 | if saturation == 0:
163 | return x
164 | else:
165 | return tf.image.random_saturation(
166 | x, lower=1-saturation, upper=1+saturation)
167 | def hue_foo():
168 | if hue == 0:
169 | return x
170 | else:
171 | return tf.image.random_hue(x, max_delta=hue)
172 | x = tf.cond(tf.less(i, 2),
173 | lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
174 | lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
175 | return x
176 |
177 | perm = tf.random_shuffle(tf.range(4))
178 | for i in range(4):
179 | image = apply_transform(perm[i], image)
180 | image = tf.clip_by_value(image, 0., 1.)
181 | return image
182 |
183 |
184 | def _compute_crop_shape(
185 | image_height, image_width, aspect_ratio, crop_proportion):
186 | """Compute aspect ratio-preserving shape for central crop.
187 |
188 | The resulting shape retains `crop_proportion` along one side and a proportion
189 | less than or equal to `crop_proportion` along the other side.
190 |
191 | Args:
192 | image_height: Height of image to be cropped.
193 | image_width: Width of image to be cropped.
194 | aspect_ratio: Desired aspect ratio (width / height) of output.
195 | crop_proportion: Proportion of image to retain along the less-cropped side.
196 |
197 | Returns:
198 | crop_height: Height of image after cropping.
199 | crop_width: Width of image after cropping.
200 | """
201 | image_width_float = tf.cast(image_width, tf.float32)
202 | image_height_float = tf.cast(image_height, tf.float32)
203 |
204 | def _requested_aspect_ratio_wider_than_image():
205 | crop_height = tf.cast(tf.rint(
206 | crop_proportion / aspect_ratio * image_width_float), tf.int32)
207 | crop_width = tf.cast(tf.rint(
208 | crop_proportion * image_width_float), tf.int32)
209 | return crop_height, crop_width
210 |
211 | def _image_wider_than_requested_aspect_ratio():
212 | crop_height = tf.cast(
213 | tf.rint(crop_proportion * image_height_float), tf.int32)
214 | crop_width = tf.cast(tf.rint(
215 | crop_proportion * aspect_ratio *
216 | image_height_float), tf.int32)
217 | return crop_height, crop_width
218 |
219 | return tf.cond(
220 | aspect_ratio > image_width_float / image_height_float,
221 | _requested_aspect_ratio_wider_than_image,
222 | _image_wider_than_requested_aspect_ratio)
223 |
224 |
225 | def center_crop(image, height, width, crop_proportion):
226 | """Crops to center of image and rescales to desired size.
227 |
228 | Args:
229 | image: Image Tensor to crop.
230 | height: Height of image to be cropped.
231 | width: Width of image to be cropped.
232 | crop_proportion: Proportion of image to retain along the less-cropped side.
233 |
234 | Returns:
235 | A `height` x `width` x channels Tensor holding a central crop of `image`.
236 | """
237 | shape = tf.shape(image)
238 | image_height = shape[0]
239 | image_width = shape[1]
240 | crop_height, crop_width = _compute_crop_shape(
241 | image_height, image_width, width / height, crop_proportion)
242 | offset_height = ((image_height - crop_height) + 1) // 2
243 | offset_width = ((image_width - crop_width) + 1) // 2
244 | image = tf.image.crop_to_bounding_box(
245 | image, offset_height, offset_width, crop_height, crop_width)
246 |
247 | image = tf.image.resize_bicubic([image], [height, width])[0]
248 |
249 | return image
250 |
251 |
252 | def distorted_bounding_box_crop(image,
253 | bbox,
254 | min_object_covered=0.1,
255 | aspect_ratio_range=(0.75, 1.33),
256 | area_range=(0.05, 1.0),
257 | max_attempts=100,
258 | scope=None):
259 | """Generates cropped_image using one of the bboxes randomly distorted.
260 |
261 | See `tf.image.sample_distorted_bounding_box` for more documentation.
262 |
263 | Args:
264 | image: `Tensor` of image data.
265 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
266 | where each coordinate is [0, 1) and the coordinates are arranged
267 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
268 | image.
269 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
270 | area of the image must contain at least this fraction of any bounding
271 | box supplied.
272 | aspect_ratio_range: An optional list of `float`s. The cropped area of the
273 | image must have an aspect ratio = width / height within this range.
274 | area_range: An optional list of `float`s. The cropped area of the image
275 | must contain a fraction of the supplied image within in this range.
276 | max_attempts: An optional `int`. Number of attempts at generating a cropped
277 | region of the image of the specified constraints. After `max_attempts`
278 | failures, return the entire image.
279 | scope: Optional `str` for name scope.
280 | Returns:
281 | (cropped image `Tensor`, distorted bbox `Tensor`).
282 | """
283 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
284 | shape = tf.shape(image)
285 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
286 | shape,
287 | bounding_boxes=bbox,
288 | min_object_covered=min_object_covered,
289 | aspect_ratio_range=aspect_ratio_range,
290 | area_range=area_range,
291 | max_attempts=max_attempts,
292 | use_image_if_no_bounding_boxes=True)
293 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
294 |
295 | # Crop the image to the specified bounding box.
296 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
297 | target_height, target_width, _ = tf.unstack(bbox_size)
298 | image = tf.image.crop_to_bounding_box(
299 | image, offset_y, offset_x, target_height, target_width)
300 |
301 | return image
302 |
303 |
304 | def crop_and_resize(image, height, width):
305 | """Make a random crop and resize it to height `height` and width `width`.
306 |
307 | Args:
308 | image: Tensor representing the image.
309 | height: Desired image height.
310 | width: Desired image width.
311 |
312 | Returns:
313 | A `height` x `width` x channels Tensor holding a random crop of `image`.
314 | """
315 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
316 | aspect_ratio = width / height
317 | image = distorted_bounding_box_crop(
318 | image,
319 | bbox,
320 | min_object_covered=0.1,
321 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
322 | area_range=(0.08, 1.0),
323 | max_attempts=100,
324 | scope=None)
325 | return tf.image.resize_bicubic([image], [height, width])[0]
326 |
327 |
328 | def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
329 | """Blurs the given image with separable convolution.
330 |
331 |
332 | Args:
333 | image: Tensor of shape [height, width, channels] and dtype float to blur.
334 | kernel_size: Integer Tensor for the size of the blur kernel. This is should
335 | be an odd number. If it is an even number, the actual kernel size will be
336 | size + 1.
337 | sigma: Sigma value for gaussian operator.
338 | padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
339 |
340 | Returns:
341 | A Tensor representing the blurred image.
342 | """
343 | radius = tf.to_int32(kernel_size / 2)
344 | kernel_size = radius * 2 + 1
345 | x = tf.to_float(tf.range(-radius, radius + 1))
346 | blur_filter = tf.exp(
347 | -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.to_float(sigma), 2.0)))
348 | blur_filter /= tf.reduce_sum(blur_filter)
349 | # One vertical and one horizontal filter.
350 | blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
351 | blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
352 | num_channels = tf.shape(image)[-1]
353 | blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
354 | blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
355 | expand_batch_dim = image.shape.ndims == 3
356 | if expand_batch_dim:
357 | # Tensorflow requires batched input to convolutions, which we can fake with
358 | # an extra dimension.
359 | image = tf.expand_dims(image, axis=0)
360 | blurred = tf.nn.depthwise_conv2d(
361 | image, blur_h, strides=[1, 1, 1, 1], padding=padding)
362 | blurred = tf.nn.depthwise_conv2d(
363 | blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
364 | if expand_batch_dim:
365 | blurred = tf.squeeze(blurred, axis=0)
366 | return blurred
367 |
368 |
369 | def random_crop_with_resize(image, height, width, p=1.0):
370 | """Randomly crop and resize an image.
371 |
372 | Args:
373 | image: `Tensor` representing an image of arbitrary size.
374 | height: Height of output image.
375 | width: Width of output image.
376 | p: Probability of applying this transformation.
377 |
378 | Returns:
379 | A preprocessed image `Tensor`.
380 | """
381 | def _transform(image): # pylint: disable=missing-docstring
382 | image = crop_and_resize(image, height, width)
383 | return image
384 | return random_apply(_transform, p=p, x=image)
385 |
386 |
387 | def random_color_jitter(image, p=1.0, impl='simclrv2'):
388 |
389 | def _transform(image):
390 | color_jitter_t = functools.partial(
391 | color_jitter, strength=FLAGS.color_jitter_strength, impl=impl)
392 | image = random_apply(color_jitter_t, p=0.8, x=image)
393 | return random_apply(to_grayscale, p=0.2, x=image)
394 | return random_apply(_transform, p=p, x=image)
395 |
396 |
397 | def random_blur(image, height, width, p=1.0):
398 | """Randomly blur an image.
399 |
400 | Args:
401 | image: `Tensor` representing an image of arbitrary size.
402 | height: Height of output image.
403 | width: Width of output image.
404 | p: probability of applying this transformation.
405 |
406 | Returns:
407 | A preprocessed image `Tensor`.
408 | """
409 | del width
410 | def _transform(image):
411 | sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
412 | return gaussian_blur(
413 | image, kernel_size=height//10, sigma=sigma, padding='SAME')
414 | return random_apply(_transform, p=p, x=image)
415 |
416 |
417 | def batch_random_blur(images_list, height, width, blur_probability=0.5):
418 | """Apply efficient batch data transformations.
419 |
420 | Args:
421 | images_list: a list of image tensors.
422 | height: the height of image.
423 | width: the width of image.
424 | blur_probability: the probaility to apply the blur operator.
425 |
426 | Returns:
427 | Preprocessed feature list.
428 | """
429 | def generate_selector(p, bsz):
430 | shape = [bsz, 1, 1, 1]
431 | selector = tf.cast(
432 | tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p),
433 | tf.float32)
434 | return selector
435 |
436 | new_images_list = []
437 | for images in images_list:
438 | images_new = random_blur(images, height, width, p=1.)
439 | selector = generate_selector(blur_probability, tf.shape(images)[0])
440 | images = images_new * selector + images * (1 - selector)
441 | images = tf.clip_by_value(images, 0., 1.)
442 | new_images_list.append(images)
443 |
444 | return new_images_list
445 |
446 |
447 | def preprocess_for_train(image,
448 | height,
449 | width,
450 | color_distort=True,
451 | crop=True,
452 | flip=True,
453 | impl='simclrv2'):
454 | """Preprocesses the given image for training.
455 |
456 | Args:
457 | image: `Tensor` representing an image of arbitrary size.
458 | height: Height of output image.
459 | width: Width of output image.
460 | color_distort: Whether to apply the color distortion.
461 | crop: Whether to crop the image.
462 | flip: Whether or not to flip left and right of an image.
463 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
464 | version of random brightness.
465 |
466 | Returns:
467 | A preprocessed image `Tensor`.
468 | """
469 | if crop:
470 | image = random_crop_with_resize(image, height, width)
471 | if flip:
472 | image = tf.image.random_flip_left_right(image)
473 | if color_distort:
474 | image = random_color_jitter(image, impl=impl)
475 | image = tf.reshape(image, [height, width, 3])
476 | image = tf.clip_by_value(image, 0., 1.)
477 | return image
478 |
479 |
480 | def preprocess_for_eval(image, height, width, crop=True):
481 | """Preprocesses the given image for evaluation.
482 |
483 | Args:
484 | image: `Tensor` representing an image of arbitrary size.
485 | height: Height of output image.
486 | width: Width of output image.
487 | crop: Whether or not to (center) crop the test images.
488 |
489 | Returns:
490 | A preprocessed image `Tensor`.
491 | """
492 | if crop:
493 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
494 | image = tf.reshape(image, [height, width, 3])
495 | image = tf.clip_by_value(image, 0., 1.)
496 | return image
497 |
498 |
499 | def preprocess_image(image, height, width, is_training=False,
500 | color_distort=True, test_crop=True):
501 | """Preprocesses the given image.
502 |
503 | Args:
504 | image: `Tensor` representing an image of arbitrary size.
505 | height: Height of output image.
506 | width: Width of output image.
507 | is_training: `bool` for whether the preprocessing is for training.
508 | color_distort: whether to apply the color distortion.
509 | test_crop: whether or not to extract a central crop of the images
510 | (as for standard ImageNet evaluation) during the evaluation.
511 |
512 | Returns:
513 | A preprocessed image `Tensor` of range [0, 1].
514 | """
515 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
516 | if is_training:
517 | return preprocess_for_train(image, height, width, color_distort)
518 | else:
519 | return preprocess_for_eval(image, height, width, test_crop)
520 |
--------------------------------------------------------------------------------
/lars_optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions and classes related to optimization (weight updates)."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import re
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | EETA_DEFAULT = 0.001
27 |
28 |
29 | class LARSOptimizer(tf.train.Optimizer):
30 | """Layer-wise Adaptive Rate Scaling for large batch training.
31 |
32 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
33 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
34 | """
35 |
36 | def __init__(self,
37 | learning_rate,
38 | momentum=0.9,
39 | use_nesterov=False,
40 | weight_decay=0.0,
41 | exclude_from_weight_decay=None,
42 | exclude_from_layer_adaptation=None,
43 | classic_momentum=True,
44 | eeta=EETA_DEFAULT,
45 | name="LARSOptimizer"):
46 | """Constructs a LARSOptimizer.
47 |
48 | Args:
49 | learning_rate: A `float` for learning rate.
50 | momentum: A `float` for momentum.
51 | use_nesterov: A 'Boolean' for whether to use nesterov momentum.
52 | weight_decay: A `float` for weight decay.
53 | exclude_from_weight_decay: A list of `string` for variable screening, if
54 | any of the string appears in a variable's name, the variable will be
55 | excluded for computing weight decay. For example, one could specify
56 | the list like ['batch_normalization', 'bias'] to exclude BN and bias
57 | from weight decay.
58 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
59 | for layer adaptation. If it is None, it will be defaulted the same as
60 | exclude_from_weight_decay.
61 | classic_momentum: A `boolean` for whether to use classic (or popular)
62 | momentum. The learning rate is applied during momeuntum update in
63 | classic momentum, but after momentum for popular momentum.
64 | eeta: A `float` for scaling of learning rate when computing trust ratio.
65 | name: The name for the scope.
66 | """
67 | super(LARSOptimizer, self).__init__(False, name)
68 |
69 | self.learning_rate = learning_rate
70 | self.momentum = momentum
71 | self.weight_decay = weight_decay
72 | self.use_nesterov = use_nesterov
73 | self.classic_momentum = classic_momentum
74 | self.eeta = eeta
75 | self.exclude_from_weight_decay = exclude_from_weight_decay
76 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
77 | # arg is None.
78 | if exclude_from_layer_adaptation:
79 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
80 | else:
81 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
82 |
83 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
84 | assignments = []
85 | for (grad, param) in grads_and_vars:
86 | if grad is None or param is None:
87 | continue
88 |
89 | param_name = param.op.name
90 |
91 | v = tf.get_variable(
92 | name=param_name + "/Momentum",
93 | shape=param.shape.as_list(),
94 | dtype=tf.float32,
95 | trainable=False,
96 | initializer=tf.zeros_initializer())
97 |
98 | if self._use_weight_decay(param_name):
99 | grad += self.weight_decay * param
100 |
101 | if self.classic_momentum:
102 | trust_ratio = 1.0
103 | if self._do_layer_adaptation(param_name):
104 | w_norm = tf.norm(param, ord=2)
105 | g_norm = tf.norm(grad, ord=2)
106 | trust_ratio = tf.where(
107 | tf.greater(w_norm, 0), tf.where(
108 | tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm),
109 | 1.0),
110 | 1.0)
111 | scaled_lr = self.learning_rate * trust_ratio
112 |
113 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
114 | if self.use_nesterov:
115 | update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
116 | else:
117 | update = next_v
118 | next_param = param - update
119 | else:
120 | next_v = tf.multiply(self.momentum, v) + grad
121 | if self.use_nesterov:
122 | update = tf.multiply(self.momentum, next_v) + grad
123 | else:
124 | update = next_v
125 |
126 | trust_ratio = 1.0
127 | if self._do_layer_adaptation(param_name):
128 | w_norm = tf.norm(param, ord=2)
129 | v_norm = tf.norm(update, ord=2)
130 | trust_ratio = tf.where(
131 | tf.greater(w_norm, 0), tf.where(
132 | tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm),
133 | 1.0),
134 | 1.0)
135 | scaled_lr = trust_ratio * self.learning_rate
136 | next_param = param - scaled_lr * update
137 |
138 | assignments.extend([param.assign(next_param), v.assign(next_v)])
139 |
140 | if global_step is not None:
141 | new_global_step = global_step + 1
142 | assignments.append(global_step.assign(new_global_step))
143 | return tf.group(*assignments, name=name)
144 |
145 | def _use_weight_decay(self, param_name):
146 | """Whether to use L2 weight decay for `param_name`."""
147 | if not self.weight_decay:
148 | return False
149 | if self.exclude_from_weight_decay:
150 | for r in self.exclude_from_weight_decay:
151 | if re.search(r, param_name) is not None:
152 | return False
153 | return True
154 |
155 | def _do_layer_adaptation(self, param_name):
156 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
157 | if self.exclude_from_layer_adaptation:
158 | for r in self.exclude_from_layer_adaptation:
159 | if re.search(r, param_name) is not None:
160 | return False
161 | return True
162 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Model specification for SimCLR."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 |
24 | import data_util as data_util
25 | import model_util as model_util
26 | import objective as obj_lib
27 |
28 | import tensorflow.compat.v1 as tf
29 | from tensorflow.compat.v1 import estimator as tf_estimator
30 | import tensorflow.compat.v2 as tf2
31 |
32 | FLAGS = flags.FLAGS
33 |
34 |
35 | def build_model_fn(model, num_classes, num_train_examples):
36 | """Build model function."""
37 | def model_fn(features, labels, mode, params=None):
38 | """Build model and optimizer."""
39 | is_training = mode == tf_estimator.ModeKeys.TRAIN
40 |
41 | # Check training mode.
42 | if FLAGS.train_mode == 'pretrain':
43 | num_transforms = 2
44 | if FLAGS.fine_tune_after_block > -1:
45 | raise ValueError('Does not support layer freezing during pretraining,'
46 | 'should set fine_tune_after_block<=-1 for safety.')
47 | elif FLAGS.train_mode == 'finetune':
48 | num_transforms = 1
49 | else:
50 | raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))
51 |
52 | # Split channels, and optionally apply extra batched augmentation.
53 | features_list = tf.split(
54 | features, num_or_size_splits=num_transforms, axis=-1)
55 | if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
56 | features_list = data_util.batch_random_blur(
57 | features_list, FLAGS.image_size, FLAGS.image_size)
58 | features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c)
59 |
60 | # Base network forward pass.
61 | with tf.variable_scope('base_model'):
62 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
63 | # Finetune just supervised (linear) head will not update BN stats.
64 | model_train_mode = False
65 | else:
66 | # Pretrain or finetune anything else will update BN stats.
67 | model_train_mode = is_training
68 | hiddens = model(features, is_training=model_train_mode)
69 |
70 | # Add head and loss.
71 | if FLAGS.train_mode == 'pretrain':
72 | tpu_context = params['context'] if 'context' in params else None
73 | hiddens_proj = model_util.projection_head(hiddens, is_training)
74 | contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
75 | hiddens_proj,
76 | hidden_norm=FLAGS.hidden_norm,
77 | temperature=FLAGS.temperature,
78 | tpu_context=tpu_context if is_training else None)
79 | logits_sup = tf.zeros([params['batch_size'], num_classes])
80 | else:
81 | contrast_loss = tf.zeros([])
82 | logits_con = tf.zeros([params['batch_size'], 10])
83 | labels_con = tf.zeros([params['batch_size'], 10])
84 | hiddens = model_util.projection_head(hiddens, is_training)
85 | logits_sup = model_util.supervised_head(
86 | hiddens, num_classes, is_training)
87 | obj_lib.add_supervised_loss(
88 | labels=labels['labels'],
89 | logits=logits_sup,
90 | weights=labels['mask'])
91 |
92 | # Add weight decay to loss, for non-LARS optimizers.
93 | model_util.add_weight_decay(adjust_per_optimizer=True)
94 | loss = tf.losses.get_total_loss()
95 |
96 | if FLAGS.train_mode == 'pretrain':
97 | variables_to_train = tf.trainable_variables()
98 | else:
99 | collection_prefix = 'trainable_variables_inblock_'
100 | variables_to_train = []
101 | for j in range(FLAGS.fine_tune_after_block + 1, 6):
102 | variables_to_train += tf.get_collection(collection_prefix + str(j))
103 | assert variables_to_train, 'variables_to_train shouldn\'t be empty!'
104 |
105 | tf.logging.info('===============Variables to train (begin)===============')
106 | tf.logging.info(variables_to_train)
107 | tf.logging.info('================Variables to train (end)================')
108 |
109 | learning_rate = model_util.learning_rate_schedule(
110 | FLAGS.learning_rate, num_train_examples)
111 |
112 | if is_training:
113 | if FLAGS.train_summary_steps > 0:
114 | # Compute stats for the summary.
115 | prob_con = tf.nn.softmax(logits_con)
116 | entropy_con = - tf.reduce_mean(
117 | tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))
118 |
119 | summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir)
120 | # TODO(iamtingchen): remove this control_dependencies in the future.
121 | with tf.control_dependencies([summary_writer.init()]):
122 | with summary_writer.as_default():
123 | should_record = tf.math.equal(
124 | tf.math.floormod(tf.train.get_global_step(),
125 | FLAGS.train_summary_steps), 0)
126 | with tf2.summary.record_if(should_record):
127 | contrast_acc = tf.equal(
128 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
129 | contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
130 | label_acc = tf.equal(
131 | tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
132 | label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
133 | tf2.summary.scalar(
134 | 'train_contrast_loss',
135 | contrast_loss,
136 | step=tf.train.get_global_step())
137 | tf2.summary.scalar(
138 | 'train_contrast_acc',
139 | contrast_acc,
140 | step=tf.train.get_global_step())
141 | tf2.summary.scalar(
142 | 'train_label_accuracy',
143 | label_acc,
144 | step=tf.train.get_global_step())
145 | tf2.summary.scalar(
146 | 'contrast_entropy',
147 | entropy_con,
148 | step=tf.train.get_global_step())
149 | tf2.summary.scalar(
150 | 'learning_rate', learning_rate,
151 | step=tf.train.get_global_step())
152 |
153 | optimizer = model_util.get_optimizer(learning_rate)
154 | control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
155 | if FLAGS.train_summary_steps > 0:
156 | control_deps.extend(tf.summary.all_v2_summary_ops())
157 | with tf.control_dependencies(control_deps):
158 | train_op = optimizer.minimize(
159 | loss, global_step=tf.train.get_or_create_global_step(),
160 | var_list=variables_to_train)
161 |
162 | if FLAGS.checkpoint:
163 | def scaffold_fn():
164 | """Scaffold function to restore non-logits vars from checkpoint."""
165 | tf.train.init_from_checkpoint(
166 | FLAGS.checkpoint,
167 | {v.op.name: v.op.name
168 | for v in tf.global_variables(FLAGS.variable_schema)})
169 |
170 | if FLAGS.zero_init_logits_layer:
171 | # Init op that initializes output layer parameters to zeros.
172 | output_layer_parameters = [
173 | var for var in tf.trainable_variables() if var.name.startswith(
174 | 'head_supervised')]
175 | tf.logging.info('Initializing output layer parameters %s to zero',
176 | [x.op.name for x in output_layer_parameters])
177 | with tf.control_dependencies([tf.global_variables_initializer()]):
178 | init_op = tf.group([
179 | tf.assign(x, tf.zeros_like(x))
180 | for x in output_layer_parameters])
181 | return tf.train.Scaffold(init_op=init_op)
182 | else:
183 | return tf.train.Scaffold()
184 | else:
185 | scaffold_fn = None
186 |
187 | return tf_estimator.tpu.TPUEstimatorSpec(
188 | mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
189 | else:
190 |
191 | def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
192 | **kws):
193 | """Inner metric function."""
194 | metrics = {k: tf.metrics.mean(v, weights=mask)
195 | for k, v in kws.items()}
196 | metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
197 | tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
198 | weights=mask)
199 | metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
200 | tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
201 | metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
202 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1),
203 | weights=mask)
204 | metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
205 | tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
206 | return metrics
207 |
208 | metrics = {
209 | 'logits_sup': logits_sup,
210 | 'labels_sup': labels['labels'],
211 | 'logits_con': logits_con,
212 | 'labels_con': labels_con,
213 | 'mask': labels['mask'],
214 | 'contrast_loss': tf.fill((params['batch_size'],), contrast_loss),
215 | 'regularization_loss': tf.fill((params['batch_size'],),
216 | tf.losses.get_regularization_loss()),
217 | }
218 |
219 | return tf_estimator.tpu.TPUEstimatorSpec(
220 | mode=mode,
221 | loss=loss,
222 | eval_metrics=(metric_fn, metrics),
223 | scaffold_fn=None)
224 |
225 | return model_fn
226 |
--------------------------------------------------------------------------------
/model_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Network architectures related functions used in SimCLR."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 | from absl import flags
24 |
25 | import resnet
26 | from lars_optimizer import LARSOptimizer
27 |
28 | import tensorflow.compat.v1 as tf
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def add_weight_decay(adjust_per_optimizer=True):
34 | """Compute weight decay from flags."""
35 | if adjust_per_optimizer and 'lars' in FLAGS.optimizer:
36 | # Weight decay are taking care of by optimizer for these cases.
37 | # Except for supervised head, which will be added here.
38 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()
39 | if 'head_supervised' in v.name and 'bias' not in v.name]
40 | if l2_losses:
41 | tf.losses.add_loss(
42 | FLAGS.weight_decay * tf.add_n(l2_losses),
43 | tf.GraphKeys.REGULARIZATION_LOSSES)
44 | return
45 |
46 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()
47 | if 'batch_normalization' not in v.name]
48 | tf.losses.add_loss(
49 | FLAGS.weight_decay * tf.add_n(l2_losses),
50 | tf.GraphKeys.REGULARIZATION_LOSSES)
51 |
52 |
53 | def get_train_steps(num_examples):
54 | """Determine the number of training steps."""
55 | return FLAGS.train_steps or (
56 | num_examples * FLAGS.train_epochs // FLAGS.train_batch_size + 1)
57 |
58 |
59 | def learning_rate_schedule(base_learning_rate, num_examples):
60 | """Build learning rate schedule."""
61 | global_step = tf.train.get_or_create_global_step()
62 | warmup_steps = int(round(
63 | FLAGS.warmup_epochs * num_examples // FLAGS.train_batch_size))
64 | if FLAGS.learning_rate_scaling == 'linear':
65 | scaled_lr = base_learning_rate * FLAGS.train_batch_size / 256.
66 | elif FLAGS.learning_rate_scaling == 'sqrt':
67 | scaled_lr = base_learning_rate * math.sqrt(FLAGS.train_batch_size)
68 | else:
69 | raise ValueError('Unknown learning rate scaling {}'.format(
70 | FLAGS.learning_rate_scaling))
71 | learning_rate = (tf.to_float(global_step) / int(warmup_steps) * scaled_lr
72 | if warmup_steps else scaled_lr)
73 |
74 | # Cosine decay learning rate schedule
75 | total_steps = get_train_steps(num_examples)
76 | learning_rate = tf.where(
77 | global_step < warmup_steps, learning_rate,
78 | tf.train.cosine_decay(
79 | scaled_lr,
80 | global_step - warmup_steps,
81 | total_steps - warmup_steps))
82 |
83 | return learning_rate
84 |
85 |
86 | def get_optimizer(learning_rate):
87 | """Returns an optimizer."""
88 | if FLAGS.optimizer == 'momentum':
89 | optimizer = tf.train.MomentumOptimizer(
90 | learning_rate, FLAGS.momentum, use_nesterov=True)
91 | elif FLAGS.optimizer == 'adam':
92 | optimizer = tf.train.AdamOptimizer(
93 | learning_rate)
94 | elif FLAGS.optimizer == 'lars':
95 | optimizer = LARSOptimizer(
96 | learning_rate,
97 | momentum=FLAGS.momentum,
98 | weight_decay=FLAGS.weight_decay,
99 | exclude_from_weight_decay=['batch_normalization', 'bias',
100 | 'head_supervised'])
101 | else:
102 | raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer))
103 |
104 | if FLAGS.use_tpu:
105 | optimizer = tf.tpu.CrossShardOptimizer(optimizer)
106 | return optimizer
107 |
108 |
109 | def linear_layer(x,
110 | is_training,
111 | num_classes,
112 | use_bias=True,
113 | use_bn=False,
114 | name='linear_layer'):
115 | """Linear head for linear evaluation.
116 |
117 | Args:
118 | x: hidden state tensor of shape (bsz, dim).
119 | is_training: boolean indicator for training or test.
120 | num_classes: number of classes.
121 | use_bias: whether or not to use bias.
122 | use_bn: whether or not to use BN for output units.
123 | name: the name for variable scope.
124 |
125 | Returns:
126 | logits of shape (bsz, num_classes)
127 | """
128 | assert x.shape.ndims == 2, x.shape
129 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
130 | x = tf.layers.dense(
131 | inputs=x,
132 | units=num_classes,
133 | use_bias=use_bias and not use_bn,
134 | kernel_initializer=tf.random_normal_initializer(stddev=.01))
135 | if use_bn:
136 | x = resnet.batch_norm_relu(x, is_training, relu=False, center=use_bias)
137 | x = tf.identity(x, '%s_out' % name)
138 | return x
139 |
140 |
141 | def projection_head(hiddens, is_training, name='head_contrastive'):
142 | """Head for projecting hiddens fo contrastive loss."""
143 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
144 | mid_dim = hiddens.shape[-1]
145 | out_dim = FLAGS.proj_out_dim
146 | hiddens_list = [hiddens]
147 | if FLAGS.proj_head_mode == 'none':
148 | pass # directly use the output hiddens as hiddens.
149 | elif FLAGS.proj_head_mode == 'linear':
150 | hiddens = linear_layer(
151 | hiddens, is_training, out_dim,
152 | use_bias=False, use_bn=True, name='l_0')
153 | hiddens_list.append(hiddens)
154 | elif FLAGS.proj_head_mode == 'nonlinear':
155 | for j in range(FLAGS.num_proj_layers):
156 | if j != FLAGS.num_proj_layers - 1:
157 | # for the middle layers, use bias and relu for the output.
158 | dim, bias_relu = mid_dim, True
159 | else:
160 | # for the final layer, neither bias nor relu is used.
161 | dim, bias_relu = FLAGS.proj_out_dim, False
162 | hiddens = linear_layer(
163 | hiddens, is_training, dim,
164 | use_bias=bias_relu, use_bn=True, name='nl_%d'%j)
165 | hiddens = tf.nn.relu(hiddens) if bias_relu else hiddens
166 | hiddens_list.append(hiddens)
167 | else:
168 | raise ValueError('Unknown head projection mode {}'.format(
169 | FLAGS.proj_head_mode))
170 | if FLAGS.train_mode == 'pretrain':
171 | # take the projection head output during pre-training.
172 | hiddens = hiddens_list[-1]
173 | else:
174 | # for checkpoint compatibility, whole projection head is built here.
175 | # but you can select part of projection head during fine-tuning.
176 | hiddens = hiddens_list[FLAGS.ft_proj_selector]
177 | return hiddens
178 |
179 |
180 | def supervised_head(hiddens, num_classes, is_training, name='head_supervised'):
181 | """Add supervised head & also add its variables to inblock collection."""
182 | with tf.variable_scope(name):
183 | logits = linear_layer(hiddens, is_training, num_classes)
184 | for var in tf.trainable_variables():
185 | if var.name.startswith(name):
186 | tf.add_to_collection('trainable_variables_inblock_5', var)
187 | return logits
188 |
--------------------------------------------------------------------------------
/objective.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Contrastive loss functions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow.compat.v1 as tf
23 |
24 | from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import
25 |
26 | LARGE_NUM = 1e9
27 |
28 |
29 | def add_supervised_loss(labels, logits, weights, **kwargs):
30 | """Compute loss for model and add it to loss collection."""
31 | return tf.losses.softmax_cross_entropy(labels, logits, weights, **kwargs)
32 |
33 |
34 | def add_contrastive_loss(hidden,
35 | hidden_norm=True,
36 | temperature=1.0,
37 | tpu_context=None,
38 | weights=1.0):
39 | """Compute loss for model.
40 |
41 | Args:
42 | hidden: hidden vector (`Tensor`) of shape (2 * bsz, dim).
43 | hidden_norm: whether or not to use normalization on the hidden vector.
44 | temperature: a `floating` number for temperature scaling.
45 | tpu_context: context information for tpu.
46 | weights: a weighting number or vector.
47 |
48 | Returns:
49 | A loss scalar.
50 | The logits for contrastive prediction task.
51 | The labels for contrastive prediction task.
52 | """
53 | # Get (normalized) hidden1 and hidden2.
54 | if hidden_norm:
55 | hidden = tf.math.l2_normalize(hidden, -1)
56 | hidden1, hidden2 = tf.split(hidden, 2, 0)
57 | batch_size = tf.shape(hidden1)[0]
58 |
59 | # Gather hidden1/hidden2 across replicas and create local labels.
60 | if tpu_context is not None:
61 | hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
62 | hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
63 | enlarged_batch_size = tf.shape(hidden1_large)[0]
64 | # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
65 | replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
66 | labels_idx = tf.range(batch_size) + replica_id * batch_size
67 | labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
68 | masks = tf.one_hot(labels_idx, enlarged_batch_size)
69 | else:
70 | hidden1_large = hidden1
71 | hidden2_large = hidden2
72 | labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
73 | masks = tf.one_hot(tf.range(batch_size), batch_size)
74 |
75 | logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
76 | logits_aa = logits_aa - masks * LARGE_NUM
77 | logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
78 | logits_bb = logits_bb - masks * LARGE_NUM
79 | logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
80 | logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
81 |
82 | loss_a = tf.losses.softmax_cross_entropy(
83 | labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
84 | loss_b = tf.losses.softmax_cross_entropy(
85 | labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
86 | loss = loss_a + loss_b
87 |
88 | return loss, logits_ab, labels
89 |
90 |
91 | def tpu_cross_replica_concat(tensor, tpu_context=None):
92 | """Reduce a concatenation of the `tensor` across TPU cores.
93 |
94 | Args:
95 | tensor: tensor to concatenate.
96 | tpu_context: A `TPUContext`. If not set, CPU execution is assumed.
97 |
98 | Returns:
99 | Tensor of the same rank as `tensor` with first dimension `num_replicas`
100 | times larger.
101 | """
102 | if tpu_context is None or tpu_context.num_replicas <= 1:
103 | return tensor
104 |
105 | num_replicas = tpu_context.num_replicas
106 |
107 | with tf.name_scope('tpu_cross_replica_concat'):
108 | # This creates a tensor that is like the input tensor but has an added
109 | # replica dimension as the outermost dimension. On each replica it will
110 | # contain the local values and zeros for all other values that need to be
111 | # fetched from other replicas.
112 | ext_tensor = tf.scatter_nd(
113 | indices=[[xla.replica_id()]],
114 | updates=[tensor],
115 | shape=[num_replicas] + tensor.shape.as_list())
116 |
117 | # As every value is only present on one replica and 0 in all others, adding
118 | # them all together will result in the full tensor on all replicas.
119 | ext_tensor = tf.tpu.cross_replica_sum(ext_tensor)
120 |
121 | # Flatten the replica dimension.
122 | # The first dimension size will be: tensor.shape[0] * num_replicas
123 | # Using [-1] trick to support also scalar input.
124 | return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
125 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow==1.15.4
3 | tensorflow-datasets==3.1.0
4 | tensorflow-hub==0.8.0
5 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The main training pipeline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import json
23 | import math
24 | import os
25 | from absl import app
26 | from absl import flags
27 |
28 | import resnet
29 | import data as data_lib
30 | import model as model_lib
31 | import model_util as model_util
32 |
33 | import tensorflow.compat.v1 as tf
34 | from tensorflow.compat.v1 import estimator as tf_estimator
35 | import tensorflow_datasets as tfds
36 | import tensorflow_hub as hub
37 |
38 |
39 | FLAGS = flags.FLAGS
40 |
41 |
42 | flags.DEFINE_float(
43 | 'learning_rate', 0.3,
44 | 'Initial learning rate per batch size of 256.')
45 |
46 | flags.DEFINE_enum(
47 | 'learning_rate_scaling', 'linear', ['linear', 'sqrt'],
48 | 'How to scale the learning rate as a function of batch size.')
49 |
50 | flags.DEFINE_float(
51 | 'warmup_epochs', 10,
52 | 'Number of epochs of warmup.')
53 |
54 | flags.DEFINE_float(
55 | 'weight_decay', 1e-4,
56 | 'Amount of weight decay to use.')
57 |
58 | flags.DEFINE_float(
59 | 'batch_norm_decay', 0.9,
60 | 'Batch norm decay parameter.')
61 |
62 | flags.DEFINE_integer(
63 | 'train_batch_size', 512,
64 | 'Batch size for training.')
65 |
66 | flags.DEFINE_string(
67 | 'train_split', 'train',
68 | 'Split for training.')
69 |
70 | flags.DEFINE_integer(
71 | 'train_epochs', 100,
72 | 'Number of epochs to train for.')
73 |
74 | flags.DEFINE_integer(
75 | 'train_steps', 0,
76 | 'Number of steps to train for. If provided, overrides train_epochs.')
77 |
78 | flags.DEFINE_integer(
79 | 'eval_batch_size', 256,
80 | 'Batch size for eval.')
81 |
82 | flags.DEFINE_integer(
83 | 'train_summary_steps', 100,
84 | 'Steps before saving training summaries. If 0, will not save.')
85 |
86 | flags.DEFINE_integer(
87 | 'checkpoint_epochs', 1,
88 | 'Number of epochs between checkpoints/summaries.')
89 |
90 | flags.DEFINE_integer(
91 | 'checkpoint_steps', 0,
92 | 'Number of steps between checkpoints/summaries. If provided, overrides '
93 | 'checkpoint_epochs.')
94 |
95 | flags.DEFINE_string(
96 | 'eval_split', 'validation',
97 | 'Split for evaluation.')
98 |
99 | flags.DEFINE_string(
100 | 'dataset', 'imagenet2012',
101 | 'Name of a dataset.')
102 |
103 | flags.DEFINE_bool(
104 | 'cache_dataset', False,
105 | 'Whether to cache the entire dataset in memory. If the dataset is '
106 | 'ImageNet, this is a very bad idea, but for smaller datasets it can '
107 | 'improve performance.')
108 |
109 | flags.DEFINE_enum(
110 | 'mode', 'train', ['train', 'eval', 'train_then_eval'],
111 | 'Whether to perform training or evaluation.')
112 |
113 | flags.DEFINE_enum(
114 | 'train_mode', 'pretrain', ['pretrain', 'finetune'],
115 | 'The train mode controls different objectives and trainable components.')
116 |
117 | flags.DEFINE_string(
118 | 'checkpoint', None,
119 | 'Loading from the given checkpoint for continued training or fine-tuning.')
120 |
121 | flags.DEFINE_string(
122 | 'variable_schema', '?!global_step',
123 | 'This defines whether some variable from the checkpoint should be loaded.')
124 |
125 | flags.DEFINE_bool(
126 | 'zero_init_logits_layer', False,
127 | 'If True, zero initialize layers after avg_pool for supervised learning.')
128 |
129 | flags.DEFINE_integer(
130 | 'fine_tune_after_block', -1,
131 | 'The layers after which block that we will fine-tune. -1 means fine-tuning '
132 | 'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '
133 | 'just the linera head.')
134 |
135 | flags.DEFINE_string(
136 | 'master', None,
137 | 'Address/name of the TensorFlow master to use. By default, use an '
138 | 'in-process master.')
139 |
140 | flags.DEFINE_string(
141 | 'model_dir', None,
142 | 'Model directory for training.')
143 |
144 | flags.DEFINE_string(
145 | 'data_dir', None,
146 | 'Directory where dataset is stored.')
147 |
148 | flags.DEFINE_bool(
149 | 'use_tpu', True,
150 | 'Whether to run on TPU.')
151 |
152 | tf.flags.DEFINE_string(
153 | 'tpu_name', None,
154 | 'The Cloud TPU to use for training. This should be either the name '
155 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
156 | 'url.')
157 |
158 | tf.flags.DEFINE_string(
159 | 'tpu_zone', None,
160 | '[Optional] GCE zone where the Cloud TPU is located in. If not '
161 | 'specified, we will attempt to automatically detect the GCE project from '
162 | 'metadata.')
163 |
164 | tf.flags.DEFINE_string(
165 | 'gcp_project', None,
166 | '[Optional] Project name for the Cloud TPU-enabled project. If not '
167 | 'specified, we will attempt to automatically detect the GCE project from '
168 | 'metadata.')
169 |
170 | flags.DEFINE_enum(
171 | 'optimizer', 'lars', ['momentum', 'adam', 'lars'],
172 | 'Optimizer to use.')
173 |
174 | flags.DEFINE_float(
175 | 'momentum', 0.9,
176 | 'Momentum parameter.')
177 |
178 | flags.DEFINE_string(
179 | 'eval_name', None,
180 | 'Name for eval.')
181 |
182 | flags.DEFINE_integer(
183 | 'keep_checkpoint_max', 5,
184 | 'Maximum number of checkpoints to keep.')
185 |
186 | flags.DEFINE_integer(
187 | 'keep_hub_module_max', 1,
188 | 'Maximum number of Hub modules to keep.')
189 |
190 | flags.DEFINE_float(
191 | 'temperature', 0.1,
192 | 'Temperature parameter for contrastive loss.')
193 |
194 | flags.DEFINE_boolean(
195 | 'hidden_norm', True,
196 | 'Temperature parameter for contrastive loss.')
197 |
198 | flags.DEFINE_enum(
199 | 'proj_head_mode', 'nonlinear', ['none', 'linear', 'nonlinear'],
200 | 'How the head projection is done.')
201 |
202 | flags.DEFINE_integer(
203 | 'proj_out_dim', 128,
204 | 'Number of head projection dimension.')
205 |
206 | flags.DEFINE_integer(
207 | 'num_proj_layers', 3,
208 | 'Number of non-linear head layers.')
209 |
210 | flags.DEFINE_integer(
211 | 'ft_proj_selector', 0,
212 | 'Which layer of the projection head to use during fine-tuning. '
213 | '0 means throwing away the projection head, and -1 means the final layer.')
214 |
215 | flags.DEFINE_boolean(
216 | 'global_bn', True,
217 | 'Whether to aggregate BN statistics across distributed cores.')
218 |
219 | flags.DEFINE_integer(
220 | 'width_multiplier', 1,
221 | 'Multiplier to change width of network.')
222 |
223 | flags.DEFINE_integer(
224 | 'resnet_depth', 50,
225 | 'Depth of ResNet.')
226 |
227 | flags.DEFINE_float(
228 | 'sk_ratio', 0.,
229 | 'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')
230 |
231 | flags.DEFINE_float(
232 | 'se_ratio', 0.,
233 | 'If it is bigger than 0, it will enable SE.')
234 |
235 | flags.DEFINE_integer(
236 | 'image_size', 224,
237 | 'Input image size.')
238 |
239 | flags.DEFINE_float(
240 | 'color_jitter_strength', 1.0,
241 | 'The strength of color jittering.')
242 |
243 | flags.DEFINE_boolean(
244 | 'use_blur', True,
245 | 'Whether or not to use Gaussian blur for augmentation during pretraining.')
246 |
247 |
248 | def build_hub_module(model, num_classes, global_step, checkpoint_path):
249 | """Create TF-Hub module."""
250 |
251 | tags_and_args = [
252 | # The default graph is built with batch_norm, dropout etc. in inference
253 | # mode. This graph version is good for inference, not training.
254 | ([], {'is_training': False}),
255 | # A separate "train" graph builds batch_norm, dropout etc. in training
256 | # mode.
257 | (['train'], {'is_training': True}),
258 | ]
259 |
260 | def module_fn(is_training):
261 | """Function that builds TF-Hub module."""
262 | endpoints = {}
263 | inputs = tf.placeholder(
264 | tf.float32, [None, None, None, 3])
265 | with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE):
266 | hiddens = model(inputs, is_training)
267 | for v in ['initial_conv', 'initial_max_pool', 'block_group1',
268 | 'block_group2', 'block_group3', 'block_group4',
269 | 'final_avg_pool']:
270 | endpoints[v] = tf.get_default_graph().get_tensor_by_name(
271 | 'base_model/{}:0'.format(v))
272 | if FLAGS.train_mode == 'pretrain':
273 | hiddens_proj = model_util.projection_head(hiddens, is_training)
274 | endpoints['proj_head_input'] = hiddens
275 | endpoints['proj_head_output'] = hiddens_proj
276 | else:
277 | logits_sup = model_util.supervised_head(
278 | hiddens, num_classes, is_training)
279 | endpoints['logits_sup'] = logits_sup
280 | hub.add_signature(inputs=dict(images=inputs),
281 | outputs=dict(endpoints, default=hiddens))
282 |
283 | # Drop the non-supported non-standard graph collection.
284 | drop_collections = ['trainable_variables_inblock_%d'%d for d in range(6)]
285 | spec = hub.create_module_spec(module_fn, tags_and_args, drop_collections)
286 | hub_export_dir = os.path.join(FLAGS.model_dir, 'hub')
287 | checkpoint_export_dir = os.path.join(hub_export_dir, str(global_step))
288 | if tf.io.gfile.exists(checkpoint_export_dir):
289 | # Do not save if checkpoint already saved.
290 | tf.io.gfile.rmtree(checkpoint_export_dir)
291 | spec.export(
292 | checkpoint_export_dir,
293 | checkpoint_path=checkpoint_path,
294 | name_transform_fn=None)
295 |
296 | if FLAGS.keep_hub_module_max > 0:
297 | # Delete old exported Hub modules.
298 | exported_steps = []
299 | for subdir in tf.io.gfile.listdir(hub_export_dir):
300 | if not subdir.isdigit():
301 | continue
302 | exported_steps.append(int(subdir))
303 | exported_steps.sort()
304 | for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
305 | tf.io.gfile.rmtree(os.path.join(hub_export_dir, str(step_to_delete)))
306 |
307 |
308 | def perform_evaluation(estimator, input_fn, eval_steps, model, num_classes,
309 | checkpoint_path=None):
310 | """Perform evaluation.
311 |
312 | Args:
313 | estimator: TPUEstimator instance.
314 | input_fn: Input function for estimator.
315 | eval_steps: Number of steps for evaluation.
316 | model: Instance of transfer_learning.models.Model.
317 | num_classes: Number of classes to build model for.
318 | checkpoint_path: Path of checkpoint to evaluate.
319 |
320 | Returns:
321 | result: A Dict of metrics and their values.
322 | """
323 | if not checkpoint_path:
324 | checkpoint_path = estimator.latest_checkpoint()
325 | result = estimator.evaluate(
326 | input_fn, eval_steps, checkpoint_path=checkpoint_path,
327 | name=FLAGS.eval_name)
328 |
329 | # Record results as JSON.
330 | result_json_path = os.path.join(FLAGS.model_dir, 'result.json')
331 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
332 | json.dump({k: float(v) for k, v in result.items()}, f)
333 | result_json_path = os.path.join(
334 | FLAGS.model_dir, 'result_%d.json'%result['global_step'])
335 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
336 | json.dump({k: float(v) for k, v in result.items()}, f)
337 | flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json')
338 |
339 | def json_serializable(val):
340 | try:
341 | json.dumps(val)
342 | return True
343 | except TypeError:
344 | return False
345 |
346 | with tf.io.gfile.GFile(flag_json_path, 'w') as f:
347 | serializable_flags = {}
348 | for key, val in FLAGS.flag_values_dict().items():
349 | # Some flag value types e.g. datetime.timedelta are not json serializable,
350 | # filter those out.
351 | if json_serializable(val):
352 | serializable_flags[key] = val
353 | json.dump(serializable_flags, f)
354 |
355 | # Save Hub module.
356 | build_hub_module(model, num_classes,
357 | global_step=result['global_step'],
358 | checkpoint_path=checkpoint_path)
359 |
360 | return result
361 |
362 |
363 | def main(argv):
364 | if len(argv) > 1:
365 | raise app.UsageError('Too many command-line arguments.')
366 |
367 | # Enable training summary.
368 | if FLAGS.train_summary_steps > 0:
369 | tf.config.set_soft_device_placement(True)
370 |
371 |
372 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
373 | builder.download_and_prepare()
374 | num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
375 | num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
376 | num_classes = builder.info.features['label'].num_classes
377 |
378 | train_steps = model_util.get_train_steps(num_train_examples)
379 | eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
380 | epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))
381 |
382 | resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
383 | model = resnet.resnet_v1(
384 | resnet_depth=FLAGS.resnet_depth,
385 | width_multiplier=FLAGS.width_multiplier,
386 | cifar_stem=FLAGS.image_size <= 32)
387 |
388 | checkpoint_steps = (
389 | FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))
390 |
391 | cluster = None
392 | if FLAGS.use_tpu and FLAGS.master is None:
393 | if FLAGS.tpu_name:
394 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
395 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
396 | else:
397 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
398 | tf.config.experimental_connect_to_cluster(cluster)
399 | tf.tpu.experimental.initialize_tpu_system(cluster)
400 |
401 | default_eval_mode = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V1
402 | sliced_eval_mode = tf_estimator.tpu.InputPipelineConfig.SLICED
403 | run_config = tf_estimator.tpu.RunConfig(
404 | tpu_config=tf_estimator.tpu.TPUConfig(
405 | iterations_per_loop=checkpoint_steps,
406 | eval_training_input_configuration=sliced_eval_mode
407 | if FLAGS.use_tpu else default_eval_mode),
408 | model_dir=FLAGS.model_dir,
409 | save_summary_steps=checkpoint_steps,
410 | save_checkpoints_steps=checkpoint_steps,
411 | keep_checkpoint_max=FLAGS.keep_checkpoint_max,
412 | master=FLAGS.master,
413 | cluster=cluster)
414 | estimator = tf_estimator.tpu.TPUEstimator(
415 | model_lib.build_model_fn(model, num_classes, num_train_examples),
416 | config=run_config,
417 | train_batch_size=FLAGS.train_batch_size,
418 | eval_batch_size=FLAGS.eval_batch_size,
419 | use_tpu=FLAGS.use_tpu)
420 |
421 | if FLAGS.mode == 'eval':
422 | for ckpt in tf.train.checkpoints_iterator(
423 | run_config.model_dir, min_interval_secs=15):
424 | try:
425 | result = perform_evaluation(
426 | estimator=estimator,
427 | input_fn=data_lib.build_input_fn(builder, False),
428 | eval_steps=eval_steps,
429 | model=model,
430 | num_classes=num_classes,
431 | checkpoint_path=ckpt)
432 | except tf.errors.NotFoundError:
433 | continue
434 | if result['global_step'] >= train_steps:
435 | return
436 | else:
437 | estimator.train(
438 | data_lib.build_input_fn(builder, True), max_steps=train_steps)
439 | if FLAGS.mode == 'train_then_eval':
440 | perform_evaluation(
441 | estimator=estimator,
442 | input_fn=data_lib.build_input_fn(builder, False),
443 | eval_steps=eval_steps,
444 | model=model,
445 | num_classes=num_classes)
446 |
447 |
448 | if __name__ == '__main__':
449 | tf.disable_v2_behavior() # Disable eager mode when running with TF2.
450 | app.run(main)
451 |
--------------------------------------------------------------------------------
/tf2/README.md:
--------------------------------------------------------------------------------
1 | # TF2 implementation of SimCLR
2 |
3 | This implementation is based on TensorFlow 2.x. We use `tf.keras` layers for building the model and use `tf.data` for our input pipeline. The model is trained using a [custom training loop](https://www.tensorflow.org/tutorials/distribute/custom_training) with `tf.distribute` on multiple TPUs.
4 |
5 |
6 |

7 |
8 |
11 |
12 |
13 |
14 | ## Pre-trained models for SimCLRv2
15 |
16 |
17 | We have converted the checkpoints for the TF1 models of SimCLR v1 and v2 to TF2 [SavedModel](https://www.tensorflow.org/guide/saved_model):
18 |
19 | * Pretrained SimCLRv2 models (with linear eval head): [gs://simclr-checkpoints-tf2/simclrv2/pretrained](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/pretrained)
20 | * Fine-tuned SimCLRv2 models on 1% of labels: [gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/finetuned_1pct)
21 | * Fine-tuned SimCLRv2 models on 10% of labels: [gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/finetuned_10pct)
22 | * Fine-tuned SimCLRv2 models on 100% of labels: [gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/finetuned_100pct)
23 | * Supervised models with the same architectures: [gs://simclr-checkpoints-tf2/simclrv2/supervised](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/supervised)
24 | * The distilled / self-trained models (after fine-tuning) are also provided:
25 | * [gs://simclr-checkpoints-tf2/simclrv2/distill_1pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/distill_1pct)
26 | * [gs://simclr-checkpoints-tf2/simclrv2/distill_10pct](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv2/distill_10pct)
27 |
28 | We also provide examples on how to use the SavedModels in `colabs/` folder. In addition to the TF1 colabs we provide a `imagenet_results.ipynb` colab to verify results from SimCLR v1 and v2 papers for ImageNet.
29 |
30 | ## Pre-trained models for SimCLRv1
31 |
32 | The pre-trained models (base network with linear classifier layer) can be found below. Note that for these SimCLRv1 checkpoints, the projection head is not available.
33 |
34 | | SavedModel | ImageNet Top-1 |
35 | |--------------------------------------------------------------------------------------------------------------|------------------------|
36 | |[ResNet50 (1x)](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv1/pretrain/1x) | 69.1 |
37 | |[ResNet50 (2x)](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv1/pretrain/2x) | 74.2 |
38 | |[ResNet50 (4x)](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv1/pretrain/4x) | 76.6 |
39 |
40 | Additional SimCLRv1 checkpoints are available: [gs://simclr-checkpoints-tf2/simclrv1](https://console.cloud.google.com/storage/browser/simclr-checkpoints-tf2/simclrv1).
41 |
42 | A note on the signature of the TensorFlow SavedModel: `logits_sup` is the supervised classification logits for ImageNet 1000 categories. Others (e.g. `initial_max_pool`, `block_group1`) are middle layers of ResNet; refer to resnet.py for the specifics.
43 |
44 | ## Enviroment setup
45 |
46 | Our models are trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining.
47 |
48 | The code can be run on multiple GPUs by replacing `tf.distribute.TPUStrategy` with `tf.distribute.MirroredStrategy`. See the TensorFlow distributed training [guide](https://www.tensorflow.org/guide/distributed_training) for an overview of `tf.distribute`.
49 |
50 | The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.
51 |
52 | ```
53 | pip install -r requirements.txt
54 | ```
55 |
56 | ## Pretraining
57 |
58 | To pretrain the model on CIFAR-10 with CPU / 1 or more GPUs, try the following command:
59 |
60 | ```
61 | python run.py --train_mode=pretrain \
62 | --train_batch_size=512 --train_epochs=1000 \
63 | --learning_rate=1.0 --weight_decay=1e-4 --temperature=0.5 \
64 | --dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
65 | --use_blur=False --color_jitter_strength=0.5 \
66 | --model_dir=/tmp/simclr_test --use_tpu=False
67 | ```
68 |
69 | To pretrain the model on ImageNet with Cloud TPUs, first check out the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for basic information on how to use Google Cloud TPUs.
70 |
71 | Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for [tensorflow_datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012), please set the following enviroment variables:
72 |
73 | ```
74 | TPU_NAME=
75 | STORAGE_BUCKET=gs://
76 | DATA_DIR=$STORAGE_BUCKET/
77 | MODEL_DIR=$STORAGE_BUCKET/
78 | ```
79 |
80 | The following command can be used to pretrain a ResNet-50 on ImageNet (which reflects the default hyperparameters in our paper):
81 |
82 | ```
83 | python run.py --train_mode=pretrain \
84 | --train_batch_size=4096 --train_epochs=100 --temperature=0.1 \
85 | --learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
86 | --dataset=imagenet2012 --image_size=224 --eval_split=validation \
87 | --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
88 | --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
89 | ```
90 |
91 | A batch size of 4096 requires at least 32 TPUs. 100 epochs takes around 6 hours with 32 TPU v3s. Note that learning rate of 0.3 with `learning_rate_scaling=linear` is equivalent to that of 0.075 with `learning_rate_scaling=sqrt` when the batch size is 4096. However, using sqrt scaling allows it to train better when smaller batch size is used.
92 |
93 | ## Finetuning the linear head (linear eval)
94 |
95 | You could simply set `--lineareval_while_pretraining=True` during pretraining, which will train the linear classifier as you pretrain the model. The `stop_gradient` operator is uesd to prevent backpropagating the label information to representations.
96 |
97 | More conventionally, you can also finetune the linear head on top of a pretrained model after pretraining, as follows:
98 |
99 | ```
100 | class Model(tf.keras.Model):
101 | def __init__(self, path):
102 | super(Model, self).__init__()
103 | # Load a pretrained SimCLR model.
104 | self.saved_model = tf.saved_model.load(path)
105 | # Linear head.
106 | self.dense_layer = tf.keras.layers.Dense(units=num_classes,
107 | name="head_supervised_new")
108 | self.optimizer =
109 |
110 | def call(self, x):
111 | with tf.GradientTape() as tape:
112 | # Use `trainable=False` since we do not wish to update batch norm
113 | # statistics of the loaded model. If finetuning everything, set this to
114 | # True.
115 | outputs = self.saved_model(x['image'], trainable=False)
116 | logits_t = self.dense_layer(outputs['final_avg_pool'])
117 | loss_t = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
118 | labels = tf.one_hot(x['label'], num_classes), logits=logits_t))
119 | dense_layer_weights = self.dense_layer.trainable_weights
120 | print('Variables to train:', dense_layer_weights)
121 | # Note: We only compute gradients wrt the linear head. To finetune all
122 | # weights use self.trainable_weights instead.
123 | grads = tape.gradient(loss_t, dense_layer_weights)
124 | self.optimizer.apply_gradients(zip(grads, dense_layer_weights))
125 | return loss_t, x["image"], logits_t, x["label"]
126 |
127 | model = Model("gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_1x_sk0/saved_model/")
128 |
129 | # Use tf.function to speed up training. Remove this when debugging intermediate
130 | # model activations.
131 | @tf.function
132 | def train_step(x):
133 | return model(x)
134 |
135 | ds = build_dataset(...)
136 | iterator = iter(ds)
137 | for _ in range(num_steps):
138 | train_step(next(iterator))
139 | ```
140 |
141 | Check the colab in `colabs/finetuning.ipynb` for a complete example.
142 |
143 | ## Semi-supervised learning and fine-tuning the whole network
144 |
145 | You can access 1% and 10% ImageNet subsets used for semi-supervised learning via [tensorflow datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012_subset): simply set `dataset=imagenet2012_subset/1pct` and `dataset=imagenet2012_subset/10pct` in the command line for fine-tuning on these subsets.
146 |
147 | You can also find image IDs of these subsets in `imagenet_subsets/`.
148 |
149 | ## Cite
150 |
151 | [SimCLR paper](https://arxiv.org/abs/2002.05709):
152 |
153 | ```
154 | @article{chen2020simple,
155 | title={A Simple Framework for Contrastive Learning of Visual Representations},
156 | author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey},
157 | journal={arXiv preprint arXiv:2002.05709},
158 | year={2020}
159 | }
160 | ```
161 |
162 | [SimCLRv2 paper](https://arxiv.org/abs/2006.10029):
163 |
164 | ```
165 | @article{chen2020big,
166 | title={Big Self-Supervised Models are Strong Semi-Supervised Learners},
167 | author={Chen, Ting and Kornblith, Simon and Swersky, Kevin and Norouzi, Mohammad and Hinton, Geoffrey},
168 | journal={arXiv preprint arXiv:2006.10029},
169 | year={2020}
170 | }
171 | ```
172 |
173 | ## Disclaimer
174 | This is not an official Google product.
175 |
--------------------------------------------------------------------------------
/tf2/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data pipeline."""
17 |
18 | import functools
19 | from absl import flags
20 | from absl import logging
21 |
22 | import data_util
23 | import tensorflow.compat.v2 as tf
24 | import tensorflow_datasets as tfds
25 |
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | def build_input_fn(builder, global_batch_size, topology, is_training):
30 | """Build input function.
31 |
32 | Args:
33 | builder: TFDS builder for specified dataset.
34 | global_batch_size: Global batch size.
35 | topology: An instance of `tf.tpu.experimental.Topology` or None.
36 | is_training: Whether to build in training mode.
37 |
38 | Returns:
39 | A function that accepts a dict of params and returns a tuple of images and
40 | features, to be used as the input_fn in TPUEstimator.
41 | """
42 |
43 | def _input_fn(input_context):
44 | """Inner input function."""
45 | batch_size = input_context.get_per_replica_batch_size(global_batch_size)
46 | logging.info('Global batch size: %d', global_batch_size)
47 | logging.info('Per-replica batch size: %d', batch_size)
48 | preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True)
49 | preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False)
50 | num_classes = builder.info.features['label'].num_classes
51 |
52 | def map_fn(image, label):
53 | """Produces multiple transformations of the same batch."""
54 | if is_training and FLAGS.train_mode == 'pretrain':
55 | xs = []
56 | for _ in range(2): # Two transformations
57 | xs.append(preprocess_fn_pretrain(image))
58 | image = tf.concat(xs, -1)
59 | else:
60 | image = preprocess_fn_finetune(image)
61 | label = tf.one_hot(label, num_classes)
62 | return image, label
63 |
64 | logging.info('num_input_pipelines: %d', input_context.num_input_pipelines)
65 | dataset = builder.as_dataset(
66 | split=FLAGS.train_split if is_training else FLAGS.eval_split,
67 | shuffle_files=is_training,
68 | as_supervised=True,
69 | # Passing the input_context to TFDS makes TFDS read different parts
70 | # of the dataset on different workers. We also adjust the interleave
71 | # parameters to achieve better performance.
72 | read_config=tfds.ReadConfig(
73 | interleave_cycle_length=32,
74 | interleave_block_length=1,
75 | input_context=input_context))
76 | if FLAGS.cache_dataset:
77 | dataset = dataset.cache()
78 | if is_training:
79 | options = tf.data.Options()
80 | options.experimental_deterministic = False
81 | options.experimental_slack = True
82 | dataset = dataset.with_options(options)
83 | buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10
84 | dataset = dataset.shuffle(batch_size * buffer_multiplier)
85 | dataset = dataset.repeat(-1)
86 | dataset = dataset.map(
87 | map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
88 | dataset = dataset.batch(batch_size, drop_remainder=is_training)
89 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
90 | return dataset
91 |
92 | return _input_fn
93 |
94 |
95 | def build_distributed_dataset(builder, batch_size, is_training, strategy,
96 | topology):
97 | input_fn = build_input_fn(builder, batch_size, topology, is_training)
98 | return strategy.distribute_datasets_from_function(input_fn)
99 |
100 |
101 | def get_preprocess_fn(is_training, is_pretrain):
102 | """Get function that accepts an image and returns a preprocessed image."""
103 | # Disable test cropping for small images (e.g. CIFAR)
104 | if FLAGS.image_size <= 32:
105 | test_crop = False
106 | else:
107 | test_crop = True
108 | color_jitter_strength = FLAGS.color_jitter_strength if is_pretrain else 0.
109 | return functools.partial(
110 | data_util.preprocess_image,
111 | height=FLAGS.image_size,
112 | width=FLAGS.image_size,
113 | is_training=is_training,
114 | color_jitter_strength=color_jitter_strength,
115 | test_crop=test_crop)
116 |
--------------------------------------------------------------------------------
/tf2/data_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data preprocessing and augmentation."""
17 |
18 | import functools
19 |
20 | import tensorflow.compat.v2 as tf
21 |
22 | CROP_PROPORTION = 0.875 # Standard for ImageNet.
23 |
24 |
25 | def random_apply(func, p, x):
26 | """Randomly apply function func to x with probability p."""
27 | return tf.cond(
28 | tf.less(
29 | tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
30 | tf.cast(p, tf.float32)), lambda: func(x), lambda: x)
31 |
32 |
33 | def random_brightness(image, max_delta, impl='simclrv2'):
34 | """A multiplicative vs additive change of brightness."""
35 | if impl == 'simclrv2':
36 | factor = tf.random.uniform([], tf.maximum(1.0 - max_delta, 0),
37 | 1.0 + max_delta)
38 | image = image * factor
39 | elif impl == 'simclrv1':
40 | image = tf.image.random_brightness(image, max_delta=max_delta)
41 | else:
42 | raise ValueError('Unknown impl {} for random brightness.'.format(impl))
43 | return image
44 |
45 |
46 | def to_grayscale(image, keep_channels=True):
47 | image = tf.image.rgb_to_grayscale(image)
48 | if keep_channels:
49 | image = tf.tile(image, [1, 1, 3])
50 | return image
51 |
52 |
53 | def color_jitter(image, strength, random_order=True, impl='simclrv2'):
54 | """Distorts the color of the image.
55 |
56 | Args:
57 | image: The input image tensor.
58 | strength: the floating number for the strength of the color augmentation.
59 | random_order: A bool, specifying whether to randomize the jittering order.
60 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
61 | version of random brightness.
62 |
63 | Returns:
64 | The distorted image tensor.
65 | """
66 | brightness = 0.8 * strength
67 | contrast = 0.8 * strength
68 | saturation = 0.8 * strength
69 | hue = 0.2 * strength
70 | if random_order:
71 | return color_jitter_rand(
72 | image, brightness, contrast, saturation, hue, impl=impl)
73 | else:
74 | return color_jitter_nonrand(
75 | image, brightness, contrast, saturation, hue, impl=impl)
76 |
77 |
78 | def color_jitter_nonrand(image,
79 | brightness=0,
80 | contrast=0,
81 | saturation=0,
82 | hue=0,
83 | impl='simclrv2'):
84 | """Distorts the color of the image (jittering order is fixed).
85 |
86 | Args:
87 | image: The input image tensor.
88 | brightness: A float, specifying the brightness for color jitter.
89 | contrast: A float, specifying the contrast for color jitter.
90 | saturation: A float, specifying the saturation for color jitter.
91 | hue: A float, specifying the hue for color jitter.
92 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
93 | version of random brightness.
94 |
95 | Returns:
96 | The distorted image tensor.
97 | """
98 | with tf.name_scope('distort_color'):
99 | def apply_transform(i, x, brightness, contrast, saturation, hue):
100 | """Apply the i-th transformation."""
101 | if brightness != 0 and i == 0:
102 | x = random_brightness(x, max_delta=brightness, impl=impl)
103 | elif contrast != 0 and i == 1:
104 | x = tf.image.random_contrast(
105 | x, lower=1-contrast, upper=1+contrast)
106 | elif saturation != 0 and i == 2:
107 | x = tf.image.random_saturation(
108 | x, lower=1-saturation, upper=1+saturation)
109 | elif hue != 0:
110 | x = tf.image.random_hue(x, max_delta=hue)
111 | return x
112 |
113 | for i in range(4):
114 | image = apply_transform(i, image, brightness, contrast, saturation, hue)
115 | image = tf.clip_by_value(image, 0., 1.)
116 | return image
117 |
118 |
119 | def color_jitter_rand(image,
120 | brightness=0,
121 | contrast=0,
122 | saturation=0,
123 | hue=0,
124 | impl='simclrv2'):
125 | """Distorts the color of the image (jittering order is random).
126 |
127 | Args:
128 | image: The input image tensor.
129 | brightness: A float, specifying the brightness for color jitter.
130 | contrast: A float, specifying the contrast for color jitter.
131 | saturation: A float, specifying the saturation for color jitter.
132 | hue: A float, specifying the hue for color jitter.
133 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
134 | version of random brightness.
135 |
136 | Returns:
137 | The distorted image tensor.
138 | """
139 | with tf.name_scope('distort_color'):
140 | def apply_transform(i, x):
141 | """Apply the i-th transformation."""
142 | def brightness_foo():
143 | if brightness == 0:
144 | return x
145 | else:
146 | return random_brightness(x, max_delta=brightness, impl=impl)
147 |
148 | def contrast_foo():
149 | if contrast == 0:
150 | return x
151 | else:
152 | return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
153 | def saturation_foo():
154 | if saturation == 0:
155 | return x
156 | else:
157 | return tf.image.random_saturation(
158 | x, lower=1-saturation, upper=1+saturation)
159 | def hue_foo():
160 | if hue == 0:
161 | return x
162 | else:
163 | return tf.image.random_hue(x, max_delta=hue)
164 | x = tf.cond(tf.less(i, 2),
165 | lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
166 | lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
167 | return x
168 |
169 | perm = tf.random.shuffle(tf.range(4))
170 | for i in range(4):
171 | image = apply_transform(perm[i], image)
172 | image = tf.clip_by_value(image, 0., 1.)
173 | return image
174 |
175 |
176 | def _compute_crop_shape(
177 | image_height, image_width, aspect_ratio, crop_proportion):
178 | """Compute aspect ratio-preserving shape for central crop.
179 |
180 | The resulting shape retains `crop_proportion` along one side and a proportion
181 | less than or equal to `crop_proportion` along the other side.
182 |
183 | Args:
184 | image_height: Height of image to be cropped.
185 | image_width: Width of image to be cropped.
186 | aspect_ratio: Desired aspect ratio (width / height) of output.
187 | crop_proportion: Proportion of image to retain along the less-cropped side.
188 |
189 | Returns:
190 | crop_height: Height of image after cropping.
191 | crop_width: Width of image after cropping.
192 | """
193 | image_width_float = tf.cast(image_width, tf.float32)
194 | image_height_float = tf.cast(image_height, tf.float32)
195 |
196 | def _requested_aspect_ratio_wider_than_image():
197 | crop_height = tf.cast(
198 | tf.math.rint(crop_proportion / aspect_ratio * image_width_float),
199 | tf.int32)
200 | crop_width = tf.cast(
201 | tf.math.rint(crop_proportion * image_width_float), tf.int32)
202 | return crop_height, crop_width
203 |
204 | def _image_wider_than_requested_aspect_ratio():
205 | crop_height = tf.cast(
206 | tf.math.rint(crop_proportion * image_height_float), tf.int32)
207 | crop_width = tf.cast(
208 | tf.math.rint(crop_proportion * aspect_ratio * image_height_float),
209 | tf.int32)
210 | return crop_height, crop_width
211 |
212 | return tf.cond(
213 | aspect_ratio > image_width_float / image_height_float,
214 | _requested_aspect_ratio_wider_than_image,
215 | _image_wider_than_requested_aspect_ratio)
216 |
217 |
218 | def center_crop(image, height, width, crop_proportion):
219 | """Crops to center of image and rescales to desired size.
220 |
221 | Args:
222 | image: Image Tensor to crop.
223 | height: Height of image to be cropped.
224 | width: Width of image to be cropped.
225 | crop_proportion: Proportion of image to retain along the less-cropped side.
226 |
227 | Returns:
228 | A `height` x `width` x channels Tensor holding a central crop of `image`.
229 | """
230 | shape = tf.shape(image)
231 | image_height = shape[0]
232 | image_width = shape[1]
233 | crop_height, crop_width = _compute_crop_shape(
234 | image_height, image_width, width / height, crop_proportion)
235 | offset_height = ((image_height - crop_height) + 1) // 2
236 | offset_width = ((image_width - crop_width) + 1) // 2
237 | image = tf.image.crop_to_bounding_box(
238 | image, offset_height, offset_width, crop_height, crop_width)
239 |
240 | image = tf.image.resize([image], [height, width],
241 | method=tf.image.ResizeMethod.BICUBIC)[0]
242 |
243 | return image
244 |
245 |
246 | def distorted_bounding_box_crop(image,
247 | bbox,
248 | min_object_covered=0.1,
249 | aspect_ratio_range=(0.75, 1.33),
250 | area_range=(0.05, 1.0),
251 | max_attempts=100,
252 | scope=None):
253 | """Generates cropped_image using one of the bboxes randomly distorted.
254 |
255 | See `tf.image.sample_distorted_bounding_box` for more documentation.
256 |
257 | Args:
258 | image: `Tensor` of image data.
259 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
260 | where each coordinate is [0, 1) and the coordinates are arranged
261 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
262 | image.
263 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
264 | area of the image must contain at least this fraction of any bounding
265 | box supplied.
266 | aspect_ratio_range: An optional list of `float`s. The cropped area of the
267 | image must have an aspect ratio = width / height within this range.
268 | area_range: An optional list of `float`s. The cropped area of the image
269 | must contain a fraction of the supplied image within in this range.
270 | max_attempts: An optional `int`. Number of attempts at generating a cropped
271 | region of the image of the specified constraints. After `max_attempts`
272 | failures, return the entire image.
273 | scope: Optional `str` for name scope.
274 | Returns:
275 | (cropped image `Tensor`, distorted bbox `Tensor`).
276 | """
277 | with tf.name_scope(scope or 'distorted_bounding_box_crop'):
278 | shape = tf.shape(image)
279 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
280 | shape,
281 | bounding_boxes=bbox,
282 | min_object_covered=min_object_covered,
283 | aspect_ratio_range=aspect_ratio_range,
284 | area_range=area_range,
285 | max_attempts=max_attempts,
286 | use_image_if_no_bounding_boxes=True)
287 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
288 |
289 | # Crop the image to the specified bounding box.
290 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
291 | target_height, target_width, _ = tf.unstack(bbox_size)
292 | image = tf.image.crop_to_bounding_box(
293 | image, offset_y, offset_x, target_height, target_width)
294 |
295 | return image
296 |
297 |
298 | def crop_and_resize(image, height, width):
299 | """Make a random crop and resize it to height `height` and width `width`.
300 |
301 | Args:
302 | image: Tensor representing the image.
303 | height: Desired image height.
304 | width: Desired image width.
305 |
306 | Returns:
307 | A `height` x `width` x channels Tensor holding a random crop of `image`.
308 | """
309 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
310 | aspect_ratio = width / height
311 | image = distorted_bounding_box_crop(
312 | image,
313 | bbox,
314 | min_object_covered=0.1,
315 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
316 | area_range=(0.08, 1.0),
317 | max_attempts=100,
318 | scope=None)
319 | return tf.image.resize([image], [height, width],
320 | method=tf.image.ResizeMethod.BICUBIC)[0]
321 |
322 |
323 | def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
324 | """Blurs the given image with separable convolution.
325 |
326 |
327 | Args:
328 | image: Tensor of shape [height, width, channels] and dtype float to blur.
329 | kernel_size: Integer Tensor for the size of the blur kernel. This is should
330 | be an odd number. If it is an even number, the actual kernel size will be
331 | size + 1.
332 | sigma: Sigma value for gaussian operator.
333 | padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
334 |
335 | Returns:
336 | A Tensor representing the blurred image.
337 | """
338 | radius = tf.cast(kernel_size / 2, dtype=tf.int32)
339 | kernel_size = radius * 2 + 1
340 | x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32)
341 | blur_filter = tf.exp(-tf.pow(x, 2.0) /
342 | (2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0)))
343 | blur_filter /= tf.reduce_sum(blur_filter)
344 | # One vertical and one horizontal filter.
345 | blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
346 | blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
347 | num_channels = tf.shape(image)[-1]
348 | blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
349 | blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
350 | expand_batch_dim = image.shape.ndims == 3
351 | if expand_batch_dim:
352 | # Tensorflow requires batched input to convolutions, which we can fake with
353 | # an extra dimension.
354 | image = tf.expand_dims(image, axis=0)
355 | blurred = tf.nn.depthwise_conv2d(
356 | image, blur_h, strides=[1, 1, 1, 1], padding=padding)
357 | blurred = tf.nn.depthwise_conv2d(
358 | blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
359 | if expand_batch_dim:
360 | blurred = tf.squeeze(blurred, axis=0)
361 | return blurred
362 |
363 |
364 | def random_crop_with_resize(image, height, width, p=1.0):
365 | """Randomly crop and resize an image.
366 |
367 | Args:
368 | image: `Tensor` representing an image of arbitrary size.
369 | height: Height of output image.
370 | width: Width of output image.
371 | p: Probability of applying this transformation.
372 |
373 | Returns:
374 | A preprocessed image `Tensor`.
375 | """
376 | def _transform(image): # pylint: disable=missing-docstring
377 | image = crop_and_resize(image, height, width)
378 | return image
379 | return random_apply(_transform, p=p, x=image)
380 |
381 |
382 | def random_color_jitter(image, p=1.0, strength=1.0,
383 | impl='simclrv2'):
384 |
385 | def _transform(image):
386 | color_jitter_t = functools.partial(
387 | color_jitter, strength=strength, impl=impl)
388 | image = random_apply(color_jitter_t, p=0.8, x=image)
389 | return random_apply(to_grayscale, p=0.2, x=image)
390 | return random_apply(_transform, p=p, x=image)
391 |
392 |
393 | def random_blur(image, height, width, p=1.0):
394 | """Randomly blur an image.
395 |
396 | Args:
397 | image: `Tensor` representing an image of arbitrary size.
398 | height: Height of output image.
399 | width: Width of output image.
400 | p: probability of applying this transformation.
401 |
402 | Returns:
403 | A preprocessed image `Tensor`.
404 | """
405 | del width
406 | def _transform(image):
407 | sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
408 | return gaussian_blur(
409 | image, kernel_size=height//10, sigma=sigma, padding='SAME')
410 | return random_apply(_transform, p=p, x=image)
411 |
412 |
413 | def batch_random_blur(images_list, height, width, blur_probability=0.5):
414 | """Apply efficient batch data transformations.
415 |
416 | Args:
417 | images_list: a list of image tensors.
418 | height: the height of image.
419 | width: the width of image.
420 | blur_probability: the probaility to apply the blur operator.
421 |
422 | Returns:
423 | Preprocessed feature list.
424 | """
425 | def generate_selector(p, bsz):
426 | shape = [bsz, 1, 1, 1]
427 | selector = tf.cast(
428 | tf.less(tf.random.uniform(shape, 0, 1, dtype=tf.float32), p),
429 | tf.float32)
430 | return selector
431 |
432 | new_images_list = []
433 | for images in images_list:
434 | images_new = random_blur(images, height, width, p=1.)
435 | selector = generate_selector(blur_probability, tf.shape(images)[0])
436 | images = images_new * selector + images * (1 - selector)
437 | images = tf.clip_by_value(images, 0., 1.)
438 | new_images_list.append(images)
439 |
440 | return new_images_list
441 |
442 |
443 | def preprocess_for_train(image,
444 | height,
445 | width,
446 | color_jitter_strength=0.,
447 | crop=True,
448 | flip=True,
449 | impl='simclrv2'):
450 | """Preprocesses the given image for training.
451 |
452 | Args:
453 | image: `Tensor` representing an image of arbitrary size.
454 | height: Height of output image.
455 | width: Width of output image.
456 | color_jitter_strength: `float` between 0 and 1 indicating the color
457 | distortion strength, disable color distortion if not bigger than 0.
458 | crop: Whether to crop the image.
459 | flip: Whether or not to flip left and right of an image.
460 | impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
461 | version of random brightness.
462 |
463 | Returns:
464 | A preprocessed image `Tensor`.
465 | """
466 | if crop:
467 | image = random_crop_with_resize(image, height, width)
468 | if flip:
469 | image = tf.image.random_flip_left_right(image)
470 | if color_jitter_strength > 0:
471 | image = random_color_jitter(image, strength=color_jitter_strength,
472 | impl=impl)
473 | image = tf.reshape(image, [height, width, 3])
474 | image = tf.clip_by_value(image, 0., 1.)
475 | return image
476 |
477 |
478 | def preprocess_for_eval(image, height, width, crop=True):
479 | """Preprocesses the given image for evaluation.
480 |
481 | Args:
482 | image: `Tensor` representing an image of arbitrary size.
483 | height: Height of output image.
484 | width: Width of output image.
485 | crop: Whether or not to (center) crop the test images.
486 |
487 | Returns:
488 | A preprocessed image `Tensor`.
489 | """
490 | if crop:
491 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
492 | image = tf.reshape(image, [height, width, 3])
493 | image = tf.clip_by_value(image, 0., 1.)
494 | return image
495 |
496 |
497 | def preprocess_image(image, height, width, is_training=False,
498 | color_jitter_strength=0., test_crop=True):
499 | """Preprocesses the given image.
500 |
501 | Args:
502 | image: `Tensor` representing an image of arbitrary size.
503 | height: Height of output image.
504 | width: Width of output image.
505 | is_training: `bool` for whether the preprocessing is for training.
506 | color_jitter_strength: `float` between 0 and 1 indicating the color
507 | distortion strength, disable color distortion if not bigger than 0.
508 | test_crop: whether or not to extract a central crop of the images
509 | (as for standard ImageNet evaluation) during the evaluation.
510 |
511 | Returns:
512 | A preprocessed image `Tensor` of range [0, 1].
513 | """
514 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
515 | if is_training:
516 | return preprocess_for_train(image, height, width, color_jitter_strength)
517 | else:
518 | return preprocess_for_eval(image, height, width, test_crop)
519 |
--------------------------------------------------------------------------------
/tf2/lars_optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions and classes related to optimization (weight updates)."""
17 |
18 | import re
19 |
20 | import tensorflow.compat.v2 as tf
21 |
22 | EETA_DEFAULT = 0.001
23 |
24 |
25 | class LARSOptimizer(tf.keras.optimizers.legacy.Optimizer):
26 | """Layer-wise Adaptive Rate Scaling for large batch training.
27 |
28 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
29 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
30 | """
31 |
32 | def __init__(self,
33 | learning_rate,
34 | momentum=0.9,
35 | use_nesterov=False,
36 | weight_decay=0.0,
37 | exclude_from_weight_decay=None,
38 | exclude_from_layer_adaptation=None,
39 | classic_momentum=True,
40 | eeta=EETA_DEFAULT,
41 | name="LARSOptimizer"):
42 | """Constructs a LARSOptimizer.
43 |
44 | Args:
45 | learning_rate: A `float` for learning rate.
46 | momentum: A `float` for momentum.
47 | use_nesterov: A 'Boolean' for whether to use nesterov momentum.
48 | weight_decay: A `float` for weight decay.
49 | exclude_from_weight_decay: A list of `string` for variable screening, if
50 | any of the string appears in a variable's name, the variable will be
51 | excluded for computing weight decay. For example, one could specify
52 | the list like ['batch_normalization', 'bias'] to exclude BN and bias
53 | from weight decay.
54 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
55 | for layer adaptation. If it is None, it will be defaulted the same as
56 | exclude_from_weight_decay.
57 | classic_momentum: A `boolean` for whether to use classic (or popular)
58 | momentum. The learning rate is applied during momeuntum update in
59 | classic momentum, but after momentum for popular momentum.
60 | eeta: A `float` for scaling of learning rate when computing trust ratio.
61 | name: The name for the scope.
62 | """
63 | super(LARSOptimizer, self).__init__(name)
64 |
65 | self._set_hyper("learning_rate", learning_rate)
66 | self.momentum = momentum
67 | self.weight_decay = weight_decay
68 | self.use_nesterov = use_nesterov
69 | self.classic_momentum = classic_momentum
70 | self.eeta = eeta
71 | self.exclude_from_weight_decay = exclude_from_weight_decay
72 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
73 | # arg is None.
74 | if exclude_from_layer_adaptation:
75 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
76 | else:
77 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
78 |
79 | def _create_slots(self, var_list):
80 | for v in var_list:
81 | self.add_slot(v, "Momentum")
82 |
83 | def _resource_apply_dense(self, grad, param, apply_state=None):
84 | if grad is None or param is None:
85 | return tf.no_op()
86 |
87 | var_device, var_dtype = param.device, param.dtype.base_dtype
88 | coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
89 | self._fallback_apply_state(var_device, var_dtype))
90 | learning_rate = coefficients["lr_t"]
91 |
92 | param_name = param.name
93 |
94 | v = self.get_slot(param, "Momentum")
95 |
96 | if self._use_weight_decay(param_name):
97 | grad += self.weight_decay * param
98 |
99 | if self.classic_momentum:
100 | trust_ratio = 1.0
101 | if self._do_layer_adaptation(param_name):
102 | w_norm = tf.norm(param, ord=2)
103 | g_norm = tf.norm(grad, ord=2)
104 | trust_ratio = tf.where(
105 | tf.greater(w_norm, 0),
106 | tf.where(tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0),
107 | 1.0)
108 | scaled_lr = learning_rate * trust_ratio
109 |
110 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
111 | if self.use_nesterov:
112 | update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
113 | else:
114 | update = next_v
115 | next_param = param - update
116 | else:
117 | next_v = tf.multiply(self.momentum, v) + grad
118 | if self.use_nesterov:
119 | update = tf.multiply(self.momentum, next_v) + grad
120 | else:
121 | update = next_v
122 |
123 | trust_ratio = 1.0
124 | if self._do_layer_adaptation(param_name):
125 | w_norm = tf.norm(param, ord=2)
126 | v_norm = tf.norm(update, ord=2)
127 | trust_ratio = tf.where(
128 | tf.greater(w_norm, 0),
129 | tf.where(tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 1.0),
130 | 1.0)
131 | scaled_lr = trust_ratio * learning_rate
132 | next_param = param - scaled_lr * update
133 |
134 | return tf.group(*[
135 | param.assign(next_param, use_locking=False),
136 | v.assign(next_v, use_locking=False)
137 | ])
138 |
139 | def _use_weight_decay(self, param_name):
140 | """Whether to use L2 weight decay for `param_name`."""
141 | if not self.weight_decay:
142 | return False
143 | if self.exclude_from_weight_decay:
144 | for r in self.exclude_from_weight_decay:
145 | # TODO(srbs): Try to avoid name based filtering.
146 | if re.search(r, param_name) is not None:
147 | return False
148 | return True
149 |
150 | def _do_layer_adaptation(self, param_name):
151 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
152 | if self.exclude_from_layer_adaptation:
153 | for r in self.exclude_from_layer_adaptation:
154 | # TODO(srbs): Try to avoid name based filtering.
155 | if re.search(r, param_name) is not None:
156 | return False
157 | return True
158 |
159 | def get_config(self):
160 | config = super(LARSOptimizer, self).get_config()
161 | config.update({
162 | "learning_rate": self._serialize_hyperparameter("learning_rate"),
163 | "momentum": self.momentum,
164 | "classic_momentum": self.classic_momentum,
165 | "weight_decay": self.weight_decay,
166 | "eeta": self.eeta,
167 | "use_nesterov": self.use_nesterov,
168 | })
169 | return config
170 |
--------------------------------------------------------------------------------
/tf2/metrics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Training utilities."""
17 |
18 | from absl import logging
19 |
20 | import tensorflow.compat.v2 as tf
21 |
22 |
23 | def update_pretrain_metrics_train(contrast_loss, contrast_acc, contrast_entropy,
24 | loss, logits_con, labels_con):
25 | """Updated pretraining metrics."""
26 | contrast_loss.update_state(loss)
27 |
28 | contrast_acc_val = tf.equal(
29 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
30 | contrast_acc_val = tf.reduce_mean(tf.cast(contrast_acc_val, tf.float32))
31 | contrast_acc.update_state(contrast_acc_val)
32 |
33 | prob_con = tf.nn.softmax(logits_con)
34 | entropy_con = -tf.reduce_mean(
35 | tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))
36 | contrast_entropy.update_state(entropy_con)
37 |
38 |
39 | def update_pretrain_metrics_eval(contrast_loss_metric,
40 | contrastive_top_1_accuracy_metric,
41 | contrastive_top_5_accuracy_metric,
42 | contrast_loss, logits_con, labels_con):
43 | contrast_loss_metric.update_state(contrast_loss)
44 | contrastive_top_1_accuracy_metric.update_state(
45 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
46 | contrastive_top_5_accuracy_metric.update_state(labels_con, logits_con)
47 |
48 |
49 | def update_finetune_metrics_train(supervised_loss_metric, supervised_acc_metric,
50 | loss, labels, logits):
51 | supervised_loss_metric.update_state(loss)
52 |
53 | label_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits, axis=1))
54 | label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
55 | supervised_acc_metric.update_state(label_acc)
56 |
57 |
58 | def update_finetune_metrics_eval(label_top_1_accuracy_metrics,
59 | label_top_5_accuracy_metrics, outputs, labels):
60 | label_top_1_accuracy_metrics.update_state(
61 | tf.argmax(labels, 1), tf.argmax(outputs, axis=1))
62 | label_top_5_accuracy_metrics.update_state(labels, outputs)
63 |
64 |
65 | def _float_metric_value(metric):
66 | """Gets the value of a float-value keras metric."""
67 | return metric.result().numpy().astype(float)
68 |
69 |
70 | def log_and_write_metrics_to_summary(all_metrics, global_step):
71 | for metric in all_metrics:
72 | metric_value = _float_metric_value(metric)
73 | logging.info('Step: [%d] %s = %f', global_step, metric.name, metric_value)
74 | tf.summary.scalar(metric.name, metric_value, step=global_step)
75 |
--------------------------------------------------------------------------------
/tf2/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Model specification for SimCLR."""
17 |
18 | import math
19 | from absl import flags
20 |
21 | import data_util
22 | import lars_optimizer
23 | import resnet
24 | import tensorflow.compat.v2 as tf
25 |
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | def build_optimizer(learning_rate):
30 | """Returns the optimizer."""
31 | if FLAGS.optimizer == 'momentum':
32 | return tf.keras.optimizers.SGD(learning_rate, FLAGS.momentum, nesterov=True)
33 | elif FLAGS.optimizer == 'adam':
34 | return tf.keras.optimizers.Adam(learning_rate)
35 | elif FLAGS.optimizer == 'lars':
36 | return lars_optimizer.LARSOptimizer(
37 | learning_rate,
38 | momentum=FLAGS.momentum,
39 | weight_decay=FLAGS.weight_decay,
40 | exclude_from_weight_decay=[
41 | 'batch_normalization', 'bias', 'head_supervised'
42 | ])
43 | else:
44 | raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer))
45 |
46 |
47 | def add_weight_decay(model, adjust_per_optimizer=True):
48 | """Compute weight decay from flags."""
49 | if adjust_per_optimizer and 'lars' in FLAGS.optimizer:
50 | # Weight decay are taking care of by optimizer for these cases.
51 | # Except for supervised head, which will be added here.
52 | l2_losses = [
53 | tf.nn.l2_loss(v)
54 | for v in model.trainable_variables
55 | if 'head_supervised' in v.name and 'bias' not in v.name
56 | ]
57 | if l2_losses:
58 | return FLAGS.weight_decay * tf.add_n(l2_losses)
59 | else:
60 | return 0
61 |
62 | # TODO(srbs): Think of a way to avoid name-based filtering here.
63 | l2_losses = [
64 | tf.nn.l2_loss(v)
65 | for v in model.trainable_weights
66 | if 'batch_normalization' not in v.name
67 | ]
68 | loss = FLAGS.weight_decay * tf.add_n(l2_losses)
69 | return loss
70 |
71 |
72 | def get_train_steps(num_examples):
73 | """Determine the number of training steps."""
74 | return FLAGS.train_steps or (
75 | num_examples * FLAGS.train_epochs // FLAGS.train_batch_size + 1)
76 |
77 |
78 | class WarmUpAndCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
79 | """Applies a warmup schedule on a given learning rate decay schedule."""
80 |
81 | def __init__(self, base_learning_rate, num_examples, name=None):
82 | super(WarmUpAndCosineDecay, self).__init__()
83 | self.base_learning_rate = base_learning_rate
84 | self.num_examples = num_examples
85 | self._name = name
86 |
87 | def __call__(self, step):
88 | with tf.name_scope(self._name or 'WarmUpAndCosineDecay'):
89 | warmup_steps = int(
90 | round(FLAGS.warmup_epochs * self.num_examples //
91 | FLAGS.train_batch_size))
92 | if FLAGS.learning_rate_scaling == 'linear':
93 | scaled_lr = self.base_learning_rate * FLAGS.train_batch_size / 256.
94 | elif FLAGS.learning_rate_scaling == 'sqrt':
95 | scaled_lr = self.base_learning_rate * math.sqrt(FLAGS.train_batch_size)
96 | else:
97 | raise ValueError('Unknown learning rate scaling {}'.format(
98 | FLAGS.learning_rate_scaling))
99 | learning_rate = (
100 | step / float(warmup_steps) * scaled_lr if warmup_steps else scaled_lr)
101 |
102 | # Cosine decay learning rate schedule
103 | total_steps = get_train_steps(self.num_examples)
104 | # TODO(srbs): Cache this object.
105 | cosine_decay = tf.keras.experimental.CosineDecay(
106 | scaled_lr, total_steps - warmup_steps)
107 | learning_rate = tf.where(step < warmup_steps, learning_rate,
108 | cosine_decay(step - warmup_steps))
109 |
110 | return learning_rate
111 |
112 | def get_config(self):
113 | return {
114 | 'base_learning_rate': self.base_learning_rate,
115 | 'num_examples': self.num_examples,
116 | }
117 |
118 |
119 | class LinearLayer(tf.keras.layers.Layer):
120 |
121 | def __init__(self,
122 | num_classes,
123 | use_bias=True,
124 | use_bn=False,
125 | name='linear_layer',
126 | **kwargs):
127 | # Note: use_bias is ignored for the dense layer when use_bn=True.
128 | # However, it is still used for batch norm.
129 | super(LinearLayer, self).__init__(**kwargs)
130 | self.num_classes = num_classes
131 | self.use_bias = use_bias
132 | self.use_bn = use_bn
133 | self._name = name
134 | if self.use_bn:
135 | self.bn_relu = resnet.BatchNormRelu(relu=False, center=use_bias)
136 |
137 | def build(self, input_shape):
138 | # TODO(srbs): Add a new SquareDense layer.
139 | if callable(self.num_classes):
140 | num_classes = self.num_classes(input_shape)
141 | else:
142 | num_classes = self.num_classes
143 | self.dense = tf.keras.layers.Dense(
144 | num_classes,
145 | kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
146 | use_bias=self.use_bias and not self.use_bn)
147 | super(LinearLayer, self).build(input_shape)
148 |
149 | def call(self, inputs, training):
150 | assert inputs.shape.ndims == 2, inputs.shape
151 | inputs = self.dense(inputs)
152 | if self.use_bn:
153 | inputs = self.bn_relu(inputs, training=training)
154 | return inputs
155 |
156 |
157 | class ProjectionHead(tf.keras.layers.Layer):
158 |
159 | def __init__(self, **kwargs):
160 | out_dim = FLAGS.proj_out_dim
161 | self.linear_layers = []
162 | if FLAGS.proj_head_mode == 'none':
163 | pass # directly use the output hiddens as hiddens
164 | elif FLAGS.proj_head_mode == 'linear':
165 | self.linear_layers = [
166 | LinearLayer(
167 | num_classes=out_dim, use_bias=False, use_bn=True, name='l_0')
168 | ]
169 | elif FLAGS.proj_head_mode == 'nonlinear':
170 | for j in range(FLAGS.num_proj_layers):
171 | if j != FLAGS.num_proj_layers - 1:
172 | # for the middle layers, use bias and relu for the output.
173 | self.linear_layers.append(
174 | LinearLayer(
175 | num_classes=lambda input_shape: int(input_shape[-1]),
176 | use_bias=True,
177 | use_bn=True,
178 | name='nl_%d' % j))
179 | else:
180 | # for the final layer, neither bias nor relu is used.
181 | self.linear_layers.append(
182 | LinearLayer(
183 | num_classes=FLAGS.proj_out_dim,
184 | use_bias=False,
185 | use_bn=True,
186 | name='nl_%d' % j))
187 | else:
188 | raise ValueError('Unknown head projection mode {}'.format(
189 | FLAGS.proj_head_mode))
190 | super(ProjectionHead, self).__init__(**kwargs)
191 |
192 | def call(self, inputs, training):
193 | if FLAGS.proj_head_mode == 'none':
194 | return inputs # directly use the output hiddens as hiddens
195 | hiddens_list = [tf.identity(inputs, 'proj_head_input')]
196 | if FLAGS.proj_head_mode == 'linear':
197 | assert len(self.linear_layers) == 1, len(self.linear_layers)
198 | return hiddens_list.append(self.linear_layers[0](hiddens_list[-1],
199 | training))
200 | elif FLAGS.proj_head_mode == 'nonlinear':
201 | for j in range(FLAGS.num_proj_layers):
202 | hiddens = self.linear_layers[j](hiddens_list[-1], training)
203 | if j != FLAGS.num_proj_layers - 1:
204 | # for the middle layers, use bias and relu for the output.
205 | hiddens = tf.nn.relu(hiddens)
206 | hiddens_list.append(hiddens)
207 | else:
208 | raise ValueError('Unknown head projection mode {}'.format(
209 | FLAGS.proj_head_mode))
210 | # The first element is the output of the projection head.
211 | # The second element is the input of the finetune head.
212 | proj_head_output = tf.identity(hiddens_list[-1], 'proj_head_output')
213 | return proj_head_output, hiddens_list[FLAGS.ft_proj_selector]
214 |
215 |
216 | class SupervisedHead(tf.keras.layers.Layer):
217 |
218 | def __init__(self, num_classes, name='head_supervised', **kwargs):
219 | super(SupervisedHead, self).__init__(name=name, **kwargs)
220 | self.linear_layer = LinearLayer(num_classes)
221 |
222 | def call(self, inputs, training):
223 | inputs = self.linear_layer(inputs, training)
224 | inputs = tf.identity(inputs, name='logits_sup')
225 | return inputs
226 |
227 |
228 | class Model(tf.keras.models.Model):
229 | """Resnet model with projection or supervised layer."""
230 |
231 | def __init__(self, num_classes, **kwargs):
232 | super(Model, self).__init__(**kwargs)
233 | self.resnet_model = resnet.resnet(
234 | resnet_depth=FLAGS.resnet_depth,
235 | width_multiplier=FLAGS.width_multiplier,
236 | cifar_stem=FLAGS.image_size <= 32)
237 | self._projection_head = ProjectionHead()
238 | if FLAGS.train_mode == 'finetune' or FLAGS.lineareval_while_pretraining:
239 | self.supervised_head = SupervisedHead(num_classes)
240 |
241 | def __call__(self, inputs, training):
242 | features = inputs
243 | if training and FLAGS.train_mode == 'pretrain':
244 | if FLAGS.fine_tune_after_block > -1:
245 | raise ValueError('Does not support layer freezing during pretraining,'
246 | 'should set fine_tune_after_block<=-1 for safety.')
247 | if inputs.shape[3] is None:
248 | raise ValueError('The input channels dimension must be statically known '
249 | f'(got input shape {inputs.shape})')
250 | num_transforms = inputs.shape[3] // 3
251 | num_transforms = tf.repeat(3, num_transforms)
252 | # Split channels, and optionally apply extra batched augmentation.
253 | features_list = tf.split(
254 | features, num_or_size_splits=num_transforms, axis=-1)
255 | if FLAGS.use_blur and training and FLAGS.train_mode == 'pretrain':
256 | features_list = data_util.batch_random_blur(features_list,
257 | FLAGS.image_size,
258 | FLAGS.image_size)
259 | features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c)
260 |
261 | # Base network forward pass.
262 | hiddens = self.resnet_model(features, training=training)
263 |
264 | # Add heads.
265 | projection_head_outputs, supervised_head_inputs = self._projection_head(
266 | hiddens, training)
267 |
268 | if FLAGS.train_mode == 'finetune':
269 | supervised_head_outputs = self.supervised_head(supervised_head_inputs,
270 | training)
271 | return None, supervised_head_outputs
272 | elif FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining:
273 | # When performing pretraining and linear evaluation together we do not
274 | # want information from linear eval flowing back into pretraining network
275 | # so we put a stop_gradient.
276 | supervised_head_outputs = self.supervised_head(
277 | tf.stop_gradient(supervised_head_inputs), training)
278 | return projection_head_outputs, supervised_head_outputs
279 | else:
280 | return projection_head_outputs, None
281 |
--------------------------------------------------------------------------------
/tf2/objective.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Contrastive loss functions."""
17 |
18 | from absl import flags
19 |
20 | import tensorflow.compat.v2 as tf
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | LARGE_NUM = 1e9
25 |
26 |
27 | def add_supervised_loss(labels, logits):
28 | """Compute mean supervised loss over local batch."""
29 | losses = tf.keras.losses.CategoricalCrossentropy(
30 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels,
31 | logits)
32 | return tf.reduce_mean(losses)
33 |
34 |
35 | def add_contrastive_loss(hidden,
36 | hidden_norm=True,
37 | temperature=1.0,
38 | strategy=None):
39 | """Compute loss for model.
40 |
41 | Args:
42 | hidden: hidden vector (`Tensor`) of shape (bsz, dim).
43 | hidden_norm: whether or not to use normalization on the hidden vector.
44 | temperature: a `floating` number for temperature scaling.
45 | strategy: context information for tpu.
46 |
47 | Returns:
48 | A loss scalar.
49 | The logits for contrastive prediction task.
50 | The labels for contrastive prediction task.
51 | """
52 | # Get (normalized) hidden1 and hidden2.
53 | if hidden_norm:
54 | hidden = tf.math.l2_normalize(hidden, -1)
55 | hidden1, hidden2 = tf.split(hidden, 2, 0)
56 | batch_size = tf.shape(hidden1)[0]
57 |
58 | # Gather hidden1/hidden2 across replicas and create local labels.
59 | if strategy is not None:
60 | hidden1_large = tpu_cross_replica_concat(hidden1, strategy)
61 | hidden2_large = tpu_cross_replica_concat(hidden2, strategy)
62 | enlarged_batch_size = tf.shape(hidden1_large)[0]
63 | # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
64 | replica_context = tf.distribute.get_replica_context()
65 | replica_id = tf.cast(
66 | tf.cast(replica_context.replica_id_in_sync_group, tf.uint32), tf.int32)
67 | labels_idx = tf.range(batch_size) + replica_id * batch_size
68 | labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
69 | masks = tf.one_hot(labels_idx, enlarged_batch_size)
70 | else:
71 | hidden1_large = hidden1
72 | hidden2_large = hidden2
73 | labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
74 | masks = tf.one_hot(tf.range(batch_size), batch_size)
75 |
76 | logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
77 | logits_aa = logits_aa - masks * LARGE_NUM
78 | logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
79 | logits_bb = logits_bb - masks * LARGE_NUM
80 | logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
81 | logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
82 |
83 | loss_a = tf.nn.softmax_cross_entropy_with_logits(
84 | labels, tf.concat([logits_ab, logits_aa], 1))
85 | loss_b = tf.nn.softmax_cross_entropy_with_logits(
86 | labels, tf.concat([logits_ba, logits_bb], 1))
87 | loss = tf.reduce_mean(loss_a + loss_b)
88 |
89 | return loss, logits_ab, labels
90 |
91 |
92 | def tpu_cross_replica_concat(tensor, strategy=None):
93 | """Reduce a concatenation of the `tensor` across TPU cores.
94 |
95 | Args:
96 | tensor: tensor to concatenate.
97 | strategy: A `tf.distribute.Strategy`. If not set, CPU execution is assumed.
98 |
99 | Returns:
100 | Tensor of the same rank as `tensor` with first dimension `num_replicas`
101 | times larger.
102 | """
103 | if strategy is None or strategy.num_replicas_in_sync <= 1:
104 | return tensor
105 |
106 | num_replicas = strategy.num_replicas_in_sync
107 |
108 | replica_context = tf.distribute.get_replica_context()
109 | with tf.name_scope('tpu_cross_replica_concat'):
110 | # This creates a tensor that is like the input tensor but has an added
111 | # replica dimension as the outermost dimension. On each replica it will
112 | # contain the local values and zeros for all other values that need to be
113 | # fetched from other replicas.
114 | ext_tensor = tf.scatter_nd(
115 | indices=[[replica_context.replica_id_in_sync_group]],
116 | updates=[tensor],
117 | shape=tf.concat([[num_replicas], tf.shape(tensor)], axis=0))
118 |
119 | # As every value is only present on one replica and 0 in all others, adding
120 | # them all together will result in the full tensor on all replicas.
121 | ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
122 | ext_tensor)
123 |
124 | # Flatten the replica dimension.
125 | # The first dimension size will be: tensor.shape[0] * num_replicas
126 | # Using [-1] trick to support also scalar input.
127 | return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
128 |
--------------------------------------------------------------------------------
/tf2/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow
2 | tensorflow-datasets
3 |
--------------------------------------------------------------------------------
/tf2/resnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Contains definitions for the post-activation form of Residual Networks.
17 |
18 | Residual networks (ResNets) were proposed in:
19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
21 | """
22 |
23 | from absl import flags
24 | import tensorflow.compat.v2 as tf
25 |
26 |
27 | FLAGS = flags.FLAGS
28 | BATCH_NORM_EPSILON = 1e-5
29 |
30 |
31 | class BatchNormRelu(tf.keras.layers.Layer): # pylint: disable=missing-docstring
32 |
33 | def __init__(self,
34 | relu=True,
35 | init_zero=False,
36 | center=True,
37 | scale=True,
38 | data_format='channels_last',
39 | **kwargs):
40 | super(BatchNormRelu, self).__init__(**kwargs)
41 | self.relu = relu
42 | if init_zero:
43 | gamma_initializer = tf.zeros_initializer()
44 | else:
45 | gamma_initializer = tf.ones_initializer()
46 | if data_format == 'channels_first':
47 | axis = 1
48 | else:
49 | axis = -1
50 | if FLAGS.global_bn:
51 | # TODO(srbs): Set fused=True
52 | # Batch normalization layers with fused=True only support 4D input
53 | # tensors.
54 | self.bn = tf.keras.layers.experimental.SyncBatchNormalization(
55 | axis=axis,
56 | momentum=FLAGS.batch_norm_decay,
57 | epsilon=BATCH_NORM_EPSILON,
58 | center=center,
59 | scale=scale,
60 | gamma_initializer=gamma_initializer)
61 | else:
62 | # TODO(srbs): Set fused=True
63 | # Batch normalization layers with fused=True only support 4D input
64 | # tensors.
65 | self.bn = tf.keras.layers.BatchNormalization(
66 | axis=axis,
67 | momentum=FLAGS.batch_norm_decay,
68 | epsilon=BATCH_NORM_EPSILON,
69 | center=center,
70 | scale=scale,
71 | fused=False,
72 | gamma_initializer=gamma_initializer)
73 |
74 | def call(self, inputs, training):
75 | inputs = self.bn(inputs, training=training)
76 | if self.relu:
77 | inputs = tf.nn.relu(inputs)
78 | return inputs
79 |
80 |
81 | class DropBlock(tf.keras.layers.Layer): # pylint: disable=missing-docstring
82 |
83 | def __init__(self,
84 | keep_prob,
85 | dropblock_size,
86 | data_format='channels_last',
87 | **kwargs):
88 | self.keep_prob = keep_prob
89 | self.dropblock_size = dropblock_size
90 | self.data_format = data_format
91 | super(DropBlock, self).__init__(**kwargs)
92 |
93 | def call(self, net, training):
94 | keep_prob = self.keep_prob
95 | dropblock_size = self.dropblock_size
96 | data_format = self.data_format
97 | if not training or keep_prob is None:
98 | return net
99 |
100 | tf.logging.info(
101 | 'Applying DropBlock: dropblock_size {}, net.shape {}'.format(
102 | dropblock_size, net.shape))
103 |
104 | if data_format == 'channels_last':
105 | _, width, height, _ = net.get_shape().as_list()
106 | else:
107 | _, _, width, height = net.get_shape().as_list()
108 | if width != height:
109 | raise ValueError('Input tensor with width!=height is not supported.')
110 |
111 | dropblock_size = min(dropblock_size, width)
112 | # seed_drop_rate is the gamma parameter of DropBlcok.
113 | seed_drop_rate = (1.0 - keep_prob) * width**2 / dropblock_size**2 / (
114 | width - dropblock_size + 1)**2
115 |
116 | # Forces the block to be inside the feature map.
117 | w_i, h_i = tf.meshgrid(tf.range(width), tf.range(width))
118 | valid_block_center = tf.logical_and(
119 | tf.logical_and(w_i >= int(dropblock_size // 2),
120 | w_i < width - (dropblock_size - 1) // 2),
121 | tf.logical_and(h_i >= int(dropblock_size // 2),
122 | h_i < width - (dropblock_size - 1) // 2))
123 |
124 | valid_block_center = tf.expand_dims(valid_block_center, 0)
125 | valid_block_center = tf.expand_dims(
126 | valid_block_center, -1 if data_format == 'channels_last' else 0)
127 |
128 | randnoise = tf.random_uniform(net.shape, dtype=tf.float32)
129 | block_pattern = (
130 | 1 - tf.cast(valid_block_center, dtype=tf.float32) + tf.cast(
131 | (1 - seed_drop_rate), dtype=tf.float32) + randnoise) >= 1
132 | block_pattern = tf.cast(block_pattern, dtype=tf.float32)
133 |
134 | if dropblock_size == width:
135 | block_pattern = tf.reduce_min(
136 | block_pattern,
137 | axis=[1, 2] if data_format == 'channels_last' else [2, 3],
138 | keepdims=True)
139 | else:
140 | if data_format == 'channels_last':
141 | ksize = [1, dropblock_size, dropblock_size, 1]
142 | else:
143 | ksize = [1, 1, dropblock_size, dropblock_size]
144 | block_pattern = -tf.nn.max_pool(
145 | -block_pattern,
146 | ksize=ksize,
147 | strides=[1, 1, 1, 1],
148 | padding='SAME',
149 | data_format='NHWC' if data_format == 'channels_last' else 'NCHW')
150 |
151 | percent_ones = (
152 | tf.cast(tf.reduce_sum((block_pattern)), tf.float32) /
153 | tf.cast(tf.size(block_pattern), tf.float32))
154 |
155 | net = net / tf.cast(percent_ones, net.dtype) * tf.cast(
156 | block_pattern, net.dtype)
157 | return net
158 |
159 |
160 | class FixedPadding(tf.keras.layers.Layer): # pylint: disable=missing-docstring
161 |
162 | def __init__(self, kernel_size, data_format='channels_last', **kwargs):
163 | super(FixedPadding, self).__init__(**kwargs)
164 | self.kernel_size = kernel_size
165 | self.data_format = data_format
166 |
167 | def call(self, inputs, training):
168 | kernel_size = self.kernel_size
169 | data_format = self.data_format
170 | pad_total = kernel_size - 1
171 | pad_beg = pad_total // 2
172 | pad_end = pad_total - pad_beg
173 | if data_format == 'channels_first':
174 | padded_inputs = tf.pad(
175 | inputs, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
176 | else:
177 | padded_inputs = tf.pad(
178 | inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
179 |
180 | return padded_inputs
181 |
182 |
183 | class Conv2dFixedPadding(tf.keras.layers.Layer): # pylint: disable=missing-docstring
184 |
185 | def __init__(self,
186 | filters,
187 | kernel_size,
188 | strides,
189 | data_format='channels_last',
190 | **kwargs):
191 | super(Conv2dFixedPadding, self).__init__(**kwargs)
192 | if strides > 1:
193 | self.fixed_padding = FixedPadding(kernel_size, data_format=data_format)
194 | else:
195 | self.fixed_padding = None
196 | self.conv2d = tf.keras.layers.Conv2D(
197 | filters=filters,
198 | kernel_size=kernel_size,
199 | strides=strides,
200 | padding=('SAME' if strides == 1 else 'VALID'),
201 | use_bias=False,
202 | kernel_initializer=tf.keras.initializers.VarianceScaling(),
203 | data_format=data_format)
204 |
205 | def call(self, inputs, training):
206 | if self.fixed_padding:
207 | inputs = self.fixed_padding(inputs, training=training)
208 | return self.conv2d(inputs, training=training)
209 |
210 |
211 | class IdentityLayer(tf.keras.layers.Layer):
212 |
213 | def call(self, inputs, training):
214 | return tf.identity(inputs)
215 |
216 |
217 | class SK_Conv2D(tf.keras.layers.Layer): # pylint: disable=invalid-name
218 | """Selective kernel convolutional layer (https://arxiv.org/abs/1903.06586)."""
219 |
220 | def __init__(self,
221 | filters,
222 | strides,
223 | sk_ratio,
224 | min_dim=32,
225 | data_format='channels_last',
226 | **kwargs):
227 | super(SK_Conv2D, self).__init__(**kwargs)
228 | self.data_format = data_format
229 | self.filters = filters
230 | self.sk_ratio = sk_ratio
231 | self.min_dim = min_dim
232 |
233 | # Two stream convs (using split and both are 3x3).
234 | self.conv2d_fixed_padding = Conv2dFixedPadding(
235 | filters=2 * filters,
236 | kernel_size=3,
237 | strides=strides,
238 | data_format=data_format)
239 | self.batch_norm_relu = BatchNormRelu(data_format=data_format)
240 |
241 | # Mixing weights for two streams.
242 | mid_dim = max(int(filters * sk_ratio), min_dim)
243 | self.conv2d_0 = tf.keras.layers.Conv2D(
244 | filters=mid_dim,
245 | kernel_size=1,
246 | strides=1,
247 | kernel_initializer=tf.keras.initializers.VarianceScaling(),
248 | use_bias=False,
249 | data_format=data_format)
250 | self.batch_norm_relu_1 = BatchNormRelu(data_format=data_format)
251 | self.conv2d_1 = tf.keras.layers.Conv2D(
252 | filters=2 * filters,
253 | kernel_size=1,
254 | strides=1,
255 | kernel_initializer=tf.keras.initializers.VarianceScaling(),
256 | use_bias=False,
257 | data_format=data_format)
258 |
259 | def call(self, inputs, training):
260 | channel_axis = 1 if self.data_format == 'channels_first' else 3
261 | pooling_axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
262 |
263 | # Two stream convs (using split and both are 3x3).
264 | inputs = self.conv2d_fixed_padding(inputs, training=training)
265 | inputs = self.batch_norm_relu(inputs, training=training)
266 | inputs = tf.stack(tf.split(inputs, num_or_size_splits=2, axis=channel_axis))
267 |
268 | # Mixing weights for two streams.
269 | global_features = tf.reduce_mean(
270 | tf.reduce_sum(inputs, axis=0), pooling_axes, keepdims=True)
271 | global_features = self.conv2d_0(global_features, training=training)
272 | global_features = self.batch_norm_relu_1(global_features, training=training)
273 | mixing = self.conv2d_1(global_features, training=training)
274 | mixing = tf.stack(tf.split(mixing, num_or_size_splits=2, axis=channel_axis))
275 | mixing = tf.nn.softmax(mixing, axis=0)
276 |
277 | return tf.reduce_sum(inputs * mixing, axis=0)
278 |
279 |
280 | class SE_Layer(tf.keras.layers.Layer): # pylint: disable=invalid-name
281 | """Squeeze and Excitation layer (https://arxiv.org/abs/1709.01507)."""
282 |
283 | def __init__(self, filters, se_ratio, data_format='channels_last', **kwargs):
284 | super(SE_Layer, self).__init__(**kwargs)
285 | self.data_format = data_format
286 | self.se_reduce = tf.keras.layers.Conv2D(
287 | max(1, int(filters * se_ratio)),
288 | kernel_size=[1, 1],
289 | strides=[1, 1],
290 | kernel_initializer=tf.keras.initializers.VarianceScaling(),
291 | padding='same',
292 | data_format=data_format,
293 | use_bias=True)
294 | self.se_expand = tf.keras.layers.Conv2D(
295 | None, # This is filled later in build().
296 | kernel_size=[1, 1],
297 | strides=[1, 1],
298 | kernel_initializer=tf.keras.initializers.VarianceScaling(),
299 | padding='same',
300 | data_format=data_format,
301 | use_bias=True)
302 |
303 | def build(self, input_shape):
304 | self.se_expand.filters = input_shape[-1]
305 | super(SE_Layer, self).build(input_shape)
306 |
307 | def call(self, inputs, training):
308 | spatial_dims = [2, 3] if self.data_format == 'channels_first' else [1, 2]
309 | se_tensor = tf.reduce_mean(inputs, spatial_dims, keepdims=True)
310 | se_tensor = self.se_expand(tf.nn.relu(self.se_reduce(se_tensor)))
311 | return tf.sigmoid(se_tensor) * inputs
312 |
313 |
314 | class ResidualBlock(tf.keras.layers.Layer): # pylint: disable=missing-docstring
315 |
316 | def __init__(self,
317 | filters,
318 | strides,
319 | use_projection=False,
320 | data_format='channels_last',
321 | dropblock_keep_prob=None,
322 | dropblock_size=None,
323 | **kwargs):
324 | super(ResidualBlock, self).__init__(**kwargs)
325 | del dropblock_keep_prob
326 | del dropblock_size
327 | self.conv2d_bn_layers = []
328 | self.shortcut_layers = []
329 | if use_projection:
330 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187)
331 | if strides > 1:
332 | self.shortcut_layers.append(FixedPadding(2, data_format))
333 | self.shortcut_layers.append(
334 | tf.keras.layers.AveragePooling2D(
335 | pool_size=2,
336 | strides=strides,
337 | padding='SAME' if strides == 1 else 'VALID',
338 | data_format=data_format))
339 | self.shortcut_layers.append(
340 | Conv2dFixedPadding(
341 | filters=filters,
342 | kernel_size=1,
343 | strides=1,
344 | data_format=data_format))
345 | else:
346 | self.shortcut_layers.append(
347 | Conv2dFixedPadding(
348 | filters=filters,
349 | kernel_size=1,
350 | strides=strides,
351 | data_format=data_format))
352 | self.shortcut_layers.append(
353 | BatchNormRelu(relu=False, data_format=data_format))
354 |
355 | self.conv2d_bn_layers.append(
356 | Conv2dFixedPadding(
357 | filters=filters,
358 | kernel_size=3,
359 | strides=strides,
360 | data_format=data_format))
361 | self.conv2d_bn_layers.append(BatchNormRelu(data_format=data_format))
362 | self.conv2d_bn_layers.append(
363 | Conv2dFixedPadding(
364 | filters=filters, kernel_size=3, strides=1, data_format=data_format))
365 | self.conv2d_bn_layers.append(
366 | BatchNormRelu(relu=False, init_zero=True, data_format=data_format))
367 | if FLAGS.se_ratio > 0:
368 | self.se_layer = SE_Layer(filters, FLAGS.se_ratio, data_format=data_format)
369 |
370 | def call(self, inputs, training):
371 | shortcut = inputs
372 | for layer in self.shortcut_layers:
373 | # Projection shortcut in first layer to match filters and strides
374 | shortcut = layer(shortcut, training=training)
375 |
376 | for layer in self.conv2d_bn_layers:
377 | inputs = layer(inputs, training=training)
378 |
379 | if FLAGS.se_ratio > 0:
380 | inputs = self.se_layer(inputs, training=training)
381 |
382 | return tf.nn.relu(inputs + shortcut)
383 |
384 |
385 | class BottleneckBlock(tf.keras.layers.Layer):
386 | """BottleneckBlock."""
387 |
388 | def __init__(self,
389 | filters,
390 | strides,
391 | use_projection=False,
392 | data_format='channels_last',
393 | dropblock_keep_prob=None,
394 | dropblock_size=None,
395 | **kwargs):
396 | super(BottleneckBlock, self).__init__(**kwargs)
397 | self.projection_layers = []
398 | if use_projection:
399 | filters_out = 4 * filters
400 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187)
401 | if strides > 1:
402 | self.projection_layers.append(FixedPadding(2, data_format))
403 | self.projection_layers.append(
404 | tf.keras.layers.AveragePooling2D(
405 | pool_size=2,
406 | strides=strides,
407 | padding='SAME' if strides == 1 else 'VALID',
408 | data_format=data_format))
409 | self.projection_layers.append(
410 | Conv2dFixedPadding(
411 | filters=filters_out,
412 | kernel_size=1,
413 | strides=1,
414 | data_format=data_format))
415 | else:
416 | self.projection_layers.append(
417 | Conv2dFixedPadding(
418 | filters=filters_out,
419 | kernel_size=1,
420 | strides=strides,
421 | data_format=data_format))
422 | self.projection_layers.append(
423 | BatchNormRelu(relu=False, data_format=data_format))
424 | self.shortcut_dropblock = DropBlock(
425 | data_format=data_format,
426 | keep_prob=dropblock_keep_prob,
427 | dropblock_size=dropblock_size)
428 |
429 | self.conv_relu_dropblock_layers = []
430 |
431 | self.conv_relu_dropblock_layers.append(
432 | Conv2dFixedPadding(
433 | filters=filters, kernel_size=1, strides=1, data_format=data_format))
434 | self.conv_relu_dropblock_layers.append(
435 | BatchNormRelu(data_format=data_format))
436 | self.conv_relu_dropblock_layers.append(
437 | DropBlock(
438 | data_format=data_format,
439 | keep_prob=dropblock_keep_prob,
440 | dropblock_size=dropblock_size))
441 |
442 | if FLAGS.sk_ratio > 0:
443 | self.conv_relu_dropblock_layers.append(
444 | SK_Conv2D(filters, strides, FLAGS.sk_ratio, data_format=data_format))
445 | else:
446 | self.conv_relu_dropblock_layers.append(
447 | Conv2dFixedPadding(
448 | filters=filters,
449 | kernel_size=3,
450 | strides=strides,
451 | data_format=data_format))
452 | self.conv_relu_dropblock_layers.append(
453 | BatchNormRelu(data_format=data_format))
454 | self.conv_relu_dropblock_layers.append(
455 | DropBlock(
456 | data_format=data_format,
457 | keep_prob=dropblock_keep_prob,
458 | dropblock_size=dropblock_size))
459 |
460 | self.conv_relu_dropblock_layers.append(
461 | Conv2dFixedPadding(
462 | filters=4 * filters,
463 | kernel_size=1,
464 | strides=1,
465 | data_format=data_format))
466 | self.conv_relu_dropblock_layers.append(
467 | BatchNormRelu(relu=False, init_zero=True, data_format=data_format))
468 | self.conv_relu_dropblock_layers.append(
469 | DropBlock(
470 | data_format=data_format,
471 | keep_prob=dropblock_keep_prob,
472 | dropblock_size=dropblock_size))
473 |
474 | if FLAGS.se_ratio > 0:
475 | self.conv_relu_dropblock_layers.append(
476 | SE_Layer(filters, FLAGS.se_ratio, data_format=data_format))
477 |
478 | def call(self, inputs, training):
479 | shortcut = inputs
480 | for layer in self.projection_layers:
481 | shortcut = layer(shortcut, training=training)
482 | shortcut = self.shortcut_dropblock(shortcut, training=training)
483 |
484 | for layer in self.conv_relu_dropblock_layers:
485 | inputs = layer(inputs, training=training)
486 |
487 | return tf.nn.relu(inputs + shortcut)
488 |
489 |
490 | class BlockGroup(tf.keras.layers.Layer): # pylint: disable=missing-docstring
491 |
492 | def __init__(self,
493 | filters,
494 | block_fn,
495 | blocks,
496 | strides,
497 | data_format='channels_last',
498 | dropblock_keep_prob=None,
499 | dropblock_size=None,
500 | **kwargs):
501 | self._name = kwargs.get('name')
502 | super(BlockGroup, self).__init__(**kwargs)
503 |
504 | self.layers = []
505 | self.layers.append(
506 | block_fn(
507 | filters,
508 | strides,
509 | use_projection=True,
510 | data_format=data_format,
511 | dropblock_keep_prob=dropblock_keep_prob,
512 | dropblock_size=dropblock_size))
513 |
514 | for _ in range(1, blocks):
515 | self.layers.append(
516 | block_fn(
517 | filters,
518 | 1,
519 | data_format=data_format,
520 | dropblock_keep_prob=dropblock_keep_prob,
521 | dropblock_size=dropblock_size))
522 |
523 | def call(self, inputs, training):
524 | for layer in self.layers:
525 | inputs = layer(inputs, training=training)
526 | return tf.identity(inputs, self._name)
527 |
528 |
529 | class Resnet(tf.keras.layers.Layer): # pylint: disable=missing-docstring
530 |
531 | def __init__(self,
532 | block_fn,
533 | layers,
534 | width_multiplier,
535 | cifar_stem=False,
536 | data_format='channels_last',
537 | dropblock_keep_probs=None,
538 | dropblock_size=None,
539 | **kwargs):
540 | super(Resnet, self).__init__(**kwargs)
541 | self.data_format = data_format
542 | if dropblock_keep_probs is None:
543 | dropblock_keep_probs = [None] * 4
544 | if not isinstance(dropblock_keep_probs,
545 | list) or len(dropblock_keep_probs) != 4:
546 | raise ValueError('dropblock_keep_probs is not valid:',
547 | dropblock_keep_probs)
548 | trainable = (
549 | FLAGS.train_mode != 'finetune' or FLAGS.fine_tune_after_block == -1)
550 | self.initial_conv_relu_max_pool = []
551 | if cifar_stem:
552 | self.initial_conv_relu_max_pool.append(
553 | Conv2dFixedPadding(
554 | filters=64 * width_multiplier,
555 | kernel_size=3,
556 | strides=1,
557 | data_format=data_format,
558 | trainable=trainable))
559 | self.initial_conv_relu_max_pool.append(
560 | IdentityLayer(name='initial_conv', trainable=trainable))
561 | self.initial_conv_relu_max_pool.append(
562 | BatchNormRelu(data_format=data_format, trainable=trainable))
563 | self.initial_conv_relu_max_pool.append(
564 | IdentityLayer(name='initial_max_pool', trainable=trainable))
565 | else:
566 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187)
567 | self.initial_conv_relu_max_pool.append(
568 | Conv2dFixedPadding(
569 | filters=64 * width_multiplier // 2,
570 | kernel_size=3,
571 | strides=2,
572 | data_format=data_format,
573 | trainable=trainable))
574 | self.initial_conv_relu_max_pool.append(
575 | BatchNormRelu(data_format=data_format, trainable=trainable))
576 | self.initial_conv_relu_max_pool.append(
577 | Conv2dFixedPadding(
578 | filters=64 * width_multiplier // 2,
579 | kernel_size=3,
580 | strides=1,
581 | data_format=data_format,
582 | trainable=trainable))
583 | self.initial_conv_relu_max_pool.append(
584 | BatchNormRelu(data_format=data_format, trainable=trainable))
585 | self.initial_conv_relu_max_pool.append(
586 | Conv2dFixedPadding(
587 | filters=64 * width_multiplier,
588 | kernel_size=3,
589 | strides=1,
590 | data_format=data_format,
591 | trainable=trainable))
592 | else:
593 | self.initial_conv_relu_max_pool.append(
594 | Conv2dFixedPadding(
595 | filters=64 * width_multiplier,
596 | kernel_size=7,
597 | strides=2,
598 | data_format=data_format,
599 | trainable=trainable))
600 | self.initial_conv_relu_max_pool.append(
601 | IdentityLayer(name='initial_conv', trainable=trainable))
602 | self.initial_conv_relu_max_pool.append(
603 | BatchNormRelu(data_format=data_format, trainable=trainable))
604 |
605 | self.initial_conv_relu_max_pool.append(
606 | tf.keras.layers.MaxPooling2D(
607 | pool_size=3,
608 | strides=2,
609 | padding='SAME',
610 | data_format=data_format,
611 | trainable=trainable))
612 | self.initial_conv_relu_max_pool.append(
613 | IdentityLayer(name='initial_max_pool', trainable=trainable))
614 |
615 | self.block_groups = []
616 | # TODO(srbs): This impl is different from the original one in the case where
617 | # fine_tune_after_block != 4. In that case earlier BN stats were getting
618 | # updated. Now they will not be. Check with Ting to make sure this is ok.
619 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 0:
620 | trainable = True
621 |
622 | self.block_groups.append(
623 | BlockGroup(
624 | filters=64 * width_multiplier,
625 | block_fn=block_fn,
626 | blocks=layers[0],
627 | strides=1,
628 | name='block_group1',
629 | data_format=data_format,
630 | dropblock_keep_prob=dropblock_keep_probs[0],
631 | dropblock_size=dropblock_size,
632 | trainable=trainable))
633 |
634 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 1:
635 | trainable = True
636 |
637 | self.block_groups.append(
638 | BlockGroup(
639 | filters=128 * width_multiplier,
640 | block_fn=block_fn,
641 | blocks=layers[1],
642 | strides=2,
643 | name='block_group2',
644 | data_format=data_format,
645 | dropblock_keep_prob=dropblock_keep_probs[1],
646 | dropblock_size=dropblock_size,
647 | trainable=trainable))
648 |
649 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 2:
650 | trainable = True
651 |
652 | self.block_groups.append(
653 | BlockGroup(
654 | filters=256 * width_multiplier,
655 | block_fn=block_fn,
656 | blocks=layers[2],
657 | strides=2,
658 | name='block_group3',
659 | data_format=data_format,
660 | dropblock_keep_prob=dropblock_keep_probs[2],
661 | dropblock_size=dropblock_size,
662 | trainable=trainable))
663 |
664 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 3:
665 | trainable = True
666 |
667 | self.block_groups.append(
668 | BlockGroup(
669 | filters=512 * width_multiplier,
670 | block_fn=block_fn,
671 | blocks=layers[3],
672 | strides=2,
673 | name='block_group4',
674 | data_format=data_format,
675 | dropblock_keep_prob=dropblock_keep_probs[3],
676 | dropblock_size=dropblock_size,
677 | trainable=trainable))
678 |
679 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 4:
680 | # This case doesn't really matter.
681 | trainable = True
682 |
683 | def call(self, inputs, training):
684 | for layer in self.initial_conv_relu_max_pool:
685 | inputs = layer(inputs, training=training)
686 |
687 | for i, layer in enumerate(self.block_groups):
688 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == i:
689 | inputs = tf.stop_gradient(inputs)
690 | inputs = layer(inputs, training=training)
691 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 4:
692 | inputs = tf.stop_gradient(inputs)
693 | if self.data_format == 'channels_last':
694 | inputs = tf.reduce_mean(inputs, [1, 2])
695 | else:
696 | inputs = tf.reduce_mean(inputs, [2, 3])
697 |
698 | inputs = tf.identity(inputs, 'final_avg_pool')
699 | return inputs
700 |
701 |
702 | def resnet(resnet_depth,
703 | width_multiplier,
704 | cifar_stem=False,
705 | data_format='channels_last',
706 | dropblock_keep_probs=None,
707 | dropblock_size=None):
708 | """Returns the ResNet model for a given size and number of output classes."""
709 | model_params = {
710 | 18: {
711 | 'block': ResidualBlock,
712 | 'layers': [2, 2, 2, 2]
713 | },
714 | 34: {
715 | 'block': ResidualBlock,
716 | 'layers': [3, 4, 6, 3]
717 | },
718 | 50: {
719 | 'block': BottleneckBlock,
720 | 'layers': [3, 4, 6, 3]
721 | },
722 | 101: {
723 | 'block': BottleneckBlock,
724 | 'layers': [3, 4, 23, 3]
725 | },
726 | 152: {
727 | 'block': BottleneckBlock,
728 | 'layers': [3, 8, 36, 3]
729 | },
730 | 200: {
731 | 'block': BottleneckBlock,
732 | 'layers': [3, 24, 36, 3]
733 | }
734 | }
735 |
736 | if resnet_depth not in model_params:
737 | raise ValueError('Not a valid resnet_depth:', resnet_depth)
738 |
739 | params = model_params[resnet_depth]
740 | return Resnet(
741 | params['block'],
742 | params['layers'],
743 | width_multiplier,
744 | cifar_stem=cifar_stem,
745 | dropblock_keep_probs=dropblock_keep_probs,
746 | dropblock_size=dropblock_size,
747 | data_format=data_format)
748 |
--------------------------------------------------------------------------------
/tf2/run.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The main training pipeline."""
17 |
18 | import json
19 | import math
20 | import os
21 |
22 | from absl import app
23 | from absl import flags
24 | from absl import logging
25 | import data as data_lib
26 | import metrics
27 | import model as model_lib
28 | import objective as obj_lib
29 | import tensorflow.compat.v2 as tf
30 | import tensorflow_datasets as tfds
31 |
32 |
33 |
34 | FLAGS = flags.FLAGS
35 |
36 |
37 | flags.DEFINE_float(
38 | 'learning_rate', 0.3,
39 | 'Initial learning rate per batch size of 256.')
40 |
41 | flags.DEFINE_enum(
42 | 'learning_rate_scaling', 'linear', ['linear', 'sqrt'],
43 | 'How to scale the learning rate as a function of batch size.')
44 |
45 | flags.DEFINE_float(
46 | 'warmup_epochs', 10,
47 | 'Number of epochs of warmup.')
48 |
49 | flags.DEFINE_float('weight_decay', 1e-6, 'Amount of weight decay to use.')
50 |
51 | flags.DEFINE_float(
52 | 'batch_norm_decay', 0.9,
53 | 'Batch norm decay parameter.')
54 |
55 | flags.DEFINE_integer(
56 | 'train_batch_size', 512,
57 | 'Batch size for training.')
58 |
59 | flags.DEFINE_string(
60 | 'train_split', 'train',
61 | 'Split for training.')
62 |
63 | flags.DEFINE_integer(
64 | 'train_epochs', 100,
65 | 'Number of epochs to train for.')
66 |
67 | flags.DEFINE_integer(
68 | 'train_steps', 0,
69 | 'Number of steps to train for. If provided, overrides train_epochs.')
70 |
71 | flags.DEFINE_integer(
72 | 'eval_steps', 0,
73 | 'Number of steps to eval for. If not provided, evals over entire dataset.')
74 |
75 | flags.DEFINE_integer(
76 | 'eval_batch_size', 256,
77 | 'Batch size for eval.')
78 |
79 | flags.DEFINE_integer(
80 | 'checkpoint_epochs', 1,
81 | 'Number of epochs between checkpoints/summaries.')
82 |
83 | flags.DEFINE_integer(
84 | 'checkpoint_steps', 0,
85 | 'Number of steps between checkpoints/summaries. If provided, overrides '
86 | 'checkpoint_epochs.')
87 |
88 | flags.DEFINE_string(
89 | 'eval_split', 'validation',
90 | 'Split for evaluation.')
91 |
92 | flags.DEFINE_string(
93 | 'dataset', 'imagenet2012',
94 | 'Name of a dataset.')
95 |
96 | flags.DEFINE_bool(
97 | 'cache_dataset', False,
98 | 'Whether to cache the entire dataset in memory. If the dataset is '
99 | 'ImageNet, this is a very bad idea, but for smaller datasets it can '
100 | 'improve performance.')
101 |
102 | flags.DEFINE_enum(
103 | 'mode', 'train', ['train', 'eval', 'train_then_eval'],
104 | 'Whether to perform training or evaluation.')
105 |
106 | flags.DEFINE_enum(
107 | 'train_mode', 'pretrain', ['pretrain', 'finetune'],
108 | 'The train mode controls different objectives and trainable components.')
109 |
110 | flags.DEFINE_bool('lineareval_while_pretraining', True,
111 | 'Whether to finetune supervised head while pretraining.')
112 |
113 | flags.DEFINE_string(
114 | 'checkpoint', None,
115 | 'Loading from the given checkpoint for fine-tuning if a finetuning '
116 | 'checkpoint does not already exist in model_dir.')
117 |
118 | flags.DEFINE_bool(
119 | 'zero_init_logits_layer', False,
120 | 'If True, zero initialize layers after avg_pool for supervised learning.')
121 |
122 | flags.DEFINE_integer(
123 | 'fine_tune_after_block', -1,
124 | 'The layers after which block that we will fine-tune. -1 means fine-tuning '
125 | 'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '
126 | 'just the linear head.')
127 |
128 | flags.DEFINE_string(
129 | 'master', None,
130 | 'Address/name of the TensorFlow master to use. By default, use an '
131 | 'in-process master.')
132 |
133 | flags.DEFINE_string(
134 | 'model_dir', None,
135 | 'Model directory for training.')
136 |
137 | flags.DEFINE_string(
138 | 'data_dir', None,
139 | 'Directory where dataset is stored.')
140 |
141 | flags.DEFINE_bool(
142 | 'use_tpu', True,
143 | 'Whether to run on TPU.')
144 |
145 | flags.DEFINE_string(
146 | 'tpu_name', None,
147 | 'The Cloud TPU to use for training. This should be either the name '
148 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
149 | 'url.')
150 |
151 | flags.DEFINE_string(
152 | 'tpu_zone', None,
153 | '[Optional] GCE zone where the Cloud TPU is located in. If not '
154 | 'specified, we will attempt to automatically detect the GCE project from '
155 | 'metadata.')
156 |
157 | flags.DEFINE_string(
158 | 'gcp_project', None,
159 | '[Optional] Project name for the Cloud TPU-enabled project. If not '
160 | 'specified, we will attempt to automatically detect the GCE project from '
161 | 'metadata.')
162 |
163 | flags.DEFINE_enum(
164 | 'optimizer', 'lars', ['momentum', 'adam', 'lars'],
165 | 'Optimizer to use.')
166 |
167 | flags.DEFINE_float(
168 | 'momentum', 0.9,
169 | 'Momentum parameter.')
170 |
171 | flags.DEFINE_string(
172 | 'eval_name', None,
173 | 'Name for eval.')
174 |
175 | flags.DEFINE_integer(
176 | 'keep_checkpoint_max', 5,
177 | 'Maximum number of checkpoints to keep.')
178 |
179 | flags.DEFINE_integer(
180 | 'keep_hub_module_max', 1,
181 | 'Maximum number of Hub modules to keep.')
182 |
183 | flags.DEFINE_float(
184 | 'temperature', 0.1,
185 | 'Temperature parameter for contrastive loss.')
186 |
187 | flags.DEFINE_boolean(
188 | 'hidden_norm', True,
189 | 'Temperature parameter for contrastive loss.')
190 |
191 | flags.DEFINE_enum(
192 | 'proj_head_mode', 'nonlinear', ['none', 'linear', 'nonlinear'],
193 | 'How the head projection is done.')
194 |
195 | flags.DEFINE_integer(
196 | 'proj_out_dim', 128,
197 | 'Number of head projection dimension.')
198 |
199 | flags.DEFINE_integer(
200 | 'num_proj_layers', 3,
201 | 'Number of non-linear head layers.')
202 |
203 | flags.DEFINE_integer(
204 | 'ft_proj_selector', 0,
205 | 'Which layer of the projection head to use during fine-tuning. '
206 | '0 means no projection head, and -1 means the final layer.')
207 |
208 | flags.DEFINE_boolean(
209 | 'global_bn', True,
210 | 'Whether to aggregate BN statistics across distributed cores.')
211 |
212 | flags.DEFINE_integer(
213 | 'width_multiplier', 1,
214 | 'Multiplier to change width of network.')
215 |
216 | flags.DEFINE_integer(
217 | 'resnet_depth', 50,
218 | 'Depth of ResNet.')
219 |
220 | flags.DEFINE_float(
221 | 'sk_ratio', 0.,
222 | 'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')
223 |
224 | flags.DEFINE_float(
225 | 'se_ratio', 0.,
226 | 'If it is bigger than 0, it will enable SE.')
227 |
228 | flags.DEFINE_integer(
229 | 'image_size', 224,
230 | 'Input image size.')
231 |
232 | flags.DEFINE_float(
233 | 'color_jitter_strength', 1.0,
234 | 'The strength of color jittering.')
235 |
236 | flags.DEFINE_boolean(
237 | 'use_blur', True,
238 | 'Whether or not to use Gaussian blur for augmentation during pretraining.')
239 |
240 |
241 | def get_salient_tensors_dict(include_projection_head):
242 | """Returns a dictionary of tensors."""
243 | graph = tf.compat.v1.get_default_graph()
244 | result = {}
245 | for i in range(1, 5):
246 | result['block_group%d' % i] = graph.get_tensor_by_name(
247 | 'resnet/block_group%d/block_group%d:0' % (i, i))
248 | result['initial_conv'] = graph.get_tensor_by_name(
249 | 'resnet/initial_conv/Identity:0')
250 | result['initial_max_pool'] = graph.get_tensor_by_name(
251 | 'resnet/initial_max_pool/Identity:0')
252 | result['final_avg_pool'] = graph.get_tensor_by_name('resnet/final_avg_pool:0')
253 | result['logits_sup'] = graph.get_tensor_by_name(
254 | 'head_supervised/logits_sup:0')
255 | if include_projection_head:
256 | result['proj_head_input'] = graph.get_tensor_by_name(
257 | 'projection_head/proj_head_input:0')
258 | result['proj_head_output'] = graph.get_tensor_by_name(
259 | 'projection_head/proj_head_output:0')
260 | return result
261 |
262 |
263 | def build_saved_model(model, include_projection_head=True):
264 | """Returns a tf.Module for saving to SavedModel."""
265 |
266 | class SimCLRModel(tf.Module):
267 | """Saved model for exporting to hub."""
268 |
269 | def __init__(self, model):
270 | self.model = model
271 | # This can't be called `trainable_variables` because `tf.Module` has
272 | # a getter with the same name.
273 | self.trainable_variables_list = model.trainable_variables
274 |
275 | @tf.function
276 | def __call__(self, inputs, trainable):
277 | self.model(inputs, training=trainable)
278 | return get_salient_tensors_dict(include_projection_head)
279 |
280 | module = SimCLRModel(model)
281 | input_spec = tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
282 | module.__call__.get_concrete_function(input_spec, trainable=True)
283 | module.__call__.get_concrete_function(input_spec, trainable=False)
284 | return module
285 |
286 |
287 | def save(model, global_step):
288 | """Export as SavedModel for finetuning and inference."""
289 | saved_model = build_saved_model(model)
290 | export_dir = os.path.join(FLAGS.model_dir, 'saved_model')
291 | checkpoint_export_dir = os.path.join(export_dir, str(global_step))
292 | if tf.io.gfile.exists(checkpoint_export_dir):
293 | tf.io.gfile.rmtree(checkpoint_export_dir)
294 | tf.saved_model.save(saved_model, checkpoint_export_dir)
295 |
296 | if FLAGS.keep_hub_module_max > 0:
297 | # Delete old exported SavedModels.
298 | exported_steps = []
299 | for subdir in tf.io.gfile.listdir(export_dir):
300 | if not subdir.isdigit():
301 | continue
302 | exported_steps.append(int(subdir))
303 | exported_steps.sort()
304 | for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
305 | tf.io.gfile.rmtree(os.path.join(export_dir, str(step_to_delete)))
306 |
307 |
308 | def try_restore_from_checkpoint(model, global_step, optimizer):
309 | """Restores the latest ckpt if it exists, otherwise check FLAGS.checkpoint."""
310 | checkpoint = tf.train.Checkpoint(
311 | model=model, global_step=global_step, optimizer=optimizer)
312 | checkpoint_manager = tf.train.CheckpointManager(
313 | checkpoint,
314 | directory=FLAGS.model_dir,
315 | max_to_keep=FLAGS.keep_checkpoint_max)
316 | latest_ckpt = checkpoint_manager.latest_checkpoint
317 | if latest_ckpt:
318 | # Restore model weights, global step, optimizer states
319 | logging.info('Restoring from latest checkpoint: %s', latest_ckpt)
320 | checkpoint_manager.checkpoint.restore(latest_ckpt).expect_partial()
321 | elif FLAGS.checkpoint:
322 | # Restore model weights only, but not global step and optimizer states
323 | logging.info('Restoring from given checkpoint: %s', FLAGS.checkpoint)
324 | checkpoint_manager2 = tf.train.CheckpointManager(
325 | tf.train.Checkpoint(model=model),
326 | directory=FLAGS.model_dir,
327 | max_to_keep=FLAGS.keep_checkpoint_max)
328 | checkpoint_manager2.checkpoint.restore(FLAGS.checkpoint).expect_partial()
329 | if FLAGS.zero_init_logits_layer:
330 | model = checkpoint_manager2.checkpoint.model
331 | output_layer_parameters = model.supervised_head.trainable_weights
332 | logging.info('Initializing output layer parameters %s to zero',
333 | [x.op.name for x in output_layer_parameters])
334 | for x in output_layer_parameters:
335 | x.assign(tf.zeros_like(x))
336 |
337 | return checkpoint_manager
338 |
339 |
340 | def json_serializable(val):
341 | try:
342 | json.dumps(val)
343 | return True
344 | except TypeError:
345 | return False
346 |
347 |
348 | def perform_evaluation(model, builder, eval_steps, ckpt, strategy, topology):
349 | """Perform evaluation."""
350 | if FLAGS.train_mode == 'pretrain' and not FLAGS.lineareval_while_pretraining:
351 | logging.info('Skipping eval during pretraining without linear eval.')
352 | return
353 | # Build input pipeline.
354 | ds = data_lib.build_distributed_dataset(builder, FLAGS.eval_batch_size, False,
355 | strategy, topology)
356 | summary_writer = tf.summary.create_file_writer(FLAGS.model_dir)
357 |
358 | # Build metrics.
359 | with strategy.scope():
360 | regularization_loss = tf.keras.metrics.Mean('eval/regularization_loss')
361 | label_top_1_accuracy = tf.keras.metrics.Accuracy(
362 | 'eval/label_top_1_accuracy')
363 | label_top_5_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(
364 | 5, 'eval/label_top_5_accuracy')
365 | all_metrics = [
366 | regularization_loss, label_top_1_accuracy, label_top_5_accuracy
367 | ]
368 |
369 | # Restore checkpoint.
370 | logging.info('Restoring from %s', ckpt)
371 | checkpoint = tf.train.Checkpoint(
372 | model=model, global_step=tf.Variable(0, dtype=tf.int64))
373 | checkpoint.restore(ckpt).expect_partial()
374 | global_step = checkpoint.global_step
375 | logging.info('Performing eval at step %d', global_step.numpy())
376 |
377 | def single_step(features, labels):
378 | _, supervised_head_outputs = model(features, training=False)
379 | assert supervised_head_outputs is not None
380 | outputs = supervised_head_outputs
381 | l = labels['labels']
382 | metrics.update_finetune_metrics_eval(label_top_1_accuracy,
383 | label_top_5_accuracy, outputs, l)
384 | reg_loss = model_lib.add_weight_decay(model, adjust_per_optimizer=True)
385 | regularization_loss.update_state(reg_loss)
386 |
387 | with strategy.scope():
388 |
389 | @tf.function
390 | def run_single_step(iterator):
391 | images, labels = next(iterator)
392 | features, labels = images, {'labels': labels}
393 | strategy.run(single_step, (features, labels))
394 |
395 | iterator = iter(ds)
396 | for i in range(eval_steps):
397 | run_single_step(iterator)
398 | logging.info('Completed eval for %d / %d steps', i + 1, eval_steps)
399 | logging.info('Finished eval for %s', ckpt)
400 |
401 | # Write summaries
402 | cur_step = global_step.numpy()
403 | logging.info('Writing summaries for %d step', cur_step)
404 | with summary_writer.as_default():
405 | metrics.log_and_write_metrics_to_summary(all_metrics, cur_step)
406 | summary_writer.flush()
407 |
408 | # Record results as JSON.
409 | result_json_path = os.path.join(FLAGS.model_dir, 'result.json')
410 | result = {metric.name: metric.result().numpy() for metric in all_metrics}
411 | result['global_step'] = global_step.numpy()
412 | logging.info(result)
413 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
414 | json.dump({k: float(v) for k, v in result.items()}, f)
415 | result_json_path = os.path.join(
416 | FLAGS.model_dir, 'result_%d.json'%result['global_step'])
417 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
418 | json.dump({k: float(v) for k, v in result.items()}, f)
419 | flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json')
420 | with tf.io.gfile.GFile(flag_json_path, 'w') as f:
421 | serializable_flags = {}
422 | for key, val in FLAGS.flag_values_dict().items():
423 | # Some flag value types e.g. datetime.timedelta are not json serializable,
424 | # filter those out.
425 | if json_serializable(val):
426 | serializable_flags[key] = val
427 | json.dump(serializable_flags, f)
428 |
429 | # Export as SavedModel for finetuning and inference.
430 | save(model, global_step=result['global_step'])
431 |
432 | return result
433 |
434 |
435 | def _restore_latest_or_from_pretrain(checkpoint_manager):
436 | """Restores the latest ckpt if training already.
437 |
438 | Or restores from FLAGS.checkpoint if in finetune mode.
439 |
440 | Args:
441 | checkpoint_manager: tf.traiin.CheckpointManager.
442 | """
443 | latest_ckpt = checkpoint_manager.latest_checkpoint
444 | if latest_ckpt:
445 | # The model is not build yet so some variables may not be available in
446 | # the object graph. Those are lazily initialized. To suppress the warning
447 | # in that case we specify `expect_partial`.
448 | logging.info('Restoring from %s', latest_ckpt)
449 | checkpoint_manager.checkpoint.restore(latest_ckpt).expect_partial()
450 | elif FLAGS.train_mode == 'finetune':
451 | # Restore from pretrain checkpoint.
452 | assert FLAGS.checkpoint, 'Missing pretrain checkpoint.'
453 | logging.info('Restoring from %s', FLAGS.checkpoint)
454 | checkpoint_manager.checkpoint.restore(FLAGS.checkpoint).expect_partial()
455 | # TODO(iamtingchen): Can we instead use a zeros initializer for the
456 | # supervised head?
457 | if FLAGS.zero_init_logits_layer:
458 | model = checkpoint_manager.checkpoint.model
459 | output_layer_parameters = model.supervised_head.trainable_weights
460 | logging.info('Initializing output layer parameters %s to zero',
461 | [x.op.name for x in output_layer_parameters])
462 | for x in output_layer_parameters:
463 | x.assign(tf.zeros_like(x))
464 |
465 |
466 | def main(argv):
467 | if len(argv) > 1:
468 | raise app.UsageError('Too many command-line arguments.')
469 |
470 |
471 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
472 | builder.download_and_prepare()
473 | num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
474 | num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
475 | num_classes = builder.info.features['label'].num_classes
476 |
477 | train_steps = model_lib.get_train_steps(num_train_examples)
478 | eval_steps = FLAGS.eval_steps or int(
479 | math.ceil(num_eval_examples / FLAGS.eval_batch_size))
480 | epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))
481 |
482 | logging.info('# train examples: %d', num_train_examples)
483 | logging.info('# train_steps: %d', train_steps)
484 | logging.info('# eval examples: %d', num_eval_examples)
485 | logging.info('# eval steps: %d', eval_steps)
486 |
487 | checkpoint_steps = (
488 | FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))
489 |
490 | topology = None
491 | if FLAGS.use_tpu:
492 | if FLAGS.tpu_name:
493 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
494 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
495 | else:
496 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.master)
497 | tf.config.experimental_connect_to_cluster(cluster)
498 | topology = tf.tpu.experimental.initialize_tpu_system(cluster)
499 | logging.info('Topology:')
500 | logging.info('num_tasks: %d', topology.num_tasks)
501 | logging.info('num_tpus_per_task: %d', topology.num_tpus_per_task)
502 | strategy = tf.distribute.TPUStrategy(cluster)
503 |
504 | else:
505 | # For (multiple) GPUs.
506 | strategy = tf.distribute.MirroredStrategy()
507 | logging.info('Running using MirroredStrategy on %d replicas',
508 | strategy.num_replicas_in_sync)
509 |
510 | with strategy.scope():
511 | model = model_lib.Model(num_classes)
512 |
513 | if FLAGS.mode == 'eval':
514 | for ckpt in tf.train.checkpoints_iterator(
515 | FLAGS.model_dir, min_interval_secs=15):
516 | result = perform_evaluation(model, builder, eval_steps, ckpt, strategy,
517 | topology)
518 | if result['global_step'] >= train_steps:
519 | logging.info('Eval complete. Exiting...')
520 | return
521 | else:
522 | summary_writer = tf.summary.create_file_writer(FLAGS.model_dir)
523 | with strategy.scope():
524 | # Build input pipeline.
525 | ds = data_lib.build_distributed_dataset(builder, FLAGS.train_batch_size,
526 | True, strategy, topology)
527 |
528 | # Build LR schedule and optimizer.
529 | learning_rate = model_lib.WarmUpAndCosineDecay(FLAGS.learning_rate,
530 | num_train_examples)
531 | optimizer = model_lib.build_optimizer(learning_rate)
532 |
533 | # Build metrics.
534 | all_metrics = [] # For summaries.
535 | weight_decay_metric = tf.keras.metrics.Mean('train/weight_decay')
536 | total_loss_metric = tf.keras.metrics.Mean('train/total_loss')
537 | all_metrics.extend([weight_decay_metric, total_loss_metric])
538 | if FLAGS.train_mode == 'pretrain':
539 | contrast_loss_metric = tf.keras.metrics.Mean('train/contrast_loss')
540 | contrast_acc_metric = tf.keras.metrics.Mean('train/contrast_acc')
541 | contrast_entropy_metric = tf.keras.metrics.Mean(
542 | 'train/contrast_entropy')
543 | all_metrics.extend([
544 | contrast_loss_metric, contrast_acc_metric, contrast_entropy_metric
545 | ])
546 | if FLAGS.train_mode == 'finetune' or FLAGS.lineareval_while_pretraining:
547 | supervised_loss_metric = tf.keras.metrics.Mean('train/supervised_loss')
548 | supervised_acc_metric = tf.keras.metrics.Mean('train/supervised_acc')
549 | all_metrics.extend([supervised_loss_metric, supervised_acc_metric])
550 |
551 | # Restore checkpoint if available.
552 | checkpoint_manager = try_restore_from_checkpoint(
553 | model, optimizer.iterations, optimizer)
554 |
555 | steps_per_loop = checkpoint_steps
556 |
557 | def single_step(features, labels):
558 | with tf.GradientTape() as tape:
559 | # Log summaries on the last step of the training loop to match
560 | # logging frequency of other scalar summaries.
561 | #
562 | # Notes:
563 | # 1. Summary ops on TPUs get outside compiled so they do not affect
564 | # performance.
565 | # 2. Summaries are recorded only on replica 0. So effectively this
566 | # summary would be written once per host when should_record == True.
567 | # 3. optimizer.iterations is incremented in the call to apply_gradients.
568 | # So we use `iterations + 1` here so that the step number matches
569 | # those of scalar summaries.
570 | # 4. We intentionally run the summary op before the actual model
571 | # training so that it can run in parallel.
572 | should_record = tf.equal((optimizer.iterations + 1) % steps_per_loop, 0)
573 | with tf.summary.record_if(should_record):
574 | # Only log augmented images for the first tower.
575 | tf.summary.image(
576 | 'image', features[:, :, :, :3], step=optimizer.iterations + 1)
577 | projection_head_outputs, supervised_head_outputs = model(
578 | features, training=True)
579 | loss = None
580 | if projection_head_outputs is not None:
581 | outputs = projection_head_outputs
582 | con_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
583 | outputs,
584 | hidden_norm=FLAGS.hidden_norm,
585 | temperature=FLAGS.temperature,
586 | strategy=strategy)
587 | if loss is None:
588 | loss = con_loss
589 | else:
590 | loss += con_loss
591 | metrics.update_pretrain_metrics_train(contrast_loss_metric,
592 | contrast_acc_metric,
593 | contrast_entropy_metric,
594 | con_loss, logits_con,
595 | labels_con)
596 | if supervised_head_outputs is not None:
597 | outputs = supervised_head_outputs
598 | l = labels['labels']
599 | if FLAGS.train_mode == 'pretrain' and FLAGS.lineareval_while_pretraining:
600 | l = tf.concat([l, l], 0)
601 | sup_loss = obj_lib.add_supervised_loss(labels=l, logits=outputs)
602 | if loss is None:
603 | loss = sup_loss
604 | else:
605 | loss += sup_loss
606 | metrics.update_finetune_metrics_train(supervised_loss_metric,
607 | supervised_acc_metric, sup_loss,
608 | l, outputs)
609 | weight_decay = model_lib.add_weight_decay(
610 | model, adjust_per_optimizer=True)
611 | weight_decay_metric.update_state(weight_decay)
612 | loss += weight_decay
613 | total_loss_metric.update_state(loss)
614 | # The default behavior of `apply_gradients` is to sum gradients from all
615 | # replicas so we divide the loss by the number of replicas so that the
616 | # mean gradient is applied.
617 | loss = loss / strategy.num_replicas_in_sync
618 | logging.info('Trainable variables:')
619 | for var in model.trainable_variables:
620 | logging.info(var.name)
621 | grads = tape.gradient(loss, model.trainable_variables)
622 | optimizer.apply_gradients(zip(grads, model.trainable_variables))
623 |
624 | with strategy.scope():
625 |
626 | @tf.function
627 | def train_multiple_steps(iterator):
628 | # `tf.range` is needed so that this runs in a `tf.while_loop` and is
629 | # not unrolled.
630 | for _ in tf.range(steps_per_loop):
631 | # Drop the "while" prefix created by tf.while_loop which otherwise
632 | # gets prefixed to every variable name. This does not affect training
633 | # but does affect the checkpoint conversion script.
634 | # TODO(b/161712658): Remove this.
635 | with tf.name_scope(''):
636 | images, labels = next(iterator)
637 | features, labels = images, {'labels': labels}
638 | strategy.run(single_step, (features, labels))
639 |
640 | global_step = optimizer.iterations
641 | cur_step = global_step.numpy()
642 | iterator = iter(ds)
643 | while cur_step < train_steps:
644 | # Calls to tf.summary.xyz lookup the summary writer resource which is
645 | # set by the summary writer's context manager.
646 | with summary_writer.as_default():
647 | train_multiple_steps(iterator)
648 | cur_step = global_step.numpy()
649 | checkpoint_manager.save(cur_step)
650 | logging.info('Completed: %d / %d steps', cur_step, train_steps)
651 | metrics.log_and_write_metrics_to_summary(all_metrics, cur_step)
652 | tf.summary.scalar(
653 | 'learning_rate',
654 | learning_rate(tf.cast(global_step, dtype=tf.float32)),
655 | global_step)
656 | summary_writer.flush()
657 | for metric in all_metrics:
658 | metric.reset_states()
659 | logging.info('Training complete...')
660 |
661 | if FLAGS.mode == 'train_then_eval':
662 | perform_evaluation(model, builder, eval_steps,
663 | checkpoint_manager.latest_checkpoint, strategy,
664 | topology)
665 |
666 |
667 | if __name__ == '__main__':
668 | tf.compat.v1.enable_v2_behavior()
669 | # For outside compilation of summaries on TPU.
670 | tf.config.set_soft_device_placement(True)
671 | app.run(main)
672 |
--------------------------------------------------------------------------------