├── LICENSE
├── README.md
├── data
├── abalone.data
├── interaction_dataset.csv
└── mushrooms.csv
├── requirements.txt
└── src
├── additional_resources
├── CatBoostClassifier.ipynb
├── IsolationForest.ipynb
└── RandomForestRegressor.ipynb
├── archive
├── image_data.ipynb
├── project_1.ipynb
├── project_1_solution.ipynb
├── project_2.ipynb
└── project_2_solution.ipynb
├── kernel_vs_tree.ipynb
└── shap_tutorial.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Conor O'Sullivan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SHAP-tutorial
2 |
3 | To access the course: https://adataodyssey.com/courses/shap-with-python/
4 |
5 | Watch the course outline here: https://youtu.be/n98pFxcD73w
6 |
--------------------------------------------------------------------------------
/data/interaction_dataset.csv:
--------------------------------------------------------------------------------
1 | "experience" "degree" "performance" "sales" "days_late" "bonus"
2 | 31 1 6.11 29 14 197
3 | 35 1 9.55 44 8 314
4 | 9 1 2.64 26 20 88
5 | 40 1 0.22 13 7 233
6 | 18 1 6.46 11 13 108
7 | 0 1 4.57 72 0 101
8 | 31 0 3.38 83 9 61
9 | 14 1 5.63 27 9 112
10 | 37 1 0.03 13 7 183
11 | 16 1 1.1 7 16 92
12 | 36 1 5.3 18 18 225
13 | 11 1 5.35 17 18 107
14 | 34 1 3.88 53 4 234
15 | 28 1 1.52 58 16 161
16 | 5 1 2.96 33 7 52
17 | 13 1 8.01 23 17 129
18 | 1 0 1.02 71 17 17
19 | 29 1 7.08 99 1 346
20 | 17 1 9.15 61 0 227
21 | 2 1 4.43 14 5 37
22 | 20 1 7.28 82 2 263
23 | 34 1 2.45 97 6 260
24 | 16 1 6.03 64 2 182
25 | 39 1 4.9 5 3 235
26 | 18 1 4.87 87 9 203
27 | 8 1 7.12 28 6 110
28 | 24 0 3.37 16 0 14
29 | 33 0 9.76 52 6 134
30 | 38 1 7.55 46 11 278
31 | 24 0 1.9 61 20 22
32 | 27 1 1.55 57 15 152
33 | 37 0 6.24 42 3 88
34 | 37 1 8.07 30 17 243
35 | 29 0 0.1 14 4 0
36 | 11 1 0.18 6 19 62
37 | 37 1 4.6 33 5 259
38 | 8 1 9.11 11 16 78
39 | 21 0 9.87 81 14 196
40 | 23 0 9.43 99 3 213
41 | 31 1 1.01 51 0 193
42 | 21 1 5.1 90 19 225
43 | 7 1 9.08 77 11 220
44 | 26 1 5.97 68 15 250
45 | 10 1 1.98 24 20 49
46 | 33 1 8.98 58 5 310
47 | 8 1 1.85 13 6 72
48 | 3 0 9.95 42 1 107
49 | 10 1 7.01 9 19 75
50 | 5 1 9.83 8 13 85
51 | 32 1 1.99 8 16 181
52 | 1 1 1.41 25 4 38
53 | 21 1 9.16 35 10 221
54 | 1 0 2.35 80 19 45
55 | 36 0 9.76 6 9 10
56 | 10 0 8.26 52 9 122
57 | 10 1 9.4 4 12 90
58 | 30 1 5.4 79 8 269
59 | 31 1 6.06 93 6 297
60 | 18 1 1.08 23 16 100
61 | 20 0 7.94 13 12 42
62 | 40 1 4 35 12 259
63 | 36 1 3.5 93 19 285
64 | 13 1 3.69 3 17 66
65 | 11 0 7.29 91 11 159
66 | 16 1 4.61 61 6 178
67 | 9 1 7.62 90 9 230
68 | 4 1 4.63 31 1 90
69 | 25 0 2.39 9 8 0
70 | 26 0 3.25 3 6 11
71 | 20 1 7.69 49 18 189
72 | 32 1 8.42 36 14 266
73 | 36 0 6.82 24 7 49
74 | 38 1 1.98 90 15 258
75 | 21 1 2.27 37 14 132
76 | 28 1 5.12 32 2 214
77 | 29 1 6.41 10 14 190
78 | 40 0 1.46 35 0 3
79 | 23 1 8.28 50 15 211
80 | 24 0 6.03 76 1 114
81 | 0 0 2.61 90 11 51
82 | 31 0 1.62 52 5 13
83 | 36 0 8.29 3 6 6
84 | 19 0 3.75 53 11 51
85 | 33 1 1.94 56 5 235
86 | 15 1 9.72 96 5 323
87 | 14 1 3.2 18 2 109
88 | 23 1 2.47 57 20 156
89 | 38 0 4.73 40 4 59
90 | 30 1 6.03 66 0 263
91 | 22 1 4.55 34 0 182
92 | 29 1 6.04 75 10 291
93 | 6 1 0.43 30 4 42
94 | 33 1 1.86 68 1 221
95 | 37 0 8.79 39 9 70
96 | 0 1 5.5 26 5 53
97 | 10 1 1.88 87 19 97
98 | 33 0 4.14 31 10 12
99 | 16 0 2.22 26 11 4
100 | 25 1 5.14 39 20 203
101 | 36 1 1.8 16 18 211
102 | 18 1 2.24 9 1 135
103 | 24 0 0.95 90 4 52
104 | 5 1 6.63 95 13 192
105 | 7 0 1.09 56 19 1
106 | 21 1 6.42 99 15 283
107 | 3 1 3.09 49 16 84
108 | 26 1 5.47 19 7 159
109 | 38 1 1.86 54 0 225
110 | 39 0 3.89 91 0 96
111 | 28 1 1.43 86 5 199
112 | 27 1 8.35 59 16 257
113 | 13 1 4.02 12 17 102
114 | 5 1 2.78 7 19 26
115 | 10 1 0.63 90 20 65
116 | 33 1 2.73 29 0 198
117 | 32 1 8.67 99 13 386
118 | 1 1 8.68 18 6 86
119 | 23 1 5.75 68 15 214
120 | 22 1 4.16 3 13 109
121 | 33 1 4.59 82 2 284
122 | 11 1 0.7 67 9 71
123 | 35 0 1.17 50 15 0
124 | 7 1 4.84 61 15 110
125 | 4 1 1.79 57 10 48
126 | 13 1 7.69 55 19 179
127 | 15 1 4.63 100 6 217
128 | 27 1 9.6 77 3 348
129 | 37 1 8.02 54 16 284
130 | 4 1 9.6 15 12 90
131 | 11 1 0.33 5 19 40
132 | 3 1 2.98 21 19 47
133 | 9 1 3.68 93 13 146
134 | 23 1 9.36 37 16 228
135 | 26 1 1.35 65 16 174
136 | 2 1 7.83 60 16 133
137 | 18 1 5.13 69 5 180
138 | 21 1 4.81 47 1 193
139 | 6 1 8.18 93 0 253
140 | 23 1 0.64 62 16 154
141 | 35 1 5.31 20 9 236
142 | 17 1 9.63 72 13 248
143 | 12 1 0.61 29 19 68
144 | 35 1 3.56 59 16 247
145 | 30 0 5.98 92 7 145
146 | 7 1 1.8 38 2 73
147 | 16 1 1.7 60 5 149
148 | 30 0 8.69 12 18 28
149 | 14 1 8.87 10 6 141
150 | 34 1 7.31 17 13 211
151 | 11 0 9.67 14 12 41
152 | 19 0 9.81 33 4 91
153 | 20 1 1.1 87 16 134
154 | 17 1 2.43 54 8 130
155 | 22 1 5.11 39 16 162
156 | 16 1 5.69 60 16 191
157 | 27 0 2.13 42 0 26
158 | 31 0 6.35 98 19 149
159 | 36 0 3.09 61 7 41
160 | 19 1 3.8 36 8 131
161 | 22 0 6.6 7 14 2
162 | 7 1 4.1 88 0 169
163 | 4 1 1.26 83 14 69
164 | 20 1 8.97 73 12 256
165 | 11 0 8.85 11 0 53
166 | 16 1 2.45 16 8 110
167 | 4 1 2.31 54 10 58
168 | 6 1 4.07 44 5 81
169 | 31 1 4.83 18 13 199
170 | 32 1 7.19 39 8 265
171 | 33 1 9.45 4 5 213
172 | 20 0 9.1 61 1 123
173 | 0 1 6.69 100 15 173
174 | 33 0 1.82 44 9 2
175 | 2 1 5.19 91 17 135
176 | 26 0 3.02 35 8 9
177 | 12 0 9.26 62 9 151
178 | 27 1 0.08 30 19 127
179 | 11 1 2.03 100 8 121
180 | 8 1 3.8 79 11 146
181 | 25 1 0.12 39 15 131
182 | 31 1 5.54 92 18 288
183 | 0 1 5.35 44 7 73
184 | 21 1 7.61 53 0 250
185 | 19 1 6.33 59 5 213
186 | 32 0 7.5 6 3 34
187 | 39 0 5.81 47 11 53
188 | 21 1 7.51 38 19 197
189 | 7 1 1.1 61 19 56
190 | 13 1 3.53 71 16 137
191 | 13 1 7.14 57 0 209
192 | 32 1 1.16 58 14 210
193 | 28 1 9.33 56 8 269
194 | 12 1 5.77 35 3 150
195 | 34 1 4.6 63 6 247
196 | 31 0 3.41 48 4 41
197 | 20 1 7 75 14 243
198 | 35 1 6.28 27 9 244
199 | 38 1 6.78 10 7 240
200 | 39 0 5.95 16 1 32
201 | 22 1 5.25 70 18 222
202 | 9 1 7.16 27 5 128
203 | 8 1 4.68 52 7 102
204 | 10 1 3.92 45 9 99
205 | 29 1 6.63 99 11 338
206 | 12 1 2.97 55 5 114
207 | 14 1 1.16 35 18 99
208 | 24 0 5.85 93 8 119
209 | 22 1 1.52 11 18 132
210 | 36 1 4.45 10 18 186
211 | 8 0 0.65 27 13 0
212 | 9 1 2.17 94 13 111
213 | 37 1 1.44 41 6 240
214 | 14 1 5.24 1 1 109
215 | 2 1 0.97 60 12 28
216 | 24 1 4.17 89 0 232
217 | 7 1 8.69 70 4 187
218 | 1 0 5.83 35 13 48
219 | 1 1 5.91 29 9 58
220 | 24 1 8.64 94 2 339
221 | 24 1 8.93 43 17 246
222 | 20 1 4.55 11 7 127
223 | 5 1 3.29 75 17 89
224 | 23 1 7.99 13 15 149
225 | 23 0 7.52 23 12 57
226 | 18 1 3.27 2 4 112
227 | 37 0 1.06 80 10 24
228 | 5 1 6.36 6 19 61
229 | 3 0 3.99 74 2 92
230 | 5 1 2.31 8 0 72
231 | 34 1 7.45 53 4 292
232 | 21 1 9.25 83 3 316
233 | 8 1 1.04 9 10 67
234 | 35 0 4.3 37 0 42
235 | 38 0 2.97 58 16 40
236 | 28 1 5.47 27 19 169
237 | 24 1 3.67 7 8 163
238 | 11 1 2.68 61 4 108
239 | 26 0 5.22 82 13 115
240 | 40 1 7.92 33 12 279
241 | 36 0 0.76 87 17 5
242 | 25 1 7.47 54 17 244
243 | 27 0 9.04 13 14 34
244 | 18 1 1.8 98 19 145
245 | 13 1 1.09 19 20 71
246 | 8 1 9.89 25 8 129
247 | 31 0 0.58 91 4 20
248 | 14 1 1.41 5 19 70
249 | 17 1 8.95 88 19 281
250 | 17 1 5.96 53 17 173
251 | 11 1 4.63 42 5 121
252 | 28 0 8.44 98 5 189
253 | 4 0 9.54 45 8 111
254 | 37 1 0.79 41 5 217
255 | 39 0 8.65 27 11 52
256 | 23 1 1.82 50 4 151
257 | 12 1 0.22 6 18 61
258 | 38 0 3.45 99 8 92
259 | 28 1 9.74 90 3 361
260 | 40 0 6.29 24 10 28
261 | 14 1 3.34 55 12 141
262 | 16 1 1.27 27 0 130
263 | 14 1 5.34 58 18 161
264 | 23 1 5.01 31 4 190
265 | 37 1 6.75 3 5 229
266 | 27 1 2 74 12 195
267 | 30 1 1.4 26 17 151
268 | 15 1 7.61 4 0 104
269 | 19 0 0.14 30 9 7
270 | 39 1 6.57 74 0 337
271 | 27 1 7.8 34 1 244
272 | 27 1 2.31 11 1 178
273 | 23 0 7.74 69 8 137
274 | 23 0 6.15 62 13 86
275 | 2 1 7.14 91 10 189
276 | 37 1 9.4 62 3 335
277 | 4 1 1.86 59 19 46
278 | 2 0 8.85 51 18 108
279 | 13 1 2.58 61 12 128
280 | 34 0 5.9 5 8 28
281 | 17 0 5.89 5 17 16
282 | 20 1 6.14 33 4 160
283 | 26 1 0.04 5 11 133
284 | 25 1 8.71 9 7 192
285 | 25 1 0.17 45 5 133
286 | 19 1 6.56 20 6 165
287 | 32 1 5.25 4 8 170
288 | 33 1 1.78 92 20 227
289 | 0 1 9.19 12 20 46
290 | 13 1 3.24 8 7 90
291 | 6 1 9.74 97 0 297
292 | 36 1 6.6 61 8 295
293 | 15 1 9.05 45 14 187
294 | 28 1 1.25 41 4 183
295 | 7 1 0.95 53 9 84
296 | 6 1 3.51 12 2 50
297 | 29 1 4.15 83 18 226
298 | 30 1 9.46 44 6 270
299 | 36 1 7.39 51 8 297
300 | 9 1 2.58 91 20 106
301 | 20 0 2.54 93 11 66
302 | 26 0 7.77 39 4 72
303 | 37 1 4.39 47 13 232
304 | 2 1 1.16 50 5 46
305 | 11 1 5.96 22 4 102
306 | 40 1 2.72 94 5 303
307 | 26 1 7.52 65 0 273
308 | 11 0 8.07 15 17 35
309 | 36 0 4.77 11 12 15
310 | 9 0 6 58 17 89
311 | 40 1 4.66 20 20 238
312 | 35 1 7.77 74 3 327
313 | 29 1 3.68 11 13 172
314 | 17 1 7.37 93 6 255
315 | 18 0 2.9 22 6 4
316 | 31 1 9.02 100 20 386
317 | 26 1 4.34 16 2 183
318 | 32 1 5.13 47 18 244
319 | 10 1 5.48 92 9 206
320 | 25 0 9.11 68 7 134
321 | 34 0 5.22 11 8 0
322 | 38 0 6.66 67 6 124
323 | 3 0 6.36 7 18 0
324 | 10 1 9.17 21 3 146
325 | 29 0 9.33 3 1 15
326 | 28 1 2.64 46 13 185
327 | 10 1 4.1 49 12 114
328 | 10 0 5.46 98 11 126
329 | 22 1 4.07 100 8 237
330 | 7 1 0.25 27 9 41
331 | 22 1 7.75 37 10 207
332 | 24 1 2.8 12 7 138
333 | 38 1 7.12 4 14 233
334 | 18 1 8.4 7 12 123
335 | 32 0 4.34 10 1 33
336 | 10 1 0.47 21 9 83
337 | 5 1 2.87 12 20 35
338 | 40 1 1.21 61 10 252
339 | 31 1 6.12 81 6 278
340 | 1 1 9.8 63 20 143
341 | 11 1 3.42 83 13 136
342 | 37 0 0.86 45 0 2
343 | 22 0 1.73 66 14 19
344 | 32 1 9.73 59 3 334
345 | 25 0 2.95 81 7 78
346 | 29 1 0.65 56 17 174
347 | 14 0 0.82 48 10 2
348 | 14 1 5.17 90 5 202
349 | 23 1 0.54 30 8 149
350 | 29 0 0.8 49 18 9
351 | 27 1 0.3 30 20 145
352 | 35 1 8.11 79 5 340
353 | 17 1 9.21 78 5 273
354 | 5 1 9.7 100 18 278
355 | 1 1 7.53 49 2 125
356 | 16 1 9.3 47 19 197
357 | 39 0 1.33 20 17 0
358 | 10 1 6.74 13 12 88
359 | 10 1 8.18 55 17 170
360 | 18 0 8.44 37 12 85
361 | 2 0 5.22 27 12 21
362 | 7 0 1.15 2 4 0
363 | 0 0 5.14 13 2 5
364 | 26 0 2.55 85 18 34
365 | 5 0 9.59 20 7 60
366 | 13 1 0.8 34 0 100
367 | 22 1 0.57 89 15 145
368 | 30 1 8.27 22 9 200
369 | 16 1 5.93 16 17 110
370 | 23 1 9.43 19 15 174
371 | 28 1 1.17 63 3 191
372 | 4 0 4.02 69 5 64
373 | 27 1 9.6 54 8 287
374 | 27 0 6.49 97 12 128
375 | 6 0 9.95 94 3 225
376 | 25 1 2.34 1 14 143
377 | 9 1 5.04 40 3 129
378 | 14 1 2.77 81 11 135
379 | 22 0 2.87 1 18 3
380 | 23 1 9.75 81 4 316
381 | 21 1 2.62 77 4 182
382 | 37 1 1.24 55 5 244
383 | 32 1 4.64 37 13 208
384 | 29 0 9.87 31 2 66
385 | 22 1 4.26 72 7 211
386 | 9 1 8.4 16 1 108
387 | 27 1 0.02 22 9 134
388 | 13 1 4.62 62 19 139
389 | 1 1 4.44 68 12 114
390 | 15 1 2.25 43 14 122
391 | 20 1 7.93 28 15 186
392 | 34 1 7.37 13 4 217
393 | 25 0 4.81 99 0 129
394 | 34 1 3.84 24 2 235
395 | 27 1 1.11 46 12 148
396 | 16 0 0.5 74 4 16
397 | 38 1 4.41 98 20 318
398 | 14 1 5.41 79 10 177
399 | 16 0 0.43 61 19 3
400 | 27 0 1.22 75 11 11
401 | 9 1 9.4 74 12 207
402 | 26 0 2.59 0 12 0
403 | 23 1 3.41 45 17 166
404 | 18 0 6.48 99 6 158
405 | 22 1 9.16 90 2 334
406 | 21 0 3.91 74 2 70
407 | 4 1 5.56 90 19 168
408 | 40 1 8.98 25 18 276
409 | 2 1 7.48 52 5 143
410 | 34 1 8.89 6 8 230
411 | 12 1 5.53 97 2 215
412 | 15 1 9.23 54 10 199
413 | 33 1 8.95 65 2 343
414 | 37 1 1.2 17 8 193
415 | 0 1 7.26 27 6 72
416 | 7 1 0.24 96 8 58
417 | 36 1 2.26 0 16 182
418 | 8 0 8.03 69 12 125
419 | 21 1 2.45 72 0 172
420 | 40 0 9.96 53 0 146
421 | 34 1 1.82 74 14 239
422 | 29 1 5.88 80 2 304
423 | 6 0 3.03 96 14 72
424 | 7 1 5.96 34 3 118
425 | 15 1 2.59 32 0 113
426 | 2 0 8.9 92 18 197
427 | 3 1 5.9 90 18 150
428 | 7 0 3.98 38 8 36
429 | 32 0 0.25 69 2 30
430 | 3 1 8.06 78 16 186
431 | 30 0 9.58 59 11 114
432 | 28 0 3.52 3 9 0
433 | 6 0 8.29 9 3 44
434 | 15 1 0.52 56 3 118
435 | 33 0 9.56 63 10 161
436 | 0 0 4.6 17 12 23
437 | 36 1 2.13 65 8 243
438 | 36 1 7.05 88 9 339
439 | 6 1 1.39 95 12 76
440 | 40 1 9.31 2 7 241
441 | 2 1 0.61 5 19 5
442 | 38 0 5.6 41 2 70
443 | 25 1 4 61 14 181
444 | 36 1 5.36 20 20 199
445 | 6 1 9.47 32 12 106
446 | 40 1 3.43 92 16 302
447 | 18 0 1.56 76 12 22
448 | 34 1 5.37 40 3 264
449 | 1 0 2.69 19 2 35
450 | 20 0 4.43 62 15 42
451 | 12 1 6.95 22 13 114
452 | 6 1 3.22 9 1 51
453 | 38 1 5.32 27 18 246
454 | 7 1 6.47 30 6 123
455 | 10 1 8.69 43 20 143
456 | 24 1 2.6 15 9 142
457 | 20 1 9.37 84 8 312
458 | 26 1 8.78 56 18 269
459 | 7 1 1.42 11 3 63
460 | 25 1 5.71 74 9 247
461 | 21 1 1.68 12 9 132
462 | 32 0 9.69 31 15 75
463 | 8 1 4.39 38 9 97
464 | 15 1 4.6 4 10 95
465 | 13 1 2.88 92 4 162
466 | 28 1 9.21 6 16 194
467 | 2 1 5.99 44 18 90
468 | 26 0 0.08 38 20 0
469 | 9 1 0.79 34 20 57
470 | 27 1 1.26 40 3 188
471 | 22 1 2.37 27 7 160
472 | 18 0 2.03 71 6 25
473 | 19 1 1.77 64 19 123
474 | 30 1 0.65 66 9 176
475 | 26 0 8.47 78 18 149
476 | 39 1 0.72 54 13 233
477 | 35 0 0.51 82 7 12
478 | 4 1 1.8 13 15 30
479 | 34 1 3.29 21 11 215
480 | 11 1 0.4 78 9 98
481 | 27 0 8.94 27 20 35
482 | 17 1 1.11 69 6 139
483 | 18 1 7.4 15 15 116
484 | 6 0 6.11 46 10 77
485 | 37 0 3.38 35 1 29
486 | 16 1 6.5 52 8 195
487 | 34 1 6.72 93 10 328
488 | 36 1 5.07 96 10 328
489 | 0 0 4.72 63 1 86
490 | 5 1 8.63 1 10 62
491 | 15 0 0.33 64 17 15
492 | 1 1 2.83 67 6 85
493 | 38 0 4.45 52 18 67
494 | 20 1 3.43 1 15 109
495 | 39 0 4 54 16 56
496 | 11 0 5.34 41 10 61
497 | 26 1 8.62 94 2 351
498 | 30 1 0.02 11 20 150
499 | 30 1 6.75 78 15 305
500 | 18 1 9.88 36 16 186
501 | 21 0 8.63 58 20 129
502 | 28 1 5.38 36 4 224
503 | 15 1 4.72 41 6 134
504 | 12 1 6.66 19 8 108
505 | 17 1 7.11 0 20 85
506 | 7 1 7.51 8 0 78
507 | 39 1 2.77 71 14 253
508 | 4 1 1.09 99 15 88
509 | 26 0 9.94 68 0 178
510 | 39 1 5.84 22 2 237
511 | 13 1 5.81 29 18 106
512 | 4 1 3.28 100 3 132
513 | 3 1 9.24 44 9 121
514 | 24 1 6.75 60 10 225
515 | 30 1 5.94 68 3 267
516 | 7 1 0.62 71 20 63
517 | 31 0 4.43 34 9 45
518 | 21 1 4.57 85 5 239
519 | 23 0 5.59 100 0 123
520 | 17 0 4.93 77 2 116
521 | 8 1 1.36 74 13 99
522 | 10 1 5.24 95 12 195
523 | 7 1 0.14 91 11 73
524 | 17 0 4.57 6 3 0
525 | 26 1 3 38 12 162
526 | 35 1 9.03 84 5 364
527 | 5 1 6.04 67 12 150
528 | 10 1 0.53 35 17 78
529 | 38 1 1.95 91 8 241
530 | 11 1 8.16 93 3 270
531 | 2 0 6.9 68 4 136
532 | 33 1 8.55 12 15 218
533 | 40 1 7.55 63 8 333
534 | 7 1 8.75 20 5 97
535 | 26 1 5.03 81 12 253
536 | 30 1 3.1 54 12 221
537 | 6 0 2.58 7 9 0
538 | 4 1 6.38 34 1 117
539 | 37 1 5.25 80 7 327
540 | 9 0 5.37 71 16 104
541 | 33 0 2.91 78 13 54
542 | 4 1 7.8 17 11 85
543 | 0 1 0.24 72 17 33
544 | 19 1 5.74 26 17 136
545 | 33 1 3.9 51 18 214
546 | 3 1 1.91 23 9 25
547 | 14 0 0.56 32 8 0
548 | 13 1 6.85 91 7 239
549 | 19 1 8.5 87 8 305
550 | 37 1 2.11 24 6 236
551 | 2 1 7.55 75 5 181
552 | 17 1 2.56 42 11 135
553 | 18 0 1.28 9 5 7
554 | 17 1 4.24 24 13 108
555 | 30 1 2.24 9 20 141
556 | 11 1 5.42 97 0 216
557 | 25 1 6.38 84 18 268
558 | 26 1 3.94 15 5 176
559 | 19 0 1.92 1 0 1
560 | 21 1 0.58 52 17 112
561 | 31 1 3.76 69 8 233
562 | 8 1 9.65 2 12 52
563 | 21 0 6.47 94 2 163
564 | 19 1 9.21 65 8 248
565 | 1 0 5.47 66 4 74
566 | 5 0 4.75 93 3 109
567 | 13 1 0.06 7 6 74
568 | 20 1 1.75 82 5 161
569 | 30 1 2.28 63 20 215
570 | 0 1 2.65 47 5 58
571 | 9 1 1.26 76 18 74
572 | 33 1 5.83 92 1 334
573 | 2 1 1.44 18 19 27
574 | 29 1 8.44 19 14 196
575 | 7 1 4.9 90 20 133
576 | 21 1 8.72 42 0 213
577 | 30 1 7.7 75 3 328
578 | 6 1 0.12 5 6 62
579 | 40 1 3.9 91 2 329
580 | 33 1 2.03 70 2 210
581 | 2 0 9.26 17 20 42
582 | 12 1 1.29 33 10 92
583 | 5 0 0.24 50 20 1
584 | 22 1 1.43 73 7 171
585 | 35 1 5.36 75 18 270
586 | 4 1 0.45 51 7 54
587 | 31 1 5.69 100 0 330
588 | 27 1 3.6 88 10 243
589 | 6 1 1.44 50 7 89
590 | 31 1 8.99 44 19 266
591 | 0 1 6.42 24 14 72
592 | 35 1 6.44 1 16 174
593 | 10 1 6.74 93 17 218
594 | 33 1 8.54 85 3 347
595 | 4 0 9.43 85 4 201
596 | 7 1 9.1 12 7 83
597 | 5 0 4.51 84 16 86
598 | 28 1 5.96 21 3 184
599 | 29 1 5.12 93 16 262
600 | 10 1 3.17 13 2 77
601 | 1 1 2.06 97 8 100
602 | 4 1 6.74 57 17 110
603 | 24 0 5.62 51 2 68
604 | 0 0 2.84 34 4 14
605 | 11 0 8.69 49 0 131
606 | 25 1 1.05 18 4 156
607 | 19 1 6.13 88 0 233
608 | 34 0 9.47 34 7 71
609 | 25 1 7.13 0 1 159
610 | 6 1 0.79 21 7 47
611 | 26 1 1.39 51 0 191
612 | 1 0 4.44 50 7 72
613 | 31 1 1.28 20 17 174
614 | 6 1 4.77 73 10 143
615 | 20 1 3.06 9 8 110
616 | 19 1 4.21 27 0 154
617 | 37 0 5.52 31 7 47
618 | 22 1 6.28 16 19 154
619 | 17 1 2.66 48 12 123
620 | 26 1 6.67 28 0 188
621 | 19 1 2.23 66 5 147
622 | 25 1 1.11 9 12 143
623 | 37 1 1.66 2 20 188
624 | 15 1 7.5 9 1 105
625 | 0 1 7.18 59 8 116
626 | 24 0 4.14 53 9 72
627 | 22 1 5.43 37 16 167
628 | 20 1 6.61 22 18 161
629 | 0 1 1.44 9 19 8
630 | 33 0 2.07 41 6 13
631 | 18 1 2.89 41 20 128
632 | 1 1 2.28 57 3 47
633 | 8 1 8.02 100 15 257
634 | 23 1 2.99 34 17 160
635 | 38 1 1.95 32 0 231
636 | 21 1 7.43 35 0 179
637 | 35 1 0.46 27 18 198
638 | 34 1 1.71 82 12 228
639 | 8 0 7.7 77 8 143
640 | 28 1 9.95 44 12 280
641 | 23 1 8.81 26 10 197
642 | 7 1 7.79 72 12 176
643 | 23 1 9.36 38 3 225
644 | 32 0 4.8 80 6 103
645 | 25 1 7.47 93 1 332
646 | 18 1 4.93 14 9 144
647 | 34 1 3.52 76 7 272
648 | 19 1 1.55 40 3 136
649 | 21 1 2.21 30 16 145
650 | 39 1 8.94 83 18 396
651 | 13 0 0.93 12 4 20
652 | 22 1 4.26 50 8 167
653 | 22 0 7.54 91 14 151
654 | 27 1 5.94 96 0 303
655 | 17 0 3.01 47 17 7
656 | 6 0 2.19 92 7 40
657 | 1 1 4 8 6 41
658 | 29 0 0.65 24 12 0
659 | 14 1 2.79 29 8 103
660 | 27 1 2.42 51 13 180
661 | 11 1 7.28 52 0 160
662 | 19 0 3.88 98 11 78
663 | 8 1 9.4 87 1 272
664 | 34 1 6.43 14 1 216
665 | 9 1 0.18 13 17 58
666 | 24 1 1.29 25 20 146
667 | 15 1 0.22 38 0 92
668 | 24 1 2.43 74 7 193
669 | 29 1 3.79 64 13 206
670 | 3 1 8.5 77 3 196
671 | 24 0 8.58 25 10 68
672 | 40 1 0.52 74 19 237
673 | 7 1 3.68 6 16 62
674 | 34 0 9.34 45 6 125
675 | 9 1 5.86 83 8 179
676 | 33 1 7.93 0 2 215
677 | 18 1 2.69 11 19 124
678 | 2 1 3.11 17 5 34
679 | 9 1 4.62 60 19 140
680 | 6 1 8.42 56 12 180
681 | 36 1 6.39 6 3 215
682 | 15 1 8.78 98 1 285
683 | 34 1 0.21 7 18 170
684 | 3 1 5.45 26 11 64
685 | 26 0 3.52 71 1 60
686 | 18 1 8.84 47 10 219
687 | 3 1 4.25 19 14 49
688 | 4 0 7.1 23 0 53
689 | 19 1 0.7 47 11 102
690 | 7 1 3.08 62 19 78
691 | 38 1 5.52 87 19 311
692 | 20 1 4.53 63 19 188
693 | 1 1 6.47 68 10 133
694 | 24 1 1.05 58 5 152
695 | 14 0 3.27 69 4 72
696 | 35 0 5.75 34 9 61
697 | 39 1 1.78 83 7 241
698 | 20 0 8.13 73 16 151
699 | 13 0 7.07 83 19 112
700 | 15 1 3.38 89 9 151
701 | 35 1 5.06 66 1 293
702 | 5 1 6.22 35 5 99
703 | 23 0 8.74 71 8 137
704 | 17 1 1.32 37 1 129
705 | 19 1 3.09 97 1 192
706 | 29 1 0.59 18 17 169
707 | 19 1 2.04 100 8 171
708 | 24 1 0.37 46 13 143
709 | 6 1 9.33 60 16 167
710 | 34 1 7.21 82 11 336
711 | 33 1 6.22 44 5 257
712 | 32 1 3.66 10 16 181
713 | 21 1 2.18 75 20 138
714 | 0 1 8.29 67 11 149
715 | 7 1 2.38 4 7 70
716 | 29 1 2.36 82 6 213
717 | 15 1 2.13 61 12 109
718 | 16 1 6.53 28 12 148
719 | 32 1 8.54 7 4 194
720 | 15 1 1.94 75 16 142
721 | 17 1 8.72 43 6 219
722 | 15 0 1.06 45 20 0
723 | 32 1 4.51 20 12 195
724 | 37 0 5.77 23 7 32
725 | 32 1 6.35 34 11 239
726 | 40 1 0.78 57 17 240
727 | 32 1 7.67 61 15 282
728 | 4 1 4.64 33 5 74
729 | 8 0 5.53 48 4 62
730 | 33 1 3.96 13 4 198
731 | 5 1 6.27 50 14 103
732 | 18 1 7.19 48 11 193
733 | 36 0 3.03 53 0 67
734 | 33 1 4.61 51 0 249
735 | 16 1 1.93 15 4 90
736 | 33 1 5.47 90 18 281
737 | 20 1 7.91 72 8 274
738 | 18 1 9.6 93 8 312
739 | 7 1 7.91 74 8 176
740 | 37 1 4.01 72 14 270
741 | 7 1 2.9 78 2 102
742 | 29 1 5.59 54 19 234
743 | 5 1 5.17 60 1 132
744 | 12 1 1.24 69 11 115
745 | 14 1 7.39 86 10 252
746 | 1 1 0.09 63 17 18
747 | 30 1 4.93 75 2 262
748 | 31 1 4.42 92 20 272
749 | 30 1 8.96 8 18 203
750 | 36 0 5.3 94 9 125
751 | 28 1 7.11 74 16 259
752 | 7 1 9.91 68 7 220
753 | 36 1 0.48 46 16 213
754 | 36 0 6.79 46 0 88
755 | 10 1 1.21 7 5 61
756 | 18 1 8.06 71 15 252
757 | 6 0 0.68 89 5 35
758 | 11 1 1.14 47 20 71
759 | 17 1 9.1 67 17 250
760 | 39 0 4.73 90 19 102
761 | 8 1 2.79 41 17 93
762 | 0 1 1.03 89 12 57
763 | 8 1 1.77 25 5 62
764 | 1 1 4.14 31 11 61
765 | 7 1 4.33 65 9 105
766 | 7 1 9.47 100 7 274
767 | 5 1 5.57 35 7 88
768 | 33 1 1.84 95 17 212
769 | 38 0 6.11 78 3 139
770 | 38 1 1.82 80 15 229
771 | 35 1 5.36 42 11 264
772 | 2 1 8.13 94 15 205
773 | 34 1 5.88 97 19 306
774 | 40 1 2.36 26 12 231
775 | 13 1 2.15 64 13 134
776 | 8 1 0.79 47 2 96
777 | 39 0 1.06 12 11 13
778 | 37 0 0.15 23 8 9
779 | 17 1 1.34 91 14 135
780 | 37 1 0.51 18 15 216
781 | 20 1 8.01 52 6 229
782 | 24 1 8.84 16 15 182
783 | 0 0 4.81 28 20 34
784 | 15 1 9.18 13 6 127
785 | 24 1 8.01 4 2 154
786 | 25 1 4 70 8 234
787 | 18 1 6.67 84 15 222
788 | 25 1 5.02 92 10 275
789 | 27 1 9.44 13 8 198
790 | 17 0 5 30 15 21
791 | 20 1 3.66 67 5 188
792 | 5 0 7.34 12 5 19
793 | 0 1 9.51 16 16 69
794 | 27 0 6.23 5 18 9
795 | 29 1 6.51 82 15 289
796 | 38 0 3.05 71 4 50
797 | 25 0 6.36 38 5 80
798 | 32 1 1.86 43 3 210
799 | 6 1 7.26 59 13 156
800 | 40 1 4.86 79 16 291
801 | 5 0 4.33 89 6 101
802 | 5 1 4.51 99 16 153
803 | 37 1 4.74 92 20 290
804 | 30 1 3.73 64 4 224
805 | 8 1 4.2 95 19 157
806 | 35 0 2.85 96 12 58
807 | 2 0 9.8 76 11 169
808 | 31 1 2.91 29 20 185
809 | 17 1 9.76 30 7 202
810 | 0 1 0.64 8 12 0
811 | 0 1 9.57 34 12 111
812 | 21 1 4.35 29 1 180
813 | 14 1 0.18 92 6 93
814 | 9 1 5.14 81 7 152
815 | 21 1 7.05 32 18 192
816 | 10 0 8.99 94 20 168
817 | 38 1 1.19 74 2 253
818 | 6 1 9.5 45 18 137
819 | 9 0 0.37 78 12 4
820 | 14 1 4.21 54 15 133
821 | 16 1 6.07 96 8 258
822 | 11 0 7.29 81 5 153
823 | 31 1 1.34 83 8 198
824 | 21 1 2.45 46 13 142
825 | 22 1 0.29 61 9 148
826 | 27 1 0.14 13 11 148
827 | 26 1 6.26 87 4 300
828 | 1 0 8.21 73 11 148
829 | 19 1 3.16 28 2 160
830 | 15 0 5.59 1 2 13
831 | 30 1 1.4 7 10 169
832 | 15 1 3.96 16 0 111
833 | 18 1 7.94 14 13 143
834 | 0 1 8.44 15 20 34
835 | 20 0 0.32 84 12 27
836 | 12 1 9.34 92 5 269
837 | 25 1 1.96 2 7 159
838 | 17 1 8.31 56 18 189
839 | 31 1 5.61 4 18 183
840 | 15 1 7.08 78 10 211
841 | 38 0 5.65 14 15 25
842 | 39 1 4.67 8 12 221
843 | 10 0 8.09 16 18 30
844 | 6 0 8.22 39 6 71
845 | 20 1 4.57 2 12 118
846 | 28 1 8.74 86 14 349
847 | 13 1 9.04 49 12 174
848 | 24 1 2.54 66 8 185
849 | 5 0 6.41 41 4 66
850 | 40 1 0.59 68 1 232
851 | 23 1 2.05 43 11 144
852 | 13 1 4.75 83 20 172
853 | 16 1 5.1 86 11 220
854 | 28 0 2.34 87 0 64
855 | 26 0 2.6 98 0 76
856 | 30 0 1.48 48 17 0
857 | 38 1 7.72 3 15 218
858 | 7 1 2.81 3 5 74
859 | 40 1 5.67 69 6 324
860 | 31 1 2.99 36 12 215
861 | 30 1 6.18 47 20 219
862 | 28 1 5.02 36 4 224
863 | 23 1 0.62 54 9 156
864 | 31 0 6.08 11 2 35
865 | 16 1 7.27 49 7 192
866 | 20 1 3.66 28 8 132
867 | 5 1 9.93 52 11 177
868 | 36 0 0.02 92 4 12
869 | 39 1 7.88 52 0 317
870 | 29 0 1.64 73 15 21
871 | 37 0 8.97 78 2 171
872 | 5 1 1.3 39 11 67
873 | 14 1 1.32 24 20 80
874 | 20 1 9.35 19 14 180
875 | 24 1 8.41 47 19 210
876 | 4 1 0.84 76 17 48
877 | 15 1 1.17 94 18 115
878 | 22 1 8.08 95 20 292
879 | 27 1 3.39 28 9 170
880 | 5 1 2.03 91 17 77
881 | 17 1 8.91 10 14 122
882 | 40 0 6.76 95 9 151
883 | 7 1 7.79 2 13 72
884 | 11 1 3.69 61 20 134
885 | 26 0 5.28 1 18 0
886 | 14 1 6.1 5 5 88
887 | 29 1 5.17 10 10 195
888 | 26 0 4.28 91 14 88
889 | 8 1 4.04 73 3 128
890 | 11 1 2.18 61 10 94
891 | 2 1 7.6 77 9 157
892 | 37 0 0.33 22 5 0
893 | 27 1 3.22 23 7 182
894 | 16 1 2.01 72 6 144
895 | 22 1 8.74 19 19 165
896 | 28 1 3.37 92 9 240
897 | 32 1 4.58 1 1 174
898 | 27 1 1.92 71 17 176
899 | 11 1 2.23 14 12 82
900 | 24 0 5.05 24 12 42
901 | 9 1 5.38 16 7 79
902 | 27 1 0.42 88 1 164
903 | 33 1 9.22 53 12 305
904 | 14 1 2.52 100 11 173
905 | 11 0 2.87 82 3 83
906 | 16 1 3.85 66 9 146
907 | 13 1 8.9 91 6 269
908 | 8 0 3.59 70 15 50
909 | 6 1 0.9 67 2 92
910 | 26 0 2.87 89 13 82
911 | 1 1 1.41 6 17 5
912 | 28 1 9.72 6 18 183
913 | 23 1 3.71 83 12 204
914 | 7 0 3.69 10 3 0
915 | 29 1 8.21 39 8 251
916 | 32 1 3.97 17 7 213
917 | 16 1 5.7 22 13 142
918 | 33 0 5.86 42 7 67
919 | 6 0 9.84 89 13 209
920 | 29 1 9.27 39 10 259
921 | 39 0 2.16 57 20 12
922 | 17 1 6.01 71 3 210
923 | 33 1 5.44 73 18 255
924 | 16 1 7.39 68 20 196
925 | 18 1 8.06 33 8 190
926 | 37 0 9.9 11 1 48
927 | 36 0 1.67 2 5 0
928 | 39 1 2.53 82 11 286
929 | 9 1 9.57 8 4 84
930 | 22 1 7.01 54 16 218
931 | 19 1 9.01 96 4 314
932 | 21 1 9.77 78 14 314
933 | 19 0 1.4 86 10 41
934 | 10 0 4.59 28 3 36
935 | 21 0 7.62 3 15 0
936 | 7 0 7.26 26 4 65
937 | 16 1 4.35 66 16 154
938 | 10 0 2.8 92 13 52
939 | 33 1 0.04 67 6 211
940 | 20 1 6.15 33 4 176
941 | 39 1 9.27 79 4 382
942 | 26 1 5.65 63 15 225
943 | 20 1 1.09 15 16 103
944 | 4 0 2.14 85 2 45
945 | 5 1 5.77 54 17 124
946 | 16 1 3.49 51 7 149
947 | 6 1 7.36 84 9 184
948 | 0 1 9.26 64 8 164
949 | 40 1 6.27 27 17 274
950 | 40 1 0.67 36 6 214
951 | 38 1 2.12 88 0 256
952 | 30 1 0.27 99 15 175
953 | 32 0 4.96 52 14 45
954 | 40 0 0.93 26 13 11
955 | 26 1 6.38 66 9 246
956 | 36 1 6.96 40 7 270
957 | 10 1 5.01 46 10 130
958 | 0 1 9.51 50 11 144
959 | 39 0 7.98 32 19 66
960 | 35 1 2.15 24 2 207
961 | 9 1 9.91 45 1 164
962 | 14 0 4.1 48 16 57
963 | 24 1 8.78 20 9 194
964 | 15 1 0.32 53 13 78
965 | 23 1 8.5 45 16 233
966 | 38 1 4.37 68 4 305
967 | 7 1 9.79 38 7 165
968 | 16 1 4.64 98 1 234
969 | 25 1 0.36 6 14 139
970 | 33 1 1.78 36 0 192
971 | 33 1 1.26 58 11 210
972 | 3 1 8.12 76 9 182
973 | 6 1 3.33 40 11 67
974 | 8 1 1.37 33 9 88
975 | 15 0 1.5 14 14 0
976 | 25 1 6.08 91 2 303
977 | 16 1 8.4 92 11 289
978 | 36 1 1.14 0 0 187
979 | 3 0 0.78 30 1 31
980 | 5 0 1.97 53 17 10
981 | 34 1 0.99 92 3 246
982 | 30 1 7.52 57 16 248
983 | 11 0 0.13 8 16 0
984 | 40 1 8.95 95 5 439
985 | 38 1 2.97 38 4 235
986 | 36 1 9.47 55 20 332
987 | 19 0 2.43 46 8 43
988 | 17 1 1.35 20 7 127
989 | 2 1 4.8 91 5 160
990 | 27 1 9.9 92 7 380
991 | 33 1 4.14 23 10 224
992 | 16 0 3.19 15 14 12
993 | 22 0 4.65 38 2 59
994 | 24 0 7.81 43 3 100
995 | 14 1 7.26 66 1 212
996 | 19 1 3.84 39 17 148
997 | 8 1 4.1 76 17 142
998 | 25 0 1.66 13 5 24
999 | 20 1 8.95 52 11 230
1000 | 16 1 9.4 91 15 284
1001 | 39 1 2.58 47 0 241
1002 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Automatically generated by https://github.com/damnever/pigar.
2 |
3 | catboost==1.1.1
4 | matplotlib==3.6.0
5 | numpy==1.23.3
6 | opencv-python==4.6.0.66
7 | pandas==1.5.0
8 | Pillow==9.2.0
9 | scikit-learn==1.2.0
10 | seaborn==0.12.2
11 | shap==0.41.0
12 | torch==1.13.1
13 | torchvision==0.14.1
14 | xgboost==1.7.3
15 |
--------------------------------------------------------------------------------
/src/additional_resources/CatBoostClassifier.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#imports\n",
10 | "import pandas as pd\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "\n",
13 | "from catboost import CatBoostClassifier\n",
14 | "import xgboost as xgb\n",
15 | "\n",
16 | "import shap\n",
17 | "\n",
18 | "from sklearn.metrics import accuracy_score,confusion_matrix\n",
19 | "\n",
20 | "#set figure background to white\n",
21 | "plt.rcParams.update({'figure.facecolor':'white'})"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Dataset"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "#load data \n",
38 | "data = pd.read_csv(\"../../data/mushrooms.csv\")\n",
39 | "\n",
40 | "#get features\n",
41 | "y = data['class']\n",
42 | "y = y.astype('category').cat.codes\n",
43 | "X = data.drop('class', axis=1)\n",
44 | "\n",
45 | "print(len(data))\n",
46 | "data.head()"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "# XGBoost"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# Create dummy variables for the categorical features\n",
63 | "X_dummy = pd.get_dummies(X)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "# Fit model\n",
73 | "model = xgb.XGBClassifier()\n",
74 | "model.fit(X_dummy, y)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "#Get SHAP values\n",
84 | "explainer = shap.Explainer(model)\n",
85 | "shap_values = explainer(X_dummy)\n",
86 | "\n",
87 | "# Display SHAP values for the first observation\n",
88 | "shap.plots.waterfall(shap_values[0])"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "# CatBoost"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "model = CatBoostClassifier(iterations=20,\n",
105 | " learning_rate=0.01,\n",
106 | " depth=3)\n",
107 | "\n",
108 | "# train model\n",
109 | "cat_features = list(range(len(X.columns)))\n",
110 | "model.fit(X, y, cat_features)"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "#Get SHAP values\n",
120 | "explainer = shap.Explainer(model)\n",
121 | "shap_values = explainer(X)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "# Display SHAP values for the first observation\n",
131 | "shap.plots.waterfall(shap_values[0])"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "# Mean SHAP \n",
141 | "shap.plots.bar(shap_values)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "# Beeswarm plot \n",
151 | "shap.plots.beeswarm(shap_values)"
152 | ]
153 | }
154 | ],
155 | "metadata": {
156 | "kernelspec": {
157 | "display_name": "shap",
158 | "language": "python",
159 | "name": "shap"
160 | },
161 | "language_info": {
162 | "codemirror_mode": {
163 | "name": "ipython",
164 | "version": 3
165 | },
166 | "file_extension": ".py",
167 | "mimetype": "text/x-python",
168 | "name": "python",
169 | "nbconvert_exporter": "python",
170 | "pygments_lexer": "ipython3",
171 | "version": "3.10.4"
172 | }
173 | },
174 | "nbformat": 4,
175 | "nbformat_minor": 2
176 | }
177 |
--------------------------------------------------------------------------------
/src/additional_resources/IsolationForest.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Imports \n",
10 | "import pandas as pd\n",
11 | "import numpy as np\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "import seaborn as sns\n",
14 | "\n",
15 | "import shap\n",
16 | "\n",
17 | "from sklearn.ensemble import IsolationForest\n",
18 | "from ucimlrepo import fetch_ucirepo\n",
19 | "\n",
20 | "# Set figure background to white\n",
21 | "plt.rcParams.update({'figure.facecolor':'white'})"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Data Cleaning and Feature Engineering"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# Fetch dataset from UCI repository\n",
38 | "power_consumption = fetch_ucirepo(id=235)\n",
39 | "\n",
40 | "print(power_consumption.variables) "
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "# Get all features\n",
50 | "data = power_consumption.data.features\n",
51 | "data['Date'] = pd.to_datetime(data['Date'], format='%d/%m/%Y')\n",
52 | "\n",
53 | "# List of features to check\n",
54 | "feature_columns = ['Global_active_power', 'Global_reactive_power', 'Voltage', \n",
55 | " 'Global_intensity', 'Sub_metering_1', 'Sub_metering_2', 'Sub_metering_3']\n",
56 | "\n",
57 | "# Convert feature columns to numeric and replace any errors with NaN\n",
58 | "data[feature_columns] = data[feature_columns].apply(pd.to_numeric, errors='coerce')\n",
59 | "\n",
60 | "# Drop rows where all feature columns are missing (NaN) \n",
61 | "data_cleaned = data.dropna(subset=feature_columns, how='all')\n",
62 | "\n",
63 | "# Drop rows where ALL feature columns are NaN\n",
64 | "data_cleaned.head()"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# Group by 'Date' and calculate mean and standard deviation (ignore NaN values)\n",
74 | "data_aggregated = data_cleaned.groupby('Date')[feature_columns].agg(['mean', 'std'])\n",
75 | "\n",
76 | "# Rename columns to the desired format (MEAN_ColumnName, STD_ColumnName)\n",
77 | "data_aggregated.columns = [\n",
78 | " f'{agg_type.upper()}_{col}' for col, agg_type in data_aggregated.columns\n",
79 | "]\n",
80 | "\n",
81 | "# Reset the index\n",
82 | "data_aggregated.reset_index(inplace=True)\n",
83 | "\n",
84 | "# Display the result\n",
85 | "print(data_aggregated.shape)\n",
86 | "data_aggregated.head()"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "# Train IsolationForest"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 10,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "# Parameters\n",
103 | "n_estimators = 100 # Number of trees\n",
104 | "sample_size = 256 # Number of samples used to train each tree\n",
105 | "contamination = 0.02 # Expected proportion of anomalies"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "# Select Features\n",
115 | "features = data_aggregated.drop('Date', axis=1)\n",
116 | "\n",
117 | "# Train Isolation Forest\n",
118 | "iso_forest = IsolationForest(n_estimators=n_estimators, \n",
119 | " contamination=contamination, \n",
120 | " max_samples=sample_size,\n",
121 | " random_state=42)\n",
122 | "\n",
123 | "iso_forest.fit(features)"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "data_aggregated['anomaly_score'] = iso_forest.decision_function(features)\n",
133 | "data_aggregated['anomaly'] = iso_forest.predict(features)\n",
134 | "\n",
135 | "data_aggregated['anomaly'].value_counts()"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "# Visualization of the results\n",
145 | "plt.figure(figsize=(10, 5))\n",
146 | "\n",
147 | "# Plot normal instances\n",
148 | "normal = data_aggregated[data_aggregated['anomaly'] == 1]\n",
149 | "plt.scatter(normal['Date'], normal['anomaly_score'], label='Normal')\n",
150 | "\n",
151 | "# Plot anomalies\n",
152 | "anomalies = data_aggregated[data_aggregated['anomaly'] == -1]\n",
153 | "plt.scatter(anomalies['Date'], anomalies['anomaly_score'], label='Anomaly')\n",
154 | "\n",
155 | "plt.xlabel(\"Instance\")\n",
156 | "plt.ylabel(\"Anomaly Score\")\n",
157 | "plt.legend()"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "metadata": {},
163 | "source": [
164 | "# KernelSHAP with Anomaly Score\n"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "# Using the anomaly score and TreeSHAP (this code won't work)\n",
174 | "explainer = shap.TreeExplainer(iso_forest.decision_function, features)\n",
175 | "shap_values = explainer(features)"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "# Select all anomalies and 100 random normal instances\n",
185 | "normal_sample = np.random.choice(normal.index,size=100,replace=False)\n",
186 | "sample = np.append(anomalies.index,normal_sample)\n",
187 | "\n",
188 | "len(sample) # 129"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "# Using the anomaly score and KernelSHAP\n",
198 | "explainer = shap.Explainer(iso_forest.decision_function, features)\n",
199 | "shap_values = explainer(features.iloc[sample])"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "# Plot waterfall plot of an anomaly\n",
209 | "shap.plots.waterfall(shap_values[0])"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "# Plot waterfall plot of a normal instance\n",
219 | "shap.plots.waterfall(shap_values[100])"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "# MeanSHAP Plot\n",
229 | "shap.plots.bar(shap_values)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "# Beeswarm plot\n",
239 | "shap.plots.beeswarm(shap_values)"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "# TreeSHAP with Path Length"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 22,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "# Calculate SHAP values\n",
256 | "explainer = shap.TreeExplainer(iso_forest)\n",
257 | "shap_values = explainer(features)"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "# Waterfall plot for an anomaly\n",
267 | "shap.plots.waterfall(shap_values[0])"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "# Waterfall plot for a normal instance\n",
277 | "shap.plots.waterfall(shap_values[2])"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "# Calculate f(x)\n",
287 | "path_length = shap_values.base_values + shap_values.values.sum(axis=1)\n",
288 | "\n",
289 | "# Get f(x) for anomalies and normal instances\n",
290 | "anomalies = data_aggregated[data_aggregated['anomaly'] == -1]\n",
291 | "path_length_anomalies = path_length[anomalies.index]\n",
292 | "\n",
293 | "normal = data_aggregated[data_aggregated['anomaly'] == 1]\n",
294 | "path_length_normal = path_length[normal.index]\n",
295 | "\n",
296 | "# Plot boxplots for f(x)\n",
297 | "plt.figure(figsize=(10, 5))\n",
298 | "plt.boxplot([path_length_anomalies, path_length_normal], labels=['Anomaly','Normal'])\n",
299 | "plt.ylabel(\"Average Path Length f(x)\")"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "# MeanSHAP\n",
309 | "shap.plots.bar(shap_values)"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": null,
315 | "metadata": {},
316 | "outputs": [],
317 | "source": [
318 | "# MeanSHAP\n",
319 | "shap.plots.beeswarm(shap_values)"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": 26,
325 | "metadata": {},
326 | "outputs": [],
327 | "source": [
328 | "# Interaction values\n",
329 | "shap_interaction_values = explainer.shap_interaction_values(features)"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "metadata": {},
336 | "outputs": [],
337 | "source": [
338 | "# Get absolute mean of matrices\n",
339 | "mean_shap = np.abs(shap_interaction_values).mean(0)\n",
340 | "mean_shap = np.round(mean_shap, 1)\n",
341 | "\n",
342 | "df = pd.DataFrame(mean_shap, index=features.columns, columns=features.columns)\n",
343 | "\n",
344 | "# Times off diagonal by 2\n",
345 | "df.where(df.values == np.diagonal(df), df.values * 2, inplace=True)\n",
346 | "\n",
347 | "# Display\n",
348 | "sns.set(font_scale=1)\n",
349 | "sns.heatmap(df, cmap=\"coolwarm\", annot=True)\n",
350 | "plt.yticks(rotation=0)"
351 | ]
352 | }
353 | ],
354 | "metadata": {
355 | "kernelspec": {
356 | "display_name": "shap",
357 | "language": "python",
358 | "name": "shap"
359 | },
360 | "language_info": {
361 | "codemirror_mode": {
362 | "name": "ipython",
363 | "version": 3
364 | },
365 | "file_extension": ".py",
366 | "mimetype": "text/x-python",
367 | "name": "python",
368 | "nbconvert_exporter": "python",
369 | "pygments_lexer": "ipython3",
370 | "version": "3.10.4"
371 | }
372 | },
373 | "nbformat": 4,
374 | "nbformat_minor": 2
375 | }
376 |
--------------------------------------------------------------------------------
/src/additional_resources/RandomForestRegressor.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SHAP for RandomForestRegressor\n",
8 | "
\n",
9 | "Dataset: https://www.kaggle.com/datasets/conorsully1/interaction-dataset"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Imports\n",
19 | "import pandas as pd\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "\n",
22 | "from sklearn.ensemble import RandomForestRegressor\n",
23 | "\n",
24 | "import shap\n",
25 | "shap.initjs() \n",
26 | "\n",
27 | "# Set figure background to white\n",
28 | "plt.rcParams.update({'figure.facecolor':'white'})"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "## Dataset"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "#import dataset\n",
45 | "data = pd.read_csv(\"../../data/interaction_dataset.csv\",sep='\\t')\n",
46 | "\n",
47 | "y = data['bonus']\n",
48 | "X = data.drop('bonus', axis=1)\n",
49 | "\n",
50 | "print(len(data))\n",
51 | "data.head()"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "# Modelling"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "#Train model\n",
68 | "model = RandomForestRegressor(n_estimators=100) \n",
69 | "model.fit(X, y)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "# SHAP Values"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 4,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "#Get SHAP values\n",
86 | "explainer = shap.Explainer(model)\n",
87 | "shap_values = explainer(X)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "# Plot waterfall\n",
97 | "shap.plots.waterfall(shap_values[0])"
98 | ]
99 | }
100 | ],
101 | "metadata": {
102 | "kernelspec": {
103 | "display_name": "shap",
104 | "language": "python",
105 | "name": "shap"
106 | },
107 | "language_info": {
108 | "codemirror_mode": {
109 | "name": "ipython",
110 | "version": 3
111 | },
112 | "file_extension": ".py",
113 | "mimetype": "text/x-python",
114 | "name": "python",
115 | "nbconvert_exporter": "python",
116 | "pygments_lexer": "ipython3",
117 | "version": "3.10.4"
118 | }
119 | },
120 | "nbformat": 4,
121 | "nbformat_minor": 2
122 | }
123 |
--------------------------------------------------------------------------------
/src/archive/image_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Using SHAP to debug a PyTorch Image Regression Model"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 2,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "# Imports\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "\n",
22 | "import glob \n",
23 | "import random \n",
24 | "\n",
25 | "from PIL import Image\n",
26 | "import cv2\n",
27 | "\n",
28 | "import torch\n",
29 | "import torchvision\n",
30 | "from torchvision import transforms\n",
31 | "from torch.utils.data import DataLoader\n",
32 | "\n",
33 | "import shap\n",
34 | "from sklearn.metrics import mean_squared_error"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "#Load example image\n",
44 | "name = \"32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg\"\n",
45 | "x = int(name.split(\"_\")[0])\n",
46 | "y = int(name.split(\"_\")[1])\n",
47 | "\n",
48 | "img = Image.open(\"../data/room_1/\" + name)\n",
49 | "img = np.array(img)\n",
50 | "cv2.circle(img, (x, y), 8, (0, 255, 0), 3)\n",
51 | "\n",
52 | "plt.imshow(img)\n",
53 | "\n",
54 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/example.png\"\n",
55 | "plt.savefig(path, bbox_inches='tight',facecolor='w', edgecolor='w', transparent=False,dpi=200)"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# Model Training"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 5,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "class ImageDataset(torch.utils.data.Dataset):\n",
72 | " def __init__(self, paths, transform):\n",
73 | "\n",
74 | " self.transform = transform\n",
75 | " self.paths = paths\n",
76 | "\n",
77 | " def __getitem__(self, idx):\n",
78 | " \"\"\"Get image and target (x, y) coordinates\"\"\"\n",
79 | "\n",
80 | " # Read image\n",
81 | " path = self.paths[idx]\n",
82 | " image = cv2.imread(path, cv2.IMREAD_COLOR)\n",
83 | " image = Image.fromarray(image)\n",
84 | "\n",
85 | " # Transform image\n",
86 | " image = self.transform(image)\n",
87 | " \n",
88 | " # Get target\n",
89 | " target = self.get_target(path)\n",
90 | " target = torch.Tensor(target)\n",
91 | "\n",
92 | " return image, target\n",
93 | " \n",
94 | " def get_target(self,path):\n",
95 | " \"\"\"Get the target (x, y) coordinates from path\"\"\"\n",
96 | "\n",
97 | " name = os.path.basename(path)\n",
98 | " items = name.split('_')\n",
99 | " x = items[0]\n",
100 | " y = items[1]\n",
101 | "\n",
102 | " # Scale between -1 and 1\n",
103 | " x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right\n",
104 | " y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom\n",
105 | "\n",
106 | " return [x, y]\n",
107 | "\n",
108 | " def __len__(self):\n",
109 | " return len(self.paths)\n"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "TRANSFORMS = transforms.Compose([\n",
119 | " transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),\n",
120 | " transforms.Resize((224, 224)),\n",
121 | " transforms.ToTensor(),\n",
122 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
123 | "])\n",
124 | "\n",
125 | "all_rooms = False # Change if you want to use all the data\n",
126 | "\n",
127 | "paths = glob.glob('../data/room_1/*')\n",
128 | "if all_rooms:\n",
129 | " paths = paths + glob.glob('../data/room_2/*') + glob.glob('../data/room_3/*')\n",
130 | "\n",
131 | "# Shuffle the paths\n",
132 | "random.shuffle(paths)\n",
133 | "\n",
134 | "# Create a datasets for training and validation\n",
135 | "split = int(0.8 * len(paths))\n",
136 | "train_data = ImageDataset(paths[:split], TRANSFORMS)\n",
137 | "valid_data = ImageDataset(paths[split:], TRANSFORMS)\n",
138 | "\n",
139 | "# Prepare data for Pytorch model\n",
140 | "train_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
141 | "valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())\n",
142 | "\n",
143 | "print(train_data.__len__())\n",
144 | "print(valid_data.__len__())"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 63,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "output_dim = 2 # x, y\n",
154 | "device = torch.device('cpu') # or 'cuda' if you have a GPU\n",
155 | "\n",
156 | "# RESNET 18\n",
157 | "model = torchvision.models.resnet18(pretrained=True)\n",
158 | "model.fc = torch.nn.Linear(512, output_dim)\n",
159 | "model = model.to(device)\n",
160 | "\n",
161 | "optimizer = torch.optim.Adam(model.parameters())"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "name = \"direction_model_1\" # Change this to save a new model\n",
171 | "\n",
172 | "# Train the model\n",
173 | "min_loss = np.inf\n",
174 | "for epoch in range(10):\n",
175 | "\n",
176 | " model = model.train()\n",
177 | " for images, target in iter(train_loader):\n",
178 | "\n",
179 | " images = images.to(device)\n",
180 | " target = target.to(device)\n",
181 | " \n",
182 | " # Zero gradients of parameters\n",
183 | " optimizer.zero_grad() \n",
184 | "\n",
185 | " # Execute model to get outputs\n",
186 | " output = model(images)\n",
187 | "\n",
188 | " # Calculate loss\n",
189 | " loss = torch.nn.functional.mse_loss(output, target)\n",
190 | "\n",
191 | " # Run backpropogation to accumulate gradients\n",
192 | " loss.backward()\n",
193 | "\n",
194 | " # Update model parameters\n",
195 | " optimizer.step()\n",
196 | "\n",
197 | " # Calculate validation loss\n",
198 | " model = model.eval()\n",
199 | "\n",
200 | " images, target = next(iter(valid_loader))\n",
201 | " images = images.to(device)\n",
202 | " target = target.to(device)\n",
203 | "\n",
204 | " output = model(images)\n",
205 | " valid_loss = torch.nn.functional.mse_loss(output, target)\n",
206 | "\n",
207 | " print(\"Epoch: {}, Validation Loss: {}\".format(epoch, valid_loss.item()))\n",
208 | " \n",
209 | " if valid_loss < min_loss:\n",
210 | " print(\"Saving model\")\n",
211 | " torch.save(model, '../models/{}.pth'.format(name))\n",
212 | "\n",
213 | " min_loss = valid_loss"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "# Model Evaluation"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 9,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "def model_evaluation(loaders,labels,save_path = None):\n",
230 | "\n",
231 | " \"\"\"Evaluate direction models with mse and scatter plots\n",
232 | " loaders: list of data loaders\n",
233 | " labels: list of labels for plot title\"\"\"\n",
234 | "\n",
235 | " n = len(loaders)\n",
236 | " fig, axs = plt.subplots(1, n, figsize=(7*n, 6))\n",
237 | " fig.patch.set_facecolor('xkcd:white')\n",
238 | "\n",
239 | " # Evalution metrics\n",
240 | " for i, loader in enumerate(loaders):\n",
241 | "\n",
242 | " # Load all data\n",
243 | " images, target = next(iter(loader))\n",
244 | " images = images.to(device)\n",
245 | " target = target.to(device)\n",
246 | "\n",
247 | " output=model(images)\n",
248 | "\n",
249 | " # Get x predictions\n",
250 | " x_pred=output.detach().cpu().numpy()[:,0]\n",
251 | " x_target=target.cpu().numpy()[:,0]\n",
252 | "\n",
253 | " # Calculate MSE\n",
254 | " mse = mean_squared_error(x_target, x_pred)\n",
255 | "\n",
256 | " # Plot predcitons\n",
257 | " axs[i].scatter(x_target,x_pred)\n",
258 | " axs[i].plot([-1, 1], \n",
259 | " [-1, 1], \n",
260 | " color='r', \n",
261 | " linestyle='-', \n",
262 | " linewidth=2)\n",
263 | "\n",
264 | " axs[i].set_ylabel('Predicted x', size =15)\n",
265 | " axs[i].set_xlabel('Actual x', size =15)\n",
266 | " axs[i].set_title(\"{0} MSE: {1:.4f}\".format(labels[i], mse),size = 18)\n",
267 | "\n",
268 | " if save_path != None:\n",
269 | " fig.savefig(save_path)\n"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {},
276 | "outputs": [],
277 | "source": [
278 | "# Load saved model \n",
279 | "model = torch.load('../models/direction_model_1.pth')\n",
280 | "model.eval()\n",
281 | "model.to(device)\n",
282 | "\n",
283 | "# Create new loader for all data\n",
284 | "train_loader = DataLoader(train_data, batch_size=train_data.__len__())\n",
285 | "\n",
286 | "# Evaluate model on training and validation set\n",
287 | "loaders = [train_loader,valid_loader]\n",
288 | "labels = [\"Train\",\"Validation\"]\n",
289 | "\n",
290 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_1.png\"\n",
291 | "model_evaluation(loaders,labels,save_path=path)"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "metadata": {},
298 | "outputs": [],
299 | "source": [
300 | "# Evaluate on data for additonal rooms\n",
301 | "room_2 = glob.glob('../data/room_2/*')\n",
302 | "room_3 = glob.glob('../data/room_3/*')\n",
303 | "\n",
304 | "room_2_data = ImageDataset(room_2, TRANSFORMS)\n",
305 | "room_3_data = ImageDataset(room_3, TRANSFORMS)\n",
306 | "\n",
307 | "room_2_loader = DataLoader(room_2_data, batch_size=room_2_data.__len__())\n",
308 | "room_3_loader = DataLoader(room_3_data, batch_size=room_3_data.__len__())\n",
309 | "\n",
310 | "# Evaluate model on training and validation set\n",
311 | "loaders = [room_2_loader ,room_3_loader]\n",
312 | "labels = [\"Room 2\",\"Room 3\"]\n",
313 | "\n",
314 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_2.png\"\n",
315 | "model_evaluation(loaders,labels, save_path=path)"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "# Load saved model \n",
325 | "model = torch.load('../models/direction_model_2.pth')\n",
326 | "\n",
327 | "model.eval()\n",
328 | "model.to(device)\n",
329 | "\n",
330 | "# Evaluate model on training and validation set\n",
331 | "loaders = [room_2_loader ,room_3_loader]\n",
332 | "labels = [\"Room 2\",\"Room 3\"]\n",
333 | "\n",
334 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_3.png\"\n",
335 | "model_evaluation(loaders,labels,save_path=path)"
336 | ]
337 | },
338 | {
339 | "attachments": {},
340 | "cell_type": "markdown",
341 | "metadata": {},
342 | "source": [
343 | "# SHAP Explainer "
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 7,
349 | "metadata": {},
350 | "outputs": [],
351 | "source": [
352 | "# Load saved model \n",
353 | "model = torch.load('../models/direction_model_1.pth') #change for different model\n",
354 | "model.eval()\n",
355 | "\n",
356 | "# Use CPU\n",
357 | "device = torch.device('cpu')\n",
358 | "model = model.to(device)"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 8,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "#Load 100 images for background\n",
368 | "shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)\n",
369 | "background, _ = next(iter(shap_loader))\n",
370 | "background = background.to(device)\n",
371 | "\n",
372 | "#Create SHAP explainer \n",
373 | "explainer = shap.DeepExplainer(model, background)"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": null,
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "# Load test images of right and left turn\n",
383 | "paths = glob.glob('../data/room_1/*')\n",
384 | "test_images = [Image.open(paths[0]), Image.open(paths[3])]\n",
385 | "test_images = np.array(test_images)\n",
386 | "\n",
387 | "test_input = [TRANSFORMS(img) for img in test_images]\n",
388 | "test_input = torch.stack(test_input).to(device)\n",
389 | "\n",
390 | "# Get SHAP values\n",
391 | "shap_values = explainer.shap_values(test_input)\n",
392 | "\n",
393 | "# Reshape shap values and images for plotting\n",
394 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
395 | "test_numpy = np.array([np.array(img) for img in test_images])\n",
396 | "\n",
397 | "shap.image_plot(shap_numpy, test_numpy,show=False)"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": null,
403 | "metadata": {},
404 | "outputs": [],
405 | "source": [
406 | "# Using gradient explainer\n",
407 | "explainer = shap.GradientExplainer(model, background)\n",
408 | "shap_values = explainer.shap_values(test_input)\n",
409 | "\n",
410 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
411 | "\n",
412 | "shap.image_plot(shap_numpy, test_numpy)"
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "metadata": {},
419 | "outputs": [],
420 | "source": [
421 | "# Load model trained on room 1, 2 and 3\n",
422 | "model = torch.load('../models/direction_model_2.pth') #change for different model\n",
423 | "\n",
424 | "# Use CPU\n",
425 | "device = torch.device('cpu')\n",
426 | "model = model.to(device)\n",
427 | "\n",
428 | "#Load 100 images for background\n",
429 | "shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)\n",
430 | "background, _ = next(iter(shap_loader))\n",
431 | "background = background.to(device)\n",
432 | "\n",
433 | "#Create SHAP explainer \n",
434 | "explainer = shap.DeepExplainer(model, background)\n",
435 | "\n",
436 | "# Load test images of right and left turn\n",
437 | "paths = glob.glob('../data/room_1/*')\n",
438 | "test_images = [Image.open(paths[0]), Image.open(paths[3])]\n",
439 | "test_images = np.array(test_images)\n",
440 | "\n",
441 | "# Transform images\n",
442 | "test_input = [TRANSFORMS(img) for img in test_images]\n",
443 | "test_input = torch.stack(test_input).to(device)\n",
444 | "\n",
445 | "# Get SHAP values\n",
446 | "shap_values = explainer.shap_values(test_input)\n",
447 | "\n",
448 | "# Reshape shap values and images for plotting\n",
449 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
450 | "test_numpy = np.array([np.array(img) for img in test_images])\n",
451 | "\n",
452 | "shap.image_plot(shap_numpy, test_numpy,show=False)\n",
453 | "plt.savefig(\"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/shap_plot_2.png\",facecolor='white',dpi=300,bbox_inches='tight')"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": null,
459 | "metadata": {},
460 | "outputs": [],
461 | "source": [
462 | "# Using Gradient Explainer\n",
463 | "e = shap.GradientExplainer(model, background)\n",
464 | "shap_values = e.shap_values(test_input)\n",
465 | "\n",
466 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
467 | "test_numpy = np.array([np.array(img) for img in test_images])\n",
468 | "\n",
469 | "shap.image_plot(shap_numpy, test_numpy)"
470 | ]
471 | }
472 | ],
473 | "metadata": {
474 | "kernelspec": {
475 | "display_name": "pytorch",
476 | "language": "python",
477 | "name": "pytorch"
478 | },
479 | "language_info": {
480 | "codemirror_mode": {
481 | "name": "ipython",
482 | "version": 3
483 | },
484 | "file_extension": ".py",
485 | "mimetype": "text/x-python",
486 | "name": "python",
487 | "nbconvert_exporter": "python",
488 | "pygments_lexer": "ipython3",
489 | "version": "3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ]"
490 | },
491 | "orig_nbformat": 4,
492 | "vscode": {
493 | "interpreter": {
494 | "hash": "3c0d4fcf1a0a408688084e944cab5ef64e86c1ae9800e884f9b7a2ac0ee51db6"
495 | }
496 | }
497 | },
498 | "nbformat": 4,
499 | "nbformat_minor": 2
500 | }
501 |
--------------------------------------------------------------------------------
/src/archive/project_1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Project 1: salary bonus\n",
8 | "
\n",
9 | "Use SHAP to answer the following questions:\n",
10 | "