├── README.md ├── LICENSE ├── DAN_Keras.ipynb └── DAN_keras_core.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ### Problem Statement : 2 | 3 | Inferring emotions from facial expressions is very critical for social communication and deficits in facial emotion recognition are a very important marker in the diagnosis of autism spectrum disorder. This project uses AI to help autistic individuals recognize emotions in facial expressions. 4 | 5 | 6 | ### Model : 7 | 8 | Keras Core and Keras implementation of the below paper 9 | [Distract Your Attention: Multi-head Cross Attention Network for Facial Expression Recognition](https://arxiv.org/pdf/2109.07270.pdf) 10 | 11 | ### Dataset : 12 | 13 | AffectNet is a database of facial expressions and contains more than 1M facial images collected from the Internet by querying three major search engines using 1250 emotion-related keywords in six different languages. The labeled emotions in the dataset are neutral, happy, angry, sad, fear, surprise, disgust, and contempt. 14 | 15 | [AffectNet](http://mohammadmahoor.com/affectnet/) 16 | 17 | Mediapipe Demo : https://codepen.io/Tensor-Girl/pen/ExdJmmP 18 | 19 | ### Intel Extension of TensorFlow 20 | 21 | Include the line pip install --upgrade intel-extension-for-tensorflow[cpu] in the beginning to use the Intel Extension of Tensorflow 22 | 23 | ### References : 24 | 25 | [Distract Your Attention: Multi-head Cross Attention Network for Facial Expression Recognition](https://arxiv.org/pdf/2109.07270.pdf) 26 | 27 | [Pytorch Implementation of DAN](https://github.com/yaoing/DAN) 28 | 29 | [Official Keras Core Documentation](https://keras.io/keras_core/) 30 | 31 | [Facial expression recognition as a candidate marker for autism spectrum disorder](https://molecularautism.biomedcentral.com/articles/10.1186/s13229-018-0187-7) 32 | 33 | ### My Advocacy in Autism 34 | 35 | #### NeuroAI 36 | 37 | https://humansofdata.atlan.com/2019/08/unravel-the-mystery-of-the-human-brain-at-neuroai/ 38 | 39 | #### Neurodiversity India Summit 40 | 41 | [2022](https://neuroaiworld.com/neurodiversity-india-summit-2022/) 42 | [2021](https://neuroaiworld.com/neurodiversity-india-summit-2021/) 43 | [2020](https://neuroaiworld.com/neurodiversity-india-summit-2020/) 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /DAN_Keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "## Downloading data from google cloud" 7 | ], 8 | "metadata": { 9 | "id": "IR7K14_Ov9YI" 10 | } 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "TyHKa3bNhty5", 20 | "outputId": "241bb2bb-467c-4601-e918-3129e9b2b608" 21 | }, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "--2023-07-19 06:54:36-- https://storage.googleapis.com/kerascvnlp_data/young-affectnet-hq.zip\n", 28 | "Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.170.128, 142.251.175.128, 172.253.118.128, ...\n", 29 | "Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.170.128|:443... connected.\n", 30 | "HTTP request sent, awaiting response... 200 OK\n", 31 | "Length: 5441294496 (5.1G) [application/zip]\n", 32 | "Saving to: ‘young-affectnet-hq.zip’\n", 33 | "\n", 34 | "young-affectnet-hq. 100%[===================>] 5.07G 18.9MB/s in 4m 44s \n", 35 | "\n", 36 | "2023-07-19 06:59:20 (18.3 MB/s) - ‘young-affectnet-hq.zip’ saved [5441294496/5441294496]\n", 37 | "\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "!wget https://storage.googleapis.com/kerascvnlp_data/young-affectnet-hq.zip" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "id": "-CRW5ufU3Qsj" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "!unzip -q /content/young-affectnet-hq.zip -d data/" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "source": [ 59 | "## Building the [DAN model ](https://arxiv.org/pdf/2109.07270.pdf)" 60 | ], 61 | "metadata": { 62 | "id": "-ZhCIAYowH-h" 63 | } 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "ALbPisTtWxPr", 73 | "outputId": "1e665d52-c401-45d8-a175-ee0f0602a56d" 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "(10, 8)\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "import tensorflow as tf\n", 86 | "from tensorflow.keras import Model\n", 87 | "from tensorflow.keras.layers import Layer\n", 88 | "from tensorflow.keras import Sequential\n", 89 | "import tensorflow.keras.layers as nn\n", 90 | "import keras\n", 91 | "\n", 92 | "class ChannelAttn(Layer):\n", 93 | " def __init__(self, c=512) -> None:\n", 94 | " super(ChannelAttn,self).__init__()\n", 95 | " self.gap = nn.AveragePooling2D(7)\n", 96 | " self.attention = Sequential([\n", 97 | " nn.Dense(32),\n", 98 | " nn.BatchNormalization(),\n", 99 | " nn.ReLU(),\n", 100 | " nn.Dense(c,activation='sigmoid')]\n", 101 | " )\n", 102 | "\n", 103 | " def call(self, x):\n", 104 | "\n", 105 | " x = self.gap(x)\n", 106 | " x = nn.Flatten()(x)\n", 107 | " y = self.attention(x)\n", 108 | " return x * y\n", 109 | "\n", 110 | "\n", 111 | "class SpatialAttn(Layer):\n", 112 | " def __init__(self, c=512):\n", 113 | " super(SpatialAttn,self).__init__()\n", 114 | " self.conv1x1 = Sequential([\n", 115 | " nn.Conv2D(256, 1),\n", 116 | " nn.BatchNormalization()]\n", 117 | " )\n", 118 | " self.conv_3x3 = Sequential([\n", 119 | " nn.ZeroPadding2D(padding=(1, 1)),\n", 120 | " nn.Conv2D(512, 3,1),\n", 121 | " nn.BatchNormalization()]\n", 122 | " )\n", 123 | " self.conv_1x3 = Sequential([\n", 124 | " nn.ZeroPadding2D(padding=(0, 1)),\n", 125 | " nn.Conv2D(512, (1,3)),\n", 126 | " nn.BatchNormalization()]\n", 127 | " )\n", 128 | " self.conv_3x1 = Sequential([\n", 129 | " nn.ZeroPadding2D(padding=(1, 0)),\n", 130 | " nn.Conv2D(512,(3,1)),\n", 131 | " nn.BatchNormalization()]\n", 132 | " )\n", 133 | " self.norm = nn.ReLU()\n", 134 | "\n", 135 | " def call(self, x) :\n", 136 | " y = self.conv1x1(x)\n", 137 | " y = self.norm(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y))\n", 138 | " y = tf.math.reduce_sum(y,axis=1, keepdims=True)\n", 139 | " return x*y\n", 140 | "\n", 141 | "\n", 142 | "class CrossAttnHead(Layer):\n", 143 | " def __init__(self, c=512):\n", 144 | " super(CrossAttnHead,self).__init__()\n", 145 | " self.sa = SpatialAttn(c)\n", 146 | " self.ca = ChannelAttn(c)\n", 147 | "\n", 148 | " def call(self, x):\n", 149 | " return self.ca(self.sa(x))\n", 150 | "\n", 151 | "\n", 152 | "\n", 153 | "class DAN(Model):\n", 154 | " def __init__(self, num_classes=8):\n", 155 | " super(DAN,self).__init__()\n", 156 | " self.mod = tf.keras.applications.ResNet50(\n", 157 | " include_top=False,\n", 158 | " weights=\"imagenet\",\n", 159 | " input_shape=(224,224,3)\n", 160 | " )\n", 161 | " self.mod.trainable= False\n", 162 | " self.num_head = 4\n", 163 | " self.hd = CrossAttnHead()\n", 164 | " self.hd=[]\n", 165 | " for i in range(self.num_head):\n", 166 | " self.hd.append(CrossAttnHead())\n", 167 | " self.features = nn.Conv2D(512, 1,padding='same')\n", 168 | " self.fc = nn.Dense(num_classes)\n", 169 | " self.bn = nn.BatchNormalization()\n", 170 | "\n", 171 | " def call(self, x) :\n", 172 | " x = self.mod(x)\n", 173 | " x=self.features(x)\n", 174 | " heads = []\n", 175 | " for h in self.hd:\n", 176 | " heads.append(h(x))\n", 177 | "\n", 178 | " heads = tf.transpose(tf.stack(heads),perm=(1,0,2))\n", 179 | " heads = tf.nn.log_softmax(heads, axis=1)\n", 180 | " out = self.bn(self.fc(tf.math.reduce_sum(heads,axis=1)))\n", 181 | " return out\n", 182 | "\n", 183 | "model = DAN()\n", 184 | "img = tf.random.normal(shape=[10, 224, 224, 3])\n", 185 | "preds = model(img)\n", 186 | "print(preds.shape)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "source": [ 192 | "creating the image dataloader using the keras-core utils function" 193 | ], 194 | "metadata": { 195 | "id": "wcVEXZpmwR1P" 196 | } 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": { 202 | "colab": { 203 | "base_uri": "https://localhost:8080/" 204 | }, 205 | "id": "y-lfM07X386v", 206 | "outputId": "4d0e46d8-7ce6-49f1-becb-a6c544951bef" 207 | }, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "Found 14648 images belonging to 8 classes.\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "train_datagen = tf.keras.preprocessing.image.ImageDataGenerator()\n", 219 | "train_generator = train_datagen.flow_from_directory(\n", 220 | " directory=\"data/\",\n", 221 | " target_size=(224, 224),\n", 222 | " batch_size=32,\n", 223 | " class_mode=\"categorical\",\n", 224 | " shuffle=True,\n", 225 | " seed=42\n", 226 | ")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "colab": { 234 | "background_save": true 235 | }, 236 | "id": "S-fYX5bf-Qhz" 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "model.compile(optimizer='adam',loss=keras.losses.CategoricalCrossentropy())\n", 241 | "model.fit(train_generator,epochs=1)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "source": [ 247 | "## Saving and loading the model using inbuilt functions of keras_core" 248 | ], 249 | "metadata": { 250 | "id": "rYr1Xx99vt3R" 251 | } 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": { 257 | "id": "lKxEstbn-zud" 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "model.save('weights/')" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": { 268 | "id": "dYqcWuqkOJW7" 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "pb = keras.models.load_model('weights/')" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": { 279 | "id": "d2Kb428rOgga" 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "pred = pb(img)\n", 284 | "pred.shape" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": { 291 | "id": "bzObYUiROqnJ" 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "!zip -r wgts.zip weights/" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "id": "KVUFuilbizMP" 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "!wget https://storage.googleapis.com/kerascvnlp_data/danwgts/wgts.zip\n", 307 | "!unzip wgts.zip" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": { 314 | "colab": { 315 | "base_uri": "https://localhost:8080/" 316 | }, 317 | "id": "OgLYD0Kd1Z_X", 318 | "outputId": "c963fe41-60f0-434b-e783-c2ea720236b6" 319 | }, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 326 | "Collecting tflite-support\n", 327 | " Downloading tflite-support-0.1.0a1.tar.gz (390 kB)\n", 328 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m390.3/390.3 kB\u001b[0m \u001b[31m26.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 329 | "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 330 | "Collecting pybind11>=2.4 (from tflite-support)\n", 331 | " Using cached pybind11-2.10.4-py3-none-any.whl (222 kB)\n", 332 | "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tflite-support) (1.4.0)\n", 333 | "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from tflite-support) (1.22.4)\n", 334 | "Building wheels for collected packages: tflite-support\n", 335 | " Building wheel for tflite-support (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 336 | " Created wheel for tflite-support: filename=tflite_support-0.1.0a1-cp310-cp310-linux_x86_64.whl size=5942401 sha256=c7e9b35b7492f34e69eee330b1b65e327b281488484e256ff208e30908d0ddbb\n", 337 | " Stored in directory: /root/.cache/pip/wheels/71/5c/da/9e5e661ec26e03ee57e69428d40fffbefe3c0aff649c55776d\n", 338 | "Successfully built tflite-support\n", 339 | "Installing collected packages: pybind11, tflite-support\n", 340 | "Successfully installed pybind11-2.10.4 tflite-support-0.1.0a1\n" 341 | ] 342 | } 343 | ], 344 | "source": [ 345 | "!pip install tflite-support" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": { 352 | "colab": { 353 | "base_uri": "https://localhost:8080/" 354 | }, 355 | "id": "7ODbDxBn2hKV", 356 | "outputId": "8c1e5577-b1ad-445d-f3fe-61b8c70b5431" 357 | }, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "2023-05-26 16:26:33.200042: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", 364 | "Finished populating metadata and associated file to the model:\n", 365 | "model.tflite\n", 366 | "The metadata json file has been saved to:\n", 367 | "/content/oiut/model.json\n", 368 | "The associated file that has been been packed to the model is:\n", 369 | "['labels.txt']\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "!python ./metadata_writer_for_image_classifier.py \\\n", 375 | " --model_file=model.tflite \\\n", 376 | " --label_file=labels.txt \\\n", 377 | " --export_directory=/content/oiut" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "colab": { 385 | "base_uri": "https://localhost:8080/" 386 | }, 387 | "id": "SrDY-KX8Pify", 388 | "outputId": "606c229d-fb65-47b8-9e32-a4920cb7c273" 389 | }, 390 | "outputs": [ 391 | { 392 | "name": "stderr", 393 | "output_type": "stream", 394 | "text": [ 395 | "WARNING:absl:Found untraced functions such as conv2d_41_layer_call_fn, conv2d_41_layer_call_and_return_conditional_losses, _jit_compiled_convolution_op, dense_21_layer_call_fn, dense_21_layer_call_and_return_conditional_losses while saving (showing 5 of 99). These functions will not be directly callable after loading.\n" 396 | ] 397 | } 398 | ], 399 | "source": [ 400 | "import tensorflow as tf\n", 401 | "\n", 402 | "# Convert the model\n", 403 | "pb = keras.models.load_model('weights/')\n", 404 | "converter = tf.lite.TFLiteConverter.from_keras_model(pb) # path to the SavedModel directory\n", 405 | "tflite_model = converter.convert()\n", 406 | "\n", 407 | "# Save the model.\n", 408 | "with open('model.tflite', 'wb') as f:\n", 409 | " f.write(tflite_model)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "id": "eRabRedAC7mo" 417 | }, 418 | "outputs": [], 419 | "source": [] 420 | } 421 | ], 422 | "metadata": { 423 | "accelerator": "GPU", 424 | "colab": { 425 | "provenance": [] 426 | }, 427 | "gpuClass": "standard", 428 | "kernelspec": { 429 | "display_name": "Python 3", 430 | "name": "python3" 431 | }, 432 | "language_info": { 433 | "name": "python" 434 | } 435 | }, 436 | "nbformat": 4, 437 | "nbformat_minor": 0 438 | } -------------------------------------------------------------------------------- /DAN_keras_core.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "## Downloading data from google cloud" 7 | ], 8 | "metadata": { 9 | "id": "IR7K14_Ov9YI" 10 | } 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "TyHKa3bNhty5", 20 | "outputId": "18be82ec-7136-4ea3-f5cc-0fee1ced6620" 21 | }, 22 | "outputs": [ 23 | { 24 | "output_type": "stream", 25 | "name": "stdout", 26 | "text": [ 27 | "--2023-07-24 17:29:19-- https://storage.googleapis.com/kerascvnlp_data/young-affectnet-hq.zip\n", 28 | "Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.117.128, 142.250.99.128, 142.250.107.128, ...\n", 29 | "Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.117.128|:443... connected.\n", 30 | "HTTP request sent, awaiting response... 200 OK\n", 31 | "Length: 5441294496 (5.1G) [application/zip]\n", 32 | "Saving to: ‘young-affectnet-hq.zip’\n", 33 | "\n", 34 | "young-affectnet-hq. 100%[===================>] 5.07G 177MB/s in 46s \n", 35 | "\n", 36 | "2023-07-24 17:30:06 (112 MB/s) - ‘young-affectnet-hq.zip’ saved [5441294496/5441294496]\n", 37 | "\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "!wget https://storage.googleapis.com/kerascvnlp_data/young-affectnet-hq.zip" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "id": "-CRW5ufU3Qsj" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "!unzip -q /content/young-affectnet-hq.zip -d data/" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "source": [ 59 | "## installing and importing required libraries" 60 | ], 61 | "metadata": { 62 | "id": "mZlQ6Ir3wCA0" 63 | } 64 | }, 65 | { 66 | "cell_type": "code", 67 | "source": [ 68 | "!pip install -q keras-core" 69 | ], 70 | "metadata": { 71 | "colab": { 72 | "base_uri": "https://localhost:8080/" 73 | }, 74 | "id": "rOtfVr7a9Hf6", 75 | "outputId": "4bf1838e-9232-42ac-afba-61b1aa29f661" 76 | }, 77 | "execution_count": null, 78 | "outputs": [ 79 | { 80 | "output_type": "stream", 81 | "name": "stdout", 82 | "text": [ 83 | "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/753.1 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.8/753.1 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━\u001b[0m \u001b[32m583.7/753.1 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m753.1/753.1 kB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 84 | "\u001b[?25h" 85 | ] 86 | } 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "source": [ 92 | "## Building the [DAN model ](https://arxiv.org/pdf/2109.07270.pdf)" 93 | ], 94 | "metadata": { 95 | "id": "-ZhCIAYowH-h" 96 | } 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "colab": { 103 | "base_uri": "https://localhost:8080/" 104 | }, 105 | "id": "ALbPisTtWxPr", 106 | "outputId": "3457e7f8-db47-435e-8ad1-0c4c5eab2333" 107 | }, 108 | "outputs": [ 109 | { 110 | "output_type": "stream", 111 | "name": "stdout", 112 | "text": [ 113 | "(10, 8)\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "import tensorflow as tf\n", 119 | "from keras_core import Model, Sequential\n", 120 | "from keras_core.layers import Layer\n", 121 | "import keras_core.layers as nn\n", 122 | "import keras_core as keras\n", 123 | "\n", 124 | "class ChannelAttn(Layer):\n", 125 | " def __init__(self, c=512) -> None:\n", 126 | " super(ChannelAttn,self).__init__()\n", 127 | " self.gap = nn.AveragePooling2D(7)\n", 128 | " self.attention = Sequential([\n", 129 | " nn.Dense(32),\n", 130 | " nn.BatchNormalization(),\n", 131 | " nn.ReLU(),\n", 132 | " nn.Dense(c,activation='sigmoid')]\n", 133 | " )\n", 134 | "\n", 135 | " def call(self, x):\n", 136 | "\n", 137 | " x = self.gap(x)\n", 138 | " x = nn.Flatten()(x)\n", 139 | " y = self.attention(x)\n", 140 | " return x * y\n", 141 | "\n", 142 | "\n", 143 | "class SpatialAttn(Layer):\n", 144 | " def __init__(self, c=512):\n", 145 | " super(SpatialAttn,self).__init__()\n", 146 | " self.conv1x1 = Sequential([\n", 147 | " nn.Conv2D(256, 1),\n", 148 | " nn.BatchNormalization()]\n", 149 | " )\n", 150 | " self.conv_3x3 = Sequential([\n", 151 | " nn.ZeroPadding2D(padding=(1, 1)),\n", 152 | " nn.Conv2D(512, 3,1),\n", 153 | " nn.BatchNormalization()]\n", 154 | " )\n", 155 | " self.conv_1x3 = Sequential([\n", 156 | " nn.ZeroPadding2D(padding=(0, 1)),\n", 157 | " nn.Conv2D(512, (1,3)),\n", 158 | " nn.BatchNormalization()]\n", 159 | " )\n", 160 | " self.conv_3x1 = Sequential([\n", 161 | " nn.ZeroPadding2D(padding=(1, 0)),\n", 162 | " nn.Conv2D(512,(3,1)),\n", 163 | " nn.BatchNormalization()]\n", 164 | " )\n", 165 | " self.norm = nn.ReLU()\n", 166 | "\n", 167 | " def call(self, x) :\n", 168 | " y = self.conv1x1(x)\n", 169 | " y = self.norm(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y))\n", 170 | " y = tf.math.reduce_sum(y,axis=1, keepdims=True)\n", 171 | " return x*y\n", 172 | "\n", 173 | "\n", 174 | "class CrossAttnHead(Layer):\n", 175 | " def __init__(self, c=512):\n", 176 | " super(CrossAttnHead,self).__init__()\n", 177 | " self.sa = SpatialAttn(c)\n", 178 | " self.ca = ChannelAttn(c)\n", 179 | "\n", 180 | " def call(self, x):\n", 181 | " return self.ca(self.sa(x))\n", 182 | "\n", 183 | "\n", 184 | "@keras.saving.register_keras_serializable(package='custom')\n", 185 | "class DAN(Model):\n", 186 | " def __init__(self, num_classes=8,trainable=True,dtype='float32'):\n", 187 | " super(DAN,self).__init__()\n", 188 | " self.mod = keras.applications.ResNet50(\n", 189 | " include_top=False,\n", 190 | " weights=\"imagenet\",\n", 191 | " input_shape=(224,224,3)\n", 192 | " )\n", 193 | " self.mod.trainable= False\n", 194 | " self.num_head = 4\n", 195 | " self.hd = CrossAttnHead()\n", 196 | " self.hd=[]\n", 197 | " for i in range(self.num_head):\n", 198 | " self.hd.append(CrossAttnHead())\n", 199 | " self.features = nn.Conv2D(512, 1,padding='same')\n", 200 | " self.fc = nn.Dense(num_classes)\n", 201 | " self.bn = nn.BatchNormalization()\n", 202 | "\n", 203 | " def call(self, x) :\n", 204 | " x = self.mod(x)\n", 205 | " x=self.features(x)\n", 206 | " heads = []\n", 207 | " for h in self.hd:\n", 208 | " heads.append(h(x))\n", 209 | "\n", 210 | " heads = tf.transpose(tf.stack(heads),perm=(1,0,2))\n", 211 | " heads = keras.ops.log_softmax(heads, axis=1)\n", 212 | " out = self.bn(self.fc(tf.math.reduce_sum(heads,axis=1)))\n", 213 | " return out\n", 214 | "\n", 215 | "model = DAN()\n", 216 | "img = tf.random.normal(shape=[10, 224, 224, 3])\n", 217 | "preds = model(img)\n", 218 | "print(preds.shape)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "source": [ 224 | "creating the image dataloader using the keras-core utils function" 225 | ], 226 | "metadata": { 227 | "id": "wcVEXZpmwR1P" 228 | } 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "colab": { 235 | "base_uri": "https://localhost:8080/" 236 | }, 237 | "id": "y-lfM07X386v", 238 | "outputId": "813829ba-e96d-4f0b-c49b-31d4a36da66a" 239 | }, 240 | "outputs": [ 241 | { 242 | "output_type": "stream", 243 | "name": "stdout", 244 | "text": [ 245 | "Found 14648 files belonging to 8 classes.\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "train_ds = keras.utils.image_dataset_from_directory(\n", 251 | " directory=\"data/\",\n", 252 | " labels='inferred',\n", 253 | " label_mode='categorical',\n", 254 | " batch_size=32,\n", 255 | " image_size=(224, 224))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": { 262 | "id": "S-fYX5bf-Qhz" 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "model.compile(optimizer='adam',loss=keras.losses.CategoricalCrossentropy())\n", 267 | "model.fit(train_ds,epochs=1)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "source": [ 273 | "## Saving and loading the model using inbuilt functions of keras_core" 274 | ], 275 | "metadata": { 276 | "id": "rYr1Xx99vt3R" 277 | } 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "id": "lKxEstbn-zud" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "model.save('weights.keras')" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": { 294 | "id": "dYqcWuqkOJW7", 295 | "colab": { 296 | "base_uri": "https://localhost:8080/" 297 | }, 298 | "outputId": "a3715156-4564-456f-fc66-8038c62cda0e" 299 | }, 300 | "outputs": [ 301 | { 302 | "output_type": "stream", 303 | "name": "stderr", 304 | "text": [ 305 | "/usr/local/lib/python3.10/dist-packages/keras_core/src/saving/saving_lib.py:338: UserWarning: Skipping variable loading for optimizer 'adam', because it has 190 variables whereas the saved optimizer has 2 variables. \n", 306 | " trackable.load_own_variables(weights_store.get(inner_path))\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "pb = keras.saving.load_model('weights.keras')" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": { 318 | "id": "d2Kb428rOgga", 319 | "colab": { 320 | "base_uri": "https://localhost:8080/" 321 | }, 322 | "outputId": "8ea79cca-e201-4e4f-be22-320af4af4105" 323 | }, 324 | "outputs": [ 325 | { 326 | "output_type": "execute_result", 327 | "data": { 328 | "text/plain": [ 329 | "TensorShape([10, 8])" 330 | ] 331 | }, 332 | "metadata": {}, 333 | "execution_count": 37 334 | } 335 | ], 336 | "source": [ 337 | "pred = pb(img)\n", 338 | "pred.shape" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": { 345 | "id": "bzObYUiROqnJ" 346 | }, 347 | "outputs": [], 348 | "source": [ 349 | "!zip -r wgts.zip weights.keras" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "metadata": { 356 | "id": "KVUFuilbizMP" 357 | }, 358 | "outputs": [], 359 | "source": [ 360 | "!wget https://storage.googleapis.com/kerascvnlp_data/danwgts/wgts.zip\n", 361 | "!unzip wgts.zip" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "colab": { 369 | "base_uri": "https://localhost:8080/" 370 | }, 371 | "id": "OgLYD0Kd1Z_X", 372 | "outputId": "c963fe41-60f0-434b-e783-c2ea720236b6" 373 | }, 374 | "outputs": [ 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 380 | "Collecting tflite-support\n", 381 | " Downloading tflite-support-0.1.0a1.tar.gz (390 kB)\n", 382 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m390.3/390.3 kB\u001b[0m \u001b[31m26.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 383 | "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 384 | "Collecting pybind11>=2.4 (from tflite-support)\n", 385 | " Using cached pybind11-2.10.4-py3-none-any.whl (222 kB)\n", 386 | "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tflite-support) (1.4.0)\n", 387 | "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from tflite-support) (1.22.4)\n", 388 | "Building wheels for collected packages: tflite-support\n", 389 | " Building wheel for tflite-support (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 390 | " Created wheel for tflite-support: filename=tflite_support-0.1.0a1-cp310-cp310-linux_x86_64.whl size=5942401 sha256=c7e9b35b7492f34e69eee330b1b65e327b281488484e256ff208e30908d0ddbb\n", 391 | " Stored in directory: /root/.cache/pip/wheels/71/5c/da/9e5e661ec26e03ee57e69428d40fffbefe3c0aff649c55776d\n", 392 | "Successfully built tflite-support\n", 393 | "Installing collected packages: pybind11, tflite-support\n", 394 | "Successfully installed pybind11-2.10.4 tflite-support-0.1.0a1\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "!pip install tflite-support" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": { 406 | "colab": { 407 | "base_uri": "https://localhost:8080/" 408 | }, 409 | "id": "7ODbDxBn2hKV", 410 | "outputId": "8c1e5577-b1ad-445d-f3fe-61b8c70b5431" 411 | }, 412 | "outputs": [ 413 | { 414 | "name": "stdout", 415 | "output_type": "stream", 416 | "text": [ 417 | "2023-05-26 16:26:33.200042: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", 418 | "Finished populating metadata and associated file to the model:\n", 419 | "model.tflite\n", 420 | "The metadata json file has been saved to:\n", 421 | "/content/oiut/model.json\n", 422 | "The associated file that has been been packed to the model is:\n", 423 | "['labels.txt']\n" 424 | ] 425 | } 426 | ], 427 | "source": [ 428 | "!python ./metadata_writer_for_image_classifier.py \\\n", 429 | " --model_file=model.tflite \\\n", 430 | " --label_file=labels.txt \\\n", 431 | " --export_directory=/content/oiut" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": { 438 | "colab": { 439 | "base_uri": "https://localhost:8080/" 440 | }, 441 | "id": "SrDY-KX8Pify", 442 | "outputId": "606c229d-fb65-47b8-9e32-a4920cb7c273" 443 | }, 444 | "outputs": [ 445 | { 446 | "name": "stderr", 447 | "output_type": "stream", 448 | "text": [ 449 | "WARNING:absl:Found untraced functions such as conv2d_41_layer_call_fn, conv2d_41_layer_call_and_return_conditional_losses, _jit_compiled_convolution_op, dense_21_layer_call_fn, dense_21_layer_call_and_return_conditional_losses while saving (showing 5 of 99). These functions will not be directly callable after loading.\n" 450 | ] 451 | } 452 | ], 453 | "source": [ 454 | "import tensorflow as tf\n", 455 | "\n", 456 | "# Convert the model\n", 457 | "pb = keras.saving.load_model('weights.keras')\n", 458 | "converter = tf.lite.TFLiteConverter.from_keras_model(pb) # path to the SavedModel directory\n", 459 | "tflite_model = converter.convert()\n", 460 | "\n", 461 | "# Save the model.\n", 462 | "with open('model.tflite', 'wb') as f:\n", 463 | " f.write(tflite_model)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "source": [ 469 | "## References\n", 470 | "\n", 471 | "[Distract Your Attention: Multi-head Cross Attention Network for Facial Expression Recognition](https://arxiv.org/pdf/2109.07270.pdf)\n", 472 | "\n", 473 | "[Pytorch Implementation of DAN](https://github.com/yaoing/DAN)\n", 474 | "\n", 475 | "[Official Keras Core Documentation](https://keras.io/keras_core/)" 476 | ], 477 | "metadata": { 478 | "id": "m68bfzJJ8gVS" 479 | } 480 | }, 481 | { 482 | "cell_type": "code", 483 | "source": [], 484 | "metadata": { 485 | "id": "ZnkiAi6O81D1" 486 | }, 487 | "execution_count": null, 488 | "outputs": [] 489 | } 490 | ], 491 | "metadata": { 492 | "accelerator": "GPU", 493 | "colab": { 494 | "provenance": [] 495 | }, 496 | "gpuClass": "standard", 497 | "kernelspec": { 498 | "display_name": "Python 3", 499 | "name": "python3" 500 | }, 501 | "language_info": { 502 | "name": "python" 503 | } 504 | }, 505 | "nbformat": 4, 506 | "nbformat_minor": 0 507 | } --------------------------------------------------------------------------------