├── .gitignore ├── Classification ├── decision_trees.ipynb ├── gradient_boosted_decision_trees.ipynb ├── kernelized_svm.ipynb ├── linear_svm.ipynb ├── logistic_regression.ipynb ├── mlp_pytorch.ipynb ├── naive_bayes_classifiers.ipynb ├── neural_networks_classifier.ipynb ├── random_forests.ipynb └── rnn_pytorch.ipynb ├── README.md └── Regression ├── lasso_regression.ipynb ├── linear_regression.ipynb ├── linear_vs_ polynomial_regressions.ipynb ├── neural_networks_regressor.ipynb └── ridge_regression.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea/* 3 | .DS_Store* -------------------------------------------------------------------------------- /Classification/decision_trees.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Decision Trees" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 10, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "from sklearn.model_selection import train_test_split" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## 2. Data Input and Variables" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "### Data Input\n", 64 | "# df = \n", 65 | "\n", 66 | "### Defining Variables \n", 67 | "# X = \n", 68 | "# y = \n", 69 | "\n", 70 | "### Data Input Example \n", 71 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 72 | "\n", 73 | "X = df[['horsepower', 'normalized-losses']]\n", 74 | "y = df['price']" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## 3. The Model" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "*Run to build the model.*" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 8, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Accuracy of Decision Tree classifier on training set: 0.09\n", 101 | "Accuracy of Decision Tree classifier on test set: 0.00\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "from sklearn.tree import DecisionTreeClassifier\n", 107 | "from sklearn.model_selection import train_test_split\n", 108 | "\n", 109 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 3)\n", 110 | "clf = DecisionTreeClassifier(max_depth = 4, min_samples_leaf = 8,random_state = 0).fit(X_train, y_train).fit(X_train, y_train)\n", 111 | "\n", 112 | "print('Accuracy of Decision Tree classifier on training set: {:.2f}'\n", 113 | " .format(clf.score(X_train, y_train)))\n", 114 | "print('Accuracy of Decision Tree classifier on test set: {:.2f}'\n", 115 | " .format(clf.score(X_test, y_test)))" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.4" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /Classification/gradient_boosted_decision_trees.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Gradient-boosted Decision Trees" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "from sklearn.model_selection import train_test_split" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## 2. Data Input and Variables" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "### Data Input\n", 64 | "# df = \n", 65 | "\n", 66 | "### Defining Variables \n", 67 | "# X = \n", 68 | "# y = \n", 69 | "\n", 70 | "### Data Input Example \n", 71 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 72 | "\n", 73 | "X = df[['horsepower', 'normalized-losses']]\n", 74 | "y = df['price']" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## 3. The Model" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "*Run to build the model.*" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 6, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Car dataset (learning_rate=0.1, max_depth=3)\n", 101 | "Accuracy of GBDT classifier on training set: 0.64\n", 102 | "Accuracy of GBDT classifier on test set: 0.02\n", 103 | "\n", 104 | "Car dataset (learning_rate=0.01, max_depth=2)\n", 105 | "Accuracy of GBDT classifier on training set: 0.67\n", 106 | "Accuracy of GBDT classifier on test set: 0.02\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "from sklearn.ensemble import GradientBoostingClassifier\n", 112 | "\n", 113 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 114 | "\n", 115 | "#### COMPARE YOUR MODELS ####\n", 116 | "\n", 117 | "# Model with the learning rate = 0.1 and max_dept = 3 (default settings)\n", 118 | "clf = GradientBoostingClassifier(random_state = 0).fit(X_train, y_train)\n", 119 | "\n", 120 | "print('Car dataset (learning_rate=0.1, max_depth=3)')\n", 121 | "print('Accuracy of GBDT classifier on training set: {:.2f}'\n", 122 | " .format(clf.score(X_train, y_train)))\n", 123 | "print('Accuracy of GBDT classifier on test set: {:.2f}\\n'\n", 124 | " .format(clf.score(X_test, y_test)))\n", 125 | "\n", 126 | "# Model with the learning rate = 0.01 and max_dept = 2\n", 127 | "clf = GradientBoostingClassifier(learning_rate = 0.01, max_depth = 2, random_state = 0).fit(X_train, y_train)\n", 128 | "\n", 129 | "print('Car dataset (learning_rate=0.01, max_depth=2)')\n", 130 | "print('Accuracy of GBDT classifier on training set: {:.2f}'\n", 131 | " .format(clf.score(X_train, y_train)))\n", 132 | "print('Accuracy of GBDT classifier on test set: {:.2f}'\n", 133 | " .format(clf.score(X_test, y_test)))" 134 | ] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.6.4" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 2 158 | } 159 | -------------------------------------------------------------------------------- /Classification/kernelized_svm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Kernelized Support Vector Machines" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "from sklearn.model_selection import train_test_split" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## 2. Data Input and Variables" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "### Data Input\n", 64 | "# df = \n", 65 | "\n", 66 | "### Defining Variables \n", 67 | "# X = \n", 68 | "# y = \n", 69 | "\n", 70 | "### Data Input Example \n", 71 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 72 | "\n", 73 | "X = df[['horsepower']]\n", 74 | "y = df['price']" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## 3. The Model" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "*Run to build the SVM with both default radial basis function (RBF) and polynomial kernel.*" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Accuracy of RBF-kernel SVC on training set: 0.37\n", 101 | "Accuracy of RBF-kernel SVC on test set: 0.00\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "from sklearn.svm import SVC\n", 107 | "\n", 108 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 109 | "\n", 110 | "\n", 111 | "\n", 112 | "# The default SVC kernel is radial basis function (RBF)\n", 113 | "clf = SVC().fit(X_train, y_train)\n", 114 | "\n", 115 | "print('Accuracy of RBF-kernel SVC on training set: {:.2f}'\n", 116 | " .format(clf.score(X_train, y_train)))\n", 117 | "print('Accuracy of RBF-kernel SVC on test set: {:.2f}'\n", 118 | " .format(clf.score(X_test, y_test)))\n", 119 | "\n", 120 | "### THIS MIGHT TAKE A WHILE\n", 121 | "# # Compare decision boundries with polynomial kernel, degree = 3\n", 122 | "# clf = SVC(kernel = 'poly', degree = 3).fit(X_train, y_train)\n", 123 | "\n", 124 | "# print('Accuracy of poly-kernel SVC on training set: {:.2f}'\n", 125 | "# .format(clf.score(X_train, y_train)))\n", 126 | "# print('Accuracy of poly-kernel SVC on test set: {:.2f}'\n", 127 | "# .format(clf.score(X_test, y_test)))" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "### 3.1. Support Vector Machine with RBF kernel: gamma parameter" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "SVM (RBF) with gamma = 1e-05\n", 147 | "Accuracy of SVM (RBF) classifier on training set: 0.09\n", 148 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 149 | "\n", 150 | "SVM (RBF) with gamma = 100\n", 151 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 152 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 153 | "\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 159 | "\n", 160 | "for this_gamma in [0.00001, 100]:\n", 161 | " clf = SVC(kernel = 'rbf', gamma=this_gamma).fit(X_train, y_train)\n", 162 | " print('SVM (RBF) with gamma = {}'.format(this_gamma))\n", 163 | " print('Accuracy of SVM (RBF) classifier on training set: {:.2f}'\n", 164 | " .format(clf.score(X_train, y_train)))\n", 165 | " print('Accuracy of SVM (RBF) classifier on test set: {:.2f}\\n'\n", 166 | " .format(clf.score(X_test, y_test)))" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "### 3.2. Support Vector Machine with RBF kernel: using both C and gamma parameter" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 10, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "SVM (RBF) with gamma = 0.01 and C = 0.1\n", 186 | "Accuracy of SVM (RBF) classifier on training set: 0.09\n", 187 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 188 | "\n", 189 | "SVM (RBF) with gamma = 0.01 and C = 1\n", 190 | "Accuracy of SVM (RBF) classifier on training set: 0.15\n", 191 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 192 | "\n", 193 | "SVM (RBF) with gamma = 0.01 and C = 15\n", 194 | "Accuracy of SVM (RBF) classifier on training set: 0.33\n", 195 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 196 | "\n", 197 | "SVM (RBF) with gamma = 0.01 and C = 250\n", 198 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 199 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 200 | "\n", 201 | "SVM (RBF) with gamma = 1 and C = 0.1\n", 202 | "Accuracy of SVM (RBF) classifier on training set: 0.11\n", 203 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 204 | "\n", 205 | "SVM (RBF) with gamma = 1 and C = 1\n", 206 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 207 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 208 | "\n", 209 | "SVM (RBF) with gamma = 1 and C = 15\n", 210 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 211 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 212 | "\n", 213 | "SVM (RBF) with gamma = 1 and C = 250\n", 214 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 215 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 216 | "\n", 217 | "SVM (RBF) with gamma = 5 and C = 0.1\n", 218 | "Accuracy of SVM (RBF) classifier on training set: 0.11\n", 219 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 220 | "\n", 221 | "SVM (RBF) with gamma = 5 and C = 1\n", 222 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 223 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 224 | "\n", 225 | "SVM (RBF) with gamma = 5 and C = 15\n", 226 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 227 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 228 | "\n", 229 | "SVM (RBF) with gamma = 5 and C = 250\n", 230 | "Accuracy of SVM (RBF) classifier on training set: 0.37\n", 231 | "Accuracy of SVM (RBF) classifier on test set: 0.00\n", 232 | "\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "from sklearn.svm import SVC\n", 238 | "from sklearn.model_selection import train_test_split\n", 239 | "\n", 240 | "\n", 241 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 242 | "\n", 243 | "for this_gamma in [0.01, 1, 5]:\n", 244 | " \n", 245 | " for this_C in [0.1, 1, 15, 250]:\n", 246 | " title = 'gamma = {:.2f}, C = {:.2f}'.format(this_gamma, this_C)\n", 247 | " clf = SVC(kernel = 'rbf', gamma = this_gamma, C = this_C).fit(X_train, y_train)\n", 248 | " print('SVM (RBF) with gamma = {} and C = {}'.format(this_gamma, this_C))\n", 249 | " print('Accuracy of SVM (RBF) classifier on training set: {:.2f}'\n", 250 | " .format(clf.score(X_train, y_train)))\n", 251 | " print('Accuracy of SVM (RBF) classifier on test set: {:.2f}\\n'\n", 252 | " .format(clf.score(X_test, y_test)))" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "### 3.3. SVMs with normalized data (feature preprocessing using minmax scaling)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 12, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "Cars dataset (normalized with MinMax scaling)\n", 272 | "RBF-kernel SVC (with MinMax scaling) training set accuracy: 0.09\n", 273 | "RBF-kernel SVC (with MinMax scaling) test set accuracy: 0.00\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "from sklearn.preprocessing import MinMaxScaler\n", 279 | "scaler = MinMaxScaler()\n", 280 | "X_train_scaled = scaler.fit_transform(X_train)\n", 281 | "X_test_scaled = scaler.transform(X_test)\n", 282 | "\n", 283 | "clf = SVC(C=10).fit(X_train_scaled, y_train)\n", 284 | "print('Cars dataset (normalized with MinMax scaling)')\n", 285 | "print('RBF-kernel SVC (with MinMax scaling) training set accuracy: {:.2f}'\n", 286 | " .format(clf.score(X_train_scaled, y_train)))\n", 287 | "print('RBF-kernel SVC (with MinMax scaling) test set accuracy: {:.2f}'\n", 288 | " .format(clf.score(X_test_scaled, y_test)))" 289 | ] 290 | } 291 | ], 292 | "metadata": { 293 | "kernelspec": { 294 | "display_name": "Python 3", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.6.4" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 2 313 | } 314 | -------------------------------------------------------------------------------- /Classification/linear_svm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Linear Support Vector Machines" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import pandas as pd\n", 39 | "from sklearn.model_selection import train_test_split" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## 2. Data Input and Variables" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "### Data Input\n", 63 | "# df = \n", 64 | "\n", 65 | "### Defining Variables \n", 66 | "# X = \n", 67 | "# y = \n", 68 | "\n", 69 | "### Data Input Example \n", 70 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 71 | "\n", 72 | "X = df[['horsepower']]\n", 73 | "y = df['price']\n", 74 | "\n", 75 | "X_2 = df[['horsepower', 'normalized-losses']]\n", 76 | "y_2 = df['price']" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## 3. The Model" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "*Run to build the model.*" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Linear SVC, C = 1.000\n", 103 | "Accuracy of Linear SVC classifier on training set: 0.27\n", 104 | "Accuracy of Linear SVC classifier on test set: 0.00\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "from sklearn.svm import SVC\n", 110 | "\n", 111 | "\n", 112 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 113 | "\n", 114 | "#C parameter\n", 115 | "this_C = 1.0\n", 116 | "\n", 117 | "#model\n", 118 | "clf = SVC(kernel = 'linear', C=this_C).fit(X_train, y_train)\n", 119 | "print('Linear SVC, C = {:.3f}'.format(this_C))\n", 120 | "print('Accuracy of Linear SVC classifier on training set: {:.2f}'\n", 121 | " .format(clf.score(X_train, y_train)))\n", 122 | "print('Accuracy of Linear SVC classifier on test set: {:.2f}'\n", 123 | " .format(clf.score(X_test, y_test)))" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### 3.1. Linear SVM regularization: C parameter" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "Linear SVM with C = 1e-05\n", 143 | "Accuracy of Linear SVM classifier on training set: 0.08\n", 144 | "Accuracy of Linear SVM classifier on test set: 0.00\n", 145 | "\n", 146 | "Linear SVM with C = 100\n", 147 | "Accuracy of Linear SVM classifier on training set: 0.27\n", 148 | "Accuracy of Linear SVM classifier on test set: 0.00\n", 149 | "\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "for this_C in [0.00001, 100]:\n", 155 | " clf = SVC(kernel = 'linear', C=this_C).fit(X_train, y_train)\n", 156 | " print('Linear SVM with C = {}'.format(this_C))\n", 157 | " print('Accuracy of Linear SVM classifier on training set: {:.2f}'\n", 158 | " .format(clf.score(X_train, y_train)))\n", 159 | " print('Accuracy of Linear SVM classifier on test set: {:.2f}\\n'\n", 160 | " .format(clf.score(X_test, y_test)))" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "### 3.2. LinearSVC with M classes generates M one vs rest classifiers" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 6, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Coefficients:\n", 180 | " [[-2.14396025e-01 8.42666479e-02]\n", 181 | " [-1.17382373e-02 -2.65144462e-03]\n", 182 | " [-2.50881992e-02 2.65386917e-03]\n", 183 | " [-1.61677700e-02 -4.31542442e-04]\n", 184 | " [-1.29296866e-02 -2.01693724e-03]\n", 185 | " [-1.66081379e-02 -9.12590959e-04]\n", 186 | " [-1.30616076e-02 -7.62856148e-04]\n", 187 | " [-2.49290194e-02 4.42599109e-03]\n", 188 | " [-9.83381629e-03 -5.10230399e-03]\n", 189 | " [-1.68464031e-02 -8.35901198e-04]\n", 190 | " [-2.78919731e-02 5.84844881e-03]\n", 191 | " [-6.36812068e-03 -6.71373644e-03]\n", 192 | " [-4.65289397e-03 -7.14568659e-03]\n", 193 | " [-6.60199135e-03 -5.15677778e-03]\n", 194 | " [-1.33483761e-02 -1.89186669e-03]\n", 195 | " [-1.19880856e-02 1.23360392e-02]\n", 196 | " [-1.41249931e-02 -2.08341604e-03]\n", 197 | " [-3.17240951e-03 1.03903547e-02]\n", 198 | " [-7.43591448e-05 9.68438389e-03]\n", 199 | " [-1.51057191e-02 -4.54332614e-04]\n", 200 | " [-1.25707877e-02 -2.12566598e-03]\n", 201 | " [-3.60300645e-03 -8.41033278e-03]\n", 202 | " [-1.00827917e-02 -4.32172697e-03]\n", 203 | " [-5.60104665e-03 -7.46541361e-03]\n", 204 | " [-2.45725160e-02 2.27565777e-03]\n", 205 | " [-6.99547020e-03 -5.29453577e-03]\n", 206 | " [-7.50729974e-03 -5.09490567e-03]\n", 207 | " [-3.24504011e-03 -1.04632317e-02]\n", 208 | " [-1.28763717e-02 -9.87222966e-04]\n", 209 | " [-6.07320581e-03 -8.16852041e-03]\n", 210 | " [-1.14403763e-02 -3.56539990e-03]\n", 211 | " [-2.54676635e-03 -8.81117306e-03]\n", 212 | " [-1.41243331e-02 -1.37035804e-03]\n", 213 | " [-1.04894179e-02 -2.51157592e-03]\n", 214 | " [-2.34506717e-02 5.23119610e-03]\n", 215 | " [-6.58648642e-03 -4.80296672e-03]\n", 216 | " [-6.19437987e-03 -5.47020849e-03]\n", 217 | " [-1.29508907e-02 -6.66307160e-04]\n", 218 | " [-1.46862189e-02 -1.36133299e-03]\n", 219 | " [-5.42030942e-03 1.08281515e-02]\n", 220 | " [-8.26091254e-03 -5.53571258e-03]\n", 221 | " [-7.66136228e-03 -4.89077984e-03]\n", 222 | " [-5.49772048e-03 -7.01275571e-03]\n", 223 | " [-6.61433720e-03 -6.73789908e-03]\n", 224 | " [-4.42695314e-03 -7.30171625e-03]\n", 225 | " [-3.04956992e-02 7.10798933e-03]\n", 226 | " [-5.85170608e-02 1.86262800e-02]\n", 227 | " [-5.59840492e-03 -7.05964138e-03]\n", 228 | " [-6.99770750e-03 -4.44528750e-03]\n", 229 | " [-1.27172386e-02 -4.28149541e-04]\n", 230 | " [-1.29193902e-02 -2.34219412e-03]\n", 231 | " [-1.01739989e-02 -4.71094162e-03]\n", 232 | " [-3.04443035e-03 -6.63552890e-03]\n", 233 | " [-1.35599853e-03 -9.95826962e-03]\n", 234 | " [ 1.05448520e-03 -1.52497983e-02]\n", 235 | " [-6.24920490e-03 -7.41309348e-03]\n", 236 | " [-4.42914032e-03 -7.86928949e-03]\n", 237 | " [-7.66454752e-03 -4.25159517e-03]\n", 238 | " [-7.31605833e-03 -3.53749412e-03]\n", 239 | " [-7.00465766e-03 -5.06211133e-03]\n", 240 | " [-8.71377324e-03 -3.38372357e-03]\n", 241 | " [-1.11287485e-02 -2.99285388e-03]\n", 242 | " [-6.59286070e-03 -6.49724262e-03]\n", 243 | " [-5.36317851e-03 -6.93586317e-03]\n", 244 | " [-5.60898206e-03 -6.52954538e-03]\n", 245 | " [-5.90987143e-03 -6.12003125e-03]\n", 246 | " [-6.09953582e-03 -6.71876601e-03]\n", 247 | " [-1.57286444e-01 5.42901625e-02]\n", 248 | " [ 3.19573943e-03 -1.85483445e-02]\n", 249 | " [-7.71413066e-03 -6.01581538e-03]\n", 250 | " [-6.51019805e-03 -5.64191123e-03]\n", 251 | " [-8.83430764e-03 -4.56927425e-03]\n", 252 | " [-6.69844756e-04 -1.02585488e-02]\n", 253 | " [-2.13128021e-03 -9.67105366e-03]\n", 254 | " [-9.05475948e-03 -5.19140124e-03]\n", 255 | " [ 8.61079610e-04 -1.67023602e-02]\n", 256 | " [-8.66844150e-03 -3.75916976e-03]\n", 257 | " [-6.89763057e-03 -4.70872459e-03]\n", 258 | " [-9.97967547e-03 -4.21565130e-03]\n", 259 | " [-3.02611832e-04 -1.26405448e-02]\n", 260 | " [-2.83870866e-03 -1.07118996e-02]\n", 261 | " [-7.40407867e-03 -5.55672738e-03]\n", 262 | " [-1.19757403e-02 -2.25919117e-03]\n", 263 | " [-5.13761390e-03 -6.85032416e-03]\n", 264 | " [-1.02962746e-02 -3.43697459e-03]\n", 265 | " [-3.62255103e-03 -8.55366780e-03]\n", 266 | " [-1.51061530e-03 -1.23978298e-02]\n", 267 | " [ 7.86857423e-04 -1.30321535e-02]\n", 268 | " [-6.83083311e-04 -1.40166734e-02]\n", 269 | " [-2.79258959e-03 -1.04487412e-02]\n", 270 | " [-1.16856500e-02 -1.52966488e-03]\n", 271 | " [-2.55972446e-03 -1.09740127e-02]\n", 272 | " [ 4.39861013e-03 -1.86212443e-02]\n", 273 | " [-5.22265914e-03 -6.83751272e-03]\n", 274 | " [ 4.64324265e-04 -1.39573899e-02]\n", 275 | " [-1.02126926e-02 -4.12281740e-03]\n", 276 | " [-1.10669241e-02 -2.18740278e-03]\n", 277 | " [-2.78613015e-03 8.09413145e-03]\n", 278 | " [ 1.71666488e-03 -1.71445757e-02]\n", 279 | " [-1.33500081e-03 -1.08136856e-02]\n", 280 | " [-5.33720757e-03 -6.83340622e-03]\n", 281 | " [-1.15343213e-02 -2.61764653e-03]\n", 282 | " [ 8.80104659e-03 -3.24897962e-02]\n", 283 | " [-3.03489303e-03 -9.31547554e-03]\n", 284 | " [-6.61707597e-03 -6.71709528e-03]\n", 285 | " [-2.66395778e-02 7.56858869e-03]\n", 286 | " [-1.09366213e-03 -9.29695361e-03]\n", 287 | " [-4.85258990e-04 -1.39407479e-02]\n", 288 | " [ 4.12234642e-03 -1.82806261e-02]\n", 289 | " [-6.80462932e-03 -4.66299049e-03]\n", 290 | " [-9.88867572e-03 -2.68646108e-03]\n", 291 | " [-1.00024630e-02 -3.83673441e-03]\n", 292 | " [-1.22635736e-03 -1.08491021e-02]\n", 293 | " [-1.10121437e-02 -3.11299716e-03]\n", 294 | " [-1.80993941e-02 3.34098339e-04]\n", 295 | " [-3.85633237e-03 -8.83564642e-03]\n", 296 | " [-8.63796232e-03 -5.50291310e-03]\n", 297 | " [-4.34400676e-03 -6.36634295e-03]\n", 298 | " [-2.13682496e-03 -1.15543806e-02]\n", 299 | " [-8.06640570e-03 -4.06879875e-03]\n", 300 | " [ 3.51706540e-03 -2.07316261e-02]\n", 301 | " [ 3.06854737e-03 -1.77404008e-02]\n", 302 | " [ 3.32589323e-02 -6.88808017e-02]\n", 303 | " [ 1.27295521e-02 -3.30280315e-02]\n", 304 | " [-1.76104178e-03 -8.82591129e-03]\n", 305 | " [-5.37159351e-03 -6.20607657e-03]\n", 306 | " [ 1.53382641e-03 -1.64067923e-02]\n", 307 | " [-3.34038363e-03 -8.55443995e-03]\n", 308 | " [-2.61047796e-03 -1.11025386e-02]\n", 309 | " [-1.53337737e-03 -1.00464040e-02]\n", 310 | " [ 1.22716544e-02 -5.23036614e-03]\n", 311 | " [-1.93853798e-03 -1.21084622e-02]\n", 312 | " [-7.81190224e-04 -1.35544584e-02]\n", 313 | " [ 1.90665924e-05 -1.30879064e-02]\n", 314 | " [ 1.04533406e-02 -2.96454993e-02]\n", 315 | " [-1.81622660e-03 -1.06519005e-02]\n", 316 | " [-2.27613192e-03 -1.07474677e-02]\n", 317 | " [ 1.20922516e-02 -3.97719005e-02]\n", 318 | " [-2.56049890e-04 -1.43044696e-02]\n", 319 | " [ 8.60548143e-04 -1.58779161e-02]]\n", 320 | "Intercepts:\n", 321 | " [ 0.01254532 -0.01454626 -0.01742132 -0.00252403 -0.0219357 -0.02337793\n", 322 | " -0.01586994 -0.02287372 0.00462123 -0.02079588 -0.00366046 0.02279043\n", 323 | " -0.0233296 0.00046263 -0.02249522 -0.01843829 -0.01764555 -0.03075016\n", 324 | " -0.01540735 -0.02104688 -0.03006956 0.01809862 -0.04031767 -0.00578173\n", 325 | " 0.00217753 -0.02792996 -0.01019084 -0.02167765 -0.02163611 -0.01680886\n", 326 | " -0.01798087 -0.01957478 -0.021573 -0.00600343 -0.02241012 -0.0506601\n", 327 | " -0.009647 -0.02082834 0.01308944 -0.02199827 -0.03231453 -0.04597162\n", 328 | " -0.01635609 -0.01548489 -0.02525962 -0.02066032 -0.01925746 -0.00933718\n", 329 | " -0.05496637 -0.03550692 -0.04130853 0.00032973 -0.05029722 -0.03640648\n", 330 | " 0.00937706 -0.03975363 -0.02762816 -0.00993848 -0.09582902 -0.04863484\n", 331 | " -0.05534066 -0.0555024 -0.03867695 -0.05522551 -0.04797136 -0.05834683\n", 332 | " -0.03622433 -0.01635038 0.00950989 -0.05647114 -0.03288603 -0.03548065\n", 333 | " -0.01935263 -0.04079287 -0.03721051 0.00887693 -0.05078675 -0.04107982\n", 334 | " -0.05497602 0.00954188 -0.041121 -0.05613876 -0.05085163 -0.05580706\n", 335 | " -0.0478918 -0.05325284 -0.05366367 -0.0423538 -0.02429554 -0.06294562\n", 336 | " -0.04611254 -0.05036665 -0.01003085 -0.05061034 -0.04175053 -0.05127845\n", 337 | " -0.01984979 -0.05024378 -0.04189298 -0.0543486 -0.04276961 -0.04568472\n", 338 | " -0.01872168 -0.04196594 -0.07080196 -0.03628768 -0.09213109 -0.04748736\n", 339 | " -0.00951601 -0.07047381 -0.04788268 -0.04483294 -0.03499559 -0.04654704\n", 340 | " -0.04621714 -0.07035922 -0.05542719 -0.11588529 -0.05119463 -0.06921828\n", 341 | " -0.03411805 -0.03606885 0.01110691 -0.02331986 -0.0733811 -0.05888178\n", 342 | " -0.03363236 -0.03436679 -0.03638426 -0.05357524 -0.03252641 -0.03277087\n", 343 | " -0.04527914 -0.03301596 -0.03374925 -0.05100191 -0.06174497 -0.01300505\n", 344 | " -0.04467762 -0.04521454]\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "from sklearn.svm import LinearSVC\n", 350 | "\n", 351 | "X_train, X_test, y_train, y_test = train_test_split(X_2, y_2, random_state = 0)\n", 352 | "\n", 353 | "clf = LinearSVC(C=5, random_state = 67).fit(X_train, y_train)\n", 354 | "print('Coefficients:\\n', clf.coef_)\n", 355 | "print('Intercepts:\\n', clf.intercept_)" 356 | ] 357 | } 358 | ], 359 | "metadata": { 360 | "kernelspec": { 361 | "display_name": "Python 3", 362 | "language": "python", 363 | "name": "python3" 364 | }, 365 | "language_info": { 366 | "codemirror_mode": { 367 | "name": "ipython", 368 | "version": 3 369 | }, 370 | "file_extension": ".py", 371 | "mimetype": "text/x-python", 372 | "name": "python", 373 | "nbconvert_exporter": "python", 374 | "pygments_lexer": "ipython3", 375 | "version": "3.6.4" 376 | } 377 | }, 378 | "nbformat": 4, 379 | "nbformat_minor": 2 380 | } 381 | -------------------------------------------------------------------------------- /Classification/logistic_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Logistic regression" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from sklearn.preprocessing import StandardScaler\n", 42 | "from sklearn.model_selection import train_test_split" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## 2. Data Input and Variables" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 6, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "### Data Input\n", 66 | "# df = \n", 67 | "\n", 68 | "### Defining Variables \n", 69 | "# X = \n", 70 | "# y = \n", 71 | "\n", 72 | "### Data Input Example \n", 73 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 74 | "\n", 75 | "X = df[['horsepower']]\n", 76 | "y = df['price']" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## 3. The Model" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "*Run to build the model.*" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 17, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Accuracy of Logistic regression classifier on training set: 0.01\n", 103 | "Accuracy of Logistic regression classifier on test set: 0.00\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "from sklearn.linear_model import LogisticRegression\n", 109 | "\n", 110 | "X_train, X_test, y_train, y_test = train_test_split(X, y,\n", 111 | " random_state = 0)\n", 112 | "\n", 113 | "clf = LogisticRegression().fit(X_train, y_train)\n", 114 | "\n", 115 | "print('Accuracy of Logistic regression classifier on training set: {:.2f}'\n", 116 | " .format(clf.score(X_train, y_train)))\n", 117 | "print('Accuracy of Logistic regression classifier on test set: {:.2f}'\n", 118 | " .format(clf.score(X_test, y_test)))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "### 3.1. Logistic regression regularization: C parameter" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 22, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "Logistic Regression with C = 0.1\n", 138 | "Accuracy of Logistic regression classifier on training set: 0.01\n", 139 | "Accuracy of Logistic regression classifier on test set: 0.00\n", 140 | "\n", 141 | "Logistic Regression with C = 1\n", 142 | "Accuracy of Logistic regression classifier on training set: 0.01\n", 143 | "Accuracy of Logistic regression classifier on test set: 0.00\n", 144 | "\n", 145 | "Logistic Regression with C = 100\n", 146 | "Accuracy of Logistic regression classifier on training set: 0.07\n", 147 | "Accuracy of Logistic regression classifier on test set: 0.00\n", 148 | "\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "for this_C in [0.1, 1, 100]:\n", 154 | " clf = LogisticRegression(C=this_C).fit(X_train, y_train)\n", 155 | " print('Logistic Regression with C = {}'.format(this_C))\n", 156 | " print('Accuracy of Logistic regression classifier on training set: {:.2f}'\n", 157 | " .format(clf.score(X_train, y_train)))\n", 158 | " print('Accuracy of Logistic regression classifier on test set: {:.2f}\\n'\n", 159 | " .format(clf.score(X_test, y_test)))" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "Python 3", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.6.4" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /Classification/mlp_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "mlp_pytorch.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "metadata": { 20 | "collapsed": true, 21 | "id": "jqZrGQOG2ng-", 22 | "colab_type": "text" 23 | }, 24 | "cell_type": "markdown", 25 | "source": [ 26 | "# MultiLayer Perceptron" 27 | ] 28 | }, 29 | { 30 | "metadata": { 31 | "id": "EB7M0gMP2ng_", 32 | "colab_type": "text" 33 | }, 34 | "cell_type": "markdown", 35 | "source": [ 36 | "## 1. Libraries\n", 37 | "*Installing and importing necessary packages*\n", 38 | "\n", 39 | "*Working with **Python 3.6** and **PyTorch 1.0.1** *" 40 | ] 41 | }, 42 | { 43 | "metadata": { 44 | "id": "Xdemin4U2nhA", 45 | "colab_type": "code", 46 | "colab": {} 47 | }, 48 | "cell_type": "code", 49 | "source": [ 50 | "import sys\n", 51 | "import os\n", 52 | "# !{sys.executable} -m pip install http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-linux_x86_64.whl\n", 53 | "# !{sys.executable} -m pip install torch torchvision matplotlib\n", 54 | "!{sys.executable} -m pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl\n", 55 | "!{sys.executable} -m pip install torchvision matplotlib\n", 56 | "\n", 57 | "import torch\n", 58 | "import torch.nn as nn\n", 59 | "import torchvision.datasets as datasets\n", 60 | "import torchvision.transforms as transforms\n", 61 | "import torch.nn.functional as F\n", 62 | "from torch.autograd import Variable\n", 63 | "\n", 64 | "%matplotlib inline\n", 65 | "import matplotlib\n", 66 | "import matplotlib.pyplot as plt\n", 67 | "import numpy as np\n", 68 | "\n", 69 | "from timeit import default_timer as timer\n", 70 | "\n", 71 | "print(\"PyTorch version: {}\".format(torch.__version__))\n", 72 | "cudnn_enabled = torch.backends.cudnn.enabled\n", 73 | "print(\"CuDNN enabled\" if cudnn_enabled else \"CuDNN disabled\")" 74 | ], 75 | "execution_count": 0, 76 | "outputs": [] 77 | }, 78 | { 79 | "metadata": { 80 | "id": "ateauqQT2nhF", 81 | "colab_type": "text" 82 | }, 83 | "cell_type": "markdown", 84 | "source": [ 85 | "## 2. Variables\n", 86 | "*Indicate the root directory where the data must be downloaded, the directory where the results should be saved and the type of RNN (conventional, LSTM, GRU) and its respective hyper-parameters*" 87 | ] 88 | }, 89 | { 90 | "metadata": { 91 | "id": "Jt3gMpMY2nhG", 92 | "colab_type": "code", 93 | "outputId": "cd3b3fa6-02db-4cc7-d198-7cc9d10e454a", 94 | "colab": { 95 | "base_uri": "https://localhost:8080/", 96 | "height": 34 97 | } 98 | }, 99 | "cell_type": "code", 100 | "source": [ 101 | "# Make reproducible run\n", 102 | "torch.manual_seed(1)\n", 103 | "\n", 104 | "# Settable parameters\n", 105 | "params = {'root': './data/',\n", 106 | " 'results_dir': './results/',\n", 107 | " 'hidden_size': [1024, 2048],\n", 108 | " 'input_size': 784, # MNIST data input (img shape: 28*28)\n", 109 | " 'lr': 1e-3,\n", 110 | " 'weight_decay': 1e-10, # 5e-4, # 1e-10,\n", 111 | " 'momentum': 0.9,\n", 112 | " 'num_classes': 10, # class 0-9\n", 113 | " 'batch_size': 128,\n", 114 | " 'optim_type': 'Adam', # Options = [Adam, SGD, RMSprop]\n", 115 | " 'criterion_type': 'CrossEntropyLoss', # Options = [L1Loss, SmoothL1Loss, NLLLoss, CrossEntropyLoss]\n", 116 | " 'epochs': 15,\n", 117 | " 'save_step': 200,\n", 118 | " 'use_cuda': True,\n", 119 | " }\n", 120 | "\n", 121 | "# GPU usage\n", 122 | "print(\"GPU: {}, number: {}\".format(torch.cuda.is_available(), torch.cuda.device_count()))\n", 123 | "device = torch.device('cuda') if params['use_cuda'] and torch.cuda.is_available() else torch.device('cpu')\n", 124 | "\n", 125 | "# Ensure results directory exists\n", 126 | "if not os.path.exists(params['results_dir']):\n", 127 | " os.mkdir(params['results_dir'])" 128 | ], 129 | "execution_count": 0, 130 | "outputs": [ 131 | { 132 | "output_type": "stream", 133 | "text": [ 134 | "GPU: True, number: 1\n" 135 | ], 136 | "name": "stdout" 137 | } 138 | ] 139 | }, 140 | { 141 | "metadata": { 142 | "id": "BamOukfB2nhL", 143 | "colab_type": "text" 144 | }, 145 | "cell_type": "markdown", 146 | "source": [ 147 | "## 3. Dataset\n", 148 | "\n", 149 | "*Normalizing between (0.1307, 0.3081): global mean and standard deviation of the MNIST dataset*" 150 | ] 151 | }, 152 | { 153 | "metadata": { 154 | "id": "jpbW8mtb2nhN", 155 | "colab_type": "code", 156 | "outputId": "9d5d2d71-86ef-4184-d1d5-55c13b1b8537", 157 | "colab": { 158 | "base_uri": "https://localhost:8080/", 159 | "height": 34 160 | } 161 | }, 162 | "cell_type": "code", 163 | "source": [ 164 | "# Get train and test datasets\n", 165 | "# trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n", 166 | "trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", 167 | "# trans = transforms.Compose([transforms.ToTensor()])\n", 168 | "mnist_train = datasets.MNIST(\n", 169 | " root=params['root'], # directory where the data is or where it will be saved\n", 170 | " train=True, # train dataset\n", 171 | " download=True, # download if you don't have it\n", 172 | " transform=trans) # converts PIL.image or np.ndarray to torch.FloatTensor of shape (C, H, W) and normalizes from (0.0, 1.0)\n", 173 | "mnist_test = datasets.MNIST(root=params['root'], train=False, download=True, transform=trans) # transforms.ToTensor()\n", 174 | "print(\"MNIST Train {}, Test {}\".format(len(mnist_train), len(mnist_test)))\n", 175 | "\n", 176 | "# Dataloader: mini-batch during training\n", 177 | "mnist_train_dataloader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=params['batch_size'], shuffle=True)\n", 178 | "mnist_test_dataloader = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=params['batch_size'], shuffle=True)" 179 | ], 180 | "execution_count": 0, 181 | "outputs": [ 182 | { 183 | "output_type": "stream", 184 | "text": [ 185 | "MNIST Train 60000, Test 10000\n" 186 | ], 187 | "name": "stdout" 188 | } 189 | ] 190 | }, 191 | { 192 | "metadata": { 193 | "id": "QhKhpBRaU4b4", 194 | "colab_type": "text" 195 | }, 196 | "cell_type": "markdown", 197 | "source": [ 198 | "*Dataset examples*" 199 | ] 200 | }, 201 | { 202 | "metadata": { 203 | "id": "ikWHBC-2U6-B", 204 | "colab_type": "code", 205 | "outputId": "ee7d39d0-a513-4361-d3ea-83b720b2a377", 206 | "colab": { 207 | "base_uri": "https://localhost:8080/", 208 | "height": 140 209 | } 210 | }, 211 | "cell_type": "code", 212 | "source": [ 213 | "# Plot examples\n", 214 | "examples = enumerate(mnist_test_dataloader)\n", 215 | "batch_idx, (example_data, example_targets) = next(examples)\n", 216 | "\n", 217 | "fig, axes = plt.subplots(nrows=1, ncols=4)\n", 218 | "for i, ax in enumerate(axes.flat):\n", 219 | " ax.imshow(example_data[i][0]) \n", 220 | " ax.set_title('{}'.format(example_targets[i]))\n", 221 | " ax.set_xticks([])\n", 222 | " ax.set_yticks([])\n", 223 | " plt.tight_layout()" 224 | ], 225 | "execution_count": 0, 226 | "outputs": [ 227 | { 228 | "output_type": "display_data", 229 | "data": { 230 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAB7CAYAAAAhbxT1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADzhJREFUeJzt3X+wlVPfx/HPKf2gRHFOftSIcBXR\ndJD4A49GzfO4yaj7jsFDJPlZzTMaJMN4ECNTmUJ+FXfRmTQ43OE2UuE4JkVp6kqlntL4TT+oker5\nA8t3XZ192uecffa19j7v11/f1Vp778V19lnnWmtd31WyZ88eAQAQmmZpdwAAgJowQAEAgsQABQAI\nEgMUACBIDFAAgCAxQAEAgrRf2h3ItSiKBkn63+Q/S2oXx/HWFLqEGkRRdISk6ZKOk7RF0k1xHC9I\nt1f4E9+j8EVR1EXS55LWmH/+KI7j/06nR7lXdANUHMezJc3+sxxF0T8kDeZLFZzpkubGcXxeFEX/\nIekmSQxQgeB7VDC+jOO4W9qdaCxFN0BZURS11u9/Bf5n2n3BX6Io6izpFEn/JUlxHM+TNC/VTiEj\nvkdIS1EPUJKukfR+HMdr9tkS+dRT0heSxkVR9DdJX0kaGcfxknS7hQz4HoWrXRRFL0vqJmmdpFFx\nHK9It0u5U7SbJKIoaibpfyQ9nHZfsJeDJZ0kaUEcx5Gkf0qaE0VRsf/BVHD4HgVtq6SZkkZKOkHS\nvyW9Ukzfo6IdoCSdIWlbHMfL0+4I9rJZ0tdxHL/yR/kpSR0kHZ9el5AB36NAxXH8fRzHN8VxvC6O\n492SHpHUUUX0PSrmAepvkv6VdidQo/WSDvzjr3PFcbxH0m5Ju1LtFWrC9yhQURS1j6Lo6MQ/N5e0\nM43+NIZiHqB6Siqaudgis0zSJklDJSmKor9L+lH+dlmEge9RuE6T9E4URaV/lK+V9H+S1qbXpdwq\n5gGqk35ffEdg/rhjGiRpaBRFa/X7Gsff4zj+Ld2eoQZ8jwIVx/FbkqZIej+KopWSBksaGMdx0cxE\nlHAeFAAgRMV8BwUAKGAMUACAIDFAAQCCxAAFAAjSvp44ZgdFekrq0JbrlJ5srxPXKD1co/DVeI24\ngwIABIkBCgAQJAYoAECQGKAAAEFigAIABIkBCgAQJAYoAECQGKAAAEFigAIABIkBCgAQJAYoAECQ\nGKAAAEFigAIABGlf2cyBepk7d65XjqLIxWVlZS5u27Zt3voEoLBwBwUACBIDFAAgSEzxoU4qKyu9\n8rx581y8ZcsWF0+fPt1r16JFCxc3b97cxcuWLfPadenSJRfdBJAF+51dtGiRV1dVVZXxdU899ZSL\n169f7+I2bdp47VavXu3ijh071rl/3EEBAILEAAUACBJTfKiTlStXeuXZs2e7+Msvv8z4urPOOsvF\nRx99tIsff/xxr924ceMa2sUm6ZJLLnHxm2++6dXZ63LAAQfkrU9ouM2bN3vl+fPnu3jdunUuXrJk\nScZ2e/bs8epKSkpcvHXrVhd///33DeqrJP38889e+ccff3QxU3wAgKLBAAUACBIDFAAgSDlZgxo7\ndqxXPuecc1zct2/fXHwEUvTVV1+5eMKECV7dN9984+Lrr7/exWPGjPHadejQwcX77ffXj91vv/2W\ns342ZcuXL3dxct3Cfgdr2zqMMMycOdPFkydP9uoa8/q1bt3aK/fs2dPFn332mVc3aNAgF+/YscPF\nvXr18to19LER7qAAAEFigAIABKkkuQUxodZK9yZm26IkHX744S62TycfccQRdepcaOxU14wZM7y6\nG2+80cXJW+V6Ktl3Eyer65StZLaIoUOHuvi7777z6uy03qRJk1zcrFmT+dsn2+uU02uU9OCDD7r4\n9ttvz9iue/fuLk5Ow/bv37/B/Tj44INdbDOGpCyIa1Qbuz27U6dOLk5O12ZiH92QpC+++MLFZ555\npld36aWXuri8vNzFJ554oteuXbt2WX12jtR4jZrMbxEAQGFhgAIABIkBCgAQpJxsM2/fvr1X/vrr\nr11cUVHhYrtOI/kZrkOxfft2FydTxtj1Fru9WpI2bdrk4vHjxzdS7/LDZiiX9l53suw6RhNadwpO\naWlpxjq7Jrp27VoXX3755Tnvh12btWsdqN2qVatcXNu6k11fvOaaa1xs160kadeuXS5O/p61j3mE\njt8oAIAgMUABAIKUk23mn3zyiVc+5ZRTamw3bNgwr3zHHXe42G5PlaQDDzwwm4/Omr1tXrFihVc3\nZ84cF7/22msujuM46/fv1q2bi+1T/Q2Q2jbz/fff3yv/+uuvLrbTnJI0ceJEFwe0rTifgtjCbJ/0\nP/nkk706O1VttxLbn3VJOumkk1ycPHjOsr8zbNYYSercubOLly5duo9e500Q16g21dXVLj7jjDNc\nnJw2t79re/To0fgdyx+2mQMACgcDFAAgSDmZ4tu9e7dXvvLKK138wgsv/PVmic+yU0ktW7b06lq1\nauXiY445xsW//PKL187uZLGSh2/ZJKdbtmyp8TUN8cADD7h49OjRuXjLvE7x/fDDDy62mUAkP6Hr\nxo0bvbpk23yxPwfJ3aF2qip5QJ892C85rVxPQUwfTZkyxcX33HOPV2d31eZacveg/ewbbrih0T63\njoK4RlYySXLXrl1dvGHDBheXlZV57Ww2myLDFB8AoHAwQAEAgsQABQAIUk4eKU5uhXz++edd3Lt3\nbxfff//9XjubjcFmcEhKZm2wRowYkXU//5Rct7LbbWubr7fbqKdNm+bVDRw4sM79CMm4ceNcXNsh\ngjlat2mwp59+2sXJjB/PPfdcxtdNnTrVxTZjxkEHHZTD3qUrebpAY0p+9/P52YXMZi+X/HUna/bs\n2fnoTrC4gwIABIkBCgAQpEbPGnjzzTe7+Pzzz/fqZs6c6WK7DVzyE1zarcyLFy/22mXKXnDrrbd6\nZfvZkydP9uoybTtPJlW0GSjs1vdisH79ehcfddRRXl1ya3la7FTvggULXJyclrWPKNiD/CRp5MiR\nLn711VddfMUVV+Ssn2no27eviz/44ING/Syb2NQ+noDs2Z/fJLt132bmkKSdO3e6OMRk27nGHRQA\nIEgMUACAIOUkk0Rjs2fY2HOXpMy7ymxSTMnPaJHtdM4zzzzjlW2GjDzIayYJu/uqe/fuXt3q1atd\nnDyrJplYNpeS51D17NnTxbU9UW/P4xo+fLhXd+yxx7rYTh2/8847Xrs6JCsOLktBY/v0009d3KtX\nL6/OTp8nEwunKLhrVFlZ6ZUHDBiQ1etsItlHH33UxeXl5bnpWHrIJAEAKBwMUACAIDFAAQCCVBCH\n09st3dlu706urVVUVGT1Ops92Ga+LnY2G/udd97p1dn/l8kt+bleg7KHIybXoDKtO9l5eUkaMmSI\ni+3jCpK/LnLXXXe5eNasWV67oUOHZtnjpuett95KuwsF77TTTvPK9me4qqoq4+ts3amnnuri008/\n3Wt36KGHujh5qKTNvpN8lCY03EEBAILEAAUACFLY93cNkEwwm9zWaXXo0MHFCxcudLHNSFDsbrvt\nNhe//PLLXt3HH3/s4vvuu8+rswmA27Zt2+B+2ES1L730UsZ2o0aNcvHYsWO9Opv41U4ZStL8+fNr\nfL/GPNQvVDt27PDKNoHpIYcckvF1NrNIcip98ODBOepdcTvssMO88rvvvuvie++918UfffSR1+7D\nDz90sZ1ur66uzvhZr7/+uldet26dix955BEXh5iZgjsoAECQGKAAAEFigAIABKkgUh3Vxy233OKV\nbQqWdu3aeXX28LuLL764cTuWvbymOqpN165dXWznryXphBNOcPHcuXNd3L59e69dmzZt6vy5NsO6\n5D9i0K1bNxcvWbLEa9eyZcuM72nX0Ow2c7stV6rTmlRwaXSyNWzYMK9sTxewh3pedtllXruJEye6\n2KYQk/zDSmtbI7HrIjblVPI9bHqrBijYa5T07bffunjbtm0unjNnjtfuiSeecLFNVZZkH7tInvKQ\nZ6Q6AgAUDgYoAECQimqKz26bTU7j7dq1y8VPPvmkV3f11Vc3bsfqJ5gpPpvBoU+fPl7dhg0banzN\n8ccf75UvvPDCOn+unc6QpOnTp9fY7qKLLvLKNmN50tKlS11sMyLYKRGpTpkkCnb6KPn/d/To0S7O\n9P+6vuzWf8mfQkxmDLngggtcPHDgwFx8fMFeo/r66aefXHz33Xd7dZMmTXKxnXrfunVro/erFkzx\nAQAKBwMUACBIRZVJYsqUKS62U3pJzZs3z0d3ioZ96t0+yS75h6Y99NBDLl61apXX7uGHH26k3u2d\n+aI25513nouPPPJIF/fv3z+nfSoEpaWlXtlOfdvko9OmTfPa2QMLk+8xdepUF9uDL5MHQNrsLcg9\ne5Br586dM7br169fPrpTb9xBAQCCxAAFAAgSAxQAIEgFv83cbnO2mQZ2796d8TXJp64HDBiQ+441\nXDDbzLP19ttvu3jevHle3WOPPVbn9xs5cqRXfu+991y8aNEiF1977bVeO5vJ/uyzz/bqli1b5uLx\n48fXuU81aHJbmO11WblypVf3xhtv5Ls72QjiGq1Zs8bFyewn9lGJ+hwiuHPnTq/87LPPunj48OEZ\nX2czp48ZM6bOn5tDbDMHABQOBigAQJAKforPZo/o3bu3i5cvX57xNZ9//rlXtlODASm4Kb4mKojp\no3waMWKEi1esWOHV2ewcAUnlGm3atMkr20cZkr+frrrqKhfbxzVsFhfJT8C7YMECF8+aNctrZx8F\nSP6O79Kli4vt9WvduvVe/w15xBQfAKBwMEABAILEAAUACFLBpzpq0aKFi5MZzC07v2pT3ACom5KS\nv5YLmjXjb9xMNm/e7JVrWxe36aRmzJjh4mTKttoen8kkeXjoqFGjXJzyutM+8dMFAAgSAxQAIEgF\nP8W3ceNGF1dVVWVsZw9Ia9WqVaP2CWgqFi5c6JUXL17s4vLy8nx3JygdO3b0yp06dXKx/b2VlMwK\nkQ2bvVySzj33XBdfd911Xp3N6B867qAAAEFigAIABKngp/iyNWTIkLS7ABSd7du3e+Xq6moXN/Up\nvuShjBUVFS5+8cUXvbrKykoX9+jRw8XHHXec185mvbHTeMnpxOTOvULFHRQAIEgMUACAIDFAAQCC\nVPBrUGVlZS7u16+fi+3hdsl2AJBvffr0qTGWpAkTJuS7OwWBOygAQJAYoAAAQSr4Awstu+V127Zt\nXl1paWm+u9NQHFhYGJrcgYUFiGsUPg4sBAAUDgYoAECQGKAAAEEqqjWoIsMaVGFgfSN8XKPwsQYF\nACgcDFAAgCDta4oPAIBUcAcFAAgSAxQAIEgMUACAIDFAAQCCxAAFAAgSAxQAIEj/D49OSB5oVIuT\nAAAAAElFTkSuQmCC\n", 231 | "text/plain": [ 232 | "
" 233 | ] 234 | }, 235 | "metadata": { 236 | "tags": [] 237 | } 238 | } 239 | ] 240 | }, 241 | { 242 | "metadata": { 243 | "id": "dWCCngwF2nhS", 244 | "colab_type": "text" 245 | }, 246 | "cell_type": "markdown", 247 | "source": [ 248 | "## 4. The Model: MLP\n", 249 | "$y_t = \\sigma(W x_t)$" 250 | ] 251 | }, 252 | { 253 | "metadata": { 254 | "id": "N2UAizM92nhT", 255 | "colab_type": "code", 256 | "colab": {} 257 | }, 258 | "cell_type": "code", 259 | "source": [ 260 | "class MLP(nn.Module):\n", 261 | " def __init__(self):\n", 262 | " super(MLP, self).__init__()\n", 263 | " self.input_size = params['input_size']\n", 264 | " self.hidden_size = params['hidden_size']\n", 265 | " self.mlp = nn.Sequential(\n", 266 | " nn.Linear(self.input_size, self.hidden_size[0]),\n", 267 | " nn.Linear(self.hidden_size[0], self.hidden_size[1]),\n", 268 | " nn.Linear(self.hidden_size[1], params['num_classes'])\n", 269 | " )\n", 270 | " self.softmax = nn.LogSoftmax() # nn.ReLU() # nn.LogSoftmax() # Softmax()\n", 271 | " \n", 272 | " def forward(self, x): \n", 273 | " out = self.mlp(x)\n", 274 | " out = self.softmax(out)\n", 275 | " return out" 276 | ], 277 | "execution_count": 0, 278 | "outputs": [] 279 | }, 280 | { 281 | "metadata": { 282 | "id": "cw-AqMCfOgPN", 283 | "colab_type": "text" 284 | }, 285 | "cell_type": "markdown", 286 | "source": [ 287 | "*Instantiate model and optimizer*" 288 | ] 289 | }, 290 | { 291 | "metadata": { 292 | "id": "9T6W1JeOOlTW", 293 | "colab_type": "code", 294 | "outputId": "116e71fc-f672-4a88-8b4a-15debb207249", 295 | "colab": { 296 | "base_uri": "https://localhost:8080/", 297 | "height": 34 298 | } 299 | }, 300 | "cell_type": "code", 301 | "source": [ 302 | "# Instantiate model\n", 303 | "model = MLP()\n", 304 | "\n", 305 | "# Transfer model to device (CPU or GPU according to your preference and what's available)\n", 306 | "model = model.to(device)\n", 307 | "\n", 308 | "# Loss criterion\n", 309 | "if 'CrossEntropyLoss' in params['criterion_type']:\n", 310 | " criterion = nn.CrossEntropyLoss()\n", 311 | "elif 'L1Loss' in params['criterion_type']:\n", 312 | " criterion = nn.L1Loss()\n", 313 | "elif 'SmoothL1Loss' in params['criterion_type']:\n", 314 | " criterion = nn.SmoothL1Loss()\n", 315 | "else: # NLLLoss\n", 316 | " criterion = nn.NLLLoss() \n", 317 | "\n", 318 | "# Optimizer\n", 319 | "if 'Adam' in params['optim_type']:\n", 320 | " optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])\n", 321 | "elif 'SGD' in params['optim_type']:\n", 322 | " optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], momentum=params['momentum'])\n", 323 | "elif 'RMSprop' in params['optim_type']:\n", 324 | " optimizer = torch.optim.RMSprop(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], momentum=params['momentum'])\n", 325 | "\n", 326 | "# New results dir based on model's parameters\n", 327 | "res_dir = params['results_dir'] + 'mlp_{}_{}_lr{}_weight{}_trainSize_{}_testSize_{}/'.\\\n", 328 | " format(params['criterion_type'], params['optim_type'], params['lr'], params['weight_decay'], len(mnist_train), len(mnist_test))\n", 329 | "\n", 330 | "if not os.path.exists(res_dir):\n", 331 | " os.mkdir(res_dir)\n", 332 | "\n", 333 | "print(\"res_dir: {}\".format(res_dir))\n", 334 | "log_file = open(res_dir + 'log.txt', 'w')" 335 | ], 336 | "execution_count": 0, 337 | "outputs": [ 338 | { 339 | "output_type": "stream", 340 | "text": [ 341 | "res_dir: ./results/mlp_CrossEntropyLoss_Adam_lr0.001_weight1e-10_trainSize_60000_testSize_10000/\n" 342 | ], 343 | "name": "stdout" 344 | } 345 | ] 346 | }, 347 | { 348 | "metadata": { 349 | "id": "t7ZCKLry2nhV", 350 | "colab_type": "text" 351 | }, 352 | "cell_type": "markdown", 353 | "source": [ 354 | "## 5. Train" 355 | ] 356 | }, 357 | { 358 | "metadata": { 359 | "id": "qaLuhdcu2nhW", 360 | "colab_type": "code", 361 | "colab": {} 362 | }, 363 | "cell_type": "code", 364 | "source": [ 365 | "start_timer = timer()\n", 366 | "\n", 367 | "loss_arr = []\n", 368 | "train_acc_arr = []\n", 369 | "first_time = True\n", 370 | "total_num_steps = len(mnist_train_dataloader)\n", 371 | "\n", 372 | "# model.train()\n", 373 | "model.zero_grad()\n", 374 | "optimizer.zero_grad()\n", 375 | "for e in range(1, params['epochs']+1):\n", 376 | " for i, (img, label) in enumerate(mnist_train_dataloader):\n", 377 | " img = Variable(torch.squeeze(img)).to(device)\n", 378 | " img = img.view(-1, params['input_size'])\n", 379 | " label = Variable(label).to(device)\n", 380 | " \n", 381 | " # Forward\n", 382 | " out = model(img)\n", 383 | " loss = criterion(out, label)\n", 384 | " \n", 385 | " # Backward\n", 386 | " optimizer.zero_grad()\n", 387 | " loss.backward()\n", 388 | " optimizer.step()\n", 389 | " \n", 390 | " loss_arr.append(loss.item())\n", 391 | " \n", 392 | " if i % params['save_step'] == 0:\n", 393 | " # Train Accuracy\n", 394 | " _, predicted = torch.max(out.data, 1)\n", 395 | " total = label.size(0)\n", 396 | " correct = (predicted == label).sum().item()\n", 397 | " acc = 100 * correct / total\n", 398 | " train_acc_arr.append(acc)\n", 399 | " # Print update\n", 400 | " perc = 100 * ((e-1)*total_num_steps + (i+1))/float(params['epochs'] * total_num_steps)\n", 401 | " str_res = \"Completed {:.2f}%: Epoch/step [{}/{} - {}/{}], loss {:.4f}, acc {:.2f}, best acc {:.2f}\".format(perc, e, params['epochs'], i+1, total_num_steps, loss.item(), acc, max(train_acc_arr))\n", 402 | " print(str_res)\n", 403 | " # Save log\n", 404 | " log_file.write(str_res)\n", 405 | " \n", 406 | "# Save model checkpoint\n", 407 | "torch.save(model.state_dict(), res_dir + 'model.ckpt')\n", 408 | "plt.show()\n", 409 | "log_file.close()\n", 410 | "\n", 411 | "end_timer = timer() - start_timer\n", 412 | "print(\"Model took {:.4f} mins ({:.4f} hrs) to finish training with best train accuracy of {:.4f}%\".format(end_timer/60, end_timer/3600, max(train_acc_arr)))" 413 | ], 414 | "execution_count": 0, 415 | "outputs": [] 416 | }, 417 | { 418 | "metadata": { 419 | "id": "h9iBENdtaX_o", 420 | "colab_type": "text" 421 | }, 422 | "cell_type": "markdown", 423 | "source": [ 424 | "*Plot training loss curve*" 425 | ] 426 | }, 427 | { 428 | "metadata": { 429 | "id": "zMRvLLKPX63B", 430 | "colab_type": "code", 431 | "outputId": "56fe4381-c4a6-4b18-fb5c-80d24fa50b87", 432 | "colab": { 433 | "base_uri": "https://localhost:8080/", 434 | "height": 265 435 | } 436 | }, 437 | "cell_type": "code", 438 | "source": [ 439 | "# Save training loss\n", 440 | "plt.plot(loss_arr)\n", 441 | "# plt.semilogy(range(len(loss_arr)), loss_arr)\n", 442 | "plt.savefig(res_dir + 'loss.png')" 443 | ], 444 | "execution_count": 0, 445 | "outputs": [ 446 | { 447 | "output_type": "display_data", 448 | "data": { 449 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XeYU1X6B/BvpjDDVIZhmKE34dCr\nSBMY7CiWVXR3dddFWfvaFXX1x1rWsrgsirKgiA3d1VUBGygKIiC9M5SD9KEMM8P0wtT8/kiZ3OQm\nuckkkxz4fp7HR3Jzk/tmcvPec99z7rkms9kMIiJSV0SoAyAiosZhIiciUhwTORGR4pjIiYgUx0RO\nRKS4qKbeYF5eqd/DZFJS4lBYWBHIcIJKpXgZa3CoFCugVrznWqxpaYkmd88p1SKPiooMdQg+USle\nxhocKsUKqBUvY22gVCInIiJXTORERIpjIiciUhwTORGR4pjIiYgUx0RORKQ4JnIiIsUpk8hPF5/B\nB9/uRlV1XahDISIKK8ok8k0yF58v/xX7jhWFOhQiorCiTCKvt94Ao66eN8IgInKkTCI3we00A0RE\n5zRlErkdG+RERBrqJXIiItJQLpGb2SQnItJQJpGbbCVy5nEiIg11EnmoAyAiClPKJHIbNsiJiLTU\nSeQmtsmJiPSok8itzGySExFpKJPI2R4nItKnTCJvwCY5EZEjdRK5tUnO0goRkZYyiZylFSIifcok\nciIi0qdMIjdx+CERkS5lErkNS+RERFrqJXL2dhIRaSiTyFlZISLSp0wiJyIifUzkRESKUyaRs7JC\nRKQvyshKQohpAEZb139ZSrnA4blLALwEoA7AYinlC8EIlIiI9HltkQshxgHoK6UcAeAKAK85rTIT\nwA0ARgG4TAjRO+BROuCgFSIiLSOllZUAbrT+uwhAvBAiEgCEEF0BFEgps6WU9QAWA7g4KJFy2AoR\nkS6vpRUpZR2AcuvDybCUT+qsjzMA5Dmsngugm6f3S0mJQ1RUpM+BJibEAACSkmKRlpbo8+tDhbEG\nB2MNHpXiZawWhmrkACCEuBaWRH6Zh9W8NpsLCyuMblKjtKwKAFBcUom8vFK/3qOppaUlMtYgYKzB\no1K851qsng4ERjs7LwfwNIArpJTFDk+dgKVVbtPOuoyIiJqIkc7OZACvApggpSxwfE5KeRhAkhCi\nsxAiCsAEAEuDESgr5ERE+oy0yH8LoBWA/wkhbMuWA9gppVwI4B4A/7Uu/1RKuS/gURIRkVtGOjvf\nBvC2h+dXAhgRyKA84vBDIiINZa7sZG2FiEifOomciIh0KZfIWVkhItJSJpGzskJEpE+ZRE5ERPrU\nS+SsrRARaSiTyE2cNIuISJcyiZyIiPQpl8jNrK0QEWkol8iJiEiLiZyISHHKJXLe6o2ISEuZRM4x\nK0RE+pRJ5EREpI+JnIhIceokctZWiIh0qZPIiYhIl3KJnKNWiIi0lEnkJtZWiIh0KZPIiYhIn3KJ\nnHOtEBFpKZPIOYstEZE+ZRI5ERHpUy+Rs7JCRKShXiInIiINJnIiIsUpl8hZWSEi0lImkXPUChGR\nPmUSORER6WMiJyJSnDKJnHOtEBHpUyaRExGRPuUSuZnz2BIRaaiTyFlZISLSpU4iJyIiXcolchZW\niIi0ooysJIToC+BLADOklG86PXcYQDaAOuuiW6SUxwMYIwBWVoiI3PGayIUQ8QDeALDMw2rjpZRl\nAYuKiIgMM1JaqQJwJYATQY7FGNZWiIg0vLbIpZS1AGqFEJ5WmyOE6AxgNYCnpJRu021KShyioiJ9\njRNJScUAgITEWKSlJfr8+lBhrMHBWINHpXgZq4WhGrkXUwF8B6AAwCIANwD43N3KhYUVfm2kpLQS\nAFBWegZ5eaV+vUdTS0tLZKxBwFiDR6V4z7VYPR0IGp3IpZQf2v4thFgMoB88JPLGYmWFiEirUcMP\nhRDJQojvhRDNrIvGAshqfFiuONcKEZE+I6NWhgCYDqAzgBohxEQAXwE4JKVcaG2FrxNCVALYiiC2\nxomIyJWRzs7NADI9PP86gNcDGBMREflAuSs7wUmziIg0lEnkvNUbEZE+ZRI5ERHpUy6Rs7BCRKSl\nXCInIiItJnIiIsUpl8g5aIWISEuZRG7isBUiIl3KJHIiItLHRE5EpDhlEjkLK0RE+pRJ5EREpE+5\nRG7msBUiIg3lEjkREWkxkRMRKU65RM7CChGRljKJnNcDERHpUyaRExGRPvUSOWsrREQaCiVy1laI\niPQolMiJiEiPcomclRUiIi1lEjlHrRAR6VMmkRMRkT71EjnnWiEi0lAmkbOyQkSkT5lETkRE+pRL\n5CysEBFpqZPIWVshItKlTiK3+nT5fhw9VRrqMIiIwoZyiRwAnntvY6hDICIKG8okcpNDbYV1ciKi\nBsokciIi0sdETkSkOHUSOUetEBHpUieRExGRrigjKwkh+gL4EsAMKeWbTs9dAuAlAHUAFkspXwh4\nlERE5JbXFrkQIh7AGwCWuVllJoAbAIwCcJkQonfgwmvAygoRkT4jpZUqAFcCOOH8hBCiK4ACKWW2\nlLIewGIAFwc2RCIi8sRraUVKWQugVgih93QGgDyHx7kAunl6v5SUOERFRfoSIwAgOa9c8zgtLdHn\n9wgFVeIEGGuwqBQroFa8jNXCUI3cB14rIIWFFX69cUlJpeZxXl74X6aflpaoRJwAYw0WlWIF1Ir3\nXIvV04GgsaNWTsDSKrdpB50SDBERBU+jErmU8jCAJCFEZyFEFIAJAJYGIjAiIjLGa2lFCDEEwHQA\nnQHUCCEmAvgKwCEp5UIA9wD4r3X1T6WU+4IRqIl3XyYi0mWks3MzgEwPz68EMCKAMeliHici0qfM\nlZ2qtMhPni7H2qycUIdBROeQQI9aCRpVjjhPz10PAOjWPlmpoVFEpC5V8qMyLXKbM1W1oQ6BiM4R\nCiVy31+z+3ABnp67DgUlZwIfEBE1mcqqWtTX85Yy7iiUyH3P5DO/2IGTpyvw46ZjQYiIiJpCbV09\n7puxEs9/wFs8uqNQIg91BEQUClU1dQCAo6fKQhxJ+FIokTOTExHpUSiR+/9aM2/XTERnMWUSeQRb\n5ETnJDPbYV4pk8gbk8dNvC0FEZ3F1EnkTsn4i58PhCgS/9XW1ePdb/dg/7HiUIdCpAyejHunTiJ3\n+jK/XXvE8GvDpUa+fX8+Vu88iZc+2hzqUIjoLKJMIvenRh5uJZWa2vpQh0CkHNbIvVMmkftzeqXX\nEj95uhzzvtmN8jM1AYiKiCj0FErkxjL5xr25+GTZr9rXOrTMZ36xE79k5fhUmgkUNiyaRl19Pcxs\nxp3VThVWYOanW9kgs1IokRtbb/aiLCzdmI1Kh0mrHFvm5ZWWL952tZg/vlp9yOVgQeGhtq4ed0xb\ngTe+2BnqUChA9H77sxZk4YcNR/H1L4ebPJ5wpEwiD6ca+aLVh7B0YzaKy6tx4DhHoISTM9WWA/S2\n/fkhjoQCRe/kqqyyGgBwppqzjAIKJXJ/auSNaXUb8cScNXhx/maUVFQbewHP9okoCBRK5K6ZfM/h\nghBE0qC6xjIKpeIMWwXhgrVxOhcplMhdl736ybamD4Sa3Ma9uZj79S4m6TCwdV8eVmw73qTb5AVB\n3imTyN3VyDfLXPywMRuApaMrFIzuZ+FyYZJRNbV1+H7DUZSUGywdBUBJRTVO5Jdrls1elIW1u04h\nt6jS6+vV+gur540FO/Hhd7JJtxlux+9lm4/h/tdWhtWZuDKJ3N3ww1kLs/Bf6wiSUF1wE2b7WcB8\nvyEbny7fj7e+2tVk23xo5mo88856/bvBnIV/6OP55fhpa9O2cMPd0o3ZOJJTGvTtbNqbi+mfbvOp\nAbh8yzF8/MM+lJ+pxd6jhUGMzjfK3HzZyOlVuJ96Gwlv6695aJeWgNYtmgc/IC9OW2+R59xCDqTD\nOSUoLKlC13bJyD7V8OOtN5sR4XSuI7OLUFVTd1bd1Pr/3rHcrLt7+2S0T0sI+vZyCytwIr8CA7u3\nCvq2/JFTUGEf2vvukxe5PJ9fXInv12ejIgD3xP33oiwAwN6jhejbJdXQaz5aus/+73BKNwolcu+Z\n3NMt/errzS6ljcLSKiQnNAubKXILS6vs45/1duKz0fPvbzK87vtL9gIAvp5+bbDCCRnbsMlAqDeb\n8e2awxgsWqNdq3jNc0++tQ4A8Nr9FyIpvlnAthkIhaVVeMHL7dzmfLkLB0+UBHS7/g5TDqeGo0Kl\nlca9/ok5a3HP9JX2x0dzSvHorF8w75s9bl/T1F9UZQBaGYFk+/jlZ2qQX+y9Ph0WDHxlB0+UnNVX\nBGYdPI2Fqw5h6rz1btcJ9tBcf3y5+iAqq1zjcvztF5c1XX+NvzbLPHz20/4m3aYyidxIq9lT4j1d\nckZTCztgPaqv3ZWju/62X/Mx+R8/GbrgJ1BXl4XJiYGL2jozpsxeG+owAiK3qBJ//3ATnn337L2R\nr63s4KkdEogmyrG8soD2S7k7o3b8HJ5+I009NYPzlmrr6nGqsAKzFu7EkvVHm7Rhpkwi95bk1mbl\n4L3Fe+2P6xz2CqPf7ZGcUqzNsiT2z1ZYjqhLrSNiPG7bzcHAUb2XIPKLKr2Wj8oqa7y+TyAF8sCy\n53ABHnh9FY7nBe4GukdySrFs8zGfXlNorfvb6v9NZcu+PBQ00TaXbwl+5+n+48WYOm8DZlvrzIHQ\n2N3tjmkr8I//bG3Ue+QWVeLF+Ztw9JT3zlbng8brn+/AU9bSld7zwaROIvfy/Nxvdmsuy/5xU0MC\n/nnbCUPbeO79jZj7zW63w4oWrTqI219Zbui9HH3x8wH8+R8/ub0CdOu+PEyZsxZfrHB/s4y/zFiJ\nB15fhTlf+jeCpOJMDXIKKnx6TSD3w/eW7EVZZQ0WrwvcZGXPvb8RH/+wD7mFDZ8rfKqWDY7nleHN\nBTvx9DvuSx1GlJ+pwbxvdnv9HpvixiW2USWBnArBSMPB2zr7sos0j/cfK/bpwsHPlu/HgeMlmPv1\nbq/rOv8+dh3Sbqcp90VlErmvThU21HR9rQfW1tfbvyTHYXBf+VlCsc20eOC4fidNlnUH2Lwvz+17\n2E6XN+3N1Sxfm5VjaEd9fPYa/PXtdaj28Lf4eOk+zP3a/6GGpwor8N36o15aIoGvH1XVGD+9D0X/\nVEmFdaK2RnZofrPmMH7JysEbX+xodEyl5dWNuu7Cn7O1yqpaL+UG/TdtzJnhSx9t9nrhoN51Esfz\ny/HaZ9s9vs5oi7uqpg7frT0c1FKLQok8NAXkzfvy9Mc0O6mvNyPfwAUrelbtMHbGoGfuN7sNXeFq\n60Sq9lDTXLblGNbuOmVouzW1dXjmnfWa0saz723E/37aj50HtQeWk6fLkV9sKSsY/VFW1dQZ7mD1\n5RS2MXk8FKMUKqtq8cvOk6iprbMfCGwzePrq5OmGYaQvzt+Ml+b7f6cqwxfBmc2YtWAnVmw9jvtm\nrMR9M1a6XXf9Hv19T1Mjd9ryyu0nXRo3eiqralHq5ox47jf6re8dB07b696LVh1EXb32t2N0b/hy\n9SHM+nw7/vPjPu8r+0mZRO7rUXn9bmMJSc/arBzN6auRDp33Fu/BlDmuHYL1mlq9/ldfWxe4BHEs\nr8zjuO9ZC3YiO7ehTp1bVIl92UU+zyJ3OKcUJ/LL8fEP++wHMHuicRoR8vTchpKC0a/xqbfWYcrs\ntYbOpj6wXmm450ghvlp9yL78VKFrCaIxyfjpuevxz08aV4N1y01YH/+wD/O+3WM5q7P+CPz9BI7f\nA2D5DoOttLIGm/fl4cPvPV8NeiK/3NgZi84O9G8Ddfr7ZqzEgzNXe39/J2Yz8PL8zfjql8NY59zI\n8fJFvPzRFpRV1uCk9fd4LDd412Mok8ibcqz3p8t9Hzr0S5Z+h+cuPyf2KnNodeX50NKfOm8DnvFQ\ni5XZRZpTxifnrMUrH2/BjU9961ecgM5nDMBxyfb5jYyvPnSyBLsOFeDV/27VXCX51FvrcCSn1Glu\nevfbW7LuiMfT35yCCuw+7Plqvl2HCjyOdFqx7Th+1pmrpPxMDcxmM8xmM16avxlLrH0JtmR7OKfU\nXhbz5Vi0Juskvllz2PgLYCnhTPvPlkYd9P769jq8+637ob3OPvfQP+Qo2FnA+RObTA2lMee+M29T\nbpzIL8dPW3zrjPeXMok8pllkQN6nzI/T0uraOpwu9m/EQbWb+q23Cx9s26utq8cTOi19Z+vcjJwp\nLK3Cq//VtiKN12pdd9QjOaV4f8ke1DqcpTiPtnHueNYwWUbo1NTWQTbiEmfnJDP9U/3y0nPvb9SW\nENz89j78bi8+W3EAi1Yd0l/BoOmfbsOL8zfjjmk/YbN0PeX/8DtpP4NwPNV//fMdmP+9xPH8cuw/\nXozPnBLbjgOnscZNY8GTd77ZgwUrDxpat+JMLT5f/isWrDyIvUeL3O67ADSnyFXVdTjmNBopp6AC\nq3ee9Dler3xo0O0PwL0CXvRQfjJynMspqMT2A6ct6wex+1OZRB5sjvVDZw/OXI3HZ6/x6f2efW8j\njua4vwLt0MlSlx52R/VmMzbuzUVpheuBx1aXdkxmbzv1stfW1ePk6XJMmb0Ge47oJ8yaWvcJvbqm\nTndHfeGDTVi5/SSWeRni9q2bVmBBSRWmzFmLu/75s99DxZ6Z8wv+Otf4CJDjDqUmxx/Ti/M32evw\nts7x/OJK1Neb8eOmbJfk5I1jCa6u3qwZDqvH+VR/xbYThvpjamrrsftwge66jWlFf7r8V3zwbcN+\ntHlfLorLqnTXdUynr3y8BVPnbcApH0dFOTI6+sVoGt91qMBtH8C2/fmGh606zvlSUKptzNXVm93W\n3W2MDE0OBGUu0Q+m6Z9uQ2JcdMDf94Nv92CoSLM/dv6NvfLxFreX4q/JysGyzcfQo32yy3Mf/7AP\nKYkxeHOB9nZmjjvVna+u8Brf/a+t0l2+50ihSyvexjaOfYvDCBu9H9cBN5dR5xY2/grR7b/6N+St\nrLJG03dy4HgJFqw8iDuv7mP/DGYz8OdpP9nXeffJi1BVXefxjPBYbhlSkmLw8BvaxOwppTZmUqiq\nmjr885Nt+ONlPTB2UDtN2bExbb58p7POd77Zg9SkGLx67yjXlR2+9CPWMde5RZVIbxnXiAi0CkrO\nYMZn23HdhV0bNmswkzv2Azmb+bl/o36+36C9puT9JXvx/hLgtQcuRFJcaKc7MJTIhRAzAAyHZT95\nUEq50eG5wwCyAdiad7dIKZWazs15/GegbNidgw27G47ItfXGh3vZWgz73IwJ/q9OD7ivnTnuRrC4\nS+Ju+VC4jIoKzUngydPl+OA76XIWZHL6h3PLcNfhAkz/ZBsmjOys+755RZWY+u4Gt9utrqnDFuk6\nrPS59xt/Zen8pfvwv58OYPajYxsWBvjs/XRJFX49VoTu7VtolutNZTvjf9sxtGdr3HNdX/uykkZc\nUv/d+qM4nleOWQu93381mBO7eZNfdCbkidzrr0oIMRZAdynlCACTAczUWW28lDLT+p9SSbwpZR0M\n5AGj6Ydjuqtp+zLpUHRkaIaRPj13vcdSljvrrSMVHDsMyypr7InD85BTM+Yv2YNlPnR42SYGM6qx\nc6as3ZVjL7G5a+2+/NEWw++30WkooKeDnDe+HJM8dfAHWzhMnmWkeXQxgEUAIKXcAyBFCJEU1KjO\nIU8a6MjUE4p5WdzVtL9eY7yDMDLCtxZ5XRPcLMRsNuPoKeP18EfetMyZ7qmPAbCM3V/0s7HRGDaO\nQwJPni73uaXpa4fa3K9321vXnnapQCarjXtzcff0Fd5HY+ls0sgsqP5ozNQXhl8ZxHxv5FeVAcDx\n3DDPuszRHCHEaiHEK0KIMJ36KTwZueuNnsb8rmoCnBzzis5gTZbrCIUdB/Lx0ExtHd7X3+Fj//at\nk9lX63ad8nlEkm3cf01tfVCPqM7jvo3wZ7+wDZ31lCT1ZiX01+xFWaiuqceqHZ5HteidyRg9oLg7\noDlf1GOzcOVB1JvN2OtmYICXjRly1EPdvrH86ex0/ranAvgOQAEsLfcbAHzu7sUpKXGIigrMUMJz\nWWMmfaqprUd1gEsz7+hMB/zaZ66dStHR4fXdmwHdC7lsDnnolExtlYiiM6GfDvbjZb9i0lV90CIx\nxu1l96mpnm9akZaWiJgY9+kgISkWyQkxhuIxeuOPAydKcPsry9HOhxtqnDztfWTM3uMlKHXzvVS4\nuUxg495cdMhI8uuGFcktmhv+zMG6KYqRRH4C2hZ4WwD2Q6mU8kPbv4UQiwH0g4dEXqhztR01vQXL\ngne5sCdNUSoJJE+zNZ7OL0VxUej352Ubs7Fi8zHMengMIiL0D9CPzFjh8T3y8kpRVOq+cZCfX4a/\n/vsXQ/Fs2GGsm8w2R1AgZ8QEgGnz3d+s5MF/rdBdnltYiTkGOlX1FBZWIC/BWGfn1t0n/b4TlKeD\ngJHSylIAEwFACDEYwAkpZan1cbIQ4nshhO1TjAUQuHktKWiMTM8bDOE657o/Qt/F1aCu3owvfj7o\ntrTibjioTX5xpdtJ3QBgyfqjhuv1L3xg/K5P55qp8/zv/PXEa4tcSrlGCLFZCLEGQD2A+4QQkwAU\nSykXWlvh64QQlQC2wkNrnChYnVWhcP9rqzB+WMdQh2H3w6ZsxPp5BbS3G4f4Ou/7uWSzzEOn9NDe\nR9bU1ENn8vJK/d7glNlrXC5aILX0aJ/sdmw8kapMJuMdzf7ejzctLdFtK0ipS/TfefrSUIdAjcQk\nTmejUA8lVyqRR0YqFS4RUZNgZiQiUhwTORGR4pjIiYgUx0RORKQ4JnIiIsUxkRMRKY6JnIhIccol\n8owA3kqKiOhsoFwi//ufh4U6BCKisKJcInc3TScR0blKuURORERaTORERIpjIiciUhwTORGR4pjI\niYgUd9Yl8vtv6Nfo95gwsnPjAyEiaiJnXSIf1D1Nd/llQzsYfo+2rXjRERGp46xL5AAQF+N6T+nf\nXdwd7dPiDb0+OtK/G9gSEYWCkon8rccycfkF7lvY//rLKN3lLRJj7P++bGgHXNivDXp1SsGM+y9E\nlMNt5AZ2T0WfLi0DFzARURC5Nl0VEB0Vgf5dU/H9hmzd55tF67eob7m0Bz5Yshd/vFygTaq2dd6z\nYwtkHSpAlzZJiIyIwKO/HYhNe3Px70VZAY+fiCiQlGyRA0DL5FgAQLPoCNwwtqvHda8fY3k+PSUO\nU24e7JLEAQA6V/4PFq719qmTzrf/e8yAtvjbpKFIcWjpn2siGzllwn2/6dvoGC4e3L7R70GkMmUT\neXpKHP7vT+fjX/dd6HG9GzO7GRuFYnZdFGHSJqk+XVqic0YSbr1cAABG9s1Ap4xETL9Pv5TjyQW9\nWvu0/pgBbXSXP/WHwbhtfE9D79E6pTl+d9F5LsvHDmyLlknGDka3XNoDk6/q5bD9IYZelzmwre7y\nIcK3v4Mex/l3hvTQ7+z21btPXoT0MJ9p88J++vsEha/enVOC8r7KJnIA6NImCXGxUejfrRUA4Gqd\nhK2Tnz0yuWlgTr6qFx6+cQAAIHNQO7z9eCZ6dGhhf37S+J4YP6yj/XGzKPd/2o6tE3BjpmtCBYDO\nGYl498mL8PbjmXjm1obW/6TxvdC6RXPNuu3T4tG9fQuMHqCfJN9+PNOedAf3SMOU3w9Cu7QEl/XG\nDGiLD/52Bd598iKX5+68ujdaWc9+AKBvl5YY1a8Nbr1cYNzgdujaNgl3XN3b7We1aeUUu6Onb9U/\nGNx6hfD6vgAQ4eNe/O6TF+Efd4/QfW7yVb1wp/XzPBCAoaw270wZh2dvG4onbh7k8py/Cbl3F9ek\n4LhPNqWE5tFun7v72j6Nfv/R/dtgaM/Wmr4sFT1y08CgvK+SNXJnHVonYM6jYzW18ZF9M7AmKwfd\n2iYFZBuJcdGalp/zDjXGmkyXrD8KAJjzWCaSWsRh3sId9mVXjeiEb9ceQceMRN1t3DTuPPswyajI\nCJeDivPjEX0yXN7jjYdG4/7XVtnfY1S/NhjRJ8Me+7G8ciMf1254nwwM75OBIzml2JddZG+lZg5q\n5zYumwkjOwEAYptF4dLz2+PzFQc0z9sOAN3aJtuXpSTGoLC0CgAwdkBbXNCzNU6XVOH9JXtw6GSp\n7nbMjkdrE9Cncwp2HS7UXfd3F3cHAMS46UcZ5ZBU631tBXgQEWFCx3T97330gDZYvfOk29deObwT\nFq87Yn98Yb82GD+8I47kuP49nrxlMADg9leWa5af1z4Z+48V+xO6i47pCTh6qkyzrGVSDC4a3A4r\nth7HRUPaY9GqQwCAv00aig7pCbiloga/HivChj25fm3zj5cLREVGYPonWzXfbUbLOOQUVPj0XjeN\nOw//+2m/X3FM+f0giI4t8NRb65BbVIkubRKRnhKHdbtP2dd56Mb+WJOVo/tZgzV7q9qHNwfOHZy3\nXdkTr9w1HKKjb6cyZqcfry1xtUyK1Vnb1R0TeuM31pp8THSkJmFcM6ozJl/VCzdf0h3xzV2PoVcM\n66j5op0T5D3Xea8nm2DCC38ehudvv8C+zPE9650/oBv3XtcXz09ueI9OGYm41M1Y/L5dUtEsKgI3\nX9Idv3Uo3Vw/phuuH9MNVw7vhOioSDw/+QJN69PxQBTbzPJ3GiLScN3oLhjaszVMJhPiYqPRoXUC\nHrKeDQ3t2RqXD+9k7xdJTYrRfKZrR3XBAxP747nbL8CbD43WxPnITQN8up6gpZu+jzsmeD8D8eQ3\nY7oi1WF/MpuBgee1crv+xMxumsfjh3dEm9R4l321V6eGff2KYR3xh8t6NLzmgo7wh96Q3bED2+ms\nCVw3uitee2A0hvVKty/rlJGICJMJFw9pjzuvaWiZt9c5M3RnRJ8Me8Pp7uv64qZxDfvYLZc2fEZ3\n9yq4wuFMedbDYzSP01NczxQ93fMgMS4aJocfZnpKnOZzAUD/bq1w97V9XfrOfC2n+uKsSeTOIiMi\n0DrFeI3TlqhTk7UJ+5lbh+CJmwcZ3vFG9M3QlngcknF0VCRG9WuD2GZRiG0WhX/eO1K3HGQT6VQz\n6JieqFv+AICHbxqA6y7sgrjYKLRrFY/2rb3He2NmN6Qmxej+WHt0bGH4Myc0j8acxzJxyfkdcLmH\nhNE+LQE9O+mf+tt+GyaYcM2HtNNDAAAMa0lEQVSoLi4HrcS4Zpj96FjcfW0f/OXGgbhqRGfMfHA0\nXr5rBMYP64SOrRMs31PrBERHRaJD6wTExUYj2lri6tM5BX27pjps0DWGSU59Dc1jojDr4TEY1N2S\nZP/6hyG4/4Z+GNYn3eW1Uyed71JeeOCG/rqf9eqRnfHqvSPtZZC0Fs1x/w398OZDozHYQI3f1lnv\neAB75e4RePimAfbHN407Dxc5dALH6lxb4S6xaEtlJrRtZeD6C4eDiq2/ZYTT30nb5+S+QXHl8E72\nf996hcDtVzV8L/Gx0bhsaAekpzTHTZf00AwTbtsqHlGRrl/sEIdBC82d/g5XjuikeTygW6rm8w7r\nrf0MyQmWz9a5jeXsqk2q+xwzweG9F/zjatx1TeNLTO6cFaWVQLhpXDe0SGiGi4doR0DEx0b73Ko3\nqmVSLC67oANkdhF+M7qLy/PuLmDq1i4JB46XoLVDa6Jf11T0c0xUbvTt0hKDurfCuEHt0LdrKsYP\n1+7IPTq0wL7sIt2LqgJhcI80dGt3HFcO0273oRsH4JNl+zWtJWfO5RBb4kxJjMGzDmcgjqb8fhAW\nrDzo0mqKj41CbLNIDO+djhXbTgBoKI85ah4ThXt/0xeFpVVoldzw905oHo2yyhoMPK8VJmZ2Q9tW\n8fjXX0ah/EwtHn5jNQBgYPdWMJm0LWVHj/1uIEorauwtt7jYaPzl+n4uZREjnPtPbJpFRaC6th6J\nDgeZC/u1QbPoCFw9qgtG9m2D1z7bjmdvG4rC0io0j4tBj7aJmPv1buvaZjx3+1DcMW2F/fV6xQHH\nM7HoqEi8M2Wcbhnhrmv64Ju1h9G/Wyu3ZT6zQ5Jv0zLOpUETEWHCy3eNQFpaIvLySjF10vn2g8Sr\n946y//1tbGc/nR1Kmg9O7I/v1h/FBT3TseDngygurwYA3He9tl9k4thuWG8tm/zz3pH2fe5PV/RE\nv66puKCX60FdT3RUhKYlH2hM5FZxsdG4brTnYYzBEB8bba9rOjOZTGiXFo9Ipx3gwYkDsPdIoaHW\nm7OoyAjc76alCABTbh6E2tr6RnUq3X1tHxx38yONbRaFp/94vsvy7u1b4P/+5Lq8sbq1S8bjv3ft\nYIyMiMCsh8fAZDLZE7k7kRERmiQOWE6xyyprEBUVYW/BRUVGIDm+mWa9hdOuwel8/dp+VGSE7tDV\nS85vjz1HCjV/wwcm9sdbX+3SjDqKj3XfwWjz4h3DcfBkCdq3TkB8bBTKz9Sibat4+wGzf7dU+1le\nx/REe3K8fkxXLFh5EH+e0BuRERGIMJlQbzZjZN8Ml7b0W49l2s98bNzVgof1Tsew3ulYm5WjWf72\n45l49r2NOJFfjgSHz2Uk+XXOaOgHc/77Zw5sixYJMZh2zwgkxzf8rQec1woDrOWs9JZxKC6vxjWj\nOtv3+0nje2Lbr/lISYrBnEfHAtCWb5vHRGn6U/QEsIvFKybyIBveOx2LVh3CbVcaGyLo7HmdlmZC\n82ic3zM49bYIk8ntBVVGXdArHejlfb1QsyWJ1KQYnC6pCso2IiNMPrfEbr7EUvd1bJkPPK8VZj8y\nVrNe//O8n4GlJsfay4VTbh6MpRuPYtwg/Rq3owkjO+sO23U+8LhreXszrHc6qmvr8MF3EoDloPbQ\nxP74efsJXDykPdq2isfaXTk4r12yl3dy75Ih7XGztYbufCB2dPe1fbBq+wlc5lAWHDOgrf0Mzcjv\noWPrBBzNLfO6XrAwkQdZ65Q4zHtinN+nVcE8HSOLV+4egdo639pPF/RKx5erD6G/Tjnr8d8PQlPc\nWjbCZOnYNrqtDq0TMPkq/zpqB5yXiq2/5qND6wR0bWNpAV83uovfozAiIkwYO7AdBnZPs8ffqkVz\n3DC2m3V7DS1mXz3+u4GIiDAZHorZIiEGV49yLW364rHfD8IDr69CH4dx4rZ6vK0jP5hMZoOjGAIl\nL6/U7w3aTvtUoVK8jNU39WYzcgsrkZ7S3OPBtjGxfr/hKCqrapu05Ocu3qrqOuw/UYzenVJgMplQ\nX28O+Y3Qw2E/cFRxphaxMZH2en1dfT0WrTqEkX0z0L9nRqNjTUtLdPsHZ4ucyA8RJhMygnzlp6cR\nQE0tplkk+nRuGCES6iQejuJitek0MiLCfoYRbGft8EMionMFEzkRkeIMlVaEEDMADIdlRM2DUsqN\nDs9dAuAlAHUAFkspXwhGoEREpM9ri1wIMRZAdynlCACTAcx0WmUmgBsAjAJwmRCicdcvExGRT4yU\nVi4GsAgApJR7AKQIIZIAQAjRFUCBlDJbSlkPYLF1fSIiaiJGSisZADY7PM6zLiux/j/P4blcAB67\naVNS4hAV5f+4yrQ0/RnkwpVK8TLW4FApVkCteBmrhT/DDz2NO/I6Jqmw0LcpJx2F27hRb1SKl7EG\nh0qxAmrFe67F6ulAYKS0cgKWlrdNWwAn3TzXzrqMiIiaiNcrO4UQIwE8J6W8VAgxGMBMKeWFDs/v\nAnAVgGMA1gK4RUq5L4gxExGRA0OX6AshXgEwBkA9gPsADAJQLKVcKIQYA+Af1lW/kFL+M1jBEhGR\nqyafa4WIiAKLV3YSESmOiZyISHFM5EREimMiJyJSHBM5EZHilLmxhKcZGEMQS18AXwKYIaV8UwjR\nAcB8AJGwXCz1RylllRDiFgAPwTJs820p5TwhRDSA9wF0gmXGyNuklAeDGOs0AKNh+a5fBrAxHGMV\nQsRZt5UOIBbACwC2h2OsDjE3B5BljXVZuMYqhMgE8BmAXdZFOwFMC+N4bwEwBUAtgKkAdoRxrJMB\n/NFh0fmwTCA4G5ZctUNKeY913ccB3Ghd/pyUcrEQIhnAfwAkAygDcLOUssDXOJRokRuYgbEpY4kH\n8AYsP1yb5wHMklKOBrAfwO3W9aYCuARAJoCHhRAtAdwMoMh6UdWLsCTXYMU6DkBf69/tCgCvhWus\nAK4GsElKORbATQD+Fcax2jwDwPajC/dYf5ZSZlr/uz9c4xVCpAL4G4ALAUwAcG24xgoAUsp5tr+r\nNe4PYPmdPSilHAUgWQgxXgjRBcDvHD7Xv4QQkbAciFZY410A4Al/4lAikcPDDIwhUAXgSminIsgE\n8JX131/DsnMNA7BRSlkspawE8AssR+qLASy0rvujdVmwrISlBQAARQDiwzVWKeWnUspp1ocdYLlS\nOCxjBQAhRE8AvQF8a10UtrG6kYnwjPcSAD9KKUullCellHeGcazOpsJycWQXh4qBLd5xAJZIKaul\nlHkAjsCy/zjGa1vXZ6okcudZFm0zMDY5KWWtdcdxFC+lrLL+OxdAG+jPDKlZbp361yyEaBakWOuk\nlOXWh5NhmWY4LGO1EUKsgeVU86Ewj3U6gEccHodzrADQWwjxlRBitRDi0jCOtzOAOGusq4QQF4dx\nrHZCiKEAsmEpBxV6isvDctsyn6mSyJ2F851f3cXm6/KAEUJcC0si/4vBbYcsVinlSADXAPjIaXth\nE6sQ4lYAa6WUh3zcdqj+rr8CeA6WMsWfAMyDtn8snOI1AUgFcD2ASQDeQ5juB07+DEtt3uj29Zb7\nHasqidzTDIzhoMza8QU0zADpbmZI+3Jrx4xJSlkdrMCEEJcDeBrAeCllcbjGKoQYYu00hpRyGyyJ\npjQcY4VlkrhrhRDrYPkB/x/C9O8KAFLK49bSlVlKeQBADizlyXCM9xSANdYz3wMAShG++4GjTABr\nYGldp3qKy8Nyv2ePVSWRLwUwEQCsMzCekFKG00TEP8JyuztY//8dgPUAhgohWgghEmCp1a2C5bPY\n6tZXA/gpWEFZe8RfBTDBoSc8LGOFZVK2R61xpwNICNdYpZS/lVIOlVIOB/AOLKNWwjJWwDIKRAjx\nmPXfGbCMDHovTONdCuAiIUSEteMzbPcDGyFEWwBl1vp3DYC9QgjbDLHXW+NdDuAqIUQz6/rtAOx2\nitf22XymzKRZzjMwSim3hyiOIbDURzsDqAFwHMAtsJxWxcLSiXGblLJGCDERwOOwDDd6Q0r5sbWn\n+h0A3WHpOJ0kpcwOUqx3AngWgOO0wn+ybj/cYm0Oyyl/BwDNYSkFbALwYbjF6hT3swAOA/g+XGMV\nQiTC0u/QAkAzWP62W8M43rtgKQUCwN9hGTIblrFa4x0C4O9SyvHWx70BvAVLQ3m9lPIR6/L7YckV\nZgDPSCmXWQ9CH8HSii8C8AfrmbNPlEnkRESkT5XSChERucFETkSkOCZyIiLFMZETESmOiZyISHFM\n5EREimMiJyJS3P8DZc58q6idw8QAAAAASUVORK5CYII=\n", 450 | "text/plain": [ 451 | "
" 452 | ] 453 | }, 454 | "metadata": { 455 | "tags": [] 456 | } 457 | } 458 | ] 459 | }, 460 | { 461 | "metadata": { 462 | "id": "bUXii0BC2nhZ", 463 | "colab_type": "text" 464 | }, 465 | "cell_type": "markdown", 466 | "source": [ 467 | "## 6. Test" 468 | ] 469 | }, 470 | { 471 | "metadata": { 472 | "id": "q3Qlsqzo2nha", 473 | "colab_type": "code", 474 | "outputId": "776414c0-e9de-46a7-f14a-54e2f8c67f0d", 475 | "colab": { 476 | "base_uri": "https://localhost:8080/", 477 | "height": 382 478 | } 479 | }, 480 | "cell_type": "code", 481 | "source": [ 482 | "# Testing doesn't require the use of gradients since weights aren't being updated\n", 483 | "model.eval()\n", 484 | "with torch.no_grad():\n", 485 | " correct = 0\n", 486 | " total = 0\n", 487 | " \n", 488 | " for img, label in mnist_test_dataloader:\n", 489 | " img = torch.squeeze(img).to(device)\n", 490 | " img = img.view(-1, params['input_size'])\n", 491 | " label = label.to(device)\n", 492 | " \n", 493 | " # Forward\n", 494 | " out = model(img)\n", 495 | " \n", 496 | " # Test\n", 497 | " _, predicted = torch.max(out.data, 1)\n", 498 | " total += label.size(0)\n", 499 | " correct += (predicted == label).sum().item()\n", 500 | "\n", 501 | " # Accuracy\n", 502 | " print('Test Accuracy: {:.4f} %'.format(100 * correct / total)) \n", 503 | "\n", 504 | " # Show 4 test images\n", 505 | " fig, axes = plt.subplots(nrows=2, ncols=2)\n", 506 | " img_cpu = img.cpu()\n", 507 | " label_cpu = label.cpu()\n", 508 | " print(\"Label: {}\".format(label))\n", 509 | " print(\"Predicted: {}\".format(predicted))\n", 510 | " for i, ax in enumerate(axes.flat):\n", 511 | " img_cpu_i = img_cpu[i]\n", 512 | " img_cpu_i = img_cpu_i.view(28, 28)\n", 513 | " ax.imshow(img_cpu_i) \n", 514 | " ax.set_title('Target: {} - Prediction: {}'.format(label_cpu[i], predicted[i]))\n", 515 | " ax.set_xticks([])\n", 516 | " ax.set_yticks([])\n", 517 | " plt.tight_layout()" 518 | ], 519 | "execution_count": 0, 520 | "outputs": [ 521 | { 522 | "output_type": "stream", 523 | "text": [ 524 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:15: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 525 | " from ipykernel import kernelapp as app\n" 526 | ], 527 | "name": "stderr" 528 | }, 529 | { 530 | "output_type": "stream", 531 | "text": [ 532 | "Test Accuracy: 90.6600 %\n", 533 | "Label: tensor([7, 2, 6, 2, 7, 1, 3, 0, 8, 6, 3, 0, 4, 7, 7, 9], device='cuda:0')\n", 534 | "Predicted: tensor([7, 2, 6, 2, 7, 1, 3, 0, 8, 5, 5, 0, 4, 7, 7, 9], device='cuda:0')\n" 535 | ], 536 | "name": "stdout" 537 | }, 538 | { 539 | "output_type": "display_data", 540 | "data": { 541 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVEAAAEYCAYAAADlIcXmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xu4XPO9x/F3kAhJiMQt6ihR+QZt\nqKZSlQSNRgWlxK3auEVQh6Rt6vbkQcl5RHvkoG1oXcJDKq4hkRanJSSog0epS75F4xKaBJGIkKuc\nP9ba7fx+e/bsPfs3e2dm78/reTzymVmz5jez1+87v/WbNWt1WLduHSIi0jwbrO8GiIjUMhVREZEE\nKqIiIglUREVEEqiIiogkUBEVEUmwUak7zexa4IA87gy8B3yW56+7+7IWbBtmtivQ093nNLLcVGDP\ngps2Bx5392PLeK4vAX/P/wPYkOz1nuPuL5TV8Prrvhl4HbgceBkY6O4flFj+NHe/3sw2bMryzWzT\nscDP45uBLu7+aSWfq62poX7REbgKGEI2YPoTcLa7ry3juTYCVpNtv2vz9SwBznP3R5v3Cv617kuA\n7d19pJnNAkaX6mt1/SL/d6PLJ7RrEHAl0A1YDowp9V6XLKLufmbBit8EftDYH67CjgLWACWf092P\nK8xm9hBwczOeb5W79y1YzwnAfWa2i7uvacb6AvnG27fUMma2PfBT4PqmLJ/QljuAOwqe9/vA4Sqg\njauVfkG2He0E9AM6ALOAEcDkZjznIHdfAGBmg4FpZtbH3T9sxrrqcff9S92ffyBcAVzflOWby8w2\nAe4DjnD32Wb2XWAK8MWGHlOyiDbhCXcFbgB65Osa5+53FHx6XQic5O5mZocAvwU+JqvyvwL6uPt8\nMzsTGA10JtswRgIHAz8DVppZD+ACmjaKOwzA3f+Y8trydUwxs18BfcxsO+BiYBGw3N1HmNmRZKO5\nLoADJ7j7YjPbCpgK9AZeAlblbat7X3q5+wIzuzB/rWuA+939Z8CTwDZmNhfYA1hRsPyPgVFkHeJV\n4DR3/8DMbgNeAwYBffL7jnD3z8zsF8BrdZ/gDbxnmwKXAgemvmdSVf3iUWCau9dtf88Au6e+Pnd/\n3MzeAgaY2d+BR8gKz+7uPiQvshOB7sD7wPHu/ma+nd0C7A3MI9tm696z+cBwd/+LmZ2cv0frgKeA\n08hG0d3zfjGUrJ/ULX8cMC5/r98FRrr7PDMbD2wG7ED2QbKIbKCw0MxGA1u4+yXRy9sof/zsPM8B\ndjCzru7+SbH3I3VOdCLZH2lX4HTgxnwXtM7afEPpSPbpd7K770b2h+wMYGYHABcB+5N9aq4ALnH3\nacAMYKK7n+vua929bxN2a39OVhCSmVkHsjd1ZX5Tf+DqvIDuQjbaPcbde5P9USfly10AzHf3nYAx\nZH/0eN37k40K+gFfAb5lZt8DTgHm5SPitQXLD8zXNSi/bwEwvmCVw4GjyXYvvwB8FyB/7xosoLnT\ngEfd/c1GlpOmqYp+4e5Pu7vn6+tI9iH5dIVeY0f+3S+2Af4vL6CbA/cDP3P3L5H1ian5ciOBnmSD\ni+HAQfFK82m1CWQDgr7AFsBZZP1iVf5a3y5YfifgWuC7eb94OM91jgbOJusXS4CTANz96iIFFHdf\nlr/HdQ4GXmmogEJ6ET0E+J/837PJRmTbFNz/QP7/vkAHd//fPP+abDQFcBhwu7svcPd1wHXAkc1p\njJl9G1jp7k805/HRujoAZwJvAP/Ib17m7o/n/z4Y+JO7v5rna4Ej8scNBu4EcPc3KL7bNQx4wN0/\ncfeVZBvN9BJNOgS4s6Cz3EBYnB9w94/cfTXwN7JP36a8zg2BH5ONgqQyqq1fdMgf/wZwT3PWEa3v\nULJR9l/ymzoBdYVnP7JBQN186W3A7vme3GDgnrzwfwDMLLL6ocDs/HV/DhxDNjpvyFCyfljXR28A\nhphZXW2b5e7v5O/h8zSxX+Svc0/gv4EzSi2XtDtPVgguNLMtgc/z2woL8+L8/1sAHxXc/m7Bv7sD\nh5nZsILHd2pme74P3N7Qnflk9LbAGnf/cpFFOuW7C5BtzC+R7RavMzP49+upa/e3CpYHWEb2WnsA\nSwtuL3ztdbbk38WZurnI/HmK2apw+XydWxfkwudbS/bFWFMMBBa7+9xGl5Smqpp+kY9AbyHbrR2e\nF6Z4mSnA1/K4n7svLLKq2WZW98XSPOBgd1+eb68r3X15Qbst6hefkm2/xfpF4TYMWb9YUhfcfUXe\nxoZe4laE7+FHZHWtR56b1S/yPb+pZNMus0st2+wiamYbA3eRFZmH8gnZhr6U+BjoWpB7Ffz7PeBG\ndz+/uW3J29OBbATwXw0t04TJ6OCLpUa8BzwYf6mVt+UjsiME6mxV5PEfkG0wdY/pSTYH1JCFZLtC\ndXrmt6U6FPhDBdYjVF+/AG4iK3xHNPTlqLuf0IT1/OuLpUa8B/zN3b8R31FGv9ir4DGbk09xNGAh\n8NWC3JNs3nlx8cUbl49A7yCbqmt0rzZld74bsDHwbF7ARpN9gdK1yLJzgS55dYdsnqjO/cDwvIhg\nZkea2dj8vtVkn2xN0YvsD/RGWa+i+R4E9jezHQHMbB8zm5jf9xTwvfz2XYBvFnn8dLLd/83z0cIM\nsjmr1UDXaA4Nsl2f4WbWI3+/T6f47lC59iD7Ikoqo2r6hZkdQzYX+MNKHF3SRE+RfRHTP2/Dl8zs\nlvy9eAo43Mw2MLOtyabEYjOBwWa2Q/6Y64ETyV7zRmbWJVr+YeAAM6v79vwMssFNvRF3U+TTALcC\no5o6LdjsIprPaUwEXgCeI/t2egbwR6JPjnxI/iPgNjN7nmw3GWCduz8D/JJsd+FV4Bz+PTc4HfhP\nM5tqZhua2dx8F6mY7YFF+dxHi3P3+WR/sOl5u68inwclGw3vYmbzyObG7i3y+DnA1cCLZN+u/oVs\nBPNXsmPTFpB9QVS3/JNk85ZzyIrepmRfPJRkZr8ws9NKLLJ9/lxSAVXWL04nK6J/y5eZa2aNfcmY\nJN+tPwa4Nm/33cBdeb/8LdmofB7Ztl6sX7xF9p48RvberSTrJ/PJvhSbb2Z7R8ufATyQP98+ZN9l\nlGRmoy07TjW2L7AbcGXBezbXzPZoaF0d1sf5RPMh+hKga8Fciki7pn5Rm1rtZ59m9ryZHZXHY8nm\nTbShSLumflH7Ur+dL8cY4FdmdjnZp+3JrfjcItVK/aLGrZfdeRGRtkJncRIRSZCyO68hbO3r0Pgi\nUgb1idpXdp/QSFREJIGKqIhIAhVREZEEKqIiIglUREVEEqiIiogkUBEVEUmgIioikkBFVEQkgYqo\niEgCFVERkQSteSo8EalhK1asqHfbokWLgnzTTTcF+YUXXgjy/fffH+QrrrgiyD/96U+DvMEG1T/O\nq/4WiohUMRVREZEEKqIiIglSzmyvcyfWPp1PtLLaVJ945ZVXgnzyyfWvXPLcc8+Vtc643nToEG6C\nw4cPD/Lvfve7IG+22WZlPV8z6HyiIiKtSUVURCSBiqiISALNibZvmhOtrJruE6+99lqQv/rVrwa5\n2HGijdlkk02CHNebxtYZt+Gxxx4L8qabblp2mxqhOVERkdakIioikkBFVEQkgeZE2zfNiVZWTfWJ\nBQsWBHn33XcP8pIlS4IcH9MJ0KlTpyCffvrpQY5/Cx8799xzg3znnXeWXP7FF18M8m677VZy+WbQ\nnKiISGtSERURSaAiKiKSoGbPJ/rOO+8EecCAAUGO53v22muvIF955ZVB7t27dwVb1zQ9e/YMcgsc\n8ybSoKVLl5bMsWLzj7/+9a+DPHjw4LLacPbZZwe5sTnRu+66K8gXX3xxWc/XEjQSFRFJoCIqIpJA\nRVREJEHNHic6bty4IE+YMCHIjZ23sNz7K7GO+P5jjz02yFOmTKn3nC1Mx4lWVk0dJxqbPXt2kD/8\n8MMgDxkypN5junXrlvScTz75ZJAHDRpUcvm99947yE899VTS8xeh40RFRFqTiqiISAIVURGRBFV7\nnOiqVauCfNZZZwX5hhtuCHJ8ferG5npT76/UOkSqRWPzkS0h/p6g2O/zCz3zzDMt2Zxm0UhURCSB\niqiISAIVURGRBFU7J7pw4cIgT548OcjxHGg8l7LNNtsE+dRTT01uU3xc58svvxzkqVOnBnn69OlB\njtt4wgknJLdJpJaVe92miy66qIVa0nwaiYqIJFARFRFJoCIqIpKgan87/+mnnwZ5zz33DHJ8jex4\njnTGjBlBHjZsWAVbV9yYMWOCfM011wQ5nhNdu3Zti7epEfrtfGXpwOBGxMd/H3jggUGOf0sfe/fd\nd4Mcf/dRAfrtvIhIa1IRFRFJoCIqIpKgao8Tja83tP322wf5jTfeCHJjv7ltCU8//XSQb7rppiDH\nbYqv8yTS3px55plBbmwO9Ljjjgtyly5dKt6mVBqJiogkUBEVEUmgIioikqBq50RjY8eODfKsWbNK\nLh9f070lxOc0Xb58eZDjed3f/OY3Ld4maT8WL14c5Dlz5gT5jjvuCHL8vcImm2xScv1nnHFGkHv0\n6FFvmU6dOgU5Pg40ngONvzeIj+/eeeedgxwfa921a9cSLV4/NBIVEUmgIioikkBFVEQkQdX+dj4W\n/5b+8ssvD3K/fv2CfPTRR1e8DYsWLQpyr169ghwfFzp48OAgP/LIIxVvUyL9dr6yKton1qxZE+R7\n7rknyKeffnqQly1bFuTUY6fj2jBq1Kh6y4wYMSLIN998c5BvvPHGkus88sgjg3z99dcHuXv37k1q\nawXpt/MiIq1JRVREJIGKqIhIAhVREZEENfPFUjUo96TL48ePD/IFF1zQMg1rPn2xVFlJfeLtt98O\n8v7771/y/li3bt2CPHTo0CA/8cQTJR///vvvBzn+YqsSJ/np3bt3yTZttdVWyc+RSF8siYi0JhVR\nEZEEKqIiIgk0J1qGb33rW0F+7LHHghzPGX388cdBjk9IUgU0J1pZZfWJN998M8j77rtvkBcuXFjy\n8fFJec4999wgFzthSCkDBgwI8rPPPhvkSsyJvvrqq0HeZZddktdZYZoTFRFpTSqiIiIJVERFRBLU\nzEmZW9uLL75Y77Z4DjSeT77wwguDXIVzoFJF4pN1lDsHGh+HvNFGpbvz6tWrg3zttdcG+aWXXir5\n+Eq49dZbg3zppZe2+HO2NI1ERUQSqIiKiCRQERURSaA50QZcdNFF9W5r7Di5+MTQIqU8+OCDQY7n\n2LfddtsgT5gwoeT6PvnkkyC7e5DPP//8IDd2kvDPP/88yH369Km3TDzvX+y7hEKXXXZZkOOL5xU7\n8XO100hURCSBiqiISAIVURGRBPrtfC6+EF7Xrl3rLRPPie60005BjueDauA4Uf12vrLK6hMjR44M\ncnzcaKdOnYI8bNiwkut77rnngvzOO++U05x67r777iB/5zvfqbdMx44dgzxu3Lgg//KXvwxyXG/i\niz2+/vrrQe7cuXPTGls5+u28iEhrUhEVEUmgIioikkDHieauvvrqIBc7JjS+Lb5mUg3MgUoVKXbc\nZaFVq1YF+b777iu5fDzfGG+vO+64Y5Dj3+IPHz48yFtuuWXJ5ysm/j1/fA2lOC9YsCDIP/7xj4Mc\n/76/GmkkKiKSQEVURCSBiqiISIJ2e5zo0qVLg7zrrrsGOZ6rgfpzTP/85z+DvPXWW1eoda1Gx4lW\nVll9YuXKlUE+77zzgvzQQw8F+bXXXgvyDjvsEOQDDzwwyKecckqQ+/fvH+TGzj9aCYsXLw5yfF35\nuE/F5wuYP39+yzSsYTpOVESkNamIiogkUBEVEUnQbudE498Vx7+DL/a+7LXXXkF+5plnKt+w1qU5\n0cqqaJ+Iz+ewbNmyIHfp0iXIxc73UG3OOeecIE+aNCnIM2fODPJBBx3U4m2KaE5URKQ1qYiKiCRQ\nERURSaA50Vz8u+L4+jJQ/1yJ8fViapDmRCurpvuEAJoTFRFpXSqiIiIJVERFRBK02/OJTps2Lcjx\nb3g32ECfLyLSOFUKEZEEKqIiIglUREVEErTbOdFFixYFOT5etthxoqNGjWrRNolI7dFIVEQkgYqo\niEgCFVERkQTtdk70K1/5SpDj40R79epV7zGbbbZZi7ZJRGqPRqIiIglUREVEEqiIiogkaLfnExVA\n5xOtNPWJ2qfziYqItCYVURGRBCqiIiIJUo4T1XyaSEh9oh3SSFREJIGKqIhIAhVREZEEKqIiIglU\nREVEEqiIiogkUBEVEUmgIioikkBFVEQkgYqoiEgCFVERkQQlfztvZtcCB+RxZ+A94LM8f93dl7Vg\n2zCzXYGe7j6nCcsOByaQfTA8C5zi7p+U8VwbAauB14G1+XqWAOe5+6PNaH7hui8Btnf3kWY2Cxjt\n7i+UWP40d78+/3ejyye0a3PgRmBv4FPgfHe/r9LP09bUSr8ws47AVcAQsu35T8DZ7r62jOdqj/1i\nEHAl0A1YDowp9V6XLKLufmbBit8EftCUglZBRwFrgMY2li8B1wADgXn5v4cBdzbjOQe5+4J8vYOB\naWbWx90/bMa66nH3/Uvdn2/4VwDXN2X5RFcBb7n7cDPbDfiVmc0op5O1R7XSL4CfAjsB/chOjjIL\nGAFMbsZztot+YWabAPcBR7j7bDP7LjAF+GJDj0m62mf+iXgD0CNf1zh3v6Pg0+tC4CR3NzM7BPgt\n8DFZlf8V0Mfd55vZmcBooDPZhjESOBj4GbDSzHoAFwAvAwPd/YOoKT8E7nT3f+T57JTXVcfdHzez\nt4ABZvZ34BGyN3h3dx+Sb0wTge7A+8Dx7v6mmW0K3EI2wpsHvFbwns0Hhrv7X8zs5Pw9Wgc8BZxG\nNlrobmZzgaHAkwXLHweMI3uv3wVGuvs8MxsPbAbsQNZhFgGHu/tCMxsNbOHulxS+tnxjOZp843D3\nV8hGLJKoivrFo8A0d1+Vt+sZYPfU19eW+0W+jpHuPjvPc4AdzKxrQ3u2qXOiE8n+SLsCpwM3mtmG\nBfevzTeUjmSffie7+25kf8jOAGZ2AHARsD/Zp+YK4BJ3nwbMACa6+7nuvtbd+xbZUAD2ANaY2Z/M\n7O9mNikvEpXQEViZ/3sb4P/yDWVz4H7gZ+7+JWASMDVfbiTQE+gNDAcOileaj54nAIOAvsAWwFnA\nKcCq/LW+XbD8TsC1wHfdvS/wcJ7rHE324bEz2e7WSQDufnWRDQXAgGXAKDN7xcyezv8Wkq4q+oW7\nP+3unq+vI3Ag8HSFXmOb7Bfuvix/j+scDLxSamowtYgeAvxP/u/ZQBeyN7TOA/n/+wId3P1/8/xr\n/n3uxcOA2919gbuvA64DjiyzHd3JNpDjgb3y5zuvzHXUY2aHko0m/pLf1Amoe4P3A+YVzAvdBuxu\nZtsBg4F78g38A2BmkdUPBWbnr/tz4BiyUUhDhgJ/Khht3wAMMbO6v+Esd38nfw+fJ/v0LaU7sCWw\nNO/AlwL3mln3Rh4njauWfgGAmXXIH/8GcE9z1hGtry33i8LXuSfw38AZpZZL2p0nm3e80My2BD7P\nbysszIvz/28BfFRw+7sF/+4OHGZmwwoe36nMdiwFHnX39wHM7DpgDHBJ4UJmNgX4Wh73c/eFRdY1\n28zqJtDnAQe7+3IzA1jp7ssL2m357kWdT4GtyDawpQW3fwRsHT3PlmSfjAC4+4q8jQ29xq0I38OP\nyP5+PfJc+HxrgcKRTzFLyTrsb/Pnn2lm/yTb1Xq4kcdKadXSL+pGoLeQ7dYOzwtTvIz6RcTMBpKN\noE8q2LUvqtlF1Mw2Bu4im4B9KN99/rSBxT8GuhbkXgX/fg+40d3Pb25bgLeAzQvy2vy/gLuf0IR1\n/WsCvRHvAX9z92/Ed5jZR1F7tiry+A/IRs11j9mcfFeuAQuBrxbknmTza4uLL96ot8k6RBeyvw80\n8L5J01VZvwC4iezvfIS7rym2gPpFvXbuCdwBHOPuTzS2fMrufDdgY+DZfHdhNLCKcKOoMxfokld3\nyOaJ6twPDDezngBmdqSZjc3vW032ydaYO4DjzWy7fO7pFLKJ6Jb0FNmEc3/I5nLM7Jb8vXgKONzM\nNjCzrcnmVWIzgcFmtkP+mOuBE8le80Zm1iVa/mHgADOr+5bwDODBYiOLpsi/Vf0zMDZv/zeBL5Ad\nHibNVzX9wsyOIZsL/GFDBbQF1HS/yKcBbgVGNaWAQkIRzec0JgIvAM8BTjbh/UeiT458SP4j4DYz\nex54Kb9rnbs/A/ySbHfhVeAcYHp+/3TgP81sqpltaGZz812kuC1PAP9F9kd6Ffhnvs4Wk+++HANc\nm7f7buCufO7lt2Sjj3lko5J7izz+LbL35DGy924lcDUwn2zyf76Z7R0tfwbwQP58+wBnxuuNmdlo\ny47HK+ZkYKBlh+lMIvvkXdrAstIE1dQvyIryzsDf8mXmmtn1lXy9sTbQL/YFdgOuLHjP5prZHg2t\nq8O6desae76Ky4foS4CuBXMpIu2a+kVtarWffZrZ82Z2VB6PJZs30YYi7Zr6Re1L/Xa+HGPIfhFz\nOdmn7cmt+Nwi1Ur9osatl915EZG2QmdxEhFJkLI7ryFs7evQ+CJSBvWJ2ld2n9BIVEQkgYqoiEgC\nFVERkQQqoiIiCVRERUQSqIiKiCRQERURSdCaP/useRdddFGQx44dG+TNNtusNZsjIlVAI1ERkQQq\noiIiCVRERUQStNs50dWrV5e8/xe/+EW928aPHx/k3XcPL+F97LHHpjdMRGqKRqIiIglUREVEEqiI\niogkaLdzokuWLAnyiBEjgjx//vx6j+ndu3eQ+/btW/mGibSShQsXBvn+++8P8quvvhrkq666quzn\n2G677YI8adKkIB9++OFlr7PaaCQqIpJARVREJIGKqIhIgpSrfdb09WROPPHEIN96662NPubll18O\n8q677lrRNq0HusZSZVVVn3jvvfeC3L9//yC///77QV67dm2Lt6lDh3CTmzVrVpAHDRrU4m1ohK6x\nJCLSmlRERUQSqIiKiCRos8eJLl++PMjxMW+33357ycdPmTKl3m1tYA5U2rAHH3wwyEceeWSQV6xY\nkbT+HXfcsd5tp556asnHxPOu11xzTZD//Oc/B7kK5kTLppGoiEgCFVERkQQqoiIiCdrscaLDhw8P\n8r333lty+XiO9Hvf+169ZTp16pTesOqi40Qra732ifiaXxMnTiy5fDzHf+mllwZ52LBhQd5gg/pj\nro033rjkcyxbtizIAwYMCHJcf+LvLtYDHScqItKaVERFRBKoiIqIJGgzx4l+9NFHQX7kkUdKLv/9\n738/yPEcaBuc/5Q2rl+/fiXvj+dAJ0yYEOTDDjus4m3q1q1bkONrl8X9bs6cOUEeOHBgxdtUaRqJ\niogkUBEVEUmgIioikqDNzIl++umnQY6vobT11lsHefLkyUHu2LFjyzRMpJV079695P0PP/xwkL/w\nhS+0ZHOKmjlzZpDjc5im/r5/fdBIVEQkgYqoiEgCFVERkQQ1Oyc6d+7cIB966KFB3nDDDYM8Y8aM\nIGsOVNqaoUOHBvmxxx4L8hZbbNGazWk3NBIVEUmgIioikkBFVEQkQc3OiX722WdB/sc//hHk+LjQ\nr3/96y3epr/+9a9BXrx4ccnld9555yB/8YtfrHibpP3o3LlzkGvxekW1SCNREZEEKqIiIglUREVE\nEtTsnGhj4uvDVMLjjz8e5EMOOSTI8bXuY/H1ZDbZZJMg//73vw/yEUccUW4TRarKhx9+GOS77747\nyJtuummQt9122xZvU6VpJCoikkBFVEQkgYqoiEiCmpkTjc87eMMNN5Rcfty4cUnPd9ZZZ9W7bdKk\nSUHu0CG8RHWPHj2CvP322we5T58+QY7nh8aPHx/k+Jo38fkARKrNypUrg3zxxRcHOZ4j3WWXXYL8\n5S9/uWUa1oI0EhURSaAiKiKSQEVURCRBh/jYxTI0+4HNsXr16iBvvPHGQd5gg/Dz4J133glyr169\nSq4/Pj/pvvvuW2+Z+Nr28e/z77rrriDHv11+9913g/wf//EfJdsUzy+1wDlQOzS+iJShVftENXr6\n6aeDvM8++wQ5Pi506tSpQY7PC7welN0nNBIVEUmgIioikkBFVEQkgYqoiEiCmjnYPj6wPf5SpkuX\nLkGOD3yPxSdQjk/2EX+JBPDNb34zyLfffnvJNom0dS+99FKQDz/88JLLX3PNNUGugi+SkmkkKiKS\nQEVURCSBiqiISIKamRPdaKOwqZdddlmQTzrppCCfeOKJQb755puDHM9vrlixIsjHH398vTbccsst\nJdsU/3AhPsHImDFj6q2z0NixY0uuX6Q1xScLiS8GCfVPkrNo0aIgxyceHzBgQIVaVz00EhURSaAi\nKiKSQEVURCRBzZyAJBYfn7bnnnsGOT4hyY477hjk119/PcjxCZTj+6H+SZHjk5b8/Oc/D3I8JxrP\ncW633XZBjk/esM0229RrQ4XpBCSVVVUnIInn+Z977rkg33bbbSUfH2+/8RxpU/Tr1y/I8TYen0io\nCugEJCIirUlFVEQkgYqoiEiCmp0Tjf3oRz8K8nXXXVfW4+OTxZ5//vn1lpkzZ06QH3744bKeIz5p\n81FHHVXW41uA5kQra732ifgk3vGx0nfeeWdrNqeo+LjSvfbaK8ijR48u+fjOnTuXzBWgOVERkdak\nIioikkBFVEQkQZuZE12+fHmQTzvttCDHF8RqCfG5FC+++OIg77HHHkGOz5G6Hqz3BrQxrdonVq1a\nFeQf/OAHQY6P80xV7Ljl+DjP+Lfz8bGq5YqP9545c2aQDzrooKT1F6E5URGR1qQiKiKSQEVURCRB\nmzlhZXyNpfjcn/F8Tnytl6bMDcfLjBgxIshXXHFFkLfddttG1ynSXNOmTQtypedAv/a1rwX50Ucf\nrbdM165dg/yHP/whyKecckqQ16xZE+TFixeXbMNPfvKTILfAHGgyjURFRBKoiIqIJFARFRFJ0GaO\nEy3XN77xjSC/8sorQR41alS9xwwZMiTI3/72t4Ncg9dE0nGildWqfeK4444Lcrm/jY/nMy+99NIg\nDxw4MMj9+/cva/3FxMeNPvDAAyWX33fffYPcq1ev5DY0QseJioi0JhVREZEEKqIiIgna7ZyoAJoT\nrbRW7RPdu3cP8scff1xy+XgftJMmAAAA00lEQVROf/r06UGOrxHfTmlOVESkNamIiogkUBEVEUlQ\ncwc2ikgmvkbX5MmTg7zffvsFecaMGUFugesTtUsaiYqIJFARFRFJoCIqIpJAx4m2bzpOtLLUJ2qf\njhMVEWlNKqIiIglUREVEEqiIiogkUBEVEUmgIioikkBFVEQkgYqoiEgCFVERkQQqoiIiCVRERUQS\nqIiKiCRQERURSaAiKiKSQEVURCRByjWWdC5KkZD6RDukkaiISAIVURGRBCqiIiIJVERFRBKoiIqI\nJFARFRFJ8P8r8QuQgglWKwAAAABJRU5ErkJggg==\n", 542 | "text/plain": [ 543 | "
" 544 | ] 545 | }, 546 | "metadata": { 547 | "tags": [] 548 | } 549 | } 550 | ] 551 | } 552 | ] 553 | } -------------------------------------------------------------------------------- /Classification/naive_bayes_classifiers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Naive Bayes classifiers" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import pandas as pd\n", 39 | "from sklearn.model_selection import train_test_split" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## 2. Data Input and Variables" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 5, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "### Data Input\n", 63 | "# df = \n", 64 | "\n", 65 | "### Defining Variables \n", 66 | "# X = \n", 67 | "# y = \n", 68 | "\n", 69 | "### Data Input Example \n", 70 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 71 | "\n", 72 | "X = df[['horsepower', 'normalized-losses']]\n", 73 | "y = df['price']" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## 3. The Model" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "*Run to build the model.*" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 7, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Breast cancer dataset\n", 100 | "Accuracy of GaussianNB classifier on training set: 0.67\n", 101 | "Accuracy of GaussianNB classifier on test set: 0.00\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "from sklearn.naive_bayes import GaussianNB\n", 107 | "\n", 108 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 109 | "\n", 110 | "nbclf = GaussianNB().fit(X_train, y_train)\n", 111 | "print('Breast cancer dataset')\n", 112 | "print('Accuracy of GaussianNB classifier on training set: {:.2f}'\n", 113 | " .format(nbclf.score(X_train, y_train)))\n", 114 | "print('Accuracy of GaussianNB classifier on test set: {:.2f}'\n", 115 | " .format(nbclf.score(X_test, y_test)))" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.6.4" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /Classification/random_forests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Random forests" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import pandas as pd\n", 39 | "from sklearn.model_selection import train_test_split" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## 2. Data Input and Variables" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 9, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "### Data Input\n", 63 | "# df = \n", 64 | "\n", 65 | "### Defining Variables \n", 66 | "# X = \n", 67 | "# y = \n", 68 | "\n", 69 | "### Data Input Example \n", 70 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 71 | "\n", 72 | "X = df[['horsepower', 'normalized-losses']]\n", 73 | "y = df['price']" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## 3. The Model" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "*Run to build the model.*" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 11, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Accuracy of RF classifier on training set: 0.66\n", 100 | "Accuracy of RF classifier on test set: 0.02\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "from sklearn.ensemble import RandomForestClassifier\n", 106 | "\n", 107 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 108 | "\n", 109 | "clf = RandomForestClassifier(max_features = 2, random_state = 0).fit(X_train, y_train)\n", 110 | "\n", 111 | "print('Accuracy of RF classifier on training set: {:.2f}'\n", 112 | " .format(clf.score(X_train, y_train)))\n", 113 | "print('Accuracy of RF classifier on test set: {:.2f}'\n", 114 | " .format(clf.score(X_test, y_test)))" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.6.4" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 2 139 | } 140 | -------------------------------------------------------------------------------- /Classification/rnn_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "rnn_pytorch.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "metadata": { 20 | "collapsed": true, 21 | "id": "jqZrGQOG2ng-", 22 | "colab_type": "text" 23 | }, 24 | "cell_type": "markdown", 25 | "source": [ 26 | "# Recurrent Neural Network" 27 | ] 28 | }, 29 | { 30 | "metadata": { 31 | "id": "EB7M0gMP2ng_", 32 | "colab_type": "text" 33 | }, 34 | "cell_type": "markdown", 35 | "source": [ 36 | "## 1. Libraries\n", 37 | "*Installing and importing necessary packages*\n", 38 | "\n", 39 | "*Working with **Python 3.6** and **PyTorch 1.0.1** *" 40 | ] 41 | }, 42 | { 43 | "metadata": { 44 | "id": "-HDxCZeVAPb-", 45 | "colab_type": "code", 46 | "outputId": "08d09105-fa52-4196-d4d7-cca9cd92c630", 47 | "colab": { 48 | "base_uri": "https://localhost:8080/", 49 | "height": 289 50 | } 51 | }, 52 | "cell_type": "code", 53 | "source": [ 54 | "!nvidia-smi" 55 | ], 56 | "execution_count": 11, 57 | "outputs": [ 58 | { 59 | "output_type": "stream", 60 | "text": [ 61 | "Thu Apr 4 12:34:13 2019 \n", 62 | "+-----------------------------------------------------------------------------+\n", 63 | "| NVIDIA-SMI 418.56 Driver Version: 410.79 CUDA Version: 10.0 |\n", 64 | "|-------------------------------+----------------------+----------------------+\n", 65 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 66 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 67 | "|===============================+======================+======================|\n", 68 | "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", 69 | "| N/A 73C P0 73W / 149W | 2430MiB / 11441MiB | 0% Default |\n", 70 | "+-------------------------------+----------------------+----------------------+\n", 71 | " \n", 72 | "+-----------------------------------------------------------------------------+\n", 73 | "| Processes: GPU Memory |\n", 74 | "| GPU PID Type Process name Usage |\n", 75 | "|=============================================================================|\n", 76 | "+-----------------------------------------------------------------------------+\n" 77 | ], 78 | "name": "stdout" 79 | } 80 | ] 81 | }, 82 | { 83 | "metadata": { 84 | "id": "Xdemin4U2nhA", 85 | "colab_type": "code", 86 | "outputId": "066af4ef-98e5-463b-e84f-e062e16a45e0", 87 | "colab": { 88 | "base_uri": "https://localhost:8080/", 89 | "height": 292 90 | } 91 | }, 92 | "cell_type": "code", 93 | "source": [ 94 | "import sys\n", 95 | "import os\n", 96 | "# !{sys.executable} -m pip install http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-linux_x86_64.whl\n", 97 | "# !{sys.executable} -m pip install torch torchvision matplotlib\n", 98 | "!{sys.executable} -m pip install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl\n", 99 | "!{sys.executable} -m pip install torchvision matplotlib\n", 100 | "\n", 101 | "import torch\n", 102 | "import torch.nn as nn\n", 103 | "import torchvision.datasets as datasets\n", 104 | "import torchvision.transforms as transforms\n", 105 | "import torch.nn.functional as F\n", 106 | "from torch.autograd import Variable\n", 107 | "\n", 108 | "%matplotlib inline\n", 109 | "import matplotlib\n", 110 | "import matplotlib.pyplot as plt\n", 111 | "import numpy as np\n", 112 | "\n", 113 | "from timeit import default_timer as timer\n", 114 | "\n", 115 | "print(\"PyTorch version: {}\".format(torch.__version__))\n", 116 | "cudnn_enabled = torch.backends.cudnn.enabled\n", 117 | "print(\"CuDNN enabled\" if cudnn_enabled else \"CuDNN disabled\")" 118 | ], 119 | "execution_count": 12, 120 | "outputs": [ 121 | { 122 | "output_type": "stream", 123 | "text": [ 124 | "Requirement already satisfied: torch==1.0.1.post2 from https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl in /usr/local/lib/python3.6/dist-packages (1.0.1.post2)\n", 125 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.2.2.post3)\n", 126 | "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (3.0.3)\n", 127 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (4.1.1)\n", 128 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)\n", 129 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.14.6)\n", 130 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.0.1.post2)\n", 131 | "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.5.3)\n", 132 | "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.3.1)\n", 133 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.0.1)\n", 134 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (0.10.0)\n", 135 | "Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from pillow>=4.1.1->torchvision) (0.46)\n", 136 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib) (40.8.0)\n", 137 | "PyTorch version: 1.0.1.post2\n", 138 | "CuDNN enabled\n" 139 | ], 140 | "name": "stdout" 141 | } 142 | ] 143 | }, 144 | { 145 | "metadata": { 146 | "id": "ateauqQT2nhF", 147 | "colab_type": "text" 148 | }, 149 | "cell_type": "markdown", 150 | "source": [ 151 | "## 2. Variables\n", 152 | "*Indicate the root directory where the data must be downloaded, the directory where the results should be saved and the type of RNN (conventional, LSTM, GRU) and its respective hyper-parameters*" 153 | ] 154 | }, 155 | { 156 | "metadata": { 157 | "id": "Jt3gMpMY2nhG", 158 | "colab_type": "code", 159 | "outputId": "7c0dcb74-324f-4c13-9ee7-3f843aa0fc86", 160 | "colab": { 161 | "base_uri": "https://localhost:8080/", 162 | "height": 34 163 | } 164 | }, 165 | "cell_type": "code", 166 | "source": [ 167 | "# Make reproducible run\n", 168 | "torch.manual_seed(1)\n", 169 | "\n", 170 | "# Settable parameters\n", 171 | "params = {'root': './data/',\n", 172 | " 'results_dir': './results/',\n", 173 | " 'hidden_size': 128,\n", 174 | " 'input_size': 28, # MNIST data input (img shape: 28*28)\n", 175 | " 'sequence_length': 28,\n", 176 | " 'lr': 1e-3,\n", 177 | " 'weight_decay': 1e-10, # 5e-4, # 1e-10,\n", 178 | " 'momentum': 0.9,\n", 179 | " 'num_classes': 10, # class 0-9\n", 180 | " 'batch_size': 128,\n", 181 | " 'model_type': 'GRU', # Options = [RNN, LSTM, GRU]\n", 182 | " 'optim_type': 'Adam', # Options = [Adam, SGD, RMSprop]\n", 183 | " 'criterion_type': 'CrossEntropyLoss', # Options = [L1Loss, SmoothL1Loss, NLLLoss, CrossEntropyLoss]\n", 184 | " 'num_layers': 1,\n", 185 | " 'epochs': 30,\n", 186 | " 'save_step': 200,\n", 187 | " 'use_cuda': True,\n", 188 | " }\n", 189 | "\n", 190 | "# GPU usage\n", 191 | "print(\"GPU: {}, number: {}\".format(torch.cuda.is_available(), torch.cuda.device_count()))\n", 192 | "device = torch.device('cuda') if params['use_cuda'] and torch.cuda.is_available() else torch.device('cpu')\n", 193 | "\n", 194 | "# Ensure results directory exists\n", 195 | "if not os.path.exists(params['results_dir']):\n", 196 | " os.mkdir(params['results_dir'])" 197 | ], 198 | "execution_count": 138, 199 | "outputs": [ 200 | { 201 | "output_type": "stream", 202 | "text": [ 203 | "GPU: True, number: 1\n" 204 | ], 205 | "name": "stdout" 206 | } 207 | ] 208 | }, 209 | { 210 | "metadata": { 211 | "id": "BamOukfB2nhL", 212 | "colab_type": "text" 213 | }, 214 | "cell_type": "markdown", 215 | "source": [ 216 | "## 3. Dataset\n", 217 | "\n", 218 | "*Normalizing between (0.1307, 0.3081): global mean and standard deviation of the MNIST dataset*" 219 | ] 220 | }, 221 | { 222 | "metadata": { 223 | "id": "jpbW8mtb2nhN", 224 | "colab_type": "code", 225 | "outputId": "b0bc736b-2f2e-4544-83ab-f9ca88489ee7", 226 | "colab": { 227 | "base_uri": "https://localhost:8080/", 228 | "height": 34 229 | } 230 | }, 231 | "cell_type": "code", 232 | "source": [ 233 | "# Get train and test datasets\n", 234 | "# trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n", 235 | "trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", 236 | "# trans = transforms.Compose([transforms.ToTensor()])\n", 237 | "mnist_train = datasets.MNIST(\n", 238 | " root=params['root'], # directory where the data is or where it will be saved\n", 239 | " train=True, # train dataset\n", 240 | " download=True, # download if you don't have it\n", 241 | " transform=trans) # converts PIL.image or np.ndarray to torch.FloatTensor of shape (C, H, W) and normalizes from (0.0, 1.0)\n", 242 | "mnist_test = datasets.MNIST(root=params['root'], train=False, download=True, transform=trans) # transforms.ToTensor()\n", 243 | "print(\"MNIST Train {}, Test {}\".format(len(mnist_train), len(mnist_test)))\n", 244 | "\n", 245 | "# Dataloader: mini-batch during training\n", 246 | "mnist_train_dataloader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=params['batch_size'], shuffle=True)\n", 247 | "mnist_test_dataloader = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=params['batch_size'], shuffle=True)" 248 | ], 249 | "execution_count": 117, 250 | "outputs": [ 251 | { 252 | "output_type": "stream", 253 | "text": [ 254 | "MNIST Train 60000, Test 10000\n" 255 | ], 256 | "name": "stdout" 257 | } 258 | ] 259 | }, 260 | { 261 | "metadata": { 262 | "id": "QhKhpBRaU4b4", 263 | "colab_type": "text" 264 | }, 265 | "cell_type": "markdown", 266 | "source": [ 267 | "*Dataset examples*" 268 | ] 269 | }, 270 | { 271 | "metadata": { 272 | "id": "ikWHBC-2U6-B", 273 | "colab_type": "code", 274 | "outputId": "750904ab-edcf-4c33-b891-6c9c044833cd", 275 | "colab": { 276 | "base_uri": "https://localhost:8080/", 277 | "height": 140 278 | } 279 | }, 280 | "cell_type": "code", 281 | "source": [ 282 | "# Plot examples\n", 283 | "examples = enumerate(mnist_test_dataloader)\n", 284 | "batch_idx, (example_data, example_targets) = next(examples)\n", 285 | "\n", 286 | "fig, axes = plt.subplots(nrows=1, ncols=4)\n", 287 | "for i, ax in enumerate(axes.flat):\n", 288 | " ax.imshow(example_data[i][0]) \n", 289 | " ax.set_title('{}'.format(example_targets[i]))\n", 290 | " ax.set_xticks([])\n", 291 | " ax.set_yticks([])\n", 292 | " plt.tight_layout()" 293 | ], 294 | "execution_count": 72, 295 | "outputs": [ 296 | { 297 | "output_type": "display_data", 298 | "data": { 299 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAB7CAYAAAAhbxT1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADzhJREFUeJzt3X+wlVPfx/HPKf2gRHFOftSIcBXR\ndJD4A49GzfO4yaj7jsFDJPlZzTMaJMN4ECNTmUJ+FXfRmTQ43OE2UuE4JkVp6kqlntL4TT+oker5\nA8t3XZ192uecffa19j7v11/f1Vp778V19lnnWmtd31WyZ88eAQAQmmZpdwAAgJowQAEAgsQABQAI\nEgMUACBIDFAAgCAxQAEAgrRf2h3ItSiKBkn63+Q/S2oXx/HWFLqEGkRRdISk6ZKOk7RF0k1xHC9I\nt1f4E9+j8EVR1EXS55LWmH/+KI7j/06nR7lXdANUHMezJc3+sxxF0T8kDeZLFZzpkubGcXxeFEX/\nIekmSQxQgeB7VDC+jOO4W9qdaCxFN0BZURS11u9/Bf5n2n3BX6Io6izpFEn/JUlxHM+TNC/VTiEj\nvkdIS1EPUJKukfR+HMdr9tkS+dRT0heSxkVR9DdJX0kaGcfxknS7hQz4HoWrXRRFL0vqJmmdpFFx\nHK9It0u5U7SbJKIoaibpfyQ9nHZfsJeDJZ0kaUEcx5Gkf0qaE0VRsf/BVHD4HgVtq6SZkkZKOkHS\nvyW9Ukzfo6IdoCSdIWlbHMfL0+4I9rJZ0tdxHL/yR/kpSR0kHZ9el5AB36NAxXH8fRzHN8VxvC6O\n492SHpHUUUX0PSrmAepvkv6VdidQo/WSDvzjr3PFcbxH0m5Ju1LtFWrC9yhQURS1j6Lo6MQ/N5e0\nM43+NIZiHqB6Siqaudgis0zSJklDJSmKor9L+lH+dlmEge9RuE6T9E4URaV/lK+V9H+S1qbXpdwq\n5gGqk35ffEdg/rhjGiRpaBRFa/X7Gsff4zj+Ld2eoQZ8jwIVx/FbkqZIej+KopWSBksaGMdx0cxE\nlHAeFAAgRMV8BwUAKGAMUACAIDFAAQCCxAAFAAjSvp44ZgdFekrq0JbrlJ5srxPXKD1co/DVeI24\ngwIABIkBCgAQJAYoAECQGKAAAEFigAIABIkBCgAQJAYoAECQGKAAAEFigAIABIkBCgAQJAYoAECQ\nGKAAAEFigAIABGlf2cyBepk7d65XjqLIxWVlZS5u27Zt3voEoLBwBwUACBIDFAAgSEzxoU4qKyu9\n8rx581y8ZcsWF0+fPt1r16JFCxc3b97cxcuWLfPadenSJRfdBJAF+51dtGiRV1dVVZXxdU899ZSL\n169f7+I2bdp47VavXu3ijh071rl/3EEBAILEAAUACBJTfKiTlStXeuXZs2e7+Msvv8z4urPOOsvF\nRx99tIsff/xxr924ceMa2sUm6ZJLLnHxm2++6dXZ63LAAQfkrU9ouM2bN3vl+fPnu3jdunUuXrJk\nScZ2e/bs8epKSkpcvHXrVhd///33DeqrJP38889e+ccff3QxU3wAgKLBAAUACBIDFAAgSDlZgxo7\ndqxXPuecc1zct2/fXHwEUvTVV1+5eMKECV7dN9984+Lrr7/exWPGjPHadejQwcX77ffXj91vv/2W\ns342ZcuXL3dxct3Cfgdr2zqMMMycOdPFkydP9uoa8/q1bt3aK/fs2dPFn332mVc3aNAgF+/YscPF\nvXr18to19LER7qAAAEFigAIABKkkuQUxodZK9yZm26IkHX744S62TycfccQRdepcaOxU14wZM7y6\nG2+80cXJW+V6Ktl3Eyer65StZLaIoUOHuvi7777z6uy03qRJk1zcrFmT+dsn2+uU02uU9OCDD7r4\n9ttvz9iue/fuLk5Ow/bv37/B/Tj44INdbDOGpCyIa1Qbuz27U6dOLk5O12ZiH92QpC+++MLFZ555\npld36aWXuri8vNzFJ554oteuXbt2WX12jtR4jZrMbxEAQGFhgAIABIkBCgAQpJxsM2/fvr1X/vrr\nr11cUVHhYrtOI/kZrkOxfft2FydTxtj1Fru9WpI2bdrk4vHjxzdS7/LDZiiX9l53suw6RhNadwpO\naWlpxjq7Jrp27VoXX3755Tnvh12btWsdqN2qVatcXNu6k11fvOaaa1xs160kadeuXS5O/p61j3mE\njt8oAIAgMUABAIKUk23mn3zyiVc+5ZRTamw3bNgwr3zHHXe42G5PlaQDDzwwm4/Omr1tXrFihVc3\nZ84cF7/22msujuM46/fv1q2bi+1T/Q2Q2jbz/fff3yv/+uuvLrbTnJI0ceJEFwe0rTifgtjCbJ/0\nP/nkk706O1VttxLbn3VJOumkk1ycPHjOsr8zbNYYSercubOLly5duo9e500Q16g21dXVLj7jjDNc\nnJw2t79re/To0fgdyx+2mQMACgcDFAAgSDmZ4tu9e7dXvvLKK138wgsv/PVmic+yU0ktW7b06lq1\nauXiY445xsW//PKL187uZLGSh2/ZJKdbtmyp8TUN8cADD7h49OjRuXjLvE7x/fDDDy62mUAkP6Hr\nxo0bvbpk23yxPwfJ3aF2qip5QJ892C85rVxPQUwfTZkyxcX33HOPV2d31eZacveg/ewbbrih0T63\njoK4RlYySXLXrl1dvGHDBheXlZV57Ww2myLDFB8AoHAwQAEAgsQABQAIUk4eKU5uhXz++edd3Lt3\nbxfff//9XjubjcFmcEhKZm2wRowYkXU//5Rct7LbbWubr7fbqKdNm+bVDRw4sM79CMm4ceNcXNsh\ngjlat2mwp59+2sXJjB/PPfdcxtdNnTrVxTZjxkEHHZTD3qUrebpAY0p+9/P52YXMZi+X/HUna/bs\n2fnoTrC4gwIABIkBCgAQpEbPGnjzzTe7+Pzzz/fqZs6c6WK7DVzyE1zarcyLFy/22mXKXnDrrbd6\nZfvZkydP9uoybTtPJlW0GSjs1vdisH79ehcfddRRXl1ya3la7FTvggULXJyclrWPKNiD/CRp5MiR\nLn711VddfMUVV+Ssn2no27eviz/44ING/Syb2NQ+noDs2Z/fJLt132bmkKSdO3e6OMRk27nGHRQA\nIEgMUACAIOUkk0Rjs2fY2HOXpMy7ymxSTMnPaJHtdM4zzzzjlW2GjDzIayYJu/uqe/fuXt3q1atd\nnDyrJplYNpeS51D17NnTxbU9UW/P4xo+fLhXd+yxx7rYTh2/8847Xrs6JCsOLktBY/v0009d3KtX\nL6/OTp8nEwunKLhrVFlZ6ZUHDBiQ1etsItlHH33UxeXl5bnpWHrIJAEAKBwMUACAIDFAAQCCVBCH\n09st3dlu706urVVUVGT1Ops92Ga+LnY2G/udd97p1dn/l8kt+bleg7KHIybXoDKtO9l5eUkaMmSI\ni+3jCpK/LnLXXXe5eNasWV67oUOHZtnjpuett95KuwsF77TTTvPK9me4qqoq4+ts3amnnuri008/\n3Wt36KGHujh5qKTNvpN8lCY03EEBAILEAAUACFLY93cNkEwwm9zWaXXo0MHFCxcudLHNSFDsbrvt\nNhe//PLLXt3HH3/s4vvuu8+rswmA27Zt2+B+2ES1L730UsZ2o0aNcvHYsWO9Opv41U4ZStL8+fNr\nfL/GPNQvVDt27PDKNoHpIYcckvF1NrNIcip98ODBOepdcTvssMO88rvvvuvie++918UfffSR1+7D\nDz90sZ1ur66uzvhZr7/+uldet26dix955BEXh5iZgjsoAECQGKAAAEFigAIABKkgUh3Vxy233OKV\nbQqWdu3aeXX28LuLL764cTuWvbymOqpN165dXWznryXphBNOcPHcuXNd3L59e69dmzZt6vy5NsO6\n5D9i0K1bNxcvWbLEa9eyZcuM72nX0Ow2c7stV6rTmlRwaXSyNWzYMK9sTxewh3pedtllXruJEye6\n2KYQk/zDSmtbI7HrIjblVPI9bHqrBijYa5T07bffunjbtm0unjNnjtfuiSeecLFNVZZkH7tInvKQ\nZ6Q6AgAUDgYoAECQimqKz26bTU7j7dq1y8VPPvmkV3f11Vc3bsfqJ5gpPpvBoU+fPl7dhg0banzN\n8ccf75UvvPDCOn+unc6QpOnTp9fY7qKLLvLKNmN50tKlS11sMyLYKRGpTpkkCnb6KPn/d/To0S7O\n9P+6vuzWf8mfQkxmDLngggtcPHDgwFx8fMFeo/r66aefXHz33Xd7dZMmTXKxnXrfunVro/erFkzx\nAQAKBwMUACBIRZVJYsqUKS62U3pJzZs3z0d3ioZ96t0+yS75h6Y99NBDLl61apXX7uGHH26k3u2d\n+aI25513nouPPPJIF/fv3z+nfSoEpaWlXtlOfdvko9OmTfPa2QMLk+8xdepUF9uDL5MHQNrsLcg9\ne5Br586dM7br169fPrpTb9xBAQCCxAAFAAgSAxQAIEgFv83cbnO2mQZ2796d8TXJp64HDBiQ+441\nXDDbzLP19ttvu3jevHle3WOPPVbn9xs5cqRXfu+991y8aNEiF1977bVeO5vJ/uyzz/bqli1b5uLx\n48fXuU81aHJbmO11WblypVf3xhtv5Ls72QjiGq1Zs8bFyewn9lGJ+hwiuHPnTq/87LPPunj48OEZ\nX2czp48ZM6bOn5tDbDMHABQOBigAQJAKforPZo/o3bu3i5cvX57xNZ9//rlXtlODASm4Kb4mKojp\no3waMWKEi1esWOHV2ewcAUnlGm3atMkr20cZkr+frrrqKhfbxzVsFhfJT8C7YMECF8+aNctrZx8F\nSP6O79Kli4vt9WvduvVe/w15xBQfAKBwMEABAILEAAUACFLBpzpq0aKFi5MZzC07v2pT3ACom5KS\nv5YLmjXjb9xMNm/e7JVrWxe36aRmzJjh4mTKttoen8kkeXjoqFGjXJzyutM+8dMFAAgSAxQAIEgF\nP8W3ceNGF1dVVWVsZw9Ia9WqVaP2CWgqFi5c6JUXL17s4vLy8nx3JygdO3b0yp06dXKx/b2VlMwK\nkQ2bvVySzj33XBdfd911Xp3N6B867qAAAEFigAIABKngp/iyNWTIkLS7ABSd7du3e+Xq6moXN/Up\nvuShjBUVFS5+8cUXvbrKykoX9+jRw8XHHXec185mvbHTeMnpxOTOvULFHRQAIEgMUACAIDFAAQCC\nVPBrUGVlZS7u16+fi+3hdsl2AJBvffr0qTGWpAkTJuS7OwWBOygAQJAYoAAAQSr4Awstu+V127Zt\nXl1paWm+u9NQHFhYGJrcgYUFiGsUPg4sBAAUDgYoAECQGKAAAEEqqjWoIsMaVGFgfSN8XKPwsQYF\nACgcDFAAgCDta4oPAIBUcAcFAAgSAxQAIEgMUACAIDFAAQCCxAAFAAgSAxQAIEj/D49OSB5oVIuT\nAAAAAElFTkSuQmCC\n", 300 | "text/plain": [ 301 | "
" 302 | ] 303 | }, 304 | "metadata": { 305 | "tags": [] 306 | } 307 | } 308 | ] 309 | }, 310 | { 311 | "metadata": { 312 | "id": "dWCCngwF2nhS", 313 | "colab_type": "text" 314 | }, 315 | "cell_type": "markdown", 316 | "source": [ 317 | "## 4. The Model: RNN\n", 318 | "$h_t = \\sigma(W x_t + U h_{t-1})$\n", 319 | "\n", 320 | "Some important information is: \n", 321 | "* Input size: number of expected features in input $x$\n", 322 | "* Hidden size: number of features in hidden state $h$\n", 323 | "* After forward propagation, output has shape (batch_size, seq_length, hidden_size)\n", 324 | "* If you want to initialize of RNN with hidden and cell states different than zero, modify variables $h0$ and $c0$. Otherwise, you may set them as *None*." 325 | ] 326 | }, 327 | { 328 | "metadata": { 329 | "id": "N2UAizM92nhT", 330 | "colab_type": "code", 331 | "colab": {} 332 | }, 333 | "cell_type": "code", 334 | "source": [ 335 | "class RNN(nn.Module):\n", 336 | " def __init__(self):\n", 337 | " super(RNN, self).__init__()\n", 338 | " self.input_size = params['input_size']\n", 339 | " self.hidden_size = params['hidden_size']\n", 340 | " self.num_layers = params['num_layers']\n", 341 | " self.model_type = params['model_type']\n", 342 | " \n", 343 | " if self.model_type == 'RNN':\n", 344 | " self.rnn = nn.RNN(self.input_size, self.hidden_size, num_layers=self.num_layers, bias=True, nonlinearity='tanh', dropout=0.2, batch_first=True)\n", 345 | " elif self.model_type == 'GRU':\n", 346 | " self.rnn = nn.GRU(self.input_size, self.hidden_size, num_layers=self.num_layers, bias=True, dropout=0.2, batch_first=True)\n", 347 | " else: # 'LSTM'\n", 348 | " self.rnn = nn.LSTM(self.input_size, self.hidden_size, num_layers=self.num_layers, dropout=0.2, batch_first=True)\n", 349 | "\n", 350 | " self.bn = nn.BatchNorm1d(self.hidden_size)\n", 351 | " self.fc = nn.Linear(self.hidden_size, params['num_classes'])\n", 352 | " self.softmax = nn.LogSoftmax() # nn.ReLU() # nn.LogSoftmax() # Softmax()\n", 353 | " \n", 354 | " def forward(self, x): \n", 355 | " # Set initial hidden state $h0$ and cell state $c0$\n", 356 | " h0 = torch.zeros(self.num_layers, self.input_size, self.hidden_size, dtype=torch.float32)\n", 357 | " c0 = torch.zeros(self.num_layers, self.input_size, self.hidden_size, dtype=torch.float32)\n", 358 | "\n", 359 | " # Forward propagate RNN\n", 360 | " if self.model_type == 'LSTM':\n", 361 | " out, hidden = self.rnn(x, None) # (h0, c0)) \n", 362 | " else:\n", 363 | " out, hidden = self.rnn(x, None) # h0)\n", 364 | " \n", 365 | " # Decode last hidden state\n", 366 | " out_fc = self.bn(out[:, -1, :])\n", 367 | " out_fc = self.fc(out_fc)\n", 368 | " out_fc = self.softmax(out_fc)\n", 369 | " # print(\"Input {} -> Output shape {} -> {} | Last shape {}\".format(x.shape, out.shape, out_fc.shape, out[:, -1, :].shape))\n", 370 | " return out_fc" 371 | ], 372 | "execution_count": 0, 373 | "outputs": [] 374 | }, 375 | { 376 | "metadata": { 377 | "id": "cw-AqMCfOgPN", 378 | "colab_type": "text" 379 | }, 380 | "cell_type": "markdown", 381 | "source": [ 382 | "*Instantiate model and optimizer*" 383 | ] 384 | }, 385 | { 386 | "metadata": { 387 | "id": "9T6W1JeOOlTW", 388 | "colab_type": "code", 389 | "outputId": "15884eb1-b75a-4dce-c0fc-5c2f20db7884", 390 | "colab": { 391 | "base_uri": "https://localhost:8080/", 392 | "height": 88 393 | } 394 | }, 395 | "cell_type": "code", 396 | "source": [ 397 | "# Instantiate model\n", 398 | "model = RNN()\n", 399 | "# Allow for parallelism if multiple GPUs are detected\n", 400 | "# model = nn.DataParallel(model)\n", 401 | "# Transfer model to device (CPU or GPU according to your preference and what's available)\n", 402 | "model = model.to(device)\n", 403 | "\n", 404 | "# Loss criterion\n", 405 | "if 'CrossEntropyLoss' in params['criterion_type']:\n", 406 | " criterion = nn.CrossEntropyLoss()\n", 407 | "elif 'L1Loss' in params['criterion_type']:\n", 408 | " criterion = nn.L1Loss()\n", 409 | "elif 'SmoothL1Loss' in params['criterion_type']:\n", 410 | " criterion = nn.SmoothL1Loss()\n", 411 | "else: # NLLLoss\n", 412 | " criterion = nn.NLLLoss() \n", 413 | "\n", 414 | "# Optimizer\n", 415 | "if 'Adam' in params['optim_type']:\n", 416 | " optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])\n", 417 | "elif 'SGD' in params['optim_type']:\n", 418 | " optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], momentum=params['momentum'])\n", 419 | "elif 'RMSprop' in params['optim_type']:\n", 420 | " optimizer = torch.optim.RMSprop(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], momentum=params['momentum'])\n", 421 | "\n", 422 | "# Scheduler to reduce learning rate after it plateaus\n", 423 | "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)\n", 424 | "\n", 425 | "# New results dir based on model's parameters\n", 426 | "res_dir = params['results_dir'] + '{}_{}layers_sgd_lr{}_weight{}_trainSize_{}_testSize_{}/'.\\\n", 427 | " format(params['model_type'], params['num_layers'], params['lr'],\n", 428 | " params['weight_decay'], len(mnist_train), len(mnist_test))\n", 429 | "\n", 430 | "if not os.path.exists(res_dir):\n", 431 | " os.mkdir(res_dir)\n", 432 | "\n", 433 | "print(\"res_dir: {}\".format(res_dir))\n", 434 | "log_file = open(res_dir + 'log.txt', 'w')" 435 | ], 436 | "execution_count": 139, 437 | "outputs": [ 438 | { 439 | "output_type": "stream", 440 | "text": [ 441 | "res_dir: ./results/GRU_1layers_sgd_lr0.001_weight1e-10_trainSize_60000_testSize_10000/\n" 442 | ], 443 | "name": "stdout" 444 | }, 445 | { 446 | "output_type": "stream", 447 | "text": [ 448 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:46: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n", 449 | " \"num_layers={}\".format(dropout, num_layers))\n" 450 | ], 451 | "name": "stderr" 452 | } 453 | ] 454 | }, 455 | { 456 | "metadata": { 457 | "id": "t7ZCKLry2nhV", 458 | "colab_type": "text" 459 | }, 460 | "cell_type": "markdown", 461 | "source": [ 462 | "## 5. Train" 463 | ] 464 | }, 465 | { 466 | "metadata": { 467 | "id": "qaLuhdcu2nhW", 468 | "colab_type": "code", 469 | "outputId": "7e0dd432-08ef-4edb-9d5e-b9d9f7fb9b4d", 470 | "colab": { 471 | "base_uri": "https://localhost:8080/", 472 | "height": 1849 473 | } 474 | }, 475 | "cell_type": "code", 476 | "source": [ 477 | "start_timer = timer()\n", 478 | "\n", 479 | "loss_arr = []\n", 480 | "train_acc_arr = []\n", 481 | "first_time = True\n", 482 | "total_num_steps = len(mnist_train_dataloader)\n", 483 | "\n", 484 | "# model.train()\n", 485 | "model.zero_grad()\n", 486 | "optimizer.zero_grad()\n", 487 | "for e in range(1, params['epochs']+1):\n", 488 | " for i, (img, label) in enumerate(mnist_train_dataloader):\n", 489 | " img = Variable(torch.squeeze(img)).to(device)\n", 490 | " label = Variable(label).to(device)\n", 491 | " \n", 492 | " # Forward\n", 493 | " out = model(img)\n", 494 | " loss = criterion(out, label)\n", 495 | " \n", 496 | " # Backward\n", 497 | " optimizer.zero_grad()\n", 498 | " \n", 499 | " # start debugger\n", 500 | " # import pdb; pdb.set_trace()\n", 501 | "\n", 502 | " loss.backward()\n", 503 | " optimizer.step()\n", 504 | " scheduler.step(loss)\n", 505 | "\n", 506 | " loss_arr.append(loss.item())\n", 507 | " \n", 508 | " if i % params['save_step'] == 0:\n", 509 | " # Train Accuracy\n", 510 | " _, predicted = torch.max(out.data, 1)\n", 511 | " total = label.size(0)\n", 512 | " correct = (predicted == label).sum().item()\n", 513 | " acc = 100 * correct / total\n", 514 | " train_acc_arr.append(acc)\n", 515 | " # Print update\n", 516 | " perc = 100 * ((e-1)*total_num_steps + (i+1))/float(params['epochs'] * total_num_steps)\n", 517 | " str_res = \"Completed {:.2f}%: Epoch/step [{}/{} - {}/{}], loss {:.4f}, acc {:.2f}, best acc {:.2f}\".format(perc, e, params['epochs'], i+1, total_num_steps, loss.item(), acc, max(train_acc_arr))\n", 518 | " print(str_res) # print(\"\\r\" + str_res, end=\"\")\n", 519 | " # Save log\n", 520 | " log_file.write(str_res)\n", 521 | " \n", 522 | "# Save training loss\n", 523 | "plt.plot(loss_arr)\n", 524 | "# plt.semilogy(range(len(loss_arr)), loss_arr)\n", 525 | "plt.savefig(res_dir + 'loss.png')\n", 526 | "\n", 527 | "# Save model checkpoint\n", 528 | "torch.save(model.state_dict(), res_dir + 'model.ckpt')\n", 529 | "plt.show()\n", 530 | "log_file.close()\n", 531 | "\n", 532 | "end_timer = timer() - start_timer\n", 533 | "print(\"Model took {:.4f} mins ({:.4f} hrs) to finish training with best train accuracy of {:.4f}%\".format(end_timer/60, end_timer/3600, max(train_acc_arr)))" 534 | ], 535 | "execution_count": 140, 536 | "outputs": [ 537 | { 538 | "output_type": "stream", 539 | "text": [ 540 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n" 541 | ], 542 | "name": "stderr" 543 | }, 544 | { 545 | "output_type": "stream", 546 | "text": [ 547 | "Completed 0.01%: Epoch/step [1/30 - 1/469], loss 2.3744, acc 7.81, best acc 7.81\n", 548 | "Completed 1.43%: Epoch/step [1/30 - 201/469], loss 0.5953, acc 82.81, best acc 82.81\n", 549 | "Completed 2.85%: Epoch/step [1/30 - 401/469], loss 0.6471, acc 82.81, best acc 82.81\n", 550 | "Completed 3.34%: Epoch/step [2/30 - 1/469], loss 0.6236, acc 82.03, best acc 82.81\n", 551 | "Completed 4.76%: Epoch/step [2/30 - 201/469], loss 0.6567, acc 81.25, best acc 82.81\n", 552 | "Completed 6.18%: Epoch/step [2/30 - 401/469], loss 0.5539, acc 85.94, best acc 85.94\n", 553 | "Completed 6.67%: Epoch/step [3/30 - 1/469], loss 0.6836, acc 77.34, best acc 85.94\n", 554 | "Completed 8.10%: Epoch/step [3/30 - 201/469], loss 0.6340, acc 78.91, best acc 85.94\n", 555 | "Completed 9.52%: Epoch/step [3/30 - 401/469], loss 0.6380, acc 82.03, best acc 85.94\n", 556 | "Completed 10.01%: Epoch/step [4/30 - 1/469], loss 0.5994, acc 82.81, best acc 85.94\n", 557 | "Completed 11.43%: Epoch/step [4/30 - 201/469], loss 0.5376, acc 88.28, best acc 88.28\n", 558 | "Completed 12.85%: Epoch/step [4/30 - 401/469], loss 0.5952, acc 85.94, best acc 88.28\n", 559 | "Completed 13.34%: Epoch/step [5/30 - 1/469], loss 0.6287, acc 77.34, best acc 88.28\n", 560 | "Completed 14.76%: Epoch/step [5/30 - 201/469], loss 0.6367, acc 82.81, best acc 88.28\n", 561 | "Completed 16.18%: Epoch/step [5/30 - 401/469], loss 0.6259, acc 82.81, best acc 88.28\n", 562 | "Completed 16.67%: Epoch/step [6/30 - 1/469], loss 0.6444, acc 85.94, best acc 88.28\n", 563 | "Completed 18.10%: Epoch/step [6/30 - 201/469], loss 0.6338, acc 77.34, best acc 88.28\n", 564 | "Completed 19.52%: Epoch/step [6/30 - 401/469], loss 0.6168, acc 85.16, best acc 88.28\n", 565 | "Completed 20.01%: Epoch/step [7/30 - 1/469], loss 0.6262, acc 81.25, best acc 88.28\n", 566 | "Completed 21.43%: Epoch/step [7/30 - 201/469], loss 0.5777, acc 85.94, best acc 88.28\n", 567 | "Completed 22.85%: Epoch/step [7/30 - 401/469], loss 0.6334, acc 79.69, best acc 88.28\n", 568 | "Completed 23.34%: Epoch/step [8/30 - 1/469], loss 0.5715, acc 82.03, best acc 88.28\n", 569 | "Completed 24.76%: Epoch/step [8/30 - 201/469], loss 0.5828, acc 83.59, best acc 88.28\n", 570 | "Completed 26.18%: Epoch/step [8/30 - 401/469], loss 0.6105, acc 83.59, best acc 88.28\n", 571 | "Completed 26.67%: Epoch/step [9/30 - 1/469], loss 0.5210, acc 86.72, best acc 88.28\n", 572 | "Completed 28.10%: Epoch/step [9/30 - 201/469], loss 0.5730, acc 88.28, best acc 88.28\n", 573 | "Completed 29.52%: Epoch/step [9/30 - 401/469], loss 0.7762, acc 74.22, best acc 88.28\n", 574 | "Completed 30.01%: Epoch/step [10/30 - 1/469], loss 0.7092, acc 75.78, best acc 88.28\n", 575 | "Completed 31.43%: Epoch/step [10/30 - 201/469], loss 0.5960, acc 85.94, best acc 88.28\n", 576 | "Completed 32.85%: Epoch/step [10/30 - 401/469], loss 0.6185, acc 86.72, best acc 88.28\n", 577 | "Completed 33.34%: Epoch/step [11/30 - 1/469], loss 0.6776, acc 77.34, best acc 88.28\n", 578 | "Completed 34.76%: Epoch/step [11/30 - 201/469], loss 0.5619, acc 87.50, best acc 88.28\n", 579 | "Completed 36.18%: Epoch/step [11/30 - 401/469], loss 0.6417, acc 84.38, best acc 88.28\n", 580 | "Completed 36.67%: Epoch/step [12/30 - 1/469], loss 0.6395, acc 81.25, best acc 88.28\n", 581 | "Completed 38.10%: Epoch/step [12/30 - 201/469], loss 0.5894, acc 82.81, best acc 88.28\n", 582 | "Completed 39.52%: Epoch/step [12/30 - 401/469], loss 0.6575, acc 82.03, best acc 88.28\n", 583 | "Completed 40.01%: Epoch/step [13/30 - 1/469], loss 0.6557, acc 82.03, best acc 88.28\n", 584 | "Completed 41.43%: Epoch/step [13/30 - 201/469], loss 0.5732, acc 86.72, best acc 88.28\n", 585 | "Completed 42.85%: Epoch/step [13/30 - 401/469], loss 0.6268, acc 82.81, best acc 88.28\n", 586 | "Completed 43.34%: Epoch/step [14/30 - 1/469], loss 0.5557, acc 87.50, best acc 88.28\n", 587 | "Completed 44.76%: Epoch/step [14/30 - 201/469], loss 0.5617, acc 85.16, best acc 88.28\n", 588 | "Completed 46.18%: Epoch/step [14/30 - 401/469], loss 0.6255, acc 80.47, best acc 88.28\n", 589 | "Completed 46.67%: Epoch/step [15/30 - 1/469], loss 0.4741, acc 89.06, best acc 89.06\n", 590 | "Completed 48.10%: Epoch/step [15/30 - 201/469], loss 0.5339, acc 87.50, best acc 89.06\n", 591 | "Completed 49.52%: Epoch/step [15/30 - 401/469], loss 0.5534, acc 81.25, best acc 89.06\n", 592 | "Completed 50.01%: Epoch/step [16/30 - 1/469], loss 0.5734, acc 85.94, best acc 89.06\n", 593 | "Completed 51.43%: Epoch/step [16/30 - 201/469], loss 0.6231, acc 82.81, best acc 89.06\n", 594 | "Completed 52.85%: Epoch/step [16/30 - 401/469], loss 0.6168, acc 80.47, best acc 89.06\n", 595 | "Completed 53.34%: Epoch/step [17/30 - 1/469], loss 0.6456, acc 78.12, best acc 89.06\n", 596 | "Completed 54.76%: Epoch/step [17/30 - 201/469], loss 0.5973, acc 82.03, best acc 89.06\n", 597 | "Completed 56.18%: Epoch/step [17/30 - 401/469], loss 0.5220, acc 85.16, best acc 89.06\n", 598 | "Completed 56.67%: Epoch/step [18/30 - 1/469], loss 0.6613, acc 81.25, best acc 89.06\n", 599 | "Completed 58.10%: Epoch/step [18/30 - 201/469], loss 0.5319, acc 87.50, best acc 89.06\n", 600 | "Completed 59.52%: Epoch/step [18/30 - 401/469], loss 0.6313, acc 82.81, best acc 89.06\n", 601 | "Completed 60.01%: Epoch/step [19/30 - 1/469], loss 0.6134, acc 82.81, best acc 89.06\n", 602 | "Completed 61.43%: Epoch/step [19/30 - 201/469], loss 0.6066, acc 85.94, best acc 89.06\n", 603 | "Completed 62.85%: Epoch/step [19/30 - 401/469], loss 0.5751, acc 85.16, best acc 89.06\n", 604 | "Completed 63.34%: Epoch/step [20/30 - 1/469], loss 0.5277, acc 85.16, best acc 89.06\n", 605 | "Completed 64.76%: Epoch/step [20/30 - 201/469], loss 0.7158, acc 78.91, best acc 89.06\n", 606 | "Completed 66.18%: Epoch/step [20/30 - 401/469], loss 0.5503, acc 86.72, best acc 89.06\n", 607 | "Completed 66.67%: Epoch/step [21/30 - 1/469], loss 0.7054, acc 79.69, best acc 89.06\n", 608 | "Completed 68.10%: Epoch/step [21/30 - 201/469], loss 0.6488, acc 75.78, best acc 89.06\n", 609 | "Completed 69.52%: Epoch/step [21/30 - 401/469], loss 0.6451, acc 84.38, best acc 89.06\n", 610 | "Completed 70.01%: Epoch/step [22/30 - 1/469], loss 0.6415, acc 77.34, best acc 89.06\n", 611 | "Completed 71.43%: Epoch/step [22/30 - 201/469], loss 0.6619, acc 79.69, best acc 89.06\n", 612 | "Completed 72.85%: Epoch/step [22/30 - 401/469], loss 0.6728, acc 79.69, best acc 89.06\n", 613 | "Completed 73.34%: Epoch/step [23/30 - 1/469], loss 0.6324, acc 82.03, best acc 89.06\n", 614 | "Completed 74.76%: Epoch/step [23/30 - 201/469], loss 0.5901, acc 84.38, best acc 89.06\n", 615 | "Completed 76.18%: Epoch/step [23/30 - 401/469], loss 0.5975, acc 83.59, best acc 89.06\n", 616 | "Completed 76.67%: Epoch/step [24/30 - 1/469], loss 0.6897, acc 77.34, best acc 89.06\n", 617 | "Completed 78.10%: Epoch/step [24/30 - 201/469], loss 0.6770, acc 82.03, best acc 89.06\n", 618 | "Completed 79.52%: Epoch/step [24/30 - 401/469], loss 0.6065, acc 84.38, best acc 89.06\n", 619 | "Completed 80.01%: Epoch/step [25/30 - 1/469], loss 0.5896, acc 83.59, best acc 89.06\n", 620 | "Completed 81.43%: Epoch/step [25/30 - 201/469], loss 0.6912, acc 82.03, best acc 89.06\n", 621 | "Completed 82.85%: Epoch/step [25/30 - 401/469], loss 0.5715, acc 84.38, best acc 89.06\n", 622 | "Completed 83.34%: Epoch/step [26/30 - 1/469], loss 0.5760, acc 84.38, best acc 89.06\n", 623 | "Completed 84.76%: Epoch/step [26/30 - 201/469], loss 0.6141, acc 82.03, best acc 89.06\n", 624 | "Completed 86.18%: Epoch/step [26/30 - 401/469], loss 0.5529, acc 85.16, best acc 89.06\n", 625 | "Completed 86.67%: Epoch/step [27/30 - 1/469], loss 0.6197, acc 79.69, best acc 89.06\n", 626 | "Completed 88.10%: Epoch/step [27/30 - 201/469], loss 0.6477, acc 77.34, best acc 89.06\n", 627 | "Completed 89.52%: Epoch/step [27/30 - 401/469], loss 0.6163, acc 85.16, best acc 89.06\n", 628 | "Completed 90.01%: Epoch/step [28/30 - 1/469], loss 0.7884, acc 74.22, best acc 89.06\n", 629 | "Completed 91.43%: Epoch/step [28/30 - 201/469], loss 0.6281, acc 82.03, best acc 89.06\n", 630 | "Completed 92.85%: Epoch/step [28/30 - 401/469], loss 0.5852, acc 83.59, best acc 89.06\n", 631 | "Completed 93.34%: Epoch/step [29/30 - 1/469], loss 0.5584, acc 84.38, best acc 89.06\n", 632 | "Completed 94.76%: Epoch/step [29/30 - 201/469], loss 0.6711, acc 79.69, best acc 89.06\n", 633 | "Completed 96.18%: Epoch/step [29/30 - 401/469], loss 0.6539, acc 82.81, best acc 89.06\n", 634 | "Completed 96.67%: Epoch/step [30/30 - 1/469], loss 0.6186, acc 83.59, best acc 89.06\n", 635 | "Completed 98.10%: Epoch/step [30/30 - 201/469], loss 0.6031, acc 82.03, best acc 89.06\n", 636 | "Completed 99.52%: Epoch/step [30/30 - 401/469], loss 0.5808, acc 82.81, best acc 89.06\n" 637 | ], 638 | "name": "stdout" 639 | }, 640 | { 641 | "output_type": "display_data", 642 | "data": { 643 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XeYVNX5wPHvbIOtsMDAUqQJHEFA\nhSAg0kSxBDtqjCUajYlg7FES87NGY1TEmhhii41YUSwgKCBNBBHpHOoK7LKwsH3ZvvP7496Znb6z\nuzO7e8f38zw87Nz6TrnvPfecc++xORwOhBBCRJeYlg5ACCFE+ElyF0KIKCTJXQghopAkdyGEiEKS\n3IUQIgrFtXQATrm5xY3utpOenkR+/rFwhhMxVooVrBWvxBoZVooVrBVvOGK121Nt/qZHRck9Li62\npUMImZViBWvFK7FGhpViBWvFG8lYoyK5CyGE8CTJXQghopAkdyGEiEKS3IUQIgpJchdCiCgkyV0I\nIaKQJHchhIhCreYmpsZap3OJz8xnaO/0lg5FCCFaDcuX3D9ZsYfXP9/a0mEIIUSrYvnkXuuAmhoZ\ncEQIIdxZPrkbD1WQ5C6EEO4sn9yxgYwUKIQQniyf3G1IuV0IIbxZPrlL0V0IIXxZPrnbbFJyF0II\nb9ZP7kjBXQghvFk/uduk1l0IIbxZPrljM/q6CyGEqBPS4weUUk8AY83l/661/sht3kTg70ANoIEb\ngXHA+8AWc7FNWus/hjFuF6mWEUIIX/UmdzN5D9Zaj1ZKdQTWAx+5LTIbmKi1PqCUeh84BzgGfKO1\nnhqJoN3ZJLsLIYSPUKpllgGXmX8XAMlKKfdRXYdrrQ+Yf+cCHcMYXwhsUuMuhBBebI4GlHqVUjcB\nY7XW1/iZ1xVYDowEhgD/BHYBHYCHtNaLgm27urrG0ZiRwO9+dhm7swqZ+8T5DV5XCCGigM3fxJAf\n+auUuhC4AZjsZ15n4FNgmtb6qFJqJ/AQ8B7QF1iilOqnta4MtP38/GOhhuKhuroGcJCbW9yo9Zub\n3Z5qmVjBWvFKrJFhpVjBWvGGI1a7PdXv9FAbVM8G7gPO0VoXes1LA+YD92mtFwJorbOAd81Fdiul\ncoDuwN5GRR+M3KAqhBA+6q1zV0q1A54Epmit8/wsMhOYpbVe4LbOVUqpu82/M4AuQFZ4QvZkkzp3\nIYTwEUrJ/QqgE/CeUso5bTGwCfgSuBbor5S60Zz3DjAHeMesykkAbg5WJdMk0ltGCCF81Jvctdaz\nMbo7BtImwPRmaeGU+1OFEMKX5e9QlYK7EEL4snxyN+5iEkII4c7yyd2Z2hvSX18IIaKd9ZO7md0l\ntQshRB3LJ3cXye5CCOFi+eRuM4vuDsnuQgjhYvnk7iRV7kIIUcfyyV06ywghhC/rJ3fzfym5CyFE\nHcsn97qiu2R3IYRwsnxyl5K7EEL4snxyR/q5CyGED8snd5tkdyGE8GH95O7K7ZLdhRDCyfLJ3Unq\n3IUQoo7lk7t0cxdCCF+hjqH6BDDWXP7vWuuP3OadCTwG1ABfaK0fMafPAkZh1IbfprVeG+bYAbfH\nD0jJXQghXEIZQ3UiMFhrPRo4B3jGa5HngEuBMcBkpdQgpdR4oL+5zg3mMhEm2V0IIZxCqZZZBlxm\n/l0AJCulYgGUUn2BPK31fq11LfAFMMn89zGA1nobkK6USgt38CCP/BVCCH9CGUO1Big1X96AUfVS\nY77OAHLdFj8MHI8xoPY6t+m55rJFgfaTnp5EXFxs6JGb2rQx3kLHjimkJiU0eP2WYLentnQIDWKl\neCXWyLBSrGCteCMVa0h17gBKqQsxkvvkIIsFat+st90zP/9YqKF4qKw0zjNHjpRQnhjfqG00J7s9\nldzc4pYOI2RWildijQwrxQrWijccsQY6OYTaoHo2cB9wjta60G1WNkaJ3Km7Oa3Sa3o34GAD4g2Z\nDLMnhBC+QmlQbQc8CUzRWue5z9NaZwJpSqneSqk4YAqw0Pw31Vx/GJCttY7MqVTq3IUQwkcoJfcr\nMOrQ31NKOactBjZprecCNwNzzOnvaq13ADuUUuuUUquAWmB6eMOuY5MWVSGE8BFKg+psYHaQ+cuA\n0X6mz2haaKGRahkhhPBl/TtUpeAuhBA+LJ/cnaTgLoQQdSyf3G0yiKoQQviwfHLfuPsoADU1tS0c\niRBCtB6WT+5lFdUA7M2xxk0LQgjRHCyf3LukJwLQs3NKC0cihBCth+WTu+qZDkhvGSGEcGf55B5j\ntqfW1kp6F0IIJ8sn97rBOiS5CyGEk+WTe4yMxCSEED4sn9yd3dxrJbsLIYRLFCR3KbkLIYS3KEju\nxv8O6S8jhBAulk/uUucuhBC+LJ/cv92SA8CqzTktHIkQQrQelk/uhaWVACzfkN3CkQghROth+eTu\nJPcwCSFEnVAHyB4MfALM0lq/4Da9O/C226J9gRlAAvAIsNucvkhr/WhYIg5AbmISQog69SZ3pVQy\n8Dzwtfc8rXUWMMFcLg5YCszDGBz7Xa313WGMNSjJ7UIIUSeUapkK4Dygvkrt64APtdYlTQ2qMeQm\nJiGEqBPKANnVQLVSqr5FbwQmu70er5RaAMQDd2ut1wdbOT09ibi42Pr2EZTdntqk9ZuLVeJ0slK8\nEmtkWClWsFa8kYo1pDr3+iilRgPbtdZF5qTVQK7W+nNz3hvAkGDbyM8/1uQ4cnNb/4AddnuqJeJ0\nslK8EmtkWClWsFa84Yg10MkhXL1lpgBfOV9orbdrrT83//4WsCulmlYsF0IIEbJwJfcRwAbnC6XU\nPUqpK82/B2OU4mvCtC8hhBD1CKW3zHBgJtAbqFJKTcXoEbNXaz3XXKwrcNhttXeAN5VSfzD3cUM4\ngxZCCBFcKA2q6zC7OwZZZojX6wPAxCZFJoQQotGi5g5VIYQQdSS5CyFEFJLkLoQQUcjyyX3s0K4A\njBmc0cKRCCFE62H55J59pBSAlfI8dyGEcLF8cs/MscadaEII0Zwsn9ydY6gKIYSoY/nkDpLdhRDC\nm+WTe4zkdiGE8GH55C4FdyGE8GX55G6T7C6EED6sn9wltwshhA9J7kIIEYUsn9wT4mUMECGE8Gb5\n5H75xH4AjDqxSwtHIoQQrYflk3t6ShsAOrdPbOFIhBCi9QhpgGxzqLxPgFla6xe85mUC+wHnMHpX\naa2zlFKzgFGAA7hNa702XEG7c9a5OxyR2LoQQlhTKMPsJQPPA18HWexcrXWJ2zrjgf5a69FKqYHA\nq8Dopgbrj83M7g4kuwshhFMo1TIVwHlAdgO2Own4GEBrvQ1IV0qlNTy8+knJXQghfIUyhmo1UK2U\nCrbYS0qp3sAK4M9ABrDObX6uOa0o0AbS05OIi2t4z5ejpVUAJCYmYLenNnj9lmCVOJ2sFK/EGhlW\nihWsFW+kYg2pzr0e9wMLgDyM0vqlfpaptzd6fv6xRu28oMBYr6S0gtzc1v/4X7s91RJxOlkpXok1\nMqwUK1gr3nDEGujk0OTeMlrrN7TWh80S/hfAEIwqHPehkboBB5u6L39KyoyS+4Lv9kVi80IIYUlN\nSu5KqXZKqS+VUgnmpPHAZmAhMNVcZhiQrbWOyKk0J69xJX4hhIhmofSWGQ7MBHoDVUqpqcA8YK/W\neq5S6gtgtVKqDFgPfKC1diil1imlVgG1wPSIvQMhhBA+QmlQXQdMCDL/WeBZP9NnNCmyEMmjZYQQ\nwpfl71AVQgjhS5K7EEJEIesnd3nmrxBC+LB8cpfULoQQviyf3O3p8jRIIYTwZvnk3rNzCgBdOiS1\ncCRCCNF6WD65l5ZXA3BIbmYSQggXyyf3iqqa+hcSQoifGcsnd3mMuxBC+LJ8cpdBOoQQwpflk3ti\nQjieWiyEENHF8sm9SwfpCimEEN4sn9zlNiYhhPBl+eQuTx8QQghflk/uQgghfFk+ucdI0V0IIXyE\n1NVEKTUY+ASYpbV+wWveRODvQA2ggRuBccD7wBZzsU1a6z+GK2gPktuFEMJHKMPsJQPPA18HWGQ2\nMFFrfUAp9T5wDnAM+EZrPTVskQYguV0IIXyFUi1TAZwHZAeYP1xrfcD8OxfoGI7AQmWTahkhhPBR\nb3LXWldrrcuCzC8CUEp1BSYDX5izBiml5imlViilzgpLtEIIIUISlts7lVKdgU+BaVrro0qpncBD\nwHtAX2CJUqqf1roy0DbS05OIi4ttUhx2e2qT1m8uVonTyUrxSqyRYaVYwVrxRirWJid3pVQaMB+4\nT2u9EEBrnQW8ay6yWymVA3QH9gbaTn5+4x/Zm9w2jg5pbcnNLW70NpqL3Z5qiTidrBSvxBoZVooV\nrBVvOGINdHIIR1fImRi9aBY4JyilrlJK3W3+nQF0AbLCsK+AHPL8MCGEcAmlt8xwjATeG6hSSk0F\n5mGUwr8ErgX6K6VuNFd5B5gDvKOUuhBIAG4OViXTZNKoKoQQHupN7lrrdcCEIIu0CTD9/MYEJIQQ\nouksf4eqEEIIX1GU3KXSXQghnKJipIvSsipKy6paOgwhhGg1oqjkLoQQwkmSuxBCRCFJ7kIIEYUk\nuQshRBSS5C6EEFFIkrsQQkQhSe5CCBGFoiK5J7Zp2qOChRAi2kRFcu/bvb0MtyeEEG6iIrmD8fAB\nhzz3VwghgChJ7lv2HAUgv7iihSMRQojWISqSu1P20dKWDkEIIVqFqErutbUtHYEQQrQO0ZXcpc5d\nCCGAEB/5q5QaDHyCMVbqC17zzgQeA2qAL7TWj5jTZwGjMNo6b9Narw1n4P44aiW5CyEEhDaGajLw\nPPB1gEWeA87GGAD7G6XUh4Ad6K+1Hq2UGgi8CowOT8iBSW4XQghDKNUyFcB5QLb3DKVUXyBPa71f\na10LfAFMMv99DKC13gakK6XSwhZ1ANIVUgghDKEMkF0NVCul/M3OAHLdXh8Gjgc6AevcpueayxYF\n2k96ehJxcU270zQlpS12e2qTttEcrBCjOyvFK7FGhpViBWvFG6lYwz3MXqAbReu9gTQ//1iTd15Q\ndIzc3OImbyeS7PbUVh+jOyvFK7FGhpViBWvFG45YA50cmprcszFK5E7dzWmVXtO7AQebuK96OaQr\npBBCAE3sCqm1zgTSlFK9lVJxwBRgoflvKoBSahiQrbWO+Km0rTxATAghgNB6ywwHZgK9gSql1FRg\nHrBXaz0XuBmYYy7+rtZ6B7BDKbVOKbUKqAWmRyJ4p5P6d2LDziN0SU+K5G6EEMIyQmlQXQdMCDJ/\nGX66OWqtZzQpsgbo1imFDTuPNNfuhBCi1YuKO1RzC8oAuUNVCCGcoiK5f7/tEAAz//djC0cihBCt\nQ1Qkd6fC0sqWDkEIIVqFqEruQgghDFGV3ONio+rtCCFEo0VVNmxtz5apqq5p6RCEED9TUZbcWzqC\nOp+tyuT3T31DZk7Ax+mIMCkpq2rpEIRodaIrudN6svtHy/YA8KP0v4+oLZl53Prscuat3NvSodSr\nVp5J3WCHC8ooPiYdJRojupJ7E4+dDbuO8IeZS8k6ImOxWsWPO4yT56K1+1s4kuDmrdjLjU8s4Uhh\nWcT2se9QMet0bv0LWsiMl77ltudWtHQYlhRVyb2pXp+/ncqqWr76vnUnClGnNV2tBfPxCuPKYmtm\nfsT28eBra3lx7iZqZDDhVuHLNft4c6HmWHmV3xssyyqqI7p/Se4WVVvrYOGafeQVlbd0KC3KecjY\nbPU+VbrZVNfUNrlx/0hhWaOrcVpT25NTSVkVFZU/rw4G7y7exZIfsrjlmeXMem+Dx7y12w8zfdYy\nFn73U8T2L8ndjfMmqNZ4cHj7btsh/rd4F/9454ewb7uqupY92UVs3ns04DI5ecdY+mNW2Pfd3ErK\nqjgchrEEnKprarnpyaU898HGRm9jd1Yh9/zrW175fKtr2lsLNZ+uyvTYj5Uakm99djm3PLOs2fdb\nU+t5os3MKeKdRTua3P5x8GgpL3y0icKSipCW37I3j/99vdP1euUm4wno81dFrq1IknszWqdz+fdH\nGz1+bLuyCvnt44vZmpnXoG0VlhgnotwC35J79pFSqmuCX5rvzi5k1Wb/j9h/8LU1/O2N73n63Q0c\nDVBH/JfZq3ljgW7W3kD+Lm3X72haHfNtzy5nxr9XN2kb7pwJd8Nu/yfGQyGcSHYeKATg2y2HXNMW\n/5DFXLORHuD+V9Zw67PL6/2e/amqro14lYBTQUkFB48abVg1QRJqWUU1DoeDHfsLwnrS+sNT3/DQ\na2tdrx9+/Xu+WneAH3c1rKNDdU0tBw6XuF6/9MkWftiRy4ff7AmylqeFzdwuFJXJfe6yPQ3qpeJw\nOLxKbw62ZOah94W3fvTFuZv4bOVeCkrqWv8/XZkJwAdLd3OsvIrS8qb9sHcdKOSvL3/HS59s8ZlX\nWVVDgVnSePSNdbz82Ta/yeHg0brPor4Draw8skli056jZOYU8dGy3dz4jyUepVnA47NsDGe62Xeo\nacMNlFdWM/PdH9mxvyDocvNX73P9nVdUzm8fX8yS9Q2/AsrJM76jiqrQqjpqHQ5XFd5tzy1n+izf\nUvSyDdm8+vm2BscSzJ0vrOS+/3wXdJms3BKmz1rG0+9t4PG3f+Ch19aEbf81tQ72uSVlp0o/96C4\nHwvrd+aS7dax4uXPtnL/q2vYYhbCKs3P3d92GiKSlQRRl9zLKqr5dFUmz324kc17j7Irq7DeddZu\nP+xRelu24SAz//cj/3hnfaNicB54APnFFT71r/7qY202uOWZ5fzxmeUh7cM7GVXX1PLknPV88M1u\nAH7wU6L98+zV3PnCSo8fcb1VUG7za2r9lPiC1HUvXLufV7/wTBafrNjLCx9tqmendWa9t4GHX/+e\nz1YZdZMrN+UAsPdgEZv2+JaO84srQno66NbMPI/v6cHX1roO2EAcDgc/7MilpKyK5z/cyLebc1zz\nVm7KYcvePL8n1UDWmd/Rm19qz/1E4JD/7/zt3P3PVew8UEB5gLrv1+dvZ8WmgyGfMJrK+Tvcvs84\nIW7ZayTOo0UVHCuvYv5qz/rowtJKdh4IfvJsrIVr93PTk0tZviGb8spqnv9wE399ue6ktGbbYQAy\nD5pXqk1s42mOFqJwj6Ha4tyP66ffNRoxXp1xRtB1/CWJQGodDrZl5tO/Rzt2ZRXSvVMy7VLauOYf\nzj/GX2bXnSiWbzxIYps4fjWpv2vanmzjB9IhrW3QA7mmtpacvDK6dUxyNRg6HA6Wbchm9da6S/aS\nsir2Hixi20/BrzTyi41Se2VVXXKvrXWw5IcDDBtg93gfTs7oKiprePi/azl49Bgv3zPRNT/Yj9RZ\nx/jb8wa6pn2ywn8dY1lFNd/rw4wa1IX4EAZKf+S/3/tM27k/n7teXEnn9ESKj1Xx2E2jaJec4Hdf\nT/l5gmhldS0J8YH3/dbCHSxZn0Vim1jKKmpYv/MIowcbo0lu93OV982PWby3ZBd//73PcAc+7n/l\nO36hOnPB6X0aXJx7b/EubG7FtJWbDjL+5O4eyyzfaFTBhVLYWbHxIPb2ifTqksL3Ope3F+1g0vAe\nXHXWAI/lKqtq2HGggEG9OhAT4/tLCHbV97c3vmdPdhEPXDfC7/x3vtrJKreTZ3VNLTP+/S0VlTXM\nnD6G9FTf32owL3y0ifEndws43/lbfW3+dk4ZYA+6rbKKag6ZBYM12w5z0wUOYhqZ7CPZvhdScldK\nzQJGYfzsbtNarzWndwfedlu0LzADSAAeAXab0xdprR8NV9De+nRLY6+ZMNfvDFwHW1hSwZL1WZx9\nak8S2zTuvLZ8Qzb/XVBX0oqLjWH2nyZQWl5FaVkVOXm+ddRL1md5JPd/frwZgH/dOd5tqbofx4ff\n7OaisX14Z9FOlqzP4ndTBjGwdzrtU9rw464jHvsH+Ptb67jijP6Eru4XtXLzQd5auIO3Fu7glRln\n+L2qqHU4uPnpb1yv3etO/f2mHQ4HC77b5zvDS2FJBT8dKmHo8R15e9EOVm3OIbegnEvG9Q26nr8e\nQiVlVaw3+3gfzje+gzueX+H3xB6o6utwfhk/7DAux2tqHR7JbPPeo67qk7IK35Ktv/7lzu9pU4D6\nd/ckfiC3lAO5e7ng9D4+ud39MRbVNbUede/V1bUsWOP5Wf93gWbsSd2IsdmorXV4fEe2AKfjtdsP\nu/5+e9EOn/lfrzvgk9zfWrSDFRsPcs3kAUwc1sNnnXtfWuUzzeFwUFpe7SrgLN+YTZcOviOoOb9D\np5ueXOr6+64XV/LyvRMblFB/2JHr92oWfNtyvAt73lWXf/mPZxvNvkPF9M5IY96KvaQkxTPh5O71\nXj26ene1ZHJXSo0H+mutRyulBgKvYo68pLXOwhylyRxDdSnGEHxTMYbcuzsyYXtq41baeiVIneF/\nPtvK1sx85pn13C/cPpbisqqgH/C8FXs5bUgGVdW11NQ6yMzxrQ4BuO3ZFdQ6HFxxRr+Q47756W8Y\n1Dsd8EySn3/7E907JbuSyX8+M+qZX51xBkcKfRObex25u4KSCvS+Ak4d2DlgV0HnQeQAfsop5qHX\n1/os49uzwPcDyysqp31qG2JsNvZkF/H+0t0+yzgHVXF64LW1FJVW8vANp/KT+bkedKvnDFRNcu9L\n3/qd/uZ83+++tLyKjbuP4nA46NQukeqaWr+ldjBKk+7ck5nzKtCf8srg7Q7OnhGhKnUr8R44XML9\nr9bVQbsnOYCZAeLak13E8d3SuPnpb8hwS57uBeyi0krSzCsb7+qzUDirUbyPCSd/J8E5X+/kq+8P\nuF4v/sF/e0N9VVO1tQ5iYhtfuVFcWvcZz/lqp8e8/3zq2a5T5PYocX8NqM7bCpz3Miz4bh9HA3RR\nzj5SSnLbuAY36DZGKMXXScDHAFrrbUqpdKVUmtbau5vEdcCHWusSpVSYwwyuvj7O7y3exSXj+/ok\nl1tCqN/+eMVevt16yHUZNsHPpV1trcN1pn538S6f+VXVtazekuMzHYySl/v/Tv4aCn/7+OIgkfoe\nDI++sY6jReWkJScwsFe63yXdW/D9JXaHw8FafdhnupPNZuN/X+9k4dr9nNg7nVunnsSjb67zWW7z\nnqM87dbX91DeMddBk1dU4bor2GaDO19YQfdOyUy7eIjffQbrdeHtpU+2uJJQQ+UVldMhrW3QZcor\nq5n2dPAufs46ZW/+3kV1TS3z3a563BO7PwdyfRsLAR5z+w72uzUouu9zd1Yhpwyws/dgUaP6oDt/\n8+6H30fL9pCSGM+Xa/xfubkn9qYoKasixmYjLTmBqupaNuw6QkVVDas25/Dn608F/F/hOc35eicH\n845x7dmKxT8EjmlLZh6rNvk/dp28T0T+CmBO7vX4AHuy668ma6xQknsG4H605prTvJP7jcBkt9fj\nlVILgHjgbq110NbJ9PQk4kKoa22MBWv2cWJ/O8XHGtcT5ZBbw9vSH7N95rdrX//A3LO9SgNOO8xu\nb94t+u8t8T1JBBMb7/tVOksPT85Zz6M3n+aa/uLczQ3atndJplOnVNffj79d189+S2Y+v39qqc/6\ndnsqP3lV0/zZrV3iO7cqgTZt4ykoqaSgpJIn5jSuQdtdKHXMgcx6fwP/undS0MLDwYLQ+jm769Ax\nBYAUP20c3iXzcHMvfPz7063MeeRcHglaaKiTkJhATIyNguIKbp251HXVmpiYwOHiSjI6JvGZW1/8\npqiv3eXOF1YC8PBNo7l/tueV3Ptf7+Smi4bUUxiCpeuzuP3KYUHrvUMZ3c0RE8OxmsbXr9jtqfUv\n1AiNqXj2+aUrpUYD291K86uBXK315+a8NwD/xTBTfhNuJAml6u27jdkBewk01ZEj/ktPzWlmPTcz\n3fevuvrP7fU0vLq7deZSn2lHjjSs2+Cib/dSVBw4CW5yu0Rd43aF81OAy/2GaMpdkVm5pVxw9zye\nvfX0gMs88mrwbn7+XPSnecy6fTzLgpQYm0NlVQ0ffe1bvx7I1Q8s8Dv9y9U/8aVXz5amCvU3+sUK\n32qST5fvoX1SfEjrX3TPpw2Ky5/H/FzxNsQ/31/PZRNCr871FujkEEpXyGyMkrpTN8C7EnEK8JXz\nhdZ6u9b6c/PvbwG7UioyxfIQrWhgvWdDfLjMt345mm3a07Bqjuc+2MiiIM/rce9VUVXd+p6LEokH\nV93xzDf19olvDv4aT63kSICqF+/upa2Z+70P4RRKcl+I0UCKUmoYkK219i5SjQBcFapKqXuUUlea\nfw/GKMVHrPNsSz8uIFz1iFbRkH7qQkTSrgORq7O2unqrZbTWq5RS65RSq4BaYLpS6jqgUGs911ys\nK+De6vYO8KZS6g/mPm4Ib9hCCCGCCanOXWs9w2vSBq/5Q7xeHwAmIoQQokVE3eMHhBBCSHIXQoio\nJMldCCGiUFQk96aOeiOEEC0pEgOvREVy75xe/x2iQgjRWjX2ERnBREVyb5PQovdHCSFEk4QyBkFD\nRUVyF0IIK2vqmK7+REVynzyyV0uHIIQQjRaJZsOoSO79jmvf0iEIIUSjRaJTSFQkd+ktI4SwMqlz\nD6g5hpsVQojIqG7C8+ADiYrkHh8XFW9DCPEztfNA+B//LFlRCCFaWGVV+McxkOQuhBAtLJTR5BpK\nkrsQQkQhSe5CCNHCpJ97lElPbdPSITSrn9v7FaIlhZTclVKzlFLfKqVWKaVGeM3LVEotV0otNf91\nr2+dSEhs03qfL/PMraf7nT6sv72ZI2mc6RcPqX+hEDxw3QiO754Wlm1Fo/492rV0CKIZtU9JiOj2\n603uSqnxQH+t9WiMsVCf87PYuVrrCea/rBDXCasnbj6NLumJDVpn6PEdIxSNp7SkBK6ZPMB3Rpga\nUR67aRRJbUIaMbFRhis7CfENv8i7dHxf7O3bAtArI5W05ATuu+YX9MpIDXeIYfHqjDPqXWb4ADu/\nmtQ/7Pt+atppzLhqWMD5iRH8fkMRzva+my4YFMat1elhT2bECZ0btM71554QkVhC8fQt/gt94RLK\nETsJ+BhAa70NSFdK1Vf8asw6TZLcNp4Hrq+7QPjjpYFLm906JTNmcAZXneUn4XqZNKxHSPsP1Ne+\nhz0FgInDevC3G0d6zmxCPduwAXYun9iPaRcNJqNDEi/cMa7xGwvBA9eN4IIxvV2vn7ttbL3rnDeq\nF4//fjTTLhrMXVec7LGtV2dGT+i2AAARLElEQVScwaO/GxlkbTjrF8f5nX7+ab09rgAuOr0PAO8/\n9kvXtLuuODnggdsuSInp4RtO5dHfjeT528fy+O9H8fBvT/WYP/2SIYw+sUvQuIN5+pYxPtNibDY6\npLXFZrPx1LTT6JhWV33Vq0sqT007jaemneaxzgPXeV4MX3uO4p93jgvpBOXt8on96u2t8fQtY7hs\nwvE+03vYkwFIS04I6QRw/mm9GT4gcAJObtv4k9jDN4wkJSk+6DLXnXsCv5syiOduG8vM6WMYPTij\n2a76p044ntl/msCdV5zk8/3VF3djhPJJZgDr3F7nmtOK3Ka9pJTqDawA/hziOh7S05OIi2v8h2y3\ne5YGe3YL/LyZp24bR2pSAlXVNQCMO6U7g/t2pLC0Ev1TPt9vO+Ra9varhvPr8wZyw98W0aNzCgcO\nl/jd5iUT+1FyrIrPV+4F4PopJzJqSAYd0trSNiHOFeNbD53D1Q8sACAx0fcLtacn8upfJ/Pcu+tZ\ntGafz/wJw3pw11XDg30UvrFN6MdHS3e5Xo89uTvLf8xy7S83vyzgun/+zQjs9lTs9lSGnpDB5NP6\nsC+nmD49O/gsO+PaEQzs04HfPPQlAJ07Gwn43M7+z+t2eypD+3Vi464jrve29IcDANx6+cmcNbIX\ni77f71r+zBE9ue1XpwAw85117M4yfk43XDyUGy4e6rHtMcN6EB8Xy7b9BazenOOa/tStY1G9OnD+\nXZ94LJ+WnOB6n97uuHIYs+b84Io5oaQi0Mfl1xsPnM215mfSt1dH5j5xPh8u3snO/QV8tyWHi8Yf\n79qv3Z7KZZMG8NLcTQBMu+wkVB/fK8xfDOnm8XrKuH6u0v2VkxVzFuqA8Qw/oTPrth92vb5myolc\nM+VEj8/khT9N5JYnl7he9+vTiX59OvH+0t0A3HP1L8gvLmfSiJ7M/ngTU8/oz8OvrCbn6DEum9Sf\nE3p1IC0lgb++tIqKSuM4mzC8B1edN8jjKnBAz/bEx8WyZc9RAGb/5Syuun9+wNgvntCPuUt38dxd\nE7h15lKPeXZ7KtdfMJiiY1WMPDGDf3640TVvxKAu3Hf9SGJjfE9B7z02hdpaBxf+aZ5r2pxHzuXK\n//ON46N/TOGSez9zvf7g8SlUVtXw3eaDPPvuj67p54/ty6fL9wBGoUPvy2dov07YbDa6ZvhWv52s\nuvj97TVFY06T3p/O/cACIA+jtH5pCOv4yM8/1ohQDHZ7Krm5xZ7bK6jb3sM3nMrseVs5kFtCt07J\nlJdWUF5qHKCv3DsRm1uxZdLJ3VjYLY3/fb0TgNzcYmxuy/328cWAUQLIKypn3spMACYP647NZmNY\nv450SGtLu+QEcDgoLizDPTL3L/CUfh1Z/P1+enROYcd+4w61Wy8dSm5uMeeP7kVhcTlrthkH4eN/\nGE1K23gS28T6vFenf905HoDPvs2k1uFg/up9pCUnMGVUT3458jhWbc5h9ZYcfjXxeL7flsPU8ccz\nYmAXftiRy+vztwNw2cTjeX+JcQDbgNOGdvPYX1KsjRO6p5GbW0xcrI3qGgeThvfg12f2x2azUVNR\nN6JMoDjdnT44g427jjCgRzs6t2vrmn5y3w7k5hYzqHc6WzPzAchIb+vaZlVltc9+3D/bI0dKiIuN\n4TeTFeOGdOWJOesB43eW66eUdOWk/gHjHdyzHVNO682gXunk5hYHHTWnXUoChSWVXDyuL3OXGQd3\ntdtncuRIMTE2G1ecpTh8uIiLT+9Np/aJHvse1q8jl0/sx8hBXUhPSfCJa9gAO7m5xTx4/QiOFpVz\nUr9OlBSV4Sx2pJgl0cQ2sZRVGIl1YK90tv1kfI7TLxrs+h3369HOZ/v/vnsC8bGeh6z3Mn06J3NC\njzSOlZRz9ZlGNdXtU4eyfONBzjylu+tK9i9XD+eBV9dw9eQBnDGsB8dKyimurbthZ8avhzF32R5X\ncq84VsFtU4fy7AcbSU9tQ35x3Yn0xD4dOH9UT84f1ROANvGxVFTVMGlYDy4/o58rxgd/N9on3uqq\nGvKO+i+Y+VNWWsEtlwzh01WZ9OmaxtL1WQzp25G8vFKP5YrMPHNSH8/Cjvvvs7iojG7t23LkiO/+\nrzl3IG/O30avTkkhHS/+BDophJLcszFK3U7dgIPOF1rrN5x/K6W+AIbUt04kpSbFU3ysirTkukvv\nHvYU1xm7k1sCATwSu9PkEccx8ZRuxMXGBFxu3EndWLYh22d+n66h1z45q1Mqq2p456sdnDGsB907\nGZe5KYnx/OHCwazZZhyEyW3jSKrnktU5aMml443L5zGDu9I+pY0rvjFDujJmSFcAXrxjvMd7ySsq\np/9x7RnYK52fcopJSYx3bSeQ2NgYqmtqiI2x+f0cQzFc2XnsplF0Tk+kvKKaH3fmcoFZzQJG9cq+\nQyUUlFQwpG9dCfaSccdTUFLJFRP7eWzvqrMGsDur0PV9t0mI5YRe6Y2Kzclms3HJuL6u187L+OED\n7PTr0Y53F+/ixD4dOHdkT/r3aMeO/YWonu05nHfMp33B5rVdf6OIxcXGcM7IngHjudpsv+nZJZWe\nXXwP7FMHdubg0VJGn5jBff/5DoBfKLsruYPxOb29aAe/P/9En/Wdifn803rz6apMjyqb/j3asfNA\nod9qyM7pST6/meM6p/hUFcXGxPDAdSNcx6j3T+ekfp14+d6JxLgVpmbdMobUJM/qtBgzBFuM/2rR\nC0/vwycr9pr7qP/3ee+vT+Ef76x3vR42wM6wAXYcDgfnjexJx3ZG1dmU03rz2apMn/XbpyRQUFIJ\nwMRTuvPlmn1cNLavz3LuLj9zAOOHZhATgbuYQknuC4GHgH8rpYYB2VrrYgClVDvgPeB8rXUlMB74\nAMgKtE6kPXLjSA4eKaVz+0T+eu0vXD8Ah1nBHepHGB+giuiZP55OeZVRGho1qAtbM/M4M0DdcCCz\nbhnD0aIK12V0Qnws15070O+yT007jdyCMpLbNrxOrpt5ogiF+4/wDxcObvC+3M364+kcKw9tTEib\nzUZGByPBJbWNZ8bVw33m98pIpReeSSw9tY1HPb7TpOE9mDQ8cDuJs5ljwHHtXVdLxn5CChcwktPL\n90wkJsZGrcNB145JDDiuvav67USzFHfDlLqGw+S2cZSWVzf6JAhGsl2z/TBpScF7WcTGxHDJOM8k\nm9ExmYtO7+M60U0a3oMzzKtNpweuG0FVdV2p2l+T0IyrhlFd4yDGT/VGQ9TXqO5Mdi/eMY4Ymy3o\naGu2AEe1Z3KvP6auAY4Xm81Gp/Z1nTXGDMngs1WZpHldAT41bYwrzzh/I6F835FI7BBCctdar1JK\nrVNKrQJqgelKqeuAQq31XLO0vlopVQasBz7QWju814lI9H6kJSWQ1tP48fft5laKNn+pTTm4wKib\ndW41IT62UYmwXUob2qWE1ue7Q1pbOqS1rX/BVqRdcoJRLdWKnNCzPdv3FWA3D9IZVw2j1uHgxn8s\nqWdN/5zJLcZmY+jxnepd/qnpYyg3654b6+Jxfbl4XPCSoLe7f3Uyq7ccQplXZe68jwXvhDvh5G6s\n2nyQ6VNP9lgnPi68yejUgV2YtzLTdUXiLlgvoVsvHcqcr3Zy9qn1F65COe7TkhK44/KT6Noh+JjM\nXdKTuOfKU3wKT8Zvom4/Tc01TRVSnbvWeobXpA1u854Fng1hnRaVap5lU/w0YorGaZeUwOHKMtpa\nYAzbu351MqVl1R7VdZEqMfnTJj6WNvHN/zkN6t2BQb19G79D0SGtLU9NG+O3TSucunVKdlXDNITq\nmc6DXr2ZAjl7RGhX1+5Vf8E0taqvOfxs7lC9/ryBTBzWncsmBq9HFqG7/fKTmHByN84+NXD9cGsR\nGxPjkdidBvc1El93s8uqaBmRPtF2t4deRRktWvbOiGbUIa0t10xWLR1GVMnokMS157TcTSDhcOul\nQ8kvrnBV14jo8rspg9iVVehqD/k5+fm9YyHcxMXGSGKPYqMHZzB6cEb9C0ahn021jBBC/JxIchdC\niCgkyV0IIaKQJHchhIhCktyFECIKSXIXQogoJMldCCGikCR3IYSIQjZHJIbdFkII0aKk5C6EEFFI\nkrsQQkQhSe5CCBGFJLkLIUQUkuQuhBBRSJK7EEJEIUnuQggRhSw/WIdSahYwCmMI7Nu01mtbMJYn\ngLEYn+vfgbXAm0AscBC4RmtdoZS6CrgdY/Dw2VrrV5RS8cDrQC+gBrhea70nwvEmApuBR4CvW2us\nZgz3ANXA/cDG1hirUioFeANIB9oADwE5wL8wfp8btdY3m8v+CbjMnP6Q1voLpVQ74B2gHVAC/Fpr\nnReBOAcDnwCztNYvKKWOo4mfp1LqJH/vM0KxvgbEA1XA1VrrnNYQq7943aafDSzQWtvM1xGP19Il\nd6XUeKC/1no0cAPwXAvGMhEYbMZyDvAM8DDwotZ6LLAL+K1SKhkjQZ0JTADuUEp1AH4NFGitTwce\nxTg5RNpfAWfyaJWxKqU6Ag8ApwNTgAtba6zAdYDWWk8EpmIMHP8MRqFjDNBOKXWuUqoP8Cu39/S0\nUioW42Bfasb6EXBvuAM0P6fnMU7mTuH4PH3eZ4Ri/RtGMhwPzAXubA2xBokXpVRb4M8YJ06aK15L\nJ3dgEvAxgNZ6G5CulEproViWYZTEAAqAZIwvbp457VOML3MksFZrXai1LgNWAmMw3stcc9mvzGkR\no5Q6ARgEfG5Oaq2xngl8pbUu1lof1Frf1IpjPQJ0NP9Oxzhx9nG7mnTGOhGYr7Wu1FrnAj9hfBfu\nsTqXDbcK4Dwg223aBJrweSqlEvD/PiMR6zTgQ/PvXIzPuzXEGihegL8ALwKV5utmidfqyT0D4wt2\nyjWnNTutdY3WutR8eQPwBZCsta4wpx0GuuIbs890rXUt4DC/2EiZCdzp9rq1xtobSFJKzVNKLVdK\nTWqtsWqt/wf0VErtwjjZ3w3kB4spyHTntHDHWG0mFHdN+jzNaf7eZ9hj1VqXaq1rzCud6RjVWC0e\na6B4lVIDgJO01u+7TW6WeK2e3L3ZWjoApdSFGMn9Fq9ZgWJr6PQmU0pdC3yrtd7bwH03e6zmtjsC\nl2BUe7zmtb9WE6tS6mpgn9a6H3AG8FYTYmqp33I4Ps+Ixm4m9jeBxVrrr/0s0mpiBWbhWYjyJyLx\nWj25Z+NZUu+GWa/VEsxGk/uAc7XWhUCJ2WgJ0B0jXu+YfaabDSs2rXUlkfFL4EKl1GrgRuD/WnGs\nh4BVZqloN1AMFLfSWMcAXwJorTcAiUCnYDEFme6c1hya9N1jHHMd/SwbKa8BO7XWD5mvW2WsSqnu\nwAnA2+ax1lUp9U1zxWv15L4Qo+EKpdQwIFtrXdwSgZg9HZ4Eprj1cPgKuNT8+1JgAfAdMEIp1d7s\nXTEGWI7xXpx19ucDSyIVq9b6Cq31CK31KOBljN4yrTJWc19nKKVizMbVlFYc6y6M+lSUUr0wTkTb\nlFKnm/MvMWNdDPxSKZWglOqGccBu9YrV+b6aQ5M+T611FbDdz/sMO7OXSaXW+gG3ya0yVq11ltb6\neK31KPNYO2g2BDdLvJZ/5K9S6nFgHEaXoulmiakl4rgJeBDY4Tb5NxjJsy1Go9n1WusqpdRU4E8Y\ndWrPa63fNi81Xwb6YzTMXKe13t8McT8IZGKUON9ojbEqpX6PUdUFRm+Jta0xVvNAfRXogtEd9v8w\nukL+G6Mg9Z3W+k5z2T8CV5mx/lVr/bW5/lsYJbUCjG5+hWGOcThGe0tvjK6EWWYcr9OEz1MpNcjf\n+4xArJ2BcqDIXGyr1npaS8caJN5LnIU9pVSm1rq3+XfE47V8chdCCOHL6tUyQggh/JDkLoQQUUiS\nuxBCRCFJ7kIIEYUkuQshRBSS5C6EEFFIkrsQQkSh/wecT9/Avx2ZGwAAAABJRU5ErkJggg==\n", 644 | "text/plain": [ 645 | "
" 646 | ] 647 | }, 648 | "metadata": { 649 | "tags": [] 650 | } 651 | }, 652 | { 653 | "output_type": "stream", 654 | "text": [ 655 | "Model took 7.3842 mins (0.1231 hrs) to finish training with best train accuracy of 89.0625%\n" 656 | ], 657 | "name": "stdout" 658 | } 659 | ] 660 | }, 661 | { 662 | "metadata": { 663 | "id": "bUXii0BC2nhZ", 664 | "colab_type": "text" 665 | }, 666 | "cell_type": "markdown", 667 | "source": [ 668 | "## 6. Test" 669 | ] 670 | }, 671 | { 672 | "metadata": { 673 | "id": "q3Qlsqzo2nha", 674 | "colab_type": "code", 675 | "outputId": "6dffc841-769c-40b1-e470-0efcde353ce1", 676 | "colab": { 677 | "base_uri": "https://localhost:8080/", 678 | "height": 402 679 | } 680 | }, 681 | "cell_type": "code", 682 | "source": [ 683 | "# Testing doesn't require the use of gradients since weights aren't being updated\n", 684 | "model.eval()\n", 685 | "with torch.no_grad():\n", 686 | " correct = 0\n", 687 | " total = 0\n", 688 | " \n", 689 | " for img, label in mnist_test_dataloader:\n", 690 | " img = torch.squeeze(img).to(device)\n", 691 | " label = label.to(device)\n", 692 | " \n", 693 | " # Forward\n", 694 | " out = model(img)\n", 695 | " \n", 696 | " # Test\n", 697 | " _, predicted = torch.max(out.data, 1)\n", 698 | " total += label.size(0)\n", 699 | " correct += (predicted == label).sum().item()\n", 700 | "\n", 701 | " # Accuracy\n", 702 | " print('Test Accuracy: {:.4f} %'.format(100 * correct / total)) \n", 703 | "\n", 704 | " # Show 4 test images\n", 705 | " fig, axes = plt.subplots(nrows=2, ncols=2)\n", 706 | " img_cpu = img.cpu()\n", 707 | " label_cpu = label.cpu()\n", 708 | " print(\"Shapes: img {}, label {}, predicted {}\".format(img_cpu.size(), label_cpu.size(), predicted.size()))\n", 709 | " print(\"Label: {}\".format(label))\n", 710 | " print(\"Predicted: {}\".format(predicted))\n", 711 | " for i, ax in enumerate(axes.flat):\n", 712 | " ax.imshow(img_cpu[i]) \n", 713 | " ax.set_title('Target: {} - Prediction: {}'.format(label_cpu[i], predicted[i]))\n", 714 | " ax.set_xticks([])\n", 715 | " ax.set_yticks([])\n", 716 | " plt.tight_layout()" 717 | ], 718 | "execution_count": 141, 719 | "outputs": [ 720 | { 721 | "output_type": "stream", 722 | "text": [ 723 | "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:34: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n" 724 | ], 725 | "name": "stderr" 726 | }, 727 | { 728 | "output_type": "stream", 729 | "text": [ 730 | "Test Accuracy: 82.5300 %\n", 731 | "Shapes: img torch.Size([16, 28, 28]), label torch.Size([16]), predicted torch.Size([16])\n", 732 | "Label: tensor([2, 0, 0, 6, 2, 4, 9, 6, 2, 2, 7, 7, 0, 4, 8, 9], device='cuda:0')\n", 733 | "Predicted: tensor([2, 0, 0, 6, 2, 1, 9, 8, 2, 2, 7, 7, 0, 4, 5, 9], device='cuda:0')\n" 734 | ], 735 | "name": "stdout" 736 | }, 737 | { 738 | "output_type": "display_data", 739 | "data": { 740 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVEAAAEYCAYAAADlIcXmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHlJJREFUeJzt3Xu8VHW9//EXAookchOOmpmK8gHp\ngp3QRLAItTRJH15AMysLDnCQ1MLwWtSjTFMpf5qKmrfwgXgv4eEtT4meyKSoOCofBJFCFDRuAt64\n/P5Ya9d8v3szs2d/Z29m9n4/Hw8f7vfMmrW+s5jvZ77znTVrtdu2bRsiItI0O+3oBoiI1DIVURGR\nBCqiIiIJVERFRBKoiIqIJFARFRFJ0KHYnWZ2AzAsj32AFcDbeR7k7m81Y9sws/5AT3d/psRyHYGf\nAcPJ3hh+A0x09y1lbKsD8D6wGNiSr2ctMNndf9u0Z/CvdU8B9nH30Wb2O+Acd/9rkeXHuPvN+d8l\nl09o11DgaqALsBE4t9S+lprqF+2AK4ATgG3Afe5+SZnbaov94hDgeqAnsAoY6+7Pb2/5okXU3ccX\nrPgV4Mst3MlOBjYDpbb5bWB/4GNAO+B3wFeA25qwzaHu/jqAmR0JPGhmfd39n01YVz3u/pli9+dv\nCFcANzdm+aYys12Bh4AT3f1pM/sicBfw4ebYXmtSQ/3iDOAI4KNk/eJpM5vn7g81YZttpV+0A2YC\n33L3WWZ2Elm/GLi9xxQtoo3YYH/gFqBHvq5L3H1mwbvXRcDX3N3M7AvANGA92ejnWqCvuy83s/HA\nOUAnshfGaOBY4HzgXTPrAVwIPA8Mcfc3o6b8FnjQ3d/L2/UcMCDluQG4+xwzWwYcZmaLgP8hKzwD\n3H14/mKaCnQD3gBOd/dXzKwzcAdwKLAUeKlgny0HTnH3P5jZWfk+2gbMBcaQjaK7mdlC4Bjg9wXL\nnwZcku/rV4HR7r7UzH4I7A7sS/ZGsgo4wd1Xmtk5QHd3nxI9vQ7545/O8zPAvma2m7tvSN13bVkV\n9YtTgdsK+sX0/LamFNF/aeX9YiDQ2d1n5c/1ATO72cwOcveXaEDqnOhUsuLVHxgL/MLM2hfcvyV/\noXQkGxWe5e4HkxW4TgBmNgz4LvAZstHkO8AUd38QeBiY6u7fcfct7t6vgRcK7v6su3u+vo7AUcCz\nic+tTkfg3fzv/wD+mL9QugK/As539wPJhv9358uNJvsocABwCvC5eKVmdiBwOTAU6Ad0ByYAXwfe\ny5/r3wuW3x+4Afiiu/cDHs9znVOBiWQfL9cCXwNw92saeKHg7m/l+7jOscALKqAVURX9AugLLCnI\nS8hea5XQKvsF9fcZwMsU2W+pRfQLwE/zv58GPkC2Q+vMyv/fD2jn7k/k+TqyjxcAI4AZ7v66u28D\nbgROakpj8qH4jWQ74f6mrCNa3/Fko4k/5DftDNQVnk8DSwvmhaYDA8xsb+BI4P78Bf4mMLuB1R8D\nPJ0/763ASLJRyPYcA/zG3V/O8y3AcDOr+zf8nbv/I9+H88nefRv7PAcCVwHjGvsYKapa+kVnsuJb\n5+28LUlaeb+I9xmU2G9JH+eB44CLzGwPYGt+W2FhXp3/vzuwpuD2Vwv+7gaMMLPjCh6/c7kNyd/V\n7yAbvp+S/wPEy9wF/GceP+3uKxtY1dNmVjeBvhQ41t03mhnAu+6+saDdln+8qLMJ6EX2AltXcPsa\noHe0nT3I3hkBcPd38jZu7yn2ItyHa8j+/XrkuXB7W4DCkc92mdkQspHC1wo+2kuaaukXG8lHtrnO\nQL1PGuoXgXifwXb2W50mF1Ez2wW4l+yLicfyLyo2bWfx9cBuBXmvgr9XAL9w9wua2pbcrWT/wCe6\n++aGFnD3Mxqxnn9NoJewAljg7p+K7zCzNUDXgpt6NfD4N4FPFDymK/X/8QqtBA4pyD3J5tdWN7x4\nafkIdCYw0t3/t6nrkX+rsn6xEDiQ7DsDgIOAF+KF1C8Cdfusbvs7kU0F1NtvdVI+zncBdgHm5R+j\nzwHeI3xRFDbsA/moB7J5ojq/Ak4xs555o08ys0n5fe+TvbMVZWYjyZ7omdsroM1gLtkXMZ/M23Cg\nmd2R74u5wAlmtpOZ9Sabb4zNBo40s33zx9wMfJXsOXcws/jjw+PAMDOr+/Z8HPBoQyPuxshfHL8E\n/ksFtKKqpl8A9wBjzWxXM9ud7AuaGeU+oTLVdL9w978B6/OaAtlc7KKC6YJ6mlxE8zmNqcBfgT8B\nTjbh/QjRO0c+JP9vYLqZzQf+L79rm7s/B1xJ9nHhReCbwK/z+38NnG1md5tZezNbmH9Eio0lK6IL\n8mUWmtnNTX1ujZF/fBkJ3JC3+z7g3nzuZRrZ6GMp2ajkgQYev4xsnzxFtu/eBa4BlpN9KbbczA6N\nlh8HzMq3dzgwPl5vzMzOsex4vNgRwMHA1QX7bKGZfbyRu0AaUGX9YibZN+cLgHlkc6yPVO7Z1tcK\n+gXAacC3zewlsgJ+ZrF1tdsR5xPNh+hrgd0K5lJE2jT1i9rUYj/7NLP5ZnZyHkeRzZvohSJtmvpF\n7Uv9dr4c5wLXmtmPyd5tz2rBbYtUK/WLGrdDPs6LiLQWOouTiEiClI/zGsLWvnalF5EyqE/UvrL7\nhEaiIiIJVERFRBKoiIqIJFARFRFJoCIqIpJARVREJIGKqIhIAhVREZEEKqIiIglUREVEEqiIiogk\naMlT4bWoZ58Nr5g8ZsyYIHfp0iXIl19+eZAHDx5cb53t2zfq2m8i0oZoJCoikkBFVEQkgYqoiEiC\nlDPbV/W5E0eMGBHk2bNnB7ldu+KnDbzkkkvq3XbxxRcHeeedd25i66qGzidaWVXdJ0qZNWtWkCdM\nmBDkTp3qX/49/u6hW7fGXMm5qul8oiIiLUlFVEQkgYqoiEgCzYmWoU+fPkG+9NJLgzxq1Kgg18Cc\nqeZEK6uq+0TsjTfeCPIhhxwS5BUrVgS5oVpxww03BHncuHFB3rhxY5A3bdoU5F69ejWusS1Hc6Ii\nIi1JRVREJIGKqIhIglY7J3r99dcH+bvf/W6Q16xZk7yNeN+dcMIJQb766quDfMABByRvs8I0J1pZ\nVd0nVq1aFeTDDjssyMuWLSt7nTvtFI7DOnfuHOStW7cWzRdddFGQGzo+u4VpTlREpCWpiIqIJFAR\nFRFJ0GrnRGOrV68O8tq1a4suf+edd9a77dZbbw3y8uXLgxwfezpkyJAgP/nkk0Hu0GGHn85Vc6KV\nVdV94gc/+EGQp0yZsmMaUiCeU43btAPmSDUnKiLSklRERUQSqIiKiCRoM3OizWHy5MlBvvLKK4su\n/8c//jHIn/zkJyvepjJpTrSyqqpP/OQnPwlyfD7cLVu2tGRzGiWeI33iiSeCPGzYsOZuguZERURa\nkoqoiEgCFVERkQQ7/EDFWvaJT3wiyOVet+nRRx+teJtE6lx33XVBLjUHeuaZZwY5fn3G5x9tyLnn\nnhvkiRMnBvnoo48O8ssvvxzk+Lf1GzZsKLnNHU0jURGRBCqiIiIJVERFRBLoONEES5YsCXLfvn2L\nLh9fw2bevHkVb1OZdJxoZe3QPjFnzpwgDx8+PMil5kRfffXVIC9atCjIDR2jed555wX5xz/+cZDj\n64zF12SKr20f+/jHPx7k+fPnF12+AnScqIhIS1IRFRFJoCIqIpJAx4km2G+//YI8YMCAID///PMt\n2Bpp65566qkgl5oD3XfffYMcXx+pW7duQR4xYkS9dUyaNCnI8RxobMyYMUGeO3dukKdPnx7k1157\nrej6qoFGoiIiCVRERUQSqIiKiCTQnGiClStXBnnp0qVFlx81alRzNkfamPg6YVOnTi3r8fFv47t2\n7RpkMwvyvffeW28dpeZAY/F1xbp3717W46uRRqIiIglUREVEEqiIiogk0JxogscffzzIGzduLLr8\nnnvu2ZzNkTbmscceC/K6deuKLn/66acHuU+fPkWX79SpU9MaVob4t/PXXnttkFevXh3k3//+90Ee\nPHhw8zSsDBqJiogkUBEVEUmgIioikqDVzIm+9dZbQb7sssuCHP8OeObMmUXXN27cuHq3PfLII0GO\nz99Y6hpLF1xwQZDj4/zi4/ROPfXUIO+6665F1y9STO/evYPcsWPHHdSSf4vbFNu8eXOQ43OeVgON\nREVEEqiIiogkUBEVEUmgIioikqDVfLEUXzDrtttuC3J8Qb5SXwKNHz++5DbLXefrr78e5PjA4sWL\nFwf5ox/9aJDjC92JyI6nkaiISAIVURGRBCqiIiIJWs2caKmTL1SDXXbZJcjxRbvOPvvsosuLFJo2\nbdqObkKye+65J8jlfs9QDTQSFRFJoCIqIpJARVREJEGrmRN95plnylr+8MMPD/KHPvShIMdzNY0R\nnyAkvhDYfvvtF+R99tmn7G2I1Jk3b96ObkLZ4hOKxCdZjudA45PyfPazn22ehiXQSFREJIGKqIhI\nAhVREZEENTsnGs8HrVy5sujy8fFnX/7yl4O8dOnSoss3Zp33339/kIcMGVJyHSKt2ZYtW4Icny/i\nzjvvLPr4iRMnBrlnz56VaVgFaSQqIpJARVREJIGKqIhIgpqdE40vstW+ffsgb926tejjJ0yYUPT+\nxvxmt3///kEeNGhQyceIVMro0aOD/LOf/azo8hs3bgxyfHHHzp07BznuU42xYsWKIF966aVBjs/z\nG9tpp3BcN2zYsLLb0NI0EhURSaAiKiKSQEVURCRBu8YcD7kdTX5gc4jnTubMmRPk5jhP4YIFC4J8\n8MEHJ6+zhVX/yRprS4v2iYcffjjIJ554YtiYMvv2uHHjgjx06NAgDx48uN5jbrnlliDfddddQX7l\nlVeKbrN3795BnjFjRpB3wJxo2X1CI1ERkQQqoiIiCVRERUQStJo50XguJp7fiY+RKzUnutdee9W7\n7cknnwyymZXTxGqkOdHK2qF9YvLkyUG+6qqrgpzQ15ssPu7zoIMOCnL8W/qjjjqq2dtUguZERURa\nkoqoiEgCFVERkQStZk409uyzzwZ57NixQV68eHGQ4/mjM844o946u3TpUqHWVQ3NiVZWVfWJAw88\nMMgvv/xyxbfRoUN4+o0+ffoE+brrrgvy8OHDK96GCtOcqIhIS1IRFRFJoCIqIpKg1c6JSqNoTrSy\nqrpPxNcAW716dZC///3vB7lHjx5Bjq93BHDqqacGuVu3bilNrAaaExURaUkqoiIiCVRERUQSaE60\nbdOcaGWpT9Q+zYmKiLQkFVERkQQqoiIiCVRERUQSqIiKiCRQERURSaAiKiKSQEVURCSBiqiISAIV\nURGRBCqiIiIJVERFRBKoiIqIJFARFRFJoCIqIpKgQ+lFtkvnohQJqU+0QRqJiogkUBEVEUmgIioi\nkkBFVEQkgYqoiEgCFVERkQQqoiIiCVRERUQSqIiKiCRQERURSaAiKiKSoOhv583sBmBYHvsAK4C3\n8zzI3d9qxrZhZv2Bnu7+TInl2gFXACcA24D73P2SMrfVAXgfWAxsIXuDWQtMdvffNqH5heueAuzj\n7qPN7HfAOe7+1yLLj3H3m/O/Sy6f0K5DgOuBnsAqYKy7P1/p7bQ2tdIv8mVPAS4nez3PA77u7hvK\n2FZb7BddgV8AhwKbgAvc/aHtLV+0iLr7+IIVvwJ8uTH/cBV0MrAZKLXNM4AjgI+SnQTiaTObV+yJ\nFzHU3V8HMLMjgQfNrK+7/7MJ66rH3T9T7H4z60j2hnBzY5ZvqvyNZybwLXefZWYnAXcBA5tje61J\nrfQLMzsQ+H/AEGBp/vdxwD1N2Gab6Be5nwHL3P0UMzsYuNbMHnb3LQ0tnHIWp7p3xFuAHvm6LnH3\nmQXvXhcBX3N3M7MvANOA9cDVwLVAX3dfbmbjgXOATmQvjNHAscD5wLtm1gO4EHgeGOLub0ZNORW4\nzd3fy9s1Pb+tKUX0X9x9jpktAw4zs0XA/+TrHODuw/MX01SgG/AGcLq7v2JmnYE7yN7JlgIvFeyz\n5cAp7v4HMzsr30fbgLnAGOA3QDczWwgcA/y+YPnTgEvyff0qMNrdl5rZD4HdgX2Bj5GNKk9w95Vm\ndg7Q3d2nRE9vINDZ3Wflz/UBM7vZzA5y95eQJquifnEmcI+7v5zniZV4fq25X5jZrmS148P5c30B\nGF5sf6TOiU4FHnT3/sBY4Bdm1r7g/i35C6UjcBtwlrsfDAwge2FgZsOA7wKfAfYH3gGmuPuDwMPA\nVHf/jrtvcfd+DbxQAPoCSwryEqBf4nOr0xF4N//7P4A/5i+UrsCvgPPd/UCyj8V358uNJvuIfABw\nCvC5eKX5KOFyYGje1u7ABODrwHv5c/17wfL7AzcAX3T3fsDjea5zKlkn6UP2cetrAO5+TQMFFOrv\nM4CXqdx+a8uqpV98HNhsZr8xs0Vmdn1eJCqhtfYLA94C/svMXjCzZ/N/i+1KLaJfAH6a//008AGy\nHVpnVv7/fkA7d38iz9fx73MvjgBmuPvr7r4NuBE4qcx2dCZ7kdV5O29LEjM7nmw08Yf8pp2BB/O/\nPw0sLZgXmg4MMLO9gSOB+/MX+JvA7AZWfwzwdP68twIjyUYh23MM8JuCUcUtwHAzq/s3/J27/yPf\nh/PJ3n2LifcZVGi/SdX0i27AUcDpwCfy7U0ucx31tPJ+0Q3YA1iXv7H9AHjAzLpt7wFJH+fJ5lcu\nMrM9gK35bYWFeXX+/+7AmoLbX40aPcLMjit4/M5ltmMj+Tt4rjNQb/LczO4C/jOPn3b3lQ2s62kz\nq5tAXwoc6+4bzQzgXXffWNBuyz9e1NkE9CJ7ga0ruH0N0Dvazh5k74wAuPs7eRu39xx7Ee7DNWT/\nfj3yXLi9LUDhyKch8T6D7ew3KVu19It1wG/d/Q0AM7sROBeYUriQ+kVgHdkb2bR8+7PN7DWyKYjH\nG3pAk4uome0C3Auc6O6P5R8TNm1n8fXAbgV5r4K/VwC/cPcLmtoWYCFwIFD37ncQ8EK8kLuf0Yh1\n/WsCvYQVwAJ3/1R8h5mtAboW3NSrgce/STY6qHtMV+oXtUIrgUMKck+y+bXVDS9eUt0+q9v+TmQf\neertN2m8KusXywhfh1vy/wLqF4G/k71RfIDs3we2s9/qpHyc7wLsAszLv+k9B3iP8EVRZyHwATMb\nkuexBff9CjjFzHoCmNlJZjYpv+99sne2Uu4BxprZrma2O9lE9Ixyn1CZ5gL7mtknIZvLMbM78n0x\nFzjBzHYys95kXwbEZgNHmtm++WNuBr5K9pw7mFn8sfpxYJiZfTjP44BH8488ZXP3vwHrzWxkftPX\ngUUFH4ukaaqpX8wETjezvfM52a+TfUHTnGq9X/wTeBKYlLd/MPBBssPDGtTkIprPaUwF/gr8CXCy\nCe9HiN458iH5fwPTzWw+8H/5Xdvc/TngSrKPCy8C3wR+nd//a+BsM7vbzNqb2cL8I1JsJtk3hAvI\nnuwMd3+kqc+tMfKPLyOBG/J23wfcm8+9TCMbfSwlG5U80MDjl5Htk6fI9t27wDXAcuBZYLmZHRot\nPw6YlW/vcGB8vN6YmZ1j2fF4DTkN+LaZvUT2Qj2z9DOXYqqpX7j7/wI/IiteLwKv5etsNq2kX5wF\nDLHs8LXrgZHuvm47y9Ju27ZtpbZXcfkQfS2wW8Fcikibpn5Rm1rsZ59mNt/MTs7jKLJ5E71QpE1T\nv6h9qd/Ol+NcsiP/f0z2bntWC25bpFqpX9S4HfJxXkSktdBZnEREEqR8nNcQtva1K72IlEF9ovaV\n3Sc0EhURSaAiKiKSQEVURCSBiqiISAIVURGRBCqiIiIJVERFRBK05M8+RaSVee+994J8+umnB7ln\nz55Bvummm5q9TS1NI1ERkQQqoiIiCVRERUQSaE5URJrs8ssvD/KDDz4Y5B/96Ect2ZwdQiNREZEE\nKqIiIglUREVEEqSc2b5Vnztx/fr19W4744zw8tyzZ88uuo5437ZrF56q8LHHHgvy0UcfXU4TK0Hn\nE62sVt0nnnvuuXq3HXfccUHetGlTkOfOnRvkj33sY5VvWGXpfKIiIi1JRVREJIGKqIhIAs2J5rZs\n2RLkkSNH1lvmoYcequg2u3btGuRvfOMbQZ48eXKQ99hjj4puH82JVlqr6hPx7+IPOeSQesu8+OKL\nQY6PC73wwgsr37DmpTlREZGWpCIqIpJARVREJEGbmRPdvHlzkK+99togX3bZZUFevXp1yXXGc5RD\nhgwpuo45c+aUXGeh/fffP8iLFy8u6/GNoDnRyqqpPlFKfO7PcePG1Vtm1113DfKSJUuCvOeee1a+\nYc1Lc6IiIi1JRVREJIGKqIhIgjZzPtFbb701yJMmTUpeZ3zuxMGDBwd5w4YNQZ41a1aQV61aFeTz\nzjuv6ONFmlN8XOgTTzxR8jHxdws1OAeaTCNREZEEKqIiIglUREVEErSa40Tj40CnTZsW5O985ztB\nfuedd4qu76tf/Wq9266++uogd+vWLcjx+UJLiY8j7dWrV5Dbt28f5Ph6Nt/61rfK2l4DdJxoZVVV\nnyhX/Dv4AQMGBHnMmDH1HnP99dcHOX7N1iAdJyoi0pJUREVEEqiIiogkaDXHicZzoN/85jfLevzE\niRODPGrUqHrLdO/evfyGJYjPcbpy5coW3b60bvFxod/73veCHP8ufvz48fXW0QrmQJNpJCoikkBF\nVEQkgYqoiEiCmp0Tvfvuu4McHwdaSnxuxCuuuCLIu+yyS9MaJlIjrrzyyiDfd999Re8fOHBgs7ep\nFmkkKiKSQEVURCSBiqiISIKanRNdtmxZkEv9Fr5Lly5B/tKXvhRkzYFKa/f6668HOT4XRHxcaNxH\npGEaiYqIJFARFRFJoCIqIpKgZuZEX3vttSDH10yK7bzzzkGOj4E74ogjKtMwkRrx05/+NMhr164N\n8lVXXRXkvfbaq+Q64/M7zJ8/P8g//OEPg9yzZ8+i9zdmm9VGI1ERkQQqoiIiCVRERUQS1Myc6OzZ\ns4O8ePHiosvHv/s96qijKt6mVH/+85+L3h8fuzp06NDmbI7UmPfffz/I8bk933zzzSDfdNNNQe7U\nqVOQTzvttKLbi+c/of53E2PHjg1yfA7e+Hju3XffPcjxvG0t0EhURCSBiqiISAIVURGRBDUzJ3rh\nhRfu6CZU3MUXX1z0/ni+6Pjjj2/O5kiN2bp1a5DjOdH4uM9169YFOb4O2d577110+YZer/F154cP\nHx7k22+/PcjxeYAvuuiiIJ9//vlF21SNNBIVEUmgIioikkBFVEQkgYqoiEiCqv1iKT6B7ObNm4su\nH38JM2DAgIq3KVU8Cf+nP/1pB7VEWoP4xxjr168P8o033lj08ZMmTSp6f3xwfvz6BTj55JODfMst\ntwS5Y8eORe8fNmxYkGvhi6SYRqIiIglUREVEEqiIiogkqNo50fgiWvF8T6xfv35BjudadoT4pLfx\niaG3bdtW9PGlDsYXKXTHHXcEecOGDUEeNGhQkPfcc88gx99DTJkyJcjxwfwAEyZMCHJ8Qcj4Nbxo\n0aIgxydlrkUaiYqIJFARFRFJoCIqIpKgaudEW4P4BA5PPfVU0eX32WefIJ900kkVb5O0XqtWrSp6\nf48ePYLcoUPY/WfMmBHkt99+O8iHHnpovXXGF3y85pprgnzFFVcEefz48UGOjzOtRRqJiogkUBEV\nEUmgIioikkBzos3ol7/8ZdH7e/fuHeRHH300yB/84Acr3iZpu5YtWxbk+KJxpc5PEc/xAwwcODDI\nL774YpA///nPBzk+9rQ10EhURCSBiqiISAIVURGRBJoTLUP8W+QFCxYE+dxzzw3yX/7yl6LrO+yw\nw4Lcv3//hNZJW3fQQQcVvX/hwoVBPvzww4P80ksvFX28u9e7LT7/w9FHHx3kBx54IMidOnUquo1a\npJGoiEgCFVERkQQqoiIiCVrNnOiKFSuCvGTJkiD36dMnyK+++mqQS/2uHWDatGlBfuaZZ8ppIiNG\njAjy9OnTy3q8SDFf+cpXghz3ifjcnvGcfbt27ZLb8NZbbwX5b3/7W5Ab+v19rdNIVEQkgYqoiEgC\nFVERkQRVOye62267BTmer4mPT1u+fHmQP/e5zwX5Ix/5SJDjOdMXXnihSe0sFF9f5lOf+lSQ77zz\nziDHz1Gkki644IIgx3Om//jHP8paX0PngoivRR//Hn/jxo1lbaMWaSQqIpJARVREJIGKqIhIgnal\nrn1eRJMf2BTxtVluuummltw8UP+aRwcccECQzzrrrCD369ev2duUKP3AQCnUon1CmkXZfUIjURGR\nBCqiIiIJVERFRBLUzJxofA3ss88+O8i333570vp//vOfB3nQoEH1lunbt2+Q4+NCa5DmRCtLc6K1\nT3OiIiItSUVURCSBiqiISIKamROVZqE50cpSn6h9mhMVEWlJKqIiIglUREVEEqiIiogkUBEVEUmg\nIioikkBFVEQkgYqoiEgCFVERkQQqoiIiCVRERUQSqIiKiCRQERURSaAiKiKSQEVURCRBh4TH6lyU\nIiH1iTZII1ERkQQqoiIiCVRERUQSqIiKiCRQERURSaAiKiKS4P8DX+Lo5rfGvfQAAAAASUVORK5C\nYII=\n", 741 | "text/plain": [ 742 | "
" 743 | ] 744 | }, 745 | "metadata": { 746 | "tags": [] 747 | } 748 | } 749 | ] 750 | }, 751 | { 752 | "metadata": { 753 | "id": "LYYZW0YSYH2f", 754 | "colab_type": "text" 755 | }, 756 | "cell_type": "markdown", 757 | "source": [ 758 | "## 7. Summary\n", 759 | "*Steps per epoch on 128 batch_size = 469, using GPU*\n", 760 | "\n", 761 | "\n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | "
ModelEpochTrain Acc (%)Test Acc (%)Training Time (min)
RNN1081.2575.062.23
3083.5975.146.90
LSTM1094.5390.042.55
3095.3190.107.62
GRU1097.6692.452.47
3089.0682.537.38
" 784 | ] 785 | } 786 | ] 787 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Info 2 | ML is a repository that contains supervised learning algorithm templates. The algorithms are divided by regression and classification, as shown below in "List of Algorithms". 3 | 4 | # How To Use 5 | Each algorithm is divided in different sections in Jupyter Notebook with appropriate instructions. None of the sections needs to be adjusted but the data input and input/target variables definition. Note that all algorithms use one dataset as an example. Therefore, in some algorithms, the quality of the input and target variables of being suitable or proper in the circumstances might be very low. 6 | 7 | # Author 8 | The author is responsible for the content and quality of the code. Please refer to The Learning Machine (thelearningmachine.ai) for any remarks. 9 | 10 | # List of Algorithms 11 | ## Classification 12 | - [Neural Network Classifier](./Classification/neural_networks_classifier.ipynb) 13 | - [Linear Support Vector Machines (LSVM)](./Classification/linear_svm.ipynb) 14 | - [Logistic Regression](./Classification/logistic_regression.ipynb) 15 | - [MultiLayer Perceptron (MLP)](./Classification/mlp_pytorch.ipynb) 16 | - [Kernelized Support Vector Machines (SVM)](./Classification/kernelized_svm.ipynb) 17 | - [Decision Trees](./Classification/decision_trees.ipynb) 18 | - [Naive Bayes Classifiers](./Classification/naive_bayes_classifiers.ipynb) 19 | - [Random Forests](./Classification/random_forests.ipynb) 20 | - [Recurrent Neural Network (RNN)](./Classification/rnn_pytorch.ipynb): Applied to MNIST Classification Task 21 | 22 | ## Regression 23 | - [Neural Network Regressor](./Regression/neural_networks_regressor.ipynb) 24 | - [Ridge Regression](./Regression/ridge_regression.ipynb) 25 | - [Lasso Regression](./Regression/lasso_regression.ipynb) 26 | - [Polynomial Regression](./Regression/linear_vs_%20polynomial_regressions.ipynb) 27 | - [Linear Regression](./Regression/linear_regression.ipynb) 28 | -------------------------------------------------------------------------------- /Regression/lasso_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Lasso Regression" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from sklearn.preprocessing import MinMaxScaler\n", 42 | "from sklearn.model_selection import train_test_split" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## 2. Data Input and Variables" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "### Data Input\n", 66 | "# df = \n", 67 | "\n", 68 | "### Defining Variables \n", 69 | "# X = \n", 70 | "# y = \n", 71 | "\n", 72 | "### Data Input Example \n", 73 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 74 | "\n", 75 | "X = df[['horsepower']]\n", 76 | "y = df['price']" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## 3. The Model" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "*Run to build the model.*" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 8, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "lasso regression linear model intercept: 3984.172831038173\n", 103 | "lasso regression linear model coeff:[33549.01456544]\n", 104 | "\n", 105 | "Non-zero features: 1\n", 106 | "\n", 107 | "R-squared score (training): 0.623\n", 108 | "R-squared score (test): 0.666\n", 109 | "\n", 110 | "Features with non-zero weight (sorted by absolute magnitude):\n", 111 | "\thorsepower, 33549.015\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "from sklearn.linear_model import Lasso\n", 117 | "\n", 118 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 119 | "scaler = MinMaxScaler()\n", 120 | "\n", 121 | "X_train_scaled = scaler.fit_transform(X_train)\n", 122 | "X_test_scaled = scaler.transform(X_test)\n", 123 | "\n", 124 | "linlasso = Lasso(alpha=2.0, max_iter = 10000).fit(X_train_scaled, y_train)\n", 125 | "\n", 126 | "### Intercept & coefficient, # of non-zero features & weights, R-squared for training & test data set\n", 127 | "print('lasso regression linear model intercept: {}'\n", 128 | " .format(linlasso.intercept_))\n", 129 | "print('lasso regression linear model coeff:{}'\n", 130 | " .format(linlasso.coef_))\n", 131 | "print('\\nNon-zero features: {}'\n", 132 | " .format(np.sum(linlasso.coef_ != 0)))\n", 133 | "print('\\nR-squared score (training): {:.3f}'\n", 134 | " .format(linlasso.score(X_train_scaled, y_train)))\n", 135 | "print('R-squared score (test): {:.3f}\\n'\n", 136 | " .format(linlasso.score(X_test_scaled, y_test)))\n", 137 | "print('Features with non-zero weight (sorted by absolute magnitude):')\n", 138 | "\n", 139 | "for e in sorted (list(zip(list(X), linlasso.coef_)),\n", 140 | " key = lambda e: -abs(e[1])):\n", 141 | " if e[1] != 0:\n", 142 | " print('\\t{}, {:.3f}'.format(e[0], e[1]))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### 3.1. Regularization parameter alpha on R-squared" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "*Run to check how alpha affects the model score.*" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 5, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "Lasso regression: effect of alpha regularization\n", 169 | "parameter on number of features kept in final model\n", 170 | "\n", 171 | "Alpha = 0.50\n", 172 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.67\n", 173 | "\n", 174 | "Alpha = 1.00\n", 175 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.67\n", 176 | "\n", 177 | "Alpha = 2.00\n", 178 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.67\n", 179 | "\n", 180 | "Alpha = 3.00\n", 181 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.67\n", 182 | "\n", 183 | "Alpha = 5.00\n", 184 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.66\n", 185 | "\n", 186 | "Alpha = 10.00\n", 187 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.66\n", 188 | "\n", 189 | "Alpha = 20.00\n", 190 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.66\n", 191 | "\n", 192 | "Alpha = 50.00\n", 193 | "Features kept: 1, r-squared training: 0.62, r-squared test: 0.65\n", 194 | "\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "print('Lasso regression: effect of alpha regularization\\n\\\n", 200 | "parameter on number of features kept in final model\\n')\n", 201 | "\n", 202 | "for alpha in [0.5, 1, 2, 3, 5, 10, 20, 50]:\n", 203 | " linlasso = Lasso(alpha, max_iter = 10000).fit(X_train_scaled, y_train)\n", 204 | " r2_train = linlasso.score(X_train_scaled, y_train)\n", 205 | " r2_test = linlasso.score(X_test_scaled, y_test)\n", 206 | " \n", 207 | " print('Alpha = {:.2f}\\nFeatures kept: {}, r-squared training: {:.2f}, \\\n", 208 | "r-squared test: {:.2f}\\n'\n", 209 | " .format(alpha, np.sum(linlasso.coef_ != 0), r2_train, r2_test))" 210 | ] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "Python 3", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.6.4" 230 | } 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 2 234 | } 235 | -------------------------------------------------------------------------------- /Regression/linear_vs_ polynomial_regressions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Linear vs Polynomial Regression" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from sklearn.preprocessing import StandardScaler\n", 42 | "from sklearn.model_selection import train_test_split" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## 2. Data Input and Variables" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "### Data Input\n", 66 | "# df = \n", 67 | "\n", 68 | "### Defining Variables \n", 69 | "# X = \n", 70 | "# y = \n", 71 | "\n", 72 | "### Data Input Example \n", 73 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 74 | "\n", 75 | "X = df[['horsepower']]\n", 76 | "y = df['price']" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## 3. The Models" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### 3.1. Linear Regression" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "*Run to build the Linear Regression model.*" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "linear model coeff (w): [157.09522969]\n", 110 | "linear model intercept (b): -3574.121\n", 111 | "R-squared score (training): 0.623\n", 112 | "R-squared score (test): 0.666\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "from sklearn.linear_model import LinearRegression\n", 118 | "\n", 119 | "# train_test_split\n", 120 | "X_train, X_test, y_train, y_test = train_test_split(X, y,\n", 121 | " random_state = 0)\n", 122 | "\n", 123 | "# Linear regression def\n", 124 | "linreg = LinearRegression().fit(X_train, y_train)\n", 125 | "\n", 126 | "### intercept & coefficient, R-squared for training & test data set\n", 127 | "print('linear model coeff (w): {}'\n", 128 | " .format(linreg.coef_))\n", 129 | "print('linear model intercept (b): {:.3f}'\n", 130 | " .format(linreg.intercept_))\n", 131 | "print('R-squared score (training): {:.3f}'\n", 132 | " .format(linreg.score(X_train, y_train)))\n", 133 | "print('R-squared score (test): {:.3f}'\n", 134 | " .format(linreg.score(X_test, y_test)))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "### 3.2. Polynomial Regression" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "*Run to build the Polynomial Regression model.*" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "(poly deg 2) linear model coeff (w):\n", 161 | "[0.00000000e+00 1.40923904e+02 6.48999886e-02]\n", 162 | "(poly deg 2) linear model intercept (b): -2683.607\n", 163 | "(poly deg 2) R-squared score (training): 0.623\n", 164 | "(poly deg 2) R-squared score (test): 0.670\n", 165 | "\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "from sklearn.preprocessing import PolynomialFeatures\n", 171 | "\n", 172 | "'''\n", 173 | "Now we transform the original input data to add\n", 174 | "polynomial features up to degree 2\n", 175 | "\n", 176 | "'''\n", 177 | "\n", 178 | "poly = PolynomialFeatures(degree=2)\n", 179 | "X_poly = poly.fit_transform(X)\n", 180 | "\n", 181 | "# train_test_split\n", 182 | "X_train, X_test, y_train, y_test = train_test_split(X_poly, y,\n", 183 | " random_state = 0)\n", 184 | "# Polynomial regression def\n", 185 | "linreg = LinearRegression().fit(X_train, y_train)\n", 186 | "\n", 187 | "### intercept & coefficient, R-squared for training & test data set\n", 188 | "print('(poly deg 2) linear model coeff (w):\\n{}'\n", 189 | " .format(linreg.coef_))\n", 190 | "print('(poly deg 2) linear model intercept (b): {:.3f}'\n", 191 | " .format(linreg.intercept_))\n", 192 | "print('(poly deg 2) R-squared score (training): {:.3f}'\n", 193 | " .format(linreg.score(X_train, y_train)))\n", 194 | "print('(poly deg 2) R-squared score (test): {:.3f}\\n'\n", 195 | " .format(linreg.score(X_test, y_test)))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### 3.3. Polynomial Regression with Regularization" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "Run to build the Polynomial Regression model with a regularization penalty." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 7, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "(poly deg 2 + ridge) linear model coeff (w):\n", 222 | "[0.00000000e+00 1.40908895e+02 6.49573697e-02]\n", 223 | "(poly deg 2 + ridge) linear model intercept (b): -2682.747\n", 224 | "(poly deg 2 + ridge) R-squared score (training): 0.623\n", 225 | "(poly deg 2 + ridge) R-squared score (test): 0.670\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "from sklearn.linear_model import Ridge\n", 231 | "\n", 232 | "'''\n", 233 | "Addition of many polynomial features often leads to\n", 234 | "overfitting, so we often use polynomial features in combination\n", 235 | "with regression that has a regularization penalty, like ridge\n", 236 | "regression.\n", 237 | "'''\n", 238 | "\n", 239 | "X_train, X_test, y_train, y_test = train_test_split(X_poly, y,\n", 240 | " random_state = 0)\n", 241 | "linreg = Ridge().fit(X_train, y_train)\n", 242 | "\n", 243 | "### intercept & coefficient, R-squared for training & test data set\n", 244 | "print('(poly deg 2 + ridge) linear model coeff (w):\\n{}'\n", 245 | " .format(linreg.coef_))\n", 246 | "print('(poly deg 2 + ridge) linear model intercept (b): {:.3f}'\n", 247 | " .format(linreg.intercept_))\n", 248 | "print('(poly deg 2 + ridge) R-squared score (training): {:.3f}'\n", 249 | " .format(linreg.score(X_train, y_train)))\n", 250 | "print('(poly deg 2 + ridge) R-squared score (test): {:.3f}'\n", 251 | " .format(linreg.score(X_test, y_test)))" 252 | ] 253 | } 254 | ], 255 | "metadata": { 256 | "kernelspec": { 257 | "display_name": "Python 3", 258 | "language": "python", 259 | "name": "python3" 260 | }, 261 | "language_info": { 262 | "codemirror_mode": { 263 | "name": "ipython", 264 | "version": 3 265 | }, 266 | "file_extension": ".py", 267 | "mimetype": "text/x-python", 268 | "name": "python", 269 | "nbconvert_exporter": "python", 270 | "pygments_lexer": "ipython3", 271 | "version": "3.6.4" 272 | } 273 | }, 274 | "nbformat": 4, 275 | "nbformat_minor": 2 276 | } 277 | -------------------------------------------------------------------------------- /Regression/neural_networks_regressor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Neural Networks - Regression" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from sklearn.model_selection import train_test_split" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## 2. Data Input and Variables" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "### Data Input\n", 65 | "# df = \n", 66 | "\n", 67 | "### Defining Variables \n", 68 | "# X = \n", 69 | "# y = \n", 70 | "\n", 71 | "### Data Input Example \n", 72 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 73 | "\n", 74 | "X = df[['horsepower', 'normalized-losses']]\n", 75 | "y = df['price']" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "## 3. The Model" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 9, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "Accuracy of NN classifier on training set: 0.65\n", 95 | "Accuracy of NN classifier on test set: 0.68\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "from sklearn.neural_network import MLPClassifier\n", 101 | "from sklearn.preprocessing import MinMaxScaler\n", 102 | "from sklearn.neural_network import MLPRegressor\n", 103 | "\n", 104 | "# normalized\n", 105 | "scaler = MinMaxScaler()\n", 106 | "\n", 107 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 108 | "X_train_scaled = scaler.fit_transform(X_train)\n", 109 | "X_test_scaled = scaler.transform(X_test)\n", 110 | "\n", 111 | "# model \n", 112 | "clf = MLPRegressor(hidden_layer_sizes = [100, 100], alpha = 5.0, random_state = 0, solver='lbfgs').fit(X_train_scaled, y_train)\n", 113 | "\n", 114 | "print('Accuracy of NN regressor on training set: {:.2f}'\n", 115 | " .format(clf.score(X_train_scaled, y_train)))\n", 116 | "print('Accuracy of NN regressor on test set: {:.2f}'\n", 117 | " .format(clf.score(X_test_scaled, y_test)))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### 3.1. Accuracy with different activation functions and regularization parameter alpha" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 13, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "Accuracy of NN regressor with activation funtion = tanh and alpha = 0.0001 on training set: 0.65\n", 137 | "Accuracy of NN regressor on test set: 0.68\n", 138 | "\n", 139 | "Accuracy of NN regressor with activation funtion = tanh and alpha = 1.0 on training set: 0.65\n", 140 | "Accuracy of NN regressor on test set: 0.68\n", 141 | "\n", 142 | "Accuracy of NN regressor with activation funtion = tanh and alpha = 100 on training set: 0.65\n", 143 | "Accuracy of NN regressor on test set: 0.68\n", 144 | "\n", 145 | "Accuracy of NN regressor with activation funtion = relu and alpha = 0.0001 on training set: 0.65\n", 146 | "Accuracy of NN regressor on test set: 0.68\n", 147 | "\n", 148 | "Accuracy of NN regressor with activation funtion = relu and alpha = 1.0 on training set: 0.65\n", 149 | "Accuracy of NN regressor on test set: 0.68\n", 150 | "\n", 151 | "Accuracy of NN regressor with activation funtion = relu and alpha = 100 on training set: 0.65\n", 152 | "Accuracy of NN regressor on test set: 0.68\n", 153 | "\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 159 | "\n", 160 | "# normalized\n", 161 | "scaler = MinMaxScaler()\n", 162 | "\n", 163 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)\n", 164 | "X_train_scaled = scaler.fit_transform(X_train)\n", 165 | "X_test_scaled = scaler.transform(X_test)\n", 166 | "\n", 167 | "# Accuracy with different activation functions and regularization parameter alpha\n", 168 | "for thisactivation in ['tanh', 'relu']:\n", 169 | " for thisalpha in [0.0001, 1.0, 100]:\n", 170 | " mlpreg = MLPRegressor(hidden_layer_sizes = [100,100],\n", 171 | " activation = thisactivation,\n", 172 | " alpha = thisalpha,\n", 173 | " solver = 'lbfgs').fit(X_train, y_train)\n", 174 | " print('Accuracy of NN regressor with activation funtion = {} and alpha = {} on training set: {:.2f}'.format(thisactivation, thisalpha, clf.score(X_train_scaled, y_train)))\n", 175 | " print('Accuracy of NN regressor on test set: {:.2f}\\n'.format(clf.score(X_test_scaled, y_test)))\n", 176 | " \n", 177 | " " 178 | ] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "Python 3", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.6.4" 198 | } 199 | }, 200 | "nbformat": 4, 201 | "nbformat_minor": 2 202 | } 203 | -------------------------------------------------------------------------------- /Regression/ridge_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Supervised Learning Algorithms: Ridge Regression" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*In this template, only **data input** and **input/target variables** need to be specified (see \"Data Input & Variables\" section for further instructions). None of the other sections needs to be adjusted. As a data input example, .csv file from IBM Box web repository is used.*" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## 1. Libraries" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "*Run to import the required libraries.*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 7, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib notebook\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from sklearn.preprocessing import StandardScaler\n", 42 | "from sklearn.model_selection import train_test_split" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## 2. Data Input and Variables" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "*Define the data input as well as the input (X) and target (y) variables and run the code. Do not change the data & variable names **['df', 'X', 'y']** as they are used in further sections.*" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 11, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "### Data Input\n", 66 | "# df = \n", 67 | "\n", 68 | "### Defining Variables \n", 69 | "# X = \n", 70 | "# y = \n", 71 | "\n", 72 | "### Data Input Example \n", 73 | "df = pd.read_csv('https://ibm.box.com/shared/static/q6iiqb1pd7wo8r3q28jvgsrprzezjqk3.csv')\n", 74 | "\n", 75 | "X = df[['horsepower']]\n", 76 | "y = df['price']" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## 3. The Model" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "*Run to build the model.*" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 19, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Ridge regression linear model intercept: 12555.906666666666\n", 103 | "Ridge regression linear model coeff: [5036.92170125]\n", 104 | "\n", 105 | "R-squared score (training): 0.614\n", 106 | "R-squared score (test): 0.625\n", 107 | "\n", 108 | "Number of non-zero features: 1\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "from sklearn.linear_model import Ridge\n", 114 | "\n", 115 | "# train_test_split\n", 116 | "X_train, X_test, y_train, y_test = train_test_split(X, y,\n", 117 | " random_state = 0)\n", 118 | "# feature normalization\n", 119 | "scaler = StandardScaler()\n", 120 | "\n", 121 | "X_train_scaled = scaler.fit_transform(X_train)\n", 122 | "X_test_scaled = scaler.transform(X_test)\n", 123 | "\n", 124 | "# ridge regression def\n", 125 | "linridge = Ridge(alpha=20.0).fit(X_train_scaled, y_train)\n", 126 | "\n", 127 | "### intercept & coefficient, # of non-zero features & weights, R-squared for training & test data set\n", 128 | "print('Ridge regression linear model intercept: {}'\n", 129 | " .format(linridge.intercept_))\n", 130 | "print('Ridge regression linear model coeff: {}\\n'\n", 131 | " .format(linridge.coef_))\n", 132 | "print('R-squared score (training): {:.3f}'\n", 133 | " .format(linridge.score(X_train_scaled, y_train)))\n", 134 | "print('R-squared score (test): {:.3f}\\n'\n", 135 | " .format(linridge.score(X_test_scaled, y_test)))\n", 136 | "print('Number of non-zero features: {}'\n", 137 | " .format(np.sum(linridge.coef_ != 0)))" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### 3.1. Regularization parameter alpha on R-squared" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "*Run to check how alpha affects the model score.*" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 22, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "Ridge regression: effect of alpha regularization parameter\n", 164 | "\n", 165 | "Alpha = 0.00\n", 166 | "num abs(coeff) > 1.0: 1, r-squared training: 0.62, r-squared test: 0.67\n", 167 | "\n", 168 | "Alpha = 1.00\n", 169 | "num abs(coeff) > 1.0: 1, r-squared training: 0.62, r-squared test: 0.66\n", 170 | "\n", 171 | "Alpha = 10.00\n", 172 | "num abs(coeff) > 1.0: 1, r-squared training: 0.62, r-squared test: 0.65\n", 173 | "\n", 174 | "Alpha = 20.00\n", 175 | "num abs(coeff) > 1.0: 1, r-squared training: 0.61, r-squared test: 0.63\n", 176 | "\n", 177 | "Alpha = 50.00\n", 178 | "num abs(coeff) > 1.0: 1, r-squared training: 0.58, r-squared test: 0.56\n", 179 | "\n", 180 | "Alpha = 100.00\n", 181 | "num abs(coeff) > 1.0: 1, r-squared training: 0.52, r-squared test: 0.48\n", 182 | "\n", 183 | "Alpha = 1000.00\n", 184 | "num abs(coeff) > 1.0: 1, r-squared training: 0.15, r-squared test: 0.07\n", 185 | "\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "print('Ridge regression: effect of alpha regularization parameter\\n')\n", 191 | "for this_alpha in [0, 1, 10, 20, 50, 100, 1000]:\n", 192 | " linridge = Ridge(alpha = this_alpha).fit(X_train_scaled, y_train)\n", 193 | " r2_train = linridge.score(X_train_scaled, y_train)\n", 194 | " r2_test = linridge.score(X_test_scaled, y_test)\n", 195 | " num_coeff_bigger = np.sum(abs(linridge.coef_) > 1.0)\n", 196 | " print('Alpha = {:.2f}\\nnum abs(coeff) > 1.0: {}, \\\n", 197 | "r-squared training: {:.2f}, r-squared test: {:.2f}\\n'\n", 198 | " .format(this_alpha, num_coeff_bigger, r2_train, r2_test))" 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "Python 3", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.6.4" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 2 223 | } 224 | --------------------------------------------------------------------------------