├── explanations.py ├── links.py ├── datatypes.py ├── common.py ├── README.md ├── dynamic_shap_plots.py └── shap_plots.py /explanations.py: -------------------------------------------------------------------------------- 1 | from common import Model, Instance 2 | from datatypes import Data 3 | from links import Link 4 | 5 | class Explanation: 6 | def __init__(self): 7 | pass 8 | 9 | class AdditiveExplanation(Explanation): 10 | def __init__(self, base_value, out_value, effects, effects_var, instance, link, model, data): 11 | self.base_value = base_value 12 | self.out_value = out_value 13 | self.effects = effects 14 | self.effects_var = effects_var 15 | # assert isinstance(instance, Instance) 16 | self.instance = instance 17 | # assert isinstance(link, Link) 18 | self.link = link 19 | # assert isinstance(model, Model) 20 | self.model = model 21 | # assert isinstance(data, Data) 22 | self.data = data 23 | -------------------------------------------------------------------------------- /links.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Link: 5 | def __init__(self): 6 | pass 7 | 8 | 9 | class IdentityLink(Link): 10 | def __str__(self): 11 | return "identity" 12 | 13 | @staticmethod 14 | def f(x): 15 | return x 16 | 17 | @staticmethod 18 | def finv(x): 19 | return x 20 | 21 | 22 | class LogitLink(Link): 23 | def __str__(self): 24 | return "logit" 25 | 26 | @staticmethod 27 | def f(x): 28 | return np.log(x/(1-x)) 29 | 30 | @staticmethod 31 | def finv(x): 32 | return 1/(1+np.exp(-x)) 33 | 34 | 35 | def convert_to_link(val): 36 | if isinstance(val, Link): 37 | return val 38 | elif val == "identity": 39 | return IdentityLink() 40 | elif val == "logit": 41 | return LogitLink() 42 | else: 43 | assert False, "Passed link object must be a subclass of iml.Link" 44 | -------------------------------------------------------------------------------- /datatypes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | try: 3 | import pandas as pd 4 | except ImportError: 5 | pass 6 | 7 | class Data: 8 | def __init__(self): 9 | pass 10 | 11 | 12 | class DenseData(Data): 13 | def __init__(self, data, group_names, *args): 14 | self.groups = args[0] if len(args) > 0 and args[0] != None else [np.array([i]) for i in range(len(group_names))] 15 | 16 | l = sum(len(g) for g in self.groups) 17 | num_samples = data.shape[0] 18 | t = False 19 | if l != data.shape[1]: 20 | t = True 21 | num_samples = data.shape[1] 22 | 23 | valid = (not t and l == data.shape[1]) or (t and l == data.shape[0]) 24 | assert valid, "# of names must match data matrix!" 25 | 26 | self.weights = args[1] if len(args) > 1 else np.ones(num_samples) 27 | self.weights /= np.sum(self.weights) 28 | wl = len(self.weights) 29 | valid = (not t and wl == data.shape[0]) or (t and wl == data.shape[1]) 30 | assert valid, "# weights must match data matrix!" 31 | 32 | self.transposed = t 33 | self.group_names = group_names 34 | self.data = data 35 | 36 | 37 | class DenseDataWithIndex(DenseData): 38 | def __init__(self, data, group_names, index, index_name, *args): 39 | DenseData.__init__(self, data, group_names, *args) 40 | self.index_value = index 41 | self.index_name = index_name 42 | 43 | def convert_to_df(self): 44 | data = pd.DataFrame(self.data, columns=self.group_names) 45 | index = pd.DataFrame(self.index_value, columns=[self.index_name]) 46 | df = pd.concat([index, data], axis=1) 47 | df = df.set_index(self.index_name) 48 | return df 49 | 50 | 51 | def convert_to_data(val, keep_index=False): 52 | if isinstance(val, Data): 53 | return val 54 | elif type(val) == np.ndarray: 55 | return DenseData(val, [str(i) for i in range(val.shape[1])]) 56 | elif str(type(val)).endswith("'pandas.core.series.Series'>"): 57 | return DenseData(val.as_matrix().reshape((1,len(val))), list(val.index)) 58 | elif str(type(val)).endswith("'pandas.core.frame.DataFrame'>"): 59 | if keep_index: 60 | return DenseDataWithIndex(val.as_matrix(), list(val.columns), val.index.values, val.index.name) 61 | else: 62 | return DenseData(val.as_matrix(), list(val.columns)) 63 | else: 64 | assert False, "Unknown type passed as data object: "+str(type(val)) 65 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | from datatypes import DenseData, DenseDataWithIndex 2 | import re 3 | import pandas as pd 4 | 5 | class Instance: 6 | def __init__(self, x, group_display_values): 7 | self.x = x 8 | self.group_display_values = group_display_values 9 | 10 | 11 | def convert_to_instance(val): 12 | if isinstance(val, Instance): 13 | return val 14 | else: 15 | return Instance(val, None) 16 | 17 | 18 | class InstanceWithIndex(Instance): 19 | def __init__(self, x, column_name, index_value, index_name, group_display_values): 20 | Instance.__init__(self, x, group_display_values) 21 | self.index_value = index_value 22 | self.index_name = index_name 23 | self.column_name = column_name 24 | 25 | def convert_to_df(self): 26 | index = pd.DataFrame(self.index_value, columns=[self.index_name]) 27 | data = pd.DataFrame(self.x, columns=self.column_name) 28 | df = pd.concat([index, data], axis=1) 29 | df = df.set_index(self.index_name) 30 | return df 31 | 32 | 33 | def convert_to_instance_with_index(val, column_name, index_value, index_name): 34 | return InstanceWithIndex(val, column_name, index_value, index_name, None) 35 | 36 | 37 | def match_instance_to_data(instance, data): 38 | assert isinstance(instance, Instance), "instance must be of type Instance!" 39 | 40 | if isinstance(data, DenseData): 41 | if instance.group_display_values is None: 42 | instance.group_display_values = [instance.x[0, group[0]] if len(group) == 1 else "" for group in data.groups] 43 | assert len(instance.group_display_values) == len(data.groups) 44 | instance.groups = data.groups 45 | 46 | 47 | class Model: 48 | def __init__(self, f, out_names): 49 | self.f = f 50 | self.out_names = out_names 51 | 52 | 53 | def convert_to_model(val): 54 | if isinstance(val, Model): 55 | return val 56 | else: 57 | return Model(val, None) 58 | 59 | 60 | def match_model_to_data(model, data): 61 | assert isinstance(model, Model), "model must be of type Model!" 62 | 63 | if isinstance(data, DenseData): 64 | try: 65 | if isinstance(data, DenseDataWithIndex): 66 | out_val = model.f(data.convert_to_df()) 67 | else: 68 | out_val = model.f(data.data) 69 | except: 70 | print("Provided model function fails when applied to the provided data set.") 71 | raise 72 | 73 | if model.out_names is None: 74 | if len(out_val.shape) == 1: 75 | model.out_names = ["output value"] 76 | else: 77 | model.out_names = ["output value "+str(i) for i in range(out_val.shape[0])] 78 | 79 | def verify_valid_cmap(cmap): 80 | assert (isinstance(cmap,str) 81 | or isinstance(cmap,list) 82 | ),"Plot color map must be string or list!" 83 | if isinstance(cmap,list): 84 | assert (len(cmap) > 1),"Color map must be at least two colors." 85 | _rgbstring = re.compile(r'#[a-fA-F0-9]{6}$') 86 | for color in cmap: 87 | assert(bool(_rgbstring.match(color))),"Invalid color found in CMAP." 88 | 89 | return cmap 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic SHAP Plots 2 | This project enables interactive plotting of the visualizations from the [SHAP](https://github.com/slundberg/shap) project. The plots show the relative importances of the feature variables in a dataset when making predictions on the target variable. 3 | 4 | The core functions to calculate the SHAP values are taken from the SHAP library, and modified to return the matplotlib figure objects instead of plotting them. The file `dynamic_shap_plots.py` binds them all together to produce the interactive visualizations with the Plotly library. 5 | 6 | ## Requirements: 7 | ``` 8 | shap 9 | plotly 10 | pandas 11 | sklearn 12 | matplotlib 13 | xgboost 14 | iml 15 | scipy 16 | numpy 17 | ``` 18 | This package has been built and tested on Windows 10 with Python 3.5. Slight modifications may be needed in case of errors when using in Linux or Mac OS. 19 | 20 | ## Some dynamic SHAP visualizations in Jupyter notebook: 21 | 22 | ### 1. Summary Plot: 23 | ``` 24 | from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot 25 | from dynamic_shap_plots import summary_plot_plotly_fig as sum_plot 26 | import warnings 27 | warnings.filterwarnings('ignore') 28 | 29 | plotly_fig = sum_plot(r'path\to\dataset.csv', target='target column') 30 | 31 | init_notebook_mode(connected=True) 32 | iplot(plotly_fig, show_link=False) 33 | ``` 34 | 35 | ![](https://user-images.githubusercontent.com/39755678/62591715-16cbb480-b903-11e9-818f-82ce793af4b1.png) 36 | 37 | To save the figure: 38 | ``` 39 | plot(plotly_fig, show_link=False, filename=r'path\to\save\figure.html') 40 | ``` 41 | 42 | ### 2. Dependence Plot: 43 | ``` 44 | from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot 45 | from dynamic_shap_plots import dependence_plot_to_plotly_fig as dep_plot 46 | from shap_plots import shap_summary_plot, shap_dependence_plot 47 | import warnings 48 | warnings.filterwarnings('ignore') 49 | 50 | lis, features = dep_plot(r'path\to\dataset.csv', target='target column', max_display=20) 51 | 52 | init_notebook_mode(connected=True) 53 | for i in range(len(lis)): 54 | iplot(lis[i], show_link=False) 55 | ``` 56 | Alternately, you can also plot for specific features: 57 | 58 | ``` 59 | >>> features.index('Q2FC - Timeliness of billing notices/statements') 60 | 15 61 | >>> iplot(lis[15], show_link=False) 62 | ``` 63 | 64 | ![](https://user-images.githubusercontent.com/39755678/62591656-e257f880-b902-11e9-8a44-d5f75ad2304e.png) 65 | 66 | ### 3. Interaction Plot: 67 | ``` 68 | from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot 69 | from dynamic_shap_plots import dependence_plot_to_plotly_fig as dep_plot 70 | from dynamic_shap_plots import interaction_plot_to_plotly_fig as int_plot 71 | from shap_plots import shap_summary_plot, shap_dependence_plot 72 | import warnings 73 | warnings.filterwarnings('ignore') 74 | 75 | lis, features = int_plot(r'path\to\dataset.csv', target='target column', max_display=20) 76 | 77 | init_notebook_mode(connected=True) 78 | for i in range(len(lis)): 79 | iplot(lis[i], show_link=False) 80 | ``` 81 | Alternately, you can also plot for specific features: 82 | 83 | ``` 84 | >>> features.index('QCF - Caring company') 85 | 262 86 | >>> iplot(lis[262], show_link=False) 87 | ``` 88 | 89 | ![](https://user-images.githubusercontent.com/39755678/62591749-39f66400-b903-11e9-9400-5c0eaec4c35d.png) 90 | -------------------------------------------------------------------------------- /dynamic_shap_plots.py: -------------------------------------------------------------------------------- 1 | from shap_plots import shap_summary_plot, shap_dependence_plot 2 | import plotly.tools as tls 3 | import dash_core_components as dcc 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | import numpy as np 7 | import xgboost 8 | import shap 9 | import matplotlib 10 | import plotly.graph_objs as go 11 | try: 12 | import matplotlib.pyplot as pl 13 | from matplotlib.colors import LinearSegmentedColormap 14 | from matplotlib.ticker import MaxNLocator 15 | except ImportError: 16 | pass 17 | from sklearn import preprocessing 18 | 19 | cdict1 = { 20 | 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), 21 | (1.0, 0.9607843137254902, 0.9607843137254902)), 22 | 23 | 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), 24 | (1.0, 0.15294117647058825, 0.15294117647058825)), 25 | 26 | 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), 27 | (1.0, 0.3411764705882353, 0.3411764705882353)), 28 | 29 | 'alpha': ((0.0, 1, 1), 30 | (0.5, 1, 1), 31 | (1.0, 1, 1)) 32 | } # #1E88E5 -> #ff0052 33 | red_blue = LinearSegmentedColormap('RedBlue', cdict1) 34 | 35 | def matplotlib_to_plotly(cmap, pl_entries): 36 | h = 1.0/(pl_entries-1) 37 | pl_colorscale = [] 38 | 39 | for k in range(pl_entries): 40 | C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255)) 41 | pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))]) 42 | 43 | return pl_colorscale 44 | 45 | red_blue = matplotlib_to_plotly(red_blue, 255) 46 | 47 | def summary_plot_plotly_fig(dataset, target='target column', max_display = 20): 48 | data = pd.read_csv(dataset, encoding="ISO-8859-1") 49 | X = data.drop(['target column'], axis=1) 50 | 51 | y = data[target] 52 | y = y/max(y) 53 | 54 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7) 55 | 56 | X_train.fillna((-999), inplace=True) 57 | X_test.fillna((-999), inplace=True) 58 | 59 | _, shap_values, feature_names = train_model_and_return_shap_values(X, y, target) 60 | 61 | mpl_fig = shap_summary_plot(shap_values, pd.DataFrame(X_train, columns=X.columns), feature_names=feature_names, max_display=20) 62 | 63 | plotly_fig = tls.mpl_to_plotly(mpl_fig) 64 | 65 | plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}} 66 | 67 | feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) 68 | feature_order = feature_order[-min(max_display, len(feature_order)):] 69 | text = [feature_names[i] for i in feature_order] 70 | text = iter(text) 71 | 72 | for i in range(1, len(plotly_fig['data']), 2): 73 | t = text.__next__() 74 | plotly_fig['data'][i]['name'] = '' 75 | plotly_fig['data'][i]['text'] = t 76 | plotly_fig['data'][i]['hoverinfo'] = 'text' 77 | 78 | colorbar_trace = go.Scatter(x=[None], 79 | y=[None], 80 | mode='markers', 81 | marker=dict( 82 | colorscale=red_blue, 83 | showscale=True, 84 | cmin=-5, 85 | cmax=5, 86 | colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0) 87 | ), 88 | hoverinfo='none' 89 | ) 90 | 91 | plotly_fig['layout']['showlegend'] = False 92 | plotly_fig['layout']['hovermode'] = 'closest' 93 | plotly_fig['layout']['height']=600 94 | plotly_fig['layout']['width']=500 95 | 96 | plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False) 97 | plotly_fig['layout']['yaxis'].update(dict(visible=False)) 98 | plotly_fig.add_trace(colorbar_trace) 99 | plotly_fig.layout.update( 100 | annotations=[dict( 101 | x=1.18, 102 | align="right", 103 | valign="top", 104 | text='Feature value', 105 | showarrow=False, 106 | xref="paper", 107 | yref="paper", 108 | xanchor="right", 109 | yanchor="middle", 110 | textangle=-90, 111 | font=dict(family='Calibri', size=14) 112 | ) 113 | ], 114 | margin=dict(t=20) 115 | ) 116 | return plotly_fig 117 | 118 | def train_model_and_return_shap_values(X, y, target): 119 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7) 120 | 121 | X_train.fillna((-999), inplace=True) 122 | X_test.fillna((-999), inplace=True) 123 | 124 | # Some of values are float or integer and some object. This is why we need to cast them: 125 | for f in X_train.columns: 126 | if X_train[f].dtype=='object': 127 | lbl = preprocessing.LabelEncoder() 128 | lbl.fit(list(X_train[f].values)) 129 | X_train[f] = lbl.transform(list(X_train[f].values)) 130 | 131 | for f in X_test.columns: 132 | if X_test[f].dtype=='object': 133 | lbl = preprocessing.LabelEncoder() 134 | lbl.fit(list(X_test[f].values)) 135 | X_test[f] = lbl.transform(list(X_test[f].values)) 136 | 137 | X_train=np.array(X_train) 138 | X_test=np.array(X_test) 139 | X_train = X_train.astype(float) 140 | X_test = X_test.astype(float) 141 | 142 | d_train = xgboost.DMatrix(X_train, label=y_train, feature_names=list(X)) 143 | d_test = xgboost.DMatrix(X_test, label=y_test, feature_names=list(X)) 144 | 145 | # train the model 146 | params = { 147 | "eta": 0.01, 148 | "subsample": 0.5, 149 | "base_score": np.mean(y_train), 150 | "silent": 1 151 | } 152 | 153 | model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=None, early_stopping_rounds=50) 154 | feature_names = model.feature_names 155 | shap_values = shap.TreeExplainer(model).shap_values(pd.DataFrame(X_train, columns=X.columns)) 156 | return model, shap_values, feature_names 157 | 158 | def dependence_plot_to_plotly_fig(dataset, target='target column', max_display=10): 159 | data = pd.read_csv(dataset, encoding="ISO-8859-1") 160 | X = data.drop(['target column'], axis=1) 161 | y = data[target] 162 | y = y/max(y) 163 | 164 | xgb_full = xgboost.DMatrix(X, label=y) 165 | 166 | # create a train/test split 167 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7) 168 | xgb_train = xgboost.DMatrix(X_train, label=y_train) 169 | xgb_test = xgboost.DMatrix(X_test, label=y_test) 170 | 171 | # use validation set to choose # of trees 172 | params = { 173 | # "eta": 0.002, 174 | # "max_depth": 3, 175 | # "subsample": 0.5, 176 | "silent": 1 177 | } 178 | model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None) 179 | 180 | # train final model on the full data set 181 | params = { 182 | # "eta": 0.002, 183 | # "max_depth": 3, 184 | # "subsample": 0.5, 185 | "silent": 1 186 | } 187 | model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None) 188 | features = model.feature_names 189 | shap_values = shap.TreeExplainer(model).shap_values(X) 190 | 191 | feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) 192 | feature_order = feature_order[-min(max_display, len(feature_order)):] 193 | features = [features[i] for i in feature_order[::-1]] 194 | 195 | lis = [] 196 | for i in features: 197 | mpl_fig, interaction_index = shap_dependence_plot(i, shap_values, X) 198 | plotly_fig = tls.mpl_to_plotly(mpl_fig) 199 | 200 | # The x-tick labels start by default from 0, which is not necessarily the min value of the feature. 201 | # So, we need to increment the x-tick labels by 1. But while doing so, the y-axis gets shifted. 202 | # To prevent that, we need to manually control the x-axis range from r_min to r_max 203 | new_x = [] 204 | for j in plotly_fig['data'][0]['x']: 205 | new_x.append(j) 206 | 207 | r_min = min(plotly_fig['data'][0]['x']) 208 | r_max = max(plotly_fig['data'][0]['x']) 209 | 210 | plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1]) 211 | plotly_fig['data'][0]['x'] = tuple(new_x) 212 | 213 | # Define the colorbar 214 | colorbar_trace = go.Scatter(x=[None], 215 | y=[None], 216 | mode='markers', 217 | marker=dict( 218 | colorscale=red_blue, 219 | showscale=True, 220 | colorbar=dict(thickness=5, outlinewidth=0), 221 | color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])], 222 | ), 223 | hoverinfo='none' 224 | ) 225 | 226 | plotly_fig['layout']['showlegend'] = False 227 | plotly_fig['layout']['hovermode'] = 'closest' 228 | plotly_fig['layout']['height']=380 229 | plotly_fig['layout']['width']=450 230 | plotly_fig['layout']['xaxis'].update(zeroline=True, 231 | showline=True, 232 | ticklen=4, 233 | showgrid=False, 234 | tickmode='linear') 235 | title = plotly_fig['layout']['yaxis']['title'] 236 | plotly_fig['layout']['yaxis'].update(title=title.split(' -')[0]) 237 | 238 | plotly_fig.add_trace(colorbar_trace) 239 | plotly_fig.layout.update( 240 | annotations=[dict( 241 | x=1.23, 242 | align="right", 243 | valign="top", 244 | text=X.columns[interaction_index], 245 | showarrow=False, 246 | xref="paper", 247 | yref="paper", 248 | xanchor="right", 249 | yanchor="middle", 250 | textangle=-90, 251 | font=dict(family='Calibri', size=14) 252 | ) 253 | ], 254 | margin=dict(t=50, b=50, l=50, r=80) 255 | ) 256 | lis.append(plotly_fig) 257 | return lis, features 258 | 259 | def interaction_plot_to_plotly_fig(dataset, target_col='target column', max_display=10): 260 | data = pd.read_csv(dataset, encoding="ISO-8859-1") 261 | X = data.drop(['target column'], axis=1) 262 | y = data[target_col] 263 | y = y/max(y) 264 | 265 | xgb_full = xgboost.DMatrix(X, label=y) 266 | 267 | # create a train/test split 268 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7) 269 | xgb_train = xgboost.DMatrix(X_train, label=y_train) 270 | xgb_test = xgboost.DMatrix(X_test, label=y_test) 271 | 272 | # use validation set to choose # of trees 273 | params = { 274 | # "eta": 0.002, 275 | # "max_depth": 3, 276 | # "subsample": 0.5, 277 | "silent": 1 278 | } 279 | model_train = xgboost.train(params, xgb_train, 3000, evals = [(xgb_test, "test")], verbose_eval=None) 280 | 281 | # train final model on the full data set 282 | params = { 283 | # "eta": 0.002, 284 | # "max_depth": 3, 285 | # "subsample": 0.5, 286 | "silent": 1 287 | } 288 | model = xgboost.train(params, xgb_full, 1500, evals = [(xgb_full, "test")], verbose_eval=None) 289 | features = model.feature_names 290 | shap_values = shap.TreeExplainer(model).shap_values(X) 291 | 292 | feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) 293 | feature_order = feature_order[-min(max_display, len(feature_order)):] 294 | features = [features[i] for i in feature_order[::-1]] 295 | 296 | shap_interaction_values = shap.TreeExplainer(model).shap_interaction_values(X) 297 | 298 | lis = [] 299 | for i in features: 300 | for j in features: 301 | mpl_fig = pl.figure() 302 | ax = mpl_fig.add_subplot(111) 303 | _, interaction_index = shap_dependence_plot ( (i, j), shap_interaction_values, X.iloc[:2000,:] ) 304 | plotly_fig = tls.mpl_to_plotly(mpl_fig) 305 | 306 | r_min = min(plotly_fig['data'][0]['x']) 307 | r_max = max(plotly_fig['data'][0]['x']) 308 | 309 | plotly_fig['layout']['xaxis'].update(range=[r_min-1, r_max+1]) 310 | plotly_fig['layout']['showlegend'] = False 311 | plotly_fig['layout']['hovermode'] = 'closest' 312 | plotly_fig['layout']['height']=380 313 | plotly_fig['layout']['width']=450 314 | plotly_fig['layout']['xaxis'].update(zeroline=True, 315 | showline=True, 316 | ticklen=4, 317 | showgrid=False, 318 | tickmode='linear') 319 | plotly_fig['layout']['yaxis'].update(showline=True) 320 | 321 | if i!=j: 322 | # plotly_fig['layout']['height']=380 323 | plotly_fig['layout']['width']=480 324 | plotly_fig['layout']['yaxis']['title'] = "SHAP interaction value for {} and {}".format(i.split('-')[0], j.split('-')[0]) 325 | # Define the colorbar 326 | colorbar_trace = go.Scatter(x=[None], 327 | y=[None], 328 | mode='markers', 329 | marker=dict( 330 | colorscale=red_blue, 331 | showscale=True, 332 | colorbar=dict(thickness=5, outlinewidth=0), 333 | color=[min(X[X.columns[interaction_index]]), max(X[X.columns[interaction_index]])], 334 | ), 335 | hoverinfo='none' 336 | ) 337 | plotly_fig.add_trace(colorbar_trace) 338 | plotly_fig.layout.update( 339 | annotations=[dict( 340 | x=1.23, 341 | align="right", 342 | valign="top", 343 | text=X.columns[interaction_index], 344 | showarrow=False, 345 | xref="paper", 346 | yref="paper", 347 | xanchor="right", 348 | yanchor="middle", 349 | textangle=-90, 350 | font=dict(family='Calibri', size=14) 351 | ) 352 | ], 353 | margin=dict(t=30, b=30, l=60, r=80) 354 | ) 355 | else: 356 | plotly_fig['layout']['yaxis']['title'] = "SHAP main effect value for {}".format(i.split('-')[0]) 357 | lis.append(plotly_fig) 358 | return lis, features 359 | -------------------------------------------------------------------------------- /shap_plots.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import iml 3 | import numpy as np 4 | from iml import Instance, Model 5 | from iml.datatypes import DenseData 6 | from iml.explanations import AdditiveExplanation 7 | from iml.links import IdentityLink 8 | from scipy.stats import gaussian_kde 9 | import matplotlib 10 | try: 11 | import matplotlib.pyplot as pl 12 | from matplotlib.colors import LinearSegmentedColormap 13 | from matplotlib.ticker import MaxNLocator 14 | 15 | cdict1 = { 16 | 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), 17 | (1.0, 0.9607843137254902, 0.9607843137254902)), 18 | 19 | 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), 20 | (1.0, 0.15294117647058825, 0.15294117647058825)), 21 | 22 | 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), 23 | (1.0, 0.3411764705882353, 0.3411764705882353)), 24 | 25 | 'alpha': ((0.0, 1, 1), 26 | (0.5, 0.3, 0.3), 27 | (1.0, 1, 1)) 28 | } # #1E88E5 -> #ff0052 29 | red_blue = LinearSegmentedColormap('RedBlue', cdict1) 30 | 31 | cdict1 = { 32 | 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), 33 | (1.0, 0.9607843137254902, 0.9607843137254902)), 34 | 35 | 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), 36 | (1.0, 0.15294117647058825, 0.15294117647058825)), 37 | 38 | 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), 39 | (1.0, 0.3411764705882353, 0.3411764705882353)), 40 | 41 | 'alpha': ((0.0, 1, 1), 42 | (0.5, 1, 1), 43 | (1.0, 1, 1)) 44 | } # #1E88E5 -> #ff0052 45 | red_blue_solid = LinearSegmentedColormap('RedBlue', cdict1) 46 | except ImportError: 47 | pass 48 | 49 | labels = { 50 | 'MAIN_EFFECT': "SHAP main effect value for\n%s", 51 | 'INTERACTION_VALUE': "SHAP interaction value", 52 | 'INTERACTION_EFFECT': "SHAP interaction value for\n%s and %s", 53 | 'VALUE': "SHAP value (impact on model output)", 54 | 'VALUE_FOR': "SHAP value for\n%s", 55 | 'PLOT_FOR': "SHAP plot for %s", 56 | 'FEATURE': "Feature %s", 57 | 'FEATURE_VALUE': "Feature value", 58 | 'FEATURE_VALUE_LOW': "Low", 59 | 'FEATURE_VALUE_HIGH': "High", 60 | 'JOINT_VALUE': "Joint SHAP value" 61 | } 62 | 63 | def shap_summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type="dot", 64 | color=None, axis_color="#333333", title=None, alpha=1, show=True, sort=True, 65 | color_bar=True, auto_size_plot=True, layered_violin_max_num_bins=20): 66 | """Create a SHAP summary plot, colored by feature values when they are provided. 67 | 68 | Parameters 69 | ---------- 70 | shap_values : numpy.array 71 | Matrix of SHAP values (# samples x # features) 72 | 73 | features : numpy.array or pandas.DataFrame or list 74 | Matrix of feature values (# samples x # features) or a feature_names list as shorthand 75 | 76 | feature_names : list 77 | Names of the features (length # features) 78 | 79 | max_display : int 80 | How many top features to include in the plot (default is 20, or 7 for interaction plots) 81 | 82 | plot_type : "dot" (default) or "violin" 83 | What type of summary plot to produce 84 | """ 85 | 86 | assert len(shap_values.shape) != 1, "Summary plots need a matrix of shap_values, not a vector." 87 | 88 | # default color: 89 | if color is None: 90 | color = "coolwarm" if plot_type == 'layered_violin' else "#ff0052" 91 | 92 | # convert from a DataFrame or other types 93 | if str(type(features)) == "": 94 | if feature_names is None: 95 | feature_names = features.columns 96 | features = features.values 97 | elif str(type(features)) == "": 98 | if feature_names is None: 99 | feature_names = features 100 | features = None 101 | elif (features is not None) and len(features.shape) == 1 and feature_names is None: 102 | feature_names = features 103 | features = None 104 | 105 | if feature_names is None: 106 | feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)] 107 | 108 | mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1)) 109 | 110 | # plotting SHAP interaction values 111 | if len(shap_values.shape) == 3: 112 | if max_display is None: 113 | max_display = 7 114 | else: 115 | max_display = min(len(feature_names), max_display) 116 | 117 | sort_inds = np.argsort(-np.abs(shap_values[:, :-1, :-1].sum(1)).sum(0)) 118 | 119 | # get plotting limits 120 | delta = 1.0 / (shap_values.shape[1] ** 2) 121 | slow = np.nanpercentile(shap_values, delta) 122 | shigh = np.nanpercentile(shap_values, 100 - delta) 123 | v = max(abs(slow), abs(shigh)) 124 | slow = -0.2 125 | shigh = 0.2 126 | 127 | # mpl_fig = pl.figure(figsize=(1.5 * max_display + 1, 1 * max_display + 1)) 128 | ax = mpl_fig.subplot(1, max_display, 1) 129 | proj_shap_values = shap_values[:, sort_inds[0], np.hstack((sort_inds, len(sort_inds)))] 130 | proj_shap_values[:, 1:] *= 2 # because off diag effects are split in half 131 | shap_summary_plot( 132 | proj_shap_values, features[:, sort_inds], 133 | feature_names=feature_names[sort_inds], 134 | sort=False, show=False, color_bar=False, 135 | auto_size_plot=False, 136 | max_display=max_display 137 | ) 138 | pl.xlim((slow, shigh)) 139 | pl.xlabel("") 140 | title_length_limit = 11 141 | pl.title(shorten_text(feature_names[sort_inds[0]], title_length_limit)) 142 | for i in range(1, max_display): 143 | ind = sort_inds[i] 144 | pl.subplot(1, max_display, i + 1) 145 | proj_shap_values = shap_values[:, ind, np.hstack((sort_inds, len(sort_inds)))] 146 | proj_shap_values *= 2 147 | proj_shap_values[:, i] /= 2 # because only off diag effects are split in half 148 | shap_summary_plot( 149 | proj_shap_values, features[:, sort_inds], 150 | sort=False, 151 | feature_names=["" for i in range(features.shape[1])], 152 | show=False, 153 | color_bar=False, 154 | auto_size_plot=False, 155 | max_display=max_display 156 | ) 157 | pl.xlim((slow, shigh)) 158 | pl.xlabel("") 159 | if i == max_display // 2: 160 | pl.xlabel(labels['INTERACTION_VALUE']) 161 | pl.title(shorten_text(feature_names[ind], title_length_limit)) 162 | pl.tight_layout(pad=0, w_pad=0, h_pad=0.0) 163 | pl.subplots_adjust(hspace=0, wspace=0.1) 164 | # if show: 165 | # # pl.show() 166 | return mpl_fig 167 | 168 | if max_display is None: 169 | max_display = 20 170 | 171 | if sort: 172 | # order features by the sum of their effect magnitudes 173 | feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) 174 | feature_order = feature_order[-min(max_display, len(feature_order)):] 175 | else: 176 | feature_order = np.flip(np.arange(min(max_display, shap_values.shape[1] - 1)), 0) 177 | 178 | row_height = 0.4 179 | if auto_size_plot: 180 | pl.gcf().set_size_inches(8, len(feature_order) * row_height + 1.5) 181 | pl.axvline(x=0, color="#999999", zorder=-1) 182 | 183 | if plot_type == "dot": 184 | for pos, i in enumerate(feature_order): 185 | pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) 186 | shaps = shap_values[:, i] 187 | values = None if features is None else features[:, i] 188 | inds = np.arange(len(shaps)) 189 | np.random.shuffle(inds) 190 | if values is not None: 191 | values = values[inds] 192 | shaps = shaps[inds] 193 | colored_feature = True 194 | try: 195 | values = np.array(values, dtype=np.float64) # make sure this can be numeric 196 | except: 197 | colored_feature = False 198 | N = len(shaps) 199 | # hspacing = (np.max(shaps) - np.min(shaps)) / 200 200 | # curr_bin = [] 201 | nbins = 100 202 | quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8)) 203 | inds = np.argsort(quant + np.random.randn(N) * 1e-6) 204 | layer = 0 205 | last_bin = -1 206 | ys = np.zeros(N) 207 | for ind in inds: 208 | if quant[ind] != last_bin: 209 | layer = 0 210 | ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1) 211 | layer += 1 212 | last_bin = quant[ind] 213 | ys *= 0.9 * (row_height / np.max(ys + 1)) 214 | 215 | if features is not None and colored_feature: 216 | # trim the color range, but prevent the color range from collapsing 217 | vmin = np.nanpercentile(values, 5) 218 | vmax = np.nanpercentile(values, 95) 219 | if vmin == vmax: 220 | vmin = np.nanpercentile(values, 1) 221 | vmax = np.nanpercentile(values, 99) 222 | if vmin == vmax: 223 | vmin = np.min(values) 224 | vmax = np.max(values) 225 | 226 | assert features.shape[0] == len(shaps), "Feature and SHAP matrices must have the same number of rows!" 227 | nan_mask = np.isnan(values) 228 | pl.scatter(shaps[nan_mask], pos + ys[nan_mask], color="#777777", vmin=vmin, 229 | vmax=vmax, s=16, alpha=alpha, linewidth=0, 230 | zorder=3, rasterized=len(shaps) > 500) 231 | pl.scatter(shaps[np.invert(nan_mask)], pos + ys[np.invert(nan_mask)], 232 | cmap=red_blue, vmin=vmin, vmax=vmax, s=16, 233 | c=values[np.invert(nan_mask)], alpha=alpha, linewidth=0, 234 | zorder=3, rasterized=len(shaps) > 500) 235 | else: 236 | 237 | pl.scatter(shaps, pos + ys, s=16, alpha=alpha, linewidth=0, zorder=3, 238 | color=color if colored_feature else "#777777", rasterized=len(shaps) > 500) 239 | 240 | elif plot_type == "violin": 241 | for pos, i in enumerate(feature_order): 242 | pl.axhline(y=pos, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) 243 | 244 | if features is not None: 245 | global_low = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 1) 246 | global_high = np.nanpercentile(shap_values[:, :len(feature_names)].flatten(), 99) 247 | for pos, i in enumerate(feature_order): 248 | shaps = shap_values[:, i] 249 | shap_min, shap_max = np.min(shaps), np.max(shaps) 250 | rng = shap_max - shap_min 251 | xs = np.linspace(np.min(shaps) - rng * 0.2, np.max(shaps) + rng * 0.2, 100) 252 | if np.std(shaps) < (global_high - global_low) / 100: 253 | ds = gaussian_kde(shaps + np.random.randn(len(shaps)) * (global_high - global_low) / 100)(xs) 254 | else: 255 | ds = gaussian_kde(shaps)(xs) 256 | ds /= np.max(ds) * 3 257 | 258 | values = features[:, i] 259 | window_size = max(10, len(values) // 20) 260 | smooth_values = np.zeros(len(xs) - 1) 261 | sort_inds = np.argsort(shaps) 262 | trailing_pos = 0 263 | leading_pos = 0 264 | running_sum = 0 265 | back_fill = 0 266 | for j in range(len(xs) - 1): 267 | 268 | while leading_pos < len(shaps) and xs[j] >= shaps[sort_inds[leading_pos]]: 269 | running_sum += values[sort_inds[leading_pos]] 270 | leading_pos += 1 271 | if leading_pos - trailing_pos > 20: 272 | running_sum -= values[sort_inds[trailing_pos]] 273 | trailing_pos += 1 274 | if leading_pos - trailing_pos > 0: 275 | smooth_values[j] = running_sum / (leading_pos - trailing_pos) 276 | for k in range(back_fill): 277 | smooth_values[j - k - 1] = smooth_values[j] 278 | else: 279 | back_fill += 1 280 | 281 | vmin = np.nanpercentile(values, 5) 282 | vmax = np.nanpercentile(values, 95) 283 | if vmin == vmax: 284 | vmin = np.nanpercentile(values, 1) 285 | vmax = np.nanpercentile(values, 99) 286 | if vmin == vmax: 287 | vmin = np.min(values) 288 | vmax = np.max(values) 289 | pl.scatter(shaps, np.ones(shap_values.shape[0]) * pos, s=9, cmap=red_blue_solid, vmin=vmin, vmax=vmax, 290 | c=values, alpha=alpha, linewidth=0, zorder=1) 291 | # smooth_values -= nxp.nanpercentile(smooth_values, 5) 292 | # smooth_values /= np.nanpercentile(smooth_values, 95) 293 | smooth_values -= vmin 294 | if vmax - vmin > 0: 295 | smooth_values /= vmax - vmin 296 | for i in range(len(xs) - 1): 297 | if ds[i] > 0.05 or ds[i + 1] > 0.05: 298 | pl.fill_between([xs[i], xs[i + 1]], [pos + ds[i], pos + ds[i + 1]], 299 | [pos - ds[i], pos - ds[i + 1]], color=red_blue_solid(smooth_values[i]), 300 | zorder=2) 301 | 302 | else: 303 | parts = pl.violinplot(shap_values[:, feature_order], range(len(feature_order)), points=200, vert=False, 304 | widths=0.7, 305 | showmeans=False, showextrema=False, showmedians=False) 306 | 307 | for pc in parts['bodies']: 308 | pc.set_facecolor(color) 309 | pc.set_edgecolor('none') 310 | pc.set_alpha(alpha) 311 | 312 | elif plot_type == "layered_violin": # courtesy of @kodonnell 313 | num_x_points = 200 314 | bins = np.linspace(0, features.shape[0], layered_violin_max_num_bins + 1).round(0).astype( 315 | 'int') # the indices of the feature data corresponding to each bin 316 | shap_min, shap_max = np.min(shap_values[:, :-1]), np.max(shap_values[:, :-1]) 317 | x_points = np.linspace(shap_min, shap_max, num_x_points) 318 | 319 | # loop through each feature and plot: 320 | for pos, ind in enumerate(feature_order): 321 | # decide how to handle: if #unique < layered_violin_max_num_bins then split by unique value, otherwise use bins/percentiles. 322 | # to keep simpler code, in the case of uniques, we just adjust the bins to align with the unique counts. 323 | feature = features[:, ind] 324 | unique, counts = np.unique(feature, return_counts=True) 325 | if unique.shape[0] <= layered_violin_max_num_bins: 326 | order = np.argsort(unique) 327 | thesebins = np.cumsum(counts[order]) 328 | thesebins = np.insert(thesebins, 0, 0) 329 | else: 330 | thesebins = bins 331 | nbins = thesebins.shape[0] - 1 332 | # order the feature data so we can apply percentiling 333 | order = np.argsort(feature) 334 | # x axis is located at y0 = pos, with pos being there for offset 335 | y0 = np.ones(num_x_points) * pos 336 | # calculate kdes: 337 | ys = np.zeros((nbins, num_x_points)) 338 | for i in range(nbins): 339 | # get shap values in this bin: 340 | shaps = shap_values[order[thesebins[i]:thesebins[i + 1]], ind] 341 | # if there's only one element, then we can't 342 | if shaps.shape[0] == 1: 343 | warnings.warn( 344 | "not enough data in bin #%d for feature %s, so it'll be ignored. Try increasing the number of records to plot." 345 | % (i, feature_names[ind])) 346 | # to ignore it, just set it to the previous y-values (so the area between them will be zero). Not ys is already 0, so there's 347 | # nothing to do if i == 0 348 | if i > 0: 349 | ys[i, :] = ys[i - 1, :] 350 | continue 351 | # save kde of them: note that we add a tiny bit of gaussian noise to avoid singular matrix errors 352 | ys[i, :] = gaussian_kde(shaps + np.random.normal(loc=0, scale=0.001, size=shaps.shape[0]))(x_points) 353 | # scale it up so that the 'size' of each y represents the size of the bin. For continuous data this will 354 | # do nothing, but when we've gone with the unqique option, this will matter - e.g. if 99% are male and 1% 355 | # female, we want the 1% to appear a lot smaller. 356 | size = thesebins[i + 1] - thesebins[i] 357 | bin_size_if_even = features.shape[0] / nbins 358 | relative_bin_size = size / bin_size_if_even 359 | ys[i, :] *= relative_bin_size 360 | # now plot 'em. We don't plot the individual strips, as this can leave whitespace between them. 361 | # instead, we plot the full kde, then remove outer strip and plot over it, etc., to ensure no 362 | # whitespace 363 | ys = np.cumsum(ys, axis=0) 364 | width = 0.8 365 | scale = ys.max() * 2 / width # 2 is here as we plot both sides of x axis 366 | for i in range(nbins - 1, -1, -1): 367 | y = ys[i, :] / scale 368 | c = pl.get_cmap(color)(i / ( 369 | nbins - 1)) if color in pl.cm.datad else color # if color is a cmap, use it, otherwise use a color 370 | pl.fill_between(x_points, pos - y, pos + y, facecolor=c) 371 | pl.xlim(shap_min, shap_max) 372 | 373 | # draw the color bar 374 | if color_bar and features is not None and (plot_type != "layered_violin" or color in pl.cm.datad): 375 | import matplotlib.cm as cm 376 | m = cm.ScalarMappable(cmap=red_blue_solid if plot_type != "layered_violin" else pl.get_cmap(color)) 377 | m.set_array([0, 1]) 378 | cb = pl.colorbar(m, ticks=[0, 1], aspect=1000) 379 | cb.set_ticklabels([labels['FEATURE_VALUE_LOW'], labels['FEATURE_VALUE_HIGH']]) 380 | cb.set_label(labels['FEATURE_VALUE'], size=12, labelpad=0) 381 | cb.ax.tick_params(labelsize=11, length=0) 382 | cb.set_alpha(1) 383 | cb.outline.set_visible(False) 384 | bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) 385 | cb.ax.set_aspect((bbox.height - 0.9) * 20) 386 | # cb.draw_all() 387 | 388 | pl.gca().xaxis.set_ticks_position('bottom') 389 | pl.gca().yaxis.set_ticks_position('none') 390 | pl.gca().spines['right'].set_visible(False) 391 | pl.gca().spines['top'].set_visible(False) 392 | pl.gca().spines['left'].set_visible(False) 393 | pl.gca().tick_params(color=axis_color, labelcolor=axis_color) 394 | pl.yticks(range(len(feature_order)), [feature_names[i] for i in feature_order], fontsize=13) 395 | pl.gca().tick_params('y', length=20, width=0.5, which='major') 396 | pl.gca().tick_params('x', labelsize=11) 397 | pl.ylim(-1, len(feature_order)) 398 | pl.xlabel(labels['VALUE'], fontsize=13) 399 | pl.tight_layout() 400 | # if show: 401 | # pl.show() 402 | return mpl_fig 403 | 404 | 405 | 406 | 407 | 408 | 409 | def approx_interactions(index, shap_values, X): 410 | """ Order other features by how much interaction they seem to have with the feature at the given index. 411 | 412 | This just bins the SHAP values for a feature along that feature's value. For true Shapley interaction 413 | index values for SHAP see the interaction_contribs option implemented in XGBoost. 414 | """ 415 | 416 | if X.shape[0] > 10000: 417 | a = np.arange(X.shape[0]) 418 | np.random.shuffle(a) 419 | inds = a[:10000] 420 | else: 421 | inds = np.arange(X.shape[0]) 422 | 423 | x = X[inds, index] 424 | srt = np.argsort(x) 425 | shap_ref = shap_values[inds, index] 426 | shap_ref = shap_ref[srt] 427 | inc = max(min(int(len(x) / 10.0), 50), 1) 428 | interactions = [] 429 | for i in range(X.shape[1]): 430 | val_other = X[inds, i][srt].astype(np.float) 431 | v = 0.0 432 | if not (i == index or np.sum(np.abs(val_other)) < 1e-8): 433 | for j in range(0, len(x), inc): 434 | if np.std(val_other[j:j + inc]) > 0 and np.std(shap_ref[j:j + inc]) > 0: 435 | v += abs(np.corrcoef(shap_ref[j:j + inc], val_other[j:j + inc])[0, 1]) 436 | interactions.append(v) 437 | 438 | return np.argsort(-np.abs(interactions)) 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | def shap_dependence_plot(ind, shap_values, features, feature_names=None, display_features=None, 447 | interaction_index="auto", color="#1E88E5", axis_color="#333333", 448 | dot_size=16, alpha=1, title=None, show=True): 449 | """ 450 | Create a SHAP dependence plot, colored by an interaction feature. 451 | 452 | Parameters 453 | ---------- 454 | ind : int 455 | Index of the feature to plot. 456 | 457 | shap_values : numpy.array 458 | Matrix of SHAP values (# samples x # features) 459 | 460 | features : numpy.array or pandas.DataFrame 461 | Matrix of feature values (# samples x # features) 462 | 463 | feature_names : list 464 | Names of the features (length # features) 465 | 466 | display_features : numpy.array or pandas.DataFrame 467 | Matrix of feature values for visual display (such as strings instead of coded values) 468 | 469 | interaction_index : "auto", None, or int 470 | The index of the feature used to color the plot. 471 | """ 472 | 473 | # convert from DataFrames if we got any 474 | if str(type(features)).endswith("'pandas.core.frame.DataFrame'>"): 475 | if feature_names is None: 476 | feature_names = features.columns 477 | features = features.values 478 | if str(type(display_features)).endswith("'pandas.core.frame.DataFrame'>"): 479 | if feature_names is None: 480 | feature_names = display_features.columns 481 | display_features = display_features.values 482 | elif display_features is None: 483 | display_features = features 484 | 485 | if feature_names is None: 486 | feature_names = [labels['FEATURE'] % str(i) for i in range(shap_values.shape[1] - 1)] 487 | 488 | # allow vectors to be passed 489 | if len(shap_values.shape) == 1: 490 | shap_values = np.reshape(shap_values, len(shap_values), 1) 491 | if len(features.shape) == 1: 492 | features = np.reshape(features, len(features), 1) 493 | 494 | def convert_name(ind): 495 | if type(ind) == str: 496 | nzinds = np.where(feature_names == ind)[0] 497 | if len(nzinds) == 0: 498 | print("Could not find feature named: " + ind) 499 | return None 500 | else: 501 | return nzinds[0] 502 | else: 503 | return ind 504 | 505 | ind = convert_name(ind) 506 | 507 | mpl_fig = pl.gcf() 508 | ax = mpl_fig.gca() 509 | 510 | # plotting SHAP interaction values 511 | if len(shap_values.shape) == 3 and len(ind) == 2: 512 | ind1 = convert_name(ind[0]) 513 | ind2 = convert_name(ind[1]) 514 | if ind1 == ind2: 515 | proj_shap_values = shap_values[:, ind2, :] 516 | else: 517 | proj_shap_values = shap_values[:, ind2, :] * 2 # off-diag values are split in half 518 | 519 | # TODO: remove recursion; generally the functions should be shorter for more maintainable code 520 | return shap_dependence_plot( 521 | ind1, proj_shap_values, features, feature_names=feature_names, 522 | interaction_index=ind2, display_features=display_features, show=False 523 | ) 524 | 525 | assert shap_values.shape[0] == features.shape[0], \ 526 | "'shap_values' and 'features' values must have the same number of rows!" 527 | assert shap_values.shape[1] == features.shape[1], \ 528 | "'shap_values' must have the same number of columns as 'features'!" 529 | 530 | # get both the raw and display feature values 531 | xv = features[:, ind] 532 | xd = display_features[:, ind] 533 | s = shap_values[:, ind] 534 | if type(xd[0]) == str: 535 | name_map = {} 536 | for i in range(len(xv)): 537 | name_map[xd[i]] = xv[i] 538 | xnames = list(name_map.keys()) 539 | 540 | # allow a single feature name to be passed alone 541 | if type(feature_names) == str: 542 | feature_names = [feature_names] 543 | name = feature_names[ind] 544 | 545 | # guess what other feature as the stongest interaction with the plotted feature 546 | if interaction_index == "auto": 547 | interaction_index = approx_interactions(ind, shap_values, features)[0] 548 | interaction_index = convert_name(interaction_index) 549 | categorical_interaction = False 550 | 551 | # get both the raw and display color values 552 | if interaction_index is not None: 553 | cv = features[:, interaction_index] 554 | cd = display_features[:, interaction_index] 555 | clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5) 556 | chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95) 557 | if type(cd[0]) == str: 558 | cname_map = {} 559 | for i in range(len(cv)): 560 | cname_map[cd[i]] = cv[i] 561 | cnames = list(cname_map.keys()) 562 | categorical_interaction = True 563 | elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50: 564 | categorical_interaction = True 565 | 566 | # discritize colors for categorical features 567 | color_norm = None 568 | if categorical_interaction and clow != chigh: 569 | bounds = np.linspace(clow, chigh, chigh - clow + 2) 570 | color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N) 571 | 572 | # the actual scatter plot, TODO: adapt the dot_size to the number of data points? 573 | if interaction_index is not None: 574 | pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue, 575 | alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500) 576 | else: 577 | pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5", 578 | alpha=alpha, rasterized=len(xv) > 500) 579 | 580 | if interaction_index != ind and interaction_index is not None: 581 | # draw the color bar 582 | if type(cd[0]) == str: 583 | tick_positions = [cname_map[n] for n in cnames] 584 | if len(tick_positions) == 2: 585 | tick_positions[0] -= 0.25 586 | tick_positions[1] += 0.25 587 | cb = pl.colorbar(ticks=tick_positions) 588 | cb.set_ticklabels(cnames) 589 | else: 590 | cb = pl.colorbar() 591 | 592 | cb.set_label(feature_names[interaction_index], size=13) 593 | cb.ax.tick_params(labelsize=11) 594 | if categorical_interaction: 595 | cb.ax.tick_params(length=0) 596 | cb.set_alpha(1) 597 | cb.outline.set_visible(False) 598 | bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) 599 | cb.ax.set_aspect((bbox.height - 0.7) * 20) 600 | 601 | # make the plot more readable 602 | if interaction_index != ind: 603 | pl.gcf().set_size_inches(7.5, 5) 604 | else: 605 | pl.gcf().set_size_inches(6, 5) 606 | # pl.xlabel(name, color=axis_color, fontsize=13) 607 | # pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13) 608 | if title is not None: 609 | pl.title(title, color=axis_color, fontsize=13) 610 | pl.gca().xaxis.set_ticks_position('bottom') 611 | pl.gca().yaxis.set_ticks_position('left') 612 | pl.gca().spines['right'].set_visible(False) 613 | pl.gca().spines['top'].set_visible(False) 614 | pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11) 615 | for spine in pl.gca().spines.values(): 616 | spine.set_edgecolor(axis_color) 617 | if type(xd[0]) == str: 618 | pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11) 619 | # if show: 620 | # pl.show() 621 | 622 | 623 | if ind1 == ind2: 624 | pl.ylabel(labels['MAIN_EFFECT'] % feature_names[ind1]) 625 | else: 626 | pl.ylabel(labels['INTERACTION_EFFECT'] % (feature_names[ind1], feature_names[ind2])) 627 | 628 | return mpl_fig, interaction_index 629 | 630 | 631 | # # if show: 632 | # # pl.show() 633 | # return 634 | # return mpl_fig 635 | 636 | # assert shap_values.shape[0] == features.shape[0], "'shap_values' and 'features' values must have the same number of rows!" 637 | # assert shap_values.shape[1] == features.shape[1] + 1, "'shap_values' must have one more column than 'features'!" 638 | 639 | # get both the raw and display feature values 640 | xv = features[:, ind] 641 | xd = display_features[:, ind] 642 | s = shap_values[:, ind] 643 | if type(xd[0]) == str: 644 | name_map = {} 645 | for i in range(len(xv)): 646 | name_map[xd[i]] = xv[i] 647 | xnames = list(name_map.keys()) 648 | 649 | # allow a single feature name to be passed alone 650 | if type(feature_names) == str: 651 | feature_names = [feature_names] 652 | name = feature_names[ind] 653 | 654 | # guess what other feature as the stongest interaction with the plotted feature 655 | if interaction_index == "auto": 656 | interaction_index = approx_interactions(ind, shap_values, features)[0] 657 | interaction_index = convert_name(interaction_index) 658 | categorical_interaction = False 659 | 660 | # get both the raw and display color values 661 | if interaction_index is not None: 662 | cv = features[:, interaction_index] 663 | cd = display_features[:, interaction_index] 664 | clow = np.nanpercentile(features[:, interaction_index].astype(np.float), 5) 665 | chigh = np.nanpercentile(features[:, interaction_index].astype(np.float), 95) 666 | if type(cd[0]) == str: 667 | cname_map = {} 668 | for i in range(len(cv)): 669 | cname_map[cd[i]] = cv[i] 670 | cnames = list(cname_map.keys()) 671 | categorical_interaction = True 672 | elif clow % 1 == 0 and chigh % 1 == 0 and len(set(features[:, interaction_index])) < 50: 673 | categorical_interaction = True 674 | 675 | # discritize colors for categorical features 676 | color_norm = None 677 | if categorical_interaction and clow != chigh: 678 | bounds = np.linspace(clow, chigh, chigh - clow + 2) 679 | color_norm = matplotlib.colors.BoundaryNorm(bounds, red_blue.N) 680 | 681 | # the actual scatter plot, TODO: adapt the dot_size to the number of data points? 682 | if interaction_index is not None: 683 | pl.scatter(xv, s, s=dot_size, linewidth=0, c=features[:, interaction_index], cmap=red_blue, 684 | alpha=alpha, vmin=clow, vmax=chigh, norm=color_norm, rasterized=len(xv) > 500) 685 | else: 686 | pl.scatter(xv, s, s=dot_size, linewidth=0, color="#1E88E5", 687 | alpha=alpha, rasterized=len(xv) > 500) 688 | 689 | if interaction_index != ind and interaction_index is not None: 690 | # draw the color bar 691 | if type(cd[0]) == str: 692 | tick_positions = [cname_map[n] for n in cnames] 693 | if len(tick_positions) == 2: 694 | tick_positions[0] -= 0.25 695 | tick_positions[1] += 0.25 696 | cb = pl.colorbar(ticks=tick_positions) 697 | cb.set_ticklabels(cnames) 698 | else: 699 | cb = pl.colorbar() 700 | 701 | cb.set_label(feature_names[interaction_index], size=13) 702 | cb.ax.tick_params(labelsize=11) 703 | if categorical_interaction: 704 | cb.ax.tick_params(length=0) 705 | cb.set_alpha(1) 706 | cb.outline.set_visible(False) 707 | bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted()) 708 | cb.ax.set_aspect((bbox.height - 0.7) * 20) 709 | 710 | # make the plot more readable 711 | if interaction_index != ind: 712 | pl.gcf().set_size_inches(7.5, 5) 713 | else: 714 | pl.gcf().set_size_inches(6, 5) 715 | pl.xlabel(name, color=axis_color, fontsize=13) 716 | pl.ylabel(labels['VALUE_FOR'] % name, color=axis_color, fontsize=13) 717 | if title is not None: 718 | pl.title(title, color=axis_color, fontsize=13) 719 | pl.gca().xaxis.set_ticks_position('bottom') 720 | pl.gca().yaxis.set_ticks_position('left') 721 | pl.gca().spines['right'].set_visible(False) 722 | pl.gca().spines['top'].set_visible(False) 723 | pl.gca().tick_params(color=axis_color, labelcolor=axis_color, labelsize=11) 724 | for spine in pl.gca().spines.values(): 725 | spine.set_edgecolor(axis_color) 726 | if type(xd[0]) == str: 727 | pl.xticks([name_map[n] for n in xnames], xnames, rotation='vertical', fontsize=11) 728 | # if show: 729 | # pl.show() 730 | return mpl_fig, interaction_index --------------------------------------------------------------------------------