├── .gitignore ├── NN_pytorch.ipynb ├── README.md ├── autoencoder.ipynb ├── collab_filtering.ipynb ├── data ├── cancer.csv ├── fashion │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-labels-idx1-ubyte.gz │ ├── train-images-idx3-ubyte.gz │ └── train-labels-idx1-ubyte.gz ├── housing.csv └── nlp │ ├── KVH.txt │ ├── kvh_trn │ └── trn.txt │ ├── kvh_val │ └── val.txt │ ├── nietzsche.txt │ └── trump.txt ├── knn.ipynb ├── linear_regression.ipynb ├── logistic_regression.ipynb ├── model ├── __init__.py ├── activation_classes.py ├── activations.py ├── gradients.py ├── knn.py ├── linear_model.py ├── metrics.py ├── neural_network.py ├── optimizers.py ├── random_forest.py └── utils.py ├── neural_net_optimizers.ipynb ├── random_forest_classifier.ipynb ├── random_forest_regressor.ipynb ├── rnn-vietnamese.ipynb └── scratch_neural_net.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | */.ipynb_checkpoints/* 3 | *__pycache__/ 4 | *.vscode/ 5 | fastai 6 | 7 | data/large_ds 8 | data/nlp/models 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning from scratch! 2 | 3 | Update: Code implementations have been moved to python module. Notebook will only show results and model comparison 4 | 5 | To refresh my knowledge, I will attempt to implement some basic machine learning algorithms from scratch using only python and limited numpy/pandas function. 6 | My model implementations will be compared to existing models from popular ML library (sklearn) 7 | - [Linear Regression with weight decay (L2 regularization)](https://github.com/anhquan0412/basic_model_scratch/blob/master/linear_regression.ipynb) 8 | - [Logistic Regression with weight decay](https://github.com/anhquan0412/basic_model_scratch/blob/master/logistic_regression.ipynb) 9 | - Random Forest with Permutation Feature Importances 10 | - [Random Forest Regressor](https://github.com/anhquan0412/basic_model_scratch/blob/master/random_forest_regressor.ipynb) 11 | - [Random Forest Classifier](https://github.com/anhquan0412/basic_model_scratch/blob/master/random_forest_classifier.ipynb) 12 | - [K Nearest Neighbors: supervised and unsupervised](https://github.com/anhquan0412/basic_model_scratch/blob/master/knn.ipynb) 13 | - [Neural network for classification](https://github.com/anhquan0412/basic_model_scratch/blob/master/scratch_neural_net.ipynb) 14 | - Stochastic Gradient Descent 15 | - Multiple hidden layers 16 | - Variety of activation functions + gradients (Sigmoid, Softmax, ReLU ...) customized for each hidden layer 17 | - L2 regularization 18 | - Dropout 19 | - [Dynamic learning rate optimizer](https://github.com/anhquan0412/basic_model_scratch/blob/master/neural_net_optimizers.ipynb) (momentum, RMSProp and Adam) 20 | - TODO: batchnorm 21 | 22 | The following notebooks uses Pytorch libraries so they are not implemented from scratch. However, I try not to use any high level Pytorch function 23 | - [Pytorch Neural Network](https://github.com/anhquan0412/basic_model_scratch/blob/master/NN_pytorch.ipynb) with: 24 | - Custom Data Loader 25 | - Data Augmentation on 1 channel image: torchvision vs fastai 26 | - Shallow NN with batchnorm and dropout 27 | - Learning rate finder 28 | - [Auto Encoding](https://github.com/anhquan0412/basic_model_scratch/blob/master/autoencoder.ipynb) 29 | - [Collaborative Filtering](https://github.com/anhquan0412/basic_model_scratch/blob/master/collab_filtering.ipynb) 30 | - [Char RNN in Vietnamese (Fast.ai)](https://github.com/anhquan0412/basic_model_scratch/blob/master/rnn-vietnamese.ipynb) 31 | -------------------------------------------------------------------------------- /data/fashion/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/fashion/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/fashion/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/fashion/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/housing.csv: -------------------------------------------------------------------------------- 1 | RM,LSTAT,PTRATIO,MEDV 2 | 6.575,4.98,15.3,504000.0 3 | 6.421,9.14,17.8,453600.0 4 | 7.185,4.03,17.8,728700.0 5 | 6.998,2.94,18.7,701400.0 6 | 7.147,5.33,18.7,760200.0 7 | 6.43,5.21,18.7,602700.0 8 | 6.012,12.43,15.2,480900.0 9 | 6.172,19.15,15.2,569100.0 10 | 5.631,29.93,15.2,346500.0 11 | 6.004,17.1,15.2,396900.0 12 | 6.377,20.45,15.2,315000.0 13 | 6.009,13.27,15.2,396900.0 14 | 5.889,15.71,15.2,455700.0 15 | 5.949,8.26,21.0,428400.0 16 | 6.096,10.26,21.0,382200.0 17 | 5.834,8.47,21.0,417900.0 18 | 5.935,6.58,21.0,485100.0 19 | 5.99,14.67,21.0,367500.0 20 | 5.456,11.69,21.0,424200.0 21 | 5.727,11.28,21.0,382200.0 22 | 5.57,21.02,21.0,285600.0 23 | 5.965,13.83,21.0,411600.0 24 | 6.142,18.72,21.0,319200.0 25 | 5.813,19.88,21.0,304500.0 26 | 5.924,16.3,21.0,327600.0 27 | 5.599,16.51,21.0,291900.0 28 | 5.813,14.81,21.0,348600.0 29 | 6.047,17.28,21.0,310800.0 30 | 6.495,12.8,21.0,386400.0 31 | 6.674,11.98,21.0,441000.0 32 | 5.713,22.6,21.0,266700.0 33 | 6.072,13.04,21.0,304500.0 34 | 5.95,27.71,21.0,277200.0 35 | 5.701,18.35,21.0,275100.0 36 | 6.096,20.34,21.0,283500.0 37 | 5.933,9.68,19.2,396900.0 38 | 5.841,11.41,19.2,420000.0 39 | 5.85,8.77,19.2,441000.0 40 | 5.966,10.13,19.2,518700.0 41 | 6.595,4.32,18.3,646800.0 42 | 7.024,1.98,18.3,732900.0 43 | 6.77,4.84,17.9,558600.0 44 | 6.169,5.81,17.9,531300.0 45 | 6.211,7.44,17.9,518700.0 46 | 6.069,9.55,17.9,445200.0 47 | 5.682,10.21,17.9,405300.0 48 | 5.786,14.15,17.9,420000.0 49 | 6.03,18.8,17.9,348600.0 50 | 5.399,30.81,17.9,302400.0 51 | 5.602,16.2,17.9,407400.0 52 | 5.963,13.45,16.8,413700.0 53 | 6.115,9.43,16.8,430500.0 54 | 6.511,5.28,16.8,525000.0 55 | 5.998,8.43,16.8,491400.0 56 | 5.888,14.8,21.1,396900.0 57 | 7.249,4.81,17.9,743400.0 58 | 6.383,5.77,17.3,518700.0 59 | 6.816,3.95,15.1,663600.0 60 | 6.145,6.86,19.7,489300.0 61 | 5.927,9.22,19.7,411600.0 62 | 5.741,13.15,19.7,392700.0 63 | 5.966,14.44,19.7,336000.0 64 | 6.456,6.73,19.7,466200.0 65 | 6.762,9.5,19.7,525000.0 66 | 7.104,8.05,18.6,693000.0 67 | 6.29,4.67,16.1,493500.0 68 | 5.787,10.24,16.1,407400.0 69 | 5.878,8.1,18.9,462000.0 70 | 5.594,13.09,18.9,365400.0 71 | 5.885,8.79,18.9,438900.0 72 | 6.417,6.72,19.2,508200.0 73 | 5.961,9.88,19.2,455700.0 74 | 6.065,5.52,19.2,478800.0 75 | 6.245,7.54,19.2,491400.0 76 | 6.273,6.78,18.7,506100.0 77 | 6.286,8.94,18.7,449400.0 78 | 6.279,11.97,18.7,420000.0 79 | 6.14,10.27,18.7,436800.0 80 | 6.232,12.34,18.7,445200.0 81 | 5.874,9.1,18.7,426300.0 82 | 6.727,5.29,19.0,588000.0 83 | 6.619,7.22,19.0,501900.0 84 | 6.302,6.72,19.0,520800.0 85 | 6.167,7.51,19.0,480900.0 86 | 6.389,9.62,18.5,501900.0 87 | 6.63,6.53,18.5,558600.0 88 | 6.015,12.86,18.5,472500.0 89 | 6.121,8.44,18.5,466200.0 90 | 7.007,5.5,17.8,495600.0 91 | 7.079,5.7,17.8,602700.0 92 | 6.417,8.81,17.8,474600.0 93 | 6.405,8.2,17.8,462000.0 94 | 6.442,8.16,18.2,480900.0 95 | 6.211,6.21,18.2,525000.0 96 | 6.249,10.59,18.2,432600.0 97 | 6.625,6.65,18.0,596400.0 98 | 6.163,11.34,18.0,449400.0 99 | 8.069,4.21,18.0,812700.0 100 | 7.82,3.57,18.0,919800.0 101 | 7.416,6.19,18.0,697200.0 102 | 6.727,9.42,20.9,577500.0 103 | 6.781,7.67,20.9,556500.0 104 | 6.405,10.63,20.9,390600.0 105 | 6.137,13.44,20.9,405300.0 106 | 6.167,12.33,20.9,422100.0 107 | 5.851,16.47,20.9,409500.0 108 | 5.836,18.66,20.9,409500.0 109 | 6.127,14.09,20.9,428400.0 110 | 6.474,12.27,20.9,415800.0 111 | 6.229,15.55,20.9,407400.0 112 | 6.195,13.0,20.9,455700.0 113 | 6.715,10.16,17.8,478800.0 114 | 5.913,16.21,17.8,394800.0 115 | 6.092,17.09,17.8,392700.0 116 | 6.254,10.45,17.8,388500.0 117 | 5.928,15.76,17.8,384300.0 118 | 6.176,12.04,17.8,445200.0 119 | 6.021,10.3,17.8,403200.0 120 | 5.872,15.37,17.8,428400.0 121 | 5.731,13.61,17.8,405300.0 122 | 5.87,14.37,19.1,462000.0 123 | 6.004,14.27,19.1,426300.0 124 | 5.961,17.93,19.1,430500.0 125 | 5.856,25.41,19.1,363300.0 126 | 5.879,17.58,19.1,394800.0 127 | 5.986,14.81,19.1,449400.0 128 | 5.613,27.26,19.1,329700.0 129 | 5.693,17.19,21.2,340200.0 130 | 6.431,15.39,21.2,378000.0 131 | 5.637,18.34,21.2,300300.0 132 | 6.458,12.6,21.2,403200.0 133 | 6.326,12.26,21.2,411600.0 134 | 6.372,11.12,21.2,483000.0 135 | 5.822,15.03,21.2,386400.0 136 | 5.757,17.31,21.2,327600.0 137 | 6.335,16.96,21.2,380100.0 138 | 5.942,16.9,21.2,365400.0 139 | 6.454,14.59,21.2,359100.0 140 | 5.857,21.32,21.2,279300.0 141 | 6.151,18.46,21.2,373800.0 142 | 6.174,24.16,21.2,294000.0 143 | 5.019,34.41,21.2,302400.0 144 | 5.403,26.82,14.7,281400.0 145 | 5.468,26.42,14.7,327600.0 146 | 4.903,29.29,14.7,247800.0 147 | 6.13,27.8,14.7,289800.0 148 | 5.628,16.65,14.7,327600.0 149 | 4.926,29.53,14.7,306600.0 150 | 5.186,28.32,14.7,373800.0 151 | 5.597,21.45,14.7,323400.0 152 | 6.122,14.1,14.7,451500.0 153 | 5.404,13.28,14.7,411600.0 154 | 5.012,12.12,14.7,321300.0 155 | 5.709,15.79,14.7,407400.0 156 | 6.129,15.12,14.7,357000.0 157 | 6.152,15.02,14.7,327600.0 158 | 5.272,16.14,14.7,275100.0 159 | 6.943,4.59,14.7,867300.0 160 | 6.066,6.43,14.7,510300.0 161 | 6.51,7.39,14.7,489300.0 162 | 6.25,5.5,14.7,567000.0 163 | 5.854,11.64,14.7,476700.0 164 | 6.101,9.81,14.7,525000.0 165 | 5.877,12.14,14.7,499800.0 166 | 6.319,11.1,14.7,499800.0 167 | 6.402,11.32,14.7,468300.0 168 | 5.875,14.43,14.7,365400.0 169 | 5.88,12.03,14.7,401100.0 170 | 5.572,14.69,16.6,485100.0 171 | 6.416,9.04,16.6,495600.0 172 | 5.859,9.64,16.6,474600.0 173 | 6.546,5.33,16.6,617400.0 174 | 6.02,10.11,16.6,487200.0 175 | 6.315,6.29,16.6,516600.0 176 | 6.86,6.92,16.6,627900.0 177 | 6.98,5.04,17.8,781200.0 178 | 7.765,7.56,17.8,835800.0 179 | 6.144,9.45,17.8,760200.0 180 | 7.155,4.82,17.8,795900.0 181 | 6.563,5.68,17.8,682500.0 182 | 5.604,13.98,17.8,554400.0 183 | 6.153,13.15,17.8,621600.0 184 | 6.782,6.68,15.2,672000.0 185 | 6.556,4.56,15.2,625800.0 186 | 7.185,5.39,15.2,732900.0 187 | 6.951,5.1,15.2,777000.0 188 | 6.739,4.69,15.2,640500.0 189 | 7.178,2.87,15.2,764400.0 190 | 6.8,5.03,15.6,653100.0 191 | 6.604,4.38,15.6,611100.0 192 | 7.287,4.08,12.6,699300.0 193 | 7.107,8.61,12.6,636300.0 194 | 7.274,6.62,12.6,726600.0 195 | 6.975,4.56,17.0,732900.0 196 | 7.135,4.45,17.0,690900.0 197 | 6.162,7.43,14.7,506100.0 198 | 7.61,3.11,14.7,888300.0 199 | 7.853,3.81,14.7,1018500.0 200 | 5.891,10.87,18.6,474600.0 201 | 6.326,10.97,18.6,512400.0 202 | 5.783,18.06,18.6,472500.0 203 | 6.064,14.66,18.6,512400.0 204 | 5.344,23.09,18.6,420000.0 205 | 5.96,17.27,18.6,455700.0 206 | 5.404,23.98,18.6,405300.0 207 | 5.807,16.03,18.6,470400.0 208 | 6.375,9.38,18.6,590100.0 209 | 5.412,29.55,18.6,497700.0 210 | 6.182,9.47,18.6,525000.0 211 | 5.888,13.51,16.4,489300.0 212 | 6.642,9.69,16.4,602700.0 213 | 5.951,17.92,16.4,451500.0 214 | 6.373,10.5,16.4,483000.0 215 | 6.951,9.71,17.4,560700.0 216 | 6.164,21.46,17.4,455700.0 217 | 6.879,9.93,17.4,577500.0 218 | 6.618,7.6,17.4,632100.0 219 | 8.266,4.14,17.4,940800.0 220 | 8.04,3.13,17.4,789600.0 221 | 7.163,6.36,17.4,663600.0 222 | 7.686,3.92,17.4,980700.0 223 | 6.552,3.76,17.4,661500.0 224 | 5.981,11.65,17.4,510300.0 225 | 7.412,5.25,17.4,665700.0 226 | 8.337,2.47,17.4,875700.0 227 | 8.247,3.95,17.4,1014300.0 228 | 6.726,8.05,17.4,609000.0 229 | 6.086,10.88,17.4,504000.0 230 | 6.631,9.54,17.4,527100.0 231 | 7.358,4.73,17.4,661500.0 232 | 6.481,6.36,16.6,497700.0 233 | 6.606,7.37,16.6,489300.0 234 | 6.897,11.38,16.6,462000.0 235 | 6.095,12.4,16.6,422100.0 236 | 6.358,11.22,16.6,466200.0 237 | 6.393,5.19,16.6,497700.0 238 | 5.593,12.5,19.1,369600.0 239 | 5.605,18.46,19.1,388500.0 240 | 6.108,9.16,19.1,510300.0 241 | 6.226,10.15,19.1,430500.0 242 | 6.433,9.52,19.1,514500.0 243 | 6.718,6.56,19.1,550200.0 244 | 6.487,5.9,19.1,512400.0 245 | 6.438,3.59,19.1,520800.0 246 | 6.957,3.53,19.1,621600.0 247 | 8.259,3.54,19.1,898800.0 248 | 6.108,6.57,16.4,459900.0 249 | 5.876,9.25,16.4,438900.0 250 | 7.454,3.11,15.9,924000.0 251 | 7.333,7.79,13.0,756000.0 252 | 6.842,6.9,13.0,632100.0 253 | 7.203,9.59,13.0,709800.0 254 | 7.52,7.26,13.0,905100.0 255 | 8.398,5.91,13.0,1024800.0 256 | 7.327,11.25,13.0,651000.0 257 | 7.206,8.1,13.0,766500.0 258 | 5.56,10.45,13.0,478800.0 259 | 7.014,14.79,13.0,644700.0 260 | 7.47,3.16,13.0,913500.0 261 | 5.92,13.65,18.6,434700.0 262 | 5.856,13.0,18.6,443100.0 263 | 6.24,6.59,18.6,529200.0 264 | 6.538,7.73,18.6,512400.0 265 | 7.691,6.58,18.6,739200.0 266 | 6.758,3.53,17.6,680400.0 267 | 6.854,2.98,17.6,672000.0 268 | 7.267,6.05,17.6,697200.0 269 | 6.826,4.16,17.6,695100.0 270 | 6.482,7.19,17.6,611100.0 271 | 6.812,4.85,14.9,737100.0 272 | 7.82,3.76,14.9,953400.0 273 | 6.968,4.59,14.9,743400.0 274 | 7.645,3.01,14.9,966000.0 275 | 7.088,7.85,15.3,676200.0 276 | 6.453,8.23,15.3,462000.0 277 | 6.23,12.93,18.2,422100.0 278 | 6.209,7.14,16.6,487200.0 279 | 6.315,7.6,16.6,468300.0 280 | 6.565,9.51,16.6,520800.0 281 | 6.861,3.33,19.2,598500.0 282 | 7.148,3.56,19.2,783300.0 283 | 6.63,4.7,19.2,585900.0 284 | 6.127,8.58,16.0,501900.0 285 | 6.009,10.4,16.0,455700.0 286 | 6.678,6.27,16.0,600600.0 287 | 6.549,7.39,16.0,569100.0 288 | 5.79,15.84,16.0,426300.0 289 | 6.345,4.97,14.8,472500.0 290 | 7.041,4.74,14.8,609000.0 291 | 6.871,6.07,14.8,520800.0 292 | 6.59,9.5,16.1,462000.0 293 | 6.495,8.67,16.1,554400.0 294 | 6.982,4.86,16.1,695100.0 295 | 7.236,6.93,18.4,758100.0 296 | 6.616,8.93,18.4,596400.0 297 | 7.42,6.47,18.4,701400.0 298 | 6.849,7.53,18.4,592200.0 299 | 6.635,4.54,18.4,478800.0 300 | 5.972,9.97,18.4,426300.0 301 | 4.973,12.64,18.4,338100.0 302 | 6.122,5.98,18.4,464100.0 303 | 6.023,11.72,18.4,407400.0 304 | 6.266,7.9,18.4,453600.0 305 | 6.567,9.28,18.4,499800.0 306 | 5.705,11.5,18.4,340200.0 307 | 5.914,18.33,18.4,373800.0 308 | 5.782,15.94,18.4,415800.0 309 | 6.382,10.36,18.4,485100.0 310 | 6.113,12.73,18.4,441000.0 311 | 6.426,7.2,19.6,499800.0 312 | 6.376,6.87,19.6,485100.0 313 | 6.041,7.7,19.6,428400.0 314 | 5.708,11.74,19.6,388500.0 315 | 6.415,6.12,19.6,525000.0 316 | 6.431,5.08,19.6,516600.0 317 | 6.312,6.15,19.6,483000.0 318 | 6.083,12.79,19.6,466200.0 319 | 5.868,9.97,16.9,405300.0 320 | 6.333,7.34,16.9,474600.0 321 | 6.144,9.09,16.9,415800.0 322 | 5.706,12.43,16.9,359100.0 323 | 6.031,7.83,16.9,407400.0 324 | 6.316,5.68,20.2,466200.0 325 | 6.31,6.75,20.2,434700.0 326 | 6.037,8.01,20.2,443100.0 327 | 5.869,9.8,20.2,409500.0 328 | 5.895,10.56,20.2,388500.0 329 | 6.059,8.51,20.2,432600.0 330 | 5.985,9.74,20.2,399000.0 331 | 5.968,9.29,20.2,392700.0 332 | 7.241,5.49,15.5,686700.0 333 | 6.54,8.65,15.9,346500.0 334 | 6.696,7.18,17.6,501900.0 335 | 6.874,4.61,17.6,655200.0 336 | 6.014,10.53,18.8,367500.0 337 | 5.898,12.67,18.8,361200.0 338 | 6.516,6.36,17.9,485100.0 339 | 6.635,5.99,17.0,514500.0 340 | 6.939,5.89,19.7,558600.0 341 | 6.49,5.98,19.7,480900.0 342 | 6.579,5.49,18.3,506100.0 343 | 5.884,7.79,18.3,390600.0 344 | 6.728,4.5,17.0,632100.0 345 | 5.663,8.05,22.0,382200.0 346 | 5.936,5.57,22.0,432600.0 347 | 6.212,17.6,20.2,373800.0 348 | 6.395,13.27,20.2,455700.0 349 | 6.127,11.48,20.2,476700.0 350 | 6.112,12.67,20.2,474600.0 351 | 6.398,7.79,20.2,525000.0 352 | 6.251,14.19,20.2,417900.0 353 | 5.362,10.19,20.2,436800.0 354 | 5.803,14.64,20.2,352800.0 355 | 3.561,7.12,20.2,577500.0 356 | 4.963,14.0,20.2,459900.0 357 | 3.863,13.33,20.2,485100.0 358 | 4.906,34.77,20.2,289800.0 359 | 4.138,37.97,20.2,289800.0 360 | 7.313,13.44,20.2,315000.0 361 | 6.649,23.24,20.2,291900.0 362 | 6.794,21.24,20.2,279300.0 363 | 6.38,23.69,20.2,275100.0 364 | 6.223,21.78,20.2,214200.0 365 | 6.968,17.21,20.2,218400.0 366 | 6.545,21.08,20.2,228900.0 367 | 5.536,23.6,20.2,237300.0 368 | 5.52,24.56,20.2,258300.0 369 | 4.368,30.63,20.2,184800.0 370 | 5.277,30.81,20.2,151200.0 371 | 4.652,28.28,20.2,220500.0 372 | 5.0,31.99,20.2,155400.0 373 | 4.88,30.62,20.2,214200.0 374 | 5.39,20.85,20.2,241500.0 375 | 5.713,17.11,20.2,317100.0 376 | 6.051,18.76,20.2,487200.0 377 | 5.036,25.68,20.2,203700.0 378 | 6.193,15.17,20.2,289800.0 379 | 5.887,16.35,20.2,266700.0 380 | 6.471,17.12,20.2,275100.0 381 | 6.405,19.37,20.2,262500.0 382 | 5.747,19.92,20.2,178500.0 383 | 5.453,30.59,20.2,105000.0 384 | 5.852,29.97,20.2,132300.0 385 | 5.987,26.77,20.2,117600.0 386 | 6.343,20.32,20.2,151200.0 387 | 6.404,20.31,20.2,254100.0 388 | 5.349,19.77,20.2,174300.0 389 | 5.531,27.38,20.2,178500.0 390 | 5.683,22.98,20.2,105000.0 391 | 4.138,23.34,20.2,249900.0 392 | 5.608,12.13,20.2,585900.0 393 | 5.617,26.4,20.2,361200.0 394 | 6.852,19.78,20.2,577500.0 395 | 5.757,10.11,20.2,315000.0 396 | 6.657,21.22,20.2,361200.0 397 | 4.628,34.37,20.2,375900.0 398 | 5.155,20.08,20.2,342300.0 399 | 4.519,36.98,20.2,147000.0 400 | 6.434,29.05,20.2,151200.0 401 | 6.782,25.79,20.2,157500.0 402 | 5.304,26.64,20.2,218400.0 403 | 5.957,20.62,20.2,184800.0 404 | 6.824,22.74,20.2,176400.0 405 | 6.411,15.02,20.2,350700.0 406 | 6.006,15.7,20.2,298200.0 407 | 5.648,14.1,20.2,436800.0 408 | 6.103,23.29,20.2,281400.0 409 | 5.565,17.16,20.2,245700.0 410 | 5.896,24.39,20.2,174300.0 411 | 5.837,15.69,20.2,214200.0 412 | 6.202,14.52,20.2,228900.0 413 | 6.193,21.52,20.2,231000.0 414 | 6.38,24.08,20.2,199500.0 415 | 6.348,17.64,20.2,304500.0 416 | 6.833,19.69,20.2,296100.0 417 | 6.425,12.03,20.2,338100.0 418 | 6.436,16.22,20.2,300300.0 419 | 6.208,15.17,20.2,245700.0 420 | 6.629,23.27,20.2,281400.0 421 | 6.461,18.05,20.2,201600.0 422 | 6.152,26.45,20.2,182700.0 423 | 5.935,34.02,20.2,176400.0 424 | 5.627,22.88,20.2,268800.0 425 | 5.818,22.11,20.2,220500.0 426 | 6.406,19.52,20.2,359100.0 427 | 6.219,16.59,20.2,386400.0 428 | 6.485,18.85,20.2,323400.0 429 | 5.854,23.79,20.2,226800.0 430 | 6.459,23.98,20.2,247800.0 431 | 6.341,17.79,20.2,312900.0 432 | 6.251,16.44,20.2,264600.0 433 | 6.185,18.13,20.2,296100.0 434 | 6.417,19.31,20.2,273000.0 435 | 6.749,17.44,20.2,281400.0 436 | 6.655,17.73,20.2,319200.0 437 | 6.297,17.27,20.2,338100.0 438 | 7.393,16.74,20.2,373800.0 439 | 6.728,18.71,20.2,312900.0 440 | 6.525,18.13,20.2,296100.0 441 | 5.976,19.01,20.2,266700.0 442 | 5.936,16.94,20.2,283500.0 443 | 6.301,16.23,20.2,312900.0 444 | 6.081,14.7,20.2,420000.0 445 | 6.701,16.42,20.2,344400.0 446 | 6.376,14.65,20.2,371700.0 447 | 6.317,13.99,20.2,409500.0 448 | 6.513,10.29,20.2,424200.0 449 | 6.209,13.22,20.2,449400.0 450 | 5.759,14.13,20.2,417900.0 451 | 5.952,17.15,20.2,399000.0 452 | 6.003,21.32,20.2,401100.0 453 | 5.926,18.13,20.2,401100.0 454 | 5.713,14.76,20.2,422100.0 455 | 6.167,16.29,20.2,417900.0 456 | 6.229,12.87,20.2,411600.0 457 | 6.437,14.36,20.2,487200.0 458 | 6.98,11.66,20.2,625800.0 459 | 5.427,18.14,20.2,289800.0 460 | 6.162,24.1,20.2,279300.0 461 | 6.484,18.68,20.2,350700.0 462 | 5.304,24.91,20.2,252000.0 463 | 6.185,18.03,20.2,306600.0 464 | 6.229,13.11,20.2,449400.0 465 | 6.242,10.74,20.2,483000.0 466 | 6.75,7.74,20.2,497700.0 467 | 7.061,7.01,20.2,525000.0 468 | 5.762,10.42,20.2,457800.0 469 | 5.871,13.34,20.2,432600.0 470 | 6.312,10.58,20.2,445200.0 471 | 6.114,14.98,20.2,401100.0 472 | 5.905,11.45,20.2,432600.0 473 | 5.454,18.06,20.1,319200.0 474 | 5.414,23.97,20.1,147000.0 475 | 5.093,29.68,20.1,170100.0 476 | 5.983,18.07,20.1,285600.0 477 | 5.983,13.35,20.1,422100.0 478 | 5.707,12.01,19.2,457800.0 479 | 5.926,13.59,19.2,514500.0 480 | 5.67,17.6,19.2,485100.0 481 | 5.39,21.14,19.2,413700.0 482 | 5.794,14.1,19.2,384300.0 483 | 6.019,12.92,19.2,445200.0 484 | 5.569,15.1,19.2,367500.0 485 | 6.027,14.33,19.2,352800.0 486 | 6.593,9.67,21.0,470400.0 487 | 6.12,9.08,21.0,432600.0 488 | 6.976,5.64,21.0,501900.0 489 | 6.794,6.48,21.0,462000.0 490 | 6.03,7.88,21.0,249900.0 491 | -------------------------------------------------------------------------------- /knn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Goals: Build K nearest neighbors model from scratch and compare with sklearn model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Unsupervised nearest neighbors" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "%reload_ext autoreload\n", 25 | "%autoreload 2" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 26, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "from model.metrics import accuracy" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "Toy example" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "toy = np.array([[-3, -2], [-2, -1], [-1, -1],[1, 1], [2, 1], [3, 2]])" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "[]" 65 | ] 66 | }, 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | }, 71 | { 72 | "data": { 73 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEg5JREFUeJzt3X+M5Hd93/Hn6w6b5AgtSb0Jxr69ddVTFJfQkI7cRFQVFSY5LOQLSZCMVo3TJFrR1gqRWgk3J4FKdVKiSGmUgAKbYMVUW5woxOUqjoIdiAhqTbxnHWBzOLlaPXt7VrxAgVib1rrw7h8zFnvX2du9/X53Z3c+z4c0mu/3M5+b9+ejs+c13x9zn1QVkqT2HJj0ACRJk2EASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhr1kkkP4GpuuOGGmpubm/QwJGnfOHPmzFeqamYrffd0AMzNzbG8vDzpYUjSvpHkwlb7egpIkhplAEhSowwASWqUASBJjTIAJKlRnQMgyeEkn05yLskTSd4xpk+S/GaS80m+kOSHu9aVpKmztARzc3DgwPB5aWlHy/VxG+gl4F9X1WNJXg6cSfJQVX1pXZ83AUdHj38E/PboWZIEww/7hQVYWxvuX7gw3AeYn9+Rkp2PAKrq2ap6bLT9V8A54KYruh0HPlRDjwCvSHJj19qSNDVOnPj2h/+L1taG7Tuk12sASeaA1wKfu+Klm4Bn1u2v8P+HxIvvsZBkOcny6upqn8OTpL3r6aevrb0HvQVAku8CPgL8UlV988qXx/yRsavRV9ViVQ2qajAzs6VfM0vS/jc7e23tPeglAJJcx/DDf6mq/mhMlxXg8Lr9m4GLfdSWpKlw8iQcOnR526FDw/Yd0sddQAE+CJyrql/foNsp4GdGdwP9CPCNqnq2a21Jmhrz87C4CEeOQDJ8XlzcsQvA0M9dQK8D/hnwxSRnR22/DMwCVNX7gdPAHcB5YA345z3UlaTpMj+/ox/4V+ocAFX1Wcaf41/fp4B/1bWWJKk//hJYkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRvW1JOR9SZ5L8vgGr78+yTeSnB093tVHXUnS9vWxIhjA7wHvBT50lT5/WlVv7qmeJKmjXo4AquozwNf6eC9J0u7YzWsAP5rk80k+nuTv72JdSdIYfZ0C2sxjwJGqej7JHcB/Bo6O65hkAVgAmJ2d3aXhSVJ7duUIoKq+WVXPj7ZPA9cluWGDvotVNaiqwczMzG4MT5KatCsBkOSVSTLavm1U96u7UVuSNF4vp4CSfBh4PXBDkhXg3cB1AFX1fuCngX+R5BLw18BdVVV91JYkbU8vAVBVb9vk9fcyvE1UkrRH+EtgSWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjegmAJPcleS7J4xu8niS/meR8ki8k+eE+6kqStq+vI4DfA45d5fU3AUdHjwXgt3uqK2m7lpZgbg4OHBg+Ly1NekTbMy3zmIC+loT8TJK5q3Q5DnxotA7wI0lekeTGqnq2j/qSrtHSEiwswNracP/CheE+wPz85MZ1raZlHhOyW9cAbgKeWbe/MmqTNAknTnz7Q/NFa2vD9v1kWuYxIbsVABnTVmM7JgtJlpMsr66u7vCwpEY9/fS1te9V0zKPCdmtAFgBDq/bvxm4OK5jVS1W1aCqBjMzM7syOKk5s7PX1r5XTcs8JmS3AuAU8DOju4F+BPiG5/+lCTp5Eg4durzt0KFh+34yLfOYkL5uA/0w8N+B70+ykuTnk7w9ydtHXU4DTwHngd8B/mUfdSVt0/w8LC7CkSOQDJ8XF/ffhdNpmceEZHhjzt40GAxqeXl50sOQpH0jyZmqGmylr78ElqRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1qq8VwY4leTLJ+ST3jnn9Z5OsJjk7evxCH3UlSdv3kq5vkOQg8D7gjQwXf380yamq+tIVXX+/qu7pWk+S1I8+jgBuA85X1VNV9QLwAHC8h/eVJO2gPgLgJuCZdfsro7Yr/VSSLyT5wySHN3qzJAtJlpMsr66u9jA8SdI4fQRAxrRdudL8fwHmquo1wMPA/Ru9WVUtVtWgqgYzMzM9DE+SNE4fAbACrP9GfzNwcX2HqvpqVf3f0e7vAP+wh7qSpA76CIBHgaNJbklyPXAXcGp9hyQ3rtu9EzjXQ11JUged7wKqqktJ7gE+ARwE7quqJ5K8B1iuqlPALya5E7gEfA342a51JUndpOrK0/V7x2AwqOXl5UkPQ5L2jSRnqmqwlb7+EliSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRG9RIASY4leTLJ+ST3jnn9pUl+f/T655LM9VFXkrR9nQMgyUHgfcCbgFuBtyW59YpuPw/876r6e8B/AH61a11JUjd9HAHcBpyvqqeq6gXgAeD4FX2OA/ePtv8QeEOS9FBbkrRNfQTATcAz6/ZXRm1j+1TVJeAbwN/pobYkaZv6CIBx3+SvXGh4K32GHZOFJMtJlldXVzsPTpI0Xh8BsAIcXrd/M3Bxoz5JXgL8beBr496sqharalBVg5mZmR6GJ0kap48AeBQ4muSWJNcDdwGnruhzCrh7tP3TwKeqauwRgCRpd7yk6xtU1aUk9wCfAA4C91XVE0neAyxX1Sngg8B/THKe4Tf/u7rWlSR10zkAAKrqNHD6irZ3rdv+P8Bb+6glSeqHvwSWpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDWqUwAk+Z4kDyX5i9Hzd2/Q72+SnB09rlwuUpI0AV2PAO4F/riqjgJ/PNof56+r6odGjzs71pQk9aBrABwH7h9t3w/8RMf3kyTtkq4B8H1V9SzA6Pl7N+j3HUmWkzySxJCQpD1g00XhkzwMvHLMSyeuoc5sVV1M8neBTyX5YlX9jw3qLQALALOzs9dQQpJ0LTYNgKq6faPXkvxlkhur6tkkNwLPbfAeF0fPTyX5E+C1wNgAqKpFYBFgMBjUpjOQJG1L11NAp4C7R9t3Ax+9skOS707y0tH2DcDrgC91rCtJ6qhrAPwK8MYkfwG8cbRPkkGS3x31+QFgOcnngU8Dv1JVBoAkTdimp4Cupqq+CrxhTPsy8Auj7f8G/GCXOpKk/vlLYElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSozoFQJK3JnkiybeSDK7S71iSJ5OcT3Jvl5pNWVqCuTk4cGD4vLQ06RFtz7TMQ5oynVYEAx4HfhL4wEYdkhwE3sdwycgV4NEkp1wWchNLS7CwAGtrw/0LF4b7APPzkxvXtZqWeUhTqNMRQFWdq6onN+l2G3C+qp6qqheAB4DjXeo24cSJb39ovmhtbdi+n0zLPKQptBvXAG4Cnlm3vzJqGyvJQpLlJMurq6s7Prg96+mnr619r5qWeUhTaNMASPJwksfHPLb6LT5j2mqjzlW1WFWDqhrMzMxsscQUmp29tva9alrmIU2hTQOgqm6vqlePeXx0izVWgMPr9m8GLm5nsE05eRIOHbq87dChYft+Mi3zkKbQbpwCehQ4muSWJNcDdwGndqHu/jY/D4uLcOQIJMPnxcX9d+F0WuYhTaFUbXg2ZvM/nLwF+C1gBvg6cLaqfjzJq4Dfrao7Rv3uAH4DOAjcV1Vb+vo3GAxqeXl52+OTpNYkOVNVG96Wv16n20Cr6kHgwTHtF4E71u2fBk53qSVJ6pe/BJakRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNapTACR5a5InknwryYYr0CT5n0m+mORsEpf4kqQ9oNOKYMDjwE8CH9hC339aVV/pWE+S1JOuS0KeA0jSz2gkSbtmt64BFPDJJGeSLFytY5KFJMtJlldXV3dpeJLUnk2PAJI8DLxyzEsnquqjW6zzuqq6mOR7gYeSfLmqPjOuY1UtAosAg8Ggtvj+kqRrtGkAVNXtXYtU1cXR83NJHgRuA8YGgCRpd+z4KaAkL0vy8he3gR9jePFYkjRBXW8DfUuSFeBHgY8l+cSo/VVJTo+6fR/w2SSfB/4M+FhV/dcudSVJ3XW9C+hB4MEx7ReBO0bbTwH/oEsdSVL//CWwJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjeq6IMyvJflyki8keTDJKzbodyzJk0nOJ7m3S01JUj+6HgE8BLy6ql4D/Dnwb6/skOQg8D7gTcCtwNuS3Nqx7saWlmBuDg4cGD4vLe1YKUnazzoFQFV9sqoujXYfAW4e0+024HxVPVVVLwAPAMe71N3Q0hIsLMCFC1A1fF5YMAQkaYw+rwH8HPDxMe03Ac+s218ZtfXvxAlYW7u8bW1t2C5JusymawIneRh45ZiXTlTVR0d9TgCXgHFftTOmra5SbwFYAJidnd1seJd7+ulra5ekhm0aAFV1+9VeT3I38GbgDVU17oN9BTi8bv9m4OJV6i0CiwCDwWDDoBhrdnZ42mdcuyTpMl3vAjoGvBO4s6rWNuj2KHA0yS1JrgfuAk51qbuhkyfh0KHL2w4dGrZLki7T9RrAe4GXAw8lOZvk/QBJXpXkNMDoIvE9wCeAc8AfVNUTHeuONz8Pi4tw5Agkw+fFxWG7JOkyGX/WZm8YDAa1vLw86WFI0r6R5ExVDbbS118CS1KjDABJapQBIEmNMgAkqVEGgCQ1ak/fBZRkFRjzy64tuQH4So/DmaRpmcu0zAOcy140LfOAbnM5UlUzW+m4pwOgiyTLW70Vaq+blrlMyzzAuexF0zIP2L25eApIkhplAEhSo6Y5ABYnPYAeTctcpmUe4Fz2ommZB+zSXKb2GoAk6eqm+QhAknQVUx0ASf79aMH6s0k+meRVkx7TdiT5tSRfHs3lwSSvmPSYtivJW5M8keRbSfbdHRtJjiV5Msn5JPdOejxdJLkvyXNJHp/0WLpIcjjJp5OcG/239Y5Jj2m7knxHkj9L8vnRXP7djtab5lNASf5WVX1ztP2LwK1V9fYJD+uaJfkx4FNVdSnJrwJU1TsnPKxtSfIDwLeADwD/pqr2zT/3muQg8OfAGxkudPQo8Laq+tJEB7ZNSf4J8Dzwoap69aTHs11JbgRurKrHkrwcOAP8xH78e0kS4GVV9XyS64DPAu+oqkd2ot5UHwG8+OE/8jKushTlXlZVnxytqwDwCMNV1falqjpXVU9OehzbdBtwvqqeqqoXgAeA4xMe07ZV1WeAr016HF1V1bNV9dho+68YrjuyM+uO77Aaen60e93osWOfW1MdAABJTiZ5BpgH3jXp8fTg54CPT3oQjboJeGbd/gr79INmWiWZA14LfG6yI9m+JAeTnAWeAx6qqh2by74PgCQPJ3l8zOM4QFWdqKrDDBesv2eyo93YZvMY9TkBXGI4lz1rK3PZpzKmbV8eVU6jJN8FfAT4pSuO/veVqvqbqvohhkf6tyXZsdNzmy4Kv9dttmj9Ov8J+Bjw7h0czrZtNo8kdwNvBt5Qe/zCzTX8new3K8Dhdfs3AxcnNBatMzpf/hFgqar+aNLj6UNVfT3JnwDHgB25UL/vjwCuJsnRdbt3Al+e1Fi6SHIMeCdwZ1WtTXo8DXsUOJrkliTXA3cBpyY8puaNLpx+EDhXVb8+6fF0kWTmxbv8knwncDs7+Lk17XcBfQT4foZ3nVwA3l5V/2uyo7p2Sc4DLwW+Omp6ZD/ezQSQ5C3AbwEzwNeBs1X145Md1dYluQP4DeAgcF9VnZzwkLYtyYeB1zP8lyf/Enh3VX1wooPahiT/GPhT4IsM/18H+OWqOj25UW1PktcA9zP87+sA8AdV9Z4dqzfNASBJ2thUnwKSJG3MAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVH/DzE8KcwjdPNkAAAAAElFTkSuQmCC\n", 74 | "text/plain": [ 75 | "
" 76 | ] 77 | }, 78 | "metadata": {}, 79 | "output_type": "display_data" 80 | } 81 | ], 82 | "source": [ 83 | "plt.plot(toy[:,0],toy[:,1],'ro')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## 1. Sklearn unsupervised nearest neighbors" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 6, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "[[0. 1.41421356 2.23606798]\n", 103 | " [0. 1. 1.41421356]\n", 104 | " [0. 1. 2.23606798]\n", 105 | " [0. 1. 2.23606798]\n", 106 | " [0. 1. 1.41421356]\n", 107 | " [0. 1.41421356 2.23606798]]\n", 108 | "[[0 1 2]\n", 109 | " [1 2 0]\n", 110 | " [2 1 0]\n", 111 | " [3 4 5]\n", 112 | " [4 3 5]\n", 113 | " [5 4 3]]\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "from sklearn.neighbors import NearestNeighbors\n", 119 | "nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(toy) # 3 neighbors, including himself\n", 120 | "dist,idx=nbrs.kneighbors(toy)\n", 121 | "print(dist)\n", 122 | "print(idx)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "(array([[0.70710678, 0.70710678, 1.58113883]]), array([[2, 1, 0]], dtype=int64))\n" 135 | ] 136 | }, 137 | { 138 | "data": { 139 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEnNJREFUeJzt3X+MpVd93/H3ZxebZAgtST0Jxt7ZcdVVFJfSkF65iagqKkyyWMgb0iAZjRqnSTQirRUitRJuVgKVaqVEkdIoAQUmwYqJpjhRiMtWLAU7EBGUmnjWWsBmcbK12PV0rXiAArEmrbXxt3/cazE7vbPz4z4zd+ae90u6ep5znrP3nKO17+c+P+6eVBWSpPYcGvcAJEnjYQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGvWScQ/gWm644YaanZ0d9zAk6cA4e/bsV6tqeitt93UAzM7OsrS0NO5hSNKBkeTiVtt6CUiSGmUASFKjDABJapQBIEmNMgAkqVEjB0CSI0k+neR8kieSvGNImyT5jSQXknwhyQ+N2q8kTZzFRZidhUOH+tvFxV3trovHQK8A/66qHkvycuBskoeq6ktr2rwJODZ4/VPgtwZbSRL0P+zn52F1tV++eLFfBpib25UuRz4DqKpnquqxwf5fA+eBm9Y1OwF8qPoeAV6R5MZR+5akiXHy5Lc//F+0utqv3yWd3gNIMgu8FvjcukM3AU+vKS/z/4fEi+8xn2QpydLKykqXw5Ok/evSpe3Vd6CzAEjyXcBHgF+sqm+tPzzkjwxdjb6qFqqqV1W96ekt/ZpZkg6+mZnt1XegkwBIch39D//FqvqjIU2WgSNryjcDl7voW5ImwqlTMDV1dd3UVL9+l3TxFFCADwLnq+rXNmh2GvipwdNAPwx8s6qeGbVvSZoYc3OwsABHj0LS3y4s7NoNYOjmKaDXAf8K+GKSc4O6XwJmAKrq/cAZ4A7gArAK/OsO+pWkyTI3t6sf+OuNHABV9VmGX+Nf26aAfztqX5Kk7vhLYElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhrV1ZKQ9yV5NsnjGxx/fZJvJjk3eL2ri34lSTvXxYpgAL8LvBf40DXa/GlVvbmj/iRJI+rkDKCqPgN8vYv3kiTtjb28B/AjST6f5ONJ/uEe9itJGqKrS0CbeQw4WlXPJbkD+K/AsWENk8wD8wAzMzN7NDxJas+enAFU1beq6rnB/hnguiQ3bNB2oap6VdWbnp7ei+FJUpP2JACSvDJJBvu3Dfr92l70LUkarpNLQEk+DLweuCHJMvBu4DqAqno/8JPAzye5AvwNcFdVVRd9S5J2ppMAqKq3bXL8vfQfE5Uk7RP+EliSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1KhOAiDJfUmeTfL4BseT5DeSXEjyhSQ/1EW/kqSd6+oM4HeB49c4/ibg2OA1D/xWR/1K2qnFRZidhUOH+tvFxXGPaGcmZR5j0NWSkJ9JMnuNJieADw3WAX4kySuS3FhVz3TRv6RtWlyE+XlYXe2XL17slwHm5sY3ru2alHmMyV7dA7gJeHpNeXlQJ2kcTp789ofmi1ZX+/UHyaTMY0z2KgAypK6GNkzmkywlWVpZWdnlYUmNunRpe/X71aTMY0z2KgCWgSNryjcDl4c1rKqFqupVVW96enpPBic1Z2Zme/X71aTMY0z2KgBOAz81eBroh4Fvev1fGqNTp2Bq6uq6qal+/UEyKfMYk64eA/0w8D+A70+ynORnk7w9ydsHTc4ATwEXgN8G/k0X/Uraobk5WFiAo0ch6W8XFg7ejdNJmceYpP9gzv7U6/VqaWlp3MOQpAMjydmq6m2lrb8ElqRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1qqsVwY4neTLJhST3Djn+00lWkpwbvH6ui34lSTv3klHfIMlh4H3AG+kv/v5oktNV9aV1TX+/qu4ZtT9JUje6OAO4DbhQVU9V1fPAA8CJDt5XkrSLugiAm4Cn15SXB3Xr/cskX0jyh0mObPRmSeaTLCVZWllZ6WB4kqRhugiADKlbv9L8fwNmq+o1wMPA/Ru9WVUtVFWvqnrT09MdDE+SNEwXAbAMrP1GfzNweW2DqvpaVf3fQfG3gX/SQb+SpBF0EQCPAseS3JLkeuAu4PTaBkluXFO8EzjfQb+SpBGM/BRQVV1Jcg/wCeAwcF9VPZHkPcBSVZ0GfiHJncAV4OvAT4/aryRpNKlaf7l+/+j1erW0tDTuYUjSgZHkbFX1ttLWXwJLUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUqE4CIMnxJE8muZDk3iHHX5rk9wfHP5dktot+JUk7N3IAJDkMvA94E3Ar8LYkt65r9rPA/66qfwD8Z+BXRu1XkjSaLs4AbgMuVNVTVfU88ABwYl2bE8D9g/0/BN6QJB30LUnaoS4C4Cbg6TXl5UHd0DZVdQX4JvD3OuhbkrRDXQTAsG/y6xca3kqbfsNkPslSkqWVlZWRBydJGq6LAFgGjqwp3wxc3qhNkpcAfxf4+rA3q6qFqupVVW96erqD4UmShukiAB4FjiW5Jcn1wF3A6XVtTgN3D/Z/EvhUVQ09A5Ak7Y2XjPoGVXUlyT3AJ4DDwH1V9USS9wBLVXUa+CDwe0ku0P/mf9eo/UqSRjNyAABU1RngzLq6d63Z/z/AW7voS5LUDX8JLEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElq1EgBkOR7kjyU5C8H2+/eoN3fJjk3eK1fLlKSNAajngHcC/xxVR0D/nhQHuZvquoHB687R+xTktSBUQPgBHD/YP9+4MdHfD9J0h4ZNQC+r6qeARhsv3eDdt+RZCnJI0kMCUnaBzZdFD7Jw8Arhxw6uY1+ZqrqcpK/D3wqyRer6n9u0N88MA8wMzOzjS4kSduxaQBU1e0bHUvyV0lurKpnktwIPLvBe1webJ9K8ifAa4GhAVBVC8ACQK/Xq01nIEnakVEvAZ0G7h7s3w18dH2DJN+d5KWD/RuA1wFfGrFfSdKIRg2AXwbemOQvgTcOyiTpJfmdQZsfAJaSfB74NPDLVWUASNKYbXoJ6Fqq6mvAG4bULwE/N9j/M+AfjdKPJKl7/hJYkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktSokQIgyVuTPJHkhSS9a7Q7nuTJJBeS3DtKn01ZXITZWTh0qL9dXBz3iHZmUuYhTZiRVgQDHgd+AvjARg2SHAbeR3/JyGXg0SSnXRZyE4uLMD8Pq6v98sWL/TLA3Nz4xrVdkzIPaQKNdAZQVeer6slNmt0GXKiqp6rqeeAB4MQo/Tbh5Mlvf2i+aHW1X3+QTMo8pAm0F/cAbgKeXlNeHtQNlWQ+yVKSpZWVlV0f3L516dL26verSZmHNIE2DYAkDyd5fMhrq9/iM6SuNmpcVQtV1auq3vT09Ba7mEAzM9ur368mZR7SBNo0AKrq9qp69ZDXR7fYxzJwZE35ZuDyTgbblFOnYGrq6rqpqX79QTIp85Am0F5cAnoUOJbkliTXA3cBp/eg34Ntbg4WFuDoUUj624WFg3fjdFLmIU2gVG14NWbzP5y8BfhNYBr4BnCuqn4syauA36mqOwbt7gB+HTgM3FdVW/r61+v1amlpacfjk6TWJDlbVRs+lr/WSI+BVtWDwIND6i8Dd6wpnwHOjNKXJKlb/hJYkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktSokQIgyVuTPJHkhSQbrkCT5CtJvpjkXBKX+NpHFhdhdhYOHepvFxfHPSJJe2WkFcGAx4GfAD6whbb/oqq+OmJ/6tDiIszPw+pqv3zxYr8MLtkrtWCkM4CqOl9VT3Y1GO2tkye//eH/otXVfr2kybdX9wAK+GSSs0nmr9UwyXySpSRLKysrezS8Nl26tL16SZNl00tASR4GXjnk0Mmq+ugW+3ldVV1O8r3AQ0m+XFWfGdawqhaABYBer1dbfH/twMxM/7LPsHpJk2/TAKiq20ftpKouD7bPJnkQuA0YGgDaO6dOXX0PAGBqql8vafLt+iWgJC9L8vIX94EfpX/zWGM2NwcLC3D0KCT97cKCN4ClVoz6GOhbkiwDPwJ8LMknBvWvSnJm0Oz7gM8m+Tzw58DHquq/j9KvujM3B1/5CrzwQn/rh7/UjpEeA62qB4EHh9RfBu4Y7D8F/ONR+pEkdc9fAktSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktSoUReE+dUkX07yhSQPJnnFBu2OJ3kyyYUk947SpySpG6OeATwEvLqqXgP8BfAf1jdIchh4H/Am4FbgbUluHbHfjS0uwuwsHDrU3y4u7lpXknSQjRQAVfXJqroyKD4C3Dyk2W3Ahap6qqqeBx4ATozS74YWF/urnF+8CFX97fy8ISBJQ3R5D+BngI8Pqb8JeHpNeXlQ172TJ2F19eq61dV+vSTpKpuuCZzkYeCVQw6drKqPDtqcBK4Aw75qZ0hdXaO/eWAeYGZmZrPhXe3Spe3VS1LDNg2Aqrr9WseT3A28GXhDVQ37YF8Gjqwp3wxcvkZ/C8ACQK/X2zAohpqZ6V/2GVYvSbrKqE8BHQfeCdxZVasbNHsUOJbkliTXA3cBp0fpd0OnTsHU1NV1U1P9eknSVUa9B/Be4OXAQ0nOJXk/QJJXJTkDMLhJfA/wCeA88AdV9cSI/Q43NwcLC3D0KCT97cJCv16SdJUMv2qzP/R6vVpaWhr3MCTpwEhytqp6W2nrL4ElqVEGgCQ1ygCQpEYZAJLUKANAkhq1r58CSrICDPll15bcAHy1w+GM06TMZVLmAc5lP5qUecBoczlaVdNbabivA2AUSZa2+ijUfjcpc5mUeYBz2Y8mZR6wd3PxEpAkNcoAkKRGTXIALIx7AB2alLlMyjzAuexHkzIP2KO5TOw9AEnStU3yGYAk6RomOgCS/KfBgvXnknwyyavGPaadSPKrSb48mMuDSV4x7jHtVJK3JnkiyQtJDtwTG0mOJ3kyyYUk9457PKNIcl+SZ5M8Pu6xjCLJkSSfTnJ+8N/WO8Y9pp1K8h1J/jzJ5wdz+Y+72t8kXwJK8neq6luD/V8Abq2qt495WNuW5EeBT1XVlSS/AlBV7xzzsHYkyQ8ALwAfAP59VR2Yf+41yWHgL4A30l/o6FHgbVX1pbEObIeS/HPgOeBDVfXqcY9np5LcCNxYVY8leTlwFvjxg/j3kiTAy6rquSTXAZ8F3lFVj+xGfxN9BvDih//Ay7jGUpT7WVV9crCuAsAj9FdVO5Cq6nxVPTnucezQbcCFqnqqqp4HHgBOjHlMO1ZVnwG+Pu5xjKqqnqmqxwb7f01/3ZHdWXd8l1Xfc4PidYPXrn1uTXQAACQ5leRpYA5417jH04GfAT4+7kE06ibg6TXlZQ7oB82kSjILvBb43HhHsnNJDic5BzwLPFRVuzaXAx8ASR5O8viQ1wmAqjpZVUfoL1h/z3hHu7HN5jFocxK4Qn8u+9ZW5nJAZUjdgTyrnERJvgv4CPCL687+D5Sq+tuq+kH6Z/q3Jdm1y3ObLgq/3222aP0a/wX4GPDuXRzOjm02jyR3A28G3lD7/MbNNv5ODppl4Mia8s3A5TGNRWsMrpd/BFisqj8a93i6UFXfSPInwHFgV27UH/gzgGtJcmxN8U7gy+MayyiSHAfeCdxZVavjHk/DHgWOJbklyfXAXcDpMY+peYMbpx8EzlfVr417PKNIMv3iU35JvhO4nV383Jr0p4A+Anw//adOLgJvr6r/Nd5RbV+SC8BLga8Nqh45iE8zASR5C/CbwDTwDeBcVf3YeEe1dUnuAH4dOAzcV1WnxjykHUvyYeD19P/lyb8C3l1VHxzroHYgyT8D/hT4Iv3/1wF+qarOjG9UO5PkNcD99P/7OgT8QVW9Z9f6m+QAkCRtbKIvAUmSNmYASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUqP8HKrVTm8232E8AAAAASUVORK5CYII=\n", 140 | "text/plain": [ 141 | "
" 142 | ] 143 | }, 144 | "metadata": {}, 145 | "output_type": "display_data" 146 | } 147 | ], 148 | "source": [ 149 | "test = [-1.5,-1.5]\n", 150 | "print(nbrs.kneighbors([test]))\n", 151 | "plt.plot(toy[:,0],toy[:,1],'ro')\n", 152 | "plt.plot(test[0],test[1],'bo')\n", 153 | "plt.show()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "## 2. My Unsupervised KNN" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 11, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "from model.knn import CustomNearestNeighbor" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 12, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "c_nn = CustomNearestNeighbor(k=3)\n", 179 | "c_nn.fit(toy)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 18, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "c_dist,c_idx=c_nn.kneighbors(toy)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 16, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/plain": [ 199 | "array([[ True, True, True],\n", 200 | " [ True, True, True],\n", 201 | " [ True, True, True],\n", 202 | " [ True, True, True],\n", 203 | " [ True, True, True],\n", 204 | " [ True, True, True]])" 205 | ] 206 | }, 207 | "execution_count": 16, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "c_dist == dist" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 19, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/plain": [ 224 | "array([[ True, True, True],\n", 225 | " [ True, True, True],\n", 226 | " [ True, True, True],\n", 227 | " [ True, True, True],\n", 228 | " [ True, True, True],\n", 229 | " [ True, True, True]])" 230 | ] 231 | }, 232 | "execution_count": 19, 233 | "metadata": {}, 234 | "output_type": "execute_result" 235 | } 236 | ], 237 | "source": [ 238 | "c_idx ==idx" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "Identical results to sklearn nearest neighbors" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "# Supervised nearest neighbors" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "We will use Fashion MNIST https://github.com/zalandoresearch/fashion-mnist for benchmarking" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 20, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# Fashion MNIST\n", 269 | "def load_mnist(path, kind='train'):\n", 270 | " import os\n", 271 | " import struct\n", 272 | " import gzip\n", 273 | " import numpy as np\n", 274 | "\n", 275 | " \"\"\"Load MNIST data from `path`\"\"\"\n", 276 | " labels_path = os.path.join(path,\n", 277 | " '%s-labels-idx1-ubyte.gz'\n", 278 | " % kind)\n", 279 | " images_path = os.path.join(path,\n", 280 | " '%s-images-idx3-ubyte.gz'\n", 281 | " % kind)\n", 282 | "\n", 283 | " with gzip.open(labels_path, 'rb') as lbpath:\n", 284 | " struct.unpack('>II', lbpath.read(8))\n", 285 | " labels = np.frombuffer(lbpath.read(), dtype=np.uint8)\n", 286 | "\n", 287 | " with gzip.open(images_path, 'rb') as imgpath:\n", 288 | " struct.unpack(\">IIII\", imgpath.read(16))\n", 289 | " images = np.frombuffer(imgpath.read(), dtype=np.uint8).reshape(len(labels), 784)\n", 290 | "\n", 291 | " return images, labels" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 21, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "labels={\n", 301 | " 0:'T-shirt/top',\n", 302 | " 1:'Trouser',\n", 303 | " 2:'Pullover',\n", 304 | " 3:'Dress',\n", 305 | " 4:'Coat',\n", 306 | " 5:'Sandal',\n", 307 | " 6:'Shirt',\n", 308 | " 7:'Sneaker',\n", 309 | " 8:'Bag',\n", 310 | " 9:'Ankle boot',\n", 311 | "}" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 22, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "Training set size: (60000, 784)\n", 324 | "Testing set size: (10000, 784)\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "X_train, y_train = load_mnist('data/fashion', kind='train')\n", 330 | "X_test, y_test = load_mnist('data/fashion', kind='t10k')\n", 331 | "y_train,y_test = y_train.astype(np.int64),y_test.astype(np.int64)\n", 332 | "print('Training set size: {}'.format(X_train.shape))\n", 333 | "print('Testing set size: {}'.format(X_test.shape))" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 23, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "#only take 5000 images for simplification\n", 343 | "np.random.seed(42)\n", 344 | "train_idx =np.random.permutation(5000)\n", 345 | "test_idx = np.random.permutation(1000)\n", 346 | "X_train = X_train[train_idx]\n", 347 | "y_train = y_train[train_idx]\n", 348 | "X_test = X_test[test_idx]\n", 349 | "y_test = y_test[test_idx]" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "Preprocess and split dataset" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 24, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stderr", 366 | "output_type": "stream", 367 | "text": [ 368 | "C:\\Users\\qtran\\AppData\\Local\\Continuum\\Miniconda3\\envs\\fastai-cpu\\lib\\site-packages\\sklearn\\utils\\validation.py:475: DataConversionWarning: Data with input dtype uint8 was converted to float64 by StandardScaler.\n", 369 | " warnings.warn(msg, DataConversionWarning)\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "# from sklearn.model_selection import train_test_split\n", 375 | "from sklearn.metrics import accuracy_score\n", 376 | "from sklearn.preprocessing import StandardScaler\n", 377 | "\n", 378 | "s = StandardScaler()\n", 379 | "X_train = s.fit_transform(X_train)\n", 380 | "X_test = s.transform(X_test)\n", 381 | "y_train = np.array(y_train)\n", 382 | "y_test = np.array(y_test)" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": {}, 388 | "source": [ 389 | "## 1. Sklearn KNeighborsClassifier" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 27, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "data": { 399 | "text/plain": [ 400 | "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", 401 | " metric_params=None, n_jobs=1, n_neighbors=10, p=2,\n", 402 | " weights='uniform')" 403 | ] 404 | }, 405 | "execution_count": 27, 406 | "metadata": {}, 407 | "output_type": "execute_result" 408 | } 409 | ], 410 | "source": [ 411 | "from sklearn.neighbors import KNeighborsClassifier\n", 412 | "\n", 413 | "nn = KNeighborsClassifier(n_neighbors=10)\n", 414 | "nn.fit(X_train,y_train)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 28, 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "data": { 424 | "text/plain": [ 425 | "0.807" 426 | ] 427 | }, 428 | "execution_count": 28, 429 | "metadata": {}, 430 | "output_type": "execute_result" 431 | } 432 | ], 433 | "source": [ 434 | "pred = nn.predict(X_test)\n", 435 | "accuracy(y_test,pred)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "## 2. My supervised KNN" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 29, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "(5000, 784)\n", 455 | "(1000, 784)\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "print(X_train.shape)\n", 461 | "print(X_test.shape)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 32, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "c_nn = CustomNearestNeighbor(k=10)\n", 471 | "c_nn.fit(X_train,y_train)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "### uniform weights among class" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 33, 484 | "metadata": { 485 | "scrolled": false 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "y_pred,counter,class_sorted=c_nn.predict_classification(X_test,weighted=False)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 34, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "data": { 499 | "text/plain": [ 500 | "0.804" 501 | ] 502 | }, 503 | "execution_count": 34, 504 | "metadata": {}, 505 | "output_type": "execute_result" 506 | } 507 | ], 508 | "source": [ 509 | "accuracy_score(y_test,y_pred)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "### Weight points by the inverse of distance" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 35, 522 | "metadata": {}, 523 | "outputs": [], 524 | "source": [ 525 | "y_pred,weighted_counter,class_sorted=c_nn.predict_classification(X_test,weighted=True)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 36, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "data": { 535 | "text/plain": [ 536 | "0.809" 537 | ] 538 | }, 539 | "execution_count": 36, 540 | "metadata": {}, 541 | "output_type": "execute_result" 542 | } 543 | ], 544 | "source": [ 545 | "accuracy_score(y_test,y_pred)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": {}, 551 | "source": [ 552 | "Not bad! With weighted inverse-Euclidean distance, there is small boost in accuracy on test set" 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "metadata": {}, 558 | "source": [ 559 | "### Evaluate results" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 37, 565 | "metadata": {}, 566 | "outputs": [], 567 | "source": [ 568 | "from matplotlib import pyplot as plt,cm\n", 569 | "def evaluate(idx):\n", 570 | " img = X_test[idx].reshape([28,28])\n", 571 | " plt.imshow(img,cmap=cm.binary)\n", 572 | " print(f'Weights (in order): {sorted(weighted_counter[idx],reverse=True)}')\n", 573 | " print(f'Predictions (in order): {[labels[i] for i in class_sorted[idx]]}')\n", 574 | " print('Actual: ' + labels[y_test[0]])\n" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 38, 580 | "metadata": {}, 581 | "outputs": [ 582 | { 583 | "name": "stdout", 584 | "output_type": "stream", 585 | "text": [ 586 | "Weights (in order): [0.3040748811530332, 0.28658216317523766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n", 587 | "Predictions (in order): ['Coat', 'Shirt', 'Ankle boot', 'Bag', 'Sneaker', 'Sandal', 'Dress', 'Pullover', 'Trouser', 'T-shirt/top']\n", 588 | "Actual: Coat\n" 589 | ] 590 | }, 591 | { 592 | "data": { 593 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE8tJREFUeJzt3X+IleeVB/DvyUQdnRl1RuNvd9KtISQEYpdBlmRZsjRp0k3BNNDQIRQXSm2gwhYK2SCE5p9CWGy7ISwFu5EaqLGFNhtDQtsQGtzCIhklNBo3aTBaJzNxHJ1xRmeiUU//mNfsaOY958593vu+157vB8SZe+573+e+d87cmTnPcx5RVRBRPDdUPQAiqgaTnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXF5CcKislPFNSNZZ6ss7NTV61aVeYpCyMiuTFvluQNN9jfY63HruX4ixcv5sY++eQT81iPd+65c+ea8ZaWltyYd928+OXLl5PijVTVzNmBgQGMjIzYX1CZpOQXkQcAPAOgBcB/qerT1v1XrVqFXbt2WY+XMpakuMdKAu+LbP78+WbcS6AFCxaY8eHh4dzYRx99ZB5rfeMAgIULF5rxNWvW1H28d+7z58+b8cnJSTM+MTGRG/O+Hhr9jck6PuUbR29vb833rfvHfhFpAfCfAL4M4HYAvSJye72PR0TlSvmdfwOA91X1iKpeALAbwMZihkVEjZaS/KsBHJ/2eX9221VEZLOI9IlI38jISMLpiKhIKck/0y9Nn/llRVW3q2qPqvZ0dnYmnI6IipSS/P0A1k77fA2AgbThEFFZUpL/TQC3iMjnRGQugK8D2FPMsIio0eou9anqRRHZAuC3mCr17VDVQ95xjSrnpZbyrHo0ANx4Y/6lmjdvnnmsV8o7fvy4Gd+/f78Z//jjj3NjbW1t5rHedbt06ZIZP3jwoBm3SqTeNffi999/vxmfM2dObswrI3plSO+6pF5XS1FzCJLq/Kr6KoBXCxkJEZWK03uJgmLyEwXF5CcKislPFBSTnygoJj9RUKWu56+SVzO2asIA0NraWvex27ZtM+PWHAIAuPPOO8340NBQbiy1Xu09N29Z7blz53Jj3hwE7zV75ZVXzPiDDz5oxi2pS8RT6vgpx84G3/mJgmLyEwXF5CcKislPFBSTnygoJj9RUE1V6kspr3gtpr1ymrfs1ip5ecd6br31VjPulbysTrErVqwwj+3o6DDjo6OjZvzDDz8049bYvOe1ZMkSM+6VMa2l1qnt0r0lwSm8JbtFlQL5zk8UFJOfKCgmP1FQTH6ioJj8REEx+YmCYvITBVV6nT+l/XbKTrkprbkB4IMPPsiN7d271zx20aJFZtxqvV1L3Br7sWPHzGO9HYS96zY+Pm7G29vbc2PeNR8bG0s695NPPpkb8+ZWPPzww2bcW+rsbY2esnV5UVuP852fKCgmP1FQTH6ioJj8REEx+YmCYvITBcXkJwoqqc4vIkcBjAO4BOCiqvbUcEzK+XJj3pp6r1798ssvm/EDBw7kxrx151atGwCGh4fNuPfcrHkAixcvNo+9cOGCGT9z5owZ99pvL1iwoO5ze2vmvXkC3d3dubHBwUHz2GeffdaMP/bYY2bc27bdeu4p/Rtmk19FTPL5J1W1v3qJqOnwx36ioFKTXwH8TkT2i8jmIgZEROVI/bH/blUdEJFlAF4Tkf9T1asmumffFDYDwMqVKxNPR0RFSXrnV9WB7P8hAC8C2DDDfbarao+q9nR2dqacjogKVHfyi0ibiHRc+RjAlwAcLGpgRNRYKT/2LwfwYlZauBHALlX9TSGjIqKGqzv5VfUIAHvv6FnyeqWn9O33aqf79u0z41b/e2+9vjc2r3e+x6qlW1tkA3492uv77z23iYmJ3Jg3/8H7NdGbB2DFvT4GJ06cMOO7du0y448++qgZt9bse70AvGteK5b6iIJi8hMFxeQnCorJTxQUk58oKCY/UVDX1RbdVonDK3+88cYb9QzpU1ZpyCvNeCUtrxWz1ybaW9qacuy6devM+JEjR8z46tWrc2PeFttey3KvTGmV07xze2VGr214yrJcbwvulGXx0/GdnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXF5CcKqtQ6v4iY9XivfpmypPfdd9814ym19JT5CbUc77Hq2d4cA6+m7G2T7bXutq6rd829WrnHanluLTUG/NdkdHTUjI+MjJhxaxm497xZ5yeiJEx+oqCY/ERBMfmJgmLyEwXF5CcKislPFNR1tZ7finvHvvfee2bca1Ftrdn31pV7rDo9kHZdvLF5j+3Vq70twC1e+2xvHsDk5KQZt66rtz24V2v3tk3v7+8349a27l6vgZQ8mI7v/ERBMfmJgmLyEwXF5CcKislPFBSTnygoJj9RUG6dX0R2APgKgCFVvSO7rQvALwDcDOAogEdU1S4I///j1TvWJAMDA2a8u7vbjFvr+b018R6vb783DyClR4I3D8CrOXtaW1tzY16t3JOyVbX3mnlj8+KnT58249Y8gmZaz/8zAA9cc9sTAF5X1VsAvJ59TkTXETf5VXUvgGu/jW0EsDP7eCeAhwoeFxE1WL0/Ny1X1UEAyP5fVtyQiKgMDf+Dn4hsFpE+Eenzfg8iovLUm/wnRGQlAGT/D+XdUVW3q2qPqvZ0dXXVeToiKlq9yb8HwKbs400AXipmOERUFjf5ReQFAP8L4FYR6ReRbwJ4GsB9IvInAPdlnxPRdcSt86tqb07oiwWPJal+ee7cOTPu7fXusWqvZ86cMY/t6Oio+7FriafsKeDVuxcsWGDGPdbYvJ7/3ti9NflW/Pz58+axXq8Bb46BN6/EOt57vVPmN1z1OIU8ChFdd5j8REEx+YmCYvITBcXkJwqKyU8UVOmtu1PaDlvxd955xzw2tXxite5OXdLrndsql3lx75p6y4VT49bYvKXMqcbHx3Nj3lJl73l5JdChodxJrwDStptPWcJ91ePUfE8i+qvC5CcKislPFBSTnygoJj9RUEx+oqCY/ERBNdUW3R6rhum1CLNaSAN+C2urVbM3h8CrKS9atMiMe3XflDbQ3tJVq1YO+NfNWrabet2819Saf5GyDXYt556YmDDjKecus3U3Ef0VYvITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioEqv86e0LLbio6Oj5rFe+2yvjbRVM/bWds+ZM8eMey2ovZqyJbV1t1fHt/ocAHY93Ztj4D2295pbr4u3Xt97Tbwtur1eBdZ1914ztu4moiRMfqKgmPxEQTH5iYJi8hMFxeQnCorJTxSUW+cXkR0AvgJgSFXvyG57CsC3AJzM7rZVVV+t4bHMWr3Xn9461tuie/HixWbcq9taNeXu7m7zWI9XU/bmCaT0cfdq6am8uRsWb+zetutWn4Sbbrop6bG9r1Wvzn/q1Knc2JIlS8xjy1zP/zMAD8xw+49VdX32z018ImoubvKr6l4AdpscIrrupPzOv0VE/igiO0Sks7AREVEp6k3+nwD4PID1AAYB/DDvjiKyWUT6RKTP67NHROWpK/lV9YSqXlLVywB+CmCDcd/tqtqjqj1dXV31jpOIClZX8ovIymmffhXAwWKGQ0RlqaXU9wKAewAsFZF+AN8HcI+IrAegAI4C+HYDx0hEDeAmv6r2znDzc/WczKvzezVhq57t1WXb29vNuFfnt9ZfezVjb9355OSkGffWnls1ZW8OgdfnwJs/4fUasF4zb47B2bNnk+LLly/PjXl7JYyNjZlxr8+BF/f2DbBYdf7ZzAHgDD+ioJj8REEx+YmCYvITBcXkJwqKyU8UVOmtu1PKFFbZKHXLZe9469xWSQnwy2Xe8k9vSa9VCkxtA+0tXfVaf6eUdr2W6N51s8q3a9asMY89dOiQGfdKpN51sUrTXmm3KHznJwqKyU8UFJOfKCgmP1FQTH6ioJj8REEx+YmCKrXOr6pubdZi1ay92mjK0lPv8b3lwh6v3u0ty7V4cwS8pdApS08Be9mtVwtP2f7be3yvlbv39eSN/fz582Y85TW1xjabOQJ85ycKislPFBSTnygoJj9RUEx+oqCY/ERBMfmJgip9Pb9Vh/TmAFjx1C2VPda4vdbdXk3Xq82mbHOdyqtXt7W1mXHruaX2CvDmAVjXfdmyZeax3tei95p48ysmJibqPnfKXJnp+M5PFBSTnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXlFr9FZC2A5wGsAHAZwHZVfUZEugD8AsDNAI4CeERVR1IG49W7U/qZz2br4tmee+HCheaxqXsKeHXdlP0MPF6fAy9eZX96q5buref35hh4czescwPA8PBwbsx7zcqs818E8D1VvQ3A3wP4jojcDuAJAK+r6i0AXs8+J6LrhJv8qjqoqgeyj8cBHAawGsBGADuzu+0E8FCjBklExZvV7/wicjOALwDYB2C5qg4CU98gANjzJYmoqdSc/CLSDuBXAL6rqmOzOG6ziPSJSN/p06frGSMRNUBNyS8iczCV+D9X1V9nN58QkZVZfCWAoZmOVdXtqtqjqj1dXV1FjJmICuAmv0z9Kfo5AIdV9UfTQnsAbMo+3gTgpeKHR0SNUss617sBfAPA2yLyVnbbVgBPA/iliHwTwJ8BfK2WE6a0HbZKHKklrZTyirUVNJC+pDelzbO3bDa1rOSNzVp265U4veviLekdHR3NjXnP24t75/ZKhSdPnsyNpSzpnU351E1+Vf0DgLxX6Ys1n4mImgpn+BEFxeQnCorJTxQUk58oKCY/UVBMfqKgSm/dbdUovfqmVRf2WnfPnz/fjHtLMK26bmqL6dRls1abaG9sXl343LlzZnzevHlmPGVb9ZR5H4A9B8Gr43sty0dGklavJ21dztbdRJSEyU8UFJOfKCgmP1FQTH6ioJj8REEx+YmCKrXOr6pJW3RbNeOUNe+AX9e1Ht+r43u1cm8759bWVjNujc2rlXt1ei/uzY+wXlPveafOj7Be09T+D16d33vNrdbdKXX+2azn5zs/UVBMfqKgmPxEQTH5iYJi8hMFxeQnCorJTxRUU63n9+qbVp3f6tEOAEuXLjXjXk3ZqqV74/b6+nv9673Ht3rze337W1pazPjYmL0zm9cnwZqj4M3r8J63F7fWzHvnXrbM3nry2LFjZtybR2BtXcf1/ETUUEx+oqCY/ERBMfmJgmLyEwXF5CcKislPFJRb5xeRtQCeB7ACwGUA21X1GRF5CsC3AFzZaHyrqr7qPZ5Vo0ypd1v7nQN+Pdrr+2+tk/Zqut7YUnrfA/a6eG/Nu/fY3nPz+iiMj4/nxrx69uTkpBk/deqUGT9z5kzd5161apUZ37dvnxn35p0cOXIkN/b444+bxxZV569lks9FAN9T1QMi0gFgv4i8lsV+rKrbChkJEZXKTX5VHQQwmH08LiKHAaxu9MCIqLFm9Tu/iNwM4AsArvzMs0VE/igiO0SkM+eYzSLSJyJ91pRGIipXzckvIu0AfgXgu6o6BuAnAD4PYD2mfjL44UzHqep2Ve1R1Z6urq4ChkxERagp+UVkDqYS/+eq+msAUNUTqnpJVS8D+CmADY0bJhEVzU1+mfpz8HMADqvqj6bdvnLa3b4K4GDxwyOiRqnlr/13A/gGgLdF5K3stq0AekVkPQAFcBTAt70HUtWkMoVVljpw4IB5rLe0tb293Yxb5TSvHHbbbbeZ8cHBQTPuLbu1zu8tVfbKaV4pr62tzYxbr7f32B0dHWZ83bp1Znz9+vW5Ma/Ud9ddd5nxQ4cOmfH+/n4zvmLFityYlyNFte6u5a/9fwAwU9a5NX0ial6c4UcUFJOfKCgmP1FQTH6ioJj8REEx+YmCaqrW3bOpUV5r2zZ7ceHu3bvNeHd3txnfuHFjbsyrld97771m3Nuq2mv97c1hsFjtrWuJe0uhrRbY3lLmlHbqgD3/wRu3twR8y5YtZjxl7N6xKXkyHd/5iYJi8hMFxeQnCorJTxQUk58oKCY/UVBMfqKgpKiaYU0nEzkJYPrexksBDJc2gNlp1rE167gAjq1eRY6tW1VvquWOpSb/Z04u0qeqPZUNwNCsY2vWcQEcW72qGht/7CcKislPFFTVyb+94vNbmnVszTougGOrVyVjq/R3fiKqTtXv/ERUkUqSX0QeEJF3ReR9EXmiijHkEZGjIvK2iLwlIn0Vj2WHiAyJyMFpt3WJyGsi8qfs/xm3SatobE+JyIfZtXtLRP65orGtFZHfi8hhETkkIv+a3V7ptTPGVcl1K/3HfhFpAfAegPsA9AN4E0Cvqr5T6kByiMhRAD2qWnlNWET+EcBZAM+r6h3Zbf8O4LSqPp194+xU1X9rkrE9BeBs1Ts3ZxvKrJy+szSAhwD8Cyq8dsa4HkEF162Kd/4NAN5X1SOqegHAbgD5nTICU9W9AK7d3XQjgJ3Zxzsx9cVTupyxNQVVHVTVA9nH4wCu7Cxd6bUzxlWJKpJ/NYDj0z7vR3Nt+a0Afici+0Vkc9WDmcHybNv0K9un57fKqYa7c3OZrtlZummuXT07XhetiuSfafefZio53K2qfwfgywC+k/14S7Wpaefmssyws3RTqHfH66JVkfz9ANZO+3wNgIEKxjEjVR3I/h8C8CKab/fhE1c2Sc3+H6p4PJ9qpp2bZ9pZGk1w7Zppx+sqkv9NALeIyOdEZC6ArwPYU8E4PkNE2rI/xEBE2gB8Cc23+/AeAJuyjzcBeKnCsVylWXZuzttZGhVfu2bb8bqSST5ZKeM/ALQA2KGqPyh9EDMQkb/F1Ls9MNXZeFeVYxORFwDcg6lVXycAfB/AfwP4JYC/AfBnAF9T1dL/8JYztnsw9aPrpzs3X/kdu+Sx/QOA/wHwNoAr7aK3Yur368qunTGuXlRw3TjDjygozvAjCorJTxQUk58oKCY/UVBMfqKgmPxEQTH5iYJi8hMF9RfIvWniLmVqwwAAAABJRU5ErkJggg==\n", 594 | "text/plain": [ 595 | "
" 596 | ] 597 | }, 598 | "metadata": {}, 599 | "output_type": "display_data" 600 | } 601 | ], 602 | "source": [ 603 | "# Correct prediction\n", 604 | "evaluate(0)" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": {}, 610 | "source": [ 611 | "Based on weighted_counter, my KNN classifier results in a close weight between class 4 (Coat) with weight .304 and 6 (Shirt) with weight .286." 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 39, 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "name": "stdout", 621 | "output_type": "stream", 622 | "text": [ 623 | "Weights (in order): [0.3754328103184546, 0.14807884211858424, 0.07343824003610218, 0.07208342067277425, 0.0704589146423518, 0.0, 0.0, 0.0, 0.0, 0.0]\n", 624 | "Predictions (in order): ['T-shirt/top', 'Coat', 'Pullover', 'Dress', 'Shirt', 'Ankle boot', 'Bag', 'Sneaker', 'Sandal', 'Trouser']\n", 625 | "Actual: Coat\n" 626 | ] 627 | }, 628 | { 629 | "data": { 630 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE3VJREFUeJzt3V9sVPeVB/DvCX+NMYZgDObfpiERSkhUunJgJVarRFWqUFUiPDQqDxUrVaUPjbSV+rARL81LpWjVf3lYVaIbVCK1aZFaNjxEu42ijbJIqMQkiCTLsiBCwMaBJGAYmxhjOPvgS2WI7znj+c29d+j5fiRke8787v3N9Rxmxuf3R1QVRBTPPVV3gIiqweQnCorJTxQUk58oKCY/UVBMfqKgmPxEQTH5iYJi8hMFNbPMk82dO1c7OjoKObaIFHLcZpzbi99zj/1/8IwZMxo+ftHXxTv+8PBwbuzatWtm287OTjN+48YNM26NXvVGthY98rWo49dqNYyOjtb1S09KfhF5CsCLAGYA+DdVfcG6f0dHB7Zu3Wodr+G+eAmUmgTW8WfOtC/j7NmzzXhbW5sZX7BggRmfM2dOw+dOfRJ61/3gwYO5sVOnTpltN2/ebMZrtZoZ//zzz3Nj4+PjZtubN2+ace+6ee2teMp/TPv27TPbTtbw234RmQHgXwFsBvAwgG0i8nCjxyOicqV85t8A4KSqnlLVMQC/A7ClOd0ioqKlJP8KAGcn/dyf3XYbEdkhIn0i0jc6OppwOiJqppTkn+pD9Bc+jKjqLlXtVdXeuXPnJpyOiJopJfn7Aaya9PNKAOfSukNEZUlJ/rcBPCgiXxKR2QC+BWB/c7pFREVruNSnquMi8iyA/8REqW+3qn7gtUspuRVZs/bKddZHFq+c1t7ebsavX79uxg8dOmTGe3t7c2Nnz57NjQF+yWvRokVm3KrjA0B3d3duzPt9fvCB/XRat26dGbd+Z1euXDHbppYCvTEIltQyYr2S6vyq+hqA15rSEyIqFYf3EgXF5CcKislPFBSTnygoJj9RUEx+oqBKnc8P2DXMlHnxXluvju9Nq7Xi3nx779gDAwNm3BsW/dFHH5lxS1dXlxn3as5e36zHPn/+fLPtuXP2gNEjR46Y8cceeyw35k1F9uaheHFv7Ib1fPXGGFh9n85YGL7yEwXF5CcKislPFBSTnygoJj9RUEx+oqBKL/WllOusuFdumzVrlhm3VsD14illQgA4c+aMGV+zZo0Zt0o/3nRj77p500e96zY2NpYbu3Tpktl24cKFZnxoaMiMW1OpvVKfp8ilv722XimwXnzlJwqKyU8UFJOfKCgmP1FQTH6ioJj8REEx+YmCuqvq/FZt1qtXe/VuL27Vs71zv/nmm2Z82bJlZtybNmv1zWvr1YxTa87Wdf3www/NtmvXrjXj3mOzdun1xid4j9tbmrvIpbtTjj0ZX/mJgmLyEwXF5CcKislPFBSTnygoJj9RUEx+oqCS6vwichpADcANAOOqmr9X9MT9C6vze/OzvaW7vfn+VnuvrTfvvKenp+FzA/Z187ai9sY3eNfVW6La2uLbq2d7zwdv2fFarZYb6+joMNt6tXTvuqWMn/DWUPDGldSrGYN8nlDVT5twHCIqEd/2EwWVmvwK4E8iclhEdjSjQ0RUjtS3/ZtU9ZyIdAN4XUT+V1XfmnyH7D+FHYC/PRMRlSfplV9Vz2VfLwDYB2DDFPfZpaq9qtrrLWRJROVpOPlFpF1EOm59D+BrAN5vVseIqFgpb/uXAtiXlWNmAvitqv5HU3pFRIVrOPlV9RSAL0+3nVW79WrKVtyrhXu1Ua+9de7h4WGzbXd3txn3xgl48ZR13L2ashdPWb/+5MmTZtsNG77wKfI23hgDK+7V6b3H7V1zr28pdX7r3Nyim4hcTH6ioJj8REEx+YmCYvITBcXkJwqq9KW7LV6pL6VM6JX6UkqB1hLRADBv3rykuDe91HrsqUtUe8tje1OGrW22ly5darb1eOU6a0qvVz71rnnqFHGrnOeVEadTzrPwlZ8oKCY/UVBMfqKgmPxEQTH5iYJi8hMFxeQnCqr0Or9Vk/bqlyl1/pRje/GxsTGz7aef2osbezXhJUuWmHFL6pRdb+m1q1evNnz8gYEBs633O2lvbzfjIyMjubHUadLeGAPvOWFdF286sPdcrxdf+YmCYvITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioEqv86ds0W3NuU+dz58yTsCb0+7NiU/Zghuw5557j9ur81tz4gFgxYoVZtyqdx8/ftxse/nyZTPubbNt1fJTt+D2avEp8/m931lKDk3GV36ioJj8REEx+YmCYvITBcXkJwqKyU8UFJOfKCi3zi8iuwF8A8AFVX0ku+1eAL8HcB+A0wCeUdVLqZ1JnXNfVFuPV/P1auHeGIOUdfu9mrFXz/biXV1dZry/v7/htt6eAt6ce6u9N98+dWxGyj4QqceuVz2v/L8G8NQdtz0H4A1VfRDAG9nPRHQXcZNfVd8CcPGOm7cA2JN9vwfA003uFxEVrNHP/EtVdRAAsq/dzesSEZWh8D/4icgOEekTkT5vTzsiKk+jyX9eRHoAIPt6Ie+OqrpLVXtVtbetra3B0xFRszWa/PsBbM++3w7g1eZ0h4jK4ia/iLwC4CCAtSLSLyLfAfACgCdF5ASAJ7Ofiegu4tb5VXVbTuir0z2ZiCSt2+/VfYuUMmfeqyl7dV1rj3uvvTev3Lvm3rr8nZ2dZnxwcDA39sADD5htU8dmpNTDU2vtKetHeOdOyaHbjlP3PYnorwqTnygoJj9RUEx+oqCY/ERBMfmJgmqppbtT2hY5ZRcAzp49mxs7evSo2XbTpk1m3JvS65XrrOmn3pDq0dFRMz5v3jwz7l33BQsW5Ma8EubixYvNuDeVes6cObkxb5p0kaU8L55S6psOvvITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioJj8REGVXue3pCzd7bVN2YIbAEZGRnJj3tLc3jbYXh3f67s1Zdir03s1ZY835Xf16tW5seXLl5ttvSm/hw8fNuPWY7927ZrZNnUZ+SKXmbfGGHBKLxG5mPxEQTH5iYJi8hMFxeQnCorJTxQUk58oqNLr/FbN2qt3pyxR7W01bc39Buzaqlfn92rp3pLkXu3WOn57e3uh5/aWJbe24ba27waAtWvXmnGvzm/x5vN7iqzzp6wFMK3zNOUoRHTXYfITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioNw6v4jsBvANABdU9ZHstucBfBfAJ9nddqrqa3Ucy6xJe7V4q9aeMkagnvaW8fHxpGNb6+7XE+/o6MiNefVsr87vPTaP1ffjx4+bbRctWmTGveeL9di98Qme1Dp/yjbbzdqjop5X/l8DeGqK23+uquuzf27iE1FrcZNfVd8CcLGEvhBRiVI+8z8rIkdFZLeI2O/PiKjlNJr8vwSwBsB6AIMAfpp3RxHZISJ9ItLn7RtHROVpKPlV9byq3lDVmwB+BWCDcd9dqtqrqr1tbW2N9pOImqyh5BeRnkk/bgXwfnO6Q0RlqafU9wqAxwF0iUg/gB8BeFxE1gNQAKcBfK/APhJRAdzkV9VtU9z8UiMn8+r8KXuie229WrvX3up3ap3fW1vfW2vAqmen7gngPbbr16+bcatvXh2/p6fHjKf03dtLIVXKnPsixwjcdpy670lEf1WY/ERBMfmJgmLyEwXF5CcKislPFFRLLd3tlUdSSn1e3CuRWKMTL1++bLb1HpdXLkuZuupN2fWui9d3r5RYq9VyY0888YTZ1ttG22NtH+4tae6VAr3r6sWt4xddhryFr/xEQTH5iYJi8hMFxeQnCorJTxQUk58oKCY/UVCl1/lTaphWLd6rR6dMk6ynvSVlujDg14wtVq0b8H8f3hiEkZERM/7xxx/nxu6//36z7fDwsBn3WMvGeUuae/HUcQApba34dM7LV36ioJj8REEx+YmCYvITBcXkJwqKyU8UFJOfKKiWqvN7tdOUtinzq732o6OjZluvXu2NMfDaW0t/e4/LGwfg8ZYVt46/YMECs603hsC7LtbS3d6S5Kl1/pR46hiEevGVnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXF5CcKyq3zi8gqAC8DWAbgJoBdqvqiiNwL4PcA7gNwGsAzqnrJOpaqJs1FTqnzp9ZtrXX7vbbeFtxee2/9euuxebXyoaEhM97R0WHGT5w4Yca7urpyY/Pnzzfben1LqfN74xu8dQy8cQIpW5unnLvZ8/nHAfxQVR8C8HcAvi8iDwN4DsAbqvoggDeyn4noLuEmv6oOquo72fc1AMcArACwBcCe7G57ADxdVCeJqPmm9ZlfRO4D8BUAfwawVFUHgYn/IAB0N7tzRFScupNfROYD+AOAH6jqlWm02yEifSLSZ62pRkTlqiv5RWQWJhL/N6r6x+zm8yLSk8V7AFyYqq2q7lLVXlXttf5oRkTlcpNfJpatfQnAMVX92aTQfgDbs++3A3i1+d0joqLUM6V3E4BvA3hPRI5kt+0E8AKAvSLyHQBnAHyznhNapYgqp/R6pUBr+W2vJOVt4b1y5Uoz7k19tbbw9pYN98pKXpnx4MGDZnzjxo25sUcffdRse+WK/enSW0794sWLubGFCxeabb3rUmQpMLUsXS83+VX1AIC8q/zVpvSCiErHEX5EQTH5iYJi8hMFxeQnCorJTxQUk58oqNKX7rZqmF5t1KrrerVR79hePdxaXtur0+/du9eMr1692oy/++67Znzx4sW5MW+MQGdnpxn3rotnbGwsN+YtWe4NB581a1bD5547d27DbYH055s1TiBlWXFu0U1ELiY/UVBMfqKgmPxEQTH5iYJi8hMFxeQnCqrUOr+3dHfK8ttF1/mtJazXrVtntvWW7vZ4NWmLNacdAAYGBsz4zJn2U8RaSwCw5+R7dX5v6/PPPvvMjNdqtdzY+fPnzbbLly8346nPN6u9t1YA6/xElITJTxQUk58oKCY/UVBMfqKgmPxEQTH5iYJqqfn8Xn0zZT5/al3Wqp96a9unbLGdGvfq8N6xvTn1c+bMMePW79Qb1+E9Hx566CEzbu0Q1dPTY7Yt+vmUMmalWev285WfKCgmP1FQTH6ioJj8REEx+YmCYvITBcXkJwrKrfOLyCoALwNYBuAmgF2q+qKIPA/guwA+ye66U1Vfs46lqmYN09tvPWVP8yLr/AcOHDDbXrp0yYxb6+4D/jgBbx97i1ULB/w59Sk155S17QFgyZIlZry9vT03VvTzJSVe1rr99QzyGQfwQ1V9R0Q6ABwWkdez2M9V9Sd1n42IWoab/Ko6CGAw+74mIscArCi6Y0RUrGl95heR+wB8BcCfs5ueFZGjIrJbRBbltNkhIn0i0ucNFSWi8tSd/CIyH8AfAPxAVa8A+CWANQDWY+KdwU+naqequ1S1V1V7vc+XRFSeupJfRGZhIvF/o6p/BABVPa+qN1T1JoBfAdhQXDeJqNnc5JeJP8G/BOCYqv5s0u2Tp0VtBfB+87tHREWp56/9mwB8G8B7InIku20ngG0ish6AAjgN4Hv1nNAroVispZ5TSy9emdEqaXnbXG/ZssWMb9y40Yx718wq7wwNDZltvaW7vfbe33GsUqFXwvS2yfbOffXqVTNu8ZZyL7JUmHrsetXz1/4DAKbKDLOmT0StjSP8iIJi8hMFxeQnCorJTxQUk58oKCY/UVClb9Ft1cunMx3xTl6d3tsO2lvi2jr+ypUrzbZezdirV3vDoq2+e9uDd3V1mXGvpuw9tkOHDuXGvKnI3jiAkZERM+6N7bB4j8ubypwyhsHrd7NyiK/8REEx+YmCYvITBcXkJwqKyU8UFJOfKCgmP1FQklJbn/bJRD4B8NGkm7oAfFpaB6anVfvWqv0C2LdGNbNvf6Oq9prmmVKT/wsnF+lT1d7KOmBo1b61ar8A9q1RVfWNb/uJgmLyEwVVdfLvqvj8llbtW6v2C2DfGlVJ3yr9zE9E1an6lZ+IKlJJ8ovIUyJyXEROishzVfQhj4icFpH3ROSIiPRV3JfdInJBRN6fdNu9IvK6iJzIvk65TVpFfXteRAaya3dERL5eUd9Wich/icgxEflARP4pu73Sa2f0q5LrVvrbfhGZAeD/ADwJoB/A2wC2qer/lNqRHCJyGkCvqlZeExaRfwAwDOBlVX0ku+1fAFxU1Rey/zgXqeo/t0jfngcwXPXOzdmGMj2Td5YG8DSAf0SF187o1zOo4LpV8cq/AcBJVT2lqmMAfgfA3tUiKFV9C8DFO27eAmBP9v0eTDx5SpfTt5agqoOq+k72fQ3ArZ2lK712Rr8qUUXyrwBwdtLP/WitLb8VwJ9E5LCI7Ki6M1NYmm2bfmv79O6K+3Mnd+fmMt2xs3TLXLtGdrxutiqSf6r1sFqp5LBJVf8WwGYA38/e3lJ96tq5uSxT7CzdEhrd8brZqkj+fgCrJv28EsC5CvoxJVU9l329AGAfWm/34fO3NknNvl6ouD9/0Uo7N0+1szRa4Nq10o7XVST/2wAeFJEvichsAN8CsL+CfnyBiLRnf4iBiLQD+Bpab/fh/QC2Z99vB/BqhX25Tavs3Jy3szQqvnattuN1JYN8slLGLwDMALBbVX9ceiemICL3Y+LVHphY2fi3VfZNRF4B8DgmZn2dB/AjAP8OYC+A1QDOAPimqpb+h7ecvj2Oibeuf9m5+dZn7JL79vcA/hvAewBuLXW7ExOfryu7dka/tqGC68YRfkRBcYQfUVBMfqKgmPxEQTH5iYJi8hMFxeQnCorJTxQUk58oqP8HACBxv22h61sAAAAASUVORK5CYII=\n", 631 | "text/plain": [ 632 | "
" 633 | ] 634 | }, 635 | "metadata": {}, 636 | "output_type": "display_data" 637 | } 638 | ], 639 | "source": [ 640 | "# Incorrect prediction\n", 641 | "evaluate(6)" 642 | ] 643 | } 644 | ], 645 | "metadata": { 646 | "kernelspec": { 647 | "display_name": "Python 3", 648 | "language": "python", 649 | "name": "python3" 650 | }, 651 | "language_info": { 652 | "codemirror_mode": { 653 | "name": "ipython", 654 | "version": 3 655 | }, 656 | "file_extension": ".py", 657 | "mimetype": "text/x-python", 658 | "name": "python", 659 | "nbconvert_exporter": "python", 660 | "pygments_lexer": "ipython3", 661 | "version": "3.6.6" 662 | } 663 | }, 664 | "nbformat": 4, 665 | "nbformat_minor": 2 666 | } 667 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/model/__init__.py -------------------------------------------------------------------------------- /model/activation_classes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class Sigmoid(): 3 | def __call__(self,x): 4 | return 1/(1+np.exp(-x)) 5 | def grad(self,x): 6 | x_acted = self.__call__(x) 7 | return x_acted*(1-x_acted) 8 | 9 | class Softmax(): 10 | def __call__(self,x): 11 | # return np.exp(x) / np.sum(np.exp(x), axis=1)[:,None] 12 | 13 | #this is more stable. Avoid np.exp overflow problem 14 | e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) 15 | return e_x / np.sum(e_x, axis=-1, keepdims=True) 16 | 17 | class Tanh(): 18 | def __call__(self, x): 19 | return 2/(1 + np.exp(-2*x)) - 1 20 | 21 | def grad(self, x): 22 | return 1 - np.power(self.__call__(x), 2) 23 | 24 | class ReLU(): 25 | def __call__(self, x): 26 | # return np.maximum(x,0) 27 | return np.where(x >= 0, x, 0) 28 | 29 | def grad(self, x): 30 | return np.where(x >= 0, 1, 0) 31 | 32 | class LeakyReLU(): 33 | def __init__(self, alpha=0.2): 34 | self.alpha = alpha 35 | 36 | def __call__(self, x): 37 | return np.where(x >= 0, x, self.alpha * x) 38 | 39 | def grad(self, x): 40 | return np.where(x >= 0, 1, self.alpha) 41 | -------------------------------------------------------------------------------- /model/activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def sigmoid(x): 3 | return 1/(1+np.exp(-x)) 4 | -------------------------------------------------------------------------------- /model/gradients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import onehot_array 3 | def MSE_grad(y,y_pred): 4 | ''' 5 | Derivative of MSE loss w.r.t y_pred (not w) 6 | ''' 7 | return (2/len(y))*(y_pred-y) 8 | 9 | def logloss_sigmoid_grad(y,y_pred): 10 | ''' 11 | Derivative of sigmoid + log loss combination is equivalent to derivative of MSE 12 | ''' 13 | return MSE_grad(y,y_pred)/2 14 | 15 | def logloss_softmax_grad(y,y_pred): 16 | y_onehot = onehot_array(y,y_pred.shape[1]) 17 | return (1/len(y_pred)) * (y_pred - y_onehot) -------------------------------------------------------------------------------- /model/knn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class CustomNearestNeighbor(): 3 | def __init__(self,k): 4 | self.k = k 5 | self.eps=1e-6 6 | def fit(self,X,y=None): 7 | self.X_train = np.array(X) 8 | if y is not None: 9 | self.y_train = y 10 | self.n_classes = np.unique(y).shape[0] 11 | 12 | def kneighbors(self,X): 13 | ''' 14 | Return sorted k distance and k indices of input X 15 | ''' 16 | X = np.array(X) 17 | dist = np.zeros([len(X),self.k]) 18 | idxs = np.zeros([len(X),self.k],dtype=int) 19 | #euclidian distance 20 | for i,x in enumerate(X): 21 | temp_dist = np.linalg.norm(x - self.X_train,axis=1) 22 | idxs[i] = (np.argsort(temp_dist)[:self.k]) 23 | dist[i] = temp_dist[idxs[i]] 24 | return [dist,idxs] 25 | def predict_classification(self,X_test,weighted=False): 26 | if not hasattr(self,'y_train'): 27 | raise ValueError('y and n_class are undefined') 28 | 29 | dist,idxs = self.kneighbors(X_test) 30 | inv_dist = 1/(dist+ self.eps) # eps to avoid divided by 0 31 | 32 | y_pred = np.zeros(idxs.shape[0]) 33 | wc = np.zeros([idxs.shape[0],self.n_classes]) 34 | class_sorted = np.zeros([idxs.shape[0],self.n_classes],dtype=int) 35 | for i,idx in enumerate(idxs): 36 | # calculate weighted count (wc) 37 | class_counter = np.bincount(self.y_train[idx],weights=inv_dist[i] if weighted else None,minlength=self.n_classes) 38 | class_sorted[i] = (np.argsort(class_counter)[::-1]) 39 | wc[i] = class_counter 40 | y_pred[i]=class_sorted[i][0] 41 | return y_pred,wc,class_sorted -------------------------------------------------------------------------------- /model/linear_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import get_train_val,batch_iterator,plot_learning_curve 3 | from model.metrics import MSE 4 | from model.gradients import MSE_grad 5 | 6 | def initialize_weight(dim): 7 | W0 = np.array([[0]]) # bias, 1x1 8 | W= np.random.rand(dim,1) 9 | return np.concatenate((W0,W)) 10 | 11 | class CustomLinearModel(): 12 | def __init__(self,dim,is_reg,loss_fn,grad_fn,act_fn = lambda x: x): 13 | self.dim,self.act_fn,self.loss_fn,self.grad_fn,self.is_reg = dim,act_fn,loss_fn,grad_fn,is_reg 14 | 15 | self.W = initialize_weight(self.dim) 16 | self.train_losses=[] 17 | self.val_losses=[] 18 | def fit(self,X,y,lr,l2=0,n_iteration=50,val_ratio=.2): 19 | ''' 20 | Fit data using gradient descent and l2 regularization 21 | ''' 22 | X_train,y_train,X_val,y_val = get_train_val(X,y,val_ratio) 23 | for i in range(n_iteration): 24 | y_pred = self.act_fn(np.squeeze(X_train @ self.W)) 25 | # MSE loss for regression 26 | loss = self.loss_fn(y_train,y_pred) 27 | grad = self.grad_fn(y_train,y_pred) # shape (n,) 28 | grad_w = X_train.T @ grad # shape (dim,) 29 | 30 | if len(grad_w.shape)==1: grad_w = grad_w[:,None] # turn (dim,) to (dim,1) 31 | #ignore update of grad_w0 (bias term) since w0 does not contribute to regularization process 32 | grad_w[1:,:]+= 2*(l2/len(X_train))*self.W[1:,:] # (2 *lambda / m)* weight 33 | 34 | self.W-= lr*grad_w 35 | 36 | #save training loss 37 | self.train_losses.append(loss) 38 | #predict validation set 39 | y_pred = self.act_fn(np.squeeze(X_val @ self.W)) 40 | val_loss = self.loss_fn(y_val,y_pred) 41 | self.val_losses.append(val_loss) 42 | if (i+1) % 20 == 0: 43 | print(f'{i+1}. Training loss: {loss}, Val loss:{val_loss}') 44 | 45 | plot_learning_curve(self.train_losses,self.val_losses) 46 | 47 | def fit_epoch(self,X,y,lr,epochs,bs,l2=0,val_ratio=0.2): 48 | ''' 49 | Fit data using stochastic gradient descent and l2 regularization 50 | ''' 51 | X_train,y_train,X_val,y_val = get_train_val(X,y,val_ratio) 52 | for epoch in range(epochs): 53 | train_cumloss,val_cumloss = 0,0 54 | # get batch from train set 55 | for xb,yb in batch_iterator(X_train,y_train,bs): 56 | y_pred = self.act_fn(np.squeeze(xb @ self.W)) 57 | train_cumloss+= self.loss_fn(yb,y_pred) * len(xb) 58 | 59 | grad = self.grad_fn(yb,y_pred) 60 | grad_w = xb.T @ grad 61 | if len(grad_w.shape)==1: grad_w = grad_w[:,None] 62 | grad_w[1:,:]+= 2*(l2/len(xb))*self.W[1:,:] 63 | self.W-= lr*grad_w 64 | 65 | # get double of bs from validation set (since there's less calculation for prediction) 66 | for xb,yb in batch_iterator(X_val,y_val,bs*2): 67 | y_pred = self.act_fn(np.squeeze(xb @ self.W)) 68 | val_cumloss += self.loss_fn(yb,y_pred) * len(xb) 69 | 70 | self.train_losses.append(train_cumloss/ len(X_train)) 71 | self.val_losses.append(val_cumloss / len(X_val)) 72 | print(f'Epoch {epoch+1}. Training loss: {self.train_losses[-1]}, Val loss:{self.val_losses[-1]}') 73 | plot_learning_curve(self.train_losses,self.val_losses) 74 | 75 | def get_weight(self): 76 | return self.W 77 | def predict(self,X,thres=0.5): 78 | if X.shape[1] == self.dim: 79 | X0 = np.array([[1]*X.shape[0]]).T # nx1 80 | X = np.concatenate((X0,X),axis=1) 81 | y_pred= self.act_fn(np.squeeze(X @ self.W)) 82 | if not self.is_reg: 83 | y_pred = (y_pred >= thres).astype(np.uint8) 84 | return y_pred 85 | def predict_proba(self,X): 86 | if self.is_reg: 87 | raise Exception('Cannot predict probability for regression') 88 | if X.shape[1] == self.dim: 89 | X0 = np.array([[1]*X.shape[0]]).T # nx1 90 | X = np.concatenate((X0,X),axis=1) 91 | return self.act_fn(np.squeeze(X @ self.W)) -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import onehot_array 3 | def MSE(y,y_pred): 4 | return np.mean((y_pred -y)**2) 5 | def logloss(y,y_pred): 6 | y_pred= np.clip(y_pred,1e-5,1-1e-5) 7 | return np.mean(-np.log(y_pred)*y - np.log(1-y_pred)*(1-y)) 8 | def multi_logloss(y,y_pred): 9 | y_pred= np.clip(y_pred,1e-5,1-1e-5) 10 | y_onehot = onehot_array(y,y_pred.shape[1]) 11 | return -np.mean(np.log(np.sum(y_onehot * y_pred,axis=1))) 12 | def accuracy(y,y_pred): 13 | return np.mean(y==y_pred) -------------------------------------------------------------------------------- /model/neural_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import get_train_val,batch_iterator,plot_learning_curve 3 | from model.metrics import multi_logloss 4 | from model.gradients import MSE_grad 5 | from model.activation_classes import Softmax,ReLU,LeakyReLU 6 | from model.gradients import logloss_softmax_grad 7 | from IPython.core.debugger import set_trace 8 | from model.optimizers import GradientDescent 9 | 10 | def init_weight(shape): 11 | ''' 12 | Kaiming He normal initialization 13 | ''' 14 | np.random.seed(42) 15 | return [np.random.uniform(size=shape) * np.sqrt(2/shape[0]), np.zeros((1,shape[1]))] 16 | 17 | 18 | def drop_out(X_act,keep_prob): 19 | ''' 20 | Inverted dropout implementation 21 | ''' 22 | mask = np.random.rand(*X_act.shape) <= keep_prob 23 | X_act= (X_act*mask)/keep_prob 24 | return X_act,mask 25 | class CustomNeuralNetwork(): 26 | ''' 27 | Simple neural network for binary classification 28 | ''' 29 | def __init__(self,layers,act_obj,opt= GradientDescent,keep_prob=1.0): 30 | ''' 31 | Layers include output layer 32 | E.g for 10 output classification, input layer size 400 and 1 hidden layer size 200: [400,200,10] 33 | 34 | act_obj: object (or list of objects) from activation_classes module to apply before reaching each hidden layer (exclude the last SoftMax activation before loss calculation) 35 | keep_prob: keep probability (or list) for drop out at each hidden layer. Note that we don't drop out the first layer 36 | ''' 37 | self.act_objs= [act_obj]*(len(layers)-2) if type(act_obj)!=list else act_obj 38 | self.keep_probs = [keep_prob]*(len(layers)-2) if type(keep_prob)!=list else keep_prob 39 | # list of [weight,bias] 40 | self.weights = [init_weight((layers[i],layers[i+1])) for i in range(len(layers)-1)] 41 | self.opt = opt(layers) 42 | assert len(self.act_objs) == len(self.keep_probs),'# of activation objs and # of keep probs must be equal' 43 | assert len(self.act_objs) == len(layers)-2, 'We only need (# of hidden layers) activation objects' 44 | 45 | # class params 46 | self.train_losses,self.val_losses=[],[] 47 | self.X_inputs,self.X_acts,self.X_masks=[],[],[] 48 | 49 | def forward_pass(self,X,train): 50 | if train: 51 | self.X_inputs,self.X_acts,self.X_masks = [X],[X],[] 52 | inp = X 53 | for i,w in enumerate(self.weights): 54 | inp = inp @ w[0] + w[1] 55 | if idropout) 91 | grad_wrt_input = grad_wrt_input * self.X_masks[i] / self.keep_probs[i] # grad_wrt_activation, after factoring in the inverted dropout 92 | grad_wrt_input = grad_wrt_input * self.act_objs[i].grad(self.X_inputs[i+1]) # actual grad_wrt_input 93 | 94 | grad_wbias = np.sum(grad_wrt_input,axis=0) # (1,200) 95 | 96 | grad_w = self.X_acts[i].T @ grad_wrt_input # (400,200) 97 | 98 | grad_w,grad_wbias = self.opt.step(grad_w,grad_wbias,i,**kwargs) 99 | self.weights[i][0]-= lr*(grad_w + (l2/bs) * self.weights[i][0]) #update weight 100 | self.weights[i][1]-= lr*grad_wbias #update bias 101 | 102 | def fit_epoch(self,X_train,y_train,X_val,y_val,lr,epochs,bs=64,l2=0,beta1=0.9,beta2=0.99): 103 | ''' 104 | Fit data using stochastic gradient descent and l2 regularization 105 | ''' 106 | # set_trace() 107 | for epoch in range(epochs): 108 | train_cumloss,val_cumloss = 0,0 109 | # get batch from train set 110 | for xb,yb in batch_iterator(X_train,y_train,bs): 111 | y_pred = self.forward_pass(xb,True) 112 | train_cumloss+= multi_logloss(yb,y_pred) * len(xb) 113 | 114 | self.backward_pass(yb,y_pred,l2,lr,beta1=beta1,beta2=beta2) 115 | 116 | # get double of bs from validation set (since there's less calculation for prediction) 117 | for xb,yb in batch_iterator(X_val,y_val,bs*2): 118 | y_pred = self.forward_pass(xb,False) 119 | val_cumloss += multi_logloss(yb,y_pred) * len(xb) 120 | 121 | self.train_losses.append(train_cumloss/ len(X_train)) 122 | self.val_losses.append(val_cumloss / len(X_val)) 123 | print(f'Epoch {epoch+1}. Training loss: {self.train_losses[-1]}, Val loss:{self.val_losses[-1]}') 124 | plot_learning_curve(self.train_losses,self.val_losses) 125 | 126 | def predict(self,X): 127 | y_proba = self.forward_pass(X,False) 128 | return np.argmax(y_proba,axis=1) 129 | def predict_proba(self,X): 130 | return self.forward_pass(X,False) -------------------------------------------------------------------------------- /model/optimizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def init_exp_grad(shape): 4 | ''' 5 | Initialize value for exponential weight average 6 | ''' 7 | return [np.zeros(shape), np.zeros((1,shape[1]))] 8 | def bias_correction(t,beta): 9 | temp= max(beta**t,10e-6) 10 | return 1/(1-temp) 11 | class GradientDescent(): 12 | def __init__(self,layers): 13 | pass 14 | def step(self,grad_w,grad_wbias,layer,**kwargs): 15 | return grad_w,grad_wbias 16 | 17 | class Momentum(): 18 | def __init__(self,layers): 19 | self.exp_grad = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 20 | self.t=[0]*(len(layers)-1) 21 | def step(self,grad_w,grad_wbias,layer,**kwargs): 22 | beta1 = kwargs['beta1'] 23 | self.t[layer]+=1 24 | self.exp_grad[layer][0] = beta1*self.exp_grad[layer][0] + (1-beta1)*grad_w 25 | self.exp_grad[layer][1] = beta1*self.exp_grad[layer][1] + (1-beta1)*grad_wbias 26 | bias_corr = bias_correction(self.t[layer],beta1) 27 | return bias_corr*self.exp_grad[layer][0],bias_corr*self.exp_grad[layer][1] 28 | 29 | class RMSProp(): 30 | def __init__(self,layers): 31 | self.exp_grad_sqr = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 32 | self.t=[0]*(len(layers)-1) 33 | def step(self,grad_w,grad_wbias,layer,**kwargs): 34 | beta2 = kwargs['beta2'] 35 | eps = 10e-8 36 | self.t[layer]+=1 37 | self.exp_grad_sqr[layer][0] = beta2*self.exp_grad_sqr[layer][0] + (1-beta2)* grad_w**2 38 | self.exp_grad_sqr[layer][1] = beta2*self.exp_grad_sqr[layer][1] + (1-beta2)* grad_wbias**2 39 | bias_corr = bias_correction(self.t[layer],beta2) 40 | new_gradw = grad_w / (np.sqrt(bias_corr*self.exp_grad_sqr[layer][0]) + eps) 41 | new_gradwb = grad_wbias / (np.sqrt(bias_corr*self.exp_grad_sqr[layer][1]) + eps) 42 | return new_gradw,new_gradwb 43 | 44 | class Adam(): 45 | def __init__(self,layers): 46 | self.exp_grad = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 47 | self.exp_grad_sqr = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 48 | self.t=[0]*(len(layers)-1) 49 | def step(self,grad_w,grad_wbias,layer,**kwargs): 50 | beta1,beta2 = kwargs['beta1'],kwargs['beta2'] 51 | eps = 10e-8 52 | self.t[layer]+=1 53 | 54 | self.exp_grad[layer][0] = beta1*self.exp_grad[layer][0] + (1-beta1)*grad_w 55 | self.exp_grad[layer][1] = beta1*self.exp_grad[layer][1] + (1-beta1)*grad_wbias 56 | 57 | self.exp_grad_sqr[layer][0] = beta2*self.exp_grad_sqr[layer][0] + (1-beta2)* grad_w**2 58 | self.exp_grad_sqr[layer][1] = beta2*self.exp_grad_sqr[layer][1] + (1-beta2)* grad_wbias**2 59 | 60 | bias_corr1 = bias_correction(self.t[layer],beta1) 61 | bias_corr2 = bias_correction(self.t[layer],beta2) 62 | 63 | new_gradw = (bias_corr1*self.exp_grad[layer][0]) / (np.sqrt(bias_corr2*self.exp_grad_sqr[layer][0]) + eps) 64 | new_gradwb = (bias_corr1*self.exp_grad[layer][1]) / (np.sqrt(bias_corr2*self.exp_grad_sqr[layer][1]) + eps) 65 | return new_gradw,new_gradwb 66 | -------------------------------------------------------------------------------- /model/random_forest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.metrics import logloss, MSE 3 | 4 | def var_agg(n,s,s_squared): return (s_squared/n) - (s/n)**2 5 | 6 | class RandomForest(): 7 | def __init__(self, X, y, n_trees, sample_sz,is_reg=True, min_leaf=3,max_features=1): 8 | np.random.seed(42) 9 | if hasattr(y,'values'): y = y.values 10 | if hasattr(X,'values'): X = X.values 11 | self.X,self.y,self.sample_sz,self.min_leaf = X,y,sample_sz,min_leaf 12 | self.trees = [self.create_tree(is_reg,min_leaf,max_features) for i in range(n_trees)] # store roots of n_trees decision trees 13 | self.is_reg = is_reg 14 | 15 | def create_tree(self,is_reg,min_leaf,max_features): 16 | # generate random idxs with size sample_sz 17 | sample_idxs = np.random.permutation(len(self.y))[:self.sample_sz] 18 | return DecisionTreeNode(self.X[sample_idxs,:], self.y[sample_idxs], is_reg,min_leaf,max_features) 19 | def predict(self, X,thres=0.5): 20 | if hasattr(X,'values'): X = X.values 21 | y_pred = np.mean([t.predict(X) for t in self.trees], axis=0) 22 | if not self.is_reg: 23 | return (y_pred >= thres).astype(np.uint8) 24 | return y_pred 25 | def predict_proba(self,X): 26 | if self.is_reg: 27 | raise Exception('Cannot predict probability for regression') 28 | if hasattr(X,'values'): X = X.values 29 | return np.mean([t.predict(X) for t in self.trees], axis=0) 30 | 31 | 32 | class DecisionTreeNode(): 33 | def __init__(self, X, y, is_reg,min_leaf,max_features): 34 | self.X,self.y,self.min_leaf,self.max_features,self.is_reg = X,y,min_leaf,max_features,is_reg 35 | self.n,self.c = len(y), X.shape[1] 36 | if self.X.shape[0] != self.n: 37 | raise ValueError('X and y don\'t have the same size') 38 | self.val = np.mean(y) 39 | 40 | # Metric (loss score) 41 | self.score = float('inf') # initialize to infinity for a leaf 42 | 43 | self.col_idx= -1 # index of column chosen to split 44 | self.split_value = None # chosen split value from col with col_idx 45 | 46 | self.lhs_tree_node = None 47 | self.rhs_tree_node = None 48 | 49 | self.find_varsplit() # find best split and populate lhs + rhs tree node 50 | 51 | 52 | def find_varsplit(self): 53 | # exit clause when best split has been made from parent node 54 | if len(np.unique(self.y)) == 1: return 55 | 56 | # Assuming max_feature = self.c, as we consider all features for splitting 57 | n_col = int(self.c*self.max_features) 58 | 59 | for i in np.random.permutation(n_col): 60 | self.find_best_split_reg(i) if self.is_reg else self.find_best_split_clas(i) 61 | if self.is_leaf: return # exit clause when no split is made, because of min_leaf 62 | split_col = self.split_col 63 | lhs_idx = np.nonzero(split_col<=self.split_value)[0] 64 | rhs_idx = np.nonzero(split_col>self.split_value)[0] 65 | 66 | self.lhs_tree_node = DecisionTreeNode(self.X[lhs_idx,:], self.y[lhs_idx],self.is_reg,self.min_leaf,self.max_features) 67 | self.rhs_tree_node = DecisionTreeNode(self.X[rhs_idx,:], self.y[rhs_idx],self.is_reg,self.min_leaf,self.max_features) 68 | 69 | def find_best_split_reg(self, col_idx): 70 | x = self.X[:,col_idx] 71 | y = self.y 72 | 73 | sort_idx = np.argsort(x) 74 | sort_x,sort_y = x[sort_idx],y[sort_idx] 75 | rhs_cnt,rhs_sum,rhs_sum2 = self.n,sort_y.sum(), (sort_y**2).sum() 76 | lhs_cnt,lhs_sum,lhs_sum2=0,0.0,0.0 77 | for i in range(0,self.n- self.min_leaf): 78 | xi,yi = sort_x[i],sort_y[i] 79 | lhs_cnt += 1; rhs_cnt -= 1 80 | lhs_sum += yi; rhs_sum -= yi 81 | lhs_sum2 += yi**2; rhs_sum2 -= yi**2 82 | if i=0.5).astype(np.uint8) 41 | return X,y,W 42 | 43 | def plot_learning_curve(train_losses,val_losses): 44 | plt.plot(range(len(train_losses)),train_losses,'o-',color='r',label='Training loss',markersize=1) 45 | plt.plot(range(len(train_losses)),val_losses,'o-',color='g',label='Validation loss',markersize=1) 46 | plt.legend(loc="best") 47 | plt.show() 48 | 49 | def plot_feature_importances_rf(importances,col_names,figsize=(20,10)): 50 | fea_imp_df = pd.DataFrame(data={'Feature':col_names,'Importance':importances}).set_index('Feature') 51 | fea_imp_df = fea_imp_df.sort_values('Importance', ascending=True) 52 | fea_imp_df.plot(kind='barh',figsize=figsize) 53 | return fea_imp_df 54 | def permutation_importances(rf,X,y,metric,lowerisbetter=True): 55 | baseline = metric(rf,X,y) 56 | imp=[] 57 | for col in X.columns: 58 | save = X[col].copy() 59 | X[col] = np.random.permutation(X[col]) 60 | m = metric(rf,X,y) 61 | X[col] = save 62 | if lowerisbetter: 63 | imp.append(m-baseline) 64 | else: 65 | imp.append(baseline-m) 66 | fea_imp = np.array(imp) 67 | 68 | return plot_feature_importances_rf(fea_imp,X.columns.values) 69 | 70 | def draw_tree(t, df, size=10, ratio=0.6, precision=0): 71 | """ Draws a representation of a random forest in IPython. 72 | Parameters: 73 | ----------- 74 | t: The tree you wish to draw 75 | df: The data used to train the tree. This is used to get the names of the features. 76 | """ 77 | s=export_graphviz(t, out_file=None, feature_names=df.columns, filled=True, 78 | special_characters=True, rotate=True, precision=precision) 79 | IPython.display.display(graphviz.Source(re.sub('Tree {', 80 | f'Tree {{ size={size}; ratio={ratio}', s))) 81 | 82 | 83 | -------------------------------------------------------------------------------- /random_forest_regressor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Goals: rebuild Random Forest Regression from scratch\n", 8 | "\n", 9 | "Use only python and some basic numpy function (numpy slicing, np.mean, np.nonzero)\n", 10 | "\n", 11 | "Produce comparable results to Sklearn Ramdon Forest on some regression dataset" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%load_ext autoreload\n", 21 | "%autoreload 2\n", 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 12, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "\n", 32 | "import pandas as pd\n", 33 | "import numpy as np\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "from model.utils import draw_tree\n", 36 | "from model.metrics import *\n", 37 | "np.random.seed(42)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Generate a random-generated dataset" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 9, 50 | "metadata": { 51 | "scrolled": true 52 | }, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | " f0 f1 y\n", 59 | "0 2 9 5\n", 60 | "1 6 4 7\n", 61 | "2 9 7 1\n", 62 | "3 1 9 9\n", 63 | "4 4 9 3\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "nrows=20\n", 69 | "df = pd.DataFrame(np.random.randint(1,10,(nrows,3)),columns=['f0','f1','y'])\n", 70 | "print(df.head())\n", 71 | "y = df.y.values\n", 72 | "df.drop('y',axis=1,inplace=True)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Check whether my RF splits are the same as sklearn Random Forest's" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 31, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from sklearn.ensemble import RandomForestRegressor\n", 89 | "from sklearn import metrics" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 10, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "data": { 99 | "image/svg+xml": [ 100 | "\r\n", 101 | "\r\n", 103 | "\r\n", 105 | "\r\n", 106 | "\r\n", 108 | "\r\n", 109 | "Tree\r\n", 110 | "\r\n", 111 | "\r\n", 112 | "0\r\n", 113 | "\r\n", 114 | "f1 ≤ 5.5\r\n", 115 | "mse = 7.628\r\n", 116 | "samples = 20\r\n", 117 | "value = 4.35\r\n", 118 | "\r\n", 119 | "\r\n", 120 | "1\r\n", 121 | "\r\n", 122 | "f1 ≤ 4.5\r\n", 123 | "mse = 4.36\r\n", 124 | "samples = 10\r\n", 125 | "value = 2.8\r\n", 126 | "\r\n", 127 | "\r\n", 128 | "0->1\r\n", 129 | "\r\n", 130 | "\r\n", 131 | "True\r\n", 132 | "\r\n", 133 | "\r\n", 134 | "6\r\n", 135 | "\r\n", 136 | "f1 ≤ 6.5\r\n", 137 | "mse = 6.09\r\n", 138 | "samples = 10\r\n", 139 | "value = 5.9\r\n", 140 | "\r\n", 141 | "\r\n", 142 | "0->6\r\n", 143 | "\r\n", 144 | "\r\n", 145 | "False\r\n", 146 | "\r\n", 147 | "\r\n", 148 | "2\r\n", 149 | "\r\n", 150 | "f0 ≤ 3.5\r\n", 151 | "mse = 4.245\r\n", 152 | "samples = 7\r\n", 153 | "value = 3.571\r\n", 154 | "\r\n", 155 | "\r\n", 156 | "1->2\r\n", 157 | "\r\n", 158 | "\r\n", 159 | "\r\n", 160 | "\r\n", 161 | "5\r\n", 162 | "\r\n", 163 | "mse = 0.0\r\n", 164 | "samples = 3\r\n", 165 | "value = 1.0\r\n", 166 | "\r\n", 167 | "\r\n", 168 | "1->5\r\n", 169 | "\r\n", 170 | "\r\n", 171 | "\r\n", 172 | "\r\n", 173 | "3\r\n", 174 | "\r\n", 175 | "mse = 2.667\r\n", 176 | "samples = 3\r\n", 177 | "value = 4.0\r\n", 178 | "\r\n", 179 | "\r\n", 180 | "2->3\r\n", 181 | "\r\n", 182 | "\r\n", 183 | "\r\n", 184 | "\r\n", 185 | "4\r\n", 186 | "\r\n", 187 | "mse = 5.188\r\n", 188 | "samples = 4\r\n", 189 | "value = 3.25\r\n", 190 | "\r\n", 191 | "\r\n", 192 | "2->4\r\n", 193 | "\r\n", 194 | "\r\n", 195 | "\r\n", 196 | "\r\n", 197 | "7\r\n", 198 | "\r\n", 199 | "mse = 0.222\r\n", 200 | "samples = 3\r\n", 201 | "value = 7.667\r\n", 202 | "\r\n", 203 | "\r\n", 204 | "6->7\r\n", 205 | "\r\n", 206 | "\r\n", 207 | "\r\n", 208 | "\r\n", 209 | "8\r\n", 210 | "\r\n", 211 | "f0 ≤ 3.5\r\n", 212 | "mse = 6.694\r\n", 213 | "samples = 7\r\n", 214 | "value = 5.143\r\n", 215 | "\r\n", 216 | "\r\n", 217 | "6->8\r\n", 218 | "\r\n", 219 | "\r\n", 220 | "\r\n", 221 | "\r\n", 222 | "9\r\n", 223 | "\r\n", 224 | "mse = 2.889\r\n", 225 | "samples = 3\r\n", 226 | "value = 6.667\r\n", 227 | "\r\n", 228 | "\r\n", 229 | "8->9\r\n", 230 | "\r\n", 231 | "\r\n", 232 | "\r\n", 233 | "\r\n", 234 | "10\r\n", 235 | "\r\n", 236 | "mse = 6.5\r\n", 237 | "samples = 4\r\n", 238 | "value = 4.0\r\n", 239 | "\r\n", 240 | "\r\n", 241 | "8->10\r\n", 242 | "\r\n", 243 | "\r\n", 244 | "\r\n", 245 | "\r\n", 246 | "\r\n" 247 | ], 248 | "text/plain": [ 249 | "" 250 | ] 251 | }, 252 | "metadata": {}, 253 | "output_type": "display_data" 254 | } 255 | ], 256 | "source": [ 257 | "# sklearn RF split\n", 258 | "rf = RandomForestRegressor(n_estimators=1,max_depth=3,bootstrap=False,min_samples_leaf=3)\n", 259 | "rf.fit(df,y)\n", 260 | "draw_tree(rf.estimators_[0], df, precision=3)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 11, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "from model.random_forest import RandomForest" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 20, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "# My RF split\n", 279 | "ens = RandomForest(df, y,1, nrows)\n", 280 | "tree = ens.trees[0]" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 22, 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "Sample size: 20. Pred value: 4.35. Loss: 7.628\n", 292 | "Best split from feature 1 at value 5" 293 | ] 294 | }, 295 | "execution_count": 22, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "tree" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 23, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "data": { 311 | "text/plain": [ 312 | "Sample size: 10. Pred value: 2.80. Loss: 4.360\n", 313 | "Best split from feature 1 at value 4" 314 | ] 315 | }, 316 | "execution_count": 23, 317 | "metadata": {}, 318 | "output_type": "execute_result" 319 | } 320 | ], 321 | "source": [ 322 | "tree.lhs_tree_node" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 24, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "data": { 332 | "text/plain": [ 333 | "Sample size: 10. Pred value: 5.90. Loss: 6.090\n", 334 | "Best split from feature 1 at value 6" 335 | ] 336 | }, 337 | "execution_count": 24, 338 | "metadata": {}, 339 | "output_type": "execute_result" 340 | } 341 | ], 342 | "source": [ 343 | "tree.rhs_tree_node" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 27, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "Sample size: 7. Pred value: 3.57. Loss: 4.245\n", 356 | "Best split from feature 0 at value 2\n", 357 | "Sample size: 3. Pred value: 1.00. Loss: 0.000\n", 358 | "\n" 359 | ] 360 | } 361 | ], 362 | "source": [ 363 | "print(tree.lhs_tree_node.lhs_tree_node)\n", 364 | "print(tree.lhs_tree_node.rhs_tree_node)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 28, 370 | "metadata": {}, 371 | "outputs": [ 372 | { 373 | "name": "stdout", 374 | "output_type": "stream", 375 | "text": [ 376 | "Sample size: 3. Pred value: 7.67. Loss: 0.222\n", 377 | "\n", 378 | "Sample size: 7. Pred value: 5.14. Loss: 6.694\n", 379 | "Best split from feature 0 at value 3\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "print(tree.rhs_tree_node.lhs_tree_node)\n", 385 | "print(tree.rhs_tree_node.rhs_tree_node)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": {}, 391 | "source": [ 392 | "My RF regressor is able to produce the exact split like sklearn RF" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "metadata": {}, 398 | "source": [ 399 | "# Benchmarking on random-generated dataset" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 29, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "nrows=1000\n", 409 | "ncols=10\n", 410 | "df = pd.DataFrame(np.random.randint(1,10,(nrows,ncols)))\n", 411 | "y = df[ncols-1].values\n", 412 | "df.drop(ncols-1,axis=1,inplace=True)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 32, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stdout", 422 | "output_type": "stream", 423 | "text": [ 424 | "MSE score: 2.0772524444444445\n", 425 | "R2 score: 0.6952187072704108\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "rf = RandomForestRegressor(n_estimators=5,bootstrap=False,min_samples_leaf=3)\n", 431 | "rf.fit(df,y)\n", 432 | "\n", 433 | "rf_pred = rf.predict(df)\n", 434 | "print(f'MSE score: {MSE(y,rf_pred)}')\n", 435 | "print(f'R2 score: {metrics.r2_score(y,rf_pred)}')" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 33, 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "name": "stdout", 445 | "output_type": "stream", 446 | "text": [ 447 | "MSE score: 2.1058883111111113\n", 448 | "R2 score: 0.6910171589778858\n" 449 | ] 450 | } 451 | ], 452 | "source": [ 453 | "tb = RandomForest(df, y, n_trees=5, sample_sz=nrows)\n", 454 | "tb_pred = tb.predict(df)\n", 455 | "print(f'MSE score: {metrics.mean_squared_error(y,tb_pred)}')\n", 456 | "print(f'R2 score: {metrics.r2_score(y,tb_pred)}')" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "# Benchmarking: Boston housing dataset" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "The dataset for this project originates from the UCI Machine Learning Repository. The Boston housing data was collected in 1978 and each of the 506 entries represent aggregated data about 14 features for homes from various suburbs in Boston, Massachusetts. For the purposes of this project, the following preprocessing steps have been made to the dataset:\n", 471 | "\n", 472 | "- 16 data points have an 'MEDV' value of 50.0. These data points likely contain missing or censored values and have been removed.\n", 473 | "- 1 data point has an 'RM' value of 8.78. This data point can be considered an outlier and has been removed.\n", 474 | "- The features 'RM', 'LSTAT', 'PTRATIO', and 'MEDV' are essential. The remaining non-relevant features have been excluded.\n", 475 | "- The feature 'MEDV' has been multiplicatively scaled to account for 35 years of market inflation." 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 34, 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "(489, 4)\n" 488 | ] 489 | }, 490 | { 491 | "data": { 492 | "text/html": [ 493 | "
\n", 494 | "\n", 507 | "\n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | "
RMLSTATPTRATIOMEDV
06.5754.9815.3504000.0
16.4219.1417.8453600.0
27.1854.0317.8728700.0
36.9982.9418.7701400.0
47.1475.3318.7760200.0
\n", 555 | "
" 556 | ], 557 | "text/plain": [ 558 | " RM LSTAT PTRATIO MEDV\n", 559 | "0 6.575 4.98 15.3 504000.0\n", 560 | "1 6.421 9.14 17.8 453600.0\n", 561 | "2 7.185 4.03 17.8 728700.0\n", 562 | "3 6.998 2.94 18.7 701400.0\n", 563 | "4 7.147 5.33 18.7 760200.0" 564 | ] 565 | }, 566 | "execution_count": 34, 567 | "metadata": {}, 568 | "output_type": "execute_result" 569 | } 570 | ], 571 | "source": [ 572 | "housing = pd.read_csv('data/housing.csv')\n", 573 | "print(housing.shape)\n", 574 | "housing.head()" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 35, 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "y = housing.MEDV.values\n", 584 | "X = housing.drop('MEDV',axis=1)" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 36, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "from sklearn.model_selection import train_test_split\n", 594 | "def get_train_val(X,y):\n", 595 | " X_train,X_val,y_train,y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n", 596 | " return X_train.reset_index(drop=True),X_val.reset_index(drop=True),y_train,y_val" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 37, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "X_train,X_val,y_train,y_val = get_train_val(X,y)" 606 | ] 607 | }, 608 | { 609 | "cell_type": "markdown", 610 | "metadata": {}, 611 | "source": [ 612 | "### sklearn Random Forest" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 81, 618 | "metadata": {}, 619 | "outputs": [ 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "MSE train score: 3100445764.0763206\n", 625 | "MSE val score: 3219791806.9669237\n", 626 | "R2 train score: 0.8905847020820434\n", 627 | "R2 val score: 0.8534966571980468\n" 628 | ] 629 | } 630 | ], 631 | "source": [ 632 | "rf = RandomForestRegressor(n_estimators=15,max_features=0.8,bootstrap=False,min_samples_leaf=10)\n", 633 | "rf.fit(X_train,y_train)\n", 634 | "\n", 635 | "rf_train_pred = rf.predict(X_train)\n", 636 | "rf_val_pred = rf.predict(X_val)\n", 637 | "\n", 638 | "print(f'MSE train score: {metrics.mean_squared_error(y_train,rf_train_pred)}')\n", 639 | "print(f'MSE val score: {metrics.mean_squared_error(y_val,rf_val_pred)}')\n", 640 | "print(f'R2 train score: {metrics.r2_score(y_train,rf_train_pred)}')\n", 641 | "print(f'R2 val score: {metrics.r2_score(y_val,rf_val_pred)}')" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 82, 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "data": { 651 | "image/png": "\n", 652 | "text/plain": [ 653 | "
" 654 | ] 655 | }, 656 | "metadata": {}, 657 | "output_type": "display_data" 658 | } 659 | ], 660 | "source": [ 661 | "plt.scatter(range(0,len(y_val)),y_val,alpha=0.5,label='True')\n", 662 | "plt.scatter(range(0,len(rf_val_pred)),rf_val_pred,c='r',label='Pred')\n", 663 | "plt.legend()\n", 664 | "plt.show()" 665 | ] 666 | }, 667 | { 668 | "cell_type": "markdown", 669 | "metadata": {}, 670 | "source": [ 671 | "## My Random Forest" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 86, 677 | "metadata": {}, 678 | "outputs": [ 679 | { 680 | "data": { 681 | "text/plain": [ 682 | "391" 683 | ] 684 | }, 685 | "execution_count": 86, 686 | "metadata": {}, 687 | "output_type": "execute_result" 688 | } 689 | ], 690 | "source": [ 691 | "X_train.shape[0]" 692 | ] 693 | }, 694 | { 695 | "cell_type": "markdown", 696 | "metadata": {}, 697 | "source": [ 698 | "One possible thing I can tune here but not in sklearn Random Forest is the size of data subset for each tree (sample_sz)" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 134, 704 | "metadata": {}, 705 | "outputs": [ 706 | { 707 | "name": "stdout", 708 | "output_type": "stream", 709 | "text": [ 710 | "MSE train score: 3372489559.856298\n", 711 | "MSE val score: 3477783760.5038266\n", 712 | "R2 train score: 0.8809842267868835\n", 713 | "R2 val score: 0.8417577977080082\n", 714 | "Wall time: 221 ms\n" 715 | ] 716 | } 717 | ], 718 | "source": [ 719 | "%%time\n", 720 | "tb = RandomForest(X_train, y_train, n_trees=20, sample_sz=150,min_leaf=4,max_features=1)\n", 721 | "tb_train_pred = tb.predict(X_train)\n", 722 | "tb_val_pred = tb.predict(X_val)\n", 723 | "print(f'MSE train score: {metrics.mean_squared_error(y_train,tb_train_pred)}')\n", 724 | "print(f'MSE val score: {metrics.mean_squared_error(y_val,tb_val_pred)}')\n", 725 | "print(f'R2 train score: {metrics.r2_score(y_train,tb_train_pred)}')\n", 726 | "print(f'R2 val score: {metrics.r2_score(y_val,tb_val_pred)}')" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 135, 732 | "metadata": {}, 733 | "outputs": [ 734 | { 735 | "data": { 736 | "image/png": "\n", 737 | "text/plain": [ 738 | "
" 739 | ] 740 | }, 741 | "metadata": {}, 742 | "output_type": "display_data" 743 | } 744 | ], 745 | "source": [ 746 | "plt.scatter(range(0,len(y_val)),y_val,alpha=0.5,label='True')\n", 747 | "plt.scatter(range(0,len(tb_val_pred)),tb_val_pred,c='r',label='Pred')\n", 748 | "plt.legend()\n", 749 | "plt.show()" 750 | ] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "metadata": {}, 755 | "source": [ 756 | "Able to draw a comparable result to sklearn Random Forest" 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": 136, 762 | "metadata": {}, 763 | "outputs": [ 764 | { 765 | "data": { 766 | "text/plain": [ 767 | "count 98.000000\n", 768 | "mean 14091.515729\n", 769 | "std 13859.367850\n", 770 | "min 255.530303\n", 771 | "25% 5550.851892\n", 772 | "50% 10468.054696\n", 773 | "75% 18474.863616\n", 774 | "max 75983.185499\n", 775 | "dtype: float64" 776 | ] 777 | }, 778 | "execution_count": 136, 779 | "metadata": {}, 780 | "output_type": "execute_result" 781 | } 782 | ], 783 | "source": [ 784 | "# Take a look at prediction differences b/t 2 models\n", 785 | "pd.Series(abs(rf_val_pred-tb_val_pred)).describe()" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "# Calculate feature importance" 793 | ] 794 | }, 795 | { 796 | "cell_type": "markdown", 797 | "metadata": {}, 798 | "source": [ 799 | "Feature importance will be calculated by using **permutation importance**. There is an argument which is mentioned in this [doc](http://parrt.cs.usfca.edu/doc/rf-importance/index.html#3) that feature importance calculated in sklearn is misleading. SKlearn used **mean decrease in impurity (gini improtance)**, which does not always give an accurate picture of importance. \n", 800 | "\n", 801 | "On the other hand, **permutation importance** is calculated as follows: Record a baseline accuracy (classifier) or R2 score (regressor) by passing a validation set or the out-of-bag (OOB) samples through the Random Forest. Permute the column values of a single predictor feature and then pass all test samples back through the Random Forest and recompute the accuracy or R2. The importance of that feature is the difference between the baseline and the drop in overall accuracy or R2 caused by permuting the column. This is more expensive to calculate, but results are more reliable" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 138, 807 | "metadata": {}, 808 | "outputs": [], 809 | "source": [ 810 | "from model.utils import permutation_importances" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 139, 816 | "metadata": {}, 817 | "outputs": [], 818 | "source": [ 819 | "??permutation_importances" 820 | ] 821 | }, 822 | { 823 | "cell_type": "markdown", 824 | "metadata": {}, 825 | "source": [ 826 | "## Feature importance from Sklearn's RF model\n" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 143, 832 | "metadata": {}, 833 | "outputs": [], 834 | "source": [ 835 | "def metric(rf,X,y):\n", 836 | " y_pred = rf.predict(X)\n", 837 | " return metrics.r2_score(y,y_pred)" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": 144, 843 | "metadata": {}, 844 | "outputs": [ 845 | { 846 | "data": { 847 | "image/png": "\n", 848 | "text/plain": [ 849 | "
" 850 | ] 851 | }, 852 | "metadata": {}, 853 | "output_type": "display_data" 854 | } 855 | ], 856 | "source": [ 857 | "fea_imp_df = permutation_importances(rf,X_val,y_val,metric,False)" 858 | ] 859 | }, 860 | { 861 | "cell_type": "markdown", 862 | "metadata": {}, 863 | "source": [ 864 | "## Feature importance from 'scratch' RF" 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "execution_count": 145, 870 | "metadata": {}, 871 | "outputs": [ 872 | { 873 | "data": { 874 | "image/png": "\n", 875 | "text/plain": [ 876 | "
" 877 | ] 878 | }, 879 | "metadata": {}, 880 | "output_type": "display_data" 881 | } 882 | ], 883 | "source": [ 884 | "fea_imp_df = permutation_importances(tb,X_val,y_val,metric,False)" 885 | ] 886 | } 887 | ], 888 | "metadata": { 889 | "kernelspec": { 890 | "display_name": "Python 3", 891 | "language": "python", 892 | "name": "python3" 893 | }, 894 | "language_info": { 895 | "codemirror_mode": { 896 | "name": "ipython", 897 | "version": 3 898 | }, 899 | "file_extension": ".py", 900 | "mimetype": "text/x-python", 901 | "name": "python", 902 | "nbconvert_exporter": "python", 903 | "pygments_lexer": "ipython3", 904 | "version": "3.6.6" 905 | } 906 | }, 907 | "nbformat": 4, 908 | "nbformat_minor": 2 909 | } 910 | --------------------------------------------------------------------------------