├── README.md ├── LICENSE └── CIFAR_10C_Evaluation.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Consistency-Training-with-Supervision 2 | Contains experimentation notebooks for my Keras Example [Consistency Training with Supervision](https://keras.io/examples/vision/consistency_training/). This example also provides a template for performing semi-supervised / weakly supervised learning. 3 | 4 | Promising results on [CIFAR-10-C](https://github.com/hendrycks/robustness) with the process shown in the example: 5 | 6 |

7 | 8 |

9 | 10 | **More things one can incorporate**: 11 | 12 | * Incorporate more data during training the student. 13 | * Filter high confidence predictions from teacher during training the student. 14 | * Use recipes like [Stochastic Depth](https://arxiv.org/abs/1603.09382) for training the teacher. The current example uses [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to induce geometric ensembling. 15 | 16 | Full-scale experiments are available [here](https://git.io/JO55v). 17 | 18 | ## Acknowledgements 19 | 20 | * [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP credits. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /CIFAR_10C_Evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 5, 4 | "metadata": { 5 | "environment": { 6 | "name": "tf2-gpu.2-4.mnightly-2021-01-20-debian-10-test", 7 | "type": "gcloud", 8 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-4:mnightly-2021-01-20-debian-10-test" 9 | }, 10 | "kernelspec": { 11 | "display_name": "Python 3", 12 | "language": "python", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "codemirror_mode": { 17 | "name": "ipython", 18 | "version": 3 19 | }, 20 | "file_extension": ".py", 21 | "mimetype": "text/x-python", 22 | "name": "python", 23 | "nbconvert_exporter": "python", 24 | "pygments_lexer": "ipython3", 25 | "version": "3.7.9" 26 | }, 27 | "colab": { 28 | "name": "CIFAR_10C_Evaluation.ipynb", 29 | "provenance": [], 30 | "include_colab_link": true 31 | } 32 | }, 33 | "cells": [ 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "id": "view-in-github", 38 | "colab_type": "text" 39 | }, 40 | "source": [ 41 | "\"Open" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": { 47 | "id": "filled-jurisdiction" 48 | }, 49 | "source": [ 50 | "## Setup" 51 | ], 52 | "id": "filled-jurisdiction" 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "id": "liberal-edmonton" 58 | }, 59 | "source": [ 60 | "# All model weights\n", 61 | "!wget https://git.io/JOKI9 -O consistency_training_model_weights.zip" 62 | ], 63 | "id": "liberal-edmonton", 64 | "execution_count": null, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "three-niger" 71 | }, 72 | "source": [ 73 | "from tensorflow.keras import layers\n", 74 | "import tensorflow as tf\n", 75 | "\n", 76 | "import tensorflow_datasets as tfds\n", 77 | "tfds.disable_progress_bar()\n", 78 | "\n", 79 | "from tqdm import tqdm\n", 80 | "import numpy as np" 81 | ], 82 | "id": "three-niger", 83 | "execution_count": null, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "residential-gossip" 90 | }, 91 | "source": [ 92 | "## Define Hyperparameters" 93 | ], 94 | "id": "residential-gossip" 95 | }, 96 | { 97 | "cell_type": "code", 98 | "metadata": { 99 | "id": "loose-devil" 100 | }, 101 | "source": [ 102 | "AUTO = tf.data.AUTOTUNE\n", 103 | "DATASET_NAME = \"cifar10_corrupted\"\n", 104 | "BATCH_SIZE = 128\n", 105 | "IMAGE_SIZE = 72" 106 | ], 107 | "id": "loose-devil", 108 | "execution_count": null, 109 | "outputs": [] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "metadata": { 114 | "id": "related-yahoo", 115 | "outputId": "b1cc4e9a-d29f-41c1-9558-b4477ff469ab" 116 | }, 117 | "source": [ 118 | "VERSIONS = [\n", 119 | " \"brightness_5\",\n", 120 | " \"contrast_5\",\n", 121 | " \"defocus_blur_5\",\n", 122 | " \"elastic_5\",\n", 123 | " \"fog_5\",\n", 124 | " \"frost_5\",\n", 125 | " \"frosted_glass_blur_5\",\n", 126 | " \"gaussian_blur_5\",\n", 127 | " \"gaussian_noise_5\",\n", 128 | " \"impulse_noise_5\",\n", 129 | " \"jpeg_compression_5\",\n", 130 | " \"motion_blur_5\",\n", 131 | " \"pixelate_5\",\n", 132 | " \"saturate_5\",\n", 133 | " \"shot_noise_5\",\n", 134 | " \"snow_5\",\n", 135 | " \"spatter_5\",\n", 136 | " \"speckle_noise_5\",\n", 137 | " \"zoom_blur_5\"\n", 138 | "]\n", 139 | "\n", 140 | "print(f\"Total sub-versions of the CIFAR10-C dataset: {len(VERSIONS)}\")" 141 | ], 142 | "id": "related-yahoo", 143 | "execution_count": null, 144 | "outputs": [ 145 | { 146 | "output_type": "stream", 147 | "text": [ 148 | "Total sub-versions of the CIFAR10-C dataset: 19\n" 149 | ], 150 | "name": "stdout" 151 | } 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "responsible-techno" 158 | }, 159 | "source": [ 160 | "## Utilities" 161 | ], 162 | "id": "responsible-techno" 163 | }, 164 | { 165 | "cell_type": "code", 166 | "metadata": { 167 | "id": "dedicated-typing" 168 | }, 169 | "source": [ 170 | "def prepare_dataset(ds):\n", 171 | " ds = (ds\n", 172 | " .batch(BATCH_SIZE)\n", 173 | " .map(lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y), \n", 174 | " num_parallel_calls=AUTO)\n", 175 | " .prefetch(AUTO)\n", 176 | " )\n", 177 | " return ds" 178 | ], 179 | "id": "dedicated-typing", 180 | "execution_count": null, 181 | "outputs": [] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "metadata": { 186 | "id": "designing-chancellor" 187 | }, 188 | "source": [ 189 | "def get_training_model(num_classes=10):\n", 190 | " resnet50_v2 = tf.keras.applications.ResNet50V2(\n", 191 | " weights=None,\n", 192 | " include_top=False,\n", 193 | " input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),\n", 194 | " )\n", 195 | " model = tf.keras.Sequential(\n", 196 | " [\n", 197 | " layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),\n", 198 | " layers.experimental.preprocessing.Rescaling(scale=1.0 / 127.5, offset=-1),\n", 199 | " resnet50_v2,\n", 200 | " layers.GlobalAveragePooling2D(),\n", 201 | " layers.Dense(num_classes)\n", 202 | " ]\n", 203 | " )\n", 204 | " return model" 205 | ], 206 | "id": "designing-chancellor", 207 | "execution_count": null, 208 | "outputs": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "metadata": { 213 | "id": "dangerous-processing" 214 | }, 215 | "source": [ 216 | "def evaluate_model(model):\n", 217 | " acc_dict = {}\n", 218 | " for version in tqdm(VERSIONS):\n", 219 | " print(f\"Processing {version}\")\n", 220 | " dataset_fullname = DATASET_NAME + \"/\" + version\n", 221 | " loaded_ds = tfds.load(\n", 222 | " dataset_fullname,\n", 223 | " split=\"test\",\n", 224 | " as_supervised=True\n", 225 | " )\n", 226 | " loaded_ds = prepare_dataset(loaded_ds)\n", 227 | " _, acc = model.evaluate(loaded_ds, verbose=0)\n", 228 | " print(f\"Test accuracy on {version}: {acc*100}%\")\n", 229 | " acc_dict[version] = acc*100\n", 230 | " \n", 231 | " return acc_dict, np.mean(list(acc_dict.values()))" 232 | ], 233 | "id": "dangerous-processing", 234 | "execution_count": null, 235 | "outputs": [] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": { 240 | "id": "monthly-mount" 241 | }, 242 | "source": [ 243 | "## Evaluation" 244 | ], 245 | "id": "monthly-mount" 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "id": "rural-attitude" 251 | }, 252 | "source": [ 253 | "### SWA" 254 | ], 255 | "id": "rural-attitude" 256 | }, 257 | { 258 | "cell_type": "code", 259 | "metadata": { 260 | "id": "essential-height", 261 | "outputId": "6e029f58-9f51-490b-d8a4-f53a97186ba4" 262 | }, 263 | "source": [ 264 | "# Evaluate teacher model trained with SWA\n", 265 | "teacher_model_swa = get_training_model()\n", 266 | "teacher_model_swa.load_weights(\"teacher_model_swa.h5\")\n", 267 | "teacher_model_swa.compile(loss=\"sparse_categorical_crossentropy\",\n", 268 | " metrics=[\"accuracy\"])\n", 269 | "acc_dict, mean_top_1 = evaluate_model(teacher_model_swa)\n", 270 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")" 271 | ], 272 | "id": "essential-height", 273 | "execution_count": null, 274 | "outputs": [ 275 | { 276 | "output_type": "stream", 277 | "text": [ 278 | " 0%| | 0/19 [00:00