├── README.md
├── LICENSE
└── CIFAR_10C_Evaluation.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # Consistency-Training-with-Supervision
2 | Contains experimentation notebooks for my Keras Example [Consistency Training with Supervision](https://keras.io/examples/vision/consistency_training/). This example also provides a template for performing semi-supervised / weakly supervised learning.
3 |
4 | Promising results on [CIFAR-10-C](https://github.com/hendrycks/robustness) with the process shown in the example:
5 |
6 |
7 |
8 |
9 |
10 | **More things one can incorporate**:
11 |
12 | * Incorporate more data during training the student.
13 | * Filter high confidence predictions from teacher during training the student.
14 | * Use recipes like [Stochastic Depth](https://arxiv.org/abs/1603.09382) for training the teacher. The current example uses [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to induce geometric ensembling.
15 |
16 | Full-scale experiments are available [here](https://git.io/JO55v).
17 |
18 | ## Acknowledgements
19 |
20 | * [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP credits.
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/CIFAR_10C_Evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 5,
4 | "metadata": {
5 | "environment": {
6 | "name": "tf2-gpu.2-4.mnightly-2021-01-20-debian-10-test",
7 | "type": "gcloud",
8 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-4:mnightly-2021-01-20-debian-10-test"
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "language": "python",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "codemirror_mode": {
17 | "name": "ipython",
18 | "version": 3
19 | },
20 | "file_extension": ".py",
21 | "mimetype": "text/x-python",
22 | "name": "python",
23 | "nbconvert_exporter": "python",
24 | "pygments_lexer": "ipython3",
25 | "version": "3.7.9"
26 | },
27 | "colab": {
28 | "name": "CIFAR_10C_Evaluation.ipynb",
29 | "provenance": [],
30 | "include_colab_link": true
31 | }
32 | },
33 | "cells": [
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {
37 | "id": "view-in-github",
38 | "colab_type": "text"
39 | },
40 | "source": [
41 | "
"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {
47 | "id": "filled-jurisdiction"
48 | },
49 | "source": [
50 | "## Setup"
51 | ],
52 | "id": "filled-jurisdiction"
53 | },
54 | {
55 | "cell_type": "code",
56 | "metadata": {
57 | "id": "liberal-edmonton"
58 | },
59 | "source": [
60 | "# All model weights\n",
61 | "!wget https://git.io/JOKI9 -O consistency_training_model_weights.zip"
62 | ],
63 | "id": "liberal-edmonton",
64 | "execution_count": null,
65 | "outputs": []
66 | },
67 | {
68 | "cell_type": "code",
69 | "metadata": {
70 | "id": "three-niger"
71 | },
72 | "source": [
73 | "from tensorflow.keras import layers\n",
74 | "import tensorflow as tf\n",
75 | "\n",
76 | "import tensorflow_datasets as tfds\n",
77 | "tfds.disable_progress_bar()\n",
78 | "\n",
79 | "from tqdm import tqdm\n",
80 | "import numpy as np"
81 | ],
82 | "id": "three-niger",
83 | "execution_count": null,
84 | "outputs": []
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {
89 | "id": "residential-gossip"
90 | },
91 | "source": [
92 | "## Define Hyperparameters"
93 | ],
94 | "id": "residential-gossip"
95 | },
96 | {
97 | "cell_type": "code",
98 | "metadata": {
99 | "id": "loose-devil"
100 | },
101 | "source": [
102 | "AUTO = tf.data.AUTOTUNE\n",
103 | "DATASET_NAME = \"cifar10_corrupted\"\n",
104 | "BATCH_SIZE = 128\n",
105 | "IMAGE_SIZE = 72"
106 | ],
107 | "id": "loose-devil",
108 | "execution_count": null,
109 | "outputs": []
110 | },
111 | {
112 | "cell_type": "code",
113 | "metadata": {
114 | "id": "related-yahoo",
115 | "outputId": "b1cc4e9a-d29f-41c1-9558-b4477ff469ab"
116 | },
117 | "source": [
118 | "VERSIONS = [\n",
119 | " \"brightness_5\",\n",
120 | " \"contrast_5\",\n",
121 | " \"defocus_blur_5\",\n",
122 | " \"elastic_5\",\n",
123 | " \"fog_5\",\n",
124 | " \"frost_5\",\n",
125 | " \"frosted_glass_blur_5\",\n",
126 | " \"gaussian_blur_5\",\n",
127 | " \"gaussian_noise_5\",\n",
128 | " \"impulse_noise_5\",\n",
129 | " \"jpeg_compression_5\",\n",
130 | " \"motion_blur_5\",\n",
131 | " \"pixelate_5\",\n",
132 | " \"saturate_5\",\n",
133 | " \"shot_noise_5\",\n",
134 | " \"snow_5\",\n",
135 | " \"spatter_5\",\n",
136 | " \"speckle_noise_5\",\n",
137 | " \"zoom_blur_5\"\n",
138 | "]\n",
139 | "\n",
140 | "print(f\"Total sub-versions of the CIFAR10-C dataset: {len(VERSIONS)}\")"
141 | ],
142 | "id": "related-yahoo",
143 | "execution_count": null,
144 | "outputs": [
145 | {
146 | "output_type": "stream",
147 | "text": [
148 | "Total sub-versions of the CIFAR10-C dataset: 19\n"
149 | ],
150 | "name": "stdout"
151 | }
152 | ]
153 | },
154 | {
155 | "cell_type": "markdown",
156 | "metadata": {
157 | "id": "responsible-techno"
158 | },
159 | "source": [
160 | "## Utilities"
161 | ],
162 | "id": "responsible-techno"
163 | },
164 | {
165 | "cell_type": "code",
166 | "metadata": {
167 | "id": "dedicated-typing"
168 | },
169 | "source": [
170 | "def prepare_dataset(ds):\n",
171 | " ds = (ds\n",
172 | " .batch(BATCH_SIZE)\n",
173 | " .map(lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y), \n",
174 | " num_parallel_calls=AUTO)\n",
175 | " .prefetch(AUTO)\n",
176 | " )\n",
177 | " return ds"
178 | ],
179 | "id": "dedicated-typing",
180 | "execution_count": null,
181 | "outputs": []
182 | },
183 | {
184 | "cell_type": "code",
185 | "metadata": {
186 | "id": "designing-chancellor"
187 | },
188 | "source": [
189 | "def get_training_model(num_classes=10):\n",
190 | " resnet50_v2 = tf.keras.applications.ResNet50V2(\n",
191 | " weights=None,\n",
192 | " include_top=False,\n",
193 | " input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),\n",
194 | " )\n",
195 | " model = tf.keras.Sequential(\n",
196 | " [\n",
197 | " layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),\n",
198 | " layers.experimental.preprocessing.Rescaling(scale=1.0 / 127.5, offset=-1),\n",
199 | " resnet50_v2,\n",
200 | " layers.GlobalAveragePooling2D(),\n",
201 | " layers.Dense(num_classes)\n",
202 | " ]\n",
203 | " )\n",
204 | " return model"
205 | ],
206 | "id": "designing-chancellor",
207 | "execution_count": null,
208 | "outputs": []
209 | },
210 | {
211 | "cell_type": "code",
212 | "metadata": {
213 | "id": "dangerous-processing"
214 | },
215 | "source": [
216 | "def evaluate_model(model):\n",
217 | " acc_dict = {}\n",
218 | " for version in tqdm(VERSIONS):\n",
219 | " print(f\"Processing {version}\")\n",
220 | " dataset_fullname = DATASET_NAME + \"/\" + version\n",
221 | " loaded_ds = tfds.load(\n",
222 | " dataset_fullname,\n",
223 | " split=\"test\",\n",
224 | " as_supervised=True\n",
225 | " )\n",
226 | " loaded_ds = prepare_dataset(loaded_ds)\n",
227 | " _, acc = model.evaluate(loaded_ds, verbose=0)\n",
228 | " print(f\"Test accuracy on {version}: {acc*100}%\")\n",
229 | " acc_dict[version] = acc*100\n",
230 | " \n",
231 | " return acc_dict, np.mean(list(acc_dict.values()))"
232 | ],
233 | "id": "dangerous-processing",
234 | "execution_count": null,
235 | "outputs": []
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "metadata": {
240 | "id": "monthly-mount"
241 | },
242 | "source": [
243 | "## Evaluation"
244 | ],
245 | "id": "monthly-mount"
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "metadata": {
250 | "id": "rural-attitude"
251 | },
252 | "source": [
253 | "### SWA"
254 | ],
255 | "id": "rural-attitude"
256 | },
257 | {
258 | "cell_type": "code",
259 | "metadata": {
260 | "id": "essential-height",
261 | "outputId": "6e029f58-9f51-490b-d8a4-f53a97186ba4"
262 | },
263 | "source": [
264 | "# Evaluate teacher model trained with SWA\n",
265 | "teacher_model_swa = get_training_model()\n",
266 | "teacher_model_swa.load_weights(\"teacher_model_swa.h5\")\n",
267 | "teacher_model_swa.compile(loss=\"sparse_categorical_crossentropy\",\n",
268 | " metrics=[\"accuracy\"])\n",
269 | "acc_dict, mean_top_1 = evaluate_model(teacher_model_swa)\n",
270 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")"
271 | ],
272 | "id": "essential-height",
273 | "execution_count": null,
274 | "outputs": [
275 | {
276 | "output_type": "stream",
277 | "text": [
278 | " 0%| | 0/19 [00:00, ?it/s]"
279 | ],
280 | "name": "stderr"
281 | },
282 | {
283 | "output_type": "stream",
284 | "text": [
285 | "Processing brightness_5\n"
286 | ],
287 | "name": "stdout"
288 | },
289 | {
290 | "output_type": "stream",
291 | "text": [
292 | " 5%|▌ | 1/19 [00:17<05:20, 17.80s/it]"
293 | ],
294 | "name": "stderr"
295 | },
296 | {
297 | "output_type": "stream",
298 | "text": [
299 | "Test accuracy on brightness_5: 76.05999708175659%\n",
300 | "Processing contrast_5\n"
301 | ],
302 | "name": "stdout"
303 | },
304 | {
305 | "output_type": "stream",
306 | "text": [
307 | " 11%|█ | 2/19 [00:19<02:22, 8.40s/it]"
308 | ],
309 | "name": "stderr"
310 | },
311 | {
312 | "output_type": "stream",
313 | "text": [
314 | "Test accuracy on contrast_5: 24.34999942779541%\n",
315 | "Processing defocus_blur_5\n"
316 | ],
317 | "name": "stdout"
318 | },
319 | {
320 | "output_type": "stream",
321 | "text": [
322 | " 16%|█▌ | 3/19 [00:21<01:26, 5.42s/it]"
323 | ],
324 | "name": "stderr"
325 | },
326 | {
327 | "output_type": "stream",
328 | "text": [
329 | "Test accuracy on defocus_blur_5: 71.59000039100647%\n",
330 | "Processing elastic_5\n"
331 | ],
332 | "name": "stdout"
333 | },
334 | {
335 | "output_type": "stream",
336 | "text": [
337 | " 21%|██ | 4/19 [00:23<01:00, 4.02s/it]"
338 | ],
339 | "name": "stderr"
340 | },
341 | {
342 | "output_type": "stream",
343 | "text": [
344 | "Test accuracy on elastic_5: 74.19999837875366%\n",
345 | "Processing fog_5\n"
346 | ],
347 | "name": "stdout"
348 | },
349 | {
350 | "output_type": "stream",
351 | "text": [
352 | " 26%|██▋ | 5/19 [00:25<00:45, 3.25s/it]"
353 | ],
354 | "name": "stderr"
355 | },
356 | {
357 | "output_type": "stream",
358 | "text": [
359 | "Test accuracy on fog_5: 48.28999936580658%\n",
360 | "Processing frost_5\n"
361 | ],
362 | "name": "stdout"
363 | },
364 | {
365 | "output_type": "stream",
366 | "text": [
367 | " 32%|███▏ | 6/19 [00:27<00:36, 2.79s/it]"
368 | ],
369 | "name": "stderr"
370 | },
371 | {
372 | "output_type": "stream",
373 | "text": [
374 | "Test accuracy on frost_5: 62.26999759674072%\n",
375 | "Processing frosted_glass_blur_5\n"
376 | ],
377 | "name": "stdout"
378 | },
379 | {
380 | "output_type": "stream",
381 | "text": [
382 | " 37%|███▋ | 7/19 [00:28<00:29, 2.48s/it]"
383 | ],
384 | "name": "stderr"
385 | },
386 | {
387 | "output_type": "stream",
388 | "text": [
389 | "Test accuracy on frosted_glass_blur_5: 58.980000019073486%\n",
390 | "Processing gaussian_blur_5\n"
391 | ],
392 | "name": "stdout"
393 | },
394 | {
395 | "output_type": "stream",
396 | "text": [
397 | " 42%|████▏ | 8/19 [00:30<00:25, 2.28s/it]"
398 | ],
399 | "name": "stderr"
400 | },
401 | {
402 | "output_type": "stream",
403 | "text": [
404 | "Test accuracy on gaussian_blur_5: 66.60000085830688%\n",
405 | "Processing gaussian_noise_5\n"
406 | ],
407 | "name": "stdout"
408 | },
409 | {
410 | "output_type": "stream",
411 | "text": [
412 | " 47%|████▋ | 9/19 [00:32<00:21, 2.15s/it]"
413 | ],
414 | "name": "stderr"
415 | },
416 | {
417 | "output_type": "stream",
418 | "text": [
419 | "Test accuracy on gaussian_noise_5: 41.909998655319214%\n",
420 | "Processing impulse_noise_5\n"
421 | ],
422 | "name": "stdout"
423 | },
424 | {
425 | "output_type": "stream",
426 | "text": [
427 | " 53%|█████▎ | 10/19 [00:34<00:18, 2.06s/it]"
428 | ],
429 | "name": "stderr"
430 | },
431 | {
432 | "output_type": "stream",
433 | "text": [
434 | "Test accuracy on impulse_noise_5: 23.970000445842743%\n",
435 | "Processing jpeg_compression_5\n"
436 | ],
437 | "name": "stdout"
438 | },
439 | {
440 | "output_type": "stream",
441 | "text": [
442 | " 58%|█████▊ | 11/19 [00:36<00:15, 2.00s/it]"
443 | ],
444 | "name": "stderr"
445 | },
446 | {
447 | "output_type": "stream",
448 | "text": [
449 | "Test accuracy on jpeg_compression_5: 79.1599988937378%\n",
450 | "Processing motion_blur_5\n"
451 | ],
452 | "name": "stdout"
453 | },
454 | {
455 | "output_type": "stream",
456 | "text": [
457 | " 63%|██████▎ | 12/19 [00:38<00:13, 1.97s/it]"
458 | ],
459 | "name": "stderr"
460 | },
461 | {
462 | "output_type": "stream",
463 | "text": [
464 | "Test accuracy on motion_blur_5: 65.77000021934509%\n",
465 | "Processing pixelate_5\n"
466 | ],
467 | "name": "stdout"
468 | },
469 | {
470 | "output_type": "stream",
471 | "text": [
472 | " 68%|██████▊ | 13/19 [00:40<00:11, 1.93s/it]"
473 | ],
474 | "name": "stderr"
475 | },
476 | {
477 | "output_type": "stream",
478 | "text": [
479 | "Test accuracy on pixelate_5: 73.51999878883362%\n",
480 | "Processing saturate_5\n"
481 | ],
482 | "name": "stdout"
483 | },
484 | {
485 | "output_type": "stream",
486 | "text": [
487 | " 74%|███████▎ | 14/19 [00:42<00:09, 1.91s/it]"
488 | ],
489 | "name": "stderr"
490 | },
491 | {
492 | "output_type": "stream",
493 | "text": [
494 | "Test accuracy on saturate_5: 69.6399986743927%\n",
495 | "Processing shot_noise_5\n"
496 | ],
497 | "name": "stdout"
498 | },
499 | {
500 | "output_type": "stream",
501 | "text": [
502 | " 79%|███████▉ | 15/19 [00:43<00:07, 1.89s/it]"
503 | ],
504 | "name": "stderr"
505 | },
506 | {
507 | "output_type": "stream",
508 | "text": [
509 | "Test accuracy on shot_noise_5: 45.21999955177307%\n",
510 | "Processing snow_5\n"
511 | ],
512 | "name": "stdout"
513 | },
514 | {
515 | "output_type": "stream",
516 | "text": [
517 | " 84%|████████▍ | 16/19 [00:45<00:05, 1.89s/it]"
518 | ],
519 | "name": "stderr"
520 | },
521 | {
522 | "output_type": "stream",
523 | "text": [
524 | "Test accuracy on snow_5: 65.39999842643738%\n",
525 | "Processing spatter_5\n"
526 | ],
527 | "name": "stdout"
528 | },
529 | {
530 | "output_type": "stream",
531 | "text": [
532 | " 89%|████████▉ | 17/19 [00:47<00:03, 1.87s/it]"
533 | ],
534 | "name": "stderr"
535 | },
536 | {
537 | "output_type": "stream",
538 | "text": [
539 | "Test accuracy on spatter_5: 64.16000127792358%\n",
540 | "Processing speckle_noise_5\n"
541 | ],
542 | "name": "stdout"
543 | },
544 | {
545 | "output_type": "stream",
546 | "text": [
547 | " 95%|█████████▍| 18/19 [00:49<00:01, 1.86s/it]"
548 | ],
549 | "name": "stderr"
550 | },
551 | {
552 | "output_type": "stream",
553 | "text": [
554 | "Test accuracy on speckle_noise_5: 45.62999904155731%\n",
555 | "Processing zoom_blur_5\n"
556 | ],
557 | "name": "stdout"
558 | },
559 | {
560 | "output_type": "stream",
561 | "text": [
562 | "100%|██████████| 19/19 [00:51<00:00, 2.70s/it]"
563 | ],
564 | "name": "stderr"
565 | },
566 | {
567 | "output_type": "stream",
568 | "text": [
569 | "Test accuracy on zoom_blur_5: 77.49000191688538%\n",
570 | "Mean Top-1 Accuracy: 59.695262579541456%\n"
571 | ],
572 | "name": "stdout"
573 | },
574 | {
575 | "output_type": "stream",
576 | "text": [
577 | "\n"
578 | ],
579 | "name": "stderr"
580 | }
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "metadata": {
586 | "id": "instant-insurance",
587 | "outputId": "e62bb54f-d2e3-4e32-abc1-32bdd5f1a22f"
588 | },
589 | "source": [
590 | "# Evaluate the corresponding student model\n",
591 | "student_noisy_swa = get_training_model()\n",
592 | "student_noisy_swa.load_weights(\"student_noisy_swa.h5\")\n",
593 | "student_noisy_swa.compile(loss=\"sparse_categorical_crossentropy\",\n",
594 | " metrics=[\"accuracy\"])\n",
595 | "acc_dict, mean_top_1 = evaluate_model(student_noisy_swa)\n",
596 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")"
597 | ],
598 | "id": "instant-insurance",
599 | "execution_count": null,
600 | "outputs": [
601 | {
602 | "output_type": "stream",
603 | "text": [
604 | " 0%| | 0/19 [00:00, ?it/s]"
605 | ],
606 | "name": "stderr"
607 | },
608 | {
609 | "output_type": "stream",
610 | "text": [
611 | "Processing brightness_5\n"
612 | ],
613 | "name": "stdout"
614 | },
615 | {
616 | "output_type": "stream",
617 | "text": [
618 | " 5%|▌ | 1/19 [00:02<00:51, 2.85s/it]"
619 | ],
620 | "name": "stderr"
621 | },
622 | {
623 | "output_type": "stream",
624 | "text": [
625 | "Test accuracy on brightness_5: 82.91000127792358%\n",
626 | "Processing contrast_5\n"
627 | ],
628 | "name": "stdout"
629 | },
630 | {
631 | "output_type": "stream",
632 | "text": [
633 | " 11%|█ | 2/19 [00:04<00:37, 2.22s/it]"
634 | ],
635 | "name": "stderr"
636 | },
637 | {
638 | "output_type": "stream",
639 | "text": [
640 | "Test accuracy on contrast_5: 29.730001091957092%\n",
641 | "Processing defocus_blur_5\n"
642 | ],
643 | "name": "stdout"
644 | },
645 | {
646 | "output_type": "stream",
647 | "text": [
648 | " 16%|█▌ | 3/19 [00:06<00:32, 2.01s/it]"
649 | ],
650 | "name": "stderr"
651 | },
652 | {
653 | "output_type": "stream",
654 | "text": [
655 | "Test accuracy on defocus_blur_5: 73.00000190734863%\n",
656 | "Processing elastic_5\n"
657 | ],
658 | "name": "stdout"
659 | },
660 | {
661 | "output_type": "stream",
662 | "text": [
663 | " 21%|██ | 4/19 [00:08<00:28, 1.92s/it]"
664 | ],
665 | "name": "stderr"
666 | },
667 | {
668 | "output_type": "stream",
669 | "text": [
670 | "Test accuracy on elastic_5: 74.29999709129333%\n",
671 | "Processing fog_5\n"
672 | ],
673 | "name": "stdout"
674 | },
675 | {
676 | "output_type": "stream",
677 | "text": [
678 | " 26%|██▋ | 5/19 [00:09<00:26, 1.86s/it]"
679 | ],
680 | "name": "stderr"
681 | },
682 | {
683 | "output_type": "stream",
684 | "text": [
685 | "Test accuracy on fog_5: 50.26000142097473%\n",
686 | "Processing frost_5\n"
687 | ],
688 | "name": "stdout"
689 | },
690 | {
691 | "output_type": "stream",
692 | "text": [
693 | " 32%|███▏ | 6/19 [00:11<00:23, 1.83s/it]"
694 | ],
695 | "name": "stderr"
696 | },
697 | {
698 | "output_type": "stream",
699 | "text": [
700 | "Test accuracy on frost_5: 58.92000198364258%\n",
701 | "Processing frosted_glass_blur_5\n"
702 | ],
703 | "name": "stdout"
704 | },
705 | {
706 | "output_type": "stream",
707 | "text": [
708 | " 37%|███▋ | 7/19 [00:13<00:21, 1.81s/it]"
709 | ],
710 | "name": "stderr"
711 | },
712 | {
713 | "output_type": "stream",
714 | "text": [
715 | "Test accuracy on frosted_glass_blur_5: 59.78999733924866%\n",
716 | "Processing gaussian_blur_5\n"
717 | ],
718 | "name": "stdout"
719 | },
720 | {
721 | "output_type": "stream",
722 | "text": [
723 | " 42%|████▏ | 8/19 [00:15<00:19, 1.80s/it]"
724 | ],
725 | "name": "stderr"
726 | },
727 | {
728 | "output_type": "stream",
729 | "text": [
730 | "Test accuracy on gaussian_blur_5: 67.93000102043152%\n",
731 | "Processing gaussian_noise_5\n"
732 | ],
733 | "name": "stdout"
734 | },
735 | {
736 | "output_type": "stream",
737 | "text": [
738 | " 47%|████▋ | 9/19 [00:17<00:17, 1.79s/it]"
739 | ],
740 | "name": "stderr"
741 | },
742 | {
743 | "output_type": "stream",
744 | "text": [
745 | "Test accuracy on gaussian_noise_5: 46.25999927520752%\n",
746 | "Processing impulse_noise_5\n"
747 | ],
748 | "name": "stdout"
749 | },
750 | {
751 | "output_type": "stream",
752 | "text": [
753 | " 53%|█████▎ | 10/19 [00:18<00:16, 1.78s/it]"
754 | ],
755 | "name": "stderr"
756 | },
757 | {
758 | "output_type": "stream",
759 | "text": [
760 | "Test accuracy on impulse_noise_5: 30.98999857902527%\n",
761 | "Processing jpeg_compression_5\n"
762 | ],
763 | "name": "stdout"
764 | },
765 | {
766 | "output_type": "stream",
767 | "text": [
768 | " 58%|█████▊ | 11/19 [00:20<00:14, 1.78s/it]"
769 | ],
770 | "name": "stderr"
771 | },
772 | {
773 | "output_type": "stream",
774 | "text": [
775 | "Test accuracy on jpeg_compression_5: 76.39999985694885%\n",
776 | "Processing motion_blur_5\n"
777 | ],
778 | "name": "stdout"
779 | },
780 | {
781 | "output_type": "stream",
782 | "text": [
783 | " 63%|██████▎ | 12/19 [00:22<00:12, 1.77s/it]"
784 | ],
785 | "name": "stderr"
786 | },
787 | {
788 | "output_type": "stream",
789 | "text": [
790 | "Test accuracy on motion_blur_5: 66.72000288963318%\n",
791 | "Processing pixelate_5\n"
792 | ],
793 | "name": "stdout"
794 | },
795 | {
796 | "output_type": "stream",
797 | "text": [
798 | " 68%|██████▊ | 13/19 [00:24<00:10, 1.77s/it]"
799 | ],
800 | "name": "stderr"
801 | },
802 | {
803 | "output_type": "stream",
804 | "text": [
805 | "Test accuracy on pixelate_5: 73.94999861717224%\n",
806 | "Processing saturate_5\n"
807 | ],
808 | "name": "stdout"
809 | },
810 | {
811 | "output_type": "stream",
812 | "text": [
813 | " 74%|███████▎ | 14/19 [00:25<00:08, 1.77s/it]"
814 | ],
815 | "name": "stderr"
816 | },
817 | {
818 | "output_type": "stream",
819 | "text": [
820 | "Test accuracy on saturate_5: 81.05000257492065%\n",
821 | "Processing shot_noise_5\n"
822 | ],
823 | "name": "stdout"
824 | },
825 | {
826 | "output_type": "stream",
827 | "text": [
828 | " 79%|███████▉ | 15/19 [00:27<00:07, 1.77s/it]"
829 | ],
830 | "name": "stderr"
831 | },
832 | {
833 | "output_type": "stream",
834 | "text": [
835 | "Test accuracy on shot_noise_5: 51.88000202178955%\n",
836 | "Processing snow_5\n"
837 | ],
838 | "name": "stdout"
839 | },
840 | {
841 | "output_type": "stream",
842 | "text": [
843 | " 84%|████████▍ | 16/19 [00:29<00:05, 1.76s/it]"
844 | ],
845 | "name": "stderr"
846 | },
847 | {
848 | "output_type": "stream",
849 | "text": [
850 | "Test accuracy on snow_5: 65.82000255584717%\n",
851 | "Processing spatter_5\n"
852 | ],
853 | "name": "stdout"
854 | },
855 | {
856 | "output_type": "stream",
857 | "text": [
858 | " 89%|████████▉ | 17/19 [00:31<00:03, 1.76s/it]"
859 | ],
860 | "name": "stderr"
861 | },
862 | {
863 | "output_type": "stream",
864 | "text": [
865 | "Test accuracy on spatter_5: 70.31999826431274%\n",
866 | "Processing speckle_noise_5\n"
867 | ],
868 | "name": "stdout"
869 | },
870 | {
871 | "output_type": "stream",
872 | "text": [
873 | " 95%|█████████▍| 18/19 [00:32<00:01, 1.76s/it]"
874 | ],
875 | "name": "stderr"
876 | },
877 | {
878 | "output_type": "stream",
879 | "text": [
880 | "Test accuracy on speckle_noise_5: 53.46999764442444%\n",
881 | "Processing zoom_blur_5\n"
882 | ],
883 | "name": "stdout"
884 | },
885 | {
886 | "output_type": "stream",
887 | "text": [
888 | "100%|██████████| 19/19 [00:34<00:00, 1.82s/it]"
889 | ],
890 | "name": "stderr"
891 | },
892 | {
893 | "output_type": "stream",
894 | "text": [
895 | "Test accuracy on zoom_blur_5: 79.67000007629395%\n",
896 | "Mean Top-1 Accuracy: 62.80894765728399%\n"
897 | ],
898 | "name": "stdout"
899 | },
900 | {
901 | "output_type": "stream",
902 | "text": [
903 | "\n"
904 | ],
905 | "name": "stderr"
906 | }
907 | ]
908 | },
909 | {
910 | "cell_type": "markdown",
911 | "metadata": {
912 | "id": "stainless-prescription"
913 | },
914 | "source": [
915 | "### MA"
916 | ],
917 | "id": "stainless-prescription"
918 | },
919 | {
920 | "cell_type": "code",
921 | "metadata": {
922 | "id": "limiting-scanning",
923 | "outputId": "5935e901-1288-4abb-bf87-77a295456c62"
924 | },
925 | "source": [
926 | "# Evaluate teacher model trained with MA\n",
927 | "teacher_model_ma = get_training_model()\n",
928 | "teacher_model_ma.load_weights(\"teacher_model_ma.h5\")\n",
929 | "teacher_model_ma.compile(loss=\"sparse_categorical_crossentropy\",\n",
930 | " metrics=[\"accuracy\"])\n",
931 | "acc_dict, mean_top_1 = evaluate_model(teacher_model_ma)\n",
932 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")"
933 | ],
934 | "id": "limiting-scanning",
935 | "execution_count": null,
936 | "outputs": [
937 | {
938 | "output_type": "stream",
939 | "text": [
940 | " 0%| | 0/19 [00:00, ?it/s]"
941 | ],
942 | "name": "stderr"
943 | },
944 | {
945 | "output_type": "stream",
946 | "text": [
947 | "Processing brightness_5\n"
948 | ],
949 | "name": "stdout"
950 | },
951 | {
952 | "output_type": "stream",
953 | "text": [
954 | " 5%|▌ | 1/19 [00:02<00:50, 2.83s/it]"
955 | ],
956 | "name": "stderr"
957 | },
958 | {
959 | "output_type": "stream",
960 | "text": [
961 | "Test accuracy on brightness_5: 73.14000129699707%\n",
962 | "Processing contrast_5\n"
963 | ],
964 | "name": "stdout"
965 | },
966 | {
967 | "output_type": "stream",
968 | "text": [
969 | " 11%|█ | 2/19 [00:04<00:37, 2.21s/it]"
970 | ],
971 | "name": "stderr"
972 | },
973 | {
974 | "output_type": "stream",
975 | "text": [
976 | "Test accuracy on contrast_5: 19.679999351501465%\n",
977 | "Processing defocus_blur_5\n"
978 | ],
979 | "name": "stdout"
980 | },
981 | {
982 | "output_type": "stream",
983 | "text": [
984 | " 16%|█▌ | 3/19 [00:06<00:32, 2.01s/it]"
985 | ],
986 | "name": "stderr"
987 | },
988 | {
989 | "output_type": "stream",
990 | "text": [
991 | "Test accuracy on defocus_blur_5: 71.5499997138977%\n",
992 | "Processing elastic_5\n"
993 | ],
994 | "name": "stdout"
995 | },
996 | {
997 | "output_type": "stream",
998 | "text": [
999 | " 21%|██ | 4/19 [00:08<00:28, 1.91s/it]"
1000 | ],
1001 | "name": "stderr"
1002 | },
1003 | {
1004 | "output_type": "stream",
1005 | "text": [
1006 | "Test accuracy on elastic_5: 74.76999759674072%\n",
1007 | "Processing fog_5\n"
1008 | ],
1009 | "name": "stdout"
1010 | },
1011 | {
1012 | "output_type": "stream",
1013 | "text": [
1014 | " 26%|██▋ | 5/19 [00:09<00:25, 1.85s/it]"
1015 | ],
1016 | "name": "stderr"
1017 | },
1018 | {
1019 | "output_type": "stream",
1020 | "text": [
1021 | "Test accuracy on fog_5: 47.96999990940094%\n",
1022 | "Processing frost_5\n"
1023 | ],
1024 | "name": "stdout"
1025 | },
1026 | {
1027 | "output_type": "stream",
1028 | "text": [
1029 | " 32%|███▏ | 6/19 [00:11<00:23, 1.81s/it]"
1030 | ],
1031 | "name": "stderr"
1032 | },
1033 | {
1034 | "output_type": "stream",
1035 | "text": [
1036 | "Test accuracy on frost_5: 61.29999756813049%\n",
1037 | "Processing frosted_glass_blur_5\n"
1038 | ],
1039 | "name": "stdout"
1040 | },
1041 | {
1042 | "output_type": "stream",
1043 | "text": [
1044 | " 37%|███▋ | 7/19 [00:13<00:21, 1.79s/it]"
1045 | ],
1046 | "name": "stderr"
1047 | },
1048 | {
1049 | "output_type": "stream",
1050 | "text": [
1051 | "Test accuracy on frosted_glass_blur_5: 61.41999959945679%\n",
1052 | "Processing gaussian_blur_5\n"
1053 | ],
1054 | "name": "stdout"
1055 | },
1056 | {
1057 | "output_type": "stream",
1058 | "text": [
1059 | " 42%|████▏ | 8/19 [00:15<00:19, 1.78s/it]"
1060 | ],
1061 | "name": "stderr"
1062 | },
1063 | {
1064 | "output_type": "stream",
1065 | "text": [
1066 | "Test accuracy on gaussian_blur_5: 66.03000164031982%\n",
1067 | "Processing gaussian_noise_5\n"
1068 | ],
1069 | "name": "stdout"
1070 | },
1071 | {
1072 | "output_type": "stream",
1073 | "text": [
1074 | " 47%|████▋ | 9/19 [00:16<00:17, 1.77s/it]"
1075 | ],
1076 | "name": "stderr"
1077 | },
1078 | {
1079 | "output_type": "stream",
1080 | "text": [
1081 | "Test accuracy on gaussian_noise_5: 45.899999141693115%\n",
1082 | "Processing impulse_noise_5\n"
1083 | ],
1084 | "name": "stdout"
1085 | },
1086 | {
1087 | "output_type": "stream",
1088 | "text": [
1089 | " 53%|█████▎ | 10/19 [00:18<00:15, 1.76s/it]"
1090 | ],
1091 | "name": "stderr"
1092 | },
1093 | {
1094 | "output_type": "stream",
1095 | "text": [
1096 | "Test accuracy on impulse_noise_5: 30.320000648498535%\n",
1097 | "Processing jpeg_compression_5\n"
1098 | ],
1099 | "name": "stdout"
1100 | },
1101 | {
1102 | "output_type": "stream",
1103 | "text": [
1104 | " 58%|█████▊ | 11/19 [00:20<00:14, 1.76s/it]"
1105 | ],
1106 | "name": "stderr"
1107 | },
1108 | {
1109 | "output_type": "stream",
1110 | "text": [
1111 | "Test accuracy on jpeg_compression_5: 78.4600019454956%\n",
1112 | "Processing motion_blur_5\n"
1113 | ],
1114 | "name": "stdout"
1115 | },
1116 | {
1117 | "output_type": "stream",
1118 | "text": [
1119 | " 63%|██████▎ | 12/19 [00:22<00:12, 1.76s/it]"
1120 | ],
1121 | "name": "stderr"
1122 | },
1123 | {
1124 | "output_type": "stream",
1125 | "text": [
1126 | "Test accuracy on motion_blur_5: 64.19000029563904%\n",
1127 | "Processing pixelate_5\n"
1128 | ],
1129 | "name": "stdout"
1130 | },
1131 | {
1132 | "output_type": "stream",
1133 | "text": [
1134 | " 68%|██████▊ | 13/19 [00:23<00:10, 1.75s/it]"
1135 | ],
1136 | "name": "stderr"
1137 | },
1138 | {
1139 | "output_type": "stream",
1140 | "text": [
1141 | "Test accuracy on pixelate_5: 72.26999998092651%\n",
1142 | "Processing saturate_5\n"
1143 | ],
1144 | "name": "stdout"
1145 | },
1146 | {
1147 | "output_type": "stream",
1148 | "text": [
1149 | " 74%|███████▎ | 14/19 [00:25<00:08, 1.75s/it]"
1150 | ],
1151 | "name": "stderr"
1152 | },
1153 | {
1154 | "output_type": "stream",
1155 | "text": [
1156 | "Test accuracy on saturate_5: 66.04999899864197%\n",
1157 | "Processing shot_noise_5\n"
1158 | ],
1159 | "name": "stdout"
1160 | },
1161 | {
1162 | "output_type": "stream",
1163 | "text": [
1164 | " 79%|███████▉ | 15/19 [00:27<00:07, 1.75s/it]"
1165 | ],
1166 | "name": "stderr"
1167 | },
1168 | {
1169 | "output_type": "stream",
1170 | "text": [
1171 | "Test accuracy on shot_noise_5: 49.36999976634979%\n",
1172 | "Processing snow_5\n"
1173 | ],
1174 | "name": "stdout"
1175 | },
1176 | {
1177 | "output_type": "stream",
1178 | "text": [
1179 | " 84%|████████▍ | 16/19 [00:29<00:05, 1.75s/it]"
1180 | ],
1181 | "name": "stderr"
1182 | },
1183 | {
1184 | "output_type": "stream",
1185 | "text": [
1186 | "Test accuracy on snow_5: 63.60999941825867%\n",
1187 | "Processing spatter_5\n"
1188 | ],
1189 | "name": "stdout"
1190 | },
1191 | {
1192 | "output_type": "stream",
1193 | "text": [
1194 | " 89%|████████▉ | 17/19 [00:30<00:03, 1.75s/it]"
1195 | ],
1196 | "name": "stderr"
1197 | },
1198 | {
1199 | "output_type": "stream",
1200 | "text": [
1201 | "Test accuracy on spatter_5: 63.8700008392334%\n",
1202 | "Processing speckle_noise_5\n"
1203 | ],
1204 | "name": "stdout"
1205 | },
1206 | {
1207 | "output_type": "stream",
1208 | "text": [
1209 | " 95%|█████████▍| 18/19 [00:32<00:01, 1.75s/it]"
1210 | ],
1211 | "name": "stderr"
1212 | },
1213 | {
1214 | "output_type": "stream",
1215 | "text": [
1216 | "Test accuracy on speckle_noise_5: 49.75000023841858%\n",
1217 | "Processing zoom_blur_5\n"
1218 | ],
1219 | "name": "stdout"
1220 | },
1221 | {
1222 | "output_type": "stream",
1223 | "text": [
1224 | "100%|██████████| 19/19 [00:34<00:00, 1.81s/it]"
1225 | ],
1226 | "name": "stderr"
1227 | },
1228 | {
1229 | "output_type": "stream",
1230 | "text": [
1231 | "Test accuracy on zoom_blur_5: 78.32000255584717%\n",
1232 | "Mean Top-1 Accuracy: 59.89315792133934%\n"
1233 | ],
1234 | "name": "stdout"
1235 | },
1236 | {
1237 | "output_type": "stream",
1238 | "text": [
1239 | "\n"
1240 | ],
1241 | "name": "stderr"
1242 | }
1243 | ]
1244 | },
1245 | {
1246 | "cell_type": "code",
1247 | "metadata": {
1248 | "id": "legendary-berry",
1249 | "outputId": "ea17c1ac-412a-4843-ccc9-4f742837db0b"
1250 | },
1251 | "source": [
1252 | "# Evaluate the corresponding student model\n",
1253 | "student_noisy_ma = get_training_model()\n",
1254 | "student_noisy_ma.load_weights(\"student_noisy_ma.h5\")\n",
1255 | "student_noisy_ma.compile(loss=\"sparse_categorical_crossentropy\",\n",
1256 | " metrics=[\"accuracy\"])\n",
1257 | "acc_dict, mean_top_1 = evaluate_model(student_noisy_ma)\n",
1258 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")"
1259 | ],
1260 | "id": "legendary-berry",
1261 | "execution_count": null,
1262 | "outputs": [
1263 | {
1264 | "output_type": "stream",
1265 | "text": [
1266 | " 0%| | 0/19 [00:00, ?it/s]"
1267 | ],
1268 | "name": "stderr"
1269 | },
1270 | {
1271 | "output_type": "stream",
1272 | "text": [
1273 | "Processing brightness_5\n"
1274 | ],
1275 | "name": "stdout"
1276 | },
1277 | {
1278 | "output_type": "stream",
1279 | "text": [
1280 | " 5%|▌ | 1/19 [00:02<00:51, 2.84s/it]"
1281 | ],
1282 | "name": "stderr"
1283 | },
1284 | {
1285 | "output_type": "stream",
1286 | "text": [
1287 | "Test accuracy on brightness_5: 81.80999755859375%\n",
1288 | "Processing contrast_5\n"
1289 | ],
1290 | "name": "stdout"
1291 | },
1292 | {
1293 | "output_type": "stream",
1294 | "text": [
1295 | " 11%|█ | 2/19 [00:04<00:37, 2.20s/it]"
1296 | ],
1297 | "name": "stderr"
1298 | },
1299 | {
1300 | "output_type": "stream",
1301 | "text": [
1302 | "Test accuracy on contrast_5: 28.65999937057495%\n",
1303 | "Processing defocus_blur_5\n"
1304 | ],
1305 | "name": "stdout"
1306 | },
1307 | {
1308 | "output_type": "stream",
1309 | "text": [
1310 | " 16%|█▌ | 3/19 [00:06<00:31, 1.99s/it]"
1311 | ],
1312 | "name": "stderr"
1313 | },
1314 | {
1315 | "output_type": "stream",
1316 | "text": [
1317 | "Test accuracy on defocus_blur_5: 70.93999981880188%\n",
1318 | "Processing elastic_5\n"
1319 | ],
1320 | "name": "stdout"
1321 | },
1322 | {
1323 | "output_type": "stream",
1324 | "text": [
1325 | " 21%|██ | 4/19 [00:08<00:28, 1.90s/it]"
1326 | ],
1327 | "name": "stderr"
1328 | },
1329 | {
1330 | "output_type": "stream",
1331 | "text": [
1332 | "Test accuracy on elastic_5: 72.43000268936157%\n",
1333 | "Processing fog_5\n"
1334 | ],
1335 | "name": "stdout"
1336 | },
1337 | {
1338 | "output_type": "stream",
1339 | "text": [
1340 | " 26%|██▋ | 5/19 [00:09<00:25, 1.85s/it]"
1341 | ],
1342 | "name": "stderr"
1343 | },
1344 | {
1345 | "output_type": "stream",
1346 | "text": [
1347 | "Test accuracy on fog_5: 49.43999946117401%\n",
1348 | "Processing frost_5\n"
1349 | ],
1350 | "name": "stdout"
1351 | },
1352 | {
1353 | "output_type": "stream",
1354 | "text": [
1355 | " 32%|███▏ | 6/19 [00:11<00:23, 1.81s/it]"
1356 | ],
1357 | "name": "stderr"
1358 | },
1359 | {
1360 | "output_type": "stream",
1361 | "text": [
1362 | "Test accuracy on frost_5: 60.009998083114624%\n",
1363 | "Processing frosted_glass_blur_5\n"
1364 | ],
1365 | "name": "stdout"
1366 | },
1367 | {
1368 | "output_type": "stream",
1369 | "text": [
1370 | " 37%|███▋ | 7/19 [00:13<00:21, 1.79s/it]"
1371 | ],
1372 | "name": "stderr"
1373 | },
1374 | {
1375 | "output_type": "stream",
1376 | "text": [
1377 | "Test accuracy on frosted_glass_blur_5: 58.139997720718384%\n",
1378 | "Processing gaussian_blur_5\n"
1379 | ],
1380 | "name": "stdout"
1381 | },
1382 | {
1383 | "output_type": "stream",
1384 | "text": [
1385 | " 42%|████▏ | 8/19 [00:15<00:19, 1.78s/it]"
1386 | ],
1387 | "name": "stderr"
1388 | },
1389 | {
1390 | "output_type": "stream",
1391 | "text": [
1392 | "Test accuracy on gaussian_blur_5: 65.90999960899353%\n",
1393 | "Processing gaussian_noise_5\n"
1394 | ],
1395 | "name": "stdout"
1396 | },
1397 | {
1398 | "output_type": "stream",
1399 | "text": [
1400 | " 47%|████▋ | 9/19 [00:16<00:17, 1.77s/it]"
1401 | ],
1402 | "name": "stderr"
1403 | },
1404 | {
1405 | "output_type": "stream",
1406 | "text": [
1407 | "Test accuracy on gaussian_noise_5: 47.00999855995178%\n",
1408 | "Processing impulse_noise_5\n"
1409 | ],
1410 | "name": "stdout"
1411 | },
1412 | {
1413 | "output_type": "stream",
1414 | "text": [
1415 | " 53%|█████▎ | 10/19 [00:18<00:15, 1.76s/it]"
1416 | ],
1417 | "name": "stderr"
1418 | },
1419 | {
1420 | "output_type": "stream",
1421 | "text": [
1422 | "Test accuracy on impulse_noise_5: 29.399999976158142%\n",
1423 | "Processing jpeg_compression_5\n"
1424 | ],
1425 | "name": "stdout"
1426 | },
1427 | {
1428 | "output_type": "stream",
1429 | "text": [
1430 | " 58%|█████▊ | 11/19 [00:20<00:14, 1.76s/it]"
1431 | ],
1432 | "name": "stderr"
1433 | },
1434 | {
1435 | "output_type": "stream",
1436 | "text": [
1437 | "Test accuracy on jpeg_compression_5: 76.10999941825867%\n",
1438 | "Processing motion_blur_5\n"
1439 | ],
1440 | "name": "stdout"
1441 | },
1442 | {
1443 | "output_type": "stream",
1444 | "text": [
1445 | " 63%|██████▎ | 12/19 [00:22<00:12, 1.76s/it]"
1446 | ],
1447 | "name": "stderr"
1448 | },
1449 | {
1450 | "output_type": "stream",
1451 | "text": [
1452 | "Test accuracy on motion_blur_5: 65.93000292778015%\n",
1453 | "Processing pixelate_5\n"
1454 | ],
1455 | "name": "stdout"
1456 | },
1457 | {
1458 | "output_type": "stream",
1459 | "text": [
1460 | " 68%|██████▊ | 13/19 [00:23<00:10, 1.75s/it]"
1461 | ],
1462 | "name": "stderr"
1463 | },
1464 | {
1465 | "output_type": "stream",
1466 | "text": [
1467 | "Test accuracy on pixelate_5: 71.3699996471405%\n",
1468 | "Processing saturate_5\n"
1469 | ],
1470 | "name": "stdout"
1471 | },
1472 | {
1473 | "output_type": "stream",
1474 | "text": [
1475 | " 74%|███████▎ | 14/19 [00:25<00:08, 1.75s/it]"
1476 | ],
1477 | "name": "stderr"
1478 | },
1479 | {
1480 | "output_type": "stream",
1481 | "text": [
1482 | "Test accuracy on saturate_5: 80.47000169754028%\n",
1483 | "Processing shot_noise_5\n"
1484 | ],
1485 | "name": "stdout"
1486 | },
1487 | {
1488 | "output_type": "stream",
1489 | "text": [
1490 | " 79%|███████▉ | 15/19 [00:27<00:06, 1.75s/it]"
1491 | ],
1492 | "name": "stderr"
1493 | },
1494 | {
1495 | "output_type": "stream",
1496 | "text": [
1497 | "Test accuracy on shot_noise_5: 51.42999887466431%\n",
1498 | "Processing snow_5\n"
1499 | ],
1500 | "name": "stdout"
1501 | },
1502 | {
1503 | "output_type": "stream",
1504 | "text": [
1505 | " 84%|████████▍ | 16/19 [00:29<00:05, 1.74s/it]"
1506 | ],
1507 | "name": "stderr"
1508 | },
1509 | {
1510 | "output_type": "stream",
1511 | "text": [
1512 | "Test accuracy on snow_5: 66.38000011444092%\n",
1513 | "Processing spatter_5\n"
1514 | ],
1515 | "name": "stdout"
1516 | },
1517 | {
1518 | "output_type": "stream",
1519 | "text": [
1520 | " 89%|████████▉ | 17/19 [00:30<00:03, 1.75s/it]"
1521 | ],
1522 | "name": "stderr"
1523 | },
1524 | {
1525 | "output_type": "stream",
1526 | "text": [
1527 | "Test accuracy on spatter_5: 70.24000287055969%\n",
1528 | "Processing speckle_noise_5\n"
1529 | ],
1530 | "name": "stdout"
1531 | },
1532 | {
1533 | "output_type": "stream",
1534 | "text": [
1535 | " 95%|█████████▍| 18/19 [00:32<00:01, 1.75s/it]"
1536 | ],
1537 | "name": "stderr"
1538 | },
1539 | {
1540 | "output_type": "stream",
1541 | "text": [
1542 | "Test accuracy on speckle_noise_5: 52.27000117301941%\n",
1543 | "Processing zoom_blur_5\n"
1544 | ],
1545 | "name": "stdout"
1546 | },
1547 | {
1548 | "output_type": "stream",
1549 | "text": [
1550 | "100%|██████████| 19/19 [00:34<00:00, 1.81s/it]"
1551 | ],
1552 | "name": "stderr"
1553 | },
1554 | {
1555 | "output_type": "stream",
1556 | "text": [
1557 | "Test accuracy on zoom_blur_5: 77.46000289916992%\n",
1558 | "Mean Top-1 Accuracy: 61.863684340527186%\n"
1559 | ],
1560 | "name": "stdout"
1561 | },
1562 | {
1563 | "output_type": "stream",
1564 | "text": [
1565 | "\n"
1566 | ],
1567 | "name": "stderr"
1568 | }
1569 | ]
1570 | },
1571 | {
1572 | "cell_type": "markdown",
1573 | "metadata": {
1574 | "id": "regional-composer"
1575 | },
1576 | "source": [
1577 | "### Regular"
1578 | ],
1579 | "id": "regional-composer"
1580 | },
1581 | {
1582 | "cell_type": "code",
1583 | "metadata": {
1584 | "id": "valued-dominican",
1585 | "outputId": "578da18e-4baa-4787-df05-36942085e974"
1586 | },
1587 | "source": [
1588 | "# Evaluate teacher model \n",
1589 | "teacher_model = get_training_model()\n",
1590 | "teacher_model.load_weights(\"teacher_model.h5\")\n",
1591 | "teacher_model.compile(loss=\"sparse_categorical_crossentropy\",\n",
1592 | " metrics=[\"accuracy\"])\n",
1593 | "acc_dict, mean_top_1 = evaluate_model(teacher_model)\n",
1594 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")"
1595 | ],
1596 | "id": "valued-dominican",
1597 | "execution_count": null,
1598 | "outputs": [
1599 | {
1600 | "output_type": "stream",
1601 | "text": [
1602 | " 0%| | 0/19 [00:00, ?it/s]"
1603 | ],
1604 | "name": "stderr"
1605 | },
1606 | {
1607 | "output_type": "stream",
1608 | "text": [
1609 | "Processing brightness_5\n"
1610 | ],
1611 | "name": "stdout"
1612 | },
1613 | {
1614 | "output_type": "stream",
1615 | "text": [
1616 | " 5%|▌ | 1/19 [00:03<00:55, 3.08s/it]"
1617 | ],
1618 | "name": "stderr"
1619 | },
1620 | {
1621 | "output_type": "stream",
1622 | "text": [
1623 | "Test accuracy on brightness_5: 75.52000284194946%\n",
1624 | "Processing contrast_5\n"
1625 | ],
1626 | "name": "stdout"
1627 | },
1628 | {
1629 | "output_type": "stream",
1630 | "text": [
1631 | " 11%|█ | 2/19 [00:04<00:38, 2.29s/it]"
1632 | ],
1633 | "name": "stderr"
1634 | },
1635 | {
1636 | "output_type": "stream",
1637 | "text": [
1638 | "Test accuracy on contrast_5: 24.3599995970726%\n",
1639 | "Processing defocus_blur_5\n"
1640 | ],
1641 | "name": "stdout"
1642 | },
1643 | {
1644 | "output_type": "stream",
1645 | "text": [
1646 | " 16%|█▌ | 3/19 [00:06<00:32, 2.04s/it]"
1647 | ],
1648 | "name": "stderr"
1649 | },
1650 | {
1651 | "output_type": "stream",
1652 | "text": [
1653 | "Test accuracy on defocus_blur_5: 71.95000052452087%\n",
1654 | "Processing elastic_5\n"
1655 | ],
1656 | "name": "stdout"
1657 | },
1658 | {
1659 | "output_type": "stream",
1660 | "text": [
1661 | " 21%|██ | 4/19 [00:08<00:28, 1.93s/it]"
1662 | ],
1663 | "name": "stderr"
1664 | },
1665 | {
1666 | "output_type": "stream",
1667 | "text": [
1668 | "Test accuracy on elastic_5: 74.55999851226807%\n",
1669 | "Processing fog_5\n"
1670 | ],
1671 | "name": "stdout"
1672 | },
1673 | {
1674 | "output_type": "stream",
1675 | "text": [
1676 | " 26%|██▋ | 5/19 [00:10<00:26, 1.86s/it]"
1677 | ],
1678 | "name": "stderr"
1679 | },
1680 | {
1681 | "output_type": "stream",
1682 | "text": [
1683 | "Test accuracy on fog_5: 45.71999907493591%\n",
1684 | "Processing frost_5\n"
1685 | ],
1686 | "name": "stdout"
1687 | },
1688 | {
1689 | "output_type": "stream",
1690 | "text": [
1691 | " 32%|███▏ | 6/19 [00:11<00:23, 1.82s/it]"
1692 | ],
1693 | "name": "stderr"
1694 | },
1695 | {
1696 | "output_type": "stream",
1697 | "text": [
1698 | "Test accuracy on frost_5: 62.48999834060669%\n",
1699 | "Processing frosted_glass_blur_5\n"
1700 | ],
1701 | "name": "stdout"
1702 | },
1703 | {
1704 | "output_type": "stream",
1705 | "text": [
1706 | " 37%|███▋ | 7/19 [00:13<00:21, 1.80s/it]"
1707 | ],
1708 | "name": "stderr"
1709 | },
1710 | {
1711 | "output_type": "stream",
1712 | "text": [
1713 | "Test accuracy on frosted_glass_blur_5: 63.33000063896179%\n",
1714 | "Processing gaussian_blur_5\n"
1715 | ],
1716 | "name": "stdout"
1717 | },
1718 | {
1719 | "output_type": "stream",
1720 | "text": [
1721 | " 42%|████▏ | 8/19 [00:15<00:19, 1.78s/it]"
1722 | ],
1723 | "name": "stderr"
1724 | },
1725 | {
1726 | "output_type": "stream",
1727 | "text": [
1728 | "Test accuracy on gaussian_blur_5: 66.90000295639038%\n",
1729 | "Processing gaussian_noise_5\n"
1730 | ],
1731 | "name": "stdout"
1732 | },
1733 | {
1734 | "output_type": "stream",
1735 | "text": [
1736 | " 47%|████▋ | 9/19 [00:17<00:17, 1.78s/it]"
1737 | ],
1738 | "name": "stderr"
1739 | },
1740 | {
1741 | "output_type": "stream",
1742 | "text": [
1743 | "Test accuracy on gaussian_noise_5: 44.62999999523163%\n",
1744 | "Processing impulse_noise_5\n"
1745 | ],
1746 | "name": "stdout"
1747 | },
1748 | {
1749 | "output_type": "stream",
1750 | "text": [
1751 | " 53%|█████▎ | 10/19 [00:18<00:15, 1.77s/it]"
1752 | ],
1753 | "name": "stderr"
1754 | },
1755 | {
1756 | "output_type": "stream",
1757 | "text": [
1758 | "Test accuracy on impulse_noise_5: 27.889999747276306%\n",
1759 | "Processing jpeg_compression_5\n"
1760 | ],
1761 | "name": "stdout"
1762 | },
1763 | {
1764 | "output_type": "stream",
1765 | "text": [
1766 | " 58%|█████▊ | 11/19 [00:20<00:14, 1.76s/it]"
1767 | ],
1768 | "name": "stderr"
1769 | },
1770 | {
1771 | "output_type": "stream",
1772 | "text": [
1773 | "Test accuracy on jpeg_compression_5: 78.46999764442444%\n",
1774 | "Processing motion_blur_5\n"
1775 | ],
1776 | "name": "stdout"
1777 | },
1778 | {
1779 | "output_type": "stream",
1780 | "text": [
1781 | " 63%|██████▎ | 12/19 [00:22<00:12, 1.76s/it]"
1782 | ],
1783 | "name": "stderr"
1784 | },
1785 | {
1786 | "output_type": "stream",
1787 | "text": [
1788 | "Test accuracy on motion_blur_5: 66.21999740600586%\n",
1789 | "Processing pixelate_5\n"
1790 | ],
1791 | "name": "stdout"
1792 | },
1793 | {
1794 | "output_type": "stream",
1795 | "text": [
1796 | " 68%|██████▊ | 13/19 [00:24<00:10, 1.75s/it]"
1797 | ],
1798 | "name": "stderr"
1799 | },
1800 | {
1801 | "output_type": "stream",
1802 | "text": [
1803 | "Test accuracy on pixelate_5: 75.98999738693237%\n",
1804 | "Processing saturate_5\n"
1805 | ],
1806 | "name": "stdout"
1807 | },
1808 | {
1809 | "output_type": "stream",
1810 | "text": [
1811 | " 74%|███████▎ | 14/19 [00:25<00:08, 1.75s/it]"
1812 | ],
1813 | "name": "stderr"
1814 | },
1815 | {
1816 | "output_type": "stream",
1817 | "text": [
1818 | "Test accuracy on saturate_5: 67.47999787330627%\n",
1819 | "Processing shot_noise_5\n"
1820 | ],
1821 | "name": "stdout"
1822 | },
1823 | {
1824 | "output_type": "stream",
1825 | "text": [
1826 | " 79%|███████▉ | 15/19 [00:27<00:06, 1.75s/it]"
1827 | ],
1828 | "name": "stderr"
1829 | },
1830 | {
1831 | "output_type": "stream",
1832 | "text": [
1833 | "Test accuracy on shot_noise_5: 48.44000041484833%\n",
1834 | "Processing snow_5\n"
1835 | ],
1836 | "name": "stdout"
1837 | },
1838 | {
1839 | "output_type": "stream",
1840 | "text": [
1841 | " 84%|████████▍ | 16/19 [00:29<00:05, 1.75s/it]"
1842 | ],
1843 | "name": "stderr"
1844 | },
1845 | {
1846 | "output_type": "stream",
1847 | "text": [
1848 | "Test accuracy on snow_5: 65.85999727249146%\n",
1849 | "Processing spatter_5\n"
1850 | ],
1851 | "name": "stdout"
1852 | },
1853 | {
1854 | "output_type": "stream",
1855 | "text": [
1856 | " 89%|████████▉ | 17/19 [00:31<00:03, 1.74s/it]"
1857 | ],
1858 | "name": "stderr"
1859 | },
1860 | {
1861 | "output_type": "stream",
1862 | "text": [
1863 | "Test accuracy on spatter_5: 63.81999850273132%\n",
1864 | "Processing speckle_noise_5\n"
1865 | ],
1866 | "name": "stdout"
1867 | },
1868 | {
1869 | "output_type": "stream",
1870 | "text": [
1871 | " 95%|█████████▍| 18/19 [00:32<00:01, 1.74s/it]"
1872 | ],
1873 | "name": "stderr"
1874 | },
1875 | {
1876 | "output_type": "stream",
1877 | "text": [
1878 | "Test accuracy on speckle_noise_5: 49.059998989105225%\n",
1879 | "Processing zoom_blur_5\n"
1880 | ],
1881 | "name": "stdout"
1882 | },
1883 | {
1884 | "output_type": "stream",
1885 | "text": [
1886 | "100%|██████████| 19/19 [00:34<00:00, 1.82s/it]"
1887 | ],
1888 | "name": "stderr"
1889 | },
1890 | {
1891 | "output_type": "stream",
1892 | "text": [
1893 | "Test accuracy on zoom_blur_5: 77.02000141143799%\n",
1894 | "Mean Top-1 Accuracy: 60.51105205949984%\n"
1895 | ],
1896 | "name": "stdout"
1897 | },
1898 | {
1899 | "output_type": "stream",
1900 | "text": [
1901 | "\n"
1902 | ],
1903 | "name": "stderr"
1904 | }
1905 | ]
1906 | },
1907 | {
1908 | "cell_type": "code",
1909 | "metadata": {
1910 | "id": "hydraulic-finding",
1911 | "outputId": "4b224c7f-5576-40b6-e3ad-cf26691793c8"
1912 | },
1913 | "source": [
1914 | "# Evaluate the corresponding student model\n",
1915 | "student_noisy = get_training_model()\n",
1916 | "student_noisy.load_weights(\"student_noisy.h5\")\n",
1917 | "student_noisy.compile(loss=\"sparse_categorical_crossentropy\",\n",
1918 | " metrics=[\"accuracy\"])\n",
1919 | "acc_dict, mean_top_1 = evaluate_model(student_noisy)\n",
1920 | "print(f\"Mean Top-1 Accuracy: {mean_top_1}%\")"
1921 | ],
1922 | "id": "hydraulic-finding",
1923 | "execution_count": null,
1924 | "outputs": [
1925 | {
1926 | "output_type": "stream",
1927 | "text": [
1928 | " 0%| | 0/19 [00:00, ?it/s]"
1929 | ],
1930 | "name": "stderr"
1931 | },
1932 | {
1933 | "output_type": "stream",
1934 | "text": [
1935 | "Processing brightness_5\n"
1936 | ],
1937 | "name": "stdout"
1938 | },
1939 | {
1940 | "output_type": "stream",
1941 | "text": [
1942 | " 5%|▌ | 1/19 [00:02<00:52, 2.92s/it]"
1943 | ],
1944 | "name": "stderr"
1945 | },
1946 | {
1947 | "output_type": "stream",
1948 | "text": [
1949 | "Test accuracy on brightness_5: 79.80999946594238%\n",
1950 | "Processing contrast_5\n"
1951 | ],
1952 | "name": "stdout"
1953 | },
1954 | {
1955 | "output_type": "stream",
1956 | "text": [
1957 | " 11%|█ | 2/19 [00:04<00:37, 2.23s/it]"
1958 | ],
1959 | "name": "stderr"
1960 | },
1961 | {
1962 | "output_type": "stream",
1963 | "text": [
1964 | "Test accuracy on contrast_5: 20.76999992132187%\n",
1965 | "Processing defocus_blur_5\n"
1966 | ],
1967 | "name": "stdout"
1968 | },
1969 | {
1970 | "output_type": "stream",
1971 | "text": [
1972 | " 16%|█▌ | 3/19 [00:06<00:32, 2.01s/it]"
1973 | ],
1974 | "name": "stderr"
1975 | },
1976 | {
1977 | "output_type": "stream",
1978 | "text": [
1979 | "Test accuracy on defocus_blur_5: 68.98000240325928%\n",
1980 | "Processing elastic_5\n"
1981 | ],
1982 | "name": "stdout"
1983 | },
1984 | {
1985 | "output_type": "stream",
1986 | "text": [
1987 | " 21%|██ | 4/19 [00:08<00:28, 1.90s/it]"
1988 | ],
1989 | "name": "stderr"
1990 | },
1991 | {
1992 | "output_type": "stream",
1993 | "text": [
1994 | "Test accuracy on elastic_5: 71.53000235557556%\n",
1995 | "Processing fog_5\n"
1996 | ],
1997 | "name": "stdout"
1998 | },
1999 | {
2000 | "output_type": "stream",
2001 | "text": [
2002 | " 26%|██▋ | 5/19 [00:09<00:25, 1.84s/it]"
2003 | ],
2004 | "name": "stderr"
2005 | },
2006 | {
2007 | "output_type": "stream",
2008 | "text": [
2009 | "Test accuracy on fog_5: 46.140000224113464%\n",
2010 | "Processing frost_5\n"
2011 | ],
2012 | "name": "stdout"
2013 | },
2014 | {
2015 | "output_type": "stream",
2016 | "text": [
2017 | " 32%|███▏ | 6/19 [00:11<00:23, 1.81s/it]"
2018 | ],
2019 | "name": "stderr"
2020 | },
2021 | {
2022 | "output_type": "stream",
2023 | "text": [
2024 | "Test accuracy on frost_5: 57.56999850273132%\n",
2025 | "Processing frosted_glass_blur_5\n"
2026 | ],
2027 | "name": "stdout"
2028 | },
2029 | {
2030 | "output_type": "stream",
2031 | "text": [
2032 | " 37%|███▋ | 7/19 [00:13<00:21, 1.79s/it]"
2033 | ],
2034 | "name": "stderr"
2035 | },
2036 | {
2037 | "output_type": "stream",
2038 | "text": [
2039 | "Test accuracy on frosted_glass_blur_5: 57.03999996185303%\n",
2040 | "Processing gaussian_blur_5\n"
2041 | ],
2042 | "name": "stdout"
2043 | },
2044 | {
2045 | "output_type": "stream",
2046 | "text": [
2047 | " 42%|████▏ | 8/19 [00:15<00:19, 1.78s/it]"
2048 | ],
2049 | "name": "stderr"
2050 | },
2051 | {
2052 | "output_type": "stream",
2053 | "text": [
2054 | "Test accuracy on gaussian_blur_5: 63.429999351501465%\n",
2055 | "Processing gaussian_noise_5\n"
2056 | ],
2057 | "name": "stdout"
2058 | },
2059 | {
2060 | "output_type": "stream",
2061 | "text": [
2062 | " 47%|████▋ | 9/19 [00:16<00:17, 1.77s/it]"
2063 | ],
2064 | "name": "stderr"
2065 | },
2066 | {
2067 | "output_type": "stream",
2068 | "text": [
2069 | "Test accuracy on gaussian_noise_5: 44.06000077724457%\n",
2070 | "Processing impulse_noise_5\n"
2071 | ],
2072 | "name": "stdout"
2073 | },
2074 | {
2075 | "output_type": "stream",
2076 | "text": [
2077 | " 53%|█████▎ | 10/19 [00:18<00:15, 1.76s/it]"
2078 | ],
2079 | "name": "stderr"
2080 | },
2081 | {
2082 | "output_type": "stream",
2083 | "text": [
2084 | "Test accuracy on impulse_noise_5: 24.68000054359436%\n",
2085 | "Processing jpeg_compression_5\n"
2086 | ],
2087 | "name": "stdout"
2088 | },
2089 | {
2090 | "output_type": "stream",
2091 | "text": [
2092 | " 58%|█████▊ | 11/19 [00:20<00:14, 1.76s/it]"
2093 | ],
2094 | "name": "stderr"
2095 | },
2096 | {
2097 | "output_type": "stream",
2098 | "text": [
2099 | "Test accuracy on jpeg_compression_5: 74.55000281333923%\n",
2100 | "Processing motion_blur_5\n"
2101 | ],
2102 | "name": "stdout"
2103 | },
2104 | {
2105 | "output_type": "stream",
2106 | "text": [
2107 | " 63%|██████▎ | 12/19 [00:22<00:12, 1.75s/it]"
2108 | ],
2109 | "name": "stderr"
2110 | },
2111 | {
2112 | "output_type": "stream",
2113 | "text": [
2114 | "Test accuracy on motion_blur_5: 62.459999322891235%\n",
2115 | "Processing pixelate_5\n"
2116 | ],
2117 | "name": "stdout"
2118 | },
2119 | {
2120 | "output_type": "stream",
2121 | "text": [
2122 | " 68%|██████▊ | 13/19 [00:23<00:10, 1.75s/it]"
2123 | ],
2124 | "name": "stderr"
2125 | },
2126 | {
2127 | "output_type": "stream",
2128 | "text": [
2129 | "Test accuracy on pixelate_5: 70.16000151634216%\n",
2130 | "Processing saturate_5\n"
2131 | ],
2132 | "name": "stdout"
2133 | },
2134 | {
2135 | "output_type": "stream",
2136 | "text": [
2137 | " 74%|███████▎ | 14/19 [00:25<00:08, 1.75s/it]"
2138 | ],
2139 | "name": "stderr"
2140 | },
2141 | {
2142 | "output_type": "stream",
2143 | "text": [
2144 | "Test accuracy on saturate_5: 77.31000185012817%\n",
2145 | "Processing shot_noise_5\n"
2146 | ],
2147 | "name": "stdout"
2148 | },
2149 | {
2150 | "output_type": "stream",
2151 | "text": [
2152 | " 79%|███████▉ | 15/19 [00:27<00:07, 1.75s/it]"
2153 | ],
2154 | "name": "stderr"
2155 | },
2156 | {
2157 | "output_type": "stream",
2158 | "text": [
2159 | "Test accuracy on shot_noise_5: 48.60000014305115%\n",
2160 | "Processing snow_5\n"
2161 | ],
2162 | "name": "stdout"
2163 | },
2164 | {
2165 | "output_type": "stream",
2166 | "text": [
2167 | " 84%|████████▍ | 16/19 [00:29<00:05, 1.75s/it]"
2168 | ],
2169 | "name": "stderr"
2170 | },
2171 | {
2172 | "output_type": "stream",
2173 | "text": [
2174 | "Test accuracy on snow_5: 63.2099986076355%\n",
2175 | "Processing spatter_5\n"
2176 | ],
2177 | "name": "stdout"
2178 | },
2179 | {
2180 | "output_type": "stream",
2181 | "text": [
2182 | " 89%|████████▉ | 17/19 [00:30<00:03, 1.75s/it]"
2183 | ],
2184 | "name": "stderr"
2185 | },
2186 | {
2187 | "output_type": "stream",
2188 | "text": [
2189 | "Test accuracy on spatter_5: 65.49000144004822%\n",
2190 | "Processing speckle_noise_5\n"
2191 | ],
2192 | "name": "stdout"
2193 | },
2194 | {
2195 | "output_type": "stream",
2196 | "text": [
2197 | " 95%|█████████▍| 18/19 [00:32<00:01, 1.75s/it]"
2198 | ],
2199 | "name": "stderr"
2200 | },
2201 | {
2202 | "output_type": "stream",
2203 | "text": [
2204 | "Test accuracy on speckle_noise_5: 48.96999895572662%\n",
2205 | "Processing zoom_blur_5\n"
2206 | ],
2207 | "name": "stdout"
2208 | },
2209 | {
2210 | "output_type": "stream",
2211 | "text": [
2212 | "100%|██████████| 19/19 [00:34<00:00, 1.81s/it]"
2213 | ],
2214 | "name": "stderr"
2215 | },
2216 | {
2217 | "output_type": "stream",
2218 | "text": [
2219 | "Test accuracy on zoom_blur_5: 75.85999965667725%\n",
2220 | "Mean Top-1 Accuracy: 58.98000041120931%\n"
2221 | ],
2222 | "name": "stdout"
2223 | },
2224 | {
2225 | "output_type": "stream",
2226 | "text": [
2227 | "\n"
2228 | ],
2229 | "name": "stderr"
2230 | }
2231 | ]
2232 | },
2233 | {
2234 | "cell_type": "markdown",
2235 | "metadata": {
2236 | "id": "urban-sapphire"
2237 | },
2238 | "source": [
2239 | "## Evaluate on CIFAR-10 Test Set"
2240 | ],
2241 | "id": "urban-sapphire"
2242 | },
2243 | {
2244 | "cell_type": "code",
2245 | "metadata": {
2246 | "id": "precious-release"
2247 | },
2248 | "source": [
2249 | "(_, _), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
2250 | "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
2251 | "test_ds = prepare_dataset(test_ds)"
2252 | ],
2253 | "id": "precious-release",
2254 | "execution_count": null,
2255 | "outputs": []
2256 | },
2257 | {
2258 | "cell_type": "code",
2259 | "metadata": {
2260 | "id": "broadband-dream",
2261 | "outputId": "cd48ea0c-e40d-41ae-9fa9-3ac681c5580f"
2262 | },
2263 | "source": [
2264 | "# Evaluate teacher model trained with SWA\n",
2265 | "_, test_acc = teacher_model_swa.evaluate(test_ds, verbose=0)\n",
2266 | "print(\"Test accuracy with SWA teacher: {:.2f}%\".format(test_acc * 100))\n",
2267 | "\n",
2268 | "# Evaluate the corresponding student\n",
2269 | "_, test_acc = student_noisy_swa.evaluate(test_ds, verbose=0)\n",
2270 | "print(\"Test accuracy with noisy SWA student: {:.2f}%\".format(test_acc * 100))"
2271 | ],
2272 | "id": "broadband-dream",
2273 | "execution_count": null,
2274 | "outputs": [
2275 | {
2276 | "output_type": "stream",
2277 | "text": [
2278 | "Test accuracy with SWA teacher: 84.82%\n",
2279 | "Test accuracy with noisy SWA student: 85.24%\n"
2280 | ],
2281 | "name": "stdout"
2282 | }
2283 | ]
2284 | },
2285 | {
2286 | "cell_type": "code",
2287 | "metadata": {
2288 | "id": "enormous-repository",
2289 | "outputId": "ae8923c8-7e55-4753-aa9e-2ee075bbef9f"
2290 | },
2291 | "source": [
2292 | "# Evaluate teacher model trained with MA\n",
2293 | "_, test_acc = teacher_model_ma.evaluate(test_ds, verbose=0)\n",
2294 | "print(\"Test accuracy with MA teacher: {:.2f}%\".format(test_acc * 100))\n",
2295 | "\n",
2296 | "# Evaluate the corresponding student\n",
2297 | "_, test_acc = student_noisy_ma.evaluate(test_ds, verbose=0)\n",
2298 | "print(\"Test accuracy with noisy MA student: {:.2f}%\".format(test_acc * 100))"
2299 | ],
2300 | "id": "enormous-repository",
2301 | "execution_count": null,
2302 | "outputs": [
2303 | {
2304 | "output_type": "stream",
2305 | "text": [
2306 | "Test accuracy with MA teacher: 83.88%\n",
2307 | "Test accuracy with noisy MA student: 84.42%\n"
2308 | ],
2309 | "name": "stdout"
2310 | }
2311 | ]
2312 | },
2313 | {
2314 | "cell_type": "code",
2315 | "metadata": {
2316 | "id": "dental-quest",
2317 | "outputId": "3f18d6cc-9605-44ca-8001-9d9961cd499b"
2318 | },
2319 | "source": [
2320 | "# Evaluate regular teacher model\n",
2321 | "_, test_acc = teacher_model.evaluate(test_ds, verbose=0)\n",
2322 | "print(\"Test accuracy with regular teacher: {:.2f}%\".format(test_acc * 100))\n",
2323 | "\n",
2324 | "# Evaluate the corresponding student\n",
2325 | "_, test_acc = student_noisy.evaluate(test_ds, verbose=0)\n",
2326 | "print(\"Test accuracy with regular noisy student: {:.2f}%\".format(test_acc * 100))"
2327 | ],
2328 | "id": "dental-quest",
2329 | "execution_count": null,
2330 | "outputs": [
2331 | {
2332 | "output_type": "stream",
2333 | "text": [
2334 | "Test accuracy with regular teacher: 83.20%\n",
2335 | "Test accuracy with regular noisy student: 82.16%\n"
2336 | ],
2337 | "name": "stdout"
2338 | }
2339 | ]
2340 | }
2341 | ]
2342 | }
--------------------------------------------------------------------------------