├── README.md
├── requirements.txt
├── LICENSE
├── src
├── additional_resources
│ ├── RandomForestRegressor.ipynb
│ ├── CatBoostClassifier.ipynb
│ └── IsolationForest.ipynb
├── archive
│ ├── project_2.ipynb
│ ├── project_1.ipynb
│ ├── project_2_solution.ipynb
│ ├── project_1_solution.ipynb
│ └── image_data.ipynb
├── kernel_vs_tree.ipynb
└── shap_tutorial.ipynb
└── data
└── interaction_dataset.csv
/README.md:
--------------------------------------------------------------------------------
1 | # SHAP-tutorial
2 |
3 | To access the course: https://adataodyssey.com/courses/shap-with-python/
4 |
5 | Watch the course outline here: https://youtu.be/n98pFxcD73w
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Automatically generated by https://github.com/damnever/pigar.
2 |
3 | catboost==1.1.1
4 | matplotlib==3.6.0
5 | numpy==1.23.3
6 | opencv-python==4.6.0.66
7 | pandas==1.5.0
8 | Pillow==9.2.0
9 | scikit-learn==1.2.0
10 | seaborn==0.12.2
11 | shap==0.41.0
12 | torch==1.13.1
13 | torchvision==0.14.1
14 | xgboost==1.7.3
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Conor O'Sullivan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/additional_resources/RandomForestRegressor.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SHAP for RandomForestRegressor\n",
8 | "
\n",
9 | "Dataset: https://www.kaggle.com/datasets/conorsully1/interaction-dataset"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Imports\n",
19 | "import pandas as pd\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "\n",
22 | "from sklearn.ensemble import RandomForestRegressor\n",
23 | "\n",
24 | "import shap\n",
25 | "shap.initjs() \n",
26 | "\n",
27 | "# Set figure background to white\n",
28 | "plt.rcParams.update({'figure.facecolor':'white'})"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "## Dataset"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "#import dataset\n",
45 | "data = pd.read_csv(\"../../data/interaction_dataset.csv\",sep='\\t')\n",
46 | "\n",
47 | "y = data['bonus']\n",
48 | "X = data.drop('bonus', axis=1)\n",
49 | "\n",
50 | "print(len(data))\n",
51 | "data.head()"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "# Modelling"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "#Train model\n",
68 | "model = RandomForestRegressor(n_estimators=100) \n",
69 | "model.fit(X, y)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "# SHAP Values"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 4,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "#Get SHAP values\n",
86 | "explainer = shap.Explainer(model)\n",
87 | "shap_values = explainer(X)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "# Plot waterfall\n",
97 | "shap.plots.waterfall(shap_values[0])"
98 | ]
99 | }
100 | ],
101 | "metadata": {
102 | "kernelspec": {
103 | "display_name": "shap",
104 | "language": "python",
105 | "name": "shap"
106 | },
107 | "language_info": {
108 | "codemirror_mode": {
109 | "name": "ipython",
110 | "version": 3
111 | },
112 | "file_extension": ".py",
113 | "mimetype": "text/x-python",
114 | "name": "python",
115 | "nbconvert_exporter": "python",
116 | "pygments_lexer": "ipython3",
117 | "version": "3.10.4"
118 | }
119 | },
120 | "nbformat": 4,
121 | "nbformat_minor": 2
122 | }
123 |
--------------------------------------------------------------------------------
/src/archive/project_2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Project 2: mushroom classification\n",
8 | "
\n",
9 | "Use the SHAP analysis to answer the following questions:\n",
10 | "
\n",
11 | "- For the first prediction, which feature has the most significant contribution?\n",
12 | "
- Overall, which feature has the most significant contributions? \n",
13 | "
- Which odors are associated with poisonous mushrooms? \n",
14 | "
\n",
15 | "\n",
16 | "Dataset: https://www.kaggle.com/datasets/uciml/mushroom-classification"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "#imports\n",
26 | "import pandas as pd\n",
27 | "import numpy as np\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "\n",
30 | "from catboost import CatBoostClassifier\n",
31 | "\n",
32 | "import shap\n",
33 | "\n",
34 | "from sklearn.metrics import accuracy_score,confusion_matrix"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "#load data \n",
44 | "data = pd.read_csv(\"../data/mushrooms.csv\")\n",
45 | "\n",
46 | "#get features\n",
47 | "y = data['class']\n",
48 | "y = y.astype('category').cat.codes\n",
49 | "X = data.drop('class', axis=1)\n",
50 | "\n",
51 | "\n",
52 | "print(len(data))\n",
53 | "data.head()"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "scrolled": true
61 | },
62 | "outputs": [],
63 | "source": [
64 | "model = CatBoostClassifier(iterations=20,\n",
65 | " learning_rate=0.01,\n",
66 | " depth=3)\n",
67 | "\n",
68 | "# train model\n",
69 | "cat_features = list(range(len(X.columns)))\n",
70 | "model.fit(X, y, cat_features)\n",
71 | "\n",
72 | "#Get predictions\n",
73 | "y_pred = model.predict(X)\n",
74 | "\n",
75 | "print(confusion_matrix(y, y_pred))\n",
76 | "accuracy_score(y, y_pred)"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "# Standard SHAP values"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 4,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "# get shap values\n",
93 | "explainer = shap.Explainer(model)\n",
94 | "shap_values = explainer(X)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "#For the first prediction, which feature has the most significant contribution?"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "#Overall, which feature has the most significant contributions?"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "#Which odors are associated with poisonous mushrooms?"
122 | ]
123 | }
124 | ],
125 | "metadata": {
126 | "kernelspec": {
127 | "display_name": "SHAP",
128 | "language": "python",
129 | "name": "shap"
130 | },
131 | "language_info": {
132 | "codemirror_mode": {
133 | "name": "ipython",
134 | "version": 3
135 | },
136 | "file_extension": ".py",
137 | "mimetype": "text/x-python",
138 | "name": "python",
139 | "nbconvert_exporter": "python",
140 | "pygments_lexer": "ipython3",
141 | "version": "3.10.6"
142 | }
143 | },
144 | "nbformat": 4,
145 | "nbformat_minor": 2
146 | }
147 |
--------------------------------------------------------------------------------
/src/additional_resources/CatBoostClassifier.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "#imports\n",
10 | "import pandas as pd\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "\n",
13 | "from catboost import CatBoostClassifier\n",
14 | "import xgboost as xgb\n",
15 | "\n",
16 | "import shap\n",
17 | "\n",
18 | "from sklearn.metrics import accuracy_score,confusion_matrix\n",
19 | "\n",
20 | "#set figure background to white\n",
21 | "plt.rcParams.update({'figure.facecolor':'white'})"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Dataset"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "#load data \n",
38 | "data = pd.read_csv(\"../../data/mushrooms.csv\")\n",
39 | "\n",
40 | "#get features\n",
41 | "y = data['class']\n",
42 | "y = y.astype('category').cat.codes\n",
43 | "X = data.drop('class', axis=1)\n",
44 | "\n",
45 | "print(len(data))\n",
46 | "data.head()"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "# XGBoost"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# Create dummy variables for the categorical features\n",
63 | "X_dummy = pd.get_dummies(X)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "# Fit model\n",
73 | "model = xgb.XGBClassifier()\n",
74 | "model.fit(X_dummy, y)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "#Get SHAP values\n",
84 | "explainer = shap.Explainer(model)\n",
85 | "shap_values = explainer(X_dummy)\n",
86 | "\n",
87 | "# Display SHAP values for the first observation\n",
88 | "shap.plots.waterfall(shap_values[0])"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "# CatBoost"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "model = CatBoostClassifier(iterations=20,\n",
105 | " learning_rate=0.01,\n",
106 | " depth=3)\n",
107 | "\n",
108 | "# train model\n",
109 | "cat_features = list(range(len(X.columns)))\n",
110 | "model.fit(X, y, cat_features)"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "#Get SHAP values\n",
120 | "explainer = shap.Explainer(model)\n",
121 | "shap_values = explainer(X)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "# Display SHAP values for the first observation\n",
131 | "shap.plots.waterfall(shap_values[0])"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "# Mean SHAP \n",
141 | "shap.plots.bar(shap_values)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "# Beeswarm plot \n",
151 | "shap.plots.beeswarm(shap_values)"
152 | ]
153 | }
154 | ],
155 | "metadata": {
156 | "kernelspec": {
157 | "display_name": "shap",
158 | "language": "python",
159 | "name": "shap"
160 | },
161 | "language_info": {
162 | "codemirror_mode": {
163 | "name": "ipython",
164 | "version": 3
165 | },
166 | "file_extension": ".py",
167 | "mimetype": "text/x-python",
168 | "name": "python",
169 | "nbconvert_exporter": "python",
170 | "pygments_lexer": "ipython3",
171 | "version": "3.10.4"
172 | }
173 | },
174 | "nbformat": 4,
175 | "nbformat_minor": 2
176 | }
177 |
--------------------------------------------------------------------------------
/src/archive/project_1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Project 1: salary bonus\n",
8 | "
\n",
9 | "Use SHAP to answer the following questions:\n",
10 | "\n",
11 | "- Which features do NOT have a significant relationship with bonus?\n",
12 | "
- What tends to happens to an employee's bonus as they gain more experience? \n",
13 | "
- Are there any potential interactions in the dataset? \n",
14 | "
\n",
15 | "
\n",
16 | "Dataset: https://www.kaggle.com/conorsully1/interaction-dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "#imports\n",
26 | "import pandas as pd\n",
27 | "import numpy as np\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "import seaborn as sns\n",
30 | "\n",
31 | "from sklearn.ensemble import RandomForestRegressor\n",
32 | "\n",
33 | "import shap\n",
34 | "shap.initjs()\n",
35 | "\n",
36 | "path = \"/Users/conorosully/Google Drive/My Drive/Medium/SHAP Interactions/Figures/{}\""
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "## Dataset"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "#import dataset\n",
53 | "data = pd.read_csv(\"../data/interaction_dataset.csv\",sep='\\t')\n",
54 | "\n",
55 | "y = data['bonus']\n",
56 | "X = data.drop('bonus', axis=1)\n",
57 | "\n",
58 | "print(len(data))\n",
59 | "data.head()"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Modelling"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "#Train model\n",
76 | "model = RandomForestRegressor(n_estimators=100) \n",
77 | "model.fit(X, y)\n",
78 | "\n",
79 | "#Get predictions\n",
80 | "y_pred = model.predict(X)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "#Model evaluation\n",
90 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,8))\n",
91 | "\n",
92 | "plt.scatter(y,y_pred)\n",
93 | "plt.plot([0, 400], [0, 400], color='r', linestyle='-', linewidth=2)\n",
94 | "\n",
95 | "plt.ylabel('Predicted',size=20)\n",
96 | "plt.xlabel('Actual',size=20)"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "## Standard SHAP values"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 5,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "#Get SHAP values\n",
113 | "explainer = shap.Explainer(model,X[0:10])\n",
114 | "shap_values = explainer(X)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "# Which features do NOT have a significant relationship with bonus?"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "# What tends to happens to an employee's bonus as they gain more experience? "
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "## SHAP interaction values"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 6,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "#Get SHAP interaction values\n",
149 | "explainer = shap.Explainer(model)\n",
150 | "shap_interaction = explainer.shap_interaction_values(X)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "# Are there any potential interactions in the dataset? "
160 | ]
161 | }
162 | ],
163 | "metadata": {
164 | "kernelspec": {
165 | "display_name": "SHAP",
166 | "language": "python",
167 | "name": "shap"
168 | },
169 | "language_info": {
170 | "codemirror_mode": {
171 | "name": "ipython",
172 | "version": 3
173 | },
174 | "file_extension": ".py",
175 | "mimetype": "text/x-python",
176 | "name": "python",
177 | "nbconvert_exporter": "python",
178 | "pygments_lexer": "ipython3",
179 | "version": "3.10.6"
180 | }
181 | },
182 | "nbformat": 4,
183 | "nbformat_minor": 2
184 | }
185 |
--------------------------------------------------------------------------------
/src/archive/project_2_solution.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Project 2: mushroom classification\n",
8 | "
\n",
9 | "Use the SHAP analysis to answer the following questions:\n",
10 | "\n",
11 | "- For the first prediction, which feature has the most significant contibution?\n",
12 | "
- Overall, which feature has the most significant contributions? \n",
13 | "
- Which odors are associated with poisonous mushrooms? \n",
14 | "
\n",
15 | "\n",
16 | "Dataset: https://www.kaggle.com/datasets/uciml/mushroom-classification"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "#imports\n",
26 | "import pandas as pd\n",
27 | "import numpy as np\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "\n",
30 | "from catboost import CatBoostClassifier\n",
31 | "\n",
32 | "import shap\n",
33 | "\n",
34 | "from sklearn.metrics import accuracy_score,confusion_matrix"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "#load data \n",
44 | "data = pd.read_csv(\"../data/mushrooms.csv\")\n",
45 | "\n",
46 | "#get features\n",
47 | "y = data['class']\n",
48 | "y = y.astype('category').cat.codes\n",
49 | "X = data.drop('class', axis=1)\n",
50 | "\n",
51 | "# replace all categorical features with integer values\n",
52 | "for col in X.columns:\n",
53 | " X[col] = X[col].astype('category').cat.codes\n",
54 | "\n",
55 | "\n",
56 | "print(len(data))\n",
57 | "X.head()"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "# Standard SHAP values"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 4,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# get shap values\n",
74 | "explainer = shap.Explainer(model)\n",
75 | "shap_values = explainer(X)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "#For the first prediction, which feature has the most significant contribution?\n",
85 | "#Answer: odor\n",
86 | "shap.plots.waterfall(shap_values[0],max_display=5)"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "#Overall, which feature has the most significant contributions?\n",
96 | "#Answer: odor\n",
97 | "shap.plots.bar(shap_values,show=False)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "shap.plots.beeswarm(shap_values)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "#Which odors are associated with poisonous mushrooms?\n",
116 | "#All the odors with SHAP values > 0 "
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "#get shaply values and data\n",
126 | "odor_values = shap_values[:,4].values\n",
127 | "odor_data = X['odor']\n",
128 | "unique_odor = set(X['odor'])\n",
129 | "\n",
130 | "#split odor shap values based on odor category\n",
131 | "odor_categories = list(set(odor_data))\n",
132 | "\n",
133 | "odor_groups = []\n",
134 | "for o in odor_categories:\n",
135 | " relevant_values = odor_values[odor_data == o]\n",
136 | " odor_groups.append(relevant_values)\n",
137 | " \n",
138 | "#replace categories with labels\n",
139 | "odor_labels = {'a':'almond',\n",
140 | " 'l':'anise', \n",
141 | " 'c':'creosote', \n",
142 | " 'y':'fishy', \n",
143 | " 'f':'foul', \n",
144 | " 'm':'musty', \n",
145 | " 'n':'none', \n",
146 | " 'p':'pungent', \n",
147 | " 's':'spicy'}\n",
148 | "\n",
149 | "labels = [odor_labels[u] for u in unique_odor]\n",
150 | "\n",
151 | "#plot boxplot\n",
152 | "plt.figure(figsize=(8, 5))\n",
153 | "\n",
154 | "plt.boxplot(odor_groups,labels=labels)\n",
155 | "\n",
156 | "plt.ylabel('SHAP values',size=15)\n",
157 | "plt.xlabel('Odor',size=15)"
158 | ]
159 | }
160 | ],
161 | "metadata": {
162 | "kernelspec": {
163 | "display_name": "xai",
164 | "language": "python",
165 | "name": "xai"
166 | },
167 | "language_info": {
168 | "codemirror_mode": {
169 | "name": "ipython",
170 | "version": 3
171 | },
172 | "file_extension": ".py",
173 | "mimetype": "text/x-python",
174 | "name": "python",
175 | "nbconvert_exporter": "python",
176 | "pygments_lexer": "ipython3",
177 | "version": "3.9.12"
178 | }
179 | },
180 | "nbformat": 4,
181 | "nbformat_minor": 2
182 | }
183 |
--------------------------------------------------------------------------------
/src/archive/project_1_solution.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Project 1: salary bonus\n",
8 | "
\n",
9 | "Use the SHAP analysis to answer the following questions:\n",
10 | "\n",
11 | "- Which features does NOT have a significant relationship with bonus?\n",
12 | "
- What tends to happens to an employee's bonus as they gain more experience? \n",
13 | "
- Are there any potential interactions in the dataset? \n",
14 | "
\n",
15 | "
\n",
16 | "Dataset: https://www.kaggle.com/conorsully1/interaction-dataset"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "#imports\n",
26 | "import pandas as pd\n",
27 | "import numpy as np\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "import seaborn as sns\n",
30 | "\n",
31 | "from sklearn.ensemble import RandomForestRegressor\n",
32 | "\n",
33 | "import shap\n",
34 | "shap.initjs()\n",
35 | "\n",
36 | "path = \"/Users/conorosully/Google Drive/My Drive/Medium/SHAP Interactions/Figures/{}\""
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "## Dataset"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "#import dataset\n",
53 | "data = pd.read_csv(\"../data/interaction_dataset.csv\",sep='\\t')\n",
54 | "\n",
55 | "y = data['bonus']\n",
56 | "X = data.drop('bonus', axis=1)\n",
57 | "\n",
58 | "print(len(data))\n",
59 | "data.head()"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "## Modelling"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "#Train model\n",
76 | "model = RandomForestRegressor(n_estimators=100) \n",
77 | "model.fit(X, y)\n",
78 | "\n",
79 | "#Get predictions\n",
80 | "y_pred = model.predict(X)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "#Model evaluation\n",
90 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,8))\n",
91 | "\n",
92 | "plt.scatter(y,y_pred)\n",
93 | "plt.plot([0, 400], [0, 400], color='r', linestyle='-', linewidth=2)\n",
94 | "\n",
95 | "plt.ylabel('Predicted',size=20)\n",
96 | "plt.xlabel('Actual',size=20)"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "## Standard SHAP values"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 7,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "#Get SHAP values\n",
113 | "explainer = shap.Explainer(model,X[0:10])\n",
114 | "shap_values = explainer(X)"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "# waterfall plot for first observation\n",
124 | "shap.plots.waterfall(shap_values[0])"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "# Which features do NOT have a significant relationship with bonus?\n",
134 | "# Answer: days_late\n",
135 | "shap.plots.bar(shap_values)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "#What tends to happens to an employee's bonus as they gain more experience? \n",
145 | "# Answer: their bonus increases\n",
146 | "# You could have also used a dependency plot\n",
147 | "shap.plots.beeswarm(shap_values)"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "## SHAP interaction values"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 9,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "#Get SHAP interaction values\n",
164 | "explainer = shap.Explainer(model)\n",
165 | "shap_interaction = explainer.shap_interaction_values(X)"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "# Are there any potential interactions in the dataset? \n",
175 | "# Answer: yes - experience.degree & performance.sales"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "# Get absolute mean of matrices\n",
185 | "mean_shap = np.abs(shap_interaction).mean(0)\n",
186 | "df = pd.DataFrame(mean_shap,index=X.columns,columns=X.columns)\n",
187 | "\n",
188 | "# times off diagonal by 2\n",
189 | "df.where(df.values == np.diagonal(df),df.values*2,inplace=True)\n",
190 | "\n",
191 | "# display \n",
192 | "plt.figure(figsize=(10, 10), facecolor='w', edgecolor='k')\n",
193 | "sns.set(font_scale=1.5)\n",
194 | "sns.heatmap(df,cmap='coolwarm',annot=True,fmt='.3g',cbar=False)\n",
195 | "plt.yticks(rotation=0) "
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "# Experience-degree depenence plot\n",
205 | "shap.dependence_plot(\n",
206 | " (\"experience\", \"degree\"),\n",
207 | " shap_interaction, X,\n",
208 | " display_features=X)"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "metadata": {},
215 | "outputs": [],
216 | "source": [
217 | "# Performance-sales depenence plot\n",
218 | "shap.dependence_plot(\n",
219 | " (\"performance\", \"sales\"),\n",
220 | " shap_interaction, X,\n",
221 | " display_features=X)"
222 | ]
223 | }
224 | ],
225 | "metadata": {
226 | "kernelspec": {
227 | "display_name": "shap",
228 | "language": "python",
229 | "name": "shap"
230 | },
231 | "language_info": {
232 | "codemirror_mode": {
233 | "name": "ipython",
234 | "version": 3
235 | },
236 | "file_extension": ".py",
237 | "mimetype": "text/x-python",
238 | "name": "python",
239 | "nbconvert_exporter": "python",
240 | "pygments_lexer": "ipython3",
241 | "version": "3.10.4"
242 | }
243 | },
244 | "nbformat": 4,
245 | "nbformat_minor": 2
246 | }
247 |
--------------------------------------------------------------------------------
/src/additional_resources/IsolationForest.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Imports \n",
10 | "import pandas as pd\n",
11 | "import numpy as np\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "import seaborn as sns\n",
14 | "\n",
15 | "import shap\n",
16 | "\n",
17 | "from sklearn.ensemble import IsolationForest\n",
18 | "from ucimlrepo import fetch_ucirepo\n",
19 | "\n",
20 | "# Set figure background to white\n",
21 | "plt.rcParams.update({'figure.facecolor':'white'})"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Data Cleaning and Feature Engineering"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# Fetch dataset from UCI repository\n",
38 | "power_consumption = fetch_ucirepo(id=235)\n",
39 | "\n",
40 | "print(power_consumption.variables) "
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "# Get all features\n",
50 | "data = power_consumption.data.features\n",
51 | "data['Date'] = pd.to_datetime(data['Date'], format='%d/%m/%Y')\n",
52 | "\n",
53 | "# List of features to check\n",
54 | "feature_columns = ['Global_active_power', 'Global_reactive_power', 'Voltage', \n",
55 | " 'Global_intensity', 'Sub_metering_1', 'Sub_metering_2', 'Sub_metering_3']\n",
56 | "\n",
57 | "# Convert feature columns to numeric and replace any errors with NaN\n",
58 | "data[feature_columns] = data[feature_columns].apply(pd.to_numeric, errors='coerce')\n",
59 | "\n",
60 | "# Drop rows where all feature columns are missing (NaN) \n",
61 | "data_cleaned = data.dropna(subset=feature_columns, how='all')\n",
62 | "\n",
63 | "# Drop rows where ALL feature columns are NaN\n",
64 | "data_cleaned.head()"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# Group by 'Date' and calculate mean and standard deviation (ignore NaN values)\n",
74 | "data_aggregated = data_cleaned.groupby('Date')[feature_columns].agg(['mean', 'std'])\n",
75 | "\n",
76 | "# Rename columns to the desired format (MEAN_ColumnName, STD_ColumnName)\n",
77 | "data_aggregated.columns = [\n",
78 | " f'{agg_type.upper()}_{col}' for col, agg_type in data_aggregated.columns\n",
79 | "]\n",
80 | "\n",
81 | "# Reset the index\n",
82 | "data_aggregated.reset_index(inplace=True)\n",
83 | "\n",
84 | "# Display the result\n",
85 | "print(data_aggregated.shape)\n",
86 | "data_aggregated.head()"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "# Train IsolationForest"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 10,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "# Parameters\n",
103 | "n_estimators = 100 # Number of trees\n",
104 | "sample_size = 256 # Number of samples used to train each tree\n",
105 | "contamination = 0.02 # Expected proportion of anomalies"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "# Select Features\n",
115 | "features = data_aggregated.drop('Date', axis=1)\n",
116 | "\n",
117 | "# Train Isolation Forest\n",
118 | "iso_forest = IsolationForest(n_estimators=n_estimators, \n",
119 | " contamination=contamination, \n",
120 | " max_samples=sample_size,\n",
121 | " random_state=42)\n",
122 | "\n",
123 | "iso_forest.fit(features)"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "data_aggregated['anomaly_score'] = iso_forest.decision_function(features)\n",
133 | "data_aggregated['anomaly'] = iso_forest.predict(features)\n",
134 | "\n",
135 | "data_aggregated['anomaly'].value_counts()"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "# Visualization of the results\n",
145 | "plt.figure(figsize=(10, 5))\n",
146 | "\n",
147 | "# Plot normal instances\n",
148 | "normal = data_aggregated[data_aggregated['anomaly'] == 1]\n",
149 | "plt.scatter(normal['Date'], normal['anomaly_score'], label='Normal')\n",
150 | "\n",
151 | "# Plot anomalies\n",
152 | "anomalies = data_aggregated[data_aggregated['anomaly'] == -1]\n",
153 | "plt.scatter(anomalies['Date'], anomalies['anomaly_score'], label='Anomaly')\n",
154 | "\n",
155 | "plt.xlabel(\"Instance\")\n",
156 | "plt.ylabel(\"Anomaly Score\")\n",
157 | "plt.legend()"
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "metadata": {},
163 | "source": [
164 | "# KernelSHAP with Anomaly Score\n"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "# Using the anomaly score and TreeSHAP (this code won't work)\n",
174 | "explainer = shap.TreeExplainer(iso_forest.decision_function, features)\n",
175 | "shap_values = explainer(features)"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "# Select all anomalies and 100 random normal instances\n",
185 | "normal_sample = np.random.choice(normal.index,size=100,replace=False)\n",
186 | "sample = np.append(anomalies.index,normal_sample)\n",
187 | "\n",
188 | "len(sample) # 129"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "# Using the anomaly score and KernelSHAP\n",
198 | "explainer = shap.Explainer(iso_forest.decision_function, features)\n",
199 | "shap_values = explainer(features.iloc[sample])"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "# Plot waterfall plot of an anomaly\n",
209 | "shap.plots.waterfall(shap_values[0])"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "# Plot waterfall plot of a normal instance\n",
219 | "shap.plots.waterfall(shap_values[100])"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "# MeanSHAP Plot\n",
229 | "shap.plots.bar(shap_values)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "# Beeswarm plot\n",
239 | "shap.plots.beeswarm(shap_values)"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "# TreeSHAP with Path Length"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 22,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "# Calculate SHAP values\n",
256 | "explainer = shap.TreeExplainer(iso_forest)\n",
257 | "shap_values = explainer(features)"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "# Waterfall plot for an anomaly\n",
267 | "shap.plots.waterfall(shap_values[0])"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "# Waterfall plot for a normal instance\n",
277 | "shap.plots.waterfall(shap_values[2])"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "# Calculate f(x)\n",
287 | "path_length = shap_values.base_values + shap_values.values.sum(axis=1)\n",
288 | "\n",
289 | "# Get f(x) for anomalies and normal instances\n",
290 | "anomalies = data_aggregated[data_aggregated['anomaly'] == -1]\n",
291 | "path_length_anomalies = path_length[anomalies.index]\n",
292 | "\n",
293 | "normal = data_aggregated[data_aggregated['anomaly'] == 1]\n",
294 | "path_length_normal = path_length[normal.index]\n",
295 | "\n",
296 | "# Plot boxplots for f(x)\n",
297 | "plt.figure(figsize=(10, 5))\n",
298 | "plt.boxplot([path_length_anomalies, path_length_normal], labels=['Anomaly','Normal'])\n",
299 | "plt.ylabel(\"Average Path Length f(x)\")"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "# MeanSHAP\n",
309 | "shap.plots.bar(shap_values)"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": null,
315 | "metadata": {},
316 | "outputs": [],
317 | "source": [
318 | "# MeanSHAP\n",
319 | "shap.plots.beeswarm(shap_values)"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": 26,
325 | "metadata": {},
326 | "outputs": [],
327 | "source": [
328 | "# Interaction values\n",
329 | "shap_interaction_values = explainer.shap_interaction_values(features)"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "metadata": {},
336 | "outputs": [],
337 | "source": [
338 | "# Get absolute mean of matrices\n",
339 | "mean_shap = np.abs(shap_interaction_values).mean(0)\n",
340 | "mean_shap = np.round(mean_shap, 1)\n",
341 | "\n",
342 | "df = pd.DataFrame(mean_shap, index=features.columns, columns=features.columns)\n",
343 | "\n",
344 | "# Times off diagonal by 2\n",
345 | "df.where(df.values == np.diagonal(df), df.values * 2, inplace=True)\n",
346 | "\n",
347 | "# Display\n",
348 | "sns.set(font_scale=1)\n",
349 | "sns.heatmap(df, cmap=\"coolwarm\", annot=True)\n",
350 | "plt.yticks(rotation=0)"
351 | ]
352 | }
353 | ],
354 | "metadata": {
355 | "kernelspec": {
356 | "display_name": "shap",
357 | "language": "python",
358 | "name": "shap"
359 | },
360 | "language_info": {
361 | "codemirror_mode": {
362 | "name": "ipython",
363 | "version": 3
364 | },
365 | "file_extension": ".py",
366 | "mimetype": "text/x-python",
367 | "name": "python",
368 | "nbconvert_exporter": "python",
369 | "pygments_lexer": "ipython3",
370 | "version": "3.10.4"
371 | }
372 | },
373 | "nbformat": 4,
374 | "nbformat_minor": 2
375 | }
376 |
--------------------------------------------------------------------------------
/src/archive/image_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Using SHAP to debug a PyTorch Image Regression Model"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 2,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "# Imports\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "\n",
22 | "import glob \n",
23 | "import random \n",
24 | "\n",
25 | "from PIL import Image\n",
26 | "import cv2\n",
27 | "\n",
28 | "import torch\n",
29 | "import torchvision\n",
30 | "from torchvision import transforms\n",
31 | "from torch.utils.data import DataLoader\n",
32 | "\n",
33 | "import shap\n",
34 | "from sklearn.metrics import mean_squared_error"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "#Load example image\n",
44 | "name = \"32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg\"\n",
45 | "x = int(name.split(\"_\")[0])\n",
46 | "y = int(name.split(\"_\")[1])\n",
47 | "\n",
48 | "img = Image.open(\"../data/room_1/\" + name)\n",
49 | "img = np.array(img)\n",
50 | "cv2.circle(img, (x, y), 8, (0, 255, 0), 3)\n",
51 | "\n",
52 | "plt.imshow(img)\n",
53 | "\n",
54 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/example.png\"\n",
55 | "plt.savefig(path, bbox_inches='tight',facecolor='w', edgecolor='w', transparent=False,dpi=200)"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# Model Training"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 5,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "class ImageDataset(torch.utils.data.Dataset):\n",
72 | " def __init__(self, paths, transform):\n",
73 | "\n",
74 | " self.transform = transform\n",
75 | " self.paths = paths\n",
76 | "\n",
77 | " def __getitem__(self, idx):\n",
78 | " \"\"\"Get image and target (x, y) coordinates\"\"\"\n",
79 | "\n",
80 | " # Read image\n",
81 | " path = self.paths[idx]\n",
82 | " image = cv2.imread(path, cv2.IMREAD_COLOR)\n",
83 | " image = Image.fromarray(image)\n",
84 | "\n",
85 | " # Transform image\n",
86 | " image = self.transform(image)\n",
87 | " \n",
88 | " # Get target\n",
89 | " target = self.get_target(path)\n",
90 | " target = torch.Tensor(target)\n",
91 | "\n",
92 | " return image, target\n",
93 | " \n",
94 | " def get_target(self,path):\n",
95 | " \"\"\"Get the target (x, y) coordinates from path\"\"\"\n",
96 | "\n",
97 | " name = os.path.basename(path)\n",
98 | " items = name.split('_')\n",
99 | " x = items[0]\n",
100 | " y = items[1]\n",
101 | "\n",
102 | " # Scale between -1 and 1\n",
103 | " x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right\n",
104 | " y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom\n",
105 | "\n",
106 | " return [x, y]\n",
107 | "\n",
108 | " def __len__(self):\n",
109 | " return len(self.paths)\n"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "TRANSFORMS = transforms.Compose([\n",
119 | " transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),\n",
120 | " transforms.Resize((224, 224)),\n",
121 | " transforms.ToTensor(),\n",
122 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
123 | "])\n",
124 | "\n",
125 | "all_rooms = False # Change if you want to use all the data\n",
126 | "\n",
127 | "paths = glob.glob('../data/room_1/*')\n",
128 | "if all_rooms:\n",
129 | " paths = paths + glob.glob('../data/room_2/*') + glob.glob('../data/room_3/*')\n",
130 | "\n",
131 | "# Shuffle the paths\n",
132 | "random.shuffle(paths)\n",
133 | "\n",
134 | "# Create a datasets for training and validation\n",
135 | "split = int(0.8 * len(paths))\n",
136 | "train_data = ImageDataset(paths[:split], TRANSFORMS)\n",
137 | "valid_data = ImageDataset(paths[split:], TRANSFORMS)\n",
138 | "\n",
139 | "# Prepare data for Pytorch model\n",
140 | "train_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
141 | "valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())\n",
142 | "\n",
143 | "print(train_data.__len__())\n",
144 | "print(valid_data.__len__())"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 63,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "output_dim = 2 # x, y\n",
154 | "device = torch.device('cpu') # or 'cuda' if you have a GPU\n",
155 | "\n",
156 | "# RESNET 18\n",
157 | "model = torchvision.models.resnet18(pretrained=True)\n",
158 | "model.fc = torch.nn.Linear(512, output_dim)\n",
159 | "model = model.to(device)\n",
160 | "\n",
161 | "optimizer = torch.optim.Adam(model.parameters())"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "name = \"direction_model_1\" # Change this to save a new model\n",
171 | "\n",
172 | "# Train the model\n",
173 | "min_loss = np.inf\n",
174 | "for epoch in range(10):\n",
175 | "\n",
176 | " model = model.train()\n",
177 | " for images, target in iter(train_loader):\n",
178 | "\n",
179 | " images = images.to(device)\n",
180 | " target = target.to(device)\n",
181 | " \n",
182 | " # Zero gradients of parameters\n",
183 | " optimizer.zero_grad() \n",
184 | "\n",
185 | " # Execute model to get outputs\n",
186 | " output = model(images)\n",
187 | "\n",
188 | " # Calculate loss\n",
189 | " loss = torch.nn.functional.mse_loss(output, target)\n",
190 | "\n",
191 | " # Run backpropogation to accumulate gradients\n",
192 | " loss.backward()\n",
193 | "\n",
194 | " # Update model parameters\n",
195 | " optimizer.step()\n",
196 | "\n",
197 | " # Calculate validation loss\n",
198 | " model = model.eval()\n",
199 | "\n",
200 | " images, target = next(iter(valid_loader))\n",
201 | " images = images.to(device)\n",
202 | " target = target.to(device)\n",
203 | "\n",
204 | " output = model(images)\n",
205 | " valid_loss = torch.nn.functional.mse_loss(output, target)\n",
206 | "\n",
207 | " print(\"Epoch: {}, Validation Loss: {}\".format(epoch, valid_loss.item()))\n",
208 | " \n",
209 | " if valid_loss < min_loss:\n",
210 | " print(\"Saving model\")\n",
211 | " torch.save(model, '../models/{}.pth'.format(name))\n",
212 | "\n",
213 | " min_loss = valid_loss"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "# Model Evaluation"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 9,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "def model_evaluation(loaders,labels,save_path = None):\n",
230 | "\n",
231 | " \"\"\"Evaluate direction models with mse and scatter plots\n",
232 | " loaders: list of data loaders\n",
233 | " labels: list of labels for plot title\"\"\"\n",
234 | "\n",
235 | " n = len(loaders)\n",
236 | " fig, axs = plt.subplots(1, n, figsize=(7*n, 6))\n",
237 | " fig.patch.set_facecolor('xkcd:white')\n",
238 | "\n",
239 | " # Evalution metrics\n",
240 | " for i, loader in enumerate(loaders):\n",
241 | "\n",
242 | " # Load all data\n",
243 | " images, target = next(iter(loader))\n",
244 | " images = images.to(device)\n",
245 | " target = target.to(device)\n",
246 | "\n",
247 | " output=model(images)\n",
248 | "\n",
249 | " # Get x predictions\n",
250 | " x_pred=output.detach().cpu().numpy()[:,0]\n",
251 | " x_target=target.cpu().numpy()[:,0]\n",
252 | "\n",
253 | " # Calculate MSE\n",
254 | " mse = mean_squared_error(x_target, x_pred)\n",
255 | "\n",
256 | " # Plot predcitons\n",
257 | " axs[i].scatter(x_target,x_pred)\n",
258 | " axs[i].plot([-1, 1], \n",
259 | " [-1, 1], \n",
260 | " color='r', \n",
261 | " linestyle='-', \n",
262 | " linewidth=2)\n",
263 | "\n",
264 | " axs[i].set_ylabel('Predicted x', size =15)\n",
265 | " axs[i].set_xlabel('Actual x', size =15)\n",
266 | " axs[i].set_title(\"{0} MSE: {1:.4f}\".format(labels[i], mse),size = 18)\n",
267 | "\n",
268 | " if save_path != None:\n",
269 | " fig.savefig(save_path)\n"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {},
276 | "outputs": [],
277 | "source": [
278 | "# Load saved model \n",
279 | "model = torch.load('../models/direction_model_1.pth')\n",
280 | "model.eval()\n",
281 | "model.to(device)\n",
282 | "\n",
283 | "# Create new loader for all data\n",
284 | "train_loader = DataLoader(train_data, batch_size=train_data.__len__())\n",
285 | "\n",
286 | "# Evaluate model on training and validation set\n",
287 | "loaders = [train_loader,valid_loader]\n",
288 | "labels = [\"Train\",\"Validation\"]\n",
289 | "\n",
290 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_1.png\"\n",
291 | "model_evaluation(loaders,labels,save_path=path)"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "metadata": {},
298 | "outputs": [],
299 | "source": [
300 | "# Evaluate on data for additonal rooms\n",
301 | "room_2 = glob.glob('../data/room_2/*')\n",
302 | "room_3 = glob.glob('../data/room_3/*')\n",
303 | "\n",
304 | "room_2_data = ImageDataset(room_2, TRANSFORMS)\n",
305 | "room_3_data = ImageDataset(room_3, TRANSFORMS)\n",
306 | "\n",
307 | "room_2_loader = DataLoader(room_2_data, batch_size=room_2_data.__len__())\n",
308 | "room_3_loader = DataLoader(room_3_data, batch_size=room_3_data.__len__())\n",
309 | "\n",
310 | "# Evaluate model on training and validation set\n",
311 | "loaders = [room_2_loader ,room_3_loader]\n",
312 | "labels = [\"Room 2\",\"Room 3\"]\n",
313 | "\n",
314 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_2.png\"\n",
315 | "model_evaluation(loaders,labels, save_path=path)"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "# Load saved model \n",
325 | "model = torch.load('../models/direction_model_2.pth')\n",
326 | "\n",
327 | "model.eval()\n",
328 | "model.to(device)\n",
329 | "\n",
330 | "# Evaluate model on training and validation set\n",
331 | "loaders = [room_2_loader ,room_3_loader]\n",
332 | "labels = [\"Room 2\",\"Room 3\"]\n",
333 | "\n",
334 | "path = \"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/evaluation_3.png\"\n",
335 | "model_evaluation(loaders,labels,save_path=path)"
336 | ]
337 | },
338 | {
339 | "attachments": {},
340 | "cell_type": "markdown",
341 | "metadata": {},
342 | "source": [
343 | "# SHAP Explainer "
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 7,
349 | "metadata": {},
350 | "outputs": [],
351 | "source": [
352 | "# Load saved model \n",
353 | "model = torch.load('../models/direction_model_1.pth') #change for different model\n",
354 | "model.eval()\n",
355 | "\n",
356 | "# Use CPU\n",
357 | "device = torch.device('cpu')\n",
358 | "model = model.to(device)"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 8,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "#Load 100 images for background\n",
368 | "shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)\n",
369 | "background, _ = next(iter(shap_loader))\n",
370 | "background = background.to(device)\n",
371 | "\n",
372 | "#Create SHAP explainer \n",
373 | "explainer = shap.DeepExplainer(model, background)"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": null,
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "# Load test images of right and left turn\n",
383 | "paths = glob.glob('../data/room_1/*')\n",
384 | "test_images = [Image.open(paths[0]), Image.open(paths[3])]\n",
385 | "test_images = np.array(test_images)\n",
386 | "\n",
387 | "test_input = [TRANSFORMS(img) for img in test_images]\n",
388 | "test_input = torch.stack(test_input).to(device)\n",
389 | "\n",
390 | "# Get SHAP values\n",
391 | "shap_values = explainer.shap_values(test_input)\n",
392 | "\n",
393 | "# Reshape shap values and images for plotting\n",
394 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
395 | "test_numpy = np.array([np.array(img) for img in test_images])\n",
396 | "\n",
397 | "shap.image_plot(shap_numpy, test_numpy,show=False)"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": null,
403 | "metadata": {},
404 | "outputs": [],
405 | "source": [
406 | "# Using gradient explainer\n",
407 | "explainer = shap.GradientExplainer(model, background)\n",
408 | "shap_values = explainer.shap_values(test_input)\n",
409 | "\n",
410 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
411 | "\n",
412 | "shap.image_plot(shap_numpy, test_numpy)"
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "metadata": {},
419 | "outputs": [],
420 | "source": [
421 | "# Load model trained on room 1, 2 and 3\n",
422 | "model = torch.load('../models/direction_model_2.pth') #change for different model\n",
423 | "\n",
424 | "# Use CPU\n",
425 | "device = torch.device('cpu')\n",
426 | "model = model.to(device)\n",
427 | "\n",
428 | "#Load 100 images for background\n",
429 | "shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)\n",
430 | "background, _ = next(iter(shap_loader))\n",
431 | "background = background.to(device)\n",
432 | "\n",
433 | "#Create SHAP explainer \n",
434 | "explainer = shap.DeepExplainer(model, background)\n",
435 | "\n",
436 | "# Load test images of right and left turn\n",
437 | "paths = glob.glob('../data/room_1/*')\n",
438 | "test_images = [Image.open(paths[0]), Image.open(paths[3])]\n",
439 | "test_images = np.array(test_images)\n",
440 | "\n",
441 | "# Transform images\n",
442 | "test_input = [TRANSFORMS(img) for img in test_images]\n",
443 | "test_input = torch.stack(test_input).to(device)\n",
444 | "\n",
445 | "# Get SHAP values\n",
446 | "shap_values = explainer.shap_values(test_input)\n",
447 | "\n",
448 | "# Reshape shap values and images for plotting\n",
449 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
450 | "test_numpy = np.array([np.array(img) for img in test_images])\n",
451 | "\n",
452 | "shap.image_plot(shap_numpy, test_numpy,show=False)\n",
453 | "plt.savefig(\"/Users/conorosullivan/Google Drive/My Drive/Medium/shap_imagedata/shap_plot_2.png\",facecolor='white',dpi=300,bbox_inches='tight')"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": null,
459 | "metadata": {},
460 | "outputs": [],
461 | "source": [
462 | "# Using Gradient Explainer\n",
463 | "e = shap.GradientExplainer(model, background)\n",
464 | "shap_values = e.shap_values(test_input)\n",
465 | "\n",
466 | "shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))\n",
467 | "test_numpy = np.array([np.array(img) for img in test_images])\n",
468 | "\n",
469 | "shap.image_plot(shap_numpy, test_numpy)"
470 | ]
471 | }
472 | ],
473 | "metadata": {
474 | "kernelspec": {
475 | "display_name": "pytorch",
476 | "language": "python",
477 | "name": "pytorch"
478 | },
479 | "language_info": {
480 | "codemirror_mode": {
481 | "name": "ipython",
482 | "version": 3
483 | },
484 | "file_extension": ".py",
485 | "mimetype": "text/x-python",
486 | "name": "python",
487 | "nbconvert_exporter": "python",
488 | "pygments_lexer": "ipython3",
489 | "version": "3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ]"
490 | },
491 | "orig_nbformat": 4,
492 | "vscode": {
493 | "interpreter": {
494 | "hash": "3c0d4fcf1a0a408688084e944cab5ef64e86c1ae9800e884f9b7a2ac0ee51db6"
495 | }
496 | }
497 | },
498 | "nbformat": 4,
499 | "nbformat_minor": 2
500 | }
501 |
--------------------------------------------------------------------------------
/src/kernel_vs_tree.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Kernel SHAP vs Tree SHAP\n",
8 | "Experiments to understand the time complexity of SHAP approximations"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "#imports\n",
18 | "import pandas as pd\n",
19 | "import numpy as np\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "\n",
22 | "#import xgboost as xgb\n",
23 | "from sklearn.ensemble import RandomForestRegressor\n",
24 | "import sklearn.datasets as ds\n",
25 | "\n",
26 | "import datetime\n",
27 | "\n",
28 | "import shap\n",
29 | "shap.initjs()"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "# Functions\n",
39 | "def runSHAP(n,kernel=True): \n",
40 | " \"\"\"\n",
41 | " Calculate shap values and return time taken\n",
42 | " n: number of SHAP values to calculate\n",
43 | " kernel: set False if using TreeSHAP \n",
44 | " \"\"\"\n",
45 | " \n",
46 | " x_sample = X[np.random.choice(X.shape[0], n, replace=True)]\n",
47 | " \n",
48 | " begin = datetime.datetime.now()\n",
49 | " if kernel:\n",
50 | " #Caculate SHAP values using KernelSHAP\n",
51 | " shap_values = kernelSHAP.shap_values(x_sample,l1_reg=False)\n",
52 | " time = datetime.datetime.now() - begin\n",
53 | " print(\"Kernel {}: \".format(n), time)\n",
54 | " else:\n",
55 | " #Caculate SHAP values using TreeSHAP\n",
56 | " shap_values = treeSHAP(x_sample)\n",
57 | " time = datetime.datetime.now() - begin\n",
58 | " print(\"Tree {}: \".format(n), time)\n",
59 | " \n",
60 | " return time\n",
61 | "\n",
62 | "def model_properties(model):\n",
63 | " \"\"\"Returns average depth and number of features and leaves of a random forest\"\"\"\n",
64 | " \n",
65 | " depths = []\n",
66 | " features = []\n",
67 | " leaves = []\n",
68 | " \n",
69 | " for tree in model.estimators_:\n",
70 | " depths.append(tree.get_depth())\n",
71 | " leaves.append(tree.get_n_leaves())\n",
72 | " n_feat = len(set(tree.tree_.feature)) -1 \n",
73 | " features.append(n_feat)\n",
74 | " \n",
75 | " return np.mean(depths), np.mean(features), np.mean(leaves)"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "## Experiment 1: Number of samples"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 3,
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "#Simulate regression data\n",
92 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n",
93 | "\n",
94 | "y= data[1]\n",
95 | "X = data[0]\n",
96 | "\n",
97 | "feature_names = range(len(X))"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "#Train model\n",
107 | "model = RandomForestRegressor(n_estimators=100,max_depth=4,random_state=0)\n",
108 | "model.fit(X, y)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 5,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "#Get shap estimators\n",
118 | "kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n",
119 | "treeSHAP = shap.TreeExplainer(model)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "results = []\n",
129 | "for n in [10,100,1000,2000,5000,10000]*3:\n",
130 | " #Calculate SHAP Values\n",
131 | " kernel_time = runSHAP(n=n)\n",
132 | " tree_time = runSHAP(n=n,kernel=False)\n",
133 | " \n",
134 | " result = [n,kernel_time,tree_time]\n",
135 | " results.append(result)\n",
136 | " \n",
137 | "results_1 = pd.DataFrame(results,columns = ['n','kernelSHAP','treeSHAP'])"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "avg_1 = results_1.groupby(by='n',as_index=False).mean()\n",
147 | "avg_1"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "k_sec"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "#Find average run time\n",
166 | "avg_1 = results_1.groupby(by='n',as_index=False).mean()\n",
167 | "\n",
168 | "k_sec = [t.total_seconds() for t in avg_1['kernelSHAP']]\n",
169 | "t_sec = [t.total_seconds() for t in avg_1['treeSHAP']]\n",
170 | "n = avg_1['n']\n",
171 | "\n",
172 | "#Proportional run time\n",
173 | "print((k_sec/n)/(t_sec/n))\n",
174 | "\n",
175 | "#Plot run time by number of observations\n",
176 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n",
177 | "\n",
178 | "plt.plot(n, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n",
179 | "plt.plot(n, t_sec, linestyle='-', linewidth=2,marker='o',label = 'TreeSHAP')\n",
180 | "\n",
181 | "plt.ylabel('Time (seconds)',size=20)\n",
182 | "plt.xlabel('Number of observations',size=20)\n",
183 | "plt.legend(fontsize=15)"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "#Number of observations\n",
193 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n",
194 | "\n",
195 | "plt.plot(n, t_sec, linestyle='-', color='#F87F0E',linewidth=2,marker='o',label = 'TreeSHAP')\n",
196 | "\n",
197 | "plt.ylabel('Time (seconds)',size=20)\n",
198 | "plt.xlabel('Number of observations',size=20)\n",
199 | "plt.legend(fontsize=15)"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {},
205 | "source": [
206 | "## Experiment 2: number of features\n",
207 | " "
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "metadata": {
214 | "scrolled": true
215 | },
216 | "outputs": [],
217 | "source": [
218 | "results = []\n",
219 | "\n",
220 | "for n_features, n_informative in zip([2,4,6,8,10,12,13,14,16,18,20]*3,[2,4,6,8,10,12,13,14,16,18,20]*3):\n",
221 | " \n",
222 | " #Simulate regression data\n",
223 | " data = ds.make_regression(n_samples=10000, n_features=n_features, n_informative=n_informative, n_targets=1,noise=0.1)\n",
224 | "\n",
225 | " y= data[1]\n",
226 | " X = data[0]\n",
227 | "\n",
228 | " feature_names = range(len(X))\n",
229 | "\n",
230 | " #Train model\n",
231 | " model = RandomForestRegressor(n_estimators=100,max_depth=10,random_state=0)\n",
232 | " model.fit(X, y)\n",
233 | " \n",
234 | " #get model properties\n",
235 | " avg_depth, avg_feat, avg_leaves = model_properties(model)\n",
236 | " \n",
237 | " #Get shap estimators\n",
238 | " kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n",
239 | " treeSHAP = shap.TreeExplainer(model)\n",
240 | " \n",
241 | " #Calculate SHAP values\n",
242 | " kernel_time = runSHAP(n=100)\n",
243 | " tree_time = runSHAP(n=100,kernel=False)\n",
244 | " \n",
245 | " result = [n_features, avg_depth, avg_feat, avg_leaves, kernel_time,tree_time]\n",
246 | " results.append(result)\n",
247 | "\n",
248 | "results_2 = pd.DataFrame(results,columns = ['n_features','avg_depth', 'avg_feat', 'avg_leaves','kernelSHAP','treeSHAP'])\n"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "#Get average run time\n",
258 | "avg_2 = results_2[['n_features','kernelSHAP','treeSHAP']].groupby(by='n_features',as_index=False).mean()\n",
259 | "\n",
260 | "k_sec = [t.total_seconds() for t in avg_2['kernelSHAP']]\n",
261 | "t_sec = [t.total_seconds() for t in avg_2['treeSHAP']]\n",
262 | "n = avg_2['n_features']\n",
263 | "\n",
264 | "print((k_sec/n)/(t_sec/n))\n",
265 | "\n",
266 | "#Plot run time by number of features\n",
267 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n",
268 | "\n",
269 | "plt.plot(n, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n",
270 | "plt.plot(n, t_sec, linestyle='-', linewidth=2,marker='o',label = 'TreeSHAP')\n",
271 | "\n",
272 | "plt.ylabel('Time (seconds)',size=20)\n",
273 | "plt.xlabel('Number of features',size=20)\n",
274 | "plt.legend(fontsize=15)"
275 | ]
276 | },
277 | {
278 | "cell_type": "markdown",
279 | "metadata": {},
280 | "source": [
281 | "## Experiment 3: number of trees"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "#Simulate regression data\n",
291 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n",
292 | "\n",
293 | "y= data[1]\n",
294 | "X = data[0]\n",
295 | "\n",
296 | "feature_names = range(len(X))"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "metadata": {},
303 | "outputs": [],
304 | "source": [
305 | "results = []\n",
306 | "\n",
307 | "for trees in [10,20,50,100,200,500,1000]*3:\n",
308 | " #Train model\n",
309 | " model = RandomForestRegressor(n_estimators=trees,max_depth=4,random_state=0)\n",
310 | " model.fit(X, y)\n",
311 | " \n",
312 | " #Get shap estimators\n",
313 | " kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n",
314 | " treeSHAP = shap.TreeExplainer(model)\n",
315 | " \n",
316 | " #Calculate SHAP Values\n",
317 | " kernel_time = runSHAP(n=100)\n",
318 | " tree_time = runSHAP(n=100,kernel=False)\n",
319 | " \n",
320 | " result = [trees,kernel_time,tree_time]\n",
321 | " results.append(result)\n",
322 | "\n",
323 | "results_3 = pd.DataFrame(results,columns = ['trees','kernelSHAP','treeSHAP'])"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": null,
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "#Get average run time\n",
333 | "avg_3 = results_3.groupby(by='trees',as_index=False).mean()\n",
334 | "\n",
335 | "k_sec = [t.total_seconds() for t in avg_3['kernelSHAP']]\n",
336 | "t_sec = [t.total_seconds() for t in avg_3['treeSHAP']]\n",
337 | "trees = avg_3['trees']\n",
338 | "\n",
339 | "print((k_sec/trees)/(t_sec/trees))\n",
340 | "\n",
341 | "#Plot run time by number of trees\n",
342 | "fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(20,10))\n",
343 | "\n",
344 | "ax[0].plot(trees, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n",
345 | "ax[0].set_ylabel('Time (seconds)',size=20)\n",
346 | "ax[0].set_xlabel('Number of trees',size=20)\n",
347 | "ax[0].legend(fontsize=15)\n",
348 | "\n",
349 | "ax[1].plot(trees, t_sec, color='#F87F0E', linewidth=2,marker='o',label = 'TreeSHAP')\n",
350 | "ax[1].set_ylabel('Time (seconds)',size=20)\n",
351 | "ax[1].set_xlabel('Number of trees',size=20)\n",
352 | "ax[1].legend(fontsize=15)"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "metadata": {},
358 | "source": [
359 | "## Experiment 4: tree depth"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "scrolled": true
367 | },
368 | "outputs": [],
369 | "source": [
370 | "#Simulate regression data\n",
371 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n",
372 | "\n",
373 | "y= data[1]\n",
374 | "X = data[0]\n",
375 | "\n",
376 | "feature_names = range(len(X))\n",
377 | "\n",
378 | "results = []\n",
379 | "\n",
380 | "#for depth in [2,4,6]:\n",
381 | "for depth in [2,4,6,8,10,15,20]*3:\n",
382 | "\n",
383 | " #Train model\n",
384 | " model = RandomForestRegressor(n_estimators=100,max_depth=depth,random_state=0)\n",
385 | " model.fit(X, y)\n",
386 | " \n",
387 | " #get model properties\n",
388 | " avg_depth, avg_feat, avg_leaves = model_properties(model)\n",
389 | " \n",
390 | " #Get shap estimators\n",
391 | " kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 10))\n",
392 | " treeSHAP = shap.TreeExplainer(model)\n",
393 | " \n",
394 | " #Calculate SHAP values\n",
395 | " kernel_time = runSHAP(n=100)\n",
396 | " tree_time = runSHAP(n=100,kernel=False)\n",
397 | " \n",
398 | " result = [depth, avg_depth, avg_feat, avg_leaves, kernel_time,tree_time]\n",
399 | " results.append(result)\n",
400 | "\n",
401 | "results_4 = pd.DataFrame(results,columns = ['depth','avg_depth', 'avg_feat', 'avg_leaves','kernelSHAP','treeSHAP'])"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {},
408 | "outputs": [],
409 | "source": [
410 | "#Get average run time\n",
411 | "avg_4 = results_4[['depth','kernelSHAP','treeSHAP']].groupby(by='depth',as_index=False).mean()\n",
412 | "\n",
413 | "k_sec = [t.total_seconds() for t in avg_4['kernelSHAP']]\n",
414 | "t_sec = [t.total_seconds() for t in avg_4['treeSHAP']]\n",
415 | "depth = avg_4['depth']\n",
416 | "\n",
417 | "#Plot run tume by tree depth\n",
418 | "fig, ax = plt.subplots(nrows=1, ncols=1,figsize=(8,6))\n",
419 | "\n",
420 | "plt.plot(depth, k_sec, linestyle='-', linewidth=2,marker='o',label = 'KernelSHAP')\n",
421 | "plt.plot(depth, t_sec, linestyle='-', linewidth=2,marker='o',label = 'TreeSHAP')\n",
422 | "plt.legend(fontsize=15)\n",
423 | "\n",
424 | "plt.ylabel('Time (seconds)',size=20)\n",
425 | "plt.xlabel('Tree depth',size=20)"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": null,
431 | "metadata": {},
432 | "outputs": [],
433 | "source": [
434 | "#Other factors\n",
435 | "r4 = results_4[['depth','avg_depth','avg_feat','avg_leaves']].groupby(by='depth',as_index=False).mean()\n",
436 | "\n",
437 | "fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(20,10))\n",
438 | "\n",
439 | "ax[0].plot(r4['depth'], r4['avg_feat'], linestyle='-', linewidth=2,marker='o')\n",
440 | "ax[0].set_ylabel('Average features',size=20)\n",
441 | "ax[0].set_xlabel('Tree depth',size=20)\n",
442 | "\n",
443 | "ax[1].plot(r4['depth'], r4['avg_leaves'], color='#F87F0E', linewidth=2,marker='o')\n",
444 | "ax[1].set_ylabel('Average leaves',size=20)\n",
445 | "ax[1].set_xlabel('Tree depth',size=20)"
446 | ]
447 | },
448 | {
449 | "cell_type": "markdown",
450 | "metadata": {},
451 | "source": [
452 | "# Archive "
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "execution_count": null,
458 | "metadata": {},
459 | "outputs": [],
460 | "source": [
461 | "#\n",
462 | "data = ds.make_regression(n_samples=10000, n_features=10, n_informative=8, n_targets=1)\n",
463 | "\n",
464 | "y= data[1]\n",
465 | "X = data[0]\n",
466 | "\n",
467 | "feature_names = range(len(X))\n",
468 | "\n",
469 | "depth = 10 # vary this value \n",
470 | "model = RandomForestRegressor(n_estimators=100,max_depth=depth,random_state=0)\n",
471 | "model.fit(X, y)\n",
472 | "\n",
473 | "model_properties(model)"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "metadata": {},
480 | "outputs": [],
481 | "source": [
482 | "#Simulate regression data\n",
483 | "data = ds.make_regression(n_samples=10000, n_features=20, n_informative=20, n_targets=1,noise=0.1)\n",
484 | "\n",
485 | "y= data[1]\n",
486 | "X = data[0]\n",
487 | "\n",
488 | "feature_names = range(len(X))\n",
489 | "\n",
490 | "#Train model\n",
491 | "model = RandomForestRegressor(n_estimators=100,max_depth=10,random_state=0)\n",
492 | "model.fit(X, y)\n",
493 | "\n",
494 | "#get model properties\n",
495 | "avg_depth, avg_feat, avg_leaves = model_properties(model)\n",
496 | "\n",
497 | "\n",
498 | "#Get shap estimators\n",
499 | "treeSHAP = shap.TreeExplainer(model)\n",
500 | "kernelSHAP = shap.KernelExplainer(model.predict,shap.sample(X, 20))\n",
501 | "\n",
502 | "#get shap values \n",
503 | "x_sample = X[np.random.choice(X.shape[0], 100, replace=True)]\n",
504 | "sv_tree = treeSHAP.shap_values(x_sample)\n",
505 | "sv_kernel = kernelSHAP.shap_values(x_sample,l1_reg=0.1)\n",
506 | "\n",
507 | "print(len(sv_tree[0]),len(sv_kernel[0]))"
508 | ]
509 | }
510 | ],
511 | "metadata": {
512 | "kernelspec": {
513 | "display_name": "SHAP",
514 | "language": "python",
515 | "name": "shap"
516 | },
517 | "language_info": {
518 | "codemirror_mode": {
519 | "name": "ipython",
520 | "version": 3
521 | },
522 | "file_extension": ".py",
523 | "mimetype": "text/x-python",
524 | "name": "python",
525 | "nbconvert_exporter": "python",
526 | "pygments_lexer": "ipython3",
527 | "version": "3.10.6"
528 | }
529 | },
530 | "nbformat": 4,
531 | "nbformat_minor": 2
532 | }
533 |
--------------------------------------------------------------------------------
/data/interaction_dataset.csv:
--------------------------------------------------------------------------------
1 | "experience" "degree" "performance" "sales" "days_late" "bonus"
2 | 31 1 6.11 29 14 197
3 | 35 1 9.55 44 8 314
4 | 9 1 2.64 26 20 88
5 | 40 1 0.22 13 7 233
6 | 18 1 6.46 11 13 108
7 | 0 1 4.57 72 0 101
8 | 31 0 3.38 83 9 61
9 | 14 1 5.63 27 9 112
10 | 37 1 0.03 13 7 183
11 | 16 1 1.1 7 16 92
12 | 36 1 5.3 18 18 225
13 | 11 1 5.35 17 18 107
14 | 34 1 3.88 53 4 234
15 | 28 1 1.52 58 16 161
16 | 5 1 2.96 33 7 52
17 | 13 1 8.01 23 17 129
18 | 1 0 1.02 71 17 17
19 | 29 1 7.08 99 1 346
20 | 17 1 9.15 61 0 227
21 | 2 1 4.43 14 5 37
22 | 20 1 7.28 82 2 263
23 | 34 1 2.45 97 6 260
24 | 16 1 6.03 64 2 182
25 | 39 1 4.9 5 3 235
26 | 18 1 4.87 87 9 203
27 | 8 1 7.12 28 6 110
28 | 24 0 3.37 16 0 14
29 | 33 0 9.76 52 6 134
30 | 38 1 7.55 46 11 278
31 | 24 0 1.9 61 20 22
32 | 27 1 1.55 57 15 152
33 | 37 0 6.24 42 3 88
34 | 37 1 8.07 30 17 243
35 | 29 0 0.1 14 4 0
36 | 11 1 0.18 6 19 62
37 | 37 1 4.6 33 5 259
38 | 8 1 9.11 11 16 78
39 | 21 0 9.87 81 14 196
40 | 23 0 9.43 99 3 213
41 | 31 1 1.01 51 0 193
42 | 21 1 5.1 90 19 225
43 | 7 1 9.08 77 11 220
44 | 26 1 5.97 68 15 250
45 | 10 1 1.98 24 20 49
46 | 33 1 8.98 58 5 310
47 | 8 1 1.85 13 6 72
48 | 3 0 9.95 42 1 107
49 | 10 1 7.01 9 19 75
50 | 5 1 9.83 8 13 85
51 | 32 1 1.99 8 16 181
52 | 1 1 1.41 25 4 38
53 | 21 1 9.16 35 10 221
54 | 1 0 2.35 80 19 45
55 | 36 0 9.76 6 9 10
56 | 10 0 8.26 52 9 122
57 | 10 1 9.4 4 12 90
58 | 30 1 5.4 79 8 269
59 | 31 1 6.06 93 6 297
60 | 18 1 1.08 23 16 100
61 | 20 0 7.94 13 12 42
62 | 40 1 4 35 12 259
63 | 36 1 3.5 93 19 285
64 | 13 1 3.69 3 17 66
65 | 11 0 7.29 91 11 159
66 | 16 1 4.61 61 6 178
67 | 9 1 7.62 90 9 230
68 | 4 1 4.63 31 1 90
69 | 25 0 2.39 9 8 0
70 | 26 0 3.25 3 6 11
71 | 20 1 7.69 49 18 189
72 | 32 1 8.42 36 14 266
73 | 36 0 6.82 24 7 49
74 | 38 1 1.98 90 15 258
75 | 21 1 2.27 37 14 132
76 | 28 1 5.12 32 2 214
77 | 29 1 6.41 10 14 190
78 | 40 0 1.46 35 0 3
79 | 23 1 8.28 50 15 211
80 | 24 0 6.03 76 1 114
81 | 0 0 2.61 90 11 51
82 | 31 0 1.62 52 5 13
83 | 36 0 8.29 3 6 6
84 | 19 0 3.75 53 11 51
85 | 33 1 1.94 56 5 235
86 | 15 1 9.72 96 5 323
87 | 14 1 3.2 18 2 109
88 | 23 1 2.47 57 20 156
89 | 38 0 4.73 40 4 59
90 | 30 1 6.03 66 0 263
91 | 22 1 4.55 34 0 182
92 | 29 1 6.04 75 10 291
93 | 6 1 0.43 30 4 42
94 | 33 1 1.86 68 1 221
95 | 37 0 8.79 39 9 70
96 | 0 1 5.5 26 5 53
97 | 10 1 1.88 87 19 97
98 | 33 0 4.14 31 10 12
99 | 16 0 2.22 26 11 4
100 | 25 1 5.14 39 20 203
101 | 36 1 1.8 16 18 211
102 | 18 1 2.24 9 1 135
103 | 24 0 0.95 90 4 52
104 | 5 1 6.63 95 13 192
105 | 7 0 1.09 56 19 1
106 | 21 1 6.42 99 15 283
107 | 3 1 3.09 49 16 84
108 | 26 1 5.47 19 7 159
109 | 38 1 1.86 54 0 225
110 | 39 0 3.89 91 0 96
111 | 28 1 1.43 86 5 199
112 | 27 1 8.35 59 16 257
113 | 13 1 4.02 12 17 102
114 | 5 1 2.78 7 19 26
115 | 10 1 0.63 90 20 65
116 | 33 1 2.73 29 0 198
117 | 32 1 8.67 99 13 386
118 | 1 1 8.68 18 6 86
119 | 23 1 5.75 68 15 214
120 | 22 1 4.16 3 13 109
121 | 33 1 4.59 82 2 284
122 | 11 1 0.7 67 9 71
123 | 35 0 1.17 50 15 0
124 | 7 1 4.84 61 15 110
125 | 4 1 1.79 57 10 48
126 | 13 1 7.69 55 19 179
127 | 15 1 4.63 100 6 217
128 | 27 1 9.6 77 3 348
129 | 37 1 8.02 54 16 284
130 | 4 1 9.6 15 12 90
131 | 11 1 0.33 5 19 40
132 | 3 1 2.98 21 19 47
133 | 9 1 3.68 93 13 146
134 | 23 1 9.36 37 16 228
135 | 26 1 1.35 65 16 174
136 | 2 1 7.83 60 16 133
137 | 18 1 5.13 69 5 180
138 | 21 1 4.81 47 1 193
139 | 6 1 8.18 93 0 253
140 | 23 1 0.64 62 16 154
141 | 35 1 5.31 20 9 236
142 | 17 1 9.63 72 13 248
143 | 12 1 0.61 29 19 68
144 | 35 1 3.56 59 16 247
145 | 30 0 5.98 92 7 145
146 | 7 1 1.8 38 2 73
147 | 16 1 1.7 60 5 149
148 | 30 0 8.69 12 18 28
149 | 14 1 8.87 10 6 141
150 | 34 1 7.31 17 13 211
151 | 11 0 9.67 14 12 41
152 | 19 0 9.81 33 4 91
153 | 20 1 1.1 87 16 134
154 | 17 1 2.43 54 8 130
155 | 22 1 5.11 39 16 162
156 | 16 1 5.69 60 16 191
157 | 27 0 2.13 42 0 26
158 | 31 0 6.35 98 19 149
159 | 36 0 3.09 61 7 41
160 | 19 1 3.8 36 8 131
161 | 22 0 6.6 7 14 2
162 | 7 1 4.1 88 0 169
163 | 4 1 1.26 83 14 69
164 | 20 1 8.97 73 12 256
165 | 11 0 8.85 11 0 53
166 | 16 1 2.45 16 8 110
167 | 4 1 2.31 54 10 58
168 | 6 1 4.07 44 5 81
169 | 31 1 4.83 18 13 199
170 | 32 1 7.19 39 8 265
171 | 33 1 9.45 4 5 213
172 | 20 0 9.1 61 1 123
173 | 0 1 6.69 100 15 173
174 | 33 0 1.82 44 9 2
175 | 2 1 5.19 91 17 135
176 | 26 0 3.02 35 8 9
177 | 12 0 9.26 62 9 151
178 | 27 1 0.08 30 19 127
179 | 11 1 2.03 100 8 121
180 | 8 1 3.8 79 11 146
181 | 25 1 0.12 39 15 131
182 | 31 1 5.54 92 18 288
183 | 0 1 5.35 44 7 73
184 | 21 1 7.61 53 0 250
185 | 19 1 6.33 59 5 213
186 | 32 0 7.5 6 3 34
187 | 39 0 5.81 47 11 53
188 | 21 1 7.51 38 19 197
189 | 7 1 1.1 61 19 56
190 | 13 1 3.53 71 16 137
191 | 13 1 7.14 57 0 209
192 | 32 1 1.16 58 14 210
193 | 28 1 9.33 56 8 269
194 | 12 1 5.77 35 3 150
195 | 34 1 4.6 63 6 247
196 | 31 0 3.41 48 4 41
197 | 20 1 7 75 14 243
198 | 35 1 6.28 27 9 244
199 | 38 1 6.78 10 7 240
200 | 39 0 5.95 16 1 32
201 | 22 1 5.25 70 18 222
202 | 9 1 7.16 27 5 128
203 | 8 1 4.68 52 7 102
204 | 10 1 3.92 45 9 99
205 | 29 1 6.63 99 11 338
206 | 12 1 2.97 55 5 114
207 | 14 1 1.16 35 18 99
208 | 24 0 5.85 93 8 119
209 | 22 1 1.52 11 18 132
210 | 36 1 4.45 10 18 186
211 | 8 0 0.65 27 13 0
212 | 9 1 2.17 94 13 111
213 | 37 1 1.44 41 6 240
214 | 14 1 5.24 1 1 109
215 | 2 1 0.97 60 12 28
216 | 24 1 4.17 89 0 232
217 | 7 1 8.69 70 4 187
218 | 1 0 5.83 35 13 48
219 | 1 1 5.91 29 9 58
220 | 24 1 8.64 94 2 339
221 | 24 1 8.93 43 17 246
222 | 20 1 4.55 11 7 127
223 | 5 1 3.29 75 17 89
224 | 23 1 7.99 13 15 149
225 | 23 0 7.52 23 12 57
226 | 18 1 3.27 2 4 112
227 | 37 0 1.06 80 10 24
228 | 5 1 6.36 6 19 61
229 | 3 0 3.99 74 2 92
230 | 5 1 2.31 8 0 72
231 | 34 1 7.45 53 4 292
232 | 21 1 9.25 83 3 316
233 | 8 1 1.04 9 10 67
234 | 35 0 4.3 37 0 42
235 | 38 0 2.97 58 16 40
236 | 28 1 5.47 27 19 169
237 | 24 1 3.67 7 8 163
238 | 11 1 2.68 61 4 108
239 | 26 0 5.22 82 13 115
240 | 40 1 7.92 33 12 279
241 | 36 0 0.76 87 17 5
242 | 25 1 7.47 54 17 244
243 | 27 0 9.04 13 14 34
244 | 18 1 1.8 98 19 145
245 | 13 1 1.09 19 20 71
246 | 8 1 9.89 25 8 129
247 | 31 0 0.58 91 4 20
248 | 14 1 1.41 5 19 70
249 | 17 1 8.95 88 19 281
250 | 17 1 5.96 53 17 173
251 | 11 1 4.63 42 5 121
252 | 28 0 8.44 98 5 189
253 | 4 0 9.54 45 8 111
254 | 37 1 0.79 41 5 217
255 | 39 0 8.65 27 11 52
256 | 23 1 1.82 50 4 151
257 | 12 1 0.22 6 18 61
258 | 38 0 3.45 99 8 92
259 | 28 1 9.74 90 3 361
260 | 40 0 6.29 24 10 28
261 | 14 1 3.34 55 12 141
262 | 16 1 1.27 27 0 130
263 | 14 1 5.34 58 18 161
264 | 23 1 5.01 31 4 190
265 | 37 1 6.75 3 5 229
266 | 27 1 2 74 12 195
267 | 30 1 1.4 26 17 151
268 | 15 1 7.61 4 0 104
269 | 19 0 0.14 30 9 7
270 | 39 1 6.57 74 0 337
271 | 27 1 7.8 34 1 244
272 | 27 1 2.31 11 1 178
273 | 23 0 7.74 69 8 137
274 | 23 0 6.15 62 13 86
275 | 2 1 7.14 91 10 189
276 | 37 1 9.4 62 3 335
277 | 4 1 1.86 59 19 46
278 | 2 0 8.85 51 18 108
279 | 13 1 2.58 61 12 128
280 | 34 0 5.9 5 8 28
281 | 17 0 5.89 5 17 16
282 | 20 1 6.14 33 4 160
283 | 26 1 0.04 5 11 133
284 | 25 1 8.71 9 7 192
285 | 25 1 0.17 45 5 133
286 | 19 1 6.56 20 6 165
287 | 32 1 5.25 4 8 170
288 | 33 1 1.78 92 20 227
289 | 0 1 9.19 12 20 46
290 | 13 1 3.24 8 7 90
291 | 6 1 9.74 97 0 297
292 | 36 1 6.6 61 8 295
293 | 15 1 9.05 45 14 187
294 | 28 1 1.25 41 4 183
295 | 7 1 0.95 53 9 84
296 | 6 1 3.51 12 2 50
297 | 29 1 4.15 83 18 226
298 | 30 1 9.46 44 6 270
299 | 36 1 7.39 51 8 297
300 | 9 1 2.58 91 20 106
301 | 20 0 2.54 93 11 66
302 | 26 0 7.77 39 4 72
303 | 37 1 4.39 47 13 232
304 | 2 1 1.16 50 5 46
305 | 11 1 5.96 22 4 102
306 | 40 1 2.72 94 5 303
307 | 26 1 7.52 65 0 273
308 | 11 0 8.07 15 17 35
309 | 36 0 4.77 11 12 15
310 | 9 0 6 58 17 89
311 | 40 1 4.66 20 20 238
312 | 35 1 7.77 74 3 327
313 | 29 1 3.68 11 13 172
314 | 17 1 7.37 93 6 255
315 | 18 0 2.9 22 6 4
316 | 31 1 9.02 100 20 386
317 | 26 1 4.34 16 2 183
318 | 32 1 5.13 47 18 244
319 | 10 1 5.48 92 9 206
320 | 25 0 9.11 68 7 134
321 | 34 0 5.22 11 8 0
322 | 38 0 6.66 67 6 124
323 | 3 0 6.36 7 18 0
324 | 10 1 9.17 21 3 146
325 | 29 0 9.33 3 1 15
326 | 28 1 2.64 46 13 185
327 | 10 1 4.1 49 12 114
328 | 10 0 5.46 98 11 126
329 | 22 1 4.07 100 8 237
330 | 7 1 0.25 27 9 41
331 | 22 1 7.75 37 10 207
332 | 24 1 2.8 12 7 138
333 | 38 1 7.12 4 14 233
334 | 18 1 8.4 7 12 123
335 | 32 0 4.34 10 1 33
336 | 10 1 0.47 21 9 83
337 | 5 1 2.87 12 20 35
338 | 40 1 1.21 61 10 252
339 | 31 1 6.12 81 6 278
340 | 1 1 9.8 63 20 143
341 | 11 1 3.42 83 13 136
342 | 37 0 0.86 45 0 2
343 | 22 0 1.73 66 14 19
344 | 32 1 9.73 59 3 334
345 | 25 0 2.95 81 7 78
346 | 29 1 0.65 56 17 174
347 | 14 0 0.82 48 10 2
348 | 14 1 5.17 90 5 202
349 | 23 1 0.54 30 8 149
350 | 29 0 0.8 49 18 9
351 | 27 1 0.3 30 20 145
352 | 35 1 8.11 79 5 340
353 | 17 1 9.21 78 5 273
354 | 5 1 9.7 100 18 278
355 | 1 1 7.53 49 2 125
356 | 16 1 9.3 47 19 197
357 | 39 0 1.33 20 17 0
358 | 10 1 6.74 13 12 88
359 | 10 1 8.18 55 17 170
360 | 18 0 8.44 37 12 85
361 | 2 0 5.22 27 12 21
362 | 7 0 1.15 2 4 0
363 | 0 0 5.14 13 2 5
364 | 26 0 2.55 85 18 34
365 | 5 0 9.59 20 7 60
366 | 13 1 0.8 34 0 100
367 | 22 1 0.57 89 15 145
368 | 30 1 8.27 22 9 200
369 | 16 1 5.93 16 17 110
370 | 23 1 9.43 19 15 174
371 | 28 1 1.17 63 3 191
372 | 4 0 4.02 69 5 64
373 | 27 1 9.6 54 8 287
374 | 27 0 6.49 97 12 128
375 | 6 0 9.95 94 3 225
376 | 25 1 2.34 1 14 143
377 | 9 1 5.04 40 3 129
378 | 14 1 2.77 81 11 135
379 | 22 0 2.87 1 18 3
380 | 23 1 9.75 81 4 316
381 | 21 1 2.62 77 4 182
382 | 37 1 1.24 55 5 244
383 | 32 1 4.64 37 13 208
384 | 29 0 9.87 31 2 66
385 | 22 1 4.26 72 7 211
386 | 9 1 8.4 16 1 108
387 | 27 1 0.02 22 9 134
388 | 13 1 4.62 62 19 139
389 | 1 1 4.44 68 12 114
390 | 15 1 2.25 43 14 122
391 | 20 1 7.93 28 15 186
392 | 34 1 7.37 13 4 217
393 | 25 0 4.81 99 0 129
394 | 34 1 3.84 24 2 235
395 | 27 1 1.11 46 12 148
396 | 16 0 0.5 74 4 16
397 | 38 1 4.41 98 20 318
398 | 14 1 5.41 79 10 177
399 | 16 0 0.43 61 19 3
400 | 27 0 1.22 75 11 11
401 | 9 1 9.4 74 12 207
402 | 26 0 2.59 0 12 0
403 | 23 1 3.41 45 17 166
404 | 18 0 6.48 99 6 158
405 | 22 1 9.16 90 2 334
406 | 21 0 3.91 74 2 70
407 | 4 1 5.56 90 19 168
408 | 40 1 8.98 25 18 276
409 | 2 1 7.48 52 5 143
410 | 34 1 8.89 6 8 230
411 | 12 1 5.53 97 2 215
412 | 15 1 9.23 54 10 199
413 | 33 1 8.95 65 2 343
414 | 37 1 1.2 17 8 193
415 | 0 1 7.26 27 6 72
416 | 7 1 0.24 96 8 58
417 | 36 1 2.26 0 16 182
418 | 8 0 8.03 69 12 125
419 | 21 1 2.45 72 0 172
420 | 40 0 9.96 53 0 146
421 | 34 1 1.82 74 14 239
422 | 29 1 5.88 80 2 304
423 | 6 0 3.03 96 14 72
424 | 7 1 5.96 34 3 118
425 | 15 1 2.59 32 0 113
426 | 2 0 8.9 92 18 197
427 | 3 1 5.9 90 18 150
428 | 7 0 3.98 38 8 36
429 | 32 0 0.25 69 2 30
430 | 3 1 8.06 78 16 186
431 | 30 0 9.58 59 11 114
432 | 28 0 3.52 3 9 0
433 | 6 0 8.29 9 3 44
434 | 15 1 0.52 56 3 118
435 | 33 0 9.56 63 10 161
436 | 0 0 4.6 17 12 23
437 | 36 1 2.13 65 8 243
438 | 36 1 7.05 88 9 339
439 | 6 1 1.39 95 12 76
440 | 40 1 9.31 2 7 241
441 | 2 1 0.61 5 19 5
442 | 38 0 5.6 41 2 70
443 | 25 1 4 61 14 181
444 | 36 1 5.36 20 20 199
445 | 6 1 9.47 32 12 106
446 | 40 1 3.43 92 16 302
447 | 18 0 1.56 76 12 22
448 | 34 1 5.37 40 3 264
449 | 1 0 2.69 19 2 35
450 | 20 0 4.43 62 15 42
451 | 12 1 6.95 22 13 114
452 | 6 1 3.22 9 1 51
453 | 38 1 5.32 27 18 246
454 | 7 1 6.47 30 6 123
455 | 10 1 8.69 43 20 143
456 | 24 1 2.6 15 9 142
457 | 20 1 9.37 84 8 312
458 | 26 1 8.78 56 18 269
459 | 7 1 1.42 11 3 63
460 | 25 1 5.71 74 9 247
461 | 21 1 1.68 12 9 132
462 | 32 0 9.69 31 15 75
463 | 8 1 4.39 38 9 97
464 | 15 1 4.6 4 10 95
465 | 13 1 2.88 92 4 162
466 | 28 1 9.21 6 16 194
467 | 2 1 5.99 44 18 90
468 | 26 0 0.08 38 20 0
469 | 9 1 0.79 34 20 57
470 | 27 1 1.26 40 3 188
471 | 22 1 2.37 27 7 160
472 | 18 0 2.03 71 6 25
473 | 19 1 1.77 64 19 123
474 | 30 1 0.65 66 9 176
475 | 26 0 8.47 78 18 149
476 | 39 1 0.72 54 13 233
477 | 35 0 0.51 82 7 12
478 | 4 1 1.8 13 15 30
479 | 34 1 3.29 21 11 215
480 | 11 1 0.4 78 9 98
481 | 27 0 8.94 27 20 35
482 | 17 1 1.11 69 6 139
483 | 18 1 7.4 15 15 116
484 | 6 0 6.11 46 10 77
485 | 37 0 3.38 35 1 29
486 | 16 1 6.5 52 8 195
487 | 34 1 6.72 93 10 328
488 | 36 1 5.07 96 10 328
489 | 0 0 4.72 63 1 86
490 | 5 1 8.63 1 10 62
491 | 15 0 0.33 64 17 15
492 | 1 1 2.83 67 6 85
493 | 38 0 4.45 52 18 67
494 | 20 1 3.43 1 15 109
495 | 39 0 4 54 16 56
496 | 11 0 5.34 41 10 61
497 | 26 1 8.62 94 2 351
498 | 30 1 0.02 11 20 150
499 | 30 1 6.75 78 15 305
500 | 18 1 9.88 36 16 186
501 | 21 0 8.63 58 20 129
502 | 28 1 5.38 36 4 224
503 | 15 1 4.72 41 6 134
504 | 12 1 6.66 19 8 108
505 | 17 1 7.11 0 20 85
506 | 7 1 7.51 8 0 78
507 | 39 1 2.77 71 14 253
508 | 4 1 1.09 99 15 88
509 | 26 0 9.94 68 0 178
510 | 39 1 5.84 22 2 237
511 | 13 1 5.81 29 18 106
512 | 4 1 3.28 100 3 132
513 | 3 1 9.24 44 9 121
514 | 24 1 6.75 60 10 225
515 | 30 1 5.94 68 3 267
516 | 7 1 0.62 71 20 63
517 | 31 0 4.43 34 9 45
518 | 21 1 4.57 85 5 239
519 | 23 0 5.59 100 0 123
520 | 17 0 4.93 77 2 116
521 | 8 1 1.36 74 13 99
522 | 10 1 5.24 95 12 195
523 | 7 1 0.14 91 11 73
524 | 17 0 4.57 6 3 0
525 | 26 1 3 38 12 162
526 | 35 1 9.03 84 5 364
527 | 5 1 6.04 67 12 150
528 | 10 1 0.53 35 17 78
529 | 38 1 1.95 91 8 241
530 | 11 1 8.16 93 3 270
531 | 2 0 6.9 68 4 136
532 | 33 1 8.55 12 15 218
533 | 40 1 7.55 63 8 333
534 | 7 1 8.75 20 5 97
535 | 26 1 5.03 81 12 253
536 | 30 1 3.1 54 12 221
537 | 6 0 2.58 7 9 0
538 | 4 1 6.38 34 1 117
539 | 37 1 5.25 80 7 327
540 | 9 0 5.37 71 16 104
541 | 33 0 2.91 78 13 54
542 | 4 1 7.8 17 11 85
543 | 0 1 0.24 72 17 33
544 | 19 1 5.74 26 17 136
545 | 33 1 3.9 51 18 214
546 | 3 1 1.91 23 9 25
547 | 14 0 0.56 32 8 0
548 | 13 1 6.85 91 7 239
549 | 19 1 8.5 87 8 305
550 | 37 1 2.11 24 6 236
551 | 2 1 7.55 75 5 181
552 | 17 1 2.56 42 11 135
553 | 18 0 1.28 9 5 7
554 | 17 1 4.24 24 13 108
555 | 30 1 2.24 9 20 141
556 | 11 1 5.42 97 0 216
557 | 25 1 6.38 84 18 268
558 | 26 1 3.94 15 5 176
559 | 19 0 1.92 1 0 1
560 | 21 1 0.58 52 17 112
561 | 31 1 3.76 69 8 233
562 | 8 1 9.65 2 12 52
563 | 21 0 6.47 94 2 163
564 | 19 1 9.21 65 8 248
565 | 1 0 5.47 66 4 74
566 | 5 0 4.75 93 3 109
567 | 13 1 0.06 7 6 74
568 | 20 1 1.75 82 5 161
569 | 30 1 2.28 63 20 215
570 | 0 1 2.65 47 5 58
571 | 9 1 1.26 76 18 74
572 | 33 1 5.83 92 1 334
573 | 2 1 1.44 18 19 27
574 | 29 1 8.44 19 14 196
575 | 7 1 4.9 90 20 133
576 | 21 1 8.72 42 0 213
577 | 30 1 7.7 75 3 328
578 | 6 1 0.12 5 6 62
579 | 40 1 3.9 91 2 329
580 | 33 1 2.03 70 2 210
581 | 2 0 9.26 17 20 42
582 | 12 1 1.29 33 10 92
583 | 5 0 0.24 50 20 1
584 | 22 1 1.43 73 7 171
585 | 35 1 5.36 75 18 270
586 | 4 1 0.45 51 7 54
587 | 31 1 5.69 100 0 330
588 | 27 1 3.6 88 10 243
589 | 6 1 1.44 50 7 89
590 | 31 1 8.99 44 19 266
591 | 0 1 6.42 24 14 72
592 | 35 1 6.44 1 16 174
593 | 10 1 6.74 93 17 218
594 | 33 1 8.54 85 3 347
595 | 4 0 9.43 85 4 201
596 | 7 1 9.1 12 7 83
597 | 5 0 4.51 84 16 86
598 | 28 1 5.96 21 3 184
599 | 29 1 5.12 93 16 262
600 | 10 1 3.17 13 2 77
601 | 1 1 2.06 97 8 100
602 | 4 1 6.74 57 17 110
603 | 24 0 5.62 51 2 68
604 | 0 0 2.84 34 4 14
605 | 11 0 8.69 49 0 131
606 | 25 1 1.05 18 4 156
607 | 19 1 6.13 88 0 233
608 | 34 0 9.47 34 7 71
609 | 25 1 7.13 0 1 159
610 | 6 1 0.79 21 7 47
611 | 26 1 1.39 51 0 191
612 | 1 0 4.44 50 7 72
613 | 31 1 1.28 20 17 174
614 | 6 1 4.77 73 10 143
615 | 20 1 3.06 9 8 110
616 | 19 1 4.21 27 0 154
617 | 37 0 5.52 31 7 47
618 | 22 1 6.28 16 19 154
619 | 17 1 2.66 48 12 123
620 | 26 1 6.67 28 0 188
621 | 19 1 2.23 66 5 147
622 | 25 1 1.11 9 12 143
623 | 37 1 1.66 2 20 188
624 | 15 1 7.5 9 1 105
625 | 0 1 7.18 59 8 116
626 | 24 0 4.14 53 9 72
627 | 22 1 5.43 37 16 167
628 | 20 1 6.61 22 18 161
629 | 0 1 1.44 9 19 8
630 | 33 0 2.07 41 6 13
631 | 18 1 2.89 41 20 128
632 | 1 1 2.28 57 3 47
633 | 8 1 8.02 100 15 257
634 | 23 1 2.99 34 17 160
635 | 38 1 1.95 32 0 231
636 | 21 1 7.43 35 0 179
637 | 35 1 0.46 27 18 198
638 | 34 1 1.71 82 12 228
639 | 8 0 7.7 77 8 143
640 | 28 1 9.95 44 12 280
641 | 23 1 8.81 26 10 197
642 | 7 1 7.79 72 12 176
643 | 23 1 9.36 38 3 225
644 | 32 0 4.8 80 6 103
645 | 25 1 7.47 93 1 332
646 | 18 1 4.93 14 9 144
647 | 34 1 3.52 76 7 272
648 | 19 1 1.55 40 3 136
649 | 21 1 2.21 30 16 145
650 | 39 1 8.94 83 18 396
651 | 13 0 0.93 12 4 20
652 | 22 1 4.26 50 8 167
653 | 22 0 7.54 91 14 151
654 | 27 1 5.94 96 0 303
655 | 17 0 3.01 47 17 7
656 | 6 0 2.19 92 7 40
657 | 1 1 4 8 6 41
658 | 29 0 0.65 24 12 0
659 | 14 1 2.79 29 8 103
660 | 27 1 2.42 51 13 180
661 | 11 1 7.28 52 0 160
662 | 19 0 3.88 98 11 78
663 | 8 1 9.4 87 1 272
664 | 34 1 6.43 14 1 216
665 | 9 1 0.18 13 17 58
666 | 24 1 1.29 25 20 146
667 | 15 1 0.22 38 0 92
668 | 24 1 2.43 74 7 193
669 | 29 1 3.79 64 13 206
670 | 3 1 8.5 77 3 196
671 | 24 0 8.58 25 10 68
672 | 40 1 0.52 74 19 237
673 | 7 1 3.68 6 16 62
674 | 34 0 9.34 45 6 125
675 | 9 1 5.86 83 8 179
676 | 33 1 7.93 0 2 215
677 | 18 1 2.69 11 19 124
678 | 2 1 3.11 17 5 34
679 | 9 1 4.62 60 19 140
680 | 6 1 8.42 56 12 180
681 | 36 1 6.39 6 3 215
682 | 15 1 8.78 98 1 285
683 | 34 1 0.21 7 18 170
684 | 3 1 5.45 26 11 64
685 | 26 0 3.52 71 1 60
686 | 18 1 8.84 47 10 219
687 | 3 1 4.25 19 14 49
688 | 4 0 7.1 23 0 53
689 | 19 1 0.7 47 11 102
690 | 7 1 3.08 62 19 78
691 | 38 1 5.52 87 19 311
692 | 20 1 4.53 63 19 188
693 | 1 1 6.47 68 10 133
694 | 24 1 1.05 58 5 152
695 | 14 0 3.27 69 4 72
696 | 35 0 5.75 34 9 61
697 | 39 1 1.78 83 7 241
698 | 20 0 8.13 73 16 151
699 | 13 0 7.07 83 19 112
700 | 15 1 3.38 89 9 151
701 | 35 1 5.06 66 1 293
702 | 5 1 6.22 35 5 99
703 | 23 0 8.74 71 8 137
704 | 17 1 1.32 37 1 129
705 | 19 1 3.09 97 1 192
706 | 29 1 0.59 18 17 169
707 | 19 1 2.04 100 8 171
708 | 24 1 0.37 46 13 143
709 | 6 1 9.33 60 16 167
710 | 34 1 7.21 82 11 336
711 | 33 1 6.22 44 5 257
712 | 32 1 3.66 10 16 181
713 | 21 1 2.18 75 20 138
714 | 0 1 8.29 67 11 149
715 | 7 1 2.38 4 7 70
716 | 29 1 2.36 82 6 213
717 | 15 1 2.13 61 12 109
718 | 16 1 6.53 28 12 148
719 | 32 1 8.54 7 4 194
720 | 15 1 1.94 75 16 142
721 | 17 1 8.72 43 6 219
722 | 15 0 1.06 45 20 0
723 | 32 1 4.51 20 12 195
724 | 37 0 5.77 23 7 32
725 | 32 1 6.35 34 11 239
726 | 40 1 0.78 57 17 240
727 | 32 1 7.67 61 15 282
728 | 4 1 4.64 33 5 74
729 | 8 0 5.53 48 4 62
730 | 33 1 3.96 13 4 198
731 | 5 1 6.27 50 14 103
732 | 18 1 7.19 48 11 193
733 | 36 0 3.03 53 0 67
734 | 33 1 4.61 51 0 249
735 | 16 1 1.93 15 4 90
736 | 33 1 5.47 90 18 281
737 | 20 1 7.91 72 8 274
738 | 18 1 9.6 93 8 312
739 | 7 1 7.91 74 8 176
740 | 37 1 4.01 72 14 270
741 | 7 1 2.9 78 2 102
742 | 29 1 5.59 54 19 234
743 | 5 1 5.17 60 1 132
744 | 12 1 1.24 69 11 115
745 | 14 1 7.39 86 10 252
746 | 1 1 0.09 63 17 18
747 | 30 1 4.93 75 2 262
748 | 31 1 4.42 92 20 272
749 | 30 1 8.96 8 18 203
750 | 36 0 5.3 94 9 125
751 | 28 1 7.11 74 16 259
752 | 7 1 9.91 68 7 220
753 | 36 1 0.48 46 16 213
754 | 36 0 6.79 46 0 88
755 | 10 1 1.21 7 5 61
756 | 18 1 8.06 71 15 252
757 | 6 0 0.68 89 5 35
758 | 11 1 1.14 47 20 71
759 | 17 1 9.1 67 17 250
760 | 39 0 4.73 90 19 102
761 | 8 1 2.79 41 17 93
762 | 0 1 1.03 89 12 57
763 | 8 1 1.77 25 5 62
764 | 1 1 4.14 31 11 61
765 | 7 1 4.33 65 9 105
766 | 7 1 9.47 100 7 274
767 | 5 1 5.57 35 7 88
768 | 33 1 1.84 95 17 212
769 | 38 0 6.11 78 3 139
770 | 38 1 1.82 80 15 229
771 | 35 1 5.36 42 11 264
772 | 2 1 8.13 94 15 205
773 | 34 1 5.88 97 19 306
774 | 40 1 2.36 26 12 231
775 | 13 1 2.15 64 13 134
776 | 8 1 0.79 47 2 96
777 | 39 0 1.06 12 11 13
778 | 37 0 0.15 23 8 9
779 | 17 1 1.34 91 14 135
780 | 37 1 0.51 18 15 216
781 | 20 1 8.01 52 6 229
782 | 24 1 8.84 16 15 182
783 | 0 0 4.81 28 20 34
784 | 15 1 9.18 13 6 127
785 | 24 1 8.01 4 2 154
786 | 25 1 4 70 8 234
787 | 18 1 6.67 84 15 222
788 | 25 1 5.02 92 10 275
789 | 27 1 9.44 13 8 198
790 | 17 0 5 30 15 21
791 | 20 1 3.66 67 5 188
792 | 5 0 7.34 12 5 19
793 | 0 1 9.51 16 16 69
794 | 27 0 6.23 5 18 9
795 | 29 1 6.51 82 15 289
796 | 38 0 3.05 71 4 50
797 | 25 0 6.36 38 5 80
798 | 32 1 1.86 43 3 210
799 | 6 1 7.26 59 13 156
800 | 40 1 4.86 79 16 291
801 | 5 0 4.33 89 6 101
802 | 5 1 4.51 99 16 153
803 | 37 1 4.74 92 20 290
804 | 30 1 3.73 64 4 224
805 | 8 1 4.2 95 19 157
806 | 35 0 2.85 96 12 58
807 | 2 0 9.8 76 11 169
808 | 31 1 2.91 29 20 185
809 | 17 1 9.76 30 7 202
810 | 0 1 0.64 8 12 0
811 | 0 1 9.57 34 12 111
812 | 21 1 4.35 29 1 180
813 | 14 1 0.18 92 6 93
814 | 9 1 5.14 81 7 152
815 | 21 1 7.05 32 18 192
816 | 10 0 8.99 94 20 168
817 | 38 1 1.19 74 2 253
818 | 6 1 9.5 45 18 137
819 | 9 0 0.37 78 12 4
820 | 14 1 4.21 54 15 133
821 | 16 1 6.07 96 8 258
822 | 11 0 7.29 81 5 153
823 | 31 1 1.34 83 8 198
824 | 21 1 2.45 46 13 142
825 | 22 1 0.29 61 9 148
826 | 27 1 0.14 13 11 148
827 | 26 1 6.26 87 4 300
828 | 1 0 8.21 73 11 148
829 | 19 1 3.16 28 2 160
830 | 15 0 5.59 1 2 13
831 | 30 1 1.4 7 10 169
832 | 15 1 3.96 16 0 111
833 | 18 1 7.94 14 13 143
834 | 0 1 8.44 15 20 34
835 | 20 0 0.32 84 12 27
836 | 12 1 9.34 92 5 269
837 | 25 1 1.96 2 7 159
838 | 17 1 8.31 56 18 189
839 | 31 1 5.61 4 18 183
840 | 15 1 7.08 78 10 211
841 | 38 0 5.65 14 15 25
842 | 39 1 4.67 8 12 221
843 | 10 0 8.09 16 18 30
844 | 6 0 8.22 39 6 71
845 | 20 1 4.57 2 12 118
846 | 28 1 8.74 86 14 349
847 | 13 1 9.04 49 12 174
848 | 24 1 2.54 66 8 185
849 | 5 0 6.41 41 4 66
850 | 40 1 0.59 68 1 232
851 | 23 1 2.05 43 11 144
852 | 13 1 4.75 83 20 172
853 | 16 1 5.1 86 11 220
854 | 28 0 2.34 87 0 64
855 | 26 0 2.6 98 0 76
856 | 30 0 1.48 48 17 0
857 | 38 1 7.72 3 15 218
858 | 7 1 2.81 3 5 74
859 | 40 1 5.67 69 6 324
860 | 31 1 2.99 36 12 215
861 | 30 1 6.18 47 20 219
862 | 28 1 5.02 36 4 224
863 | 23 1 0.62 54 9 156
864 | 31 0 6.08 11 2 35
865 | 16 1 7.27 49 7 192
866 | 20 1 3.66 28 8 132
867 | 5 1 9.93 52 11 177
868 | 36 0 0.02 92 4 12
869 | 39 1 7.88 52 0 317
870 | 29 0 1.64 73 15 21
871 | 37 0 8.97 78 2 171
872 | 5 1 1.3 39 11 67
873 | 14 1 1.32 24 20 80
874 | 20 1 9.35 19 14 180
875 | 24 1 8.41 47 19 210
876 | 4 1 0.84 76 17 48
877 | 15 1 1.17 94 18 115
878 | 22 1 8.08 95 20 292
879 | 27 1 3.39 28 9 170
880 | 5 1 2.03 91 17 77
881 | 17 1 8.91 10 14 122
882 | 40 0 6.76 95 9 151
883 | 7 1 7.79 2 13 72
884 | 11 1 3.69 61 20 134
885 | 26 0 5.28 1 18 0
886 | 14 1 6.1 5 5 88
887 | 29 1 5.17 10 10 195
888 | 26 0 4.28 91 14 88
889 | 8 1 4.04 73 3 128
890 | 11 1 2.18 61 10 94
891 | 2 1 7.6 77 9 157
892 | 37 0 0.33 22 5 0
893 | 27 1 3.22 23 7 182
894 | 16 1 2.01 72 6 144
895 | 22 1 8.74 19 19 165
896 | 28 1 3.37 92 9 240
897 | 32 1 4.58 1 1 174
898 | 27 1 1.92 71 17 176
899 | 11 1 2.23 14 12 82
900 | 24 0 5.05 24 12 42
901 | 9 1 5.38 16 7 79
902 | 27 1 0.42 88 1 164
903 | 33 1 9.22 53 12 305
904 | 14 1 2.52 100 11 173
905 | 11 0 2.87 82 3 83
906 | 16 1 3.85 66 9 146
907 | 13 1 8.9 91 6 269
908 | 8 0 3.59 70 15 50
909 | 6 1 0.9 67 2 92
910 | 26 0 2.87 89 13 82
911 | 1 1 1.41 6 17 5
912 | 28 1 9.72 6 18 183
913 | 23 1 3.71 83 12 204
914 | 7 0 3.69 10 3 0
915 | 29 1 8.21 39 8 251
916 | 32 1 3.97 17 7 213
917 | 16 1 5.7 22 13 142
918 | 33 0 5.86 42 7 67
919 | 6 0 9.84 89 13 209
920 | 29 1 9.27 39 10 259
921 | 39 0 2.16 57 20 12
922 | 17 1 6.01 71 3 210
923 | 33 1 5.44 73 18 255
924 | 16 1 7.39 68 20 196
925 | 18 1 8.06 33 8 190
926 | 37 0 9.9 11 1 48
927 | 36 0 1.67 2 5 0
928 | 39 1 2.53 82 11 286
929 | 9 1 9.57 8 4 84
930 | 22 1 7.01 54 16 218
931 | 19 1 9.01 96 4 314
932 | 21 1 9.77 78 14 314
933 | 19 0 1.4 86 10 41
934 | 10 0 4.59 28 3 36
935 | 21 0 7.62 3 15 0
936 | 7 0 7.26 26 4 65
937 | 16 1 4.35 66 16 154
938 | 10 0 2.8 92 13 52
939 | 33 1 0.04 67 6 211
940 | 20 1 6.15 33 4 176
941 | 39 1 9.27 79 4 382
942 | 26 1 5.65 63 15 225
943 | 20 1 1.09 15 16 103
944 | 4 0 2.14 85 2 45
945 | 5 1 5.77 54 17 124
946 | 16 1 3.49 51 7 149
947 | 6 1 7.36 84 9 184
948 | 0 1 9.26 64 8 164
949 | 40 1 6.27 27 17 274
950 | 40 1 0.67 36 6 214
951 | 38 1 2.12 88 0 256
952 | 30 1 0.27 99 15 175
953 | 32 0 4.96 52 14 45
954 | 40 0 0.93 26 13 11
955 | 26 1 6.38 66 9 246
956 | 36 1 6.96 40 7 270
957 | 10 1 5.01 46 10 130
958 | 0 1 9.51 50 11 144
959 | 39 0 7.98 32 19 66
960 | 35 1 2.15 24 2 207
961 | 9 1 9.91 45 1 164
962 | 14 0 4.1 48 16 57
963 | 24 1 8.78 20 9 194
964 | 15 1 0.32 53 13 78
965 | 23 1 8.5 45 16 233
966 | 38 1 4.37 68 4 305
967 | 7 1 9.79 38 7 165
968 | 16 1 4.64 98 1 234
969 | 25 1 0.36 6 14 139
970 | 33 1 1.78 36 0 192
971 | 33 1 1.26 58 11 210
972 | 3 1 8.12 76 9 182
973 | 6 1 3.33 40 11 67
974 | 8 1 1.37 33 9 88
975 | 15 0 1.5 14 14 0
976 | 25 1 6.08 91 2 303
977 | 16 1 8.4 92 11 289
978 | 36 1 1.14 0 0 187
979 | 3 0 0.78 30 1 31
980 | 5 0 1.97 53 17 10
981 | 34 1 0.99 92 3 246
982 | 30 1 7.52 57 16 248
983 | 11 0 0.13 8 16 0
984 | 40 1 8.95 95 5 439
985 | 38 1 2.97 38 4 235
986 | 36 1 9.47 55 20 332
987 | 19 0 2.43 46 8 43
988 | 17 1 1.35 20 7 127
989 | 2 1 4.8 91 5 160
990 | 27 1 9.9 92 7 380
991 | 33 1 4.14 23 10 224
992 | 16 0 3.19 15 14 12
993 | 22 0 4.65 38 2 59
994 | 24 0 7.81 43 3 100
995 | 14 1 7.26 66 1 212
996 | 19 1 3.84 39 17 148
997 | 8 1 4.1 76 17 142
998 | 25 0 1.66 13 5 24
999 | 20 1 8.95 52 11 230
1000 | 16 1 9.4 91 15 284
1001 | 39 1 2.58 47 0 241
1002 |
--------------------------------------------------------------------------------
/src/shap_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SHAP Tutorial\n",
8 | "\n",
9 | "
\n",
10 | "Course sections:\n",
11 | "\n",
12 | "- SHAP values\n",
13 | "
- SHAP aggregations\n",
14 | "
\n",
15 | " - Force plots\n",
16 | "
- Mean SHAP\n",
17 | "
- Beeswarm\n",
18 | "
- Violin\n",
19 | "
- Heatmap\n",
20 | "
- Dependence\n",
21 | "
\n",
22 | " - Custom SHAP plots\n",
23 | "
- Binary and mutliclass target variables \n",
24 | "
- SHAP interaction values\n",
25 | "
- Categorical features\n",
26 | "
\n",
27 | "
\n",
28 | "Dataset: https://archive.ics.uci.edu/ml/datasets/Abalone\n"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# imports\n",
38 | "import pandas as pd\n",
39 | "import numpy as np\n",
40 | "\n",
41 | "import matplotlib.pyplot as plt\n",
42 | "import seaborn as sns\n",
43 | "\n",
44 | "import xgboost as xgb\n",
45 | "\n",
46 | "import shap\n",
47 | "\n",
48 | "shap.initjs()"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "# Dataset\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "# import dataset\n",
65 | "data = pd.read_csv(\n",
66 | " \"../data/abalone.data\",\n",
67 | " names=[\n",
68 | " \"sex\",\n",
69 | " \"length\",\n",
70 | " \"diameter\",\n",
71 | " \"height\",\n",
72 | " \"whole weight\",\n",
73 | " \"shucked weight\",\n",
74 | " \"viscera weight\",\n",
75 | " \"shell weight\",\n",
76 | " \"rings\",\n",
77 | " ],\n",
78 | ")\n",
79 | "\n",
80 | "print(len(data))\n",
81 | "data.head()"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "# plot 1: whole weight\n",
91 | "plt.scatter(data[\"whole weight\"], data[\"rings\"])\n",
92 | "plt.ylabel(\"rings\", size=20)\n",
93 | "plt.xlabel(\"whole weight\", size=20)"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "# plot 2: sex\n",
103 | "plt.boxplot(data[data.sex == \"I\"][\"rings\"], positions=[1])\n",
104 | "plt.boxplot(data[data.sex == \"M\"][\"rings\"], positions=[2])\n",
105 | "plt.boxplot(data[data.sex == \"F\"][\"rings\"], positions=[3])\n",
106 | "\n",
107 | "plt.xticks(ticks=[1, 2, 3], labels=[\"I\", \"M\", \"F\"], size=15)\n",
108 | "plt.ylabel(\"rings\", size=20)\n",
109 | "plt.xlabel(\"sex\", size=20)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "# plot 3: Correlation heatmap\n",
119 | "cont = [\n",
120 | " \"length\",\n",
121 | " \"diameter\",\n",
122 | " \"height\",\n",
123 | " \"whole weight\",\n",
124 | " \"shucked weight\",\n",
125 | " \"viscera weight\",\n",
126 | " \"shell weight\",\n",
127 | " \"rings\",\n",
128 | "]\n",
129 | "corr_matrix = pd.DataFrame(data[cont], columns=cont).corr()\n",
130 | "\n",
131 | "sns.heatmap(corr_matrix, cmap=\"coolwarm\", center=0, annot=True, fmt=\".1g\")"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "# Feature Engineering\n"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 6,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "y = data[\"rings\"]\n",
148 | "X = data[[\"sex\", \"length\", \"height\", \"shucked weight\", \"viscera weight\", \"shell weight\"]]"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "# create dummy variables\n",
158 | "X[\"sex.M\"] = [1 if s == \"M\" else 0 for s in X[\"sex\"]]\n",
159 | "X[\"sex.F\"] = [1 if s == \"F\" else 0 for s in X[\"sex\"]]\n",
160 | "X[\"sex.I\"] = [1 if s == \"I\" else 0 for s in X[\"sex\"]]\n",
161 | "X = X.drop(\"sex\", axis=1)\n",
162 | "\n",
163 | "X.head()"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "features = X.copy()\n",
173 | "features['y'] = y\n",
174 | "\n",
175 | "features.head()"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "# Modelling\n"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# train model\n",
192 | "model = xgb.XGBRegressor(objective=\"reg:squarederror\")\n",
193 | "model.fit(X, y)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "# get predictions\n",
203 | "y_pred = model.predict(X)\n",
204 | "\n",
205 | "# model evaluation\n",
206 | "plt.figure(figsize=(5, 5))\n",
207 | "\n",
208 | "plt.scatter(y, y_pred)\n",
209 | "plt.plot([0, 30], [0, 30], color=\"r\", linestyle=\"-\", linewidth=2)\n",
210 | "\n",
211 | "plt.ylabel(\"Predicted\", size=20)\n",
212 | "plt.xlabel(\"Actual\", size=20)"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {},
218 | "source": [
219 | "# 1) Standard SHAP values\n"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 11,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "# get shap values\n",
229 | "explainer = shap.Explainer(model)\n",
230 | "shap_values = explainer(X)\n",
231 | "\n",
232 | "# shap_values = explainer(X[0:100])"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "np.shape(shap_values.values)"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "## Waterfall plot\n"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "# waterfall plot for first observation\n",
258 | "shap.plots.waterfall(shap_values[0])"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "# waterfall plot for first observation\n",
268 | "shap.plots.waterfall(shap_values[1], max_display=4)"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "metadata": {},
274 | "source": [
275 | "# 2) SHAP aggregations\n"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "metadata": {},
281 | "source": [
282 | "## Force plot\n"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": null,
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "# force plot\n",
292 | "shap.plots.force(shap_values[0])"
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "metadata": {},
298 | "source": [
299 | "## Stacked force plot\n"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "# stacked force plot\n",
309 | "shap.plots.force(shap_values[0:100])"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "metadata": {},
315 | "source": [
316 | "## Absolute Mean SHAP\n"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "# mean SHAP\n",
326 | "shap.plots.bar(shap_values)"
327 | ]
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {},
332 | "source": [
333 | "## Beeswarm plot\n"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "metadata": {},
340 | "outputs": [],
341 | "source": [
342 | "# beeswarm plot\n",
343 | "shap.plots.beeswarm(shap_values)"
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "metadata": {},
349 | "source": [
350 | "## Violin plot\n"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": null,
356 | "metadata": {},
357 | "outputs": [],
358 | "source": [
359 | "# violin plot\n",
360 | "shap.plots.violin(shap_values)"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {},
367 | "outputs": [],
368 | "source": [
369 | "# layered violin plot\n",
370 | "shap.plots.violin(shap_values, plot_type=\"layered_violin\")"
371 | ]
372 | },
373 | {
374 | "cell_type": "markdown",
375 | "metadata": {},
376 | "source": [
377 | "## Heamap\n"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "metadata": {},
384 | "outputs": [],
385 | "source": [
386 | "# heatmap\n",
387 | "shap.plots.heatmap(shap_values)"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "metadata": {},
394 | "outputs": [],
395 | "source": [
396 | "# order by predictions\n",
397 | "order = np.argsort(y_pred)\n",
398 | "shap.plots.heatmap(shap_values, instance_order=order)"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "metadata": {},
405 | "outputs": [],
406 | "source": [
407 | "# order by shell weight value\n",
408 | "order = np.argsort(data[\"shell weight\"])\n",
409 | "shap.plots.heatmap(shap_values, instance_order=order)"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "metadata": {},
415 | "source": [
416 | "## Dependence plots\n"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": [
425 | "# plot 1: shell weight\n",
426 | "shap.plots.scatter(shap_values[:, \"shell weight\"])"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": null,
432 | "metadata": {},
433 | "outputs": [],
434 | "source": [
435 | "shap.plots.scatter(\n",
436 | " shap_values[:, \"shell weight\"], color=shap_values[:, \"shucked weight\"]\n",
437 | ")"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": null,
443 | "metadata": {},
444 | "outputs": [],
445 | "source": [
446 | "# plot 2: shucked weight\n",
447 | "shap.plots.scatter(shap_values[:, \"shucked weight\"])"
448 | ]
449 | },
450 | {
451 | "cell_type": "markdown",
452 | "metadata": {},
453 | "source": [
454 | "# 3) Custom Plots\n"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": null,
460 | "metadata": {},
461 | "outputs": [],
462 | "source": [
463 | "# output SHAP object\n",
464 | "shap_values"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": null,
470 | "metadata": {},
471 | "outputs": [],
472 | "source": [
473 | "np.shape(shap_values.values)"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "metadata": {},
480 | "outputs": [],
481 | "source": [
482 | "# SHAP correlation plot\n",
483 | "corr_matrix = pd.DataFrame(shap_values.values, columns=X.columns).corr()\n",
484 | "\n",
485 | "sns.set(font_scale=1)\n",
486 | "sns.heatmap(corr_matrix, cmap=\"coolwarm\", center=0, annot=True, fmt=\".1g\")"
487 | ]
488 | },
489 | {
490 | "cell_type": "markdown",
491 | "metadata": {},
492 | "source": [
493 | "# 4) Binary and categorical target variables\n"
494 | ]
495 | },
496 | {
497 | "cell_type": "markdown",
498 | "metadata": {},
499 | "source": [
500 | "### Binary target variable\n"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": 30,
506 | "metadata": {},
507 | "outputs": [],
508 | "source": [
509 | "# binary target varibale\n",
510 | "y_bin = [1 if y_ > 10 else 0 for y_ in y]"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": null,
516 | "metadata": {},
517 | "outputs": [],
518 | "source": [
519 | "# train model\n",
520 | "model_bin = xgb.XGBClassifier(objective=\"binary:logistic\")\n",
521 | "model_bin.fit(X, y_bin)"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "metadata": {},
528 | "outputs": [],
529 | "source": [
530 | "# get shap values\n",
531 | "explainer = shap.Explainer(model_bin)\n",
532 | "shap_values_bin = explainer(X)\n",
533 | "\n",
534 | "print(shap_values_bin.shape)"
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": null,
540 | "metadata": {},
541 | "outputs": [],
542 | "source": [
543 | "# waterfall plot for first observation\n",
544 | "shap.plots.waterfall(shap_values_bin[0])"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [],
552 | "source": [
553 | "# waterfall plot for first observation\n",
554 | "shap.plots.force(shap_values_bin[0], link=\"logit\")"
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": null,
560 | "metadata": {},
561 | "outputs": [],
562 | "source": [
563 | "# waterfall plot for first observation\n",
564 | "shap.plots.bar(shap_values_bin)"
565 | ]
566 | },
567 | {
568 | "cell_type": "markdown",
569 | "metadata": {},
570 | "source": [
571 | "### Categorical target variables\n"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": null,
577 | "metadata": {},
578 | "outputs": [],
579 | "source": [
580 | "# categorical target varibale\n",
581 | "y_cat = [2 if y_ > 12 else 1 if y_ > 8 else 0 for y_ in y]\n",
582 | "\n",
583 | "# train model\n",
584 | "model_cat = xgb.XGBClassifier(objective=\"binary:logistic\")\n",
585 | "model_cat.fit(X, y_cat)"
586 | ]
587 | },
588 | {
589 | "cell_type": "code",
590 | "execution_count": null,
591 | "metadata": {},
592 | "outputs": [],
593 | "source": [
594 | "# get probability predictions\n",
595 | "model_cat.predict_proba(X)[0]"
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "execution_count": null,
601 | "metadata": {},
602 | "outputs": [],
603 | "source": [
604 | "# get shap values\n",
605 | "explainer = shap.Explainer(model_cat)\n",
606 | "shap_values_cat = explainer(X)\n",
607 | "\n",
608 | "print(np.shape(shap_values_cat))"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "# waterfall plot for first observation\n",
618 | "shap.plots.waterfall(shap_values_cat[0, :, 0])\n",
619 | "\n",
620 | "# waterfall plot for first observation\n",
621 | "shap.plots.waterfall(shap_values_cat[0, :, 1])\n",
622 | "\n",
623 | "# waterfall plot for first observation\n",
624 | "shap.plots.waterfall(shap_values_cat[0, :, 2])"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": null,
630 | "metadata": {},
631 | "outputs": [],
632 | "source": [
633 | "def softmax(x):\n",
634 | " \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n",
635 | " e_x = np.exp(x - np.max(x))\n",
636 | " return e_x / e_x.sum(axis=0)\n",
637 | "\n",
638 | "\n",
639 | "# convert softmax to probability\n",
640 | "x = [0.383, -0.106, 1.211]\n",
641 | "softmax(x)"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": null,
647 | "metadata": {},
648 | "outputs": [],
649 | "source": [
650 | "# calculate mean SHAP values for each class\n",
651 | "mean_0 = np.mean(np.abs(shap_values_cat.values[:, :, 0]), axis=0)\n",
652 | "mean_1 = np.mean(np.abs(shap_values_cat.values[:, :, 1]), axis=0)\n",
653 | "mean_2 = np.mean(np.abs(shap_values_cat.values[:, :, 2]), axis=0)\n",
654 | "\n",
655 | "df = pd.DataFrame({\"young\": mean_0, \"medium\": mean_1, \"old\": mean_2})\n",
656 | "\n",
657 | "# plot mean SHAP values\n",
658 | "fig, ax = plt.subplots(1, 1, figsize=(20, 10))\n",
659 | "df.plot.bar(ax=ax)\n",
660 | "\n",
661 | "ax.set_ylabel(\"Mean SHAP\", size=30)\n",
662 | "ax.set_xticklabels(X.columns, rotation=45, size=20)\n",
663 | "ax.legend(fontsize=30)"
664 | ]
665 | },
666 | {
667 | "cell_type": "code",
668 | "execution_count": null,
669 | "metadata": {},
670 | "outputs": [],
671 | "source": [
672 | "# get model predictions\n",
673 | "preds = model_cat.predict(X)\n",
674 | "\n",
675 | "new_shap_values = []\n",
676 | "for i, pred in enumerate(preds):\n",
677 | " # get shap values for predicted class\n",
678 | " new_shap_values.append(shap_values_cat.values[i][:, pred])\n",
679 | "\n",
680 | "# replace shap values\n",
681 | "shap_values_cat.values = np.array(new_shap_values)\n",
682 | "print(shap_values_cat.shape)"
683 | ]
684 | },
685 | {
686 | "cell_type": "code",
687 | "execution_count": null,
688 | "metadata": {},
689 | "outputs": [],
690 | "source": [
691 | "shap.plots.bar(shap_values_cat)"
692 | ]
693 | },
694 | {
695 | "cell_type": "code",
696 | "execution_count": null,
697 | "metadata": {},
698 | "outputs": [],
699 | "source": [
700 | "shap.plots.beeswarm(shap_values_cat)"
701 | ]
702 | },
703 | {
704 | "cell_type": "markdown",
705 | "metadata": {},
706 | "source": [
707 | "# 5) SHAP interaction value\n"
708 | ]
709 | },
710 | {
711 | "cell_type": "code",
712 | "execution_count": 45,
713 | "metadata": {},
714 | "outputs": [],
715 | "source": [
716 | "# get SHAP interaction values\n",
717 | "explainer = shap.Explainer(model)\n",
718 | "shap_interaction = explainer.shap_interaction_values(X)"
719 | ]
720 | },
721 | {
722 | "cell_type": "code",
723 | "execution_count": null,
724 | "metadata": {},
725 | "outputs": [],
726 | "source": [
727 | "# get shape of interaction values\n",
728 | "np.shape(shap_interaction)"
729 | ]
730 | },
731 | {
732 | "cell_type": "code",
733 | "execution_count": null,
734 | "metadata": {},
735 | "outputs": [],
736 | "source": [
737 | "# SHAP interaction values for first employee\n",
738 | "shap_0 = np.round(shap_interaction[0], 2)\n",
739 | "pd.DataFrame(shap_0, index=X.columns, columns=X.columns)"
740 | ]
741 | },
742 | {
743 | "cell_type": "markdown",
744 | "metadata": {},
745 | "source": [
746 | "## Mean SHAP interaction values\n"
747 | ]
748 | },
749 | {
750 | "cell_type": "code",
751 | "execution_count": null,
752 | "metadata": {},
753 | "outputs": [],
754 | "source": [
755 | "# get absolute mean of matrices\n",
756 | "mean_shap = np.abs(shap_interaction).mean(0)\n",
757 | "mean_shap = np.round(mean_shap, 1)\n",
758 | "\n",
759 | "df = pd.DataFrame(mean_shap, index=X.columns, columns=X.columns)\n",
760 | "\n",
761 | "# times off diagonal by 2\n",
762 | "df.where(df.values == np.diagonal(df), df.values * 2, inplace=True)\n",
763 | "\n",
764 | "# display\n",
765 | "sns.set(font_scale=1)\n",
766 | "sns.heatmap(df, cmap=\"coolwarm\", annot=True)\n",
767 | "plt.yticks(rotation=0)"
768 | ]
769 | },
770 | {
771 | "cell_type": "markdown",
772 | "metadata": {},
773 | "source": [
774 | "## Dependence plot\n"
775 | ]
776 | },
777 | {
778 | "cell_type": "code",
779 | "execution_count": null,
780 | "metadata": {},
781 | "outputs": [],
782 | "source": [
783 | "shap.dependence_plot(\n",
784 | " (\"shell weight\", \"shucked weight\"), shap_interaction, X, display_features=X\n",
785 | ")"
786 | ]
787 | },
788 | {
789 | "cell_type": "code",
790 | "execution_count": null,
791 | "metadata": {},
792 | "outputs": [],
793 | "source": [
794 | "# interaction between shell weight and shucked weight\n",
795 | "plt.scatter(data[\"shell weight\"], data[\"shucked weight\"], c=data[\"rings\"], cmap=\"bwr\")\n",
796 | "plt.colorbar(label=\"Number of Rings\", orientation=\"vertical\")\n",
797 | "\n",
798 | "plt.xlabel(\"shucked weight\", size=15)\n",
799 | "plt.ylabel(\"shell weight\", size=15)"
800 | ]
801 | },
802 | {
803 | "cell_type": "markdown",
804 | "metadata": {},
805 | "source": [
806 | "# 6) SHAP for categorical variables\n"
807 | ]
808 | },
809 | {
810 | "cell_type": "code",
811 | "execution_count": null,
812 | "metadata": {},
813 | "outputs": [],
814 | "source": [
815 | "X.head()"
816 | ]
817 | },
818 | {
819 | "cell_type": "code",
820 | "execution_count": null,
821 | "metadata": {},
822 | "outputs": [],
823 | "source": [
824 | "# Waterfall plot for first observation\n",
825 | "shap.plots.waterfall(shap_values[0])"
826 | ]
827 | },
828 | {
829 | "cell_type": "code",
830 | "execution_count": null,
831 | "metadata": {},
832 | "outputs": [],
833 | "source": [
834 | "new_shap_values = []\n",
835 | "\n",
836 | "# loop over all shap values:\n",
837 | "for values in shap_values.values:\n",
838 | " # sum SHAP values for sex\n",
839 | " sv = list(values)\n",
840 | " sv = sv[0:5] + [sum(sv[5:8])]\n",
841 | "\n",
842 | " new_shap_values.append(sv)"
843 | ]
844 | },
845 | {
846 | "cell_type": "code",
847 | "execution_count": null,
848 | "metadata": {},
849 | "outputs": [],
850 | "source": [
851 | "# replace shap values\n",
852 | "shap_values.values = np.array(new_shap_values)\n",
853 | "\n",
854 | "# replace data with categorical feature values\n",
855 | "X_cat = data[\n",
856 | " [\"length\", \"height\", \"shucked weight\", \"viscera weight\", \"shell weight\", \"sex\"]\n",
857 | "]\n",
858 | "shap_values.data = np.array(X_cat)\n",
859 | "\n",
860 | "# update feature names\n",
861 | "shap_values.feature_names = list(X_cat.columns)"
862 | ]
863 | },
864 | {
865 | "cell_type": "code",
866 | "execution_count": null,
867 | "metadata": {},
868 | "outputs": [],
869 | "source": [
870 | "shap.plots.waterfall(shap_values[0])"
871 | ]
872 | },
873 | {
874 | "cell_type": "code",
875 | "execution_count": null,
876 | "metadata": {},
877 | "outputs": [],
878 | "source": [
879 | "shap.plots.bar(shap_values)"
880 | ]
881 | },
882 | {
883 | "cell_type": "code",
884 | "execution_count": null,
885 | "metadata": {},
886 | "outputs": [],
887 | "source": [
888 | "shap.plots.beeswarm(shap_values)"
889 | ]
890 | },
891 | {
892 | "cell_type": "code",
893 | "execution_count": null,
894 | "metadata": {},
895 | "outputs": [],
896 | "source": [
897 | "# get shaply values and data\n",
898 | "sex_values = shap_values[:, \"sex\"].values\n",
899 | "sex_data = shap_values[:, \"sex\"].data\n",
900 | "sex_categories = [\"I\", \"M\", \"F\"]\n",
901 | "\n",
902 | "# split sex shap values based on category\n",
903 | "sex_groups = []\n",
904 | "for s in sex_categories:\n",
905 | " relevant_values = sex_values[sex_data == s]\n",
906 | " sex_groups.append(relevant_values)\n",
907 | "\n",
908 | "# plot boxplot\n",
909 | "plt.boxplot(sex_groups, labels=sex_categories)\n",
910 | "\n",
911 | "plt.ylabel(\"SHAP values\", size=15)\n",
912 | "plt.xlabel(\"Sex\", size=15)"
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "execution_count": null,
918 | "metadata": {},
919 | "outputs": [],
920 | "source": [
921 | "# create for placeholder SHAP values\n",
922 | "shap_values_sex = explainer(X)\n",
923 | "\n",
924 | "# get shaply values and data\n",
925 | "sex_values = shap_values[:, \"sex\"].values\n",
926 | "sex_data = shap_values[:, \"sex\"].data\n",
927 | "sex_categories = [\"I\", \"M\", \"F\"]\n",
928 | "\n",
929 | "# create new SHAP values array\n",
930 | "\n",
931 | "# split odor SHAP values by unique odor categories\n",
932 | "new_shap_values = [\n",
933 | " np.array(pd.Series(sex_values)[sex_data == s]) for s in sex_categories\n",
934 | "]\n",
935 | "\n",
936 | "# each sublist needs to be the same length\n",
937 | "max_len = max([len(v) for v in new_shap_values])\n",
938 | "new_shap_values = [\n",
939 | " np.append(vs, [np.nan] * (max_len - len(vs))) for vs in new_shap_values\n",
940 | "]\n",
941 | "new_shap_values = np.array(new_shap_values)\n",
942 | "\n",
943 | "# transpost matrix so categories are columns and SHAP values are rows\n",
944 | "new_shap_values = new_shap_values.transpose()\n",
945 | "\n",
946 | "# replace shap values\n",
947 | "shap_values_sex.values = np.array(new_shap_values)\n",
948 | "\n",
949 | "# replace data with placeholder array\n",
950 | "shap_values_sex.data = np.array([[0] * len(sex_categories)] * max_len)\n",
951 | "\n",
952 | "# replace base data with placeholder array\n",
953 | "shap_values_sex.base = np.array([0] * max_len)\n",
954 | "\n",
955 | "# replace feature names with category labels\n",
956 | "shap_values_sex.feature_names = list(sex_categories)\n",
957 | "\n",
958 | "# use beeswarm as before\n",
959 | "shap.plots.beeswarm(shap_values_sex, color_bar=False)"
960 | ]
961 | },
962 | {
963 | "cell_type": "code",
964 | "execution_count": null,
965 | "metadata": {},
966 | "outputs": [],
967 | "source": [
968 | "import warnings\n",
969 | "\n",
970 | "warnings.filterwarnings(\"ignore\")"
971 | ]
972 | }
973 | ],
974 | "metadata": {
975 | "kernelspec": {
976 | "display_name": "shap",
977 | "language": "python",
978 | "name": "shap"
979 | },
980 | "language_info": {
981 | "codemirror_mode": {
982 | "name": "ipython",
983 | "version": 3
984 | },
985 | "file_extension": ".py",
986 | "mimetype": "text/x-python",
987 | "name": "python",
988 | "nbconvert_exporter": "python",
989 | "pygments_lexer": "ipython3",
990 | "version": "3.10.4"
991 | }
992 | },
993 | "nbformat": 4,
994 | "nbformat_minor": 2
995 | }
996 |
--------------------------------------------------------------------------------