├── LICENSE ├── README.md └── Distillation_Toy_Example.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sayak Paul 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge-Distillation-in-Keras 2 | Demonstrates knowledge distillation (kd) for image-based models in Keras. To know more check out my blog post [Distilling Knowledge in Neural Networks](https://app.wandb.ai/authors/knowledge-distillation/reports/Distilling-Knowledge-in-Deep-Neural-Networks--VmlldzoyMjkxODk) that accompanies this repository. The blog post covers the following points - 3 | 4 | - What is softmax telling us? 5 | - Using the softmax information for teaching - Knowledge distillation 6 | - Loss functions in knowledge distillation 7 | - A few training recipes 8 | - Experimental results 9 | - Conclusion 10 | 11 | ## About the notebooks 12 | - `Distillation_Toy_Example.ipynb` - kd on the MNIST dataset 13 | - `Distillation_with_Transfer_Learning.ipynb` - kd (with the typical KD loss) on the Flowers dataset with a fine-tuned model 14 | - `Distillation_with_Transfer_Learning_MSE.ipynb` - kd (with an MSE loss) on the Flowers dataset with a fine-tuned model 15 | - `Effect_of_Data_Augmentation.ipynb` - studies the effect of data augmentation on kd 16 | 17 | ## Results 18 | Interact with the all the results [here](https://app.wandb.ai/authors/knowledge-distillation). 19 | 20 | ## Acknowledgements 21 | I am grateful to [Aakash Kumar Nain](https://twitter.com/A_K_Nain) for providing valuable feedback on the code. 22 | 23 |
24 | -------------------------------------------------------------------------------- /Distillation_Toy_Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Distillation Toy Example.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPsr2km9BDyh8GSWlvpA85Y", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "cBctyswACU4c", 33 | "colab_type": "code", 34 | "colab": {} 35 | }, 36 | "source": [ 37 | "# Imports\n", 38 | "import tensorflow as tf\n", 39 | "\n", 40 | "from tensorflow.keras import models\n", 41 | "from tensorflow.keras import layers\n", 42 | "\n", 43 | "tf.random.set_seed(666)" 44 | ], 45 | "execution_count": 1, 46 | "outputs": [] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "metadata": { 51 | "id": "YrhXamQACk6S", 52 | "colab_type": "code", 53 | "colab": { 54 | "base_uri": "https://localhost:8080/", 55 | "height": 34 56 | }, 57 | "outputId": "b2cea7cc-962f-44d4-bd25-7b8105504bd8" 58 | }, 59 | "source": [ 60 | "# Load the FashionMNIST dataset, scale the pixel values\n", 61 | "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()\n", 62 | "X_train = X_train/255.\n", 63 | "X_test = X_test/255.\n", 64 | "\n", 65 | "X_train.shape, X_test.shape, y_train.shape, y_test.shape" 66 | ], 67 | "execution_count": 2, 68 | "outputs": [ 69 | { 70 | "output_type": "execute_result", 71 | "data": { 72 | "text/plain": [ 73 | "((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))" 74 | ] 75 | }, 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "execution_count": 2 80 | } 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "ZYeuzIyPCor2", 87 | "colab_type": "code", 88 | "colab": {} 89 | }, 90 | "source": [ 91 | "# Change the pixel values to float32 and reshape input data\n", 92 | "X_train = X_train.astype(\"float32\").reshape(-1, 28, 28, 1)\n", 93 | "X_test = X_test.astype(\"float32\").reshape(-1, 28, 28, 1)" 94 | ], 95 | "execution_count": 3, 96 | "outputs": [] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "7R-PxhlfCqtu", 102 | "colab_type": "code", 103 | "colab": {} 104 | }, 105 | "source": [ 106 | "# Define utility function for building a basic shallow Convnet \n", 107 | "def get_teacher_model():\n", 108 | " model = models.Sequential()\n", 109 | " model.add(layers.Conv2D(16, (5, 5), activation=\"relu\",\n", 110 | " input_shape=(28, 28, 1)))\n", 111 | " model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n", 112 | " model.add(layers.Conv2D(32, (5, 5), activation=\"relu\"))\n", 113 | " model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n", 114 | " model.add(layers.Dropout(0.2))\n", 115 | " model.add(layers.Flatten())\n", 116 | " model.add(layers.Dense(128, activation=\"relu\"))\n", 117 | " model.add(layers.Dense(10))\n", 118 | " \n", 119 | " return model" 120 | ], 121 | "execution_count": 4, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "metadata": { 127 | "id": "l07x1M5ADDWt", 128 | "colab_type": "code", 129 | "colab": {} 130 | }, 131 | "source": [ 132 | "# Define loss function and optimizer\n", 133 | "loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", 134 | "optimizer = tf.keras.optimizers.Adam()" 135 | ], 136 | "execution_count": 5, 137 | "outputs": [] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "metadata": { 142 | "id": "lcBBDW2JDI6y", 143 | "colab_type": "code", 144 | "colab": { 145 | "base_uri": "https://localhost:8080/", 146 | "height": 374 147 | }, 148 | "outputId": "edac5d3c-b1fb-4d19-dc0d-1684bed5715a" 149 | }, 150 | "source": [ 151 | "# Prepare TF dataset\n", 152 | "train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)\n", 153 | "test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)\n", 154 | "\n", 155 | "# Train the teacher model\n", 156 | "teacher_model = get_teacher_model()\n", 157 | "teacher_model.compile(loss=loss_func, optimizer=optimizer, metrics=[\"accuracy\"])\n", 158 | "teacher_model.fit(train_ds,\n", 159 | " validation_data=test_ds,\n", 160 | " epochs=10)" 161 | ], 162 | "execution_count": 6, 163 | "outputs": [ 164 | { 165 | "output_type": "stream", 166 | "text": [ 167 | "Epoch 1/10\n", 168 | "938/938 [==============================] - 3s 3ms/step - loss: 0.5794 - accuracy: 0.7885 - val_loss: 0.4403 - val_accuracy: 0.8405\n", 169 | "Epoch 2/10\n", 170 | "938/938 [==============================] - 3s 3ms/step - loss: 0.3885 - accuracy: 0.8584 - val_loss: 0.3942 - val_accuracy: 0.8509\n", 171 | "Epoch 3/10\n", 172 | "938/938 [==============================] - 3s 3ms/step - loss: 0.3375 - accuracy: 0.8763 - val_loss: 0.3468 - val_accuracy: 0.8737\n", 173 | "Epoch 4/10\n", 174 | "938/938 [==============================] - 3s 3ms/step - loss: 0.3070 - accuracy: 0.8873 - val_loss: 0.3303 - val_accuracy: 0.8798\n", 175 | "Epoch 5/10\n", 176 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2877 - accuracy: 0.8945 - val_loss: 0.3120 - val_accuracy: 0.8846\n", 177 | "Epoch 6/10\n", 178 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2703 - accuracy: 0.8995 - val_loss: 0.2943 - val_accuracy: 0.8920\n", 179 | "Epoch 7/10\n", 180 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2544 - accuracy: 0.9056 - val_loss: 0.2818 - val_accuracy: 0.8960\n", 181 | "Epoch 8/10\n", 182 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2427 - accuracy: 0.9098 - val_loss: 0.2795 - val_accuracy: 0.8969\n", 183 | "Epoch 9/10\n", 184 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2327 - accuracy: 0.9141 - val_loss: 0.2767 - val_accuracy: 0.8998\n", 185 | "Epoch 10/10\n", 186 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2222 - accuracy: 0.9158 - val_loss: 0.2726 - val_accuracy: 0.9020\n" 187 | ], 188 | "name": "stdout" 189 | }, 190 | { 191 | "output_type": "execute_result", 192 | "data": { 193 | "text/plain": [ 194 | "" 195 | ] 196 | }, 197 | "metadata": { 198 | "tags": [] 199 | }, 200 | "execution_count": 6 201 | } 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "metadata": { 207 | "id": "OXADzI35Dw3g", 208 | "colab_type": "code", 209 | "colab": { 210 | "base_uri": "https://localhost:8080/", 211 | "height": 51 212 | }, 213 | "outputId": "b33dd53b-70ff-496b-b1ea-af4543f2b815" 214 | }, 215 | "source": [ 216 | "# Evaluate and serialize\n", 217 | "print(\"Test accuracy: {:.2f}\".format(teacher_model.evaluate(test_ds)[1]*100))\n", 218 | "teacher_model.save_weights(\"teacher_model.h5\")" 219 | ], 220 | "execution_count": 7, 221 | "outputs": [ 222 | { 223 | "output_type": "stream", 224 | "text": [ 225 | "157/157 [==============================] - 0s 2ms/step - loss: 0.2726 - accuracy: 0.9020\n", 226 | "Test accuracy: 90.20\n" 227 | ], 228 | "name": "stdout" 229 | } 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "metadata": { 235 | "id": "shnrhMFQFKwZ", 236 | "colab_type": "code", 237 | "colab": {} 238 | }, 239 | "source": [ 240 | "# Student model utility\n", 241 | "def get_student_model():\n", 242 | " model = models.Sequential()\n", 243 | " model.add(layers.Input(shape=(28, 28, 1)))\n", 244 | " model.add(layers.Flatten())\n", 245 | " model.add(layers.Dense(48, activation=\"relu\"))\n", 246 | " model.add(layers.Dense(10))\n", 247 | " \n", 248 | " return model" 249 | ], 250 | "execution_count": 8, 251 | "outputs": [] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "metadata": { 256 | "id": "dPFOtO4mGLIr", 257 | "colab_type": "code", 258 | "colab": {} 259 | }, 260 | "source": [ 261 | "# Credits: https://github.com/google-research/simclr/blob/master/colabs/distillation_self_training.ipynb\n", 262 | "def get_kd_loss(student_logits, teacher_logits, temperature=0.5):\n", 263 | " teacher_probs = tf.nn.softmax(teacher_logits / temperature)\n", 264 | " kd_loss = tf.compat.v1.losses.softmax_cross_entropy(\n", 265 | " teacher_probs, student_logits / temperature, temperature**2)\n", 266 | " return kd_loss" 267 | ], 268 | "execution_count": 9, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "KDZ5DWUeGkK2", 275 | "colab_type": "code", 276 | "colab": {} 277 | }, 278 | "source": [ 279 | "# Model, optimizer\n", 280 | "student_model = get_student_model()\n", 281 | "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n", 282 | "\n", 283 | "# Average the loss across the batch size within an epoch\n", 284 | "train_loss = tf.keras.metrics.Mean(name=\"train_loss\")\n", 285 | "valid_loss = tf.keras.metrics.Mean(name=\"test_loss\")\n", 286 | "\n", 287 | "# Specify the performance metric\n", 288 | "train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name=\"train_acc\")\n", 289 | "valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name=\"valid_acc\")" 290 | ], 291 | "execution_count": 16, 292 | "outputs": [] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "metadata": { 297 | "id": "5w1sCCqQGeTe", 298 | "colab_type": "code", 299 | "colab": {} 300 | }, 301 | "source": [ 302 | "# Train utils\n", 303 | "@tf.function\n", 304 | "def model_train(images, labels, teacher_model, \n", 305 | " student_model, optimizer, temperature):\n", 306 | " teacher_logits = teacher_model(images)\n", 307 | "\n", 308 | " with tf.GradientTape() as tape:\n", 309 | " student_logits = student_model(images)\n", 310 | " loss = get_kd_loss(student_logits, teacher_logits, temperature)\n", 311 | " \n", 312 | " gradients = tape.gradient(loss, student_model.trainable_variables)\n", 313 | " optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))\n", 314 | "\n", 315 | " train_loss(loss)\n", 316 | " train_acc(labels, tf.nn.softmax(student_logits))" 317 | ], 318 | "execution_count": 17, 319 | "outputs": [] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "metadata": { 324 | "id": "qXjapT-hHeP1", 325 | "colab_type": "code", 326 | "colab": {} 327 | }, 328 | "source": [ 329 | "# Validation utils\n", 330 | "@tf.function\n", 331 | "def model_validate(images, labels, teacher_model, \n", 332 | " student_model, temperature):\n", 333 | " teacher_logits = teacher_model(images)\n", 334 | "\n", 335 | " student_logits = student_model(images)\n", 336 | " loss = get_kd_loss(student_logits, teacher_logits, temperature)\n", 337 | "\n", 338 | " valid_loss(loss)\n", 339 | " valid_acc(labels, tf.nn.softmax(student_logits))" 340 | ], 341 | "execution_count": 18, 342 | "outputs": [] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "metadata": { 347 | "id": "ph4r4J_zHqFE", 348 | "colab_type": "code", 349 | "colab": {} 350 | }, 351 | "source": [ 352 | "# Tie everything together\n", 353 | "def train_model(epochs, teacher_model, student_model, optimizer, temperature=0.5):\n", 354 | " for epoch in range(epochs):\n", 355 | " for (images, labels) in train_ds:\n", 356 | " model_train(images, labels, teacher_model, student_model, optimizer, temperature)\n", 357 | "\n", 358 | " for (images, labels) in test_ds:\n", 359 | " model_validate(images, labels, teacher_model, student_model, temperature)\n", 360 | " \n", 361 | " (loss, acc) = train_loss.result(), train_acc.result()\n", 362 | " (val_loss, val_acc) = valid_loss.result(), valid_acc.result()\n", 363 | " \n", 364 | " train_loss.reset_states(), train_acc.reset_states()\n", 365 | " valid_loss.reset_states(), valid_acc.reset_states()\n", 366 | " \n", 367 | " template = \"Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}\"\n", 368 | " print (template.format(epoch+1,\n", 369 | " loss,\n", 370 | " acc,\n", 371 | " val_loss,\n", 372 | " val_acc))\n", 373 | " \n", 374 | " \n", 375 | " return teacher_model, student_model" 376 | ], 377 | "execution_count": 19, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "O-breyl1dwNR", 384 | "colab_type": "code", 385 | "colab": { 386 | "base_uri": "https://localhost:8080/", 387 | "height": 187 388 | }, 389 | "outputId": "78143bf4-f890-4d72-bbef-3003b0bcf627" 390 | }, 391 | "source": [ 392 | "_, student_model = train_model(10, teacher_model, student_model, optimizer)" 393 | ], 394 | "execution_count": 20, 395 | "outputs": [ 396 | { 397 | "output_type": "stream", 398 | "text": [ 399 | "Epoch 1, loss: 0.116, acc: 0.816, val_loss: 0.097, val_acc: 0.825\n", 400 | "Epoch 2, loss: 0.091, acc: 0.848, val_loss: 0.091, val_acc: 0.838\n", 401 | "Epoch 3, loss: 0.086, acc: 0.853, val_loss: 0.088, val_acc: 0.841\n", 402 | "Epoch 4, loss: 0.084, acc: 0.857, val_loss: 0.086, val_acc: 0.846\n", 403 | "Epoch 5, loss: 0.082, acc: 0.858, val_loss: 0.089, val_acc: 0.838\n", 404 | "Epoch 6, loss: 0.081, acc: 0.861, val_loss: 0.085, val_acc: 0.848\n", 405 | "Epoch 7, loss: 0.080, acc: 0.862, val_loss: 0.088, val_acc: 0.840\n", 406 | "Epoch 8, loss: 0.079, acc: 0.863, val_loss: 0.092, val_acc: 0.838\n", 407 | "Epoch 9, loss: 0.078, acc: 0.864, val_loss: 0.085, val_acc: 0.850\n", 408 | "Epoch 10, loss: 0.078, acc: 0.864, val_loss: 0.086, val_acc: 0.845\n" 409 | ], 410 | "name": "stdout" 411 | } 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "id": "H0DHWweqcqIJ", 418 | "colab_type": "text" 419 | }, 420 | "source": [ 421 | "This can be further improved with longer training time and more careful hyperparameter tuning. " 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "metadata": { 427 | "id": "fmLLgmpybLYi", 428 | "colab_type": "code", 429 | "colab": {} 430 | }, 431 | "source": [ 432 | "# Serialize\n", 433 | "student_model.save_weights(\"student_model.h5\")" 434 | ], 435 | "execution_count": 21, 436 | "outputs": [] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "metadata": { 441 | "id": "rJeZK9enJct9", 442 | "colab_type": "code", 443 | "colab": { 444 | "base_uri": "https://localhost:8080/", 445 | "height": 51 446 | }, 447 | "outputId": "765a834b-ac7d-4818-ed86-b2953c678d17" 448 | }, 449 | "source": [ 450 | "# Investigate the sizes\n", 451 | "!ls -lh *.h5" 452 | ], 453 | "execution_count": 22, 454 | "outputs": [ 455 | { 456 | "output_type": "stream", 457 | "text": [ 458 | "-rw-r--r-- 1 root root 163K Aug 31 07:47 student_model.h5\n", 459 | "-rw-r--r-- 1 root root 335K Aug 31 07:44 teacher_model.h5\n" 460 | ], 461 | "name": "stdout" 462 | } 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": { 468 | "id": "SNfgaGNncnSt", 469 | "colab_type": "text" 470 | }, 471 | "source": [ 472 | "Let's check the total number of trainable params." 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "metadata": { 478 | "id": "cSHNvya8cYLP", 479 | "colab_type": "code", 480 | "colab": { 481 | "base_uri": "https://localhost:8080/", 482 | "height": 425 483 | }, 484 | "outputId": "43bb53d7-5e8f-4d9a-bfdb-77e278d9ff20" 485 | }, 486 | "source": [ 487 | "teacher_model.summary()" 488 | ], 489 | "execution_count": 23, 490 | "outputs": [ 491 | { 492 | "output_type": "stream", 493 | "text": [ 494 | "Model: \"sequential\"\n", 495 | "_________________________________________________________________\n", 496 | "Layer (type) Output Shape Param # \n", 497 | "=================================================================\n", 498 | "conv2d (Conv2D) (None, 24, 24, 16) 416 \n", 499 | "_________________________________________________________________\n", 500 | "max_pooling2d (MaxPooling2D) (None, 12, 12, 16) 0 \n", 501 | "_________________________________________________________________\n", 502 | "conv2d_1 (Conv2D) (None, 8, 8, 32) 12832 \n", 503 | "_________________________________________________________________\n", 504 | "max_pooling2d_1 (MaxPooling2 (None, 4, 4, 32) 0 \n", 505 | "_________________________________________________________________\n", 506 | "dropout (Dropout) (None, 4, 4, 32) 0 \n", 507 | "_________________________________________________________________\n", 508 | "flatten (Flatten) (None, 512) 0 \n", 509 | "_________________________________________________________________\n", 510 | "dense (Dense) (None, 128) 65664 \n", 511 | "_________________________________________________________________\n", 512 | "dense_1 (Dense) (None, 10) 1290 \n", 513 | "=================================================================\n", 514 | "Total params: 80,202\n", 515 | "Trainable params: 80,202\n", 516 | "Non-trainable params: 0\n", 517 | "_________________________________________________________________\n" 518 | ], 519 | "name": "stdout" 520 | } 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "metadata": { 526 | "id": "T0-Y1gpDccZ_", 527 | "colab_type": "code", 528 | "colab": { 529 | "base_uri": "https://localhost:8080/", 530 | "height": 255 531 | }, 532 | "outputId": "d7a27cdb-13b4-4736-a1a0-a8a336357b30" 533 | }, 534 | "source": [ 535 | "student_model.summary()" 536 | ], 537 | "execution_count": 24, 538 | "outputs": [ 539 | { 540 | "output_type": "stream", 541 | "text": [ 542 | "Model: \"sequential_2\"\n", 543 | "_________________________________________________________________\n", 544 | "Layer (type) Output Shape Param # \n", 545 | "=================================================================\n", 546 | "flatten_2 (Flatten) (None, 784) 0 \n", 547 | "_________________________________________________________________\n", 548 | "dense_4 (Dense) (None, 48) 37680 \n", 549 | "_________________________________________________________________\n", 550 | "dense_5 (Dense) (None, 10) 490 \n", 551 | "=================================================================\n", 552 | "Total params: 38,170\n", 553 | "Trainable params: 38,170\n", 554 | "Non-trainable params: 0\n", 555 | "_________________________________________________________________\n" 556 | ], 557 | "name": "stdout" 558 | } 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": { 564 | "id": "AzC3KhO_J42N", 565 | "colab_type": "text" 566 | }, 567 | "source": [ 568 | "Further size decrease is possible with TFLite. " 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "metadata": { 574 | "id": "Z8d0R_ypVp8y", 575 | "colab_type": "code", 576 | "colab": {} 577 | }, 578 | "source": [ 579 | "# Credits: https://www.tensorflow.org/lite/performance/post_training_quant\n", 580 | "\n", 581 | "def representative_data_gen():\n", 582 | " for input_value in tf.data.Dataset.from_tensor_slices(X_train).batch(1).take(100):\n", 583 | " yield [input_value]\n", 584 | "\n", 585 | "def convert_to_tflite(model, tflite_file):\n", 586 | " converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", 587 | " converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", 588 | " converter.representative_dataset = representative_data_gen\n", 589 | " converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n", 590 | " converter.inference_input_type = tf.int8\n", 591 | " converter.inference_output_type = tf.int8\n", 592 | " tflite_quant_model = converter.convert()\n", 593 | "\n", 594 | " open(tflite_file, 'wb').write(tflite_quant_model)" 595 | ], 596 | "execution_count": 25, 597 | "outputs": [] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "metadata": { 602 | "id": "bZgxSge7Y3hU", 603 | "colab_type": "code", 604 | "colab": { 605 | "base_uri": "https://localhost:8080/", 606 | "height": 190 607 | }, 608 | "outputId": "71470926-d629-4bdb-f207-a6ed22a6f599" 609 | }, 610 | "source": [ 611 | "convert_to_tflite(teacher_model, \"teacher.tflite\")\n", 612 | "convert_to_tflite(student_model, \"student.tflite\")" 613 | ], 614 | "execution_count": 26, 615 | "outputs": [ 616 | { 617 | "output_type": "stream", 618 | "text": [ 619 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n", 620 | "Instructions for updating:\n", 621 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", 622 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", 623 | "Instructions for updating:\n", 624 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", 625 | "INFO:tensorflow:Assets written to: /tmp/tmp5020kbxi/assets\n", 626 | "INFO:tensorflow:Assets written to: /tmp/tmp2t19bpk6/assets\n" 627 | ], 628 | "name": "stdout" 629 | }, 630 | { 631 | "output_type": "stream", 632 | "text": [ 633 | "INFO:tensorflow:Assets written to: /tmp/tmp2t19bpk6/assets\n" 634 | ], 635 | "name": "stderr" 636 | } 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "metadata": { 642 | "id": "-f8eGqRtZA-w", 643 | "colab_type": "code", 644 | "colab": { 645 | "base_uri": "https://localhost:8080/", 646 | "height": 51 647 | }, 648 | "outputId": "2c34b9f9-b22a-4e2f-9385-893c778ed2ea" 649 | }, 650 | "source": [ 651 | "!ls -lh *.tflite" 652 | ], 653 | "execution_count": 27, 654 | "outputs": [ 655 | { 656 | "output_type": "stream", 657 | "text": [ 658 | "-rw-r--r-- 1 root root 40K Aug 31 07:48 student.tflite\n", 659 | "-rw-r--r-- 1 root root 85K Aug 31 07:48 teacher.tflite\n" 660 | ], 661 | "name": "stdout" 662 | } 663 | ] 664 | } 665 | ] 666 | } 667 | --------------------------------------------------------------------------------