├── Breast_cancer_data.csv ├── K-means Clustering.ipynb ├── KNN Classification + Regression.ipynb ├── README.md ├── airfoil_noise_data.csv ├── decision tree classification.ipynb ├── decision tree regression.ipynb ├── linear regression.ipynb ├── logistic regression.ipynb ├── naive bayes.ipynb └── sgd.ipynb /Breast_cancer_data.csv: -------------------------------------------------------------------------------- 1 | mean_radius,mean_texture,mean_perimeter,mean_area,mean_smoothness,diagnosis 2 | 17.99,10.38,122.8,1001.0,0.1184,0 3 | 20.57,17.77,132.9,1326.0,0.08474,0 4 | 19.69,21.25,130.0,1203.0,0.1096,0 5 | 11.42,20.38,77.58,386.1,0.1425,0 6 | 20.29,14.34,135.1,1297.0,0.1003,0 7 | 12.45,15.7,82.57,477.1,0.1278,0 8 | 18.25,19.98,119.6,1040.0,0.09463,0 9 | 13.71,20.83,90.2,577.9,0.1189,0 10 | 13.0,21.82,87.5,519.8,0.1273,0 11 | 12.46,24.04,83.97,475.9,0.1186,0 12 | 16.02,23.24,102.7,797.8,0.08206,0 13 | 15.78,17.89,103.6,781.0,0.0971,0 14 | 19.17,24.8,132.4,1123.0,0.0974,0 15 | 15.85,23.95,103.7,782.7,0.08401,0 16 | 13.73,22.61,93.6,578.3,0.1131,0 17 | 14.54,27.54,96.73,658.8,0.1139,0 18 | 14.68,20.13,94.74,684.5,0.09867,0 19 | 16.13,20.68,108.1,798.8,0.117,0 20 | 19.81,22.15,130.0,1260.0,0.09831,0 21 | 13.54,14.36,87.46,566.3,0.09779,1 22 | 13.08,15.71,85.63,520.0,0.1075,1 23 | 9.504,12.44,60.34,273.9,0.1024,1 24 | 15.34,14.26,102.5,704.4,0.1073,0 25 | 21.16,23.04,137.2,1404.0,0.09428,0 26 | 16.65,21.38,110.0,904.6,0.1121,0 27 | 17.14,16.4,116.0,912.7,0.1186,0 28 | 14.58,21.53,97.41,644.8,0.1054,0 29 | 18.61,20.25,122.1,1094.0,0.0944,0 30 | 15.3,25.27,102.4,732.4,0.1082,0 31 | 17.57,15.05,115.0,955.1,0.09847,0 32 | 18.63,25.11,124.8,1088.0,0.1064,0 33 | 11.84,18.7,77.93,440.6,0.1109,0 34 | 17.02,23.98,112.8,899.3,0.1197,0 35 | 19.27,26.47,127.9,1162.0,0.09401,0 36 | 16.13,17.88,107.0,807.2,0.104,0 37 | 16.74,21.59,110.1,869.5,0.0961,0 38 | 14.25,21.72,93.63,633.0,0.09823,0 39 | 13.03,18.42,82.61,523.8,0.08983,1 40 | 14.99,25.2,95.54,698.8,0.09387,0 41 | 13.48,20.82,88.4,559.2,0.1016,0 42 | 13.44,21.58,86.18,563.0,0.08162,0 43 | 10.95,21.35,71.9,371.1,0.1227,0 44 | 19.07,24.81,128.3,1104.0,0.09081,0 45 | 13.28,20.28,87.32,545.2,0.1041,0 46 | 13.17,21.81,85.42,531.5,0.09714,0 47 | 18.65,17.6,123.7,1076.0,0.1099,0 48 | 8.196,16.84,51.71,201.9,0.086,1 49 | 13.17,18.66,85.98,534.6,0.1158,0 50 | 12.05,14.63,78.04,449.3,0.1031,1 51 | 13.49,22.3,86.91,561.0,0.08752,1 52 | 11.76,21.6,74.72,427.9,0.08637,1 53 | 13.64,16.34,87.21,571.8,0.07685,1 54 | 11.94,18.24,75.71,437.6,0.08261,1 55 | 18.22,18.7,120.3,1033.0,0.1148,0 56 | 15.1,22.02,97.26,712.8,0.09056,0 57 | 11.52,18.75,73.34,409.0,0.09524,1 58 | 19.21,18.57,125.5,1152.0,0.1053,0 59 | 14.71,21.59,95.55,656.9,0.1137,0 60 | 13.05,19.31,82.61,527.2,0.0806,1 61 | 8.618,11.79,54.34,224.5,0.09752,1 62 | 10.17,14.88,64.55,311.9,0.1134,1 63 | 8.598,20.98,54.66,221.8,0.1243,1 64 | 14.25,22.15,96.42,645.7,0.1049,0 65 | 9.173,13.86,59.2,260.9,0.07721,1 66 | 12.68,23.84,82.69,499.0,0.1122,0 67 | 14.78,23.94,97.4,668.3,0.1172,0 68 | 9.465,21.01,60.11,269.4,0.1044,1 69 | 11.31,19.04,71.8,394.1,0.08139,1 70 | 9.029,17.33,58.79,250.5,0.1066,1 71 | 12.78,16.49,81.37,502.5,0.09831,1 72 | 18.94,21.31,123.6,1130.0,0.09009,0 73 | 8.888,14.64,58.79,244.0,0.09783,1 74 | 17.2,24.52,114.2,929.4,0.1071,0 75 | 13.8,15.79,90.43,584.1,0.1007,0 76 | 12.31,16.52,79.19,470.9,0.09172,1 77 | 16.07,19.65,104.1,817.7,0.09168,0 78 | 13.53,10.94,87.91,559.2,0.1291,1 79 | 18.05,16.15,120.2,1006.0,0.1065,0 80 | 20.18,23.97,143.7,1245.0,0.1286,0 81 | 12.86,18.0,83.19,506.3,0.09934,1 82 | 11.45,20.97,73.81,401.5,0.1102,1 83 | 13.34,15.86,86.49,520.0,0.1078,1 84 | 25.22,24.91,171.5,1878.0,0.1063,0 85 | 19.1,26.29,129.1,1132.0,0.1215,0 86 | 12.0,15.65,76.95,443.3,0.09723,1 87 | 18.46,18.52,121.1,1075.0,0.09874,0 88 | 14.48,21.46,94.25,648.2,0.09444,0 89 | 19.02,24.59,122.0,1076.0,0.09029,0 90 | 12.36,21.8,79.78,466.1,0.08772,1 91 | 14.64,15.24,95.77,651.9,0.1132,1 92 | 14.62,24.02,94.57,662.7,0.08974,1 93 | 15.37,22.76,100.2,728.2,0.092,0 94 | 13.27,14.76,84.74,551.7,0.07355,1 95 | 13.45,18.3,86.6,555.1,0.1022,1 96 | 15.06,19.83,100.3,705.6,0.1039,0 97 | 20.26,23.03,132.4,1264.0,0.09078,0 98 | 12.18,17.84,77.79,451.1,0.1045,1 99 | 9.787,19.94,62.11,294.5,0.1024,1 100 | 11.6,12.84,74.34,412.6,0.08983,1 101 | 14.42,19.77,94.48,642.5,0.09752,0 102 | 13.61,24.98,88.05,582.7,0.09488,0 103 | 6.981,13.43,43.79,143.5,0.117,1 104 | 12.18,20.52,77.22,458.7,0.08013,1 105 | 9.876,19.4,63.95,298.3,0.1005,1 106 | 10.49,19.29,67.41,336.1,0.09989,1 107 | 13.11,15.56,87.21,530.2,0.1398,0 108 | 11.64,18.33,75.17,412.5,0.1142,1 109 | 12.36,18.54,79.01,466.7,0.08477,1 110 | 22.27,19.67,152.8,1509.0,0.1326,0 111 | 11.34,21.26,72.48,396.5,0.08759,1 112 | 9.777,16.99,62.5,290.2,0.1037,1 113 | 12.63,20.76,82.15,480.4,0.09933,1 114 | 14.26,19.65,97.83,629.9,0.07837,1 115 | 10.51,20.19,68.64,334.2,0.1122,1 116 | 8.726,15.83,55.84,230.9,0.115,1 117 | 11.93,21.53,76.53,438.6,0.09768,1 118 | 8.95,15.76,58.74,245.2,0.09462,1 119 | 14.87,16.67,98.64,682.5,0.1162,0 120 | 15.78,22.91,105.7,782.6,0.1155,0 121 | 17.95,20.01,114.2,982.0,0.08402,0 122 | 11.41,10.82,73.34,403.3,0.09373,1 123 | 18.66,17.12,121.4,1077.0,0.1054,0 124 | 24.25,20.2,166.2,1761.0,0.1447,0 125 | 14.5,10.89,94.28,640.7,0.1101,1 126 | 13.37,16.39,86.1,553.5,0.07115,1 127 | 13.85,17.21,88.44,588.7,0.08785,1 128 | 13.61,24.69,87.76,572.6,0.09258,0 129 | 19.0,18.91,123.4,1138.0,0.08217,0 130 | 15.1,16.39,99.58,674.5,0.115,1 131 | 19.79,25.12,130.4,1192.0,0.1015,0 132 | 12.19,13.29,79.08,455.8,0.1066,1 133 | 15.46,19.48,101.7,748.9,0.1092,0 134 | 16.16,21.54,106.2,809.8,0.1008,0 135 | 15.71,13.93,102.0,761.7,0.09462,1 136 | 18.45,21.91,120.2,1075.0,0.0943,0 137 | 12.77,22.47,81.72,506.3,0.09055,0 138 | 11.71,16.67,74.72,423.6,0.1051,1 139 | 11.43,15.39,73.06,399.8,0.09639,1 140 | 14.95,17.57,96.85,678.1,0.1167,0 141 | 11.28,13.39,73.0,384.8,0.1164,1 142 | 9.738,11.97,61.24,288.5,0.0925,1 143 | 16.11,18.05,105.1,813.0,0.09721,0 144 | 11.43,17.31,73.66,398.0,0.1092,1 145 | 12.9,15.92,83.74,512.2,0.08677,1 146 | 10.75,14.97,68.26,355.3,0.07793,1 147 | 11.9,14.65,78.11,432.8,0.1152,1 148 | 11.8,16.58,78.99,432.0,0.1091,0 149 | 14.95,18.77,97.84,689.5,0.08138,1 150 | 14.44,15.18,93.97,640.1,0.0997,1 151 | 13.74,17.91,88.12,585.0,0.07944,1 152 | 13.0,20.78,83.51,519.4,0.1135,1 153 | 8.219,20.7,53.27,203.9,0.09405,1 154 | 9.731,15.34,63.78,300.2,0.1072,1 155 | 11.15,13.08,70.87,381.9,0.09754,1 156 | 13.15,15.34,85.31,538.9,0.09384,1 157 | 12.25,17.94,78.27,460.3,0.08654,1 158 | 17.68,20.74,117.4,963.7,0.1115,0 159 | 16.84,19.46,108.4,880.2,0.07445,1 160 | 12.06,12.74,76.84,448.6,0.09311,1 161 | 10.9,12.96,68.69,366.8,0.07515,1 162 | 11.75,20.18,76.1,419.8,0.1089,1 163 | 19.19,15.94,126.3,1157.0,0.08694,0 164 | 19.59,18.15,130.7,1214.0,0.112,0 165 | 12.34,22.22,79.85,464.5,0.1012,1 166 | 23.27,22.04,152.1,1686.0,0.08439,0 167 | 14.97,19.76,95.5,690.2,0.08421,1 168 | 10.8,9.71,68.77,357.6,0.09594,1 169 | 16.78,18.8,109.3,886.3,0.08865,0 170 | 17.47,24.68,116.1,984.6,0.1049,0 171 | 14.97,16.95,96.22,685.9,0.09855,1 172 | 12.32,12.39,78.85,464.1,0.1028,1 173 | 13.43,19.63,85.84,565.4,0.09048,0 174 | 15.46,11.89,102.5,736.9,0.1257,0 175 | 11.08,14.71,70.21,372.7,0.1006,1 176 | 10.66,15.15,67.49,349.6,0.08792,1 177 | 8.671,14.45,54.42,227.2,0.09138,1 178 | 9.904,18.06,64.6,302.4,0.09699,1 179 | 16.46,20.11,109.3,832.9,0.09831,0 180 | 13.01,22.22,82.01,526.4,0.06251,1 181 | 12.81,13.06,81.29,508.8,0.08739,1 182 | 27.22,21.87,182.1,2250.0,0.1094,0 183 | 21.09,26.57,142.7,1311.0,0.1141,0 184 | 15.7,20.31,101.2,766.6,0.09597,0 185 | 11.41,14.92,73.53,402.0,0.09059,1 186 | 15.28,22.41,98.92,710.6,0.09057,0 187 | 10.08,15.11,63.76,317.5,0.09267,1 188 | 18.31,18.58,118.6,1041.0,0.08588,0 189 | 11.71,17.19,74.68,420.3,0.09774,1 190 | 11.81,17.39,75.27,428.9,0.1007,1 191 | 12.3,15.9,78.83,463.7,0.0808,1 192 | 14.22,23.12,94.37,609.9,0.1075,0 193 | 12.77,21.41,82.02,507.4,0.08749,1 194 | 9.72,18.22,60.73,288.1,0.0695,1 195 | 12.34,26.86,81.15,477.4,0.1034,0 196 | 14.86,23.21,100.4,671.4,0.1044,0 197 | 12.91,16.33,82.53,516.4,0.07941,1 198 | 13.77,22.29,90.63,588.9,0.12,0 199 | 18.08,21.84,117.4,1024.0,0.07371,0 200 | 19.18,22.49,127.5,1148.0,0.08523,0 201 | 14.45,20.22,94.49,642.7,0.09872,0 202 | 12.23,19.56,78.54,461.0,0.09586,1 203 | 17.54,19.32,115.1,951.6,0.08968,0 204 | 23.29,26.67,158.9,1685.0,0.1141,0 205 | 13.81,23.75,91.56,597.8,0.1323,0 206 | 12.47,18.6,81.09,481.9,0.09965,1 207 | 15.12,16.68,98.78,716.6,0.08876,0 208 | 9.876,17.27,62.92,295.4,0.1089,1 209 | 17.01,20.26,109.7,904.3,0.08772,0 210 | 13.11,22.54,87.02,529.4,0.1002,1 211 | 15.27,12.91,98.17,725.5,0.08182,1 212 | 20.58,22.14,134.7,1290.0,0.0909,0 213 | 11.84,18.94,75.51,428.0,0.08871,1 214 | 28.11,18.47,188.5,2499.0,0.1142,0 215 | 17.42,25.56,114.5,948.0,0.1006,0 216 | 14.19,23.81,92.87,610.7,0.09463,0 217 | 13.86,16.93,90.96,578.9,0.1026,0 218 | 11.89,18.35,77.32,432.2,0.09363,1 219 | 10.2,17.48,65.05,321.2,0.08054,1 220 | 19.8,21.56,129.7,1230.0,0.09383,0 221 | 19.53,32.47,128.0,1223.0,0.0842,0 222 | 13.65,13.16,87.88,568.9,0.09646,1 223 | 13.56,13.9,88.59,561.3,0.1051,1 224 | 10.18,17.53,65.12,313.1,0.1061,1 225 | 15.75,20.25,102.6,761.3,0.1025,0 226 | 13.27,17.02,84.55,546.4,0.08445,1 227 | 14.34,13.47,92.51,641.2,0.09906,1 228 | 10.44,15.46,66.62,329.6,0.1053,1 229 | 15.0,15.51,97.45,684.5,0.08371,1 230 | 12.62,23.97,81.35,496.4,0.07903,1 231 | 12.83,22.33,85.26,503.2,0.1088,0 232 | 17.05,19.08,113.4,895.0,0.1141,0 233 | 11.32,27.08,71.76,395.7,0.06883,1 234 | 11.22,33.81,70.79,386.8,0.0778,1 235 | 20.51,27.81,134.4,1319.0,0.09159,0 236 | 9.567,15.91,60.21,279.6,0.08464,1 237 | 14.03,21.25,89.79,603.4,0.0907,1 238 | 23.21,26.97,153.5,1670.0,0.09509,0 239 | 20.48,21.46,132.5,1306.0,0.08355,0 240 | 14.22,27.85,92.55,623.9,0.08223,1 241 | 17.46,39.28,113.4,920.6,0.09812,0 242 | 13.64,15.6,87.38,575.3,0.09423,1 243 | 12.42,15.04,78.61,476.5,0.07926,1 244 | 11.3,18.19,73.93,389.4,0.09592,1 245 | 13.75,23.77,88.54,590.0,0.08043,1 246 | 19.4,23.5,129.1,1155.0,0.1027,0 247 | 10.48,19.86,66.72,337.7,0.107,1 248 | 13.2,17.43,84.13,541.6,0.07215,1 249 | 12.89,14.11,84.95,512.2,0.0876,1 250 | 10.65,25.22,68.01,347.0,0.09657,1 251 | 11.52,14.93,73.87,406.3,0.1013,1 252 | 20.94,23.56,138.9,1364.0,0.1007,0 253 | 11.5,18.45,73.28,407.4,0.09345,1 254 | 19.73,19.82,130.7,1206.0,0.1062,0 255 | 17.3,17.08,113.0,928.2,0.1008,0 256 | 19.45,19.33,126.5,1169.0,0.1035,0 257 | 13.96,17.05,91.43,602.4,0.1096,0 258 | 19.55,28.77,133.6,1207.0,0.0926,0 259 | 15.32,17.27,103.2,713.3,0.1335,0 260 | 15.66,23.2,110.2,773.5,0.1109,0 261 | 15.53,33.56,103.7,744.9,0.1063,0 262 | 20.31,27.06,132.9,1288.0,0.1,0 263 | 17.35,23.06,111.0,933.1,0.08662,0 264 | 17.29,22.13,114.4,947.8,0.08999,0 265 | 15.61,19.38,100.0,758.6,0.0784,0 266 | 17.19,22.07,111.6,928.3,0.09726,0 267 | 20.73,31.12,135.7,1419.0,0.09469,0 268 | 10.6,18.95,69.28,346.4,0.09688,1 269 | 13.59,21.84,87.16,561.0,0.07956,1 270 | 12.87,16.21,82.38,512.2,0.09425,1 271 | 10.71,20.39,69.5,344.9,0.1082,1 272 | 14.29,16.82,90.3,632.6,0.06429,1 273 | 11.29,13.04,72.23,388.0,0.09834,1 274 | 21.75,20.99,147.3,1491.0,0.09401,0 275 | 9.742,15.67,61.5,289.9,0.09037,1 276 | 17.93,24.48,115.2,998.9,0.08855,0 277 | 11.89,17.36,76.2,435.6,0.1225,1 278 | 11.33,14.16,71.79,396.6,0.09379,1 279 | 18.81,19.98,120.9,1102.0,0.08923,0 280 | 13.59,17.84,86.24,572.3,0.07948,1 281 | 13.85,15.18,88.99,587.4,0.09516,1 282 | 19.16,26.6,126.2,1138.0,0.102,0 283 | 11.74,14.02,74.24,427.3,0.07813,1 284 | 19.4,18.18,127.2,1145.0,0.1037,0 285 | 16.24,18.77,108.8,805.1,0.1066,0 286 | 12.89,15.7,84.08,516.6,0.07818,1 287 | 12.58,18.4,79.83,489.0,0.08393,1 288 | 11.94,20.76,77.87,441.0,0.08605,1 289 | 12.89,13.12,81.89,515.9,0.06955,1 290 | 11.26,19.96,73.72,394.1,0.0802,1 291 | 11.37,18.89,72.17,396.0,0.08713,1 292 | 14.41,19.73,96.03,651.0,0.08757,1 293 | 14.96,19.1,97.03,687.3,0.08992,1 294 | 12.95,16.02,83.14,513.7,0.1005,1 295 | 11.85,17.46,75.54,432.7,0.08372,1 296 | 12.72,13.78,81.78,492.1,0.09667,1 297 | 13.77,13.27,88.06,582.7,0.09198,1 298 | 10.91,12.35,69.14,363.7,0.08518,1 299 | 11.76,18.14,75.0,431.1,0.09968,0 300 | 14.26,18.17,91.22,633.1,0.06576,1 301 | 10.51,23.09,66.85,334.2,0.1015,1 302 | 19.53,18.9,129.5,1217.0,0.115,0 303 | 12.46,19.89,80.43,471.3,0.08451,1 304 | 20.09,23.86,134.7,1247.0,0.108,0 305 | 10.49,18.61,66.86,334.3,0.1068,1 306 | 11.46,18.16,73.59,403.1,0.08853,1 307 | 11.6,24.49,74.23,417.2,0.07474,1 308 | 13.2,15.82,84.07,537.3,0.08511,1 309 | 9.0,14.4,56.36,246.3,0.07005,1 310 | 13.5,12.71,85.69,566.2,0.07376,1 311 | 13.05,13.84,82.71,530.6,0.08352,1 312 | 11.7,19.11,74.33,418.7,0.08814,1 313 | 14.61,15.69,92.68,664.9,0.07618,1 314 | 12.76,13.37,82.29,504.1,0.08794,1 315 | 11.54,10.72,73.73,409.1,0.08597,1 316 | 8.597,18.6,54.09,221.2,0.1074,1 317 | 12.49,16.85,79.19,481.6,0.08511,1 318 | 12.18,14.08,77.25,461.4,0.07734,1 319 | 18.22,18.87,118.7,1027.0,0.09746,0 320 | 9.042,18.9,60.07,244.5,0.09968,1 321 | 12.43,17.0,78.6,477.3,0.07557,1 322 | 10.25,16.18,66.52,324.2,0.1061,1 323 | 20.16,19.66,131.1,1274.0,0.0802,0 324 | 12.86,13.32,82.82,504.8,0.1134,1 325 | 20.34,21.51,135.9,1264.0,0.117,0 326 | 12.2,15.21,78.01,457.9,0.08673,1 327 | 12.67,17.3,81.25,489.9,0.1028,1 328 | 14.11,12.88,90.03,616.5,0.09309,1 329 | 12.03,17.93,76.09,446.0,0.07683,1 330 | 16.27,20.71,106.9,813.7,0.1169,0 331 | 16.26,21.88,107.5,826.8,0.1165,0 332 | 16.03,15.51,105.8,793.2,0.09491,0 333 | 12.98,19.35,84.52,514.0,0.09579,1 334 | 11.22,19.86,71.94,387.3,0.1054,1 335 | 11.25,14.78,71.38,390.0,0.08306,1 336 | 12.3,19.02,77.88,464.4,0.08313,1 337 | 17.06,21.0,111.8,918.6,0.1119,0 338 | 12.99,14.23,84.08,514.3,0.09462,1 339 | 18.77,21.43,122.9,1092.0,0.09116,0 340 | 10.05,17.53,64.41,310.8,0.1007,1 341 | 23.51,24.27,155.1,1747.0,0.1069,0 342 | 14.42,16.54,94.15,641.2,0.09751,1 343 | 9.606,16.84,61.64,280.5,0.08481,1 344 | 11.06,14.96,71.49,373.9,0.1033,1 345 | 19.68,21.68,129.9,1194.0,0.09797,0 346 | 11.71,15.45,75.03,420.3,0.115,1 347 | 10.26,14.71,66.2,321.6,0.09882,1 348 | 12.06,18.9,76.66,445.3,0.08386,1 349 | 14.76,14.74,94.87,668.7,0.08875,1 350 | 11.47,16.03,73.02,402.7,0.09076,1 351 | 11.95,14.96,77.23,426.7,0.1158,1 352 | 11.66,17.07,73.7,421.0,0.07561,1 353 | 15.75,19.22,107.1,758.6,0.1243,0 354 | 25.73,17.46,174.2,2010.0,0.1149,0 355 | 15.08,25.74,98.0,716.6,0.1024,0 356 | 11.14,14.07,71.24,384.6,0.07274,1 357 | 12.56,19.07,81.92,485.8,0.0876,1 358 | 13.05,18.59,85.09,512.0,0.1082,1 359 | 13.87,16.21,88.52,593.7,0.08743,1 360 | 8.878,15.49,56.74,241.0,0.08293,1 361 | 9.436,18.32,59.82,278.6,0.1009,1 362 | 12.54,18.07,79.42,491.9,0.07436,1 363 | 13.3,21.57,85.24,546.1,0.08582,1 364 | 12.76,18.84,81.87,496.6,0.09676,1 365 | 16.5,18.29,106.6,838.1,0.09686,1 366 | 13.4,16.95,85.48,552.4,0.07937,1 367 | 20.44,21.78,133.8,1293.0,0.0915,0 368 | 20.2,26.83,133.7,1234.0,0.09905,0 369 | 12.21,18.02,78.31,458.4,0.09231,1 370 | 21.71,17.25,140.9,1546.0,0.09384,0 371 | 22.01,21.9,147.2,1482.0,0.1063,0 372 | 16.35,23.29,109.0,840.4,0.09742,0 373 | 15.19,13.21,97.65,711.8,0.07963,1 374 | 21.37,15.1,141.3,1386.0,0.1001,0 375 | 20.64,17.35,134.8,1335.0,0.09446,0 376 | 13.69,16.07,87.84,579.1,0.08302,1 377 | 16.17,16.07,106.3,788.5,0.0988,1 378 | 10.57,20.22,70.15,338.3,0.09073,1 379 | 13.46,28.21,85.89,562.1,0.07517,1 380 | 13.66,15.15,88.27,580.6,0.08268,1 381 | 11.08,18.83,73.3,361.6,0.1216,0 382 | 11.27,12.96,73.16,386.3,0.1237,1 383 | 11.04,14.93,70.67,372.7,0.07987,1 384 | 12.05,22.72,78.75,447.8,0.06935,1 385 | 12.39,17.48,80.64,462.9,0.1042,1 386 | 13.28,13.72,85.79,541.8,0.08363,1 387 | 14.6,23.29,93.97,664.7,0.08682,0 388 | 12.21,14.09,78.78,462.0,0.08108,1 389 | 13.88,16.16,88.37,596.6,0.07026,1 390 | 11.27,15.5,73.38,392.0,0.08365,1 391 | 19.55,23.21,128.9,1174.0,0.101,0 392 | 10.26,12.22,65.75,321.6,0.09996,1 393 | 8.734,16.84,55.27,234.3,0.1039,1 394 | 15.49,19.97,102.4,744.7,0.116,0 395 | 21.61,22.28,144.4,1407.0,0.1167,0 396 | 12.1,17.72,78.07,446.2,0.1029,1 397 | 14.06,17.18,89.75,609.1,0.08045,1 398 | 13.51,18.89,88.1,558.1,0.1059,1 399 | 12.8,17.46,83.05,508.3,0.08044,1 400 | 11.06,14.83,70.31,378.2,0.07741,1 401 | 11.8,17.26,75.26,431.9,0.09087,1 402 | 17.91,21.02,124.4,994.0,0.123,0 403 | 11.93,10.91,76.14,442.7,0.08872,1 404 | 12.96,18.29,84.18,525.2,0.07351,1 405 | 12.94,16.17,83.18,507.6,0.09879,1 406 | 12.34,14.95,78.29,469.1,0.08682,1 407 | 10.94,18.59,70.39,370.0,0.1004,1 408 | 16.14,14.86,104.3,800.0,0.09495,1 409 | 12.85,21.37,82.63,514.5,0.07551,1 410 | 17.99,20.66,117.8,991.7,0.1036,0 411 | 12.27,17.92,78.41,466.1,0.08685,1 412 | 11.36,17.57,72.49,399.8,0.08858,1 413 | 11.04,16.83,70.92,373.2,0.1077,1 414 | 9.397,21.68,59.75,268.8,0.07969,1 415 | 14.99,22.11,97.53,693.7,0.08515,1 416 | 15.13,29.81,96.71,719.5,0.0832,0 417 | 11.89,21.17,76.39,433.8,0.09773,1 418 | 9.405,21.7,59.6,271.2,0.1044,1 419 | 15.5,21.08,102.9,803.1,0.112,0 420 | 12.7,12.17,80.88,495.0,0.08785,1 421 | 11.16,21.41,70.95,380.3,0.1018,1 422 | 11.57,19.04,74.2,409.7,0.08546,1 423 | 14.69,13.98,98.22,656.1,0.1031,1 424 | 11.61,16.02,75.46,408.2,0.1088,1 425 | 13.66,19.13,89.46,575.3,0.09057,1 426 | 9.742,19.12,61.93,289.7,0.1075,1 427 | 10.03,21.28,63.19,307.3,0.08117,1 428 | 10.48,14.98,67.49,333.6,0.09816,1 429 | 10.8,21.98,68.79,359.9,0.08801,1 430 | 11.13,16.62,70.47,381.1,0.08151,1 431 | 12.72,17.67,80.98,501.3,0.07896,1 432 | 14.9,22.53,102.1,685.0,0.09947,0 433 | 12.4,17.68,81.47,467.8,0.1054,1 434 | 20.18,19.54,133.8,1250.0,0.1133,0 435 | 18.82,21.97,123.7,1110.0,0.1018,0 436 | 14.86,16.94,94.89,673.7,0.08924,1 437 | 13.98,19.62,91.12,599.5,0.106,0 438 | 12.87,19.54,82.67,509.2,0.09136,1 439 | 14.04,15.98,89.78,611.2,0.08458,1 440 | 13.85,19.6,88.68,592.6,0.08684,1 441 | 14.02,15.66,89.59,606.5,0.07966,1 442 | 10.97,17.2,71.73,371.5,0.08915,1 443 | 17.27,25.42,112.4,928.8,0.08331,0 444 | 13.78,15.79,88.37,585.9,0.08817,1 445 | 10.57,18.32,66.82,340.9,0.08142,1 446 | 18.03,16.85,117.5,990.0,0.08947,0 447 | 11.99,24.89,77.61,441.3,0.103,1 448 | 17.75,28.03,117.3,981.6,0.09997,0 449 | 14.8,17.66,95.88,674.8,0.09179,1 450 | 14.53,19.34,94.25,659.7,0.08388,1 451 | 21.1,20.52,138.1,1384.0,0.09684,0 452 | 11.87,21.54,76.83,432.0,0.06613,1 453 | 19.59,25.0,127.7,1191.0,0.1032,0 454 | 12.0,28.23,76.77,442.5,0.08437,1 455 | 14.53,13.98,93.86,644.2,0.1099,1 456 | 12.62,17.15,80.62,492.9,0.08583,1 457 | 13.38,30.72,86.34,557.2,0.09245,1 458 | 11.63,29.29,74.87,415.1,0.09357,1 459 | 13.21,25.25,84.1,537.9,0.08791,1 460 | 13.0,25.13,82.61,520.2,0.08369,1 461 | 9.755,28.2,61.68,290.9,0.07984,1 462 | 17.08,27.15,111.2,930.9,0.09898,0 463 | 27.42,26.27,186.9,2501.0,0.1084,0 464 | 14.4,26.99,92.25,646.1,0.06995,1 465 | 11.6,18.36,73.88,412.7,0.08508,1 466 | 13.17,18.22,84.28,537.3,0.07466,1 467 | 13.24,20.13,86.87,542.9,0.08284,1 468 | 13.14,20.74,85.98,536.9,0.08675,1 469 | 9.668,18.1,61.06,286.3,0.08311,1 470 | 17.6,23.33,119.0,980.5,0.09289,0 471 | 11.62,18.18,76.38,408.8,0.1175,1 472 | 9.667,18.49,61.49,289.1,0.08946,1 473 | 12.04,28.14,76.85,449.9,0.08752,1 474 | 14.92,14.93,96.45,686.9,0.08098,1 475 | 12.27,29.97,77.42,465.4,0.07699,1 476 | 10.88,15.62,70.41,358.9,0.1007,1 477 | 12.83,15.73,82.89,506.9,0.0904,1 478 | 14.2,20.53,92.41,618.4,0.08931,1 479 | 13.9,16.62,88.97,599.4,0.06828,1 480 | 11.49,14.59,73.99,404.9,0.1046,1 481 | 16.25,19.51,109.8,815.8,0.1026,0 482 | 12.16,18.03,78.29,455.3,0.09087,1 483 | 13.9,19.24,88.73,602.9,0.07991,1 484 | 13.47,14.06,87.32,546.3,0.1071,1 485 | 13.7,17.64,87.76,571.1,0.0995,1 486 | 15.73,11.28,102.8,747.2,0.1043,1 487 | 12.45,16.41,82.85,476.7,0.09514,1 488 | 14.64,16.85,94.21,666.0,0.08641,1 489 | 19.44,18.82,128.1,1167.0,0.1089,0 490 | 11.68,16.17,75.49,420.5,0.1128,1 491 | 16.69,20.2,107.1,857.6,0.07497,0 492 | 12.25,22.44,78.18,466.5,0.08192,1 493 | 17.85,13.23,114.6,992.1,0.07838,1 494 | 18.01,20.56,118.4,1007.0,0.1001,0 495 | 12.46,12.83,78.83,477.3,0.07372,1 496 | 13.16,20.54,84.06,538.7,0.07335,1 497 | 14.87,20.21,96.12,680.9,0.09587,1 498 | 12.65,18.17,82.69,485.6,0.1076,1 499 | 12.47,17.31,80.45,480.1,0.08928,1 500 | 18.49,17.52,121.3,1068.0,0.1012,0 501 | 20.59,21.24,137.8,1320.0,0.1085,0 502 | 15.04,16.74,98.73,689.4,0.09883,1 503 | 13.82,24.49,92.33,595.9,0.1162,0 504 | 12.54,16.32,81.25,476.3,0.1158,1 505 | 23.09,19.83,152.1,1682.0,0.09342,0 506 | 9.268,12.87,61.49,248.7,0.1634,1 507 | 9.676,13.14,64.12,272.5,0.1255,1 508 | 12.22,20.04,79.47,453.1,0.1096,1 509 | 11.06,17.12,71.25,366.5,0.1194,1 510 | 16.3,15.7,104.7,819.8,0.09427,1 511 | 15.46,23.95,103.8,731.3,0.1183,0 512 | 11.74,14.69,76.31,426.0,0.08099,1 513 | 14.81,14.7,94.66,680.7,0.08472,1 514 | 13.4,20.52,88.64,556.7,0.1106,0 515 | 14.58,13.66,94.29,658.8,0.09832,1 516 | 15.05,19.07,97.26,701.9,0.09215,0 517 | 11.34,18.61,72.76,391.2,0.1049,1 518 | 18.31,20.58,120.8,1052.0,0.1068,0 519 | 19.89,20.26,130.5,1214.0,0.1037,0 520 | 12.88,18.22,84.45,493.1,0.1218,1 521 | 12.75,16.7,82.51,493.8,0.1125,1 522 | 9.295,13.9,59.96,257.8,0.1371,1 523 | 24.63,21.6,165.5,1841.0,0.103,0 524 | 11.26,19.83,71.3,388.1,0.08511,1 525 | 13.71,18.68,88.73,571.0,0.09916,1 526 | 9.847,15.68,63.0,293.2,0.09492,1 527 | 8.571,13.1,54.53,221.3,0.1036,1 528 | 13.46,18.75,87.44,551.1,0.1075,1 529 | 12.34,12.27,78.94,468.5,0.09003,1 530 | 13.94,13.17,90.31,594.2,0.1248,1 531 | 12.07,13.44,77.83,445.2,0.11,1 532 | 11.75,17.56,75.89,422.9,0.1073,1 533 | 11.67,20.02,75.21,416.2,0.1016,1 534 | 13.68,16.33,87.76,575.5,0.09277,1 535 | 20.47,20.67,134.7,1299.0,0.09156,0 536 | 10.96,17.62,70.79,365.6,0.09687,1 537 | 20.55,20.86,137.8,1308.0,0.1046,0 538 | 14.27,22.55,93.77,629.8,0.1038,0 539 | 11.69,24.44,76.37,406.4,0.1236,1 540 | 7.729,25.49,47.98,178.8,0.08098,1 541 | 7.691,25.44,48.34,170.4,0.08668,1 542 | 11.54,14.44,74.65,402.9,0.09984,1 543 | 14.47,24.99,95.81,656.4,0.08837,1 544 | 14.74,25.42,94.7,668.6,0.08275,1 545 | 13.21,28.06,84.88,538.4,0.08671,1 546 | 13.87,20.7,89.77,584.8,0.09578,1 547 | 13.62,23.23,87.19,573.2,0.09246,1 548 | 10.32,16.35,65.31,324.9,0.09434,1 549 | 10.26,16.58,65.85,320.8,0.08877,1 550 | 9.683,19.34,61.05,285.7,0.08491,1 551 | 10.82,24.21,68.89,361.6,0.08192,1 552 | 10.86,21.48,68.51,360.5,0.07431,1 553 | 11.13,22.44,71.49,378.4,0.09566,1 554 | 12.77,29.43,81.35,507.9,0.08276,1 555 | 9.333,21.94,59.01,264.0,0.0924,1 556 | 12.88,28.92,82.5,514.3,0.08123,1 557 | 10.29,27.61,65.67,321.4,0.0903,1 558 | 10.16,19.59,64.73,311.7,0.1003,1 559 | 9.423,27.88,59.26,271.3,0.08123,1 560 | 14.59,22.68,96.39,657.1,0.08473,1 561 | 11.51,23.93,74.52,403.5,0.09261,1 562 | 14.05,27.15,91.38,600.4,0.09929,1 563 | 11.2,29.37,70.67,386.0,0.07449,1 564 | 15.22,30.62,103.4,716.9,0.1048,0 565 | 20.92,25.09,143.0,1347.0,0.1099,0 566 | 21.56,22.39,142.0,1479.0,0.111,0 567 | 20.13,28.25,131.2,1261.0,0.0978,0 568 | 16.6,28.08,108.3,858.1,0.08455,0 569 | 20.6,29.33,140.1,1265.0,0.1178,0 570 | 7.76,24.54,47.92,181.0,0.05263,1 571 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML_from_Scratch 2 | Trying to implement basic ml algorithms from scratch in python. I have also made videos explaining these algorithms... 3 | 4 | ## Gradient Descent 5 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/36zkIAAUcZ4/mqdefault.jpg)](https://youtu.be/36zkIAAUcZ4) 6 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/41BiBUZbg9U/mqdefault.jpg)](https://youtu.be/41BiBUZbg9U) 7 | 8 | 9 | ## Linear Regression 10 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/fnDO1s4fzi4/mqdefault.jpg)](https://youtu.be/fnDO1s4fzi4) 11 | 12 | ## Logistic Regression 13 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/NtjAeXppomA/mqdefault.jpg)](https://youtu.be/NtjAeXppomA) 14 | 15 | ## Stochastic Gradient Descent 16 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/V8InSDYHG4s/mqdefault.jpg)](https://youtu.be/V8InSDYHG4s) 17 | 18 | ## KNN 19 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/0RwM2BaLNkE/mqdefault.jpg)](https://youtu.be/0RwM2BaLNkE) 20 | 21 | ## K-means 22 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/IB9WfafBmjk/mqdefault.jpg)](https://youtu.be/IB9WfafBmjk) 23 | 24 | ## Decision Tree Classification 25 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/ZVR2Way4nwQ/mqdefault.jpg)](https://youtu.be/ZVR2Way4nwQ) 26 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/sgQAhG5Q7iY/mqdefault.jpg)](https://youtu.be/sgQAhG5Q7iY) 27 | 28 | 29 | ## Decision Tree Regression 30 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/UhY5vPfQIrA/mqdefault.jpg)](https://youtu.be/UhY5vPfQIrA) 31 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/P2ZB8c5Ha1Q/mqdefault.jpg)](https://youtu.be/P2ZB8c5Ha1Q) 32 | 33 | ## Naive Bayes Classification 34 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/lFJbZ6LVxN8/mqdefault.jpg)](https://youtu.be/lFJbZ6LVxN8) 35 | [![IMAGE ALT TEXT HERE](https://i.ytimg.com/vi/3I8oX3OUL6I/mqdefault.jpg)](https://youtu.be/3I8oX3OUL6I) 36 | 37 | ## Data source 38 | - airfoil_noise_data.csv (converted from the .dat file available at https://archive.ics.uci.edu/ml/datasets/airfoil+self-noise) 39 | 40 | Donor: 41 | Dr Roberto Lopez 42 | robertolopez '@' intelnics.com 43 | Intelnics 44 | 45 | Creators: 46 | Thomas F. Brooks, D. Stuart Pope and Michael A. Marcolini 47 | NASA 48 | 49 | - Breast_cancer_data.csv (taken from https://www.kaggle.com/merishnasuwal/breast-cancer-prediction-dataset) 50 | -------------------------------------------------------------------------------- /decision tree classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Import tools" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Get the data" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": { 31 | "scrolled": false 32 | }, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/html": [ 37 | "
\n", 38 | "\n", 51 | "\n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | "
sepal_lengthsepal_widthpetal_lengthpetal_widthtype
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20
55.43.91.70.40
64.63.41.40.30
75.03.41.50.20
84.42.91.40.20
94.93.11.50.10
\n", 145 | "
" 146 | ], 147 | "text/plain": [ 148 | " sepal_length sepal_width petal_length petal_width type\n", 149 | "0 5.1 3.5 1.4 0.2 0\n", 150 | "1 4.9 3.0 1.4 0.2 0\n", 151 | "2 4.7 3.2 1.3 0.2 0\n", 152 | "3 4.6 3.1 1.5 0.2 0\n", 153 | "4 5.0 3.6 1.4 0.2 0\n", 154 | "5 5.4 3.9 1.7 0.4 0\n", 155 | "6 4.6 3.4 1.4 0.3 0\n", 156 | "7 5.0 3.4 1.5 0.2 0\n", 157 | "8 4.4 2.9 1.4 0.2 0\n", 158 | "9 4.9 3.1 1.5 0.1 0" 159 | ] 160 | }, 161 | "execution_count": 2, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'type']\n", 168 | "data = pd.read_csv(\"iris.csv\", skiprows=1, header=None, names=col_names)\n", 169 | "data.head(10)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "## Node class" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 3, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "class Node():\n", 186 | " def __init__(self, feature_index=None, threshold=None, left=None, right=None, info_gain=None, value=None):\n", 187 | " ''' constructor ''' \n", 188 | " \n", 189 | " # for decision node\n", 190 | " self.feature_index = feature_index\n", 191 | " self.threshold = threshold\n", 192 | " self.left = left\n", 193 | " self.right = right\n", 194 | " self.info_gain = info_gain\n", 195 | " \n", 196 | " # for leaf node\n", 197 | " self.value = value" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "## Tree class" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 4, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "class DecisionTreeClassifier():\n", 214 | " def __init__(self, min_samples_split=2, max_depth=2):\n", 215 | " ''' constructor '''\n", 216 | " \n", 217 | " # initialize the root of the tree \n", 218 | " self.root = None\n", 219 | " \n", 220 | " # stopping conditions\n", 221 | " self.min_samples_split = min_samples_split\n", 222 | " self.max_depth = max_depth\n", 223 | " \n", 224 | " def build_tree(self, dataset, curr_depth=0):\n", 225 | " ''' recursive function to build the tree ''' \n", 226 | " \n", 227 | " X, Y = dataset[:,:-1], dataset[:,-1]\n", 228 | " num_samples, num_features = np.shape(X)\n", 229 | " \n", 230 | " # split until stopping conditions are met\n", 231 | " if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n", 232 | " # find the best split\n", 233 | " best_split = self.get_best_split(dataset, num_samples, num_features)\n", 234 | " # check if information gain is positive\n", 235 | " if best_split[\"info_gain\"]>0:\n", 236 | " # recur left\n", 237 | " left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n", 238 | " # recur right\n", 239 | " right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n", 240 | " # return decision node\n", 241 | " return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n", 242 | " left_subtree, right_subtree, best_split[\"info_gain\"])\n", 243 | " \n", 244 | " # compute leaf node\n", 245 | " leaf_value = self.calculate_leaf_value(Y)\n", 246 | " # return leaf node\n", 247 | " return Node(value=leaf_value)\n", 248 | " \n", 249 | " def get_best_split(self, dataset, num_samples, num_features):\n", 250 | " ''' function to find the best split '''\n", 251 | " \n", 252 | " # dictionary to store the best split\n", 253 | " best_split = {}\n", 254 | " max_info_gain = -float(\"inf\")\n", 255 | " \n", 256 | " # loop over all the features\n", 257 | " for feature_index in range(num_features):\n", 258 | " feature_values = dataset[:, feature_index]\n", 259 | " possible_thresholds = np.unique(feature_values)\n", 260 | " # loop over all the feature values present in the data\n", 261 | " for threshold in possible_thresholds:\n", 262 | " # get current split\n", 263 | " dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n", 264 | " # check if childs are not null\n", 265 | " if len(dataset_left)>0 and len(dataset_right)>0:\n", 266 | " y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n", 267 | " # compute information gain\n", 268 | " curr_info_gain = self.information_gain(y, left_y, right_y, \"gini\")\n", 269 | " # update the best split if needed\n", 270 | " if curr_info_gain>max_info_gain:\n", 271 | " best_split[\"feature_index\"] = feature_index\n", 272 | " best_split[\"threshold\"] = threshold\n", 273 | " best_split[\"dataset_left\"] = dataset_left\n", 274 | " best_split[\"dataset_right\"] = dataset_right\n", 275 | " best_split[\"info_gain\"] = curr_info_gain\n", 276 | " max_info_gain = curr_info_gain\n", 277 | " \n", 278 | " # return best split\n", 279 | " return best_split\n", 280 | " \n", 281 | " def split(self, dataset, feature_index, threshold):\n", 282 | " ''' function to split the data '''\n", 283 | " \n", 284 | " dataset_left = np.array([row for row in dataset if row[feature_index]<=threshold])\n", 285 | " dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n", 286 | " return dataset_left, dataset_right\n", 287 | " \n", 288 | " def information_gain(self, parent, l_child, r_child, mode=\"entropy\"):\n", 289 | " ''' function to compute information gain '''\n", 290 | " \n", 291 | " weight_l = len(l_child) / len(parent)\n", 292 | " weight_r = len(r_child) / len(parent)\n", 293 | " if mode==\"gini\":\n", 294 | " gain = self.gini_index(parent) - (weight_l*self.gini_index(l_child) + weight_r*self.gini_index(r_child))\n", 295 | " else:\n", 296 | " gain = self.entropy(parent) - (weight_l*self.entropy(l_child) + weight_r*self.entropy(r_child))\n", 297 | " return gain\n", 298 | " \n", 299 | " def entropy(self, y):\n", 300 | " ''' function to compute entropy '''\n", 301 | " \n", 302 | " class_labels = np.unique(y)\n", 303 | " entropy = 0\n", 304 | " for cls in class_labels:\n", 305 | " p_cls = len(y[y == cls]) / len(y)\n", 306 | " entropy += -p_cls * np.log2(p_cls)\n", 307 | " return entropy\n", 308 | " \n", 309 | " def gini_index(self, y):\n", 310 | " ''' function to compute gini index '''\n", 311 | " \n", 312 | " class_labels = np.unique(y)\n", 313 | " gini = 0\n", 314 | " for cls in class_labels:\n", 315 | " p_cls = len(y[y == cls]) / len(y)\n", 316 | " gini += p_cls**2\n", 317 | " return 1 - gini\n", 318 | " \n", 319 | " def calculate_leaf_value(self, Y):\n", 320 | " ''' function to compute leaf node '''\n", 321 | " \n", 322 | " Y = list(Y)\n", 323 | " return max(Y, key=Y.count)\n", 324 | " \n", 325 | " def print_tree(self, tree=None, indent=\" \"):\n", 326 | " ''' function to print the tree '''\n", 327 | " \n", 328 | " if not tree:\n", 329 | " tree = self.root\n", 330 | "\n", 331 | " if tree.value is not None:\n", 332 | " print(tree.value)\n", 333 | "\n", 334 | " else:\n", 335 | " print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.info_gain)\n", 336 | " print(\"%sleft:\" % (indent), end=\"\")\n", 337 | " self.print_tree(tree.left, indent + indent)\n", 338 | " print(\"%sright:\" % (indent), end=\"\")\n", 339 | " self.print_tree(tree.right, indent + indent)\n", 340 | " \n", 341 | " def fit(self, X, Y):\n", 342 | " ''' function to train the tree '''\n", 343 | " \n", 344 | " dataset = np.concatenate((X, Y), axis=1)\n", 345 | " self.root = self.build_tree(dataset)\n", 346 | " \n", 347 | " def predict(self, X):\n", 348 | " ''' function to predict new dataset '''\n", 349 | " \n", 350 | " preditions = [self.make_prediction(x, self.root) for x in X]\n", 351 | " return preditions\n", 352 | " \n", 353 | " def make_prediction(self, x, tree):\n", 354 | " ''' function to predict a single data point '''\n", 355 | " \n", 356 | " if tree.value!=None: return tree.value\n", 357 | " feature_val = x[tree.feature_index]\n", 358 | " if feature_val<=tree.threshold:\n", 359 | " return self.make_prediction(x, tree.left)\n", 360 | " else:\n", 361 | " return self.make_prediction(x, tree.right)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": {}, 367 | "source": [ 368 | "## Train-Test split" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 5, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "X = data.iloc[:, :-1].values\n", 378 | "Y = data.iloc[:, -1].values.reshape(-1,1)\n", 379 | "from sklearn.model_selection import train_test_split\n", 380 | "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, random_state=41)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "metadata": {}, 386 | "source": [ 387 | "## Fit the model" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 6, 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "X_2 <= 1.9 ? 0.33741385372714494\n", 400 | " left:0.0\n", 401 | " right:X_3 <= 1.5 ? 0.427106638180289\n", 402 | " left:X_2 <= 4.9 ? 0.05124653739612173\n", 403 | " left:1.0\n", 404 | " right:2.0\n", 405 | " right:X_2 <= 5.0 ? 0.019631171921475288\n", 406 | " left:X_1 <= 2.8 ? 0.20833333333333334\n", 407 | " left:2.0\n", 408 | " right:1.0\n", 409 | " right:2.0\n" 410 | ] 411 | } 412 | ], 413 | "source": [ 414 | "classifier = DecisionTreeClassifier(min_samples_split=3, max_depth=3)\n", 415 | "classifier.fit(X_train,Y_train)\n", 416 | "classifier.print_tree()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "metadata": {}, 422 | "source": [ 423 | "## Test the model" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 7, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "data": { 433 | "text/plain": [ 434 | "0.9333333333333333" 435 | ] 436 | }, 437 | "execution_count": 7, 438 | "metadata": {}, 439 | "output_type": "execute_result" 440 | } 441 | ], 442 | "source": [ 443 | "Y_pred = classifier.predict(X_test) \n", 444 | "from sklearn.metrics import accuracy_score\n", 445 | "accuracy_score(Y_test, Y_pred)" 446 | ] 447 | } 448 | ], 449 | "metadata": { 450 | "kernelspec": { 451 | "display_name": "Python 3", 452 | "language": "python", 453 | "name": "python3" 454 | }, 455 | "language_info": { 456 | "codemirror_mode": { 457 | "name": "ipython", 458 | "version": 3 459 | }, 460 | "file_extension": ".py", 461 | "mimetype": "text/x-python", 462 | "name": "python", 463 | "nbconvert_exporter": "python", 464 | "pygments_lexer": "ipython3", 465 | "version": "3.8.5" 466 | } 467 | }, 468 | "nbformat": 4, 469 | "nbformat_minor": 4 470 | } 471 | -------------------------------------------------------------------------------- /decision tree regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Import tools" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Get the data" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/html": [ 35 | "
\n", 36 | "\n", 49 | "\n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | "
x0x1x2x3x4y
08000.00.304871.30.002663126.201
110000.00.304871.30.002663125.201
212500.00.304871.30.002663125.951
316000.00.304871.30.002663127.591
420000.00.304871.30.002663127.461
\n", 109 | "
" 110 | ], 111 | "text/plain": [ 112 | " x0 x1 x2 x3 x4 y\n", 113 | "0 800 0.0 0.3048 71.3 0.002663 126.201\n", 114 | "1 1000 0.0 0.3048 71.3 0.002663 125.201\n", 115 | "2 1250 0.0 0.3048 71.3 0.002663 125.951\n", 116 | "3 1600 0.0 0.3048 71.3 0.002663 127.591\n", 117 | "4 2000 0.0 0.3048 71.3 0.002663 127.461" 118 | ] 119 | }, 120 | "execution_count": 2, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "data = pd.read_csv(\"airfoil_noise_data.csv\")\n", 127 | "data.head(5)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Node class" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 3, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "class Node():\n", 144 | " def __init__(self, feature_index=None, threshold=None, left=None, right=None, var_red=None, value=None):\n", 145 | " ''' constructor ''' \n", 146 | " \n", 147 | " # for decision node\n", 148 | " self.feature_index = feature_index\n", 149 | " self.threshold = threshold\n", 150 | " self.left = left\n", 151 | " self.right = right\n", 152 | " self.var_red = var_red\n", 153 | " \n", 154 | " # for leaf node\n", 155 | " self.value = value" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "## Tree class" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 4, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "class DecisionTreeRegressor():\n", 172 | " def __init__(self, min_samples_split=2, max_depth=2):\n", 173 | " ''' constructor '''\n", 174 | " \n", 175 | " # initialize the root of the tree \n", 176 | " self.root = None\n", 177 | " \n", 178 | " # stopping conditions\n", 179 | " self.min_samples_split = min_samples_split\n", 180 | " self.max_depth = max_depth\n", 181 | " \n", 182 | " def build_tree(self, dataset, curr_depth=0):\n", 183 | " ''' recursive function to build the tree '''\n", 184 | " \n", 185 | " X, Y = dataset[:,:-1], dataset[:,-1]\n", 186 | " num_samples, num_features = np.shape(X)\n", 187 | " best_split = {}\n", 188 | " # split until stopping conditions are met\n", 189 | " if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n", 190 | " # find the best split\n", 191 | " best_split = self.get_best_split(dataset, num_samples, num_features)\n", 192 | " # check if information gain is positive\n", 193 | " if best_split[\"var_red\"]>0:\n", 194 | " # recur left\n", 195 | " left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n", 196 | " # recur right\n", 197 | " right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n", 198 | " # return decision node\n", 199 | " return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n", 200 | " left_subtree, right_subtree, best_split[\"var_red\"])\n", 201 | " \n", 202 | " # compute leaf node\n", 203 | " leaf_value = self.calculate_leaf_value(Y)\n", 204 | " # return leaf node\n", 205 | " return Node(value=leaf_value)\n", 206 | " \n", 207 | " def get_best_split(self, dataset, num_samples, num_features):\n", 208 | " ''' function to find the best split '''\n", 209 | " \n", 210 | " # dictionary to store the best split\n", 211 | " best_split = {}\n", 212 | " max_var_red = -float(\"inf\")\n", 213 | " # loop over all the features\n", 214 | " for feature_index in range(num_features):\n", 215 | " feature_values = dataset[:, feature_index]\n", 216 | " possible_thresholds = np.unique(feature_values)\n", 217 | " # loop over all the feature values present in the data\n", 218 | " for threshold in possible_thresholds:\n", 219 | " # get current split\n", 220 | " dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n", 221 | " # check if childs are not null\n", 222 | " if len(dataset_left)>0 and len(dataset_right)>0:\n", 223 | " y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n", 224 | " # compute information gain\n", 225 | " curr_var_red = self.variance_reduction(y, left_y, right_y)\n", 226 | " # update the best split if needed\n", 227 | " if curr_var_red>max_var_red:\n", 228 | " best_split[\"feature_index\"] = feature_index\n", 229 | " best_split[\"threshold\"] = threshold\n", 230 | " best_split[\"dataset_left\"] = dataset_left\n", 231 | " best_split[\"dataset_right\"] = dataset_right\n", 232 | " best_split[\"var_red\"] = curr_var_red\n", 233 | " max_var_red = curr_var_red\n", 234 | " \n", 235 | " # return best split\n", 236 | " return best_split\n", 237 | " \n", 238 | " def split(self, dataset, feature_index, threshold):\n", 239 | " ''' function to split the data '''\n", 240 | " \n", 241 | " dataset_left = np.array([row for row in dataset if row[feature_index]<=threshold])\n", 242 | " dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n", 243 | " return dataset_left, dataset_right\n", 244 | " \n", 245 | " def variance_reduction(self, parent, l_child, r_child):\n", 246 | " ''' function to compute variance reduction '''\n", 247 | " \n", 248 | " weight_l = len(l_child) / len(parent)\n", 249 | " weight_r = len(r_child) / len(parent)\n", 250 | " reduction = np.var(parent) - (weight_l * np.var(l_child) + weight_r * np.var(r_child))\n", 251 | " return reduction\n", 252 | " \n", 253 | " def calculate_leaf_value(self, Y):\n", 254 | " ''' function to compute leaf node '''\n", 255 | " \n", 256 | " val = np.mean(Y)\n", 257 | " return val\n", 258 | " \n", 259 | " def print_tree(self, tree=None, indent=\" \"):\n", 260 | " ''' function to print the tree '''\n", 261 | " \n", 262 | " if not tree:\n", 263 | " tree = self.root\n", 264 | "\n", 265 | " if tree.value is not None:\n", 266 | " print(tree.value)\n", 267 | "\n", 268 | " else:\n", 269 | " print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.var_red)\n", 270 | " print(\"%sleft:\" % (indent), end=\"\")\n", 271 | " self.print_tree(tree.left, indent + indent)\n", 272 | " print(\"%sright:\" % (indent), end=\"\")\n", 273 | " self.print_tree(tree.right, indent + indent)\n", 274 | " \n", 275 | " def fit(self, X, Y):\n", 276 | " ''' function to train the tree '''\n", 277 | " \n", 278 | " dataset = np.concatenate((X, Y), axis=1)\n", 279 | " self.root = self.build_tree(dataset)\n", 280 | " \n", 281 | " def make_prediction(self, x, tree):\n", 282 | " ''' function to predict new dataset '''\n", 283 | " \n", 284 | " if tree.value!=None: return tree.value\n", 285 | " feature_val = x[tree.feature_index]\n", 286 | " if feature_val<=tree.threshold:\n", 287 | " return self.make_prediction(x, tree.left)\n", 288 | " else:\n", 289 | " return self.make_prediction(x, tree.right)\n", 290 | " \n", 291 | " def predict(self, X):\n", 292 | " ''' function to predict a single data point '''\n", 293 | " \n", 294 | " preditions = [self.make_prediction(x, self.root) for x in X]\n", 295 | " return preditions" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "## Train-Test split" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 5, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "X = data.iloc[:, :-1].values\n", 312 | "Y = data.iloc[:, -1].values.reshape(-1,1)\n", 313 | "from sklearn.model_selection import train_test_split\n", 314 | "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, random_state=41)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "## Fit the model" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 6, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "name": "stdout", 331 | "output_type": "stream", 332 | "text": [ 333 | "X_0 <= 3150.0 ? 7.132048702017748\n", 334 | " left:X_4 <= 0.033779199999999995 ? 3.5903305690676675\n", 335 | " left:X_3 <= 55.5 ? 1.1789899981318328\n", 336 | " left:X_4 <= 0.00251435 ? 1.614396721819876\n", 337 | " left:128.9919833333333\n", 338 | " right:125.90953579676673\n", 339 | " right:X_1 <= 15.4 ? 2.2342245360792994\n", 340 | " left:129.39160280373832\n", 341 | " right:123.80422222222222\n", 342 | " right:X_0 <= 1250.0 ? 9.970884020498875\n", 343 | " left:X_4 <= 0.0483159 ? 6.355275159824863\n", 344 | " left:124.38024528301887\n", 345 | " right:118.30039999999998\n", 346 | " right:X_3 <= 39.6 ? 5.036286657241022\n", 347 | " left:113.58091666666667\n", 348 | " right:118.07284615384614\n", 349 | " right:X_4 <= 0.00146332 ? 29.082992105065273\n", 350 | " left:X_0 <= 8000.0 ? 11.886497073996967\n", 351 | " left:X_2 <= 0.0508 ? 7.608945827689513\n", 352 | " left:134.04247500000002\n", 353 | " right:127.33581818181818\n", 354 | " right:X_4 <= 0.00076193 ? 10.622919322400815\n", 355 | " left:128.94078571428574\n", 356 | " right:122.4076875\n", 357 | " right:X_4 <= 0.022902799999999997 ? 5.638575922510647\n", 358 | " left:X_0 <= 6300.0 ? 5.985051045988911\n", 359 | " left:120.04740816326529\n", 360 | " right:114.67370491803278\n", 361 | " right:X_4 <= 0.0368233 ? 8.63874479304644\n", 362 | " left:113.83169565217393\n", 363 | " right:107.6395833333333\n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "regressor = DecisionTreeRegressor(min_samples_split=3, max_depth=3)\n", 369 | "regressor.fit(X_train,Y_train)\n", 370 | "regressor.print_tree()" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "## Test the model" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 7, 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/plain": [ 388 | "4.851358097184457" 389 | ] 390 | }, 391 | "execution_count": 7, 392 | "metadata": {}, 393 | "output_type": "execute_result" 394 | } 395 | ], 396 | "source": [ 397 | "Y_pred = regressor.predict(X_test) \n", 398 | "from sklearn.metrics import mean_squared_error\n", 399 | "np.sqrt(mean_squared_error(Y_test, Y_pred))" 400 | ] 401 | } 402 | ], 403 | "metadata": { 404 | "kernelspec": { 405 | "display_name": "Python 3", 406 | "language": "python", 407 | "name": "python3" 408 | }, 409 | "language_info": { 410 | "codemirror_mode": { 411 | "name": "ipython", 412 | "version": 3 413 | }, 414 | "file_extension": ".py", 415 | "mimetype": "text/x-python", 416 | "name": "python", 417 | "nbconvert_exporter": "python", 418 | "pygments_lexer": "ipython3", 419 | "version": "3.8.5" 420 | } 421 | }, 422 | "nbformat": 4, 423 | "nbformat_minor": 4 424 | } 425 | -------------------------------------------------------------------------------- /linear regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Generating Fake Data" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 16, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from sklearn.datasets.samples_generator import make_regression\n", 17 | "X, y = make_regression(n_samples=200, n_features=1, n_informative=1, noise=6, bias=30, random_state=200)\n", 18 | "m = 200" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Visualizing the Data" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 25, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "image/png": "\n", 36 | "text/plain": [ 37 | "
" 38 | ] 39 | }, 40 | "metadata": {}, 41 | "output_type": "display_data" 42 | } 43 | ], 44 | "source": [ 45 | "from matplotlib import pyplot as plt\n", 46 | "plt.scatter(X,y, c = \"red\",alpha=.5, marker = 'o')\n", 47 | "plt.xlabel(\"X\")\n", 48 | "plt.ylabel(\"Y\")\n", 49 | "plt.show()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Linear Model" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 18, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import numpy as np\n", 66 | "def h(X,w):\n", 67 | " return (w[1]*np.array(X[:,0])+w[0])" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Cost Function" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 19, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "def cost(w,X,y):\n", 84 | " return (.5/m) * np.sum(np.square(h(X,w)-np.array(y)))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## Gradient Descent " 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 20, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def grad(w,X,y):\n", 101 | " g = [0]*2\n", 102 | " g[0] = (1/m) * np.sum(h(X,w)-np.array(y))\n", 103 | " g[1] = (1/m) * np.sum((h(X,w)-np.array(y))*np.array(X[:,0]))\n", 104 | " return g\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 21, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def descent(w_new, w_prev, lr):\n", 114 | " print(w_prev)\n", 115 | " print(cost(w_prev,X,y))\n", 116 | " j=0\n", 117 | " while True:\n", 118 | " w_prev = w_new\n", 119 | " w0 = w_prev[0] - lr*grad(w_prev,X,y)[0]\n", 120 | " w1 = w_prev[1] - lr*grad(w_prev,X,y)[1]\n", 121 | " w_new = [w0, w1]\n", 122 | " print(w_new)\n", 123 | " print(cost(w_new,X,y))\n", 124 | " if (w_new[0]-w_prev[0])**2 + (w_new[1]-w_prev[1])**2 <= pow(10,-6):\n", 125 | " return w_new\n", 126 | " if j>500: \n", 127 | " return w_new\n", 128 | " j+=1 " 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "## Initializing Parameters" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 22, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "w = [0,-1]" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## Training the Model" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 23, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "[0, -1]\n", 164 | "540.5360663843456\n", 165 | "[3.0956308633447547, 0.11442770988081663]\n", 166 | "437.91139336428444\n", 167 | "[5.873446610978822, 1.1023454281382854]\n", 168 | "355.5039050187037\n", 169 | "[8.366165526017987, 1.9778657783247602]\n", 170 | "289.3267499184995\n", 171 | "[10.603129563187093, 2.753547324958939]\n", 172 | "236.1799750745718\n", 173 | "[12.610653489037027, 3.440564026385428]\n", 174 | "193.49509649539323\n", 175 | "[14.412337853388406, 4.048856351454087]\n", 176 | "159.2103901995911\n", 177 | "[16.0293495446536, 4.587266032213945]\n", 178 | "131.6708284668908\n", 179 | "[17.480673291820082, 5.063656213710697]\n", 180 | "109.54778810165583\n", 181 | "[18.7833371265594, 5.485018573380515]\n", 182 | "91.77462156224563\n", 183 | "[19.952614505935692, 5.857568814053481]\n", 184 | "77.49495508304668\n", 185 | "[21.002205515744066, 6.186831784078626]\n", 186 | "66.02119816099949\n", 187 | "[21.944399323224108, 6.4777173436470505]\n", 188 | "56.801246289923824\n", 189 | "[22.79021982273288, 6.734587976310905]\n", 190 | "49.39175789964725\n", 191 | "[23.549556216205993, 6.961319037445921]\n", 192 | "43.436706577550574\n", 193 | "[24.23128008944935, 7.1613524356181975]\n", 194 | "38.6501664442448\n", 195 | "[24.843350383306017, 7.337744457271138]\n", 196 | "34.802494555336104\n", 197 | "[25.39290751357782, 7.493208368754656]\n", 198 | "31.709239459080347\n", 199 | "[25.88635776349851, 7.630152361500887]\n", 200 | "29.22223761683571\n", 201 | "[26.329448955986066, 7.750713345236864]\n", 202 | "27.22246575546644\n", 203 | "[26.727338308440384, 7.85678703973609]\n", 204 | "25.614302554322332\n", 205 | "[27.084653279240058, 7.950054767051352]\n", 206 | "24.320921533807876\n", 207 | "[27.405546131200424, 8.032007302818467]\n", 208 | "23.280591944370723\n", 209 | "[27.69374286207346, 8.10396610651881]\n", 210 | "22.443708530267372\n", 211 | "[27.952587084793493, 8.167102216040767]\n", 212 | "21.77040640846802\n", 213 | "[28.1850793797893, 8.222453061042739]\n", 214 | "21.22864568158778\n", 215 | "[28.393912587566714, 8.270937422096216]\n", 216 | "20.792673176173675\n", 217 | "[28.581503461264624, 8.313368738022628]\n", 218 | "20.44178697219165\n", 219 | "[28.75002105541824, 8.350466941915155]\n", 220 | "20.159344055135406\n", 221 | "[28.901412188203544, 8.382868986773879]\n", 222 | "19.93196319200716\n", 223 | "[29.03742427951789, 8.411138204226754]\n", 224 | "19.74888457866804\n", 225 | "[29.159625835953786, 8.435772624233959]\n", 226 | "19.601455387770752\n", 227 | "[29.2694248256702, 8.45721236977788]\n", 228 | "19.482716431969585\n", 229 | "[29.368085161021078, 8.475846228144906]\n", 230 | "19.387070041859424\n", 231 | "[29.45674148426251, 8.492017489347575]\n", 232 | "19.310013179228125\n", 233 | "[29.536412431457247, 8.506029132372674]\n", 234 | "19.24792295396926\n", 235 | "[29.60801253158606, 8.518148431144402]\n", 236 | "19.19788424005612\n", 237 | "[29.672362881642048, 8.528611044246796]\n", 238 | "19.157551114830053\n", 239 | "[29.73020072393229, 8.537624645454327]\n", 240 | "19.125035474814123\n", 241 | "[29.782188038766204, 8.545372145882025]\n", 242 | "19.09881748922257\n", 243 | "[29.828919254015993, 8.552014553005522]\n", 244 | "19.077673602614112\n", 245 | "[29.87092816255075, 8.557693506843796]\n", 246 | "19.06061864154971\n", 247 | "[29.90869412914744, 8.562533529178292]\n", 248 | "19.046859257450855\n", 249 | "[29.942647660055936, 8.566644017743355]\n", 250 | "19.035756481849372\n", 251 | "[29.97317540084111, 8.57012101381262]\n", 252 | "19.026795607153897\n", 253 | "[30.000624621352255, 8.573048768478005]\n", 254 | "19.019561957024997\n", 255 | "[30.025307240597662, 8.57550113013079]\n", 256 | "19.013721392385282\n", 257 | "[30.047503438857856, 8.577542773171217]\n", 258 | "19.009004625587373\n", 259 | "[30.067464899489234, 8.579230285760998]\n", 260 | "19.005194597236063\n", 261 | "[30.085417718492856, 8.580613132462974]\n", 262 | "19.00211631637512\n", 263 | "[30.101565015998244, 8.581734505857439]\n", 264 | "18.99962868224005\n", 265 | "[30.11608928029284, 8.582632079662158]\n", 266 | "18.99761790019706\n", 267 | "[30.12915447187187, 8.583338674491873]\n", 268 | "18.995992180371438\n", 269 | "[30.14090791215345, 8.583882846154577]\n", 270 | "18.994677468461337\n", 271 | "[30.151481978966018, 8.584289405279359]\n", 272 | "18.993614007260764\n", 273 | "[30.160995628639302, 8.58457987608947]\n", 274 | "18.99275356682987\n", 275 | "[30.169555762489143, 8.584772901261081]\n", 276 | "18.992057212939287\n", 277 | "[30.177258453655917, 8.584884599031389]\n", 278 | "18.991493508895633\n", 279 | "[30.184190048614877, 8.584928878028586]\n", 280 | "18.991037066345022\n", 281 | "[30.190428156204195, 8.584917714681556]\n", 282 | "18.99066737713041\n", 283 | "[30.196042535696083, 8.584861397520429]\n", 284 | "18.99036787153328\n", 285 | "[30.201095894251807, 8.584768742193141]\n", 286 | "18.990125158892102\n", 287 | "[30.205644603039065, 8.58464728059093]\n", 288 | "18.989928415168315\n", 289 | "[30.20973934033721, 8.584503427091704]\n", 290 | "18.98976888893204\n", 291 | "[30.21342566910097, 8.58434262458882]\n", 292 | "18.98963950279431\n", 293 | "[30.216744555686475, 8.584169472669567]\n", 294 | "18.989534531782173\n", 295 | "[30.219732835755565, 8.58398784003826]\n", 296 | "18.9894493437512\n", 297 | "[30.22242363275721, 8.583800963039536]\n", 298 | "18.989380189826502\n", 299 | "[30.22484673383125, 8.583611531925033]\n", 300 | "18.989324035195228\n", 301 | "[30.227028927482966, 8.583421766317985]\n", 302 | "18.989278422451623\n", 303 | "[30.228994306931416, 8.583233481162887]\n", 304 | "18.98924136120799\n", 305 | "[30.23076454263459, 8.583048144298838]\n", 306 | "18.98921123890338\n", 307 | "[30.23235912713573, 8.582866926663364]\n", 308 | "18.98918674872298\n", 309 | "[30.233795595053294, 8.582690746016727]\n", 310 | "18.98916683133221\n", 311 | "[30.235089720748128, 8.582520304973011]\n", 312 | "18.989150627766747\n", 313 | "[30.23625569594233, 8.582356124032495]\n", 314 | "18.98913744133326\n", 315 | "[30.237306289331627, 8.582198570228412]\n", 316 | "18.989126706789886\n", 317 | "[30.238252990024396, 8.582047881929046]\n", 318 | "18.98911796540923\n", 319 | "[30.238252990024396, 8.582047881929046]\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "w = descent(w,w,.1)\n", 325 | "print(w)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "## Visualizing the Result" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 24, 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "image/png": "\n", 343 | "text/plain": [ 344 | "
" 345 | ] 346 | }, 347 | "metadata": {}, 348 | "output_type": "display_data" 349 | } 350 | ], 351 | "source": [ 352 | "def graph(formula, x_range): \n", 353 | " x = np.array(x_range) \n", 354 | " y = formula(x) \n", 355 | " plt.plot(x, y) \n", 356 | " \n", 357 | "def my_formula(x):\n", 358 | " return w[0]+w[1]*x\n", 359 | "\n", 360 | "plt.scatter(X,y, c = \"red\",alpha=.5, marker = 'o')\n", 361 | "graph(my_formula, range(-5,5))\n", 362 | "plt.xlabel('X')\n", 363 | "plt.ylabel('Y')\n", 364 | "plt.show()" 365 | ] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": "Python 3", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.6.5" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 2 389 | } 390 | -------------------------------------------------------------------------------- /logistic regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Creating Fake Data" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 26, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from sklearn.datasets.samples_generator import make_blobs\n", 17 | "X, Y = make_blobs(n_samples=200, centers=2, n_features=2, cluster_std=5, random_state=11)\n", 18 | "m = 200" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Visualizing the Data" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 28, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "image/png": "\n", 36 | "text/plain": [ 37 | "
" 38 | ] 39 | }, 40 | "metadata": {}, 41 | "output_type": "display_data" 42 | } 43 | ], 44 | "source": [ 45 | "from matplotlib import pyplot as plt\n", 46 | "from pandas import DataFrame \n", 47 | "df = DataFrame(dict(x=X[:,0], y=X[:,1], label=Y))\n", 48 | "colors = {0:'blue', 1:'orange'}\n", 49 | "fig, ax = plt.subplots()\n", 50 | "grouped = df.groupby('label')\n", 51 | "for key, group in grouped:\n", 52 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n", 53 | "plt.xlabel('X_0')\n", 54 | "plt.ylabel('X_1')\n", 55 | "plt.show()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "## Logistic Model" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 6, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "import numpy as np\n", 72 | "def sigmoid(z):\n", 73 | " return 1 / (1 + np.exp(-z))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 7, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def hx(w,X):\n", 83 | " z = np.array(w[0] + w[1]*np.array(X[:,0]) + w[2]*np.array(X[:,1]))\n", 84 | " return sigmoid(z)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## Cost Function - Binary Cross Entropy " 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 8, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def cost(w, X, Y):\n", 101 | " y_pred = hx(w,X)\n", 102 | " return -1 * sum(Y*np.log(y_pred) + (1-Y)*np.log(1-y_pred))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Gradient Descent" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 9, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def grad(w, X, Y):\n", 119 | " y_pred = hx(w,X)\n", 120 | " g = [0]*3\n", 121 | " g[0] = -1 * sum(Y*(1-y_pred) - (1-Y)*y_pred)\n", 122 | " g[1] = -1 * sum(Y*(1-y_pred)*X[:,0] - (1-Y)*y_pred*X[:,0])\n", 123 | " g[2] = -1 * sum(Y*(1-y_pred)*X[:,1] - (1-Y)*y_pred*X[:,1])\n", 124 | " return g" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 37, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def descent(w_new, w_prev, lr):\n", 134 | " print(w_prev)\n", 135 | " print(cost(w_prev, X, Y))\n", 136 | " j=0\n", 137 | " while True:\n", 138 | " w_prev = w_new\n", 139 | " w0 = w_prev[0] - lr*grad(w_prev, X, Y)[0]\n", 140 | " w1 = w_prev[1] - lr*grad(w_prev, X, Y)[1]\n", 141 | " w2 = w_prev[2] - lr*grad(w_prev, X, Y)[2]\n", 142 | " w_new = [w0, w1, w2]\n", 143 | " print(w_new)\n", 144 | " print(cost(w_new, X, Y))\n", 145 | " if (w_new[0]-w_prev[0])**2 + (w_new[1]-w_prev[1])**2 + (w_new[2]-w_prev[2])**2 100: \n", 149 | " return w_new\n", 150 | " j+=1" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## Initializing Parameters" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 38, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "w=[1,1,1] " 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "## Training the Model" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 39, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "[1, 1, 1]\n", 186 | "126.96627984087802\n", 187 | "[1.2539422898588644, -0.5512710240837779, 1.1843328109648115]\n", 188 | "112.81815690482537\n", 189 | "[1.2400289717256012, 1.3365539538600526, 1.3041857280167006]\n", 190 | "168.5749357223354\n", 191 | "[1.496038556709955, -0.26091766794801896, 1.498589948744304]\n", 192 | "75.77036936925771\n", 193 | "[1.552998152697941, 0.5805338886016641, 1.2059963033700993]\n", 194 | "67.88063552779857\n", 195 | "[1.7129970823802407, -0.12884219161477717, 0.9205552719943874]\n", 196 | "49.631255078140065\n", 197 | "[1.721522374547689, 0.8081464107255731, 0.784948881261807]\n", 198 | "84.65164447415044\n", 199 | "[1.9318934810214365, -0.5853214217947893, 0.9260582355386137]\n", 200 | "137.02236299684375\n", 201 | "[1.792209433897121, 2.259519445068961, 1.8717281361304319]\n", 202 | "296.375960910099\n", 203 | "[2.0717704014100984, 0.4671501798576154, 2.2375086018139254]\n", 204 | "96.21982855549399\n", 205 | "[2.2275891867070294, 0.17556932949773574, 1.7757719140648793]\n", 206 | "66.93614685472798\n", 207 | "[2.34635073219192, 0.10042351859759649, 1.3377088708014218]\n", 208 | "47.88369966457564\n", 209 | "[2.420814222298708, 0.2006649466226462, 0.9803329461122527]\n", 210 | "35.74059972551982\n", 211 | "[2.47109414690982, 0.15761477718705647, 0.7206148121331153]\n", 212 | "31.156227520259954\n", 213 | "[2.4669494157236738, 0.33368915465087157, 0.6530382051207929]\n", 214 | "31.825794017269445\n", 215 | "[2.5018185122166057, 0.052068234061564855, 0.5947139224875355]\n", 216 | "37.13648828690444\n", 217 | "[2.4168290256808196, 0.8618663362335011, 0.8721726092434906]\n", 218 | "77.31322559093016\n", 219 | "[2.596954895024072, -0.39576768312727784, 0.9305031886901464]\n", 220 | "97.53803980330012\n", 221 | "[2.4680716616586964, 1.8009384856190926, 1.4733863829082396]\n", 222 | "210.1650208652238\n", 223 | "[2.7256313646387795, 0.10398042243284711, 1.7582377785757068]\n", 224 | "60.930800473556\n", 225 | "[2.8139079434551295, 0.18042029298994794, 1.356581836503274]\n", 226 | "44.87910324512945\n", 227 | "[2.8764378665294585, 0.18788023515027294, 1.0240956114450044]\n", 228 | "35.1115767203633\n", 229 | "[2.9004091563475756, 0.25373283853477424, 0.8052448292395792]\n", 230 | "31.203934623682823\n", 231 | "[2.905564316215357, 0.23573784244902107, 0.6972405539065145]\n", 232 | "30.64594404630533\n", 233 | "[2.884978695284491, 0.3156866373830539, 0.7013654663287641]\n", 234 | "30.87299094573297\n", 235 | "[2.8899757309235032, 0.18571374355707734, 0.662108112619276]\n", 236 | "31.58765075793203\n", 237 | "[2.8456086105455767, 0.44695414312741255, 0.7472822400464028]\n", 238 | "34.997339236852504\n", 239 | "[2.8946903864992204, 0.0311385451581056, 0.652724208526957]\n", 240 | "41.16553917784135\n", 241 | "[2.78333314375564, 1.0205125031757631, 0.9858502424130702]\n", 242 | "90.41042650497221\n", 243 | "[2.970171901127539, -0.2925927602835001, 1.054289078329687]\n", 244 | "77.28627608547319\n", 245 | "[2.8765338440725268, 1.3870481637289969, 1.2593201244024514]\n", 246 | "137.884699579421\n", 247 | "[3.096205693561861, -0.09815951484676533, 1.3875700666825428]\n", 248 | "55.28433950652438\n", 249 | "[3.091222620686947, 0.691357878366354, 1.1851902704372872]\n", 250 | "54.56897648020144\n", 251 | "[3.197031648206172, 0.04853253754001363, 0.9522363399023503]\n", 252 | "38.879877723561194\n", 253 | "[3.152849114551307, 0.6705086356955698, 0.9042855300104767]\n", 254 | "47.42326651334623\n", 255 | "[3.249401718635574, -0.05586144519042935, 0.7968950006624693]\n", 256 | "50.77734982683568\n", 257 | "[3.1265348460183477, 1.2297186855362119, 1.1081714720550426]\n", 258 | "111.36985561835434\n", 259 | "[3.326563756861511, -0.1791741101134301, 1.2218462740483753]\n", 260 | "62.108694704750796\n", 261 | "[3.269128921912794, 1.0293263482837012, 1.2008084141382773]\n", 262 | "82.99082598516479\n", 263 | "[3.433600163161756, -0.10416961202267583, 1.1426909126167342]\n", 264 | "53.969099415361306\n", 265 | "[3.372434392083106, 0.9606133194736919, 1.1258764505552485]\n", 266 | "73.24867011604414\n", 267 | "[3.5227671344926526, -0.10541269443776047, 1.0639829014706683]\n", 268 | "54.95877911862518\n", 269 | "[3.4402314176940667, 1.0678781146572354, 1.1359656744505753]\n", 270 | "84.39209613357035\n", 271 | "[3.6053162123525917, -0.1094419475498285, 1.1193961332548408]\n", 272 | "55.78053131843188\n", 273 | "[3.5284998867311446, 1.035838685329749, 1.1603344516799623]\n", 274 | "79.31512100875807\n", 275 | "[3.6850794814765253, -0.0823331021711422, 1.117327385696941]\n", 276 | "53.365825253435844\n", 277 | "[3.6094245208717335, 0.990136688364013, 1.1479057502919368]\n", 278 | "73.00073094363131\n", 279 | "[3.7550639240710013, -0.06190456364194685, 1.0874826200950347]\n", 280 | "51.85370536629623\n", 281 | "[3.6740149106228186, 0.9863883933259856, 1.1383256238195103]\n", 282 | "71.6077027550773\n", 283 | "[3.816365454730692, -0.053415680920452235, 1.079972647866281]\n", 284 | "51.512482717018116\n", 285 | "[3.7321415001226725, 0.9874518828741854, 1.1419582058166857]\n", 286 | "70.91147940312463\n", 287 | "[3.8721953817202945, -0.040975086743274325, 1.0817922008147334]\n", 288 | "50.68840558905084\n", 289 | "[3.7880511473342113, 0.9676622239349191, 1.1410140920986547]\n", 290 | "68.11602429098141\n", 291 | "[3.9218160093807857, -0.02032252700940551, 1.069864113475001]\n", 292 | "49.12965800771219\n", 293 | "[3.8372555826614154, 0.9397003151318558, 1.1301065005574003]\n", 294 | "64.66526827228455\n", 295 | "[3.963079147188273, 0.0015189750142303726, 1.048002481200122]\n", 296 | "47.55953959619178\n", 297 | "[3.8768575558853064, 0.9161026456599739, 1.116464165382025]\n", 298 | "61.870753988367085\n", 299 | "[3.995752432150913, 0.019597576030487285, 1.027414752411626]\n", 300 | "46.36190019402692\n", 301 | "[3.907745460549524, 0.89829389197684, 1.1055063163003485]\n", 302 | "59.82357707095895\n", 303 | "[4.0211666051250194, 0.03440304270300665, 1.0117153695411087]\n", 304 | "45.42762114954028\n", 305 | "[3.932061436494771, 0.8820448670623741, 1.0967233693735197]\n", 306 | "58.07265242127726\n", 307 | "[4.040538211472588, 0.0486222507940266, 0.9980600115598323]\n", 308 | "44.528626105329884\n", 309 | "[3.95104075816509, 0.8635822178845587, 1.087464778932447]\n", 310 | "56.25191093140636\n", 311 | "[4.0542064633885655, 0.06390897971924736, 0.9830845369826282]\n", 312 | "43.560297250359895\n", 313 | "[3.9647146116297285, 0.8419896969170714, 1.0762339319145042]\n", 314 | "54.28859760963312\n", 315 | "[4.061974777022511, 0.08010117719043697, 0.9658476947904583]\n", 316 | "42.55609321933739\n", 317 | "[3.9726570293297416, 0.8185528342778258, 1.0632653646122094]\n", 318 | "52.30757518813522\n", 319 | "[4.063741047353516, 0.09591304118161548, 0.9474917553090927]\n", 320 | "41.60869408164448\n", 321 | "[3.9746698824165576, 0.7955306511635729, 1.0497944865210715]\n", 322 | "50.49160062072334\n", 323 | "[4.0598875620346835, 0.10973094543337403, 0.929938489713735]\n", 324 | "40.807102034142616\n", 325 | "[3.9710681877912237, 0.7754202550430537, 1.0373138237663093]\n", 326 | "49.006821287773285\n", 327 | "[4.051367062032833, 0.120083598318384, 0.9149680408447317]\n", 328 | "40.21376802733146\n", 329 | "[3.9626934207233138, 0.7605664236412619, 1.027175552319945]\n", 330 | "47.97972673492794\n", 331 | "[4.03960146305846, 0.1258626202000106, 0.9038701158750538]\n", 332 | "39.8665848590012\n", 333 | "[3.9508175989838605, 0.7529117846030396, 1.0204483002710898]\n", 334 | "47.49986495259877\n", 335 | "[4.026319016314264, 0.12633822832598507, 0.8975122703008471]\n", 336 | "39.79309105387893\n", 337 | "[3.937022647315148, 0.7539702241394387, 1.0179583095352271]\n", 338 | "47.64232237902202\n", 339 | "[4.013431600037268, 0.12097838008325779, 0.8966869188055052]\n", 340 | "40.03060914399324\n", 341 | "[3.9231193858382776, 0.7650499798316024, 1.0204780871163988]\n", 342 | "48.50370516914084\n", 343 | "[4.0030116224672625, 0.10920033650354244, 0.9025652832558652]\n", 344 | "40.647028942768124\n", 345 | "[3.911168807082645, 0.7874617085157639, 1.0289496414232508]\n", 346 | "50.233655133405684\n", 347 | "[3.9973475854531766, 0.09037455328475408, 0.9170124097333662]\n", 348 | "41.74920834355003\n", 349 | "[3.9036347048532614, 0.8219989928942072, 1.0444815131818708]\n", 350 | "53.00756665023323\n", 351 | "[3.998896442510712, 0.06491278493707686, 0.9421306486550264]\n", 352 | "43.414481126056494\n", 353 | "[3.903578932273919, 0.8658652317559935, 1.0673333907773224]\n", 354 | "56.743118405663786\n", 355 | "[4.009605120891904, 0.03790132000571389, 0.9771595777097101]\n", 356 | "45.367245465512056\n", 357 | "[3.9142374607656074, 0.9053571447819756, 1.0933903406414625]\n", 358 | "60.26445149748922\n", 359 | "[4.028982704706953, 0.022311918251399665, 1.0115367561804436]\n", 360 | "46.56452396428107\n", 361 | "[3.936118279261484, 0.9161216408470005, 1.1107981056735687]\n", 362 | "61.119030592288695\n", 363 | "[4.05221859760455, 0.029367074818002936, 1.0249045662087954]\n", 364 | "46.03001820966252\n", 365 | "[3.962804479787527, 0.890313219168536, 1.1082732172211933]\n", 366 | "58.54793332826261\n", 367 | "[4.071694739617031, 0.05576987084898566, 1.0076561457699287]\n", 368 | "44.146310350361176\n", 369 | "[3.9843650919995444, 0.8441134205516451, 1.087132296857637]\n", 370 | "54.38313567139385\n", 371 | "[4.081221437775141, 0.08935078570508614, 0.9713207459960232]\n", 372 | "42.00179619891018\n", 373 | "[3.9945804561228613, 0.7946114356824799, 1.058305857789166]\n", 374 | "50.363531876210516\n", 375 | "[4.078719695444127, 0.11976411937445264, 0.9330027012706759]\n", 376 | "40.284159514532284\n", 377 | "[3.992398365094767, 0.7520453654158086, 1.0315755035230465]\n", 378 | "47.27452627208367\n", 379 | "[4.065870596044042, 0.14186053507189011, 0.9025860178561832]\n", 380 | "39.138492702670455\n", 381 | "[3.9799787694173245, 0.7210735715432988, 1.0111931600659616]\n", 382 | "45.255334753917104\n", 383 | "[4.0460976445817955, 0.15448991706397763, 0.8819752531296552]\n", 384 | "38.48142071458852\n", 385 | "[3.960514974560398, 0.7030948526309135, 0.9977984989102618]\n", 386 | "44.1983732970012\n", 387 | "[4.022890960756249, 0.1583065776331426, 0.8696992149881329]\n", 388 | "38.21770810120381\n", 389 | "[3.937085053426291, 0.6982103483166531, 0.9909832248890901]\n", 390 | "43.98944447469072\n", 391 | "[3.937085053426291, 0.6982103483166531, 0.9909832248890901]\n" 392 | ] 393 | } 394 | ], 395 | "source": [ 396 | "w = descent(w,w,.0099)\n", 397 | "print(w)" 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "metadata": {}, 403 | "source": [ 404 | "## Visualizing the Result" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 40, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "data": { 414 | "image/png": "\n", 415 | "text/plain": [ 416 | "
" 417 | ] 418 | }, 419 | "metadata": {}, 420 | "output_type": "display_data" 421 | } 422 | ], 423 | "source": [ 424 | "def graph(formula, x_range): \n", 425 | " x = np.array(x_range) \n", 426 | " y = formula(x) \n", 427 | " plt.plot(x, y) \n", 428 | " \n", 429 | "def my_formula(x):\n", 430 | " return (-w[0]-w[1]*x)/w[2]\n", 431 | "\n", 432 | "from matplotlib import pyplot as plt\n", 433 | "from pandas import DataFrame \n", 434 | "df = DataFrame(dict(x=X[:,0], y=X[:,1], label=Y))\n", 435 | "colors = {0:'blue', 1:'orange'}\n", 436 | "fig, ax = plt.subplots()\n", 437 | "grouped = df.groupby('label')\n", 438 | "for key, group in grouped:\n", 439 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n", 440 | "graph(my_formula, range(-20,15))\n", 441 | "plt.xlabel('X_0')\n", 442 | "plt.ylabel('X_1')\n", 443 | "plt.show()" 444 | ] 445 | } 446 | ], 447 | "metadata": { 448 | "kernelspec": { 449 | "display_name": "Python 3", 450 | "language": "python", 451 | "name": "python3" 452 | }, 453 | "language_info": { 454 | "codemirror_mode": { 455 | "name": "ipython", 456 | "version": 3 457 | }, 458 | "file_extension": ".py", 459 | "mimetype": "text/x-python", 460 | "name": "python", 461 | "nbconvert_exporter": "python", 462 | "pygments_lexer": "ipython3", 463 | "version": "3.6.5" 464 | } 465 | }, 466 | "nbformat": 4, 467 | "nbformat_minor": 2 468 | } 469 | -------------------------------------------------------------------------------- /sgd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Creating Fake Data" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from sklearn.utils import shuffle\n", 18 | "from sklearn.datasets.samples_generator import make_blobs\n", 19 | "X, Y = make_blobs(n_samples=300, centers=2, n_features=2, cluster_std=5, random_state=11)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## Visualizing the Data" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "image/png": "\n", 37 | "text/plain": [ 38 | "
" 39 | ] 40 | }, 41 | "metadata": {}, 42 | "output_type": "display_data" 43 | } 44 | ], 45 | "source": [ 46 | "from matplotlib import pyplot as plt\n", 47 | "from pandas import DataFrame \n", 48 | "df = DataFrame(dict(x=X[:,0], y=X[:,1], label=Y))\n", 49 | "colors = {0:'blue', 1:'orange'}\n", 50 | "fig, ax = plt.subplots()\n", 51 | "grouped = df.groupby('label')\n", 52 | "for key, group in grouped:\n", 53 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n", 54 | "plt.xlabel('X_1')\n", 55 | "plt.ylabel('X_2')\n", 56 | "plt.show()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Splitting into batches" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "def next_batch(X, Y, batch_size):\n", 73 | " for i in np.arange(0, X.shape[0], batch_size):\n", 74 | " yield (X[i:i + batch_size], Y[i:i + batch_size])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Adding column of 1's " 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "(300, 3)" 93 | ] 94 | }, 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "X = np.c_[np.ones((X.shape[0])), X]\n", 102 | "X.shape" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Logistic Model" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "import numpy as np\n", 119 | "def sigmoid(z):\n", 120 | " return 1 / (1 + np.exp(-z))" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 7, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "def hx(W,X):\n", 130 | " return sigmoid(np.dot(X,W))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "## Cost Function - Binary Cross Entropy " 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "def cost(W, X, Y):\n", 147 | " y_pred = hx(W,X)\n", 148 | " return -1 * sum(Y*np.log(y_pred) + (1-Y)*np.log(1-y_pred))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## Stochastic Gradient Descent" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 9, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def grad(W, X, Y):\n", 165 | " y_pred = hx(W,X)\n", 166 | " A = (Y*(1-y_pred) - (1-Y)*y_pred)\n", 167 | " g = -1* np.dot(A.T,X)\n", 168 | " return g" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 10, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "def sgd(W_new, W_prev, lr, batch_size, epochs):\n", 178 | " X_, Y_ = shuffle(X, Y, random_state=0)\n", 179 | " for e in range(epochs):\n", 180 | " epoch_loss = []\n", 181 | " X_, Y_ = shuffle(X_, Y_, random_state=0)\n", 182 | " for (batchX, batchY) in next_batch(X_, Y_, batch_size):\n", 183 | " W_prev = W_new\n", 184 | " epoch_loss.append(cost(W_prev, batchX, batchY))\n", 185 | " gradients = grad(W_prev, batchX, batchY)\n", 186 | " W_new = W_prev - lr*gradients\n", 187 | " print(np.average(epoch_loss))\n", 188 | " return W_new" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## Initializing Weights & Bias" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 11, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "text/plain": [ 206 | "(3,)" 207 | ] 208 | }, 209 | "execution_count": 11, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | } 213 | ], 214 | "source": [ 215 | "W = np.random.uniform(size=(X.shape[1],))\n", 216 | "W.shape" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "## Training the Model" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 12, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "[0.79618896 0.4163837 0.42391672]\n", 236 | "6.277384446178855\n", 237 | "5.743854541290189\n", 238 | "5.639645113634937\n", 239 | "5.470853335391083\n", 240 | "5.667962099033007\n", 241 | "5.28981499238036\n", 242 | "5.271342725834096\n", 243 | "5.060368761481374\n", 244 | "5.446976524735743\n", 245 | "5.066447704414417\n", 246 | "5.09547652563199\n", 247 | "5.027047324020139\n", 248 | "5.231413878163542\n", 249 | "4.988553649645022\n", 250 | "5.045966239495252\n", 251 | "5.114093205881277\n", 252 | "5.24313649711587\n", 253 | "5.3309800879511595\n", 254 | "5.2369933729138705\n", 255 | "4.931357222546931\n", 256 | "5.141072913940974\n", 257 | "4.865177946360237\n", 258 | "5.159259956251705\n", 259 | "4.9850013574989935\n", 260 | "5.075992543432262\n", 261 | "5.257537980731875\n", 262 | "5.15220734911969\n", 263 | "5.376113028226739\n", 264 | "5.01790696023796\n", 265 | "4.91743150129637\n", 266 | "5.155592756181093\n", 267 | "4.778496875740468\n", 268 | "5.170830308497642\n", 269 | "4.982057737386109\n", 270 | "5.065383224896023\n", 271 | "5.130301783329287\n", 272 | "5.0560422794900335\n", 273 | "4.751542045996459\n", 274 | "5.044836198314092\n", 275 | "5.200065208594642\n", 276 | "5.285956700133218\n", 277 | "4.69607623817095\n", 278 | "5.1110131343684895\n", 279 | "4.929261653440459\n", 280 | "5.508577234972032\n", 281 | "4.9295442093408335\n", 282 | "5.35853341436791\n", 283 | "5.109882297828997\n", 284 | "5.268900506923824\n", 285 | "5.0553568625861995\n", 286 | "5.127108057053507\n", 287 | "5.256180863878795\n", 288 | "5.2017470229439295\n", 289 | "5.29706022610866\n", 290 | "4.842578732589242\n", 291 | "5.216497984078021\n", 292 | "5.167172158840554\n", 293 | "5.030913544155235\n", 294 | "4.9208154561894215\n", 295 | "5.0286895579435775\n", 296 | "5.254927983516582\n", 297 | "5.01657084246088\n", 298 | "4.902782163580072\n", 299 | "4.952060238428389\n", 300 | "4.9191214387250035\n", 301 | "5.027098170280329\n", 302 | "5.077946337475286\n", 303 | "4.880098287861289\n", 304 | "4.944972140093644\n", 305 | "5.052666303168795\n", 306 | "5.081615975832211\n", 307 | "5.182569862898558\n", 308 | "4.9697239615921625\n", 309 | "5.056706134894904\n", 310 | "5.110368058329958\n", 311 | "4.996883358551884\n", 312 | "4.989696640078847\n", 313 | "4.858898302441649\n", 314 | "4.930715531754518\n", 315 | "5.009314794537656\n", 316 | "5.127984153485331\n", 317 | "5.0231049124602825\n", 318 | "5.254008279820028\n", 319 | "5.014073801135062\n", 320 | "4.97986671455312\n", 321 | "5.05243719762335\n", 322 | "5.0420997708089415\n", 323 | "5.069451589063102\n", 324 | "5.150257226697879\n", 325 | "5.164400634677876\n", 326 | "5.039134823534483\n", 327 | "4.9883474157043795\n", 328 | "5.132433782589542\n", 329 | "4.891167385232467\n", 330 | "5.062085674381252\n", 331 | "5.2087055067642165\n", 332 | "4.955120478076216\n", 333 | "4.88214965225135\n", 334 | "4.935119155279668\n", 335 | "5.4527388101645755\n", 336 | "4.902647474574163\n", 337 | "4.973218626980669\n", 338 | "4.969125747240431\n", 339 | "5.034819845305565\n", 340 | "5.11158072777152\n", 341 | "5.052709686640736\n", 342 | "4.968840505649011\n", 343 | "5.088612053925859\n", 344 | "4.915818444090352\n", 345 | "5.093157727293972\n", 346 | "5.033660011718062\n", 347 | "5.223412724166921\n", 348 | "5.1077816665230475\n", 349 | "5.135494856877415\n", 350 | "4.989777059244171\n", 351 | "5.165527490515357\n", 352 | "5.327764783138228\n", 353 | "5.019080784749545\n", 354 | "5.0239554594473805\n", 355 | "5.026682139880581\n", 356 | "5.159710856114193\n", 357 | "5.027795318281852\n", 358 | "4.988648104694423\n", 359 | "5.012928309587274\n", 360 | "5.081666714554431\n", 361 | "5.272117622609143\n", 362 | "4.851869972177484\n", 363 | "4.956969286562633\n", 364 | "5.025160410234134\n", 365 | "4.993677569943672\n", 366 | "5.099870584911604\n", 367 | "5.096181311604027\n", 368 | "5.236923964162504\n", 369 | "4.903674873885619\n", 370 | "5.075807820064275\n", 371 | "4.701673307737802\n", 372 | "5.299083765240716\n", 373 | "5.112201692629844\n", 374 | "5.057198863500682\n", 375 | "5.213612968874003\n", 376 | "5.1198465312517225\n", 377 | "4.928097399784619\n", 378 | "4.991964359552881\n", 379 | "4.994824548764347\n", 380 | "5.049272339126923\n", 381 | "4.942713247937617\n", 382 | "4.91913758023255\n", 383 | "5.43143326229052\n", 384 | "5.116114463071936\n", 385 | "4.9625898877661445\n", 386 | "5.042137105313292\n", 387 | "5.157805681662321\n", 388 | "5.04896325685934\n", 389 | "5.018982423826876\n", 390 | "5.185034061711793\n", 391 | "5.032465453425752\n", 392 | "4.897108003638827\n", 393 | "4.95421678149162\n", 394 | "5.259006286037971\n", 395 | "5.182717981721985\n", 396 | "4.884826442197778\n", 397 | "5.156842877561891\n", 398 | "5.047645060816286\n", 399 | "5.050121866311711\n", 400 | "4.869794749038038\n", 401 | "4.9957190876138835\n", 402 | "5.022988953794398\n", 403 | "5.002839753440677\n", 404 | "4.881319490767778\n", 405 | "5.055591006717888\n", 406 | "5.015670050297785\n", 407 | "5.013392499945725\n", 408 | "4.989724363452584\n", 409 | "4.880412359751283\n", 410 | "5.416224099805669\n", 411 | "5.2867392415724135\n", 412 | "5.11414942585176\n", 413 | "5.079966239068073\n", 414 | "4.925427043815789\n", 415 | "5.145176957835869\n", 416 | "5.130160451160554\n", 417 | "5.024814422423505\n", 418 | "4.958769936817604\n", 419 | "4.81843182784822\n", 420 | "5.34842232805925\n", 421 | "5.279335054775109\n", 422 | "5.107921381756868\n", 423 | "5.15079791674754\n", 424 | "5.096406330816437\n", 425 | "5.00585692466524\n", 426 | "5.200218860079871\n", 427 | "4.96229644142354\n", 428 | "4.937523034036624\n", 429 | "5.174090668819386\n", 430 | "5.1400977102309175\n", 431 | "5.1319926651455186\n", 432 | "5.048127875476851\n", 433 | "5.416048064604175\n", 434 | "5.057120736099771\n", 435 | "4.864273959307718\n", 436 | "[2.38150361 0.22263106 0.53666744]\n" 437 | ] 438 | } 439 | ], 440 | "source": [ 441 | "print(W)\n", 442 | "W = sgd(W, W, .009, 32, 200)\n", 443 | "print(W)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "metadata": {}, 449 | "source": [ 450 | "## Visualizing the Result" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 13, 456 | "metadata": {}, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "image/png": "\n", 461 | "text/plain": [ 462 | "
" 463 | ] 464 | }, 465 | "metadata": {}, 466 | "output_type": "display_data" 467 | } 468 | ], 469 | "source": [ 470 | "def graph(formula, x_range): \n", 471 | " x = np.array(x_range) \n", 472 | " y = formula(x) \n", 473 | " plt.plot(x, y) \n", 474 | " \n", 475 | "def my_formula(x):\n", 476 | " return (-W[0]-W[1]*x)/W[2]\n", 477 | "\n", 478 | "from matplotlib import pyplot as plt\n", 479 | "from pandas import DataFrame \n", 480 | "df = DataFrame(dict(x=X[:,1], y=X[:,2], label=Y))\n", 481 | "colors = {0:'blue', 1:'orange'}\n", 482 | "fig, ax = plt.subplots()\n", 483 | "grouped = df.groupby('label')\n", 484 | "for key, group in grouped:\n", 485 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n", 486 | "graph(my_formula, range(-20,15))\n", 487 | "plt.xlabel('X_1')\n", 488 | "plt.ylabel('X_2')\n", 489 | "plt.show()" 490 | ] 491 | } 492 | ], 493 | "metadata": { 494 | "kernelspec": { 495 | "display_name": "Python 3", 496 | "language": "python", 497 | "name": "python3" 498 | }, 499 | "language_info": { 500 | "codemirror_mode": { 501 | "name": "ipython", 502 | "version": 3 503 | }, 504 | "file_extension": ".py", 505 | "mimetype": "text/x-python", 506 | "name": "python", 507 | "nbconvert_exporter": "python", 508 | "pygments_lexer": "ipython3", 509 | "version": "3.6.5" 510 | } 511 | }, 512 | "nbformat": 4, 513 | "nbformat_minor": 2 514 | } 515 | --------------------------------------------------------------------------------