├── .gitignore ├── LICENSE ├── README.md ├── classification.ipynb ├── conversion.ipynb ├── fine_tune.ipynb ├── i1k_eval ├── README.md ├── eval.ipynb └── imagenet_class_index.json └── model-selector.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViT-jax2tf 2 | 3 |

4 |
5 | Example usage. 6 |

7 | 8 | This repository hosts code for converting the original Vision Transformer models [1] (JAX) to 9 | TensorFlow. 10 | 11 | The original models were fine-tuned on the ImageNet-1k dataset [2]. For more details 12 | on the training protocols, please follow [3]. The authors of [3] open-sourced about 13 | **50k different variants of Vision Transformer models** in JAX. Using the 14 | [`conversion.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/conversion.ipynb) 15 | notebook, one should be able to take a model from the pool of models and convert that 16 | to TensorFlow and use that with TensorFlow Hub and Keras. 17 | 18 | The original model classes and weights [4] were converted using the `jax2tf` tool [5]. 19 | 20 | **Note that it's a requirement to use TensorFlow 2.6 or greater to use the converted models.** 21 | 22 | ## Vision Transformers on TensorFlow Hub 23 | 24 | Find the model collection on TensorFlow Hub: https://tfhub.dev/sayakpaul/collections/vision_transformer/1. 25 | 26 | Eight best performing ImageNet-1k models have also been made available on TensorFlow 27 | Hub that can be used either for off-the-shelf image classification or transfer learning. 28 | Please follow the [`model-selector.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/model-selector.ipynb) 29 | notebook to understand how these models were chosen. 30 | 31 | The table below provides a performance summary: 32 | 33 | | **Model** | **Top-1 Accuracy** | **Checkpoint** | **Misc** | 34 | |:---:|:---:|:---:|:---:| 35 | | B/8 | 85.948 | B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz | | 36 | | L/16 | 85.716 | L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz | | 37 | | B/16 | 84.018 | B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | | 38 | | R50-L/32 | 83.784 | R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz | | 39 | | R26-S/32 (light aug) | 80.944 | R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz | [tb.dev run](https://tensorboard.dev/experiment/8rjW26CoRJWdAR3ejtgvHQ/) | 40 | | R26-S/32 (medium aug) | 80.462 | R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz | | 41 | | S/16 | 80.462 | S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | [tb.dev run](https://tensorboard.dev/experiment/52LkVYfnQDykgyDHmWjzBA/) | 42 | | B/32 | 79.436 | B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz | | 43 | 44 | Note that the top-1 accuracy is reported on ImageNet-1k validation set. The checkpoints are present in the following GCS 45 | location: `gs://vit_models/augreg`. More details on these can be found in [4]. 46 | 47 | ### Image classifiers 48 | 49 | * [ViT-S16](https://tfhub.dev/sayakpaul/vit_s16_classification/1) 50 | * [ViT-B8](https://tfhub.dev/sayakpaul/vit_b8_classification/1) 51 | * [ViT-B16](https://tfhub.dev/sayakpaul/vit_b16_classification/1) 52 | * [ViT-B32](https://tfhub.dev/sayakpaul/vit_b32_classification/1) 53 | * [ViT-L16](https://tfhub.dev/sayakpaul/vit_l16_classification/1) 54 | * [ViT-R26-S32 (light augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_classification/1) 55 | * [ViT-R26-S32 (medium augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_medaug_classification/1) 56 | * [ViT-R50-L32](https://tfhub.dev/sayakpaul/vit_r50_l32_classification/1) 57 | 58 | ### Feature extractors 59 | 60 | * [ViT-S16](https://tfhub.dev/sayakpaul/vit_s16_fe/1) 61 | * [ViT-B8](https://tfhub.dev/sayakpaul/vit_b8_fe/1) 62 | * [ViT-B16](https://tfhub.dev/sayakpaul/vit_b16_fe/1) 63 | * [ViT-B32](https://tfhub.dev/sayakpaul/vit_b32_fe/1) 64 | * [ViT-L16](https://tfhub.dev/sayakpaul/vit_l16_fe/1) 65 | * [ViT-R26-S32 (light augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_fe/1) 66 | * [ViT-R26-S32 (medium augmentation)](https://tfhub.dev/sayakpaul/vit_r26_s32_medaug_fe/1) 67 | * [ViT-R50-L32](https://tfhub.dev/sayakpaul/vit_r50_l32_fe/1) 68 | 69 | ## Other notebooks 70 | 71 | * [`classification.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/classification.ipynb): Shows how to load a Vision Transformer model from TensorFlow Hub 72 | and run image classification. 73 | * [`fine_tune.ipynb`](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/fine_tune.ipynb): Shows how to 74 | fine-tune a Vision Transformer model from TensorFlow Hub on the `tf_flowers` dataset. 75 | 76 | Additionally, [`i1k_eval`](https://github.com/sayakpaul/ViT-jax2tf/tree/main/i1k_eval) contains files for running 77 | evaluation on ImageNet-1k `validation` split. 78 | 79 | ## References 80 | 81 | [1] [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Dosovitskiy et al.](https://arxiv.org/abs/2010.11929) 82 | 83 | [2] [ImageNet-1k](https://www.image-net.org/challenges/LSVRC/2012/index.php) 84 | 85 | [3] [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers by Steiner et al.](https://arxiv.org/abs/2106.10270) 86 | 87 | [4] [Vision Transformer GitHub](https://github.com/google-research/vision_transformer) 88 | 89 | [5] [jax2tf tool](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/) 90 | 91 | ## Acknowledgements 92 | 93 | Thanks to the authors of Vision Transformers for their efforts put into open-sourcing 94 | the models. 95 | 96 | Thanks to the [ML-GDE program](https://developers.google.com/programs/experts/) for providing GCP Credit support 97 | that helped me execute the experiments for this project. 98 | -------------------------------------------------------------------------------- /classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "classification", 8 | "provenance": [], 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3 (ipykernel)", 13 | "language": "python", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "codemirror_mode": { 18 | "name": "ipython", 19 | "version": 3 20 | }, 21 | "file_extension": ".py", 22 | "mimetype": "text/x-python", 23 | "name": "python", 24 | "nbconvert_exporter": "python", 25 | "pygments_lexer": "ipython3", 26 | "version": "3.8.2" 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "view-in-github", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "\"Open" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "ls5qlUELNIT7" 44 | }, 45 | "source": [ 46 | "## Imports" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "lIYdn1woOS1n" 53 | }, 54 | "source": [ 55 | "import tensorflow as tf\n", 56 | "import tensorflow_hub as hub\n", 57 | "\n", 58 | "from PIL import Image\n", 59 | "from io import BytesIO\n", 60 | "import matplotlib.pyplot as plt\n", 61 | "import numpy as np\n", 62 | "import requests" 63 | ], 64 | "execution_count": null, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": { 70 | "id": "hyQ87VPcNIUA" 71 | }, 72 | "source": [ 73 | "## Image preprocessing utilities (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "metadata": { 79 | "id": "HZZC3qJcLQyM" 80 | }, 81 | "source": [ 82 | "def preprocess_image(image):\n", 83 | " image = np.array(image)\n", 84 | " image_resized = tf.image.resize(image, (224, 224))\n", 85 | " image_resized = tf.cast(image_resized, tf.float32)\n", 86 | " image_resized = (image_resized - 127.5) / 127.5\n", 87 | " return tf.expand_dims(image_resized, 0).numpy()\n", 88 | "\n", 89 | "def load_image_from_url(url):\n", 90 | " response = requests.get(url)\n", 91 | " image = Image.open(BytesIO(response.content))\n", 92 | " image = preprocess_image(image)\n", 93 | " return image\n", 94 | "\n", 95 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt" 96 | ], 97 | "execution_count": null, 98 | "outputs": [] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "id": "YNj3BnemNIUB" 104 | }, 105 | "source": [ 106 | "## Load image and infer" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "metadata": { 112 | "id": "LPh1RSMXLdVD" 113 | }, 114 | "source": [ 115 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n", 116 | " lines = f.readlines()\n", 117 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n", 118 | "\n", 119 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n", 120 | "image = load_image_from_url(img_url)\n", 121 | "\n", 122 | "plt.imshow((image[0] + 1) / 2)\n", 123 | "plt.show()" 124 | ], 125 | "execution_count": null, 126 | "outputs": [] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "metadata": { 131 | "id": "uOo3vYhqLu3Y" 132 | }, 133 | "source": [ 134 | "model_url = \"https://tfhub.dev/sayakpaul/vit_s16_classification/1\"\n", 135 | "\n", 136 | "classification_model = tf.keras.Sequential(\n", 137 | " [hub.KerasLayer(model_url)]\n", 138 | ") \n", 139 | "predictions = classification_model.predict(image)\n", 140 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n", 141 | "predicted_label" 142 | ], 143 | "execution_count": null, 144 | "outputs": [] 145 | } 146 | ] 147 | } -------------------------------------------------------------------------------- /conversion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "conversion", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "display_name": "Python 3 (ipykernel)", 12 | "language": "python", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "codemirror_mode": { 17 | "name": "ipython", 18 | "version": 3 19 | }, 20 | "file_extension": ".py", 21 | "mimetype": "text/x-python", 22 | "name": "python", 23 | "nbconvert_exporter": "python", 24 | "pygments_lexer": "ipython3", 25 | "version": "3.8.2" 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "view-in-github", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "\"Open" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "J8J1g5vBT5aj" 43 | }, 44 | "source": [ 45 | "## References\n", 46 | "\n", 47 | "* https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md\n", 48 | "* https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "id": "piv05HW04aUW" 55 | }, 56 | "source": [ 57 | "## Setup" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "id": "tLVMT01KScv5" 64 | }, 65 | "source": [ 66 | "!pip install -q absl-py>=0.12.0 chex>=0.0.7 clu>=0.0.3 einops>=0.3.0\n", 67 | "!pip install -q flax==0.3.3 ml-collections==0.1.0 tf-nightly\n", 68 | "!pip install -q numpy>=1.19.5 pandas>=1.1.0" 69 | ], 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "metadata": { 76 | "id": "lIYdn1woOS1n" 77 | }, 78 | "source": [ 79 | "# Clone repository and pull latest changes.\n", 80 | "![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer\n", 81 | "!cd vision_transformer && git pull" 82 | ], 83 | "execution_count": null, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "CwBrIdAE4ciM" 90 | }, 91 | "source": [ 92 | "## Imports" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "aWuEnEshSdzt" 99 | }, 100 | "source": [ 101 | "import sys\n", 102 | "\n", 103 | "if \"./vision_transformer\" not in sys.path:\n", 104 | " sys.path.append(\"./vision_transformer\")\n", 105 | "\n", 106 | "from vit_jax import models\n", 107 | "from vit_jax import checkpoint\n", 108 | "from vit_jax.configs import common as common_config\n", 109 | "from vit_jax.configs import models as models_config\n", 110 | "\n", 111 | "from jax.experimental import jax2tf\n", 112 | "import tensorflow as tf\n", 113 | "import flax\n", 114 | "import jax\n", 115 | "\n", 116 | "from PIL import Image\n", 117 | "from io import BytesIO\n", 118 | "import numpy as np\n", 119 | "import requests" 120 | ], 121 | "execution_count": null, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "metadata": { 127 | "id": "RFV5xTgW28ys" 128 | }, 129 | "source": [ 130 | "print(f\"JAX version: {jax.__version__}\")\n", 131 | "print(f\"FLAX version: {flax.__version__}\")\n", 132 | "print(f\"TensorFlow version: {tf.__version__}\")" 133 | ], 134 | "execution_count": null, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "id": "Vt5uGYJH3LXM" 141 | }, 142 | "source": [ 143 | "## Classification / Feature Extractor model" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "metadata": { 149 | "id": "hKh0k1M7SgSN" 150 | }, 151 | "source": [ 152 | "#@title Choose a model type\n", 153 | "VIT_MODELS = \"B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224\" #@param [\"L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224\", \"B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224\", \"R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224\", \"R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224\", \"R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224\", \"S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224\", \"B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224\", \"B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224\"]\n", 154 | "#@markdown The models were selected based on the criteria shown here in [this notebook](https://github.com/sayakpaul/ViT-jax2tf/blob/main/model-selector.ipynb).\n", 155 | "\n", 156 | "print(f\"Model type selected: ViT-{VIT_MODELS.split('-')[0]}\")\n", 157 | "\n", 158 | "ROOT_GCS_PATH = \"gs://vit_models/augreg\"" 159 | ], 160 | "execution_count": null, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "bjlhNR62-JJJ" 167 | }, 168 | "source": [ 169 | "classification_model = True\n", 170 | "\n", 171 | "if classification_model:\n", 172 | " num_classes = 1000\n", 173 | " print(\"Will be converting a classification model.\")\n", 174 | "else:\n", 175 | " num_classes = None\n", 176 | " print(\"Will be converting a feature extraction model.\")" 177 | ], 178 | "execution_count": null, 179 | "outputs": [] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "mclDoMqbShzV" 185 | }, 186 | "source": [ 187 | "# Instantiate model class and load the corresponding checkpoints.\n", 188 | "config = common_config.get_config()\n", 189 | "config.model = models_config.AUGREG_CONFIGS[f\"{VIT_MODELS.split('-')[0]}\"]\n", 190 | "\n", 191 | "model = models.VisionTransformer(num_classes=num_classes, **config.model)\n", 192 | "\n", 193 | "path = f\"{ROOT_GCS_PATH}/{VIT_MODELS}.npz\"\n", 194 | "params = checkpoint.load(path)\n", 195 | "\n", 196 | "if not num_classes:\n", 197 | " _ = params.pop(\"head\")" 198 | ], 199 | "execution_count": null, 200 | "outputs": [] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": { 205 | "id": "RqbuBCFw9vyg" 206 | }, 207 | "source": [ 208 | "## Conversion\n", 209 | "\n", 210 | "Code has been reused from the official examples [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)." 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": { 216 | "id": "XT2GLwXg95tE" 217 | }, 218 | "source": [ 219 | "### Step 1: Get a prediction function out of the JAX model & convert it to a native TF function" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "metadata": { 225 | "id": "J2e1h-F2SmDB" 226 | }, 227 | "source": [ 228 | "predict_fn = lambda params, inputs: model.apply(\n", 229 | " dict(params=params), inputs, train=False\n", 230 | ")\n", 231 | "\n", 232 | "with_gradient = False if num_classes else True\n", 233 | "tf_fn = jax2tf.convert(\n", 234 | " predict_fn,\n", 235 | " with_gradient=with_gradient,\n", 236 | " polymorphic_shapes=[None, \"b, 224, 224, 3\"],\n", 237 | ")" 238 | ], 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": { 245 | "id": "cRGAjnKBRGgU" 246 | }, 247 | "source": [ 248 | "We set `polymorphic_shapes` to allow the converted model operate with arbitrary batch sizes. Know more about the shape polymorphism in JAX from [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)." 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": { 254 | "id": "PKE1msyx-3ge" 255 | }, 256 | "source": [ 257 | "### Step 2: Set the trainability of the individual param groups and construct TF graph" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "metadata": { 263 | "id": "8RNlRp9pTHgF" 264 | }, 265 | "source": [ 266 | "param_vars = tf.nest.map_structure(\n", 267 | " lambda param: tf.Variable(param, trainable=with_gradient), params\n", 268 | ")\n", 269 | "tf_graph = tf.function(\n", 270 | " lambda inputs: tf_fn(param_vars, inputs), autograph=False, jit_compile=True\n", 271 | ")" 272 | ], 273 | "execution_count": null, 274 | "outputs": [] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": { 279 | "id": "3fDRubHD_Sf3" 280 | }, 281 | "source": [ 282 | "### Step 3: Serialize as a SavedModel" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "cellView": "form", 289 | "id": "1QJQwDEyTs2V" 290 | }, 291 | "source": [ 292 | "#@title SavedModel wrapper class utility from [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py#L128)\n", 293 | "class _ReusableSavedModelWrapper(tf.train.Checkpoint):\n", 294 | " \"\"\"Wraps a function and its parameters for saving to a SavedModel.\n", 295 | " Implements the interface described at\n", 296 | " https://www.tensorflow.org/hub/reusable_saved_models.\n", 297 | " \"\"\"\n", 298 | "\n", 299 | " def __init__(self, tf_graph, param_vars):\n", 300 | " \"\"\"Args:\n", 301 | " tf_graph: a tf.function taking one argument (the inputs), which can be\n", 302 | " be tuples/lists/dictionaries of np.ndarray or tensors. The function\n", 303 | " may have references to the tf.Variables in `param_vars`.\n", 304 | " param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,\n", 305 | " to be saved as the variables of the SavedModel.\n", 306 | " \"\"\"\n", 307 | " super().__init__()\n", 308 | " # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models\n", 309 | " self.variables = tf.nest.flatten(param_vars)\n", 310 | " self.trainable_variables = [v for v in self.variables if v.trainable]\n", 311 | " # If you intend to prescribe regularization terms for users of the model,\n", 312 | " # add them as @tf.functions with no inputs to this list. Else drop this.\n", 313 | " self.regularization_losses = []\n", 314 | " self.__call__ = tf_graph\n" 315 | ], 316 | "execution_count": null, 317 | "outputs": [] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "metadata": { 322 | "id": "xr2Vf9Ql_lca" 323 | }, 324 | "source": [ 325 | "input_signatures = [tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32)]\n", 326 | "model_dir = VIT_MODELS if num_classes else f\"{VIT_MODELS}_fe\"\n", 327 | "signatures = {}\n", 328 | "saved_model_options = None\n", 329 | "\n", 330 | "print(f\"Saving model to {model_dir} directory.\")" 331 | ], 332 | "execution_count": null, 333 | "outputs": [] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "metadata": { 338 | "id": "pMn9fJxuTKON" 339 | }, 340 | "source": [ 341 | "signatures[\n", 342 | " tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n", 343 | "] = tf_graph.get_concrete_function(input_signatures[0])\n", 344 | "\n", 345 | "wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)\n", 346 | "if with_gradient:\n", 347 | " if not saved_model_options:\n", 348 | " saved_model_options = tf.saved_model.SaveOptions(\n", 349 | " experimental_custom_gradients=True\n", 350 | " )\n", 351 | " else:\n", 352 | " saved_model_options.experimental_custom_gradients = True\n", 353 | "tf.saved_model.save(\n", 354 | " wrapper, model_dir, signatures=signatures, options=saved_model_options\n", 355 | ")\n", 356 | "\n", 357 | "# Note that directly saving the `wrapper` to a GCS location is\n", 358 | "# also supported." 359 | ], 360 | "execution_count": null, 361 | "outputs": [] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "id": "2PJr-uVs_vz-" 367 | }, 368 | "source": [ 369 | "## Functional test (credits: [Willi Gierke](https://ch.linkedin.com/in/willi-gierke))" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": { 375 | "id": "NA2G4HzvC5_l" 376 | }, 377 | "source": [ 378 | "### Image preprocessing utilities " 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "metadata": { 384 | "id": "XyvjkBAE5iFL" 385 | }, 386 | "source": [ 387 | "def preprocess_image(image):\n", 388 | " image = np.array(image)\n", 389 | " image_resized = tf.image.resize(image, (224, 224))\n", 390 | " image_resized = tf.cast(image_resized, tf.float32)\n", 391 | " image_resized = (image_resized - 127.5) / 127.5\n", 392 | " return tf.expand_dims(image_resized, 0).numpy()\n", 393 | "\n", 394 | "def load_image_from_url(url):\n", 395 | " response = requests.get(url)\n", 396 | " image = Image.open(BytesIO(response.content))\n", 397 | " image = preprocess_image(image)\n", 398 | " return image\n", 399 | "\n", 400 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt" 401 | ], 402 | "execution_count": null, 403 | "outputs": [] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": { 408 | "id": "Hd-YH-hqAIQ9" 409 | }, 410 | "source": [ 411 | "### Load image and ImageNet-1k class mappings" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "metadata": { 417 | "id": "4vDQd6MEAEp_" 418 | }, 419 | "source": [ 420 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n", 421 | " lines = f.readlines()\n", 422 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n", 423 | "\n", 424 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n", 425 | "image = load_image_from_url(img_url)" 426 | ], 427 | "execution_count": null, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "9A-LxOYBANnv" 434 | }, 435 | "source": [ 436 | "### Inference\n", 437 | "\n", 438 | "This is only application for the classification models. For fine-tuning/feature extraction, please follow [this notebook](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/fine_tune.ipynb) instead." 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "metadata": { 444 | "id": "99DTGYB25o5d" 445 | }, 446 | "source": [ 447 | "# Load the converted SavedModel and check whether it finds the elephant.\n", 448 | "restored_model = tf.saved_model.load(model_dir)\n", 449 | "predictions = restored_model.signatures[\"serving_default\"](tf.constant(image))\n", 450 | "logits = predictions[\"output_0\"][0]\n", 451 | "predicted_label = imagenet_int_to_str[int(np.argmax(logits))]\n", 452 | "expected_label = \"Indian_elephant, Elephas_maximus\"\n", 453 | "assert (\n", 454 | " predicted_label == expected_label\n", 455 | "), f\"Expected {expected_label} but was {predicted_label}\"" 456 | ], 457 | "execution_count": null, 458 | "outputs": [] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": { 463 | "id": "itfezjKjAXA6" 464 | }, 465 | "source": [ 466 | "## Inference with TensorFlow Hub \n", 467 | "\n", 468 | "Run the following code snippet. You can also follow [this notebook](https://colab.research.google.com/github/sayakpaul/ViT-jax2tf/blob/main/classification.ipynb). " 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": { 474 | "id": "HR8lV2377Ad3" 475 | }, 476 | "source": [ 477 | "```python\n", 478 | "import tensorflow_hub as hub\n", 479 | "\n", 480 | "classification_model = tf.keras.Sequential([hub.KerasLayer(model_dir)])\n", 481 | "predictions = classification_model.predict(image)\n", 482 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n", 483 | "predicted_label\n", 484 | "```" 485 | ] 486 | } 487 | ] 488 | } 489 | -------------------------------------------------------------------------------- /fine_tune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "Yx9kQrATLdy5" 7 | }, 8 | "source": [ 9 | "## Imports" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "UqO1f2Z7QoOC" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "from tensorflow import keras\n", 21 | "import tensorflow as tf\n", 22 | "import tensorflow_hub as hub\n", 23 | "\n", 24 | "import tensorflow_datasets as tfds\n", 25 | "\n", 26 | "tfds.disable_progress_bar()\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import numpy as np" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "id": "yS9v0obPLgVy" 36 | }, 37 | "source": [ 38 | "## Model building utility" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "id": "yZ0gsA41RVVM" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "def get_model(\n", 50 | " handle=\"https://tfhub.dev/sayakpaul/vit_s16_fe/1\", \n", 51 | " num_classes=5,\n", 52 | "):\n", 53 | " hub_layer = hub.KerasLayer(handle, trainable=True)\n", 54 | "\n", 55 | " model = keras.Sequential(\n", 56 | " [\n", 57 | " keras.layers.InputLayer((224, 224, 3)),\n", 58 | " hub_layer,\n", 59 | " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", 60 | " ]\n", 61 | " )\n", 62 | "\n", 63 | " return model" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "id": "jk5VZIafSxGv" 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "get_model().summary()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": { 80 | "id": "ELO3it7aLk6E" 81 | }, 82 | "source": [ 83 | "## Data input pipeline\n", 84 | "\n", 85 | "Code has been reused from the [official repository](https://github.com/google-research/vision_transformer/blob/main/vit_jax/input_pipeline.py)." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "id": "hE2MxLkbEfa7" 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "BATCH_SIZE = 64\n", 97 | "AUTO = tf.data.AUTOTUNE" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "id": "Ouk89SHNS9us" 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = 224):\n", 109 | " def preprocess(image, label):\n", 110 | " # For training, do a random crop and horizontal flip.\n", 111 | " if train:\n", 112 | " channels = image.shape[-1]\n", 113 | " begin, size, _ = tf.image.sample_distorted_bounding_box(\n", 114 | " tf.shape(image),\n", 115 | " tf.zeros([0, 0, 4], tf.float32),\n", 116 | " area_range=(0.05, 1.0),\n", 117 | " min_object_covered=0,\n", 118 | " use_image_if_no_bounding_boxes=True,\n", 119 | " )\n", 120 | " image = tf.slice(image, begin, size)\n", 121 | "\n", 122 | " image.set_shape([None, None, channels])\n", 123 | " image = tf.image.resize(image, [image_size, image_size])\n", 124 | " if tf.random.uniform(shape=[]) > 0.5:\n", 125 | " image = tf.image.flip_left_right(image)\n", 126 | "\n", 127 | " else:\n", 128 | " image = tf.image.resize(image, [image_size, image_size])\n", 129 | "\n", 130 | " image = (image - 127.5) / 127.5\n", 131 | " return image, label\n", 132 | "\n", 133 | " if train:\n", 134 | " dataset = dataset.shuffle(BATCH_SIZE * 10)\n", 135 | "\n", 136 | " return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": { 142 | "id": "Hl3i3onrLtO6" 143 | }, 144 | "source": [ 145 | "## `tf_flowers` dataset" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "id": "lHJFUqFZE-0F" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "train_dataset, val_dataset = tfds.load(\n", 157 | " \"tf_flowers\", split=[\"train[:90%]\", \"train[90%:]\"], as_supervised=True\n", 158 | ")\n", 159 | "\n", 160 | "num_train = tf.data.experimental.cardinality(train_dataset)\n", 161 | "num_val = tf.data.experimental.cardinality(val_dataset)\n", 162 | "print(f\"Number of training examples: {num_train}\")\n", 163 | "print(f\"Number of validation examples: {num_val}\")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "id": "rnjfXf8RLzCp" 170 | }, 171 | "source": [ 172 | "### Prepare dataset" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "id": "DRd8kkcMFxSw" 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "train_dataset = make_dataset(train_dataset, True)\n", 184 | "val_dataset = make_dataset(val_dataset, False)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "id": "HbeqUHVdLz5J" 191 | }, 192 | "source": [ 193 | "### Visualize" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "id": "O8Ui5B8hGNXQ" 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "sample_images, _ = next(iter(train_dataset))\n", 205 | "\n", 206 | "plt.figure(figsize=(10, 10))\n", 207 | "for n in range(25):\n", 208 | " ax = plt.subplot(5, 5, n + 1)\n", 209 | " plt.imshow((sample_images[n].numpy() + 1) / 2)\n", 210 | " plt.axis(\"off\")\n", 211 | "plt.show()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": { 217 | "id": "h1XNooD2L2IY" 218 | }, 219 | "source": [ 220 | "## Learning rate scheduling \n", 221 | "\n", 222 | "For fine-tuning the authors follow a warm-up + [cosine | linear] schedule as per the [official notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/linen/vit_jax.ipynb). " 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": { 229 | "id": "3cFQQaQoGjuF" 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "# Reference:\n", 234 | "# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2\n", 235 | "\n", 236 | "\n", 237 | "class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):\n", 238 | " def __init__(\n", 239 | " self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps\n", 240 | " ):\n", 241 | " super(WarmUpCosine, self).__init__()\n", 242 | "\n", 243 | " self.learning_rate_base = learning_rate_base\n", 244 | " self.total_steps = total_steps\n", 245 | " self.warmup_learning_rate = warmup_learning_rate\n", 246 | " self.warmup_steps = warmup_steps\n", 247 | " self.pi = tf.constant(np.pi)\n", 248 | "\n", 249 | " def __call__(self, step):\n", 250 | " if self.total_steps < self.warmup_steps:\n", 251 | " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n", 252 | " learning_rate = (\n", 253 | " 0.5\n", 254 | " * self.learning_rate_base\n", 255 | " * (\n", 256 | " 1\n", 257 | " + tf.cos(\n", 258 | " self.pi\n", 259 | " * (tf.cast(step, tf.float32) - self.warmup_steps)\n", 260 | " / float(self.total_steps - self.warmup_steps)\n", 261 | " )\n", 262 | " )\n", 263 | " )\n", 264 | "\n", 265 | " if self.warmup_steps > 0:\n", 266 | " if self.learning_rate_base < self.warmup_learning_rate:\n", 267 | " raise ValueError(\n", 268 | " \"Learning_rate_base must be larger or equal to \"\n", 269 | " \"warmup_learning_rate.\"\n", 270 | " )\n", 271 | " slope = (\n", 272 | " self.learning_rate_base - self.warmup_learning_rate\n", 273 | " ) / self.warmup_steps\n", 274 | " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n", 275 | " learning_rate = tf.where(\n", 276 | " step < self.warmup_steps, warmup_rate, learning_rate\n", 277 | " )\n", 278 | " return tf.where(\n", 279 | " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n", 280 | " )" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": { 286 | "id": "1-fPavRyMH1X" 287 | }, 288 | "source": [ 289 | "## Training hyperparameters\n", 290 | "\n", 291 | "These have been referred from the official notebooks ([1](https://colab.research.google.com/github/google-research/vision_transformer/blob/linen/vit_jax.ipynb) and [2](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb)). \n", 292 | "\n", 293 | "Differences:\n", 294 | "\n", 295 | "* No gradient accumulation\n", 296 | "* Lower batch size for demoing on a single GPU (64 as opposed to 512)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": { 303 | "id": "bgsWyyhgHAaB" 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "EPOCHS = 8\n", 308 | "TOTAL_STEPS = int((num_train / BATCH_SIZE) * EPOCHS)\n", 309 | "WARMUP_STEPS = 10\n", 310 | "INIT_LR = 0.03\n", 311 | "WAMRUP_LR = 0.006\n", 312 | "\n", 313 | "print(TOTAL_STEPS)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "id": "LppzZVdbHwPE" 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "scheduled_lrs = WarmUpCosine(\n", 325 | " learning_rate_base=INIT_LR,\n", 326 | " total_steps=TOTAL_STEPS,\n", 327 | " warmup_learning_rate=WAMRUP_LR,\n", 328 | " warmup_steps=WARMUP_STEPS,\n", 329 | ")\n", 330 | "\n", 331 | "lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]\n", 332 | "plt.plot(lrs)\n", 333 | "plt.xlabel(\"Step\", fontsize=14)\n", 334 | "plt.ylabel(\"LR\", fontsize=14)\n", 335 | "plt.show()" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": { 341 | "id": "i__WUIfcMpDk" 342 | }, 343 | "source": [ 344 | "### Optimizer and loss function" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": { 351 | "id": "yblMGEOjIiyP" 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "optimizer = keras.optimizers.SGD(scheduled_lrs, clipnorm=1.0)\n", 356 | "loss = keras.losses.SparseCategoricalCrossentropy()" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": { 362 | "id": "apdPFM_TMsJg" 363 | }, 364 | "source": [ 365 | "## Model training and validation" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": { 372 | "id": "wd5XVBMZJlu_" 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "model = get_model()\n", 377 | "model.compile(loss=loss, optimizer=optimizer, metrics=[\"accuracy\"])\n", 378 | "\n", 379 | "history = model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": { 386 | "id": "InPOQqfJK3gR" 387 | }, 388 | "outputs": [], 389 | "source": [ 390 | "plt.figure(figsize=(7, 7))\n", 391 | "history = history.history\n", 392 | "\n", 393 | "plt.plot(history[\"loss\"], label=\"train_loss\")\n", 394 | "plt.plot(history[\"val_loss\"], label=\"val_loss\")\n", 395 | "plt.plot(history[\"accuracy\"], label=\"train_accuracy\")\n", 396 | "plt.plot(history[\"val_accuracy\"], label=\"val_accuracy\")\n", 397 | "\n", 398 | "plt.legend()\n", 399 | "plt.show()" 400 | ] 401 | } 402 | ], 403 | "metadata": { 404 | "accelerator": "GPU", 405 | "colab": { 406 | "collapsed_sections": [], 407 | "include_colab_link": true, 408 | "machine_shape": "hm", 409 | "name": "fine-tune.ipynb", 410 | "provenance": [] 411 | }, 412 | "kernelspec": { 413 | "display_name": "Python 3 (ipykernel)", 414 | "language": "python", 415 | "name": "python3" 416 | }, 417 | "language_info": { 418 | "codemirror_mode": { 419 | "name": "ipython", 420 | "version": 3 421 | }, 422 | "file_extension": ".py", 423 | "mimetype": "text/x-python", 424 | "name": "python", 425 | "nbconvert_exporter": "python", 426 | "pygments_lexer": "ipython3", 427 | "version": "3.8.2" 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 1 432 | } 433 | -------------------------------------------------------------------------------- /i1k_eval/README.md: -------------------------------------------------------------------------------- 1 | This directory provides a notebook and ImageNet-1k class mapping file to run evaluation on the ImageNet-1k `validation` split using the [ViT models from TF-Hub](https://tfhub.dev/sayakpaul/collections/vision_transformer/1). One should use this same setup to evaluate the [MLP-Mixer models from TF-Hub](https://tfhub.dev/sayakpaul/collections/mlp-mixer/1). The notebook assumes the following files are present in your working directory: 2 | 3 | * The `validation` split directory of ImageNet-1k. 4 | * The class mapping files (`.json`). 5 | -------------------------------------------------------------------------------- /i1k_eval/eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cc520ed3-785b-4491-aa40-4346b72a4574", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import tensorflow_datasets as tfds\n", 11 | "import tensorflow_hub as hub\n", 12 | "import tensorflow as tf\n", 13 | "\n", 14 | "from imutils import paths\n", 15 | "import json" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "86454d69-4857-45db-a98f-abc5b1c97e3d", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "try:\n", 26 | " tpu = None\n", 27 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n", 28 | " tf.config.experimental_connect_to_cluster(tpu)\n", 29 | " tf.tpu.experimental.initialize_tpu_system(tpu)\n", 30 | " strategy = tf.distribute.TPUStrategy(tpu)\n", 31 | "except ValueError:\n", 32 | " strategy = tf.distribute.MirroredStrategy()\n", 33 | "\n", 34 | "print(\"Number of accelerators: \", strategy.num_replicas_in_sync)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "c3f047d5-e54d-43c2-b19d-0ce5e56e9ff6", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "AUTO = tf.data.AUTOTUNE\n", 45 | "BATCH_SIZE = 128 * strategy.num_replicas_in_sync" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "c0bbd0fe-08b9-4d6a-ac19-ab1322e20594", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "with open(\"imagenet_class_index.json\", \"r\") as read_file:\n", 56 | " imagenet_labels = json.load(read_file)\n", 57 | "\n", 58 | "MAPPING_DICT = {}\n", 59 | "LABEL_NAMES = {}\n", 60 | "for label_id in list(imagenet_labels.keys()):\n", 61 | " MAPPING_DICT[imagenet_labels[label_id][0]] = int(label_id)\n", 62 | " LABEL_NAMES[int(label_id)] = imagenet_labels[label_id][1]" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "3c9ad7b4-ce65-42c9-b9f6-b7d9a3e2a030", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "all_val_paths = list(paths.list_images(\"val\"))\n", 73 | "all_val_labels = [MAPPING_DICT[x.split(\"/\")[1]] for x in all_val_paths]\n", 74 | "\n", 75 | "all_val_paths[:5], all_val_labels[:5]" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "3cb2086a-b442-4bd0-8b7d-f717b8816286", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "def load_and_prepare(path, label):\n", 86 | " image = tf.io.read_file(path)\n", 87 | " image = tf.image.decode_png(image, channels=3)\n", 88 | " image = tf.image.resize(image, (224, 224))\n", 89 | "\n", 90 | " return image, label" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "289f1a88-70b1-4a99-9a95-08daf4f38290", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "dataset = tf.data.Dataset.from_tensor_slices((all_val_paths, all_val_labels))\n", 101 | "\n", 102 | "dataset = dataset.map(load_and_prepare, num_parallel_calls=AUTO).batch(BATCH_SIZE)\n", 103 | "dataset = dataset.prefetch(AUTO)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "14bd0f87-24ce-48cf-b263-e878f260e618", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def get_model(model_url=\"https://tfhub.dev/sayakpaul/vit_s16_classification/1\"):\n", 114 | " classification_model = tf.keras.Sequential(\n", 115 | " [\n", 116 | " tf.keras.layers.InputLayer((224, 224, 3)),\n", 117 | " tf.keras.layers.Rescaling(\n", 118 | " scale=1.0 / 127.5, offset=-1\n", 119 | " ), # Scales to [-1, 1].\n", 120 | " hub.KerasLayer(model_url),\n", 121 | " ]\n", 122 | " )\n", 123 | " return classification_model" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "d0de72ef-495d-41ae-9875-809101b3c177", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def eval_util(model_url, arch):\n", 134 | " tb_callback = tf.keras.callbacks.TensorBoard(log_dir=f\"logs_{arch}\")\n", 135 | " with strategy.scope():\n", 136 | " model = get_model(model_url)\n", 137 | " model.compile(metrics=[\"accuracy\"])\n", 138 | " model.evaluate(dataset, callbacks=[tb_callback])" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "e2287612-8d1a-4131-87cd-5d0798a6b313", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "model_urls = [\n", 149 | " \"https://tfhub.dev/sayakpaul/vit_s16_classification/1\",\n", 150 | " \"https://tfhub.dev/sayakpaul/vit_r26_s32_lightaug_classification/1\",\n", 151 | "]\n", 152 | "\n", 153 | "archs = [\"s16\", \"r26_s32\"]\n", 154 | "\n", 155 | "for model_url, arch in zip(model_urls, archs):\n", 156 | " print(f\"Evaluating {arch}\")\n", 157 | " eval_util(model_url, arch)" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "environment": { 163 | "name": "tf2-gpu.2-6.m80", 164 | "type": "gcloud", 165 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m80" 166 | }, 167 | "kernelspec": { 168 | "display_name": "Python 3 (ipykernel)", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.8.2" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 5 187 | } 188 | -------------------------------------------------------------------------------- /i1k_eval/imagenet_class_index.json: -------------------------------------------------------------------------------- 1 | {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"], "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"], "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"], "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"], "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"], "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"], "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"], "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"], "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"], "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"], "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"], "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"], "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"], "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"], "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"], "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"], "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"], "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"], "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"], "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"], "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"], "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"], "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"], "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"], "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"], "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"], "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"], "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"], "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"], "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"], "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"], "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"], "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"], "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"], "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"], "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"], "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"], "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"], "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"], "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"], "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"], "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"], "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"], "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"], "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"], "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"], "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"], "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"], "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"], "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"], "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"], "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"], "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"], "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"], "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"], "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"], "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"], "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"], "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"], "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"], "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"], "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"], "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"], "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"], "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"], "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"], "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"], "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"], "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"], "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"], "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"], "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"], "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"], "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"], "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"], "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"], "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"], "179": ["n02093256", "Staffordshire_bullterrier"], "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"], "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"], "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"], "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"], "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"], "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"], "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"], "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"], "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"], "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"], "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"], "202": ["n02098105", "soft-coated_wheaten_terrier"], "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"], "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"], "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"], "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"], "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"], "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"], "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"], "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"], "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"], "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"], "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"], "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"], "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"], "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"], "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"], "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"], "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"], "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"], "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"], "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"], "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"], "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"], "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"], "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"], "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"], "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"], "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"], "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"], "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"], "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"], "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"], "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"], "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"], "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"], "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"], "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"], "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"], "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"], "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"], "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"], "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"], "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"], "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"], "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"], "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"], "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"], "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"], "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"], "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"], "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"], "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"], "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"], "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"], "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"], "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"], "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"], "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"], "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"], "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"], "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"], "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"], "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"], "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"], "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"], "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"], "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"], "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"], "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"], "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"], "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"], "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"], "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"], "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"], "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"], "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"], "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"], "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"], "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"], "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"], "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"], "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"], "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"], "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"], "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"], "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"], "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"], "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"], "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"], "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"], "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"], "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"], "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"], "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"], "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"], "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"], "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"], "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"], "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"], "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"], "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"], "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"], "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"], "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"], "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"], "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"], "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"], "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"], "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"], "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"], "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"], "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"], "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"], "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"], "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"], "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"], "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"], "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"], "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"], "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"], "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"], "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"], "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"], "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"], "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"], "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"], "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"], "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"], "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"], "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"], "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"], "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"], "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"], "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"], "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"], "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"], "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"], "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"], "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"], "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"], "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"], "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"], "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"], "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"], "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"], "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"], "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"], "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"], "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"], "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"], "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"], "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"], "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"], "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"], "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"], "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"], "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"], "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"], "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"], "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"], "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"], "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"], "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"], "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"], "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"], "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"], "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"], "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"], "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"], "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"], "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"], "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"], "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"], "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"], "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"], "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"], "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"], "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"], "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"], "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"], "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"], "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"], "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"], "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"], "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"], "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"], "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"], "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"], "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"], "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"], "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"], "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"], "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"], "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"], "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"], "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"], "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"], "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"], "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"], "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"], "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"], "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"], "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"], "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"], "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"], "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"], "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"], "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"], "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"], "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"], "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"], "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"], "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"], "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"], "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"], "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"], "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"], "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"], "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"], "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"], "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"], "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"], "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"], "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"], "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"], "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"], "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"], "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"], "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"], "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"], "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"], "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"], "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"], "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"], "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"], "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"], "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"], "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"], "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"], "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"], "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"], "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"], "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"], "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"], "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"], "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"], "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"], "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"], "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"], "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"], "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"], "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"], "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"], "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"], "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"], "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"], "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"], "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"], "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"], "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"], "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"], "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"], "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"], "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"], "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"], "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"], "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"], "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"], "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"], "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"], "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"], "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"], "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"], "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"], "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"], "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"], "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"], "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"], "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"], "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"], "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"], "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"], "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"], "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"], "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"], "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"], "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"], "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"], "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"], "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"], "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"], "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"], "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"], "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"], "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"], "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"], "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"], "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"], "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"], "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"], "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"], "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"], "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"], "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"], "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"], "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"], "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"], "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"], "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"], "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"], "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"], "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"], "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"], "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"], "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"], "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"], "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"], "999": ["n15075141", "toilet_tissue"]} -------------------------------------------------------------------------------- /model-selector.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "model-selector", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "display_name": "Python 3 (ipykernel)", 12 | "language": "python", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "codemirror_mode": { 17 | "name": "ipython", 18 | "version": 3 19 | }, 20 | "file_extension": ".py", 21 | "mimetype": "text/x-python", 22 | "name": "python", 23 | "nbconvert_exporter": "python", 24 | "pygments_lexer": "ipython3", 25 | "version": "3.8.2" 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "view-in-github", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "\"Open" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "yXeTsp2E7Wxq" 43 | }, 44 | "source": [ 45 | "## Reference\n", 46 | "\n", 47 | "* [vit_jax_augreg.ipynb](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "lIYdn1woOS1n" 54 | }, 55 | "source": [ 56 | "import tensorflow as tf\n", 57 | "import pandas as pd" 58 | ], 59 | "execution_count": null, 60 | "outputs": [] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "id": "KqgZLLHE7SC7" 66 | }, 67 | "source": [ 68 | "# Load master table from Cloud.\n", 69 | "with tf.io.gfile.GFile(\"gs://vit_models/augreg/index.csv\") as f:\n", 70 | " df = pd.read_csv(f)" 71 | ], 72 | "execution_count": null, 73 | "outputs": [] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "metadata": { 78 | "colab": { 79 | "base_uri": "https://localhost:8080/" 80 | }, 81 | "id": "bto9lWO17e4x", 82 | "outputId": "207ee223-8e25-4222-c3a7-999434ab91ef" 83 | }, 84 | "source": [ 85 | "df.columns" 86 | ], 87 | "execution_count": null, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "Index(['name', 'ds', 'epochs', 'lr', 'aug', 'wd', 'do', 'sd', 'best_val',\n", 93 | " 'final_val', 'final_test', 'adapt_ds', 'adapt_lr', 'adapt_steps',\n", 94 | " 'adapt_resolution', 'adapt_final_val', 'adapt_final_test', 'params',\n", 95 | " 'infer_samples_per_sec', 'filename', 'adapt_filename'],\n", 96 | " dtype='object')" 97 | ] 98 | }, 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "colab": { 109 | "base_uri": "https://localhost:8080/" 110 | }, 111 | "id": "RdC3XNws74AY", 112 | "outputId": "59bca181-8985-4e0e-88b8-c9fbaf0a0c4a" 113 | }, 114 | "source": [ 115 | "# How many different pre-training datasets?\n", 116 | "df[\"ds\"].value_counts()" 117 | ], 118 | "execution_count": null, 119 | "outputs": [ 120 | { 121 | "data": { 122 | "text/plain": [ 123 | "i21k 17238\n", 124 | "i1k 17136\n", 125 | "i21k_30 17135\n", 126 | "Name: ds, dtype: int64" 127 | ] 128 | }, 129 | "execution_count": 4, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": { 138 | "id": "-_AWCq3d_G2I" 139 | }, 140 | "source": [ 141 | "Filter based on the following criteria:\n", 142 | "\n", 143 | "* Models should be pre-trained on ImageNet-21k and fine-tuned on ImageNet-1k.\n", 144 | "* The final ImageNet-1k validation accuracy should be at least 75%. \n", 145 | "* The transfer resolution should be 224 $\\times$ 224." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/", 153 | "height": 632 154 | }, 155 | "id": "K7Dmm_aU7l2P", 156 | "outputId": "e3c17dd9-2983-4fa5-f2b4-df9515e1f510" 157 | }, 158 | "source": [ 159 | "i21k_i1k_models = df.query(\"ds=='i21k' & adapt_ds=='imagenet2012'\")\n", 160 | "models_ge_75 = i21k_i1k_models.query(\"adapt_final_test >= 0.75 & adapt_resolution==224\")\n", 161 | "models_ge_75.head()" 162 | ], 163 | "execution_count": null, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/html": [ 168 | "
\n", 169 | "\n", 182 | "\n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | "
namedsepochslraugwddosdbest_valfinal_val...adapt_dsadapt_lradapt_stepsadapt_resolutionadapt_final_valadapt_final_testparamsinfer_samples_per_secfilenameadapt_filename
5508R26+S/32i21k300.00.001light00.030.00.00.4653520.465049...imagenet20120.03200002240.8581020.7995436430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0...R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0...
5509R26+S/32i21k300.00.001light00.030.00.00.4653520.465049...imagenet20120.01200002240.8553700.8018236430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0...R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0...
5576R26+S/32i21k300.00.001medium20.030.10.10.4485740.448105...imagenet20120.03200002240.8303930.7963236430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0....R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0....
5577R26+S/32i21k300.00.001medium20.030.10.10.4485740.448105...imagenet20120.01200002240.8215740.7879636430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0....R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0....
5644R26+S/32i21k300.00.001light00.100.00.00.4783980.477715...imagenet20120.03200002240.8527940.8007436430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1...R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1...
\n", 332 | "

5 rows × 21 columns

\n", 333 | "
" 334 | ], 335 | "text/plain": [ 336 | " name ds epochs lr aug wd do sd best_val \\\n", 337 | "5508 R26+S/32 i21k 300.0 0.001 light0 0.03 0.0 0.0 0.465352 \n", 338 | "5509 R26+S/32 i21k 300.0 0.001 light0 0.03 0.0 0.0 0.465352 \n", 339 | "5576 R26+S/32 i21k 300.0 0.001 medium2 0.03 0.1 0.1 0.448574 \n", 340 | "5577 R26+S/32 i21k 300.0 0.001 medium2 0.03 0.1 0.1 0.448574 \n", 341 | "5644 R26+S/32 i21k 300.0 0.001 light0 0.10 0.0 0.0 0.478398 \n", 342 | "\n", 343 | " final_val ... adapt_ds adapt_lr adapt_steps adapt_resolution \\\n", 344 | "5508 0.465049 ... imagenet2012 0.03 20000 224 \n", 345 | "5509 0.465049 ... imagenet2012 0.01 20000 224 \n", 346 | "5576 0.448105 ... imagenet2012 0.03 20000 224 \n", 347 | "5577 0.448105 ... imagenet2012 0.01 20000 224 \n", 348 | "5644 0.477715 ... imagenet2012 0.03 20000 224 \n", 349 | "\n", 350 | " adapt_final_val adapt_final_test params infer_samples_per_sec \\\n", 351 | "5508 0.858102 0.79954 36430000.0 1814.25 \n", 352 | "5509 0.855370 0.80182 36430000.0 1814.25 \n", 353 | "5576 0.830393 0.79632 36430000.0 1814.25 \n", 354 | "5577 0.821574 0.78796 36430000.0 1814.25 \n", 355 | "5644 0.852794 0.80074 36430000.0 1814.25 \n", 356 | "\n", 357 | " filename \\\n", 358 | "5508 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n", 359 | "5509 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n", 360 | "5576 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n", 361 | "5577 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n", 362 | "5644 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1... \n", 363 | "\n", 364 | " adapt_filename \n", 365 | "5508 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n", 366 | "5509 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n", 367 | "5576 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n", 368 | "5577 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n", 369 | "5644 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.1... \n", 370 | "\n", 371 | "[5 rows x 21 columns]" 372 | ] 373 | }, 374 | "execution_count": 7, 375 | "metadata": {}, 376 | "output_type": "execute_result" 377 | } 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "colab": { 384 | "base_uri": "https://localhost:8080/" 385 | }, 386 | "id": "dEmy40f3pBxT", 387 | "outputId": "66d07a15-e092-4edc-a5a0-02544a4860d7" 388 | }, 389 | "source": [ 390 | "models_ge_75[\"name\"].value_counts()" 391 | ], 392 | "execution_count": null, 393 | "outputs": [ 394 | { 395 | "data": { 396 | "text/plain": [ 397 | "R26+S/32 56\n", 398 | "S/16 54\n", 399 | "R50+L/32 54\n", 400 | "B/16 54\n", 401 | "B/32 53\n", 402 | "L/16 52\n", 403 | "B/8 6\n", 404 | "Name: name, dtype: int64" 405 | ] 406 | }, 407 | "execution_count": 8, 408 | "metadata": {}, 409 | "output_type": "execute_result" 410 | } 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": { 416 | "id": "kOzlSwR-_G2K" 417 | }, 418 | "source": [ 419 | "Now, we first fetch the maximum accuracies with respect to a given model type and then we pick the underlying models. " 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "metadata": { 425 | "colab": { 426 | "base_uri": "https://localhost:8080/", 427 | "height": 1000 428 | }, 429 | "id": "F0CCk1yupOXC", 430 | "outputId": "47cd3859-fd71-471e-b4e6-b8a2a7bd33d7" 431 | }, 432 | "source": [ 433 | "best_scores_by_model_type = (\n", 434 | " models_ge_75.groupby(\"name\")[\"adapt_final_test\"].max().values\n", 435 | ")\n", 436 | "results = models_ge_75[\"adapt_final_test\"].apply(\n", 437 | " lambda x: x in best_scores_by_model_type\n", 438 | ")\n", 439 | "models_ge_75[results].sort_values(by=[\"adapt_final_test\"], ascending=False).head(10)" 440 | ], 441 | "execution_count": null, 442 | "outputs": [ 443 | { 444 | "data": { 445 | "text/html": [ 446 | "
\n", 447 | "\n", 460 | "\n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | "
namedsepochslraugwddosdbest_valfinal_val...adapt_dsadapt_lradapt_stepsadapt_resolutionadapt_final_valadapt_final_testparamsinfer_samples_per_secfilenameadapt_filename
50966B/8i21k300.00.001medium20.100.00.00.5214260.521006...imagenet20120.01200002240.8917420.85948NaNNaNB_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_...B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_...
26011L/16i21k300.00.001medium10.100.10.10.5122750.512275...imagenet20120.01200002240.9011080.85716304330000.0228.01L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do...L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do...
24854B/16i21k300.00.001medium20.030.00.00.5042580.503623...imagenet20120.03200002240.8824540.8401886570000.0658.56B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d...B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d...
26861R50+L/32i21k300.00.001medium10.100.10.10.5144530.513877...imagenet20120.01200002240.8979080.83784110950000.01046.83R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0....R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0....
6392R26+S/32i21k300.00.001light00.030.10.10.4778910.477471...imagenet20120.03200002240.8500620.8094436430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0...R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0...
5849R26+S/32i21k300.00.001medium20.100.00.00.4625880.462373...imagenet20120.01200002240.8379640.8046236430000.01814.25R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0....R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0....
12410S/16i21k300.00.001light10.030.00.00.4726760.472402...imagenet20120.03200002240.8413210.8046222050000.01508.35S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do...S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do...
22610B/32i21k300.00.001medium10.030.00.00.4737890.473525...imagenet20120.03200002240.8441310.7943688220000.03597.19B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d...B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d...
\n", 682 | "

8 rows × 21 columns

\n", 683 | "
" 684 | ], 685 | "text/plain": [ 686 | " name ds epochs lr aug wd do sd best_val \\\n", 687 | "50966 B/8 i21k 300.0 0.001 medium2 0.10 0.0 0.0 0.521426 \n", 688 | "26011 L/16 i21k 300.0 0.001 medium1 0.10 0.1 0.1 0.512275 \n", 689 | "24854 B/16 i21k 300.0 0.001 medium2 0.03 0.0 0.0 0.504258 \n", 690 | "26861 R50+L/32 i21k 300.0 0.001 medium1 0.10 0.1 0.1 0.514453 \n", 691 | "6392 R26+S/32 i21k 300.0 0.001 light0 0.03 0.1 0.1 0.477891 \n", 692 | "5849 R26+S/32 i21k 300.0 0.001 medium2 0.10 0.0 0.0 0.462588 \n", 693 | "12410 S/16 i21k 300.0 0.001 light1 0.03 0.0 0.0 0.472676 \n", 694 | "22610 B/32 i21k 300.0 0.001 medium1 0.03 0.0 0.0 0.473789 \n", 695 | "\n", 696 | " final_val ... adapt_ds adapt_lr adapt_steps adapt_resolution \\\n", 697 | "50966 0.521006 ... imagenet2012 0.01 20000 224 \n", 698 | "26011 0.512275 ... imagenet2012 0.01 20000 224 \n", 699 | "24854 0.503623 ... imagenet2012 0.03 20000 224 \n", 700 | "26861 0.513877 ... imagenet2012 0.01 20000 224 \n", 701 | "6392 0.477471 ... imagenet2012 0.03 20000 224 \n", 702 | "5849 0.462373 ... imagenet2012 0.01 20000 224 \n", 703 | "12410 0.472402 ... imagenet2012 0.03 20000 224 \n", 704 | "22610 0.473525 ... imagenet2012 0.03 20000 224 \n", 705 | "\n", 706 | " adapt_final_val adapt_final_test params infer_samples_per_sec \\\n", 707 | "50966 0.891742 0.85948 NaN NaN \n", 708 | "26011 0.901108 0.85716 304330000.0 228.01 \n", 709 | "24854 0.882454 0.84018 86570000.0 658.56 \n", 710 | "26861 0.897908 0.83784 110950000.0 1046.83 \n", 711 | "6392 0.850062 0.80944 36430000.0 1814.25 \n", 712 | "5849 0.837964 0.80462 36430000.0 1814.25 \n", 713 | "12410 0.841321 0.80462 22050000.0 1508.35 \n", 714 | "22610 0.844131 0.79436 88220000.0 3597.19 \n", 715 | "\n", 716 | " filename \\\n", 717 | "50966 B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_... \n", 718 | "26011 L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do... \n", 719 | "24854 B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d... \n", 720 | "26861 R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.... \n", 721 | "6392 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n", 722 | "5849 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n", 723 | "12410 S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do... \n", 724 | "22610 B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d... \n", 725 | "\n", 726 | " adapt_filename \n", 727 | "50966 B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_... \n", 728 | "26011 L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do... \n", 729 | "24854 B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-d... \n", 730 | "26861 R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.... \n", 731 | "6392 R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.0... \n", 732 | "5849 R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.... \n", 733 | "12410 S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do... \n", 734 | "22610 B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-d... \n", 735 | "\n", 736 | "[8 rows x 21 columns]" 737 | ] 738 | }, 739 | "execution_count": 9, 740 | "metadata": {}, 741 | "output_type": "execute_result" 742 | } 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "metadata": { 748 | "colab": { 749 | "base_uri": "https://localhost:8080/" 750 | }, 751 | "id": "kryppwagsWMg", 752 | "outputId": "c4469b79-e23d-4784-b02b-62ad67fb182b" 753 | }, 754 | "source": [ 755 | "models_ge_75[results].sort_values(by=[\"adapt_final_test\"], ascending=False).head(10)[\n", 756 | " \"adapt_filename\"\n", 757 | "].values.tolist()" 758 | ], 759 | "execution_count": null, 760 | "outputs": [ 761 | { 762 | "data": { 763 | "text/plain": [ 764 | "['B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224',\n", 765 | " 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224',\n", 766 | " 'B_16-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224',\n", 767 | " 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224',\n", 768 | " 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224',\n", 769 | " 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224',\n", 770 | " 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224',\n", 771 | " 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224']" 772 | ] 773 | }, 774 | "execution_count": 10, 775 | "metadata": {}, 776 | "output_type": "execute_result" 777 | } 778 | ] 779 | } 780 | ] 781 | } --------------------------------------------------------------------------------