├── LICENSE ├── README.md ├── data ├── abalone.data ├── interaction_dataset.csv └── mushrooms.csv ├── requirements.txt └── src ├── additional_resources ├── CatBoostClassifier.ipynb ├── IsolationForest.ipynb └── RandomForestRegressor.ipynb ├── archive ├── image_data.ipynb ├── project_1.ipynb ├── project_1_solution.ipynb ├── project_2.ipynb └── project_2_solution.ipynb ├── kernel_vs_tree.ipynb └── shap_tutorial.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Conor O'Sullivan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SHAP-tutorial 2 | 3 | To access the course: https://adataodyssey.com/courses/shap-with-python/ 4 | 5 | Watch the course outline here: https://youtu.be/n98pFxcD73w 6 | -------------------------------------------------------------------------------- /data/interaction_dataset.csv: -------------------------------------------------------------------------------- 1 | "experience" "degree" "performance" "sales" "days_late" "bonus" 2 | 31 1 6.11 29 14 197 3 | 35 1 9.55 44 8 314 4 | 9 1 2.64 26 20 88 5 | 40 1 0.22 13 7 233 6 | 18 1 6.46 11 13 108 7 | 0 1 4.57 72 0 101 8 | 31 0 3.38 83 9 61 9 | 14 1 5.63 27 9 112 10 | 37 1 0.03 13 7 183 11 | 16 1 1.1 7 16 92 12 | 36 1 5.3 18 18 225 13 | 11 1 5.35 17 18 107 14 | 34 1 3.88 53 4 234 15 | 28 1 1.52 58 16 161 16 | 5 1 2.96 33 7 52 17 | 13 1 8.01 23 17 129 18 | 1 0 1.02 71 17 17 19 | 29 1 7.08 99 1 346 20 | 17 1 9.15 61 0 227 21 | 2 1 4.43 14 5 37 22 | 20 1 7.28 82 2 263 23 | 34 1 2.45 97 6 260 24 | 16 1 6.03 64 2 182 25 | 39 1 4.9 5 3 235 26 | 18 1 4.87 87 9 203 27 | 8 1 7.12 28 6 110 28 | 24 0 3.37 16 0 14 29 | 33 0 9.76 52 6 134 30 | 38 1 7.55 46 11 278 31 | 24 0 1.9 61 20 22 32 | 27 1 1.55 57 15 152 33 | 37 0 6.24 42 3 88 34 | 37 1 8.07 30 17 243 35 | 29 0 0.1 14 4 0 36 | 11 1 0.18 6 19 62 37 | 37 1 4.6 33 5 259 38 | 8 1 9.11 11 16 78 39 | 21 0 9.87 81 14 196 40 | 23 0 9.43 99 3 213 41 | 31 1 1.01 51 0 193 42 | 21 1 5.1 90 19 225 43 | 7 1 9.08 77 11 220 44 | 26 1 5.97 68 15 250 45 | 10 1 1.98 24 20 49 46 | 33 1 8.98 58 5 310 47 | 8 1 1.85 13 6 72 48 | 3 0 9.95 42 1 107 49 | 10 1 7.01 9 19 75 50 | 5 1 9.83 8 13 85 51 | 32 1 1.99 8 16 181 52 | 1 1 1.41 25 4 38 53 | 21 1 9.16 35 10 221 54 | 1 0 2.35 80 19 45 55 | 36 0 9.76 6 9 10 56 | 10 0 8.26 52 9 122 57 | 10 1 9.4 4 12 90 58 | 30 1 5.4 79 8 269 59 | 31 1 6.06 93 6 297 60 | 18 1 1.08 23 16 100 61 | 20 0 7.94 13 12 42 62 | 40 1 4 35 12 259 63 | 36 1 3.5 93 19 285 64 | 13 1 3.69 3 17 66 65 | 11 0 7.29 91 11 159 66 | 16 1 4.61 61 6 178 67 | 9 1 7.62 90 9 230 68 | 4 1 4.63 31 1 90 69 | 25 0 2.39 9 8 0 70 | 26 0 3.25 3 6 11 71 | 20 1 7.69 49 18 189 72 | 32 1 8.42 36 14 266 73 | 36 0 6.82 24 7 49 74 | 38 1 1.98 90 15 258 75 | 21 1 2.27 37 14 132 76 | 28 1 5.12 32 2 214 77 | 29 1 6.41 10 14 190 78 | 40 0 1.46 35 0 3 79 | 23 1 8.28 50 15 211 80 | 24 0 6.03 76 1 114 81 | 0 0 2.61 90 11 51 82 | 31 0 1.62 52 5 13 83 | 36 0 8.29 3 6 6 84 | 19 0 3.75 53 11 51 85 | 33 1 1.94 56 5 235 86 | 15 1 9.72 96 5 323 87 | 14 1 3.2 18 2 109 88 | 23 1 2.47 57 20 156 89 | 38 0 4.73 40 4 59 90 | 30 1 6.03 66 0 263 91 | 22 1 4.55 34 0 182 92 | 29 1 6.04 75 10 291 93 | 6 1 0.43 30 4 42 94 | 33 1 1.86 68 1 221 95 | 37 0 8.79 39 9 70 96 | 0 1 5.5 26 5 53 97 | 10 1 1.88 87 19 97 98 | 33 0 4.14 31 10 12 99 | 16 0 2.22 26 11 4 100 | 25 1 5.14 39 20 203 101 | 36 1 1.8 16 18 211 102 | 18 1 2.24 9 1 135 103 | 24 0 0.95 90 4 52 104 | 5 1 6.63 95 13 192 105 | 7 0 1.09 56 19 1 106 | 21 1 6.42 99 15 283 107 | 3 1 3.09 49 16 84 108 | 26 1 5.47 19 7 159 109 | 38 1 1.86 54 0 225 110 | 39 0 3.89 91 0 96 111 | 28 1 1.43 86 5 199 112 | 27 1 8.35 59 16 257 113 | 13 1 4.02 12 17 102 114 | 5 1 2.78 7 19 26 115 | 10 1 0.63 90 20 65 116 | 33 1 2.73 29 0 198 117 | 32 1 8.67 99 13 386 118 | 1 1 8.68 18 6 86 119 | 23 1 5.75 68 15 214 120 | 22 1 4.16 3 13 109 121 | 33 1 4.59 82 2 284 122 | 11 1 0.7 67 9 71 123 | 35 0 1.17 50 15 0 124 | 7 1 4.84 61 15 110 125 | 4 1 1.79 57 10 48 126 | 13 1 7.69 55 19 179 127 | 15 1 4.63 100 6 217 128 | 27 1 9.6 77 3 348 129 | 37 1 8.02 54 16 284 130 | 4 1 9.6 15 12 90 131 | 11 1 0.33 5 19 40 132 | 3 1 2.98 21 19 47 133 | 9 1 3.68 93 13 146 134 | 23 1 9.36 37 16 228 135 | 26 1 1.35 65 16 174 136 | 2 1 7.83 60 16 133 137 | 18 1 5.13 69 5 180 138 | 21 1 4.81 47 1 193 139 | 6 1 8.18 93 0 253 140 | 23 1 0.64 62 16 154 141 | 35 1 5.31 20 9 236 142 | 17 1 9.63 72 13 248 143 | 12 1 0.61 29 19 68 144 | 35 1 3.56 59 16 247 145 | 30 0 5.98 92 7 145 146 | 7 1 1.8 38 2 73 147 | 16 1 1.7 60 5 149 148 | 30 0 8.69 12 18 28 149 | 14 1 8.87 10 6 141 150 | 34 1 7.31 17 13 211 151 | 11 0 9.67 14 12 41 152 | 19 0 9.81 33 4 91 153 | 20 1 1.1 87 16 134 154 | 17 1 2.43 54 8 130 155 | 22 1 5.11 39 16 162 156 | 16 1 5.69 60 16 191 157 | 27 0 2.13 42 0 26 158 | 31 0 6.35 98 19 149 159 | 36 0 3.09 61 7 41 160 | 19 1 3.8 36 8 131 161 | 22 0 6.6 7 14 2 162 | 7 1 4.1 88 0 169 163 | 4 1 1.26 83 14 69 164 | 20 1 8.97 73 12 256 165 | 11 0 8.85 11 0 53 166 | 16 1 2.45 16 8 110 167 | 4 1 2.31 54 10 58 168 | 6 1 4.07 44 5 81 169 | 31 1 4.83 18 13 199 170 | 32 1 7.19 39 8 265 171 | 33 1 9.45 4 5 213 172 | 20 0 9.1 61 1 123 173 | 0 1 6.69 100 15 173 174 | 33 0 1.82 44 9 2 175 | 2 1 5.19 91 17 135 176 | 26 0 3.02 35 8 9 177 | 12 0 9.26 62 9 151 178 | 27 1 0.08 30 19 127 179 | 11 1 2.03 100 8 121 180 | 8 1 3.8 79 11 146 181 | 25 1 0.12 39 15 131 182 | 31 1 5.54 92 18 288 183 | 0 1 5.35 44 7 73 184 | 21 1 7.61 53 0 250 185 | 19 1 6.33 59 5 213 186 | 32 0 7.5 6 3 34 187 | 39 0 5.81 47 11 53 188 | 21 1 7.51 38 19 197 189 | 7 1 1.1 61 19 56 190 | 13 1 3.53 71 16 137 191 | 13 1 7.14 57 0 209 192 | 32 1 1.16 58 14 210 193 | 28 1 9.33 56 8 269 194 | 12 1 5.77 35 3 150 195 | 34 1 4.6 63 6 247 196 | 31 0 3.41 48 4 41 197 | 20 1 7 75 14 243 198 | 35 1 6.28 27 9 244 199 | 38 1 6.78 10 7 240 200 | 39 0 5.95 16 1 32 201 | 22 1 5.25 70 18 222 202 | 9 1 7.16 27 5 128 203 | 8 1 4.68 52 7 102 204 | 10 1 3.92 45 9 99 205 | 29 1 6.63 99 11 338 206 | 12 1 2.97 55 5 114 207 | 14 1 1.16 35 18 99 208 | 24 0 5.85 93 8 119 209 | 22 1 1.52 11 18 132 210 | 36 1 4.45 10 18 186 211 | 8 0 0.65 27 13 0 212 | 9 1 2.17 94 13 111 213 | 37 1 1.44 41 6 240 214 | 14 1 5.24 1 1 109 215 | 2 1 0.97 60 12 28 216 | 24 1 4.17 89 0 232 217 | 7 1 8.69 70 4 187 218 | 1 0 5.83 35 13 48 219 | 1 1 5.91 29 9 58 220 | 24 1 8.64 94 2 339 221 | 24 1 8.93 43 17 246 222 | 20 1 4.55 11 7 127 223 | 5 1 3.29 75 17 89 224 | 23 1 7.99 13 15 149 225 | 23 0 7.52 23 12 57 226 | 18 1 3.27 2 4 112 227 | 37 0 1.06 80 10 24 228 | 5 1 6.36 6 19 61 229 | 3 0 3.99 74 2 92 230 | 5 1 2.31 8 0 72 231 | 34 1 7.45 53 4 292 232 | 21 1 9.25 83 3 316 233 | 8 1 1.04 9 10 67 234 | 35 0 4.3 37 0 42 235 | 38 0 2.97 58 16 40 236 | 28 1 5.47 27 19 169 237 | 24 1 3.67 7 8 163 238 | 11 1 2.68 61 4 108 239 | 26 0 5.22 82 13 115 240 | 40 1 7.92 33 12 279 241 | 36 0 0.76 87 17 5 242 | 25 1 7.47 54 17 244 243 | 27 0 9.04 13 14 34 244 | 18 1 1.8 98 19 145 245 | 13 1 1.09 19 20 71 246 | 8 1 9.89 25 8 129 247 | 31 0 0.58 91 4 20 248 | 14 1 1.41 5 19 70 249 | 17 1 8.95 88 19 281 250 | 17 1 5.96 53 17 173 251 | 11 1 4.63 42 5 121 252 | 28 0 8.44 98 5 189 253 | 4 0 9.54 45 8 111 254 | 37 1 0.79 41 5 217 255 | 39 0 8.65 27 11 52 256 | 23 1 1.82 50 4 151 257 | 12 1 0.22 6 18 61 258 | 38 0 3.45 99 8 92 259 | 28 1 9.74 90 3 361 260 | 40 0 6.29 24 10 28 261 | 14 1 3.34 55 12 141 262 | 16 1 1.27 27 0 130 263 | 14 1 5.34 58 18 161 264 | 23 1 5.01 31 4 190 265 | 37 1 6.75 3 5 229 266 | 27 1 2 74 12 195 267 | 30 1 1.4 26 17 151 268 | 15 1 7.61 4 0 104 269 | 19 0 0.14 30 9 7 270 | 39 1 6.57 74 0 337 271 | 27 1 7.8 34 1 244 272 | 27 1 2.31 11 1 178 273 | 23 0 7.74 69 8 137 274 | 23 0 6.15 62 13 86 275 | 2 1 7.14 91 10 189 276 | 37 1 9.4 62 3 335 277 | 4 1 1.86 59 19 46 278 | 2 0 8.85 51 18 108 279 | 13 1 2.58 61 12 128 280 | 34 0 5.9 5 8 28 281 | 17 0 5.89 5 17 16 282 | 20 1 6.14 33 4 160 283 | 26 1 0.04 5 11 133 284 | 25 1 8.71 9 7 192 285 | 25 1 0.17 45 5 133 286 | 19 1 6.56 20 6 165 287 | 32 1 5.25 4 8 170 288 | 33 1 1.78 92 20 227 289 | 0 1 9.19 12 20 46 290 | 13 1 3.24 8 7 90 291 | 6 1 9.74 97 0 297 292 | 36 1 6.6 61 8 295 293 | 15 1 9.05 45 14 187 294 | 28 1 1.25 41 4 183 295 | 7 1 0.95 53 9 84 296 | 6 1 3.51 12 2 50 297 | 29 1 4.15 83 18 226 298 | 30 1 9.46 44 6 270 299 | 36 1 7.39 51 8 297 300 | 9 1 2.58 91 20 106 301 | 20 0 2.54 93 11 66 302 | 26 0 7.77 39 4 72 303 | 37 1 4.39 47 13 232 304 | 2 1 1.16 50 5 46 305 | 11 1 5.96 22 4 102 306 | 40 1 2.72 94 5 303 307 | 26 1 7.52 65 0 273 308 | 11 0 8.07 15 17 35 309 | 36 0 4.77 11 12 15 310 | 9 0 6 58 17 89 311 | 40 1 4.66 20 20 238 312 | 35 1 7.77 74 3 327 313 | 29 1 3.68 11 13 172 314 | 17 1 7.37 93 6 255 315 | 18 0 2.9 22 6 4 316 | 31 1 9.02 100 20 386 317 | 26 1 4.34 16 2 183 318 | 32 1 5.13 47 18 244 319 | 10 1 5.48 92 9 206 320 | 25 0 9.11 68 7 134 321 | 34 0 5.22 11 8 0 322 | 38 0 6.66 67 6 124 323 | 3 0 6.36 7 18 0 324 | 10 1 9.17 21 3 146 325 | 29 0 9.33 3 1 15 326 | 28 1 2.64 46 13 185 327 | 10 1 4.1 49 12 114 328 | 10 0 5.46 98 11 126 329 | 22 1 4.07 100 8 237 330 | 7 1 0.25 27 9 41 331 | 22 1 7.75 37 10 207 332 | 24 1 2.8 12 7 138 333 | 38 1 7.12 4 14 233 334 | 18 1 8.4 7 12 123 335 | 32 0 4.34 10 1 33 336 | 10 1 0.47 21 9 83 337 | 5 1 2.87 12 20 35 338 | 40 1 1.21 61 10 252 339 | 31 1 6.12 81 6 278 340 | 1 1 9.8 63 20 143 341 | 11 1 3.42 83 13 136 342 | 37 0 0.86 45 0 2 343 | 22 0 1.73 66 14 19 344 | 32 1 9.73 59 3 334 345 | 25 0 2.95 81 7 78 346 | 29 1 0.65 56 17 174 347 | 14 0 0.82 48 10 2 348 | 14 1 5.17 90 5 202 349 | 23 1 0.54 30 8 149 350 | 29 0 0.8 49 18 9 351 | 27 1 0.3 30 20 145 352 | 35 1 8.11 79 5 340 353 | 17 1 9.21 78 5 273 354 | 5 1 9.7 100 18 278 355 | 1 1 7.53 49 2 125 356 | 16 1 9.3 47 19 197 357 | 39 0 1.33 20 17 0 358 | 10 1 6.74 13 12 88 359 | 10 1 8.18 55 17 170 360 | 18 0 8.44 37 12 85 361 | 2 0 5.22 27 12 21 362 | 7 0 1.15 2 4 0 363 | 0 0 5.14 13 2 5 364 | 26 0 2.55 85 18 34 365 | 5 0 9.59 20 7 60 366 | 13 1 0.8 34 0 100 367 | 22 1 0.57 89 15 145 368 | 30 1 8.27 22 9 200 369 | 16 1 5.93 16 17 110 370 | 23 1 9.43 19 15 174 371 | 28 1 1.17 63 3 191 372 | 4 0 4.02 69 5 64 373 | 27 1 9.6 54 8 287 374 | 27 0 6.49 97 12 128 375 | 6 0 9.95 94 3 225 376 | 25 1 2.34 1 14 143 377 | 9 1 5.04 40 3 129 378 | 14 1 2.77 81 11 135 379 | 22 0 2.87 1 18 3 380 | 23 1 9.75 81 4 316 381 | 21 1 2.62 77 4 182 382 | 37 1 1.24 55 5 244 383 | 32 1 4.64 37 13 208 384 | 29 0 9.87 31 2 66 385 | 22 1 4.26 72 7 211 386 | 9 1 8.4 16 1 108 387 | 27 1 0.02 22 9 134 388 | 13 1 4.62 62 19 139 389 | 1 1 4.44 68 12 114 390 | 15 1 2.25 43 14 122 391 | 20 1 7.93 28 15 186 392 | 34 1 7.37 13 4 217 393 | 25 0 4.81 99 0 129 394 | 34 1 3.84 24 2 235 395 | 27 1 1.11 46 12 148 396 | 16 0 0.5 74 4 16 397 | 38 1 4.41 98 20 318 398 | 14 1 5.41 79 10 177 399 | 16 0 0.43 61 19 3 400 | 27 0 1.22 75 11 11 401 | 9 1 9.4 74 12 207 402 | 26 0 2.59 0 12 0 403 | 23 1 3.41 45 17 166 404 | 18 0 6.48 99 6 158 405 | 22 1 9.16 90 2 334 406 | 21 0 3.91 74 2 70 407 | 4 1 5.56 90 19 168 408 | 40 1 8.98 25 18 276 409 | 2 1 7.48 52 5 143 410 | 34 1 8.89 6 8 230 411 | 12 1 5.53 97 2 215 412 | 15 1 9.23 54 10 199 413 | 33 1 8.95 65 2 343 414 | 37 1 1.2 17 8 193 415 | 0 1 7.26 27 6 72 416 | 7 1 0.24 96 8 58 417 | 36 1 2.26 0 16 182 418 | 8 0 8.03 69 12 125 419 | 21 1 2.45 72 0 172 420 | 40 0 9.96 53 0 146 421 | 34 1 1.82 74 14 239 422 | 29 1 5.88 80 2 304 423 | 6 0 3.03 96 14 72 424 | 7 1 5.96 34 3 118 425 | 15 1 2.59 32 0 113 426 | 2 0 8.9 92 18 197 427 | 3 1 5.9 90 18 150 428 | 7 0 3.98 38 8 36 429 | 32 0 0.25 69 2 30 430 | 3 1 8.06 78 16 186 431 | 30 0 9.58 59 11 114 432 | 28 0 3.52 3 9 0 433 | 6 0 8.29 9 3 44 434 | 15 1 0.52 56 3 118 435 | 33 0 9.56 63 10 161 436 | 0 0 4.6 17 12 23 437 | 36 1 2.13 65 8 243 438 | 36 1 7.05 88 9 339 439 | 6 1 1.39 95 12 76 440 | 40 1 9.31 2 7 241 441 | 2 1 0.61 5 19 5 442 | 38 0 5.6 41 2 70 443 | 25 1 4 61 14 181 444 | 36 1 5.36 20 20 199 445 | 6 1 9.47 32 12 106 446 | 40 1 3.43 92 16 302 447 | 18 0 1.56 76 12 22 448 | 34 1 5.37 40 3 264 449 | 1 0 2.69 19 2 35 450 | 20 0 4.43 62 15 42 451 | 12 1 6.95 22 13 114 452 | 6 1 3.22 9 1 51 453 | 38 1 5.32 27 18 246 454 | 7 1 6.47 30 6 123 455 | 10 1 8.69 43 20 143 456 | 24 1 2.6 15 9 142 457 | 20 1 9.37 84 8 312 458 | 26 1 8.78 56 18 269 459 | 7 1 1.42 11 3 63 460 | 25 1 5.71 74 9 247 461 | 21 1 1.68 12 9 132 462 | 32 0 9.69 31 15 75 463 | 8 1 4.39 38 9 97 464 | 15 1 4.6 4 10 95 465 | 13 1 2.88 92 4 162 466 | 28 1 9.21 6 16 194 467 | 2 1 5.99 44 18 90 468 | 26 0 0.08 38 20 0 469 | 9 1 0.79 34 20 57 470 | 27 1 1.26 40 3 188 471 | 22 1 2.37 27 7 160 472 | 18 0 2.03 71 6 25 473 | 19 1 1.77 64 19 123 474 | 30 1 0.65 66 9 176 475 | 26 0 8.47 78 18 149 476 | 39 1 0.72 54 13 233 477 | 35 0 0.51 82 7 12 478 | 4 1 1.8 13 15 30 479 | 34 1 3.29 21 11 215 480 | 11 1 0.4 78 9 98 481 | 27 0 8.94 27 20 35 482 | 17 1 1.11 69 6 139 483 | 18 1 7.4 15 15 116 484 | 6 0 6.11 46 10 77 485 | 37 0 3.38 35 1 29 486 | 16 1 6.5 52 8 195 487 | 34 1 6.72 93 10 328 488 | 36 1 5.07 96 10 328 489 | 0 0 4.72 63 1 86 490 | 5 1 8.63 1 10 62 491 | 15 0 0.33 64 17 15 492 | 1 1 2.83 67 6 85 493 | 38 0 4.45 52 18 67 494 | 20 1 3.43 1 15 109 495 | 39 0 4 54 16 56 496 | 11 0 5.34 41 10 61 497 | 26 1 8.62 94 2 351 498 | 30 1 0.02 11 20 150 499 | 30 1 6.75 78 15 305 500 | 18 1 9.88 36 16 186 501 | 21 0 8.63 58 20 129 502 | 28 1 5.38 36 4 224 503 | 15 1 4.72 41 6 134 504 | 12 1 6.66 19 8 108 505 | 17 1 7.11 0 20 85 506 | 7 1 7.51 8 0 78 507 | 39 1 2.77 71 14 253 508 | 4 1 1.09 99 15 88 509 | 26 0 9.94 68 0 178 510 | 39 1 5.84 22 2 237 511 | 13 1 5.81 29 18 106 512 | 4 1 3.28 100 3 132 513 | 3 1 9.24 44 9 121 514 | 24 1 6.75 60 10 225 515 | 30 1 5.94 68 3 267 516 | 7 1 0.62 71 20 63 517 | 31 0 4.43 34 9 45 518 | 21 1 4.57 85 5 239 519 | 23 0 5.59 100 0 123 520 | 17 0 4.93 77 2 116 521 | 8 1 1.36 74 13 99 522 | 10 1 5.24 95 12 195 523 | 7 1 0.14 91 11 73 524 | 17 0 4.57 6 3 0 525 | 26 1 3 38 12 162 526 | 35 1 9.03 84 5 364 527 | 5 1 6.04 67 12 150 528 | 10 1 0.53 35 17 78 529 | 38 1 1.95 91 8 241 530 | 11 1 8.16 93 3 270 531 | 2 0 6.9 68 4 136 532 | 33 1 8.55 12 15 218 533 | 40 1 7.55 63 8 333 534 | 7 1 8.75 20 5 97 535 | 26 1 5.03 81 12 253 536 | 30 1 3.1 54 12 221 537 | 6 0 2.58 7 9 0 538 | 4 1 6.38 34 1 117 539 | 37 1 5.25 80 7 327 540 | 9 0 5.37 71 16 104 541 | 33 0 2.91 78 13 54 542 | 4 1 7.8 17 11 85 543 | 0 1 0.24 72 17 33 544 | 19 1 5.74 26 17 136 545 | 33 1 3.9 51 18 214 546 | 3 1 1.91 23 9 25 547 | 14 0 0.56 32 8 0 548 | 13 1 6.85 91 7 239 549 | 19 1 8.5 87 8 305 550 | 37 1 2.11 24 6 236 551 | 2 1 7.55 75 5 181 552 | 17 1 2.56 42 11 135 553 | 18 0 1.28 9 5 7 554 | 17 1 4.24 24 13 108 555 | 30 1 2.24 9 20 141 556 | 11 1 5.42 97 0 216 557 | 25 1 6.38 84 18 268 558 | 26 1 3.94 15 5 176 559 | 19 0 1.92 1 0 1 560 | 21 1 0.58 52 17 112 561 | 31 1 3.76 69 8 233 562 | 8 1 9.65 2 12 52 563 | 21 0 6.47 94 2 163 564 | 19 1 9.21 65 8 248 565 | 1 0 5.47 66 4 74 566 | 5 0 4.75 93 3 109 567 | 13 1 0.06 7 6 74 568 | 20 1 1.75 82 5 161 569 | 30 1 2.28 63 20 215 570 | 0 1 2.65 47 5 58 571 | 9 1 1.26 76 18 74 572 | 33 1 5.83 92 1 334 573 | 2 1 1.44 18 19 27 574 | 29 1 8.44 19 14 196 575 | 7 1 4.9 90 20 133 576 | 21 1 8.72 42 0 213 577 | 30 1 7.7 75 3 328 578 | 6 1 0.12 5 6 62 579 | 40 1 3.9 91 2 329 580 | 33 1 2.03 70 2 210 581 | 2 0 9.26 17 20 42 582 | 12 1 1.29 33 10 92 583 | 5 0 0.24 50 20 1 584 | 22 1 1.43 73 7 171 585 | 35 1 5.36 75 18 270 586 | 4 1 0.45 51 7 54 587 | 31 1 5.69 100 0 330 588 | 27 1 3.6 88 10 243 589 | 6 1 1.44 50 7 89 590 | 31 1 8.99 44 19 266 591 | 0 1 6.42 24 14 72 592 | 35 1 6.44 1 16 174 593 | 10 1 6.74 93 17 218 594 | 33 1 8.54 85 3 347 595 | 4 0 9.43 85 4 201 596 | 7 1 9.1 12 7 83 597 | 5 0 4.51 84 16 86 598 | 28 1 5.96 21 3 184 599 | 29 1 5.12 93 16 262 600 | 10 1 3.17 13 2 77 601 | 1 1 2.06 97 8 100 602 | 4 1 6.74 57 17 110 603 | 24 0 5.62 51 2 68 604 | 0 0 2.84 34 4 14 605 | 11 0 8.69 49 0 131 606 | 25 1 1.05 18 4 156 607 | 19 1 6.13 88 0 233 608 | 34 0 9.47 34 7 71 609 | 25 1 7.13 0 1 159 610 | 6 1 0.79 21 7 47 611 | 26 1 1.39 51 0 191 612 | 1 0 4.44 50 7 72 613 | 31 1 1.28 20 17 174 614 | 6 1 4.77 73 10 143 615 | 20 1 3.06 9 8 110 616 | 19 1 4.21 27 0 154 617 | 37 0 5.52 31 7 47 618 | 22 1 6.28 16 19 154 619 | 17 1 2.66 48 12 123 620 | 26 1 6.67 28 0 188 621 | 19 1 2.23 66 5 147 622 | 25 1 1.11 9 12 143 623 | 37 1 1.66 2 20 188 624 | 15 1 7.5 9 1 105 625 | 0 1 7.18 59 8 116 626 | 24 0 4.14 53 9 72 627 | 22 1 5.43 37 16 167 628 | 20 1 6.61 22 18 161 629 | 0 1 1.44 9 19 8 630 | 33 0 2.07 41 6 13 631 | 18 1 2.89 41 20 128 632 | 1 1 2.28 57 3 47 633 | 8 1 8.02 100 15 257 634 | 23 1 2.99 34 17 160 635 | 38 1 1.95 32 0 231 636 | 21 1 7.43 35 0 179 637 | 35 1 0.46 27 18 198 638 | 34 1 1.71 82 12 228 639 | 8 0 7.7 77 8 143 640 | 28 1 9.95 44 12 280 641 | 23 1 8.81 26 10 197 642 | 7 1 7.79 72 12 176 643 | 23 1 9.36 38 3 225 644 | 32 0 4.8 80 6 103 645 | 25 1 7.47 93 1 332 646 | 18 1 4.93 14 9 144 647 | 34 1 3.52 76 7 272 648 | 19 1 1.55 40 3 136 649 | 21 1 2.21 30 16 145 650 | 39 1 8.94 83 18 396 651 | 13 0 0.93 12 4 20 652 | 22 1 4.26 50 8 167 653 | 22 0 7.54 91 14 151 654 | 27 1 5.94 96 0 303 655 | 17 0 3.01 47 17 7 656 | 6 0 2.19 92 7 40 657 | 1 1 4 8 6 41 658 | 29 0 0.65 24 12 0 659 | 14 1 2.79 29 8 103 660 | 27 1 2.42 51 13 180 661 | 11 1 7.28 52 0 160 662 | 19 0 3.88 98 11 78 663 | 8 1 9.4 87 1 272 664 | 34 1 6.43 14 1 216 665 | 9 1 0.18 13 17 58 666 | 24 1 1.29 25 20 146 667 | 15 1 0.22 38 0 92 668 | 24 1 2.43 74 7 193 669 | 29 1 3.79 64 13 206 670 | 3 1 8.5 77 3 196 671 | 24 0 8.58 25 10 68 672 | 40 1 0.52 74 19 237 673 | 7 1 3.68 6 16 62 674 | 34 0 9.34 45 6 125 675 | 9 1 5.86 83 8 179 676 | 33 1 7.93 0 2 215 677 | 18 1 2.69 11 19 124 678 | 2 1 3.11 17 5 34 679 | 9 1 4.62 60 19 140 680 | 6 1 8.42 56 12 180 681 | 36 1 6.39 6 3 215 682 | 15 1 8.78 98 1 285 683 | 34 1 0.21 7 18 170 684 | 3 1 5.45 26 11 64 685 | 26 0 3.52 71 1 60 686 | 18 1 8.84 47 10 219 687 | 3 1 4.25 19 14 49 688 | 4 0 7.1 23 0 53 689 | 19 1 0.7 47 11 102 690 | 7 1 3.08 62 19 78 691 | 38 1 5.52 87 19 311 692 | 20 1 4.53 63 19 188 693 | 1 1 6.47 68 10 133 694 | 24 1 1.05 58 5 152 695 | 14 0 3.27 69 4 72 696 | 35 0 5.75 34 9 61 697 | 39 1 1.78 83 7 241 698 | 20 0 8.13 73 16 151 699 | 13 0 7.07 83 19 112 700 | 15 1 3.38 89 9 151 701 | 35 1 5.06 66 1 293 702 | 5 1 6.22 35 5 99 703 | 23 0 8.74 71 8 137 704 | 17 1 1.32 37 1 129 705 | 19 1 3.09 97 1 192 706 | 29 1 0.59 18 17 169 707 | 19 1 2.04 100 8 171 708 | 24 1 0.37 46 13 143 709 | 6 1 9.33 60 16 167 710 | 34 1 7.21 82 11 336 711 | 33 1 6.22 44 5 257 712 | 32 1 3.66 10 16 181 713 | 21 1 2.18 75 20 138 714 | 0 1 8.29 67 11 149 715 | 7 1 2.38 4 7 70 716 | 29 1 2.36 82 6 213 717 | 15 1 2.13 61 12 109 718 | 16 1 6.53 28 12 148 719 | 32 1 8.54 7 4 194 720 | 15 1 1.94 75 16 142 721 | 17 1 8.72 43 6 219 722 | 15 0 1.06 45 20 0 723 | 32 1 4.51 20 12 195 724 | 37 0 5.77 23 7 32 725 | 32 1 6.35 34 11 239 726 | 40 1 0.78 57 17 240 727 | 32 1 7.67 61 15 282 728 | 4 1 4.64 33 5 74 729 | 8 0 5.53 48 4 62 730 | 33 1 3.96 13 4 198 731 | 5 1 6.27 50 14 103 732 | 18 1 7.19 48 11 193 733 | 36 0 3.03 53 0 67 734 | 33 1 4.61 51 0 249 735 | 16 1 1.93 15 4 90 736 | 33 1 5.47 90 18 281 737 | 20 1 7.91 72 8 274 738 | 18 1 9.6 93 8 312 739 | 7 1 7.91 74 8 176 740 | 37 1 4.01 72 14 270 741 | 7 1 2.9 78 2 102 742 | 29 1 5.59 54 19 234 743 | 5 1 5.17 60 1 132 744 | 12 1 1.24 69 11 115 745 | 14 1 7.39 86 10 252 746 | 1 1 0.09 63 17 18 747 | 30 1 4.93 75 2 262 748 | 31 1 4.42 92 20 272 749 | 30 1 8.96 8 18 203 750 | 36 0 5.3 94 9 125 751 | 28 1 7.11 74 16 259 752 | 7 1 9.91 68 7 220 753 | 36 1 0.48 46 16 213 754 | 36 0 6.79 46 0 88 755 | 10 1 1.21 7 5 61 756 | 18 1 8.06 71 15 252 757 | 6 0 0.68 89 5 35 758 | 11 1 1.14 47 20 71 759 | 17 1 9.1 67 17 250 760 | 39 0 4.73 90 19 102 761 | 8 1 2.79 41 17 93 762 | 0 1 1.03 89 12 57 763 | 8 1 1.77 25 5 62 764 | 1 1 4.14 31 11 61 765 | 7 1 4.33 65 9 105 766 | 7 1 9.47 100 7 274 767 | 5 1 5.57 35 7 88 768 | 33 1 1.84 95 17 212 769 | 38 0 6.11 78 3 139 770 | 38 1 1.82 80 15 229 771 | 35 1 5.36 42 11 264 772 | 2 1 8.13 94 15 205 773 | 34 1 5.88 97 19 306 774 | 40 1 2.36 26 12 231 775 | 13 1 2.15 64 13 134 776 | 8 1 0.79 47 2 96 777 | 39 0 1.06 12 11 13 778 | 37 0 0.15 23 8 9 779 | 17 1 1.34 91 14 135 780 | 37 1 0.51 18 15 216 781 | 20 1 8.01 52 6 229 782 | 24 1 8.84 16 15 182 783 | 0 0 4.81 28 20 34 784 | 15 1 9.18 13 6 127 785 | 24 1 8.01 4 2 154 786 | 25 1 4 70 8 234 787 | 18 1 6.67 84 15 222 788 | 25 1 5.02 92 10 275 789 | 27 1 9.44 13 8 198 790 | 17 0 5 30 15 21 791 | 20 1 3.66 67 5 188 792 | 5 0 7.34 12 5 19 793 | 0 1 9.51 16 16 69 794 | 27 0 6.23 5 18 9 795 | 29 1 6.51 82 15 289 796 | 38 0 3.05 71 4 50 797 | 25 0 6.36 38 5 80 798 | 32 1 1.86 43 3 210 799 | 6 1 7.26 59 13 156 800 | 40 1 4.86 79 16 291 801 | 5 0 4.33 89 6 101 802 | 5 1 4.51 99 16 153 803 | 37 1 4.74 92 20 290 804 | 30 1 3.73 64 4 224 805 | 8 1 4.2 95 19 157 806 | 35 0 2.85 96 12 58 807 | 2 0 9.8 76 11 169 808 | 31 1 2.91 29 20 185 809 | 17 1 9.76 30 7 202 810 | 0 1 0.64 8 12 0 811 | 0 1 9.57 34 12 111 812 | 21 1 4.35 29 1 180 813 | 14 1 0.18 92 6 93 814 | 9 1 5.14 81 7 152 815 | 21 1 7.05 32 18 192 816 | 10 0 8.99 94 20 168 817 | 38 1 1.19 74 2 253 818 | 6 1 9.5 45 18 137 819 | 9 0 0.37 78 12 4 820 | 14 1 4.21 54 15 133 821 | 16 1 6.07 96 8 258 822 | 11 0 7.29 81 5 153 823 | 31 1 1.34 83 8 198 824 | 21 1 2.45 46 13 142 825 | 22 1 0.29 61 9 148 826 | 27 1 0.14 13 11 148 827 | 26 1 6.26 87 4 300 828 | 1 0 8.21 73 11 148 829 | 19 1 3.16 28 2 160 830 | 15 0 5.59 1 2 13 831 | 30 1 1.4 7 10 169 832 | 15 1 3.96 16 0 111 833 | 18 1 7.94 14 13 143 834 | 0 1 8.44 15 20 34 835 | 20 0 0.32 84 12 27 836 | 12 1 9.34 92 5 269 837 | 25 1 1.96 2 7 159 838 | 17 1 8.31 56 18 189 839 | 31 1 5.61 4 18 183 840 | 15 1 7.08 78 10 211 841 | 38 0 5.65 14 15 25 842 | 39 1 4.67 8 12 221 843 | 10 0 8.09 16 18 30 844 | 6 0 8.22 39 6 71 845 | 20 1 4.57 2 12 118 846 | 28 1 8.74 86 14 349 847 | 13 1 9.04 49 12 174 848 | 24 1 2.54 66 8 185 849 | 5 0 6.41 41 4 66 850 | 40 1 0.59 68 1 232 851 | 23 1 2.05 43 11 144 852 | 13 1 4.75 83 20 172 853 | 16 1 5.1 86 11 220 854 | 28 0 2.34 87 0 64 855 | 26 0 2.6 98 0 76 856 | 30 0 1.48 48 17 0 857 | 38 1 7.72 3 15 218 858 | 7 1 2.81 3 5 74 859 | 40 1 5.67 69 6 324 860 | 31 1 2.99 36 12 215 861 | 30 1 6.18 47 20 219 862 | 28 1 5.02 36 4 224 863 | 23 1 0.62 54 9 156 864 | 31 0 6.08 11 2 35 865 | 16 1 7.27 49 7 192 866 | 20 1 3.66 28 8 132 867 | 5 1 9.93 52 11 177 868 | 36 0 0.02 92 4 12 869 | 39 1 7.88 52 0 317 870 | 29 0 1.64 73 15 21 871 | 37 0 8.97 78 2 171 872 | 5 1 1.3 39 11 67 873 | 14 1 1.32 24 20 80 874 | 20 1 9.35 19 14 180 875 | 24 1 8.41 47 19 210 876 | 4 1 0.84 76 17 48 877 | 15 1 1.17 94 18 115 878 | 22 1 8.08 95 20 292 879 | 27 1 3.39 28 9 170 880 | 5 1 2.03 91 17 77 881 | 17 1 8.91 10 14 122 882 | 40 0 6.76 95 9 151 883 | 7 1 7.79 2 13 72 884 | 11 1 3.69 61 20 134 885 | 26 0 5.28 1 18 0 886 | 14 1 6.1 5 5 88 887 | 29 1 5.17 10 10 195 888 | 26 0 4.28 91 14 88 889 | 8 1 4.04 73 3 128 890 | 11 1 2.18 61 10 94 891 | 2 1 7.6 77 9 157 892 | 37 0 0.33 22 5 0 893 | 27 1 3.22 23 7 182 894 | 16 1 2.01 72 6 144 895 | 22 1 8.74 19 19 165 896 | 28 1 3.37 92 9 240 897 | 32 1 4.58 1 1 174 898 | 27 1 1.92 71 17 176 899 | 11 1 2.23 14 12 82 900 | 24 0 5.05 24 12 42 901 | 9 1 5.38 16 7 79 902 | 27 1 0.42 88 1 164 903 | 33 1 9.22 53 12 305 904 | 14 1 2.52 100 11 173 905 | 11 0 2.87 82 3 83 906 | 16 1 3.85 66 9 146 907 | 13 1 8.9 91 6 269 908 | 8 0 3.59 70 15 50 909 | 6 1 0.9 67 2 92 910 | 26 0 2.87 89 13 82 911 | 1 1 1.41 6 17 5 912 | 28 1 9.72 6 18 183 913 | 23 1 3.71 83 12 204 914 | 7 0 3.69 10 3 0 915 | 29 1 8.21 39 8 251 916 | 32 1 3.97 17 7 213 917 | 16 1 5.7 22 13 142 918 | 33 0 5.86 42 7 67 919 | 6 0 9.84 89 13 209 920 | 29 1 9.27 39 10 259 921 | 39 0 2.16 57 20 12 922 | 17 1 6.01 71 3 210 923 | 33 1 5.44 73 18 255 924 | 16 1 7.39 68 20 196 925 | 18 1 8.06 33 8 190 926 | 37 0 9.9 11 1 48 927 | 36 0 1.67 2 5 0 928 | 39 1 2.53 82 11 286 929 | 9 1 9.57 8 4 84 930 | 22 1 7.01 54 16 218 931 | 19 1 9.01 96 4 314 932 | 21 1 9.77 78 14 314 933 | 19 0 1.4 86 10 41 934 | 10 0 4.59 28 3 36 935 | 21 0 7.62 3 15 0 936 | 7 0 7.26 26 4 65 937 | 16 1 4.35 66 16 154 938 | 10 0 2.8 92 13 52 939 | 33 1 0.04 67 6 211 940 | 20 1 6.15 33 4 176 941 | 39 1 9.27 79 4 382 942 | 26 1 5.65 63 15 225 943 | 20 1 1.09 15 16 103 944 | 4 0 2.14 85 2 45 945 | 5 1 5.77 54 17 124 946 | 16 1 3.49 51 7 149 947 | 6 1 7.36 84 9 184 948 | 0 1 9.26 64 8 164 949 | 40 1 6.27 27 17 274 950 | 40 1 0.67 36 6 214 951 | 38 1 2.12 88 0 256 952 | 30 1 0.27 99 15 175 953 | 32 0 4.96 52 14 45 954 | 40 0 0.93 26 13 11 955 | 26 1 6.38 66 9 246 956 | 36 1 6.96 40 7 270 957 | 10 1 5.01 46 10 130 958 | 0 1 9.51 50 11 144 959 | 39 0 7.98 32 19 66 960 | 35 1 2.15 24 2 207 961 | 9 1 9.91 45 1 164 962 | 14 0 4.1 48 16 57 963 | 24 1 8.78 20 9 194 964 | 15 1 0.32 53 13 78 965 | 23 1 8.5 45 16 233 966 | 38 1 4.37 68 4 305 967 | 7 1 9.79 38 7 165 968 | 16 1 4.64 98 1 234 969 | 25 1 0.36 6 14 139 970 | 33 1 1.78 36 0 192 971 | 33 1 1.26 58 11 210 972 | 3 1 8.12 76 9 182 973 | 6 1 3.33 40 11 67 974 | 8 1 1.37 33 9 88 975 | 15 0 1.5 14 14 0 976 | 25 1 6.08 91 2 303 977 | 16 1 8.4 92 11 289 978 | 36 1 1.14 0 0 187 979 | 3 0 0.78 30 1 31 980 | 5 0 1.97 53 17 10 981 | 34 1 0.99 92 3 246 982 | 30 1 7.52 57 16 248 983 | 11 0 0.13 8 16 0 984 | 40 1 8.95 95 5 439 985 | 38 1 2.97 38 4 235 986 | 36 1 9.47 55 20 332 987 | 19 0 2.43 46 8 43 988 | 17 1 1.35 20 7 127 989 | 2 1 4.8 91 5 160 990 | 27 1 9.9 92 7 380 991 | 33 1 4.14 23 10 224 992 | 16 0 3.19 15 14 12 993 | 22 0 4.65 38 2 59 994 | 24 0 7.81 43 3 100 995 | 14 1 7.26 66 1 212 996 | 19 1 3.84 39 17 148 997 | 8 1 4.1 76 17 142 998 | 25 0 1.66 13 5 24 999 | 20 1 8.95 52 11 230 1000 | 16 1 9.4 91 15 284 1001 | 39 1 2.58 47 0 241 1002 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | catboost==1.1.1 4 | matplotlib==3.6.0 5 | numpy==1.23.3 6 | opencv-python==4.6.0.66 7 | pandas==1.5.0 8 | Pillow==9.2.0 9 | scikit-learn==1.2.0 10 | seaborn==0.12.2 11 | shap==0.41.0 12 | torch==1.13.1 13 | torchvision==0.14.1 14 | xgboost==1.7.3 15 | -------------------------------------------------------------------------------- /src/additional_resources/CatBoostClassifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#imports\n", 10 | "import pandas as pd\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "from catboost import CatBoostClassifier\n", 14 | "import xgboost as xgb\n", 15 | "\n", 16 | "import shap\n", 17 | "\n", 18 | "from sklearn.metrics import accuracy_score,confusion_matrix\n", 19 | "\n", 20 | "#set figure background to white\n", 21 | "plt.rcParams.update({'figure.facecolor':'white'})" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Dataset" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "#load data \n", 38 | "data = pd.read_csv(\"../../data/mushrooms.csv\")\n", 39 | "\n", 40 | "#get features\n", 41 | "y = data['class']\n", 42 | "y = y.astype('category').cat.codes\n", 43 | "X = data.drop('class', axis=1)\n", 44 | "\n", 45 | "print(len(data))\n", 46 | "data.head()" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "# XGBoost" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# Create dummy variables for the categorical features\n", 63 | "X_dummy = pd.get_dummies(X)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# Fit model\n", 73 | "model = xgb.XGBClassifier()\n", 74 | "model.fit(X_dummy, y)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "#Get SHAP values\n", 84 | "explainer = shap.Explainer(model)\n", 85 | "shap_values = explainer(X_dummy)\n", 86 | "\n", 87 | "# Display SHAP values for the first observation\n", 88 | "shap.plots.waterfall(shap_values[0])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "# CatBoost" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "model = CatBoostClassifier(iterations=20,\n", 105 | " learning_rate=0.01,\n", 106 | " depth=3)\n", 107 | "\n", 108 | "# train model\n", 109 | "cat_features = list(range(len(X.columns)))\n", 110 | "model.fit(X, y, cat_features)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "#Get SHAP values\n", 120 | "explainer = shap.Explainer(model)\n", 121 | "shap_values = explainer(X)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Display SHAP values for the first observation\n", 131 | "shap.plots.waterfall(shap_values[0])" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "# Mean SHAP \n", 141 | "shap.plots.bar(shap_values)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# Beeswarm plot \n", 151 | "shap.plots.beeswarm(shap_values)" 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "shap", 158 | "language": "python", 159 | "name": "shap" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.10.4" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /src/additional_resources/IsolationForest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Imports \n", 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import seaborn as sns\n", 14 | "\n", 15 | "import shap\n", 16 | "\n", 17 | "from sklearn.ensemble import IsolationForest\n", 18 | "from ucimlrepo import fetch_ucirepo\n", 19 | "\n", 20 | "# Set figure background to white\n", 21 | "plt.rcParams.update({'figure.facecolor':'white'})" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Data Cleaning and Feature Engineering" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# Fetch dataset from UCI repository\n", 38 | "power_consumption = fetch_ucirepo(id=235)\n", 39 | "\n", 40 | "print(power_consumption.variables) " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Get all features\n", 50 | "data = power_consumption.data.features\n", 51 | "data['Date'] = pd.to_datetime(data['Date'], format='%d/%m/%Y')\n", 52 | "\n", 53 | "# List of features to check\n", 54 | "feature_columns = ['Global_active_power', 'Global_reactive_power', 'Voltage', \n", 55 | " 'Global_intensity', 'Sub_metering_1', 'Sub_metering_2', 'Sub_metering_3']\n", 56 | "\n", 57 | "# Convert feature columns to numeric and replace any errors with NaN\n", 58 | "data[feature_columns] = data[feature_columns].apply(pd.to_numeric, errors='coerce')\n", 59 | "\n", 60 | "# Drop rows where all feature columns are missing (NaN) \n", 61 | "data_cleaned = data.dropna(subset=feature_columns, how='all')\n", 62 | "\n", 63 | "# Drop rows where ALL feature columns are NaN\n", 64 | "data_cleaned.head()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Group by 'Date' and calculate mean and standard deviation (ignore NaN values)\n", 74 | "data_aggregated = data_cleaned.groupby('Date')[feature_columns].agg(['mean', 'std'])\n", 75 | "\n", 76 | "# Rename columns to the desired format (MEAN_ColumnName, STD_ColumnName)\n", 77 | "data_aggregated.columns = [\n", 78 | " f'{agg_type.upper()}_{col}' for col, agg_type in data_aggregated.columns\n", 79 | "]\n", 80 | "\n", 81 | "# Reset the index\n", 82 | "data_aggregated.reset_index(inplace=True)\n", 83 | "\n", 84 | "# Display the result\n", 85 | "print(data_aggregated.shape)\n", 86 | "data_aggregated.head()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "# Train IsolationForest" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 10, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# Parameters\n", 103 | "n_estimators = 100 # Number of trees\n", 104 | "sample_size = 256 # Number of samples used to train each tree\n", 105 | "contamination = 0.02 # Expected proportion of anomalies" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# Select Features\n", 115 | "features = data_aggregated.drop('Date', axis=1)\n", 116 | "\n", 117 | "# Train Isolation Forest\n", 118 | "iso_forest = IsolationForest(n_estimators=n_estimators, \n", 119 | " contamination=contamination, \n", 120 | " max_samples=sample_size,\n", 121 | " random_state=42)\n", 122 | "\n", 123 | "iso_forest.fit(features)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "data_aggregated['anomaly_score'] = iso_forest.decision_function(features)\n", 133 | "data_aggregated['anomaly'] = iso_forest.predict(features)\n", 134 | "\n", 135 | "data_aggregated['anomaly'].value_counts()" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# Visualization of the results\n", 145 | "plt.figure(figsize=(10, 5))\n", 146 | "\n", 147 | "# Plot normal instances\n", 148 | "normal = data_aggregated[data_aggregated['anomaly'] == 1]\n", 149 | "plt.scatter(normal['Date'], normal['anomaly_score'], label='Normal')\n", 150 | "\n", 151 | "# Plot anomalies\n", 152 | "anomalies = data_aggregated[data_aggregated['anomaly'] == -1]\n", 153 | "plt.scatter(anomalies['Date'], anomalies['anomaly_score'], label='Anomaly')\n", 154 | "\n", 155 | "plt.xlabel(\"Instance\")\n", 156 | "plt.ylabel(\"Anomaly Score\")\n", 157 | "plt.legend()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "# KernelSHAP with Anomaly Score\n" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "# Using the anomaly score and TreeSHAP (this code won't work)\n", 174 | "explainer = shap.TreeExplainer(iso_forest.decision_function, features)\n", 175 | "shap_values = explainer(features)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "# Select all anomalies and 100 random normal instances\n", 185 | "normal_sample = np.random.choice(normal.index,size=100,replace=False)\n", 186 | "sample = np.append(anomalies.index,normal_sample)\n", 187 | "\n", 188 | "len(sample) # 129" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# Using the anomaly score and KernelSHAP\n", 198 | "explainer = shap.Explainer(iso_forest.decision_function, features)\n", 199 | "shap_values = explainer(features.iloc[sample])" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# Plot waterfall plot of an anomaly\n", 209 | "shap.plots.waterfall(shap_values[0])" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "# Plot waterfall plot of a normal instance\n", 219 | "shap.plots.waterfall(shap_values[100])" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "# MeanSHAP Plot\n", 229 | "shap.plots.bar(shap_values)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# Beeswarm plot\n", 239 | "shap.plots.beeswarm(shap_values)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "# TreeSHAP with Path Length" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 22, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# Calculate SHAP values\n", 256 | "explainer = shap.TreeExplainer(iso_forest)\n", 257 | "shap_values = explainer(features)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# Waterfall plot for an anomaly\n", 267 | "shap.plots.waterfall(shap_values[0])" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "# Waterfall plot for a normal instance\n", 277 | "shap.plots.waterfall(shap_values[2])" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "# Calculate f(x)\n", 287 | "path_length = shap_values.base_values + shap_values.values.sum(axis=1)\n", 288 | "\n", 289 | "# Get f(x) for anomalies and normal instances\n", 290 | "anomalies = data_aggregated[data_aggregated['anomaly'] == -1]\n", 291 | "path_length_anomalies = path_length[anomalies.index]\n", 292 | "\n", 293 | "normal = data_aggregated[data_aggregated['anomaly'] == 1]\n", 294 | "path_length_normal = path_length[normal.index]\n", 295 | "\n", 296 | "# Plot boxplots for f(x)\n", 297 | "plt.figure(figsize=(10, 5))\n", 298 | "plt.boxplot([path_length_anomalies, path_length_normal], labels=['Anomaly','Normal'])\n", 299 | "plt.ylabel(\"Average Path Length f(x)\")" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "# MeanSHAP\n", 309 | "shap.plots.bar(shap_values)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "# MeanSHAP\n", 319 | "shap.plots.beeswarm(shap_values)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 26, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "# Interaction values\n", 329 | "shap_interaction_values = explainer.shap_interaction_values(features)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "# Get absolute mean of matrices\n", 339 | "mean_shap = np.abs(shap_interaction_values).mean(0)\n", 340 | "mean_shap = np.round(mean_shap, 1)\n", 341 | "\n", 342 | "df = pd.DataFrame(mean_shap, index=features.columns, columns=features.columns)\n", 343 | "\n", 344 | "# Times off diagonal by 2\n", 345 | "df.where(df.values == np.diagonal(df), df.values * 2, inplace=True)\n", 346 | "\n", 347 | "# Display\n", 348 | "sns.set(font_scale=1)\n", 349 | "sns.heatmap(df, cmap=\"coolwarm\", annot=True)\n", 350 | "plt.yticks(rotation=0)" 351 | ] 352 | } 353 | ], 354 | "metadata": { 355 | "kernelspec": { 356 | "display_name": "shap", 357 | "language": "python", 358 | "name": "shap" 359 | }, 360 | "language_info": { 361 | "codemirror_mode": { 362 | "name": "ipython", 363 | "version": 3 364 | }, 365 | "file_extension": ".py", 366 | "mimetype": "text/x-python", 367 | "name": "python", 368 | "nbconvert_exporter": "python", 369 | "pygments_lexer": "ipython3", 370 | "version": "3.10.4" 371 | } 372 | }, 373 | "nbformat": 4, 374 | "nbformat_minor": 2 375 | } 376 | -------------------------------------------------------------------------------- /src/additional_resources/RandomForestRegressor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SHAP for RandomForestRegressor\n", 8 | "
\n", 9 | "Dataset: https://www.kaggle.com/datasets/conorsully1/interaction-dataset" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Imports\n", 19 | "import pandas as pd\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "\n", 22 | "from sklearn.ensemble import RandomForestRegressor\n", 23 | "\n", 24 | "import shap\n", 25 | "shap.initjs() \n", 26 | "\n", 27 | "# Set figure background to white\n", 28 | "plt.rcParams.update({'figure.facecolor':'white'})" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Dataset" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "#import dataset\n", 45 | "data = pd.read_csv(\"../../data/interaction_dataset.csv\",sep='\\t')\n", 46 | "\n", 47 | "y = data['bonus']\n", 48 | "X = data.drop('bonus', axis=1)\n", 49 | "\n", 50 | "print(len(data))\n", 51 | "data.head()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "# Modelling" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "#Train model\n", 68 | "model = RandomForestRegressor(n_estimators=100) \n", 69 | "model.fit(X, y)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# SHAP Values" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "#Get SHAP values\n", 86 | "explainer = shap.Explainer(model)\n", 87 | "shap_values = explainer(X)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# Plot waterfall\n", 97 | "shap.plots.waterfall(shap_values[0])" 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "shap", 104 | "language": "python", 105 | "name": "shap" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.10.4" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /src/archive/image_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Using SHAP to debug a PyTorch Image Regression Model" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 2, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "# Imports\n", 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "\n", 22 | "import glob \n", 23 | "import random \n", 24 | "\n", 25 | "from PIL import Image\n", 26 | "import cv2\n", 27 | "\n", 28 | "import torch\n", 29 | "import torchvision\n", 30 | "from torchvision import transforms\n", 31 | "from torch.utils.data import DataLoader\n", 32 | "\n", 33 | "import shap\n", 34 | "from sklearn.metrics import mean_squared_error" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#Load example image\n", 44 | "name = \"32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg\"\n", 45 | "x = int(name.split(\"_\")[0])\n", 46 | "y = int(name.split(\"_\")[1])\n", 47 | "\n", 48 | "img = Image.open(\"../data/room_1/\" + name)\n", 49 | "img = np.array(img)\n", 50 | "cv2.circle(img, (x, y), 8, (0, 255, 0), 3)\n", 51 | "\n", 52 | "plt.imshow(img)\n", 53 | "\n", 54 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/example.png\"\n", 55 | "plt.savefig(path, bbox_inches='tight',facecolor='w', edgecolor='w', transparent=False,dpi=200)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Model Training" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "class ImageDataset(torch.utils.data.Dataset):\n", 72 | " def __init__(self, paths, transform):\n", 73 | "\n", 74 | " self.transform = transform\n", 75 | " self.paths = paths\n", 76 | "\n", 77 | " def __getitem__(self, idx):\n", 78 | " \"\"\"Get image and target (x, y) coordinates\"\"\"\n", 79 | "\n", 80 | " # Read image\n", 81 | " path = self.paths[idx]\n", 82 | " image = cv2.imread(path, cv2.IMREAD_COLOR)\n", 83 | " image = Image.fromarray(image)\n", 84 | "\n", 85 | " # Transform image\n", 86 | " image = self.transform(image)\n", 87 | " \n", 88 | " # Get target\n", 89 | " target = self.get_target(path)\n", 90 | " target = torch.Tensor(target)\n", 91 | "\n", 92 | " return image, target\n", 93 | " \n", 94 | " def get_target(self,path):\n", 95 | " \"\"\"Get the target (x, y) coordinates from path\"\"\"\n", 96 | "\n", 97 | " name = os.path.basename(path)\n", 98 | " items = name.split('_')\n", 99 | " x = items[0]\n", 100 | " y = items[1]\n", 101 | "\n", 102 | " # Scale between -1 and 1\n", 103 | " x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right\n", 104 | " y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom\n", 105 | "\n", 106 | " return [x, y]\n", 107 | "\n", 108 | " def __len__(self):\n", 109 | " return len(self.paths)\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "TRANSFORMS = transforms.Compose([\n", 119 | " transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),\n", 120 | " transforms.Resize((224, 224)),\n", 121 | " transforms.ToTensor(),\n", 122 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 123 | "])\n", 124 | "\n", 125 | "all_rooms = False # Change if you want to use all the data\n", 126 | "\n", 127 | "paths = glob.glob('../data/room_1/*')\n", 128 | "if all_rooms:\n", 129 | " paths = paths + glob.glob('../data/room_2/*') + glob.glob('../data/room_3/*')\n", 130 | "\n", 131 | "# Shuffle the paths\n", 132 | "random.shuffle(paths)\n", 133 | "\n", 134 | "# Create a datasets for training and validation\n", 135 | "split = int(0.8 * len(paths))\n", 136 | "train_data = ImageDataset(paths[:split], TRANSFORMS)\n", 137 | "valid_data = ImageDataset(paths[split:], TRANSFORMS)\n", 138 | "\n", 139 | "# Prepare data for Pytorch model\n", 140 | "train_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n", 141 | "valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())\n", 142 | "\n", 143 | "print(train_data.__len__())\n", 144 | "print(valid_data.__len__())" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 63, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "output_dim = 2 # x, y\n", 154 | "device = torch.device('cpu') # or 'cuda' if you have a GPU\n", 155 | "\n", 156 | "# RESNET 18\n", 157 | "model = torchvision.models.resnet18(pretrained=True)\n", 158 | "model.fc = torch.nn.Linear(512, output_dim)\n", 159 | "model = model.to(device)\n", 160 | "\n", 161 | "optimizer = torch.optim.Adam(model.parameters())" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "name = \"direction_model_1\" # Change this to save a new model\n", 171 | "\n", 172 | "# Train the model\n", 173 | "min_loss = np.inf\n", 174 | "for epoch in range(10):\n", 175 | "\n", 176 | " model = model.train()\n", 177 | " for images, target in iter(train_loader):\n", 178 | "\n", 179 | " images = images.to(device)\n", 180 | " target = target.to(device)\n", 181 | " \n", 182 | " # Zero gradients of parameters\n", 183 | " optimizer.zero_grad() \n", 184 | "\n", 185 | " # Execute model to get outputs\n", 186 | " output = model(images)\n", 187 | "\n", 188 | " # Calculate loss\n", 189 | " loss = torch.nn.functional.mse_loss(output, target)\n", 190 | "\n", 191 | " # Run backpropogation to accumulate gradients\n", 192 | " loss.backward()\n", 193 | "\n", 194 | " # Update model parameters\n", 195 | " optimizer.step()\n", 196 | "\n", 197 | " # Calculate validation loss\n", 198 | " model = model.eval()\n", 199 | "\n", 200 | " images, target = next(iter(valid_loader))\n", 201 | " images = images.to(device)\n", 202 | " target = target.to(device)\n", 203 | "\n", 204 | " output = model(images)\n", 205 | " valid_loss = torch.nn.functional.mse_loss(output, target)\n", 206 | "\n", 207 | " print(\"Epoch: {}, Validation Loss: {}\".format(epoch, valid_loss.item()))\n", 208 | " \n", 209 | " if valid_loss < min_loss:\n", 210 | " print(\"Saving model\")\n", 211 | " torch.save(model, '../models/{}.pth'.format(name))\n", 212 | "\n", 213 | " min_loss = valid_loss" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "# Model Evaluation" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 9, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def model_evaluation(loaders,labels,save_path = None):\n", 230 | "\n", 231 | " \"\"\"Evaluate direction models with mse and scatter plots\n", 232 | " loaders: list of data loaders\n", 233 | " labels: list of labels for plot title\"\"\"\n", 234 | "\n", 235 | " n = len(loaders)\n", 236 | " fig, axs = plt.subplots(1, n, figsize=(7*n, 6))\n", 237 | " fig.patch.set_facecolor('xkcd:white')\n", 238 | "\n", 239 | " # Evalution metrics\n", 240 | " for i, loader in enumerate(loaders):\n", 241 | "\n", 242 | " # Load all data\n", 243 | " images, target = next(iter(loader))\n", 244 | " images = images.to(device)\n", 245 | " target = target.to(device)\n", 246 | "\n", 247 | " output=model(images)\n", 248 | "\n", 249 | " # Get x predictions\n", 250 | " x_pred=output.detach().cpu().numpy()[:,0]\n", 251 | " x_target=target.cpu().numpy()[:,0]\n", 252 | "\n", 253 | " # Calculate MSE\n", 254 | " mse = mean_squared_error(x_target, x_pred)\n", 255 | "\n", 256 | " # Plot predcitons\n", 257 | " axs[i].scatter(x_target,x_pred)\n", 258 | " axs[i].plot([-1, 1], \n", 259 | " [-1, 1], \n", 260 | " color='r', \n", 261 | " linestyle='-', \n", 262 | " linewidth=2)\n", 263 | "\n", 264 | " axs[i].set_ylabel('Predicted x', size =15)\n", 265 | " axs[i].set_xlabel('Actual x', size =15)\n", 266 | " axs[i].set_title(\"{0} MSE: {1:.4f}\".format(labels[i], mse),size = 18)\n", 267 | "\n", 268 | " if save_path != None:\n", 269 | " fig.savefig(save_path)\n" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "# Load saved model \n", 279 | "model = torch.load('../models/direction_model_1.pth')\n", 280 | "model.eval()\n", 281 | "model.to(device)\n", 282 | "\n", 283 | "# Create new loader for all data\n", 284 | "train_loader = DataLoader(train_data, batch_size=train_data.__len__())\n", 285 | "\n", 286 | "# Evaluate model on training and validation set\n", 287 | "loaders = [train_loader,valid_loader]\n", 288 | "labels = [\"Train\",\"Validation\"]\n", 289 | "\n", 290 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_1.png\"\n", 291 | "model_evaluation(loaders,labels,save_path=path)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "# Evaluate on data for additonal rooms\n", 301 | "room_2 = glob.glob('../data/room_2/*')\n", 302 | "room_3 = glob.glob('../data/room_3/*')\n", 303 | "\n", 304 | "room_2_data = ImageDataset(room_2, TRANSFORMS)\n", 305 | "room_3_data = ImageDataset(room_3, TRANSFORMS)\n", 306 | "\n", 307 | "room_2_loader = DataLoader(room_2_data, batch_size=room_2_data.__len__())\n", 308 | "room_3_loader = DataLoader(room_3_data, batch_size=room_3_data.__len__())\n", 309 | "\n", 310 | "# Evaluate model on training and validation set\n", 311 | "loaders = [room_2_loader ,room_3_loader]\n", 312 | "labels = [\"Room 2\",\"Room 3\"]\n", 313 | "\n", 314 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_2.png\"\n", 315 | "model_evaluation(loaders,labels, save_path=path)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "# Load saved model \n", 325 | "model = torch.load('../models/direction_model_2.pth')\n", 326 | "\n", 327 | "model.eval()\n", 328 | "model.to(device)\n", 329 | "\n", 330 | "# Evaluate model on training and validation set\n", 331 | "loaders = [room_2_loader ,room_3_loader]\n", 332 | "labels = [\"Room 2\",\"Room 3\"]\n", 333 | "\n", 334 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_3.png\"\n", 335 | "model_evaluation(loaders,labels,save_path=path)" 336 | ] 337 | }, 338 | { 339 | "attachments": {}, 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "# SHAP Explainer " 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 7, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "# Load saved model \n", 353 | "model = torch.load('../models/direction_model_1.pth') #change for different model\n", 354 | "model.eval()\n", 355 | "\n", 356 | "# Use CPU\n", 357 | "device = torch.device('cpu')\n", 358 | "model = model.to(device)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 8, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "#Load 100 images for background\n", 368 | "shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)\n", 369 | "background, _ = next(iter(shap_loader))\n", 370 | "background = background.to(device)\n", 371 | "\n", 372 | "#Create SHAP explainer \n", 373 | "explainer = shap.DeepExplainer(model, background)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "# Load test images of right and left turn\n", 383 | "paths = glob.glob('../data/room_1/*')\n", 384 | "test_images = [Image.open(paths[0]), Image.open(paths[3])]\n", 385 | "test_images = np.array(test_images)\n", 386 | "\n", 387 | "test_input = [TRANSFORMS(img) for img in test_images]\n", 388 | "test_input = torch.stack(test_input).to(device)\n", 389 | "\n", 390 | "# Get SHAP values\n", 391 | "shap_values = explainer.shap_values(test_input)\n", 392 | "\n", 393 | "# Reshape shap values and images for plotting\n", 394 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n", 395 | "test_numpy = np.array([np.array(img) for img in test_images])\n", 396 | "\n", 397 | "shap.image_plot(shap_numpy, test_numpy,show=False)" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "# Using gradient explainer\n", 407 | "explainer = shap.GradientExplainer(model, background)\n", 408 | "shap_values = explainer.shap_values(test_input)\n", 409 | "\n", 410 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n", 411 | "\n", 412 | "shap.image_plot(shap_numpy, test_numpy)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "# Load model trained on room 1, 2 and 3\n", 422 | "model = torch.load('../models/direction_model_2.pth') #change for different model\n", 423 | "\n", 424 | "# Use CPU\n", 425 | "device = torch.device('cpu')\n", 426 | "model = model.to(device)\n", 427 | "\n", 428 | "#Load 100 images for background\n", 429 | "shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)\n", 430 | "background, _ = next(iter(shap_loader))\n", 431 | "background = background.to(device)\n", 432 | "\n", 433 | "#Create SHAP explainer \n", 434 | "explainer = shap.DeepExplainer(model, background)\n", 435 | "\n", 436 | "# Load test images of right and left turn\n", 437 | "paths = glob.glob('../data/room_1/*')\n", 438 | "test_images = [Image.open(paths[0]), Image.open(paths[3])]\n", 439 | "test_images = np.array(test_images)\n", 440 | "\n", 441 | "# Transform images\n", 442 | "test_input = [TRANSFORMS(img) for img in test_images]\n", 443 | "test_input = torch.stack(test_input).to(device)\n", 444 | "\n", 445 | "# Get SHAP values\n", 446 | "shap_values = explainer.shap_values(test_input)\n", 447 | "\n", 448 | "# Reshape shap values and images for plotting\n", 449 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n", 450 | "test_numpy = np.array([np.array(img) for img in test_images])\n", 451 | "\n", 452 | "shap.image_plot(shap_numpy, test_numpy,show=False)\n", 453 | "plt.savefig(\"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/shap_plot_2.png\",facecolor='white',dpi=300,bbox_inches='tight')" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "# Using Gradient Explainer\n", 463 | "e = shap.GradientExplainer(model, background)\n", 464 | "shap_values = e.shap_values(test_input)\n", 465 | "\n", 466 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n", 467 | "test_numpy = np.array([np.array(img) for img in test_images])\n", 468 | "\n", 469 | "shap.image_plot(shap_numpy, test_numpy)" 470 | ] 471 | } 472 | ], 473 | "metadata": { 474 | "kernelspec": { 475 | "display_name": "pytorch", 476 | "language": "python", 477 | "name": "pytorch" 478 | }, 479 | "language_info": { 480 | "codemirror_mode": { 481 | "name": "ipython", 482 | "version": 3 483 | }, 484 | "file_extension": ".py", 485 | "mimetype": "text/x-python", 486 | "name": "python", 487 | "nbconvert_exporter": "python", 488 | "pygments_lexer": "ipython3", 489 | "version": "3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ]" 490 | }, 491 | "orig_nbformat": 4, 492 | "vscode": { 493 | "interpreter": { 494 | "hash": "3c0d4fcf1a0a408688084e944cab5ef64e86c1ae9800e884f9b7a2ac0ee51db6" 495 | } 496 | } 497 | }, 498 | "nbformat": 4, 499 | "nbformat_minor": 2 500 | } 501 | -------------------------------------------------------------------------------- /src/archive/project_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Project 1: salary bonus\n", 8 | "
\n", 9 | "Use SHAP to answer the following questions:\n", 10 | "
    \n", 11 | "
  1. Which features do NOT have a significant relationship with bonus?\n", 12 | "
  2. What tends to happens to an employee's bonus as they gain more experience? \n", 13 | "
  3. Are there any potential interactions in the dataset? \n", 14 | "
\n", 15 | "
\n", 16 | "Dataset: https://www.kaggle.com/conorsully1/interaction-dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "#imports\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import seaborn as sns\n", 30 | "\n", 31 | "from sklearn.ensemble import RandomForestRegressor\n", 32 | "\n", 33 | "import shap\n", 34 | "shap.initjs()\n", 35 | "\n", 36 | "path = \"/Users/conorosully/Google Drive/My Drive/Medium/SHAP Interactions/Figures/{}\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## Dataset" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "#import dataset\n", 53 | "data = pd.read_csv(\"../data/interaction_dataset.csv\",sep='\\t')\n", 54 | "\n", 55 | "y = data['bonus']\n", 56 | "X = data.drop('bonus', axis=1)\n", 57 | "\n", 58 | "print(len(data))\n", 59 | "data.head()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Modelling" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "#Train model\n", 76 | "model = RandomForestRegressor(n_estimators=100) \n", 77 | "model.fit(X, y)\n", 78 | "\n", 79 | "#Get predictions\n", 80 | "y_pred = model.predict(X)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "#Model evaluation\n", 90 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,8))\n", 91 | "\n", 92 | "plt.scatter(y,y_pred)\n", 93 | "plt.plot([0, 400], [0, 400], color='r', linestyle='-', linewidth=2)\n", 94 | "\n", 95 | "plt.ylabel('Predicted',size=20)\n", 96 | "plt.xlabel('Actual',size=20)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Standard SHAP values" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "#Get SHAP values\n", 113 | "explainer = shap.Explainer(model,X[0:10])\n", 114 | "shap_values = explainer(X)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# Which features do NOT have a significant relationship with bonus?" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# What tends to happens to an employee's bonus as they gain more experience? " 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## SHAP interaction values" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "#Get SHAP interaction values\n", 149 | "explainer = shap.Explainer(model)\n", 150 | "shap_interaction = explainer.shap_interaction_values(X)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# Are there any potential interactions in the dataset? " 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "SHAP", 166 | "language": "python", 167 | "name": "shap" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.10.6" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /src/archive/project_1_solution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Project 1: salary bonus\n", 8 | "
\n", 9 | "Use the SHAP analysis to answer the following questions:\n", 10 | "
    \n", 11 | "
  1. Which features does NOT have a significant relationship with bonus?\n", 12 | "
  2. What tends to happens to an employee's bonus as they gain more experience? \n", 13 | "
  3. Are there any potential interactions in the dataset? \n", 14 | "
\n", 15 | "
\n", 16 | "Dataset: https://www.kaggle.com/conorsully1/interaction-dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "#imports\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import seaborn as sns\n", 30 | "\n", 31 | "from sklearn.ensemble import RandomForestRegressor\n", 32 | "\n", 33 | "import shap\n", 34 | "shap.initjs()\n", 35 | "\n", 36 | "path = \"/Users/conorosully/Google Drive/My Drive/Medium/SHAP Interactions/Figures/{}\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## Dataset" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "#import dataset\n", 53 | "data = pd.read_csv(\"../data/interaction_dataset.csv\",sep='\\t')\n", 54 | "\n", 55 | "y = data['bonus']\n", 56 | "X = data.drop('bonus', axis=1)\n", 57 | "\n", 58 | "print(len(data))\n", 59 | "data.head()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Modelling" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "#Train model\n", 76 | "model = RandomForestRegressor(n_estimators=100) \n", 77 | "model.fit(X, y)\n", 78 | "\n", 79 | "#Get predictions\n", 80 | "y_pred = model.predict(X)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "#Model evaluation\n", 90 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,8))\n", 91 | "\n", 92 | "plt.scatter(y,y_pred)\n", 93 | "plt.plot([0, 400], [0, 400], color='r', linestyle='-', linewidth=2)\n", 94 | "\n", 95 | "plt.ylabel('Predicted',size=20)\n", 96 | "plt.xlabel('Actual',size=20)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Standard SHAP values" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 7, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "#Get SHAP values\n", 113 | "explainer = shap.Explainer(model,X[0:10])\n", 114 | "shap_values = explainer(X)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# waterfall plot for first observation\n", 124 | "shap.plots.waterfall(shap_values[0])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# Which features do NOT have a significant relationship with bonus?\n", 134 | "# Answer: days_late\n", 135 | "shap.plots.bar(shap_values)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "#What tends to happens to an employee's bonus as they gain more experience? \n", 145 | "# Answer: their bonus increases\n", 146 | "# You could have also used a dependency plot\n", 147 | "shap.plots.beeswarm(shap_values)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "## SHAP interaction values" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 9, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "#Get SHAP interaction values\n", 164 | "explainer = shap.Explainer(model)\n", 165 | "shap_interaction = explainer.shap_interaction_values(X)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# Are there any potential interactions in the dataset? \n", 175 | "# Answer: yes - experience.degree & performance.sales" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "# Get absolute mean of matrices\n", 185 | "mean_shap = np.abs(shap_interaction).mean(0)\n", 186 | "df = pd.DataFrame(mean_shap,index=X.columns,columns=X.columns)\n", 187 | "\n", 188 | "# times off diagonal by 2\n", 189 | "df.where(df.values == np.diagonal(df),df.values*2,inplace=True)\n", 190 | "\n", 191 | "# display \n", 192 | "plt.figure(figsize=(10, 10), facecolor='w', edgecolor='k')\n", 193 | "sns.set(font_scale=1.5)\n", 194 | "sns.heatmap(df,cmap='coolwarm',annot=True,fmt='.3g',cbar=False)\n", 195 | "plt.yticks(rotation=0) " 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "# Experience-degree depenence plot\n", 205 | "shap.dependence_plot(\n", 206 | " (\"experience\", \"degree\"),\n", 207 | " shap_interaction, X,\n", 208 | " display_features=X)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "# Performance-sales depenence plot\n", 218 | "shap.dependence_plot(\n", 219 | " (\"performance\", \"sales\"),\n", 220 | " shap_interaction, X,\n", 221 | " display_features=X)" 222 | ] 223 | } 224 | ], 225 | "metadata": { 226 | "kernelspec": { 227 | "display_name": "shap", 228 | "language": "python", 229 | "name": "shap" 230 | }, 231 | "language_info": { 232 | "codemirror_mode": { 233 | "name": "ipython", 234 | "version": 3 235 | }, 236 | "file_extension": ".py", 237 | "mimetype": "text/x-python", 238 | "name": "python", 239 | "nbconvert_exporter": "python", 240 | "pygments_lexer": "ipython3", 241 | "version": "3.10.4" 242 | } 243 | }, 244 | "nbformat": 4, 245 | "nbformat_minor": 2 246 | } 247 | -------------------------------------------------------------------------------- /src/archive/project_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Project 2: mushroom classification\n", 8 | "
\n", 9 | "Use the SHAP analysis to answer the following questions:\n", 10 | "
    \n", 11 | "
  1. For the first prediction, which feature has the most significant contribution?\n", 12 | "
  2. Overall, which feature has the most significant contributions? \n", 13 | "
  3. Which odors are associated with poisonous mushrooms? \n", 14 | "
\n", 15 | "\n", 16 | "Dataset: https://www.kaggle.com/datasets/uciml/mushroom-classification" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "#imports\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "from catboost import CatBoostClassifier\n", 31 | "\n", 32 | "import shap\n", 33 | "\n", 34 | "from sklearn.metrics import accuracy_score,confusion_matrix" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#load data \n", 44 | "data = pd.read_csv(\"../data/mushrooms.csv\")\n", 45 | "\n", 46 | "#get features\n", 47 | "y = data['class']\n", 48 | "y = y.astype('category').cat.codes\n", 49 | "X = data.drop('class', axis=1)\n", 50 | "\n", 51 | "\n", 52 | "print(len(data))\n", 53 | "data.head()" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "scrolled": true 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "model = CatBoostClassifier(iterations=20,\n", 65 | " learning_rate=0.01,\n", 66 | " depth=3)\n", 67 | "\n", 68 | "# train model\n", 69 | "cat_features = list(range(len(X.columns)))\n", 70 | "model.fit(X, y, cat_features)\n", 71 | "\n", 72 | "#Get predictions\n", 73 | "y_pred = model.predict(X)\n", 74 | "\n", 75 | "print(confusion_matrix(y, y_pred))\n", 76 | "accuracy_score(y, y_pred)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "# Standard SHAP values" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# get shap values\n", 93 | "explainer = shap.Explainer(model)\n", 94 | "shap_values = explainer(X)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "#For the first prediction, which feature has the most significant contribution?" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "#Overall, which feature has the most significant contributions?" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "#Which odors are associated with poisonous mushrooms?" 122 | ] 123 | } 124 | ], 125 | "metadata": { 126 | "kernelspec": { 127 | "display_name": "SHAP", 128 | "language": "python", 129 | "name": "shap" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.10.6" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | -------------------------------------------------------------------------------- /src/archive/project_2_solution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Project 2: mushroom classification\n", 8 | "
\n", 9 | "Use the SHAP analysis to answer the following questions:\n", 10 | "
    \n", 11 | "
  1. For the first prediction, which feature has the most significant contibution?\n", 12 | "
  2. Overall, which feature has the most significant contributions? \n", 13 | "
  3. Which odors are associated with poisonous mushrooms? \n", 14 | "
\n", 15 | "\n", 16 | "Dataset: https://www.kaggle.com/datasets/uciml/mushroom-classification" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "#imports\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "from catboost import CatBoostClassifier\n", 31 | "\n", 32 | "import shap\n", 33 | "\n", 34 | "from sklearn.metrics import accuracy_score,confusion_matrix" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#load data \n", 44 | "data = pd.read_csv(\"../data/mushrooms.csv\")\n", 45 | "\n", 46 | "#get features\n", 47 | "y = data['class']\n", 48 | "y = y.astype('category').cat.codes\n", 49 | "X = data.drop('class', axis=1)\n", 50 | "\n", 51 | "# replace all categorical features with integer values\n", 52 | "for col in X.columns:\n", 53 | " X[col] = X[col].astype('category').cat.codes\n", 54 | "\n", 55 | "\n", 56 | "print(len(data))\n", 57 | "X.head()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "# Standard SHAP values" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# get shap values\n", 74 | "explainer = shap.Explainer(model)\n", 75 | "shap_values = explainer(X)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "#For the first prediction, which feature has the most significant contribution?\n", 85 | "#Answer: odor\n", 86 | "shap.plots.waterfall(shap_values[0],max_display=5)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "#Overall, which feature has the most significant contributions?\n", 96 | "#Answer: odor\n", 97 | "shap.plots.bar(shap_values,show=False)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "shap.plots.beeswarm(shap_values)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "#Which odors are associated with poisonous mushrooms?\n", 116 | "#All the odors with SHAP values > 0 " 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "#get shaply values and data\n", 126 | "odor_values = shap_values[:,4].values\n", 127 | "odor_data = X['odor']\n", 128 | "unique_odor = set(X['odor'])\n", 129 | "\n", 130 | "#split odor shap values based on odor category\n", 131 | "odor_categories = list(set(odor_data))\n", 132 | "\n", 133 | "odor_groups = []\n", 134 | "for o in odor_categories:\n", 135 | " relevant_values = odor_values[odor_data == o]\n", 136 | " odor_groups.append(relevant_values)\n", 137 | " \n", 138 | "#replace categories with labels\n", 139 | "odor_labels = {'a':'almond',\n", 140 | " 'l':'anise', \n", 141 | " 'c':'creosote', \n", 142 | " 'y':'fishy', \n", 143 | " 'f':'foul', \n", 144 | " 'm':'musty', \n", 145 | " 'n':'none', \n", 146 | " 'p':'pungent', \n", 147 | " 's':'spicy'}\n", 148 | "\n", 149 | "labels = [odor_labels[u] for u in unique_odor]\n", 150 | "\n", 151 | "#plot boxplot\n", 152 | "plt.figure(figsize=(8, 5))\n", 153 | "\n", 154 | "plt.boxplot(odor_groups,labels=labels)\n", 155 | "\n", 156 | "plt.ylabel('SHAP values',size=15)\n", 157 | "plt.xlabel('Odor',size=15)" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "xai", 164 | "language": "python", 165 | "name": "xai" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.9.12" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 2 182 | } 183 | -------------------------------------------------------------------------------- /src/kernel_vs_tree.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Kernel SHAP vs Tree SHAP\n", 8 | "Experiments to understand the time complexity of SHAP approximations" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "#imports\n", 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "\n", 22 | "#import xgboost as xgb\n", 23 | "from sklearn.ensemble import RandomForestRegressor\n", 24 | "import sklearn.datasets as ds\n", 25 | "\n", 26 | "import datetime\n", 27 | "\n", 28 | "import shap\n", 29 | "shap.initjs()" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# Functions\n", 39 | "def runSHAP(n,kernel=True): \n", 40 | " \"\"\"\n", 41 | " Calculate shap values and return time taken\n", 42 | " n: number of SHAP values to calculate\n", 43 | " kernel: set False if using TreeSHAP \n", 44 | " \"\"\"\n", 45 | " \n", 46 | " x_sample = X[np.random.choice(X.shape[0], n, replace=True)]\n", 47 | " \n", 48 | " begin = datetime.datetime.now()\n", 49 | " if kernel:\n", 50 | " #Caculate SHAP values using KernelSHAP\n", 51 | " shap_values = kernelSHAP.shap_values(x_sample,l1_reg=False)\n", 52 | " time = datetime.datetime.now() - begin\n", 53 | " print(\"Kernel {}: \".format(n), time)\n", 54 | " else:\n", 55 | " #Caculate SHAP values using TreeSHAP\n", 56 | " shap_values = treeSHAP(x_sample)\n", 57 | " time = datetime.datetime.now() - begin\n", 58 | " print(\"Tree {}: \".format(n), time)\n", 59 | " \n", 60 | " return time\n", 61 | "\n", 62 | "def model_properties(model):\n", 63 | " \"\"\"Returns average depth and number of features and leaves of a random forest\"\"\"\n", 64 | " \n", 65 | " depths = []\n", 66 | " features = []\n", 67 | " leaves = []\n", 68 | " \n", 69 | " for tree in model.estimators_:\n", 70 | " depths.append(tree.get_depth())\n", 71 | " leaves.append(tree.get_n_leaves())\n", 72 | " n_feat = len(set(tree.tree_.feature)) -1 \n", 73 | " features.append(n_feat)\n", 74 | " \n", 75 | " return np.mean(depths), np.mean(features), np.mean(leaves)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "## Experiment 1: Number of samples" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "#Simulate regression data\n", 92 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n", 93 | "\n", 94 | "y= data[1]\n", 95 | "X = data[0]\n", 96 | "\n", 97 | "feature_names = range(len(X))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "#Train model\n", 107 | "model = RandomForestRegressor(n_estimators=100,max_depth=4,random_state=0)\n", 108 | "model.fit(X, y)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "#Get shap estimators\n", 118 | "kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n", 119 | "treeSHAP = shap.TreeExplainer(model)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "results = []\n", 129 | "for n in [10,100,1000,2000,5000,10000]*3:\n", 130 | " #Calculate SHAP Values\n", 131 | " kernel_time = runSHAP(n=n)\n", 132 | " tree_time = runSHAP(n=n,kernel=False)\n", 133 | " \n", 134 | " result = [n,kernel_time,tree_time]\n", 135 | " results.append(result)\n", 136 | " \n", 137 | "results_1 = pd.DataFrame(results,columns = ['n','kernelSHAP','treeSHAP'])" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "avg_1 = results_1.groupby(by='n',as_index=False).mean()\n", 147 | "avg_1" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "k_sec" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "#Find average run time\n", 166 | "avg_1 = results_1.groupby(by='n',as_index=False).mean()\n", 167 | "\n", 168 | "k_sec = [t.total_seconds() for t in avg_1['kernelSHAP']]\n", 169 | "t_sec = [t.total_seconds() for t in avg_1['treeSHAP']]\n", 170 | "n = avg_1['n']\n", 171 | "\n", 172 | "#Proportional run time\n", 173 | "print((k_sec/n)/(t_sec/n))\n", 174 | "\n", 175 | "#Plot run time by number of observations\n", 176 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n", 177 | "\n", 178 | "plt.plot(n, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n", 179 | "plt.plot(n, t_sec, linestyle='-', linewidth=2,marker='o',label = 'TreeSHAP')\n", 180 | "\n", 181 | "plt.ylabel('Time (seconds)',size=20)\n", 182 | "plt.xlabel('Number of observations',size=20)\n", 183 | "plt.legend(fontsize=15)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "#Number of observations\n", 193 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n", 194 | "\n", 195 | "plt.plot(n, t_sec, linestyle='-', color='#F87F0E',linewidth=2,marker='o',label = 'TreeSHAP')\n", 196 | "\n", 197 | "plt.ylabel('Time (seconds)',size=20)\n", 198 | "plt.xlabel('Number of observations',size=20)\n", 199 | "plt.legend(fontsize=15)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "## Experiment 2: number of features\n", 207 | " " 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "scrolled": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "results = []\n", 219 | "\n", 220 | "for n_features, n_informative in zip([2,4,6,8,10,12,13,14,16,18,20]*3,[2,4,6,8,10,12,13,14,16,18,20]*3):\n", 221 | " \n", 222 | " #Simulate regression data\n", 223 | " data = ds.make_regression(n_samples=10000, n_features=n_features, n_informative=n_informative, n_targets=1,noise=0.1)\n", 224 | "\n", 225 | " y= data[1]\n", 226 | " X = data[0]\n", 227 | "\n", 228 | " feature_names = range(len(X))\n", 229 | "\n", 230 | " #Train model\n", 231 | " model = RandomForestRegressor(n_estimators=100,max_depth=10,random_state=0)\n", 232 | " model.fit(X, y)\n", 233 | " \n", 234 | " #get model properties\n", 235 | " avg_depth, avg_feat, avg_leaves = model_properties(model)\n", 236 | " \n", 237 | " #Get shap estimators\n", 238 | " kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n", 239 | " treeSHAP = shap.TreeExplainer(model)\n", 240 | " \n", 241 | " #Calculate SHAP values\n", 242 | " kernel_time = runSHAP(n=100)\n", 243 | " tree_time = runSHAP(n=100,kernel=False)\n", 244 | " \n", 245 | " result = [n_features, avg_depth, avg_feat, avg_leaves, kernel_time,tree_time]\n", 246 | " results.append(result)\n", 247 | "\n", 248 | "results_2 = pd.DataFrame(results,columns = ['n_features','avg_depth', 'avg_feat', 'avg_leaves','kernelSHAP','treeSHAP'])\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "#Get average run time\n", 258 | "avg_2 = results_2[['n_features','kernelSHAP','treeSHAP']].groupby(by='n_features',as_index=False).mean()\n", 259 | "\n", 260 | "k_sec = [t.total_seconds() for t in avg_2['kernelSHAP']]\n", 261 | "t_sec = [t.total_seconds() for t in avg_2['treeSHAP']]\n", 262 | "n = avg_2['n_features']\n", 263 | "\n", 264 | "print((k_sec/n)/(t_sec/n))\n", 265 | "\n", 266 | "#Plot run time by number of features\n", 267 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n", 268 | "\n", 269 | "plt.plot(n, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n", 270 | "plt.plot(n, t_sec, linestyle='-', linewidth=2,marker='o',label = 'TreeSHAP')\n", 271 | "\n", 272 | "plt.ylabel('Time (seconds)',size=20)\n", 273 | "plt.xlabel('Number of features',size=20)\n", 274 | "plt.legend(fontsize=15)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "## Experiment 3: number of trees" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "#Simulate regression data\n", 291 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n", 292 | "\n", 293 | "y= data[1]\n", 294 | "X = data[0]\n", 295 | "\n", 296 | "feature_names = range(len(X))" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "results = []\n", 306 | "\n", 307 | "for trees in [10,20,50,100,200,500,1000]*3:\n", 308 | " #Train model\n", 309 | " model = RandomForestRegressor(n_estimators=trees,max_depth=4,random_state=0)\n", 310 | " model.fit(X, y)\n", 311 | " \n", 312 | " #Get shap estimators\n", 313 | " kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n", 314 | " treeSHAP = shap.TreeExplainer(model)\n", 315 | " \n", 316 | " #Calculate SHAP Values\n", 317 | " kernel_time = runSHAP(n=100)\n", 318 | " tree_time = runSHAP(n=100,kernel=False)\n", 319 | " \n", 320 | " result = [trees,kernel_time,tree_time]\n", 321 | " results.append(result)\n", 322 | "\n", 323 | "results_3 = pd.DataFrame(results,columns = ['trees','kernelSHAP','treeSHAP'])" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "#Get average run time\n", 333 | "avg_3 = results_3.groupby(by='trees',as_index=False).mean()\n", 334 | "\n", 335 | "k_sec = [t.total_seconds() for t in avg_3['kernelSHAP']]\n", 336 | "t_sec = [t.total_seconds() for t in avg_3['treeSHAP']]\n", 337 | "trees = avg_3['trees']\n", 338 | "\n", 339 | "print((k_sec/trees)/(t_sec/trees))\n", 340 | "\n", 341 | "#Plot run time by number of trees\n", 342 | "fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(20,10))\n", 343 | "\n", 344 | "ax[0].plot(trees, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n", 345 | "ax[0].set_ylabel('Time (seconds)',size=20)\n", 346 | "ax[0].set_xlabel('Number of trees',size=20)\n", 347 | "ax[0].legend(fontsize=15)\n", 348 | "\n", 349 | "ax[1].plot(trees, t_sec, color='#F87F0E', linewidth=2,marker='o',label = 'TreeSHAP')\n", 350 | "ax[1].set_ylabel('Time (seconds)',size=20)\n", 351 | "ax[1].set_xlabel('Number of trees',size=20)\n", 352 | "ax[1].legend(fontsize=15)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "## Experiment 4: tree depth" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "scrolled": true 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "#Simulate regression data\n", 371 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n", 372 | "\n", 373 | "y= data[1]\n", 374 | "X = data[0]\n", 375 | "\n", 376 | "feature_names = range(len(X))\n", 377 | "\n", 378 | "results = []\n", 379 | "\n", 380 | "#for depth in [2,4,6]:\n", 381 | "for depth in [2,4,6,8,10,15,20]*3:\n", 382 | "\n", 383 | " #Train model\n", 384 | " model = RandomForestRegressor(n_estimators=100,max_depth=depth,random_state=0)\n", 385 | " model.fit(X, y)\n", 386 | " \n", 387 | " #get model properties\n", 388 | " avg_depth, avg_feat, avg_leaves = model_properties(model)\n", 389 | " \n", 390 | " #Get shap estimators\n", 391 | " kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n", 392 | " treeSHAP = shap.TreeExplainer(model)\n", 393 | " \n", 394 | " #Calculate SHAP values\n", 395 | " kernel_time = runSHAP(n=100)\n", 396 | " tree_time = runSHAP(n=100,kernel=False)\n", 397 | " \n", 398 | " result = [depth, avg_depth, avg_feat, avg_leaves, kernel_time,tree_time]\n", 399 | " results.append(result)\n", 400 | "\n", 401 | "results_4 = pd.DataFrame(results,columns = ['depth','avg_depth', 'avg_feat', 'avg_leaves','kernelSHAP','treeSHAP'])" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "#Get average run time\n", 411 | "avg_4 = results_4[['depth','kernelSHAP','treeSHAP']].groupby(by='depth',as_index=False).mean()\n", 412 | "\n", 413 | "k_sec = [t.total_seconds() for t in avg_4['kernelSHAP']]\n", 414 | "t_sec = [t.total_seconds() for t in avg_4['treeSHAP']]\n", 415 | "depth = avg_4['depth']\n", 416 | "\n", 417 | "#Plot run tume by tree depth\n", 418 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n", 419 | "\n", 420 | "plt.plot(depth, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n", 421 | "plt.plot(depth, t_sec, linestyle='-', linewidth=2,marker='o',label = 'TreeSHAP')\n", 422 | "plt.legend(fontsize=15)\n", 423 | "\n", 424 | "plt.ylabel('Time (seconds)',size=20)\n", 425 | "plt.xlabel('Tree depth',size=20)" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "#Other factors\n", 435 | "r4 = results_4[['depth','avg_depth','avg_feat','avg_leaves']].groupby(by='depth',as_index=False).mean()\n", 436 | "\n", 437 | "fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(20,10))\n", 438 | "\n", 439 | "ax[0].plot(r4['depth'], r4['avg_feat'], linestyle='-', linewidth=2,marker='o')\n", 440 | "ax[0].set_ylabel('Average features',size=20)\n", 441 | "ax[0].set_xlabel('Tree depth',size=20)\n", 442 | "\n", 443 | "ax[1].plot(r4['depth'], r4['avg_leaves'], color='#F87F0E', linewidth=2,marker='o')\n", 444 | "ax[1].set_ylabel('Average leaves',size=20)\n", 445 | "ax[1].set_xlabel('Tree depth',size=20)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "metadata": {}, 451 | "source": [ 452 | "# Archive " 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "#\n", 462 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n", 463 | "\n", 464 | "y= data[1]\n", 465 | "X = data[0]\n", 466 | "\n", 467 | "feature_names = range(len(X))\n", 468 | "\n", 469 | "depth = 10 # vary this value \n", 470 | "model = RandomForestRegressor(n_estimators=100,max_depth=depth,random_state=0)\n", 471 | "model.fit(X, y)\n", 472 | "\n", 473 | "model_properties(model)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "#Simulate regression data\n", 483 | "data = ds.make_regression(n_samples=10000, n_features=20, n_informative=20, n_targets=1,noise=0.1)\n", 484 | "\n", 485 | "y= data[1]\n", 486 | "X = data[0]\n", 487 | "\n", 488 | "feature_names = range(len(X))\n", 489 | "\n", 490 | "#Train model\n", 491 | "model = RandomForestRegressor(n_estimators=100,max_depth=10,random_state=0)\n", 492 | "model.fit(X, y)\n", 493 | "\n", 494 | "#get model properties\n", 495 | "avg_depth, avg_feat, avg_leaves = model_properties(model)\n", 496 | "\n", 497 | "\n", 498 | "#Get shap estimators\n", 499 | "treeSHAP = shap.TreeExplainer(model)\n", 500 | "kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 20))\n", 501 | "\n", 502 | "#get shap values \n", 503 | "x_sample = X[np.random.choice(X.shape[0], 100, replace=True)]\n", 504 | "sv_tree = treeSHAP.shap_values(x_sample)\n", 505 | "sv_kernel = kernelSHAP.shap_values(x_sample,l1_reg=0.1)\n", 506 | "\n", 507 | "print(len(sv_tree[0]),len(sv_kernel[0]))" 508 | ] 509 | } 510 | ], 511 | "metadata": { 512 | "kernelspec": { 513 | "display_name": "SHAP", 514 | "language": "python", 515 | "name": "shap" 516 | }, 517 | "language_info": { 518 | "codemirror_mode": { 519 | "name": "ipython", 520 | "version": 3 521 | }, 522 | "file_extension": ".py", 523 | "mimetype": "text/x-python", 524 | "name": "python", 525 | "nbconvert_exporter": "python", 526 | "pygments_lexer": "ipython3", 527 | "version": "3.10.6" 528 | } 529 | }, 530 | "nbformat": 4, 531 | "nbformat_minor": 2 532 | } 533 | -------------------------------------------------------------------------------- /src/shap_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SHAP Tutorial\n", 8 | "\n", 9 | "
\n", 10 | "Course sections:\n", 11 | "
    \n", 12 | "
  1. SHAP values\n", 13 | "
  2. SHAP aggregations\n", 14 | "
      \n", 15 | "
    1. Force plots\n", 16 | "
    2. Mean SHAP\n", 17 | "
    3. Beeswarm\n", 18 | "
    4. Violin\n", 19 | "
    5. Heatmap\n", 20 | "
    6. Dependence\n", 21 | "
    \n", 22 | "
  3. Custom SHAP plots\n", 23 | "
  4. Binary and mutliclass target variables \n", 24 | "
  5. SHAP interaction values\n", 25 | "
  6. Categorical features\n", 26 | "
\n", 27 | "
\n", 28 | "Dataset: https://archive.ics.uci.edu/ml/datasets/Abalone\n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# imports\n", 38 | "import pandas as pd\n", 39 | "import numpy as np\n", 40 | "\n", 41 | "import matplotlib.pyplot as plt\n", 42 | "import seaborn as sns\n", 43 | "\n", 44 | "import xgboost as xgb\n", 45 | "\n", 46 | "import shap\n", 47 | "\n", 48 | "shap.initjs()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# Dataset\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# import dataset\n", 65 | "data = pd.read_csv(\n", 66 | " \"../data/abalone.data\",\n", 67 | " names=[\n", 68 | " \"sex\",\n", 69 | " \"length\",\n", 70 | " \"diameter\",\n", 71 | " \"height\",\n", 72 | " \"whole weight\",\n", 73 | " \"shucked weight\",\n", 74 | " \"viscera weight\",\n", 75 | " \"shell weight\",\n", 76 | " \"rings\",\n", 77 | " ],\n", 78 | ")\n", 79 | "\n", 80 | "print(len(data))\n", 81 | "data.head()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# plot 1: whole weight\n", 91 | "plt.scatter(data[\"whole weight\"], data[\"rings\"])\n", 92 | "plt.ylabel(\"rings\", size=20)\n", 93 | "plt.xlabel(\"whole weight\", size=20)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# plot 2: sex\n", 103 | "plt.boxplot(data[data.sex == \"I\"][\"rings\"], positions=[1])\n", 104 | "plt.boxplot(data[data.sex == \"M\"][\"rings\"], positions=[2])\n", 105 | "plt.boxplot(data[data.sex == \"F\"][\"rings\"], positions=[3])\n", 106 | "\n", 107 | "plt.xticks(ticks=[1, 2, 3], labels=[\"I\", \"M\", \"F\"], size=15)\n", 108 | "plt.ylabel(\"rings\", size=20)\n", 109 | "plt.xlabel(\"sex\", size=20)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# plot 3: Correlation heatmap\n", 119 | "cont = [\n", 120 | " \"length\",\n", 121 | " \"diameter\",\n", 122 | " \"height\",\n", 123 | " \"whole weight\",\n", 124 | " \"shucked weight\",\n", 125 | " \"viscera weight\",\n", 126 | " \"shell weight\",\n", 127 | " \"rings\",\n", 128 | "]\n", 129 | "corr_matrix = pd.DataFrame(data[cont], columns=cont).corr()\n", 130 | "\n", 131 | "sns.heatmap(corr_matrix, cmap=\"coolwarm\", center=0, annot=True, fmt=\".1g\")" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "# Feature Engineering\n" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "y = data[\"rings\"]\n", 148 | "X = data[[\"sex\", \"length\", \"height\", \"shucked weight\", \"viscera weight\", \"shell weight\"]]" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# create dummy variables\n", 158 | "X[\"sex.M\"] = [1 if s == \"M\" else 0 for s in X[\"sex\"]]\n", 159 | "X[\"sex.F\"] = [1 if s == \"F\" else 0 for s in X[\"sex\"]]\n", 160 | "X[\"sex.I\"] = [1 if s == \"I\" else 0 for s in X[\"sex\"]]\n", 161 | "X = X.drop(\"sex\", axis=1)\n", 162 | "\n", 163 | "X.head()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "features = X.copy()\n", 173 | "features['y'] = y\n", 174 | "\n", 175 | "features.head()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "# Modelling\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# train model\n", 192 | "model = xgb.XGBRegressor(objective=\"reg:squarederror\")\n", 193 | "model.fit(X, y)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# get predictions\n", 203 | "y_pred = model.predict(X)\n", 204 | "\n", 205 | "# model evaluation\n", 206 | "plt.figure(figsize=(5, 5))\n", 207 | "\n", 208 | "plt.scatter(y, y_pred)\n", 209 | "plt.plot([0, 30], [0, 30], color=\"r\", linestyle=\"-\", linewidth=2)\n", 210 | "\n", 211 | "plt.ylabel(\"Predicted\", size=20)\n", 212 | "plt.xlabel(\"Actual\", size=20)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "# 1) Standard SHAP values\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 11, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "# get shap values\n", 229 | "explainer = shap.Explainer(model)\n", 230 | "shap_values = explainer(X)\n", 231 | "\n", 232 | "# shap_values = explainer(X[0:100])" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "np.shape(shap_values.values)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "## Waterfall plot\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# waterfall plot for first observation\n", 258 | "shap.plots.waterfall(shap_values[0])" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "# waterfall plot for first observation\n", 268 | "shap.plots.waterfall(shap_values[1], max_display=4)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "# 2) SHAP aggregations\n" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "## Force plot\n" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "# force plot\n", 292 | "shap.plots.force(shap_values[0])" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "## Stacked force plot\n" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "# stacked force plot\n", 309 | "shap.plots.force(shap_values[0:100])" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "## Absolute Mean SHAP\n" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# mean SHAP\n", 326 | "shap.plots.bar(shap_values)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "## Beeswarm plot\n" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "# beeswarm plot\n", 343 | "shap.plots.beeswarm(shap_values)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "## Violin plot\n" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "# violin plot\n", 360 | "shap.plots.violin(shap_values)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "# layered violin plot\n", 370 | "shap.plots.violin(shap_values, plot_type=\"layered_violin\")" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "metadata": {}, 376 | "source": [ 377 | "## Heamap\n" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "# heatmap\n", 387 | "shap.plots.heatmap(shap_values)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "# order by predictions\n", 397 | "order = np.argsort(y_pred)\n", 398 | "shap.plots.heatmap(shap_values, instance_order=order)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "# order by shell weight value\n", 408 | "order = np.argsort(data[\"shell weight\"])\n", 409 | "shap.plots.heatmap(shap_values, instance_order=order)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": {}, 415 | "source": [ 416 | "## Dependence plots\n" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# plot 1: shell weight\n", 426 | "shap.plots.scatter(shap_values[:, \"shell weight\"])" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "shap.plots.scatter(\n", 436 | " shap_values[:, \"shell weight\"], color=shap_values[:, \"shucked weight\"]\n", 437 | ")" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "# plot 2: shucked weight\n", 447 | "shap.plots.scatter(shap_values[:, \"shucked weight\"])" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": {}, 453 | "source": [ 454 | "# 3) Custom Plots\n" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "# output SHAP object\n", 464 | "shap_values" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "np.shape(shap_values.values)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "# SHAP correlation plot\n", 483 | "corr_matrix = pd.DataFrame(shap_values.values, columns=X.columns).corr()\n", 484 | "\n", 485 | "sns.set(font_scale=1)\n", 486 | "sns.heatmap(corr_matrix, cmap=\"coolwarm\", center=0, annot=True, fmt=\".1g\")" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "# 4) Binary and categorical target variables\n" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "### Binary target variable\n" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 30, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "# binary target varibale\n", 510 | "y_bin = [1 if y_ > 10 else 0 for y_ in y]" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [ 519 | "# train model\n", 520 | "model_bin = xgb.XGBClassifier(objective=\"binary:logistic\")\n", 521 | "model_bin.fit(X, y_bin)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "# get shap values\n", 531 | "explainer = shap.Explainer(model_bin)\n", 532 | "shap_values_bin = explainer(X)\n", 533 | "\n", 534 | "print(shap_values_bin.shape)" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [ 543 | "# waterfall plot for first observation\n", 544 | "shap.plots.waterfall(shap_values_bin[0])" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "# waterfall plot for first observation\n", 554 | "shap.plots.force(shap_values_bin[0], link=\"logit\")" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": null, 560 | "metadata": {}, 561 | "outputs": [], 562 | "source": [ 563 | "# waterfall plot for first observation\n", 564 | "shap.plots.bar(shap_values_bin)" 565 | ] 566 | }, 567 | { 568 | "cell_type": "markdown", 569 | "metadata": {}, 570 | "source": [ 571 | "### Categorical target variables\n" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "# categorical target varibale\n", 581 | "y_cat = [2 if y_ > 12 else 1 if y_ > 8 else 0 for y_ in y]\n", 582 | "\n", 583 | "# train model\n", 584 | "model_cat = xgb.XGBClassifier(objective=\"binary:logistic\")\n", 585 | "model_cat.fit(X, y_cat)" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": null, 591 | "metadata": {}, 592 | "outputs": [], 593 | "source": [ 594 | "# get probability predictions\n", 595 | "model_cat.predict_proba(X)[0]" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": null, 601 | "metadata": {}, 602 | "outputs": [], 603 | "source": [ 604 | "# get shap values\n", 605 | "explainer = shap.Explainer(model_cat)\n", 606 | "shap_values_cat = explainer(X)\n", 607 | "\n", 608 | "print(np.shape(shap_values_cat))" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "# waterfall plot for first observation\n", 618 | "shap.plots.waterfall(shap_values_cat[0, :, 0])\n", 619 | "\n", 620 | "# waterfall plot for first observation\n", 621 | "shap.plots.waterfall(shap_values_cat[0, :, 1])\n", 622 | "\n", 623 | "# waterfall plot for first observation\n", 624 | "shap.plots.waterfall(shap_values_cat[0, :, 2])" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": {}, 631 | "outputs": [], 632 | "source": [ 633 | "def softmax(x):\n", 634 | " \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n", 635 | " e_x = np.exp(x - np.max(x))\n", 636 | " return e_x / e_x.sum(axis=0)\n", 637 | "\n", 638 | "\n", 639 | "# convert softmax to probability\n", 640 | "x = [0.383, -0.106, 1.211]\n", 641 | "softmax(x)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [ 650 | "# calculate mean SHAP values for each class\n", 651 | "mean_0 = np.mean(np.abs(shap_values_cat.values[:, :, 0]), axis=0)\n", 652 | "mean_1 = np.mean(np.abs(shap_values_cat.values[:, :, 1]), axis=0)\n", 653 | "mean_2 = np.mean(np.abs(shap_values_cat.values[:, :, 2]), axis=0)\n", 654 | "\n", 655 | "df = pd.DataFrame({\"young\": mean_0, \"medium\": mean_1, \"old\": mean_2})\n", 656 | "\n", 657 | "# plot mean SHAP values\n", 658 | "fig, ax = plt.subplots(1, 1, figsize=(20, 10))\n", 659 | "df.plot.bar(ax=ax)\n", 660 | "\n", 661 | "ax.set_ylabel(\"Mean SHAP\", size=30)\n", 662 | "ax.set_xticklabels(X.columns, rotation=45, size=20)\n", 663 | "ax.legend(fontsize=30)" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": null, 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [ 672 | "# get model predictions\n", 673 | "preds = model_cat.predict(X)\n", 674 | "\n", 675 | "new_shap_values = []\n", 676 | "for i, pred in enumerate(preds):\n", 677 | " # get shap values for predicted class\n", 678 | " new_shap_values.append(shap_values_cat.values[i][:, pred])\n", 679 | "\n", 680 | "# replace shap values\n", 681 | "shap_values_cat.values = np.array(new_shap_values)\n", 682 | "print(shap_values_cat.shape)" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": null, 688 | "metadata": {}, 689 | "outputs": [], 690 | "source": [ 691 | "shap.plots.bar(shap_values_cat)" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": null, 697 | "metadata": {}, 698 | "outputs": [], 699 | "source": [ 700 | "shap.plots.beeswarm(shap_values_cat)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": {}, 706 | "source": [ 707 | "# 5) SHAP interaction value\n" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 45, 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "# get SHAP interaction values\n", 717 | "explainer = shap.Explainer(model)\n", 718 | "shap_interaction = explainer.shap_interaction_values(X)" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": null, 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "# get shape of interaction values\n", 728 | "np.shape(shap_interaction)" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": null, 734 | "metadata": {}, 735 | "outputs": [], 736 | "source": [ 737 | "# SHAP interaction values for first employee\n", 738 | "shap_0 = np.round(shap_interaction[0], 2)\n", 739 | "pd.DataFrame(shap_0, index=X.columns, columns=X.columns)" 740 | ] 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "metadata": {}, 745 | "source": [ 746 | "## Mean SHAP interaction values\n" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": null, 752 | "metadata": {}, 753 | "outputs": [], 754 | "source": [ 755 | "# get absolute mean of matrices\n", 756 | "mean_shap = np.abs(shap_interaction).mean(0)\n", 757 | "mean_shap = np.round(mean_shap, 1)\n", 758 | "\n", 759 | "df = pd.DataFrame(mean_shap, index=X.columns, columns=X.columns)\n", 760 | "\n", 761 | "# times off diagonal by 2\n", 762 | "df.where(df.values == np.diagonal(df), df.values * 2, inplace=True)\n", 763 | "\n", 764 | "# display\n", 765 | "sns.set(font_scale=1)\n", 766 | "sns.heatmap(df, cmap=\"coolwarm\", annot=True)\n", 767 | "plt.yticks(rotation=0)" 768 | ] 769 | }, 770 | { 771 | "cell_type": "markdown", 772 | "metadata": {}, 773 | "source": [ 774 | "## Dependence plot\n" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": null, 780 | "metadata": {}, 781 | "outputs": [], 782 | "source": [ 783 | "shap.dependence_plot(\n", 784 | " (\"shell weight\", \"shucked weight\"), shap_interaction, X, display_features=X\n", 785 | ")" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": null, 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "# interaction between shell weight and shucked weight\n", 795 | "plt.scatter(data[\"shell weight\"], data[\"shucked weight\"], c=data[\"rings\"], cmap=\"bwr\")\n", 796 | "plt.colorbar(label=\"Number of Rings\", orientation=\"vertical\")\n", 797 | "\n", 798 | "plt.xlabel(\"shucked weight\", size=15)\n", 799 | "plt.ylabel(\"shell weight\", size=15)" 800 | ] 801 | }, 802 | { 803 | "cell_type": "markdown", 804 | "metadata": {}, 805 | "source": [ 806 | "# 6) SHAP for categorical variables\n" 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "execution_count": null, 812 | "metadata": {}, 813 | "outputs": [], 814 | "source": [ 815 | "X.head()" 816 | ] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "execution_count": null, 821 | "metadata": {}, 822 | "outputs": [], 823 | "source": [ 824 | "# Waterfall plot for first observation\n", 825 | "shap.plots.waterfall(shap_values[0])" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": null, 831 | "metadata": {}, 832 | "outputs": [], 833 | "source": [ 834 | "new_shap_values = []\n", 835 | "\n", 836 | "# loop over all shap values:\n", 837 | "for values in shap_values.values:\n", 838 | " # sum SHAP values for sex\n", 839 | " sv = list(values)\n", 840 | " sv = sv[0:5] + [sum(sv[5:8])]\n", 841 | "\n", 842 | " new_shap_values.append(sv)" 843 | ] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "execution_count": null, 848 | "metadata": {}, 849 | "outputs": [], 850 | "source": [ 851 | "# replace shap values\n", 852 | "shap_values.values = np.array(new_shap_values)\n", 853 | "\n", 854 | "# replace data with categorical feature values\n", 855 | "X_cat = data[\n", 856 | " [\"length\", \"height\", \"shucked weight\", \"viscera weight\", \"shell weight\", \"sex\"]\n", 857 | "]\n", 858 | "shap_values.data = np.array(X_cat)\n", 859 | "\n", 860 | "# update feature names\n", 861 | "shap_values.feature_names = list(X_cat.columns)" 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": null, 867 | "metadata": {}, 868 | "outputs": [], 869 | "source": [ 870 | "shap.plots.waterfall(shap_values[0])" 871 | ] 872 | }, 873 | { 874 | "cell_type": "code", 875 | "execution_count": null, 876 | "metadata": {}, 877 | "outputs": [], 878 | "source": [ 879 | "shap.plots.bar(shap_values)" 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": null, 885 | "metadata": {}, 886 | "outputs": [], 887 | "source": [ 888 | "shap.plots.beeswarm(shap_values)" 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "execution_count": null, 894 | "metadata": {}, 895 | "outputs": [], 896 | "source": [ 897 | "# get shaply values and data\n", 898 | "sex_values = shap_values[:, \"sex\"].values\n", 899 | "sex_data = shap_values[:, \"sex\"].data\n", 900 | "sex_categories = [\"I\", \"M\", \"F\"]\n", 901 | "\n", 902 | "# split sex shap values based on category\n", 903 | "sex_groups = []\n", 904 | "for s in sex_categories:\n", 905 | " relevant_values = sex_values[sex_data == s]\n", 906 | " sex_groups.append(relevant_values)\n", 907 | "\n", 908 | "# plot boxplot\n", 909 | "plt.boxplot(sex_groups, labels=sex_categories)\n", 910 | "\n", 911 | "plt.ylabel(\"SHAP values\", size=15)\n", 912 | "plt.xlabel(\"Sex\", size=15)" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": null, 918 | "metadata": {}, 919 | "outputs": [], 920 | "source": [ 921 | "# create for placeholder SHAP values\n", 922 | "shap_values_sex = explainer(X)\n", 923 | "\n", 924 | "# get shaply values and data\n", 925 | "sex_values = shap_values[:, \"sex\"].values\n", 926 | "sex_data = shap_values[:, \"sex\"].data\n", 927 | "sex_categories = [\"I\", \"M\", \"F\"]\n", 928 | "\n", 929 | "# create new SHAP values array\n", 930 | "\n", 931 | "# split odor SHAP values by unique odor categories\n", 932 | "new_shap_values = [\n", 933 | " np.array(pd.Series(sex_values)[sex_data == s]) for s in sex_categories\n", 934 | "]\n", 935 | "\n", 936 | "# each sublist needs to be the same length\n", 937 | "max_len = max([len(v) for v in new_shap_values])\n", 938 | "new_shap_values = [\n", 939 | " np.append(vs, [np.nan] * (max_len - len(vs))) for vs in new_shap_values\n", 940 | "]\n", 941 | "new_shap_values = np.array(new_shap_values)\n", 942 | "\n", 943 | "# transpost matrix so categories are columns and SHAP values are rows\n", 944 | "new_shap_values = new_shap_values.transpose()\n", 945 | "\n", 946 | "# replace shap values\n", 947 | "shap_values_sex.values = np.array(new_shap_values)\n", 948 | "\n", 949 | "# replace data with placeholder array\n", 950 | "shap_values_sex.data = np.array([[0] * len(sex_categories)] * max_len)\n", 951 | "\n", 952 | "# replace base data with placeholder array\n", 953 | "shap_values_sex.base = np.array([0] * max_len)\n", 954 | "\n", 955 | "# replace feature names with category labels\n", 956 | "shap_values_sex.feature_names = list(sex_categories)\n", 957 | "\n", 958 | "# use beeswarm as before\n", 959 | "shap.plots.beeswarm(shap_values_sex, color_bar=False)" 960 | ] 961 | }, 962 | { 963 | "cell_type": "code", 964 | "execution_count": null, 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [ 968 | "import warnings\n", 969 | "\n", 970 | "warnings.filterwarnings(\"ignore\")" 971 | ] 972 | } 973 | ], 974 | "metadata": { 975 | "kernelspec": { 976 | "display_name": "shap", 977 | "language": "python", 978 | "name": "shap" 979 | }, 980 | "language_info": { 981 | "codemirror_mode": { 982 | "name": "ipython", 983 | "version": 3 984 | }, 985 | "file_extension": ".py", 986 | "mimetype": "text/x-python", 987 | "name": "python", 988 | "nbconvert_exporter": "python", 989 | "pygments_lexer": "ipython3", 990 | "version": "3.10.4" 991 | } 992 | }, 993 | "nbformat": 4, 994 | "nbformat_minor": 2 995 | } 996 | --------------------------------------------------------------------------------