├── .gitignore
├── LICENSE
├── README.md
├── classification.ipynb
├── conversion.ipynb
├── fine_tune.ipynb
├── i1k_eval
├── README.md
├── eval.ipynb
└── imagenet_class_index.json
└── model-selector.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ViT-jax2tf
2 |
3 |
4 | 
5 | Example usage.
6 |
7 |
8 | This repository hosts code for converting the original Vision Transformer models [1] (JAX) to
9 | TensorFlow.
10 |
11 | The original models were fine-tuned on the ImageNet-1k dataset [2]. For more details
12 | on the training protocols, please follow [3]. The authors of [3] open-sourced about
13 | **50k different variants of Vision Transformer models** in JAX. Using the
14 | [`conversion.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/conversion.ipynb)
15 | notebook, one should be able to take a model from the pool of models and convert that
16 | to TensorFlow and use that with TensorFlow Hub and Keras.
17 |
18 | The original model classes and weights [4] were converted using the `jax2tf` tool [5].
19 |
20 | **Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.**
21 |
22 | ## Vision Transformers on TensorFlow Hub
23 |
24 | Find the model collection on TensorFlow Hub: https://tfhub.dev/sayakpaul/collections/vision_transformer/1.
25 |
26 | Eight best performing ImageNet-1k models have also been made available on TensorFlow
27 | Hub that can be used either for off-the-shelf image classification or transfer learning.
28 | Please follow the [`model-selector.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/model-selector.ipynb)
29 | notebook to understand how these models were chosen.
30 |
31 | The table below provides a performance summary:
32 |
33 | | **Model** | **Top-1 Accuracy** | **Checkpoint** | **Misc** |
34 | |:---:|:---:|:---:|:---:|
35 | | B/8 | 85.948 | B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
36 | | L/16 | 85.716 | L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
37 | | B/16 | 84.018 | B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | |
38 | | R50-L/32 | 83.784 | R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
39 | | R26-S/32 (light aug) | 80.944 | R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz | [tb.dev run](https://tensorboard.dev/experiment/8rjW26CoRJWdAR3ejtgvHQ/) |
40 | | R26-S/32 (medium aug) | 80.462 | R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz | |
41 | | S/16 | 80.462 | S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | [tb.dev run](https://tensorboard.dev/experiment/52LkVYfnQDykgyDHmWjzBA/) |
42 | | B/32 | 79.436 | B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | |
43 |
44 | Note that the top-1 accuracy is reported on ImageNet-1k validation set. The checkpoints are present in the following GCS
45 | location: `gs://vit_models/augreg`. More details on these can be found in [4].
46 |
47 | ### Image classifiers
48 |
49 | * [ViT-S16](https://tfhub.dev/sayakpaul/vit_s16_classification/1)
50 | * [ViT-B8](https://tfhub.dev/sayakpaul/vit_b8_classification/1)
51 | * [ViT-B16](https://tfhub.dev/sayakpaul/vit_b16_classification/1)
52 | * [ViT-B32](https://tfhub.dev/sayakpaul/vit_b32_classification/1)
53 | * [ViT-L16](https://tfhub.dev/sayakpaul/vit_l16_classification/1)
54 | * [ViT-R26-S32 (light augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_classification/1)
55 | * [ViT-R26-S32 (medium augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_medaug_classification/1)
56 | * [ViT-R50-L32](https://tfhub.dev/sayakpaul/vit_r50_l32_classification/1)
57 |
58 | ### Feature extractors
59 |
60 | * [ViT-S16](https://tfhub.dev/sayakpaul/vit_s16_fe/1)
61 | * [ViT-B8](https://tfhub.dev/sayakpaul/vit_b8_fe/1)
62 | * [ViT-B16](https://tfhub.dev/sayakpaul/vit_b16_fe/1)
63 | * [ViT-B32](https://tfhub.dev/sayakpaul/vit_b32_fe/1)
64 | * [ViT-L16](https://tfhub.dev/sayakpaul/vit_l16_fe/1)
65 | * [ViT-R26-S32 (light augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_fe/1)
66 | * [ViT-R26-S32 (medium augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_medaug_fe/1)
67 | * [ViT-R50-L32](https://tfhub.dev/sayakpaul/vit_r50_l32_fe/1)
68 |
69 | ## Other notebooks
70 |
71 | * [`classification.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/classification.ipynb): Shows how to load a Vision Transformer model from TensorFlow Hub
72 | and run image classification.
73 | * [`fine_tune.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/fine_tune.ipynb): Shows how to
74 | fine-tune a Vision Transformer model from TensorFlow Hub on the `tf_flowers` dataset.
75 |
76 | Additionally, [`i1k_eval`](https://github.com/sayakpaul/ViT-jax2tf/tree/main/i1k_eval) contains files for running
77 | evaluation on ImageNet-1k `validation` split.
78 |
79 | ## References
80 |
81 | [1] [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)
82 |
83 | [2] [ImageNet-1k](https://www.image-net.org/challenges/LSVRC/2012/index.php)
84 |
85 | [3] [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers by Steiner et al.](https://arxiv.org/abs/2106.10270)
86 |
87 | [4] [Vision Transformer GitHub](https://github.com/google-research/vision_transformer)
88 |
89 | [5] [jax2tf tool](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/)
90 |
91 | ## Acknowledgements
92 |
93 | Thanks to the authors of Vision Transformers for their efforts put into open-sourcing
94 | the models.
95 |
96 | Thanks to the [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP Credit support
97 | that helped me execute the experiments for this project.
98 |
--------------------------------------------------------------------------------
/classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "classification",
8 | "provenance": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3 (ipykernel)",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.8.2"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "view-in-github",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "
"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {
43 | "id": "ls5qlUELNIT7"
44 | },
45 | "source": [
46 | "## Imports"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "metadata": {
52 | "id": "lIYdn1woOS1n"
53 | },
54 | "source": [
55 | "import tensorflow as tf\n",
56 | "import tensorflow_hub as hub\n",
57 | "\n",
58 | "from PIL import Image\n",
59 | "from io import BytesIO\n",
60 | "import matplotlib.pyplot as plt\n",
61 | "import numpy as np\n",
62 | "import requests"
63 | ],
64 | "execution_count": null,
65 | "outputs": []
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {
70 | "id": "hyQ87VPcNIUA"
71 | },
72 | "source": [
73 | "## Image preprocessing utilities (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "metadata": {
79 | "id": "HZZC3qJcLQyM"
80 | },
81 | "source": [
82 | "def preprocess_image(image):\n",
83 | " image = np.array(image)\n",
84 | " image_resized = tf.image.resize(image, (224, 224))\n",
85 | " image_resized = tf.cast(image_resized, tf.float32)\n",
86 | " image_resized = (image_resized - 127.5) / 127.5\n",
87 | " return tf.expand_dims(image_resized, 0).numpy()\n",
88 | "\n",
89 | "def load_image_from_url(url):\n",
90 | " response = requests.get(url)\n",
91 | " image = Image.open(BytesIO(response.content))\n",
92 | " image = preprocess_image(image)\n",
93 | " return image\n",
94 | "\n",
95 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
96 | ],
97 | "execution_count": null,
98 | "outputs": []
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {
103 | "id": "YNj3BnemNIUB"
104 | },
105 | "source": [
106 | "## Load image and infer"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "metadata": {
112 | "id": "LPh1RSMXLdVD"
113 | },
114 | "source": [
115 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n",
116 | " lines = f.readlines()\n",
117 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n",
118 | "\n",
119 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n",
120 | "image = load_image_from_url(img_url)\n",
121 | "\n",
122 | "plt.imshow((image[0] + 1) / 2)\n",
123 | "plt.show()"
124 | ],
125 | "execution_count": null,
126 | "outputs": []
127 | },
128 | {
129 | "cell_type": "code",
130 | "metadata": {
131 | "id": "uOo3vYhqLu3Y"
132 | },
133 | "source": [
134 | "model_url = \"https://tfhub.dev/sayakpaul/vit_s16_classification/1\"\n",
135 | "\n",
136 | "classification_model = tf.keras.Sequential(\n",
137 | " [hub.KerasLayer(model_url)]\n",
138 | ") \n",
139 | "predictions = classification_model.predict(image)\n",
140 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n",
141 | "predicted_label"
142 | ],
143 | "execution_count": null,
144 | "outputs": []
145 | }
146 | ]
147 | }
--------------------------------------------------------------------------------
/conversion.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "conversion",
7 | "provenance": [],
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3 (ipykernel)",
12 | "language": "python",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "codemirror_mode": {
17 | "name": "ipython",
18 | "version": 3
19 | },
20 | "file_extension": ".py",
21 | "mimetype": "text/x-python",
22 | "name": "python",
23 | "nbconvert_exporter": "python",
24 | "pygments_lexer": "ipython3",
25 | "version": "3.8.2"
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "view-in-github",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "
"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "id": "J8J1g5vBT5aj"
43 | },
44 | "source": [
45 | "## References\n",
46 | "\n",
47 | "* https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md\n",
48 | "* https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {
54 | "id": "piv05HW04aUW"
55 | },
56 | "source": [
57 | "## Setup"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "metadata": {
63 | "id": "tLVMT01KScv5"
64 | },
65 | "source": [
66 | "!pip install -q absl-py>=0.12.0 chex>=0.0.7 clu>=0.0.3 einops>=0.3.0\n",
67 | "!pip install -q flax==0.3.3 ml-collections==0.1.0 tf-nightly\n",
68 | "!pip install -q numpy>=1.19.5 pandas>=1.1.0"
69 | ],
70 | "execution_count": null,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "code",
75 | "metadata": {
76 | "id": "lIYdn1woOS1n"
77 | },
78 | "source": [
79 | "# Clone repository and pull latest changes.\n",
80 | "![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer\n",
81 | "!cd vision_transformer && git pull"
82 | ],
83 | "execution_count": null,
84 | "outputs": []
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {
89 | "id": "CwBrIdAE4ciM"
90 | },
91 | "source": [
92 | "## Imports"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "aWuEnEshSdzt"
99 | },
100 | "source": [
101 | "import sys\n",
102 | "\n",
103 | "if \"./vision_transformer\" not in sys.path:\n",
104 | " sys.path.append(\"./vision_transformer\")\n",
105 | "\n",
106 | "from vit_jax import models\n",
107 | "from vit_jax import checkpoint\n",
108 | "from vit_jax.configs import common as common_config\n",
109 | "from vit_jax.configs import models as models_config\n",
110 | "\n",
111 | "from jax.experimental import jax2tf\n",
112 | "import tensorflow as tf\n",
113 | "import flax\n",
114 | "import jax\n",
115 | "\n",
116 | "from PIL import Image\n",
117 | "from io import BytesIO\n",
118 | "import numpy as np\n",
119 | "import requests"
120 | ],
121 | "execution_count": null,
122 | "outputs": []
123 | },
124 | {
125 | "cell_type": "code",
126 | "metadata": {
127 | "id": "RFV5xTgW28ys"
128 | },
129 | "source": [
130 | "print(f\"JAX version: {jax.__version__}\")\n",
131 | "print(f\"FLAX version: {flax.__version__}\")\n",
132 | "print(f\"TensorFlow version: {tf.__version__}\")"
133 | ],
134 | "execution_count": null,
135 | "outputs": []
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {
140 | "id": "Vt5uGYJH3LXM"
141 | },
142 | "source": [
143 | "## Classification / Feature Extractor model"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "metadata": {
149 | "id": "hKh0k1M7SgSN"
150 | },
151 | "source": [
152 | "#@title Choose a model type\n",
153 | "VIT_MODELS = \"B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224\" #@param [\"L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224\", \"B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224\", \"R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224\", \"R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224\", \"R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224\", \"S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224\", \"B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224\", \"B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224\"]\n",
154 | "#@markdown The models were selected based on the criteria shown here in [this notebook](https://github.com/sayakpaul/ViT-jax2tf/blob/main/model-selector.ipynb).\n",
155 | "\n",
156 | "print(f\"Model type selected: ViT-{VIT_MODELS.split('-')[0]}\")\n",
157 | "\n",
158 | "ROOT_GCS_PATH = \"gs://vit_models/augreg\""
159 | ],
160 | "execution_count": null,
161 | "outputs": []
162 | },
163 | {
164 | "cell_type": "code",
165 | "metadata": {
166 | "id": "bjlhNR62-JJJ"
167 | },
168 | "source": [
169 | "classification_model = True\n",
170 | "\n",
171 | "if classification_model:\n",
172 | " num_classes = 1000\n",
173 | " print(\"Will be converting a classification model.\")\n",
174 | "else:\n",
175 | " num_classes = None\n",
176 | " print(\"Will be converting a feature extraction model.\")"
177 | ],
178 | "execution_count": null,
179 | "outputs": []
180 | },
181 | {
182 | "cell_type": "code",
183 | "metadata": {
184 | "id": "mclDoMqbShzV"
185 | },
186 | "source": [
187 | "# Instantiate model class and load the corresponding checkpoints.\n",
188 | "config = common_config.get_config()\n",
189 | "config.model = models_config.AUGREG_CONFIGS[f\"{VIT_MODELS.split('-')[0]}\"]\n",
190 | "\n",
191 | "model = models.VisionTransformer(num_classes=num_classes, **config.model)\n",
192 | "\n",
193 | "path = f\"{ROOT_GCS_PATH}/{VIT_MODELS}.npz\"\n",
194 | "params = checkpoint.load(path)\n",
195 | "\n",
196 | "if not num_classes:\n",
197 | " _ = params.pop(\"head\")"
198 | ],
199 | "execution_count": null,
200 | "outputs": []
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {
205 | "id": "RqbuBCFw9vyg"
206 | },
207 | "source": [
208 | "## Conversion\n",
209 | "\n",
210 | "Code has been reused from the official examples [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)."
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {
216 | "id": "XT2GLwXg95tE"
217 | },
218 | "source": [
219 | "### Step 1: Get a prediction function out of the JAX model & convert it to a native TF function"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "metadata": {
225 | "id": "J2e1h-F2SmDB"
226 | },
227 | "source": [
228 | "predict_fn = lambda params, inputs: model.apply(\n",
229 | " dict(params=params), inputs, train=False\n",
230 | ")\n",
231 | "\n",
232 | "with_gradient = False if num_classes else True\n",
233 | "tf_fn = jax2tf.convert(\n",
234 | " predict_fn,\n",
235 | " with_gradient=with_gradient,\n",
236 | " polymorphic_shapes=[None, \"b, 224, 224, 3\"],\n",
237 | ")"
238 | ],
239 | "execution_count": null,
240 | "outputs": []
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {
245 | "id": "cRGAjnKBRGgU"
246 | },
247 | "source": [
248 | "We set `polymorphic_shapes` to allow the converted model operate with arbitrary batch sizes. Know more about the shape polymorphism in JAX from [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)."
249 | ]
250 | },
251 | {
252 | "cell_type": "markdown",
253 | "metadata": {
254 | "id": "PKE1msyx-3ge"
255 | },
256 | "source": [
257 | "### Step 2: Set the trainability of the individual param groups and construct TF graph"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "metadata": {
263 | "id": "8RNlRp9pTHgF"
264 | },
265 | "source": [
266 | "param_vars = tf.nest.map_structure(\n",
267 | " lambda param: tf.Variable(param, trainable=with_gradient), params\n",
268 | ")\n",
269 | "tf_graph = tf.function(\n",
270 | " lambda inputs: tf_fn(param_vars, inputs), autograph=False, jit_compile=True\n",
271 | ")"
272 | ],
273 | "execution_count": null,
274 | "outputs": []
275 | },
276 | {
277 | "cell_type": "markdown",
278 | "metadata": {
279 | "id": "3fDRubHD_Sf3"
280 | },
281 | "source": [
282 | "### Step 3: Serialize as a SavedModel"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "metadata": {
288 | "cellView": "form",
289 | "id": "1QJQwDEyTs2V"
290 | },
291 | "source": [
292 | "#@title SavedModel wrapper class utility from [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py#L128)\n",
293 | "class _ReusableSavedModelWrapper(tf.train.Checkpoint):\n",
294 | " \"\"\"Wraps a function and its parameters for saving to a SavedModel.\n",
295 | " Implements the interface described at\n",
296 | " https://www.tensorflow.org/hub/reusable_saved_models.\n",
297 | " \"\"\"\n",
298 | "\n",
299 | " def __init__(self, tf_graph, param_vars):\n",
300 | " \"\"\"Args:\n",
301 | " tf_graph: a tf.function taking one argument (the inputs), which can be\n",
302 | " be tuples/lists/dictionaries of np.ndarray or tensors. The function\n",
303 | " may have references to the tf.Variables in `param_vars`.\n",
304 | " param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,\n",
305 | " to be saved as the variables of the SavedModel.\n",
306 | " \"\"\"\n",
307 | " super().__init__()\n",
308 | " # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models\n",
309 | " self.variables = tf.nest.flatten(param_vars)\n",
310 | " self.trainable_variables = [v for v in self.variables if v.trainable]\n",
311 | " # If you intend to prescribe regularization terms for users of the model,\n",
312 | " # add them as @tf.functions with no inputs to this list. Else drop this.\n",
313 | " self.regularization_losses = []\n",
314 | " self.__call__ = tf_graph\n"
315 | ],
316 | "execution_count": null,
317 | "outputs": []
318 | },
319 | {
320 | "cell_type": "code",
321 | "metadata": {
322 | "id": "xr2Vf9Ql_lca"
323 | },
324 | "source": [
325 | "input_signatures = [tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32)]\n",
326 | "model_dir = VIT_MODELS if num_classes else f\"{VIT_MODELS}_fe\"\n",
327 | "signatures = {}\n",
328 | "saved_model_options = None\n",
329 | "\n",
330 | "print(f\"Saving model to {model_dir} directory.\")"
331 | ],
332 | "execution_count": null,
333 | "outputs": []
334 | },
335 | {
336 | "cell_type": "code",
337 | "metadata": {
338 | "id": "pMn9fJxuTKON"
339 | },
340 | "source": [
341 | "signatures[\n",
342 | " tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n",
343 | "] = tf_graph.get_concrete_function(input_signatures[0])\n",
344 | "\n",
345 | "wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)\n",
346 | "if with_gradient:\n",
347 | " if not saved_model_options:\n",
348 | " saved_model_options = tf.saved_model.SaveOptions(\n",
349 | " experimental_custom_gradients=True\n",
350 | " )\n",
351 | " else:\n",
352 | " saved_model_options.experimental_custom_gradients = True\n",
353 | "tf.saved_model.save(\n",
354 | " wrapper, model_dir, signatures=signatures, options=saved_model_options\n",
355 | ")\n",
356 | "\n",
357 | "# Note that directly saving the `wrapper` to a GCS location is\n",
358 | "# also supported."
359 | ],
360 | "execution_count": null,
361 | "outputs": []
362 | },
363 | {
364 | "cell_type": "markdown",
365 | "metadata": {
366 | "id": "2PJr-uVs_vz-"
367 | },
368 | "source": [
369 | "## Functional test (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))"
370 | ]
371 | },
372 | {
373 | "cell_type": "markdown",
374 | "metadata": {
375 | "id": "NA2G4HzvC5_l"
376 | },
377 | "source": [
378 | "### Image preprocessing utilities "
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "metadata": {
384 | "id": "XyvjkBAE5iFL"
385 | },
386 | "source": [
387 | "def preprocess_image(image):\n",
388 | " image = np.array(image)\n",
389 | " image_resized = tf.image.resize(image, (224, 224))\n",
390 | " image_resized = tf.cast(image_resized, tf.float32)\n",
391 | " image_resized = (image_resized - 127.5) / 127.5\n",
392 | " return tf.expand_dims(image_resized, 0).numpy()\n",
393 | "\n",
394 | "def load_image_from_url(url):\n",
395 | " response = requests.get(url)\n",
396 | " image = Image.open(BytesIO(response.content))\n",
397 | " image = preprocess_image(image)\n",
398 | " return image\n",
399 | "\n",
400 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt"
401 | ],
402 | "execution_count": null,
403 | "outputs": []
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "metadata": {
408 | "id": "Hd-YH-hqAIQ9"
409 | },
410 | "source": [
411 | "### Load image and ImageNet-1k class mappings"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "metadata": {
417 | "id": "4vDQd6MEAEp_"
418 | },
419 | "source": [
420 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n",
421 | " lines = f.readlines()\n",
422 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n",
423 | "\n",
424 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n",
425 | "image = load_image_from_url(img_url)"
426 | ],
427 | "execution_count": null,
428 | "outputs": []
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "metadata": {
433 | "id": "9A-LxOYBANnv"
434 | },
435 | "source": [
436 | "### Inference\n",
437 | "\n",
438 | "This is only application for the classification models. For fine-tuning/feature extraction, please follow [this notebook](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/fine_tune.ipynb) instead."
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "metadata": {
444 | "id": "99DTGYB25o5d"
445 | },
446 | "source": [
447 | "# Load the converted SavedModel and check whether it finds the elephant.\n",
448 | "restored_model = tf.saved_model.load(model_dir)\n",
449 | "predictions = restored_model.signatures[\"serving_default\"](tf.constant(image))\n",
450 | "logits = predictions[\"output_0\"][0]\n",
451 | "predicted_label = imagenet_int_to_str[int(np.argmax(logits))]\n",
452 | "expected_label = \"Indian_elephant, Elephas_maximus\"\n",
453 | "assert (\n",
454 | " predicted_label == expected_label\n",
455 | "), f\"Expected {expected_label} but was {predicted_label}\""
456 | ],
457 | "execution_count": null,
458 | "outputs": []
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {
463 | "id": "itfezjKjAXA6"
464 | },
465 | "source": [
466 | "## Inference with TensorFlow Hub \n",
467 | "\n",
468 | "Run the following code snippet. You can also follow [this notebook](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/classification.ipynb). "
469 | ]
470 | },
471 | {
472 | "cell_type": "markdown",
473 | "metadata": {
474 | "id": "HR8lV2377Ad3"
475 | },
476 | "source": [
477 | "```python\n",
478 | "import tensorflow_hub as hub\n",
479 | "\n",
480 | "classification_model = tf.keras.Sequential([hub.KerasLayer(model_dir)])\n",
481 | "predictions = classification_model.predict(image)\n",
482 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n",
483 | "predicted_label\n",
484 | "```"
485 | ]
486 | }
487 | ]
488 | }
489 |
--------------------------------------------------------------------------------
/fine_tune.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "Yx9kQrATLdy5"
7 | },
8 | "source": [
9 | "## Imports"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "UqO1f2Z7QoOC"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "from tensorflow import keras\n",
21 | "import tensorflow as tf\n",
22 | "import tensorflow_hub as hub\n",
23 | "\n",
24 | "import tensorflow_datasets as tfds\n",
25 | "\n",
26 | "tfds.disable_progress_bar()\n",
27 | "\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "import numpy as np"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {
35 | "id": "yS9v0obPLgVy"
36 | },
37 | "source": [
38 | "## Model building utility"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {
45 | "id": "yZ0gsA41RVVM"
46 | },
47 | "outputs": [],
48 | "source": [
49 | "def get_model(\n",
50 | " handle=\"https://tfhub.dev/sayakpaul/vit_s16_fe/1\", \n",
51 | " num_classes=5,\n",
52 | "):\n",
53 | " hub_layer = hub.KerasLayer(handle, trainable=True)\n",
54 | "\n",
55 | " model = keras.Sequential(\n",
56 | " [\n",
57 | " keras.layers.InputLayer((224, 224, 3)),\n",
58 | " hub_layer,\n",
59 | " keras.layers.Dense(num_classes, activation=\"softmax\"),\n",
60 | " ]\n",
61 | " )\n",
62 | "\n",
63 | " return model"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "id": "jk5VZIafSxGv"
71 | },
72 | "outputs": [],
73 | "source": [
74 | "get_model().summary()"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {
80 | "id": "ELO3it7aLk6E"
81 | },
82 | "source": [
83 | "## Data input pipeline\n",
84 | "\n",
85 | "Code has been reused from the [official repository](https://github.com/google-research/vision_transformer/blob/main/vit_jax/input_pipeline.py)."
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {
92 | "id": "hE2MxLkbEfa7"
93 | },
94 | "outputs": [],
95 | "source": [
96 | "BATCH_SIZE = 64\n",
97 | "AUTO = tf.data.AUTOTUNE"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "Ouk89SHNS9us"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = 224):\n",
109 | " def preprocess(image, label):\n",
110 | " # For training, do a random crop and horizontal flip.\n",
111 | " if train:\n",
112 | " channels = image.shape[-1]\n",
113 | " begin, size, _ = tf.image.sample_distorted_bounding_box(\n",
114 | " tf.shape(image),\n",
115 | " tf.zeros([0, 0, 4], tf.float32),\n",
116 | " area_range=(0.05, 1.0),\n",
117 | " min_object_covered=0,\n",
118 | " use_image_if_no_bounding_boxes=True,\n",
119 | " )\n",
120 | " image = tf.slice(image, begin, size)\n",
121 | "\n",
122 | " image.set_shape([None, None, channels])\n",
123 | " image = tf.image.resize(image, [image_size, image_size])\n",
124 | " if tf.random.uniform(shape=[]) > 0.5:\n",
125 | " image = tf.image.flip_left_right(image)\n",
126 | "\n",
127 | " else:\n",
128 | " image = tf.image.resize(image, [image_size, image_size])\n",
129 | "\n",
130 | " image = (image - 127.5) / 127.5\n",
131 | " return image, label\n",
132 | "\n",
133 | " if train:\n",
134 | " dataset = dataset.shuffle(BATCH_SIZE * 10)\n",
135 | "\n",
136 | " return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {
142 | "id": "Hl3i3onrLtO6"
143 | },
144 | "source": [
145 | "## `tf_flowers` dataset"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {
152 | "id": "lHJFUqFZE-0F"
153 | },
154 | "outputs": [],
155 | "source": [
156 | "train_dataset, val_dataset = tfds.load(\n",
157 | " \"tf_flowers\", split=[\"train[:90%]\", \"train[90%:]\"], as_supervised=True\n",
158 | ")\n",
159 | "\n",
160 | "num_train = tf.data.experimental.cardinality(train_dataset)\n",
161 | "num_val = tf.data.experimental.cardinality(val_dataset)\n",
162 | "print(f\"Number of training examples: {num_train}\")\n",
163 | "print(f\"Number of validation examples: {num_val}\")"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "id": "rnjfXf8RLzCp"
170 | },
171 | "source": [
172 | "### Prepare dataset"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {
179 | "id": "DRd8kkcMFxSw"
180 | },
181 | "outputs": [],
182 | "source": [
183 | "train_dataset = make_dataset(train_dataset, True)\n",
184 | "val_dataset = make_dataset(val_dataset, False)"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {
190 | "id": "HbeqUHVdLz5J"
191 | },
192 | "source": [
193 | "### Visualize"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "metadata": {
200 | "id": "O8Ui5B8hGNXQ"
201 | },
202 | "outputs": [],
203 | "source": [
204 | "sample_images, _ = next(iter(train_dataset))\n",
205 | "\n",
206 | "plt.figure(figsize=(10, 10))\n",
207 | "for n in range(25):\n",
208 | " ax = plt.subplot(5, 5, n + 1)\n",
209 | " plt.imshow((sample_images[n].numpy() + 1) / 2)\n",
210 | " plt.axis(\"off\")\n",
211 | "plt.show()"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "metadata": {
217 | "id": "h1XNooD2L2IY"
218 | },
219 | "source": [
220 | "## Learning rate scheduling \n",
221 | "\n",
222 | "For fine-tuning the authors follow a warm-up + [cosine | linear] schedule as per the [official notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/linen/vit_jax.ipynb). "
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {
229 | "id": "3cFQQaQoGjuF"
230 | },
231 | "outputs": [],
232 | "source": [
233 | "# Reference:\n",
234 | "# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2\n",
235 | "\n",
236 | "\n",
237 | "class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):\n",
238 | " def __init__(\n",
239 | " self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps\n",
240 | " ):\n",
241 | " super(WarmUpCosine, self).__init__()\n",
242 | "\n",
243 | " self.learning_rate_base = learning_rate_base\n",
244 | " self.total_steps = total_steps\n",
245 | " self.warmup_learning_rate = warmup_learning_rate\n",
246 | " self.warmup_steps = warmup_steps\n",
247 | " self.pi = tf.constant(np.pi)\n",
248 | "\n",
249 | " def __call__(self, step):\n",
250 | " if self.total_steps < self.warmup_steps:\n",
251 | " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
252 | " learning_rate = (\n",
253 | " 0.5\n",
254 | " * self.learning_rate_base\n",
255 | " * (\n",
256 | " 1\n",
257 | " + tf.cos(\n",
258 | " self.pi\n",
259 | " * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
260 | " / float(self.total_steps - self.warmup_steps)\n",
261 | " )\n",
262 | " )\n",
263 | " )\n",
264 | "\n",
265 | " if self.warmup_steps > 0:\n",
266 | " if self.learning_rate_base < self.warmup_learning_rate:\n",
267 | " raise ValueError(\n",
268 | " \"Learning_rate_base must be larger or equal to \"\n",
269 | " \"warmup_learning_rate.\"\n",
270 | " )\n",
271 | " slope = (\n",
272 | " self.learning_rate_base - self.warmup_learning_rate\n",
273 | " ) / self.warmup_steps\n",
274 | " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
275 | " learning_rate = tf.where(\n",
276 | " step < self.warmup_steps, warmup_rate, learning_rate\n",
277 | " )\n",
278 | " return tf.where(\n",
279 | " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
280 | " )"
281 | ]
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "metadata": {
286 | "id": "1-fPavRyMH1X"
287 | },
288 | "source": [
289 | "## Training hyperparameters\n",
290 | "\n",
291 | "These have been referred from the official notebooks ([1](https://colab.research.google.com/github/google-research/vision_transformer/blob/linen/vit_jax.ipynb) and [2](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb)). \n",
292 | "\n",
293 | "Differences:\n",
294 | "\n",
295 | "* No gradient accumulation\n",
296 | "* Lower batch size for demoing on a single GPU (64 as opposed to 512)"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "metadata": {
303 | "id": "bgsWyyhgHAaB"
304 | },
305 | "outputs": [],
306 | "source": [
307 | "EPOCHS = 8\n",
308 | "TOTAL_STEPS = int((num_train / BATCH_SIZE) * EPOCHS)\n",
309 | "WARMUP_STEPS = 10\n",
310 | "INIT_LR = 0.03\n",
311 | "WAMRUP_LR = 0.006\n",
312 | "\n",
313 | "print(TOTAL_STEPS)"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "metadata": {
320 | "id": "LppzZVdbHwPE"
321 | },
322 | "outputs": [],
323 | "source": [
324 | "scheduled_lrs = WarmUpCosine(\n",
325 | " learning_rate_base=INIT_LR,\n",
326 | " total_steps=TOTAL_STEPS,\n",
327 | " warmup_learning_rate=WAMRUP_LR,\n",
328 | " warmup_steps=WARMUP_STEPS,\n",
329 | ")\n",
330 | "\n",
331 | "lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]\n",
332 | "plt.plot(lrs)\n",
333 | "plt.xlabel(\"Step\", fontsize=14)\n",
334 | "plt.ylabel(\"LR\", fontsize=14)\n",
335 | "plt.show()"
336 | ]
337 | },
338 | {
339 | "cell_type": "markdown",
340 | "metadata": {
341 | "id": "i__WUIfcMpDk"
342 | },
343 | "source": [
344 | "### Optimizer and loss function"
345 | ]
346 | },
347 | {
348 | "cell_type": "code",
349 | "execution_count": null,
350 | "metadata": {
351 | "id": "yblMGEOjIiyP"
352 | },
353 | "outputs": [],
354 | "source": [
355 | "optimizer = keras.optimizers.SGD(scheduled_lrs, clipnorm=1.0)\n",
356 | "loss = keras.losses.SparseCategoricalCrossentropy()"
357 | ]
358 | },
359 | {
360 | "cell_type": "markdown",
361 | "metadata": {
362 | "id": "apdPFM_TMsJg"
363 | },
364 | "source": [
365 | "## Model training and validation"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": null,
371 | "metadata": {
372 | "id": "wd5XVBMZJlu_"
373 | },
374 | "outputs": [],
375 | "source": [
376 | "model = get_model()\n",
377 | "model.compile(loss=loss, optimizer=optimizer, metrics=[\"accuracy\"])\n",
378 | "\n",
379 | "history = model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": null,
385 | "metadata": {
386 | "id": "InPOQqfJK3gR"
387 | },
388 | "outputs": [],
389 | "source": [
390 | "plt.figure(figsize=(7, 7))\n",
391 | "history = history.history\n",
392 | "\n",
393 | "plt.plot(history[\"loss\"], label=\"train_loss\")\n",
394 | "plt.plot(history[\"val_loss\"], label=\"val_loss\")\n",
395 | "plt.plot(history[\"accuracy\"], label=\"train_accuracy\")\n",
396 | "plt.plot(history[\"val_accuracy\"], label=\"val_accuracy\")\n",
397 | "\n",
398 | "plt.legend()\n",
399 | "plt.show()"
400 | ]
401 | }
402 | ],
403 | "metadata": {
404 | "accelerator": "GPU",
405 | "colab": {
406 | "collapsed_sections": [],
407 | "include_colab_link": true,
408 | "machine_shape": "hm",
409 | "name": "fine-tune.ipynb",
410 | "provenance": []
411 | },
412 | "kernelspec": {
413 | "display_name": "Python 3 (ipykernel)",
414 | "language": "python",
415 | "name": "python3"
416 | },
417 | "language_info": {
418 | "codemirror_mode": {
419 | "name": "ipython",
420 | "version": 3
421 | },
422 | "file_extension": ".py",
423 | "mimetype": "text/x-python",
424 | "name": "python",
425 | "nbconvert_exporter": "python",
426 | "pygments_lexer": "ipython3",
427 | "version": "3.8.2"
428 | }
429 | },
430 | "nbformat": 4,
431 | "nbformat_minor": 1
432 | }
433 |
--------------------------------------------------------------------------------
/i1k_eval/README.md:
--------------------------------------------------------------------------------
1 | This directory provides a notebook and ImageNet-1k class mapping file to run evaluation on the ImageNet-1k `validation` split using the [ViT models from TF-Hub](https://tfhub.dev/sayakpaul/collections/vision_transformer/1). One should use this same setup to evaluate the [MLP-Mixer models from TF-Hub](https://tfhub.dev/sayakpaul/collections/mlp-mixer/1). The notebook assumes the following files are present in your working directory:
2 |
3 | * The `validation` split directory of ImageNet-1k.
4 | * The class mapping files (`.json`).
5 |
--------------------------------------------------------------------------------
/i1k_eval/eval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "cc520ed3-785b-4491-aa40-4346b72a4574",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import tensorflow_datasets as tfds\n",
11 | "import tensorflow_hub as hub\n",
12 | "import tensorflow as tf\n",
13 | "\n",
14 | "from imutils import paths\n",
15 | "import json"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "id": "86454d69-4857-45db-a98f-abc5b1c97e3d",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "try:\n",
26 | " tpu = None\n",
27 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
28 | " tf.config.experimental_connect_to_cluster(tpu)\n",
29 | " tf.tpu.experimental.initialize_tpu_system(tpu)\n",
30 | " strategy = tf.distribute.TPUStrategy(tpu)\n",
31 | "except ValueError:\n",
32 | " strategy = tf.distribute.MirroredStrategy()\n",
33 | "\n",
34 | "print(\"Number of accelerators: \", strategy.num_replicas_in_sync)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "id": "c3f047d5-e54d-43c2-b19d-0ce5e56e9ff6",
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "AUTO = tf.data.AUTOTUNE\n",
45 | "BATCH_SIZE = 128 * strategy.num_replicas_in_sync"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "id": "c0bbd0fe-08b9-4d6a-ac19-ab1322e20594",
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "with open(\"imagenet_class_index.json\", \"r\") as read_file:\n",
56 | " imagenet_labels = json.load(read_file)\n",
57 | "\n",
58 | "MAPPING_DICT = {}\n",
59 | "LABEL_NAMES = {}\n",
60 | "for label_id in list(imagenet_labels.keys()):\n",
61 | " MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)\n",
62 | " LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "id": "3c9ad7b4-ce65-42c9-b9f6-b7d9a3e2a030",
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "all_val_paths = list(paths.list_images(\"val\"))\n",
73 | "all_val_labels = [MAPPING_DICT[x.split(\"/\")[1]] for x in all_val_paths]\n",
74 | "\n",
75 | "all_val_paths[:5], all_val_labels[:5]"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "id": "3cb2086a-b442-4bd0-8b7d-f717b8816286",
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "def load_and_prepare(path, label):\n",
86 | " image = tf.io.read_file(path)\n",
87 | " image = tf.image.decode_png(image, channels=3)\n",
88 | " image = tf.image.resize(image, (224, 224))\n",
89 | "\n",
90 | " return image, label"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "id": "289f1a88-70b1-4a99-9a95-08daf4f38290",
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))\n",
101 | "\n",
102 | "dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)\n",
103 | "dataset = dataset.prefetch(AUTO)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "id": "14bd0f87-24ce-48cf-b263-e878f260e618",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "def get_model(model_url=\"https://tfhub.dev/sayakpaul/vit_s16_classification/1\"):\n",
114 | " classification_model = tf.keras.Sequential(\n",
115 | " [\n",
116 | " tf.keras.layers.InputLayer((224, 224, 3)),\n",
117 | " tf.keras.layers.Rescaling(\n",
118 | " scale=1.0 / 127.5, offset=-1\n",
119 | " ), # Scales to [-1, 1].\n",
120 | " hub.KerasLayer(model_url),\n",
121 | " ]\n",
122 | " )\n",
123 | " return classification_model"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "id": "d0de72ef-495d-41ae-9875-809101b3c177",
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def eval_util(model_url, arch):\n",
134 | " tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f\"logs_{arch}\")\n",
135 | " with strategy.scope():\n",
136 | " model = get_model(model_url)\n",
137 | " model.compile(metrics=[\"accuracy\"])\n",
138 | " model.evaluate(dataset, callbacks=[tb_callback])"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "e2287612-8d1a-4131-87cd-5d0798a6b313",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "model_urls = [\n",
149 | " \"https://tfhub.dev/sayakpaul/vit_s16_classification/1\",\n",
150 | " \"https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_classification/1\",\n",
151 | "]\n",
152 | "\n",
153 | "archs = [\"s16\", \"r26_s32\"]\n",
154 | "\n",
155 | "for model_url, arch in zip(model_urls, archs):\n",
156 | " print(f\"Evaluating {arch}\")\n",
157 | " eval_util(model_url, arch)"
158 | ]
159 | }
160 | ],
161 | "metadata": {
162 | "environment": {
163 | "name": "tf2-gpu.2-6.m80",
164 | "type": "gcloud",
165 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m80"
166 | },
167 | "kernelspec": {
168 | "display_name": "Python 3 (ipykernel)",
169 | "language": "python",
170 | "name": "python3"
171 | },
172 | "language_info": {
173 | "codemirror_mode": {
174 | "name": "ipython",
175 | "version": 3
176 | },
177 | "file_extension": ".py",
178 | "mimetype": "text/x-python",
179 | "name": "python",
180 | "nbconvert_exporter": "python",
181 | "pygments_lexer": "ipython3",
182 | "version": "3.8.2"
183 | }
184 | },
185 | "nbformat": 4,
186 | "nbformat_minor": 5
187 | }
188 |
--------------------------------------------------------------------------------
/i1k_eval/imagenet_class_index.json:
--------------------------------------------------------------------------------
1 | {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"], "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"], "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"], "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"], "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"], "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"], "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"], "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"], "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"], "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"], "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"], "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"], "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"], "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"], "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"], "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"], "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"], "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"], "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"], "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"], "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"], "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"], "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"], "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"], "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"], "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"], "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"], "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"], "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"], "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"], "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"], "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"], "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"], "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"], "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"], "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"], "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"], "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"], "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"], "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"], "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"], "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"], "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"], "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"], "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"], "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"], "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"], "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"], "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"], "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"], "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"], "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"], "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"], "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"], "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"], "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"], "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"], "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"], "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"], "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"], "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"], "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"], "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"], "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"], "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"], "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"], "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"], "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"], "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"], "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"], "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"], "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"], "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"], "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"], "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"], "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"], "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"], "179": ["n02093256", "Staffordshire_bullterrier"], "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"], "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"], "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"], "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"], "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"], "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"], "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"], "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"], "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"], "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"], "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"], "202": ["n02098105", "soft-coated_wheaten_terrier"], "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"], "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"], "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"], "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"], "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"], "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"], "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"], "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"], "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"], "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"], "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"], "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"], "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"], "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"], "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"], "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"], "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"], "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"], "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"], "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"], "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"], "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"], "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"], "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"], "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"], "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"], "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"], "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"], "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"], "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"], "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"], "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"], "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"], "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"], "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"], "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"], "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"], "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"], "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"], "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"], "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"], "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"], "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"], "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"], "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"], "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"], "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"], "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"], "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"], "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"], "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"], "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"], "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"], "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"], "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"], "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"], "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"], "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"], "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"], "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"], "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"], "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"], "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"], "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"], "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"], "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"], "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"], "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"], "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"], "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"], "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"], "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"], "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"], "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"], "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"], "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"], "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"], "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"], "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"], "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"], "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"], "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"], "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"], "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"], "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"], "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"], "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"], "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"], "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"], "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"], "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"], "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"], "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"], "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"], "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"], "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"], "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"], "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"], "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"], "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"], "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"], "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"], "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"], "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"], "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"], "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"], "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"], "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"], "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"], "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"], "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"], "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"], "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"], "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"], "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"], "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"], "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"], "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"], "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"], "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"], "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"], "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"], "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"], "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"], "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"], "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"], "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"], "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"], "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"], "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"], "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"], "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"], "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"], "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"], "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"], "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"], "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"], "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"], "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"], "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"], "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"], "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"], "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"], "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"], "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"], "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"], "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"], "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"], "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"], "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"], "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"], "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"], "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"], "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"], "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"], "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"], "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"], "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"], "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"], "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"], "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"], "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"], "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"], "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"], "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"], "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"], "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"], "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"], "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"], "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"], "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"], "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"], "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"], "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"], "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"], "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"], "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"], "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"], "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"], "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"], "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"], "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"], "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"], "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"], "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"], "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"], "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"], "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"], "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"], "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"], "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"], "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"], "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"], "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"], "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"], "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"], "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"], "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"], "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"], "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"], "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"], "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"], "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"], "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"], "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"], "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"], "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"], "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"], "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"], "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"], "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"], "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"], "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"], "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"], "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"], "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"], "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"], "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"], "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"], "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"], "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"], "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"], "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"], "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"], "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"], "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"], "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"], "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"], "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"], "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"], "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"], "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"], "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"], "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"], "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"], "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"], "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"], "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"], "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"], "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"], "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"], "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"], "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"], "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"], "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"], "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"], "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"], "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"], "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"], "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"], "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"], "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"], "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"], "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"], "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"], "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"], "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"], "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"], "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"], "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"], "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"], "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"], "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"], "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"], "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"], "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"], "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"], "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"], "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"], "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"], "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"], "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"], "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"], "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"], "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"], "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"], "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"], "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"], "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"], "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"], "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"], "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"], "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"], "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"], "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"], "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"], "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"], "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"], "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"], "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"], "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"], "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"], "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"], "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"], "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"], "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"], "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"], "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"], "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"], "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"], "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"], "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"], "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"], "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"], "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"], "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"], "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"], "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"], "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"], "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"], "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"], "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"], "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"], "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"], "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"], "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"], "999": ["n15075141", "toilet_tissue"]}
--------------------------------------------------------------------------------
/model-selector.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "model-selector",
7 | "provenance": [],
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3 (ipykernel)",
12 | "language": "python",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "codemirror_mode": {
17 | "name": "ipython",
18 | "version": 3
19 | },
20 | "file_extension": ".py",
21 | "mimetype": "text/x-python",
22 | "name": "python",
23 | "nbconvert_exporter": "python",
24 | "pygments_lexer": "ipython3",
25 | "version": "3.8.2"
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "view-in-github",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "
"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "id": "yXeTsp2E7Wxq"
43 | },
44 | "source": [
45 | "## Reference\n",
46 | "\n",
47 | "* [vit_jax_augreg.ipynb](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "metadata": {
53 | "id": "lIYdn1woOS1n"
54 | },
55 | "source": [
56 | "import tensorflow as tf\n",
57 | "import pandas as pd"
58 | ],
59 | "execution_count": null,
60 | "outputs": []
61 | },
62 | {
63 | "cell_type": "code",
64 | "metadata": {
65 | "id": "KqgZLLHE7SC7"
66 | },
67 | "source": [
68 | "# Load master table from Cloud.\n",
69 | "with tf.io.gfile.GFile(\"gs://vit_models/augreg/index.csv\") as f:\n",
70 | " df = pd.read_csv(f)"
71 | ],
72 | "execution_count": null,
73 | "outputs": []
74 | },
75 | {
76 | "cell_type": "code",
77 | "metadata": {
78 | "colab": {
79 | "base_uri": "https://localhost:8080/"
80 | },
81 | "id": "bto9lWO17e4x",
82 | "outputId": "207ee223-8e25-4222-c3a7-999434ab91ef"
83 | },
84 | "source": [
85 | "df.columns"
86 | ],
87 | "execution_count": null,
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/plain": [
92 | "Index(['name', 'ds', 'epochs', 'lr', 'aug', 'wd', 'do', 'sd', 'best_val',\n",
93 | " 'final_val', 'final_test', 'adapt_ds', 'adapt_lr', 'adapt_steps',\n",
94 | " 'adapt_resolution', 'adapt_final_val', 'adapt_final_test', 'params',\n",
95 | " 'infer_samples_per_sec', 'filename', 'adapt_filename'],\n",
96 | " dtype='object')"
97 | ]
98 | },
99 | "execution_count": 3,
100 | "metadata": {},
101 | "output_type": "execute_result"
102 | }
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "metadata": {
108 | "colab": {
109 | "base_uri": "https://localhost:8080/"
110 | },
111 | "id": "RdC3XNws74AY",
112 | "outputId": "59bca181-8985-4e0e-88b8-c9fbaf0a0c4a"
113 | },
114 | "source": [
115 | "# How many different pre-training datasets?\n",
116 | "df[\"ds\"].value_counts()"
117 | ],
118 | "execution_count": null,
119 | "outputs": [
120 | {
121 | "data": {
122 | "text/plain": [
123 | "i21k 17238\n",
124 | "i1k 17136\n",
125 | "i21k_30 17135\n",
126 | "Name: ds, dtype: int64"
127 | ]
128 | },
129 | "execution_count": 4,
130 | "metadata": {},
131 | "output_type": "execute_result"
132 | }
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {
138 | "id": "-_AWCq3d_G2I"
139 | },
140 | "source": [
141 | "Filter based on the following criteria:\n",
142 | "\n",
143 | "* Models should be pre-trained on ImageNet-21k and fine-tuned on ImageNet-1k.\n",
144 | "* The final ImageNet-1k validation accuracy should be at least 75%. \n",
145 | "* The transfer resolution should be 224 $\\times$ 224."
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 632
154 | },
155 | "id": "K7Dmm_aU7l2P",
156 | "outputId": "e3c17dd9-2983-4fa5-f2b4-df9515e1f510"
157 | },
158 | "source": [
159 | "i21k_i1k_models = df.query(\"ds=='i21k' & adapt_ds=='imagenet2012'\")\n",
160 | "models_ge_75 = i21k_i1k_models.query(\"adapt_final_test >= 0.75 & adapt_resolution==224\")\n",
161 | "models_ge_75.head()"
162 | ],
163 | "execution_count": null,
164 | "outputs": [
165 | {
166 | "data": {
167 | "text/html": [
168 | "\n",
169 | "\n",
182 | "
\n",
183 | " \n",
184 | " \n",
185 | " | \n",
186 | " name | \n",
187 | " ds | \n",
188 | " epochs | \n",
189 | " lr | \n",
190 | " aug | \n",
191 | " wd | \n",
192 | " do | \n",
193 | " sd | \n",
194 | " best_val | \n",
195 | " final_val | \n",
196 | " ... | \n",
197 | " adapt_ds | \n",
198 | " adapt_lr | \n",
199 | " adapt_steps | \n",
200 | " adapt_resolution | \n",
201 | " adapt_final_val | \n",
202 | " adapt_final_test | \n",
203 | " params | \n",
204 | " infer_samples_per_sec | \n",
205 | " filename | \n",
206 | " adapt_filename | \n",
207 | "
\n",
208 | " \n",
209 | " \n",
210 | " \n",
211 | " 5508 | \n",
212 | " R26+S/32 | \n",
213 | " i21k | \n",
214 | " 300.0 | \n",
215 | " 0.001 | \n",
216 | " light0 | \n",
217 | " 0.03 | \n",
218 | " 0.0 | \n",
219 | " 0.0 | \n",
220 | " 0.465352 | \n",
221 | " 0.465049 | \n",
222 | " ... | \n",
223 | " imagenet2012 | \n",
224 | " 0.03 | \n",
225 | " 20000 | \n",
226 | " 224 | \n",
227 | " 0.858102 | \n",
228 | " 0.79954 | \n",
229 | " 36430000.0 | \n",
230 | " 1814.25 | \n",
231 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... | \n",
232 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... | \n",
233 | "
\n",
234 | " \n",
235 | " 5509 | \n",
236 | " R26+S/32 | \n",
237 | " i21k | \n",
238 | " 300.0 | \n",
239 | " 0.001 | \n",
240 | " light0 | \n",
241 | " 0.03 | \n",
242 | " 0.0 | \n",
243 | " 0.0 | \n",
244 | " 0.465352 | \n",
245 | " 0.465049 | \n",
246 | " ... | \n",
247 | " imagenet2012 | \n",
248 | " 0.01 | \n",
249 | " 20000 | \n",
250 | " 224 | \n",
251 | " 0.855370 | \n",
252 | " 0.80182 | \n",
253 | " 36430000.0 | \n",
254 | " 1814.25 | \n",
255 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... | \n",
256 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... | \n",
257 | "
\n",
258 | " \n",
259 | " 5576 | \n",
260 | " R26+S/32 | \n",
261 | " i21k | \n",
262 | " 300.0 | \n",
263 | " 0.001 | \n",
264 | " medium2 | \n",
265 | " 0.03 | \n",
266 | " 0.1 | \n",
267 | " 0.1 | \n",
268 | " 0.448574 | \n",
269 | " 0.448105 | \n",
270 | " ... | \n",
271 | " imagenet2012 | \n",
272 | " 0.03 | \n",
273 | " 20000 | \n",
274 | " 224 | \n",
275 | " 0.830393 | \n",
276 | " 0.79632 | \n",
277 | " 36430000.0 | \n",
278 | " 1814.25 | \n",
279 | " R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... | \n",
280 | " R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... | \n",
281 | "
\n",
282 | " \n",
283 | " 5577 | \n",
284 | " R26+S/32 | \n",
285 | " i21k | \n",
286 | " 300.0 | \n",
287 | " 0.001 | \n",
288 | " medium2 | \n",
289 | " 0.03 | \n",
290 | " 0.1 | \n",
291 | " 0.1 | \n",
292 | " 0.448574 | \n",
293 | " 0.448105 | \n",
294 | " ... | \n",
295 | " imagenet2012 | \n",
296 | " 0.01 | \n",
297 | " 20000 | \n",
298 | " 224 | \n",
299 | " 0.821574 | \n",
300 | " 0.78796 | \n",
301 | " 36430000.0 | \n",
302 | " 1814.25 | \n",
303 | " R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... | \n",
304 | " R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... | \n",
305 | "
\n",
306 | " \n",
307 | " 5644 | \n",
308 | " R26+S/32 | \n",
309 | " i21k | \n",
310 | " 300.0 | \n",
311 | " 0.001 | \n",
312 | " light0 | \n",
313 | " 0.10 | \n",
314 | " 0.0 | \n",
315 | " 0.0 | \n",
316 | " 0.478398 | \n",
317 | " 0.477715 | \n",
318 | " ... | \n",
319 | " imagenet2012 | \n",
320 | " 0.03 | \n",
321 | " 20000 | \n",
322 | " 224 | \n",
323 | " 0.852794 | \n",
324 | " 0.80074 | \n",
325 | " 36430000.0 | \n",
326 | " 1814.25 | \n",
327 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1... | \n",
328 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1... | \n",
329 | "
\n",
330 | " \n",
331 | "
\n",
332 | "
5 rows × 21 columns
\n",
333 | "
"
334 | ],
335 | "text/plain": [
336 | " name ds epochs lr aug wd do sd best_val \\\n",
337 | "5508 R26+S/32 i21k 300.0 0.001 light0 0.03 0.0 0.0 0.465352 \n",
338 | "5509 R26+S/32 i21k 300.0 0.001 light0 0.03 0.0 0.0 0.465352 \n",
339 | "5576 R26+S/32 i21k 300.0 0.001 medium2 0.03 0.1 0.1 0.448574 \n",
340 | "5577 R26+S/32 i21k 300.0 0.001 medium2 0.03 0.1 0.1 0.448574 \n",
341 | "5644 R26+S/32 i21k 300.0 0.001 light0 0.10 0.0 0.0 0.478398 \n",
342 | "\n",
343 | " final_val ... adapt_ds adapt_lr adapt_steps adapt_resolution \\\n",
344 | "5508 0.465049 ... imagenet2012 0.03 20000 224 \n",
345 | "5509 0.465049 ... imagenet2012 0.01 20000 224 \n",
346 | "5576 0.448105 ... imagenet2012 0.03 20000 224 \n",
347 | "5577 0.448105 ... imagenet2012 0.01 20000 224 \n",
348 | "5644 0.477715 ... imagenet2012 0.03 20000 224 \n",
349 | "\n",
350 | " adapt_final_val adapt_final_test params infer_samples_per_sec \\\n",
351 | "5508 0.858102 0.79954 36430000.0 1814.25 \n",
352 | "5509 0.855370 0.80182 36430000.0 1814.25 \n",
353 | "5576 0.830393 0.79632 36430000.0 1814.25 \n",
354 | "5577 0.821574 0.78796 36430000.0 1814.25 \n",
355 | "5644 0.852794 0.80074 36430000.0 1814.25 \n",
356 | "\n",
357 | " filename \\\n",
358 | "5508 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n",
359 | "5509 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n",
360 | "5576 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n",
361 | "5577 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n",
362 | "5644 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1... \n",
363 | "\n",
364 | " adapt_filename \n",
365 | "5508 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n",
366 | "5509 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n",
367 | "5576 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n",
368 | "5577 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n",
369 | "5644 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1... \n",
370 | "\n",
371 | "[5 rows x 21 columns]"
372 | ]
373 | },
374 | "execution_count": 7,
375 | "metadata": {},
376 | "output_type": "execute_result"
377 | }
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "metadata": {
383 | "colab": {
384 | "base_uri": "https://localhost:8080/"
385 | },
386 | "id": "dEmy40f3pBxT",
387 | "outputId": "66d07a15-e092-4edc-a5a0-02544a4860d7"
388 | },
389 | "source": [
390 | "models_ge_75[\"name\"].value_counts()"
391 | ],
392 | "execution_count": null,
393 | "outputs": [
394 | {
395 | "data": {
396 | "text/plain": [
397 | "R26+S/32 56\n",
398 | "S/16 54\n",
399 | "R50+L/32 54\n",
400 | "B/16 54\n",
401 | "B/32 53\n",
402 | "L/16 52\n",
403 | "B/8 6\n",
404 | "Name: name, dtype: int64"
405 | ]
406 | },
407 | "execution_count": 8,
408 | "metadata": {},
409 | "output_type": "execute_result"
410 | }
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {
416 | "id": "kOzlSwR-_G2K"
417 | },
418 | "source": [
419 | "Now, we first fetch the maximum accuracies with respect to a given model type and then we pick the underlying models. "
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "metadata": {
425 | "colab": {
426 | "base_uri": "https://localhost:8080/",
427 | "height": 1000
428 | },
429 | "id": "F0CCk1yupOXC",
430 | "outputId": "47cd3859-fd71-471e-b4e6-b8a2a7bd33d7"
431 | },
432 | "source": [
433 | "best_scores_by_model_type = (\n",
434 | " models_ge_75.groupby(\"name\")[\"adapt_final_test\"].max().values\n",
435 | ")\n",
436 | "results = models_ge_75[\"adapt_final_test\"].apply(\n",
437 | " lambda x: x in best_scores_by_model_type\n",
438 | ")\n",
439 | "models_ge_75[results].sort_values(by=[\"adapt_final_test\"], ascending=False).head(10)"
440 | ],
441 | "execution_count": null,
442 | "outputs": [
443 | {
444 | "data": {
445 | "text/html": [
446 | "\n",
447 | "\n",
460 | "
\n",
461 | " \n",
462 | " \n",
463 | " | \n",
464 | " name | \n",
465 | " ds | \n",
466 | " epochs | \n",
467 | " lr | \n",
468 | " aug | \n",
469 | " wd | \n",
470 | " do | \n",
471 | " sd | \n",
472 | " best_val | \n",
473 | " final_val | \n",
474 | " ... | \n",
475 | " adapt_ds | \n",
476 | " adapt_lr | \n",
477 | " adapt_steps | \n",
478 | " adapt_resolution | \n",
479 | " adapt_final_val | \n",
480 | " adapt_final_test | \n",
481 | " params | \n",
482 | " infer_samples_per_sec | \n",
483 | " filename | \n",
484 | " adapt_filename | \n",
485 | "
\n",
486 | " \n",
487 | " \n",
488 | " \n",
489 | " 50966 | \n",
490 | " B/8 | \n",
491 | " i21k | \n",
492 | " 300.0 | \n",
493 | " 0.001 | \n",
494 | " medium2 | \n",
495 | " 0.10 | \n",
496 | " 0.0 | \n",
497 | " 0.0 | \n",
498 | " 0.521426 | \n",
499 | " 0.521006 | \n",
500 | " ... | \n",
501 | " imagenet2012 | \n",
502 | " 0.01 | \n",
503 | " 20000 | \n",
504 | " 224 | \n",
505 | " 0.891742 | \n",
506 | " 0.85948 | \n",
507 | " NaN | \n",
508 | " NaN | \n",
509 | " B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_... | \n",
510 | " B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_... | \n",
511 | "
\n",
512 | " \n",
513 | " 26011 | \n",
514 | " L/16 | \n",
515 | " i21k | \n",
516 | " 300.0 | \n",
517 | " 0.001 | \n",
518 | " medium1 | \n",
519 | " 0.10 | \n",
520 | " 0.1 | \n",
521 | " 0.1 | \n",
522 | " 0.512275 | \n",
523 | " 0.512275 | \n",
524 | " ... | \n",
525 | " imagenet2012 | \n",
526 | " 0.01 | \n",
527 | " 20000 | \n",
528 | " 224 | \n",
529 | " 0.901108 | \n",
530 | " 0.85716 | \n",
531 | " 304330000.0 | \n",
532 | " 228.01 | \n",
533 | " L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do... | \n",
534 | " L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do... | \n",
535 | "
\n",
536 | " \n",
537 | " 24854 | \n",
538 | " B/16 | \n",
539 | " i21k | \n",
540 | " 300.0 | \n",
541 | " 0.001 | \n",
542 | " medium2 | \n",
543 | " 0.03 | \n",
544 | " 0.0 | \n",
545 | " 0.0 | \n",
546 | " 0.504258 | \n",
547 | " 0.503623 | \n",
548 | " ... | \n",
549 | " imagenet2012 | \n",
550 | " 0.03 | \n",
551 | " 20000 | \n",
552 | " 224 | \n",
553 | " 0.882454 | \n",
554 | " 0.84018 | \n",
555 | " 86570000.0 | \n",
556 | " 658.56 | \n",
557 | " B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d... | \n",
558 | " B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d... | \n",
559 | "
\n",
560 | " \n",
561 | " 26861 | \n",
562 | " R50+L/32 | \n",
563 | " i21k | \n",
564 | " 300.0 | \n",
565 | " 0.001 | \n",
566 | " medium1 | \n",
567 | " 0.10 | \n",
568 | " 0.1 | \n",
569 | " 0.1 | \n",
570 | " 0.514453 | \n",
571 | " 0.513877 | \n",
572 | " ... | \n",
573 | " imagenet2012 | \n",
574 | " 0.01 | \n",
575 | " 20000 | \n",
576 | " 224 | \n",
577 | " 0.897908 | \n",
578 | " 0.83784 | \n",
579 | " 110950000.0 | \n",
580 | " 1046.83 | \n",
581 | " R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.... | \n",
582 | " R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.... | \n",
583 | "
\n",
584 | " \n",
585 | " 6392 | \n",
586 | " R26+S/32 | \n",
587 | " i21k | \n",
588 | " 300.0 | \n",
589 | " 0.001 | \n",
590 | " light0 | \n",
591 | " 0.03 | \n",
592 | " 0.1 | \n",
593 | " 0.1 | \n",
594 | " 0.477891 | \n",
595 | " 0.477471 | \n",
596 | " ... | \n",
597 | " imagenet2012 | \n",
598 | " 0.03 | \n",
599 | " 20000 | \n",
600 | " 224 | \n",
601 | " 0.850062 | \n",
602 | " 0.80944 | \n",
603 | " 36430000.0 | \n",
604 | " 1814.25 | \n",
605 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... | \n",
606 | " R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... | \n",
607 | "
\n",
608 | " \n",
609 | " 5849 | \n",
610 | " R26+S/32 | \n",
611 | " i21k | \n",
612 | " 300.0 | \n",
613 | " 0.001 | \n",
614 | " medium2 | \n",
615 | " 0.10 | \n",
616 | " 0.0 | \n",
617 | " 0.0 | \n",
618 | " 0.462588 | \n",
619 | " 0.462373 | \n",
620 | " ... | \n",
621 | " imagenet2012 | \n",
622 | " 0.01 | \n",
623 | " 20000 | \n",
624 | " 224 | \n",
625 | " 0.837964 | \n",
626 | " 0.80462 | \n",
627 | " 36430000.0 | \n",
628 | " 1814.25 | \n",
629 | " R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... | \n",
630 | " R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... | \n",
631 | "
\n",
632 | " \n",
633 | " 12410 | \n",
634 | " S/16 | \n",
635 | " i21k | \n",
636 | " 300.0 | \n",
637 | " 0.001 | \n",
638 | " light1 | \n",
639 | " 0.03 | \n",
640 | " 0.0 | \n",
641 | " 0.0 | \n",
642 | " 0.472676 | \n",
643 | " 0.472402 | \n",
644 | " ... | \n",
645 | " imagenet2012 | \n",
646 | " 0.03 | \n",
647 | " 20000 | \n",
648 | " 224 | \n",
649 | " 0.841321 | \n",
650 | " 0.80462 | \n",
651 | " 22050000.0 | \n",
652 | " 1508.35 | \n",
653 | " S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do... | \n",
654 | " S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do... | \n",
655 | "
\n",
656 | " \n",
657 | " 22610 | \n",
658 | " B/32 | \n",
659 | " i21k | \n",
660 | " 300.0 | \n",
661 | " 0.001 | \n",
662 | " medium1 | \n",
663 | " 0.03 | \n",
664 | " 0.0 | \n",
665 | " 0.0 | \n",
666 | " 0.473789 | \n",
667 | " 0.473525 | \n",
668 | " ... | \n",
669 | " imagenet2012 | \n",
670 | " 0.03 | \n",
671 | " 20000 | \n",
672 | " 224 | \n",
673 | " 0.844131 | \n",
674 | " 0.79436 | \n",
675 | " 88220000.0 | \n",
676 | " 3597.19 | \n",
677 | " B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d... | \n",
678 | " B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d... | \n",
679 | "
\n",
680 | " \n",
681 | "
\n",
682 | "
8 rows × 21 columns
\n",
683 | "
"
684 | ],
685 | "text/plain": [
686 | " name ds epochs lr aug wd do sd best_val \\\n",
687 | "50966 B/8 i21k 300.0 0.001 medium2 0.10 0.0 0.0 0.521426 \n",
688 | "26011 L/16 i21k 300.0 0.001 medium1 0.10 0.1 0.1 0.512275 \n",
689 | "24854 B/16 i21k 300.0 0.001 medium2 0.03 0.0 0.0 0.504258 \n",
690 | "26861 R50+L/32 i21k 300.0 0.001 medium1 0.10 0.1 0.1 0.514453 \n",
691 | "6392 R26+S/32 i21k 300.0 0.001 light0 0.03 0.1 0.1 0.477891 \n",
692 | "5849 R26+S/32 i21k 300.0 0.001 medium2 0.10 0.0 0.0 0.462588 \n",
693 | "12410 S/16 i21k 300.0 0.001 light1 0.03 0.0 0.0 0.472676 \n",
694 | "22610 B/32 i21k 300.0 0.001 medium1 0.03 0.0 0.0 0.473789 \n",
695 | "\n",
696 | " final_val ... adapt_ds adapt_lr adapt_steps adapt_resolution \\\n",
697 | "50966 0.521006 ... imagenet2012 0.01 20000 224 \n",
698 | "26011 0.512275 ... imagenet2012 0.01 20000 224 \n",
699 | "24854 0.503623 ... imagenet2012 0.03 20000 224 \n",
700 | "26861 0.513877 ... imagenet2012 0.01 20000 224 \n",
701 | "6392 0.477471 ... imagenet2012 0.03 20000 224 \n",
702 | "5849 0.462373 ... imagenet2012 0.01 20000 224 \n",
703 | "12410 0.472402 ... imagenet2012 0.03 20000 224 \n",
704 | "22610 0.473525 ... imagenet2012 0.03 20000 224 \n",
705 | "\n",
706 | " adapt_final_val adapt_final_test params infer_samples_per_sec \\\n",
707 | "50966 0.891742 0.85948 NaN NaN \n",
708 | "26011 0.901108 0.85716 304330000.0 228.01 \n",
709 | "24854 0.882454 0.84018 86570000.0 658.56 \n",
710 | "26861 0.897908 0.83784 110950000.0 1046.83 \n",
711 | "6392 0.850062 0.80944 36430000.0 1814.25 \n",
712 | "5849 0.837964 0.80462 36430000.0 1814.25 \n",
713 | "12410 0.841321 0.80462 22050000.0 1508.35 \n",
714 | "22610 0.844131 0.79436 88220000.0 3597.19 \n",
715 | "\n",
716 | " filename \\\n",
717 | "50966 B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_... \n",
718 | "26011 L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do... \n",
719 | "24854 B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d... \n",
720 | "26861 R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.... \n",
721 | "6392 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n",
722 | "5849 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n",
723 | "12410 S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do... \n",
724 | "22610 B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d... \n",
725 | "\n",
726 | " adapt_filename \n",
727 | "50966 B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_... \n",
728 | "26011 L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do... \n",
729 | "24854 B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d... \n",
730 | "26861 R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.... \n",
731 | "6392 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n",
732 | "5849 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n",
733 | "12410 S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do... \n",
734 | "22610 B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d... \n",
735 | "\n",
736 | "[8 rows x 21 columns]"
737 | ]
738 | },
739 | "execution_count": 9,
740 | "metadata": {},
741 | "output_type": "execute_result"
742 | }
743 | ]
744 | },
745 | {
746 | "cell_type": "code",
747 | "metadata": {
748 | "colab": {
749 | "base_uri": "https://localhost:8080/"
750 | },
751 | "id": "kryppwagsWMg",
752 | "outputId": "c4469b79-e23d-4784-b02b-62ad67fb182b"
753 | },
754 | "source": [
755 | "models_ge_75[results].sort_values(by=[\"adapt_final_test\"], ascending=False).head(10)[\n",
756 | " \"adapt_filename\"\n",
757 | "].values.tolist()"
758 | ],
759 | "execution_count": null,
760 | "outputs": [
761 | {
762 | "data": {
763 | "text/plain": [
764 | "['B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224',\n",
765 | " 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224',\n",
766 | " 'B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224',\n",
767 | " 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224',\n",
768 | " 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224',\n",
769 | " 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224',\n",
770 | " 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224',\n",
771 | " 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224']"
772 | ]
773 | },
774 | "execution_count": 10,
775 | "metadata": {},
776 | "output_type": "execute_result"
777 | }
778 | ]
779 | }
780 | ]
781 | }
--------------------------------------------------------------------------------