├── Digit Recognition ├── Dataset │ ├── test.csv.zip │ └── train.csv.zip ├── digit reco.ipynb └── digit.py ├── FIFA 2019 ├── data.csv ├── final with deep.ipynb ├── final2.ipynb └── mywork.ipynb ├── LICENSE └── README.md /Digit Recognition/Dataset/test.csv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shsarv/Deep-Learning-Projects/c38d54bb470952a9da2b8f8da8fada89d79d66ee/Digit Recognition/Dataset/test.csv.zip -------------------------------------------------------------------------------- /Digit Recognition/Dataset/train.csv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shsarv/Deep-Learning-Projects/c38d54bb470952a9da2b8f8da8fada89d79d66ee/Digit Recognition/Dataset/train.csv.zip -------------------------------------------------------------------------------- /Digit Recognition/digit reco.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "

Digit Recognizer

" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Goal is to correctly identify digits from a dataset of tens of thousands of handwritten images. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Configuration" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": { 28 | "id": "4JIP7spmj-Lm" 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | " ! pip install -q kaggle" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": { 39 | "id": "md-msCPmkS-B" 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "from google.colab import files " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": { 50 | "colab": { 51 | "base_uri": "https://localhost:8080/", 52 | "height": 92, 53 | "resources": { 54 | "http://localhost:8080/nbextensions/google.colab/files.js": { 55 | "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", 56 | "headers": [ 57 | [ 58 | "content-type", 59 | "application/javascript" 60 | ] 61 | ], 62 | "ok": true, 63 | "status": 200, 64 | "status_text": "OK" 65 | } 66 | } 67 | }, 68 | "id": "B_r-0yX-kZBh", 69 | "outputId": "ca91e665-60fd-4621-f3d6-031f00f363d8" 70 | }, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/html": [ 75 | "\n", 76 | " \n", 78 | " \n", 79 | " Upload widget is only available when the cell has been executed in the\n", 80 | " current browser session. Please rerun this cell to enable.\n", 81 | " \n", 82 | " " 83 | ], 84 | "text/plain": [ 85 | "" 86 | ] 87 | }, 88 | "metadata": { 89 | "tags": [] 90 | }, 91 | "output_type": "display_data" 92 | }, 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Saving kaggle.json to kaggle.json\n" 98 | ] 99 | }, 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "{'kaggle.json': b'{}'}" 104 | ] 105 | }, 106 | "execution_count": 3, 107 | "metadata": { 108 | "tags": [] 109 | }, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "files.upload()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": { 121 | "id": "XhwIblLik6HD" 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "! mkdir ~/.kaggle " 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 5, 131 | "metadata": { 132 | "id": "hN-kfrS6lDqx" 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "! cp kaggle.json ~/.kaggle/" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 6, 142 | "metadata": { 143 | "id": "-38XiQgOlIvV" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "! chmod 600 ~/.kaggle/kaggle.json" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 7, 153 | "metadata": { 154 | "colab": { 155 | "base_uri": "https://localhost:8080/" 156 | }, 157 | "id": "ehUbKerMlPa9", 158 | "outputId": "6acb1624-4f25-484f-fe78-1d73a8b2e9ed" 159 | }, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "Warning: Looks like you're using an outdated API Version, please consider updating (server 1.5.9 / client 1.5.4)\n", 166 | "ref title size lastUpdated downloadCount \n", 167 | "------------------------------------------------------ ------------------------------------------- ----- ------------------- ------------- \n", 168 | "manchunhui/us-election-2020-tweets US Election 2020 Tweets 353MB 2020-11-09 18:51:59 1289 \n", 169 | "unanimad/us-election-2020 US Election 2020 417KB 2020-11-09 13:52:09 1340 \n", 170 | "headsortails/us-election-2020-presidential-debates US Election 2020 - Presidential Debates 199MB 2020-10-23 16:56:10 301 \n", 171 | "antgoldbloom/covid19-data-from-john-hopkins-university COVID-19 data from John Hopkins University 2MB 2020-11-09 06:07:13 82 \n", 172 | "etsc9287/2020-general-election-polls 2020 General Election Polls 109KB 2020-02-09 08:20:59 454 \n", 173 | "radustoicescu/2020-united-states-presidential-election 2020 United States presidential election 11MB 2019-07-04 15:00:45 629 \n", 174 | "shivamb/netflix-shows Netflix Movies and TV Shows 971KB 2020-01-20 07:33:56 56294 \n", 175 | "terenceshin/covid19s-impact-on-airport-traffic COVID-19's Impact on Airport Traffic 106KB 2020-10-19 12:40:17 3284 \n", 176 | "sootersaalu/amazon-top-50-bestselling-books-2009-2019 Amazon Top 50 Bestselling Books 2009 - 2019 15KB 2020-10-13 09:39:21 3192 \n", 177 | "nehaprabhavalkar/indian-food-101 Indian Food 101 7KB 2020-09-30 06:23:43 6654 \n", 178 | "karangadiya/fifa19 FIFA 19 complete player dataset 2MB 2018-12-21 03:52:59 103016 \n", 179 | "omarhanyy/500-greatest-songs-of-all-time 500 Greatest Songs of All Time 33KB 2020-10-26 13:36:09 1018 \n", 180 | "heeraldedhia/groceries-dataset Groceries dataset 257KB 2020-09-17 04:36:08 7060 \n", 181 | "andrewmvd/trip-advisor-hotel-reviews Trip Advisor Hotel Reviews 5MB 2020-09-30 08:31:20 4745 \n", 182 | "docstein/brics-world-bank-indicators BRICS World Bank Indicators 4MB 2020-10-22 12:18:40 837 \n", 183 | "google/tinyquickdraw QuickDraw Sketches 11GB 2018-04-18 19:38:04 2434 \n", 184 | "datasnaek/youtube-new Trending YouTube Video Statistics 201MB 2019-06-03 00:56:47 114155 \n", 185 | "uciml/mushroom-classification Mushroom Classification 34KB 2016-12-01 23:08:00 53724 \n", 186 | "anikannal/solar-power-generation-data Solar Power Generation Data 2MB 2020-08-18 15:52:03 9300 \n", 187 | "zynicide/wine-reviews Wine Reviews 51MB 2017-11-27 17:08:04 118271 \n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | " ! kaggle datasets list" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 9, 198 | "metadata": { 199 | "colab": { 200 | "base_uri": "https://localhost:8080/" 201 | }, 202 | "id": "rz3bR4w3lT2Z", 203 | "outputId": "4591be06-3532-4507-ab9d-47c32635d799" 204 | }, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "Warning: Looks like you're using an outdated API Version, please consider updating (server 1.5.9 / client 1.5.4)\n", 211 | "Downloading train.csv.zip to /content\n", 212 | " 87% 8.00M/9.16M [00:00<00:00, 81.4MB/s]\n", 213 | "100% 9.16M/9.16M [00:00<00:00, 84.2MB/s]\n", 214 | "Downloading test.csv.zip to /content\n", 215 | " 82% 5.00M/6.09M [00:00<00:00, 35.0MB/s]\n", 216 | "100% 6.09M/6.09M [00:00<00:00, 38.7MB/s]\n", 217 | "Downloading sample_submission.csv to /content\n", 218 | " 0% 0.00/235k [00:00\n", 370 | "\n", 383 | "\n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | "
labelpixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8pixel9pixel10pixel11pixel12pixel13pixel14pixel15pixel16pixel17pixel18pixel19pixel20pixel21pixel22pixel23pixel24pixel25pixel26pixel27pixel28pixel29pixel30pixel31pixel32pixel33pixel34pixel35pixel36pixel37pixel38...pixel744pixel745pixel746pixel747pixel748pixel749pixel750pixel751pixel752pixel753pixel754pixel755pixel756pixel757pixel758pixel759pixel760pixel761pixel762pixel763pixel764pixel765pixel766pixel767pixel768pixel769pixel770pixel771pixel772pixel773pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
01000000000000000000000000000000000000000...0000000000000000000000000000000000000000
10000000000000000000000000000000000000000...0000000000000000000000000000000000000000
21000000000000000000000000000000000000000...0000000000000000000000000000000000000000
34000000000000000000000000000000000000000...0000000000000000000000000000000000000000
40000000000000000000000000000000000000000...0000000000000000000000000000000000000000
\n", 893 | "

5 rows × 785 columns

\n", 894 | "" 895 | ], 896 | "text/plain": [ 897 | " label pixel0 pixel1 pixel2 ... pixel780 pixel781 pixel782 pixel783\n", 898 | "0 1 0 0 0 ... 0 0 0 0\n", 899 | "1 0 0 0 0 ... 0 0 0 0\n", 900 | "2 1 0 0 0 ... 0 0 0 0\n", 901 | "3 4 0 0 0 ... 0 0 0 0\n", 902 | "4 0 0 0 0 ... 0 0 0 0\n", 903 | "\n", 904 | "[5 rows x 785 columns]" 905 | ] 906 | }, 907 | "execution_count": 20, 908 | "metadata": { 909 | "tags": [] 910 | }, 911 | "output_type": "execute_result" 912 | } 913 | ], 914 | "source": [ 915 | "train.head()" 916 | ] 917 | }, 918 | { 919 | "cell_type": "markdown", 920 | "metadata": { 921 | "id": "1D2COwUxnWDp" 922 | }, 923 | "source": [ 924 | "## 2. transform Dataset" 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "execution_count": 21, 930 | "metadata": { 931 | "id": "8gJ8BcU3naPR" 932 | }, 933 | "outputs": [], 934 | "source": [ 935 | "# Convert Dataframe into format ready for training\n", 936 | "def createImageData(raw: pd.DataFrame):\n", 937 | " y = raw['label'].values\n", 938 | " y.resize(y.shape[0],1)\n", 939 | " x = raw[[i for i in raw.columns if i != 'label']].values\n", 940 | " x = x.reshape([-1,1, 28, 28])\n", 941 | " y = y.astype(int).reshape(-1)\n", 942 | " x = x.astype(float)\n", 943 | " return x, y\n", 944 | "\n", 945 | "## Convert to One Hot Encoding\n", 946 | "def one_hot_embedding(labels, num_classes=10):\n", 947 | " y = torch.eye(num_classes) \n", 948 | " return y[labels] " 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 22, 954 | "metadata": { 955 | "colab": { 956 | "base_uri": "https://localhost:8080/" 957 | }, 958 | "id": "8XcrqCheng1w", 959 | "outputId": "f81d53d7-3868-4cae-deda-a8c0ad256ec8" 960 | }, 961 | "outputs": [ 962 | { 963 | "data": { 964 | "text/plain": [ 965 | "((42000, 1, 28, 28), (42000,))" 966 | ] 967 | }, 968 | "execution_count": 22, 969 | "metadata": { 970 | "tags": [] 971 | }, 972 | "output_type": "execute_result" 973 | } 974 | ], 975 | "source": [ 976 | "x_train, y_train = createImageData(train)\n", 977 | "#x_train, x_val, y_train, y_val = train_test_split(x,y, test_size=0.02)\n", 978 | "\n", 979 | "#x_train.shape, y_train.shape, x_val.shape, y_val.shape\n", 980 | "x_train.shape, y_train.shape" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": 23, 986 | "metadata": { 987 | "id": "HUiAWVIKnkTv" 988 | }, 989 | "outputs": [], 990 | "source": [ 991 | "# Normalization\n", 992 | "mean = x_train.mean()\n", 993 | "std = x_train.std()\n", 994 | "x_train = (x_train-mean)/std\n", 995 | "#x_val = (x_val-mean)/std\n", 996 | "\n", 997 | "# Numpy to Torch Tensor\n", 998 | "x_train = torch.from_numpy(np.float32(x_train)).to(device)\n", 999 | "y_train = torch.from_numpy(y_train.astype(np.long)).to(device)\n", 1000 | "y_train = one_hot_embedding(y_train)\n", 1001 | "#x_val = torch.from_numpy(np.float32(x_val))\n", 1002 | "#y_val = torch.from_numpy(y_val.astype(np.long))" 1003 | ] 1004 | }, 1005 | { 1006 | "cell_type": "markdown", 1007 | "metadata": { 1008 | "id": "RglJneAZnqyg" 1009 | }, 1010 | "source": [ 1011 | "# 3. Loading Dataset" 1012 | ] 1013 | }, 1014 | { 1015 | "cell_type": "code", 1016 | "execution_count": 24, 1017 | "metadata": { 1018 | "id": "pkg5IlqknzYk" 1019 | }, 1020 | "outputs": [], 1021 | "source": [ 1022 | "# Convert into Torch Dataset\n", 1023 | "train_ds = TensorDataset(x_train, y_train)\n", 1024 | "#val_ds = TensorDataset(x_val,y_val)" 1025 | ] 1026 | }, 1027 | { 1028 | "cell_type": "code", 1029 | "execution_count": 25, 1030 | "metadata": { 1031 | "id": "arOGS2lDn_os" 1032 | }, 1033 | "outputs": [], 1034 | "source": [ 1035 | "# Make Data Loader\n", 1036 | "train_dl = DataLoader(train_ds, batch_size=64)" 1037 | ] 1038 | }, 1039 | { 1040 | "cell_type": "markdown", 1041 | "metadata": { 1042 | "id": "tzyFHYqtoHIT" 1043 | }, 1044 | "source": [ 1045 | "## 4. EDA" 1046 | ] 1047 | }, 1048 | { 1049 | "cell_type": "code", 1050 | "execution_count": 26, 1051 | "metadata": { 1052 | "colab": { 1053 | "base_uri": "https://localhost:8080/", 1054 | "height": 283 1055 | }, 1056 | "id": "9X9yvJvyoJkx", 1057 | "outputId": "83607741-6dc7-4562-8933-e4af4db0d8a8" 1058 | }, 1059 | "outputs": [ 1060 | { 1061 | "name": "stdout", 1062 | "output_type": "stream", 1063 | "text": [ 1064 | "tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n" 1065 | ] 1066 | }, 1067 | { 1068 | "data": { 1069 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOU0lEQVR4nO3df4wUdZrH8c+jLDECCuMPMmE5ZTcmBs/IXiZ4yU0uXpZFFBNcCQZCDJfbZEiEyCZngtn7YzWXTdZT9v4wEYRgdk45VxD3RCQCR8h5xgQdf5yiHuoZCIwj409GiMmKPPdHF5sBp789Vld19fC8X8mku+uZ7u9Dw4eqruqqr7m7AJz7zqu6AQCtQdiBIAg7EARhB4Ig7EAQ41o5mJmx6x8ombvbSMubWrOb2TwzO2BmH5jZPc28FoByWd7j7GZ2vqT3JP1M0hFJr0ha4u7vJJ7Dmh0oWRlr9tmSPnD3D939T5L+IGlBE68HoETNhH2apMPDHh/Jlp3BzHrMrM/M+poYC0CTSt9B5+7rJa2X2IwHqtTMmr1f0vRhj3+YLQPQhpoJ+yuSrjKzGWY2XtJiSduKaQtA0XJvxrv7STNbKWmnpPMlPerubxfWGYBC5T70lmswPrMDpSvlSzUAxg7CDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Jo6ZTNyKejoyNZnzhxYt3aihUrmhr7+uuvT9YffvjhZH1oaKhubefOncnntvLKxxGwZgeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIDjO3gKTJk1K1m+66aZk/fHHH0/Wx42r7q+xs7MzWZ8+fXrdWm9vb/K5999/f7J+8ODBZB1naupfiZkdlPSVpG8lnXT3riKaAlC8IlYJf+funxbwOgBKxGd2IIhmw+6SdpnZq2bWM9IvmFmPmfWZWV+TYwFoQrOb8d3u3m9ml0vabWb/6+4vDP8Fd18vab0kmRlnNgAVaWrN7u792e2gpD9Kml1EUwCKlzvsZjbBzCadvi9prqT9RTUGoFiW95xhM/uRamtzqfZx4N/d/TcNnnNObsZPnjw5WX/ssceS9fnz5xfZzjnj6NGjyfqCBQuS9QMHDtStHTt2LFdPY4G720jLc39md/cPJV2XuyMALcWhNyAIwg4EQdiBIAg7EARhB4LIfegt12Dn6KG3efPmJes7duxoUScY7s4776xbW7duXQs7aa16h95YswNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEFxKepS6u7vr1lavXt3CToq1atWqZP2jjz5K1u++++5kvdGUz2V64IEH6tY+++yz5HO3bNlSdDuVY80OBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0FwPvsoPfXUU3Vrt912W6lj9/WlZ87at29f7td+5JFHkvX9+9NTAUyYMCFZ7+joqFtrdCx79uzy5hzZunVrsr5o0aLSxi4b57MDwRF2IAjCDgRB2IEgCDsQBGEHgiDsQBCcz54xG/HQ5J+dd155/y8uXbo0WR8cHEzW9+zZU2Q738uJEydy159//vnkc7u6upL1Zv5Orr766mT9lltuSda3b9+ee+yqNHy3zOxRMxs0s/3DlnWY2W4zez+7nVJumwCaNZr/Gn8v6ewpT+6RtMfdr5K0J3sMoI01DLu7vyDp87MWL5DUm93vlXRrwX0BKFjez+xT3X0gu/+xpKn1ftHMeiT15BwHQEGa3kHn7p46wcXd10taL43tE2GAsS7v7syjZtYpSdltencxgMrlDfs2Scuy+8skPVNMOwDK0vB8djN7QtINki6VdFTSryX9h6TNkv5C0iFJt7v72TvxRnqttt2Mv+6665L1119/vbSxr7jiimT98OHDpY3dzhYuXJisl3lt9w0bNiTry5cvL23sZtU7n73hZ3Z3X1Kn9NOmOgLQUnxdFgiCsANBEHYgCMIOBEHYgSA4xTUzY8aM0l57aGgoWf/mm29KG3sse+mll5L1Ru/rRRddVGQ7Yx5rdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IguPsmS+//LK013755ZeT9S+++KK0sceygYGBZH3Hjh3J+uLFi3OPfeONNybrEydOTNaPHz+ee+yysGYHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAaXkq60MEqvJR0o3Ob33vvvWT98ssvL7KdM3Ap6Xzmz5+frD/77LOljX3JJZck61V+d6LepaRZswNBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEGHOZx83Lv1HLfM4OsrR399fdQtjSsM1u5k9amaDZrZ/2LJ7zazfzN7Ifm4ut00AzRrNZvzvJc0bYfm/uvus7Cd9yRAAlWsYdnd/QdLnLegFQIma2UG30szezDbzp9T7JTPrMbM+M+trYiwATcob9rWSfixplqQBSWvq/aK7r3f3LnfvyjkWgALkCru7H3X3b939lKQNkmYX2xaAouUKu5l1Dnv4c0n76/0ugPbQ8Di7mT0h6QZJl5rZEUm/lnSDmc2S5JIOSlpeYo+FaHRd+E2bNiXrS5cuLbIdoOUaht3dl4yweGMJvQAoEV+XBYIg7EAQhB0IgrADQRB2IIgwp7ieOnUqWd+9e3eyXuahty1btiTrc+bMSdbbcXrgIkyePDlZ7+3tLW3sdevWJetlTvFdFtbsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxBEmCmbG7n44ouT9b1799atzZo1q+h2ztDXl76i1+rVq+vWUn1X7bLLLkvWH3zwwWT9jjvuyD32119/nazPnDkzWT906FDuscvGlM1AcIQdCIKwA0EQdiAIwg4EQdiBIAg7EATH2Uepu7u7bm3t2rXJ515zzTVFt3OGF198sW7trrvuauq1h4aGkvXx48cn6xdccEHdWqPz0a+99tpkvRlbt25N1hctWlTa2GXjODsQHGEHgiDsQBCEHQiCsANBEHYgCMIOBMFx9gLcfvvtyfrGjelJbydMmFBkO4X65JNPkvULL7wwWW/XP9vixYuT9c2bN7eok+LlPs5uZtPNbK+ZvWNmb5vZqmx5h5ntNrP3s9spRTcNoDij2Yw/Kekf3X2mpL+WtMLMZkq6R9Ied79K0p7sMYA21TDs7j7g7q9l97+S9K6kaZIWSDr9fcdeSbeW1SSA5n2vud7M7EpJP5G0T9JUdx/ISh9LmlrnOT2SevK3CKAIo94bb2YTJW2V9Et3P+PsCK/t5Rtx55u7r3f3LnfvaqpTAE0ZVdjN7AeqBX2Tuz+dLT5qZp1ZvVPSYDktAihCw814MzNJGyW96+6/G1baJmmZpN9mt8+U0uEY0OgwzbRp05L1NWvWFNlOoRpd7rlKx44dS9aXL19et/bcc88V3U7bG81n9r+RdIekt8zsjWzZr1QL+WYz+4WkQ5LSB5sBVKph2N39RUkjHqSX9NNi2wFQFr4uCwRB2IEgCDsQBGEHgiDsQBCc4toCkyZNStaffPLJZH3evHlFtjNmnDhxIllfuHBhsr5r164i2xkzuJQ0EBxhB4Ig7EAQhB0IgrADQRB2IAjCDgTBcfY2kJrWWJLmzJmTrM+dO7dubeXKlcnn1i5XUF+jfx+Nnv/QQw/Vrd13333J5548eTJZb3Q+e1QcZweCI+xAEIQdCIKwA0EQdiAIwg4EQdiBIDjODpxjOM4OBEfYgSAIOxAEYQeCIOxAEIQdCIKwA0E0DLuZTTezvWb2jpm9bWarsuX3mlm/mb2R/dxcfrsA8mr4pRoz65TU6e6vmdkkSa9KulW1+diPu/uDox6ML9UApav3pZrRzM8+IGkgu/+Vmb0raVqx7QEo2/f6zG5mV0r6iaR92aKVZvammT1qZlPqPKfHzPrMrK+pTgE0ZdTfjTeziZL+S9Jv3P1pM5sq6VNJLumfVdvU/4cGr8FmPFCyepvxowq7mf1A0nZJO939dyPUr5S03d3/ssHrEHagZLlPhLHa5UM3Snp3eNCzHXen/VzS/mabBFCe0eyN75b035LeknQqW/wrSUskzVJtM/6gpOXZzrzUa7FmB0rW1GZ8UQg7UD7OZweCI+xAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTR8IKTBftU0qFhjy/NlrWjdu2tXfuS6C2vInu7ol6hpeezf2dwsz5376qsgYR27a1d+5LoLa9W9cZmPBAEYQeCqDrs6yseP6Vde2vXviR6y6slvVX6mR1A61S9ZgfQIoQdCKKSsJvZPDM7YGYfmNk9VfRQj5kdNLO3smmoK52fLptDb9DM9g9b1mFmu83s/ex2xDn2KuqtLabxTkwzXul7V/X05y3/zG5m50t6T9LPJB2R9IqkJe7+TksbqcPMDkrqcvfKv4BhZn8r6bikfzs9tZaZ/Yukz939t9l/lFPcfXWb9Havvuc03iX1Vm+a8b9Xhe9dkdOf51HFmn22pA/c/UN3/5OkP0haUEEfbc/dX5D0+VmLF0jqze73qvaPpeXq9NYW3H3A3V/L7n8l6fQ045W+d4m+WqKKsE+TdHjY4yNqr/neXdIuM3vVzHqqbmYEU4dNs/WxpKlVNjOChtN4t9JZ04y3zXuXZ/rzZrGD7ru63f2vJN0kaUW2udqWvPYZrJ2Ona6V9GPV5gAckLSmymayaca3Svqluw8Nr1X53o3QV0vetyrC3i9p+rDHP8yWtQV3789uByX9UbWPHe3k6OkZdLPbwYr7+TN3P+ru37r7KUkbVOF7l00zvlXSJnd/Oltc+Xs3Ul+tet+qCPsrkq4ysxlmNl7SYknbKujjO8xsQrbjRGY2QdJctd9U1NskLcvuL5P0TIW9nKFdpvGuN824Kn7vKp/+3N1b/iPpZtX2yP+fpH+qooc6ff1I0v9kP29X3ZukJ1TbrPtGtX0bv5B0iaQ9kt6X9J+SOtqot8dUm9r7TdWC1VlRb92qbaK/KemN7Ofmqt+7RF8ted/4uiwQBDvogCAIOxAEYQeCIOxAEIQdCIKwA0EQdiCI/wcKi4arh8ukpwAAAABJRU5ErkJggg==\n", 1070 | "text/plain": [ 1071 | "
" 1072 | ] 1073 | }, 1074 | "metadata": { 1075 | "needs_background": "light", 1076 | "tags": [] 1077 | }, 1078 | "output_type": "display_data" 1079 | } 1080 | ], 1081 | "source": [ 1082 | "index = 1\n", 1083 | "pyplot.imshow(x_train.cpu()[index].reshape((28, 28)), cmap=\"gray\")\n", 1084 | "print(y_train[index])" 1085 | ] 1086 | }, 1087 | { 1088 | "cell_type": "markdown", 1089 | "metadata": { 1090 | "id": "MbkREYI5oS5i" 1091 | }, 1092 | "source": [ 1093 | "## 5. Model" 1094 | ] 1095 | }, 1096 | { 1097 | "cell_type": "code", 1098 | "execution_count": 27, 1099 | "metadata": { 1100 | "id": "nFtHBNDkoMzV" 1101 | }, 1102 | "outputs": [], 1103 | "source": [ 1104 | "# Helper Functions\n", 1105 | "\n", 1106 | "## Initialize weight with xavier_uniform\n", 1107 | "def init_weights(m):\n", 1108 | " if type(m) == nn.Linear:\n", 1109 | " torch.nn.init.xavier_uniform(m.weight)\n", 1110 | " m.bias.data.fill_(0.01)\n", 1111 | "\n", 1112 | "## Flatten Later\n", 1113 | "class Flatten(nn.Module):\n", 1114 | " def forward(self, input):\n", 1115 | " return input.view(input.size(0), -1)\n", 1116 | "\n", 1117 | "# Train the network and print accuracy and loss overtime\n", 1118 | "def fit(train_dl, model, loss, optim, epochs=10):\n", 1119 | " model = model.to(device)\n", 1120 | " print('Epoch\\tAccuracy\\tLoss')\n", 1121 | " accuracy_overtime = []\n", 1122 | " loss_overtime = []\n", 1123 | " for epoch in range(epochs):\n", 1124 | " avg_loss = 0\n", 1125 | " correct = 0\n", 1126 | " total=0\n", 1127 | " for x, y in train_dl: # Iterate over Data Loder\n", 1128 | " \n", 1129 | " # Forward pass\n", 1130 | " yhat = model(x) \n", 1131 | " l = loss(y, yhat)\n", 1132 | " \n", 1133 | " #Metrics\n", 1134 | " avg_loss+=l.item()\n", 1135 | " \n", 1136 | " # Backward pass\n", 1137 | " optim.zero_grad()\n", 1138 | " l.backward()\n", 1139 | " optim.step()\n", 1140 | " \n", 1141 | " # Metrics\n", 1142 | " _, original = torch.max(y, 1)\n", 1143 | " _, predicted = torch.max(yhat.data, 1)\n", 1144 | " total += y.size(0)\n", 1145 | " correct = correct + (original == predicted).sum().item()\n", 1146 | " \n", 1147 | " accuracy_overtime.append(correct/total)\n", 1148 | " loss_overtime.append(avg_loss/len(train_dl))\n", 1149 | " print(epoch,accuracy_overtime[-1], loss_overtime[-1], sep='\\t')\n", 1150 | " return accuracy_overtime, loss_overtime\n", 1151 | "\n", 1152 | "# Plot Accuracy and Loss of Model\n", 1153 | "def plot_accuracy_loss(accuracy, loss):\n", 1154 | " f = pyplot.figure(figsize=(15,5))\n", 1155 | " ax1 = f.add_subplot(121)\n", 1156 | " ax2 = f.add_subplot(122)\n", 1157 | " ax1.title.set_text(\"Accuracy over epochs\")\n", 1158 | " ax2.title.set_text(\"Loss over epochs\")\n", 1159 | " ax1.plot(accuracy)\n", 1160 | " ax2.plot(loss, 'r:')\n", 1161 | "\n", 1162 | "# Take an array and show what model predicts \n", 1163 | "def predict_for_index(array, model, index):\n", 1164 | " testing = array[index].view(1,28,28)\n", 1165 | " pyplot.imshow(x_train[index].reshape((28, 28)), cmap=\"gray\")\n", 1166 | " print(x_train[index].shape)\n", 1167 | " a = model(testing.float())\n", 1168 | " print('Prediction',torch.argmax(a,1))" 1169 | ] 1170 | }, 1171 | { 1172 | "cell_type": "code", 1173 | "execution_count": 28, 1174 | "metadata": { 1175 | "id": "WY6sE59AobHx" 1176 | }, 1177 | "outputs": [], 1178 | "source": [ 1179 | "# Define the model\n", 1180 | "\n", 1181 | "ff_model = nn.Sequential(\n", 1182 | " Flatten(),\n", 1183 | " nn.Linear(28*28, 100),\n", 1184 | " nn.ReLU(),\n", 1185 | " nn.Linear(100, 10),\n", 1186 | " nn.Softmax(1),\n", 1187 | ").to(device)" 1188 | ] 1189 | }, 1190 | { 1191 | "cell_type": "code", 1192 | "execution_count": 29, 1193 | "metadata": { 1194 | "colab": { 1195 | "base_uri": "https://localhost:8080/" 1196 | }, 1197 | "id": "Ppi6Ai9Yofzq", 1198 | "outputId": "a65069ae-36d9-4865-ee71-873cf3133fae" 1199 | }, 1200 | "outputs": [ 1201 | { 1202 | "name": "stderr", 1203 | "output_type": "stream", 1204 | "text": [ 1205 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:6: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", 1206 | " \n" 1207 | ] 1208 | }, 1209 | { 1210 | "data": { 1211 | "text/plain": [ 1212 | "Sequential(\n", 1213 | " (0): Flatten()\n", 1214 | " (1): Linear(in_features=784, out_features=100, bias=True)\n", 1215 | " (2): ReLU()\n", 1216 | " (3): Linear(in_features=100, out_features=10, bias=True)\n", 1217 | " (4): Softmax(dim=1)\n", 1218 | ")" 1219 | ] 1220 | }, 1221 | "execution_count": 29, 1222 | "metadata": { 1223 | "tags": [] 1224 | }, 1225 | "output_type": "execute_result" 1226 | } 1227 | ], 1228 | "source": [ 1229 | "# Initialize model with xavier initialization which is recommended for ReLu\n", 1230 | "ff_model.apply(init_weights)" 1231 | ] 1232 | }, 1233 | { 1234 | "cell_type": "code", 1235 | "execution_count": 30, 1236 | "metadata": { 1237 | "colab": { 1238 | "base_uri": "https://localhost:8080/", 1239 | "height": 572 1240 | }, 1241 | "id": "RWtKHEgboiy0", 1242 | "outputId": "cab946e0-8d39-4640-d683-653606cbc30d" 1243 | }, 1244 | "outputs": [ 1245 | { 1246 | "name": "stdout", 1247 | "output_type": "stream", 1248 | "text": [ 1249 | "Epoch\tAccuracy\tLoss\n", 1250 | "0\t0.9070476190476191\t0.013667682191970811\n", 1251 | "1\t0.9577142857142857\t0.006563970649066431\n", 1252 | "2\t0.9688333333333333\t0.004935050102914159\n", 1253 | "3\t0.9768095238095238\t0.0038218928349876007\n", 1254 | "4\t0.9801428571428571\t0.003255577986532627\n", 1255 | "5\t0.9834285714285714\t0.002722567654970757\n", 1256 | "6\t0.9863571428571428\t0.002326337982528403\n", 1257 | "7\t0.9872380952380952\t0.0021555991658734047\n", 1258 | "8\t0.9878809523809524\t0.0019994406545720267\n", 1259 | "9\t0.9885952380952381\t0.0018727844953365877\n", 1260 | "10\t0.989547619047619\t0.0017084493914981042\n", 1261 | "11\t0.9913095238095239\t0.0014923217874675437\n" 1262 | ] 1263 | }, 1264 | { 1265 | "data": { 1266 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeXwV9b3/8dcnK4SEAEmIAmFRFokIWCJupViXurXSumNda9W2au211uqttVZrrb12ub3a/or7jl5svWi17taliqASlU0iogkIhCUhCYRsn98fM4FjzAYkmSTn/Xw8zoOZ73xnzmdONN98zncZc3dERERERESk+0uIOgARERERERFpHyVwIiIiIiIiPYQSOBERERERkR5CCZyIiIiIiEgPoQRORERERESkh1ACJyIiIiIi0kMogRORTmNmK83syKjjEBER6WhmdpiZlUQdh8QfJXDSrZjZy2a2ycxSo45FRESkp9IXaCK9lxI46TbMbCQwDXDghC5+76SufL/O1tvuR0REpJHaOIl3SuCkOzkbeBO4Bzgn9oCZ5ZnZ38ys1Mw2mNmtMccuMLMlZlZhZovN7EthuZvZ6Jh695jZr8Ltw8ysxMx+amZrgLvNbKCZPRm+x6Zwe1jM+YPM7G4zWx0efzws/8DMvhFTL9nM1pvZ/s3dZBhvkZltNLO5ZjYkLP+Lmd3SpO7/mdnl4fYQM3ssjO9jM/thTL3rzGyOmT1gZpuBc5t531Qzu8XMPjWztWb2/8ysb5PP4z/D2Fea2bdjzs00s/vC9/7EzK4xs4SY483+DEKTzew9Mys3s0fMrE94Tnb4GZeFn8WrsdcUEZGOF7YFfwzbstXhdmp4rMXfy2F7uSr8Pb/MzI5o4frNthfh+5aZ2YSYujlmttXMBof7XzezhWG9f5vZxJi6K8MY3gOqmkvizGwfM3sujH2ZmZ0ac+yesN17LryHf5nZiJjjh5jZ/LCtmm9mh8Qca7b9jzn+YzNbZ2afmdl5MeXHhW1iRfjZXbFTPyyRFuiPJelOzgYeDF9Hm1kugJklAk8CnwAjgaHA7PDYKcB14bn9CXruNrTz/fYABgEjgAsJ/n+4O9wfDmwFbo2pfz+QBuwLDAb+EJbfB5wZU+844DN3f7fpG5rZ4cBNwKnAnuE9zQ4PPwycZmYW1h0IfA2YHTagTwCF4f0fAfzIzI6OufwMYA4wgOAzbOo3wFhgMjA6vM61TT6P7LD8HGCWmY0Lj/0PkAnsBUwn+LzPC+Ns62dwKnAMMAqYyI7k8sdACZAD5AL/SdD7KiIinednwEEEbcEkYCpwTXis2d/LYVtwCXCAu2cARwMrW7h+s+2Fu28D/gbMjKl7KvAvd19nwZeedwEXAVnAX4G59vkpFTOB44EB7l4X+6Zm1g94DniIoI0+HfizmeXHVPs2cANBW7eQsK00s0HAP4A/he/9e+AfZpYVntdS+w9B25lJ0HaeD9wWtt8AdwIXhZ/ZBODFFj4zkZ3j7nrpFfkL+DJQC2SH+0uB/wi3DwZKgaRmznsGuKyFazowOmb/HuBX4fZhQA3Qp5WYJgObwu09gQZgYDP1hgAVQP9wfw5wZQvXvBP4bcx+enjfIwEDPgW+Eh67AHgx3D4Q+LTJta4G7g63rwNeaeVeDKgC9o4pOxj4OObzqAP6xRx/FPg5kBh+Vvkxxy4CXm7Hz2AlcGbM/m+B/xduXw/8X+zPSC+99NJLr455hb9/j2ym/CPguJj9o4GV4Xazv5cJvvRbBxwJJLfynm21F0cCH8Ucex04O9z+C3BDk+stA6bH3M93Wnnv04BXm5T9FfhFuH0PMDvmWDpQD+QBZwFvNTn3DYIvHFtr/w8j+LI3KaZsHXBQuP1peP/9o/7vQa/e9VIPnHQX5wDPuvv6cP8hdgyjzAM+8SbftsUc+2gX37PU3asbd8wszcz+Gg752Ay8AgwIewDzgI3uvqnpRdx9NUEjdJKZDQCOpfkeMAiSvU9izq0k6K0a6u5O0BvX+O3kGTHXGQEMCYeVlJlZGcE3o7kx1y5u5V5zCL49fDvm/H+G5Y02uXtVzP4nYbzZQHJs3OH20HC7rZ/BmpjtLQSNJsB/AUXAs2a2wsyuauUaIiLSMT7XDrHjdz208HvZ3YuAHxF8WbjOzGZbOPy/ibbai5eANDM70IJ575OBv4fHRgA/btLO5cXEBq23cyOAA5uc/22CHrIvnB+2vxvD6zf9TGLjbrH9D21o8vdJbDt3EsGonE/CIZsHtxK/SLspgZPIWTAP61RgupmtsWBO2n8Ak8xsEsEv3OHNjXcPj+3dwqW3ECQtjfZocrzpcL0fA+OAA929P/CVxhDD9xkUJmjNuZdgGOUpwBvuvqqFeqsJGpngwsGQjyygsf7DwMnhuPwDgcfC8mKC3rIBMa8Mdz+ulfuJtZ7gW8J9Y87PdPf0mDoDw3gaDQ/jXU/QSziiybHGmFv7GbTI3Svc/cfuvhfBsMvLW5pTISIiHeZz7RA7fte3+nvZ3R9y9y+H5zpwczPXbrW9cPd6gtEdM8PXk+5eEdYrBm5s0s6lufvDMddqrZ0rJhiOGXt+urt/P6ZOXuOGmaUTTKNY3cxnEht3W+1/i9x9vrvPIBh2+Xh47yK7TQmcdAffJBjGkE/wbdxkYDzwKsHY+beAz4DfmFk/M+tjZoeG594BXGFmUywwOmZS8kLgDDNLNLNjCMbityaDIMkpC8fD/6LxgLt/BjxNMJ5+oAULlXwl5tzHgS8BlxHMiWvJw8B5ZjY5HNf/a2Ceu68M3+ddggbwDuAZdy8Lz3sLqAgncPcN72mCmR3Qxj01xt8A3A78IWay+NAmc+gAfmlmKWY2Dfg68L8xDe6NZpYRfr6XAw+E57T2M2iRBZPVR4dz/soJ/htoaM/9iIhIuySHbWbjK4mgHbrGggVEsgnmQj8ALf9eNrNxZnZ42G5VE7SVX/h93Y72AoIRNqcR9I49FFN+O/C9sHfOwvb+eDPLaOe9PgmMNbOzwjY62cwOMLPxMXWOM7Mvm1kKwVy4N929GHgqPPcMM0sys9MI/iZ5sh3tf7PCtvTbZpbp7rXA5uY+M5FdoQROuoNzCOZyferuaxpfBAuIfJugB+wbBGPwPyWYYH0agLv/L3AjQSNQQZBIDQqve1l4XuMwis+tGtWMPwJ9CRKoNwmGGMY6i+CbxaUEY9x/1HjA3bcS9JaNIpik3Sx3f55gXtljBEnp3gQTrWM9RDBP4KGY8+oJEqrJwMfsSPIy27inWD8lGBrzZjhE9HmCHsdGa4BNBN9EPgh8z92XhscuJZhDtwJ4LYztrjC21n4GrRkTxlBJMNfgz+7+0k7cj4iItO4pgmSr8XUd8CtgAfAe8D7wTlgGLf9eTiVYCGs9QVsxmGAednNabC8A3H1eeHwIQWLUWL6AYO73rQRtURHNrKjckrAn72sEberqMM6bw9gbPUTw5exGYArhAmTuvoGgjf0xwbSGK4Gvx0zraLH9b8NZwMqwzf0ewd8iIrvNgmk3IrK7zOxaYKy7n9lm5W7GzA4DHnD3YW3VFRER6WnM7B6gxN2vaauuSHenByGKdIBwyOX5BN+2iYiIiIh0Cg2hFNlNZnYBwSTnp939lajjEREREZHeS0MoRUREREREegj1wImIiIiIiPQQSuBERERERER6iG63iEl2draPHDky6jBERKQLvP322+vdPSfqOHoKtZEiIvGhtfax2yVwI0eOZMGCBVGHISIiXcDMPok6hp5EbaSISHxorX3UEEoREZE2mNkxZrbMzIrM7Kpmjqea2SPh8XlmNjIszzKzl8ys0sxubeHac83sg869AxER6S2UwImIiLTCzBKB24BjgXxgppnlN6l2PrDJ3UcDfwBuDsurgZ8DV7Rw7ROBys6IW0REeiclcCIiIq2bChS5+wp3rwFmAzOa1JkB3BtuzwGOMDNz9yp3f40gkfscM0sHLgd+1Xmhi4hIb6METkREpHVDgeKY/ZKwrNk67l4HlANZbVz3BuB3wJaOCVNEROKBEjgREZEuZmaTgb3d/e/tqHuhmS0wswWlpaVdEJ2IiHRnSuBERERatwrIi9kfFpY1W8fMkoBMYEMr1zwYKDCzlcBrwFgze7m5iu4+y90L3L0gJ0dPXBARiXdK4ERERFo3HxhjZqPMLAU4HZjbpM5c4Jxw+2TgRXf3li7o7n9x9yHuPhL4MvChux/W4ZGLiEiv0+2eAyciItKduHudmV0CPAMkAne5+yIzux5Y4O5zgTuB+82sCNhIkOQBEPay9QdSzOybwNfcfXFX34eIiPQOSuBERETa4O5PAU81Kbs2ZrsaOKWFc0e2ce2VwITdDlJEROKCEjgREWm38q21FK2rYPnaSg7eO4sRWf2iDklERCSuKIETEZEv2FRVw/J1lSwPk7XGf9dVbNte56YT91MC19PMmwe//jXMmgW5uVFHIyIiu0AJnIhInHJ31lfWsHxdBUXrKrcnakXrKllfWbO9XlpKImMGpzNtTA5jctMZm5vOmMEZDB3QN8LoZZfU18OyZbB6tRI4EZEeSgmciEgv5+6sq9jG8rWVfLi2guXrKoNhkOsqKdtSu71eRp8kxgxO54h9chmTm87owemMyc1gSGYfzCzCO5AOc8ghsHRp1FGIiMhuUAInItJLuDury6tZvvbzPWrL11VSUV23vd6AtGTGDs7guP32ZMzgoDdtTG46gzNSlaiJiIh0c0rgRER6kPoGZ9OWGtZXbmPVpq3BPLW1QY9a0bpKqmrqt9fNTk9h9OB0vjl56I4etcEZZKenKFGLZ/ffD3/8I7z1FiQmRh2NiIjsJCVwIiIRq6lrYEPVNtZXBIlZ8Aq2N8Rsr6/cxsaqGhqaPB46t38qYwZncEpBHmPC+WmjB6czqF9KNDck3Vt6OgwbBmVlkJUVdTQiIrKTlMCJiHSCLTV1rK+oobSZJGxDZVC+vnIb6yu2sTlmeGOstJREstJTyE5PJW9QGvsPH0h2uJ+dnsoemX0YPTidzL7JXXx30qN961vBS0REeiQlcCIibdhWV0/51lo2b62lPPa1pZYNVTWf6zELkrIattbWN3ut/n2SyM4IErDxe/Qna/SOhCw7PYWs9FRy0lPJzkghLUW/oqUTuYOG0oqI9Dj660BE4kJrSVj51jrKt9ZStrXmi8e31lJd29DidRMMBvXbkYQNH572uYRs+3ZGCln9UklJSujCuxZpwRVXwIsvwjvvRB2JiIjsJCVwItIjNDQ4lTV1VFTXsXlr7fZ/N1d/MeHa2SQMID01icy+yfTvm0xm3yRGZfcjs2/y5179m+xn9k1mQFoKiQnqxZAeZtKkYAET9cKJiPQ4SuBEpNO5O9W1DVRUBwnX5uomiVh1LRXVtZ8r21Ee/Fu5rQ731t9nZ5KwAWkpO8r6JJGUqJ4xiSNnnRV1BCIisouUwIlIu9XUNbCxqoYNVcFqiBurathQWUP59qSr9nNJV0VMolbXdOnEJhITjIw+SWT0SaJ/n2Qy+iQxfFAaGeF2/zDR2nF8R7mSMJFd4A4VFdC/f9SRiIjITlACJxLHttbUb0/GNlTVsLGyZsd2bHl4rGJb86slQtD7FZt85aSnsndOepiUJW8v7983/LdJIpaWkqhnk4l0pWnTYMAAePLJqCMREZGdoAROpJdwdyq31bWYjG1PxMJes41VLa+UmJxoDOqXwqB+qWT1SyFvYBqD+qWQ1S+FQenhv/1St5f175useWAiPc3550NqatRRiIjITlICJ9JDrS7byqvLS3ll+XoWflpGaeU2auqaX6ijT3ICWWHCNahfCqNzgoc8N5eMDUpPISM1Sb1hIr3deedFHYGIiOwCJXAiPUTVtjrmfbyBVz5cz6vLS/motAqA3P6pHDgqiz0z+2xP0LLSd/SeZaXreWIi0oL166GmBoYMiToSERFpJ/1VJ9JNNTQ4i1Zv5pXlpby6vJS3P9lEbb3TJzmBA0dlMXPqcL4yNocxg9PVWyYiO6+hAfbaC84+G269NepoRESknZTAiXQjn5Vv5dUP1/Nq0XpeW17Kpi21AOTv2Z/vfHkUXxmTw5QRA+mTnBhxpCLS4yUkwF//CmPHRh2JiIjsBCVwIhHaUlPHvBUbw1629RStqwRgcEYqh++Ty7Qx2Rw6OpucDC00ICKdYObMqCMQEZGdpAROpAs1NDiLPwuHRX64nrc/2URNfQOpSQkcuFcWpx+Qx7QxOYzN1bBIEekCtbWwYAHk5cGwYVFHIyIi7aAETqSTrSmv5tWwh+21ovVsrKoBYPye/Tnv0JFMG5NDwUgNixSRCGzcCIccArfcAj/+cdTRiIhIOyiBE+lgW2vqmffxBl5dHqwW+eHaYFhkTkYqh43NYdrYYFjk4Iw+EUcqInEvNxf+8Q844ICoIxERkXZSAieym9ydpWsq+NeHwWqR8z/eMSxy6qhBnDIlj2ljsxmXm6FhkSLS/Rx3XNQRiIjITlACJ7ILausbeOvjjTy3eC3PLV7LqrKtAOyzRwbnHjqSaWOyOWDkIA2LFJHub/16ePJJOP54yMmJOhoREWmDEjiRdirfWsu/Pizl+cVreWnZOiqq6+iTnMCXR+fwwyNG89VxgxncX8MiRaSH+egjOO88eOwxOPHEqKMREZE2KIETaUXJpi08v3gtzy9Zx5srNlDX4GSnp3DchD05Mj+XL4/Opm+KetlEpAfbf39YvBjGjYs6EhERaQclcCIx3J0PVm3muSXB0Mgln20GYPTgdL47bS+Oys9lct4AEhM0l01EeomUFBg/PuooRESknZTASdzbVlfPGx9t4Pkla3l+8TrWbK4mwaBgxCB+dtx4jszPZVR2v6jDFBHpPO+/D7Nnw3XXQXJy1NGIiEgrlMBJXCrbUsNLy9bx3OK1/GtZKVU19aSlJPKVMTkcmZ/L4fsMZlC/lKjDFBHpGosWwW9/C2eeqd44EZFuTgmcxI1PNlTx3OK1PL9kLfNXbqK+wRmckcoJk4fytfxcDt47S6tGikh8+uY3YfNm6Ns36khERKQN7UrgzOwY4L+BROAOd/9Nk+MjgLuAHGAjcKa7l4THfgscDyQAzwGXubt32B2ItKChwSksKduetDU+UHufPTL4/vS9OSo/l/2GZpKg+Wwi0oZ2tIOpwH3AFGADcJq7rzSzLGAOcABwj7tfEtZPA/4X2BuoB55w96u66n6+oI9W0BUR6SnaTODMLBG4DTgKKAHmm9lcd18cU+0W4D53v9fMDgduAs4ys0OAQ4GJYb3XgOnAyx13CyI7VNfW83rR+mA+25J1lFZsIzHBmDpyENd+fThHjs9leFZa1GGKSA/SznbwfGCTu482s9OBm4HTgGrg58CE8BXrFnd/ycxSgBfM7Fh3f7qz76dF//d/8NxzcOutkYUgIiJta08P3FSgyN1XAJjZbGAGENtw5QOXh9svAY+H2w70AVIAA5KBtbsftsgOG6tqeCFcNfLV5evZWltPemoS08flcNT4XL46bjCZaZqULyK7rD3t4AzgunB7DnCrmZm7VwGvmdno2Au6+xaC9hJ3rzGzd4BhnXoXbVmyBP75T9i6VUMpRUS6sfYkcEOB4pj9EuDAJnUKgRMJhpd8C8gwsyx3f8PMXgI+I0jgbnX3JbsftsS7Tzds4dnFa3h28VoWrNxIg8OemX04ecowjsrP5aC9skhJSog6TBHpHdrTDm6v4+51ZlYOZAHr27q4mQ0AvkHQhkbnyivhquhGcYqISPt01CImVxB823gu8AqwCqgPv3Ecz45vFZ8zs2nu/mrsyWZ2IXAhwPDhwzsoJOlN3J1Fqzfz7KIgaVu6pgII5rNd8tXRHJW/BxOG9sdM89lEpOcwsyTgYeBPjT18zdTpmjYyQV96iYj0BO1J4FYBeTH7w8Ky7dx9NUEPHGaWDpzk7mVmdgHwprtXhseeBg4GXm1y/ixgFkBBQYEWOBEAausbmLdiI88tXsNzi9eyujx8PtvIQVxz/Hi+lr+H5rOJSFdosx2MqVMSJmWZBIuZtGUWsNzd/9hShS5tI2+8EVavhttu69S3ERGRXdeeBG4+MMbMRhE0UKcDZ8RWMLNsYKO7NwBXE6xICfApcIGZ3UQwhHI60GIjJVK5rY5/LSvl2cVreGnpOjZX19EnOYFpY3L4j6PGcsT4XD2fTUS6WpvtIDAXOAd4AzgZeLGtFZfN7FcEid53OzziXVVeDhs3Rh2FiIi0os0ELhzLfwnwDMHyyXe5+yIzux5Y4O5zgcOAm8zMCYZQXhyePgc4HHifYEGTf7r7Ex1/G9KTrauo5vnF63h28Rr+XbSBmvoGBqYlc/S+e3BUfi7TxuTQN0XPZxORaLSzHbwTuN/Miggep3N64/lmthLoD6SY2TeBrwGbgZ8BS4F3wuHft7r7HV13Z8347W8jfXsREWmbdbdHshUUFPiCBQuiDkM6WdG6Sp5bvJZnF69hYXEZ7jB8UBpfy8/lqPxcpowYSFKi5mOI9HZm9ra7F0QdR0+hNlJEJD601j521CImIq1qaHDeLS7j2XA+24rSKgD2G5rJ5UeO5Wv77sHY3HQtQiIiErUZM2DvveH3v486EhERaYYSOOk01bX1vPHRhjBpW8f6ym0kJRgH753FuYeM5MjxuQwZoGcNiYh0K6NGwdChUUchIiItUAInHap8Sy0vLQvms/1rWSlVNfX0S0nksH0G87X8XA4bN5jMvnqotohIt/VHrTUmItKdKYGTDlG5rY5fzl3E399dRV2Dk5ORyoz9h/K1/FwO3juL1CQtQiIi0mO4Q10dJOsLNxGR7kYJnOy2xas3c8lD77ByQxVnHzySGZOHMGnYABISNJ9NRKTH2bIlmAP3ox/BT38adTQiItKEEjjZZe7Og/M+5fonFzMwLZmHLjiIg/bKijosERHZHWlpcMYZMHFi1JGIiEgzlMDJLtlcXcvVf3uff7z3GdPH5vD7UyeRlZ4adVgiItIRfve7qCMQEZEWKIGTnfZeSRmXPPQuq8q28tNj9uGir+yl4ZIiIr1NWRmkpAQ9ciIi0m3oScnSbu7O3a9/zEl/+Td19Q08etFBfP+wvZW8iYj0Nu+/DwMHwhNPRB2JiIg0oR44aZfyLbX8ZE4hzy5ey5HjB/NfJ09iYL+UqMMSEZHOsM8+cMMNMGlS1JGIiEgTSuCkTe98uolLH3qXdRXV/Pzr+Xzn0JGYqddNRKTXSk6Ga66JOgoREWmGEjhpUUODc/urK/ivZ5ax54A+zPneIUzKGxB1WCIi0hVqauC994LVKFM04kJEpLvQHDhp1saqGs6/dz43Pb2Uo/JzefLSaUreRETiydy5cMABUFgYdSQiIhJDPXDyBW99vJEfPvwuG6tquGHGvpx50AgNmRQRiTeHHQaPPgqjR0cdiYiIxFACJ9s1NDh/frmI3z/3IcMHpfG3HxzChKGZUYclIiJRyM6GU06JOgoREWlCCZwAUFqxjcsfXciry9dzwqQh/PrE/UhP1X8eIiJxbdUqmDcPTjwx6khERCSkv9CF14vW86NHFrJ5ay03n7QfpxbkacikiIjAww/DT34Ca9fC4MFRRyMiIiiBi2v1Dc5/v7Cc/3lxOXvnpPPA+Qcybo+MqMMSEZHu4owz4KijICsr6khERCSkBC5Ord1czQ8ffpd5H2/k5CnDuH7GvqSl6D8HERGJMWRI8BIRkW5Df7HHoZeXrePyRwvZWlPP706ZxElThkUdkoiIdFf//jcsWwbnnRd1JCIigp4DF1dq6xu4+Z9LOffu+QzOSOWJS7+s5E1ERFr34INw+eXQ0BB1JCIighK4uLG6bCunz3qTv7z8EWccOJzHLz6U0YPTow5LRES6u1/8AoqLIUF/MoiIdAcaQhkHnl+8livmFFJX7/xp5v6cMEnzGUREpJ20+qSISLeiBK4Xq6kLhkze+drHTBjan1tnfomR2f2iDktERHqaO++E+nq48MKoIxERiXtK4Hqp4o1buOShdygsKefcQ0Zy9XH7kJqUGHVYIiLSE/3tb7BtmxI4EZFuQAlcL/T0+59x5WPvAfD/zvwSx0zYM+KIRESkR5szB/r2jToKERFBCVyvUlffwA1PLubeNz5hUt4Abp25P3mD0qIOS0REejolbyIi3YaWlOol6hucn8x5j3vf+ITvfnkU/3vRwUreRESkYzQ0wCWXwN13Rx2JiEjcUw9cL9DQ4Pzs7+/z93dX8ZOjx3HxV0dHHZKIiPQmCQkwfz4MGBB1JCIicU8JXA/n7vzyiUXMnl/MpYePVvImIiKd4803wSzqKERE4p6GUPZg7s5vnl7KvW98wgXTRnH5UWOjDklERHorJW8iIt2CErge7A/PL+evr6zgrING8J/HjcfUuIqISGdZtw6OOgoeeyzqSERE4poSuB7qzy8X8acXlnNqwTB+ecK+St5ERDqRmR1jZsvMrMjMrmrmeKqZPRIen2dmI8PyLDN7ycwqzezWJudMMbP3w3P+ZN39F/mgQbBlS/BAbxERiYwSuB7ortc+5rf/XMaMyUO46cSJJCR07zZfRKQnM7NE4DbgWCAfmGlm+U2qnQ9scvfRwB+Am8PyauDnwBXNXPovwAXAmPB1TMdH34GSkuD11+HUU6OOREQkrimB62Eemvcp1z+5mGP23YPfnTKJRCVvIiKdbSpQ5O4r3L0GmA3MaFJnBnBvuD0HOMLMzN2r3P01gkRuOzPbE+jv7m+6uwP3Ad/s1LvoSO5RRyAiEreUwPUgj71dws8ef5/D9xnMn2buT1KifnwiIl1gKFAcs18SljVbx93rgHIgq41rlrRxze5n/nwYPhzeeCPqSERE4pYygB7iicLV/GROIYfunc2fv/0lUpL0oxMRiQdmdqGZLTCzBaWlpdEGk5cHhx4KffpEG4eISBxTFtADPLtoDT96ZCEFIwYx6+wp9ElOjDokEZF4sgrIi9kfFpY1W8fMkoBMYEMb1xzWxjUBcPdZ7l7g7gU5OTk7GXoH22MPePhh+NKXoo1DRCSOKYHr5l5eto5LHnqX/YZmctd5B5CWomevi4h0sfnAGDMbZWYpwOnA3CZ15gLnhNsnAy+Gc9ua5e6fAZvN7KBw9cmzgf/r+NA7SVVV1BGIiMQtJXz/YJMAACAASURBVHDd2L+L1nPR/W8zenA69543lfRUJW8iIl0tnNN2CfAMsAR41N0Xmdn1ZnZCWO1OIMvMioDLge2PGjCzlcDvgXPNrCRmBcsfAHcARcBHwNNdcT+77b77oH9/WL066khEROKSMoJuasHKjZx/7wJGZKXxwHcPJDMtOeqQRETilrs/BTzVpOzamO1q4JQWzh3ZQvkCYELHRdlFCgrgmmsgQd8Bi4hEQQlcN1RYXMa5d89nz8w+PPDdAxnULyXqkERERAL5+fDLX0YdhYhI3NLXZ93M4tWbOfuutxjYL5kHLziQwRla6UtERLqZujr46KOooxARiUtK4LqR5WsrOOvOeaSlJPLQdw9iz8y+UYckIiLyRVdeCRMnBomciIh0KQ2h7CZWrq/i23fMIyHBeOiCg8gblBZ1SCIiIs074wyYMgXq6yFJf0qIiHQl/dbtBoo3buGM29+krsF55MKDGJXdL+qQREREWlZQELxERKTLaQhlxNaUV/PtO+ZRua2O+8+fypjcjKhDEhERaVtJCcyfH3UUIiJxRz1wESqt2MYZd7zJxqoaHvjugew7JDPqkERERNrn4oth2TJYujTqSERE4ooSuIhsrKrhzDvm8VlZNfedP5XJeQOiDklERKT9rr0W3KOOQkQk7rRrCKWZHWNmy8ysyMyuaub4CDN7wczeM7OXzWxYzLHhZvasmS0xs8VmNrLjwu+ZyrfWctad8/h4QxV3nFPAASMHRR2SiIjIzpkyRfPgREQi0GYCZ2aJwG3AsUA+MNPM8ptUuwW4z90nAtcDN8Ucuw/4L3cfD0wF1nVE4D1V5bY6zr37LT5cW8Ffz5rCoaOzow5JRERk17z0ErzyStRRiIjElfb0wE0Fitx9hbvXALOBGU3q5AMvhtsvNR4PE70kd38OwN0r3X1Lh0TeA22tqec798znvZJy/mfml/jquMFRhyQiIrLrLrsMfv3rqKMQEYkr7UnghgLFMfslYVmsQuDEcPtbQIaZZQFjgTIz+5uZvWtm/xX26MWd6tp6Lrx/AQtWbuQPp03mmAl7RB2SiIjI7pk9O3iJiEiX6ajHCFwBTDezd4HpwCqgnmCRlGnh8QOAvYBzm55sZhea2QIzW1BaWtpBIXUfNXUNXPzgO7y6fD03nzSREyYNiTokERGR3ZefDwO0CJeISFdqTwK3CsiL2R8Wlm3n7qvd/UR33x/4WVhWRtBbtzAcflkHPA58qekbuPssdy9w94KcnJxdvJXuqa6+gctmv8sLS9dxwzcncEpBXtsniYiI9ATbtsGf/wyvvhp1JCIicaM9Cdx8YIyZjTKzFOB0YG5sBTPLNrPGa10N3BVz7gAza8zKDgcW737YPUN9g3PF/xby9AdruOb48Zx10IioQxIREek4yclw9dXw+ONRRyIiEjfafA6cu9eZ2SXAM0AicJe7LzKz64EF7j4XOAy4ycwceAW4ODy33syuAF4wMwPeBm7vnFvpXhoanP/82/s8vnA1Pzl6HN+dtlfUIYmIiHSshARYvhx62egZEZHurF0P8nb3p4CnmpRdG7M9B5jTwrnPARN3I8Yex9257olFPLKgmEsPH83FXx0ddUgiIiKdY7BWVBYR6UodtYiJhNydm55eyn1vfMIF00Zx+VFjow5JRESk83z2GfzoR/DOO1FHIiISF5TAdbCH3ypm1isrOPvgEfznceMJRo6KiIj0UklJcPvtsGRJ1JGIiMSFdg2hlPZ7ZtEaRg9O57pv7KvkTUREer+cHCgvDxI5ERHpdOqB60DuTmFJGQUjBpKQoORNRETihJI3EZEuowSuA326cQtlW2qZlKeHmoqISByZPx+OOw6Ki6OORESk11MC14EWFpcBMGmYEjgREYkjCQlB8rZ2bdSRiIj0ehrz0IEWFpfRNzmRsbnpUYciIiLSdaZMgfffjzoKEZG4oB64DlRYXMZ+QzNJStTHKiIiIiIiHU+ZRgeprW/gg9WbmZSXGXUoIiIiXe/BByE/H2pqoo5ERKRXUwLXQZatqaCmrkELmIiISHwaNChI4MrKoo5ERKRX0xy4DvKuFjAREZF4duyxwUtERDqVeuA6SGFxGVn9Uhg2sG/UoYiIiESnvj7qCEREejUlcB2ksLiMyXkDMNMDvEVEJE5deSVMmBB1FCIivZqGUHaAiupaikor+cakIVGHIiIiEp2CAkhKCnrhEhOjjkZEpFdSAtcB3l9VjjtawEREROLbqacGLxER6TQaQtkBCovLAZg0TI8QEBHpjczsGDNbZmZFZnZVM8dTzeyR8Pg8MxsZc+zqsHyZmR0dU/4fZrbIzD4ws4fNrE/X3E0na2iADRuijkJEpNdSAtcBFhZvYmRWGgPSUqIORUREOpiZJQK3AccC+cBMM8tvUu18YJO7jwb+ANwcnpsPnA7sCxwD/NnMEs1sKPBDoMDdJwCJYb2e76tfhdNOizoKEZFeS0MoO0BhcTkH7jUo6jBERKRzTAWK3H0FgJnNBmYAi2PqzACuC7fnALdasKrVDGC2u28DPjazovB6nxK0wX3NrBZIA1Z3wb10vu99D7Sgl4hIp1ECt5vWlFezZnM1kzX/TUSktxoKFMfslwAHtlTH3evMrBzICsvfbHLuUHd/w8xuIUjktgLPuvuzzb25mV0IXAgwfPjw3b+bzjZzZtQRiIj0ahpCuZsKS8IHeCuBExGRdjKzgQS9c6OAIUA/MzuzubruPsvdC9y9ICcnpyvD3HVr1sCnn0YdhYhIr6QEbjcVFpeRlGDk79k/6lBERKRzrALyYvaHhWXN1jGzJCAT2NDKuUcCH7t7qbvXAn8DDumU6LuaO+Tnww03RB2JiEivpARuNxWWlDF+z/70SdbzbkREeqn5wBgzG2VmKQSLjcxtUmcucE64fTLwort7WH56uErlKGAM8BbB0MmDzCwtnCt3BLCkC+6l85nB7bfDD34QdSQiIr2S5sDthoYG573icmbsrwd4i4j0VuGctkuAZwhWi7zL3ReZ2fXAAnefC9wJ3B8uUrKRcEXJsN6jBAue1AEXu3s9MM/M5gDvhOXvArO6+t46zUknRR2BiEivpQRuN6xYX0nFtjom5w2MOhQREelE7v4U8FSTsmtjtquBU1o490bgxmbKfwH8omMj7SZqauDNN2H4cBg5MupoRER6FQ2h3A0Lwwd4T87TA7xFRES2q6qC6dPhoYeijkREpNdRD9xuKCwuIz01ib2y06MORUREpPsYOBBeeAEmTYo6EhGRXkcJ3G4oLClj4rBMEhL0wFIREZHPOfzwqCMQEemVNIRyF1XX1rPks816/puIiEhzSkuD1Sg/+yzqSEREehUlcLto8Webqa13Jg1TAiciIvIFq1bBhRfCK69EHYmISK+iIZS7qLC4DID9hyuBExER+YIJE2D5cth776gjERHpVZTA7aLC4jL26N+H3P59og5FRESk+0lKgtGjo45CRKTX0RDKXVRYUs4kPT5ARESkZe+/Dz/5CVRXRx2JiEivoQRuF5RtqeHj9VVawERERKQ1H30Ef/pTMJRSREQ6hIZQ7oLCkvAB3lrAREREpGXHHQebN0NqatSRiIj0GkrgdkFhcRlmsN8wDaEUERFpUUpK1BGIiPQ6GkK5CwqLyxidk05Gn+SoQxEREeneHn8czj8/6ihERHoNJXA7yd0pLCnT/DcREZH2WLkSXn8dKiujjkREpFdQAreTVpVtZX1ljRI4ERGR9rjsMli6FNLTo45ERKRXUAK3kwqLtYCJiIhIu5lFHYGISK+iBG4nLSzeREpSAuP2yIg6FBERkZ7hxhvhvPOijkJEpFdQAreTCovLmTCkPylJ+uhERETapbZWD/MWEekgeozATqirb+D9VeWcPjUv6lBERER6juuuizoCEZFeQ91IO2H5ukq21tYzWQuYiIiIiIhIBJTA7YTC4jIAJmkBExERkZ0zYwZ873tRRyEi0uNpCOVOWFhcRmbfZEZkpUUdioiISM+Snw85OVFHISLS4ymB2wkLi4MHeJuWRBYREdk5N90UdQQiIr2ChlC205aaOj5cW6H5byIiIrvKHbZtizoKEZEeTQlcO32wajMNDpPzMqMORUREpOfZtg2GDYOTToo6EhGRHk1DKNupcQGTiVrAREREZOelpgbDKEeNijoSEZEeTQlcOy0sKWPYwL5kp6dGHYqIiEjPdPbZO7bfegumTIHExOjiERHpgdo1hNLMjjGzZWZWZGZXNXN8hJm9YGbvmdnLZjasyfH+ZlZiZrd2VOBdbeGnwQImIiIispsWL4ZDDoHf/S7qSEREepw2EzgzSwRuA44F8oGZZpbfpNotwH3uPhG4Hmi61NQNwCu7H240Siu2sapsK/srgRMREdl9+fkwaxZ8//tRRyIi0uO0pwduKlDk7ivcvQaYDcxoUicfeDHcfin2uJlNAXKBZ3c/3Gi8VxI+wFsJnIiISMf4zncgIwNqa+Gxx6KORkSkx2hPAjcUKI7ZLwnLYhUCJ4bb3wIyzCzLzBKA3wFX7G6gUSosLiMxwdh3SP+oQxEREeldbr8dTj45mBMnIiJt6qjHCFwBTDezd4HpwCqgHvgB8JS7l7R2spldaGYLzGxBaWlpB4XUcRaWlDM2N4O0FK35IiIi0qEuugj++U+YOjXqSEREeoT2JHCrgLyY/WFh2XbuvtrdT3T3/YGfhWVlwMHAJWa2kmCe3Nlm9pumb+Dus9y9wN0LcnJydu1OOom7U1hcpue/iYjEsXYs5pVqZo+Ex+eZ2ciYY1eH5cvM7OiY8gFmNsfMlprZEjM7uGvupptJTISjw49lyRJ45JFo4xER6eba06U0HxhjZqMIErfTgTNiK5hZNrDR3RuAq4G7ANz92zF1zgUK3P0LDV93tnLDFsq31jJJz38TEYlLMYt5HUUwjWC+mc1198Ux1c4HNrn7aDM7HbgZOC1c9Ot0YF9gCPC8mY1193rgv4F/uvvJZpYCpHXhbXVPv/wlvP46fOMbkKaPQ0SkOW32wLl7HXAJ8AywBHjU3ReZ2fVmdkJY7TBgmZl9SLBgyY2dFG+Xa3yA9+ThSuBEROJUexbzmgHcG27PAY4wMwvLZ7v7Nnf/GCgCpppZJvAV4E4Ad68JR67EtzvugH/9S8mbiEgr2jWpy92fAp5qUnZtzPYcggartWvcA9yz0xFGbGFxGWkpiYwZnBF1KCIiEo3mFvM6sKU67l5nZuVAVlj+ZpNzhwJbgVLgbjObBLwNXObuVZ1yBz1FenrwAvjv/4ZJk+CwwyINSUSku+moRUx6rcKSMiYMzSQxwaIORUREeo8k4EvAX8L541VAs1MMuvtCX51i69bgOXH33tt2XRGROKMErhU1dQ0sWr2ZyXr+m4hIPGtzMa/YOmaWBGQCG1o5twQocfd5YfkcgoTuC7rzQl+dpm/fYCjl7bdHHYmISLejBK4VS9dspqauQQuYiIjEt+2LeYWLjZwOzG1SZy5wTrh9MvCiu3tYfnq4SuUoYAzwlruvAYrNbFx4zhHAYmSH7GxISoKyMvjWt2DZsqgjEhHpFvRgs1Y0LmAySY8QEBGJW+GctsbFvBKBuxoX8wIWuPtcgsVI7jezImAjQZJHWO9RguSsDrg4XIES4FLgwTApXAGc16U31lOsXw9vvx0kcOPGtV1fRKSXUwLXioXF5WSnpzJ0QN+oQxERkQi1YzGvauCUFs69kWZWZ3b3hUBBx0baC40eDR9+CH36BPvuYJqXLiLxS0MoW1FYEjzA29RQiIiIRKcxeXvhBZg6FeJlMRcRkWYogWvB5upaPiqt1Pw3ERGR7iIpCRISoL6+7boiIr2UhlC24IOSctxhklagFBER6R6mT4c33wyGULrDtm07eudEROKEeuBa8G7jAibqgRMREek+Gqc1XHopHHcc1NREG4+ISBdTD1wLCovL2Cu7H5lpyVGHIiIiIk0ddBBkZkKy2mkRiS9K4FpQWFLGIXtnRx2GiIiINOfMM3dsl5bCoEGQmBhdPCIiXURDKJuxpryatZu3MWmYnv8mIiLSrW3eHPTGXXZZ1JGIiHQJ9cA1Y+H2B3hr/puIiEi31r8/XHRRsMCJiEgcUALXjMKSMpITjfF79o86FBEREWnLlVfu2H7vPZg4MbpYREQ6mYZQNmPhp2WM37M/fZI1ll5ERKTHeO01mDwZ7r8/6khERDqNErgm6huc91eVM1nDJ0VERHqWQw6BW26Bk06KOhIRkU6jBK6JFaWVVG6r0/PfREREepqEBLj8ckhLCx7y/eyzUUckItLhlMA1oQVMREREeoGbbgoe9F1UFHUkIiIdSouYNFFYUkZGahJ7ZfeLOhQRERHZVVdeCVOmwOjRUUciItKh1APXxMLiMibmZZKQYFGHIiIiIrsqLQ2+8Y1gu7AQXnwx2nhERDqIErgY1bX1LP2sQvPfREREegv34CHfP/gB1NVFHY2IyG7TEMoYi1Zvpq7BtQKliIhIb2EGjzwCW7dCUhLU1weLnZhG2ohIz6QeuBiF4QImSuBERER6kdxcGDky2P75z+Hkk6G2NtKQRER2lXrgYhSWlLFnZh8G9+8TdSgiIiLSGbKzoawMkpOjjkREZJcogYtRWFym+W8iIiK92eWX79j++GO45x645holdCLSY2gIZWhTVQ0rN2zR899ERETixWOPwR//CJ99FnUkIiLtpgQuVFii+W8iIiJx5YorYPFiGD482P/3v4NVK0VEujElcKHC4nLMYL9hmVGHIiIiIl1l6NDg3+efh0MPhUcfjTYeEZE2KIELFZaUMWZwOumpmhYoIiISd776VZg1C048MdivqYk2HhGRFiiBA9xdC5iIiIjEs8REuOCCYDGTLVvggAPgT3+KOioRkS9QAgeUbNrKhqoaLWAiIiIiwcO+J06E8eOjjkRE5As0XhBYqAd4i4iISKOMDLj//h37d98N/frBqadGF5OISEgJHMHz31KTEhi3R0bUoYiIiEh34g733Qd9+sApp4BZ1BGJSJxTAkewgMmEoZkkJ2pEqYiIiMQwg2efhaqqYHvjRlixAgoKoo5MROJU3GcsdfUNvL+qXAuYiIiISPOSk2FA+HfCNdfA9Omwfn20MYlI3Ir7HrgP11ZSXdvApDw9/01ERETacOONcMwxkJ0d7G/ZAmlp0cYkInEl7nvgtICJiIiItNvAgXDCCcH266/DqFHw1lvRxiQicSXuE7jC4jIGpiUzfJC+PRMRkeaZ2TFmtszMiszsqmaOp5rZI+HxeWY2MubY1WH5MjM7usl5iWb2rpk92fl3IR0uJwemTYN99ok6EhGJI0rgSsqYlDcA06pSIiLSDDNLBG4DjgXygZlmlt+k2vnAJncfDfwBuDk8Nx84HdgXOAb4c3i9RpcBSzr3DqTTjB0Lc+ZA//7Bs+MuvBAKC6OOSkR6ubhO4Kq21fHh2gotYCIiIq2ZChS5+wp3rwFmAzOa1JkB3BtuzwGOsOCbwRnAbHff5u4fA0Xh9TCzYcDxwB1dcA/S2VauhCeegHffjToSEenl4noRkw9WldPgmv8mIiKtGgoUx+yXAAe2VMfd68ysHMgKy99scu7QcPuPwJWAHkLaG+y9NyxdGvTGAfz730FZbm60cYlIrxPXPXCFJcECJhOHaQVKERHpOmb2dWCdu7/djroXmtkCM1tQWlraBdHJLsvMDJ4VV1MDp50G55wTdUQi0gvFdQ/cwuIy8gb1JSs9NepQRESk+1oF5MXsDwvLmqtTYmZJQCawoZVzTwBOMLPjgD5AfzN7wN3PbPrm7j4LmAVQUFDgHXJH0rlSUuDpp4PnxwFs2xb8m6q/N0Rk98V3D1xxOZPzBkYdhoiIdG/zgTFmNsrMUggWJZnbpM5coLG75WTgRXf3sPz0cJXKUcAY4C13v9rdh7n7yPB6LzaXvEkPNmECjBsXbP/0p3DYYeDKv0Vk98VtAreuoppVZVuZpOGTIiLSCnevAy4BniFYMfJRd19kZtebWfhAMO4EssysCLgcuCo8dxHwKLAY+CdwsbvXd/U9SMSOPBIOPDAYXgnwl79osRMR2WVxO4TyveJyQAuYiIhI29z9KeCpJmXXxmxXA6e0cO6NwI2tXPtl4OWOiFO6qa9/PXgBVFbClVfCJZfA/vsHvXIlJZCX1/o1RERCcZvAFZaUkZhg7DtEPXAiIiLSRdLTg4SttjbYX7AApk6Fxx+HGU2fTiEi8kVxO4RyYXEZ43Iz6JuS2HZlERERkY6SmQnZ2cF2Xh786lcwfXqw//e/w8yZsHFjdPGJSLcWlwlcQ4NTWFzGJA2fFBERkSjtsQf87GcwIPybpLQ0eJ5cZjhC6IUXNF9ORD4nLhO4lRuq2Fxdx/5K4ERERKQ7ufBCeOcdSAxHCF1xBfzwhzuOV1REE5eIdBvtSuDM7BgzW2ZmRWZ2VTPHR5jZC2b2npm9bGbDwvLJZvaGmS0Kj53W0TewKxof4K0eOBEREel2GlerhKAH7q9/Dba3boURI+Dmm6OJS0S6hTYTODNLBG4DjgXygZlmlt+k2i3Afe4+EbgeuCks3wKc7e77AscAfzSzyLOmwuJy0lISGT04PepQRERERFo2aBDkh3921dTApZfCtGnBfnExnH02FBVFF5+IdLn29MBNBYrcfYW71wCzgabLJOUDL4bbLzUed/cP3X15uL0aWAfkdETgu2NhcRn7Dc0kMcHariwiIiLSHWRmwi9/CYccEux/8AE88QQkhH/OFRUFZSLSq7UngRsKFMfsl4RlsQqBE8PtbwEZZpYVW8HMpgIpwEdN38DMLjSzBWa2oLS0tL2x75JtdfUsXr1Zz38TERGRnu3YY2HtWthrr2D/t7+Fgw6CLVuC/Xo9M16kN+qoRUyuAKab2bvAdGAVsP23hpntCdwPnOfuDU1PdvdZ7l7g7gU5OZ3bQbf0swpq6huUwImIiEjPl5KyY/tXv4LHHoO0tGD/+OPh+9+PJi4R6TTtSeBWAXkx+8PCsu3cfbW7n+ju+wM/C8vKAMysP/AP4Gfu/maHRL0btICJiIiI9EqDB8PRRwfb7lBQABMm7Ni/9FKYNy+6+ESkQ7QngZsPjDGzUWaWApwOzI2tYGbZZtZ4rauBu8LyFODvBAuczOm4sHfdwuIycjJS2TOzT9ShiIiIiHQOs6BH7uKLg/2VK+HBB+HDD4P9iopghUsNsxTpcdpM4Ny9DrgEeAZYAjzq7ovM7HozOyGsdhiwzMw+BHKBG8PyU4GvAOea2cLwNbmjb2JnFBaXMWnYAMy0gImIiIjEiVGjYM0aOPXUYP/xx+HII+Gtt4L9rVuh4QuzXESkG0pqTyV3fwp4qknZtTHbc4Av9LC5+wPAA7sZY4fZXF3LR6VVfGv/pmuwiIiIiPRysfPlTj4Z+vcPFj0B+M1v4L77YPFi6Ns3mvhEpF3alcD1Fu8VlwOa/yYiIiJxrm9fmBHzVKipU4PhlI3J21VXQXY2XHFFNPGJSIviKoFrXMBk4jAlcCIiIiLbHX988IJgwZP/397dB1dVXnsc/64kEMJLAXlTSFTeRNDyYqlFaC2Crai0OK060NLSomW04NXa1gpt753eqRVH5XLHotUiXqdSFSlU6+ALo05vlRFFIIVAqYhKgiAoJLwohCTr/vGc051wwU6BnJ2z9+8zk8le+5wT16MZlovn2c+zaRPs3Ru9Pn8+jB0blmKKSKxO1jECeWFtZTV9urWjY0mruFMRERERaZnMYOlSmDcvxO+9B9OmwRNPhLiuDqqq4stPJOVS08C5O2srqxmq2TcRERGRfy674VvPnvDuu3DNNSF+6SUoKwvfRSTnUtPA7dh7kF37Dun5NxEREZF/VVkZdOkSrgcOhNtugwsuCPFvfwtf/nI4mkBEml1qGrjySh3gLSIiInLCSkth1ixokzlTt7Aw7HDZvn2IH3ooWm4pIiddahq4NZXVtC4sYOBpHeJORURERCQ5pk6Fp5+Ollzefz880ugUqRUroKYmntxEEig1DVx5ZTUDe36K4qLCuFMRERERSa4VK2DBgnBdXQ0XXQS33x5id1i+HA4ciC8/kTyXigauvsFZV1XD0NKOcaciIiIikmwFBdHzcu3bw3PPhVk6gIqK8LxcdollTU3YDOXgwXhyFclDqWjg3tq1nwO19Xr+TURERCSXiopg9Gg466wQ9+0bGrrsmXPLl8OYMbBmTYi3boVXXglHFYjIUaWigVurDUxERERE4ldSEmbgunUL8SWXwJ/+BMOHh/h3v4PPfz4svQRYtw5efz0svRQRIEUNXIc2RfTu0i7uVEREREQkq0MHGD8eWrUK8fXXhxm6rl1DfOedYbYu28C98kpo6kRSLBUNXHllNUPLOlFQYHGnIiIiIiLHcsopYYYu6847YenS8FwdwM03hyYva9kyePPN3OYoErPEN3AHD9fztx37GFKq5ZMiIiIieaVHDxg1KooXL4Zf/zpcNzTA5Mkwe3b0+mOPQWVlbnMUybHEN3AV79VQ3+B6/k1EREQk35WVwdCh4dosPB83c2aI338fJk2CRx8N8aFD8PDD8MEH8eQq0kwS38CtrQwHRw7REQIiIiIiyWEWdrXs1y/E3bvDxo3w7W+H+LXX4DvfgZUrQ7xqFQwYAK++GuKtW8Oh4zt3hri+XpulSF5IfANXXllNz45t6P6pNnGnIiIiIiLNxQzOPhtOPTXEo0ZBeTmMHBni1q1hyJDwnB2E2bvrroPt20O8dCm0bRuawOzrs2bB7t0hrqmBXbvU5EnsEt/Ara2s1vJJERERkbQpKIDBg6Fz5xAPHgyLFkVn0k2YAFVVMHBgiPv2henToWfPEJeXh01UshYsCLN8e/aEePHi8AzeoUMhfuutMMunBk+aWaIbuN0Hatm6+yOGqoETEZETYGbjzGyTmW02s1uP8nqxmT2eeX2lmZ3Z6LWZmfubzOySzL0yM3vJzDaYWYWZ3Zi7skOXFAAAC01JREFU0YgIEA4Z79UrzMwBDBsGd90FHTOP3Vx7LRw8GDWAY8fCPfdE8Y4dsHp19Pl774ULL4x+/u23w7hxUfzii/DEE1FcW6tmT45Lohu48iod4C0iIifGzAqBecClwCBgkpkNOuJt1wB73L0f8F/AHZnPDgImAucA44B7Mz+vDvihuw8CRgDTj/IzRSRuhYVhaSaEGbwZM6J4xgzYsCGKv/99ePLJKG7fPjrPDuA3v4Gf/zyKv/lN+Mxnonju3KYzfhUV8M47J31Ikv+S3cBVVlNg8Ole2sBERESO2/nAZnff4u61wGPAhCPeMwF4OHO9GBhrZpa5/5i7H3L3t4HNwPnuvt3dVwO4+z5gI9ArB2MRkebSty986UtRfMMN8MgjUfzgg7B8eRR/7Wthli9rxQr4y1+i+Npr4Xvfi+KJE+GWW6L497+HP/85irNLOSXxiuJOoDmVV1bTv3sH2hUnepgiItK8egGND5aqAj53rPe4e52Z1QBdMvdfPeKzTRq1zHLLYcDKk5m0iLQwHTqEr6xJk5q+vmhR0/juu5susezSBTo1WlX2k5/AxRfDF78Y4r594StfgfvuC/GMGTB6NFx5ZYiffTY8/9enT4jdo9lCySuJnYFzd8qrahhSptk3ERFpmcysPfAH4CZ333uM90wzs1VmtmrXrl25TVBE4jNyZNNDzOfNC7tiZv31r3DHHVF8441w+eXhuqEBnnkm2lGzvh4uuyyciwdw+DCUlMCcOSE+dAi+8Q14/vkQ19bCkiXhqAUIzV5Dw8kfoxyXxDZwlbs/ZveBWoaWdY47FRERyW/bgLJGcWnm3lHfY2ZFQEfgw0/6rJm1IjRvC919ybH+4e7+gLsPd/fh3bp1O8GhiEhidO4cdsXM+vGPYfz4cF1QEHbFbPzM3cqVMHVquD58GG66KWzcArB3bzg3b8eOEG/fDl//etTQvf02tGoVlm1C2L3zqquiM/Z274bHH48+X1cXvqRZJLaBW/uPDUw0AyciIifkdaC/mfU2s9aETUmeOuI9TwFTMtdXAi+6u2fuT8zsUtkb6A+8lnk+7kFgo7vPyckoRCS9Cgvhs5+FM84Icdu2MHs2XHRRiLt1g82bo0PQTzsN1qwJRy0AtGsHM2fCueeGuLoa1q+HAwdCXFERntFbty7EL78cdufMPqO3fn3Y5CW7KcuePWF2sLa2WYedVIlt4Morq2nTqoCzenT4528WERE5BnevA2YAzxE2G1nk7hVm9p9m9tXM2x4EupjZZuBm4NbMZyuARcAG4FlgurvXA6OAbwFjzGxt5uuynA5MRORYWreGoUNDYwfQowf88pdhJ04IjdzGjTBmTIiHDw9N2ogRIS4rC7N//fqF+J13wgzdxx+HeNkyGDQItmwJ8dNPh+f5sjN4mzbBH/+ojVmOIbG7e5RXVnNuz460KkxsjyoiIjni7suAZUfc+/dG1weBq47x2duA24649zKg3QNEJBlKSuCcc6K4b1/4xS+iePx4+PDDKP7CF2DhwmhGsK4OPvooHL0AsHRpmPHbvx+Ki8OGLnPnhlnC4uKwm+eaNWHZqFmYEWzdOswspkAiu5vD9Q2sf69G57+JiIiIiLQ0p58eNk0pKQnxFVeEYxSyDdx118Ebb4SlmwADBoQNWoqLQ/zMM+Gg9Owumj/9KZSWRj9/3jz40Y+i+P33Q4OYEIls4Dbt2MfBww1q4ERERERE8k2nTnDeeVE8fnw4CD1rzhzY1mgvqauuaroj55YtsHp1FE+bBp9rdPrLvHkwf34U59lSzUQuoSzPbGAyTA2ciIiIiEjyNF4uOXp0+Mq6++6m773++rDTZtaSJWF2L3uQ+ogR4Yy8xx8P8f33h+f3xo5tjsxPWDIbuMpqTmnXmtLOJXGnIiIiIiIicRo3rmn8wgtNjzmYOjXasAXCksyrr44auP79YcoU+NnPQrxwYdi4ZcCA5s37GBK5hLK8soYhpR0xnS4vIiIiIiJHKmo0j3XDDeEYhKxt2+BXvwrX9fWhATz77BDv3w+TJ4dZPAg7aw4eDIsWhbi2Fj74oHlTb9afHoP9h+r4+859XPrpU+NORURERERE8k1xcbRhSmEh3HNP9FrbtuGQ9OwGK/v2Qe/e0CFzdNmWLeEg9a5dmy29xDVwbVsVsvwHF9KuOHFDExERERGROBUUQJ8+Udy9Ozz5ZBSffno40qAZJa7LKSgw+nXX4d0iIiIiIpJjOTiLLpHPwImIiIiIiCSRGjgREREREZE8oQZOREREREQkT6iBExERERERyRNq4ERERERERPKEGjgREREREZE8oQZOREREREQkT6iBExERERERyRNq4ERERERERPKEGjgREREREZE8Ye4edw5NmNku4N2T8KO6Ah+chJ+TjzT29EnruEFjz/exn+Hu3eJOIl+cpBqZhN+b46Wxp5PGnk75PvZj1scW18CdLGa2yt2Hx51HHDT29I09reMGjT2tY5fjl+bfG41dY08bjT2ZY9cSShERERERkTyhBk5ERERERCRPJLmBeyDuBGKksadPWscNGrvIvyrNvzcaezpp7OmU2LEn9hk4ERERERGRpEnyDJyIiIiIiEiiJK6BM7NxZrbJzDab2a1x55MrZlZmZi+Z2QYzqzCzG+POKdfMrNDM1pjZ03Hnkktm1snMFpvZ38xso5ldEHdOuWJmP8j8vq83s0fNrE3cOTUXM1tgZjvNbH2je6eY2XIzezPzvXOcOUrLpxqZzhqZ1voI6a2RaaqPkL4amagGzswKgXnApcAgYJKZDYo3q5ypA37o7oOAEcD0FI0960ZgY9xJxOC/gWfd/WxgCCn5d2BmvYB/A4a7+7lAITAx3qya1f8A4464dyvwgrv3B17IxCJHpRqZ6hqZ1voIKayRKayPkLIamagGDjgf2OzuW9y9FngMmBBzTjnh7tvdfXXmeh/hD6he8WaVO2ZWClwOzI87l1wys47AhcCDAO5e6+7V8WaVU0VAiZkVAW2B92LOp9m4+/8Cu4+4PQF4OHP9MHBFTpOSfKMaSfpqZFrrI6S+RqamPkL6amTSGrheQGWjuIqU/AHdmJmdCQwDVsabSU7NBW4BGuJOJMd6A7uAhzLLY+abWbu4k8oFd98G3AVsBbYDNe7+fLxZ5VwPd9+eud4B9IgzGWnxVCNJZY1Ma32ElNZI1cd/SGyNTFoDl3pm1h74A3CTu++NO59cMLPxwE53fyPuXGJQBJwH3Ofuw4ADJGiJwCfJrGWfQCjQPYF2ZjY53qzi42FLYW0rLPIJ0lYjU14fIaU1UvXx/0tajUxaA7cNKGsUl2bupYKZtSIUpoXuviTufHJoFPBVM3uHsCRojJk9Em9KOVMFVLl79m+SFxOKVRpcDLzt7rvc/TCwBBgZc0659r6ZnQaQ+b4z5nykZVONTF+NTHN9hPTWSNXHILE1MmkN3OtAfzPrbWatCQ9sPhVzTjlhZkZY473R3efEnU8uuftMdy919zMJ/81fdPdU/E2Tu+8AKs1sQObWWGBDjCnl0lZghJm1zfz+jyUFD6cf4SlgSuZ6CvBkjLlIy6cambIameb6CKmukaqPQWJrZFHcCZxM7l5nZjOA5wg77ixw94qY08qVUcC3gHVmtjZzb5a7L4sxJ8mNG4CFmf8h2wJ8N+Z8csLdV5rZYmA1YYe5NcAD8WbVfMzsUWA00NXMqoD/AGYDi8zsGuBd4Or4MpSWTjVSNTKlUlcj01YfIX010sKSUBEREREREWnpkraEUkREREREJLHUwImIiIiIiOQJNXAiIiIiIiJ5Qg2ciIiIiIhInlADJyIiIiIikifUwImIiIiIiOQJNXAiIiIiIiJ5Qg2ciIiIiIhInvg/hBLrMoEcBaMAAAAASUVORK5CYII=\n", 1267 | "text/plain": [ 1268 | "
" 1269 | ] 1270 | }, 1271 | "metadata": { 1272 | "needs_background": "light", 1273 | "tags": [] 1274 | }, 1275 | "output_type": "display_data" 1276 | } 1277 | ], 1278 | "source": [ 1279 | "optim = Adam(ff_model.parameters())\n", 1280 | "loss = nn.MSELoss()\n", 1281 | "output = fit(train_dl, ff_model, loss, optim, epochs)\n", 1282 | "plot_accuracy_loss(*output)" 1283 | ] 1284 | }, 1285 | { 1286 | "cell_type": "code", 1287 | "execution_count": 31, 1288 | "metadata": { 1289 | "colab": { 1290 | "base_uri": "https://localhost:8080/", 1291 | "height": 301 1292 | }, 1293 | "id": "N4nulJc4omhV", 1294 | "outputId": "f3f820ac-5261-4e15-d848-e0d0f3d8e74a" 1295 | }, 1296 | "outputs": [ 1297 | { 1298 | "name": "stdout", 1299 | "output_type": "stream", 1300 | "text": [ 1301 | "torch.Size([1, 28, 28])\n", 1302 | "Prediction tensor([0])\n" 1303 | ] 1304 | }, 1305 | { 1306 | "data": { 1307 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOCElEQVR4nO3df4xV9ZnH8c+zlmoCTQSJ42hRuoQYa4m0QbPJ6sZNhbiaCMSkKSbrmJIMxoolYix2E0tCNjHussY/DGYaETSstQpUrZttFcnaatI4iosjyI8lKODAxOWP2j8UB5794x66A9zzvZd7z7nnDs/7lUzm3vPMvefJCR/Oued77vmauwvAue+vqm4AQGcQdiAIwg4EQdiBIAg7EMTXOrkyM+PUP1Ayd7d6y9vas5vZzWa2y8z2mtmKdt4LQLms1XF2MztP0m5JcyUdlPSOpEXuviPxGvbsQMnK2LNfJ2mvu+9z92OSfilpfhvvB6BE7YT9MkkHxjw/mC07hZn1m9mgmQ22sS4AbSr9BJ27D0gakDiMB6rUzp79kKRpY55/M1sGoAu1E/Z3JM00s2+Z2dcl/VDSy8W0BaBoLR/Gu/uomd0r6beSzpO01t0/LKwzAIVqeeitpZXxmR0oXSkX1QAYPwg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCKKjUzajPrO6NwP9i0suuSRZv+eee3Jrvb29ydcuXrw4WW/X008/nVtbuXJl8rUHDx5M1k+cONFKS2GxZweCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIJjFtQMuuOCCZL2vry9ZX7NmTZHtjBvLly9P1h9//PFkPeo4fN4srm1dVGNm+yV9Lum4pFF3n9PO+wEoTxFX0P29u39WwPsAKBGf2YEg2g27S/qdmb1rZv31/sDM+s1s0MwG21wXgDa0exh/vbsfMrOLJb1mZh+5+5tj/8DdByQNSHFP0AHdoK09u7sfyn6PSNos6boimgJQvJbDbmYTzewbJx9LmidpqKjGABSrncP4Hkmbs+9if03Sv7v7fxbS1TgzceLEZP3tt99O1mfNmlVkO+eM1atXJ+vHjh1L1p944oki2xn3Wg67u++TdE2BvQAoEUNvQBCEHQiCsANBEHYgCMIOBMGtpAswderUZJ2htXIsXbo0WU8Nza1duzb52uPHj7fUUzdjzw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQXAr6Sb19PTk1l5//fXka6+++uqi2znFV199lVt7/vnnk6+94YYb2lp3o+mkzz///LbevyxXXXVVsr5r164OdVK8vFtJs2cHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSD4PnuT7r///txa2ePohw8fTtaXLFmSW3vllVeKbucU8+bNS9ZTt3OeMWNG0e007aWXXkrWV61alaxv2LChyHY6gj07EARhB4Ig7EAQhB0IgrADQRB2IAjCDgTB99kzEyZMSNa3b9+eW7vyyiuLbucUb731VrLe7nfSy3T33Xfn1h566KHka6dNm1Z0O03bvXt3sj537txk/cCBA0W2c1Za/j67ma01sxEzGxqzbIqZvWZme7Lfk4tsFkDxmjmMXyfp5tOWrZC0xd1nStqSPQfQxRqG3d3flHT0tMXzJa3PHq+XtKDgvgAUrNVr43vcfTh7fFhS7g3azKxfUn+L6wFQkLa/COPunjrx5u4Dkgak7j5BB5zrWh16O2JmvZKU/R4priUAZWg17C9L6sse90lKf18QQOUajrOb2XOSbpQ0VdIRST+X9GtJv5J0uaSPJf3A3U8/iVfvvbr2MP6BBx5I1h999NHS1p2aR1ySbr/99mT91VdfLbKdjrn00kuT9c2bNyfr1157bZHtnJU9e/Yk643ucTA6OlpkO6fIG2dv+Jnd3RfllL7fVkcAOorLZYEgCDsQBGEHgiDsQBCEHQiCr7hmGm2HMrfTeP4Ka5nG89Bco6mqU9Nst4spm4HgCDsQBGEHgiDsQBCEHQiCsANBEHYgCKZs7gLr1q2ruoWu9OmnnybrCxakb324bdu23NrFF1/cUk/NuuKKK5L1vXv3lrr+etizA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLNj3BoeHk7Wv/jiiw51cqY777wzWX/44Yc71Mn/Y88OBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzo5zVuo+AVWMc1et4Z7dzNaa2YiZDY1ZttLMDpnZ+9nPLeW2CaBdzRzGr5N0c53lj7n77OznP4ptC0DRGobd3d+UdLQDvQAoUTsn6O41s+3ZYf7kvD8ys34zGzSzwTbWBaBNrYZ9jaQZkmZLGpa0Ou8P3X3A3ee4+5wW1wWgAC2F3d2PuPtxdz8h6ReSriu2LQBFaynsZtY75ulCSUN5fwugOzQcZzez5yTdKGmqmR2U9HNJN5rZbEkuab+kJSX2CLRk0qRJla17586dla07T8Owu/uiOoufKqEXACXiclkgCMIOBEHYgSAIOxAEYQeC4CuuGLduu+22ZH3p0qUd6uRML774YmXrzsOeHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCYJy9Czz44IPJ+tatW5P1ffv2FdlO15g+fXqyfuuttybrEyZMKLCbUzUawx8dHS1t3a1izw4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQZi7d25lZp1b2Vnatm1bsn7NNdd0qJMzPfbYY8n68uXLO9TJ2bv88stza/fdd1/ytX19fcn6RRdd1FJPzXjqqfQNlJcsSd89/cSJE0W2c1bc3eotZ88OBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0Ewzp658MILk/U33ngjtzZ79uyi2znF8ePHk/UdO3bk1p588smi2znFXXfdlazPnDkzt9Zom5dpaGgoWb/pppuS9ZGRkSLbKVTL4+xmNs3MtprZDjP70Mx+ki2fYmavmdme7PfkopsGUJxmDuNHJS13929L+htJPzazb0taIWmLu8+UtCV7DqBLNQy7uw+7+3vZ488l7ZR0maT5ktZnf7Ze0oKymgTQvrO6B52ZTZf0XUl/lNTj7sNZ6bCknpzX9Evqb71FAEVo+my8mU2StFHSMnf/09ia187y1T355u4D7j7H3ee01SmAtjQVdjOboFrQN7j7pmzxETPrzeq9krr39CSAxkNvZmaqfSY/6u7Lxiz/F0n/6+6PmNkKSVPcPXlP5G4eemtk4cKFubWNGzd2sBM0KzW8Np6H1hrJG3pr5jP730r6R0kfmNn72bKfSXpE0q/MbLGkjyX9oIhGAZSjYdjd/Q+S6v5PIen7xbYDoCxcLgsEQdiBIAg7EARhB4Ig7EAQfMW1SbXLDeq74447kq999tlni24nhI8++ihZX7VqVbK+adOm3NqXX37ZUk/jAbeSBoIj7EAQhB0IgrADQRB2IAjCDgRB2IEgGGcvQGoMXpImT07feHfZsmXJ+vz585P1WbNmJetleuaZZ5L1Tz75JLe2c+fO5GtfeOGFZH10dDRZj4pxdiA4wg4EQdiBIAg7EARhB4Ig7EAQhB0IgnF24BzDODsQHGEHgiDsQBCEHQiCsANBEHYgCMIOBNEw7GY2zcy2mtkOM/vQzH6SLV9pZofM7P3s55by2wXQqoYX1ZhZr6Red3/PzL4h6V1JC1Sbj/3P7v6vTa+Mi2qA0uVdVNPM/OzDkoazx5+b2U5JlxXbHoCyndVndjObLum7kv6YLbrXzLab2Vozq3vvJTPrN7NBMxtsq1MAbWn62ngzmyTpvyT9s7tvMrMeSZ9JckmrVDvU/1GD9+AwHihZ3mF8U2E3swmSfiPpt+7+b3Xq0yX9xt2/0+B9CDtQspa/CGO1W6c+JWnn2KBnJ+5OWihpqN0mAZSnmbPx10v6vaQPJJ3IFv9M0iJJs1U7jN8vaUl2Mi/1XuzZgZK1dRhfFMIOlI/vswPBEXYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4JoeMPJgn0m6eMxz6dmy7pRt/bWrX1J9NaqInu7Iq/Q0e+zn7Fys0F3n1NZAwnd2lu39iXRW6s61RuH8UAQhB0IouqwD1S8/pRu7a1b+5LorVUd6a3Sz+wAOqfqPTuADiHsQBCVhN3MbjazXWa218xWVNFDHjPbb2YfZNNQVzo/XTaH3oiZDY1ZNsXMXjOzPdnvunPsVdRbV0zjnZhmvNJtV/X05x3/zG5m50naLWmupIOS3pG0yN13dLSRHGa2X9Icd6/8Agwz+ztJf5b0zMmptczsUUlH3f2R7D/Kye7+0y7pbaXOchrvknrLm2b8LlW47Yqc/rwVVezZr5O01933ufsxSb+UNL+CPrqeu78p6ehpi+dLWp89Xq/aP5aOy+mtK7j7sLu/lz3+XNLJacYr3XaJvjqiirBfJunAmOcH1V3zvbuk35nZu2bWX3UzdfSMmWbrsKSeKpupo+E03p102jTjXbPtWpn+vF2coDvT9e7+PUn/IOnH2eFqV/LaZ7BuGjtdI2mGanMADktaXWUz2TTjGyUtc/c/ja1Vue3q9NWR7VZF2A9Jmjbm+TezZV3B3Q9lv0ckbVbtY0c3OXJyBt3s90jF/fyFux9x9+PufkLSL1ThtsumGd8oaYO7b8oWV77t6vXVqe1WRdjfkTTTzL5lZl+X9ENJL1fQxxnMbGJ24kRmNlHSPHXfVNQvS+rLHvdJeqnCXk7RLdN4500zroq3XeXTn7t7x38k3aLaGfn/kfRPVfSQ09dfS/rv7OfDqnuT9Jxqh3VfqXZuY7GkiyRtkbRH0uuSpnRRb8+qNrX3dtWC1VtRb9erdoi+XdL72c8tVW+7RF8d2W5cLgsEwQk6IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQji/wCywnjAVGwN4gAAAABJRU5ErkJggg==\n", 1308 | "text/plain": [ 1309 | "
" 1310 | ] 1311 | }, 1312 | "metadata": { 1313 | "needs_background": "light", 1314 | "tags": [] 1315 | }, 1316 | "output_type": "display_data" 1317 | } 1318 | ], 1319 | "source": [ 1320 | "index = 4\n", 1321 | "predict_for_index(x_train, ff_model, index)" 1322 | ] 1323 | }, 1324 | { 1325 | "cell_type": "code", 1326 | "execution_count": 32, 1327 | "metadata": { 1328 | "id": "bnfFFyHxotZC" 1329 | }, 1330 | "outputs": [], 1331 | "source": [ 1332 | "# A too simple NN taken from pytorch.org/tutorials\n", 1333 | "class Mnist_CNN(nn.Module):\n", 1334 | " def __init__(self):\n", 1335 | " super().__init__()\n", 1336 | " self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)\n", 1337 | " self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)\n", 1338 | " self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)\n", 1339 | "\n", 1340 | " def forward(self, xb):\n", 1341 | " xb = xb.view(-1, 1, 28, 28)\n", 1342 | " xb = F.relu(self.conv1(xb))\n", 1343 | " xb = F.relu(self.conv2(xb))\n", 1344 | " xb = F.relu(self.conv3(xb))\n", 1345 | " xb = F.avg_pool2d(xb, 4)\n", 1346 | " return xb.view(-1, xb.size(1))\n", 1347 | "\n", 1348 | "class LeNet5(nn.Module):\n", 1349 | " def __init__(self):\n", 1350 | " super().__init__()\n", 1351 | " self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)\n", 1352 | " self.average1 = nn.AvgPool2d(2, stride=2)\n", 1353 | " self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)\n", 1354 | " self.average2 = nn.AvgPool2d(2, stride=2)\n", 1355 | " self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1)\n", 1356 | " \n", 1357 | " self.flatten = Flatten()\n", 1358 | " \n", 1359 | " self.fc1 = nn.Linear(120, 82)\n", 1360 | " self.fc2 = nn.Linear(82,10)\n", 1361 | "\n", 1362 | " def forward(self, xb):\n", 1363 | " xb = xb.view(-1, 1, 28, 28)\n", 1364 | " xb = F.tanh(self.conv1(xb))\n", 1365 | " xb = self.average1(xb)\n", 1366 | " xb = F.tanh(self.conv2(xb))\n", 1367 | " xb = self.average2(xb)\n", 1368 | " xb = F.tanh(self.conv3(xb))\n", 1369 | " xb = xb.view(-1, xb.shape[1])\n", 1370 | " xb = F.relu(self.fc1(xb))\n", 1371 | " xb = F.relu(self.fc2(xb))\n", 1372 | " return xb" 1373 | ] 1374 | }, 1375 | { 1376 | "cell_type": "code", 1377 | "execution_count": 33, 1378 | "metadata": { 1379 | "colab": { 1380 | "base_uri": "https://localhost:8080/", 1381 | "height": 665 1382 | }, 1383 | "id": "cCj66rlTo1OK", 1384 | "outputId": "70c3c2f1-0155-4d34-9c0c-1b4444e95880" 1385 | }, 1386 | "outputs": [ 1387 | { 1388 | "name": "stderr", 1389 | "output_type": "stream", 1390 | "text": [ 1391 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:6: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", 1392 | " \n", 1393 | "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1628: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n", 1394 | " warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n" 1395 | ] 1396 | }, 1397 | { 1398 | "name": "stdout", 1399 | "output_type": "stream", 1400 | "text": [ 1401 | "Epoch\tAccuracy\tLoss\n", 1402 | "0\t0.8295\t0.024257624548075996\n", 1403 | "1\t0.8807142857142857\t0.01573720137351951\n", 1404 | "2\t0.8872380952380953\t0.014220154276146034\n", 1405 | "3\t0.8907142857142857\t0.013434058933092147\n", 1406 | "4\t0.8926666666666667\t0.012940065951555786\n", 1407 | "5\t0.8939047619047619\t0.012577981816015037\n", 1408 | "6\t0.8947380952380952\t0.012296840460273569\n", 1409 | "7\t0.8956666666666667\t0.012080316865941456\n", 1410 | "8\t0.8961904761904762\t0.01189652137327911\n", 1411 | "9\t0.8965714285714286\t0.011736341546568878\n", 1412 | "10\t0.8971190476190476\t0.011606817156759685\n", 1413 | "11\t0.8975238095238095\t0.011493785651981876\n" 1414 | ] 1415 | }, 1416 | { 1417 | "data": { 1418 | "image/png": "\n", 1419 | "text/plain": [ 1420 | "
" 1421 | ] 1422 | }, 1423 | "metadata": { 1424 | "needs_background": "light", 1425 | "tags": [] 1426 | }, 1427 | "output_type": "display_data" 1428 | } 1429 | ], 1430 | "source": [ 1431 | "conv_model = LeNet5()\n", 1432 | "conv_model.apply(init_weights)\n", 1433 | "loss = nn.MSELoss()\n", 1434 | "optim = SGD(conv_model.parameters(), lr=0.1, momentum=0.9)\n", 1435 | "plot_accuracy_loss(*fit(train_dl, conv_model,loss,optim,epochs))" 1436 | ] 1437 | }, 1438 | { 1439 | "cell_type": "markdown", 1440 | "metadata": { 1441 | "id": "_R9-T2RQpfya" 1442 | }, 1443 | "source": [ 1444 | "## Working on test data" 1445 | ] 1446 | }, 1447 | { 1448 | "cell_type": "markdown", 1449 | "metadata": { 1450 | "id": "SjpG8AHgprei" 1451 | }, 1452 | "source": [ 1453 | "### Normalization" 1454 | ] 1455 | }, 1456 | { 1457 | "cell_type": "code", 1458 | "execution_count": 34, 1459 | "metadata": { 1460 | "colab": { 1461 | "base_uri": "https://localhost:8080/" 1462 | }, 1463 | "id": "mFl5z4rppu3P", 1464 | "outputId": "d360fe1e-9d51-4e4d-a03d-a6c9a815f107" 1465 | }, 1466 | "outputs": [ 1467 | { 1468 | "data": { 1469 | "text/plain": [ 1470 | "torch.Size([28000, 28, 28])" 1471 | ] 1472 | }, 1473 | "execution_count": 34, 1474 | "metadata": { 1475 | "tags": [] 1476 | }, 1477 | "output_type": "execute_result" 1478 | } 1479 | ], 1480 | "source": [ 1481 | "x_test = test.values\n", 1482 | "x_test = x_test.reshape([-1, 28, 28]).astype(float)\n", 1483 | "x_test = (x_test-mean)/std\n", 1484 | "x_test = torch.from_numpy(np.float32(x_test))\n", 1485 | "x_test.shape" 1486 | ] 1487 | }, 1488 | { 1489 | "cell_type": "markdown", 1490 | "metadata": { 1491 | "id": "kKd3tLlFp1Sx" 1492 | }, 1493 | "source": [ 1494 | "#### Prediction" 1495 | ] 1496 | }, 1497 | { 1498 | "cell_type": "code", 1499 | "execution_count": 35, 1500 | "metadata": { 1501 | "colab": { 1502 | "base_uri": "https://localhost:8080/", 1503 | "height": 394 1504 | }, 1505 | "id": "HKTVh8rDpxeW", 1506 | "outputId": "a13e0c85-2056-41d2-b1c6-a0bc2398e0d3" 1507 | }, 1508 | "outputs": [ 1509 | { 1510 | "name": "stdout", 1511 | "output_type": "stream", 1512 | "text": [ 1513 | "torch.Size([1, 28, 28])\n", 1514 | "Prediction tensor([3])\n", 1515 | "torch.Size([1, 28, 28])\n", 1516 | "Prediction tensor([3])\n" 1517 | ] 1518 | }, 1519 | { 1520 | "name": "stderr", 1521 | "output_type": "stream", 1522 | "text": [ 1523 | "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1628: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n", 1524 | " warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n" 1525 | ] 1526 | }, 1527 | { 1528 | "data": { 1529 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOv0lEQVR4nO3db4xUZZbH8d9ZFjTCaHDRDjLdDk5MDK6BUQIk4IaNgKyR4ETFIWbFSOwRBx3iRBfdF6gvdLJxmGw0IWmiGVhnJaMzIsaJTEswijGjqK22SGNrUGxbegkJiBIBOfuiL6bFvk+1Vbf+wPl+kk5V3VO36lj68966T937mLsLwMnvH+rdAIDaIOxAEIQdCIKwA0EQdiCIf6zlm5kZh/6BKnN3G2x5RVt2M5trZl1m1m1myyt5LQDVZeWOs5vZMEk7JM2W9Kmk1yUtdPdtiXXYsgNVVo0t+xRJ3e7+kbsfkrRO0vwKXg9AFVUS9nGSdg14/Gm27DvMrNXMtprZ1greC0CFqn6Azt3bJLVJ7MYD9VTJlr1HUvOAxz/OlgFoQJWE/XVJ55vZeDMbIekXkjYU0xaAopW9G+/uR8xsqaSNkoZJeszd3yusMwCFKnvoraw34zs7UHVV+VENgBMHYQeCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIsudnlyQz2ynpC0nfSDri7pOLaApA8SoKe+Zf3X1PAa8DoIrYjQeCqDTsLulvZvaGmbUO9gQzazWzrWa2tcL3AlABc/fyVzYb5+49Zna2pHZJt7n7S4nnl/9mAIbE3W2w5RVt2d29J7vtk/S0pCmVvB6A6ik77GY20sx+dOy+pDmSOotqDECxKjka3yTpaTM79jr/6+7PF9JVMGeddVayfttttyXrM2bMyK3NnDmznJa+deTIkWT9ueeeS9a3b9+eW+vq6iqrp2PWr1+frB84cCC3Vuqf62RUdtjd/SNJEwvsBUAVMfQGBEHYgSAIOxAEYQeCIOxAEBX9gu4Hv9kJ/Au6c845J7d25ZVXJte95pprkvVZs2aV1dMxhw4dyq199tlnFb32sGHDkvXm5uaKXr+aOjo6cmtr165NrvvII48k6408dFeVX9ABOHEQdiAIwg4EQdiBIAg7EARhB4Ig7EAQRVxwMoTUqZwTJ1Z28t+zzz6brG/ZsiVZ37BhQ26t0tNIp02blqy/+OKLyfrtt9+eW3vttdfKaelbU6dOTdYXLlyYW1u5cmVy3aampmT97rvvTtYbEVt2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiC89mH6Prrr8+tjRkzJrluqcstd3d3l9VTLcydOzdZL/XP/vjjjxfZzg8yatSo3FpnZ3qKg/379yfrl1xySbJ++PDhZL2aOJ8dCI6wA0EQdiAIwg4EQdiBIAg7EARhB4JgnB0N6+KLL07WU+erS9LNN9+cWzv99NOT61522WXJ+ubNm5P1eip7nN3MHjOzPjPrHLDsTDNrN7MPstvRRTYLoHhD2Y3/g6Tjf0a1XNImdz9f0qbsMYAGVjLs7v6SpL3HLZ4vaU12f42kqwruC0DByr0GXZO792b3P5eUe8EuM2uV1Frm+wAoSMUXnHR3Tx14c/c2SW0SB+iAeip36G23mY2VpOy2r7iWAFRDuWHfIGlRdn+RpGeKaQdAtZQcZzezJyTNlDRG0m5JKyStl/QnSS2SPpa0wN2PP4g32GuxG3+SOeWUU5L1O+64I7e2ePHi5LrnnXdesv7ll18m62+99VZubd68ecl19+3bl6w3srxx9pLf2d0975cL6V8dAGgo/FwWCIKwA0EQdiAIwg4EQdiBIJiyuQCnnnpqsl5qiGn48OFFtvMdvb29yfrYsWOT9ebm5mS91BBWS0tLbm3jxo3JdW+55ZZkvaOjI1nfs2dPsh4NW3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIJx9gLMnj07WU+d5ilJ48ePL7KdQu3atStZf/DBB5P11CWXu7q6yuoJ5WHLDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBMGVzDZx22mnJ+tlnn12jTr7vpptuStavvfbaZL3UOeNLly7Nrb399tvJdVGesqdsBnByIOxAEIQdCIKwA0EQdiAIwg4EQdiBIBhnR9KIESOS9SVLliTry5cvz629+uqryXWvu+66ZP3w4cPJelRlj7Ob2WNm1mdmnQOW3WtmPWbWkf1dUWSzAIo3lN34P0iaO8jy37v7pOzvr8W2BaBoJcPu7i9J2luDXgBUUSUH6Jaa2TvZbv7ovCeZWauZbTWzrRW8F4AKlRv2VZJ+KmmSpF5Jv8t7oru3uftkd59c5nsBKEBZYXf33e7+jbsflbRa0pRi2wJQtLLCbmYD5/n9uaTOvOcCaAwlx9nN7AlJMyWNkbRb0ors8SRJLmmnpF+6e3oicDX2OPvEiROT9dT10/fu5fhlngsuuCC31t7enly3r68vWV+wYEGy/uGHHybrJ6u8cfaSk0S4+8JBFj9acUcAaoqfywJBEHYgCMIOBEHYgSAIOxBEmFNcS12uubMz/VOBmTNn5ta2bdtWTkvhTZs2LVlfvXp1sn7GGWck67Nmzcqt7dixI7nuiYxLSQPBEXYgCMIOBEHYgSAIOxAEYQeCIOxAEGHG2W+88cZk/dJLL03WFy9eXGA3GIqWlpZkfePGjcl6T09Pbm3evHnJdQ8ePJisNzLG2YHgCDsQBGEHgiDsQBCEHQiCsANBEHYgiJJXl41i37599W4Bx/nkk0+S9RUrViTr69aty61Nnz49ue4LL7yQrJ+I2LIDQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBBhxtl7e9MzSt96663Jeuoa5YzR18f69euT9e3bt+fWrr766uS6IcfZzazZzDab2TYze8/Mfp0tP9PM2s3sg+x2dPXbBVCuoezGH5H0G3efIGmapF+Z2QRJyyVtcvfzJW3KHgNoUCXD7u697v5mdv8LSe9LGidpvqQ12dPWSLqqWk0CqNwP+s5uZj+R9DNJf5fU5O7Hvgh/LqkpZ51WSa3ltwigCEM+Gm9moyT9WdIyd98/sOb9V60c9GKS7t7m7pPdfXJFnQKoyJDCbmbD1R/0P7r7X7LFu81sbFYfK6mvOi0CKELJ3XgzM0mPSnrf3VcOKG2QtEjSb7PbZ6rSYUFefvnlZL25uTlZv/zyy3NrTz31VHLdo0ePJusoz6FDh5L13bt359ZKTRd9MhrKd/bpkv5d0rtm1pEtu0f9If+TmS2W9LGkBdVpEUARSobd3bdIGvSi85IuK7YdANXCz2WBIAg7EARhB4Ig7EAQhB0IIswprl999VWyftdddyXra9euza1deOGFyXUfeOCBZP3rr79O1jG4O++8M1mfOHFibu3+++8vup2Gx5YdCIKwA0EQdiAIwg4EQdiBIAg7EARhB4Kw/ovM1OjNzGr3ZgW74YYbcmttbW3Jdbu6upL15cvT1+osdS7+gQMHkvVGNWHChGR9yZIlFdUfeuih3Np9992XXPfgwYPJeiNz90HPUmXLDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBMM5egEmTJiXry5YtS9anTp2arKemi5ak559/Prf25JNPJtctNZ7c0tKSrE+fPj1ZnzNnTm5t3LhxyXW7u7uT9YcffjhZX7VqVbJ+smKcHQiOsANBEHYgCMIOBEHYgSAIOxAEYQeCKDnObmbNktZKapLkktrc/b/N7F5JN0v6v+yp97j7X0u81kk5zl6pkSNHJuulrmk/Y8aM3NpFF12UXLfU9fTPPffcZL3UufZbtmzJrb3yyivJddvb25P1UvOzR5U3zj6USSKOSPqNu79pZj+S9IaZHfu38Ht3z79CAICGMZT52Xsl9Wb3vzCz9yWlf/oEoOH8oO/sZvYTST+T9Pds0VIze8fMHjOz0TnrtJrZVjPbWlGnACoy5LCb2ShJf5a0zN33S1ol6aeSJql/y/+7wdZz9zZ3n+zukwvoF0CZhhR2Mxuu/qD/0d3/Iknuvtvdv3H3o5JWS5pSvTYBVKpk2M3MJD0q6X13Xzlg+dgBT/u5pM7i2wNQlKEMvc2Q9LKkdyUdzRbfI2mh+nfhXdJOSb/MDualXouhN6DK8obeOJ8dOMlwPjsQHGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiCIoVxdtkh7JH084PGYbFkjatTeGrUvid7KVWRvudf+run57N97c7OtjXptukbtrVH7kuitXLXqjd14IAjCDgRR77C31fn9Uxq1t0btS6K3ctWkt7p+ZwdQO/XesgOoEcIOBFGXsJvZXDPrMrNuM1tejx7ymNlOM3vXzDrqPT9dNoden5l1Dlh2ppm1m9kH2e2gc+zVqbd7zawn++w6zOyKOvXWbGabzWybmb1nZr/Oltf1s0v0VZPPrebf2c1smKQdkmZL+lTS65IWuvu2mjaSw8x2Sprs7nX/AYaZ/YukA5LWuvs/Z8v+S9Jed/9t9j/K0e7+Hw3S272SDtR7Gu9stqKxA6cZl3SVpBtVx88u0dcC1eBzq8eWfYqkbnf/yN0PSVonaX4d+mh47v6SpL3HLZ4vaU12f436/2OpuZzeGoK797r7m9n9LyQdm2a8rp9doq+aqEfYx0naNeDxp2qs+d5d0t/M7A0za613M4NoGjDN1ueSmurZzCBKTuNdS8dNM94wn105059XigN03zfD3S+W9G+SfpXtrjYk7/8O1khjp0OaxrtWBplm/Fv1/OzKnf68UvUIe4+k5gGPf5wtawju3pPd9kl6Wo03FfXuYzPoZrd9de7nW400jfdg04yrAT67ek5/Xo+wvy7pfDMbb2YjJP1C0oY69PE9ZjYyO3AiMxspaY4abyrqDZIWZfcXSXqmjr18R6NM4503zbjq/NnVffpzd6/5n6Qr1H9E/kNJ/1mPHnL6Ok/S29nfe/XuTdIT6t+tO6z+YxuLJf2TpE2SPpD0gqQzG6i3/1H/1N7vqD9YY+vU2wz176K/I6kj+7ui3p9doq+afG78XBYIggN0QBCEHQiCsANBEHYgCMIOBEHYgSAIOxDE/wPJ4cK3w7TDGQAAAABJRU5ErkJggg==\n", 1530 | "text/plain": [ 1531 | "
" 1532 | ] 1533 | }, 1534 | "metadata": { 1535 | "needs_background": "light", 1536 | "tags": [] 1537 | }, 1538 | "output_type": "display_data" 1539 | } 1540 | ], 1541 | "source": [ 1542 | "index = 7\n", 1543 | "predict_for_index(x_test, ff_model, index)\n", 1544 | "predict_for_index(x_test, conv_model, index)" 1545 | ] 1546 | }, 1547 | { 1548 | "cell_type": "code", 1549 | "execution_count": 36, 1550 | "metadata": { 1551 | "id": "q1BxSZ5ip87W" 1552 | }, 1553 | "outputs": [], 1554 | "source": [ 1555 | "# Export data to CSV in format of submission\n", 1556 | "def export_csv(model_name, predictions, commit_no):\n", 1557 | " df = pd.DataFrame(prediction.tolist(), columns=['Label'])\n", 1558 | " df['ImageId'] = df.index + 1\n", 1559 | " file_name = f'submission_{model_name}_v{commit_no}.csv'\n", 1560 | " print('Saving ',file_name)\n", 1561 | " df[['ImageId','Label']].to_csv(file_name, index = False)" 1562 | ] 1563 | }, 1564 | { 1565 | "cell_type": "code", 1566 | "execution_count": 37, 1567 | "metadata": { 1568 | "colab": { 1569 | "base_uri": "https://localhost:8080/", 1570 | "height": 255 1571 | }, 1572 | "id": "IUz9SXcxqDIK", 1573 | "outputId": "b8368360-b45e-465f-c46e-e86cbd7989d9" 1574 | }, 1575 | "outputs": [ 1576 | { 1577 | "data": { 1578 | "text/html": [ 1579 | "
\n", 1580 | "\n", 1593 | "\n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1608 | " \n", 1609 | " \n", 1610 | " \n", 1611 | " \n", 1612 | " \n", 1613 | " \n", 1614 | " \n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1625 | " \n", 1626 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1643 | " \n", 1644 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1736 | " \n", 1737 | " \n", 1738 | " \n", 1739 | " \n", 1740 | " \n", 1741 | " \n", 1742 | " \n", 1743 | " \n", 1744 | " \n", 1745 | " \n", 1746 | " \n", 1747 | " \n", 1748 | " \n", 1749 | " \n", 1750 | " \n", 1751 | " \n", 1752 | " \n", 1753 | " \n", 1754 | " \n", 1755 | " \n", 1756 | " \n", 1757 | " \n", 1758 | " \n", 1759 | " \n", 1760 | " \n", 1761 | " \n", 1762 | " \n", 1763 | " \n", 1764 | " \n", 1765 | " \n", 1766 | " \n", 1767 | " \n", 1768 | " \n", 1769 | " \n", 1770 | " \n", 1771 | " \n", 1772 | " \n", 1773 | " \n", 1774 | " \n", 1775 | " \n", 1776 | " \n", 1777 | " \n", 1778 | " \n", 1779 | " \n", 1780 | " \n", 1781 | " \n", 1782 | " \n", 1783 | " \n", 1784 | " \n", 1785 | " \n", 1786 | " \n", 1787 | " \n", 1788 | " \n", 1789 | " \n", 1790 | " \n", 1791 | " \n", 1792 | " \n", 1793 | " \n", 1794 | " \n", 1795 | " \n", 1796 | " \n", 1797 | " \n", 1798 | " \n", 1799 | " \n", 1800 | " \n", 1801 | " \n", 1802 | " \n", 1803 | " \n", 1804 | " \n", 1805 | " \n", 1806 | " \n", 1807 | " \n", 1808 | " \n", 1809 | " \n", 1810 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1814 | " \n", 1815 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1819 | " \n", 1820 | " \n", 1821 | " \n", 1822 | " \n", 1823 | " \n", 1824 | " \n", 1825 | " \n", 1826 | " \n", 1827 | " \n", 1828 | " \n", 1829 | " \n", 1830 | " \n", 1831 | " \n", 1832 | " \n", 1833 | " \n", 1834 | " \n", 1835 | " \n", 1836 | " \n", 1837 | " \n", 1838 | " \n", 1839 | " \n", 1840 | " \n", 1841 | " \n", 1842 | " \n", 1843 | " \n", 1844 | " \n", 1845 | " \n", 1846 | " \n", 1847 | " \n", 1848 | " \n", 1849 | " \n", 1850 | " \n", 1851 | " \n", 1852 | " \n", 1853 | " \n", 1854 | " \n", 1855 | " \n", 1856 | " \n", 1857 | " \n", 1858 | " \n", 1859 | " \n", 1860 | " \n", 1861 | " \n", 1862 | " \n", 1863 | " \n", 1864 | " \n", 1865 | " \n", 1866 | " \n", 1867 | " \n", 1868 | " \n", 1869 | " \n", 1870 | " \n", 1871 | " \n", 1872 | " \n", 1873 | " \n", 1874 | " \n", 1875 | " \n", 1876 | " \n", 1877 | " \n", 1878 | " \n", 1879 | " \n", 1880 | " \n", 1881 | " \n", 1882 | " \n", 1883 | " \n", 1884 | " \n", 1885 | " \n", 1886 | " \n", 1887 | " \n", 1888 | " \n", 1889 | " \n", 1890 | " \n", 1891 | " \n", 1892 | " \n", 1893 | " \n", 1894 | " \n", 1895 | " \n", 1896 | " \n", 1897 | " \n", 1898 | " \n", 1899 | " \n", 1900 | " \n", 1901 | " \n", 1902 | " \n", 1903 | " \n", 1904 | " \n", 1905 | " \n", 1906 | " \n", 1907 | " \n", 1908 | " \n", 1909 | " \n", 1910 | " \n", 1911 | " \n", 1912 | " \n", 1913 | " \n", 1914 | " \n", 1915 | " \n", 1916 | " \n", 1917 | " \n", 1918 | " \n", 1919 | " \n", 1920 | " \n", 1921 | " \n", 1922 | " \n", 1923 | " \n", 1924 | " \n", 1925 | " \n", 1926 | " \n", 1927 | " \n", 1928 | " \n", 1929 | " \n", 1930 | " \n", 1931 | " \n", 1932 | " \n", 1933 | " \n", 1934 | " \n", 1935 | " \n", 1936 | " \n", 1937 | " \n", 1938 | " \n", 1939 | " \n", 1940 | " \n", 1941 | " \n", 1942 | " \n", 1943 | " \n", 1944 | " \n", 1945 | " \n", 1946 | " \n", 1947 | " \n", 1948 | " \n", 1949 | " \n", 1950 | " \n", 1951 | " \n", 1952 | " \n", 1953 | " \n", 1954 | " \n", 1955 | " \n", 1956 | " \n", 1957 | " \n", 1958 | " \n", 1959 | " \n", 1960 | " \n", 1961 | " \n", 1962 | " \n", 1963 | " \n", 1964 | " \n", 1965 | " \n", 1966 | " \n", 1967 | " \n", 1968 | " \n", 1969 | " \n", 1970 | " \n", 1971 | " \n", 1972 | " \n", 1973 | " \n", 1974 | " \n", 1975 | " \n", 1976 | " \n", 1977 | " \n", 1978 | " \n", 1979 | " \n", 1980 | " \n", 1981 | " \n", 1982 | " \n", 1983 | " \n", 1984 | " \n", 1985 | " \n", 1986 | " \n", 1987 | " \n", 1988 | " \n", 1989 | " \n", 1990 | " \n", 1991 | " \n", 1992 | " \n", 1993 | " \n", 1994 | " \n", 1995 | " \n", 1996 | " \n", 1997 | " \n", 1998 | " \n", 1999 | " \n", 2000 | " \n", 2001 | " \n", 2002 | " \n", 2003 | " \n", 2004 | " \n", 2005 | " \n", 2006 | " \n", 2007 | " \n", 2008 | " \n", 2009 | " \n", 2010 | " \n", 2011 | " \n", 2012 | " \n", 2013 | " \n", 2014 | " \n", 2015 | " \n", 2016 | " \n", 2017 | " \n", 2018 | " \n", 2019 | " \n", 2020 | " \n", 2021 | " \n", 2022 | " \n", 2023 | " \n", 2024 | " \n", 2025 | " \n", 2026 | " \n", 2027 | " \n", 2028 | " \n", 2029 | " \n", 2030 | " \n", 2031 | " \n", 2032 | " \n", 2033 | " \n", 2034 | " \n", 2035 | " \n", 2036 | " \n", 2037 | " \n", 2038 | " \n", 2039 | " \n", 2040 | " \n", 2041 | " \n", 2042 | " \n", 2043 | " \n", 2044 | " \n", 2045 | " \n", 2046 | " \n", 2047 | " \n", 2048 | " \n", 2049 | " \n", 2050 | " \n", 2051 | " \n", 2052 | " \n", 2053 | " \n", 2054 | " \n", 2055 | " \n", 2056 | " \n", 2057 | " \n", 2058 | " \n", 2059 | " \n", 2060 | " \n", 2061 | " \n", 2062 | " \n", 2063 | " \n", 2064 | " \n", 2065 | " \n", 2066 | " \n", 2067 | " \n", 2068 | " \n", 2069 | " \n", 2070 | " \n", 2071 | " \n", 2072 | " \n", 2073 | " \n", 2074 | " \n", 2075 | " \n", 2076 | " \n", 2077 | " \n", 2078 | " \n", 2079 | " \n", 2080 | " \n", 2081 | " \n", 2082 | " \n", 2083 | " \n", 2084 | " \n", 2085 | " \n", 2086 | " \n", 2087 | " \n", 2088 | " \n", 2089 | " \n", 2090 | " \n", 2091 | " \n", 2092 | " \n", 2093 | " \n", 2094 | " \n", 2095 | " \n", 2096 | " \n", 2097 | " \n", 2098 | " \n", 2099 | " \n", 2100 | " \n", 2101 | " \n", 2102 | "
pixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8pixel9pixel10pixel11pixel12pixel13pixel14pixel15pixel16pixel17pixel18pixel19pixel20pixel21pixel22pixel23pixel24pixel25pixel26pixel27pixel28pixel29pixel30pixel31pixel32pixel33pixel34pixel35pixel36pixel37pixel38pixel39...pixel744pixel745pixel746pixel747pixel748pixel749pixel750pixel751pixel752pixel753pixel754pixel755pixel756pixel757pixel758pixel759pixel760pixel761pixel762pixel763pixel764pixel765pixel766pixel767pixel768pixel769pixel770pixel771pixel772pixel773pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
00000000000000000000000000000000000000000...0000000000000000000000000000000000000000
10000000000000000000000000000000000000000...0000000000000000000000000000000000000000
20000000000000000000000000000000000000000...0000000000000000000000000000000000000000
30000000000000000000000000000000000000000...0000000000000000000000000000000000000000
40000000000000000000000000000000000000000...0000000000000000000000000000000000000000
\n", 2103 | "

5 rows × 784 columns

\n", 2104 | "
" 2105 | ], 2106 | "text/plain": [ 2107 | " pixel0 pixel1 pixel2 pixel3 ... pixel780 pixel781 pixel782 pixel783\n", 2108 | "0 0 0 0 0 ... 0 0 0 0\n", 2109 | "1 0 0 0 0 ... 0 0 0 0\n", 2110 | "2 0 0 0 0 ... 0 0 0 0\n", 2111 | "3 0 0 0 0 ... 0 0 0 0\n", 2112 | "4 0 0 0 0 ... 0 0 0 0\n", 2113 | "\n", 2114 | "[5 rows x 784 columns]" 2115 | ] 2116 | }, 2117 | "execution_count": 37, 2118 | "metadata": { 2119 | "tags": [] 2120 | }, 2121 | "output_type": "execute_result" 2122 | } 2123 | ], 2124 | "source": [ 2125 | "test.head()" 2126 | ] 2127 | }, 2128 | { 2129 | "cell_type": "code", 2130 | "execution_count": 38, 2131 | "metadata": { 2132 | "id": "eKPTP4pBqFyt" 2133 | }, 2134 | "outputs": [], 2135 | "source": [ 2136 | "# just to make output easier to read\n", 2137 | "commit_no = 17" 2138 | ] 2139 | }, 2140 | { 2141 | "cell_type": "code", 2142 | "execution_count": 39, 2143 | "metadata": { 2144 | "colab": { 2145 | "base_uri": "https://localhost:8080/" 2146 | }, 2147 | "id": "3QWmDqOiqKC5", 2148 | "outputId": "28d57d95-31be-4dcb-fc6d-f4105fc8f389" 2149 | }, 2150 | "outputs": [ 2151 | { 2152 | "name": "stdout", 2153 | "output_type": "stream", 2154 | "text": [ 2155 | "Prediction tensor([2, 0, 9, ..., 3, 9, 2])\n", 2156 | "Saving submission_ff_model_v17.csv\n" 2157 | ] 2158 | } 2159 | ], 2160 | "source": [ 2161 | "ff_test_yhat = ff_model(x_test.float())\n", 2162 | "prediction = torch.argmax(ff_test_yhat,1)\n", 2163 | "print('Prediction',prediction)\n", 2164 | "export_csv('ff_model',prediction, commit_no=commit_no)" 2165 | ] 2166 | }, 2167 | { 2168 | "cell_type": "code", 2169 | "execution_count": 40, 2170 | "metadata": { 2171 | "colab": { 2172 | "base_uri": "https://localhost:8080/" 2173 | }, 2174 | "id": "LXdZ2a-FqNRR", 2175 | "outputId": "d83480c3-a897-403c-c262-29c46c6a0193" 2176 | }, 2177 | "outputs": [ 2178 | { 2179 | "name": "stderr", 2180 | "output_type": "stream", 2181 | "text": [ 2182 | "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1628: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n", 2183 | " warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n" 2184 | ] 2185 | }, 2186 | { 2187 | "name": "stdout", 2188 | "output_type": "stream", 2189 | "text": [ 2190 | "Saving submission_lenet_model_v17.csv\n" 2191 | ] 2192 | } 2193 | ], 2194 | "source": [ 2195 | "cn_train_yhat = conv_model(x_test)\n", 2196 | "prediction = torch.argmax(cn_train_yhat,1)\n", 2197 | "yo = torch.argmax(y_train,1)\n", 2198 | "export_csv('lenet_model',prediction, commit_no=commit_no)" 2199 | ] 2200 | }, 2201 | { 2202 | "cell_type": "markdown", 2203 | "metadata": { 2204 | "id": "z2Ppk7N1qWzq" 2205 | }, 2206 | "source": [ 2207 | "### Ensembling" 2208 | ] 2209 | }, 2210 | { 2211 | "cell_type": "code", 2212 | "execution_count": 41, 2213 | "metadata": { 2214 | "colab": { 2215 | "base_uri": "https://localhost:8080/" 2216 | }, 2217 | "id": "7B7zAMxiqQn-", 2218 | "outputId": "2afadec2-d152-4757-ab86-505c6ce73783" 2219 | }, 2220 | "outputs": [ 2221 | { 2222 | "name": "stdout", 2223 | "output_type": "stream", 2224 | "text": [ 2225 | "Ensemble No 1\n", 2226 | "Epoch\tAccuracy\tLoss\n" 2227 | ] 2228 | }, 2229 | { 2230 | "name": "stderr", 2231 | "output_type": "stream", 2232 | "text": [ 2233 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:6: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", 2234 | " \n", 2235 | "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1628: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n", 2236 | " warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n" 2237 | ] 2238 | }, 2239 | { 2240 | "name": "stdout", 2241 | "output_type": "stream", 2242 | "text": [ 2243 | "0\t0.9111428571428571\t0.016511163488243063\n", 2244 | "1\t0.9700238095238095\t0.007111747115510239\n", 2245 | "2\t0.9797380952380952\t0.005202083471174461\n", 2246 | "3\t0.9842619047619048\t0.004263598822100482\n", 2247 | "4\t0.9865238095238095\t0.003668796987828474\n", 2248 | "5\t0.9885714285714285\t0.0032527321415189285\n", 2249 | "6\t0.9895\t0.002938218746886091\n", 2250 | "7\t0.9904285714285714\t0.002678092770890476\n", 2251 | "8\t0.9911904761904762\t0.002474854236732335\n", 2252 | "9\t0.9917142857142857\t0.0023038261132371503\n", 2253 | "10\t0.9921904761904762\t0.002157256974708615\n", 2254 | "11\t0.9927142857142857\t0.0020299362510590156\n", 2255 | "Ensemble No 2\n", 2256 | "Epoch\tAccuracy\tLoss\n", 2257 | "0\t0.9172619047619047\t0.01556841353070745\n", 2258 | "1\t0.9723095238095238\t0.006644158197788003\n", 2259 | "2\t0.9799285714285715\t0.004954955967010638\n", 2260 | "3\t0.9838333333333333\t0.004082293982533922\n", 2261 | "4\t0.9864285714285714\t0.003539374203345337\n", 2262 | "5\t0.9879523809523809\t0.00315711912165695\n", 2263 | "6\t0.9893095238095239\t0.0028619879904193193\n", 2264 | "7\t0.9900238095238095\t0.002624628800757745\n", 2265 | "8\t0.9909523809523809\t0.002428901577410905\n", 2266 | "9\t0.9915714285714285\t0.002260973717438984\n", 2267 | "10\t0.9922380952380953\t0.0021149535993712316\n", 2268 | "11\t0.9927619047619047\t0.001978734256543752\n", 2269 | "Ensemble No 3\n", 2270 | "Epoch\tAccuracy\tLoss\n", 2271 | "0\t0.8974285714285715\t0.01978621597201925\n", 2272 | "1\t0.9721904761904762\t0.00684904772770567\n", 2273 | "2\t0.9796428571428571\t0.005134945574799153\n", 2274 | "3\t0.9835238095238096\t0.004289438445163576\n", 2275 | "4\t0.985547619047619\t0.0037338908047123303\n", 2276 | "5\t0.9873333333333333\t0.0033411253987006345\n", 2277 | "6\t0.9886428571428572\t0.0030370050722476564\n", 2278 | "7\t0.9898809523809524\t0.002805216822105558\n", 2279 | "8\t0.9906666666666667\t0.002602088730605078\n", 2280 | "9\t0.991452380952381\t0.002425416489220345\n", 2281 | "10\t0.9918809523809524\t0.0022839747143014867\n", 2282 | "11\t0.9921904761904762\t0.002156054163053972\n", 2283 | "Ensemble No 4\n", 2284 | "Epoch\tAccuracy\tLoss\n", 2285 | "0\t0.8132142857142857\t0.02603206173282916\n", 2286 | "1\t0.9691904761904762\t0.007353543678845762\n", 2287 | "2\t0.979047619047619\t0.005255459206776986\n", 2288 | "3\t0.9833333333333333\t0.004305251172879821\n", 2289 | "4\t0.9859285714285714\t0.0037251151663171837\n", 2290 | "5\t0.9877857142857143\t0.003324402994306154\n", 2291 | "6\t0.9890714285714286\t0.003014813542746637\n", 2292 | "7\t0.9898809523809524\t0.0027784892140648113\n", 2293 | "8\t0.990547619047619\t0.0025671080004759103\n", 2294 | "9\t0.9913095238095239\t0.002393849937592901\n", 2295 | "10\t0.9918095238095238\t0.002241629978257006\n", 2296 | "11\t0.9923809523809524\t0.002111142203817782\n", 2297 | "Ensemble No 5\n", 2298 | "Epoch\tAccuracy\tLoss\n", 2299 | "0\t0.8883809523809524\t0.018638266524430392\n", 2300 | "1\t0.9749285714285715\t0.006399773116913855\n", 2301 | "2\t0.9809761904761904\t0.0047689435957141415\n", 2302 | "3\t0.9848333333333333\t0.0039175816491257565\n", 2303 | "4\t0.9873809523809524\t0.0033779697038397714\n", 2304 | "5\t0.9887380952380952\t0.0029985817179226998\n", 2305 | "6\t0.9898809523809524\t0.0027158053929310097\n", 2306 | "7\t0.9910238095238095\t0.0024959628998806565\n", 2307 | "8\t0.991904761904762\t0.002313806876354841\n", 2308 | "9\t0.9925\t0.0021705634898985943\n", 2309 | "10\t0.9929285714285714\t0.002038880472395376\n", 2310 | "11\t0.9934761904761905\t0.0019274493881403447\n", 2311 | "Ensemble No 6\n", 2312 | "Epoch\tAccuracy\tLoss\n", 2313 | "0\t0.9310952380952381\t0.014221827570799756\n", 2314 | "1\t0.9762142857142857\t0.005909903033791022\n", 2315 | "2\t0.9814285714285714\t0.004611867855361817\n", 2316 | "3\t0.9846904761904762\t0.0039124973296383436\n", 2317 | "4\t0.9864761904761905\t0.0034358133105639103\n", 2318 | "5\t0.9880952380952381\t0.0030757175465936587\n", 2319 | "6\t0.9893571428571428\t0.0027871272213759873\n", 2320 | "7\t0.9902380952380953\t0.002560089231905339\n", 2321 | "8\t0.990904761904762\t0.0023720302001347757\n", 2322 | "9\t0.991547619047619\t0.00220930353609283\n", 2323 | "10\t0.9922619047619048\t0.0020755580866745143\n", 2324 | "11\t0.9926428571428572\t0.001957762911181317\n", 2325 | "Ensemble No 7\n", 2326 | "Epoch\tAccuracy\tLoss\n", 2327 | "0\t0.8910476190476191\t0.01866057389727659\n", 2328 | "1\t0.9674761904761905\t0.007664416734869145\n", 2329 | "2\t0.9764761904761905\t0.005824241854351076\n", 2330 | "3\t0.980904761904762\t0.0048237866475750076\n", 2331 | "4\t0.9840952380952381\t0.00416487974497097\n", 2332 | "5\t0.9862857142857143\t0.0036850035536284396\n", 2333 | "6\t0.988\t0.0033138244230508564\n", 2334 | "7\t0.989452380952381\t0.0030112102648751494\n", 2335 | "8\t0.9904047619047619\t0.002767998953488098\n", 2336 | "9\t0.9909761904761905\t0.002565392035929999\n", 2337 | "10\t0.9916428571428572\t0.0023878570784257565\n", 2338 | "11\t0.9923333333333333\t0.00224290590781486\n", 2339 | "Ensemble No 8\n", 2340 | "Epoch\tAccuracy\tLoss\n", 2341 | "0\t0.7372619047619048\t0.03339018757473479\n", 2342 | "1\t0.7841190476190476\t0.025281104073495445\n", 2343 | "2\t0.8697857142857143\t0.01613007787660422\n", 2344 | "3\t0.8743809523809524\t0.015026655783125481\n", 2345 | "4\t0.8762857142857143\t0.014440044820949412\n", 2346 | "5\t0.8781904761904762\t0.014032664391264615\n", 2347 | "6\t0.8794047619047619\t0.013743549205796563\n", 2348 | "7\t0.8799761904761905\t0.013518933168168731\n", 2349 | "8\t0.8807619047619047\t0.013339132827152378\n", 2350 | "9\t0.8812619047619048\t0.013192019805232774\n", 2351 | "10\t0.8816904761904761\t0.013058571646392595\n", 2352 | "11\t0.8820238095238095\t0.012946484402414622\n", 2353 | "Ensemble No 9\n", 2354 | "Epoch\tAccuracy\tLoss\n", 2355 | "0\t0.9156904761904762\t0.015428369204957418\n", 2356 | "1\t0.9750238095238095\t0.006170348094881171\n", 2357 | "2\t0.9814761904761905\t0.004768365250419842\n", 2358 | "3\t0.9846190476190476\t0.004006376401312179\n", 2359 | "4\t0.9870238095238095\t0.003503682073404165\n", 2360 | "5\t0.9882619047619048\t0.0031340198478480032\n", 2361 | "6\t0.989547619047619\t0.0028492020501732496\n", 2362 | "7\t0.9907142857142858\t0.0026232664765456276\n", 2363 | "8\t0.9912142857142857\t0.002431873612380849\n", 2364 | "9\t0.9919285714285714\t0.0022671252587520272\n", 2365 | "10\t0.9925238095238095\t0.002121901073435653\n", 2366 | "11\t0.9930238095238095\t0.0019987044014111786\n", 2367 | "Ensemble No 10\n", 2368 | "Epoch\tAccuracy\tLoss\n", 2369 | "0\t0.9145238095238095\t0.0159442980924006\n", 2370 | "1\t0.972547619047619\t0.006671273740131799\n", 2371 | "2\t0.9807619047619047\t0.004963866395654676\n", 2372 | "3\t0.9848571428571429\t0.004075142005554158\n", 2373 | "4\t0.9876428571428572\t0.003527888545342629\n", 2374 | "5\t0.9894285714285714\t0.003130926226199813\n", 2375 | "6\t0.9902380952380953\t0.0028284889235851585\n", 2376 | "7\t0.9910952380952381\t0.0025910547014135906\n", 2377 | "8\t0.9917142857142857\t0.002396904655614913\n", 2378 | "9\t0.9923095238095239\t0.002235876975185557\n", 2379 | "10\t0.9928809523809524\t0.0021000990945855766\n", 2380 | "11\t0.9932142857142857\t0.0019793774044180863\n", 2381 | "Ensemble No 11\n", 2382 | "Epoch\tAccuracy\tLoss\n", 2383 | "0\t0.9174761904761904\t0.015469770621270213\n", 2384 | "1\t0.9757380952380953\t0.006070246864528405\n", 2385 | "2\t0.9810714285714286\t0.004717265213223081\n", 2386 | "3\t0.9839761904761904\t0.004000431453473186\n", 2387 | "4\t0.9856904761904762\t0.0035248718145969673\n", 2388 | "5\t0.9871904761904762\t0.0031690288804998817\n", 2389 | "6\t0.9889523809523809\t0.002886462336228211\n", 2390 | "7\t0.9896190476190476\t0.0026626641961056286\n", 2391 | "8\t0.9906428571428572\t0.002469589142766717\n", 2392 | "9\t0.9914285714285714\t0.0023042640880992725\n", 2393 | "10\t0.991904761904762\t0.0021531063695813856\n", 2394 | "11\t0.9923333333333333\t0.0020310532989081765\n", 2395 | "Ensemble No 12\n", 2396 | "Epoch\tAccuracy\tLoss\n", 2397 | "0\t0.6915714285714286\t0.03770584853357512\n", 2398 | "1\t0.7699285714285714\t0.0267291182249535\n", 2399 | "2\t0.7757142857142857\t0.02526128294777063\n", 2400 | "3\t0.7780238095238096\t0.024551701106845517\n", 2401 | "4\t0.8675714285714285\t0.016064464039476856\n", 2402 | "5\t0.8760238095238095\t0.01465215212509569\n", 2403 | "6\t0.8776904761904762\t0.01421523496687208\n", 2404 | "7\t0.8788095238095238\t0.013916610044468324\n", 2405 | "8\t0.879547619047619\t0.013690601253975608\n", 2406 | "9\t0.8801190476190476\t0.013505928569205412\n", 2407 | "10\t0.8807857142857143\t0.013350672716887235\n", 2408 | "11\t0.881547619047619\t0.013213022782848624\n", 2409 | "Ensemble No 13\n", 2410 | "Epoch\tAccuracy\tLoss\n", 2411 | "0\t0.9043095238095238\t0.017345776722257563\n", 2412 | "1\t0.9696666666666667\t0.007376836091698157\n", 2413 | "2\t0.9795\t0.005394170997923742\n", 2414 | "3\t0.9833095238095239\t0.004454156959215335\n", 2415 | "4\t0.9853571428571428\t0.003862627853513263\n", 2416 | "5\t0.987452380952381\t0.003439586514361129\n", 2417 | "6\t0.9888095238095238\t0.0031042444393714914\n", 2418 | "7\t0.9898571428571429\t0.002830519109164879\n", 2419 | "8\t0.9910238095238095\t0.00261569553774447\n", 2420 | "9\t0.9917380952380952\t0.002434119211932959\n", 2421 | "10\t0.9923809523809524\t0.0022757627282403636\n", 2422 | "11\t0.9927857142857143\t0.0021390763814289966\n", 2423 | "Ensemble No 14\n", 2424 | "Epoch\tAccuracy\tLoss\n", 2425 | "0\t0.9013809523809524\t0.01742521362149552\n", 2426 | "1\t0.9739523809523809\t0.006462201707864597\n", 2427 | "2\t0.9809285714285715\t0.004908477179628132\n", 2428 | "3\t0.9843333333333333\t0.004113419807326088\n", 2429 | "4\t0.9862857142857143\t0.003581440640766531\n", 2430 | "5\t0.9878333333333333\t0.003194685795474117\n", 2431 | "6\t0.9891904761904762\t0.002906707553933637\n", 2432 | "7\t0.9901666666666666\t0.0026741618498035704\n", 2433 | "8\t0.9910476190476191\t0.002484290521713101\n", 2434 | "9\t0.9916190476190476\t0.0023193022649440364\n", 2435 | "10\t0.9923809523809524\t0.00217233649657896\n", 2436 | "11\t0.9927857142857143\t0.0020456477902940215\n", 2437 | "Ensemble No 15\n", 2438 | "Epoch\tAccuracy\tLoss\n", 2439 | "0\t0.8313333333333334\t0.023864344868786302\n", 2440 | "1\t0.8828095238095238\t0.015125647762399082\n", 2441 | "2\t0.8873095238095238\t0.01389738075652045\n", 2442 | "3\t0.8892142857142857\t0.0132795855097467\n", 2443 | "4\t0.8906666666666667\t0.012871225074265306\n", 2444 | "5\t0.8917857142857143\t0.012570265633064001\n", 2445 | "6\t0.8927142857142857\t0.012342315672186617\n", 2446 | "7\t0.8936190476190476\t0.01215927037686858\n", 2447 | "8\t0.8939047619047619\t0.012005833990423964\n", 2448 | "9\t0.8942857142857142\t0.011875185593140497\n", 2449 | "10\t0.8947857142857143\t0.011763219567812424\n", 2450 | "11\t0.8951904761904762\t0.011662756972428825\n" 2451 | ] 2452 | } 2453 | ], 2454 | "source": [ 2455 | "models = []\n", 2456 | "optims = []\n", 2457 | "loss = nn.MSELoss()\n", 2458 | "ensembles = 15\n", 2459 | "import sys\n", 2460 | "for i in range(ensembles):\n", 2461 | " sys.stdout.write(f'Ensemble No {i+1}\\n')\n", 2462 | " model = LeNet5()\n", 2463 | " model.apply(init_weights)\n", 2464 | " #optim = Adam(model.parameters())\n", 2465 | " optim = SGD(model.parameters(), lr=0.1, momentum=0.9)\n", 2466 | "\n", 2467 | " accuracy, _ = fit(train_dl, model,loss,optim,epochs)\n", 2468 | " if accuracy[-1] > 95:\n", 2469 | " models.append(model)\n", 2470 | " optims.append(optim)" 2471 | ] 2472 | }, 2473 | { 2474 | "cell_type": "code", 2475 | "execution_count": 43, 2476 | "metadata": { 2477 | "colab": { 2478 | "base_uri": "https://localhost:8080/" 2479 | }, 2480 | "id": "0fryczCwqdXT", 2481 | "outputId": "3ef08228-e52a-4dd4-c568-6e43f3017345" 2482 | }, 2483 | "outputs": [ 2484 | { 2485 | "name": "stdout", 2486 | "output_type": "stream", 2487 | "text": [ 2488 | "Saving submission_ensemble_15_LeNets_v17.csv\n" 2489 | ] 2490 | } 2491 | ], 2492 | "source": [ 2493 | "ensemble = cn_train_yhat\n", 2494 | "\n", 2495 | "for model in models:\n", 2496 | " ensemble+=model(x_test)\n", 2497 | "\n", 2498 | "ensemble_one_hot = torch.argmax(ensemble,1) # Find argmax\n", 2499 | "export_csv(f'ensemble_{ensembles}_LeNets',ensemble_one_hot, commit_no=commit_no)" 2500 | ] 2501 | }, 2502 | { 2503 | "cell_type": "code", 2504 | "execution_count": 44, 2505 | "metadata": { 2506 | "colab": { 2507 | "base_uri": "https://localhost:8080/" 2508 | }, 2509 | "id": "vvk0d4AQqiWt", 2510 | "outputId": "615dcb9e-1e07-4e9e-ff8c-4983602e08e1" 2511 | }, 2512 | "outputs": [ 2513 | { 2514 | "data": { 2515 | "text/plain": [ 2516 | "tensor([2, 0, 9, ..., 3, 9, 2])" 2517 | ] 2518 | }, 2519 | "execution_count": 44, 2520 | "metadata": { 2521 | "tags": [] 2522 | }, 2523 | "output_type": "execute_result" 2524 | } 2525 | ], 2526 | "source": [ 2527 | "ensemble_one_hot" 2528 | ] 2529 | }, 2530 | { 2531 | "cell_type": "code", 2532 | "execution_count": 46, 2533 | "metadata": { 2534 | "colab": { 2535 | "base_uri": "https://localhost:8080/" 2536 | }, 2537 | "id": "Qsa9B0TCqlpS", 2538 | "outputId": "918bf481-ea32-4017-940c-b8614cd37c98" 2539 | }, 2540 | "outputs": [ 2541 | { 2542 | "name": "stdout", 2543 | "output_type": "stream", 2544 | "text": [ 2545 | "Saving submission_ensemble_v17.csv\n" 2546 | ] 2547 | } 2548 | ], 2549 | "source": [ 2550 | "ensemble = ff_test_yhat + cn_train_yhat # Add probabilities of individual predictions\n", 2551 | "ensemble_one_hot = torch.argmax(y_train,1) # Find argmax\n", 2552 | "export_csv('ensemble',ensemble_one_hot, commit_no=commit_no)" 2553 | ] 2554 | }, 2555 | { 2556 | "cell_type": "markdown", 2557 | "metadata": { 2558 | "id": "nfsVXTWFqpH8" 2559 | }, 2560 | "source": [ 2561 | "### Final Accuracy achieved= 98.94%" 2562 | ] 2563 | }, 2564 | { 2565 | "cell_type": "markdown", 2566 | "metadata": {}, 2567 | "source": [ 2568 | "Thanks" 2569 | ] 2570 | } 2571 | ], 2572 | "metadata": { 2573 | "colab": { 2574 | "name": "digit.ipynb", 2575 | "provenance": [], 2576 | "toc_visible": true 2577 | }, 2578 | "kernelspec": { 2579 | "display_name": "Python 3", 2580 | "language": "python", 2581 | "name": "python3" 2582 | }, 2583 | "language_info": { 2584 | "codemirror_mode": { 2585 | "name": "ipython", 2586 | "version": 3 2587 | }, 2588 | "file_extension": ".py", 2589 | "mimetype": "text/x-python", 2590 | "name": "python", 2591 | "nbconvert_exporter": "python", 2592 | "pygments_lexer": "ipython3", 2593 | "version": "3.7.8" 2594 | } 2595 | }, 2596 | "nbformat": 4, 2597 | "nbformat_minor": 4 2598 | } 2599 | -------------------------------------------------------------------------------- /Digit Recognition/digit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | #

Digit recognition

5 | 6 | # our goal is to correctly identify digits from a dataset of tens of thousands of handwritten images. We’ve curated a set of tutorial-style kernels which cover everything from regression to neural networks. We encourage you to experiment with different algorithms to learn first-hand what works well and how techniques compare. 7 | 8 | # ## kaggle Config 9 | 10 | # In[1]: 11 | 12 | 13 | get_ipython().system(' pip install -q kaggle') 14 | 15 | 16 | # In[2]: 17 | 18 | 19 | from google.colab import files 20 | 21 | 22 | # In[3]: 23 | 24 | 25 | files.upload() 26 | 27 | 28 | # In[4]: 29 | 30 | 31 | get_ipython().system(' mkdir ~/.kaggle ') 32 | 33 | 34 | # In[5]: 35 | 36 | 37 | get_ipython().system(' cp kaggle.json ~/.kaggle/') 38 | 39 | 40 | # In[6]: 41 | 42 | 43 | get_ipython().system(' chmod 600 ~/.kaggle/kaggle.json') 44 | 45 | 46 | # In[7]: 47 | 48 | 49 | get_ipython().system(' kaggle datasets list') 50 | 51 | 52 | # In[9]: 53 | 54 | 55 | get_ipython().system('kaggle competitions download -c digit-recognizer') 56 | 57 | 58 | # In[10]: 59 | 60 | 61 | get_ipython().system(' mkdir train') 62 | 63 | 64 | # In[12]: 65 | 66 | 67 | get_ipython().system(' unzip train.csv.zip -d train') 68 | 69 | 70 | # In[13]: 71 | 72 | 73 | get_ipython().system('mkdir test') 74 | 75 | 76 | # In[14]: 77 | 78 | 79 | get_ipython().system(' unzip test.csv.zip -d test') 80 | 81 | 82 | # ## Importing Essential Libraries 83 | 84 | # In[17]: 85 | 86 | 87 | # Basic Torch 88 | import torch 89 | import torch.nn as nn 90 | import torch.nn.functional as F 91 | from torch.utils.data import DataLoader 92 | import torchvision.transforms as transforms 93 | import torchvision 94 | from torch.utils.data import TensorDataset 95 | from torch.optim import Adam, SGD 96 | 97 | # Basic Numeric Computation 98 | import numpy as np 99 | import pandas as pd 100 | 101 | # Look at data 102 | from matplotlib import pyplot 103 | 104 | # Easy way to split train data 105 | from sklearn.model_selection import train_test_split 106 | 107 | # # Looking at directory 108 | # import os 109 | # base_dir = "../input" 110 | # print(os.listdir(base_dir)) 111 | 112 | device = torch.device("cpu")# if torch.cuda.is_available() else torch.device("cpu") 113 | device 114 | epochs=12 115 | 116 | 117 | # In[19]: 118 | 119 | 120 | train = pd.read_csv('/content/train/train.csv') 121 | test = pd.read_csv('/content/test/test.csv') 122 | 123 | 124 | # In[20]: 125 | 126 | 127 | train.head() 128 | 129 | 130 | # ## 2. transforming Data 131 | 132 | # In[21]: 133 | 134 | 135 | # Convert Dataframe into format ready for training 136 | def createImageData(raw: pd.DataFrame): 137 | y = raw['label'].values 138 | y.resize(y.shape[0],1) 139 | x = raw[[i for i in raw.columns if i != 'label']].values 140 | x = x.reshape([-1,1, 28, 28]) 141 | y = y.astype(int).reshape(-1) 142 | x = x.astype(float) 143 | return x, y 144 | 145 | ## Convert to One Hot Encoding 146 | def one_hot_embedding(labels, num_classes=10): 147 | y = torch.eye(num_classes) 148 | return y[labels] 149 | 150 | 151 | # In[22]: 152 | 153 | 154 | x_train, y_train = createImageData(train) 155 | #x_train, x_val, y_train, y_val = train_test_split(x,y, test_size=0.02) 156 | 157 | #x_train.shape, y_train.shape, x_val.shape, y_val.shape 158 | x_train.shape, y_train.shape 159 | 160 | 161 | # In[23]: 162 | 163 | 164 | # Normalization 165 | mean = x_train.mean() 166 | std = x_train.std() 167 | x_train = (x_train-mean)/std 168 | #x_val = (x_val-mean)/std 169 | 170 | # Numpy to Torch Tensor 171 | x_train = torch.from_numpy(np.float32(x_train)).to(device) 172 | y_train = torch.from_numpy(y_train.astype(np.long)).to(device) 173 | y_train = one_hot_embedding(y_train) 174 | #x_val = torch.from_numpy(np.float32(x_val)) 175 | #y_val = torch.from_numpy(y_val.astype(np.long)) 176 | 177 | 178 | # # 3. Loading Dataset 179 | 180 | # In[24]: 181 | 182 | 183 | # Convert into Torch Dataset 184 | train_ds = TensorDataset(x_train, y_train) 185 | #val_ds = TensorDataset(x_val,y_val) 186 | 187 | 188 | # In[25]: 189 | 190 | 191 | # Make Data Loader 192 | train_dl = DataLoader(train_ds, batch_size=64) 193 | 194 | 195 | # ## 4. EDA 196 | 197 | # In[26]: 198 | 199 | 200 | index = 1 201 | pyplot.imshow(x_train.cpu()[index].reshape((28, 28)), cmap="gray") 202 | print(y_train[index]) 203 | 204 | 205 | # ## 5. Model 206 | 207 | # In[27]: 208 | 209 | 210 | # Helper Functions 211 | 212 | ## Initialize weight with xavier_uniform 213 | def init_weights(m): 214 | if type(m) == nn.Linear: 215 | torch.nn.init.xavier_uniform(m.weight) 216 | m.bias.data.fill_(0.01) 217 | 218 | ## Flatten Later 219 | class Flatten(nn.Module): 220 | def forward(self, input): 221 | return input.view(input.size(0), -1) 222 | 223 | # Train the network and print accuracy and loss overtime 224 | def fit(train_dl, model, loss, optim, epochs=10): 225 | model = model.to(device) 226 | print('Epoch\tAccuracy\tLoss') 227 | accuracy_overtime = [] 228 | loss_overtime = [] 229 | for epoch in range(epochs): 230 | avg_loss = 0 231 | correct = 0 232 | total=0 233 | for x, y in train_dl: # Iterate over Data Loder 234 | 235 | # Forward pass 236 | yhat = model(x) 237 | l = loss(y, yhat) 238 | 239 | #Metrics 240 | avg_loss+=l.item() 241 | 242 | # Backward pass 243 | optim.zero_grad() 244 | l.backward() 245 | optim.step() 246 | 247 | # Metrics 248 | _, original = torch.max(y, 1) 249 | _, predicted = torch.max(yhat.data, 1) 250 | total += y.size(0) 251 | correct = correct + (original == predicted).sum().item() 252 | 253 | accuracy_overtime.append(correct/total) 254 | loss_overtime.append(avg_loss/len(train_dl)) 255 | print(epoch,accuracy_overtime[-1], loss_overtime[-1], sep='\t') 256 | return accuracy_overtime, loss_overtime 257 | 258 | # Plot Accuracy and Loss of Model 259 | def plot_accuracy_loss(accuracy, loss): 260 | f = pyplot.figure(figsize=(15,5)) 261 | ax1 = f.add_subplot(121) 262 | ax2 = f.add_subplot(122) 263 | ax1.title.set_text("Accuracy over epochs") 264 | ax2.title.set_text("Loss over epochs") 265 | ax1.plot(accuracy) 266 | ax2.plot(loss, 'r:') 267 | 268 | # Take an array and show what model predicts 269 | def predict_for_index(array, model, index): 270 | testing = array[index].view(1,28,28) 271 | pyplot.imshow(x_train[index].reshape((28, 28)), cmap="gray") 272 | print(x_train[index].shape) 273 | a = model(testing.float()) 274 | print('Prediction',torch.argmax(a,1)) 275 | 276 | 277 | # In[28]: 278 | 279 | 280 | # Define the model 281 | 282 | ff_model = nn.Sequential( 283 | Flatten(), 284 | nn.Linear(28*28, 100), 285 | nn.ReLU(), 286 | nn.Linear(100, 10), 287 | nn.Softmax(1), 288 | ).to(device) 289 | 290 | 291 | # In[29]: 292 | 293 | 294 | # Initialize model with xavier initialization which is recommended for ReLu 295 | ff_model.apply(init_weights) 296 | 297 | 298 | # In[30]: 299 | 300 | 301 | optim = Adam(ff_model.parameters()) 302 | loss = nn.MSELoss() 303 | output = fit(train_dl, ff_model, loss, optim, epochs) 304 | plot_accuracy_loss(*output) 305 | 306 | 307 | # In[31]: 308 | 309 | 310 | index = 4 311 | predict_for_index(x_train, ff_model, index) 312 | 313 | 314 | # In[32]: 315 | 316 | 317 | # A too simple NN taken from pytorch.org/tutorials 318 | class Mnist_CNN(nn.Module): 319 | def __init__(self): 320 | super().__init__() 321 | self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1) 322 | self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1) 323 | self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1) 324 | 325 | def forward(self, xb): 326 | xb = xb.view(-1, 1, 28, 28) 327 | xb = F.relu(self.conv1(xb)) 328 | xb = F.relu(self.conv2(xb)) 329 | xb = F.relu(self.conv3(xb)) 330 | xb = F.avg_pool2d(xb, 4) 331 | return xb.view(-1, xb.size(1)) 332 | 333 | class LeNet5(nn.Module): 334 | def __init__(self): 335 | super().__init__() 336 | self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1) 337 | self.average1 = nn.AvgPool2d(2, stride=2) 338 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1) 339 | self.average2 = nn.AvgPool2d(2, stride=2) 340 | self.conv3 = nn.Conv2d(16, 120, kernel_size=4, stride=1) 341 | 342 | self.flatten = Flatten() 343 | 344 | self.fc1 = nn.Linear(120, 82) 345 | self.fc2 = nn.Linear(82,10) 346 | 347 | def forward(self, xb): 348 | xb = xb.view(-1, 1, 28, 28) 349 | xb = F.tanh(self.conv1(xb)) 350 | xb = self.average1(xb) 351 | xb = F.tanh(self.conv2(xb)) 352 | xb = self.average2(xb) 353 | xb = F.tanh(self.conv3(xb)) 354 | xb = xb.view(-1, xb.shape[1]) 355 | xb = F.relu(self.fc1(xb)) 356 | xb = F.relu(self.fc2(xb)) 357 | return xb 358 | 359 | 360 | # In[33]: 361 | 362 | 363 | conv_model = LeNet5() 364 | conv_model.apply(init_weights) 365 | loss = nn.MSELoss() 366 | optim = SGD(conv_model.parameters(), lr=0.1, momentum=0.9) 367 | plot_accuracy_loss(*fit(train_dl, conv_model,loss,optim,epochs)) 368 | 369 | 370 | # ## Working on test data 371 | 372 | # ### Normalization 373 | 374 | # In[34]: 375 | 376 | 377 | x_test = test.values 378 | x_test = x_test.reshape([-1, 28, 28]).astype(float) 379 | x_test = (x_test-mean)/std 380 | x_test = torch.from_numpy(np.float32(x_test)) 381 | x_test.shape 382 | 383 | 384 | # #### Prediction 385 | 386 | # In[35]: 387 | 388 | 389 | index = 7 390 | predict_for_index(x_test, ff_model, index) 391 | predict_for_index(x_test, conv_model, index) 392 | 393 | 394 | # In[36]: 395 | 396 | 397 | # Export data to CSV in format of submission 398 | def export_csv(model_name, predictions, commit_no): 399 | df = pd.DataFrame(prediction.tolist(), columns=['Label']) 400 | df['ImageId'] = df.index + 1 401 | file_name = f'submission_{model_name}_v{commit_no}.csv' 402 | print('Saving ',file_name) 403 | df[['ImageId','Label']].to_csv(file_name, index = False) 404 | 405 | 406 | # In[37]: 407 | 408 | 409 | test.head() 410 | 411 | 412 | # In[38]: 413 | 414 | 415 | # just to make output easier to read 416 | commit_no = 17 417 | 418 | 419 | # In[39]: 420 | 421 | 422 | ff_test_yhat = ff_model(x_test.float()) 423 | prediction = torch.argmax(ff_test_yhat,1) 424 | print('Prediction',prediction) 425 | export_csv('ff_model',prediction, commit_no=commit_no) 426 | 427 | 428 | # In[40]: 429 | 430 | 431 | cn_train_yhat = conv_model(x_test) 432 | prediction = torch.argmax(cn_train_yhat,1) 433 | yo = torch.argmax(y_train,1) 434 | export_csv('lenet_model',prediction, commit_no=commit_no) 435 | 436 | 437 | # ### Ensembling 438 | 439 | # In[41]: 440 | 441 | 442 | models = [] 443 | optims = [] 444 | loss = nn.MSELoss() 445 | ensembles = 15 446 | import sys 447 | for i in range(ensembles): 448 | sys.stdout.write(f'Ensemble No {i+1}\n') 449 | model = LeNet5() 450 | model.apply(init_weights) 451 | #optim = Adam(model.parameters()) 452 | optim = SGD(model.parameters(), lr=0.1, momentum=0.9) 453 | 454 | accuracy, _ = fit(train_dl, model,loss,optim,epochs) 455 | if accuracy[-1] > 95: 456 | models.append(model) 457 | optims.append(optim) 458 | 459 | 460 | # In[43]: 461 | 462 | 463 | ensemble = cn_train_yhat 464 | 465 | for model in models: 466 | ensemble+=model(x_test) 467 | 468 | ensemble_one_hot = torch.argmax(ensemble,1) # Find argmax 469 | export_csv(f'ensemble_{ensembles}_LeNets',ensemble_one_hot, commit_no=commit_no) 470 | 471 | 472 | # In[44]: 473 | 474 | 475 | ensemble_one_hot 476 | 477 | 478 | # In[46]: 479 | 480 | 481 | ensemble = ff_test_yhat + cn_train_yhat # Add probabilities of individual predictions 482 | ensemble_one_hot = torch.argmax(y_train,1) # Find argmax 483 | export_csv('ensemble',ensemble_one_hot, commit_no=commit_no) 484 | 485 | 486 | #

Final Accuracy Achieved 98.94%

487 | 488 | # Thanks 489 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sarvesh Kumar Sharma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Learining-Projects 2 | In this repository, I will keep my all Deep Learning Projects implementations. 3 | --------------------------------------------------------------------------------