├── Breast_cancer_data.csv
├── K-means Clustering.ipynb
├── KNN Classification + Regression.ipynb
├── README.md
├── airfoil_noise_data.csv
├── decision tree classification.ipynb
├── decision tree regression.ipynb
├── linear regression.ipynb
├── logistic regression.ipynb
├── naive bayes.ipynb
└── sgd.ipynb
/Breast_cancer_data.csv:
--------------------------------------------------------------------------------
1 | mean_radius,mean_texture,mean_perimeter,mean_area,mean_smoothness,diagnosis
2 | 17.99,10.38,122.8,1001.0,0.1184,0
3 | 20.57,17.77,132.9,1326.0,0.08474,0
4 | 19.69,21.25,130.0,1203.0,0.1096,0
5 | 11.42,20.38,77.58,386.1,0.1425,0
6 | 20.29,14.34,135.1,1297.0,0.1003,0
7 | 12.45,15.7,82.57,477.1,0.1278,0
8 | 18.25,19.98,119.6,1040.0,0.09463,0
9 | 13.71,20.83,90.2,577.9,0.1189,0
10 | 13.0,21.82,87.5,519.8,0.1273,0
11 | 12.46,24.04,83.97,475.9,0.1186,0
12 | 16.02,23.24,102.7,797.8,0.08206,0
13 | 15.78,17.89,103.6,781.0,0.0971,0
14 | 19.17,24.8,132.4,1123.0,0.0974,0
15 | 15.85,23.95,103.7,782.7,0.08401,0
16 | 13.73,22.61,93.6,578.3,0.1131,0
17 | 14.54,27.54,96.73,658.8,0.1139,0
18 | 14.68,20.13,94.74,684.5,0.09867,0
19 | 16.13,20.68,108.1,798.8,0.117,0
20 | 19.81,22.15,130.0,1260.0,0.09831,0
21 | 13.54,14.36,87.46,566.3,0.09779,1
22 | 13.08,15.71,85.63,520.0,0.1075,1
23 | 9.504,12.44,60.34,273.9,0.1024,1
24 | 15.34,14.26,102.5,704.4,0.1073,0
25 | 21.16,23.04,137.2,1404.0,0.09428,0
26 | 16.65,21.38,110.0,904.6,0.1121,0
27 | 17.14,16.4,116.0,912.7,0.1186,0
28 | 14.58,21.53,97.41,644.8,0.1054,0
29 | 18.61,20.25,122.1,1094.0,0.0944,0
30 | 15.3,25.27,102.4,732.4,0.1082,0
31 | 17.57,15.05,115.0,955.1,0.09847,0
32 | 18.63,25.11,124.8,1088.0,0.1064,0
33 | 11.84,18.7,77.93,440.6,0.1109,0
34 | 17.02,23.98,112.8,899.3,0.1197,0
35 | 19.27,26.47,127.9,1162.0,0.09401,0
36 | 16.13,17.88,107.0,807.2,0.104,0
37 | 16.74,21.59,110.1,869.5,0.0961,0
38 | 14.25,21.72,93.63,633.0,0.09823,0
39 | 13.03,18.42,82.61,523.8,0.08983,1
40 | 14.99,25.2,95.54,698.8,0.09387,0
41 | 13.48,20.82,88.4,559.2,0.1016,0
42 | 13.44,21.58,86.18,563.0,0.08162,0
43 | 10.95,21.35,71.9,371.1,0.1227,0
44 | 19.07,24.81,128.3,1104.0,0.09081,0
45 | 13.28,20.28,87.32,545.2,0.1041,0
46 | 13.17,21.81,85.42,531.5,0.09714,0
47 | 18.65,17.6,123.7,1076.0,0.1099,0
48 | 8.196,16.84,51.71,201.9,0.086,1
49 | 13.17,18.66,85.98,534.6,0.1158,0
50 | 12.05,14.63,78.04,449.3,0.1031,1
51 | 13.49,22.3,86.91,561.0,0.08752,1
52 | 11.76,21.6,74.72,427.9,0.08637,1
53 | 13.64,16.34,87.21,571.8,0.07685,1
54 | 11.94,18.24,75.71,437.6,0.08261,1
55 | 18.22,18.7,120.3,1033.0,0.1148,0
56 | 15.1,22.02,97.26,712.8,0.09056,0
57 | 11.52,18.75,73.34,409.0,0.09524,1
58 | 19.21,18.57,125.5,1152.0,0.1053,0
59 | 14.71,21.59,95.55,656.9,0.1137,0
60 | 13.05,19.31,82.61,527.2,0.0806,1
61 | 8.618,11.79,54.34,224.5,0.09752,1
62 | 10.17,14.88,64.55,311.9,0.1134,1
63 | 8.598,20.98,54.66,221.8,0.1243,1
64 | 14.25,22.15,96.42,645.7,0.1049,0
65 | 9.173,13.86,59.2,260.9,0.07721,1
66 | 12.68,23.84,82.69,499.0,0.1122,0
67 | 14.78,23.94,97.4,668.3,0.1172,0
68 | 9.465,21.01,60.11,269.4,0.1044,1
69 | 11.31,19.04,71.8,394.1,0.08139,1
70 | 9.029,17.33,58.79,250.5,0.1066,1
71 | 12.78,16.49,81.37,502.5,0.09831,1
72 | 18.94,21.31,123.6,1130.0,0.09009,0
73 | 8.888,14.64,58.79,244.0,0.09783,1
74 | 17.2,24.52,114.2,929.4,0.1071,0
75 | 13.8,15.79,90.43,584.1,0.1007,0
76 | 12.31,16.52,79.19,470.9,0.09172,1
77 | 16.07,19.65,104.1,817.7,0.09168,0
78 | 13.53,10.94,87.91,559.2,0.1291,1
79 | 18.05,16.15,120.2,1006.0,0.1065,0
80 | 20.18,23.97,143.7,1245.0,0.1286,0
81 | 12.86,18.0,83.19,506.3,0.09934,1
82 | 11.45,20.97,73.81,401.5,0.1102,1
83 | 13.34,15.86,86.49,520.0,0.1078,1
84 | 25.22,24.91,171.5,1878.0,0.1063,0
85 | 19.1,26.29,129.1,1132.0,0.1215,0
86 | 12.0,15.65,76.95,443.3,0.09723,1
87 | 18.46,18.52,121.1,1075.0,0.09874,0
88 | 14.48,21.46,94.25,648.2,0.09444,0
89 | 19.02,24.59,122.0,1076.0,0.09029,0
90 | 12.36,21.8,79.78,466.1,0.08772,1
91 | 14.64,15.24,95.77,651.9,0.1132,1
92 | 14.62,24.02,94.57,662.7,0.08974,1
93 | 15.37,22.76,100.2,728.2,0.092,0
94 | 13.27,14.76,84.74,551.7,0.07355,1
95 | 13.45,18.3,86.6,555.1,0.1022,1
96 | 15.06,19.83,100.3,705.6,0.1039,0
97 | 20.26,23.03,132.4,1264.0,0.09078,0
98 | 12.18,17.84,77.79,451.1,0.1045,1
99 | 9.787,19.94,62.11,294.5,0.1024,1
100 | 11.6,12.84,74.34,412.6,0.08983,1
101 | 14.42,19.77,94.48,642.5,0.09752,0
102 | 13.61,24.98,88.05,582.7,0.09488,0
103 | 6.981,13.43,43.79,143.5,0.117,1
104 | 12.18,20.52,77.22,458.7,0.08013,1
105 | 9.876,19.4,63.95,298.3,0.1005,1
106 | 10.49,19.29,67.41,336.1,0.09989,1
107 | 13.11,15.56,87.21,530.2,0.1398,0
108 | 11.64,18.33,75.17,412.5,0.1142,1
109 | 12.36,18.54,79.01,466.7,0.08477,1
110 | 22.27,19.67,152.8,1509.0,0.1326,0
111 | 11.34,21.26,72.48,396.5,0.08759,1
112 | 9.777,16.99,62.5,290.2,0.1037,1
113 | 12.63,20.76,82.15,480.4,0.09933,1
114 | 14.26,19.65,97.83,629.9,0.07837,1
115 | 10.51,20.19,68.64,334.2,0.1122,1
116 | 8.726,15.83,55.84,230.9,0.115,1
117 | 11.93,21.53,76.53,438.6,0.09768,1
118 | 8.95,15.76,58.74,245.2,0.09462,1
119 | 14.87,16.67,98.64,682.5,0.1162,0
120 | 15.78,22.91,105.7,782.6,0.1155,0
121 | 17.95,20.01,114.2,982.0,0.08402,0
122 | 11.41,10.82,73.34,403.3,0.09373,1
123 | 18.66,17.12,121.4,1077.0,0.1054,0
124 | 24.25,20.2,166.2,1761.0,0.1447,0
125 | 14.5,10.89,94.28,640.7,0.1101,1
126 | 13.37,16.39,86.1,553.5,0.07115,1
127 | 13.85,17.21,88.44,588.7,0.08785,1
128 | 13.61,24.69,87.76,572.6,0.09258,0
129 | 19.0,18.91,123.4,1138.0,0.08217,0
130 | 15.1,16.39,99.58,674.5,0.115,1
131 | 19.79,25.12,130.4,1192.0,0.1015,0
132 | 12.19,13.29,79.08,455.8,0.1066,1
133 | 15.46,19.48,101.7,748.9,0.1092,0
134 | 16.16,21.54,106.2,809.8,0.1008,0
135 | 15.71,13.93,102.0,761.7,0.09462,1
136 | 18.45,21.91,120.2,1075.0,0.0943,0
137 | 12.77,22.47,81.72,506.3,0.09055,0
138 | 11.71,16.67,74.72,423.6,0.1051,1
139 | 11.43,15.39,73.06,399.8,0.09639,1
140 | 14.95,17.57,96.85,678.1,0.1167,0
141 | 11.28,13.39,73.0,384.8,0.1164,1
142 | 9.738,11.97,61.24,288.5,0.0925,1
143 | 16.11,18.05,105.1,813.0,0.09721,0
144 | 11.43,17.31,73.66,398.0,0.1092,1
145 | 12.9,15.92,83.74,512.2,0.08677,1
146 | 10.75,14.97,68.26,355.3,0.07793,1
147 | 11.9,14.65,78.11,432.8,0.1152,1
148 | 11.8,16.58,78.99,432.0,0.1091,0
149 | 14.95,18.77,97.84,689.5,0.08138,1
150 | 14.44,15.18,93.97,640.1,0.0997,1
151 | 13.74,17.91,88.12,585.0,0.07944,1
152 | 13.0,20.78,83.51,519.4,0.1135,1
153 | 8.219,20.7,53.27,203.9,0.09405,1
154 | 9.731,15.34,63.78,300.2,0.1072,1
155 | 11.15,13.08,70.87,381.9,0.09754,1
156 | 13.15,15.34,85.31,538.9,0.09384,1
157 | 12.25,17.94,78.27,460.3,0.08654,1
158 | 17.68,20.74,117.4,963.7,0.1115,0
159 | 16.84,19.46,108.4,880.2,0.07445,1
160 | 12.06,12.74,76.84,448.6,0.09311,1
161 | 10.9,12.96,68.69,366.8,0.07515,1
162 | 11.75,20.18,76.1,419.8,0.1089,1
163 | 19.19,15.94,126.3,1157.0,0.08694,0
164 | 19.59,18.15,130.7,1214.0,0.112,0
165 | 12.34,22.22,79.85,464.5,0.1012,1
166 | 23.27,22.04,152.1,1686.0,0.08439,0
167 | 14.97,19.76,95.5,690.2,0.08421,1
168 | 10.8,9.71,68.77,357.6,0.09594,1
169 | 16.78,18.8,109.3,886.3,0.08865,0
170 | 17.47,24.68,116.1,984.6,0.1049,0
171 | 14.97,16.95,96.22,685.9,0.09855,1
172 | 12.32,12.39,78.85,464.1,0.1028,1
173 | 13.43,19.63,85.84,565.4,0.09048,0
174 | 15.46,11.89,102.5,736.9,0.1257,0
175 | 11.08,14.71,70.21,372.7,0.1006,1
176 | 10.66,15.15,67.49,349.6,0.08792,1
177 | 8.671,14.45,54.42,227.2,0.09138,1
178 | 9.904,18.06,64.6,302.4,0.09699,1
179 | 16.46,20.11,109.3,832.9,0.09831,0
180 | 13.01,22.22,82.01,526.4,0.06251,1
181 | 12.81,13.06,81.29,508.8,0.08739,1
182 | 27.22,21.87,182.1,2250.0,0.1094,0
183 | 21.09,26.57,142.7,1311.0,0.1141,0
184 | 15.7,20.31,101.2,766.6,0.09597,0
185 | 11.41,14.92,73.53,402.0,0.09059,1
186 | 15.28,22.41,98.92,710.6,0.09057,0
187 | 10.08,15.11,63.76,317.5,0.09267,1
188 | 18.31,18.58,118.6,1041.0,0.08588,0
189 | 11.71,17.19,74.68,420.3,0.09774,1
190 | 11.81,17.39,75.27,428.9,0.1007,1
191 | 12.3,15.9,78.83,463.7,0.0808,1
192 | 14.22,23.12,94.37,609.9,0.1075,0
193 | 12.77,21.41,82.02,507.4,0.08749,1
194 | 9.72,18.22,60.73,288.1,0.0695,1
195 | 12.34,26.86,81.15,477.4,0.1034,0
196 | 14.86,23.21,100.4,671.4,0.1044,0
197 | 12.91,16.33,82.53,516.4,0.07941,1
198 | 13.77,22.29,90.63,588.9,0.12,0
199 | 18.08,21.84,117.4,1024.0,0.07371,0
200 | 19.18,22.49,127.5,1148.0,0.08523,0
201 | 14.45,20.22,94.49,642.7,0.09872,0
202 | 12.23,19.56,78.54,461.0,0.09586,1
203 | 17.54,19.32,115.1,951.6,0.08968,0
204 | 23.29,26.67,158.9,1685.0,0.1141,0
205 | 13.81,23.75,91.56,597.8,0.1323,0
206 | 12.47,18.6,81.09,481.9,0.09965,1
207 | 15.12,16.68,98.78,716.6,0.08876,0
208 | 9.876,17.27,62.92,295.4,0.1089,1
209 | 17.01,20.26,109.7,904.3,0.08772,0
210 | 13.11,22.54,87.02,529.4,0.1002,1
211 | 15.27,12.91,98.17,725.5,0.08182,1
212 | 20.58,22.14,134.7,1290.0,0.0909,0
213 | 11.84,18.94,75.51,428.0,0.08871,1
214 | 28.11,18.47,188.5,2499.0,0.1142,0
215 | 17.42,25.56,114.5,948.0,0.1006,0
216 | 14.19,23.81,92.87,610.7,0.09463,0
217 | 13.86,16.93,90.96,578.9,0.1026,0
218 | 11.89,18.35,77.32,432.2,0.09363,1
219 | 10.2,17.48,65.05,321.2,0.08054,1
220 | 19.8,21.56,129.7,1230.0,0.09383,0
221 | 19.53,32.47,128.0,1223.0,0.0842,0
222 | 13.65,13.16,87.88,568.9,0.09646,1
223 | 13.56,13.9,88.59,561.3,0.1051,1
224 | 10.18,17.53,65.12,313.1,0.1061,1
225 | 15.75,20.25,102.6,761.3,0.1025,0
226 | 13.27,17.02,84.55,546.4,0.08445,1
227 | 14.34,13.47,92.51,641.2,0.09906,1
228 | 10.44,15.46,66.62,329.6,0.1053,1
229 | 15.0,15.51,97.45,684.5,0.08371,1
230 | 12.62,23.97,81.35,496.4,0.07903,1
231 | 12.83,22.33,85.26,503.2,0.1088,0
232 | 17.05,19.08,113.4,895.0,0.1141,0
233 | 11.32,27.08,71.76,395.7,0.06883,1
234 | 11.22,33.81,70.79,386.8,0.0778,1
235 | 20.51,27.81,134.4,1319.0,0.09159,0
236 | 9.567,15.91,60.21,279.6,0.08464,1
237 | 14.03,21.25,89.79,603.4,0.0907,1
238 | 23.21,26.97,153.5,1670.0,0.09509,0
239 | 20.48,21.46,132.5,1306.0,0.08355,0
240 | 14.22,27.85,92.55,623.9,0.08223,1
241 | 17.46,39.28,113.4,920.6,0.09812,0
242 | 13.64,15.6,87.38,575.3,0.09423,1
243 | 12.42,15.04,78.61,476.5,0.07926,1
244 | 11.3,18.19,73.93,389.4,0.09592,1
245 | 13.75,23.77,88.54,590.0,0.08043,1
246 | 19.4,23.5,129.1,1155.0,0.1027,0
247 | 10.48,19.86,66.72,337.7,0.107,1
248 | 13.2,17.43,84.13,541.6,0.07215,1
249 | 12.89,14.11,84.95,512.2,0.0876,1
250 | 10.65,25.22,68.01,347.0,0.09657,1
251 | 11.52,14.93,73.87,406.3,0.1013,1
252 | 20.94,23.56,138.9,1364.0,0.1007,0
253 | 11.5,18.45,73.28,407.4,0.09345,1
254 | 19.73,19.82,130.7,1206.0,0.1062,0
255 | 17.3,17.08,113.0,928.2,0.1008,0
256 | 19.45,19.33,126.5,1169.0,0.1035,0
257 | 13.96,17.05,91.43,602.4,0.1096,0
258 | 19.55,28.77,133.6,1207.0,0.0926,0
259 | 15.32,17.27,103.2,713.3,0.1335,0
260 | 15.66,23.2,110.2,773.5,0.1109,0
261 | 15.53,33.56,103.7,744.9,0.1063,0
262 | 20.31,27.06,132.9,1288.0,0.1,0
263 | 17.35,23.06,111.0,933.1,0.08662,0
264 | 17.29,22.13,114.4,947.8,0.08999,0
265 | 15.61,19.38,100.0,758.6,0.0784,0
266 | 17.19,22.07,111.6,928.3,0.09726,0
267 | 20.73,31.12,135.7,1419.0,0.09469,0
268 | 10.6,18.95,69.28,346.4,0.09688,1
269 | 13.59,21.84,87.16,561.0,0.07956,1
270 | 12.87,16.21,82.38,512.2,0.09425,1
271 | 10.71,20.39,69.5,344.9,0.1082,1
272 | 14.29,16.82,90.3,632.6,0.06429,1
273 | 11.29,13.04,72.23,388.0,0.09834,1
274 | 21.75,20.99,147.3,1491.0,0.09401,0
275 | 9.742,15.67,61.5,289.9,0.09037,1
276 | 17.93,24.48,115.2,998.9,0.08855,0
277 | 11.89,17.36,76.2,435.6,0.1225,1
278 | 11.33,14.16,71.79,396.6,0.09379,1
279 | 18.81,19.98,120.9,1102.0,0.08923,0
280 | 13.59,17.84,86.24,572.3,0.07948,1
281 | 13.85,15.18,88.99,587.4,0.09516,1
282 | 19.16,26.6,126.2,1138.0,0.102,0
283 | 11.74,14.02,74.24,427.3,0.07813,1
284 | 19.4,18.18,127.2,1145.0,0.1037,0
285 | 16.24,18.77,108.8,805.1,0.1066,0
286 | 12.89,15.7,84.08,516.6,0.07818,1
287 | 12.58,18.4,79.83,489.0,0.08393,1
288 | 11.94,20.76,77.87,441.0,0.08605,1
289 | 12.89,13.12,81.89,515.9,0.06955,1
290 | 11.26,19.96,73.72,394.1,0.0802,1
291 | 11.37,18.89,72.17,396.0,0.08713,1
292 | 14.41,19.73,96.03,651.0,0.08757,1
293 | 14.96,19.1,97.03,687.3,0.08992,1
294 | 12.95,16.02,83.14,513.7,0.1005,1
295 | 11.85,17.46,75.54,432.7,0.08372,1
296 | 12.72,13.78,81.78,492.1,0.09667,1
297 | 13.77,13.27,88.06,582.7,0.09198,1
298 | 10.91,12.35,69.14,363.7,0.08518,1
299 | 11.76,18.14,75.0,431.1,0.09968,0
300 | 14.26,18.17,91.22,633.1,0.06576,1
301 | 10.51,23.09,66.85,334.2,0.1015,1
302 | 19.53,18.9,129.5,1217.0,0.115,0
303 | 12.46,19.89,80.43,471.3,0.08451,1
304 | 20.09,23.86,134.7,1247.0,0.108,0
305 | 10.49,18.61,66.86,334.3,0.1068,1
306 | 11.46,18.16,73.59,403.1,0.08853,1
307 | 11.6,24.49,74.23,417.2,0.07474,1
308 | 13.2,15.82,84.07,537.3,0.08511,1
309 | 9.0,14.4,56.36,246.3,0.07005,1
310 | 13.5,12.71,85.69,566.2,0.07376,1
311 | 13.05,13.84,82.71,530.6,0.08352,1
312 | 11.7,19.11,74.33,418.7,0.08814,1
313 | 14.61,15.69,92.68,664.9,0.07618,1
314 | 12.76,13.37,82.29,504.1,0.08794,1
315 | 11.54,10.72,73.73,409.1,0.08597,1
316 | 8.597,18.6,54.09,221.2,0.1074,1
317 | 12.49,16.85,79.19,481.6,0.08511,1
318 | 12.18,14.08,77.25,461.4,0.07734,1
319 | 18.22,18.87,118.7,1027.0,0.09746,0
320 | 9.042,18.9,60.07,244.5,0.09968,1
321 | 12.43,17.0,78.6,477.3,0.07557,1
322 | 10.25,16.18,66.52,324.2,0.1061,1
323 | 20.16,19.66,131.1,1274.0,0.0802,0
324 | 12.86,13.32,82.82,504.8,0.1134,1
325 | 20.34,21.51,135.9,1264.0,0.117,0
326 | 12.2,15.21,78.01,457.9,0.08673,1
327 | 12.67,17.3,81.25,489.9,0.1028,1
328 | 14.11,12.88,90.03,616.5,0.09309,1
329 | 12.03,17.93,76.09,446.0,0.07683,1
330 | 16.27,20.71,106.9,813.7,0.1169,0
331 | 16.26,21.88,107.5,826.8,0.1165,0
332 | 16.03,15.51,105.8,793.2,0.09491,0
333 | 12.98,19.35,84.52,514.0,0.09579,1
334 | 11.22,19.86,71.94,387.3,0.1054,1
335 | 11.25,14.78,71.38,390.0,0.08306,1
336 | 12.3,19.02,77.88,464.4,0.08313,1
337 | 17.06,21.0,111.8,918.6,0.1119,0
338 | 12.99,14.23,84.08,514.3,0.09462,1
339 | 18.77,21.43,122.9,1092.0,0.09116,0
340 | 10.05,17.53,64.41,310.8,0.1007,1
341 | 23.51,24.27,155.1,1747.0,0.1069,0
342 | 14.42,16.54,94.15,641.2,0.09751,1
343 | 9.606,16.84,61.64,280.5,0.08481,1
344 | 11.06,14.96,71.49,373.9,0.1033,1
345 | 19.68,21.68,129.9,1194.0,0.09797,0
346 | 11.71,15.45,75.03,420.3,0.115,1
347 | 10.26,14.71,66.2,321.6,0.09882,1
348 | 12.06,18.9,76.66,445.3,0.08386,1
349 | 14.76,14.74,94.87,668.7,0.08875,1
350 | 11.47,16.03,73.02,402.7,0.09076,1
351 | 11.95,14.96,77.23,426.7,0.1158,1
352 | 11.66,17.07,73.7,421.0,0.07561,1
353 | 15.75,19.22,107.1,758.6,0.1243,0
354 | 25.73,17.46,174.2,2010.0,0.1149,0
355 | 15.08,25.74,98.0,716.6,0.1024,0
356 | 11.14,14.07,71.24,384.6,0.07274,1
357 | 12.56,19.07,81.92,485.8,0.0876,1
358 | 13.05,18.59,85.09,512.0,0.1082,1
359 | 13.87,16.21,88.52,593.7,0.08743,1
360 | 8.878,15.49,56.74,241.0,0.08293,1
361 | 9.436,18.32,59.82,278.6,0.1009,1
362 | 12.54,18.07,79.42,491.9,0.07436,1
363 | 13.3,21.57,85.24,546.1,0.08582,1
364 | 12.76,18.84,81.87,496.6,0.09676,1
365 | 16.5,18.29,106.6,838.1,0.09686,1
366 | 13.4,16.95,85.48,552.4,0.07937,1
367 | 20.44,21.78,133.8,1293.0,0.0915,0
368 | 20.2,26.83,133.7,1234.0,0.09905,0
369 | 12.21,18.02,78.31,458.4,0.09231,1
370 | 21.71,17.25,140.9,1546.0,0.09384,0
371 | 22.01,21.9,147.2,1482.0,0.1063,0
372 | 16.35,23.29,109.0,840.4,0.09742,0
373 | 15.19,13.21,97.65,711.8,0.07963,1
374 | 21.37,15.1,141.3,1386.0,0.1001,0
375 | 20.64,17.35,134.8,1335.0,0.09446,0
376 | 13.69,16.07,87.84,579.1,0.08302,1
377 | 16.17,16.07,106.3,788.5,0.0988,1
378 | 10.57,20.22,70.15,338.3,0.09073,1
379 | 13.46,28.21,85.89,562.1,0.07517,1
380 | 13.66,15.15,88.27,580.6,0.08268,1
381 | 11.08,18.83,73.3,361.6,0.1216,0
382 | 11.27,12.96,73.16,386.3,0.1237,1
383 | 11.04,14.93,70.67,372.7,0.07987,1
384 | 12.05,22.72,78.75,447.8,0.06935,1
385 | 12.39,17.48,80.64,462.9,0.1042,1
386 | 13.28,13.72,85.79,541.8,0.08363,1
387 | 14.6,23.29,93.97,664.7,0.08682,0
388 | 12.21,14.09,78.78,462.0,0.08108,1
389 | 13.88,16.16,88.37,596.6,0.07026,1
390 | 11.27,15.5,73.38,392.0,0.08365,1
391 | 19.55,23.21,128.9,1174.0,0.101,0
392 | 10.26,12.22,65.75,321.6,0.09996,1
393 | 8.734,16.84,55.27,234.3,0.1039,1
394 | 15.49,19.97,102.4,744.7,0.116,0
395 | 21.61,22.28,144.4,1407.0,0.1167,0
396 | 12.1,17.72,78.07,446.2,0.1029,1
397 | 14.06,17.18,89.75,609.1,0.08045,1
398 | 13.51,18.89,88.1,558.1,0.1059,1
399 | 12.8,17.46,83.05,508.3,0.08044,1
400 | 11.06,14.83,70.31,378.2,0.07741,1
401 | 11.8,17.26,75.26,431.9,0.09087,1
402 | 17.91,21.02,124.4,994.0,0.123,0
403 | 11.93,10.91,76.14,442.7,0.08872,1
404 | 12.96,18.29,84.18,525.2,0.07351,1
405 | 12.94,16.17,83.18,507.6,0.09879,1
406 | 12.34,14.95,78.29,469.1,0.08682,1
407 | 10.94,18.59,70.39,370.0,0.1004,1
408 | 16.14,14.86,104.3,800.0,0.09495,1
409 | 12.85,21.37,82.63,514.5,0.07551,1
410 | 17.99,20.66,117.8,991.7,0.1036,0
411 | 12.27,17.92,78.41,466.1,0.08685,1
412 | 11.36,17.57,72.49,399.8,0.08858,1
413 | 11.04,16.83,70.92,373.2,0.1077,1
414 | 9.397,21.68,59.75,268.8,0.07969,1
415 | 14.99,22.11,97.53,693.7,0.08515,1
416 | 15.13,29.81,96.71,719.5,0.0832,0
417 | 11.89,21.17,76.39,433.8,0.09773,1
418 | 9.405,21.7,59.6,271.2,0.1044,1
419 | 15.5,21.08,102.9,803.1,0.112,0
420 | 12.7,12.17,80.88,495.0,0.08785,1
421 | 11.16,21.41,70.95,380.3,0.1018,1
422 | 11.57,19.04,74.2,409.7,0.08546,1
423 | 14.69,13.98,98.22,656.1,0.1031,1
424 | 11.61,16.02,75.46,408.2,0.1088,1
425 | 13.66,19.13,89.46,575.3,0.09057,1
426 | 9.742,19.12,61.93,289.7,0.1075,1
427 | 10.03,21.28,63.19,307.3,0.08117,1
428 | 10.48,14.98,67.49,333.6,0.09816,1
429 | 10.8,21.98,68.79,359.9,0.08801,1
430 | 11.13,16.62,70.47,381.1,0.08151,1
431 | 12.72,17.67,80.98,501.3,0.07896,1
432 | 14.9,22.53,102.1,685.0,0.09947,0
433 | 12.4,17.68,81.47,467.8,0.1054,1
434 | 20.18,19.54,133.8,1250.0,0.1133,0
435 | 18.82,21.97,123.7,1110.0,0.1018,0
436 | 14.86,16.94,94.89,673.7,0.08924,1
437 | 13.98,19.62,91.12,599.5,0.106,0
438 | 12.87,19.54,82.67,509.2,0.09136,1
439 | 14.04,15.98,89.78,611.2,0.08458,1
440 | 13.85,19.6,88.68,592.6,0.08684,1
441 | 14.02,15.66,89.59,606.5,0.07966,1
442 | 10.97,17.2,71.73,371.5,0.08915,1
443 | 17.27,25.42,112.4,928.8,0.08331,0
444 | 13.78,15.79,88.37,585.9,0.08817,1
445 | 10.57,18.32,66.82,340.9,0.08142,1
446 | 18.03,16.85,117.5,990.0,0.08947,0
447 | 11.99,24.89,77.61,441.3,0.103,1
448 | 17.75,28.03,117.3,981.6,0.09997,0
449 | 14.8,17.66,95.88,674.8,0.09179,1
450 | 14.53,19.34,94.25,659.7,0.08388,1
451 | 21.1,20.52,138.1,1384.0,0.09684,0
452 | 11.87,21.54,76.83,432.0,0.06613,1
453 | 19.59,25.0,127.7,1191.0,0.1032,0
454 | 12.0,28.23,76.77,442.5,0.08437,1
455 | 14.53,13.98,93.86,644.2,0.1099,1
456 | 12.62,17.15,80.62,492.9,0.08583,1
457 | 13.38,30.72,86.34,557.2,0.09245,1
458 | 11.63,29.29,74.87,415.1,0.09357,1
459 | 13.21,25.25,84.1,537.9,0.08791,1
460 | 13.0,25.13,82.61,520.2,0.08369,1
461 | 9.755,28.2,61.68,290.9,0.07984,1
462 | 17.08,27.15,111.2,930.9,0.09898,0
463 | 27.42,26.27,186.9,2501.0,0.1084,0
464 | 14.4,26.99,92.25,646.1,0.06995,1
465 | 11.6,18.36,73.88,412.7,0.08508,1
466 | 13.17,18.22,84.28,537.3,0.07466,1
467 | 13.24,20.13,86.87,542.9,0.08284,1
468 | 13.14,20.74,85.98,536.9,0.08675,1
469 | 9.668,18.1,61.06,286.3,0.08311,1
470 | 17.6,23.33,119.0,980.5,0.09289,0
471 | 11.62,18.18,76.38,408.8,0.1175,1
472 | 9.667,18.49,61.49,289.1,0.08946,1
473 | 12.04,28.14,76.85,449.9,0.08752,1
474 | 14.92,14.93,96.45,686.9,0.08098,1
475 | 12.27,29.97,77.42,465.4,0.07699,1
476 | 10.88,15.62,70.41,358.9,0.1007,1
477 | 12.83,15.73,82.89,506.9,0.0904,1
478 | 14.2,20.53,92.41,618.4,0.08931,1
479 | 13.9,16.62,88.97,599.4,0.06828,1
480 | 11.49,14.59,73.99,404.9,0.1046,1
481 | 16.25,19.51,109.8,815.8,0.1026,0
482 | 12.16,18.03,78.29,455.3,0.09087,1
483 | 13.9,19.24,88.73,602.9,0.07991,1
484 | 13.47,14.06,87.32,546.3,0.1071,1
485 | 13.7,17.64,87.76,571.1,0.0995,1
486 | 15.73,11.28,102.8,747.2,0.1043,1
487 | 12.45,16.41,82.85,476.7,0.09514,1
488 | 14.64,16.85,94.21,666.0,0.08641,1
489 | 19.44,18.82,128.1,1167.0,0.1089,0
490 | 11.68,16.17,75.49,420.5,0.1128,1
491 | 16.69,20.2,107.1,857.6,0.07497,0
492 | 12.25,22.44,78.18,466.5,0.08192,1
493 | 17.85,13.23,114.6,992.1,0.07838,1
494 | 18.01,20.56,118.4,1007.0,0.1001,0
495 | 12.46,12.83,78.83,477.3,0.07372,1
496 | 13.16,20.54,84.06,538.7,0.07335,1
497 | 14.87,20.21,96.12,680.9,0.09587,1
498 | 12.65,18.17,82.69,485.6,0.1076,1
499 | 12.47,17.31,80.45,480.1,0.08928,1
500 | 18.49,17.52,121.3,1068.0,0.1012,0
501 | 20.59,21.24,137.8,1320.0,0.1085,0
502 | 15.04,16.74,98.73,689.4,0.09883,1
503 | 13.82,24.49,92.33,595.9,0.1162,0
504 | 12.54,16.32,81.25,476.3,0.1158,1
505 | 23.09,19.83,152.1,1682.0,0.09342,0
506 | 9.268,12.87,61.49,248.7,0.1634,1
507 | 9.676,13.14,64.12,272.5,0.1255,1
508 | 12.22,20.04,79.47,453.1,0.1096,1
509 | 11.06,17.12,71.25,366.5,0.1194,1
510 | 16.3,15.7,104.7,819.8,0.09427,1
511 | 15.46,23.95,103.8,731.3,0.1183,0
512 | 11.74,14.69,76.31,426.0,0.08099,1
513 | 14.81,14.7,94.66,680.7,0.08472,1
514 | 13.4,20.52,88.64,556.7,0.1106,0
515 | 14.58,13.66,94.29,658.8,0.09832,1
516 | 15.05,19.07,97.26,701.9,0.09215,0
517 | 11.34,18.61,72.76,391.2,0.1049,1
518 | 18.31,20.58,120.8,1052.0,0.1068,0
519 | 19.89,20.26,130.5,1214.0,0.1037,0
520 | 12.88,18.22,84.45,493.1,0.1218,1
521 | 12.75,16.7,82.51,493.8,0.1125,1
522 | 9.295,13.9,59.96,257.8,0.1371,1
523 | 24.63,21.6,165.5,1841.0,0.103,0
524 | 11.26,19.83,71.3,388.1,0.08511,1
525 | 13.71,18.68,88.73,571.0,0.09916,1
526 | 9.847,15.68,63.0,293.2,0.09492,1
527 | 8.571,13.1,54.53,221.3,0.1036,1
528 | 13.46,18.75,87.44,551.1,0.1075,1
529 | 12.34,12.27,78.94,468.5,0.09003,1
530 | 13.94,13.17,90.31,594.2,0.1248,1
531 | 12.07,13.44,77.83,445.2,0.11,1
532 | 11.75,17.56,75.89,422.9,0.1073,1
533 | 11.67,20.02,75.21,416.2,0.1016,1
534 | 13.68,16.33,87.76,575.5,0.09277,1
535 | 20.47,20.67,134.7,1299.0,0.09156,0
536 | 10.96,17.62,70.79,365.6,0.09687,1
537 | 20.55,20.86,137.8,1308.0,0.1046,0
538 | 14.27,22.55,93.77,629.8,0.1038,0
539 | 11.69,24.44,76.37,406.4,0.1236,1
540 | 7.729,25.49,47.98,178.8,0.08098,1
541 | 7.691,25.44,48.34,170.4,0.08668,1
542 | 11.54,14.44,74.65,402.9,0.09984,1
543 | 14.47,24.99,95.81,656.4,0.08837,1
544 | 14.74,25.42,94.7,668.6,0.08275,1
545 | 13.21,28.06,84.88,538.4,0.08671,1
546 | 13.87,20.7,89.77,584.8,0.09578,1
547 | 13.62,23.23,87.19,573.2,0.09246,1
548 | 10.32,16.35,65.31,324.9,0.09434,1
549 | 10.26,16.58,65.85,320.8,0.08877,1
550 | 9.683,19.34,61.05,285.7,0.08491,1
551 | 10.82,24.21,68.89,361.6,0.08192,1
552 | 10.86,21.48,68.51,360.5,0.07431,1
553 | 11.13,22.44,71.49,378.4,0.09566,1
554 | 12.77,29.43,81.35,507.9,0.08276,1
555 | 9.333,21.94,59.01,264.0,0.0924,1
556 | 12.88,28.92,82.5,514.3,0.08123,1
557 | 10.29,27.61,65.67,321.4,0.0903,1
558 | 10.16,19.59,64.73,311.7,0.1003,1
559 | 9.423,27.88,59.26,271.3,0.08123,1
560 | 14.59,22.68,96.39,657.1,0.08473,1
561 | 11.51,23.93,74.52,403.5,0.09261,1
562 | 14.05,27.15,91.38,600.4,0.09929,1
563 | 11.2,29.37,70.67,386.0,0.07449,1
564 | 15.22,30.62,103.4,716.9,0.1048,0
565 | 20.92,25.09,143.0,1347.0,0.1099,0
566 | 21.56,22.39,142.0,1479.0,0.111,0
567 | 20.13,28.25,131.2,1261.0,0.0978,0
568 | 16.6,28.08,108.3,858.1,0.08455,0
569 | 20.6,29.33,140.1,1265.0,0.1178,0
570 | 7.76,24.54,47.92,181.0,0.05263,1
571 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ML_from_Scratch
2 | Trying to implement basic ml algorithms from scratch in python. I have also made videos explaining these algorithms...
3 |
4 | ## Gradient Descent
5 | [](https://youtu.be/36zkIAAUcZ4)
6 | [](https://youtu.be/41BiBUZbg9U)
7 |
8 |
9 | ## Linear Regression
10 | [](https://youtu.be/fnDO1s4fzi4)
11 |
12 | ## Logistic Regression
13 | [](https://youtu.be/NtjAeXppomA)
14 |
15 | ## Stochastic Gradient Descent
16 | [](https://youtu.be/V8InSDYHG4s)
17 |
18 | ## KNN
19 | [](https://youtu.be/0RwM2BaLNkE)
20 |
21 | ## K-means
22 | [](https://youtu.be/IB9WfafBmjk)
23 |
24 | ## Decision Tree Classification
25 | [](https://youtu.be/ZVR2Way4nwQ)
26 | [](https://youtu.be/sgQAhG5Q7iY)
27 |
28 |
29 | ## Decision Tree Regression
30 | [](https://youtu.be/UhY5vPfQIrA)
31 | [](https://youtu.be/P2ZB8c5Ha1Q)
32 |
33 | ## Naive Bayes Classification
34 | [](https://youtu.be/lFJbZ6LVxN8)
35 | [](https://youtu.be/3I8oX3OUL6I)
36 |
37 | ## Data source
38 | - airfoil_noise_data.csv (converted from the .dat file available at https://archive.ics.uci.edu/ml/datasets/airfoil+self-noise)
39 |
40 | Donor:
41 | Dr Roberto Lopez
42 | robertolopez '@' intelnics.com
43 | Intelnics
44 |
45 | Creators:
46 | Thomas F. Brooks, D. Stuart Pope and Michael A. Marcolini
47 | NASA
48 |
49 | - Breast_cancer_data.csv (taken from https://www.kaggle.com/merishnasuwal/breast-cancer-prediction-dataset)
50 |
--------------------------------------------------------------------------------
/decision tree classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Import tools"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "import pandas as pd"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "## Get the data"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {
31 | "scrolled": false
32 | },
33 | "outputs": [
34 | {
35 | "data": {
36 | "text/html": [
37 | "
\n",
38 | "\n",
51 | "
\n",
52 | " \n",
53 | " \n",
54 | " | \n",
55 | " sepal_length | \n",
56 | " sepal_width | \n",
57 | " petal_length | \n",
58 | " petal_width | \n",
59 | " type | \n",
60 | "
\n",
61 | " \n",
62 | " \n",
63 | " \n",
64 | " 0 | \n",
65 | " 5.1 | \n",
66 | " 3.5 | \n",
67 | " 1.4 | \n",
68 | " 0.2 | \n",
69 | " 0 | \n",
70 | "
\n",
71 | " \n",
72 | " 1 | \n",
73 | " 4.9 | \n",
74 | " 3.0 | \n",
75 | " 1.4 | \n",
76 | " 0.2 | \n",
77 | " 0 | \n",
78 | "
\n",
79 | " \n",
80 | " 2 | \n",
81 | " 4.7 | \n",
82 | " 3.2 | \n",
83 | " 1.3 | \n",
84 | " 0.2 | \n",
85 | " 0 | \n",
86 | "
\n",
87 | " \n",
88 | " 3 | \n",
89 | " 4.6 | \n",
90 | " 3.1 | \n",
91 | " 1.5 | \n",
92 | " 0.2 | \n",
93 | " 0 | \n",
94 | "
\n",
95 | " \n",
96 | " 4 | \n",
97 | " 5.0 | \n",
98 | " 3.6 | \n",
99 | " 1.4 | \n",
100 | " 0.2 | \n",
101 | " 0 | \n",
102 | "
\n",
103 | " \n",
104 | " 5 | \n",
105 | " 5.4 | \n",
106 | " 3.9 | \n",
107 | " 1.7 | \n",
108 | " 0.4 | \n",
109 | " 0 | \n",
110 | "
\n",
111 | " \n",
112 | " 6 | \n",
113 | " 4.6 | \n",
114 | " 3.4 | \n",
115 | " 1.4 | \n",
116 | " 0.3 | \n",
117 | " 0 | \n",
118 | "
\n",
119 | " \n",
120 | " 7 | \n",
121 | " 5.0 | \n",
122 | " 3.4 | \n",
123 | " 1.5 | \n",
124 | " 0.2 | \n",
125 | " 0 | \n",
126 | "
\n",
127 | " \n",
128 | " 8 | \n",
129 | " 4.4 | \n",
130 | " 2.9 | \n",
131 | " 1.4 | \n",
132 | " 0.2 | \n",
133 | " 0 | \n",
134 | "
\n",
135 | " \n",
136 | " 9 | \n",
137 | " 4.9 | \n",
138 | " 3.1 | \n",
139 | " 1.5 | \n",
140 | " 0.1 | \n",
141 | " 0 | \n",
142 | "
\n",
143 | " \n",
144 | "
\n",
145 | "
"
146 | ],
147 | "text/plain": [
148 | " sepal_length sepal_width petal_length petal_width type\n",
149 | "0 5.1 3.5 1.4 0.2 0\n",
150 | "1 4.9 3.0 1.4 0.2 0\n",
151 | "2 4.7 3.2 1.3 0.2 0\n",
152 | "3 4.6 3.1 1.5 0.2 0\n",
153 | "4 5.0 3.6 1.4 0.2 0\n",
154 | "5 5.4 3.9 1.7 0.4 0\n",
155 | "6 4.6 3.4 1.4 0.3 0\n",
156 | "7 5.0 3.4 1.5 0.2 0\n",
157 | "8 4.4 2.9 1.4 0.2 0\n",
158 | "9 4.9 3.1 1.5 0.1 0"
159 | ]
160 | },
161 | "execution_count": 2,
162 | "metadata": {},
163 | "output_type": "execute_result"
164 | }
165 | ],
166 | "source": [
167 | "col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'type']\n",
168 | "data = pd.read_csv(\"iris.csv\", skiprows=1, header=None, names=col_names)\n",
169 | "data.head(10)"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {},
175 | "source": [
176 | "## Node class"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 3,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "class Node():\n",
186 | " def __init__(self, feature_index=None, threshold=None, left=None, right=None, info_gain=None, value=None):\n",
187 | " ''' constructor ''' \n",
188 | " \n",
189 | " # for decision node\n",
190 | " self.feature_index = feature_index\n",
191 | " self.threshold = threshold\n",
192 | " self.left = left\n",
193 | " self.right = right\n",
194 | " self.info_gain = info_gain\n",
195 | " \n",
196 | " # for leaf node\n",
197 | " self.value = value"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "## Tree class"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 4,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "class DecisionTreeClassifier():\n",
214 | " def __init__(self, min_samples_split=2, max_depth=2):\n",
215 | " ''' constructor '''\n",
216 | " \n",
217 | " # initialize the root of the tree \n",
218 | " self.root = None\n",
219 | " \n",
220 | " # stopping conditions\n",
221 | " self.min_samples_split = min_samples_split\n",
222 | " self.max_depth = max_depth\n",
223 | " \n",
224 | " def build_tree(self, dataset, curr_depth=0):\n",
225 | " ''' recursive function to build the tree ''' \n",
226 | " \n",
227 | " X, Y = dataset[:,:-1], dataset[:,-1]\n",
228 | " num_samples, num_features = np.shape(X)\n",
229 | " \n",
230 | " # split until stopping conditions are met\n",
231 | " if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n",
232 | " # find the best split\n",
233 | " best_split = self.get_best_split(dataset, num_samples, num_features)\n",
234 | " # check if information gain is positive\n",
235 | " if best_split[\"info_gain\"]>0:\n",
236 | " # recur left\n",
237 | " left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n",
238 | " # recur right\n",
239 | " right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n",
240 | " # return decision node\n",
241 | " return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n",
242 | " left_subtree, right_subtree, best_split[\"info_gain\"])\n",
243 | " \n",
244 | " # compute leaf node\n",
245 | " leaf_value = self.calculate_leaf_value(Y)\n",
246 | " # return leaf node\n",
247 | " return Node(value=leaf_value)\n",
248 | " \n",
249 | " def get_best_split(self, dataset, num_samples, num_features):\n",
250 | " ''' function to find the best split '''\n",
251 | " \n",
252 | " # dictionary to store the best split\n",
253 | " best_split = {}\n",
254 | " max_info_gain = -float(\"inf\")\n",
255 | " \n",
256 | " # loop over all the features\n",
257 | " for feature_index in range(num_features):\n",
258 | " feature_values = dataset[:, feature_index]\n",
259 | " possible_thresholds = np.unique(feature_values)\n",
260 | " # loop over all the feature values present in the data\n",
261 | " for threshold in possible_thresholds:\n",
262 | " # get current split\n",
263 | " dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n",
264 | " # check if childs are not null\n",
265 | " if len(dataset_left)>0 and len(dataset_right)>0:\n",
266 | " y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n",
267 | " # compute information gain\n",
268 | " curr_info_gain = self.information_gain(y, left_y, right_y, \"gini\")\n",
269 | " # update the best split if needed\n",
270 | " if curr_info_gain>max_info_gain:\n",
271 | " best_split[\"feature_index\"] = feature_index\n",
272 | " best_split[\"threshold\"] = threshold\n",
273 | " best_split[\"dataset_left\"] = dataset_left\n",
274 | " best_split[\"dataset_right\"] = dataset_right\n",
275 | " best_split[\"info_gain\"] = curr_info_gain\n",
276 | " max_info_gain = curr_info_gain\n",
277 | " \n",
278 | " # return best split\n",
279 | " return best_split\n",
280 | " \n",
281 | " def split(self, dataset, feature_index, threshold):\n",
282 | " ''' function to split the data '''\n",
283 | " \n",
284 | " dataset_left = np.array([row for row in dataset if row[feature_index]<=threshold])\n",
285 | " dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n",
286 | " return dataset_left, dataset_right\n",
287 | " \n",
288 | " def information_gain(self, parent, l_child, r_child, mode=\"entropy\"):\n",
289 | " ''' function to compute information gain '''\n",
290 | " \n",
291 | " weight_l = len(l_child) / len(parent)\n",
292 | " weight_r = len(r_child) / len(parent)\n",
293 | " if mode==\"gini\":\n",
294 | " gain = self.gini_index(parent) - (weight_l*self.gini_index(l_child) + weight_r*self.gini_index(r_child))\n",
295 | " else:\n",
296 | " gain = self.entropy(parent) - (weight_l*self.entropy(l_child) + weight_r*self.entropy(r_child))\n",
297 | " return gain\n",
298 | " \n",
299 | " def entropy(self, y):\n",
300 | " ''' function to compute entropy '''\n",
301 | " \n",
302 | " class_labels = np.unique(y)\n",
303 | " entropy = 0\n",
304 | " for cls in class_labels:\n",
305 | " p_cls = len(y[y == cls]) / len(y)\n",
306 | " entropy += -p_cls * np.log2(p_cls)\n",
307 | " return entropy\n",
308 | " \n",
309 | " def gini_index(self, y):\n",
310 | " ''' function to compute gini index '''\n",
311 | " \n",
312 | " class_labels = np.unique(y)\n",
313 | " gini = 0\n",
314 | " for cls in class_labels:\n",
315 | " p_cls = len(y[y == cls]) / len(y)\n",
316 | " gini += p_cls**2\n",
317 | " return 1 - gini\n",
318 | " \n",
319 | " def calculate_leaf_value(self, Y):\n",
320 | " ''' function to compute leaf node '''\n",
321 | " \n",
322 | " Y = list(Y)\n",
323 | " return max(Y, key=Y.count)\n",
324 | " \n",
325 | " def print_tree(self, tree=None, indent=\" \"):\n",
326 | " ''' function to print the tree '''\n",
327 | " \n",
328 | " if not tree:\n",
329 | " tree = self.root\n",
330 | "\n",
331 | " if tree.value is not None:\n",
332 | " print(tree.value)\n",
333 | "\n",
334 | " else:\n",
335 | " print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.info_gain)\n",
336 | " print(\"%sleft:\" % (indent), end=\"\")\n",
337 | " self.print_tree(tree.left, indent + indent)\n",
338 | " print(\"%sright:\" % (indent), end=\"\")\n",
339 | " self.print_tree(tree.right, indent + indent)\n",
340 | " \n",
341 | " def fit(self, X, Y):\n",
342 | " ''' function to train the tree '''\n",
343 | " \n",
344 | " dataset = np.concatenate((X, Y), axis=1)\n",
345 | " self.root = self.build_tree(dataset)\n",
346 | " \n",
347 | " def predict(self, X):\n",
348 | " ''' function to predict new dataset '''\n",
349 | " \n",
350 | " preditions = [self.make_prediction(x, self.root) for x in X]\n",
351 | " return preditions\n",
352 | " \n",
353 | " def make_prediction(self, x, tree):\n",
354 | " ''' function to predict a single data point '''\n",
355 | " \n",
356 | " if tree.value!=None: return tree.value\n",
357 | " feature_val = x[tree.feature_index]\n",
358 | " if feature_val<=tree.threshold:\n",
359 | " return self.make_prediction(x, tree.left)\n",
360 | " else:\n",
361 | " return self.make_prediction(x, tree.right)"
362 | ]
363 | },
364 | {
365 | "cell_type": "markdown",
366 | "metadata": {},
367 | "source": [
368 | "## Train-Test split"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": 5,
374 | "metadata": {},
375 | "outputs": [],
376 | "source": [
377 | "X = data.iloc[:, :-1].values\n",
378 | "Y = data.iloc[:, -1].values.reshape(-1,1)\n",
379 | "from sklearn.model_selection import train_test_split\n",
380 | "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, random_state=41)"
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {},
386 | "source": [
387 | "## Fit the model"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 6,
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "X_2 <= 1.9 ? 0.33741385372714494\n",
400 | " left:0.0\n",
401 | " right:X_3 <= 1.5 ? 0.427106638180289\n",
402 | " left:X_2 <= 4.9 ? 0.05124653739612173\n",
403 | " left:1.0\n",
404 | " right:2.0\n",
405 | " right:X_2 <= 5.0 ? 0.019631171921475288\n",
406 | " left:X_1 <= 2.8 ? 0.20833333333333334\n",
407 | " left:2.0\n",
408 | " right:1.0\n",
409 | " right:2.0\n"
410 | ]
411 | }
412 | ],
413 | "source": [
414 | "classifier = DecisionTreeClassifier(min_samples_split=3, max_depth=3)\n",
415 | "classifier.fit(X_train,Y_train)\n",
416 | "classifier.print_tree()"
417 | ]
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "metadata": {},
422 | "source": [
423 | "## Test the model"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 7,
429 | "metadata": {},
430 | "outputs": [
431 | {
432 | "data": {
433 | "text/plain": [
434 | "0.9333333333333333"
435 | ]
436 | },
437 | "execution_count": 7,
438 | "metadata": {},
439 | "output_type": "execute_result"
440 | }
441 | ],
442 | "source": [
443 | "Y_pred = classifier.predict(X_test) \n",
444 | "from sklearn.metrics import accuracy_score\n",
445 | "accuracy_score(Y_test, Y_pred)"
446 | ]
447 | }
448 | ],
449 | "metadata": {
450 | "kernelspec": {
451 | "display_name": "Python 3",
452 | "language": "python",
453 | "name": "python3"
454 | },
455 | "language_info": {
456 | "codemirror_mode": {
457 | "name": "ipython",
458 | "version": 3
459 | },
460 | "file_extension": ".py",
461 | "mimetype": "text/x-python",
462 | "name": "python",
463 | "nbconvert_exporter": "python",
464 | "pygments_lexer": "ipython3",
465 | "version": "3.8.5"
466 | }
467 | },
468 | "nbformat": 4,
469 | "nbformat_minor": 4
470 | }
471 |
--------------------------------------------------------------------------------
/decision tree regression.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Import tools"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "import pandas as pd"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "## Get the data"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "data": {
34 | "text/html": [
35 | "\n",
36 | "\n",
49 | "
\n",
50 | " \n",
51 | " \n",
52 | " | \n",
53 | " x0 | \n",
54 | " x1 | \n",
55 | " x2 | \n",
56 | " x3 | \n",
57 | " x4 | \n",
58 | " y | \n",
59 | "
\n",
60 | " \n",
61 | " \n",
62 | " \n",
63 | " 0 | \n",
64 | " 800 | \n",
65 | " 0.0 | \n",
66 | " 0.3048 | \n",
67 | " 71.3 | \n",
68 | " 0.002663 | \n",
69 | " 126.201 | \n",
70 | "
\n",
71 | " \n",
72 | " 1 | \n",
73 | " 1000 | \n",
74 | " 0.0 | \n",
75 | " 0.3048 | \n",
76 | " 71.3 | \n",
77 | " 0.002663 | \n",
78 | " 125.201 | \n",
79 | "
\n",
80 | " \n",
81 | " 2 | \n",
82 | " 1250 | \n",
83 | " 0.0 | \n",
84 | " 0.3048 | \n",
85 | " 71.3 | \n",
86 | " 0.002663 | \n",
87 | " 125.951 | \n",
88 | "
\n",
89 | " \n",
90 | " 3 | \n",
91 | " 1600 | \n",
92 | " 0.0 | \n",
93 | " 0.3048 | \n",
94 | " 71.3 | \n",
95 | " 0.002663 | \n",
96 | " 127.591 | \n",
97 | "
\n",
98 | " \n",
99 | " 4 | \n",
100 | " 2000 | \n",
101 | " 0.0 | \n",
102 | " 0.3048 | \n",
103 | " 71.3 | \n",
104 | " 0.002663 | \n",
105 | " 127.461 | \n",
106 | "
\n",
107 | " \n",
108 | "
\n",
109 | "
"
110 | ],
111 | "text/plain": [
112 | " x0 x1 x2 x3 x4 y\n",
113 | "0 800 0.0 0.3048 71.3 0.002663 126.201\n",
114 | "1 1000 0.0 0.3048 71.3 0.002663 125.201\n",
115 | "2 1250 0.0 0.3048 71.3 0.002663 125.951\n",
116 | "3 1600 0.0 0.3048 71.3 0.002663 127.591\n",
117 | "4 2000 0.0 0.3048 71.3 0.002663 127.461"
118 | ]
119 | },
120 | "execution_count": 2,
121 | "metadata": {},
122 | "output_type": "execute_result"
123 | }
124 | ],
125 | "source": [
126 | "data = pd.read_csv(\"airfoil_noise_data.csv\")\n",
127 | "data.head(5)"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "## Node class"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 3,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "class Node():\n",
144 | " def __init__(self, feature_index=None, threshold=None, left=None, right=None, var_red=None, value=None):\n",
145 | " ''' constructor ''' \n",
146 | " \n",
147 | " # for decision node\n",
148 | " self.feature_index = feature_index\n",
149 | " self.threshold = threshold\n",
150 | " self.left = left\n",
151 | " self.right = right\n",
152 | " self.var_red = var_red\n",
153 | " \n",
154 | " # for leaf node\n",
155 | " self.value = value"
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "metadata": {},
161 | "source": [
162 | "## Tree class"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 4,
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "class DecisionTreeRegressor():\n",
172 | " def __init__(self, min_samples_split=2, max_depth=2):\n",
173 | " ''' constructor '''\n",
174 | " \n",
175 | " # initialize the root of the tree \n",
176 | " self.root = None\n",
177 | " \n",
178 | " # stopping conditions\n",
179 | " self.min_samples_split = min_samples_split\n",
180 | " self.max_depth = max_depth\n",
181 | " \n",
182 | " def build_tree(self, dataset, curr_depth=0):\n",
183 | " ''' recursive function to build the tree '''\n",
184 | " \n",
185 | " X, Y = dataset[:,:-1], dataset[:,-1]\n",
186 | " num_samples, num_features = np.shape(X)\n",
187 | " best_split = {}\n",
188 | " # split until stopping conditions are met\n",
189 | " if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n",
190 | " # find the best split\n",
191 | " best_split = self.get_best_split(dataset, num_samples, num_features)\n",
192 | " # check if information gain is positive\n",
193 | " if best_split[\"var_red\"]>0:\n",
194 | " # recur left\n",
195 | " left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n",
196 | " # recur right\n",
197 | " right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n",
198 | " # return decision node\n",
199 | " return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n",
200 | " left_subtree, right_subtree, best_split[\"var_red\"])\n",
201 | " \n",
202 | " # compute leaf node\n",
203 | " leaf_value = self.calculate_leaf_value(Y)\n",
204 | " # return leaf node\n",
205 | " return Node(value=leaf_value)\n",
206 | " \n",
207 | " def get_best_split(self, dataset, num_samples, num_features):\n",
208 | " ''' function to find the best split '''\n",
209 | " \n",
210 | " # dictionary to store the best split\n",
211 | " best_split = {}\n",
212 | " max_var_red = -float(\"inf\")\n",
213 | " # loop over all the features\n",
214 | " for feature_index in range(num_features):\n",
215 | " feature_values = dataset[:, feature_index]\n",
216 | " possible_thresholds = np.unique(feature_values)\n",
217 | " # loop over all the feature values present in the data\n",
218 | " for threshold in possible_thresholds:\n",
219 | " # get current split\n",
220 | " dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n",
221 | " # check if childs are not null\n",
222 | " if len(dataset_left)>0 and len(dataset_right)>0:\n",
223 | " y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n",
224 | " # compute information gain\n",
225 | " curr_var_red = self.variance_reduction(y, left_y, right_y)\n",
226 | " # update the best split if needed\n",
227 | " if curr_var_red>max_var_red:\n",
228 | " best_split[\"feature_index\"] = feature_index\n",
229 | " best_split[\"threshold\"] = threshold\n",
230 | " best_split[\"dataset_left\"] = dataset_left\n",
231 | " best_split[\"dataset_right\"] = dataset_right\n",
232 | " best_split[\"var_red\"] = curr_var_red\n",
233 | " max_var_red = curr_var_red\n",
234 | " \n",
235 | " # return best split\n",
236 | " return best_split\n",
237 | " \n",
238 | " def split(self, dataset, feature_index, threshold):\n",
239 | " ''' function to split the data '''\n",
240 | " \n",
241 | " dataset_left = np.array([row for row in dataset if row[feature_index]<=threshold])\n",
242 | " dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n",
243 | " return dataset_left, dataset_right\n",
244 | " \n",
245 | " def variance_reduction(self, parent, l_child, r_child):\n",
246 | " ''' function to compute variance reduction '''\n",
247 | " \n",
248 | " weight_l = len(l_child) / len(parent)\n",
249 | " weight_r = len(r_child) / len(parent)\n",
250 | " reduction = np.var(parent) - (weight_l * np.var(l_child) + weight_r * np.var(r_child))\n",
251 | " return reduction\n",
252 | " \n",
253 | " def calculate_leaf_value(self, Y):\n",
254 | " ''' function to compute leaf node '''\n",
255 | " \n",
256 | " val = np.mean(Y)\n",
257 | " return val\n",
258 | " \n",
259 | " def print_tree(self, tree=None, indent=\" \"):\n",
260 | " ''' function to print the tree '''\n",
261 | " \n",
262 | " if not tree:\n",
263 | " tree = self.root\n",
264 | "\n",
265 | " if tree.value is not None:\n",
266 | " print(tree.value)\n",
267 | "\n",
268 | " else:\n",
269 | " print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.var_red)\n",
270 | " print(\"%sleft:\" % (indent), end=\"\")\n",
271 | " self.print_tree(tree.left, indent + indent)\n",
272 | " print(\"%sright:\" % (indent), end=\"\")\n",
273 | " self.print_tree(tree.right, indent + indent)\n",
274 | " \n",
275 | " def fit(self, X, Y):\n",
276 | " ''' function to train the tree '''\n",
277 | " \n",
278 | " dataset = np.concatenate((X, Y), axis=1)\n",
279 | " self.root = self.build_tree(dataset)\n",
280 | " \n",
281 | " def make_prediction(self, x, tree):\n",
282 | " ''' function to predict new dataset '''\n",
283 | " \n",
284 | " if tree.value!=None: return tree.value\n",
285 | " feature_val = x[tree.feature_index]\n",
286 | " if feature_val<=tree.threshold:\n",
287 | " return self.make_prediction(x, tree.left)\n",
288 | " else:\n",
289 | " return self.make_prediction(x, tree.right)\n",
290 | " \n",
291 | " def predict(self, X):\n",
292 | " ''' function to predict a single data point '''\n",
293 | " \n",
294 | " preditions = [self.make_prediction(x, self.root) for x in X]\n",
295 | " return preditions"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {},
301 | "source": [
302 | "## Train-Test split"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 5,
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "X = data.iloc[:, :-1].values\n",
312 | "Y = data.iloc[:, -1].values.reshape(-1,1)\n",
313 | "from sklearn.model_selection import train_test_split\n",
314 | "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, random_state=41)"
315 | ]
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {},
320 | "source": [
321 | "## Fit the model"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 6,
327 | "metadata": {},
328 | "outputs": [
329 | {
330 | "name": "stdout",
331 | "output_type": "stream",
332 | "text": [
333 | "X_0 <= 3150.0 ? 7.132048702017748\n",
334 | " left:X_4 <= 0.033779199999999995 ? 3.5903305690676675\n",
335 | " left:X_3 <= 55.5 ? 1.1789899981318328\n",
336 | " left:X_4 <= 0.00251435 ? 1.614396721819876\n",
337 | " left:128.9919833333333\n",
338 | " right:125.90953579676673\n",
339 | " right:X_1 <= 15.4 ? 2.2342245360792994\n",
340 | " left:129.39160280373832\n",
341 | " right:123.80422222222222\n",
342 | " right:X_0 <= 1250.0 ? 9.970884020498875\n",
343 | " left:X_4 <= 0.0483159 ? 6.355275159824863\n",
344 | " left:124.38024528301887\n",
345 | " right:118.30039999999998\n",
346 | " right:X_3 <= 39.6 ? 5.036286657241022\n",
347 | " left:113.58091666666667\n",
348 | " right:118.07284615384614\n",
349 | " right:X_4 <= 0.00146332 ? 29.082992105065273\n",
350 | " left:X_0 <= 8000.0 ? 11.886497073996967\n",
351 | " left:X_2 <= 0.0508 ? 7.608945827689513\n",
352 | " left:134.04247500000002\n",
353 | " right:127.33581818181818\n",
354 | " right:X_4 <= 0.00076193 ? 10.622919322400815\n",
355 | " left:128.94078571428574\n",
356 | " right:122.4076875\n",
357 | " right:X_4 <= 0.022902799999999997 ? 5.638575922510647\n",
358 | " left:X_0 <= 6300.0 ? 5.985051045988911\n",
359 | " left:120.04740816326529\n",
360 | " right:114.67370491803278\n",
361 | " right:X_4 <= 0.0368233 ? 8.63874479304644\n",
362 | " left:113.83169565217393\n",
363 | " right:107.6395833333333\n"
364 | ]
365 | }
366 | ],
367 | "source": [
368 | "regressor = DecisionTreeRegressor(min_samples_split=3, max_depth=3)\n",
369 | "regressor.fit(X_train,Y_train)\n",
370 | "regressor.print_tree()"
371 | ]
372 | },
373 | {
374 | "cell_type": "markdown",
375 | "metadata": {},
376 | "source": [
377 | "## Test the model"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": 7,
383 | "metadata": {},
384 | "outputs": [
385 | {
386 | "data": {
387 | "text/plain": [
388 | "4.851358097184457"
389 | ]
390 | },
391 | "execution_count": 7,
392 | "metadata": {},
393 | "output_type": "execute_result"
394 | }
395 | ],
396 | "source": [
397 | "Y_pred = regressor.predict(X_test) \n",
398 | "from sklearn.metrics import mean_squared_error\n",
399 | "np.sqrt(mean_squared_error(Y_test, Y_pred))"
400 | ]
401 | }
402 | ],
403 | "metadata": {
404 | "kernelspec": {
405 | "display_name": "Python 3",
406 | "language": "python",
407 | "name": "python3"
408 | },
409 | "language_info": {
410 | "codemirror_mode": {
411 | "name": "ipython",
412 | "version": 3
413 | },
414 | "file_extension": ".py",
415 | "mimetype": "text/x-python",
416 | "name": "python",
417 | "nbconvert_exporter": "python",
418 | "pygments_lexer": "ipython3",
419 | "version": "3.8.5"
420 | }
421 | },
422 | "nbformat": 4,
423 | "nbformat_minor": 4
424 | }
425 |
--------------------------------------------------------------------------------
/linear regression.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Generating Fake Data"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 16,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from sklearn.datasets.samples_generator import make_regression\n",
17 | "X, y = make_regression(n_samples=200, n_features=1, n_informative=1, noise=6, bias=30, random_state=200)\n",
18 | "m = 200"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## Visualizing the Data"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 25,
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "data": {
35 | "image/png": "\n",
36 | "text/plain": [
37 | ""
38 | ]
39 | },
40 | "metadata": {},
41 | "output_type": "display_data"
42 | }
43 | ],
44 | "source": [
45 | "from matplotlib import pyplot as plt\n",
46 | "plt.scatter(X,y, c = \"red\",alpha=.5, marker = 'o')\n",
47 | "plt.xlabel(\"X\")\n",
48 | "plt.ylabel(\"Y\")\n",
49 | "plt.show()"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## Linear Model"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 18,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "import numpy as np\n",
66 | "def h(X,w):\n",
67 | " return (w[1]*np.array(X[:,0])+w[0])"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Cost Function"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 19,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "def cost(w,X,y):\n",
84 | " return (.5/m) * np.sum(np.square(h(X,w)-np.array(y)))"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "## Gradient Descent "
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 20,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "def grad(w,X,y):\n",
101 | " g = [0]*2\n",
102 | " g[0] = (1/m) * np.sum(h(X,w)-np.array(y))\n",
103 | " g[1] = (1/m) * np.sum((h(X,w)-np.array(y))*np.array(X[:,0]))\n",
104 | " return g\n"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 21,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "def descent(w_new, w_prev, lr):\n",
114 | " print(w_prev)\n",
115 | " print(cost(w_prev,X,y))\n",
116 | " j=0\n",
117 | " while True:\n",
118 | " w_prev = w_new\n",
119 | " w0 = w_prev[0] - lr*grad(w_prev,X,y)[0]\n",
120 | " w1 = w_prev[1] - lr*grad(w_prev,X,y)[1]\n",
121 | " w_new = [w0, w1]\n",
122 | " print(w_new)\n",
123 | " print(cost(w_new,X,y))\n",
124 | " if (w_new[0]-w_prev[0])**2 + (w_new[1]-w_prev[1])**2 <= pow(10,-6):\n",
125 | " return w_new\n",
126 | " if j>500: \n",
127 | " return w_new\n",
128 | " j+=1 "
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "## Initializing Parameters"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 22,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "w = [0,-1]"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "## Training the Model"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 23,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "name": "stdout",
161 | "output_type": "stream",
162 | "text": [
163 | "[0, -1]\n",
164 | "540.5360663843456\n",
165 | "[3.0956308633447547, 0.11442770988081663]\n",
166 | "437.91139336428444\n",
167 | "[5.873446610978822, 1.1023454281382854]\n",
168 | "355.5039050187037\n",
169 | "[8.366165526017987, 1.9778657783247602]\n",
170 | "289.3267499184995\n",
171 | "[10.603129563187093, 2.753547324958939]\n",
172 | "236.1799750745718\n",
173 | "[12.610653489037027, 3.440564026385428]\n",
174 | "193.49509649539323\n",
175 | "[14.412337853388406, 4.048856351454087]\n",
176 | "159.2103901995911\n",
177 | "[16.0293495446536, 4.587266032213945]\n",
178 | "131.6708284668908\n",
179 | "[17.480673291820082, 5.063656213710697]\n",
180 | "109.54778810165583\n",
181 | "[18.7833371265594, 5.485018573380515]\n",
182 | "91.77462156224563\n",
183 | "[19.952614505935692, 5.857568814053481]\n",
184 | "77.49495508304668\n",
185 | "[21.002205515744066, 6.186831784078626]\n",
186 | "66.02119816099949\n",
187 | "[21.944399323224108, 6.4777173436470505]\n",
188 | "56.801246289923824\n",
189 | "[22.79021982273288, 6.734587976310905]\n",
190 | "49.39175789964725\n",
191 | "[23.549556216205993, 6.961319037445921]\n",
192 | "43.436706577550574\n",
193 | "[24.23128008944935, 7.1613524356181975]\n",
194 | "38.6501664442448\n",
195 | "[24.843350383306017, 7.337744457271138]\n",
196 | "34.802494555336104\n",
197 | "[25.39290751357782, 7.493208368754656]\n",
198 | "31.709239459080347\n",
199 | "[25.88635776349851, 7.630152361500887]\n",
200 | "29.22223761683571\n",
201 | "[26.329448955986066, 7.750713345236864]\n",
202 | "27.22246575546644\n",
203 | "[26.727338308440384, 7.85678703973609]\n",
204 | "25.614302554322332\n",
205 | "[27.084653279240058, 7.950054767051352]\n",
206 | "24.320921533807876\n",
207 | "[27.405546131200424, 8.032007302818467]\n",
208 | "23.280591944370723\n",
209 | "[27.69374286207346, 8.10396610651881]\n",
210 | "22.443708530267372\n",
211 | "[27.952587084793493, 8.167102216040767]\n",
212 | "21.77040640846802\n",
213 | "[28.1850793797893, 8.222453061042739]\n",
214 | "21.22864568158778\n",
215 | "[28.393912587566714, 8.270937422096216]\n",
216 | "20.792673176173675\n",
217 | "[28.581503461264624, 8.313368738022628]\n",
218 | "20.44178697219165\n",
219 | "[28.75002105541824, 8.350466941915155]\n",
220 | "20.159344055135406\n",
221 | "[28.901412188203544, 8.382868986773879]\n",
222 | "19.93196319200716\n",
223 | "[29.03742427951789, 8.411138204226754]\n",
224 | "19.74888457866804\n",
225 | "[29.159625835953786, 8.435772624233959]\n",
226 | "19.601455387770752\n",
227 | "[29.2694248256702, 8.45721236977788]\n",
228 | "19.482716431969585\n",
229 | "[29.368085161021078, 8.475846228144906]\n",
230 | "19.387070041859424\n",
231 | "[29.45674148426251, 8.492017489347575]\n",
232 | "19.310013179228125\n",
233 | "[29.536412431457247, 8.506029132372674]\n",
234 | "19.24792295396926\n",
235 | "[29.60801253158606, 8.518148431144402]\n",
236 | "19.19788424005612\n",
237 | "[29.672362881642048, 8.528611044246796]\n",
238 | "19.157551114830053\n",
239 | "[29.73020072393229, 8.537624645454327]\n",
240 | "19.125035474814123\n",
241 | "[29.782188038766204, 8.545372145882025]\n",
242 | "19.09881748922257\n",
243 | "[29.828919254015993, 8.552014553005522]\n",
244 | "19.077673602614112\n",
245 | "[29.87092816255075, 8.557693506843796]\n",
246 | "19.06061864154971\n",
247 | "[29.90869412914744, 8.562533529178292]\n",
248 | "19.046859257450855\n",
249 | "[29.942647660055936, 8.566644017743355]\n",
250 | "19.035756481849372\n",
251 | "[29.97317540084111, 8.57012101381262]\n",
252 | "19.026795607153897\n",
253 | "[30.000624621352255, 8.573048768478005]\n",
254 | "19.019561957024997\n",
255 | "[30.025307240597662, 8.57550113013079]\n",
256 | "19.013721392385282\n",
257 | "[30.047503438857856, 8.577542773171217]\n",
258 | "19.009004625587373\n",
259 | "[30.067464899489234, 8.579230285760998]\n",
260 | "19.005194597236063\n",
261 | "[30.085417718492856, 8.580613132462974]\n",
262 | "19.00211631637512\n",
263 | "[30.101565015998244, 8.581734505857439]\n",
264 | "18.99962868224005\n",
265 | "[30.11608928029284, 8.582632079662158]\n",
266 | "18.99761790019706\n",
267 | "[30.12915447187187, 8.583338674491873]\n",
268 | "18.995992180371438\n",
269 | "[30.14090791215345, 8.583882846154577]\n",
270 | "18.994677468461337\n",
271 | "[30.151481978966018, 8.584289405279359]\n",
272 | "18.993614007260764\n",
273 | "[30.160995628639302, 8.58457987608947]\n",
274 | "18.99275356682987\n",
275 | "[30.169555762489143, 8.584772901261081]\n",
276 | "18.992057212939287\n",
277 | "[30.177258453655917, 8.584884599031389]\n",
278 | "18.991493508895633\n",
279 | "[30.184190048614877, 8.584928878028586]\n",
280 | "18.991037066345022\n",
281 | "[30.190428156204195, 8.584917714681556]\n",
282 | "18.99066737713041\n",
283 | "[30.196042535696083, 8.584861397520429]\n",
284 | "18.99036787153328\n",
285 | "[30.201095894251807, 8.584768742193141]\n",
286 | "18.990125158892102\n",
287 | "[30.205644603039065, 8.58464728059093]\n",
288 | "18.989928415168315\n",
289 | "[30.20973934033721, 8.584503427091704]\n",
290 | "18.98976888893204\n",
291 | "[30.21342566910097, 8.58434262458882]\n",
292 | "18.98963950279431\n",
293 | "[30.216744555686475, 8.584169472669567]\n",
294 | "18.989534531782173\n",
295 | "[30.219732835755565, 8.58398784003826]\n",
296 | "18.9894493437512\n",
297 | "[30.22242363275721, 8.583800963039536]\n",
298 | "18.989380189826502\n",
299 | "[30.22484673383125, 8.583611531925033]\n",
300 | "18.989324035195228\n",
301 | "[30.227028927482966, 8.583421766317985]\n",
302 | "18.989278422451623\n",
303 | "[30.228994306931416, 8.583233481162887]\n",
304 | "18.98924136120799\n",
305 | "[30.23076454263459, 8.583048144298838]\n",
306 | "18.98921123890338\n",
307 | "[30.23235912713573, 8.582866926663364]\n",
308 | "18.98918674872298\n",
309 | "[30.233795595053294, 8.582690746016727]\n",
310 | "18.98916683133221\n",
311 | "[30.235089720748128, 8.582520304973011]\n",
312 | "18.989150627766747\n",
313 | "[30.23625569594233, 8.582356124032495]\n",
314 | "18.98913744133326\n",
315 | "[30.237306289331627, 8.582198570228412]\n",
316 | "18.989126706789886\n",
317 | "[30.238252990024396, 8.582047881929046]\n",
318 | "18.98911796540923\n",
319 | "[30.238252990024396, 8.582047881929046]\n"
320 | ]
321 | }
322 | ],
323 | "source": [
324 | "w = descent(w,w,.1)\n",
325 | "print(w)"
326 | ]
327 | },
328 | {
329 | "cell_type": "markdown",
330 | "metadata": {},
331 | "source": [
332 | "## Visualizing the Result"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": 24,
338 | "metadata": {},
339 | "outputs": [
340 | {
341 | "data": {
342 | "image/png": "\n",
343 | "text/plain": [
344 | ""
345 | ]
346 | },
347 | "metadata": {},
348 | "output_type": "display_data"
349 | }
350 | ],
351 | "source": [
352 | "def graph(formula, x_range): \n",
353 | " x = np.array(x_range) \n",
354 | " y = formula(x) \n",
355 | " plt.plot(x, y) \n",
356 | " \n",
357 | "def my_formula(x):\n",
358 | " return w[0]+w[1]*x\n",
359 | "\n",
360 | "plt.scatter(X,y, c = \"red\",alpha=.5, marker = 'o')\n",
361 | "graph(my_formula, range(-5,5))\n",
362 | "plt.xlabel('X')\n",
363 | "plt.ylabel('Y')\n",
364 | "plt.show()"
365 | ]
366 | }
367 | ],
368 | "metadata": {
369 | "kernelspec": {
370 | "display_name": "Python 3",
371 | "language": "python",
372 | "name": "python3"
373 | },
374 | "language_info": {
375 | "codemirror_mode": {
376 | "name": "ipython",
377 | "version": 3
378 | },
379 | "file_extension": ".py",
380 | "mimetype": "text/x-python",
381 | "name": "python",
382 | "nbconvert_exporter": "python",
383 | "pygments_lexer": "ipython3",
384 | "version": "3.6.5"
385 | }
386 | },
387 | "nbformat": 4,
388 | "nbformat_minor": 2
389 | }
390 |
--------------------------------------------------------------------------------
/logistic regression.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Creating Fake Data"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 26,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from sklearn.datasets.samples_generator import make_blobs\n",
17 | "X, Y = make_blobs(n_samples=200, centers=2, n_features=2, cluster_std=5, random_state=11)\n",
18 | "m = 200"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## Visualizing the Data"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 28,
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "data": {
35 | "image/png": "\n",
36 | "text/plain": [
37 | ""
38 | ]
39 | },
40 | "metadata": {},
41 | "output_type": "display_data"
42 | }
43 | ],
44 | "source": [
45 | "from matplotlib import pyplot as plt\n",
46 | "from pandas import DataFrame \n",
47 | "df = DataFrame(dict(x=X[:,0], y=X[:,1], label=Y))\n",
48 | "colors = {0:'blue', 1:'orange'}\n",
49 | "fig, ax = plt.subplots()\n",
50 | "grouped = df.groupby('label')\n",
51 | "for key, group in grouped:\n",
52 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n",
53 | "plt.xlabel('X_0')\n",
54 | "plt.ylabel('X_1')\n",
55 | "plt.show()"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "## Logistic Model"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 6,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "import numpy as np\n",
72 | "def sigmoid(z):\n",
73 | " return 1 / (1 + np.exp(-z))"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 7,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "def hx(w,X):\n",
83 | " z = np.array(w[0] + w[1]*np.array(X[:,0]) + w[2]*np.array(X[:,1]))\n",
84 | " return sigmoid(z)"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "## Cost Function - Binary Cross Entropy "
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 8,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "def cost(w, X, Y):\n",
101 | " y_pred = hx(w,X)\n",
102 | " return -1 * sum(Y*np.log(y_pred) + (1-Y)*np.log(1-y_pred))"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": [
109 | "## Gradient Descent"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 9,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "def grad(w, X, Y):\n",
119 | " y_pred = hx(w,X)\n",
120 | " g = [0]*3\n",
121 | " g[0] = -1 * sum(Y*(1-y_pred) - (1-Y)*y_pred)\n",
122 | " g[1] = -1 * sum(Y*(1-y_pred)*X[:,0] - (1-Y)*y_pred*X[:,0])\n",
123 | " g[2] = -1 * sum(Y*(1-y_pred)*X[:,1] - (1-Y)*y_pred*X[:,1])\n",
124 | " return g"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 37,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "def descent(w_new, w_prev, lr):\n",
134 | " print(w_prev)\n",
135 | " print(cost(w_prev, X, Y))\n",
136 | " j=0\n",
137 | " while True:\n",
138 | " w_prev = w_new\n",
139 | " w0 = w_prev[0] - lr*grad(w_prev, X, Y)[0]\n",
140 | " w1 = w_prev[1] - lr*grad(w_prev, X, Y)[1]\n",
141 | " w2 = w_prev[2] - lr*grad(w_prev, X, Y)[2]\n",
142 | " w_new = [w0, w1, w2]\n",
143 | " print(w_new)\n",
144 | " print(cost(w_new, X, Y))\n",
145 | " if (w_new[0]-w_prev[0])**2 + (w_new[1]-w_prev[1])**2 + (w_new[2]-w_prev[2])**2 100: \n",
149 | " return w_new\n",
150 | " j+=1"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [
157 | "## Initializing Parameters"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 38,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "w=[1,1,1] "
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {},
172 | "source": [
173 | "## Training the Model"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 39,
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "name": "stdout",
183 | "output_type": "stream",
184 | "text": [
185 | "[1, 1, 1]\n",
186 | "126.96627984087802\n",
187 | "[1.2539422898588644, -0.5512710240837779, 1.1843328109648115]\n",
188 | "112.81815690482537\n",
189 | "[1.2400289717256012, 1.3365539538600526, 1.3041857280167006]\n",
190 | "168.5749357223354\n",
191 | "[1.496038556709955, -0.26091766794801896, 1.498589948744304]\n",
192 | "75.77036936925771\n",
193 | "[1.552998152697941, 0.5805338886016641, 1.2059963033700993]\n",
194 | "67.88063552779857\n",
195 | "[1.7129970823802407, -0.12884219161477717, 0.9205552719943874]\n",
196 | "49.631255078140065\n",
197 | "[1.721522374547689, 0.8081464107255731, 0.784948881261807]\n",
198 | "84.65164447415044\n",
199 | "[1.9318934810214365, -0.5853214217947893, 0.9260582355386137]\n",
200 | "137.02236299684375\n",
201 | "[1.792209433897121, 2.259519445068961, 1.8717281361304319]\n",
202 | "296.375960910099\n",
203 | "[2.0717704014100984, 0.4671501798576154, 2.2375086018139254]\n",
204 | "96.21982855549399\n",
205 | "[2.2275891867070294, 0.17556932949773574, 1.7757719140648793]\n",
206 | "66.93614685472798\n",
207 | "[2.34635073219192, 0.10042351859759649, 1.3377088708014218]\n",
208 | "47.88369966457564\n",
209 | "[2.420814222298708, 0.2006649466226462, 0.9803329461122527]\n",
210 | "35.74059972551982\n",
211 | "[2.47109414690982, 0.15761477718705647, 0.7206148121331153]\n",
212 | "31.156227520259954\n",
213 | "[2.4669494157236738, 0.33368915465087157, 0.6530382051207929]\n",
214 | "31.825794017269445\n",
215 | "[2.5018185122166057, 0.052068234061564855, 0.5947139224875355]\n",
216 | "37.13648828690444\n",
217 | "[2.4168290256808196, 0.8618663362335011, 0.8721726092434906]\n",
218 | "77.31322559093016\n",
219 | "[2.596954895024072, -0.39576768312727784, 0.9305031886901464]\n",
220 | "97.53803980330012\n",
221 | "[2.4680716616586964, 1.8009384856190926, 1.4733863829082396]\n",
222 | "210.1650208652238\n",
223 | "[2.7256313646387795, 0.10398042243284711, 1.7582377785757068]\n",
224 | "60.930800473556\n",
225 | "[2.8139079434551295, 0.18042029298994794, 1.356581836503274]\n",
226 | "44.87910324512945\n",
227 | "[2.8764378665294585, 0.18788023515027294, 1.0240956114450044]\n",
228 | "35.1115767203633\n",
229 | "[2.9004091563475756, 0.25373283853477424, 0.8052448292395792]\n",
230 | "31.203934623682823\n",
231 | "[2.905564316215357, 0.23573784244902107, 0.6972405539065145]\n",
232 | "30.64594404630533\n",
233 | "[2.884978695284491, 0.3156866373830539, 0.7013654663287641]\n",
234 | "30.87299094573297\n",
235 | "[2.8899757309235032, 0.18571374355707734, 0.662108112619276]\n",
236 | "31.58765075793203\n",
237 | "[2.8456086105455767, 0.44695414312741255, 0.7472822400464028]\n",
238 | "34.997339236852504\n",
239 | "[2.8946903864992204, 0.0311385451581056, 0.652724208526957]\n",
240 | "41.16553917784135\n",
241 | "[2.78333314375564, 1.0205125031757631, 0.9858502424130702]\n",
242 | "90.41042650497221\n",
243 | "[2.970171901127539, -0.2925927602835001, 1.054289078329687]\n",
244 | "77.28627608547319\n",
245 | "[2.8765338440725268, 1.3870481637289969, 1.2593201244024514]\n",
246 | "137.884699579421\n",
247 | "[3.096205693561861, -0.09815951484676533, 1.3875700666825428]\n",
248 | "55.28433950652438\n",
249 | "[3.091222620686947, 0.691357878366354, 1.1851902704372872]\n",
250 | "54.56897648020144\n",
251 | "[3.197031648206172, 0.04853253754001363, 0.9522363399023503]\n",
252 | "38.879877723561194\n",
253 | "[3.152849114551307, 0.6705086356955698, 0.9042855300104767]\n",
254 | "47.42326651334623\n",
255 | "[3.249401718635574, -0.05586144519042935, 0.7968950006624693]\n",
256 | "50.77734982683568\n",
257 | "[3.1265348460183477, 1.2297186855362119, 1.1081714720550426]\n",
258 | "111.36985561835434\n",
259 | "[3.326563756861511, -0.1791741101134301, 1.2218462740483753]\n",
260 | "62.108694704750796\n",
261 | "[3.269128921912794, 1.0293263482837012, 1.2008084141382773]\n",
262 | "82.99082598516479\n",
263 | "[3.433600163161756, -0.10416961202267583, 1.1426909126167342]\n",
264 | "53.969099415361306\n",
265 | "[3.372434392083106, 0.9606133194736919, 1.1258764505552485]\n",
266 | "73.24867011604414\n",
267 | "[3.5227671344926526, -0.10541269443776047, 1.0639829014706683]\n",
268 | "54.95877911862518\n",
269 | "[3.4402314176940667, 1.0678781146572354, 1.1359656744505753]\n",
270 | "84.39209613357035\n",
271 | "[3.6053162123525917, -0.1094419475498285, 1.1193961332548408]\n",
272 | "55.78053131843188\n",
273 | "[3.5284998867311446, 1.035838685329749, 1.1603344516799623]\n",
274 | "79.31512100875807\n",
275 | "[3.6850794814765253, -0.0823331021711422, 1.117327385696941]\n",
276 | "53.365825253435844\n",
277 | "[3.6094245208717335, 0.990136688364013, 1.1479057502919368]\n",
278 | "73.00073094363131\n",
279 | "[3.7550639240710013, -0.06190456364194685, 1.0874826200950347]\n",
280 | "51.85370536629623\n",
281 | "[3.6740149106228186, 0.9863883933259856, 1.1383256238195103]\n",
282 | "71.6077027550773\n",
283 | "[3.816365454730692, -0.053415680920452235, 1.079972647866281]\n",
284 | "51.512482717018116\n",
285 | "[3.7321415001226725, 0.9874518828741854, 1.1419582058166857]\n",
286 | "70.91147940312463\n",
287 | "[3.8721953817202945, -0.040975086743274325, 1.0817922008147334]\n",
288 | "50.68840558905084\n",
289 | "[3.7880511473342113, 0.9676622239349191, 1.1410140920986547]\n",
290 | "68.11602429098141\n",
291 | "[3.9218160093807857, -0.02032252700940551, 1.069864113475001]\n",
292 | "49.12965800771219\n",
293 | "[3.8372555826614154, 0.9397003151318558, 1.1301065005574003]\n",
294 | "64.66526827228455\n",
295 | "[3.963079147188273, 0.0015189750142303726, 1.048002481200122]\n",
296 | "47.55953959619178\n",
297 | "[3.8768575558853064, 0.9161026456599739, 1.116464165382025]\n",
298 | "61.870753988367085\n",
299 | "[3.995752432150913, 0.019597576030487285, 1.027414752411626]\n",
300 | "46.36190019402692\n",
301 | "[3.907745460549524, 0.89829389197684, 1.1055063163003485]\n",
302 | "59.82357707095895\n",
303 | "[4.0211666051250194, 0.03440304270300665, 1.0117153695411087]\n",
304 | "45.42762114954028\n",
305 | "[3.932061436494771, 0.8820448670623741, 1.0967233693735197]\n",
306 | "58.07265242127726\n",
307 | "[4.040538211472588, 0.0486222507940266, 0.9980600115598323]\n",
308 | "44.528626105329884\n",
309 | "[3.95104075816509, 0.8635822178845587, 1.087464778932447]\n",
310 | "56.25191093140636\n",
311 | "[4.0542064633885655, 0.06390897971924736, 0.9830845369826282]\n",
312 | "43.560297250359895\n",
313 | "[3.9647146116297285, 0.8419896969170714, 1.0762339319145042]\n",
314 | "54.28859760963312\n",
315 | "[4.061974777022511, 0.08010117719043697, 0.9658476947904583]\n",
316 | "42.55609321933739\n",
317 | "[3.9726570293297416, 0.8185528342778258, 1.0632653646122094]\n",
318 | "52.30757518813522\n",
319 | "[4.063741047353516, 0.09591304118161548, 0.9474917553090927]\n",
320 | "41.60869408164448\n",
321 | "[3.9746698824165576, 0.7955306511635729, 1.0497944865210715]\n",
322 | "50.49160062072334\n",
323 | "[4.0598875620346835, 0.10973094543337403, 0.929938489713735]\n",
324 | "40.807102034142616\n",
325 | "[3.9710681877912237, 0.7754202550430537, 1.0373138237663093]\n",
326 | "49.006821287773285\n",
327 | "[4.051367062032833, 0.120083598318384, 0.9149680408447317]\n",
328 | "40.21376802733146\n",
329 | "[3.9626934207233138, 0.7605664236412619, 1.027175552319945]\n",
330 | "47.97972673492794\n",
331 | "[4.03960146305846, 0.1258626202000106, 0.9038701158750538]\n",
332 | "39.8665848590012\n",
333 | "[3.9508175989838605, 0.7529117846030396, 1.0204483002710898]\n",
334 | "47.49986495259877\n",
335 | "[4.026319016314264, 0.12633822832598507, 0.8975122703008471]\n",
336 | "39.79309105387893\n",
337 | "[3.937022647315148, 0.7539702241394387, 1.0179583095352271]\n",
338 | "47.64232237902202\n",
339 | "[4.013431600037268, 0.12097838008325779, 0.8966869188055052]\n",
340 | "40.03060914399324\n",
341 | "[3.9231193858382776, 0.7650499798316024, 1.0204780871163988]\n",
342 | "48.50370516914084\n",
343 | "[4.0030116224672625, 0.10920033650354244, 0.9025652832558652]\n",
344 | "40.647028942768124\n",
345 | "[3.911168807082645, 0.7874617085157639, 1.0289496414232508]\n",
346 | "50.233655133405684\n",
347 | "[3.9973475854531766, 0.09037455328475408, 0.9170124097333662]\n",
348 | "41.74920834355003\n",
349 | "[3.9036347048532614, 0.8219989928942072, 1.0444815131818708]\n",
350 | "53.00756665023323\n",
351 | "[3.998896442510712, 0.06491278493707686, 0.9421306486550264]\n",
352 | "43.414481126056494\n",
353 | "[3.903578932273919, 0.8658652317559935, 1.0673333907773224]\n",
354 | "56.743118405663786\n",
355 | "[4.009605120891904, 0.03790132000571389, 0.9771595777097101]\n",
356 | "45.367245465512056\n",
357 | "[3.9142374607656074, 0.9053571447819756, 1.0933903406414625]\n",
358 | "60.26445149748922\n",
359 | "[4.028982704706953, 0.022311918251399665, 1.0115367561804436]\n",
360 | "46.56452396428107\n",
361 | "[3.936118279261484, 0.9161216408470005, 1.1107981056735687]\n",
362 | "61.119030592288695\n",
363 | "[4.05221859760455, 0.029367074818002936, 1.0249045662087954]\n",
364 | "46.03001820966252\n",
365 | "[3.962804479787527, 0.890313219168536, 1.1082732172211933]\n",
366 | "58.54793332826261\n",
367 | "[4.071694739617031, 0.05576987084898566, 1.0076561457699287]\n",
368 | "44.146310350361176\n",
369 | "[3.9843650919995444, 0.8441134205516451, 1.087132296857637]\n",
370 | "54.38313567139385\n",
371 | "[4.081221437775141, 0.08935078570508614, 0.9713207459960232]\n",
372 | "42.00179619891018\n",
373 | "[3.9945804561228613, 0.7946114356824799, 1.058305857789166]\n",
374 | "50.363531876210516\n",
375 | "[4.078719695444127, 0.11976411937445264, 0.9330027012706759]\n",
376 | "40.284159514532284\n",
377 | "[3.992398365094767, 0.7520453654158086, 1.0315755035230465]\n",
378 | "47.27452627208367\n",
379 | "[4.065870596044042, 0.14186053507189011, 0.9025860178561832]\n",
380 | "39.138492702670455\n",
381 | "[3.9799787694173245, 0.7210735715432988, 1.0111931600659616]\n",
382 | "45.255334753917104\n",
383 | "[4.0460976445817955, 0.15448991706397763, 0.8819752531296552]\n",
384 | "38.48142071458852\n",
385 | "[3.960514974560398, 0.7030948526309135, 0.9977984989102618]\n",
386 | "44.1983732970012\n",
387 | "[4.022890960756249, 0.1583065776331426, 0.8696992149881329]\n",
388 | "38.21770810120381\n",
389 | "[3.937085053426291, 0.6982103483166531, 0.9909832248890901]\n",
390 | "43.98944447469072\n",
391 | "[3.937085053426291, 0.6982103483166531, 0.9909832248890901]\n"
392 | ]
393 | }
394 | ],
395 | "source": [
396 | "w = descent(w,w,.0099)\n",
397 | "print(w)"
398 | ]
399 | },
400 | {
401 | "cell_type": "markdown",
402 | "metadata": {},
403 | "source": [
404 | "## Visualizing the Result"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 40,
410 | "metadata": {},
411 | "outputs": [
412 | {
413 | "data": {
414 | "image/png": "\n",
415 | "text/plain": [
416 | ""
417 | ]
418 | },
419 | "metadata": {},
420 | "output_type": "display_data"
421 | }
422 | ],
423 | "source": [
424 | "def graph(formula, x_range): \n",
425 | " x = np.array(x_range) \n",
426 | " y = formula(x) \n",
427 | " plt.plot(x, y) \n",
428 | " \n",
429 | "def my_formula(x):\n",
430 | " return (-w[0]-w[1]*x)/w[2]\n",
431 | "\n",
432 | "from matplotlib import pyplot as plt\n",
433 | "from pandas import DataFrame \n",
434 | "df = DataFrame(dict(x=X[:,0], y=X[:,1], label=Y))\n",
435 | "colors = {0:'blue', 1:'orange'}\n",
436 | "fig, ax = plt.subplots()\n",
437 | "grouped = df.groupby('label')\n",
438 | "for key, group in grouped:\n",
439 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n",
440 | "graph(my_formula, range(-20,15))\n",
441 | "plt.xlabel('X_0')\n",
442 | "plt.ylabel('X_1')\n",
443 | "plt.show()"
444 | ]
445 | }
446 | ],
447 | "metadata": {
448 | "kernelspec": {
449 | "display_name": "Python 3",
450 | "language": "python",
451 | "name": "python3"
452 | },
453 | "language_info": {
454 | "codemirror_mode": {
455 | "name": "ipython",
456 | "version": 3
457 | },
458 | "file_extension": ".py",
459 | "mimetype": "text/x-python",
460 | "name": "python",
461 | "nbconvert_exporter": "python",
462 | "pygments_lexer": "ipython3",
463 | "version": "3.6.5"
464 | }
465 | },
466 | "nbformat": 4,
467 | "nbformat_minor": 2
468 | }
469 |
--------------------------------------------------------------------------------
/sgd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Creating Fake Data"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "from sklearn.utils import shuffle\n",
18 | "from sklearn.datasets.samples_generator import make_blobs\n",
19 | "X, Y = make_blobs(n_samples=300, centers=2, n_features=2, cluster_std=5, random_state=11)"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "## Visualizing the Data"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 3,
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "data": {
36 | "image/png": "\n",
37 | "text/plain": [
38 | ""
39 | ]
40 | },
41 | "metadata": {},
42 | "output_type": "display_data"
43 | }
44 | ],
45 | "source": [
46 | "from matplotlib import pyplot as plt\n",
47 | "from pandas import DataFrame \n",
48 | "df = DataFrame(dict(x=X[:,0], y=X[:,1], label=Y))\n",
49 | "colors = {0:'blue', 1:'orange'}\n",
50 | "fig, ax = plt.subplots()\n",
51 | "grouped = df.groupby('label')\n",
52 | "for key, group in grouped:\n",
53 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n",
54 | "plt.xlabel('X_1')\n",
55 | "plt.ylabel('X_2')\n",
56 | "plt.show()"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "## Splitting into batches"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 4,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "def next_batch(X, Y, batch_size):\n",
73 | " for i in np.arange(0, X.shape[0], batch_size):\n",
74 | " yield (X[i:i + batch_size], Y[i:i + batch_size])"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "## Adding column of 1's "
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 5,
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/plain": [
92 | "(300, 3)"
93 | ]
94 | },
95 | "execution_count": 5,
96 | "metadata": {},
97 | "output_type": "execute_result"
98 | }
99 | ],
100 | "source": [
101 | "X = np.c_[np.ones((X.shape[0])), X]\n",
102 | "X.shape"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": [
109 | "## Logistic Model"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 6,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "import numpy as np\n",
119 | "def sigmoid(z):\n",
120 | " return 1 / (1 + np.exp(-z))"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 7,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "def hx(W,X):\n",
130 | " return sigmoid(np.dot(X,W))"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "metadata": {},
136 | "source": [
137 | "## Cost Function - Binary Cross Entropy "
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 8,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "def cost(W, X, Y):\n",
147 | " y_pred = hx(W,X)\n",
148 | " return -1 * sum(Y*np.log(y_pred) + (1-Y)*np.log(1-y_pred))"
149 | ]
150 | },
151 | {
152 | "cell_type": "markdown",
153 | "metadata": {},
154 | "source": [
155 | "## Stochastic Gradient Descent"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 9,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "def grad(W, X, Y):\n",
165 | " y_pred = hx(W,X)\n",
166 | " A = (Y*(1-y_pred) - (1-Y)*y_pred)\n",
167 | " g = -1* np.dot(A.T,X)\n",
168 | " return g"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 10,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "def sgd(W_new, W_prev, lr, batch_size, epochs):\n",
178 | " X_, Y_ = shuffle(X, Y, random_state=0)\n",
179 | " for e in range(epochs):\n",
180 | " epoch_loss = []\n",
181 | " X_, Y_ = shuffle(X_, Y_, random_state=0)\n",
182 | " for (batchX, batchY) in next_batch(X_, Y_, batch_size):\n",
183 | " W_prev = W_new\n",
184 | " epoch_loss.append(cost(W_prev, batchX, batchY))\n",
185 | " gradients = grad(W_prev, batchX, batchY)\n",
186 | " W_new = W_prev - lr*gradients\n",
187 | " print(np.average(epoch_loss))\n",
188 | " return W_new"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "## Initializing Weights & Bias"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 11,
201 | "metadata": {},
202 | "outputs": [
203 | {
204 | "data": {
205 | "text/plain": [
206 | "(3,)"
207 | ]
208 | },
209 | "execution_count": 11,
210 | "metadata": {},
211 | "output_type": "execute_result"
212 | }
213 | ],
214 | "source": [
215 | "W = np.random.uniform(size=(X.shape[1],))\n",
216 | "W.shape"
217 | ]
218 | },
219 | {
220 | "cell_type": "markdown",
221 | "metadata": {},
222 | "source": [
223 | "## Training the Model"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 12,
229 | "metadata": {},
230 | "outputs": [
231 | {
232 | "name": "stdout",
233 | "output_type": "stream",
234 | "text": [
235 | "[0.79618896 0.4163837 0.42391672]\n",
236 | "6.277384446178855\n",
237 | "5.743854541290189\n",
238 | "5.639645113634937\n",
239 | "5.470853335391083\n",
240 | "5.667962099033007\n",
241 | "5.28981499238036\n",
242 | "5.271342725834096\n",
243 | "5.060368761481374\n",
244 | "5.446976524735743\n",
245 | "5.066447704414417\n",
246 | "5.09547652563199\n",
247 | "5.027047324020139\n",
248 | "5.231413878163542\n",
249 | "4.988553649645022\n",
250 | "5.045966239495252\n",
251 | "5.114093205881277\n",
252 | "5.24313649711587\n",
253 | "5.3309800879511595\n",
254 | "5.2369933729138705\n",
255 | "4.931357222546931\n",
256 | "5.141072913940974\n",
257 | "4.865177946360237\n",
258 | "5.159259956251705\n",
259 | "4.9850013574989935\n",
260 | "5.075992543432262\n",
261 | "5.257537980731875\n",
262 | "5.15220734911969\n",
263 | "5.376113028226739\n",
264 | "5.01790696023796\n",
265 | "4.91743150129637\n",
266 | "5.155592756181093\n",
267 | "4.778496875740468\n",
268 | "5.170830308497642\n",
269 | "4.982057737386109\n",
270 | "5.065383224896023\n",
271 | "5.130301783329287\n",
272 | "5.0560422794900335\n",
273 | "4.751542045996459\n",
274 | "5.044836198314092\n",
275 | "5.200065208594642\n",
276 | "5.285956700133218\n",
277 | "4.69607623817095\n",
278 | "5.1110131343684895\n",
279 | "4.929261653440459\n",
280 | "5.508577234972032\n",
281 | "4.9295442093408335\n",
282 | "5.35853341436791\n",
283 | "5.109882297828997\n",
284 | "5.268900506923824\n",
285 | "5.0553568625861995\n",
286 | "5.127108057053507\n",
287 | "5.256180863878795\n",
288 | "5.2017470229439295\n",
289 | "5.29706022610866\n",
290 | "4.842578732589242\n",
291 | "5.216497984078021\n",
292 | "5.167172158840554\n",
293 | "5.030913544155235\n",
294 | "4.9208154561894215\n",
295 | "5.0286895579435775\n",
296 | "5.254927983516582\n",
297 | "5.01657084246088\n",
298 | "4.902782163580072\n",
299 | "4.952060238428389\n",
300 | "4.9191214387250035\n",
301 | "5.027098170280329\n",
302 | "5.077946337475286\n",
303 | "4.880098287861289\n",
304 | "4.944972140093644\n",
305 | "5.052666303168795\n",
306 | "5.081615975832211\n",
307 | "5.182569862898558\n",
308 | "4.9697239615921625\n",
309 | "5.056706134894904\n",
310 | "5.110368058329958\n",
311 | "4.996883358551884\n",
312 | "4.989696640078847\n",
313 | "4.858898302441649\n",
314 | "4.930715531754518\n",
315 | "5.009314794537656\n",
316 | "5.127984153485331\n",
317 | "5.0231049124602825\n",
318 | "5.254008279820028\n",
319 | "5.014073801135062\n",
320 | "4.97986671455312\n",
321 | "5.05243719762335\n",
322 | "5.0420997708089415\n",
323 | "5.069451589063102\n",
324 | "5.150257226697879\n",
325 | "5.164400634677876\n",
326 | "5.039134823534483\n",
327 | "4.9883474157043795\n",
328 | "5.132433782589542\n",
329 | "4.891167385232467\n",
330 | "5.062085674381252\n",
331 | "5.2087055067642165\n",
332 | "4.955120478076216\n",
333 | "4.88214965225135\n",
334 | "4.935119155279668\n",
335 | "5.4527388101645755\n",
336 | "4.902647474574163\n",
337 | "4.973218626980669\n",
338 | "4.969125747240431\n",
339 | "5.034819845305565\n",
340 | "5.11158072777152\n",
341 | "5.052709686640736\n",
342 | "4.968840505649011\n",
343 | "5.088612053925859\n",
344 | "4.915818444090352\n",
345 | "5.093157727293972\n",
346 | "5.033660011718062\n",
347 | "5.223412724166921\n",
348 | "5.1077816665230475\n",
349 | "5.135494856877415\n",
350 | "4.989777059244171\n",
351 | "5.165527490515357\n",
352 | "5.327764783138228\n",
353 | "5.019080784749545\n",
354 | "5.0239554594473805\n",
355 | "5.026682139880581\n",
356 | "5.159710856114193\n",
357 | "5.027795318281852\n",
358 | "4.988648104694423\n",
359 | "5.012928309587274\n",
360 | "5.081666714554431\n",
361 | "5.272117622609143\n",
362 | "4.851869972177484\n",
363 | "4.956969286562633\n",
364 | "5.025160410234134\n",
365 | "4.993677569943672\n",
366 | "5.099870584911604\n",
367 | "5.096181311604027\n",
368 | "5.236923964162504\n",
369 | "4.903674873885619\n",
370 | "5.075807820064275\n",
371 | "4.701673307737802\n",
372 | "5.299083765240716\n",
373 | "5.112201692629844\n",
374 | "5.057198863500682\n",
375 | "5.213612968874003\n",
376 | "5.1198465312517225\n",
377 | "4.928097399784619\n",
378 | "4.991964359552881\n",
379 | "4.994824548764347\n",
380 | "5.049272339126923\n",
381 | "4.942713247937617\n",
382 | "4.91913758023255\n",
383 | "5.43143326229052\n",
384 | "5.116114463071936\n",
385 | "4.9625898877661445\n",
386 | "5.042137105313292\n",
387 | "5.157805681662321\n",
388 | "5.04896325685934\n",
389 | "5.018982423826876\n",
390 | "5.185034061711793\n",
391 | "5.032465453425752\n",
392 | "4.897108003638827\n",
393 | "4.95421678149162\n",
394 | "5.259006286037971\n",
395 | "5.182717981721985\n",
396 | "4.884826442197778\n",
397 | "5.156842877561891\n",
398 | "5.047645060816286\n",
399 | "5.050121866311711\n",
400 | "4.869794749038038\n",
401 | "4.9957190876138835\n",
402 | "5.022988953794398\n",
403 | "5.002839753440677\n",
404 | "4.881319490767778\n",
405 | "5.055591006717888\n",
406 | "5.015670050297785\n",
407 | "5.013392499945725\n",
408 | "4.989724363452584\n",
409 | "4.880412359751283\n",
410 | "5.416224099805669\n",
411 | "5.2867392415724135\n",
412 | "5.11414942585176\n",
413 | "5.079966239068073\n",
414 | "4.925427043815789\n",
415 | "5.145176957835869\n",
416 | "5.130160451160554\n",
417 | "5.024814422423505\n",
418 | "4.958769936817604\n",
419 | "4.81843182784822\n",
420 | "5.34842232805925\n",
421 | "5.279335054775109\n",
422 | "5.107921381756868\n",
423 | "5.15079791674754\n",
424 | "5.096406330816437\n",
425 | "5.00585692466524\n",
426 | "5.200218860079871\n",
427 | "4.96229644142354\n",
428 | "4.937523034036624\n",
429 | "5.174090668819386\n",
430 | "5.1400977102309175\n",
431 | "5.1319926651455186\n",
432 | "5.048127875476851\n",
433 | "5.416048064604175\n",
434 | "5.057120736099771\n",
435 | "4.864273959307718\n",
436 | "[2.38150361 0.22263106 0.53666744]\n"
437 | ]
438 | }
439 | ],
440 | "source": [
441 | "print(W)\n",
442 | "W = sgd(W, W, .009, 32, 200)\n",
443 | "print(W)"
444 | ]
445 | },
446 | {
447 | "cell_type": "markdown",
448 | "metadata": {},
449 | "source": [
450 | "## Visualizing the Result"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": 13,
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "image/png": "\n",
461 | "text/plain": [
462 | ""
463 | ]
464 | },
465 | "metadata": {},
466 | "output_type": "display_data"
467 | }
468 | ],
469 | "source": [
470 | "def graph(formula, x_range): \n",
471 | " x = np.array(x_range) \n",
472 | " y = formula(x) \n",
473 | " plt.plot(x, y) \n",
474 | " \n",
475 | "def my_formula(x):\n",
476 | " return (-W[0]-W[1]*x)/W[2]\n",
477 | "\n",
478 | "from matplotlib import pyplot as plt\n",
479 | "from pandas import DataFrame \n",
480 | "df = DataFrame(dict(x=X[:,1], y=X[:,2], label=Y))\n",
481 | "colors = {0:'blue', 1:'orange'}\n",
482 | "fig, ax = plt.subplots()\n",
483 | "grouped = df.groupby('label')\n",
484 | "for key, group in grouped:\n",
485 | " group.plot(ax=ax, kind='scatter', x='x', y='y', label=key, color=colors[key])\n",
486 | "graph(my_formula, range(-20,15))\n",
487 | "plt.xlabel('X_1')\n",
488 | "plt.ylabel('X_2')\n",
489 | "plt.show()"
490 | ]
491 | }
492 | ],
493 | "metadata": {
494 | "kernelspec": {
495 | "display_name": "Python 3",
496 | "language": "python",
497 | "name": "python3"
498 | },
499 | "language_info": {
500 | "codemirror_mode": {
501 | "name": "ipython",
502 | "version": 3
503 | },
504 | "file_extension": ".py",
505 | "mimetype": "text/x-python",
506 | "name": "python",
507 | "nbconvert_exporter": "python",
508 | "pygments_lexer": "ipython3",
509 | "version": "3.6.5"
510 | }
511 | },
512 | "nbformat": 4,
513 | "nbformat_minor": 2
514 | }
515 |
--------------------------------------------------------------------------------