├── .gitignore ├── NN_pytorch.ipynb ├── README.md ├── autoencoder.ipynb ├── collab_filtering.ipynb ├── data ├── cancer.csv ├── fashion │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-labels-idx1-ubyte.gz │ ├── train-images-idx3-ubyte.gz │ └── train-labels-idx1-ubyte.gz ├── housing.csv └── nlp │ ├── KVH.txt │ ├── kvh_trn │ └── trn.txt │ ├── kvh_val │ └── val.txt │ ├── nietzsche.txt │ └── trump.txt ├── knn.ipynb ├── linear_regression.ipynb ├── logistic_regression.ipynb ├── model ├── __init__.py ├── activation_classes.py ├── activations.py ├── gradients.py ├── knn.py ├── linear_model.py ├── metrics.py ├── neural_network.py ├── optimizers.py ├── random_forest.py └── utils.py ├── neural_net_optimizers.ipynb ├── random_forest_classifier.ipynb ├── random_forest_regressor.ipynb ├── rnn-vietnamese.ipynb └── scratch_neural_net.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | */.ipynb_checkpoints/* 3 | *__pycache__/ 4 | *.vscode/ 5 | fastai 6 | 7 | data/large_ds 8 | data/nlp/models 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning from scratch! 2 | 3 | Update: Code implementations have been moved to python module. Notebook will only show results and model comparison 4 | 5 | To refresh my knowledge, I will attempt to implement some basic machine learning algorithms from scratch using only python and limited numpy/pandas function. 6 | My model implementations will be compared to existing models from popular ML library (sklearn) 7 | - [Linear Regression with weight decay (L2 regularization)](https://github.com/anhquan0412/basic_model_scratch/blob/master/linear_regression.ipynb) 8 | - [Logistic Regression with weight decay](https://github.com/anhquan0412/basic_model_scratch/blob/master/logistic_regression.ipynb) 9 | - Random Forest with Permutation Feature Importances 10 | - [Random Forest Regressor](https://github.com/anhquan0412/basic_model_scratch/blob/master/random_forest_regressor.ipynb) 11 | - [Random Forest Classifier](https://github.com/anhquan0412/basic_model_scratch/blob/master/random_forest_classifier.ipynb) 12 | - [K Nearest Neighbors: supervised and unsupervised](https://github.com/anhquan0412/basic_model_scratch/blob/master/knn.ipynb) 13 | - [Neural network for classification](https://github.com/anhquan0412/basic_model_scratch/blob/master/scratch_neural_net.ipynb) 14 | - Stochastic Gradient Descent 15 | - Multiple hidden layers 16 | - Variety of activation functions + gradients (Sigmoid, Softmax, ReLU ...) customized for each hidden layer 17 | - L2 regularization 18 | - Dropout 19 | - [Dynamic learning rate optimizer](https://github.com/anhquan0412/basic_model_scratch/blob/master/neural_net_optimizers.ipynb) (momentum, RMSProp and Adam) 20 | - TODO: batchnorm 21 | 22 | The following notebooks uses Pytorch libraries so they are not implemented from scratch. However, I try not to use any high level Pytorch function 23 | - [Pytorch Neural Network](https://github.com/anhquan0412/basic_model_scratch/blob/master/NN_pytorch.ipynb) with: 24 | - Custom Data Loader 25 | - Data Augmentation on 1 channel image: torchvision vs fastai 26 | - Shallow NN with batchnorm and dropout 27 | - Learning rate finder 28 | - [Auto Encoding](https://github.com/anhquan0412/basic_model_scratch/blob/master/autoencoder.ipynb) 29 | - [Collaborative Filtering](https://github.com/anhquan0412/basic_model_scratch/blob/master/collab_filtering.ipynb) 30 | - [Char RNN in Vietnamese (Fast.ai)](https://github.com/anhquan0412/basic_model_scratch/blob/master/rnn-vietnamese.ipynb) 31 | -------------------------------------------------------------------------------- /data/fashion/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/fashion/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/fashion/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /data/fashion/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/data/fashion/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /data/housing.csv: -------------------------------------------------------------------------------- 1 | RM,LSTAT,PTRATIO,MEDV 2 | 6.575,4.98,15.3,504000.0 3 | 6.421,9.14,17.8,453600.0 4 | 7.185,4.03,17.8,728700.0 5 | 6.998,2.94,18.7,701400.0 6 | 7.147,5.33,18.7,760200.0 7 | 6.43,5.21,18.7,602700.0 8 | 6.012,12.43,15.2,480900.0 9 | 6.172,19.15,15.2,569100.0 10 | 5.631,29.93,15.2,346500.0 11 | 6.004,17.1,15.2,396900.0 12 | 6.377,20.45,15.2,315000.0 13 | 6.009,13.27,15.2,396900.0 14 | 5.889,15.71,15.2,455700.0 15 | 5.949,8.26,21.0,428400.0 16 | 6.096,10.26,21.0,382200.0 17 | 5.834,8.47,21.0,417900.0 18 | 5.935,6.58,21.0,485100.0 19 | 5.99,14.67,21.0,367500.0 20 | 5.456,11.69,21.0,424200.0 21 | 5.727,11.28,21.0,382200.0 22 | 5.57,21.02,21.0,285600.0 23 | 5.965,13.83,21.0,411600.0 24 | 6.142,18.72,21.0,319200.0 25 | 5.813,19.88,21.0,304500.0 26 | 5.924,16.3,21.0,327600.0 27 | 5.599,16.51,21.0,291900.0 28 | 5.813,14.81,21.0,348600.0 29 | 6.047,17.28,21.0,310800.0 30 | 6.495,12.8,21.0,386400.0 31 | 6.674,11.98,21.0,441000.0 32 | 5.713,22.6,21.0,266700.0 33 | 6.072,13.04,21.0,304500.0 34 | 5.95,27.71,21.0,277200.0 35 | 5.701,18.35,21.0,275100.0 36 | 6.096,20.34,21.0,283500.0 37 | 5.933,9.68,19.2,396900.0 38 | 5.841,11.41,19.2,420000.0 39 | 5.85,8.77,19.2,441000.0 40 | 5.966,10.13,19.2,518700.0 41 | 6.595,4.32,18.3,646800.0 42 | 7.024,1.98,18.3,732900.0 43 | 6.77,4.84,17.9,558600.0 44 | 6.169,5.81,17.9,531300.0 45 | 6.211,7.44,17.9,518700.0 46 | 6.069,9.55,17.9,445200.0 47 | 5.682,10.21,17.9,405300.0 48 | 5.786,14.15,17.9,420000.0 49 | 6.03,18.8,17.9,348600.0 50 | 5.399,30.81,17.9,302400.0 51 | 5.602,16.2,17.9,407400.0 52 | 5.963,13.45,16.8,413700.0 53 | 6.115,9.43,16.8,430500.0 54 | 6.511,5.28,16.8,525000.0 55 | 5.998,8.43,16.8,491400.0 56 | 5.888,14.8,21.1,396900.0 57 | 7.249,4.81,17.9,743400.0 58 | 6.383,5.77,17.3,518700.0 59 | 6.816,3.95,15.1,663600.0 60 | 6.145,6.86,19.7,489300.0 61 | 5.927,9.22,19.7,411600.0 62 | 5.741,13.15,19.7,392700.0 63 | 5.966,14.44,19.7,336000.0 64 | 6.456,6.73,19.7,466200.0 65 | 6.762,9.5,19.7,525000.0 66 | 7.104,8.05,18.6,693000.0 67 | 6.29,4.67,16.1,493500.0 68 | 5.787,10.24,16.1,407400.0 69 | 5.878,8.1,18.9,462000.0 70 | 5.594,13.09,18.9,365400.0 71 | 5.885,8.79,18.9,438900.0 72 | 6.417,6.72,19.2,508200.0 73 | 5.961,9.88,19.2,455700.0 74 | 6.065,5.52,19.2,478800.0 75 | 6.245,7.54,19.2,491400.0 76 | 6.273,6.78,18.7,506100.0 77 | 6.286,8.94,18.7,449400.0 78 | 6.279,11.97,18.7,420000.0 79 | 6.14,10.27,18.7,436800.0 80 | 6.232,12.34,18.7,445200.0 81 | 5.874,9.1,18.7,426300.0 82 | 6.727,5.29,19.0,588000.0 83 | 6.619,7.22,19.0,501900.0 84 | 6.302,6.72,19.0,520800.0 85 | 6.167,7.51,19.0,480900.0 86 | 6.389,9.62,18.5,501900.0 87 | 6.63,6.53,18.5,558600.0 88 | 6.015,12.86,18.5,472500.0 89 | 6.121,8.44,18.5,466200.0 90 | 7.007,5.5,17.8,495600.0 91 | 7.079,5.7,17.8,602700.0 92 | 6.417,8.81,17.8,474600.0 93 | 6.405,8.2,17.8,462000.0 94 | 6.442,8.16,18.2,480900.0 95 | 6.211,6.21,18.2,525000.0 96 | 6.249,10.59,18.2,432600.0 97 | 6.625,6.65,18.0,596400.0 98 | 6.163,11.34,18.0,449400.0 99 | 8.069,4.21,18.0,812700.0 100 | 7.82,3.57,18.0,919800.0 101 | 7.416,6.19,18.0,697200.0 102 | 6.727,9.42,20.9,577500.0 103 | 6.781,7.67,20.9,556500.0 104 | 6.405,10.63,20.9,390600.0 105 | 6.137,13.44,20.9,405300.0 106 | 6.167,12.33,20.9,422100.0 107 | 5.851,16.47,20.9,409500.0 108 | 5.836,18.66,20.9,409500.0 109 | 6.127,14.09,20.9,428400.0 110 | 6.474,12.27,20.9,415800.0 111 | 6.229,15.55,20.9,407400.0 112 | 6.195,13.0,20.9,455700.0 113 | 6.715,10.16,17.8,478800.0 114 | 5.913,16.21,17.8,394800.0 115 | 6.092,17.09,17.8,392700.0 116 | 6.254,10.45,17.8,388500.0 117 | 5.928,15.76,17.8,384300.0 118 | 6.176,12.04,17.8,445200.0 119 | 6.021,10.3,17.8,403200.0 120 | 5.872,15.37,17.8,428400.0 121 | 5.731,13.61,17.8,405300.0 122 | 5.87,14.37,19.1,462000.0 123 | 6.004,14.27,19.1,426300.0 124 | 5.961,17.93,19.1,430500.0 125 | 5.856,25.41,19.1,363300.0 126 | 5.879,17.58,19.1,394800.0 127 | 5.986,14.81,19.1,449400.0 128 | 5.613,27.26,19.1,329700.0 129 | 5.693,17.19,21.2,340200.0 130 | 6.431,15.39,21.2,378000.0 131 | 5.637,18.34,21.2,300300.0 132 | 6.458,12.6,21.2,403200.0 133 | 6.326,12.26,21.2,411600.0 134 | 6.372,11.12,21.2,483000.0 135 | 5.822,15.03,21.2,386400.0 136 | 5.757,17.31,21.2,327600.0 137 | 6.335,16.96,21.2,380100.0 138 | 5.942,16.9,21.2,365400.0 139 | 6.454,14.59,21.2,359100.0 140 | 5.857,21.32,21.2,279300.0 141 | 6.151,18.46,21.2,373800.0 142 | 6.174,24.16,21.2,294000.0 143 | 5.019,34.41,21.2,302400.0 144 | 5.403,26.82,14.7,281400.0 145 | 5.468,26.42,14.7,327600.0 146 | 4.903,29.29,14.7,247800.0 147 | 6.13,27.8,14.7,289800.0 148 | 5.628,16.65,14.7,327600.0 149 | 4.926,29.53,14.7,306600.0 150 | 5.186,28.32,14.7,373800.0 151 | 5.597,21.45,14.7,323400.0 152 | 6.122,14.1,14.7,451500.0 153 | 5.404,13.28,14.7,411600.0 154 | 5.012,12.12,14.7,321300.0 155 | 5.709,15.79,14.7,407400.0 156 | 6.129,15.12,14.7,357000.0 157 | 6.152,15.02,14.7,327600.0 158 | 5.272,16.14,14.7,275100.0 159 | 6.943,4.59,14.7,867300.0 160 | 6.066,6.43,14.7,510300.0 161 | 6.51,7.39,14.7,489300.0 162 | 6.25,5.5,14.7,567000.0 163 | 5.854,11.64,14.7,476700.0 164 | 6.101,9.81,14.7,525000.0 165 | 5.877,12.14,14.7,499800.0 166 | 6.319,11.1,14.7,499800.0 167 | 6.402,11.32,14.7,468300.0 168 | 5.875,14.43,14.7,365400.0 169 | 5.88,12.03,14.7,401100.0 170 | 5.572,14.69,16.6,485100.0 171 | 6.416,9.04,16.6,495600.0 172 | 5.859,9.64,16.6,474600.0 173 | 6.546,5.33,16.6,617400.0 174 | 6.02,10.11,16.6,487200.0 175 | 6.315,6.29,16.6,516600.0 176 | 6.86,6.92,16.6,627900.0 177 | 6.98,5.04,17.8,781200.0 178 | 7.765,7.56,17.8,835800.0 179 | 6.144,9.45,17.8,760200.0 180 | 7.155,4.82,17.8,795900.0 181 | 6.563,5.68,17.8,682500.0 182 | 5.604,13.98,17.8,554400.0 183 | 6.153,13.15,17.8,621600.0 184 | 6.782,6.68,15.2,672000.0 185 | 6.556,4.56,15.2,625800.0 186 | 7.185,5.39,15.2,732900.0 187 | 6.951,5.1,15.2,777000.0 188 | 6.739,4.69,15.2,640500.0 189 | 7.178,2.87,15.2,764400.0 190 | 6.8,5.03,15.6,653100.0 191 | 6.604,4.38,15.6,611100.0 192 | 7.287,4.08,12.6,699300.0 193 | 7.107,8.61,12.6,636300.0 194 | 7.274,6.62,12.6,726600.0 195 | 6.975,4.56,17.0,732900.0 196 | 7.135,4.45,17.0,690900.0 197 | 6.162,7.43,14.7,506100.0 198 | 7.61,3.11,14.7,888300.0 199 | 7.853,3.81,14.7,1018500.0 200 | 5.891,10.87,18.6,474600.0 201 | 6.326,10.97,18.6,512400.0 202 | 5.783,18.06,18.6,472500.0 203 | 6.064,14.66,18.6,512400.0 204 | 5.344,23.09,18.6,420000.0 205 | 5.96,17.27,18.6,455700.0 206 | 5.404,23.98,18.6,405300.0 207 | 5.807,16.03,18.6,470400.0 208 | 6.375,9.38,18.6,590100.0 209 | 5.412,29.55,18.6,497700.0 210 | 6.182,9.47,18.6,525000.0 211 | 5.888,13.51,16.4,489300.0 212 | 6.642,9.69,16.4,602700.0 213 | 5.951,17.92,16.4,451500.0 214 | 6.373,10.5,16.4,483000.0 215 | 6.951,9.71,17.4,560700.0 216 | 6.164,21.46,17.4,455700.0 217 | 6.879,9.93,17.4,577500.0 218 | 6.618,7.6,17.4,632100.0 219 | 8.266,4.14,17.4,940800.0 220 | 8.04,3.13,17.4,789600.0 221 | 7.163,6.36,17.4,663600.0 222 | 7.686,3.92,17.4,980700.0 223 | 6.552,3.76,17.4,661500.0 224 | 5.981,11.65,17.4,510300.0 225 | 7.412,5.25,17.4,665700.0 226 | 8.337,2.47,17.4,875700.0 227 | 8.247,3.95,17.4,1014300.0 228 | 6.726,8.05,17.4,609000.0 229 | 6.086,10.88,17.4,504000.0 230 | 6.631,9.54,17.4,527100.0 231 | 7.358,4.73,17.4,661500.0 232 | 6.481,6.36,16.6,497700.0 233 | 6.606,7.37,16.6,489300.0 234 | 6.897,11.38,16.6,462000.0 235 | 6.095,12.4,16.6,422100.0 236 | 6.358,11.22,16.6,466200.0 237 | 6.393,5.19,16.6,497700.0 238 | 5.593,12.5,19.1,369600.0 239 | 5.605,18.46,19.1,388500.0 240 | 6.108,9.16,19.1,510300.0 241 | 6.226,10.15,19.1,430500.0 242 | 6.433,9.52,19.1,514500.0 243 | 6.718,6.56,19.1,550200.0 244 | 6.487,5.9,19.1,512400.0 245 | 6.438,3.59,19.1,520800.0 246 | 6.957,3.53,19.1,621600.0 247 | 8.259,3.54,19.1,898800.0 248 | 6.108,6.57,16.4,459900.0 249 | 5.876,9.25,16.4,438900.0 250 | 7.454,3.11,15.9,924000.0 251 | 7.333,7.79,13.0,756000.0 252 | 6.842,6.9,13.0,632100.0 253 | 7.203,9.59,13.0,709800.0 254 | 7.52,7.26,13.0,905100.0 255 | 8.398,5.91,13.0,1024800.0 256 | 7.327,11.25,13.0,651000.0 257 | 7.206,8.1,13.0,766500.0 258 | 5.56,10.45,13.0,478800.0 259 | 7.014,14.79,13.0,644700.0 260 | 7.47,3.16,13.0,913500.0 261 | 5.92,13.65,18.6,434700.0 262 | 5.856,13.0,18.6,443100.0 263 | 6.24,6.59,18.6,529200.0 264 | 6.538,7.73,18.6,512400.0 265 | 7.691,6.58,18.6,739200.0 266 | 6.758,3.53,17.6,680400.0 267 | 6.854,2.98,17.6,672000.0 268 | 7.267,6.05,17.6,697200.0 269 | 6.826,4.16,17.6,695100.0 270 | 6.482,7.19,17.6,611100.0 271 | 6.812,4.85,14.9,737100.0 272 | 7.82,3.76,14.9,953400.0 273 | 6.968,4.59,14.9,743400.0 274 | 7.645,3.01,14.9,966000.0 275 | 7.088,7.85,15.3,676200.0 276 | 6.453,8.23,15.3,462000.0 277 | 6.23,12.93,18.2,422100.0 278 | 6.209,7.14,16.6,487200.0 279 | 6.315,7.6,16.6,468300.0 280 | 6.565,9.51,16.6,520800.0 281 | 6.861,3.33,19.2,598500.0 282 | 7.148,3.56,19.2,783300.0 283 | 6.63,4.7,19.2,585900.0 284 | 6.127,8.58,16.0,501900.0 285 | 6.009,10.4,16.0,455700.0 286 | 6.678,6.27,16.0,600600.0 287 | 6.549,7.39,16.0,569100.0 288 | 5.79,15.84,16.0,426300.0 289 | 6.345,4.97,14.8,472500.0 290 | 7.041,4.74,14.8,609000.0 291 | 6.871,6.07,14.8,520800.0 292 | 6.59,9.5,16.1,462000.0 293 | 6.495,8.67,16.1,554400.0 294 | 6.982,4.86,16.1,695100.0 295 | 7.236,6.93,18.4,758100.0 296 | 6.616,8.93,18.4,596400.0 297 | 7.42,6.47,18.4,701400.0 298 | 6.849,7.53,18.4,592200.0 299 | 6.635,4.54,18.4,478800.0 300 | 5.972,9.97,18.4,426300.0 301 | 4.973,12.64,18.4,338100.0 302 | 6.122,5.98,18.4,464100.0 303 | 6.023,11.72,18.4,407400.0 304 | 6.266,7.9,18.4,453600.0 305 | 6.567,9.28,18.4,499800.0 306 | 5.705,11.5,18.4,340200.0 307 | 5.914,18.33,18.4,373800.0 308 | 5.782,15.94,18.4,415800.0 309 | 6.382,10.36,18.4,485100.0 310 | 6.113,12.73,18.4,441000.0 311 | 6.426,7.2,19.6,499800.0 312 | 6.376,6.87,19.6,485100.0 313 | 6.041,7.7,19.6,428400.0 314 | 5.708,11.74,19.6,388500.0 315 | 6.415,6.12,19.6,525000.0 316 | 6.431,5.08,19.6,516600.0 317 | 6.312,6.15,19.6,483000.0 318 | 6.083,12.79,19.6,466200.0 319 | 5.868,9.97,16.9,405300.0 320 | 6.333,7.34,16.9,474600.0 321 | 6.144,9.09,16.9,415800.0 322 | 5.706,12.43,16.9,359100.0 323 | 6.031,7.83,16.9,407400.0 324 | 6.316,5.68,20.2,466200.0 325 | 6.31,6.75,20.2,434700.0 326 | 6.037,8.01,20.2,443100.0 327 | 5.869,9.8,20.2,409500.0 328 | 5.895,10.56,20.2,388500.0 329 | 6.059,8.51,20.2,432600.0 330 | 5.985,9.74,20.2,399000.0 331 | 5.968,9.29,20.2,392700.0 332 | 7.241,5.49,15.5,686700.0 333 | 6.54,8.65,15.9,346500.0 334 | 6.696,7.18,17.6,501900.0 335 | 6.874,4.61,17.6,655200.0 336 | 6.014,10.53,18.8,367500.0 337 | 5.898,12.67,18.8,361200.0 338 | 6.516,6.36,17.9,485100.0 339 | 6.635,5.99,17.0,514500.0 340 | 6.939,5.89,19.7,558600.0 341 | 6.49,5.98,19.7,480900.0 342 | 6.579,5.49,18.3,506100.0 343 | 5.884,7.79,18.3,390600.0 344 | 6.728,4.5,17.0,632100.0 345 | 5.663,8.05,22.0,382200.0 346 | 5.936,5.57,22.0,432600.0 347 | 6.212,17.6,20.2,373800.0 348 | 6.395,13.27,20.2,455700.0 349 | 6.127,11.48,20.2,476700.0 350 | 6.112,12.67,20.2,474600.0 351 | 6.398,7.79,20.2,525000.0 352 | 6.251,14.19,20.2,417900.0 353 | 5.362,10.19,20.2,436800.0 354 | 5.803,14.64,20.2,352800.0 355 | 3.561,7.12,20.2,577500.0 356 | 4.963,14.0,20.2,459900.0 357 | 3.863,13.33,20.2,485100.0 358 | 4.906,34.77,20.2,289800.0 359 | 4.138,37.97,20.2,289800.0 360 | 7.313,13.44,20.2,315000.0 361 | 6.649,23.24,20.2,291900.0 362 | 6.794,21.24,20.2,279300.0 363 | 6.38,23.69,20.2,275100.0 364 | 6.223,21.78,20.2,214200.0 365 | 6.968,17.21,20.2,218400.0 366 | 6.545,21.08,20.2,228900.0 367 | 5.536,23.6,20.2,237300.0 368 | 5.52,24.56,20.2,258300.0 369 | 4.368,30.63,20.2,184800.0 370 | 5.277,30.81,20.2,151200.0 371 | 4.652,28.28,20.2,220500.0 372 | 5.0,31.99,20.2,155400.0 373 | 4.88,30.62,20.2,214200.0 374 | 5.39,20.85,20.2,241500.0 375 | 5.713,17.11,20.2,317100.0 376 | 6.051,18.76,20.2,487200.0 377 | 5.036,25.68,20.2,203700.0 378 | 6.193,15.17,20.2,289800.0 379 | 5.887,16.35,20.2,266700.0 380 | 6.471,17.12,20.2,275100.0 381 | 6.405,19.37,20.2,262500.0 382 | 5.747,19.92,20.2,178500.0 383 | 5.453,30.59,20.2,105000.0 384 | 5.852,29.97,20.2,132300.0 385 | 5.987,26.77,20.2,117600.0 386 | 6.343,20.32,20.2,151200.0 387 | 6.404,20.31,20.2,254100.0 388 | 5.349,19.77,20.2,174300.0 389 | 5.531,27.38,20.2,178500.0 390 | 5.683,22.98,20.2,105000.0 391 | 4.138,23.34,20.2,249900.0 392 | 5.608,12.13,20.2,585900.0 393 | 5.617,26.4,20.2,361200.0 394 | 6.852,19.78,20.2,577500.0 395 | 5.757,10.11,20.2,315000.0 396 | 6.657,21.22,20.2,361200.0 397 | 4.628,34.37,20.2,375900.0 398 | 5.155,20.08,20.2,342300.0 399 | 4.519,36.98,20.2,147000.0 400 | 6.434,29.05,20.2,151200.0 401 | 6.782,25.79,20.2,157500.0 402 | 5.304,26.64,20.2,218400.0 403 | 5.957,20.62,20.2,184800.0 404 | 6.824,22.74,20.2,176400.0 405 | 6.411,15.02,20.2,350700.0 406 | 6.006,15.7,20.2,298200.0 407 | 5.648,14.1,20.2,436800.0 408 | 6.103,23.29,20.2,281400.0 409 | 5.565,17.16,20.2,245700.0 410 | 5.896,24.39,20.2,174300.0 411 | 5.837,15.69,20.2,214200.0 412 | 6.202,14.52,20.2,228900.0 413 | 6.193,21.52,20.2,231000.0 414 | 6.38,24.08,20.2,199500.0 415 | 6.348,17.64,20.2,304500.0 416 | 6.833,19.69,20.2,296100.0 417 | 6.425,12.03,20.2,338100.0 418 | 6.436,16.22,20.2,300300.0 419 | 6.208,15.17,20.2,245700.0 420 | 6.629,23.27,20.2,281400.0 421 | 6.461,18.05,20.2,201600.0 422 | 6.152,26.45,20.2,182700.0 423 | 5.935,34.02,20.2,176400.0 424 | 5.627,22.88,20.2,268800.0 425 | 5.818,22.11,20.2,220500.0 426 | 6.406,19.52,20.2,359100.0 427 | 6.219,16.59,20.2,386400.0 428 | 6.485,18.85,20.2,323400.0 429 | 5.854,23.79,20.2,226800.0 430 | 6.459,23.98,20.2,247800.0 431 | 6.341,17.79,20.2,312900.0 432 | 6.251,16.44,20.2,264600.0 433 | 6.185,18.13,20.2,296100.0 434 | 6.417,19.31,20.2,273000.0 435 | 6.749,17.44,20.2,281400.0 436 | 6.655,17.73,20.2,319200.0 437 | 6.297,17.27,20.2,338100.0 438 | 7.393,16.74,20.2,373800.0 439 | 6.728,18.71,20.2,312900.0 440 | 6.525,18.13,20.2,296100.0 441 | 5.976,19.01,20.2,266700.0 442 | 5.936,16.94,20.2,283500.0 443 | 6.301,16.23,20.2,312900.0 444 | 6.081,14.7,20.2,420000.0 445 | 6.701,16.42,20.2,344400.0 446 | 6.376,14.65,20.2,371700.0 447 | 6.317,13.99,20.2,409500.0 448 | 6.513,10.29,20.2,424200.0 449 | 6.209,13.22,20.2,449400.0 450 | 5.759,14.13,20.2,417900.0 451 | 5.952,17.15,20.2,399000.0 452 | 6.003,21.32,20.2,401100.0 453 | 5.926,18.13,20.2,401100.0 454 | 5.713,14.76,20.2,422100.0 455 | 6.167,16.29,20.2,417900.0 456 | 6.229,12.87,20.2,411600.0 457 | 6.437,14.36,20.2,487200.0 458 | 6.98,11.66,20.2,625800.0 459 | 5.427,18.14,20.2,289800.0 460 | 6.162,24.1,20.2,279300.0 461 | 6.484,18.68,20.2,350700.0 462 | 5.304,24.91,20.2,252000.0 463 | 6.185,18.03,20.2,306600.0 464 | 6.229,13.11,20.2,449400.0 465 | 6.242,10.74,20.2,483000.0 466 | 6.75,7.74,20.2,497700.0 467 | 7.061,7.01,20.2,525000.0 468 | 5.762,10.42,20.2,457800.0 469 | 5.871,13.34,20.2,432600.0 470 | 6.312,10.58,20.2,445200.0 471 | 6.114,14.98,20.2,401100.0 472 | 5.905,11.45,20.2,432600.0 473 | 5.454,18.06,20.1,319200.0 474 | 5.414,23.97,20.1,147000.0 475 | 5.093,29.68,20.1,170100.0 476 | 5.983,18.07,20.1,285600.0 477 | 5.983,13.35,20.1,422100.0 478 | 5.707,12.01,19.2,457800.0 479 | 5.926,13.59,19.2,514500.0 480 | 5.67,17.6,19.2,485100.0 481 | 5.39,21.14,19.2,413700.0 482 | 5.794,14.1,19.2,384300.0 483 | 6.019,12.92,19.2,445200.0 484 | 5.569,15.1,19.2,367500.0 485 | 6.027,14.33,19.2,352800.0 486 | 6.593,9.67,21.0,470400.0 487 | 6.12,9.08,21.0,432600.0 488 | 6.976,5.64,21.0,501900.0 489 | 6.794,6.48,21.0,462000.0 490 | 6.03,7.88,21.0,249900.0 491 | -------------------------------------------------------------------------------- /knn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Goals: Build K nearest neighbors model from scratch and compare with sklearn model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Unsupervised nearest neighbors" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "%reload_ext autoreload\n", 25 | "%autoreload 2" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 26, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "from model.metrics import accuracy" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "Toy example" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "toy = np.array([[-3, -2], [-2, -1], [-1, -1],[1, 1], [2, 1], [3, 2]])" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "[]" 65 | ] 66 | }, 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | }, 71 | { 72 | "data": { 73 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEg5JREFUeJzt3X+M5Hd93/Hn6w6b5AgtSb0Jxr69ddVTFJfQkI7cRFQVFSY5LOQLSZCMVo3TJFrR1gqRWgk3J4FKdVKiSGmUgAKbYMVUW5woxOUqjoIdiAhqTbxnHWBzOLlaPXt7VrxAgVib1rrw7h8zFnvX2du9/X53Z3c+z4c0mu/3M5+b9+ejs+c13x9zn1QVkqT2HJj0ACRJk2EASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhr1kkkP4GpuuOGGmpubm/QwJGnfOHPmzFeqamYrffd0AMzNzbG8vDzpYUjSvpHkwlb7egpIkhplAEhSowwASWqUASBJjTIAJKlRnQMgyeEkn05yLskTSd4xpk+S/GaS80m+kOSHu9aVpKmztARzc3DgwPB5aWlHy/VxG+gl4F9X1WNJXg6cSfJQVX1pXZ83AUdHj38E/PboWZIEww/7hQVYWxvuX7gw3AeYn9+Rkp2PAKrq2ap6bLT9V8A54KYruh0HPlRDjwCvSHJj19qSNDVOnPj2h/+L1taG7Tuk12sASeaA1wKfu+Klm4Bn1u2v8P+HxIvvsZBkOcny6upqn8OTpL3r6aevrb0HvQVAku8CPgL8UlV988qXx/yRsavRV9ViVQ2qajAzs6VfM0vS/jc7e23tPeglAJJcx/DDf6mq/mhMlxXg8Lr9m4GLfdSWpKlw8iQcOnR526FDw/Yd0sddQAE+CJyrql/foNsp4GdGdwP9CPCNqnq2a21Jmhrz87C4CEeOQDJ8XlzcsQvA0M9dQK8D/hnwxSRnR22/DMwCVNX7gdPAHcB5YA345z3UlaTpMj+/ox/4V+ocAFX1Wcaf41/fp4B/1bWWJKk//hJYkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRvW1JOR9SZ5L8vgGr78+yTeSnB093tVHXUnS9vWxIhjA7wHvBT50lT5/WlVv7qmeJKmjXo4AquozwNf6eC9J0u7YzWsAP5rk80k+nuTv72JdSdIYfZ0C2sxjwJGqej7JHcB/Bo6O65hkAVgAmJ2d3aXhSVJ7duUIoKq+WVXPj7ZPA9cluWGDvotVNaiqwczMzG4MT5KatCsBkOSVSTLavm1U96u7UVuSNF4vp4CSfBh4PXBDkhXg3cB1AFX1fuCngX+R5BLw18BdVVV91JYkbU8vAVBVb9vk9fcyvE1UkrRH+EtgSWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjegmAJPcleS7J4xu8niS/meR8ki8k+eE+6kqStq+vI4DfA45d5fU3AUdHjwXgt3uqK2m7lpZgbg4OHBg+Ly1NekTbMy3zmIC+loT8TJK5q3Q5DnxotA7wI0lekeTGqnq2j/qSrtHSEiwswNracP/CheE+wPz85MZ1raZlHhOyW9cAbgKeWbe/MmqTNAknTnz7Q/NFa2vD9v1kWuYxIbsVABnTVmM7JgtJlpMsr66u7vCwpEY9/fS1te9V0zKPCdmtAFgBDq/bvxm4OK5jVS1W1aCqBjMzM7syOKk5s7PX1r5XTcs8JmS3AuAU8DOju4F+BPiG5/+lCTp5Eg4durzt0KFh+34yLfOYkL5uA/0w8N+B70+ykuTnk7w9ydtHXU4DTwHngd8B/mUfdSVt0/w8LC7CkSOQDJ8XF/ffhdNpmceEZHhjzt40GAxqeXl50sOQpH0jyZmqGmylr78ElqRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1qq8VwY4leTLJ+ST3jnn9Z5OsJjk7evxCH3UlSdv3kq5vkOQg8D7gjQwXf380yamq+tIVXX+/qu7pWk+S1I8+jgBuA85X1VNV9QLwAHC8h/eVJO2gPgLgJuCZdfsro7Yr/VSSLyT5wySHN3qzJAtJlpMsr66u9jA8SdI4fQRAxrRdudL8fwHmquo1wMPA/Ru9WVUtVtWgqgYzMzM9DE+SNE4fAbACrP9GfzNwcX2HqvpqVf3f0e7vAP+wh7qSpA76CIBHgaNJbklyPXAXcGp9hyQ3rtu9EzjXQ11JUged7wKqqktJ7gE+ARwE7quqJ5K8B1iuqlPALya5E7gEfA342a51JUndpOrK0/V7x2AwqOXl5UkPQ5L2jSRnqmqwlb7+EliSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRG9RIASY4leTLJ+ST3jnn9pUl+f/T655LM9VFXkrR9nQMgyUHgfcCbgFuBtyW59YpuPw/876r6e8B/AH61a11JUjd9HAHcBpyvqqeq6gXgAeD4FX2OA/ePtv8QeEOS9FBbkrRNfQTATcAz6/ZXRm1j+1TVJeAbwN/pobYkaZv6CIBx3+SvXGh4K32GHZOFJMtJlldXVzsPTpI0Xh8BsAIcXrd/M3Bxoz5JXgL8beBr496sqharalBVg5mZmR6GJ0kap48AeBQ4muSWJNcDdwGnruhzCrh7tP3TwKeqauwRgCRpd7yk6xtU1aUk9wCfAA4C91XVE0neAyxX1Sngg8B/THKe4Tf/u7rWlSR10zkAAKrqNHD6irZ3rdv+P8Bb+6glSeqHvwSWpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDWqUwAk+Z4kDyX5i9Hzd2/Q72+SnB09rlwuUpI0AV2PAO4F/riqjgJ/PNof56+r6odGjzs71pQk9aBrABwH7h9t3w/8RMf3kyTtkq4B8H1V9SzA6Pl7N+j3HUmWkzySxJCQpD1g00XhkzwMvHLMSyeuoc5sVV1M8neBTyX5YlX9jw3qLQALALOzs9dQQpJ0LTYNgKq6faPXkvxlkhur6tkkNwLPbfAeF0fPTyX5E+C1wNgAqKpFYBFgMBjUpjOQJG1L11NAp4C7R9t3Ax+9skOS707y0tH2DcDrgC91rCtJ6qhrAPwK8MYkfwG8cbRPkkGS3x31+QFgOcnngU8Dv1JVBoAkTdimp4Cupqq+CrxhTPsy8Auj7f8G/GCXOpKk/vlLYElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSozoFQJK3JnkiybeSDK7S71iSJ5OcT3Jvl5pNWVqCuTk4cGD4vLQ06RFtz7TMQ5oynVYEAx4HfhL4wEYdkhwE3sdwycgV4NEkp1wWchNLS7CwAGtrw/0LF4b7APPzkxvXtZqWeUhTqNMRQFWdq6onN+l2G3C+qp6qqheAB4DjXeo24cSJb39ovmhtbdi+n0zLPKQptBvXAG4Cnlm3vzJqGyvJQpLlJMurq6s7Prg96+mnr619r5qWeUhTaNMASPJwksfHPLb6LT5j2mqjzlW1WFWDqhrMzMxsscQUmp29tva9alrmIU2hTQOgqm6vqlePeXx0izVWgMPr9m8GLm5nsE05eRIOHbq87dChYft+Mi3zkKbQbpwCehQ4muSWJNcDdwGndqHu/jY/D4uLcOQIJMPnxcX9d+F0WuYhTaFUbXg2ZvM/nLwF+C1gBvg6cLaqfjzJq4Dfrao7Rv3uAH4DOAjcV1Vb+vo3GAxqeXl52+OTpNYkOVNVG96Wv16n20Cr6kHgwTHtF4E71u2fBk53qSVJ6pe/BJakRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNapTACR5a5InknwryYYr0CT5n0m+mORsEpf4kqQ9oNOKYMDjwE8CH9hC339aVV/pWE+S1JOuS0KeA0jSz2gkSbtmt64BFPDJJGeSLFytY5KFJMtJlldXV3dpeJLUnk2PAJI8DLxyzEsnquqjW6zzuqq6mOR7gYeSfLmqPjOuY1UtAosAg8Ggtvj+kqRrtGkAVNXtXYtU1cXR83NJHgRuA8YGgCRpd+z4KaAkL0vy8he3gR9jePFYkjRBXW8DfUuSFeBHgY8l+cSo/VVJTo+6fR/w2SSfB/4M+FhV/dcudSVJ3XW9C+hB4MEx7ReBO0bbTwH/oEsdSVL//CWwJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjeq6IMyvJflyki8keTDJKzbodyzJk0nOJ7m3S01JUj+6HgE8BLy6ql4D/Dnwb6/skOQg8D7gTcCtwNuS3Nqx7saWlmBuDg4cGD4vLe1YKUnazzoFQFV9sqoujXYfAW4e0+024HxVPVVVLwAPAMe71N3Q0hIsLMCFC1A1fF5YMAQkaYw+rwH8HPDxMe03Ac+s218ZtfXvxAlYW7u8bW1t2C5JusymawIneRh45ZiXTlTVR0d9TgCXgHFftTOmra5SbwFYAJidnd1seJd7+ulra5ekhm0aAFV1+9VeT3I38GbgDVU17oN9BTi8bv9m4OJV6i0CiwCDwWDDoBhrdnZ42mdcuyTpMl3vAjoGvBO4s6rWNuj2KHA0yS1JrgfuAk51qbuhkyfh0KHL2w4dGrZLki7T9RrAe4GXAw8lOZvk/QBJXpXkNMDoIvE9wCeAc8AfVNUTHeuONz8Pi4tw5Agkw+fFxWG7JOkyGX/WZm8YDAa1vLw86WFI0r6R5ExVDbbS118CS1KjDABJapQBIEmNMgAkqVEGgCQ1ak/fBZRkFRjzy64tuQH4So/DmaRpmcu0zAOcy140LfOAbnM5UlUzW+m4pwOgiyTLW70Vaq+blrlMyzzAuexF0zIP2L25eApIkhplAEhSo6Y5ABYnPYAeTctcpmUe4Fz2ommZB+zSXKb2GoAk6eqm+QhAknQVUx0ASf79aMH6s0k+meRVkx7TdiT5tSRfHs3lwSSvmPSYtivJW5M8keRbSfbdHRtJjiV5Msn5JPdOejxdJLkvyXNJHp/0WLpIcjjJp5OcG/239Y5Jj2m7knxHkj9L8vnRXP7djtab5lNASf5WVX1ztP2LwK1V9fYJD+uaJfkx4FNVdSnJrwJU1TsnPKxtSfIDwLeADwD/pqr2zT/3muQg8OfAGxkudPQo8Laq+tJEB7ZNSf4J8Dzwoap69aTHs11JbgRurKrHkrwcOAP8xH78e0kS4GVV9XyS64DPAu+oqkd2ot5UHwG8+OE/8jKushTlXlZVnxytqwDwCMNV1falqjpXVU9OehzbdBtwvqqeqqoXgAeA4xMe07ZV1WeAr016HF1V1bNV9dho+68YrjuyM+uO77Aaen60e93osWOfW1MdAABJTiZ5BpgH3jXp8fTg54CPT3oQjboJeGbd/gr79INmWiWZA14LfG6yI9m+JAeTnAWeAx6qqh2by74PgCQPJ3l8zOM4QFWdqKrDDBesv2eyo93YZvMY9TkBXGI4lz1rK3PZpzKmbV8eVU6jJN8FfAT4pSuO/veVqvqbqvohhkf6tyXZsdNzmy4Kv9dttmj9Ov8J+Bjw7h0czrZtNo8kdwNvBt5Qe/zCzTX8new3K8Dhdfs3AxcnNBatMzpf/hFgqar+aNLj6UNVfT3JnwDHgB25UL/vjwCuJsnRdbt3Al+e1Fi6SHIMeCdwZ1WtTXo8DXsUOJrkliTXA3cBpyY8puaNLpx+EDhXVb8+6fF0kWTmxbv8knwncDs7+Lk17XcBfQT4foZ3nVwA3l5V/2uyo7p2Sc4DLwW+Omp6ZD/ezQSQ5C3AbwEzwNeBs1X145Md1dYluQP4DeAgcF9VnZzwkLYtyYeB1zP8lyf/Enh3VX1wooPahiT/GPhT4IsM/18H+OWqOj25UW1PktcA9zP87+sA8AdV9Z4dqzfNASBJ2thUnwKSJG3MAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVH/DzE8KcwjdPNkAAAAAElFTkSuQmCC\n", 74 | "text/plain": [ 75 | "
" 76 | ] 77 | }, 78 | "metadata": {}, 79 | "output_type": "display_data" 80 | } 81 | ], 82 | "source": [ 83 | "plt.plot(toy[:,0],toy[:,1],'ro')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## 1. Sklearn unsupervised nearest neighbors" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 6, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "[[0. 1.41421356 2.23606798]\n", 103 | " [0. 1. 1.41421356]\n", 104 | " [0. 1. 2.23606798]\n", 105 | " [0. 1. 2.23606798]\n", 106 | " [0. 1. 1.41421356]\n", 107 | " [0. 1.41421356 2.23606798]]\n", 108 | "[[0 1 2]\n", 109 | " [1 2 0]\n", 110 | " [2 1 0]\n", 111 | " [3 4 5]\n", 112 | " [4 3 5]\n", 113 | " [5 4 3]]\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "from sklearn.neighbors import NearestNeighbors\n", 119 | "nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(toy) # 3 neighbors, including himself\n", 120 | "dist,idx=nbrs.kneighbors(toy)\n", 121 | "print(dist)\n", 122 | "print(idx)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "(array([[0.70710678, 0.70710678, 1.58113883]]), array([[2, 1, 0]], dtype=int64))\n" 135 | ] 136 | }, 137 | { 138 | "data": { 139 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEnNJREFUeJzt3X+MpVd93/H3ZxebZAgtST0Jxt7ZcdVVFJfSkF65iagqKkyyWMgb0iAZjRqnSTQirRUitRJuVgKVaqVEkdIoAQUmwYqJpjhRiMtWLAU7EBGUmnjWWsBmcbK12PV0rXiAArEmrbXxt3/cazE7vbPz4z4zd+ae90u6ep5znrP3nKO17+c+P+6eVBWSpPYcGvcAJEnjYQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGvWScQ/gWm644YaanZ0d9zAk6cA4e/bsV6tqeitt93UAzM7OsrS0NO5hSNKBkeTiVtt6CUiSGmUASFKjDABJapQBIEmNMgAkqVEjB0CSI0k+neR8kieSvGNImyT5jSQXknwhyQ+N2q8kTZzFRZidhUOH+tvFxV3trovHQK8A/66qHkvycuBskoeq6ktr2rwJODZ4/VPgtwZbSRL0P+zn52F1tV++eLFfBpib25UuRz4DqKpnquqxwf5fA+eBm9Y1OwF8qPoeAV6R5MZR+5akiXHy5Lc//F+0utqv3yWd3gNIMgu8FvjcukM3AU+vKS/z/4fEi+8xn2QpydLKykqXw5Ok/evSpe3Vd6CzAEjyXcBHgF+sqm+tPzzkjwxdjb6qFqqqV1W96ekt/ZpZkg6+mZnt1XegkwBIch39D//FqvqjIU2WgSNryjcDl7voW5ImwqlTMDV1dd3UVL9+l3TxFFCADwLnq+rXNmh2GvipwdNAPwx8s6qeGbVvSZoYc3OwsABHj0LS3y4s7NoNYOjmKaDXAf8K+GKSc4O6XwJmAKrq/cAZ4A7gArAK/OsO+pWkyTI3t6sf+OuNHABV9VmGX+Nf26aAfztqX5Kk7vhLYElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhrV1ZKQ9yV5NsnjGxx/fZJvJjk3eL2ri34lSTvXxYpgAL8LvBf40DXa/GlVvbmj/iRJI+rkDKCqPgN8vYv3kiTtjb28B/AjST6f5ONJ/uEe9itJGqKrS0CbeQw4WlXPJbkD+K/AsWENk8wD8wAzMzN7NDxJas+enAFU1beq6rnB/hnguiQ3bNB2oap6VdWbnp7ei+FJUpP2JACSvDJJBvu3Dfr92l70LUkarpNLQEk+DLweuCHJMvBu4DqAqno/8JPAzye5AvwNcFdVVRd9S5J2ppMAqKq3bXL8vfQfE5Uk7RP+EliSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1KhOAiDJfUmeTfL4BseT5DeSXEjyhSQ/1EW/kqSd6+oM4HeB49c4/ibg2OA1D/xWR/1K2qnFRZidhUOH+tvFxXGPaGcmZR5j0NWSkJ9JMnuNJieADw3WAX4kySuS3FhVz3TRv6RtWlyE+XlYXe2XL17slwHm5sY3ru2alHmMyV7dA7gJeHpNeXlQJ2kcTp789ofmi1ZX+/UHyaTMY0z2KgAypK6GNkzmkywlWVpZWdnlYUmNunRpe/X71aTMY0z2KgCWgSNryjcDl4c1rKqFqupVVW96enpPBic1Z2Zme/X71aTMY0z2KgBOAz81eBroh4Fvev1fGqNTp2Bq6uq6qal+/UEyKfMYk64eA/0w8D+A70+ynORnk7w9ydsHTc4ATwEXgN8G/k0X/Uraobk5WFiAo0ch6W8XFg7ejdNJmceYpP9gzv7U6/VqaWlp3MOQpAMjydmq6m2lrb8ElqRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1qqsVwY4neTLJhST3Djn+00lWkpwbvH6ui34lSTv3klHfIMlh4H3AG+kv/v5oktNV9aV1TX+/qu4ZtT9JUje6OAO4DbhQVU9V1fPAA8CJDt5XkrSLugiAm4Cn15SXB3Xr/cskX0jyh0mObPRmSeaTLCVZWllZ6WB4kqRhugiADKlbv9L8fwNmq+o1wMPA/Ru9WVUtVFWvqnrT09MdDE+SNEwXAbAMrP1GfzNweW2DqvpaVf3fQfG3gX/SQb+SpBF0EQCPAseS3JLkeuAu4PTaBkluXFO8EzjfQb+SpBGM/BRQVV1Jcg/wCeAwcF9VPZHkPcBSVZ0GfiHJncAV4OvAT4/aryRpNKlaf7l+/+j1erW0tDTuYUjSgZHkbFX1ttLWXwJLUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUqE4CIMnxJE8muZDk3iHHX5rk9wfHP5dktot+JUk7N3IAJDkMvA94E3Ar8LYkt65r9rPA/66qfwD8Z+BXRu1XkjSaLs4AbgMuVNVTVfU88ABwYl2bE8D9g/0/BN6QJB30LUnaoS4C4Cbg6TXl5UHd0DZVdQX4JvD3OuhbkrRDXQTAsG/y6xca3kqbfsNkPslSkqWVlZWRBydJGq6LAFgGjqwp3wxc3qhNkpcAfxf4+rA3q6qFqupVVW96erqD4UmShukiAB4FjiW5Jcn1wF3A6XVtTgN3D/Z/EvhUVQ09A5Ak7Y2XjPoGVXUlyT3AJ4DDwH1V9USS9wBLVXUa+CDwe0ku0P/mf9eo/UqSRjNyAABU1RngzLq6d63Z/z/AW7voS5LUDX8JLEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElq1EgBkOR7kjyU5C8H2+/eoN3fJjk3eK1fLlKSNAajngHcC/xxVR0D/nhQHuZvquoHB687R+xTktSBUQPgBHD/YP9+4MdHfD9J0h4ZNQC+r6qeARhsv3eDdt+RZCnJI0kMCUnaBzZdFD7Jw8Arhxw6uY1+ZqrqcpK/D3wqyRer6n9u0N88MA8wMzOzjS4kSduxaQBU1e0bHUvyV0lurKpnktwIPLvBe1webJ9K8ifAa4GhAVBVC8ACQK/Xq01nIEnakVEvAZ0G7h7s3w18dH2DJN+d5KWD/RuA1wFfGrFfSdKIRg2AXwbemOQvgTcOyiTpJfmdQZsfAJaSfB74NPDLVWUASNKYbXoJ6Fqq6mvAG4bULwE/N9j/M+AfjdKPJKl7/hJYkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktSokQIgyVuTPJHkhSS9a7Q7nuTJJBeS3DtKn01ZXITZWTh0qL9dXBz3iHZmUuYhTZiRVgQDHgd+AvjARg2SHAbeR3/JyGXg0SSnXRZyE4uLMD8Pq6v98sWL/TLA3Nz4xrVdkzIPaQKNdAZQVeer6slNmt0GXKiqp6rqeeAB4MQo/Tbh5Mlvf2i+aHW1X3+QTMo8pAm0F/cAbgKeXlNeHtQNlWQ+yVKSpZWVlV0f3L516dL26verSZmHNIE2DYAkDyd5fMhrq9/iM6SuNmpcVQtV1auq3vT09Ba7mEAzM9ur368mZR7SBNo0AKrq9qp69ZDXR7fYxzJwZE35ZuDyTgbblFOnYGrq6rqpqX79QTIp85Am0F5cAnoUOJbkliTXA3cBp/eg34Ntbg4WFuDoUUj624WFg3fjdFLmIU2gVG14NWbzP5y8BfhNYBr4BnCuqn4syauA36mqOwbt7gB+HTgM3FdVW/r61+v1amlpacfjk6TWJDlbVRs+lr/WSI+BVtWDwIND6i8Dd6wpnwHOjNKXJKlb/hJYkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktSokQIgyVuTPJHkhSQbrkCT5CtJvpjkXBKX+NpHFhdhdhYOHepvFxfHPSJJe2WkFcGAx4GfAD6whbb/oqq+OmJ/6tDiIszPw+pqv3zxYr8MLtkrtWCkM4CqOl9VT3Y1GO2tkye//eH/otXVfr2kybdX9wAK+GSSs0nmr9UwyXySpSRLKysrezS8Nl26tL16SZNl00tASR4GXjnk0Mmq+ugW+3ldVV1O8r3AQ0m+XFWfGdawqhaABYBer1dbfH/twMxM/7LPsHpJk2/TAKiq20ftpKouD7bPJnkQuA0YGgDaO6dOXX0PAGBqql8vafLt+iWgJC9L8vIX94EfpX/zWGM2NwcLC3D0KCT97cKCN4ClVoz6GOhbkiwDPwJ8LMknBvWvSnJm0Oz7gM8m+Tzw58DHquq/j9KvujM3B1/5CrzwQn/rh7/UjpEeA62qB4EHh9RfBu4Y7D8F/ONR+pEkdc9fAktSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktSoUReE+dUkX07yhSQPJnnFBu2OJ3kyyYUk947SpySpG6OeATwEvLqqXgP8BfAf1jdIchh4H/Am4FbgbUluHbHfjS0uwuwsHDrU3y4u7lpXknSQjRQAVfXJqroyKD4C3Dyk2W3Ahap6qqqeBx4ATozS74YWF/urnF+8CFX97fy8ISBJQ3R5D+BngI8Pqb8JeHpNeXlQ172TJ2F19eq61dV+vSTpKpuuCZzkYeCVQw6drKqPDtqcBK4Aw75qZ0hdXaO/eWAeYGZmZrPhXe3Spe3VS1LDNg2Aqrr9WseT3A28GXhDVQ37YF8Gjqwp3wxcvkZ/C8ACQK/X2zAohpqZ6V/2GVYvSbrKqE8BHQfeCdxZVasbNHsUOJbkliTXA3cBp0fpd0OnTsHU1NV1U1P9eknSVUa9B/Be4OXAQ0nOJXk/QJJXJTkDMLhJfA/wCeA88AdV9cSI/Q43NwcLC3D0KCT97cJCv16SdJUMv2qzP/R6vVpaWhr3MCTpwEhytqp6W2nrL4ElqVEGgCQ1ygCQpEYZAJLUKANAkhq1r58CSrICDPll15bcAHy1w+GM06TMZVLmAc5lP5qUecBoczlaVdNbabivA2AUSZa2+ijUfjcpc5mUeYBz2Y8mZR6wd3PxEpAkNcoAkKRGTXIALIx7AB2alLlMyjzAuexHkzIP2KO5TOw9AEnStU3yGYAk6RomOgCS/KfBgvXnknwyyavGPaadSPKrSb48mMuDSV4x7jHtVJK3JnkiyQtJDtwTG0mOJ3kyyYUk9457PKNIcl+SZ5M8Pu6xjCLJkSSfTnJ+8N/WO8Y9pp1K8h1J/jzJ5wdz+Y+72t8kXwJK8neq6luD/V8Abq2qt495WNuW5EeBT1XVlSS/AlBV7xzzsHYkyQ8ALwAfAP59VR2Yf+41yWHgL4A30l/o6FHgbVX1pbEObIeS/HPgOeBDVfXqcY9np5LcCNxYVY8leTlwFvjxg/j3kiTAy6rquSTXAZ8F3lFVj+xGfxN9BvDih//Ay7jGUpT7WVV9crCuAsAj9FdVO5Cq6nxVPTnucezQbcCFqnqqqp4HHgBOjHlMO1ZVnwG+Pu5xjKqqnqmqxwb7f01/3ZHdWXd8l1Xfc4PidYPXrn1uTXQAACQ5leRpYA5417jH04GfAT4+7kE06ibg6TXlZQ7oB82kSjILvBb43HhHsnNJDic5BzwLPFRVuzaXAx8ASR5O8viQ1wmAqjpZVUfoL1h/z3hHu7HN5jFocxK4Qn8u+9ZW5nJAZUjdgTyrnERJvgv4CPCL687+D5Sq+tuq+kH6Z/q3Jdm1y3ObLgq/3222aP0a/wX4GPDuXRzOjm02jyR3A28G3lD7/MbNNv5ODppl4Mia8s3A5TGNRWsMrpd/BFisqj8a93i6UFXfSPInwHFgV27UH/gzgGtJcmxN8U7gy+MayyiSHAfeCdxZVavjHk/DHgWOJbklyfXAXcDpMY+peYMbpx8EzlfVr417PKNIMv3iU35JvhO4nV383Jr0p4A+Anw//adOLgJvr6r/Nd5RbV+SC8BLga8Nqh45iE8zASR5C/CbwDTwDeBcVf3YeEe1dUnuAH4dOAzcV1WnxjykHUvyYeD19P/lyb8C3l1VHxzroHYgyT8D/hT4Iv3/1wF+qarOjG9UO5PkNcD99P/7OgT8QVW9Z9f6m+QAkCRtbKIvAUmSNmYASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUqP8HKrVTm8232E8AAAAASUVORK5CYII=\n", 140 | "text/plain": [ 141 | "
" 142 | ] 143 | }, 144 | "metadata": {}, 145 | "output_type": "display_data" 146 | } 147 | ], 148 | "source": [ 149 | "test = [-1.5,-1.5]\n", 150 | "print(nbrs.kneighbors([test]))\n", 151 | "plt.plot(toy[:,0],toy[:,1],'ro')\n", 152 | "plt.plot(test[0],test[1],'bo')\n", 153 | "plt.show()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "## 2. My Unsupervised KNN" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 11, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "from model.knn import CustomNearestNeighbor" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 12, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "c_nn = CustomNearestNeighbor(k=3)\n", 179 | "c_nn.fit(toy)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 18, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "c_dist,c_idx=c_nn.kneighbors(toy)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 16, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/plain": [ 199 | "array([[ True, True, True],\n", 200 | " [ True, True, True],\n", 201 | " [ True, True, True],\n", 202 | " [ True, True, True],\n", 203 | " [ True, True, True],\n", 204 | " [ True, True, True]])" 205 | ] 206 | }, 207 | "execution_count": 16, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "c_dist == dist" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 19, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/plain": [ 224 | "array([[ True, True, True],\n", 225 | " [ True, True, True],\n", 226 | " [ True, True, True],\n", 227 | " [ True, True, True],\n", 228 | " [ True, True, True],\n", 229 | " [ True, True, True]])" 230 | ] 231 | }, 232 | "execution_count": 19, 233 | "metadata": {}, 234 | "output_type": "execute_result" 235 | } 236 | ], 237 | "source": [ 238 | "c_idx ==idx" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "Identical results to sklearn nearest neighbors" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "# Supervised nearest neighbors" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "We will use Fashion MNIST https://github.com/zalandoresearch/fashion-mnist for benchmarking" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 20, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# Fashion MNIST\n", 269 | "def load_mnist(path, kind='train'):\n", 270 | " import os\n", 271 | " import struct\n", 272 | " import gzip\n", 273 | " import numpy as np\n", 274 | "\n", 275 | " \"\"\"Load MNIST data from `path`\"\"\"\n", 276 | " labels_path = os.path.join(path,\n", 277 | " '%s-labels-idx1-ubyte.gz'\n", 278 | " % kind)\n", 279 | " images_path = os.path.join(path,\n", 280 | " '%s-images-idx3-ubyte.gz'\n", 281 | " % kind)\n", 282 | "\n", 283 | " with gzip.open(labels_path, 'rb') as lbpath:\n", 284 | " struct.unpack('>II', lbpath.read(8))\n", 285 | " labels = np.frombuffer(lbpath.read(), dtype=np.uint8)\n", 286 | "\n", 287 | " with gzip.open(images_path, 'rb') as imgpath:\n", 288 | " struct.unpack(\">IIII\", imgpath.read(16))\n", 289 | " images = np.frombuffer(imgpath.read(), dtype=np.uint8).reshape(len(labels), 784)\n", 290 | "\n", 291 | " return images, labels" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 21, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "labels={\n", 301 | " 0:'T-shirt/top',\n", 302 | " 1:'Trouser',\n", 303 | " 2:'Pullover',\n", 304 | " 3:'Dress',\n", 305 | " 4:'Coat',\n", 306 | " 5:'Sandal',\n", 307 | " 6:'Shirt',\n", 308 | " 7:'Sneaker',\n", 309 | " 8:'Bag',\n", 310 | " 9:'Ankle boot',\n", 311 | "}" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 22, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "Training set size: (60000, 784)\n", 324 | "Testing set size: (10000, 784)\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "X_train, y_train = load_mnist('data/fashion', kind='train')\n", 330 | "X_test, y_test = load_mnist('data/fashion', kind='t10k')\n", 331 | "y_train,y_test = y_train.astype(np.int64),y_test.astype(np.int64)\n", 332 | "print('Training set size: {}'.format(X_train.shape))\n", 333 | "print('Testing set size: {}'.format(X_test.shape))" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 23, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "#only take 5000 images for simplification\n", 343 | "np.random.seed(42)\n", 344 | "train_idx =np.random.permutation(5000)\n", 345 | "test_idx = np.random.permutation(1000)\n", 346 | "X_train = X_train[train_idx]\n", 347 | "y_train = y_train[train_idx]\n", 348 | "X_test = X_test[test_idx]\n", 349 | "y_test = y_test[test_idx]" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "Preprocess and split dataset" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 24, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stderr", 366 | "output_type": "stream", 367 | "text": [ 368 | "C:\\Users\\qtran\\AppData\\Local\\Continuum\\Miniconda3\\envs\\fastai-cpu\\lib\\site-packages\\sklearn\\utils\\validation.py:475: DataConversionWarning: Data with input dtype uint8 was converted to float64 by StandardScaler.\n", 369 | " warnings.warn(msg, DataConversionWarning)\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "# from sklearn.model_selection import train_test_split\n", 375 | "from sklearn.metrics import accuracy_score\n", 376 | "from sklearn.preprocessing import StandardScaler\n", 377 | "\n", 378 | "s = StandardScaler()\n", 379 | "X_train = s.fit_transform(X_train)\n", 380 | "X_test = s.transform(X_test)\n", 381 | "y_train = np.array(y_train)\n", 382 | "y_test = np.array(y_test)" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": {}, 388 | "source": [ 389 | "## 1. Sklearn KNeighborsClassifier" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 27, 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "data": { 399 | "text/plain": [ 400 | "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", 401 | " metric_params=None, n_jobs=1, n_neighbors=10, p=2,\n", 402 | " weights='uniform')" 403 | ] 404 | }, 405 | "execution_count": 27, 406 | "metadata": {}, 407 | "output_type": "execute_result" 408 | } 409 | ], 410 | "source": [ 411 | "from sklearn.neighbors import KNeighborsClassifier\n", 412 | "\n", 413 | "nn = KNeighborsClassifier(n_neighbors=10)\n", 414 | "nn.fit(X_train,y_train)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 28, 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "data": { 424 | "text/plain": [ 425 | "0.807" 426 | ] 427 | }, 428 | "execution_count": 28, 429 | "metadata": {}, 430 | "output_type": "execute_result" 431 | } 432 | ], 433 | "source": [ 434 | "pred = nn.predict(X_test)\n", 435 | "accuracy(y_test,pred)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "## 2. My supervised KNN" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 29, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "(5000, 784)\n", 455 | "(1000, 784)\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "print(X_train.shape)\n", 461 | "print(X_test.shape)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 32, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "c_nn = CustomNearestNeighbor(k=10)\n", 471 | "c_nn.fit(X_train,y_train)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "### uniform weights among class" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 33, 484 | "metadata": { 485 | "scrolled": false 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "y_pred,counter,class_sorted=c_nn.predict_classification(X_test,weighted=False)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 34, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "data": { 499 | "text/plain": [ 500 | "0.804" 501 | ] 502 | }, 503 | "execution_count": 34, 504 | "metadata": {}, 505 | "output_type": "execute_result" 506 | } 507 | ], 508 | "source": [ 509 | "accuracy_score(y_test,y_pred)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "### Weight points by the inverse of distance" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 35, 522 | "metadata": {}, 523 | "outputs": [], 524 | "source": [ 525 | "y_pred,weighted_counter,class_sorted=c_nn.predict_classification(X_test,weighted=True)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 36, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "data": { 535 | "text/plain": [ 536 | "0.809" 537 | ] 538 | }, 539 | "execution_count": 36, 540 | "metadata": {}, 541 | "output_type": "execute_result" 542 | } 543 | ], 544 | "source": [ 545 | "accuracy_score(y_test,y_pred)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": {}, 551 | "source": [ 552 | "Not bad! With weighted inverse-Euclidean distance, there is small boost in accuracy on test set" 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "metadata": {}, 558 | "source": [ 559 | "### Evaluate results" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 37, 565 | "metadata": {}, 566 | "outputs": [], 567 | "source": [ 568 | "from matplotlib import pyplot as plt,cm\n", 569 | "def evaluate(idx):\n", 570 | " img = X_test[idx].reshape([28,28])\n", 571 | " plt.imshow(img,cmap=cm.binary)\n", 572 | " print(f'Weights (in order): {sorted(weighted_counter[idx],reverse=True)}')\n", 573 | " print(f'Predictions (in order): {[labels[i] for i in class_sorted[idx]]}')\n", 574 | " print('Actual: ' + labels[y_test[0]])\n" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 38, 580 | "metadata": {}, 581 | "outputs": [ 582 | { 583 | "name": "stdout", 584 | "output_type": "stream", 585 | "text": [ 586 | "Weights (in order): [0.3040748811530332, 0.28658216317523766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n", 587 | "Predictions (in order): ['Coat', 'Shirt', 'Ankle boot', 'Bag', 'Sneaker', 'Sandal', 'Dress', 'Pullover', 'Trouser', 'T-shirt/top']\n", 588 | "Actual: Coat\n" 589 | ] 590 | }, 591 | { 592 | "data": { 593 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE8tJREFUeJzt3X+IleeVB/DvyUQdnRl1RuNvd9KtISQEYpdBlmRZsjRp0k3BNNDQIRQXSm2gwhYK2SCE5p9CWGy7ISwFu5EaqLGFNhtDQtsQGtzCIhklNBo3aTBaJzNxHJ1xRmeiUU//mNfsaOY958593vu+157vB8SZe+573+e+d87cmTnPcx5RVRBRPDdUPQAiqgaTnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXF5CcKislPFNSNZZ6ss7NTV61aVeYpCyMiuTFvluQNN9jfY63HruX4ixcv5sY++eQT81iPd+65c+ea8ZaWltyYd928+OXLl5PijVTVzNmBgQGMjIzYX1CZpOQXkQcAPAOgBcB/qerT1v1XrVqFXbt2WY+XMpakuMdKAu+LbP78+WbcS6AFCxaY8eHh4dzYRx99ZB5rfeMAgIULF5rxNWvW1H28d+7z58+b8cnJSTM+MTGRG/O+Hhr9jck6PuUbR29vb833rfvHfhFpAfCfAL4M4HYAvSJye72PR0TlSvmdfwOA91X1iKpeALAbwMZihkVEjZaS/KsBHJ/2eX9221VEZLOI9IlI38jISMLpiKhIKck/0y9Nn/llRVW3q2qPqvZ0dnYmnI6IipSS/P0A1k77fA2AgbThEFFZUpL/TQC3iMjnRGQugK8D2FPMsIio0eou9anqRRHZAuC3mCr17VDVQ95xjSrnpZbyrHo0ANx4Y/6lmjdvnnmsV8o7fvy4Gd+/f78Z//jjj3NjbW1t5rHedbt06ZIZP3jwoBm3SqTeNffi999/vxmfM2dObswrI3plSO+6pF5XS1FzCJLq/Kr6KoBXCxkJEZWK03uJgmLyEwXF5CcKislPFBSTnygoJj9RUKWu56+SVzO2asIA0NraWvex27ZtM+PWHAIAuPPOO8340NBQbiy1Xu09N29Z7blz53Jj3hwE7zV75ZVXzPiDDz5oxi2pS8RT6vgpx84G3/mJgmLyEwXF5CcKislPFBSTnygoJj9RUE1V6kspr3gtpr1ymrfs1ip5ecd6br31VjPulbysTrErVqwwj+3o6DDjo6OjZvzDDz8049bYvOe1ZMkSM+6VMa2l1qnt0r0lwSm8JbtFlQL5zk8UFJOfKCgmP1FQTH6ioJj8REEx+YmCYvITBVV6nT+l/XbKTrkprbkB4IMPPsiN7d271zx20aJFZtxqvV1L3Br7sWPHzGO9HYS96zY+Pm7G29vbc2PeNR8bG0s695NPPpkb8+ZWPPzww2bcW+rsbY2esnV5UVuP852fKCgmP1FQTH6ioJj8REEx+YmCYvITBcXkJwoqqc4vIkcBjAO4BOCiqvbUcEzK+XJj3pp6r1798ssvm/EDBw7kxrx151atGwCGh4fNuPfcrHkAixcvNo+9cOGCGT9z5owZ99pvL1iwoO5ze2vmvXkC3d3dubHBwUHz2GeffdaMP/bYY2bc27bdeu4p/Rtmk19FTPL5J1W1v3qJqOnwx36ioFKTXwH8TkT2i8jmIgZEROVI/bH/blUdEJFlAF4Tkf9T1asmumffFDYDwMqVKxNPR0RFSXrnV9WB7P8hAC8C2DDDfbarao+q9nR2dqacjogKVHfyi0ibiHRc+RjAlwAcLGpgRNRYKT/2LwfwYlZauBHALlX9TSGjIqKGqzv5VfUIAHvv6FnyeqWn9O33aqf79u0z41b/e2+9vjc2r3e+x6qlW1tkA3492uv77z23iYmJ3Jg3/8H7NdGbB2DFvT4GJ06cMOO7du0y448++qgZt9bse70AvGteK5b6iIJi8hMFxeQnCorJTxQUk58oKCY/UVDX1RbdVonDK3+88cYb9QzpU1ZpyCvNeCUtrxWz1ybaW9qacuy6devM+JEjR8z46tWrc2PeFttey3KvTGmV07xze2VGr214yrJcbwvulGXx0/GdnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXF5CcKqtQ6v4iY9XivfpmypPfdd9814ym19JT5CbUc77Hq2d4cA6+m7G2T7bXutq6rd829WrnHanluLTUG/NdkdHTUjI+MjJhxaxm497xZ5yeiJEx+oqCY/ERBMfmJgmLyEwXF5CcKislPFNR1tZ7finvHvvfee2bca1Ftrdn31pV7rDo9kHZdvLF5j+3Vq70twC1e+2xvHsDk5KQZt66rtz24V2v3tk3v7+8349a27l6vgZQ8mI7v/ERBMfmJgmLyEwXF5CcKislPFBSTnygoJj9RUG6dX0R2APgKgCFVvSO7rQvALwDcDOAogEdU1S4I///j1TvWJAMDA2a8u7vbjFvr+b018R6vb783DyClR4I3D8CrOXtaW1tzY16t3JOyVbX3mnlj8+KnT58249Y8gmZaz/8zAA9cc9sTAF5X1VsAvJ59TkTXETf5VXUvgGu/jW0EsDP7eCeAhwoeFxE1WL0/Ny1X1UEAyP5fVtyQiKgMDf+Dn4hsFpE+Eenzfg8iovLUm/wnRGQlAGT/D+XdUVW3q2qPqvZ0dXXVeToiKlq9yb8HwKbs400AXipmOERUFjf5ReQFAP8L4FYR6ReRbwJ4GsB9IvInAPdlnxPRdcSt86tqb07oiwWPJal+ee7cOTPu7fXusWqvZ86cMY/t6Oio+7FriafsKeDVuxcsWGDGPdbYvJ7/3ti9NflW/Pz58+axXq8Bb46BN6/EOt57vVPmN1z1OIU8ChFdd5j8REEx+YmCYvITBcXkJwqKyU8UVOmtu1PaDlvxd955xzw2tXxite5OXdLrndsql3lx75p6y4VT49bYvKXMqcbHx3Nj3lJl73l5JdChodxJrwDStptPWcJ91ePUfE8i+qvC5CcKislPFBSTnygoJj9RUEx+oqCY/ERBNdUW3R6rhum1CLNaSAN+C2urVbM3h8CrKS9atMiMe3XflDbQ3tJVq1YO+NfNWrabet2819Saf5GyDXYt556YmDDjKecus3U3Ef0VYvITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioEqv86e0LLbio6Oj5rFe+2yvjbRVM/bWds+ZM8eMey2ovZqyJbV1t1fHt/ocAHY93Ztj4D2295pbr4u3Xt97Tbwtur1eBdZ1914ztu4moiRMfqKgmPxEQTH5iYJi8hMFxeQnCorJTxSUW+cXkR0AvgJgSFXvyG57CsC3AJzM7rZVVV+t4bHMWr3Xn9461tuie/HixWbcq9taNeXu7m7zWI9XU/bmCaT0cfdq6am8uRsWb+zetutWn4Sbbrop6bG9r1Wvzn/q1Knc2JIlS8xjy1zP/zMAD8xw+49VdX32z018ImoubvKr6l4AdpscIrrupPzOv0VE/igiO0Sks7AREVEp6k3+nwD4PID1AAYB/DDvjiKyWUT6RKTP67NHROWpK/lV9YSqXlLVywB+CmCDcd/tqtqjqj1dXV31jpOIClZX8ovIymmffhXAwWKGQ0RlqaXU9wKAewAsFZF+AN8HcI+IrAegAI4C+HYDx0hEDeAmv6r2znDzc/WczKvzezVhq57t1WXb29vNuFfnt9ZfezVjb9355OSkGffWnls1ZW8OgdfnwJs/4fUasF4zb47B2bNnk+LLly/PjXl7JYyNjZlxr8+BF/f2DbBYdf7ZzAHgDD+ioJj8REEx+YmCYvITBcXkJwqKyU8UVOmtu1PKFFbZKHXLZe9469xWSQnwy2Xe8k9vSa9VCkxtA+0tXfVaf6eUdr2W6N51s8q3a9asMY89dOiQGfdKpN51sUrTXmm3KHznJwqKyU8UFJOfKCgmP1FQTH6ioJj8REEx+YmCKrXOr6pubdZi1ay92mjK0lPv8b3lwh6v3u0ty7V4cwS8pdApS08Be9mtVwtP2f7be3yvlbv39eSN/fz582Y85TW1xjabOQJ85ycKislPFBSTnygoJj9RUEx+oqCY/ERBMfmJgip9Pb9Vh/TmAFjx1C2VPda4vdbdXk3Xq82mbHOdyqtXt7W1mXHruaX2CvDmAVjXfdmyZeax3tei95p48ysmJibqPnfKXJnp+M5PFBSTnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXlFr9FZC2A5wGsAHAZwHZVfUZEugD8AsDNAI4CeERVR1IG49W7U/qZz2br4tmee+HCheaxqXsKeHXdlP0MPF6fAy9eZX96q5buref35hh4czescwPA8PBwbsx7zcqs818E8D1VvQ3A3wP4jojcDuAJAK+r6i0AXs8+J6LrhJv8qjqoqgeyj8cBHAawGsBGADuzu+0E8FCjBklExZvV7/wicjOALwDYB2C5qg4CU98gANjzJYmoqdSc/CLSDuBXAL6rqmOzOG6ziPSJSN/p06frGSMRNUBNyS8iczCV+D9X1V9nN58QkZVZfCWAoZmOVdXtqtqjqj1dXV1FjJmICuAmv0z9Kfo5AIdV9UfTQnsAbMo+3gTgpeKHR0SNUss617sBfAPA2yLyVnbbVgBPA/iliHwTwJ8BfK2WE6a0HbZKHKklrZTyirUVNJC+pDelzbO3bDa1rOSNzVp265U4veviLekdHR3NjXnP24t75/ZKhSdPnsyNpSzpnU351E1+Vf0DgLxX6Ys1n4mImgpn+BEFxeQnCorJTxQUk58oKCY/UVBMfqKgSm/dbdUovfqmVRf2WnfPnz/fjHtLMK26bmqL6dRls1abaG9sXl343LlzZnzevHlmPGVb9ZR5H4A9B8Gr43sty0dGklavJ21dztbdRJSEyU8UFJOfKCgmP1FQTH6ioJj8REEx+YmCKrXOr6pJW3RbNeOUNe+AX9e1Ht+r43u1cm8759bWVjNujc2rlXt1ei/uzY+wXlPveafOj7Be09T+D16d33vNrdbdKXX+2azn5zs/UVBMfqKgmPxEQTH5iYJi8hMFxeQnCorJTxRUU63n9+qbVp3f6tEOAEuXLjXjXk3ZqqV74/b6+nv9673Ht3rze337W1pazPjYmL0zm9cnwZqj4M3r8J63F7fWzHvnXrbM3nry2LFjZtybR2BtXcf1/ETUUEx+oqCY/ERBMfmJgmLyEwXF5CcKislPFJRb5xeRtQCeB7ACwGUA21X1GRF5CsC3AFzZaHyrqr7qPZ5Vo0ypd1v7nQN+Pdrr+2+tk/Zqut7YUnrfA/a6eG/Nu/fY3nPz+iiMj4/nxrx69uTkpBk/deqUGT9z5kzd5161apUZ37dvnxn35p0cOXIkN/b444+bxxZV569lks9FAN9T1QMi0gFgv4i8lsV+rKrbChkJEZXKTX5VHQQwmH08LiKHAaxu9MCIqLFm9Tu/iNwM4AsArvzMs0VE/igiO0SkM+eYzSLSJyJ91pRGIipXzckvIu0AfgXgu6o6BuAnAD4PYD2mfjL44UzHqep2Ve1R1Z6urq4ChkxERagp+UVkDqYS/+eq+msAUNUTqnpJVS8D+CmADY0bJhEVzU1+mfpz8HMADqvqj6bdvnLa3b4K4GDxwyOiRqnlr/13A/gGgLdF5K3stq0AekVkPQAFcBTAt70HUtWkMoVVljpw4IB5rLe0tb293Yxb5TSvHHbbbbeZ8cHBQTPuLbu1zu8tVfbKaV4pr62tzYxbr7f32B0dHWZ83bp1Znz9+vW5Ma/Ud9ddd5nxQ4cOmfH+/n4zvmLFityYlyNFte6u5a/9fwAwU9a5NX0ial6c4UcUFJOfKCgmP1FQTH6ioJj8REEx+YmCaqrW3bOpUV5r2zZ7ceHu3bvNeHd3txnfuHFjbsyrld97771m3Nuq2mv97c1hsFjtrWuJe0uhrRbY3lLmlHbqgD3/wRu3twR8y5YtZjxl7N6xKXkyHd/5iYJi8hMFxeQnCorJTxQUk58oKCY/UVBMfqKgpKiaYU0nEzkJYPrexksBDJc2gNlp1rE167gAjq1eRY6tW1VvquWOpSb/Z04u0qeqPZUNwNCsY2vWcQEcW72qGht/7CcKislPFFTVyb+94vNbmnVszTougGOrVyVjq/R3fiKqTtXv/ERUkUqSX0QeEJF3ReR9EXmiijHkEZGjIvK2iLwlIn0Vj2WHiAyJyMFpt3WJyGsi8qfs/xm3SatobE+JyIfZtXtLRP65orGtFZHfi8hhETkkIv+a3V7ptTPGVcl1K/3HfhFpAfAegPsA9AN4E0Cvqr5T6kByiMhRAD2qWnlNWET+EcBZAM+r6h3Zbf8O4LSqPp194+xU1X9rkrE9BeBs1Ts3ZxvKrJy+szSAhwD8Cyq8dsa4HkEF162Kd/4NAN5X1SOqegHAbgD5nTICU9W9AK7d3XQjgJ3Zxzsx9cVTupyxNQVVHVTVA9nH4wCu7Cxd6bUzxlWJKpJ/NYDj0z7vR3Nt+a0Afici+0Vkc9WDmcHybNv0K9un57fKqYa7c3OZrtlZummuXT07XhetiuSfafefZio53K2qfwfgywC+k/14S7Wpaefmssyws3RTqHfH66JVkfz9ANZO+3wNgIEKxjEjVR3I/h8C8CKab/fhE1c2Sc3+H6p4PJ9qpp2bZ9pZGk1w7Zppx+sqkv9NALeIyOdEZC6ArwPYU8E4PkNE2rI/xEBE2gB8Cc23+/AeAJuyjzcBeKnCsVylWXZuzttZGhVfu2bb8bqSST5ZKeM/ALQA2KGqPyh9EDMQkb/F1Ls9MNXZeFeVYxORFwDcg6lVXycAfB/AfwP4JYC/AfBnAF9T1dL/8JYztnsw9aPrpzs3X/kdu+Sx/QOA/wHwNoAr7aK3Yur368qunTGuXlRw3TjDjygozvAjCorJTxQUk58oKCY/UVBMfqKgmPxEQTH5iYJi8hMF9RfIvWniLmVqwwAAAABJRU5ErkJggg==\n", 594 | "text/plain": [ 595 | "
" 596 | ] 597 | }, 598 | "metadata": {}, 599 | "output_type": "display_data" 600 | } 601 | ], 602 | "source": [ 603 | "# Correct prediction\n", 604 | "evaluate(0)" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": {}, 610 | "source": [ 611 | "Based on weighted_counter, my KNN classifier results in a close weight between class 4 (Coat) with weight .304 and 6 (Shirt) with weight .286." 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 39, 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "name": "stdout", 621 | "output_type": "stream", 622 | "text": [ 623 | "Weights (in order): [0.3754328103184546, 0.14807884211858424, 0.07343824003610218, 0.07208342067277425, 0.0704589146423518, 0.0, 0.0, 0.0, 0.0, 0.0]\n", 624 | "Predictions (in order): ['T-shirt/top', 'Coat', 'Pullover', 'Dress', 'Shirt', 'Ankle boot', 'Bag', 'Sneaker', 'Sandal', 'Trouser']\n", 625 | "Actual: Coat\n" 626 | ] 627 | }, 628 | { 629 | "data": { 630 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE3VJREFUeJzt3V9sVPeVB/DvCX+NMYZgDObfpiERSkhUunJgJVarRFWqUFUiPDQqDxUrVaUPjbSV+rARL81LpWjVf3lYVaIbVCK1aZFaNjxEu42ijbJIqMQkiCTLsiBCwMaBJGAYmxhjOPvgS2WI7znj+c29d+j5fiRke8787v3N9Rxmxuf3R1QVRBTPPVV3gIiqweQnCorJTxQUk58oKCY/UVBMfqKgmPxEQTH5iYJi8hMFNbPMk82dO1c7OjoKObaIFHLcZpzbi99zj/1/8IwZMxo+ftHXxTv+8PBwbuzatWtm287OTjN+48YNM26NXvVGthY98rWo49dqNYyOjtb1S09KfhF5CsCLAGYA+DdVfcG6f0dHB7Zu3Wodr+G+eAmUmgTW8WfOtC/j7NmzzXhbW5sZX7BggRmfM2dOw+dOfRJ61/3gwYO5sVOnTpltN2/ebMZrtZoZ//zzz3Nj4+PjZtubN2+ace+6ee2teMp/TPv27TPbTtbw234RmQHgXwFsBvAwgG0i8nCjxyOicqV85t8A4KSqnlLVMQC/A7ClOd0ioqKlJP8KAGcn/dyf3XYbEdkhIn0i0jc6OppwOiJqppTkn+pD9Bc+jKjqLlXtVdXeuXPnJpyOiJopJfn7Aaya9PNKAOfSukNEZUlJ/rcBPCgiXxKR2QC+BWB/c7pFREVruNSnquMi8iyA/8REqW+3qn7gtUspuRVZs/bKddZHFq+c1t7ebsavX79uxg8dOmTGe3t7c2Nnz57NjQF+yWvRokVm3KrjA0B3d3duzPt9fvCB/XRat26dGbd+Z1euXDHbppYCvTEIltQyYr2S6vyq+hqA15rSEyIqFYf3EgXF5CcKislPFBSTnygoJj9RUEx+oqBKnc8P2DXMlHnxXluvju9Nq7Xi3nx779gDAwNm3BsW/dFHH5lxS1dXlxn3as5e36zHPn/+fLPtuXP2gNEjR46Y8cceeyw35k1F9uaheHFv7Ib1fPXGGFh9n85YGL7yEwXF5CcKislPFBSTnygoJj9RUEx+oqBKL/WllOusuFdumzVrlhm3VsD14illQgA4c+aMGV+zZo0Zt0o/3nRj77p500e96zY2NpYbu3Tpktl24cKFZnxoaMiMW1OpvVKfp8ilv722XimwXnzlJwqKyU8UFJOfKCgmP1FQTH6ioJj8REEx+YmCuqvq/FZt1qtXe/VuL27Vs71zv/nmm2Z82bJlZtybNmv1zWvr1YxTa87Wdf3www/NtmvXrjXj3mOzdun1xid4j9tbmrvIpbtTjj0ZX/mJgmLyEwXF5CcKislPFBSTnygoJj9RUEx+oqCS6vwichpADcANAOOqmr9X9MT9C6vze/OzvaW7vfn+VnuvrTfvvKenp+FzA/Z187ai9sY3eNfVW6La2uLbq2d7zwdv2fFarZYb6+joMNt6tXTvuqWMn/DWUPDGldSrGYN8nlDVT5twHCIqEd/2EwWVmvwK4E8iclhEdjSjQ0RUjtS3/ZtU9ZyIdAN4XUT+V1XfmnyH7D+FHYC/PRMRlSfplV9Vz2VfLwDYB2DDFPfZpaq9qtrrLWRJROVpOPlFpF1EOm59D+BrAN5vVseIqFgpb/uXAtiXlWNmAvitqv5HU3pFRIVrOPlV9RSAL0+3nVW79WrKVtyrhXu1Ua+9de7h4WGzbXd3txn3xgl48ZR13L2ashdPWb/+5MmTZtsNG77wKfI23hgDK+7V6b3H7V1zr28pdX7r3Nyim4hcTH6ioJj8REEx+YmCYvITBcXkJwqq9KW7LV6pL6VM6JX6UkqB1hLRADBv3rykuDe91HrsqUtUe8tje1OGrW22ly5darb1eOU6a0qvVz71rnnqFHGrnOeVEadTzrPwlZ8oKCY/UVBMfqKgmPxEQTH5iYJi8hMFxeQnCqr0Or9Vk/bqlyl1/pRje/GxsTGz7aef2osbezXhJUuWmHFL6pRdb+m1q1evNnz8gYEBs633O2lvbzfjIyMjubHUadLeGAPvOWFdF286sPdcrxdf+YmCYvITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioEqv86ds0W3NuU+dz58yTsCb0+7NiU/Zghuw5557j9ur81tz4gFgxYoVZtyqdx8/ftxse/nyZTPubbNt1fJTt+D2avEp8/m931lKDk3GV36ioJj8REEx+YmCYvITBcXkJwqKyU8UFJOfKCi3zi8iuwF8A8AFVX0ku+1eAL8HcB+A0wCeUdVLqZ1JnXNfVFuPV/P1auHeGIOUdfu9mrFXz/biXV1dZry/v7/htt6eAt6ce6u9N98+dWxGyj4QqceuVz2v/L8G8NQdtz0H4A1VfRDAG9nPRHQXcZNfVd8CcPGOm7cA2JN9vwfA003uFxEVrNHP/EtVdRAAsq/dzesSEZWh8D/4icgOEekTkT5vTzsiKk+jyX9eRHoAIPt6Ie+OqrpLVXtVtbetra3B0xFRszWa/PsBbM++3w7g1eZ0h4jK4ia/iLwC4CCAtSLSLyLfAfACgCdF5ASAJ7Ofiegu4tb5VXVbTuir0z2ZiCSt2+/VfYuUMmfeqyl7dV1rj3uvvTev3Lvm3rr8nZ2dZnxwcDA39sADD5htU8dmpNTDU2vtKetHeOdOyaHbjlP3PYnorwqTnygoJj9RUEx+oqCY/ERBMfmJgmqppbtT2hY5ZRcAzp49mxs7evSo2XbTpk1m3JvS65XrrOmn3pDq0dFRMz5v3jwz7l33BQsW5Ma8EubixYvNuDeVes6cObkxb5p0kaU8L55S6psOvvITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioJj8REGVXue3pCzd7bVN2YIbAEZGRnJj3tLc3jbYXh3f67s1Zdir03s1ZY835Xf16tW5seXLl5ttvSm/hw8fNuPWY7927ZrZNnUZ+SKXmbfGGHBKLxG5mPxEQTH5iYJi8hMFxeQnCorJTxQUk58oqNLr/FbN2qt3pyxR7W01bc39Buzaqlfn92rp3pLkXu3WOn57e3uh5/aWJbe24ba27waAtWvXmnGvzm/x5vN7iqzzp6wFMK3zNOUoRHTXYfITBcXkJwqKyU8UFJOfKCgmP1FQTH6ioNw6v4jsBvANABdU9ZHstucBfBfAJ9nddqrqa3Ucy6xJe7V4q9aeMkagnvaW8fHxpGNb6+7XE+/o6MiNefVsr87vPTaP1ffjx4+bbRctWmTGveeL9di98Qme1Dp/yjbbzdqjop5X/l8DeGqK23+uquuzf27iE1FrcZNfVd8CcLGEvhBRiVI+8z8rIkdFZLeI2O/PiKjlNJr8vwSwBsB6AIMAfpp3RxHZISJ9ItLn7RtHROVpKPlV9byq3lDVmwB+BWCDcd9dqtqrqr1tbW2N9pOImqyh5BeRnkk/bgXwfnO6Q0RlqafU9wqAxwF0iUg/gB8BeFxE1gNQAKcBfK/APhJRAdzkV9VtU9z8UiMn8+r8KXuie229WrvX3up3ap3fW1vfW2vAqmen7gngPbbr16+bcatvXh2/p6fHjKf03dtLIVXKnPsixwjcdpy670lEf1WY/ERBMfmJgmLyEwXF5CcKislPFFRLLd3tlUdSSn1e3CuRWKMTL1++bLb1HpdXLkuZuupN2fWui9d3r5RYq9VyY0888YTZ1ttG22NtH+4tae6VAr3r6sWt4xddhryFr/xEQTH5iYJi8hMFxeQnCorJTxQUk58oKCY/UVCl1/lTaphWLd6rR6dMk6ynvSVlujDg14wtVq0b8H8f3hiEkZERM/7xxx/nxu6//36z7fDwsBn3WMvGeUuae/HUcQApba34dM7LV36ioJj8REEx+YmCYvITBcXkJwqKyU8UFJOfKKiWqvN7tdOUtinzq732o6OjZluvXu2NMfDaW0t/e4/LGwfg8ZYVt46/YMECs603hsC7LtbS3d6S5Kl1/pR46hiEevGVnygoJj9RUEx+oqCY/ERBMfmJgmLyEwXF5CcKyq3zi8gqAC8DWAbgJoBdqvqiiNwL4PcA7gNwGsAzqnrJOpaqJs1FTqnzp9ZtrXX7vbbeFtxee2/9euuxebXyoaEhM97R0WHGT5w4Yca7urpyY/Pnzzfben1LqfN74xu8dQy8cQIpW5unnLvZ8/nHAfxQVR8C8HcAvi8iDwN4DsAbqvoggDeyn4noLuEmv6oOquo72fc1AMcArACwBcCe7G57ADxdVCeJqPmm9ZlfRO4D8BUAfwawVFUHgYn/IAB0N7tzRFScupNfROYD+AOAH6jqlWm02yEifSLSZ62pRkTlqiv5RWQWJhL/N6r6x+zm8yLSk8V7AFyYqq2q7lLVXlXttf5oRkTlcpNfJpatfQnAMVX92aTQfgDbs++3A3i1+d0joqLUM6V3E4BvA3hPRI5kt+0E8AKAvSLyHQBnAHyznhNapYgqp/R6pUBr+W2vJOVt4b1y5Uoz7k19tbbw9pYN98pKXpnx4MGDZnzjxo25sUcffdRse+WK/enSW0794sWLubGFCxeabb3rUmQpMLUsXS83+VX1AIC8q/zVpvSCiErHEX5EQTH5iYJi8hMFxeQnCorJTxQUk58oqNKX7rZqmF5t1KrrerVR79hePdxaXtur0+/du9eMr1692oy/++67Znzx4sW5MW+MQGdnpxn3rotnbGwsN+YtWe4NB581a1bD5547d27DbYH055s1TiBlWXFu0U1ELiY/UVBMfqKgmPxEQTH5iYJi8hMFxeQnCqrUOr+3dHfK8ttF1/mtJazXrVtntvWW7vZ4NWmLNacdAAYGBsz4zJn2U8RaSwCw5+R7dX5v6/PPPvvMjNdqtdzY+fPnzbbLly8346nPN6u9t1YA6/xElITJTxQUk58oKCY/UVBMfqKgmPxEQTH5iYJqqfn8Xn0zZT5/al3Wqp96a9unbLGdGvfq8N6xvTn1c+bMMePW79Qb1+E9Hx566CEzbu0Q1dPTY7Yt+vmUMmalWev285WfKCgmP1FQTH6ioJj8REEx+YmCYvITBcXkJwrKrfOLyCoALwNYBuAmgF2q+qKIPA/guwA+ye66U1Vfs46lqmYN09tvPWVP8yLr/AcOHDDbXrp0yYxb6+4D/jgBbx97i1ULB/w59Sk155S17QFgyZIlZry9vT03VvTzJSVe1rr99QzyGQfwQ1V9R0Q6ABwWkdez2M9V9Sd1n42IWoab/Ko6CGAw+74mIscArCi6Y0RUrGl95heR+wB8BcCfs5ueFZGjIrJbRBbltNkhIn0i0ucNFSWi8tSd/CIyH8AfAPxAVa8A+CWANQDWY+KdwU+naqequ1S1V1V7vc+XRFSeupJfRGZhIvF/o6p/BABVPa+qN1T1JoBfAdhQXDeJqNnc5JeJP8G/BOCYqv5s0u2Tp0VtBfB+87tHREWp56/9mwB8G8B7InIku20ngG0ish6AAjgN4Hv1nNAroVispZ5TSy9emdEqaXnbXG/ZssWMb9y40Yx718wq7wwNDZltvaW7vfbe33GsUqFXwvS2yfbOffXqVTNu8ZZyL7JUmHrsetXz1/4DAKbKDLOmT0StjSP8iIJi8hMFxeQnCorJTxQUk58oKCY/UVClb9Ft1cunMx3xTl6d3tsO2lvi2jr+ypUrzbZezdirV3vDoq2+e9uDd3V1mXGvpuw9tkOHDuXGvKnI3jiAkZERM+6N7bB4j8ubypwyhsHrd7NyiK/8REEx+YmCYvITBcXkJwqKyU8UFJOfKCgmP1FQklJbn/bJRD4B8NGkm7oAfFpaB6anVfvWqv0C2LdGNbNvf6Oq9prmmVKT/wsnF+lT1d7KOmBo1b61ar8A9q1RVfWNb/uJgmLyEwVVdfLvqvj8llbtW6v2C2DfGlVJ3yr9zE9E1an6lZ+IKlJJ8ovIUyJyXEROishzVfQhj4icFpH3ROSIiPRV3JfdInJBRN6fdNu9IvK6iJzIvk65TVpFfXteRAaya3dERL5eUd9Wich/icgxEflARP4pu73Sa2f0q5LrVvrbfhGZAeD/ADwJoB/A2wC2qer/lNqRHCJyGkCvqlZeExaRfwAwDOBlVX0ku+1fAFxU1Rey/zgXqeo/t0jfngcwXPXOzdmGMj2Td5YG8DSAf0SF187o1zOo4LpV8cq/AcBJVT2lqmMAfgfA3tUiKFV9C8DFO27eAmBP9v0eTDx5SpfTt5agqoOq+k72fQ3ArZ2lK712Rr8qUUXyrwBwdtLP/WitLb8VwJ9E5LCI7Ki6M1NYmm2bfmv79O6K+3Mnd+fmMt2xs3TLXLtGdrxutiqSf6r1sFqp5LBJVf8WwGYA38/e3lJ96tq5uSxT7CzdEhrd8brZqkj+fgCrJv28EsC5CvoxJVU9l329AGAfWm/34fO3NknNvl6ouD9/0Uo7N0+1szRa4Nq10o7XVST/2wAeFJEvichsAN8CsL+CfnyBiLRnf4iBiLQD+Bpab/fh/QC2Z99vB/BqhX25Tavs3Jy3szQqvnattuN1JYN8slLGLwDMALBbVX9ceiemICL3Y+LVHphY2fi3VfZNRF4B8DgmZn2dB/AjAP8OYC+A1QDOAPimqpb+h7ecvj2Oibeuf9m5+dZn7JL79vcA/hvAewBuLXW7ExOfryu7dka/tqGC68YRfkRBcYQfUVBMfqKgmPxEQTH5iYJi8hMFxeQnCorJTxQUk58oqP8HACBxv22h61sAAAAASUVORK5CYII=\n", 631 | "text/plain": [ 632 | "
" 633 | ] 634 | }, 635 | "metadata": {}, 636 | "output_type": "display_data" 637 | } 638 | ], 639 | "source": [ 640 | "# Incorrect prediction\n", 641 | "evaluate(6)" 642 | ] 643 | } 644 | ], 645 | "metadata": { 646 | "kernelspec": { 647 | "display_name": "Python 3", 648 | "language": "python", 649 | "name": "python3" 650 | }, 651 | "language_info": { 652 | "codemirror_mode": { 653 | "name": "ipython", 654 | "version": 3 655 | }, 656 | "file_extension": ".py", 657 | "mimetype": "text/x-python", 658 | "name": "python", 659 | "nbconvert_exporter": "python", 660 | "pygments_lexer": "ipython3", 661 | "version": "3.6.6" 662 | } 663 | }, 664 | "nbformat": 4, 665 | "nbformat_minor": 2 666 | } 667 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anhquan0412/basic_model_scratch/fd83305ae69460cf7f2ea0f4c2bbbf633eed652c/model/__init__.py -------------------------------------------------------------------------------- /model/activation_classes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class Sigmoid(): 3 | def __call__(self,x): 4 | return 1/(1+np.exp(-x)) 5 | def grad(self,x): 6 | x_acted = self.__call__(x) 7 | return x_acted*(1-x_acted) 8 | 9 | class Softmax(): 10 | def __call__(self,x): 11 | # return np.exp(x) / np.sum(np.exp(x), axis=1)[:,None] 12 | 13 | #this is more stable. Avoid np.exp overflow problem 14 | e_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) 15 | return e_x / np.sum(e_x, axis=-1, keepdims=True) 16 | 17 | class Tanh(): 18 | def __call__(self, x): 19 | return 2/(1 + np.exp(-2*x)) - 1 20 | 21 | def grad(self, x): 22 | return 1 - np.power(self.__call__(x), 2) 23 | 24 | class ReLU(): 25 | def __call__(self, x): 26 | # return np.maximum(x,0) 27 | return np.where(x >= 0, x, 0) 28 | 29 | def grad(self, x): 30 | return np.where(x >= 0, 1, 0) 31 | 32 | class LeakyReLU(): 33 | def __init__(self, alpha=0.2): 34 | self.alpha = alpha 35 | 36 | def __call__(self, x): 37 | return np.where(x >= 0, x, self.alpha * x) 38 | 39 | def grad(self, x): 40 | return np.where(x >= 0, 1, self.alpha) 41 | -------------------------------------------------------------------------------- /model/activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def sigmoid(x): 3 | return 1/(1+np.exp(-x)) 4 | -------------------------------------------------------------------------------- /model/gradients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import onehot_array 3 | def MSE_grad(y,y_pred): 4 | ''' 5 | Derivative of MSE loss w.r.t y_pred (not w) 6 | ''' 7 | return (2/len(y))*(y_pred-y) 8 | 9 | def logloss_sigmoid_grad(y,y_pred): 10 | ''' 11 | Derivative of sigmoid + log loss combination is equivalent to derivative of MSE 12 | ''' 13 | return MSE_grad(y,y_pred)/2 14 | 15 | def logloss_softmax_grad(y,y_pred): 16 | y_onehot = onehot_array(y,y_pred.shape[1]) 17 | return (1/len(y_pred)) * (y_pred - y_onehot) -------------------------------------------------------------------------------- /model/knn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class CustomNearestNeighbor(): 3 | def __init__(self,k): 4 | self.k = k 5 | self.eps=1e-6 6 | def fit(self,X,y=None): 7 | self.X_train = np.array(X) 8 | if y is not None: 9 | self.y_train = y 10 | self.n_classes = np.unique(y).shape[0] 11 | 12 | def kneighbors(self,X): 13 | ''' 14 | Return sorted k distance and k indices of input X 15 | ''' 16 | X = np.array(X) 17 | dist = np.zeros([len(X),self.k]) 18 | idxs = np.zeros([len(X),self.k],dtype=int) 19 | #euclidian distance 20 | for i,x in enumerate(X): 21 | temp_dist = np.linalg.norm(x - self.X_train,axis=1) 22 | idxs[i] = (np.argsort(temp_dist)[:self.k]) 23 | dist[i] = temp_dist[idxs[i]] 24 | return [dist,idxs] 25 | def predict_classification(self,X_test,weighted=False): 26 | if not hasattr(self,'y_train'): 27 | raise ValueError('y and n_class are undefined') 28 | 29 | dist,idxs = self.kneighbors(X_test) 30 | inv_dist = 1/(dist+ self.eps) # eps to avoid divided by 0 31 | 32 | y_pred = np.zeros(idxs.shape[0]) 33 | wc = np.zeros([idxs.shape[0],self.n_classes]) 34 | class_sorted = np.zeros([idxs.shape[0],self.n_classes],dtype=int) 35 | for i,idx in enumerate(idxs): 36 | # calculate weighted count (wc) 37 | class_counter = np.bincount(self.y_train[idx],weights=inv_dist[i] if weighted else None,minlength=self.n_classes) 38 | class_sorted[i] = (np.argsort(class_counter)[::-1]) 39 | wc[i] = class_counter 40 | y_pred[i]=class_sorted[i][0] 41 | return y_pred,wc,class_sorted -------------------------------------------------------------------------------- /model/linear_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import get_train_val,batch_iterator,plot_learning_curve 3 | from model.metrics import MSE 4 | from model.gradients import MSE_grad 5 | 6 | def initialize_weight(dim): 7 | W0 = np.array([[0]]) # bias, 1x1 8 | W= np.random.rand(dim,1) 9 | return np.concatenate((W0,W)) 10 | 11 | class CustomLinearModel(): 12 | def __init__(self,dim,is_reg,loss_fn,grad_fn,act_fn = lambda x: x): 13 | self.dim,self.act_fn,self.loss_fn,self.grad_fn,self.is_reg = dim,act_fn,loss_fn,grad_fn,is_reg 14 | 15 | self.W = initialize_weight(self.dim) 16 | self.train_losses=[] 17 | self.val_losses=[] 18 | def fit(self,X,y,lr,l2=0,n_iteration=50,val_ratio=.2): 19 | ''' 20 | Fit data using gradient descent and l2 regularization 21 | ''' 22 | X_train,y_train,X_val,y_val = get_train_val(X,y,val_ratio) 23 | for i in range(n_iteration): 24 | y_pred = self.act_fn(np.squeeze(X_train @ self.W)) 25 | # MSE loss for regression 26 | loss = self.loss_fn(y_train,y_pred) 27 | grad = self.grad_fn(y_train,y_pred) # shape (n,) 28 | grad_w = X_train.T @ grad # shape (dim,) 29 | 30 | if len(grad_w.shape)==1: grad_w = grad_w[:,None] # turn (dim,) to (dim,1) 31 | #ignore update of grad_w0 (bias term) since w0 does not contribute to regularization process 32 | grad_w[1:,:]+= 2*(l2/len(X_train))*self.W[1:,:] # (2 *lambda / m)* weight 33 | 34 | self.W-= lr*grad_w 35 | 36 | #save training loss 37 | self.train_losses.append(loss) 38 | #predict validation set 39 | y_pred = self.act_fn(np.squeeze(X_val @ self.W)) 40 | val_loss = self.loss_fn(y_val,y_pred) 41 | self.val_losses.append(val_loss) 42 | if (i+1) % 20 == 0: 43 | print(f'{i+1}. Training loss: {loss}, Val loss:{val_loss}') 44 | 45 | plot_learning_curve(self.train_losses,self.val_losses) 46 | 47 | def fit_epoch(self,X,y,lr,epochs,bs,l2=0,val_ratio=0.2): 48 | ''' 49 | Fit data using stochastic gradient descent and l2 regularization 50 | ''' 51 | X_train,y_train,X_val,y_val = get_train_val(X,y,val_ratio) 52 | for epoch in range(epochs): 53 | train_cumloss,val_cumloss = 0,0 54 | # get batch from train set 55 | for xb,yb in batch_iterator(X_train,y_train,bs): 56 | y_pred = self.act_fn(np.squeeze(xb @ self.W)) 57 | train_cumloss+= self.loss_fn(yb,y_pred) * len(xb) 58 | 59 | grad = self.grad_fn(yb,y_pred) 60 | grad_w = xb.T @ grad 61 | if len(grad_w.shape)==1: grad_w = grad_w[:,None] 62 | grad_w[1:,:]+= 2*(l2/len(xb))*self.W[1:,:] 63 | self.W-= lr*grad_w 64 | 65 | # get double of bs from validation set (since there's less calculation for prediction) 66 | for xb,yb in batch_iterator(X_val,y_val,bs*2): 67 | y_pred = self.act_fn(np.squeeze(xb @ self.W)) 68 | val_cumloss += self.loss_fn(yb,y_pred) * len(xb) 69 | 70 | self.train_losses.append(train_cumloss/ len(X_train)) 71 | self.val_losses.append(val_cumloss / len(X_val)) 72 | print(f'Epoch {epoch+1}. Training loss: {self.train_losses[-1]}, Val loss:{self.val_losses[-1]}') 73 | plot_learning_curve(self.train_losses,self.val_losses) 74 | 75 | def get_weight(self): 76 | return self.W 77 | def predict(self,X,thres=0.5): 78 | if X.shape[1] == self.dim: 79 | X0 = np.array([[1]*X.shape[0]]).T # nx1 80 | X = np.concatenate((X0,X),axis=1) 81 | y_pred= self.act_fn(np.squeeze(X @ self.W)) 82 | if not self.is_reg: 83 | y_pred = (y_pred >= thres).astype(np.uint8) 84 | return y_pred 85 | def predict_proba(self,X): 86 | if self.is_reg: 87 | raise Exception('Cannot predict probability for regression') 88 | if X.shape[1] == self.dim: 89 | X0 = np.array([[1]*X.shape[0]]).T # nx1 90 | X = np.concatenate((X0,X),axis=1) 91 | return self.act_fn(np.squeeze(X @ self.W)) -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import onehot_array 3 | def MSE(y,y_pred): 4 | return np.mean((y_pred -y)**2) 5 | def logloss(y,y_pred): 6 | y_pred= np.clip(y_pred,1e-5,1-1e-5) 7 | return np.mean(-np.log(y_pred)*y - np.log(1-y_pred)*(1-y)) 8 | def multi_logloss(y,y_pred): 9 | y_pred= np.clip(y_pred,1e-5,1-1e-5) 10 | y_onehot = onehot_array(y,y_pred.shape[1]) 11 | return -np.mean(np.log(np.sum(y_onehot * y_pred,axis=1))) 12 | def accuracy(y,y_pred): 13 | return np.mean(y==y_pred) -------------------------------------------------------------------------------- /model/neural_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.utils import get_train_val,batch_iterator,plot_learning_curve 3 | from model.metrics import multi_logloss 4 | from model.gradients import MSE_grad 5 | from model.activation_classes import Softmax,ReLU,LeakyReLU 6 | from model.gradients import logloss_softmax_grad 7 | from IPython.core.debugger import set_trace 8 | from model.optimizers import GradientDescent 9 | 10 | def init_weight(shape): 11 | ''' 12 | Kaiming He normal initialization 13 | ''' 14 | np.random.seed(42) 15 | return [np.random.uniform(size=shape) * np.sqrt(2/shape[0]), np.zeros((1,shape[1]))] 16 | 17 | 18 | def drop_out(X_act,keep_prob): 19 | ''' 20 | Inverted dropout implementation 21 | ''' 22 | mask = np.random.rand(*X_act.shape) <= keep_prob 23 | X_act= (X_act*mask)/keep_prob 24 | return X_act,mask 25 | class CustomNeuralNetwork(): 26 | ''' 27 | Simple neural network for binary classification 28 | ''' 29 | def __init__(self,layers,act_obj,opt= GradientDescent,keep_prob=1.0): 30 | ''' 31 | Layers include output layer 32 | E.g for 10 output classification, input layer size 400 and 1 hidden layer size 200: [400,200,10] 33 | 34 | act_obj: object (or list of objects) from activation_classes module to apply before reaching each hidden layer (exclude the last SoftMax activation before loss calculation) 35 | keep_prob: keep probability (or list) for drop out at each hidden layer. Note that we don't drop out the first layer 36 | ''' 37 | self.act_objs= [act_obj]*(len(layers)-2) if type(act_obj)!=list else act_obj 38 | self.keep_probs = [keep_prob]*(len(layers)-2) if type(keep_prob)!=list else keep_prob 39 | # list of [weight,bias] 40 | self.weights = [init_weight((layers[i],layers[i+1])) for i in range(len(layers)-1)] 41 | self.opt = opt(layers) 42 | assert len(self.act_objs) == len(self.keep_probs),'# of activation objs and # of keep probs must be equal' 43 | assert len(self.act_objs) == len(layers)-2, 'We only need (# of hidden layers) activation objects' 44 | 45 | # class params 46 | self.train_losses,self.val_losses=[],[] 47 | self.X_inputs,self.X_acts,self.X_masks=[],[],[] 48 | 49 | def forward_pass(self,X,train): 50 | if train: 51 | self.X_inputs,self.X_acts,self.X_masks = [X],[X],[] 52 | inp = X 53 | for i,w in enumerate(self.weights): 54 | inp = inp @ w[0] + w[1] 55 | if idropout) 91 | grad_wrt_input = grad_wrt_input * self.X_masks[i] / self.keep_probs[i] # grad_wrt_activation, after factoring in the inverted dropout 92 | grad_wrt_input = grad_wrt_input * self.act_objs[i].grad(self.X_inputs[i+1]) # actual grad_wrt_input 93 | 94 | grad_wbias = np.sum(grad_wrt_input,axis=0) # (1,200) 95 | 96 | grad_w = self.X_acts[i].T @ grad_wrt_input # (400,200) 97 | 98 | grad_w,grad_wbias = self.opt.step(grad_w,grad_wbias,i,**kwargs) 99 | self.weights[i][0]-= lr*(grad_w + (l2/bs) * self.weights[i][0]) #update weight 100 | self.weights[i][1]-= lr*grad_wbias #update bias 101 | 102 | def fit_epoch(self,X_train,y_train,X_val,y_val,lr,epochs,bs=64,l2=0,beta1=0.9,beta2=0.99): 103 | ''' 104 | Fit data using stochastic gradient descent and l2 regularization 105 | ''' 106 | # set_trace() 107 | for epoch in range(epochs): 108 | train_cumloss,val_cumloss = 0,0 109 | # get batch from train set 110 | for xb,yb in batch_iterator(X_train,y_train,bs): 111 | y_pred = self.forward_pass(xb,True) 112 | train_cumloss+= multi_logloss(yb,y_pred) * len(xb) 113 | 114 | self.backward_pass(yb,y_pred,l2,lr,beta1=beta1,beta2=beta2) 115 | 116 | # get double of bs from validation set (since there's less calculation for prediction) 117 | for xb,yb in batch_iterator(X_val,y_val,bs*2): 118 | y_pred = self.forward_pass(xb,False) 119 | val_cumloss += multi_logloss(yb,y_pred) * len(xb) 120 | 121 | self.train_losses.append(train_cumloss/ len(X_train)) 122 | self.val_losses.append(val_cumloss / len(X_val)) 123 | print(f'Epoch {epoch+1}. Training loss: {self.train_losses[-1]}, Val loss:{self.val_losses[-1]}') 124 | plot_learning_curve(self.train_losses,self.val_losses) 125 | 126 | def predict(self,X): 127 | y_proba = self.forward_pass(X,False) 128 | return np.argmax(y_proba,axis=1) 129 | def predict_proba(self,X): 130 | return self.forward_pass(X,False) -------------------------------------------------------------------------------- /model/optimizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def init_exp_grad(shape): 4 | ''' 5 | Initialize value for exponential weight average 6 | ''' 7 | return [np.zeros(shape), np.zeros((1,shape[1]))] 8 | def bias_correction(t,beta): 9 | temp= max(beta**t,10e-6) 10 | return 1/(1-temp) 11 | class GradientDescent(): 12 | def __init__(self,layers): 13 | pass 14 | def step(self,grad_w,grad_wbias,layer,**kwargs): 15 | return grad_w,grad_wbias 16 | 17 | class Momentum(): 18 | def __init__(self,layers): 19 | self.exp_grad = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 20 | self.t=[0]*(len(layers)-1) 21 | def step(self,grad_w,grad_wbias,layer,**kwargs): 22 | beta1 = kwargs['beta1'] 23 | self.t[layer]+=1 24 | self.exp_grad[layer][0] = beta1*self.exp_grad[layer][0] + (1-beta1)*grad_w 25 | self.exp_grad[layer][1] = beta1*self.exp_grad[layer][1] + (1-beta1)*grad_wbias 26 | bias_corr = bias_correction(self.t[layer],beta1) 27 | return bias_corr*self.exp_grad[layer][0],bias_corr*self.exp_grad[layer][1] 28 | 29 | class RMSProp(): 30 | def __init__(self,layers): 31 | self.exp_grad_sqr = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 32 | self.t=[0]*(len(layers)-1) 33 | def step(self,grad_w,grad_wbias,layer,**kwargs): 34 | beta2 = kwargs['beta2'] 35 | eps = 10e-8 36 | self.t[layer]+=1 37 | self.exp_grad_sqr[layer][0] = beta2*self.exp_grad_sqr[layer][0] + (1-beta2)* grad_w**2 38 | self.exp_grad_sqr[layer][1] = beta2*self.exp_grad_sqr[layer][1] + (1-beta2)* grad_wbias**2 39 | bias_corr = bias_correction(self.t[layer],beta2) 40 | new_gradw = grad_w / (np.sqrt(bias_corr*self.exp_grad_sqr[layer][0]) + eps) 41 | new_gradwb = grad_wbias / (np.sqrt(bias_corr*self.exp_grad_sqr[layer][1]) + eps) 42 | return new_gradw,new_gradwb 43 | 44 | class Adam(): 45 | def __init__(self,layers): 46 | self.exp_grad = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 47 | self.exp_grad_sqr = [init_exp_grad((layers[i],layers[i+1])) for i in range(len(layers)-1)] 48 | self.t=[0]*(len(layers)-1) 49 | def step(self,grad_w,grad_wbias,layer,**kwargs): 50 | beta1,beta2 = kwargs['beta1'],kwargs['beta2'] 51 | eps = 10e-8 52 | self.t[layer]+=1 53 | 54 | self.exp_grad[layer][0] = beta1*self.exp_grad[layer][0] + (1-beta1)*grad_w 55 | self.exp_grad[layer][1] = beta1*self.exp_grad[layer][1] + (1-beta1)*grad_wbias 56 | 57 | self.exp_grad_sqr[layer][0] = beta2*self.exp_grad_sqr[layer][0] + (1-beta2)* grad_w**2 58 | self.exp_grad_sqr[layer][1] = beta2*self.exp_grad_sqr[layer][1] + (1-beta2)* grad_wbias**2 59 | 60 | bias_corr1 = bias_correction(self.t[layer],beta1) 61 | bias_corr2 = bias_correction(self.t[layer],beta2) 62 | 63 | new_gradw = (bias_corr1*self.exp_grad[layer][0]) / (np.sqrt(bias_corr2*self.exp_grad_sqr[layer][0]) + eps) 64 | new_gradwb = (bias_corr1*self.exp_grad[layer][1]) / (np.sqrt(bias_corr2*self.exp_grad_sqr[layer][1]) + eps) 65 | return new_gradw,new_gradwb 66 | -------------------------------------------------------------------------------- /model/random_forest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from model.metrics import logloss, MSE 3 | 4 | def var_agg(n,s,s_squared): return (s_squared/n) - (s/n)**2 5 | 6 | class RandomForest(): 7 | def __init__(self, X, y, n_trees, sample_sz,is_reg=True, min_leaf=3,max_features=1): 8 | np.random.seed(42) 9 | if hasattr(y,'values'): y = y.values 10 | if hasattr(X,'values'): X = X.values 11 | self.X,self.y,self.sample_sz,self.min_leaf = X,y,sample_sz,min_leaf 12 | self.trees = [self.create_tree(is_reg,min_leaf,max_features) for i in range(n_trees)] # store roots of n_trees decision trees 13 | self.is_reg = is_reg 14 | 15 | def create_tree(self,is_reg,min_leaf,max_features): 16 | # generate random idxs with size sample_sz 17 | sample_idxs = np.random.permutation(len(self.y))[:self.sample_sz] 18 | return DecisionTreeNode(self.X[sample_idxs,:], self.y[sample_idxs], is_reg,min_leaf,max_features) 19 | def predict(self, X,thres=0.5): 20 | if hasattr(X,'values'): X = X.values 21 | y_pred = np.mean([t.predict(X) for t in self.trees], axis=0) 22 | if not self.is_reg: 23 | return (y_pred >= thres).astype(np.uint8) 24 | return y_pred 25 | def predict_proba(self,X): 26 | if self.is_reg: 27 | raise Exception('Cannot predict probability for regression') 28 | if hasattr(X,'values'): X = X.values 29 | return np.mean([t.predict(X) for t in self.trees], axis=0) 30 | 31 | 32 | class DecisionTreeNode(): 33 | def __init__(self, X, y, is_reg,min_leaf,max_features): 34 | self.X,self.y,self.min_leaf,self.max_features,self.is_reg = X,y,min_leaf,max_features,is_reg 35 | self.n,self.c = len(y), X.shape[1] 36 | if self.X.shape[0] != self.n: 37 | raise ValueError('X and y don\'t have the same size') 38 | self.val = np.mean(y) 39 | 40 | # Metric (loss score) 41 | self.score = float('inf') # initialize to infinity for a leaf 42 | 43 | self.col_idx= -1 # index of column chosen to split 44 | self.split_value = None # chosen split value from col with col_idx 45 | 46 | self.lhs_tree_node = None 47 | self.rhs_tree_node = None 48 | 49 | self.find_varsplit() # find best split and populate lhs + rhs tree node 50 | 51 | 52 | def find_varsplit(self): 53 | # exit clause when best split has been made from parent node 54 | if len(np.unique(self.y)) == 1: return 55 | 56 | # Assuming max_feature = self.c, as we consider all features for splitting 57 | n_col = int(self.c*self.max_features) 58 | 59 | for i in np.random.permutation(n_col): 60 | self.find_best_split_reg(i) if self.is_reg else self.find_best_split_clas(i) 61 | if self.is_leaf: return # exit clause when no split is made, because of min_leaf 62 | split_col = self.split_col 63 | lhs_idx = np.nonzero(split_col<=self.split_value)[0] 64 | rhs_idx = np.nonzero(split_col>self.split_value)[0] 65 | 66 | self.lhs_tree_node = DecisionTreeNode(self.X[lhs_idx,:], self.y[lhs_idx],self.is_reg,self.min_leaf,self.max_features) 67 | self.rhs_tree_node = DecisionTreeNode(self.X[rhs_idx,:], self.y[rhs_idx],self.is_reg,self.min_leaf,self.max_features) 68 | 69 | def find_best_split_reg(self, col_idx): 70 | x = self.X[:,col_idx] 71 | y = self.y 72 | 73 | sort_idx = np.argsort(x) 74 | sort_x,sort_y = x[sort_idx],y[sort_idx] 75 | rhs_cnt,rhs_sum,rhs_sum2 = self.n,sort_y.sum(), (sort_y**2).sum() 76 | lhs_cnt,lhs_sum,lhs_sum2=0,0.0,0.0 77 | for i in range(0,self.n- self.min_leaf): 78 | xi,yi = sort_x[i],sort_y[i] 79 | lhs_cnt += 1; rhs_cnt -= 1 80 | lhs_sum += yi; rhs_sum -= yi 81 | lhs_sum2 += yi**2; rhs_sum2 -= yi**2 82 | if i=0.5).astype(np.uint8) 41 | return X,y,W 42 | 43 | def plot_learning_curve(train_losses,val_losses): 44 | plt.plot(range(len(train_losses)),train_losses,'o-',color='r',label='Training loss',markersize=1) 45 | plt.plot(range(len(train_losses)),val_losses,'o-',color='g',label='Validation loss',markersize=1) 46 | plt.legend(loc="best") 47 | plt.show() 48 | 49 | def plot_feature_importances_rf(importances,col_names,figsize=(20,10)): 50 | fea_imp_df = pd.DataFrame(data={'Feature':col_names,'Importance':importances}).set_index('Feature') 51 | fea_imp_df = fea_imp_df.sort_values('Importance', ascending=True) 52 | fea_imp_df.plot(kind='barh',figsize=figsize) 53 | return fea_imp_df 54 | def permutation_importances(rf,X,y,metric,lowerisbetter=True): 55 | baseline = metric(rf,X,y) 56 | imp=[] 57 | for col in X.columns: 58 | save = X[col].copy() 59 | X[col] = np.random.permutation(X[col]) 60 | m = metric(rf,X,y) 61 | X[col] = save 62 | if lowerisbetter: 63 | imp.append(m-baseline) 64 | else: 65 | imp.append(baseline-m) 66 | fea_imp = np.array(imp) 67 | 68 | return plot_feature_importances_rf(fea_imp,X.columns.values) 69 | 70 | def draw_tree(t, df, size=10, ratio=0.6, precision=0): 71 | """ Draws a representation of a random forest in IPython. 72 | Parameters: 73 | ----------- 74 | t: The tree you wish to draw 75 | df: The data used to train the tree. This is used to get the names of the features. 76 | """ 77 | s=export_graphviz(t, out_file=None, feature_names=df.columns, filled=True, 78 | special_characters=True, rotate=True, precision=precision) 79 | IPython.display.display(graphviz.Source(re.sub('Tree {', 80 | f'Tree {{ size={size}; ratio={ratio}', s))) 81 | 82 | 83 | -------------------------------------------------------------------------------- /random_forest_regressor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Goals: rebuild Random Forest Regression from scratch\n", 8 | "\n", 9 | "Use only python and some basic numpy function (numpy slicing, np.mean, np.nonzero)\n", 10 | "\n", 11 | "Produce comparable results to Sklearn Ramdon Forest on some regression dataset" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%load_ext autoreload\n", 21 | "%autoreload 2\n", 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 12, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "\n", 32 | "import pandas as pd\n", 33 | "import numpy as np\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "from model.utils import draw_tree\n", 36 | "from model.metrics import *\n", 37 | "np.random.seed(42)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Generate a random-generated dataset" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 9, 50 | "metadata": { 51 | "scrolled": true 52 | }, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | " f0 f1 y\n", 59 | "0 2 9 5\n", 60 | "1 6 4 7\n", 61 | "2 9 7 1\n", 62 | "3 1 9 9\n", 63 | "4 4 9 3\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "nrows=20\n", 69 | "df = pd.DataFrame(np.random.randint(1,10,(nrows,3)),columns=['f0','f1','y'])\n", 70 | "print(df.head())\n", 71 | "y = df.y.values\n", 72 | "df.drop('y',axis=1,inplace=True)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Check whether my RF splits are the same as sklearn Random Forest's" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 31, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from sklearn.ensemble import RandomForestRegressor\n", 89 | "from sklearn import metrics" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 10, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "data": { 99 | "image/svg+xml": [ 100 | "\r\n", 101 | "\r\n", 103 | "\r\n", 105 | "\r\n", 106 | "\r\n", 108 | "\r\n", 109 | "Tree\r\n", 110 | "\r\n", 111 | "\r\n", 112 | "0\r\n", 113 | "\r\n", 114 | "f1 ≤ 5.5\r\n", 115 | "mse = 7.628\r\n", 116 | "samples = 20\r\n", 117 | "value = 4.35\r\n", 118 | "\r\n", 119 | "\r\n", 120 | "1\r\n", 121 | "\r\n", 122 | "f1 ≤ 4.5\r\n", 123 | "mse = 4.36\r\n", 124 | "samples = 10\r\n", 125 | "value = 2.8\r\n", 126 | "\r\n", 127 | "\r\n", 128 | "0->1\r\n", 129 | "\r\n", 130 | "\r\n", 131 | "True\r\n", 132 | "\r\n", 133 | "\r\n", 134 | "6\r\n", 135 | "\r\n", 136 | "f1 ≤ 6.5\r\n", 137 | "mse = 6.09\r\n", 138 | "samples = 10\r\n", 139 | "value = 5.9\r\n", 140 | "\r\n", 141 | "\r\n", 142 | "0->6\r\n", 143 | "\r\n", 144 | "\r\n", 145 | "False\r\n", 146 | "\r\n", 147 | "\r\n", 148 | "2\r\n", 149 | "\r\n", 150 | "f0 ≤ 3.5\r\n", 151 | "mse = 4.245\r\n", 152 | "samples = 7\r\n", 153 | "value = 3.571\r\n", 154 | "\r\n", 155 | "\r\n", 156 | "1->2\r\n", 157 | "\r\n", 158 | "\r\n", 159 | "\r\n", 160 | "\r\n", 161 | "5\r\n", 162 | "\r\n", 163 | "mse = 0.0\r\n", 164 | "samples = 3\r\n", 165 | "value = 1.0\r\n", 166 | "\r\n", 167 | "\r\n", 168 | "1->5\r\n", 169 | "\r\n", 170 | "\r\n", 171 | "\r\n", 172 | "\r\n", 173 | "3\r\n", 174 | "\r\n", 175 | "mse = 2.667\r\n", 176 | "samples = 3\r\n", 177 | "value = 4.0\r\n", 178 | "\r\n", 179 | "\r\n", 180 | "2->3\r\n", 181 | "\r\n", 182 | "\r\n", 183 | "\r\n", 184 | "\r\n", 185 | "4\r\n", 186 | "\r\n", 187 | "mse = 5.188\r\n", 188 | "samples = 4\r\n", 189 | "value = 3.25\r\n", 190 | "\r\n", 191 | "\r\n", 192 | "2->4\r\n", 193 | "\r\n", 194 | "\r\n", 195 | "\r\n", 196 | "\r\n", 197 | "7\r\n", 198 | "\r\n", 199 | "mse = 0.222\r\n", 200 | "samples = 3\r\n", 201 | "value = 7.667\r\n", 202 | "\r\n", 203 | "\r\n", 204 | "6->7\r\n", 205 | "\r\n", 206 | "\r\n", 207 | "\r\n", 208 | "\r\n", 209 | "8\r\n", 210 | "\r\n", 211 | "f0 ≤ 3.5\r\n", 212 | "mse = 6.694\r\n", 213 | "samples = 7\r\n", 214 | "value = 5.143\r\n", 215 | "\r\n", 216 | "\r\n", 217 | "6->8\r\n", 218 | "\r\n", 219 | "\r\n", 220 | "\r\n", 221 | "\r\n", 222 | "9\r\n", 223 | "\r\n", 224 | "mse = 2.889\r\n", 225 | "samples = 3\r\n", 226 | "value = 6.667\r\n", 227 | "\r\n", 228 | "\r\n", 229 | "8->9\r\n", 230 | "\r\n", 231 | "\r\n", 232 | "\r\n", 233 | "\r\n", 234 | "10\r\n", 235 | "\r\n", 236 | "mse = 6.5\r\n", 237 | "samples = 4\r\n", 238 | "value = 4.0\r\n", 239 | "\r\n", 240 | "\r\n", 241 | "8->10\r\n", 242 | "\r\n", 243 | "\r\n", 244 | "\r\n", 245 | "\r\n", 246 | "\r\n" 247 | ], 248 | "text/plain": [ 249 | "" 250 | ] 251 | }, 252 | "metadata": {}, 253 | "output_type": "display_data" 254 | } 255 | ], 256 | "source": [ 257 | "# sklearn RF split\n", 258 | "rf = RandomForestRegressor(n_estimators=1,max_depth=3,bootstrap=False,min_samples_leaf=3)\n", 259 | "rf.fit(df,y)\n", 260 | "draw_tree(rf.estimators_[0], df, precision=3)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 11, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "from model.random_forest import RandomForest" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 20, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "# My RF split\n", 279 | "ens = RandomForest(df, y,1, nrows)\n", 280 | "tree = ens.trees[0]" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 22, 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "data": { 290 | "text/plain": [ 291 | "Sample size: 20. Pred value: 4.35. Loss: 7.628\n", 292 | "Best split from feature 1 at value 5" 293 | ] 294 | }, 295 | "execution_count": 22, 296 | "metadata": {}, 297 | "output_type": "execute_result" 298 | } 299 | ], 300 | "source": [ 301 | "tree" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 23, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "data": { 311 | "text/plain": [ 312 | "Sample size: 10. Pred value: 2.80. Loss: 4.360\n", 313 | "Best split from feature 1 at value 4" 314 | ] 315 | }, 316 | "execution_count": 23, 317 | "metadata": {}, 318 | "output_type": "execute_result" 319 | } 320 | ], 321 | "source": [ 322 | "tree.lhs_tree_node" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 24, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "data": { 332 | "text/plain": [ 333 | "Sample size: 10. Pred value: 5.90. Loss: 6.090\n", 334 | "Best split from feature 1 at value 6" 335 | ] 336 | }, 337 | "execution_count": 24, 338 | "metadata": {}, 339 | "output_type": "execute_result" 340 | } 341 | ], 342 | "source": [ 343 | "tree.rhs_tree_node" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 27, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "Sample size: 7. Pred value: 3.57. Loss: 4.245\n", 356 | "Best split from feature 0 at value 2\n", 357 | "Sample size: 3. Pred value: 1.00. Loss: 0.000\n", 358 | "\n" 359 | ] 360 | } 361 | ], 362 | "source": [ 363 | "print(tree.lhs_tree_node.lhs_tree_node)\n", 364 | "print(tree.lhs_tree_node.rhs_tree_node)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 28, 370 | "metadata": {}, 371 | "outputs": [ 372 | { 373 | "name": "stdout", 374 | "output_type": "stream", 375 | "text": [ 376 | "Sample size: 3. Pred value: 7.67. Loss: 0.222\n", 377 | "\n", 378 | "Sample size: 7. Pred value: 5.14. Loss: 6.694\n", 379 | "Best split from feature 0 at value 3\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "print(tree.rhs_tree_node.lhs_tree_node)\n", 385 | "print(tree.rhs_tree_node.rhs_tree_node)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": {}, 391 | "source": [ 392 | "My RF regressor is able to produce the exact split like sklearn RF" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "metadata": {}, 398 | "source": [ 399 | "# Benchmarking on random-generated dataset" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 29, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "nrows=1000\n", 409 | "ncols=10\n", 410 | "df = pd.DataFrame(np.random.randint(1,10,(nrows,ncols)))\n", 411 | "y = df[ncols-1].values\n", 412 | "df.drop(ncols-1,axis=1,inplace=True)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 32, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stdout", 422 | "output_type": "stream", 423 | "text": [ 424 | "MSE score: 2.0772524444444445\n", 425 | "R2 score: 0.6952187072704108\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "rf = RandomForestRegressor(n_estimators=5,bootstrap=False,min_samples_leaf=3)\n", 431 | "rf.fit(df,y)\n", 432 | "\n", 433 | "rf_pred = rf.predict(df)\n", 434 | "print(f'MSE score: {MSE(y,rf_pred)}')\n", 435 | "print(f'R2 score: {metrics.r2_score(y,rf_pred)}')" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 33, 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "name": "stdout", 445 | "output_type": "stream", 446 | "text": [ 447 | "MSE score: 2.1058883111111113\n", 448 | "R2 score: 0.6910171589778858\n" 449 | ] 450 | } 451 | ], 452 | "source": [ 453 | "tb = RandomForest(df, y, n_trees=5, sample_sz=nrows)\n", 454 | "tb_pred = tb.predict(df)\n", 455 | "print(f'MSE score: {metrics.mean_squared_error(y,tb_pred)}')\n", 456 | "print(f'R2 score: {metrics.r2_score(y,tb_pred)}')" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "# Benchmarking: Boston housing dataset" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "The dataset for this project originates from the UCI Machine Learning Repository. The Boston housing data was collected in 1978 and each of the 506 entries represent aggregated data about 14 features for homes from various suburbs in Boston, Massachusetts. For the purposes of this project, the following preprocessing steps have been made to the dataset:\n", 471 | "\n", 472 | "- 16 data points have an 'MEDV' value of 50.0. These data points likely contain missing or censored values and have been removed.\n", 473 | "- 1 data point has an 'RM' value of 8.78. This data point can be considered an outlier and has been removed.\n", 474 | "- The features 'RM', 'LSTAT', 'PTRATIO', and 'MEDV' are essential. The remaining non-relevant features have been excluded.\n", 475 | "- The feature 'MEDV' has been multiplicatively scaled to account for 35 years of market inflation." 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 34, 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "(489, 4)\n" 488 | ] 489 | }, 490 | { 491 | "data": { 492 | "text/html": [ 493 | "
\n", 494 | "\n", 507 | "\n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | "
RMLSTATPTRATIOMEDV
06.5754.9815.3504000.0
16.4219.1417.8453600.0
27.1854.0317.8728700.0
36.9982.9418.7701400.0
47.1475.3318.7760200.0
\n", 555 | "
" 556 | ], 557 | "text/plain": [ 558 | " RM LSTAT PTRATIO MEDV\n", 559 | "0 6.575 4.98 15.3 504000.0\n", 560 | "1 6.421 9.14 17.8 453600.0\n", 561 | "2 7.185 4.03 17.8 728700.0\n", 562 | "3 6.998 2.94 18.7 701400.0\n", 563 | "4 7.147 5.33 18.7 760200.0" 564 | ] 565 | }, 566 | "execution_count": 34, 567 | "metadata": {}, 568 | "output_type": "execute_result" 569 | } 570 | ], 571 | "source": [ 572 | "housing = pd.read_csv('data/housing.csv')\n", 573 | "print(housing.shape)\n", 574 | "housing.head()" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 35, 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "y = housing.MEDV.values\n", 584 | "X = housing.drop('MEDV',axis=1)" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 36, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "from sklearn.model_selection import train_test_split\n", 594 | "def get_train_val(X,y):\n", 595 | " X_train,X_val,y_train,y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n", 596 | " return X_train.reset_index(drop=True),X_val.reset_index(drop=True),y_train,y_val" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 37, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "X_train,X_val,y_train,y_val = get_train_val(X,y)" 606 | ] 607 | }, 608 | { 609 | "cell_type": "markdown", 610 | "metadata": {}, 611 | "source": [ 612 | "### sklearn Random Forest" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 81, 618 | "metadata": {}, 619 | "outputs": [ 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "MSE train score: 3100445764.0763206\n", 625 | "MSE val score: 3219791806.9669237\n", 626 | "R2 train score: 0.8905847020820434\n", 627 | "R2 val score: 0.8534966571980468\n" 628 | ] 629 | } 630 | ], 631 | "source": [ 632 | "rf = RandomForestRegressor(n_estimators=15,max_features=0.8,bootstrap=False,min_samples_leaf=10)\n", 633 | "rf.fit(X_train,y_train)\n", 634 | "\n", 635 | "rf_train_pred = rf.predict(X_train)\n", 636 | "rf_val_pred = rf.predict(X_val)\n", 637 | "\n", 638 | "print(f'MSE train score: {metrics.mean_squared_error(y_train,rf_train_pred)}')\n", 639 | "print(f'MSE val score: {metrics.mean_squared_error(y_val,rf_val_pred)}')\n", 640 | "print(f'R2 train score: {metrics.r2_score(y_train,rf_train_pred)}')\n", 641 | "print(f'R2 val score: {metrics.r2_score(y_val,rf_val_pred)}')" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 82, 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "data": { 651 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAD8CAYAAACyyUlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJztvXucVNWV6P9dDd10I9C8Gmy65ZG5JOIwxmibYGK8EUHQIer4GcZ4m6vjRbmT0YivMToaX4kZM+MAGrnmB75/9NUYk4yaG55Rx1yjUfAxQVqQBMGGBloeTTt004297h91uqnurlN1qs6pOqeq1/fzOZ+q2rXP2fucvc9ee+299l6iqhiGYRiGH4rCzoBhGIaR/5gwMQzDMHxjwsQwDMPwjQkTwzAMwzcmTAzDMAzfmDAxDMMwfGPCxDAMw/CNCRPDMAzDNyZMDMMwDN8MDDsDuWL06NE6ceLEsLNhGIaRV2zYsOETVa1IFa/fCJOJEyeyfv36sLNhGIaRV4jIdi/xUg5zichjIrJXRDbGhY0UkbUi8qHzOcIJFxF5UES2ish/iMipcedc7sT/UEQujws/TUT+4JzzoIhIpmkYhmEY4eBlzuQJYHavsFuA36jqZOA3zm+A84DJzrEAeBhiggG4E/gK8GXgzi7h4MRZEHfe7EzSMAzDMMIjpTBR1VeB/b2CLwSedL4/CVwUF/6UxngDGC4ilcAsYK2q7lfVA8BaYLbz3zBVfV1j2xc/1eta6aRhGIZhhESmcyZjVbURQFUbRWSME14FfBwXr8EJSxbekCA8kzQaM7wXwzAMVzo6OmhoaKCtrS3srGSV0tJSqqurKS4uzuj8oCfgJUGYZhCeSRp9I4osIDYUxvjx41Nc1jAMoy8NDQ0MHTqUiRMn4kzpFhyqyr59+2hoaGDSpEkZXSNTYbJHRCodjaES2OuENwAnxMWrBnY54d/oFf6KE16dIH4mafRBVZcBywBqamoi6QWsvrGZVRv3sPNgK1XDy5g9dSxTKsvDzpZhGA5tbW0FLUgARIRRo0bR1NSU8TUyXbT4AtBlkXU58Hxc+GWOxdU0oNkZqloNnCsiI5yJ93OB1c5/LSIyzbHiuqzXtdJJI++ob2xm2avbaG7toLK8lObWDpa9uo36xuaws2YYRhyFLEi68HuPKTUTEXmamFYxWkQaiFll3Qc8KyLzgR3AXCf6r4Hzga3AYeAKAFXdLyLfB95y4t2jql2T+t8mZjFWBqx0DtJNIx9ZtXEP5WXFlJfFxii7Pldt3GPaiWEYeUVKYaKql7r8dU6CuApc7XKdx4DHEoSvB6YmCN+Xbhr5xs6DrVSWl/YIG1o6kJ0HW0PKkWEYUWPfvn2cc06sKdy9ezcDBgygoiK2IP3NN9+kpKQkzOx1029WwEeRquFlNLd2dGskAC1tR6kaXhZirgwjfWzuL3uMGjWKd999F4C77rqLIUOGcNNNN/WIo6qoKkVF4W23aBs9hsjsqWNpbu2gubWDTtXu77Onjg07a4bhGZv760l9YzOL127hpp+9x+K1W7L2HLZu3crUqVP5u7/7O0499VQ+/vhjhg8f3v3/M888w5VXXgnAnj17uPjii6mpqeHLX/4yb7zxRuD5MWESIlMqy1lw1iTKy4ppbG6jvKyYBWdNsh6dkVfEz/0ViXR/X7VxT9hZyzm5FqybNm1i/vz5vPPOO1RVVbnGu/baa7n55ptZv349zz77bLeQCRIb5gqZKZXlJjyMvMbm/o6Ra6OaP/uzP+P0009PGW/dunVs3ry5+/eBAwdobW2lrCy4IXUTJoZh+MLm/o6Ra8F63HHHdX8vKioiZp8UI37FvqpmfbLehrkMw/CFzf0do2p4GS1tR3uE5UqwFhUVMWLECD788EM6Ozv55S9/2f3fjBkzWLp0affvrgn9QNMP/IqGYfQrbO7vGGEL1h/96EfMnj2bc845h+rqY5uLLF26lNdee42TTz6Zk046ieXLlweetsSrRYVMTU2NmnMswzDSpb6+nilTpniPn8dm0onuVUQ2qGpNqnNtzsQwDCNA+qtRjQ1zGYZhGL4xYWIYhmH4xoSJYRiG4RsTJoZhGIZvTJiETV0dTJwIRUWxz7q6sHNkGOlj9bjfY8IkTOrqYMEC2L4dVGOfCxbYi2jkF1aPs86AAQM45ZRTmDp1KnPnzuXw4cMZX+uVV15hzpw5AeYuhgmTMLntNuhdKQ4fjoUbRr5g9TjrlJWV8e6777Jx40ZKSkr4yU9+0uN/VaWzszOk3MUwYRImO3akF24YUcTqcU+yPOT39a9/na1bt/LRRx8xZcoU/v7v/757C/o1a9ZwxhlncOqppzJ37lw+/fRTAFatWsWJJ57ImWeeyS9+8YtA89OFCZMwGT8+vXDDiCJWj4+R5SG/o0ePsnLlSv7iL/4CgM2bN3PZZZfxzjvvcNxxx/GDH/yAdevW8fbbb1NTU8OiRYtoa2vjqquu4sUXX+S3v/0tu3fvDiQvvTFhEib33guDB/cMGzw4Fm4Y+YLV42NkacivtbWVU045hZqaGsaPH8/8+fMBmDBhAtOmTQPgjTfeYNOmTXzta1/jlFNO4cknn2T79u188MEHTJo0icmTJyMizJs3z1de3LDtVMKktjb2edttsSGB8eNjL2BXuGHkA1aPj5GlIb+uOZPexG9Br6rMnDmTp59+ukecd999FxHxlb4XTDMJm9pa+Ogj6OyMffbHF9DIf6wexwhxyG/atGm89tprbN26FYDDhw+zZcsWTjzxRLZt28Yf//hHgD7CJihMmBiGYQRFiEN+FRUVPPHEE1x66aWcfPLJTJs2jQ8++IDS0lKWLVvGX/7lX3LmmWcyYcKErKRvW9AbhmEkId0t6Kmry9shP9uC3jAMIyrU1uaN8AgSG+YyDMMwfGPCxDAMIwX9YTrA7z2aMDEMw0hCaWkp+/btK2iBoqrs27eP0tLSjK9hcyaGYRhJqK6upqGhgaamprCzklVKS0uprq7O+HwTJoZhGEkoLi5m0qRJYWcj8tgwl2EYhuEbEyaGYRiGb0yYGIZhGL4xYWIYhmH4xoSJYRiG4RsTJoZhGIZvTJgYhmEYvvElTETkehF5X0Q2isjTIlIqIpNE5Pci8qGI/FRESpy4g5zfW53/J8Zd51YnfLOIzIoLn+2EbRWRW+LCE6ZhGIZhhEPGwkREqoBrgRpVnQoMAL4F/AhYrKqTgQPAfOeU+cABVf0vwGInHiJyknPenwOzgf8lIgNEZACwFDgPOAm41IlLkjQMwzCMEPA7zDUQKBORgcBgoBGYDjzn/P8kcJHz/ULnN87/50jMl+SFwDOqekRVtwFbgS87x1ZV/ZOqtgPPABc657ilYRiGYYRAxsJEVXcC9wM7iAmRZmADcFBVjzrRGoAq53sV8LFz7lEn/qj48F7nuIWPSpKGYRiGEQJ+hrlGENMqJgHjgOOIDUn1pmurzUQe7TXA8ER5XCAi60VkfaFv0mYYhhEmfoa5ZgDbVLVJVTuAXwBfBYY7w14A1cAu53sDcAKA8385sD8+vNc5buGfJEmjB6q6TFVrVLWmoqLCx60ahmEYyfAjTHYA00RksDOPcQ6wCXgZ+GsnzuXA8873F5zfOP+/pDEHAS8A33KsvSYBk4E3gbeAyY7lVgmxSfoXnHPc0jAMwzBCwM+cye+JTYK/DfzBudYy4LvADSKyldj8xqPOKY8Co5zwG4BbnOu8DzxLTBCtAq5W1c+cOZFrgNVAPfCsE5ckaRiGYRghIIXsPSyempoaXb9+fdjZMAzDyCtEZIOq1qSKZyvgDcMwDN+YMDEMwzB8Y8LEMAzD8I0JE8MwDMM3JkwMwzAM3wxMHcUwklPf2MyqjXvYebCVquFlzJ46limV5Vk/1zCM6GCaieGL+sZmlr26jebWDirLS2lu7WDZq9uob2zO6rmGYUQL00wMX6zauIfysmLKy4oBuj9XbdyTUsPwc27QmIZkGP4wzcTwxc6DrZz+u5XMn3c21806kfnzzub0361k58FWT+cOLe3ZnxlaOtDTuYFQVwcTJ6JFRYw9eQrjV/3SNCTDyBATJoYvpm9Yx8wl32PY3l2IKsP27mLmku8xfcO6lOdWDS+jpe1oj7CWtqNUDS/LVnaPUVcHCxbA9u2IKiM/aeSCpXcx5aVfdWtLqzbuyX4+DKNAMGFi+GLGigcoOdLWI6zkSBszVjyQ8tzZU8fS3NpBc2sHnard32dPHZut7B7jttvg8OEeQcVH2jjz8UVAjjUkwygATJgYvijZ1ZBWeDxTKstZcNYkysuKaWxuo7ysmAVnTcrNXMWOHQmDhzY1AjnUkAyjQLAJeMMf48fD9u2Jwz0wpbI8nIlul3wfqqjs1pAuOb069/kyjDzFNBPDH/feC4MH9wwbPDgWHmUS5LtjUCm/uuSa3GpIhlEgmDAx/FFbC8uWwYQJIBL7XLYsFu4Fx6KKoqLYZ11dNnN7jAT5Ln70EWofvJXrZ37eBIlhpIn5MzHCo8uiKn4ifPDg9ISRYRhZxfyZGNGlSxuZN6+PRRWHD8csrcLKU641JMMoEGwC3sgtibSR3rhYWmWN3nnavj32G0xDMgyPmGZi5JYE6zv64NESLDAS5SksDckw8hQTJkZuSaV1hGEJ5panXGtIhpHHmDAxcksyrSNdS7CgcMtTrjUkw8hjTJgYueXee+ks67myvLOsDFasgI8+CmeOIl/XyhhGhDBhYuSU+ulzqJt/OwcrKlERDlZUUjf/duqnzwkvU37XyhiGYdZcRm5ZtXEPzTMv5PEL/ro7rLm1g0/C9mEy5nRmv/6eLVY0jAwxzcTIKaH7MHEwL4+GESwmTIycEqoPkzjivTwWiZgPE8PwiQkTI6eE6sMkjqhoSIZRKNiciZFTunyYxPtbv+T06pzPVVQNL6O5taPb7zxEz4eJ+aU38gkTJkbOCc2HSRyzp45l2avbgJhG0tJ2NFI+TLrmdMrLinvM6djW+EZUsWEuo18SqpdHD9icjpFvmGZi9FuioCG5sfNgK5XlpT3CbE7HiDKmmRhGBImK1ZtheMU0E8MzNiGcO6I+p2MYvTHNxPCELfLLLVGf0zGM3phmYngifkIY6P5cFcI2KP2FKM/pGEZvTDMxPJHXi/zMJa9hZB3TTAxP5MMiv4SYS17DyAm+NBMRGS4iz4nIByJSLyJniMhIEVkrIh86nyOcuCIiD4rIVhH5DxE5Ne46lzvxPxSRy+PCTxORPzjnPCgi4oQnTMPIHlHZBiVtzCWvYeQEv8NcDwCrVPVE4ItAPXAL8BtVnQz8xvkNcB4w2TkWAA9DTDAAdwJfAb4M3BknHB524nadN9sJd0vDyBJ5OyFsLnkNIydkPMwlIsOAs4C/BVDVdqBdRC4EvuFEexJ4BfgucCHwlKoq8Iaj1VQ6cdeq6n7numuB2SLyCjBMVV93wp8CLgJWOtdKlIaRRfJyQnj8+NjQVqJwwzACw49m8jmgCXhcRN4RkUdE5DhgrKo2AjifY5z4VcDHcec3OGHJwhsShJMkjR6IyAIRWS8i65uamjK/UyN/MZe82aPADRvqG5tZvHYLN/3sPRav3WJm8CnwI0wGAqcCD6vql4D/JPlwkyQI0wzCPaOqy1S1RlVrKioq0jnVKBTy2SVvlBvrLsOG7dtB9ZhhQ5Ty6ANbV5U+foRJA9Cgqr93fj9HTLjscYavcD73xsU/Ie78amBXivDqBOEkScMw+lJbCx99BJ2dsc98ESRRbqwL3LDB80abURb4OSZjYaKqu4GPReQLTtA5wCbgBaDLIuty4Hnn+wvAZY5V1zSg2RmiWg2cKyIjnIn3c4HVzn8tIjLNseK6rNe1EqVROFgl7d9EvbEucMMGT+uqoi7wc4zfdSbfAepEpAT4E3AFMQH1rIjMB3YAc524vwbOB7YCh524qOp+Efk+8JYT756uyXjg28ATQBmxifeVTvh9LmkUBrY2woh6Y13ghg2e1lUlE/j98D2VmHFV4VNTU6Pr168POxvemDgx8Ys6YUJsmMYofKJeB3p3eCBm2JAv81EpiHdOFr/RZg9z+KKimEbSG5HYkGqBICIbVLUmVTzbTiWKRL1XamSfqFuh5bNhgxtxQ8tTzvgiN32yPvm6KjctrEC0s3Sx7VSiSIEPIXilX29539Uo33ZbrBMxfnxMkESosa6fPodVy0/vWT5hZypTEgwtV928kOuTCch7702snUVF4OcY00yiSNR7pTnATDOJjhVaAmOQgiufTAweClE784EJkyhildR8oEcFF4ulbQ88Uljlk+nQclQEfgQwYRJV+nklzest7wsJlx77tEf/tbDKx+Y/fGPCxIgk5gM9Irj0zEfs211Y5WNDy74xYWJEkrzd8r7QcOmZd4yrLqzysaFl39g6EyOy9GtrrqiQZD1J/fQ5Vj79AK/rTMw02IgsU176FVMSmcbW1UXaZDaKZCyYa2tj81f33MnQpkZaKippueNuqmprmQImPIxuTDMxoolbj/jyy+HJJwt25XU28LSaOwvnRgnTcjPHVsBnE9uEMfu42f0vWxbtDRCzhB/fGn7MrAvBRLvg1sQ4RM3figmTdLGdQnODixWRfvZZ4vAC3mrGb2Po2cw6QSfJr4l2FBq8UARiljucURSQJkzSJepbg/di59JHODS2Ci0q4tDYKnYufSTsLHnDxYpIiwYkDG+pqOwZEKL2GHQD6rcxrBpexsTV/8b8eWdz3awTmT/vbCau/reeZrwunaTpG9albwLsPHstKmLsyVMYv+qXoTZ42VyzlLCsc9DhXLVxD2e+uZobrjqXG2ZP4YarzuXMN1eHqjGaMEmXPNqEcefSR6i44TsM27sLUWXY3l1U3PCd/BAoLnb/v5txMR2DSnsEdwwq5cVLrjkWkOBl7ph/JXXX/lPWe8fZ6DH6bQz/ZsurfPOhu3rUg28+dBd/s+XVY5FcOkkzVjyQnglw3LMXVUZ+0sjFi27hhlknhtbgJRWmPjodbmXd/t1bs97hHPPiz7lgac8yvWDpXYx58eeBpZEuNgGfJu3V4ynZ+XHf8KoTKGmIlkA5NLaKYXt39Q0fM45he3aGkKM0SWC1tXjM6Yxf9Utmrniw27po7bxr2TH7r7h+5udj57ls3948ZhyLl6/J6gTy4rVb+vjB6Prdnb9cX9PLdvZJtlOv33nA++S1W1oOHYNKeXbBHdQ+eGvqfAdEV6eqpL2tO6y9pJRPL61l5M+eztiYw61c7rzoL5Asb02fy3fbJuCzxLp5C2nv1TNuH1TKunkLQ8qRO0ObGtMKjxwJtpSZPXUs//fLs1i0fA2LVtWzdt61nP3UA1w368RjPUsXLXFYU2PWx8uzMaTieQGnWy/bizadZDuRKZXlXD/z89w/94tcP/PzyYVwCg29+Egb3/zpQ0njBE3Vv/yghyABKGlvY+SKJ3xpEG5lfWDU8YlPCHBrlii+27bOJE1eOm0Gep3y9ccXd/eMf3vF9bx82gzODzCdIEwZWyoqE/ZeWioqGRZURnPMlMpyFpw1iVUb98RU/WX3UHzEaSi6xqZHjoR9+/qc2zWvks09pKqGlyXUnIpm/1XG15xSWc5Nn6zvu9aj8pRjkZJ55/Ti0iCo7dTd0ooj5w2em4BzMebwOmTt5o3xjfk3cv6P70j5LP284+LynCXEvcRMM0mTquFlvPXV83h0xcssWf0Bj654mbe+el6gexIFNe7ecsfdtJf00qJKSmm54+7A8hoGXT3l2hf+v2OCpIvDhzn6mfa577biQaz6b7F5lWzuIeVpfiJd6uqounlhj2tW3byw5/h+MsMQL/tOBbWdSKK0etHV4OXM0sutgR2Q2JjDqwbhpjFOWnhlymfp+x2P4F5iJkzSJBd7RgVlylh19ZU0Lfoxh8aMQ0U4NGYcTYt+TNXVVwaW11Bx6UEOaD7Ai9fc1X3fe0aMZcnf/APPfeG/Zn0PKbchlap/+UHmF/ViQZhsKMuroAhip+r4tCCWXjxOg5dT01aXhnf/vL9N2Nna+Q+3e7psl5ac0Btjimfp+x2P4F5iNgGfAdleTXvTz96jsryUorgXsVOVxuY27p/7xcDSyXtcJnv3j67k8bqXu59fU0sbW/d+yp6WI1x0SlV2Vz9nwy+4l2tGyGd8/PsxfcM6Zqx4gJJdDT22vsmGoUJSMjXmyBL59I7b3lxZZEpleVa3YnAbi83b7b2zhcs4/xvzb6Sl7Wj386sYWkrJwAGcka2GKp5suFz2cM2d/3B7Qoulpn+4nSqPyQTRSYrffqWyvJQ3vzqbtV86p4/13M6DrVSW99QKsuoPpba2T69958/eo3PWRTwaN5/VqUpjhnlI5/kV4jtuw1wRxLZf90htLTv/+YEew3g7//kBJi28Mrznl42xbA/XfPbzZ/UY2js0ZhwvXnMXz37+LE9JBDXs5HX4Jhf+alLNyQSZh3SfXyG+4yZMIkjSsVijm/rGZu4fXdNtJrxo+RruHx3TxkN7ftkYy/ZwzZ0HW/lo1kU9DEM+mnWR555+UPN0Xk2js92Yemnce+dhW9OnvPHHfWxyhFA29z/z/I7n0T6ANswVUbI9lFYIxL/AQPfnqo17Uq+HyCYJhlSyfU2/wyZBDTt5zUe8iXfXsNAlp1cHVmbJ6kZXGvF5eH9XMw0HWvnC8UMYP+q4buHjtROSyfNL+Y4nM/eO4A7ZppkYeYv5iT+G355+UEM+6eQjrcWQXnF68tfNOpHrrzqXL/zmxe6/EtWNrjz8+bhypn1uFBNHD8l4/7PAh+3ybB9AEyZGH6Kw06sXzE/8MfwOjQY17BTqEG2vfcHK9+5i5pLbuwVKsrrht2MS5JBZN3m0DyCYMDF60XuseVvTp1z/zHtc9dRb6b8UWR7vLcRJTD/46ekHKQSyonGQZAfsrno2b16fnnzxkTa+9viilHXDb8ck/vnVNx5iy95P+cLxQzjx+GGZr6FJssVNFDFhYvQgfqx536dH2LL3UxBoPtyR3kuRg224M2oAozKhGZV8xJEtIRAEbjtg7//bK4/VMxe69mRLVjeC6JikPWSWqg5EcJV7UlS1XxynnXaaGqm58dl39f7VH+iiNZt17k9+p/df9j3dO3KsfoZo85hx+twN9+miNZvdL7BiheqECaoxEZL4mDAhFi/XrFihOnhwz7wMHpz7vCTKh4j3Z9P1jEVUR42KHSK5ea7xaecwvU6XuvRZ0YDkda3rmXpg066DumjNZr3x2Xd10ZrNumnXwaR5cnsG8e9Q13H/6g/0xmff7XkNL3Ux1887AcB69dDGht7I5+rIC2ESgYqzaM1mvfP5jbpozWb90bzvaWvJoB4VvtM5EuYv0QvidoTRiLsJuWSNTTbKJJWwTfZsUj1jL88103vKtTD2UJ/chEzW8ufhGcS/Q11H1+9uMqmLqfKVpbbDhEm+CZOI9Jo37Tqo1z3zjt75/EbdO3Jsei9qqkYyqBcnU7p6/70PkcTxs1Umbvnw8my8PONUwjHTewq6Acw0vbjDTTPpBP3PocO1Y8TI3HQEuq4/YYJ2iui+0ZX63A336f2rP9A7n9+o1z3zTk9NJ926mIwstx0mTHodoQuTVD2HXL+oSehS9ztJs9Hz0kj6fXH8kO4zDqBMEg6deBEIbs/GyzNO9lx93FOnS9qdAZdjw0PLtXnMuJRax5GSUt029zI9UlLaI7y1eJC+NmuuHhnUMzzrHYFeDXr7oFJd8Z0fJh4y8yCUPAvBLLcdJkx6HaEKEy89hyB7KkGRbqPnEt+1Uci1oEy3B5eqTFLMXWzadVCfuuZePVBR6fRWj9d/qr1dl/3Pe7S9d0PnV8B5fa4+6lnzmHEJz20eMy7luV5peGh5H+GQqD41jxmnDQ8t10VrNutzN9ynzWPG6WeI7hkxtnuez3djnSie2/Mf4DJ306ss4gVln87a4MGq3/522lpGtoW8CZNeR6jCxEvPwWvvIpfzKl7mQOLyl6ghyGovMdN78vr8UvUeU8xdvD3n0j733Vo8SJdccac+d8N9um90ZUzQ9m4Msjln4qMXu+I7P+wjBLt630HhJrDc7i9+svu8Ja/q3Idf07kPv6afJdOqvU58J9A0fnvu3L4dgWTlEdegJ3o/+sxBZlA+2RbyJkx6HaEKEy+9wWS95vhKlk7DkymJetwe0o7vJXaK6P7RlXr/Zd/T85a8qv/nu/frkaoTQjUuSJtkZeJBQzhaVJQwfPeIsT0nZdPtIPix5vIxvt67fD1Z96WJay87vsHtlaeu5zj3J7/TC378W73gx79110w8ahBu5XvQued9o2PaZjpCwFOjn4HmmG0h71WYmD+TXODV10QCnwtA323WE1yn/vX3gvGx0ns/IIjZti9bFvveO39xewTlk48GzyQqk9padx8jcSggCcI7ER5Y80F4z8btnlIQv7380NKBtLQdpbm1I/MV7gnyceiGmxO6mj40ZhzD9uxMmqe2jqO8te0ACixoeJ2Lf3IPJfGeOAcPRg8fTlgmKoLE+5txKV8VYcnqD/r6Xkn23jjPVouKYvpSsrST+Oh58n+/kvC9Xrx2S1b9snj1Z+JbmIjIAGA9sFNV54jIJOAZYCTwNvDfVbVdRAYBTwGnAfuAS1T1I+catwLzgc+Aa1V1tRM+G3gAGAA8oqr3OeEJ00iWT7/CxJevBw8VzRU3QRSHinDD028H85L7cLKUc4dHIdJePZ6SnR8njaMDBiAJ/IwfqKjkibpX8vLZBOYYzuWd2D/3UoY8XdfXL0sSD6HxeSoZIAhw5DNN6JjLs7ByeQ8OjRnHfQ+vTOxsLYWQPjS2KnXaCZ5L+6BS1l73fd766nkJ3+vAhXwvvAqTIFbALwTq437/CFisqpOBA8SEBM7nAVX9L8BiJx4ichLwLeDPgdnA/xKRAY6QWgqcB5wEXOrETZZGVvDt68HPtuQe9uFpqagMZPvwpOl5yEc+bG8S1L5j6+YtpH1QqXuEwYORBQvoLOu5HUd7SSm/+x83RPLZeCGwVfIumxiOfGVd2q6m4/P0w4tP5t6LT+b+uV/k/PtupKRhRw/XuS9ecg0dvcqtY1ApL15yTc+LJlh93jGolJWXXsPbOw5yqO0oY4cO6tkWpHDV23LH3QndBLfnk2YwAAAU6UlEQVTccfexgLi2QkU4WFHJuut+wJZzLnB9r6PissKXZiIi1cCTwL3ADcA3gSbgeFU9KiJnAHep6iwRWe18f11EBgK7gQrgFgBV/SfnmquBu5wk7lLVWU74rU7YfW5pJMurH80kzB53qh5we0kpT17xjxy6+G+CGV7y6f412y6N/RBkD+6mn73H2RvW8vXHFzO0qZHWIeUgUNbSjMT3SuN6q+3jqlk3byEvnTYj1GcTiTLKhntjD6Q1JOSUne7YwYFRx/PyZQv56eSvc8jZw+u0CcMZPaQ0rbZg59JHGHrPnd1pt9xxt6ugjMqwca7c9i4BbgaGOr9HAQdVtWvHtAbo9hpaBXwM4AiBZid+FfBG3DXjz/m4V/hXUqTRAxFZACwAGO9jc7QgfD1k+gKvm7eQGUtu7zH220lsLL5lzDjWzruWtZ8/i5Pi3NSCj91zXVzhet0PKMp+WLz4uPBK1fAy3vrqeWw554LusISNSpwfkhLgfOcIi95uddP12xEY2XBv7IHZU8eybP8s/vBf5/TsUCTSEJ2yE2BPYzM7Nu5hz7s7GTt0EJPHDmH0kFibkE5bUHX1leAIj2HO4Ro3z1z7ZjzMJSJzgL2quiE+OEFUTfFfUOF9A1WXqWqNqtZUVFQkiuIJvzuK+hkme+m0Gay97vscGjOOToSmkWP58RV3cv6Sf+/2pjesdGBww0vZ8BQYEYL0fxLkkF4ut/wPyqOib0LaxDDTIaGuobSLTqnipHHl3YIEsudu+P1dzbzxp3189MmnkR02jsePZvI14AIROR8oJSZklwDDRWSgozlUA10zTg3ACUCDM8xVDuyPC+8i/pxE4Z8kSSMrzJ46lmWvbgPo0Zu55PRqT+f76RHH94Bf/9M+jnTEJnSHFQ8AYhX5z8eVM3vq2OC81mXDU2AECLKnF5SnwFxrCkF5VPRNV/3KwKrML360Z79tQSri68OUymEMLh7A5t2fcrjjM06qLA/UG2XQZCxMVPVW4FYAEfkGcJOq1orIz4C/JmZtdTnwvHPKC87v153/X1JVFZEXgP8tIouAccBk4E1iGshkx3JrJ7FJ+v/mnPOySxpZwW/D4ecFjq+8nxs9uNv08aRxQ7t7Kl15iWoliwpBNwRBPPMgh9680FugNrW08f6uQ3R8pixeuyW38yd52GnJtbvhSRVDGDlkUF5Y/WXDB/x3gWdE5AfAO8CjTvijwP8vIluJaSTfAlDV90XkWWATcBS4WlU/AxCRa4DVxEyDH1PV91OkkTX8NBx+esTxlffTI0f5yudGdps+jhlaHOmeStTIdkOQCbnWFOIFavy6jK98bkR48yd5RjY7bpHRHDPAFi3mAK9WRJGwsjEyItOyC8NSsCuvazbtpmRAEVOrhnXPAeTj2pdCIoprtXK5zsRIgZdJP99rWYzQ8FN2YazNifcIeNbnK3pMJketF5xL44QokA9rtdzIxjCXkYBUqnGux879km5PvJC1Lj9lF+bQW9RNTyNjxpxDojgU6xUTJhEhn8ZK033JA2kUMtxPKhPSFXyZlF0UhGu2LZP8km8drKDIV2MaG+aKCH7XslBXF1u9XlQU+6yrCzyPXaS7VsH32oau/Yq2b4+tmt6+PfY7C/dY39jM+h8u5Yrab/Avl3yJK2q/wfofLk06vOK57Jwy0qIixp48hfGrfhnqkGZUtuFwI8h1QUb2MWESEXyNleawsYX0X3LfjYLLPk7cdpvnPHtl2wOPcMny7zO8qRFRZXhTI5cs/z7bHnjE9RxPZRdXRqLKyE8auWDpXUx56VfhLRwkwL22soDvDpaRU0yYRARfvcQcNraQ/kvuu1Hwsflkukx79F97blsOlBxpY9qj/+p6jqeyS1BGxUfaOPPxRUB6wjXQSekcarTpklRI5zrfEX5OUcGESYTIuJfotbEN6IVIV4vybaHitl9TFvZxGrFvt3t4kueXsuxcymhoUyPgXbgGavWXY402XVyF9Eu/ym2+I/6cooKtMykEvOz0m8h/RHExDBsG+/enPamdU2uuujo6r7qKotZjPffOsjKKli8PfBLebZfmo8NHMrC9LTOfNOBaRs1jxrF4+RrPuxentQ4hldGCzx2iQyPX+c7X5xQQXteZhO5ON1dHqG57s40XV6we3MyG5pc9BZt2HdSnrrlXD1TEXKUeqKjUp665VzftOtgzYrrubxOxYoV+VlbW47l8VlZ2zHVxGr65++QtgU/xFd/5oS5as7nvvbgQ7/N80ZrNeuvP39O5D7+mZ/3zSz2v46VOZOAiNhLkOt9+0wuiXgZ5nTTBfMAHLExCKkjPpMqf2wuRaeOYQ+L9fHcd3f7Tu/Dh27wPiZ5lEA1YAHUo/t5v/fl7Omvxv+usxf+ucx9+Te98fqNe98w7MYHiwSd5sjgNDy3v4eu94aHlaec1a3i5t6ik57VepqobQdbvNPEqTGyYywtJ3O7WT58T+noBT3hw/wtk3TlRJnhyEpTtoYiIDHXEr9nZtKvZ3VHTrBNTO58K0HVuTvHjBjvX6WU6BN37+iHWP9tOJUhcrKXav3tr2msSQiOR/4hEZNk5USZ4sgbLksVXl+VU3QX/s4+711z43+hN/KT0npYjDCsd2C1IIM4qzIvRgovvmoErV/YQJAAl7W0MvefObN1WeuTa50423G7Hh3uxxsyhRWOmmGbiBRcXo0rML3S8KWn7oFLWXfcDzr/vxgxzmkXiJ2RHjoSWFmhvP/Z/Nnt3PvC0UWYWem690524+t84+6kHGLFvd0/XvCGRdDJ+71sZGy1oURGSqL6LIBHTWqOOm0FHe9UJMf/04M2FsWkmBYJLL6+zqCjtNQmhUlsbq3idnfDJJ/DYY3nhUdHTOo4seO7rvXJ/x+y/4sFH17Jk9Qex5xjys0pmcl0/fQ5182/nYEUlKsLBikrq5t9O/fQ5Ka/bUlGZVrjhzrp5C2nvpdG2Dypl3byFxwK8aJEheaZMBxMmXnApyCKXXprbWoXIES9c0m0cc7yIK+U6jiwMfUR9O49kQnbVxj38ceaFPF73CktWf8Djda/wx5kXelpl33LH3bSX9GoAS0ppuePubN1KwRLvdltFODRmHGuv+z4vnTbjWCQvgiIP3GnbRo9ecHEx2vHdWxOqsB3jqinJcRYzIeO1H70nDLsWcUG4lTtgz31R31UX3DcF9LNxaNXVV7ITGHrPnQxtaqSlopKWO+6OxuR7nhHvdruL5tYOquLqlGcXxhH3TGlzJn7I4WK6oPHqsCshEbFsyja+nlHIRNHJUn8kyDoU1k7TNmeSC2prY4IjTvXMB0ECPnfyzQPLkiCI+q66ychnJ0teyBenWUHVoXxwnmeaST/F09oNN/qJZpLvRMFnSjbIZ40xU8LUNL1qJjZnUoB4aUR8zQfce2/iRVYRsiwx8tfJUir6o9OsfHCeZ8KkwPDq1dCXlz2vE4Y5plB74kZP8qFhTZdUdTcfjEFMmIRAUI1eouus2riHM99czcwVD3Zb4qyddy2rRv5VjzR6+5qevmEdM1Y8QMmuBm/CIWKWJf3RX3h/JR8a1nTwUnej7mIZbAI+5wQ1keZ2nSE//ykXLL2LYXt3IaoM27uLC5bexZgXf95n0hKIrd1o38j5P74jZuas+emvwbdrYCNvKDTjAi91Nx+MQUwzyTFBjfe6Xeeiny2luNeq/OIjbZz/9I+552vnJe79JNsbKELaRzIKcejDSExvrbpqeBmXnF4dqYY1HbzW3ajPgZkwyTFBNXpu1xm9f2/C+MP37XYXYgVg6ltoQx9GcqLesKZDodRdG+bKMb79oae4TsuYxPsnHRh1vPvWIDl0i5stCm3ow+g/FErdNWGSRRItrAqq4rhdp+WOuxPu8/PG/BvdhVgebCKXilyMKefLQjkjv8iH+RAv2KLFLFHf2Mz6Hy7lmz99iPJPdtM8+nhevOQaav7xaoCsWXNNqSxP6Pu7fvqc5Au9UvkL7+f0x4VyhgHeFy2aMAmSuAb58JBySlr/k4FHO7r/DtzXSZoCwNZhZI7tdRUA1mHJS2wFfK7ptZPu4JaDfaJ0+zoJQphksHNvIU1a5hqzFvNJqvpqgibvMc0kKDz6WA/MW13A+2OZ1pKcXGkmBVsOyeqr2/Y8EfPXkTcELJht1+Bc49GMtmNcdTCOpQI0582HHUnDJhcWNwVdDsnqqxcf6IY3ujTA7dtzvgDZhElQeDCj7Swro+SCOcEUdoDmvLZ6PDW5sLgp6HJIVl8LYJ1TZAhRMJswCYpE5rXFxTBqVE9fJ7/+dTCFHaA5b2Tc0+bYFXC6pHQd7JOdB1s5/XcrmT/vbK6bdSLz553N6b9bWRjzMsnqawGsc4oMIQpmEyZBkchH8+OPwyef9PSxHlRhB+gTOqiFlL4IUT2PCtM3rGPmku/12Fdt5pLvMX3DurCz5p9k9bUA1jlFhhAFs03A55oIOpaKxBqKCD6XXNNePT622Wbv8KoTKGko8CEfs+YKht5Wc+DbmCHrE/AicoKIvCwi9SLyvogsdMJHishaEfnQ+RzhhIuIPCgiW0XkP0Tk1LhrXe7E/1BELo8LP01E/uCc86BIzC2gWxp5QQR7YZFYgWvj5rHt/9MILyhqa2Odhngt3kifAEcs0kZVMzqASuBU5/tQYAtwEvDPwC1O+C3Aj5zv5wMrAQGmAb93wkcCf3I+RzjfRzj/vQmc4ZyzEjjPCU+YRrLjtNNO08iwYoXqhAmqIrHPFSvCzlH4TJigGhvg6nlMmBB2znKHPQMjggDr1YNMyFgzUdVGVX3b+d4C1ANVwIXAk060J4GLnO8XAk85+XsDGC4ilcAsYK2q7lfVA8BaYLbz3zBVfd25oad6XStRGvmB9cL6EkGNLefYMzDymEAm4EVkIvAl4PfAWFVthJjAAcY40aqA+AHhBicsWXhDgnCSpNE7XwtEZL2IrG9qasr09oxcEKZ6HhXsGRh5jO/tVERkCPBz4DpVPeRMaySMmiBMMwj3jKouA5ZBbAI+nXONEIiYK+BQ6KfPoGBX/vcjfGkmIlJMTJDUqeovnOA9zhAVzmeXt6YG4IS406uBXSnCqxOEJ0vDMIw8o6BX/vcj/FhzCfAoUK+qi+L+egHossi6HHg+Lvwyx6prGtDsDFGtBs4VkRGOVda5wGrnvxYRmeakdVmvayVKwzCMPKOgV/73I/wMc30N+O/AH0TkXSfsH4H7gGdFZD6wA5jr/PdrYhZdW4HDwBUAqrpfRL4PvOXEu0dV9zvfvw08AZQRs+Za6YS7pWEYRp5hOzIXBhkLE1X9vySe1wA4J0F8Ba52udZjwGMJwtcDUxOE70uUhmEY+Ueh+EDv79h2KoZhhEqh+EDv75gwMQwjVCKxA4PhG/O0aBhG6JgX0OAIy8zahEmAmK28Ybhj70f2id+0Nd7MOheang1zBYTZyhuGO/Z+5IYwzaxNmASE2cobhjv2fuSGMB3dmTAJiMh4KzSMCGLvR24I09GdCZOAiIS3QsOIKPZ+5IYwzaxNmASE2cobhjv2fuSGMM2szW1vgJi1imG4Y+9HfuLVba+ZBgeI2cobhjv2fhQ2NsxlGIZh+MaEiWEYhuEbG+YyIoONqRtG/mKaiREJbIW0YeQ3JkyMSGArpA0jvzFhYkQCWyFtGPmNCRMjEtgKacPIb0yYGJHAVkgbRn5jwsSIBOZtzzDyGzMNNiKDrZA2jPzFNBPDMAzDNyZMDMMwDN+YMDEMwzB8Y8LEMAzD8I0JE8MwDMM3/cY5log0AdsDuNRo4JMArpNv9Mf77o/3DP3zvvvjPYO3+56gqhWpLtRvhElQiMh6L17HCo3+eN/98Z6hf953f7xnCPa+bZjLMAzD8I0JE8MwDMM3JkzSZ1nYGQiJ/njf/fGeoX/ed3+8Zwjwvm3OxDAMw/CNaSaGYRiGb0yYpIGIzBaRzSKyVURuCTs/2UBEThCRl0WkXkTeF5GFTvhIEVkrIh86nyPCzms2EJEBIvKOiPzK+T1JRH7v3PdPRaQk7DwGiYgMF5HnROQDp8zP6A9lLSLXO/V7o4g8LSKlhVjWIvKYiOwVkY1xYQnLV2I86LRv/yEip6aTlgkTj4jIAGApcB5wEnCpiJwUbq6ywlHgRlWdAkwDrnbu8xbgN6o6GfiN87sQWQjUx/3+EbDYue8DwPxQcpU9HgBWqeqJwBeJ3XtBl7WIVAHXAjWqOhUYAHyLwizrJ4DZvcLcyvc8YLJzLAAeTichEybe+TKwVVX/pKrtwDPAhSHnKXBUtVFV33a+txBrXKqI3euTTrQngYvCyWH2EJFq4C+BR5zfAkwHnnOiFNR9i8gw4CzgUQBVbVfVg/SDsibmfqNMRAYCg4FGCrCsVfVVYH+vYLfyvRB4SmO8AQwXkUqvaZkw8U4V8HHc7wYnrGARkYnAl4DfA2NVtRFiAgcYE17OssYS4Gag0/k9Cjioql3+hAutzD8HNAGPO0N7j4jIcRR4WavqTuB+YAcxIdIMbKCwyzoet/L11caZMPGOJAgrWFM4ERkC/By4TlUPhZ2fbCMic4C9qrohPjhB1EIq84HAqcDDqvol4D8psCGtRDhzBBcCk4BxwHHEhnh6U0hl7QVf9d2EiXcagBPiflcDu0LKS1YRkWJigqROVX/hBO/pUnmdz71h5S9LfA24QEQ+IjaEOZ2YpjLcGQqBwivzBqBBVX/v/H6OmHAp9LKeAWxT1SZV7QB+AXyVwi7reNzK11cbZ8LEO28Bkx2LjxJiE3YvhJynwHHmCR4F6lV1UdxfLwCXO98vB57Pdd6yiareqqrVqjqRWNm+pKq1wMvAXzvRCuq+VXU38LGIfMEJOgfYRIGXNbHhrWkiMtip7133XbBl3Qu38n0BuMyx6poGNHcNh3nBFi2mgYicT6y3OgB4TFXvDTlLgSMiZwK/Bf7AsbmDfyQ2b/IsMJ7YyzhXVXtP7BUEIvIN4CZVnSMinyOmqYwE3gHmqeqRMPMXJCJyCjGDgxLgT8AVxDqZBV3WInI3cAkx68V3gCuJzQ8UVFmLyNPAN4jtDrwHuBP4NxKUryNYHyJm/XUYuEJV13tOy4SJYRiG4Rcb5jIMwzB8Y8LEMAzD8I0JE8MwDMM3JkwMwzAM35gwMQzDMHxjwsQwDMPwjQkTwzAMwzcmTAzDMAzf/D/mXk3Y7uFU6QAAAABJRU5ErkJggg==\n", 652 | "text/plain": [ 653 | "
" 654 | ] 655 | }, 656 | "metadata": {}, 657 | "output_type": "display_data" 658 | } 659 | ], 660 | "source": [ 661 | "plt.scatter(range(0,len(y_val)),y_val,alpha=0.5,label='True')\n", 662 | "plt.scatter(range(0,len(rf_val_pred)),rf_val_pred,c='r',label='Pred')\n", 663 | "plt.legend()\n", 664 | "plt.show()" 665 | ] 666 | }, 667 | { 668 | "cell_type": "markdown", 669 | "metadata": {}, 670 | "source": [ 671 | "## My Random Forest" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 86, 677 | "metadata": {}, 678 | "outputs": [ 679 | { 680 | "data": { 681 | "text/plain": [ 682 | "391" 683 | ] 684 | }, 685 | "execution_count": 86, 686 | "metadata": {}, 687 | "output_type": "execute_result" 688 | } 689 | ], 690 | "source": [ 691 | "X_train.shape[0]" 692 | ] 693 | }, 694 | { 695 | "cell_type": "markdown", 696 | "metadata": {}, 697 | "source": [ 698 | "One possible thing I can tune here but not in sklearn Random Forest is the size of data subset for each tree (sample_sz)" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 134, 704 | "metadata": {}, 705 | "outputs": [ 706 | { 707 | "name": "stdout", 708 | "output_type": "stream", 709 | "text": [ 710 | "MSE train score: 3372489559.856298\n", 711 | "MSE val score: 3477783760.5038266\n", 712 | "R2 train score: 0.8809842267868835\n", 713 | "R2 val score: 0.8417577977080082\n", 714 | "Wall time: 221 ms\n" 715 | ] 716 | } 717 | ], 718 | "source": [ 719 | "%%time\n", 720 | "tb = RandomForest(X_train, y_train, n_trees=20, sample_sz=150,min_leaf=4,max_features=1)\n", 721 | "tb_train_pred = tb.predict(X_train)\n", 722 | "tb_val_pred = tb.predict(X_val)\n", 723 | "print(f'MSE train score: {metrics.mean_squared_error(y_train,tb_train_pred)}')\n", 724 | "print(f'MSE val score: {metrics.mean_squared_error(y_val,tb_val_pred)}')\n", 725 | "print(f'R2 train score: {metrics.r2_score(y_train,tb_train_pred)}')\n", 726 | "print(f'R2 val score: {metrics.r2_score(y_val,tb_val_pred)}')" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 135, 732 | "metadata": {}, 733 | "outputs": [ 734 | { 735 | "data": { 736 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAD8CAYAAACyyUlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJztnXt8VPW16L8rmJCgJLwChkQC9mLFcnwGi63HUx9gtPhobz3WGyrXE+C09Q3eorVWa7XVU4uPyrEf8FG85CjU2vo4CoLaa2uLCoqWgigtggkxhEcCloSA+d0/9p4wM8ye7Jm9Z/aeyfp+PvOZ2Wv247efa//WWr+1xBiDoiiKonihIOgGKIqiKLmPKhNFURTFM6pMFEVRFM+oMlEURVE8o8pEURRF8YwqE0VRFMUzqkwURVEUz6gyURRFUTyjykRRFEXxzGFBNyBbDBs2zIwePTroZiiKouQUq1ev3m6MKe9tvj6jTEaPHs2qVauCboaiKEpOISKb3czXq5lLRB4VkW0isjZKNkRElovIh/b3YFsuIvKAiGwUkfdE5OSoZabZ838oItOi5KeIyF/sZR4QEUl3G4qiKEowuPGZ/AqojZPdCLxsjBkLvGxPA5wHjLU/M4GHwFIMwK3AF4FTgVsjysGeZ2bUcrXpbENRFEUJjl6ViTHmNWBnnPgiYKH9eyFwcZT8cWOxEhgkIhXAucByY8xOY8wuYDlQa/9Xaoz5s7HSFz8et65UtqEoiqIERLo+kxHGmGYAY0yziAy35ZXAx1HzNdqyZPLGBPJ0ttGc5r4oiqI4sn//fhobG+ns7Ay6KRmluLiYqqoqCgsL01rebwe8JJCZNOTpbOPQGUVmYpnCGDVqVC+rVRRFOZTGxkYGDhzI6NGjsV26eYcxhh07dtDY2MiYMWPSWke6yqRFRCrsHkMFsM2WNwJHRc1XBWy15V+Jk//ellclmD+dbRyCMWY+MB+gpqYmlFXA1je3s3RtC01tHVQOKqF2/AjGVZQF3SxFUWw6OzvzWpEAiAhDhw6ltbU17XWkO2jxWSASkTUNeCZKfrkdcTURaLdNVcuAySIy2Ha8TwaW2f/tEZGJdhTX5XHrSmUbOcf65nbmv7aJ9o79VJQV096xn/mvbWJ9c3vQTVMUJYp8ViQRvO5jrz0TEXkCq1cxTEQasaKy7gKWiEg9sAW4xJ79BeB8YCOwF7gCwBizU0R+DLxlz3e7MSbi1P8OVsRYCfCi/SHVbeQiS9e2UFZSSFmJZaOMfC9d26K9E0VRcopelYkx5jKHv85OMK8BrnRYz6PAownkq4DxCeQ7Ut1GrtHU1kFFWXGMbGDxYTS1dQTUIkVRwsaOHTs4+2zrUfjJJ5/Qr18/ysutAelvvvkmRUVFQTavhz4zAj6MVA4qob1jf0+PBGBP5wEqB5UE2CpFSR31/WWOoUOHsmbNGgBuu+02jjjiCG644YaYeYwxGGMoKAgu3aImegyQ2vEjaO/YT3vHfrqN6fldO35E0E1TFNeo7y+W9c3t3Lv8A2749bvcu/yDjB2HjRs3Mn78eL797W9z8skn8/HHHzNo0KCe/5988kmmT58OQEtLC1//+tepqanh1FNPZeXKlb63R5VJgIyrKGPmGWMoKymkub2TspJCZp4xRt/olJwi2vdXINLze+nalqCblnWyrVjXrVtHfX0977zzDpWVlY7zXXPNNXzve99j1apVLFmypEfJ+ImauQJmXEWZKg8lp1Hf30GyHVTzuc99jgkTJvQ634oVK9iwYUPP9K5du+jo6KCkxD+TuioTRVE8ob6/g2RbsR5++OE9vwsKCrDikyyiR+wbYzLurFczl6IonlDf30EqB5Wwp/NAjCxbirWgoIDBgwfz4Ycf0t3dzW9/+9ue/8455xzmzZvXMx1x6Pu6fd/XqChKn0J9fwcJWrHefffd1NbWcvbZZ1NVdTC5yLx583j99dc5/vjjOe6441iwYIHv25boblE+U1NTY7Q4lqIoqbJ+/XrGjRvnfv4cDpNOtK8istoYU9PbsuozURRF8ZG+GlSjZi5FURTFM6pMFEVRFM+oMlEURVE8o8pEURTvNDTA6NFQUGB9NzQE3SIly6gDXlEUbzQ0wMyZsHevNb15szUNUFcXXLuUrKI9E0VRvHHzzQcVSYS9ey254gv9+vXjxBNPZPz48VxyySXsjT/eKfD73/+eKVOm+Ng6C1UmQaPmASXX2bIlNbmSMiUlJaxZs4a1a9dSVFTEL3/5y5j/jTF0d3cH1DoLVSZBEjEPbN4Mxhw0D6hCUXKJUaNSk+c7GX5B/Od//mc2btzIRx99xLhx4/jud7/bk4L+pZde4rTTTuPkk0/mkksu4dNPPwVg6dKlHHvssZx++uk8/fTTvrYngiqTIFHzgJIP3HknDBgQKxswwJL3NTL8gnjgwAFefPFF/umf/gmADRs2cPnll/POO+9w+OGHc8cdd7BixQrefvttampqmDt3Lp2dncyYMYPnnnuOP/zhD3zyySe+tCUeVSZBouYBJR+oq4P586G6GkSs7/nz+6bzPUMviB0dHZx44onU1NQwatQo6uvrAaiurmbixIkArFy5knXr1vHlL3+ZE088kYULF7J582bef/99xowZw9ixYxERpk6d6qktTmg0V5CMGmW9uSSSK0ouUVfXN5VHPBl6QYz4TOKJTkFvjGHSpEk88cQTMfOsWbMGEfG0fTdozyRI1DygKPlFgP6jiRMn8vrrr7Nx40YA9u7dywcffMCxxx7Lpk2b+Nvf/gZwiLLxC1UmQaLmAUXJLwJ8QSwvL+dXv/oVl112GccffzwTJ07k/fffp7i4mPnz5/PVr36V008/nerq6oxsX1PQK4qiJCHVFPQ0NFg+ki1brB7JnXfmzAuipqBXFEUJC33Uf6RmLkVRFMUzqkwURVF6oS+4A7zuoyoTRVGUJBQXF7Njx468VijGGHbs2EFxcXHa61CfiaIoShKqqqpobGyktbU16KZklOLiYqqqqtJeXpWJoihKEgoLCxkzZkzQzQg9auZSFEVRPKPKRFEURfGMKhNFURTFM6pMFEVRFM+oMlEURVE8o8pEURRF8YwqE0VRFMUznpSJiFwvIn8VkbUi8oSIFIvIGBF5Q0Q+FJHFIlJkz9vfnt5o/z86aj032fINInJulLzWlm0UkRuj5Am3oSiKogRD2spERCqBa4AaY8x4oB/wTeBu4F5jzFhgF1BvL1IP7DLG/A/gXns+ROQ4e7kvALXAf4pIPxHpB8wDzgOOAy6z5yXJNhRFUZQA8GrmOgwoEZHDgAFAM3AW8JT9/0LgYvv3RfY09v9ni1VL8iLgSWPMPmPMJmAjcKr92WiM+bsxpgt4ErjIXsZpG4qiKEoApK1MjDFNwD3AFiwl0g6sBtqMMQfs2RqBSvt3JfCxvewBe/6h0fK4ZZzkQ5NsQ1EURQkAL2auwVi9ijHASOBwLJNUPJFUm4kq2hsf5YnaOFNEVonIqnxP0qYoihIkXsxc5wCbjDGtxpj9wNPAl4BBttkLoArYav9uBI4CsP8vA3ZGy+OWcZJvT7KNGIwx840xNcaYmvLycg+7qiiKoiTDizLZAkwUkQG2H+NsYB3wKvANe55pwDP272ftaez/XzFWgYBngW/a0V5jgLHAm8BbwFg7cqsIy0n/rL2M0zYURVGUAPDiM3kDywn+NvAXe13zgTnALBHZiOXfeMRe5BFgqC2fBdxor+evwBIsRbQUuNIY85ntE7kKWAasB5bY85JkG4qiKEoASD5XD4umpqbGrFq1KuhmKIqi5BQistoYU9PbfDoCXlEURfGMKhOl79LQAKNHQ0GB9d3QEHSLFCVn0bK9St+koQFmzoS9e63pzZutaYC6uuDapSg5ivZMlL7JzTcfVCQR9u615IqipIwqE8U7uWgu2rIlNbmiKElRM5fijYYGumfMoKCjw5revNmaBlfmovXN7Sxd20JTWweVg0qoHT+CcRVlGW0yAKNGWaatRHJFUVJGeyaKJ7rm3HRQkdgUdHTQNeemXpdd39zO/Nc20d6xn4qyYto79jP/tU2sb27PVHMPcuedMGBArGzAAEuuKErKqDJRPFG4tTEleTRL17ZQVlJIWUkhBSI9v5eubfG7mYdSVwfz50N1NYjQVXkUL1x9OzcUjefe5R9kR6EpSh6hykTxxK6hR6Ykj6aprYOBxbGW1oHFh9HU1uGwhM/U1cFHH7G+aRdzfv4sb36pNvs9JEXJE1SZKJ5YWT+brv7FMbKu/sWsrJ/d67KVg0rY03kgRran8wCVg0p8bWNvBNpDUpQ8QZWJ4okx105n8YxbaCuvwIjQVl7B4hm3MOba6b0uWzt+BO0d+2nv2E+3MT2/a8ePyELLDxJ4D0lR8gCN5lI8Ma6iDL5/JY9d+I2UI7LGVZQx84wxMdFcl06oyk40VxSVg0po79hPWUlhjyyIHpKi5DKqTBTPjKsoS1sBeFnWL2rHj2D+a5sAq0eyp/MA7R37uXRCVaDtUpRcQs1cSp8n0kMqKymkub2TspJCZp4xJnAlpyi5hCoTJVhCMnp+3CvPc/2Mydxz6UlcP2My4155PpB2KEquomYuJTjCkmwxLO1QlBxGi2MpwTF6dOKUJtXV8NFHfa8dihJCtDiWEn7CkmwxLO1QlBxGlYkSHE5JFbOdbDEs7VCUHEaViRIcYUm2GJZ2KEoOo8pECY64ZItUV1vT2XZ6h6UdipLDqANeyTqB1TBRFCVl1AGvhJJAa5goipIxdJyJklWiM/QCPd9L17ZkvXeiPSRF8Q/tmShZpamtgwl/epH6qWdy3bnHUj/1TCb86cWsZ+jVHpKi+IsqEyWrnLV6BZPuu4XSbVsRYyjdtpVJ993CWatXZLUdWsNEUfxFlYmSVc5ZdD9F+zpjZEX7Ojln0f1ZbYfWMFEUf1GfiZJVihxqwzvJM0Uu1DBRn46SS2jPRMkuIRltHpYqj06oT0fJNVSZKNklJKPNw17DRH06Sq6hZi4lu0RGld98s5VIcdQoS5EEMNo8DFUenWhq66CirDhGpj4dJcyoMlGyT12dpirphVzw6ShKNKpMFNeoQzh7aF16JddQn4niCnUIZ5ew+3QUJR7tmSiuCFMalL5CmH06ihKP9kwUV+T0IL+GBqs0b0GB9d3QEHSLFCXvUGWiuKJyUAl7Og/EyHLCIdzQADNnWjXejbG+Z87MDYWiSlDJITwpExEZJCJPicj7IrJeRE4TkSEislxEPrS/B9vziog8ICIbReQ9ETk5aj3T7Pk/FJFpUfJTROQv9jIPiIjY8oTbUDJH2Af5OXLzzbB3b6xs715LHmZyWQkqfRKvPZP7gaXGmGOBE4D1wI3Ay8aYscDL9jTAecBY+zMTeAgsxQDcCnwROBW4NUo5PGTPG1mu1pY7bUPJEDnrEN6yJTV5WMhVJaj0WdJWJiJSCpwBPAJgjOkyxrQBFwEL7dkWAhfbvy8CHjcWK4FBIlIBnAssN8bsNMbsApYDtfZ/pcaYPxurHOTjcetKtI38IYQmjnEVZVw/6RjuueQErp90TPgVCYQmfUvK5KoSVPosXqK5jgZagcdE5ARgNXAtMMIY0wxgjGkWkeH2/JXAx1HLN9qyZPLGBHKSbCMGEZmJ1bNhVNgfHtFETByRN9OIiQN0sF+q3Hln7LGEQNK3pMyoUdZ5TyRXsoKOq0oNL2auw4CTgYeMMScB/yC5uUkSyEwactcYY+YbY2qMMTXl5eWpLBosauLwj7o6mD8fqqtBxPqePz/8SjkkOcz6KjquKnW8KJNGoNEY84Y9/RSWcmmxTVTY39ui5j8qavkqYGsv8qoEcpJsIz9QE4e/1NXBRx9Bd7f1HXZFArmhBENoivULTbSZOmkrE2PMJ8DHIvJ5W3Q2sA54FohEZE0DnrF/Pwtcbkd1TQTabVPVMmCyiAy2He+TgWX2f3tEZKIdxXV53LoSbSM/yFU7v+IvYVaCeR5tltPjqgLCazTX1UCDiLwHnAj8BLgLmCQiHwKT7GmAF4C/AxuBBcB3AYwxO4EfA2/Zn9ttGcB3gIftZf4GvGjLnbaRH6iJQwk7eW6KdT2uKo97Z6kiVqBU/lNTU2NWrVoVdDPc09AQijTtipKQggKrRxKPiNWTynEiPpOyksKYRJsx4fDxgTJgvfSFzRzpERFZbYyp6W0+HQEfVsJs4lCUfDTFRvUyxp12AjdsX5V8XFWe985SRXsmSmjR0MwQ09BA94wZFHQc9CF0l5RQsGBBbr74pNPLyPPeWQTtmSg5jYZmhpv1Z02hof4HtJVXYERoK6+gof4HrD9rStBNS490ehn52DvzgKagV0LJ0rUtnP7mMiYteoCBrc3sKa9g+dRrWDrka9o7CQFL17bQPukiHrvwGz2y9o79bM/VkgTphOPn6oDYDKE9EyWUDH/uN1w47zZKt21FjKF021YunHcbw5/7TdBNU8jD0Nl0ehm5MBYoi6gyUULJBYsfpHBfZ4yscF8nFyx+MKAWKdHkbEkCJ9INx9dAmR5UmSihZGBrc0pyJbvkbEkCJ7SX4RlVJkooEQfzgpNcySAJBublbEmCZGgvwxPqgFfCiTo3fSXtMOv4EODNm61pYFxdXW4rD8VXtGeihBM1O/iGlzDrrjk3xYwlASjo6KBrzk3uNh6SdCPrm9u5d/kH3PDrd7l3+QeZDzEPyX5nEx20qOQcfXEwo5d9vnf5B7R37KespLBHFpm+ftIxSZc1BQVIgmeEEUGcBuZFUgFt3my9CEQvH0C6EVepUfwkS2lWmuY9zMDbb+0Jnd/zwx9ReeV039YfQQctKnnJ+uZ2Vv1kHlfUfYWfXXoSV9R9hVU/mZfXgxm9DuBsautgwp9epH7qmVx37rHUTz2TCX960VUY766hR6Ykj8kmDIeOEA8g3UhkzNKsGZOZVTuOWTMmc/qbyzKXTj4LaVaa5j1M+ayrY0Lny2ddTdO8h33bRqqoMslzmuY9zO4RlZiCAnaPqAz0YvODTfc/zKULfsyg1mbEGAa1NnPpgh+z6f7w7Jffx9xrbY2zVq9g0n23xDx4Jt13C2etXhE7YwLTzMr62XT1L46Zrat/MSvrZyfeWKIHaRwmy3V5ko5Z8miOSmg+y0I9ooG330pRV2zofFFXJwNvv9W3baSMMaZPfE455RTjF40PLjDtw0eabhHTPnykaXxwgW/r9pPGBxeYfUXFxljvh8aA2VdUHNr2umHHsIqY/Yl8dgyrcFxm3dY2M/elDWb2kjVm7ksbzLqtbRlrXyaO+ewla8zzc34Wc809P+dnZvaSNa6W31d5VMJjtq/yqIMzLVpkzIABsfMMGGAaH1xgHr/qTrOrvMJ0i5hd5RXm8avudD6GIgm3Ff1pHz4y7WORDu3DRyZsxz8GDkq4z2bRIlfrXbe1zVz35Dvm1mfWmnuWvW9ufWatue7JdxyPt6mu9m2fuh2Oc7eIb9uIAKwyLp6xgT/ks/XxS5nk0gPa6SbK9s3sJ443UeRmjTwIFi0yprradIuYHcMqzFOz7oq54TOlUDJxzP97zj1mX/+4a65/sfnvOfe4W4HTAz76wVNd7fgATEkZO63H/nT1LzaLrv5J2sciHZJeMx4e+nNf2mBufWatmfvShp7Prc+stc6LByXlhmze226ViTrgU2T3iEpKt209VD58JKUtTZ7X7ydpOU9DTlfVKIqaPnaeYcAAmDYNFi6MMbfs71/M8uvuYMPZF7h2PqdDJo650z53VR5FUaML08no0Qd9GNFUV1vjKcC/DLgJnM+Rte4ZPpLlU69hS+3XMnLsHXHafydc7vMNv36XirJiCkR6ZN3G0NzeyT1dazNajyjiM4k2dXUVFdM69xe+O+HdOuB1nEmKZGtkth8RS3vKKxIqvj3lFZT61dAsU3T3Tw9JfR7D3r1W1Mxnn8WIC/d1cvpjc9lw9gUZzSGViWNetLUxJfkhuBmzM2pU4gduqoNEIw/Mm2/GbNnCrqFH8url1/LRuRcfjKLK9ih5p/0vKYEdOw6d3+U+Vw4qOSRKrielzKS6XpWHl3u88srpNEFWorncog74FNlTXpGSPB38Sr++54c/oqsoznlaVMyeH/7It7Zmnbo6q2ZGdbXjLCZOkUSIKPxM5pDKyDH3murczZgdP0tF2yPJpbublvfWs6X2awlHyWdt7IfT/t9/P90lsddBd0mJ6332klLGj3u88srplLY0Id3dlLY0BapIQJVJymTjAe01eidC5ZXTaZ37C3YPH4kRYffwkRnpBmedSNoLB4ViCvollG8fPDzjOaQycsz9eND3liokQ4NEx1WUcf2kY7jnkhO4ftIxMYokq/VqEuy/15osXlLK+HWPhwn1maRBpgcLJbXFXnKCb9vJeRwGh/3x9K/yxf/3XEzW4c6i/vzH12cz8N+m5eYgx8hAwCQ2+GwNYusNN+YbLwMp/SLINuTSPa4+kwxSeeV0sG/SUvvj6/qT2WJzCRcPQE9E2eejt/HW8Am0jD/lkMJaA7Pt+PWTuuQ2+HiHbOm2rRTPupomcK1Q/PDTRY82j+5xxL+xN7V1UFEW28PPdj0Uv9uQyvHLm3s8ClUmIaR2/Ajmv7YJICb9w6UTqgJuWQrE9xo2b7amwVeFsv6sKSxdMCHmBq4F5u88l7/8y5TY9Bm5mh7dBUkHsblQJm6VQG9Em2+Anu+lcRUYs/Ew7e3h7mcbUj1+eXGPx6E+kxAS2vTeqYwWzkJKCSe7OxDO45dBvEYZ+mXDd1uBMdP1UNz4ZOLbsKn1U1b+bQfr7MCAVPw3qR6/0N7jHlBlElKcHJeBEZ1zyZiDPY14hRJROE5x/T6mlEh2AzsevzzN5uo1ytB1Gd5ejp/bCoyZfpi6ebhHt2F9824+2PYpnz/yCI49sjSt/GepljH29R4PwXWtykRJTuQinTq1955GfJK/RPhY3CrlG9itQsxBvEYZulICLo5fKj2OlB+mTg/MBHK310akDV8YWcbEo4cyetgRafXM0ipj7JcCCMl1rcpEOYRI/H/DNT9lf/305MohuqfRW5I/n4tbpXwDZ8H0FhReQ5JdKQEXxy9jPQ6nB+Z3v5tQftbqFSldG+n0LKJJ2WTmpwIIyXWtocFKDNGOxOtnTKYswWjuGNyk5IjM53M0V8p1KvxKGRImfIyY6zUaKYjjF10bJRH9+h2S7QCsVDNzfv6s62sjpTBhh2PeE5q9rZltg4ez7LKrab3oG4m37SbFjVsyfF60nomSFtG25tLenLeJUnIkIqJIbr7ZV5tuym/B6YwkD4Et2hGfzRu9mp28jsRPFRdmU6dsB0VNH3P37As59U9LXV0brs1zSXpIld+71kpzj2HErhYue+QOxr3yfGKTmZ9p6rN9Xpxwkw0yHz5+pqDPZ2YvWWPuWfa+mfvSBrNtyAjnbKvRGXojJEhj3tW/2Pxh8iWmKy7rrd9ZVF3hkGbdsR2pzp9qW6qrrcy9Q4daH5HEx9WJJJl+M0Imj0cieslAbMB8VtAv+TwpppTvNTuyU5v6JW5H+/CRZu5LG8w9y96PLRng9txFXydO10aGzwuagl6VSTpEp9W+e+otpqOo/yHKIWkK8QSp39sc0mVn7KGXDDc3Z4RUHtaprDfRzZ/Og8BNanm/SWU/vdJLbZTIi0rSY+n3deaiXkvMi5dIzD3VgxsFkIqSyOB5UWWiyiQtogv+XPLQ6+andT8wnwwe0VOU6alZd8XeFA5E30BO9SQy+tDzA7ftTvXN0MUbt6sHYLo9Ew8PnmwWGnPav277jT9yLfYUq3M6ln5eZyn2TNqGj3SuodPbech2z9MBt8pEfSbZIsy29yjGVZRxw/ZVzJoxmcXfOZ0rXljAssuuZu7S9cxd8BJ/PPVcVwPLoqNjHMc6ZNummypubdFJomlSKusaTbJ5osfyROV2AnqPmPPgZ8lWcsaessebN2OI3b+u/sW8MOdnPdfiMSMO555hNcxd8BJtwxJfZ8YY/+45p6SbM2ceIt/fv5jnL73K2V/jlHwzjbFaWcvAnARVJtkgJHHgrmhoSOhIPPL5p12FeUYu6r9ubee1D1rZ/mknf7xiFvvj6oj7HSacEdxm63V48JstW2IevptaP+X6J99l+xAXo7ydFFm8U9qYgwrFTaZfD2Gk2ch0G8kxZl1/IBgMYLAitFZcdwevnjKp51r8oOUfPe14/IKZdBb2P2SdAv7dc07Zlf/zPw+RFz7yMHUP3BQTzNCjKAsK2D2ikqZ5D8euP42xWlnPwOyAhgZnA69hgJlOmBiNh7ZGh+p27j/AW5t2YYAvHj2YmteXcubj9zN4xydIpvfBT9wce4djtnv4SOYueImykkJa93Ty9pY2AM577xW+8193U7Sv85BlAEthOSkFr9eShzDSbGS6TbWSaXSblq9r4dx3X+Z/PfNLhu1sievT2KQTeos/STBdVUfsrSpkgmsj09mPNTQ4THgJA0zUq/nWt6ybPxPmMg9tjX5zHVFawhePHsLA4sN49+PdbKn9Gi3vrbdK1yaqpxFWeqsDAo49mOcuvYoJf3qR+qlncsf/PJHHb/9XznvvFZaecBYrrrujp44GQ4daHzd1RLyGlHoII01rlHeKpJpjLLpNRxQfxisnn8P0Hz11iHmshy1beu8dRLDNTaaggBHHj2PU0t96evNPmowzqn2OOFwbXgdc+oUqk2zgJQ48kVki8mYZ1XX3zWbqoa3xF3X5wGLOOKac40aWhiO/mM/0HPOi8bxw9e10VR4VoxAGDyhi0n23ULptKwUYhu9q4d8X3cV5777ChrMv4JFFr/J/Fr8D27dbHzeK1uuYAg+FtjKdnBFSzzEW3aajhw3g084D1iDBYUcmnP9A2eCDZjRjKN22lfJZVyc1N4kxDNnezIXzbnMeN+ICV4rS4TzuHFbBvQteSli4KxtK3g2elYmI9BORd0TkeXt6jIi8ISIfishiESmy5f3t6Y32/6Oj1nGTLd8gIudGyWtt2UYRuTFKnnAbmcTTw9pLpbze3jj37qVrzk3+2Uw9tDUsF3U2iLdTv/mlWub8/FnWN+3qUQjnLLr/EFNW8f59fOv5+UCax8Zr1UUPFRWzkek21Rxj0W060A1fPHoIpx09hBcuuzrOU3cUAAAV2klEQVShn66ru7v33gEkfIkr3NfJ6Y/NpXVPJ+u2tvO7NU0pPQtcKcoE57erfzF/rp/leF9nQ8m7wbPPRERmATVAqTFmiogsAZ42xjwpIr8E3jXGPCQi3wWON8Z8W0S+CXzNGHOpiBwHPAGcCowEVgARQ98HwCSgEXgLuMwYs85pG8na6cVnknLajkSk6/fozYYKGBF+9Lu/+GczTbOtvhynDOOH7Rtc2qkd/BOR85X2scmmDy0AfKsYmeA4mW99C3E4JxLtM0py7s6b+3sASosP47iRZa7PoyufSVS7zZYttA87kj/922w2nH0B4Hxf+3VdJ8Ktz8STMhGRKmAhcCcwC7gAaAWONMYcEJHTgNuMMeeKyDL7959F5DDgE6AcuBHAGPNTe53LgNvsTdxmjDnXlt9ky+5y2kaytnpRJkGW90x0AcazfeiRLPyv34eiBGgmL2qv+Knsbvj1u5y5ejn//Ni9PQ+9P1xxPa+eMungMXd4Edg5rIKF//X7wI5NmM9RpnHt4Hc4d61DRvCtW5YAcEr1IIYdUZzSsyAVRRmW0r7ZcsDfB3wPiKj0oUCbMSZi62gEKu3flcDHAPb/7fb8PfK4ZZzkybYRg4jMFJFVIrKqtbU13X30xcGVrplsyTFn8NxVt7F7+Ei6sUIko+nqX0zDhf8eGvNS6OqwROFnaOtZq1f0+EMitvdJ993CWatXHJzJwSQ15L6fBXZswhJGGhSuzWgJzt3+/sU8NLme0uLDehQJpPYsqLxyOqUtTUh3N6UtTUl7XLlmNk5bmYjIFGCbMWZ1tDjBrKaX//ySHyo0Zr4xpsYYU1NeXp5oFld4PalebuCmtg4+OvdiHln0KlPue40HrriV1iEj6MZKM778uh/zl3/5aihspmHHz6iXRP6Qon2dnLPo/oMCl/6JbA44y8ZYkTDjOlV/gnNX+MjDDPy3aRw3sqxHkUBmyg1Hxmqt/PsOPtr+aU7c115qwH8ZuFBEzgeKgVKsnsogETnM7jlUAZE+ZSNwFNBom7nKgJ1R8gjRyySSb0+yjYzgtV6z27rYiYiuUx0JfXzl5HMoLuzHxKOH0t6xny+UFFI7fkSM6eLSCVWh6hWEAT9rfhdtbXQnr6tL6s/wq/a6W5raOqgoi30zDyKMNEgqr5wOtvIotT8JSXDuau3zBZmp3R59PYyrKGVAYT82fPIpe/d/xnEVZaG+r9PumRhjbjLGVBljRgPfBF4xxtQBrwLfsGebBjxj/37Wnsb+/xU778uzwDftaK8xwFjgTSyH+1g7cqvI3saz9jJO28gIXqNYvLwRO4U+Hl0+IOZNJczmpbDga9SLT2m/s91TiO9lt+7p5LUPWlm3dXdgaThyiWyXGx5TfgQTPzeU4+z7O8z3tZeeiRNzgCdF5A7gHeARW/4I8H9FZCNWj+SbAMaYv9rRWeuAA8CVxpjPAETkKmAZ0A941Bjz1162kTHGVZSlfSK9vBFHLt6la1v4dN8Bvnj0EATY95lh+MDCUL+phI3oY+m5B3fnndY4hOjw0TRSxGS7pxDdy47PUpDpXlG+4OVZ0Bu53HPUdCpZwG0UUV+OsslJokJPu0ZWsWLqtbxyyjnO5y5BqOq9wydkPVIwcp29tO4TivoVML6ytMcHkK0oRSUxQUaOOpGV0OBcIuiyvb0pilwYo6EkxtW5i4yojuvJNP3H/dwzrCaQ8x6W0NNk9LUXrDA+B9wqk0yYuZQE9NY1Xrq2hdPfXMakRQ/0xKAvn3oNS4d8LZQ3T6o3eU49FFIcFOgqwMIhW2/lz+5g5p/fDSR4ws+AhEyQ7eCEMOCrKTbLqDIJCcOf+w0Xzr+dQjvctHTbVi6cdxtL9nfDpJt6WTq7pHqT59RDoaGB7hkzKOiwbdSbN1vT4KhQXNm5k6SpX7q2JRDl6jVKMdN4iYLMZTLpk8kkmugxJFyw+MEeRRKhcF8nFyx+MKAWOZNqBFIujW3omnPTQUViU9DRQdccZ4XuahySQ5TX7vKKwAYOZiPXlhfCkg1XcYcqk5CQaurtIEn1Js+lh0Khw/gRJzm4DDl2GFH9+hWzAlWugYaU91J9NNdGgGeckFdrVWUSEsThzdVJnhY+XYyp3uRZeyj4sH+7hiZOXe4kB5dv+FEjqo0I7cNHsvy6O3oS+KWiXMNQotUzLqqPhiUbbijIgWqtqkzCgtfU4r09SH28GFO9ybPyUPBp/1bWz6YrLnV5V/9iVtbPTrqcqzd8u9DWfcve594FL/UoEnCvXH3PrRXU266L8sFJlXTI39J9x0O55WyhocFhIt3U4g5hpzF5oLyWe40j29FcvS7v0/6tb25n1U/mccHiBynb/gntw47kuUuvoub7V/pmAvIS/unrOAQ3102m8FA+ONB2B4WX4+URHWcSR04ok3Rx8yB1uhjBuiBDXBfD1cPXx5stG2HM6W4jfmxI655ONm77lJY9+7j4xMrY9fT2cuLzC0ZKeNl2kO0OigD3WceZ9CXc1AUfNcq5yFa0WQhCp1BchYg67V8aPqdshGamu43osSGtezp5e0sbACMG9o8NuX7l+di390TnN8l1k3GF6iUdjZvrPdfoTfH7lL4nk6jPxC0ONtqmeQ+ze0QlpqCA3SMqD60lnQ3cJB1M5JOJJ2Q22AiuosG8+pxyhGj/08Ztn/bIx444IjYqzI2N3eG66RpZlfmaJx7KB/uVZDNr+OHP9HK8soSaudzgYKPdecllHPFEQ+9lODNNovaJWBdmdfXBt5zot59kJq8M22BTxbWfIAPlbMM4cj/Spt+taWLEwP6MHXFET26tnnQol57Uu9nP4bp+4erbefNLtaHKDxVDLvlM4gfBAt0lJRQsWJAxf6bfqM8kDk/KxOFkdxf0o6D7s0Pkh5QAzQaRB+nmzQcVSYREN1rIL+BogspXFL/d0ct+x5mP38/gHZ9YIdsB+5iSKtkZk92d3wQK+Iai8b2XJA6aDLw4ZIKuqlEUNX18qLzyKIoabbNcgM51N6gyicOTMnE42Y5lH0WQoC4Ct0oil97uCKaHEP2w/vzLzzHpvh/EZikI+HglU7KlT/+a8llXp9VrfuHGn3POfT+IqSTZ1b+YFdfdwfl3JQ+RVmIxBQVIomdH9DMi5C922aoB3zdwsMWagn4J5XvKKzLZmuS4dU7mgA02mnGvPM/1MyZzz6Uncf2MyZaDOcNE+2pOf2zuIelugvYxJRuHseSYM3juqttiytM+d9VtLDnmjF7X66okseIKV4Ng88Tfp9FcbnCIpGhz8Jns+eGPnEuBZpoUoprWnzWFpQsmxL7tZ6GJKRPfi8pS5Fl05JRjWpuAI4icosKa2jroPvdiHqn9Wo+s2xiaXYyyd12SWOmVlfWzE/byVtbP5vyIIHIN54DZLhnaM3GDw1v8kF89TOvcX8S8/WXd+R6Py7cczyOpszkCOaDRv9GRU7udepshjSDylMIm16KlQsyYa6ezeMYttJVXYERoK69g8YxbGHNt3DPCzo5Ad7f1nUCRhD2NjvpM8hEXzklPI6mz7W8J0EEZ8dUMf+43/GtUiQAgcJ9JMjwFLeSAPy2MUXZO+NHWIItmqQM+jj6lTFzgqcpeth2GYXFQ5kgEUQRPD7EQ72sYqxFmmiDL+eoI+D6Mm4eIpyp72R6BHJbRv3V1oXmgusHTSP4Q72tfLJrlqgBbwKgyyTPcVjX0VGXPx9QlrnDpoMwl04eSPrnwYE2V3q7dsJdYBnXAB4JfjrRE63Fb1dBTlb0gQhl7cVD6nppdCS35VjTLzbWbC7VdVJlkGb8eek7rWdfc7pjHKl75AAdrcGx7i3GnneAuOiuEY1RyqTSw4o1AHqwZjF50c+2GvcQyqJnLX+KdluefDy+8EGOaWTp8gi/2Xie7cVNbB3s6DxzSHS7qJ87mLzcZZuMJmU09H00fSmIiD9Zos9ClE6oy92DN8Dgnt9duNrJZe0GViV8kuuAeeujg//YFOLz+B3RP+XrMouk89JwuwNLiw2jv2N8zHfGFDCgscFZiycZxhEhhJCMXbMqKf2T1wZrh+yNfrl01c/lFogsunr17uWDxg4xe9jvqp57JdeceS/3UMxm97HcpXzhOduMvjEzcHd73mXFO454H9SFywaas5CgZvj/y5drVcSZ+kaySYRQG2F9U7Dltfaqx9r5kmA05mY7m0mixPkoWxjmF+drSQYtxZFyZOF1wcZh+/ZDPDk1bn86FmcoFmFT5xPtMIHQjnoOmLw6UU2xyICNAJtFBi9kmwcC6+BT1Xf2LD808GyGNLnMqduOkTss8STSXSfriQDnFRu8PV6jPxC/iwmV3DqtgzZTLYpJALr/ux+wa5mPCwBTDFcdVlB0MBZ50TOxD0EWiub6Mq9LBfpDNBJqKe/T+6BXtmfhJVLjswoiP4prbev5u79iPIJz/ix96Tw3ic7himG22YSArETfxJV43b7amIT8eXiHO96V4R3smGcIpQmPMtdP9GfDnY1p2HT3eO9mIuOmac1NMrXCAgo4Ouubc5Ns2AiPy8rN5sxWoEnn50Z5X3qAO+AyS0bd9H9OyB5mRNJfIdO/NVYnXXCUsmZ/7Aj73ANUBHwIyOrDKx2SLOnrcHZkeKLdr6JEM2X5oRcddQ49kSMa2miV6G6uhJjB/CKgqKaiZK3fxMdliviXOy1VW1s+mq3+sUo+UeM15klVvVBOYfwRUlRRUmeQuPiZbzJcRuLmO6xKvuUiyl58AH4B5R4DZLNL2mYjIUcDjwJFANzDfGHO/iAwBFgOjgY+AfzXG7BIRAe4Hzgf2Av/bGPO2va5pwA/sVd9hjFloy08BfgWUAC8A1xpjjNM2krVXKy0mJxTRXGrqCMd5yBRO5zfAssx5RwZ8U259Jhhj0voAFcDJ9u+BwAfAccB/ADfa8huBu+3f5wMvYo3jmwi8YcuHAH+3vwfbvwfb/70JnGYv8yJwni1PuI1kn1NOOcUoIWbRImMGDDDGeqxYnwEDLLmS31RXx573yKe6OuiW5R4ZuI+AVcaFTkjbzGWMaTZ2z8IYswdYD1QCFwEL7dkWAhfbvy8CHrfbtxIYJCIVwLnAcmPMTmP1LpYDtfZ/pcaYP9s79HjcuhJtQ8lV1NTRdwmi2Fq+EmCtIV+iuURkNHAS8AYwwhjTDJbCEZHh9myVwMdRizXasmTyxgRykmwjvl0zgZkAozJVUlbxhzzIXKykiaYr8ZeAag15ViYicgTwG+A6Y8xuyzWSeNYEsvj0VW7krjHGzAfmg+UzSWVZJctku668EirWnzWFpQsmxPqKgm6UkhKeorlEpBBLkTQYY562xS22iQr7e5stbwSOilq8Ctjai7wqgTzZNpRcRU0dfRbNwJAfpK1M7OisR4D1xpi5UX89C0yzf08DnomSXy4WE4F221S1DJgsIoNFZDAwGVhm/7dHRCba27o8bl2JtqHkKiGsK69kBzc10JXw48XM9WXgW8BfRGSNLfs+cBewRETqgS3AJfZ/L2BFdG3ECg2+AsAYs1NEfgy8Zc93uzFmp/37OxwMDX7R/pBkG0ouE7K68kp20AwM+UHaysQY80cS+zUAzk4wvwGudFjXo8CjCeSrgPEJ5DsSbUNRlNwjX2qg93V0BLyiKIGiGRjyA1UmiqIESqQKaFlJIc3tnZSVFGo55BxEswYrihI4mc7I3JcIKiWPKhMfyeu8SoriEb0/Mk8kzLqspDAmzDobPT01c/mExsorijN6f2SHIMOsVZn4hMbKK4ozen9kh6a2DgYWxxqcshVmrcrEJ4I8iYoSdvT+yA5BFrpTZeITWq1QUZzR+yM7BBlmrcrEJzRWXlGc0fsjOwQZZp12pcVcIxuVFjVaRVGc0fsjN3FbaVFDg31EY+UVxRm9P/IbNXMpiqIonlFloiiKonhGzVxKaFCbuqLkLtozUUKBjpBWlNxGlYkSCnSEtKLkNqpMlFCgI6QVJbdRZaKEAh0hrSi5jSoTJRToCGlFyW1UmSihQKvtKUpuo6HBSmjQEdKKkrtoz0RRFEXxjCoTRVEUxTOqTBRFURTPqDJRFEVRPKPKRFEURfFMnymOJSKtwGYfVjUM2O7DenKNvrjffXGfoW/ud1/cZ3C339XGmPLeVtRnlIlfiMgqN1XH8o2+uN99cZ+hb+53X9xn8He/1cylKIqieEaViaIoiuIZVSapMz/oBgREX9zvvrjP0Df3uy/uM/i43+ozURRFUTyjPRNFURTFM6pMUkBEakVkg4hsFJEbg25PJhCRo0TkVRFZLyJ/FZFrbfkQEVkuIh/a34ODbmsmEJF+IvKOiDxvT48RkTfs/V4sIkVBt9FPRGSQiDwlIu/b5/y0vnCuReR6+/peKyJPiEhxPp5rEXlURLaJyNooWcLzKxYP2M+390Tk5FS2pcrEJSLSD5gHnAccB1wmIscF26qMcACYbYwZB0wErrT380bgZWPMWOBlezofuRZYHzV9N3Cvvd+7gPpAWpU57geWGmOOBU7A2ve8PtciUglcA9QYY8YD/YBvkp/n+ldAbZzM6fyeB4y1PzOBh1LZkCoT95wKbDTG/N0Y0wU8CVwUcJt8xxjTbIx52/69B+vhUom1rwvt2RYCFwfTwswhIlXAV4GH7WkBzgKesmfJq/0WkVLgDOARAGNMlzGmjT5wrrHKb5SIyGHAAKCZPDzXxpjXgJ1xYqfzexHwuLFYCQwSkQq321Jl4p5K4OOo6UZblreIyGjgJOANYIQxphkshQMMD65lGeM+4HtAtz09FGgzxkTqCefbOT8aaAUes017D4vI4eT5uTbGNAH3AFuwlEg7sJr8PtfROJ1fT884VSbukQSyvA2FE5EjgN8A1xljdgfdnkwjIlOAbcaY1dHiBLPm0zk/DDgZeMgYcxLwD/LMpJUI20dwETAGGAkcjmXiiSefzrUbPF3vqkzc0wgcFTVdBWwNqC0ZRUQKsRRJgzHmaVvcEuny2t/bgmpfhvgycKGIfIRlwjwLq6cyyDaFQP6d80ag0Rjzhj39FJZyyfdzfQ6wyRjTaozZDzwNfIn8PtfROJ1fT884VSbueQsYa0d8FGE57J4NuE2+Y/sJHgHWG2PmRv31LDDN/j0NeCbbbcskxpibjDFVxpjRWOf2FWNMHfAq8A17trzab2PMJ8DHIvJ5W3Q2sI48P9dY5q2JIjLAvt4j+5235zoOp/P7LHC5HdU1EWiPmMPcoIMWU0BEzsd6W+0HPGqMuTPgJvmOiJwO/AH4Cwd9B9/H8pssAUZh3YyXGGPiHXt5gYh8BbjBGDNFRI7G6qkMAd4Bphpj9gXZPj8RkROxAg6KgL8DV2C9ZOb1uRaRHwGXYkUvvgNMx/IP5NW5FpEngK9gZQduAW4FfkeC82sr1gexor/2AlcYY1a53pYqE0VRFMUrauZSFEVRPKPKRFEURfGMKhNFURTFM6pMFEVRFM+oMlEURVE8o8pEURRF8YwqE0VRFMUzqkwURVEUz/x/epmlDpV4XFoAAAAASUVORK5CYII=\n", 737 | "text/plain": [ 738 | "
" 739 | ] 740 | }, 741 | "metadata": {}, 742 | "output_type": "display_data" 743 | } 744 | ], 745 | "source": [ 746 | "plt.scatter(range(0,len(y_val)),y_val,alpha=0.5,label='True')\n", 747 | "plt.scatter(range(0,len(tb_val_pred)),tb_val_pred,c='r',label='Pred')\n", 748 | "plt.legend()\n", 749 | "plt.show()" 750 | ] 751 | }, 752 | { 753 | "cell_type": "markdown", 754 | "metadata": {}, 755 | "source": [ 756 | "Able to draw a comparable result to sklearn Random Forest" 757 | ] 758 | }, 759 | { 760 | "cell_type": "code", 761 | "execution_count": 136, 762 | "metadata": {}, 763 | "outputs": [ 764 | { 765 | "data": { 766 | "text/plain": [ 767 | "count 98.000000\n", 768 | "mean 14091.515729\n", 769 | "std 13859.367850\n", 770 | "min 255.530303\n", 771 | "25% 5550.851892\n", 772 | "50% 10468.054696\n", 773 | "75% 18474.863616\n", 774 | "max 75983.185499\n", 775 | "dtype: float64" 776 | ] 777 | }, 778 | "execution_count": 136, 779 | "metadata": {}, 780 | "output_type": "execute_result" 781 | } 782 | ], 783 | "source": [ 784 | "# Take a look at prediction differences b/t 2 models\n", 785 | "pd.Series(abs(rf_val_pred-tb_val_pred)).describe()" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "# Calculate feature importance" 793 | ] 794 | }, 795 | { 796 | "cell_type": "markdown", 797 | "metadata": {}, 798 | "source": [ 799 | "Feature importance will be calculated by using **permutation importance**. There is an argument which is mentioned in this [doc](http://parrt.cs.usfca.edu/doc/rf-importance/index.html#3) that feature importance calculated in sklearn is misleading. SKlearn used **mean decrease in impurity (gini improtance)**, which does not always give an accurate picture of importance. \n", 800 | "\n", 801 | "On the other hand, **permutation importance** is calculated as follows: Record a baseline accuracy (classifier) or R2 score (regressor) by passing a validation set or the out-of-bag (OOB) samples through the Random Forest. Permute the column values of a single predictor feature and then pass all test samples back through the Random Forest and recompute the accuracy or R2. The importance of that feature is the difference between the baseline and the drop in overall accuracy or R2 caused by permuting the column. This is more expensive to calculate, but results are more reliable" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 138, 807 | "metadata": {}, 808 | "outputs": [], 809 | "source": [ 810 | "from model.utils import permutation_importances" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 139, 816 | "metadata": {}, 817 | "outputs": [], 818 | "source": [ 819 | "??permutation_importances" 820 | ] 821 | }, 822 | { 823 | "cell_type": "markdown", 824 | "metadata": {}, 825 | "source": [ 826 | "## Feature importance from Sklearn's RF model\n" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 143, 832 | "metadata": {}, 833 | "outputs": [], 834 | "source": [ 835 | "def metric(rf,X,y):\n", 836 | " y_pred = rf.predict(X)\n", 837 | " return metrics.r2_score(y,y_pred)" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": 144, 843 | "metadata": {}, 844 | "outputs": [ 845 | { 846 | "data": { 847 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABK0AAAJCCAYAAAAC+zS/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJzt3X+w3XV95/HXOwkYEIwLahdJl4DLD1sDyCZMp6wKawm2iLYrKtTdFndpSxGcbatb2nVn9kdny27ponUcf3RtaR2sFLpaf2xrtRXHX62Eyo8iCIIpDfaHxjaggBL47B+5oZcYkpOQe887nsdjJpPzPed7zn3fO585SZ75fr+nxhgBAAAAgE6WTHsAAAAAANieaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQzrJpD9DZ0572tLFq1appjwEAAADwHeP666//6hjj6bvaT7TaiVWrVmX9+vXTHgMAAADgO0ZV/eUk+zk9EAAAAIB2RCsAAAAA2hGtAAAAAGjHNa0AAACAmfDQQw9l48aNefDBB6c9ykxYvnx5Vq5cmf3222+Pni9aAQAAADNh48aNOfjgg7Nq1apU1bTH+Y42xsimTZuycePGHHnkkXv0Gk4PBAAAAGbCgw8+mEMPPVSwWgRVlUMPPfQJHdUmWgEAAAAzQ7BaPE/0Zy1aAQAAANCOa1oBAAAAM2nVJR/aq6+34dIzd7nPQQcdlK9//et79evuzIYNG/LpT386P/qjP7poX3NvcaQVAAAAwHegLVu2ZMOGDXn3u9897VH2iGgFAAAAsMiuvfbavOAFL8grXvGKHHPMMbnkkkty5ZVX5uSTT87q1atz5513JknOO++8XHDBBXne856XY445Jh/84AeTbL2o/Ktf/eqsXr06z33uc/Oxj30sSXLFFVfk5S9/ec4666ysW7cul1xyST7xiU/kxBNPzOWXX54NGzbkec97Xk466aScdNJJ+fSnP/3oPKeeemrOPvvsHHfccXnVq16VMUaS5Lrrrsv3f//354QTTsjJJ5+c++67Lw8//HBe//rXZ+3atTn++OPz9re/fa//jJweCAAAADAFN954Y2699dYccsghOeqoo3L++efns5/9bN70pjflzW9+c974xjcm2XqK38c//vHceeedOe200/LFL34xb3nLW5IkN998c2677basW7cut99+e5LkM5/5TG666aYccsghufbaa3PZZZc9Grvuv//+fOQjH8ny5ctzxx135Nxzz8369euTJJ/73Odyyy235JnPfGZOOeWUfOpTn8rJJ5+cV77ylbnqqquydu3a3HvvvTnggAPyzne+MytWrMh1112Xb37zmznllFOybt26HHnkkXvt5yNaAQAAAEzB2rVrc9hhhyVJnvWsZ2XdunVJktWrVz965FSSvOIVr8iSJUty9NFH56ijjsptt92WT37yk7n44ouTJMcdd1yOOOKIR6PV6aefnkMOOWSHX/Ohhx7KRRddlBtuuCFLly599DlJcvLJJ2flypVJkhNPPDEbNmzIihUrcthhh2Xt2rVJkqc85SlJkj/6oz/KTTfdlGuuuSZJsnnz5txxxx2iFQAAAMC+7klPetKjt5csWfLo9pIlS7Jly5ZHH6uqxzyvqh49dW9HnvzkJz/uY5dffnm+67u+KzfeeGMeeeSRLF++fIfzLF26NFu2bMkY49u+fpKMMfLmN785Z5xxxk6+wyfGNa0AAAAAGrv66qvzyCOP5M4778xdd92VY489Ns9//vNz5ZVXJkluv/323H333Tn22GO/7bkHH3xw7rvvvke3N2/enMMOOyxLlizJu971rjz88MM7/drHHXdcvvzlL+e6665Lktx3333ZsmVLzjjjjLz1rW/NQw899OgM3/jGN/bWt5zEkVYAAADAjNpw6ZnTHmEixx57bF7wghfkb//2b/O2t70ty5cvz4UXXpgLLrggq1evzrJly3LFFVc85kipbY4//vgsW7YsJ5xwQs4777xceOGFednLXparr746p5122k6PykqS/fffP1dddVUuvvjiPPDAAznggAPy0Y9+NOeff342bNiQk046KWOMPP3pT8/73ve+vfp9184OJ5t1a9asGdsuRgYAAADs22699dY8+9nPnvYYu+W8887Li1/84px99tnTHmWP7OhnXlXXjzHW7Oq5Tg8EAAAAoB2nBwIAAAA0dcUVV0x7hKlxpBUAAAAwM1wmafE80Z+1aAUAAADMhOXLl2fTpk3C1SIYY2TTpk1Zvnz5Hr+G0wMBAACAmbBy5cps3LgxX/nKV6Y9ykxYvnx5Vq5cucfPF60AAACAmbDffvvlyCOPnPYYTMjpgQAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtLNs2gN0dvM9m7Pqkg9NewwAAABgBm249MxpjzBVjrQCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGinTbSqqq/v4L5jq+raqrqhqm6tqndU1Rlz2zdU1der6gtzt3973vPeVFX3VNWSue1Xz3vOt6rq5rnbly7m9wgAAADAZJZNe4Bd+LUkl48xfj9Jqmr1GOPmJB+e2742yevGGOu3PWEuVP1Ikr9K8vwk144xfjPJb849viHJaWOMry7i9wEAAADAbmhzpNXjOCzJxm0bc8FqV05L8hdJ3prk3AWaCwAAAIAF1D1aXZ7kT6rqD6rqZ6rqqRM859wkv5PkvUleXFX7LeiEAAAAAOx1raPV3Gl9z05ydZJTk/xpVT3p8favqv2T/FCS940x7k3yZ0nW7c7XrKqfrKr1VbX+4fs37/HsAAAAAOy51tEqScYYXx5j/MYY46VJtiR5zk52f1GSFUlunrt21b/Mbp4iOMZ4xxhjzRhjzdIDV+zp2AAAAAA8Aa2jVVW9aNvpfVX1T5McmuSenTzl3CTnjzFWjTFWJTkyybqqOnDBhwUAAABgr+n06YEHVtXGedv/O8nKJG+qqgfn7nv9GONvdvTkuTB1RpKf2nbfGOMbVfXJJGcluWphxgYAAABgb2sTrcYYj3fU18/u5Dmnzrt9f5JDdrDPv95ue9WeTQgAAADAYml9eiAAAAAAs0m0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANpZNu0BOlt9+Iqsv/TMaY8BAAAAMHMcaQUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7Syb9gCd3XzP5qy65EPTHgMAgD204dIzpz0CALCHHGkFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANDORNGqqo6pqj+uqr+Y2z6+qt6wsKMBAAAAMKsmPdLq15P8QpKHkmSMcVOScxZqKAAAAABm26TR6sAxxme3u2/L3h4GAAAAAJLJo9VXq+pZSUaSVNXZSf56waYCAAAAYKYtm3C/1yR5R5LjquqeJF9K8qoFmwoAAACAmbbLaFVVS5KsGWP8QFU9OcmSMcZ9Cz8aAAAAALNql6cHjjEeSXLR3O1vCFYAAAAALLRJr2n1kap6XVV9d1Udsu3Xgk4GAAAAwMya9JpW/27u99fMu28kOWrvjgMAAAAAE0arMcaRCz0IAAAAAGwzUbSqqh/b0f1jjN/eu+MAAAAAwOSnB66dd3t5khcm+fMkohUAAAAAe92kpwdePH+7qlYkedeCTLQHqurhJDdn6/fzpST/dozxD1W1am77l8YY/3lu36cl+eskbx9jXDSdiQEAAADYmUk/PXB79yc5em8O8gQ9MMY4cYzxnCRfy2MvGH9XkhfP2355klsWczgAAAAAds+k17T6QLZ+WmCyNXR9T5KrF2qoJ+gzSY6ft/1Akluras0YY32SVyb53STPnMZwAAAAAOzapNe0umze7S1J/nKMsXEB5nlCqmpptl5v653bPfSeJOdU1d8keTjJlyNaAQAAALQ16emBPzTG+Pjcr0+NMTZW1f9c0Ml2zwFVdUOSTUkOSfKR7R7/wySnJzk3yVU7e6Gq+smqWl9V6x++f/OCDAsAAADAzk0arU7fwX0/uDcHeYIeGGOcmOSIJPvnsde0yhjjW0muT/JzSX5vZy80xnjHGGPNGGPN0gNXLNS8AAAAAOzETk8PrKqfTnJhkqOq6qZ5Dx2c5FMLOdieGGNsrqrXJvn9qnrrdg//apKPjzE2VdUUpgMAAABgUru6ptW7k/xBkl9Ocsm8++8bY3xtwaZ6AsYYn6uqG5Ock+QT8+6/JT41EAAAAGCfsNNoNcbYnGRztl4LKlX1jCTLkxxUVQeNMe5e+BF3bYxx0HbbZ83bfM4O9r8iyRULOxUAAAAAe2qia1pV1VlVdUeSLyX5eJIN2XoEFgAAAADsdZNeiP2XknxfktvHGEcmeWEaXtMKAAAAgO8Mk0arh8YYm5IsqaolY4yPJTlxAecCAAAAYIbt6kLs2/xDVR2UrRc2v7Kq/i7JloUbCwAAAIBZNumRVi9Ncn+S/5DkD5PcmeSsnT4DAAAAAPbQREdajTG+UVVHJDl6jPFbVXVgkqULOxoAAAAAs2rSTw/8iSTXJHn73F2HJ3nfQg0FAAAAwGyb9PTA1yQ5Jcm9STLGuCPJMxZqKAAAAABm26TR6ptjjG9t26iqZUnGwowEAAAAwKybNFp9vKp+MckBVXV6kquTfGDhxgIAAABglk0arS5J8pUkNyf5qST/L8kbFmooAAAAAGbbTj89sKr+2Rjj7jHGI0l+fe4XAAAAACyoXR1p9egnBFbV7y3wLAAAAACQZNfRqubdPmohBwEAAACAbXYVrcbj3AYAAACABbPTa1olOaGq7s3WI64OmLudue0xxnjKgk4HAAAAwEzaabQaYyxdrEEAAAAAYJtdnR4IAAAAAItOtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2lk17gM5WH74i6y89c9pjAAAAAMwcR1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDvLpj1AZzffszmrLvnQtMeA7ygbLj1z2iMAAACwD3CkFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO0sWLSqqoer6oaq+ouqurqqDp/bvqGq/qaq7pm3vf92+3+gqp663ev9TFU9WFUr5rbPmPf8r1fVF+Zu/3ZVnVpVH5z33B+uqpuq6raqurmqfnihvm8AAAAAnriFPNLqgTHGiWOM5yT5VpJXzm2fmORtSS7ftj3G+NZ2+38tyWu2e71zk1yX5EeSZIzx4Xmvtz7Jq+a2f2z+k6rqhCSXJXnpGOO4JC9JcllVHb9w3zoAAAAAT8RinR74iST/fDf2/0ySw7dtVNWzkhyU5A3ZGq92x+uS/I8xxpeSZO73X07y+t18HQAAAAAWyYJHq6paluQHk9w84f5Lk7wwyfvn3X1ukt/J1vh1bFU9YzdG+N4k12933/q5+3f09X+yqtZX1fqH79+8G18GAAAAgL1lIaPVAVV1Q7YGoruTvHPC/TclOSTJR+Y9dk6S94wxHknyf5O8fDfmqCRjgvuSJGOMd4wx1owx1iw9cMVufBkAAAAA9pZlC/jaD8xdb2q39p+70PoHs/WaVr82d+2po5N8pKqSZP8kdyV5y4Sve0uSNUlumnffSUk+vxuzAQAAALCIFuuaVhMbY2xO8tokr6uq/bL11MD/MsZYNffrmUkOr6ojJnzJy5L8QlWtSpK5338xya/u5dEBAAAA2EvaRaskGWN8LsmN2Xpa4DlJ3rvdLu+du3+S17ohyc8n+UBV3ZbkA0n+49z9AAAAADRUY+zw0k4kedJhR4/DfvyN0x4DvqNsuPTMaY8AAADAFFXV9WOMNbvar+WRVgAAAADMNtEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaEe0AgAAAKAd0QoAAACAdkQrAAAAANoRrQAAAABoR7QCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADaEa0AAAAAaGfZtAfobPXhK7L+0jOnPQYAAADAzHGkFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7ohUAAAAA7YhWAAAAALQjWgEAAADQjmgFAAAAQDuiFQAAAADtiFYAAAAAtCNaAQAAANCOaAUAAABAO6IVAAAAAO2IVgAAAAC0I1oBAAAA0I5oBQAAAEA7NcaY9gxtVdV9Sb4w7TlgJ56W5KvTHgJ2wTplX2Cd0p01yr7AOmVfYJ32cMQY4+m72mnZYkyyD/vCGGPNtIeAx1NV661RurNO2RdYp3RnjbIvsE7ZF1in+xanBwIAAADQjmgFAAAAQDui1c69Y9oDwC5Yo+wLrFP2BdYp3Vmj7AusU/YF1uk+xIXYAQAAAGjHkVYAAAAAtDPz0aqqXlRVX6iqL1bVJTt4/ElVddXc439WVasWf0pm3QTr9PlV9edVtaWqzp7GjDDBOv3Zqvp8Vd1UVX9cVUdMY05m1wRr9IKqurmqbqiqT1bV90xjTmbbrtbpvP3OrqpRVT4Bi0U3wfvpeVX1lbn30xuq6vxpzMnsmuS9tKpeMfd301uq6t2LPSOTmenTA6tqaZLbk5yeZGOS65KcO8b4/Lx9Lkxy/Bjjgqo6J8mPjDFeOZWBmUkTrtNVSZ6S5HVJ3j/GuGbxJ2WWTbhOT0vyZ2OM+6vqp5Oc6v2UxTLhGn3KGOPeudsvSXLhGONF05iX2TTJOp3b7+AkH0qyf5KLxhjrF3tWZteE76fnJVkzxrhoKkMy0yZco0cn+d0k/2qM8fdV9Ywxxt9NZWB2ataPtDo5yRfHGHeNMb6V5D1JXrrdPi9N8ltzt69J8sKqqkWcEXa5TscYG8YYNyV5ZBoDQiZbpx8bY9w/t/mnSVYu8ozMtknW6L3zNp+cZHb/Z49pmeTvpkny35P8ryQPLuZwMGfSdQrTMska/Ykkbxlj/H2SCFZ9zXq0OjzJX83b3jh33w73GWNsSbI5yaGLMh1sNck6hWnb3XX675P8wYJOBI810RqtqtdU1Z3ZGgReu0izwTa7XKdV9dwk3z3G+OBiDgbzTPpn/svmLglwTVV99+KMBkkmW6PHJDmmqj5VVX9aVY6sbmrWo9WOjpja/n9VJ9kHFpI1yL5g4nVaVf8myZokv7KgE8FjTbRGxxhvGWM8K8nPJ3nDgk8Fj7XTdVpVS5JcnuTnFm0i+HaTvJ9+IMmqMcbxST6afzxzBRbDJGt0WZKjk5ya5Nwk/6eqnrrAc7EHZj1abUwyv/qvTPLlx9unqpYlWZHka4syHWw1yTqFaZtonVbVDyT5T0leMsb45iLNBsnuv5e+J8kPL+hE8O1MI/39AAABoklEQVR2tU4PTvKcJNdW1YYk35fk/S7GziLb5fvpGGPTvD/nfz3Jv1ik2SCZ/N/5vz/GeGiM8aUkX8jWiEUzsx6trktydFUdWVX7Jzknyfu32+f9SX587vbZSf5kzPLV65mGSdYpTNsu1+ncKS1vz9Zg5boBLLZJ1uj8v6yemeSORZwPkl2s0zHG5jHG08YYq8YYq7L1+oAvcSF2Ftkk76eHzdt8SZJbF3E+mOTfT+9LclqSVNXTsvV0wbsWdUomsmzaA0zTGGNLVV2U5MNJlib5jTHGLVX135KsH2O8P8k7k7yrqr6YrUdYnTO9iZlFk6zTqlqb5L1J/kmSs6rqv44xvneKYzNjJnw//ZUkByW5eu7zLO4eY7xkakMzUyZcoxfNHQ34UJK/zz/+pxUsignXKUzVhOv0tXOfwrolW/8Ndd7UBmbmTLhGP5xkXVV9PsnDSV4/xtg0val5POWgIQAAAAC6mfXTAwEAAABoSLQCAAAAoB3RCgAAAIB2RCsAAAAA2hGtAAAAAGhHtAIAAACgHdEKAAAAgHZEKwAAAADa+f+VFFVJ+9cdTwAAAABJRU5ErkJggg==\n", 848 | "text/plain": [ 849 | "
" 850 | ] 851 | }, 852 | "metadata": {}, 853 | "output_type": "display_data" 854 | } 855 | ], 856 | "source": [ 857 | "fea_imp_df = permutation_importances(rf,X_val,y_val,metric,False)" 858 | ] 859 | }, 860 | { 861 | "cell_type": "markdown", 862 | "metadata": {}, 863 | "source": [ 864 | "## Feature importance from 'scratch' RF" 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "execution_count": 145, 870 | "metadata": {}, 871 | "outputs": [ 872 | { 873 | "data": { 874 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABK4AAAJCCAYAAADpzI+8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJzt3X+w5XV93/HXe1lwQXBTQFNkUxYsP0xdQLowmVAVagATRJOKCLFNMCUJQXCaRBuS2plO22loQ4rGcRRTGxIHI4E0xh9NDCbC+CuRJfIjCILgxizmh67JggLKwqd/7IVeN8vuWfaee9659/GY2dnzPfd7zn3fO585d+9zv9/vqTFGAAAAAKCbFbMeAAAAAAB2RLgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaGnlrAfo7OCDDx5r166d9RgAAAAAS8bNN9/81THGsyfZV7jaibVr12bDhg2zHgMAAABgyaiqP590X6cKAgAAANCScAUAAABAS8IVAAAAAC25xhUAAACwLDz66KPZtGlTHnnkkVmPsiysWrUqa9asyd577/20n0O4AgAAAJaFTZs25YADDsjatWtTVbMeZ0kbY2Tz5s3ZtGlTDj/88Kf9PE4VBAAAAJaFRx55JAcddJBotQiqKgcddNAeH90mXAEAAADLhmi1eBbiey1cAQAAANCSa1wBAAAAy9LaSz+8oM+38bIzd7nP/vvvn69//esL+nl3ZuPGjfnUpz6VH/7hH160z7mQHHEFAAAAsARt3bo1GzduzHvf+95Zj/K0CVcAAAAAi+yGG27IS17ykpxzzjk56qijcumll+bqq6/OSSedlHXr1uXee+9Nkpx//vm58MIL86IXvShHHXVUPvShDyXZdqH5173udVm3bl1e+MIX5mMf+1iS5KqrrsqrX/3qnHXWWTn99NNz6aWX5uMf/3iOP/74XHHFFdm4cWNe9KIX5YQTTsgJJ5yQT33qU0/Oc8opp+Tss8/OMccck9e+9rUZYyRJbrrppnzv935vjjvuuJx00kl58MEH89hjj+VNb3pTTjzxxBx77LG58sorp/J9cqogAAAAwAzceuutufPOO3PggQfmiCOOyAUXXJDPfOYzeetb35q3ve1tectb3pJk2+l+N954Y+69996ceuqp+cIXvpC3v/3tSZLbb789d911V04//fTcfffdSZJPf/rTue2223LggQfmhhtuyOWXX/5k8HrooYdy/fXXZ9WqVbnnnnty3nnnZcOGDUmSz372s7njjjvy3Oc+NyeffHI++clP5qSTTsprXvOaXHPNNTnxxBPzwAMPZN9998273/3urF69OjfddFO++c1v5uSTT87pp5+eww8/fEG/R8IVAAAAwAyceOKJOeSQQ5Ikz3ve83L66acnSdatW/fkEVRJcs4552TFihU58sgjc8QRR+Suu+7KJz7xiVxyySVJkmOOOSaHHXbYk+HqtNNOy4EHHrjDz/noo4/m4osvzi233JK99trrycckyUknnZQ1a9YkSY4//vhs3Lgxq1evziGHHJITTzwxSfKsZz0rSfIHf/AHue2223LdddclSbZs2ZJ77rlHuAIAAABYCp7xjGc8eXvFihVPbq9YsSJbt2598mNV9W2Pq6onT+PbkWc+85lP+bErrrgi3/md35lbb701jz/+eFatWrXDefbaa69s3bo1Y4y/9/mTZIyRt73tbTnjjDN28hXuOde4AgAAAGjs2muvzeOPP55777039913X44++ui8+MUvztVXX50kufvuu/OlL30pRx999N977AEHHJAHH3zwye0tW7bkkEMOyYoVK/Ke97wnjz322E4/9zHHHJMvf/nLuemmm5IkDz74YLZu3Zozzjgj73jHO/Loo48+OcM3vvGNhfqSn+SIKwAAAGBZ2njZmbMeYSJHH310XvKSl+Sv//qv8853vjOrVq3KRRddlAsvvDDr1q3LypUrc9VVV33bEVNPOPbYY7Ny5cocd9xxOf/883PRRRflVa96Va699tqceuqpOz06K0n22WefXHPNNbnkkkvy8MMPZ999981HP/rRXHDBBdm4cWNOOOGEjDHy7Gc/O+9///sX/GuvnR1attytX79+PHGBMgAAAOAftjvvvDPPf/7zZz3Gbjn//PPz8pe/PGefffasR3ladvQ9r6qbxxjrJ3m8UwUBAAAAaMmpggAAAABNXXXVVbMeYaYccQUAAAAsGy6ZtHgW4nstXAEAAADLwqpVq7J582bxahGMMbJ58+asWrVqj57HqYIAAADAsrBmzZps2rQpX/nKV2Y9yrKwatWqrFmzZo+eQ7gCAAAAloW99947hx9++KzHYDc4VRAAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJZWznqAzm6/f0vWXvrhWY8BAAAALEMbLztz1iPMnCOuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaahOuqurrO7jv6Kq6oapuqao7q+pdVXXG3PYtVfX1qvr83O3fmPe4t1bV/VW1Ym77dfMe862qun3u9mWL+TUCAAAAMLmVsx5gF34lyRVjjN9NkqpaN8a4PclH5rZvSPLGMcaGJx4wF6t+KMlfJHlxkhvGGL+W5NfmPr4xyaljjK8u4tcBAAAAwG5qc8TVUzgkyaYnNuai1a6cmuTPkrwjyXlTmgsAAACAKeserq5I8kdV9XtV9dNV9R0TPOa8JL+Z5HeSvLyq9p7qhAAAAABMRetwNXeK3/OTXJvklCR/XFXPeKr9q2qfJD+Q5P1jjAeS/EmS03fnc1bVT1TVhqra8NhDW5727AAAAADsmdbhKknGGF8eY/zvMcYrk2xN8oKd7P6yJKuT3D53Lat/kd08XXCM8a4xxvoxxvq99lv9dMcGAAAAYA+1DldV9bInTvWrqn+c5KAk9+/kIecluWCMsXaMsTbJ4UlOr6r9pj4sAAAAAAuq07sK7ldVm+Zt/88ka5K8taoembvvTWOMv9rRg+fi1BlJfvKJ+8YY36iqTyQ5K8k10xkbAAAAgGloE67GGE919NfP7OQxp8y7/VCSA3ewz7/abnvt05sQAAAAgMXU+lRBAAAAAJYv4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoaeWsB+hs3aGrs+GyM2c9BgAAAMCy5IgrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhp5awH6Oz2+7dk7aUfnvUYAADsxMbLzpz1CADAlDjiCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgpYnCVVUdVVV/WFV/Nrd9bFW9ebqjAQAAALCcTXrE1a8m+fkkjybJGOO2JOdOaygAAAAAmDRc7TfG+Mx2921d6GEAAAAA4AmThquvVtXzkowkqaqzk/zl1KYCAAAAYNlbOeF+r0/yriTHVNX9Sb6Y5LVTmwoAAACAZW+X4aqqViRZP8b4vqp6ZpIVY4wHpz8aAAAAAMvZLk8VHGM8nuTiudvfEK0AAAAAWAyTXuPq+qp6Y1V9V1Ud+MSfqU4GAAAAwLI26TWufmzu79fPu28kOWJhxwEAAACAbSYKV2OMw6c9CAAAAADMN1G4qqof2dH9Y4zfWNhxAAAAAGCbSU8VPHHe7VVJXprkT5MIVwAAAABMxaSnCl4yf7uqVid5z1Qmehqq6rEkt2fb1/PFJP9mjPF3VbV2bvu/jjH+49y+Byf5yyRXjjEuns3EAAAAAOzKpO8quL2Hkhy5kIPsoYfHGMePMV6Q5Gv59ovI35fk5fO2X53kjsUcDgAAAIDdN+k1rj6Ybe8imGyLXd+d5NppDbWHPp3k2HnbDye5s6rWjzE2JHlNkt9K8txZDAcAAADAZCa9xtXl825vTfLnY4xNU5hnj1TVXtl2/a13b/eh9yU5t6r+KsljSb4c4QoAAACgtUlPFfyBMcaNc38+OcbYVFX/faqT7Z59q+qWJJuTHJjk+u0+/vtJTktyXpJrdvZEVfUTVbWhqjY89tCWqQwLAAAAwK5NGq5O28F937+Qg+yhh8cYxyc5LMk++fZrXGWM8a0kNyf52SS/vbMnGmO8a4yxfoyxfq/9Vk9rXgAAAAB2YaenClbVTyW5KMkRVXXbvA8dkOST0xzs6RhjbKmqNyT53ap6x3Yf/uUkN44xNlfVDKYDAAAAYHfs6hpX703ye0l+Mcml8+5/cIzxtalNtQfGGJ+tqluTnJvk4/PuvyPeTRAAAADgH4ydhqsxxpYkW7Lt2lCpquckWZVk/6raf4zxpemPuGtjjP232z5r3uYLdrD/VUmumu5UAAAAAOyJia5xVVVnVdU9Sb6Y5MYkG7PtSCwAAAAAmIpJL87+X5N8T5K7xxiHJ3lpGl7jCgAAAIClY9Jw9egYY3OSFVW1YozxsSTHT3EuAAAAAJa5XV2c/Ql/V1X7Z9vFzq+uqr9JsnV6YwEAAACw3E16xNUrkzyU5N8l+f0k9yY5a6ePAAAAAIA9MNERV2OMb1TVYUmOHGP8elXtl2Sv6Y4GAAAAwHI26bsK/niS65JcOXfXoUneP62hAAAAAGDSUwVfn+TkJA8kyRjjniTPmdZQAAAAADBpuPrmGONbT2xU1cokYzojAQAAAMDk4erGqvqFJPtW1WlJrk3ywemNBQAAAMByN2m4ujTJV5LcnuQnk/zfJG+e1lAAAAAAsNN3FayqfzLG+NIY4/Ekvzr3BwAAAACmbldHXD35zoFV9dtTngUAAAAAnrSrcFXzbh8xzUEAAAAAYL5dhavxFLcBAAAAYKp2eo2rJMdV1QPZduTVvnO3M7c9xhjPmup0AAAAACxbOw1XY4y9FmsQAAAAAJhvV6cKAgAAAMBMCFcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQ0spZD9DZukNXZ8NlZ856DAAAAIBlyRFXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANDSylkP0Nnt92/J2ks/POsxYLdsvOzMWY8AAAAAC8IRVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQkXAEAAADQknAFAAAAQEvCFQAAAAAtCVcAAAAAtCRcAQAAANCScAUAAABAS8IVAAAAAC0JVwAAAAC0JFwBAAAA0JJwBQAAAEBLwhUAAAAALQlXAAAAALQ0tXBVVY9V1S1V9WdVdW1VHTq3fUtV/VVV3T9ve5/t9v9gVX3Hds/301X1SFWtnts+Y97jv15Vn5+7/RtVdUpVfWjeY3+wqm6rqruq6vaq+sFpfd0AAAAALIxpHnH18Bjj+DHGC5J8K8lr5raPT/LOJFc8sT3G+NZ2+38tyeu3e77zktyU5IeSZIzxkXnPtyHJa+e2f2T+g6rquCSXJ3nlGOOYJK9IcnlVHTu9Lx0AAACAPbVYpwp+PMk/3Y39P53k0Cc2qup5SfZP8uZsC1i7441J/tsY44tJMvf3LyZ5024+DwAAAACLaOrhqqpWJvn+JLdPuP9eSV6a5APz7j4vyW9mWwA7uqqesxsj/LMkN29334a5+3f0+X+iqjZU1YbHHtqyG58GAAAAgIU0zXC1b1Xdkm2R6EtJ3j3h/puTHJjk+nkfOzfJ+8YYjyf5P0levRtzVJIxwX1JkjHGu8YY68cY6/fab/VufBoAAAAAFtLKKT73w3PXn9qt/ecuvv6hbLvG1a/MXYvqyCTXV1WS7JPkviRvn/B570iyPslt8+47IcnndmM2AAAAABbZYl3jamJjjC1J3pDkjVW1d7adJvifxhhr5/48N8mhVXXYhE95eZKfr6q1STL39y8k+eUFHh0AAACABdQuXCXJGOOzSW7NtlMEz03yO9vt8jtz90/yXLck+bkkH6yqu5J8MMm/n7sfAAAAgKZqjB1e6okkzzjkyHHIj75l1mPAbtl42ZmzHgEAAACeUlXdPMZYP8m+LY+4AgAAAADhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhp5awH6Gzdoauz4bIzZz0GAAAAwLLkiCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAloQrAAAAAFoSrgAAAABoSbgCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABaEq4AAAAAaEm4AgAAAKAl4QoAAACAlmqMMesZ2qqqB5N8ftZzsCwdnOSrsx6CZcnaY1asPWbF2mNWrD1mxdpjVuavvcPGGM+e5EErpzfPkvD5Mcb6WQ/B8lNVG6w9ZsHaY1asPWbF2mNWrD1mxdpjVp7u2nOqIAAAAAAtCVcAAAAAtCRc7dy7Zj0Ay5a1x6xYe8yKtcesWHvMirXHrFh7zMrTWnsuzg4AAABAS464AgAAAKClZR+uquplVfX5qvpCVV26g48/o6qumfv4n1TV2sWfkqVogrX34qr606raWlVnz2JGlqYJ1t7PVNXnquq2qvrDqjpsFnOy9Eyw9i6sqtur6paq+kRVffcs5mTp2dXam7ff2VU1qsq7bbEgJnjdO7+qvjL3undLVV0wizlZeiZ53auqc+b+zXdHVb13sWdkaZrgde+Kea95d1fV3+3yOZfzqYJVtVeSu5OclmRTkpuSnDfG+Ny8fS5KcuwY48KqOjfJD40xXjOTgVkyJlx7a5M8K8kbk3xgjHHd4k/KUjPh2js1yZ+MMR6qqp9KcorXPfbUhGvvWWOMB+ZuvyLJRWOMl81iXpaOSdbe3H4HJPlwkn2SXDzG2LDYs7K0TPi6d36S9WOMi2cyJEvShGvvyCS/leRfjjH+tqqeM8b4m5kMzJIx6c/ceftfkuSFY4wf29nzLvcjrk5K8oUxxn1jjG8leV+SV263zyuT/Prc7euSvLSqahFnZGna5dobY2wcY9yW5PFZDMiSNcna+9gY46G5zT9OsmaRZ2RpmmTtPTBv85lJlu//rrGQJvn3XpL8lyT/I8kjizkcS9qkaw8W2iRr78eTvH2M8bdJIlqxQHb3de+8JL+5qydd7uHq0CR/MW9709x9O9xnjLE1yZYkBy3KdCxlk6w9mIbdXXv/NsnvTXUilouJ1l5Vvb6q7s22gPCGRZqNpW2Xa6+qXpjku8YYH1rMwVjyJv2Z+6q50/Ovq6rvWpzRWOImWXtHJTmqqj5ZVX9cVY5wZiFM/LvG3OVIDk/yR7t60uUernZ05NT2/7s7yT6wu6wrZmXitVdV/zrJ+iS/NNWJWC4mWntjjLePMZ6X5OeSvHnqU7Ec7HTtVdWKJFck+dlFm4jlYpLXvQ8mWTvGODbJR/P/z/SAPTHJ2luZ5Mgkp2TbUS//q6q+Y8rGLJUyAAACC0lEQVRzsfTtzu+55ya5bozx2K6edLmHq01J5v+vxpokX36qfapqZZLVSb62KNOxlE2y9mAaJlp7VfV9Sf5DkleMMb65SLOxtO3u6977kvzgVCdiudjV2jsgyQuS3FBVG5N8T5IPuEA7C2CXr3tjjM3zfs7+apJ/vkizsbRN+nvu744xHh1jfDHJ57MtZMGe2J1/752bCU4TTISrm5IcWVWHV9U+2faN+8B2+3wgyY/O3T47yR+N5XxFexbKJGsPpmGXa2/ulJkrsy1aud4BC2WStTf/H8xnJrlnEedj6drp2htjbBljHDzGWDvGWJtt1/Z7hYuzswAmed07ZN7mK5LcuYjzsXRN8rvG+5OcmiRVdXC2nTp436JOyVI00e+5VXV0kn+U5NOTPOmyDldz16y6OMlHsu2HxG+NMe6oqv88925GSfLuJAdV1ReS/EySp3wLZZjUJGuvqk6sqk1JXp3kyqq6Y3YTs1RM+Lr3S0n2T3Lt3NvUiqrssQnX3sVzb8l9S7b9zP3Rp3g6mNiEaw8W3IRr7w1zr3u3Ztt1/c6fzbQsJROuvY8k2VxVn0vysSRvGmNsns3ELBW78TP3vCTvm/SgoHLwEAAAAAAdLesjrgAAAADoS7gCAAAAoCXhCgAAAICWhCsAAAAAWhKuAAAAAGhJuAIAAACgJeEKAAAAgJaEKwAAAABa+n8DEXopmkujlgAAAABJRU5ErkJggg==\n", 875 | "text/plain": [ 876 | "
" 877 | ] 878 | }, 879 | "metadata": {}, 880 | "output_type": "display_data" 881 | } 882 | ], 883 | "source": [ 884 | "fea_imp_df = permutation_importances(tb,X_val,y_val,metric,False)" 885 | ] 886 | } 887 | ], 888 | "metadata": { 889 | "kernelspec": { 890 | "display_name": "Python 3", 891 | "language": "python", 892 | "name": "python3" 893 | }, 894 | "language_info": { 895 | "codemirror_mode": { 896 | "name": "ipython", 897 | "version": 3 898 | }, 899 | "file_extension": ".py", 900 | "mimetype": "text/x-python", 901 | "name": "python", 902 | "nbconvert_exporter": "python", 903 | "pygments_lexer": "ipython3", 904 | "version": "3.6.6" 905 | } 906 | }, 907 | "nbformat": 4, 908 | "nbformat_minor": 2 909 | } 910 | --------------------------------------------------------------------------------