├── README.md
├── LICENSE
└── convert_jax_weights_tf.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # BiT-jax2tf
2 | This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models
3 | are results of [2]. The original model weights come from [3].
4 |
5 | Huge thanks to [Willi Gierke](https://ch.linkedin.com/in/willi-gierke) (of Google) for helping with the porting.
6 |
7 | The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:
8 |
9 | | Model
Name | Input
Resolution | Classifier | Feature
Extractor |
10 | |:---------------: |:-------------------: |:--------------------------------------------------------------------------: |:--------------------------------------------------------------------------: |
11 | | BiT-ResNet152x2 | 384 | [Link](https://tfhub.dev/sayakpaul/bit_resnet152x2_384_classification/1) | [Link](https://tfhub.dev/sayakpaul/bit_r152x2_384_feature_extraction/1) |
12 | | BiT-ResNet152x2 | 224 | [Link](https://tfhub.dev/sayakpaul/bit_resnet152x2_224_classification/1) | [Link](https://tfhub.dev/sayakpaul/bit_r152x2_224_feature_extraction/1) |
13 | | BiT-ResNet50x1 | 224 | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_224_classification/1) | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_224_classification/1) |
14 | | BiT-ResNet50x1 | 160 | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_160_classification/1) | [Link](https://tfhub.dev/sayakpaul/distill_bit_r50x1_160_classification/1) |
15 |
16 | You could use the `convert_jax_weights_tf.ipynb` notebook to understand how model porting works between JAX and TensorFlow. There
17 | is also an experimental tool called `jax2tf` from the JAX team that you can find [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
18 |
19 | ## References
20 |
21 | [1] [Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.](https://arxiv.org/abs/1912.11370)
22 |
23 | [2] [Knowledge distillation: A good teacher is patient and consistent by Beyer et al.](https://arxiv.org/abs/2106.05237)
24 |
25 | [3] [BiT GitHub](https://github.com/google-research/big_transfer)
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/convert_jax_weights_tf.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "xpJr0WgojrSK"
17 | },
18 | "source": [
19 | "This notebook shows how to instantiate [BiT-ResNet models](https://arxiv.org/abs/1912.11370) in TensorFlow using code from the official repository [google-research/big_transfer](https://github.com/google-research/big_transfer) and load the original JAX weights into them. \n",
20 | "\n",
21 | "_**Note**: This notebook is authored by [Willi Gierke](https://ch.linkedin.com/in/willi-gierke) from Google. An initial version of the notebook was developed by Sayak Paul._"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {
28 | "colab": {
29 | "base_uri": "https://localhost:8080/"
30 | },
31 | "id": "VjXviEYKmV0T",
32 | "outputId": "f2f42bf7-42ba-4b96-a63a-35a1dc6d9b70"
33 | },
34 | "outputs": [
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "--2021-08-25 04:33:59-- https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz\n",
40 | "Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.170.128, 74.125.31.128, 173.194.210.128, ...\n",
41 | "Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.170.128|:443... connected.\n",
42 | "HTTP request sent, awaiting response... 200 OK\n",
43 | "Length: 945485848 (902M) [application/octet-stream]\n",
44 | "Saving to: ‘R152x2_T_384.npz’\n",
45 | "\n",
46 | "R152x2_T_384.npz 100%[===================>] 901.69M 68.7MB/s in 19s \n",
47 | "\n",
48 | "2021-08-25 04:34:20 (46.8 MB/s) - ‘R152x2_T_384.npz’ saved [945485848/945485848]\n",
49 | "\n",
50 | "Cloning into 'big_transfer'...\n",
51 | "remote: Enumerating objects: 31, done.\u001b[K\n",
52 | "remote: Counting objects: 100% (31/31), done.\u001b[K\n",
53 | "remote: Compressing objects: 100% (27/27), done.\u001b[K\n",
54 | "remote: Total 31 (delta 1), reused 23 (delta 1), pack-reused 0\u001b[K\n",
55 | "Unpacking objects: 100% (31/31), done.\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "# For demonstration purposes, we will be operating with a BiT-ResNet152x2 model.\n",
61 | "!wget https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz\n",
62 | "\n",
63 | "!git clone --depth 1 https://github.com/google-research/big_transfer\n",
64 | "\n",
65 | "import sys\n",
66 | "\n",
67 | "sys.path.append(\"big_transfer\")\n",
68 | "\n",
69 | "from bit_tf2 import models\n",
70 | "import tensorflow as tf\n",
71 | "import numpy as np\n",
72 | "\n",
73 | "from PIL import Image\n",
74 | "from io import BytesIO\n",
75 | "import requests\n",
76 | "\n",
77 | "\n",
78 | "def preprocess_image(image):\n",
79 | " image = np.array(image)\n",
80 | " # Resize to (384, 384).\n",
81 | " image_resized = tf.image.resize(image, (384, 384))\n",
82 | " image_resized = tf.cast(image_resized, tf.float32)\n",
83 | " image_resized = (image_resized - 127.5) / 127.5\n",
84 | " return tf.expand_dims(image_resized, 0).numpy()\n",
85 | "\n",
86 | "\n",
87 | "def load_image_from_url(url):\n",
88 | " \"\"\"Returns an image with shape [1, height, width, num_channels].\"\"\"\n",
89 | " response = requests.get(url)\n",
90 | " image = Image.open(BytesIO(response.content))\n",
91 | " image = preprocess_image(image)\n",
92 | " return image\n",
93 | "\n",
94 | "\n",
95 | "def assert_valid_variables(model):\n",
96 | " \"\"\"Raises an error if a weight only contains 0. or 1.\"\"\"\n",
97 | " for i, layer in enumerate(model.layers):\n",
98 | " print(f\"Layer {i}: {layer.name}\")\n",
99 | " if not \"layers\" in dir(layer):\n",
100 | " print(f\"{layer.name} has no .layers\")\n",
101 | " continue\n",
102 | " for j, sublayer in enumerate(layer.layers):\n",
103 | " print(f\"Sublayer {j}: {sublayer.name}\")\n",
104 | " for w in sublayer.get_weights():\n",
105 | " print(w.shape)\n",
106 | " if (w == 1.0).all() or (w == 0.0).all():\n",
107 | " raise RuntimeError(f\"PROBLEM in {layer.name}.{sublayer.name}: {w}\")"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {
114 | "colab": {
115 | "base_uri": "https://localhost:8080/"
116 | },
117 | "id": "S0xpekbRme1V",
118 | "outputId": "c6792e63-7e54-43f2-acf7-54dbc277473e"
119 | },
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "--2021-08-25 04:34:29-- https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt\n",
126 | "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.204.128, 172.217.203.128, 173.194.213.128, ...\n",
127 | "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.204.128|:443... connected.\n",
128 | "HTTP request sent, awaiting response... 200 OK\n",
129 | "Length: 21675 (21K) [text/plain]\n",
130 | "Saving to: ‘ilsvrc2012_wordnet_lemmas.txt’\n",
131 | "\n",
132 | "ilsvrc2012_wordnet_ 100%[===================>] 21.17K --.-KB/s in 0s \n",
133 | "\n",
134 | "2021-08-25 04:34:29 (112 MB/s) - ‘ilsvrc2012_wordnet_lemmas.txt’ saved [21675/21675]\n",
135 | "\n",
136 | "Model: \"resnet\"\n",
137 | "_________________________________________________________________\n",
138 | "Layer (type) Output Shape Param # \n",
139 | "=================================================================\n",
140 | "root_block (Sequential) (None, 96, 96, 128) 18816 \n",
141 | "_________________________________________________________________\n",
142 | "block1 (Sequential) (None, 96, 96, 512) 855808 \n",
143 | "_________________________________________________________________\n",
144 | "block2 (Sequential) (None, 48, 48, 1024) 9329664 \n",
145 | "_________________________________________________________________\n",
146 | "block3 (Sequential) (None, 24, 24, 2048) 162224128 \n",
147 | "_________________________________________________________________\n",
148 | "block4 (Sequential) (None, 12, 12, 4096) 59801600 \n",
149 | "_________________________________________________________________\n",
150 | "group_norm (GroupNormalizati multiple 8192 \n",
151 | "_________________________________________________________________\n",
152 | "re_lu_150 (ReLU) multiple 0 \n",
153 | "_________________________________________________________________\n",
154 | "global_average_pooling2d (Gl multiple 0 \n",
155 | "_________________________________________________________________\n",
156 | "head/dense (Dense) multiple 4097000 \n",
157 | "=================================================================\n",
158 | "Total params: 236,335,208\n",
159 | "Trainable params: 236,335,208\n",
160 | "Non-trainable params: 0\n",
161 | "_________________________________________________________________\n"
162 | ]
163 | }
164 | ],
165 | "source": [
166 | "# Load the labels.\n",
167 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt\n",
168 | "\n",
169 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n",
170 | " lines = f.readlines()\n",
171 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n",
172 | "\n",
173 | "# Load image (image provided is CC0 licensed)\n",
174 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n",
175 | "image = load_image_from_url(img_url)\n",
176 | "\n",
177 | "model = models.ResnetV2(\n",
178 | " num_units=(3, 8, 36, 3),\n",
179 | " num_outputs=1000,\n",
180 | " filters_factor=8,\n",
181 | " name=\"resnet\",\n",
182 | " trainable=True,\n",
183 | " dtype=tf.float32,\n",
184 | ")\n",
185 | "\n",
186 | "model.build((None, 384, 384, 3))\n",
187 | "model.summary()\n",
188 | "\n",
189 | "# Print smaller numpy arrays.\n",
190 | "np.set_printoptions(threshold=3, edgeitems=1)"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "metadata": {
197 | "id": "qC2XOUbPmZ_I"
198 | },
199 | "outputs": [],
200 | "source": [
201 | "# Load the weights.\n",
202 | "with open(\"R152x2_T_384.npz\", \"rb\") as f:\n",
203 | " params_tf = np.load(f)\n",
204 | " params_tf = dict(zip(params_tf.keys(), params_tf.values()))"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "id": "gvGqDZcnmg3D"
212 | },
213 | "outputs": [],
214 | "source": [
215 | "# Assign the weights of each block to the matching TF variables. Check params_tf for details.\n",
216 | "units_by_block_nr = {1: 3, 2: 8, 3: 36, 4: 3}\n",
217 | "\n",
218 | "for block_nr, units in units_by_block_nr.items():\n",
219 | " for unit_nr in range(units):\n",
220 | " model.layers[block_nr].layers[unit_nr]._unit_a.layers[0]._beta.assign(\n",
221 | " tf.Variable(\n",
222 | " params_tf[\n",
223 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/group_norm/beta\"\n",
224 | " ]\n",
225 | " )\n",
226 | " )\n",
227 | " model.layers[block_nr].layers[unit_nr]._unit_a.layers[0]._gamma.assign(\n",
228 | " tf.Variable(\n",
229 | " params_tf[\n",
230 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/group_norm/gamma\"\n",
231 | " ]\n",
232 | " )\n",
233 | " )\n",
234 | " var_name = (\n",
235 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/standardized_conv2d/kernel\"\n",
236 | " )\n",
237 | " if var_name in params_tf:\n",
238 | " model.layers[block_nr].layers[unit_nr]._unit_a_conv.kernel.assign(\n",
239 | " tf.Variable(params_tf[var_name])\n",
240 | " )\n",
241 | "\n",
242 | " var_name = f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/a/proj/standardized_conv2d/kernel\"\n",
243 | " if var_name in params_tf:\n",
244 | " model.layers[block_nr].layers[unit_nr]._proj.kernel.assign(\n",
245 | " tf.Variable(params_tf[var_name])\n",
246 | " )\n",
247 | "\n",
248 | " model.layers[block_nr].layers[unit_nr]._unit_b.layers[0]._beta.assign(\n",
249 | " tf.Variable(\n",
250 | " params_tf[\n",
251 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/b/group_norm/beta\"\n",
252 | " ]\n",
253 | " )\n",
254 | " )\n",
255 | " model.layers[block_nr].layers[unit_nr]._unit_b.layers[0]._gamma.assign(\n",
256 | " tf.Variable(\n",
257 | " params_tf[\n",
258 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/b/group_norm/gamma\"\n",
259 | " ]\n",
260 | " )\n",
261 | " )\n",
262 | " var_name = (\n",
263 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/b/standardized_conv2d/kernel\"\n",
264 | " )\n",
265 | " if var_name in params_tf:\n",
266 | " model.layers[block_nr].layers[unit_nr]._unit_b.layers[-1].kernel.assign(\n",
267 | " tf.Variable(params_tf[var_name])\n",
268 | " )\n",
269 | "\n",
270 | " model.layers[block_nr].layers[unit_nr]._unit_c.layers[0]._beta.assign(\n",
271 | " tf.Variable(\n",
272 | " params_tf[\n",
273 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/c/group_norm/beta\"\n",
274 | " ]\n",
275 | " )\n",
276 | " )\n",
277 | " model.layers[block_nr].layers[unit_nr]._unit_c.layers[0]._gamma.assign(\n",
278 | " tf.Variable(\n",
279 | " params_tf[\n",
280 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/c/group_norm/gamma\"\n",
281 | " ]\n",
282 | " )\n",
283 | " )\n",
284 | " var_name = (\n",
285 | " f\"resnet/block{block_nr}/unit{unit_nr + 1:02d}/c/standardized_conv2d/kernel\"\n",
286 | " )\n",
287 | " if var_name in params_tf:\n",
288 | " model.layers[block_nr].layers[unit_nr]._unit_c.layers[-1].kernel.assign(\n",
289 | " tf.Variable(params_tf[var_name])\n",
290 | " )"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {
297 | "colab": {
298 | "base_uri": "https://localhost:8080/"
299 | },
300 | "id": "oniPOxhlmuAn",
301 | "outputId": "44027624-fa17-4a6b-8081-0311a184dc28"
302 | },
303 | "outputs": [
304 | {
305 | "data": {
306 | "text/plain": [
307 | ""
308 | ]
309 | },
310 | "execution_count": 5,
311 | "metadata": {},
312 | "output_type": "execute_result"
313 | }
314 | ],
315 | "source": [
316 | "# Set the variables not included in the blocks.\n",
317 | "model.layers[0].layers[1].kernel.assign(\n",
318 | " tf.Variable(params_tf[\"resnet/root_block/standardized_conv2d/kernel\"])\n",
319 | ")\n",
320 | "\n",
321 | "model.layers[5]._gamma.assign(tf.Variable(params_tf[\"resnet/group_norm/gamma\"]))\n",
322 | "model.layers[5]._beta.assign(tf.Variable(params_tf[\"resnet/group_norm/beta\"]))\n",
323 | "\n",
324 | "model.layers[-1].kernel.assign(\n",
325 | " tf.Variable(params_tf[\"resnet/head/conv2d/kernel\"].reshape(4096, 1000))\n",
326 | ")\n",
327 | "model.layers[-1].bias.assign(tf.Variable(params_tf[\"resnet/head/conv2d/bias\"]))"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "metadata": {
334 | "id": "EEDWeTZ5m0vX"
335 | },
336 | "outputs": [],
337 | "source": [
338 | "# Verify that it works.\n",
339 | "logits = model.predict(image)\n",
340 | "s = tf.nn.softmax(logits, 1)\n",
341 | "assert (\n",
342 | " imagenet_int_to_str[tf.argmax(s, -1).numpy()[0]]\n",
343 | " == \"Indian_elephant, Elephas_maximus\"\n",
344 | ")"
345 | ]
346 | }
347 | ],
348 | "metadata": {
349 | "colab": {
350 | "collapsed_sections": [],
351 | "include_colab_link": true,
352 | "name": "convert_jax_weights_tf",
353 | "provenance": []
354 | },
355 | "kernelspec": {
356 | "display_name": "Python 3 (ipykernel)",
357 | "language": "python",
358 | "name": "python3"
359 | },
360 | "language_info": {
361 | "codemirror_mode": {
362 | "name": "ipython",
363 | "version": 3
364 | },
365 | "file_extension": ".py",
366 | "mimetype": "text/x-python",
367 | "name": "python",
368 | "nbconvert_exporter": "python",
369 | "pygments_lexer": "ipython3",
370 | "version": "3.8.2"
371 | }
372 | },
373 | "nbformat": 4,
374 | "nbformat_minor": 1
375 | }
376 |
--------------------------------------------------------------------------------