├── LICENSE
├── README.md
└── Distillation_Toy_Example.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Sayak Paul
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge-Distillation-in-Keras
2 | Demonstrates knowledge distillation (kd) for image-based models in Keras. To know more check out my blog post [Distilling Knowledge in Neural Networks](https://app.wandb.ai/authors/knowledge-distillation/reports/Distilling-Knowledge-in-Deep-Neural-Networks--VmlldzoyMjkxODk) that accompanies this repository. The blog post covers the following points -
3 |
4 | - What is softmax telling us?
5 | - Using the softmax information for teaching - Knowledge distillation
6 | - Loss functions in knowledge distillation
7 | - A few training recipes
8 | - Experimental results
9 | - Conclusion
10 |
11 | ## About the notebooks
12 | - `Distillation_Toy_Example.ipynb` - kd on the MNIST dataset
13 | - `Distillation_with_Transfer_Learning.ipynb` - kd (with the typical KD loss) on the Flowers dataset with a fine-tuned model
14 | - `Distillation_with_Transfer_Learning_MSE.ipynb` - kd (with an MSE loss) on the Flowers dataset with a fine-tuned model
15 | - `Effect_of_Data_Augmentation.ipynb` - studies the effect of data augmentation on kd
16 |
17 | ## Results
18 | Interact with the all the results [here](https://app.wandb.ai/authors/knowledge-distillation).
19 |
20 | ## Acknowledgements
21 | I am grateful to [Aakash Kumar Nain](https://twitter.com/A_K_Nain) for providing valuable feedback on the code.
22 |
23 |

24 |
--------------------------------------------------------------------------------
/Distillation_Toy_Example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Distillation Toy Example.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPsr2km9BDyh8GSWlvpA85Y",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "cBctyswACU4c",
33 | "colab_type": "code",
34 | "colab": {}
35 | },
36 | "source": [
37 | "# Imports\n",
38 | "import tensorflow as tf\n",
39 | "\n",
40 | "from tensorflow.keras import models\n",
41 | "from tensorflow.keras import layers\n",
42 | "\n",
43 | "tf.random.set_seed(666)"
44 | ],
45 | "execution_count": 1,
46 | "outputs": []
47 | },
48 | {
49 | "cell_type": "code",
50 | "metadata": {
51 | "id": "YrhXamQACk6S",
52 | "colab_type": "code",
53 | "colab": {
54 | "base_uri": "https://localhost:8080/",
55 | "height": 34
56 | },
57 | "outputId": "b2cea7cc-962f-44d4-bd25-7b8105504bd8"
58 | },
59 | "source": [
60 | "# Load the FashionMNIST dataset, scale the pixel values\n",
61 | "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()\n",
62 | "X_train = X_train/255.\n",
63 | "X_test = X_test/255.\n",
64 | "\n",
65 | "X_train.shape, X_test.shape, y_train.shape, y_test.shape"
66 | ],
67 | "execution_count": 2,
68 | "outputs": [
69 | {
70 | "output_type": "execute_result",
71 | "data": {
72 | "text/plain": [
73 | "((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))"
74 | ]
75 | },
76 | "metadata": {
77 | "tags": []
78 | },
79 | "execution_count": 2
80 | }
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "metadata": {
86 | "id": "ZYeuzIyPCor2",
87 | "colab_type": "code",
88 | "colab": {}
89 | },
90 | "source": [
91 | "# Change the pixel values to float32 and reshape input data\n",
92 | "X_train = X_train.astype(\"float32\").reshape(-1, 28, 28, 1)\n",
93 | "X_test = X_test.astype(\"float32\").reshape(-1, 28, 28, 1)"
94 | ],
95 | "execution_count": 3,
96 | "outputs": []
97 | },
98 | {
99 | "cell_type": "code",
100 | "metadata": {
101 | "id": "7R-PxhlfCqtu",
102 | "colab_type": "code",
103 | "colab": {}
104 | },
105 | "source": [
106 | "# Define utility function for building a basic shallow Convnet \n",
107 | "def get_teacher_model():\n",
108 | " model = models.Sequential()\n",
109 | " model.add(layers.Conv2D(16, (5, 5), activation=\"relu\",\n",
110 | " input_shape=(28, 28, 1)))\n",
111 | " model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
112 | " model.add(layers.Conv2D(32, (5, 5), activation=\"relu\"))\n",
113 | " model.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
114 | " model.add(layers.Dropout(0.2))\n",
115 | " model.add(layers.Flatten())\n",
116 | " model.add(layers.Dense(128, activation=\"relu\"))\n",
117 | " model.add(layers.Dense(10))\n",
118 | " \n",
119 | " return model"
120 | ],
121 | "execution_count": 4,
122 | "outputs": []
123 | },
124 | {
125 | "cell_type": "code",
126 | "metadata": {
127 | "id": "l07x1M5ADDWt",
128 | "colab_type": "code",
129 | "colab": {}
130 | },
131 | "source": [
132 | "# Define loss function and optimizer\n",
133 | "loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
134 | "optimizer = tf.keras.optimizers.Adam()"
135 | ],
136 | "execution_count": 5,
137 | "outputs": []
138 | },
139 | {
140 | "cell_type": "code",
141 | "metadata": {
142 | "id": "lcBBDW2JDI6y",
143 | "colab_type": "code",
144 | "colab": {
145 | "base_uri": "https://localhost:8080/",
146 | "height": 374
147 | },
148 | "outputId": "edac5d3c-b1fb-4d19-dc0d-1684bed5715a"
149 | },
150 | "source": [
151 | "# Prepare TF dataset\n",
152 | "train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)\n",
153 | "test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)\n",
154 | "\n",
155 | "# Train the teacher model\n",
156 | "teacher_model = get_teacher_model()\n",
157 | "teacher_model.compile(loss=loss_func, optimizer=optimizer, metrics=[\"accuracy\"])\n",
158 | "teacher_model.fit(train_ds,\n",
159 | " validation_data=test_ds,\n",
160 | " epochs=10)"
161 | ],
162 | "execution_count": 6,
163 | "outputs": [
164 | {
165 | "output_type": "stream",
166 | "text": [
167 | "Epoch 1/10\n",
168 | "938/938 [==============================] - 3s 3ms/step - loss: 0.5794 - accuracy: 0.7885 - val_loss: 0.4403 - val_accuracy: 0.8405\n",
169 | "Epoch 2/10\n",
170 | "938/938 [==============================] - 3s 3ms/step - loss: 0.3885 - accuracy: 0.8584 - val_loss: 0.3942 - val_accuracy: 0.8509\n",
171 | "Epoch 3/10\n",
172 | "938/938 [==============================] - 3s 3ms/step - loss: 0.3375 - accuracy: 0.8763 - val_loss: 0.3468 - val_accuracy: 0.8737\n",
173 | "Epoch 4/10\n",
174 | "938/938 [==============================] - 3s 3ms/step - loss: 0.3070 - accuracy: 0.8873 - val_loss: 0.3303 - val_accuracy: 0.8798\n",
175 | "Epoch 5/10\n",
176 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2877 - accuracy: 0.8945 - val_loss: 0.3120 - val_accuracy: 0.8846\n",
177 | "Epoch 6/10\n",
178 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2703 - accuracy: 0.8995 - val_loss: 0.2943 - val_accuracy: 0.8920\n",
179 | "Epoch 7/10\n",
180 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2544 - accuracy: 0.9056 - val_loss: 0.2818 - val_accuracy: 0.8960\n",
181 | "Epoch 8/10\n",
182 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2427 - accuracy: 0.9098 - val_loss: 0.2795 - val_accuracy: 0.8969\n",
183 | "Epoch 9/10\n",
184 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2327 - accuracy: 0.9141 - val_loss: 0.2767 - val_accuracy: 0.8998\n",
185 | "Epoch 10/10\n",
186 | "938/938 [==============================] - 3s 3ms/step - loss: 0.2222 - accuracy: 0.9158 - val_loss: 0.2726 - val_accuracy: 0.9020\n"
187 | ],
188 | "name": "stdout"
189 | },
190 | {
191 | "output_type": "execute_result",
192 | "data": {
193 | "text/plain": [
194 | ""
195 | ]
196 | },
197 | "metadata": {
198 | "tags": []
199 | },
200 | "execution_count": 6
201 | }
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "metadata": {
207 | "id": "OXADzI35Dw3g",
208 | "colab_type": "code",
209 | "colab": {
210 | "base_uri": "https://localhost:8080/",
211 | "height": 51
212 | },
213 | "outputId": "b33dd53b-70ff-496b-b1ea-af4543f2b815"
214 | },
215 | "source": [
216 | "# Evaluate and serialize\n",
217 | "print(\"Test accuracy: {:.2f}\".format(teacher_model.evaluate(test_ds)[1]*100))\n",
218 | "teacher_model.save_weights(\"teacher_model.h5\")"
219 | ],
220 | "execution_count": 7,
221 | "outputs": [
222 | {
223 | "output_type": "stream",
224 | "text": [
225 | "157/157 [==============================] - 0s 2ms/step - loss: 0.2726 - accuracy: 0.9020\n",
226 | "Test accuracy: 90.20\n"
227 | ],
228 | "name": "stdout"
229 | }
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "metadata": {
235 | "id": "shnrhMFQFKwZ",
236 | "colab_type": "code",
237 | "colab": {}
238 | },
239 | "source": [
240 | "# Student model utility\n",
241 | "def get_student_model():\n",
242 | " model = models.Sequential()\n",
243 | " model.add(layers.Input(shape=(28, 28, 1)))\n",
244 | " model.add(layers.Flatten())\n",
245 | " model.add(layers.Dense(48, activation=\"relu\"))\n",
246 | " model.add(layers.Dense(10))\n",
247 | " \n",
248 | " return model"
249 | ],
250 | "execution_count": 8,
251 | "outputs": []
252 | },
253 | {
254 | "cell_type": "code",
255 | "metadata": {
256 | "id": "dPFOtO4mGLIr",
257 | "colab_type": "code",
258 | "colab": {}
259 | },
260 | "source": [
261 | "# Credits: https://github.com/google-research/simclr/blob/master/colabs/distillation_self_training.ipynb\n",
262 | "def get_kd_loss(student_logits, teacher_logits, temperature=0.5):\n",
263 | " teacher_probs = tf.nn.softmax(teacher_logits / temperature)\n",
264 | " kd_loss = tf.compat.v1.losses.softmax_cross_entropy(\n",
265 | " teacher_probs, student_logits / temperature, temperature**2)\n",
266 | " return kd_loss"
267 | ],
268 | "execution_count": 9,
269 | "outputs": []
270 | },
271 | {
272 | "cell_type": "code",
273 | "metadata": {
274 | "id": "KDZ5DWUeGkK2",
275 | "colab_type": "code",
276 | "colab": {}
277 | },
278 | "source": [
279 | "# Model, optimizer\n",
280 | "student_model = get_student_model()\n",
281 | "optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)\n",
282 | "\n",
283 | "# Average the loss across the batch size within an epoch\n",
284 | "train_loss = tf.keras.metrics.Mean(name=\"train_loss\")\n",
285 | "valid_loss = tf.keras.metrics.Mean(name=\"test_loss\")\n",
286 | "\n",
287 | "# Specify the performance metric\n",
288 | "train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name=\"train_acc\")\n",
289 | "valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name=\"valid_acc\")"
290 | ],
291 | "execution_count": 16,
292 | "outputs": []
293 | },
294 | {
295 | "cell_type": "code",
296 | "metadata": {
297 | "id": "5w1sCCqQGeTe",
298 | "colab_type": "code",
299 | "colab": {}
300 | },
301 | "source": [
302 | "# Train utils\n",
303 | "@tf.function\n",
304 | "def model_train(images, labels, teacher_model, \n",
305 | " student_model, optimizer, temperature):\n",
306 | " teacher_logits = teacher_model(images)\n",
307 | "\n",
308 | " with tf.GradientTape() as tape:\n",
309 | " student_logits = student_model(images)\n",
310 | " loss = get_kd_loss(student_logits, teacher_logits, temperature)\n",
311 | " \n",
312 | " gradients = tape.gradient(loss, student_model.trainable_variables)\n",
313 | " optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))\n",
314 | "\n",
315 | " train_loss(loss)\n",
316 | " train_acc(labels, tf.nn.softmax(student_logits))"
317 | ],
318 | "execution_count": 17,
319 | "outputs": []
320 | },
321 | {
322 | "cell_type": "code",
323 | "metadata": {
324 | "id": "qXjapT-hHeP1",
325 | "colab_type": "code",
326 | "colab": {}
327 | },
328 | "source": [
329 | "# Validation utils\n",
330 | "@tf.function\n",
331 | "def model_validate(images, labels, teacher_model, \n",
332 | " student_model, temperature):\n",
333 | " teacher_logits = teacher_model(images)\n",
334 | "\n",
335 | " student_logits = student_model(images)\n",
336 | " loss = get_kd_loss(student_logits, teacher_logits, temperature)\n",
337 | "\n",
338 | " valid_loss(loss)\n",
339 | " valid_acc(labels, tf.nn.softmax(student_logits))"
340 | ],
341 | "execution_count": 18,
342 | "outputs": []
343 | },
344 | {
345 | "cell_type": "code",
346 | "metadata": {
347 | "id": "ph4r4J_zHqFE",
348 | "colab_type": "code",
349 | "colab": {}
350 | },
351 | "source": [
352 | "# Tie everything together\n",
353 | "def train_model(epochs, teacher_model, student_model, optimizer, temperature=0.5):\n",
354 | " for epoch in range(epochs):\n",
355 | " for (images, labels) in train_ds:\n",
356 | " model_train(images, labels, teacher_model, student_model, optimizer, temperature)\n",
357 | "\n",
358 | " for (images, labels) in test_ds:\n",
359 | " model_validate(images, labels, teacher_model, student_model, temperature)\n",
360 | " \n",
361 | " (loss, acc) = train_loss.result(), train_acc.result()\n",
362 | " (val_loss, val_acc) = valid_loss.result(), valid_acc.result()\n",
363 | " \n",
364 | " train_loss.reset_states(), train_acc.reset_states()\n",
365 | " valid_loss.reset_states(), valid_acc.reset_states()\n",
366 | " \n",
367 | " template = \"Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}\"\n",
368 | " print (template.format(epoch+1,\n",
369 | " loss,\n",
370 | " acc,\n",
371 | " val_loss,\n",
372 | " val_acc))\n",
373 | " \n",
374 | " \n",
375 | " return teacher_model, student_model"
376 | ],
377 | "execution_count": 19,
378 | "outputs": []
379 | },
380 | {
381 | "cell_type": "code",
382 | "metadata": {
383 | "id": "O-breyl1dwNR",
384 | "colab_type": "code",
385 | "colab": {
386 | "base_uri": "https://localhost:8080/",
387 | "height": 187
388 | },
389 | "outputId": "78143bf4-f890-4d72-bbef-3003b0bcf627"
390 | },
391 | "source": [
392 | "_, student_model = train_model(10, teacher_model, student_model, optimizer)"
393 | ],
394 | "execution_count": 20,
395 | "outputs": [
396 | {
397 | "output_type": "stream",
398 | "text": [
399 | "Epoch 1, loss: 0.116, acc: 0.816, val_loss: 0.097, val_acc: 0.825\n",
400 | "Epoch 2, loss: 0.091, acc: 0.848, val_loss: 0.091, val_acc: 0.838\n",
401 | "Epoch 3, loss: 0.086, acc: 0.853, val_loss: 0.088, val_acc: 0.841\n",
402 | "Epoch 4, loss: 0.084, acc: 0.857, val_loss: 0.086, val_acc: 0.846\n",
403 | "Epoch 5, loss: 0.082, acc: 0.858, val_loss: 0.089, val_acc: 0.838\n",
404 | "Epoch 6, loss: 0.081, acc: 0.861, val_loss: 0.085, val_acc: 0.848\n",
405 | "Epoch 7, loss: 0.080, acc: 0.862, val_loss: 0.088, val_acc: 0.840\n",
406 | "Epoch 8, loss: 0.079, acc: 0.863, val_loss: 0.092, val_acc: 0.838\n",
407 | "Epoch 9, loss: 0.078, acc: 0.864, val_loss: 0.085, val_acc: 0.850\n",
408 | "Epoch 10, loss: 0.078, acc: 0.864, val_loss: 0.086, val_acc: 0.845\n"
409 | ],
410 | "name": "stdout"
411 | }
412 | ]
413 | },
414 | {
415 | "cell_type": "markdown",
416 | "metadata": {
417 | "id": "H0DHWweqcqIJ",
418 | "colab_type": "text"
419 | },
420 | "source": [
421 | "This can be further improved with longer training time and more careful hyperparameter tuning. "
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "metadata": {
427 | "id": "fmLLgmpybLYi",
428 | "colab_type": "code",
429 | "colab": {}
430 | },
431 | "source": [
432 | "# Serialize\n",
433 | "student_model.save_weights(\"student_model.h5\")"
434 | ],
435 | "execution_count": 21,
436 | "outputs": []
437 | },
438 | {
439 | "cell_type": "code",
440 | "metadata": {
441 | "id": "rJeZK9enJct9",
442 | "colab_type": "code",
443 | "colab": {
444 | "base_uri": "https://localhost:8080/",
445 | "height": 51
446 | },
447 | "outputId": "765a834b-ac7d-4818-ed86-b2953c678d17"
448 | },
449 | "source": [
450 | "# Investigate the sizes\n",
451 | "!ls -lh *.h5"
452 | ],
453 | "execution_count": 22,
454 | "outputs": [
455 | {
456 | "output_type": "stream",
457 | "text": [
458 | "-rw-r--r-- 1 root root 163K Aug 31 07:47 student_model.h5\n",
459 | "-rw-r--r-- 1 root root 335K Aug 31 07:44 teacher_model.h5\n"
460 | ],
461 | "name": "stdout"
462 | }
463 | ]
464 | },
465 | {
466 | "cell_type": "markdown",
467 | "metadata": {
468 | "id": "SNfgaGNncnSt",
469 | "colab_type": "text"
470 | },
471 | "source": [
472 | "Let's check the total number of trainable params."
473 | ]
474 | },
475 | {
476 | "cell_type": "code",
477 | "metadata": {
478 | "id": "cSHNvya8cYLP",
479 | "colab_type": "code",
480 | "colab": {
481 | "base_uri": "https://localhost:8080/",
482 | "height": 425
483 | },
484 | "outputId": "43bb53d7-5e8f-4d9a-bfdb-77e278d9ff20"
485 | },
486 | "source": [
487 | "teacher_model.summary()"
488 | ],
489 | "execution_count": 23,
490 | "outputs": [
491 | {
492 | "output_type": "stream",
493 | "text": [
494 | "Model: \"sequential\"\n",
495 | "_________________________________________________________________\n",
496 | "Layer (type) Output Shape Param # \n",
497 | "=================================================================\n",
498 | "conv2d (Conv2D) (None, 24, 24, 16) 416 \n",
499 | "_________________________________________________________________\n",
500 | "max_pooling2d (MaxPooling2D) (None, 12, 12, 16) 0 \n",
501 | "_________________________________________________________________\n",
502 | "conv2d_1 (Conv2D) (None, 8, 8, 32) 12832 \n",
503 | "_________________________________________________________________\n",
504 | "max_pooling2d_1 (MaxPooling2 (None, 4, 4, 32) 0 \n",
505 | "_________________________________________________________________\n",
506 | "dropout (Dropout) (None, 4, 4, 32) 0 \n",
507 | "_________________________________________________________________\n",
508 | "flatten (Flatten) (None, 512) 0 \n",
509 | "_________________________________________________________________\n",
510 | "dense (Dense) (None, 128) 65664 \n",
511 | "_________________________________________________________________\n",
512 | "dense_1 (Dense) (None, 10) 1290 \n",
513 | "=================================================================\n",
514 | "Total params: 80,202\n",
515 | "Trainable params: 80,202\n",
516 | "Non-trainable params: 0\n",
517 | "_________________________________________________________________\n"
518 | ],
519 | "name": "stdout"
520 | }
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "metadata": {
526 | "id": "T0-Y1gpDccZ_",
527 | "colab_type": "code",
528 | "colab": {
529 | "base_uri": "https://localhost:8080/",
530 | "height": 255
531 | },
532 | "outputId": "d7a27cdb-13b4-4736-a1a0-a8a336357b30"
533 | },
534 | "source": [
535 | "student_model.summary()"
536 | ],
537 | "execution_count": 24,
538 | "outputs": [
539 | {
540 | "output_type": "stream",
541 | "text": [
542 | "Model: \"sequential_2\"\n",
543 | "_________________________________________________________________\n",
544 | "Layer (type) Output Shape Param # \n",
545 | "=================================================================\n",
546 | "flatten_2 (Flatten) (None, 784) 0 \n",
547 | "_________________________________________________________________\n",
548 | "dense_4 (Dense) (None, 48) 37680 \n",
549 | "_________________________________________________________________\n",
550 | "dense_5 (Dense) (None, 10) 490 \n",
551 | "=================================================================\n",
552 | "Total params: 38,170\n",
553 | "Trainable params: 38,170\n",
554 | "Non-trainable params: 0\n",
555 | "_________________________________________________________________\n"
556 | ],
557 | "name": "stdout"
558 | }
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {
564 | "id": "AzC3KhO_J42N",
565 | "colab_type": "text"
566 | },
567 | "source": [
568 | "Further size decrease is possible with TFLite. "
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "metadata": {
574 | "id": "Z8d0R_ypVp8y",
575 | "colab_type": "code",
576 | "colab": {}
577 | },
578 | "source": [
579 | "# Credits: https://www.tensorflow.org/lite/performance/post_training_quant\n",
580 | "\n",
581 | "def representative_data_gen():\n",
582 | " for input_value in tf.data.Dataset.from_tensor_slices(X_train).batch(1).take(100):\n",
583 | " yield [input_value]\n",
584 | "\n",
585 | "def convert_to_tflite(model, tflite_file):\n",
586 | " converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
587 | " converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
588 | " converter.representative_dataset = representative_data_gen\n",
589 | " converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
590 | " converter.inference_input_type = tf.int8\n",
591 | " converter.inference_output_type = tf.int8\n",
592 | " tflite_quant_model = converter.convert()\n",
593 | "\n",
594 | " open(tflite_file, 'wb').write(tflite_quant_model)"
595 | ],
596 | "execution_count": 25,
597 | "outputs": []
598 | },
599 | {
600 | "cell_type": "code",
601 | "metadata": {
602 | "id": "bZgxSge7Y3hU",
603 | "colab_type": "code",
604 | "colab": {
605 | "base_uri": "https://localhost:8080/",
606 | "height": 190
607 | },
608 | "outputId": "71470926-d629-4bdb-f207-a6ed22a6f599"
609 | },
610 | "source": [
611 | "convert_to_tflite(teacher_model, \"teacher.tflite\")\n",
612 | "convert_to_tflite(student_model, \"student.tflite\")"
613 | ],
614 | "execution_count": 26,
615 | "outputs": [
616 | {
617 | "output_type": "stream",
618 | "text": [
619 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
620 | "Instructions for updating:\n",
621 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
622 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
623 | "Instructions for updating:\n",
624 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
625 | "INFO:tensorflow:Assets written to: /tmp/tmp5020kbxi/assets\n",
626 | "INFO:tensorflow:Assets written to: /tmp/tmp2t19bpk6/assets\n"
627 | ],
628 | "name": "stdout"
629 | },
630 | {
631 | "output_type": "stream",
632 | "text": [
633 | "INFO:tensorflow:Assets written to: /tmp/tmp2t19bpk6/assets\n"
634 | ],
635 | "name": "stderr"
636 | }
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "metadata": {
642 | "id": "-f8eGqRtZA-w",
643 | "colab_type": "code",
644 | "colab": {
645 | "base_uri": "https://localhost:8080/",
646 | "height": 51
647 | },
648 | "outputId": "2c34b9f9-b22a-4e2f-9385-893c778ed2ea"
649 | },
650 | "source": [
651 | "!ls -lh *.tflite"
652 | ],
653 | "execution_count": 27,
654 | "outputs": [
655 | {
656 | "output_type": "stream",
657 | "text": [
658 | "-rw-r--r-- 1 root root 40K Aug 31 07:48 student.tflite\n",
659 | "-rw-r--r-- 1 root root 85K Aug 31 07:48 teacher.tflite\n"
660 | ],
661 | "name": "stdout"
662 | }
663 | ]
664 | }
665 | ]
666 | }
667 |
--------------------------------------------------------------------------------