├── .gitignore ├── fig ├── 2d-weightedshap.png ├── mnist-weightedshap.png └── inclusion-weightedshap.png ├── notebook ├── fraud_dataset.pkl └── fraud_example.pickle ├── weightedSHAP ├── third_party │ ├── __init__.py │ ├── behavior.py │ ├── image_imputers.py │ ├── resnet.py │ ├── utils.py │ ├── removal.py │ ├── surrogate.py │ └── image_surrogate.py ├── __init__.py ├── weightedSHAPEngine.py ├── utils.py ├── data.py └── train.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .DS_Store 4 | .ipynb_checkpoints 5 | /*.ipynb 6 | *.pt 7 | clf_path 8 | -------------------------------------------------------------------------------- /fig/2d-weightedshap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/fig/2d-weightedshap.png -------------------------------------------------------------------------------- /fig/mnist-weightedshap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/fig/mnist-weightedshap.png -------------------------------------------------------------------------------- /notebook/fraud_dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/notebook/fraud_dataset.pkl -------------------------------------------------------------------------------- /fig/inclusion-weightedshap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/fig/inclusion-weightedshap.png -------------------------------------------------------------------------------- /notebook/fraud_example.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/notebook/fraud_example.pickle -------------------------------------------------------------------------------- /weightedSHAP/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | from . import behavior, image_imputers, image_surrogate, removal, resnet, surrogate, utils -------------------------------------------------------------------------------- /weightedSHAP/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, third_party, train, utils, weightedSHAPEngine 2 | 3 | from .data import load_data 4 | from .train import create_model_to_explain, generate_coalition_function 5 | from .weightedSHAPEngine import compute_attributions -------------------------------------------------------------------------------- /weightedSHAP/third_party/behavior.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://github.com/iancovert/removal-explanations/blob/main/rexplain/behavior.py 3 | ''' 4 | import numpy as np 5 | 6 | class PredictionGame: 7 | ''' 8 | Cooperative game for an individual example's prediction. 9 | 10 | Args: 11 | extension: model extension (see removal.py). 12 | sample: numpy array representing a single model input. 13 | ''' 14 | 15 | def __init__(self, extension, sample, superpixel_size=1): 16 | # Add batch dimension to sample. 17 | if sample.ndim == 1: 18 | sample = sample[np.newaxis] 19 | # elif sample.shape[0] != 1: 20 | # raise ValueError('sample must have shape (ndim,) or (1,ndim)') 21 | 22 | self.extension = extension 23 | self.sample = sample 24 | self.players = np.prod(sample.shape)//(superpixel_size**2)//sample.shape[0] # sample.shape[1] 25 | 26 | # Caching. 27 | self.sample_repeat = sample 28 | 29 | def __call__(self, S): 30 | # Return scalar if single subset. 31 | single_eval = (S.ndim == 1) 32 | if single_eval: 33 | S = S[np.newaxis] 34 | input_data = self.sample 35 | else: 36 | # Try to use caching for repeated data. 37 | if len(S) != len(self.sample_repeat): 38 | self.sample_repeat = self.sample.repeat(len(S), 0) 39 | input_data = self.sample_repeat 40 | 41 | # Evaluate. 42 | output = self.extension(input_data, S) 43 | if single_eval: 44 | output = output[0] 45 | return output 46 | 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WeightedSHAP: analyzing and improving Shapley based feature attributions 2 | 3 | This repository provides an implementation of the paper *[WeightedSHAP: analyzing and improving Shapley based feature attributions](https://arxiv.org/abs/2209.13429)* accepted at [NeurIPS 2022](https://nips.cc/Conferences/2022). We show the suboptimality of SHAP and propose **a new feature attribution method called WeightedSHAP**. WeightedSHAP is a generalization of SHAP and is more effective to capture influential features. 4 | 5 | ### Quick start 6 | 7 | We provide an easy-to-follow [Jupyter notebook](notebook/Example_fraud_inclusion_AUC.ipynb), which introduces how to compute the WeightedSHAP on the Fraud dataset. 8 | 9 | ### Key results 10 | 11 |

12 | 13 |

14 | 15 | → Illustrations of the suboptimality of Shapley-based feature attributions (SHAP) when $d=2$. ***Shapley value fails to assign large attributions to more influential features*** on grey area. 16 | 17 |

18 | 19 |

20 | 21 | → Illustrations of the prediction recovery error curve and the Inclusion AUC curve as a function of the number of features added. ***WeightedSHAP effectively assigns larger values for more influential features*** and recovers the original prediction $\hat{f}(x)$ significantly faster than other state-of-the-art methods. 22 | 23 |

24 | 25 |

26 | 27 | → ***WeightedSHAP can identify more interpretable features***. In particular, SHAP fails to capture the last stroke of digit nine, which is a crucially important stroke to differentiate from the digit zero. 28 | 29 | 30 | ### References 31 | 32 | This repository highly depends on the following two repositories. 33 | 34 | - Covert, I., Lundberg, S. M., & Lee, S. I. (2021). Explaining by Removing: A Unified Framework for Model Explanation. J. Mach. Learn. Res., 22, 209-1. [[GitHub]](https://github.com/iancovert/removal-explanations) 35 | 36 | - Jethani, N., Sudarshan, M., Covert, I. C., Lee, S. I., & Ranganath, R. (2021, September). FastSHAP: Real-Time Shapley Value Estimation. In International Conference on Learning Representations. [[GitHub]](https://github.com/iancovert/fastshap/tree/main/fastshap) 37 | 38 | ### Authors 39 | 40 | - Yongchan Kwon (yk3012 (at) columbia (dot) edu) 41 | 42 | - James Zou (jamesz (at) stanford (dot) edu) 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /weightedSHAP/third_party/image_imputers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/image_imputers.py 3 | ''' 4 | import torch.nn as nn 5 | 6 | class ImageImputer: 7 | ''' 8 | Image imputer base class. 9 | Args: 10 | width: image width. 11 | height: image height. 12 | superpixel_size: superpixel width/height (int). 13 | ''' 14 | 15 | def __init__(self, width, height, superpixel_size=1): 16 | # Verify arguments. 17 | assert width % superpixel_size == 0 18 | assert height % superpixel_size == 0 19 | 20 | # Set up superpixel upsampling. 21 | self.width = width 22 | self.height = height 23 | self.supsize = superpixel_size 24 | if superpixel_size == 1: 25 | self.upsample = nn.Identity() 26 | else: 27 | self.upsample = nn.Upsample( 28 | scale_factor=superpixel_size, mode='nearest') 29 | 30 | # Set up number of players. 31 | self.small_width = width // superpixel_size 32 | self.small_height = height // superpixel_size 33 | self.num_players = self.small_width * self.small_height 34 | 35 | def __call__(self, x, S): 36 | ''' 37 | Evaluate with subset of features. 38 | Args: 39 | x: input examples. 40 | S: coalitions. 41 | ''' 42 | raise NotImplementedError 43 | 44 | def resize(self, S): 45 | ''' 46 | Resize coalition variable S into grid shape. 47 | Args: 48 | S: coalitions. 49 | ''' 50 | if len(S.shape) == 2: 51 | S = S.reshape(S.shape[0], self.small_height, 52 | self.small_width).unsqueeze(1) 53 | return self.upsample(S) 54 | 55 | 56 | class BaselineImageImputer(ImageImputer): 57 | ''' 58 | Evaluate image model while replacing features with baseline values. 59 | Args: 60 | model: predictive model. 61 | baseline: baseline value(s). 62 | width: image width. 63 | height: image height. 64 | superpixel_size: superpixel width/height (int). 65 | link: link function (e.g., nn.Softmax). 66 | ''' 67 | 68 | def __init__(self, model, baseline, width, height, superpixel_size, 69 | link=None): 70 | super().__init__(width, height, superpixel_size) 71 | self.model = model 72 | self.baseline = baseline 73 | 74 | # Set up link. 75 | if link is None: 76 | self.link = nn.Identity() 77 | elif isinstance(link, nn.Module): 78 | self.link = link 79 | else: 80 | raise ValueError('unsupported link function: {}'.format(link)) 81 | 82 | def __call__(self, x, S): 83 | ''' 84 | Evaluate model using baseline values. 85 | ''' 86 | S = self.resize(S) 87 | x_baseline = S * x + (1 - S) * self.baseline 88 | return self.link(self.model(x_baseline)) 89 | 90 | -------------------------------------------------------------------------------- /weightedSHAP/third_party/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ResNet in PyTorch. 3 | This implementation is based on kuangliu's code 4 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 5 | and the PyTorch reference implementation 6 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | ''' 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 20 | stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion*planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion*planes, 27 | kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 47 | stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * 50 | planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes, in_channels): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | # Input conv. 76 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, 77 | stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(64) 79 | 80 | # Residual blocks. 81 | channels = 64 82 | stride = 1 83 | blocks = [] 84 | for num in num_blocks: 85 | blocks.append(self._make_layer(block, channels, num, stride=stride)) 86 | channels *= 2 87 | stride = 2 88 | self.layers = nn.ModuleList(blocks) 89 | 90 | # Output layer. 91 | self.num_classes = num_classes 92 | if num_classes is not None: 93 | self.linear = nn.Linear(512*block.expansion, num_classes) 94 | 95 | def _make_layer(self, block, planes, num_blocks, stride): 96 | strides = [stride] + [1]*(num_blocks-1) 97 | layers = [] 98 | for stride in strides: 99 | layers.append(block(self.in_planes, planes, stride)) 100 | self.in_planes = planes * block.expansion 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | # Input conv. 105 | out = F.relu(self.bn1(self.conv1(x))) 106 | 107 | # Residual blocks. 108 | for layer in self.layers: 109 | out = layer(out) 110 | 111 | # Output layer. 112 | if self.num_classes is not None: 113 | out = F.avg_pool2d(out, 4) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | 117 | return out 118 | 119 | 120 | def ResNet18(num_classes, in_channels=3): 121 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, in_channels) 122 | 123 | 124 | def ResNet34(num_classes, in_channels=3): 125 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, in_channels) 126 | 127 | 128 | def ResNet50(num_classes, in_channels=3): 129 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, in_channels) 130 | -------------------------------------------------------------------------------- /weightedSHAP/weightedSHAPEngine.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | 4 | # custom modules 5 | from weightedSHAP import data, utils 6 | from weightedSHAP.third_party import behavior 7 | 8 | semivalue_list=[(-1,'LOO-First'), (1,32), (1,16), (1, 8), (1, 4), (1, 2), 9 | (1,1), (2,1), (4,1), (8, 1), (16, 1), (32,1), (-1,'LOO-Last')] 10 | attribution_list=['LOO-First', 'Beta(32,1)', 'Beta(16,1)', 'Beta(8,1)', 11 | 'Beta(4,1)', 'Beta(2,1)', 'Beta(1,1)', 'Beta(1,2)', 12 | 'Beta(1,4)', 'Beta(1,8)', 'Beta(1,16)', 'Beta(1,32)', 'LOO-Last'] 13 | 14 | def MarginalContributionValue(game, thresh=1.005, batch_size=1, n_check_period=100): 15 | '''Calculate feature attributions using the marginal contributions.''' 16 | 17 | # index of the added feature, cardinality of set 18 | arange = np.arange(batch_size) 19 | output = game(np.zeros(game.players, dtype=bool)) 20 | MC_size=[game.players, game.players] + list(output.shape) 21 | MC_mat=np.zeros(MC_size) 22 | MC_count=np.zeros((game.players, game.players)) 23 | 24 | converged = False 25 | n_iter = 0 26 | marginal_contribs=np.zeros([0, np.prod([game.players] + list(output.shape))]) 27 | while not converged: 28 | for _ in range(n_check_period): 29 | # Sample permutations. 30 | permutations = np.tile(np.arange(game.players), (batch_size, 1)) 31 | for row in permutations: 32 | np.random.shuffle(row) 33 | S = np.zeros((batch_size, game.players), dtype=bool) 34 | 35 | # Unroll permutations. 36 | prev_value = game(S) 37 | marginal_contribs_tmp = np.zeros(([batch_size, game.players] + list(output.shape))) 38 | for j in range(game.players): 39 | ''' 40 | Marginal contribution estimates with respect to j samples 41 | j = 0 means LOO-First 42 | ''' 43 | S[arange, permutations[:, j]] = 1 44 | next_value = game(S) 45 | MC_mat[permutations[:, j], j] += (next_value - prev_value) 46 | MC_count[permutations[:, j], j] += 1 47 | marginal_contribs_tmp[arange, permutations[:, j]] = (next_value - prev_value) 48 | 49 | # update 50 | prev_value = next_value 51 | marginal_contribs=np.concatenate([marginal_contribs, marginal_contribs_tmp.reshape(batch_size,-1)], axis=0) 52 | 53 | if (n_iter+1) == 100: 54 | converged=True 55 | elif (n_iter+1) >= 2: 56 | if utils.check_convergence(marginal_contribs) < thresh: 57 | print(f'Therehosld: {int(0.999*(game.players*n_check_period))}') 58 | converged=True 59 | else: 60 | pass 61 | 62 | n_iter += 1 63 | 64 | print(f'We have seen {((n_iter+1)*n_check_period*batch_size)} random subsets for each feature.') 65 | if len(MC_mat.shape) != 2: 66 | # classification case 67 | MC_count=np.repeat(MC_count, MC_mat.shape[-1], axis=-1).reshape(MC_mat.shape) 68 | 69 | return MC_mat, MC_count 70 | 71 | 72 | def compute_attributions(problem, ML_model, 73 | model_to_explain, conditional_extension, 74 | X_train, y_train, X_val, y_val, X_test, y_test, n_max=100): 75 | ''' 76 | Compute attribution values and evaluate its performance 77 | ''' 78 | pred_list, pred_masking = [], [] 79 | cond_pred_keep_absolute, cond_pred_remove_absolute=[], [] 80 | value_list=[] 81 | n_max=min(n_max, len(X_test)) 82 | for ind in tqdm(range(n_max)): 83 | # Store original prediction 84 | original_pred=utils.compute_predict(model_to_explain, X_test[ind,:].reshape(1,-1), problem, ML_model) 85 | pred_list.append(original_pred) 86 | 87 | # Estimate marginal contributions 88 | conditional_game=behavior.PredictionGame(conditional_extension, X_test[ind, :]) 89 | MC_conditional_mat, MC_conditional_count=MarginalContributionValue(conditional_game) 90 | MC_est=np.array(MC_conditional_mat/(MC_conditional_count+1e-16)) 91 | 92 | # Optimize weight for WeightedSHAP (By default, AUP is used) 93 | attribution_dict_all=utils.compute_semivalue_from_MC(MC_est, semivalue_list) 94 | cond_pred_keep_absolute_list=utils.compute_cond_pred_list(attribution_dict_all, conditional_game) 95 | AUP_list=np.sum(np.abs(np.array(cond_pred_keep_absolute_list)- original_pred), axis=1) 96 | WeightedSHAP_index=np.argmin(AUP_list) 97 | value_list.append(attribution_dict_all[attribution_list[WeightedSHAP_index]]) 98 | 99 | ''' 100 | Evaluation 101 | ''' 102 | # Conditional prediction from most important to least important (keep absolte) 103 | cond_pred_keep_absolute_list=utils.compute_cond_pred_list(attribution_dict_all, conditional_game) 104 | cond_pred_keep_absolute.append(cond_pred_keep_absolute_list) 105 | 106 | exp_dict=dict() 107 | exp_dict['value_list']=value_list 108 | exp_dict['true_list']=np.array(y_test)[:n_max] 109 | exp_dict['pred_list']=np.array(pred_list) 110 | exp_dict['input_list']=np.array(X_test)[:n_max] 111 | exp_dict['cond_pred_keep_absolute']=np.array(cond_pred_keep_absolute) 112 | 113 | return exp_dict 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /weightedSHAP/third_party/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/utils.py 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import itertools 8 | from torch.utils.data import Dataset 9 | 10 | class MaskLayer1d(nn.Module): 11 | ''' 12 | Masking for 1d inputs. 13 | 14 | Args: 15 | append: whether to append the mask along channels dim. 16 | value: replacement value for held out features. 17 | ''' 18 | def __init__(self, append=True, value=0): 19 | super().__init__() 20 | self.append = append 21 | self.value = value 22 | 23 | def forward(self, input_tuple): 24 | x, S = input_tuple 25 | x = x * S + self.value * (1 - S) 26 | if self.append: 27 | x = torch.cat((x, S), dim=1) 28 | return x 29 | 30 | class MaskLayer2d(nn.Module): 31 | ''' 32 | Masking for 2d inputs. 33 | 34 | Args: 35 | append: whether to append the mask along channels dim. 36 | value: replacement value for held out features. 37 | ''' 38 | def __init__(self, append=True, value=0): 39 | super().__init__() 40 | self.append = append 41 | self.value = value 42 | 43 | def forward(self, input_tuple): 44 | ''' 45 | Apply mask to input. 46 | 47 | Args: 48 | input_tuple: tuple of input x and mask S. 49 | ''' 50 | x, S = input_tuple 51 | x = x * S + self.value * (1 - S) 52 | if self.append: 53 | x = torch.cat((x, S), dim=1) 54 | return x 55 | 56 | class KLDivLoss(nn.Module): 57 | ''' 58 | KL divergence loss that applies log softmax operation to predictions. 59 | Args: 60 | reduction: how to reduce loss value (e.g., 'batchmean'). 61 | log_target: whether the target is expected as a log probabilities (or as 62 | probabilities). 63 | ''' 64 | 65 | def __init__(self, reduction='batchmean', log_target=False): 66 | super().__init__() 67 | self.kld = nn.KLDivLoss(reduction=reduction, log_target=log_target) 68 | 69 | def forward(self, pred, target): 70 | ''' 71 | Evaluate loss. 72 | Args: 73 | pred: 74 | target: 75 | ''' 76 | return self.kld(pred.log_softmax(dim=1), target) 77 | 78 | class MSELoss(nn.Module): 79 | ''' 80 | MSE loss. 81 | Args: 82 | reduction: how to reduce loss value (e.g., 'batchmean'). 83 | log_target: whether the target is expected as a log probabilities (or as 84 | probabilities). 85 | ''' 86 | 87 | def __init__(self, reduction='mean'): 88 | super().__init__() 89 | self.mseloss = nn.MSELoss(reduction=reduction) 90 | 91 | def forward(self, pred, target): 92 | ''' 93 | Evaluate loss. 94 | Args: 95 | pred: 96 | target: 97 | ''' 98 | return self.mseloss(pred, target) 99 | 100 | class UniformSampler: 101 | ''' 102 | For sampling player subsets with cardinality chosen uniformly at random. 103 | Args: 104 | num_players: number of players. 105 | ''' 106 | 107 | def __init__(self, num_players): 108 | self.num_players = num_players 109 | 110 | def sample(self, batch_size): 111 | ''' 112 | Generate sample. 113 | Args: 114 | batch_size: 115 | ''' 116 | S = torch.ones(batch_size, self.num_players, dtype=torch.float32) 117 | num_included = (torch.rand(batch_size) * (self.num_players + 1)).int() 118 | # TODO ideally avoid for loops 119 | # TODO ideally pass buffer to assign samples in place 120 | for i in range(batch_size): 121 | S[i, num_included[i]:] = 0 122 | S[i] = S[i, torch.randperm(self.num_players)] 123 | 124 | return S 125 | 126 | class DatasetRepeat(Dataset): 127 | ''' 128 | A wrapper around multiple datasets that allows repeated elements when the 129 | dataset sizes don't match. The number of elements is the maximum dataset 130 | size, and all datasets must be broadcastable to the same size. 131 | Args: 132 | datasets: list of dataset objects. 133 | ''' 134 | 135 | def __init__(self, datasets): 136 | # Get maximum number of elements. 137 | assert np.all([isinstance(dset, Dataset) for dset in datasets]) 138 | items = [len(dset) for dset in datasets] 139 | num_items = np.max(items) 140 | 141 | # Ensure all datasets align. 142 | # assert np.all([num_items % num == 0 for num in items]) 143 | self.dsets = datasets 144 | self.num_items = num_items 145 | self.items = items 146 | 147 | def __getitem__(self, index): 148 | assert 0 <= index < self.num_items 149 | return_items = [dset[index % num] for dset, num in 150 | zip(self.dsets, self.items)] 151 | return tuple(itertools.chain(*return_items)) 152 | 153 | def __len__(self): 154 | return self.num_items 155 | 156 | class DatasetInputOnly(Dataset): 157 | ''' 158 | A wrapper around a dataset object to ensure that only the first element is 159 | returned. 160 | Args: 161 | dataset: dataset object. 162 | ''' 163 | 164 | def __init__(self, dataset): 165 | assert isinstance(dataset, Dataset) 166 | self.dataset = dataset 167 | 168 | def __getitem__(self, index): 169 | return (self.dataset[index][0],) 170 | 171 | def __len__(self): 172 | return len(self.dataset) 173 | 174 | -------------------------------------------------------------------------------- /weightedSHAP/utils.py: -------------------------------------------------------------------------------- 1 | import os, sys, inspect, pickle 2 | import numpy as np 3 | from weightedSHAP import train 4 | 5 | def crossentropyloss(pred, target): 6 | '''Cross entropy loss that does not average across samples.''' 7 | if pred.ndim == 1: 8 | pred = pred[:, np.newaxis] 9 | pred = np.concatenate((1 - pred, pred), axis=1) 10 | 11 | if pred.shape == target.shape: 12 | # Soft cross entropy loss. 13 | pred = np.clip(pred, a_min=1e-12, a_max=1-1e-12) 14 | return - np.sum(np.log(pred) * target, axis=1) 15 | else: 16 | # Standard cross entropy loss. 17 | return - np.log(pred[np.arange(len(pred)), target]) 18 | 19 | def mseloss(pred, target): 20 | '''MSE loss that does not average across samples.''' 21 | return np.sum((pred - target) ** 2, axis=1) 22 | 23 | def beta_constant(a, b): 24 | ''' 25 | the second argument (b; beta) should be integer in this function 26 | ''' 27 | beta_fct_value=1/a 28 | for i in range(1,b): 29 | beta_fct_value=beta_fct_value*(i/(a+i)) 30 | return beta_fct_value 31 | 32 | def compute_weight_list(m, alpha=1, beta=1): 33 | ''' 34 | Given a prior distribution (beta distribution (alpha,beta)) 35 | beta_constant(j+1, m-j) = j! (m-j-1)! / (m-1)! / m # which is exactly the Shapley weights. 36 | 37 | # weight_list[n] is a weight when baseline model uses 'n' samples (w^{(n)}(j)*binom{n-1}{j} in the paper). 38 | ''' 39 | weight_list=np.zeros(m) 40 | normalizing_constant=1/beta_constant(alpha, beta) 41 | for j in np.arange(m): 42 | # when the cardinality of random sets is j 43 | weight_list[j]=beta_constant(j+alpha, m-j+beta-1)/beta_constant(j+1, m-j) 44 | weight_list[j]=normalizing_constant*weight_list[j] # we need this '/m' but omit for stability # normalizing 45 | return weight_list/np.sum(weight_list) 46 | 47 | def compute_semivalue_from_MC(marginal_contrib, semivalue_list): 48 | ''' 49 | With the marginal contribution values, it computes semivalues 50 | 51 | ''' 52 | semivalue_dict={} 53 | n_elements=marginal_contrib.shape[0] 54 | for weight in semivalue_list: 55 | alpha, beta=weight 56 | if alpha > 0: 57 | model_name=f'Beta({beta},{alpha})' 58 | weight_list=compute_weight_list(m=n_elements, alpha=alpha, beta=beta) 59 | else: 60 | if beta == 'LOO-First': 61 | model_name='LOO-First' 62 | weight_list=np.zeros(n_elements) 63 | weight_list[0]=1 64 | elif beta == 'LOO-Last': 65 | model_name='LOO-Last' 66 | weight_list=np.zeros(n_elements) 67 | weight_list[-1]=1 68 | 69 | if len(marginal_contrib.shape) == 2: 70 | semivalue_tmp=np.einsum('ij,j->i', marginal_contrib, weight_list) 71 | else: 72 | # classification case 73 | semivalue_tmp=np.einsum('ijk,j->ik', marginal_contrib, weight_list) 74 | semivalue_dict[model_name]=semivalue_tmp 75 | return semivalue_dict 76 | 77 | def check_convergence(mem, n_require=100): 78 | """ 79 | Compute Gelman-Rubin statistic 80 | Ref. https://arxiv.org/pdf/1812.09384.pdf (p.7, Eq.4) 81 | """ 82 | if len(mem) < n_require: 83 | return 100 84 | n_chains=10 85 | (N,n_to_be_valued)=mem.shape 86 | if (N % n_chains) == 0: 87 | n_MC_sample=N//n_chains 88 | offset=0 89 | else: 90 | n_MC_sample=N//n_chains 91 | offset=(N%n_chains) 92 | mem=mem[offset:] 93 | percent=25 94 | while True: 95 | IQR_contstant=np.percentile(mem.reshape(-1), 50+percent) - np.percentile(mem.reshape(-1), 50-percent) 96 | if IQR_contstant == 0: 97 | percent=(50+percent)//2 98 | if percent >= 49: 99 | assert False, 'CHECK!!! IQR is zero!!!' 100 | else: 101 | break 102 | 103 | mem_tmp=mem.reshape(n_chains, n_MC_sample, n_to_be_valued) 104 | GR_list=[] 105 | for j in range(n_to_be_valued): 106 | mem_tmp_j_original=mem_tmp[:,:,j].T # now we have (n_MC_sample, n_chains) 107 | mem_tmp_j=mem_tmp_j_original/IQR_contstant 108 | mem_tmp_j_mean=np.mean(mem_tmp_j, axis=0) 109 | s_term=np.sum((mem_tmp_j-mem_tmp_j_mean)**2)/(n_chains*(n_MC_sample-1)) # + 1e-16 this could lead to wrong estimator 110 | if s_term == 0: 111 | continue 112 | mu_hat_j=np.mean(mem_tmp_j) 113 | B_term=n_MC_sample*np.sum((mem_tmp_j_mean-mu_hat_j)**2)/(n_chains-1) 114 | 115 | GR_stat=np.sqrt((n_MC_sample-1)/n_MC_sample + B_term/(s_term*n_MC_sample)) 116 | GR_list.append(GR_stat) 117 | GR_stat=np.max(GR_list) 118 | print(f'Total number of random sets: {len(mem)}, GR_stat: {GR_stat}', flush=True) 119 | return GR_stat 120 | 121 | def compute_cond_pred_list(attribution_dict, game, more_important_first=True): 122 | n_features=game.players 123 | n_max_features=n_features # min(n_features, 200) 124 | 125 | cond_pred_list=[] 126 | for method in attribution_dict.keys(): 127 | cond_pred_list_tmp=[] 128 | if more_important_first is True: 129 | # more important to less important (large to zero) 130 | sorted_index=np.argsort(np.abs(attribution_dict[method]))[::-1] 131 | else: 132 | # less important to more important (zero to large) 133 | sorted_index=np.argsort(np.abs(attribution_dict[method])) 134 | 135 | for n_top in range(n_max_features+1): 136 | top_index=sorted_index[:n_top] 137 | S=np.zeros(n_features, dtype=bool) 138 | S[top_index]=True 139 | 140 | # prediction recovery error 141 | cond_pred_list_tmp.append(game(S)) 142 | cond_pred_list.append(cond_pred_list_tmp) 143 | 144 | return cond_pred_list 145 | 146 | def compute_pred_maksing_list(attribution_dict, model_to_explain, x, problem, ML_model, more_important_first=True): 147 | n_features=x.shape[1] 148 | n_max_features=n_features # min(n_features, 200) 149 | 150 | pred_masking_list=[] 151 | for method in attribution_dict.keys(): 152 | pred_masking_list_tmp=[] 153 | if more_important_first is True: 154 | # more important to less important (large to zero) 155 | sorted_index=np.argsort(np.abs(attribution_dict[method]))[::-1] 156 | else: 157 | # less important to more important (zero to large) 158 | sorted_index=np.argsort(np.abs(attribution_dict[method])) 159 | 160 | for n_top in range(n_max_features+1): 161 | top_index=sorted_index[:n_top] 162 | curr_x=np.zeros((1,n_features)) # Input matrix is standardized 163 | curr_x[0, top_index] = x[0, top_index] 164 | 165 | # prediction recovery error 166 | curr_pred=compute_predict(model_to_explain, curr_x, problem, ML_model) 167 | pred_masking_list_tmp.append(curr_pred) 168 | pred_masking_list.append(pred_masking_list_tmp) 169 | 170 | return pred_masking_list 171 | 172 | def compute_predict(model_to_explain, x, problem, ML_model): 173 | if (ML_model == 'linear') and (problem == 'classification'): 174 | return float(model_to_explain.predict_proba(x)[:,1]) 175 | else: 176 | return float(model_to_explain.predict(x)) 177 | 178 | -------------------------------------------------------------------------------- /weightedSHAP/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pickle 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.datasets import load_boston 6 | from sklearn.preprocessing import StandardScaler 7 | 8 | def load_data(problem, dataset, dir_path, random_factor=2, random_seed=2022): 9 | ''' 10 | Load a dataset 11 | We split datasets as Train:Val:Est:Test=7:1:1:1 12 | Train: to train a model 13 | Val: to optimize hyperparameters 14 | Est: to estimate coalition functions 15 | Test: to evaluate performance 16 | ''' 17 | print('-'*30) 18 | print('Load a dataset') 19 | print('-'*30) 20 | 21 | if problem=='regression': 22 | (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)=load_regression_dataset(dataset=dataset, dir_path=dir_path, rid=random_seed) 23 | elif problem=='classification': 24 | (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)=load_classification_dataset(dataset=dataset, dir_path=dir_path, rid=random_seed) 25 | else: 26 | raise NotImplementedError('Check problem') 27 | 28 | if random_factor != 0: 29 | # We add noisy features to the original dataset. 30 | print('-'*30) 31 | print('Before adding noise') 32 | print(f'Shape of X_train, X_val, X_est, X_test: {X_train.shape}, {X_val.shape}, {X_est.shape}, {X_test.shape}') 33 | print('-'*30) 34 | dim_noise=int(X_train.shape[1]*random_factor) 35 | X_train=extend_dataset(X_train, dim_noise, verbose=True) 36 | X_val=extend_dataset(X_val, dim_noise) 37 | X_est=extend_dataset(X_est, dim_noise) 38 | X_test=extend_dataset(X_test, dim_noise) 39 | print('After adding noise') 40 | print(f'Shape of X_train, X_val, X_est, X_test: {X_train.shape}, {X_val.shape}, {X_est.shape}, {X_test.shape}') 41 | print('-'*30) 42 | else: 43 | # We use the original dataset. 44 | print('-'*30) 45 | print(f'Shape of X_train, X_val, X_est, X_test: {X_train.shape}, {X_val.shape}, {X_est.shape}, {X_test.shape}') 46 | print('-'*30) 47 | 48 | return (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test) 49 | 50 | def load_regression_dataset(dataset='abalone', dir_path='dir_path', rid=1): 51 | ''' 52 | This function loads regression datasets. 53 | dir_path: path to regression datasets. 54 | 55 | You may need to download datasets first. Make sure to store in 'dir_path'. 56 | The datasets are avaiable at the following links. 57 | abalone: https://archive.ics.uci.edu/ml/machine-learning-databases/abalone/ 58 | airfoil: https://archive.ics.uci.edu/ml/machine-learning-databases/00291/ 59 | whitewine: https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/ 60 | ''' 61 | np.random.seed(rid) 62 | 63 | if dataset == 'boston': 64 | print('-'*50) 65 | print('Boston') 66 | print('-'*50) 67 | data=load_boston() 68 | X, y=data['data'], data['target'] 69 | elif dataset == 'abalone': 70 | print('-'*50) 71 | print('Abalone') 72 | print('-'*50) 73 | raw_data = pd.read_csv(dir_path+"/abalone.data", header=None) 74 | raw_data.dropna(inplace=True) 75 | X, y = pd.get_dummies(raw_data.iloc[:,:-1],drop_first=True).values, raw_data.iloc[:,-1].values 76 | elif dataset == 'whitewine': 77 | print('-'*50) 78 | print('whitewine') 79 | print('-'*50) 80 | raw_data = pd.read_csv(dir_path+"/winequality-white.csv",delimiter=";") 81 | raw_data.dropna(inplace=True) 82 | X, y = raw_data.values[:,:-1], raw_data.values[:,-1] 83 | elif dataset == 'airfoil': 84 | print('-'*50) 85 | print('airfoil') 86 | print('-'*50) 87 | raw_data = pd.read_csv(dir_path+"/airfoil_self_noise.dat", sep='\t', names=['X1','X2,','X3','X4','X5','Y']) 88 | X, y = raw_data.values[:,:-1], raw_data.values[:,-1] 89 | else: 90 | raise NotImplementedError(f'Check {dataset}') 91 | 92 | X = standardize_data(X) 93 | X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.1) 94 | X_train, X_val, y_train, y_val=train_test_split(X_train, y_train, test_size=float(1/9)) 95 | X_train, X_est, y_train, y_est=train_test_split(X_train, y_train, test_size=float(1/8)) 96 | 97 | return (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test) 98 | 99 | def load_classification_dataset(dataset='gaussian', dir_path='dir_path', rid=1): 100 | ''' 101 | This function loads classification datasets. 102 | dir_path: path to classification datasets. 103 | ''' 104 | np.random.seed(rid) 105 | 106 | if dataset == 'gaussian': 107 | print('-'*50) 108 | print('Gaussian') 109 | print('-'*50) 110 | n, input_dim, rho=10000, 10, 0.25 111 | U_cov=np.diag((1-rho)*np.ones(input_dim))+rho 112 | U_mean=np.zeros(input_dim) 113 | X=np.random.multivariate_normal(U_mean, U_cov, n) 114 | 115 | beta_true=(np.linspace(input_dim,(41*input_dim/50),input_dim)/input_dim).reshape(input_dim,1) 116 | p_true=np.exp(X.dot(beta_true))/(1.+np.exp(X.dot(beta_true))) 117 | y=np.random.binomial(n=1, p=p_true).reshape(-1) 118 | elif dataset == 'fraud': 119 | print('-'*50) 120 | print('Fraud Detection') 121 | print('-'*50) 122 | data_dict=pickle.load(open(f'{dir_path}/fraud_dataset.pkl', 'rb')) 123 | data, target = data_dict['X_num'], data_dict['y'] 124 | target = (target == 1) + 0.0 125 | target = target.astype(np.int32) 126 | X, y=make_balance_sample(data, target) 127 | else: 128 | raise NotImplementedError(f'Check {dataset}') 129 | 130 | X = standardize_data(X) 131 | X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.1) 132 | X_train, X_val, y_train, y_val=train_test_split(X_train, y_train, test_size=float(1/9)) 133 | X_train, X_est, y_train, y_est=train_test_split(X_train, y_train, test_size=float(1/8)) 134 | 135 | return (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test) 136 | 137 | 138 | ''' 139 | Data utils 140 | ''' 141 | 142 | def extend_dataset(X, d_to_add, verbose=False): 143 | n, d_prev = X.shape 144 | rho=(np.sum((X.T.dot(X)/n)[:d_prev,:d_prev])-d_prev)/(d_prev*(d_prev-1)) 145 | 146 | if -1/(4*(d_prev-1)+1e-16) > rho: 147 | if verbose is True: 148 | # if rho is too small, the sigma_square defined below can be negative 149 | print(f'Initial rho: {rho:.4f}') 150 | rho=max(-1/(4*(d_prev-1)+1e-16), rho) 151 | print(f'After fixing rho: {rho:.4f}') 152 | else: 153 | if verbose is True: 154 | print(f'Rho: {rho:.4f}') 155 | 156 | for _ in range(d_to_add): 157 | sigma_square=1-(rho**2)*(d_prev)/(1+rho*(d_prev-1)+1e-16) 158 | new_X = (rho/(1+rho*(d_prev-1)+1e-16))*X.dot(np.ones((d_prev,1))) 159 | new_X += np.random.normal(size=(n,1))*np.sqrt(sigma_square) 160 | X = np.concatenate((X, new_X), axis=1) 161 | d_prev += 1 162 | return X 163 | 164 | def make_balance_sample(data, target): 165 | p = np.mean(target) 166 | minor_class=1 if p < 0.5 else 0 167 | 168 | index_minor_class = np.where(target == minor_class)[0] 169 | n_minor_class=len(index_minor_class) 170 | n_major_class=len(target)-n_minor_class 171 | new_minor=np.random.choice(index_minor_class, size=n_major_class-n_minor_class, replace=True) 172 | 173 | data=np.concatenate([data, data[new_minor]]) 174 | target=np.concatenate([target, target[new_minor]]) 175 | return data, target 176 | 177 | def standardize_data(X): 178 | ss=StandardScaler() 179 | ss.fit(X) 180 | try: 181 | X = ss.transform(X.values) 182 | except: 183 | X = ss.transform(X) 184 | return X 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /weightedSHAP/third_party/removal.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://github.com/iancovert/removal-explanations/blob/main/rexplain/removal.py 3 | ''' 4 | import numpy as np 5 | 6 | class MarginalExtension: 7 | '''Extend a model by marginalizing out removed features using their 8 | marginal distribution.''' 9 | def __init__(self, data, model): 10 | self.model = model 11 | self.data = data 12 | self.data_repeat = data 13 | self.samples = len(data) 14 | # self.x_addr = None 15 | # self.x_repeat = None 16 | 17 | def __call__(self, x, S): 18 | # Prepare x and S. 19 | n = len(x) 20 | x = x.repeat(self.samples, 0) 21 | S = S.repeat(self.samples, 0) 22 | 23 | # Prepare samples. 24 | if len(self.data_repeat) != self.samples * n: 25 | self.data_repeat = np.tile(self.data, (n, 1)) 26 | 27 | # Replace specified indices. 28 | x_ = x.copy() 29 | x_[~S] = self.data_repeat[~S] 30 | 31 | # Make predictions. 32 | pred = self.model(x_) 33 | pred = pred.reshape(-1, self.samples, *pred.shape[1:]) 34 | return np.mean(pred, axis=1) 35 | 36 | 37 | class ConditionalSupervisedExtension: 38 | '''Extend a model using a supervised surrogate model.''' 39 | def __init__(self, surrogate): 40 | self.surrogate = surrogate 41 | 42 | def __call__(self, x, S): 43 | return self.surrogate(x, S) 44 | 45 | class DefaultExtension: 46 | '''Extend a model by replacing removed features with default values.''' 47 | def __init__(self, values, model): 48 | self.model = model 49 | if values.ndim == 1: 50 | values = values[np.newaxis] 51 | elif values[0] != 1: 52 | raise ValueError('values shape must be (dim,) or (1, dim)') 53 | self.values = values 54 | self.values_repeat = values 55 | 56 | def __call__(self, x, S): 57 | # Prepare x. 58 | if len(x) != len(self.values_repeat): 59 | self.values_repeat = self.values.repeat(len(x), 0) 60 | 61 | # Replace specified indices. 62 | x_ = x.copy() 63 | x_[~S] = self.values_repeat[~S] 64 | 65 | # Make predictions. 66 | return self.model(x_) 67 | 68 | class MarginalExtensionApprox: 69 | '''Extend a model by marginalizing out removed features using their 70 | marginal distribution.''' 71 | def __init__(self, data_mean, model, grad_array): 72 | self.model = model 73 | self.data_mean = data_mean 74 | self.grad=grad_array 75 | 76 | def __call__(self, x, S): 77 | # Prepare samples. 78 | n=len(x) 79 | if len(self.data_mean) != n: 80 | self.data_repeat = np.tile(self.data_mean, (n, 1)) 81 | 82 | # Replace specified indices. 83 | x_ = x.copy() 84 | x_[~S] = self.data_repeat[~S] 85 | 86 | # Make predictions. 87 | pred = self.model(x) 88 | pred += (x_-x).dot(self.grad) 89 | pred = pred.reshape(-1, 1, *pred.shape[1:]) 90 | return np.mean(pred, axis=1) 91 | 92 | class UniformExtension: 93 | '''Extend a model by marginalizing out removed features using a 94 | uniform distribution.''' 95 | def __init__(self, values, categorical_inds, samples, model): 96 | self.model = model 97 | self.values = values 98 | self.categorical_inds = categorical_inds 99 | self.samples = samples 100 | 101 | def __call__(self, x, S): 102 | # Prepare x and S. 103 | n = len(x) 104 | x = x.repeat(self.samples, 0) 105 | S = S.repeat(self.samples, 0) 106 | 107 | # Prepare samples. 108 | samples = np.zeros((n * self.samples, x.shape[1])) 109 | for i in range(x.shape[1]): 110 | if i in self.categorical_inds: 111 | inds = np.random.choice( 112 | len(self.values[i]), n * self.samples) 113 | samples[:, i] = self.values[i][inds] 114 | else: 115 | samples[:, i] = np.random.uniform( 116 | low=self.values[i][0], high=self.values[i][1], 117 | size=n * self.samples) 118 | 119 | # Replace specified indices. 120 | x_ = x.copy() 121 | x_[~S] = samples[~S] 122 | 123 | # Make predictions. 124 | pred = self.model(x_) 125 | pred = pred.reshape(-1, self.samples, *pred.shape[1:]) 126 | return np.mean(pred, axis=1) 127 | 128 | 129 | class UniformContinuousExtension: 130 | ''' 131 | Extend a model by marginalizing out removed features using a 132 | uniform distribution. Specific to sets of continuous features. 133 | 134 | TODO: should we have caching here for repeating x? 135 | 136 | ''' 137 | def __init__(self, min_vals, max_vals, samples, model): 138 | self.model = model 139 | self.min = min_vals 140 | self.max = max_vals 141 | self.samples = samples 142 | 143 | def __call__(self, x, S): 144 | # Prepare x and S. 145 | x = x.repeat(self.samples, 0) 146 | S = S.repeat(self.samples, 0) 147 | 148 | # Prepare samples. 149 | u = np.random.uniform(size=x.shape) 150 | samples = u * self.min + (1 - u) * self.max 151 | 152 | # Replace specified indices. 153 | x_ = x.copy() 154 | x_[~S] = samples[~S] 155 | 156 | # Make predictions. 157 | pred = self.model(x_) 158 | pred = pred.reshape(-1, self.samples, *pred.shape[1:]) 159 | return np.mean(pred, axis=1) 160 | 161 | 162 | class ProductMarginalExtension: 163 | '''Extend a model by marginalizing out removed features the 164 | product of their marginal distributions.''' 165 | def __init__(self, data, samples, model): 166 | self.model = model 167 | self.data = data 168 | self.data_repeat = data 169 | self.samples = samples 170 | 171 | def __call__(self, x, S): 172 | # Prepare x and S. 173 | n = len(x) 174 | x = x.repeat(self.samples, 0) 175 | S = S.repeat(self.samples, 0) 176 | 177 | # Prepare samples. 178 | samples = np.zeros((n * self.samples, x.shape[1])) 179 | for i in range(x.shape[1]): 180 | inds = np.random.choice(len(self.data), n * self.samples) 181 | samples[:, i] = self.data[inds, i] 182 | 183 | # Replace specified indices. 184 | x_ = x.copy() 185 | x_[~S] = samples[~S] 186 | 187 | # Make predictions. 188 | pred = self.model(x_) 189 | pred = pred.reshape(-1, self.samples, *pred.shape[1:]) 190 | return np.mean(pred, axis=1) 191 | 192 | 193 | class SeparateModelExtension: 194 | '''Extend a model using separate models for each subset of features.''' 195 | def __init__(self, model_dict): 196 | self.model_dict = model_dict 197 | 198 | def __call__(self, x, S): 199 | output = [] 200 | for i in range(len(S)): 201 | # Extract model. 202 | row = S[i] 203 | model = self.model_dict[str(row)] 204 | 205 | # Make prediction. 206 | output.append(model(x[i:i+1, row])) 207 | 208 | return np.concatenate(output, axis=0) 209 | 210 | 211 | class ConditionalExtension: 212 | '''Extend a model by marginalizing out removed features using a model of 213 | their conditional distribution.''' 214 | def __init__(self, conditional_model, samples, model): 215 | self.model = model 216 | self.conditional_model = conditional_model 217 | self.samples = samples 218 | self.x_addr = None 219 | self.x_repeat = None 220 | 221 | def __call__(self, x, S): 222 | # Prepare x. 223 | if self.x_addr != id(x): 224 | self.x_addr = id(x) 225 | self.x_repeat = x.repeat(self.samples, 0) 226 | x = self.x_repeat 227 | 228 | # Prepare samples. 229 | S = S.repeat(self.samples, 0) 230 | samples = self.conditional_model(x, S) 231 | 232 | # Replace specified indices. 233 | x_ = x.copy() 234 | x_[~S] = samples[~S] 235 | 236 | # Make predictions. 237 | pred = self.model(x_) 238 | pred = pred.reshape(-1, self.samples, *pred.shape[1:]) 239 | return np.mean(pred, axis=1) 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /weightedSHAP/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from time import time 3 | import torch 4 | import torch.nn as nn 5 | from scipy.stats import ttest_ind 6 | 7 | from weightedSHAP.third_party import removal, surrogate 8 | from weightedSHAP.third_party.utils import MaskLayer1d, KLDivLoss, MSELoss 9 | 10 | def create_boosting_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model): 11 | print('Train a model to explain: Boosting') 12 | import lightgbm as lgb 13 | if problem == 'classification': 14 | params = { 15 | "learning_rate": 0.005, 16 | "objective": "binary", 17 | "metric": ["binary_logloss", "binary_error"], 18 | "num_threads": 1, 19 | "verbose": -1, 20 | "num_leaves":15, 21 | "bagging_fraction":0.5, 22 | "feature_fraction":0.5, 23 | "lambda_l2": 1e-3 24 | } 25 | else: 26 | params = { 27 | "learning_rate": 0.005, 28 | "objective": "mean_squared_error", 29 | "metric": "mean_squared_error", 30 | "num_threads": 1, 31 | "verbose": -1, 32 | "num_leaves":15, 33 | "bagging_fraction":0.5, 34 | "feature_fraction":0.5, 35 | "lambda_l2": 1e-3 36 | } 37 | 38 | d_train = lgb.Dataset(X_train, label=y_train) 39 | d_val = lgb.Dataset(X_val, label=y_val) 40 | 41 | callbacks=[lgb.early_stopping(25)] 42 | model_to_explain=lgb.train(params, d_train, 43 | num_boost_round=1000, 44 | valid_sets=[d_val], 45 | callbacks=callbacks) 46 | 47 | return model_to_explain 48 | 49 | def create_linear_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model): 50 | print('Train a model to explain: Linear') 51 | from sklearn.linear_model import LinearRegression, LogisticRegression 52 | if problem == 'classification': 53 | model_to_explain=LogisticRegression() 54 | model_to_explain.fit(X_train, y_train) 55 | else: 56 | model_to_explain=LinearRegression() 57 | model_to_explain.fit(X_train, y_train) 58 | 59 | return model_to_explain 60 | 61 | def create_MLP_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model): 62 | print('Train a model to explain: MLP') 63 | device = torch.device('cpu') 64 | num_features=X_train.shape[1] 65 | n_output=2 if problem=='classification' else 1 66 | model_to_explain = nn.Sequential( 67 | nn.Linear(num_features, 128), 68 | nn.ReLU(inplace=True), 69 | nn.Linear(128, 128), 70 | nn.ReLU(inplace=True), 71 | nn.Linear(128, n_output)).to(device) 72 | 73 | # training part 74 | return model_to_explain 75 | 76 | def check_overfitting(X_train, y_train, X_val, y_val, model_to_explain, problem, ML_model): 77 | ''' 78 | Check overfitting 79 | ''' 80 | if problem == 'classification': 81 | tmp_err=lambda y1, y2: ((y1 > 0.5) != y2) + 0.0 82 | else: 83 | tmp_err=lambda y1, y2: (y1-y2)**2 84 | 85 | if (ML_model == 'linear') and (problem == 'classification'): 86 | tr_pred_error=tmp_err(model_to_explain.predict_proba(X_train)[:,1], y_train) 87 | val_pred_error=tmp_err(model_to_explain.predict_proba(X_val)[:,1], y_val) 88 | elif ML_model == 'MLP': 89 | # not used 90 | y_train_pred=model_to_explain(torch.from_numpy(X_train.astype(np.float32))).detach().numpy().reshape(-1) 91 | y_val_pred=model_to_explain(torch.from_numpy(X_val.astype(np.float32))).detach().numpy().reshape(-1) 92 | tr_pred_error=tmp_err(y_train_pred, y_train) 93 | val_pred_error=tmp_err(y_val_pred, y_val) 94 | else: 95 | tr_pred_error=tmp_err(model_to_explain.predict(X_train), y_train) 96 | val_pred_error=tmp_err(model_to_explain.predict(X_val), y_val) 97 | 98 | p_value=ttest_ind(tr_pred_error, val_pred_error)[1] 99 | overfitting_check='Not overfitted' if p_value > 0.01 else 'Overfitted' 100 | 101 | tr_err, val_err = np.mean(tr_pred_error), np.mean(val_pred_error) 102 | print(f'Overfitting? / P-value: {overfitting_check} / {p_value:.4f}') 103 | print(f'Tr error, Val error: {tr_err:.3f}, {val_err:.3f}') 104 | return tr_err, val_err 105 | 106 | def create_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model): 107 | print('-'*30) 108 | print('Train a model') 109 | start_time=time() 110 | if ML_model=='linear': 111 | model_to_explain=create_linear_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model) 112 | elif ML_model=='boosting': 113 | model_to_explain=create_boosting_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model) 114 | elif ML_model=='MLP': 115 | model_to_explain=create_MLP_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model) 116 | else: 117 | raise ValueError(f'Check ML_model: {ML_model}') 118 | 119 | elapsed_time_train=time()-start_time 120 | print(f'Elapsed time for training a model to explain: {elapsed_time_train:.2f} seconds') 121 | print('-'*30) 122 | return model_to_explain # , tr_err, val_err 123 | 124 | 125 | def create_surrogate_model(model_to_explain, X_train, X_est, problem='classification', ML_model='linear', verbose=False): 126 | start_time=time() 127 | # [Step 1] Create surrogate model 128 | device = torch.device('cpu') 129 | num_features=X_train.shape[1] 130 | n_output=2 if problem=='classification' else 1 131 | surrogate_model = nn.Sequential( 132 | MaskLayer1d(value=0, append=True), 133 | nn.Linear(2 * num_features, 128), 134 | nn.ELU(inplace=True), 135 | nn.Linear(128, 128), 136 | nn.ELU(inplace=True), 137 | nn.Linear(128, n_output)).to(device) 138 | 139 | # Set up surrogate object 140 | surrogate_object = surrogate.Surrogate(surrogate_model, num_features) 141 | 142 | if problem=='classification': 143 | loss_fn=KLDivLoss() 144 | # predict_proba and predict 145 | def original_model(x): 146 | if ML_model == 'linear': 147 | pred = (model_to_explain.predict_proba(x.cpu().numpy())[:,1]).reshape(-1) 148 | else: 149 | pred = model_to_explain.predict(x.cpu().numpy()) 150 | 151 | pred = np.stack([1 - pred, pred]).T 152 | return torch.tensor(pred, dtype=torch.float32, device=x.device) 153 | else: 154 | loss_fn=MSELoss() 155 | def original_model(x): 156 | pred = model_to_explain.predict(x.cpu().numpy()).reshape(-1, 1) 157 | return torch.tensor(pred, dtype=torch.float32, device=x.device) 158 | 159 | # Train 160 | surrogate_object.train_original_model(X_train, 161 | X_est, 162 | original_model, 163 | batch_size=64, 164 | max_epochs=100, 165 | loss_fn=loss_fn, 166 | validation_samples=128, 167 | validation_batch_size=(2**12), 168 | lookback=10, 169 | verbose=verbose) 170 | elapsed_time=time()-start_time 171 | print(f'Elapsed time for training a surrogate model: {elapsed_time:.2f} seconds') 172 | return surrogate_model 173 | 174 | 175 | def generate_coalition_function(model_to_explain, X_train, X_est, 176 | problem='classification', ML_model='linear', verbose=False): 177 | surrogate_model=create_surrogate_model(model_to_explain, X_train, X_est, problem, ML_model, verbose) 178 | 179 | device=torch.device('cpu') 180 | if problem == 'classification': 181 | model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device), 182 | torch.tensor(S, dtype=torch.float32, device=device))).softmax(dim=-1).cpu().data.numpy().reshape(x.shape[0],-1)[:,1] 183 | elif problem == 'regression': 184 | model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device), 185 | torch.tensor(S, dtype=torch.float32, device=device))).cpu().data.numpy().reshape(x.shape[0],-1)[:,0] 186 | else: 187 | raise ValueError(f'Check problem: {problem}') 188 | conditional_extension=removal.ConditionalSupervisedExtension(model_condi_wrapper) 189 | 190 | return conditional_extension 191 | 192 | -------------------------------------------------------------------------------- /weightedSHAP/third_party/surrogate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/surrogate.py 3 | ''' 4 | import torch 5 | import torch.optim as optim 6 | import numpy as np 7 | from torch.utils.data import Dataset, TensorDataset, DataLoader 8 | from torch.utils.data import RandomSampler, BatchSampler 9 | from copy import deepcopy 10 | from tqdm.auto import tqdm 11 | from .utils import UniformSampler, DatasetRepeat 12 | 13 | def validate(surrogate, loss_fn, data_loader): 14 | ''' 15 | Calculate mean validation loss. 16 | Args: 17 | loss_fn: loss function. 18 | data_loader: data loader. 19 | ''' 20 | with torch.no_grad(): 21 | # Setup. 22 | device = next(surrogate.surrogate.parameters()).device 23 | mean_loss = 0 24 | N = 0 25 | 26 | for x, y, S in data_loader: 27 | x = x.to(device) 28 | y = y.to(device) 29 | S = S.to(device) 30 | try: 31 | pred = surrogate(x, S) 32 | except: 33 | pred = surrogate(x, S.long()) 34 | loss = loss_fn(pred, y) 35 | N += len(x) 36 | mean_loss += len(x) * (loss - mean_loss) / N 37 | 38 | return mean_loss 39 | 40 | 41 | def generate_labels(dataset, model, batch_size): 42 | ''' 43 | Generate prediction labels for a set of inputs. 44 | Args: 45 | dataset: dataset object. 46 | model: predictive model. 47 | batch_size: minibatch size. 48 | ''' 49 | with torch.no_grad(): 50 | # Setup. 51 | preds = [] 52 | if isinstance(model, torch.nn.Module): 53 | device = next(model.parameters()).device 54 | else: 55 | device = torch.device('cpu') 56 | loader = DataLoader(dataset, batch_size=batch_size) 57 | 58 | for (x,) in loader: 59 | pred = model(x.to(device)).cpu() 60 | preds.append(pred) 61 | 62 | return torch.cat(preds) 63 | 64 | 65 | class Surrogate: 66 | ''' 67 | Wrapper around surrogate model. 68 | Args: 69 | surrogate: surrogate model. 70 | num_features: number of features. 71 | groups: (optional) feature groups, represented by a list of lists. 72 | ''' 73 | 74 | def __init__(self, surrogate, num_features, groups=None): 75 | # Store surrogate model. 76 | self.surrogate = surrogate 77 | 78 | # Store feature groups. 79 | if groups is None: 80 | self.num_players = num_features 81 | self.groups_matrix = None 82 | else: 83 | # Verify groups. 84 | inds_list = [] 85 | for group in groups: 86 | inds_list += list(group) 87 | assert np.all(np.sort(inds_list) == np.arange(num_features)) 88 | 89 | # Map groups to features. 90 | self.num_players = len(groups) 91 | device = next(surrogate.parameters()).device 92 | self.groups_matrix = torch.zeros( 93 | len(groups), num_features, dtype=torch.float32, device=device) 94 | for i, group in enumerate(groups): 95 | self.groups_matrix[i, group] = 1 96 | 97 | def train(self, 98 | train_data, 99 | val_data, 100 | batch_size, 101 | max_epochs, 102 | loss_fn, 103 | validation_samples=1, 104 | validation_batch_size=None, 105 | lr=1e-3, 106 | min_lr=1e-5, 107 | lr_factor=0.5, 108 | lookback=5, 109 | training_seed=None, 110 | validation_seed=None, 111 | bar=False, 112 | verbose=False): 113 | ''' 114 | Train surrogate model. 115 | Args: 116 | train_data: training data with inputs and the original model's 117 | predictions (np.ndarray tuple, torch.Tensor tuple, 118 | torch.utils.data.Dataset). 119 | val_data: validation data with inputs and the original model's 120 | predictions (np.ndarray tuple, torch.Tensor tuple, 121 | torch.utils.data.Dataset). 122 | batch_size: minibatch size. 123 | max_epochs: maximum training epochs. 124 | loss_fn: loss function (e.g., fastshap.KLDivLoss). 125 | validation_samples: number of samples per validation example. 126 | validation_batch_size: validation minibatch size. 127 | lr: initial learning rate. 128 | min_lr: minimum learning rate. 129 | lr_factor: learning rate decrease factor. 130 | lookback: lookback window for early stopping. 131 | training_seed: random seed for training. 132 | validation_seed: random seed for generating validation data. 133 | verbose: verbosity. 134 | ''' 135 | # Set up train dataset. 136 | if isinstance(train_data, tuple): 137 | x_train, y_train = train_data 138 | if isinstance(x_train, np.ndarray): 139 | x_train = torch.tensor(x_train, dtype=torch.float32) 140 | y_train = torch.tensor(y_train, dtype=torch.float32) 141 | train_set = TensorDataset(x_train, y_train) 142 | elif isinstance(train_data, Dataset): 143 | train_set = train_data 144 | else: 145 | raise ValueError('train_data must be either tuple of tensors or a ' 146 | 'PyTorch Dataset') 147 | 148 | # Set up train data loader. 149 | random_sampler = RandomSampler( 150 | train_set, replacement=True, 151 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size) 152 | batch_sampler = BatchSampler( 153 | random_sampler, batch_size=batch_size, drop_last=True) 154 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler) 155 | 156 | # Set up validation dataset. 157 | sampler = UniformSampler(self.num_players) 158 | if validation_seed is not None: 159 | torch.manual_seed(validation_seed) 160 | S_val = sampler.sample(len(val_data) * validation_samples) 161 | 162 | if isinstance(val_data, tuple): 163 | x_val, y_val = val_data 164 | if isinstance(x_val, np.ndarray): 165 | x_val = torch.tensor(x_val, dtype=torch.float32) 166 | y_val = torch.tensor(y_val, dtype=torch.float32) 167 | x_val_repeat = x_val.repeat(validation_samples, 1) 168 | y_val_repeat = y_val.repeat(validation_samples, 1) 169 | val_set = TensorDataset( 170 | x_val_repeat, y_val_repeat, S_val) 171 | elif isinstance(val_data, Dataset): 172 | val_set = DatasetRepeat([val_data, TensorDataset(S_val)]) 173 | else: 174 | raise ValueError('val_data must be either tuple of tensors or a ' 175 | 'PyTorch Dataset') 176 | 177 | if validation_batch_size is None: 178 | validation_batch_size = batch_size 179 | val_loader = DataLoader(val_set, batch_size=validation_batch_size) 180 | 181 | # Setup for training. 182 | surrogate = self.surrogate 183 | device = next(surrogate.parameters()).device 184 | optimizer = optim.Adam(surrogate.parameters(), lr=lr) 185 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 186 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr, 187 | verbose=verbose) 188 | best_loss = validate(self, loss_fn, val_loader).item() 189 | best_epoch = 0 190 | best_model = deepcopy(surrogate) 191 | loss_list = [best_loss] 192 | if training_seed is not None: 193 | torch.manual_seed(training_seed) 194 | 195 | for epoch in range(max_epochs): 196 | # Batch iterable. 197 | if bar: 198 | batch_iter = tqdm(train_loader, desc='Training epoch') 199 | else: 200 | batch_iter = train_loader 201 | 202 | for x, y in batch_iter: 203 | # Prepare data. 204 | x = x.to(device) 205 | y = y.to(device) 206 | 207 | # Generate subsets. 208 | S = sampler.sample(batch_size).to(device=device) 209 | 210 | # Make predictions. 211 | pred = self.__call__(x, S) 212 | loss = loss_fn(pred, y) 213 | 214 | # Optimizer step. 215 | loss.backward() 216 | optimizer.step() 217 | surrogate.zero_grad() 218 | 219 | # Evaluate validation loss. 220 | self.surrogate.eval() 221 | val_loss = validate(self, loss_fn, val_loader).item() 222 | self.surrogate.train() 223 | 224 | # Print progress. 225 | if verbose: 226 | print('----- Epoch = {} -----'.format(epoch + 1)) 227 | print('Val loss = {:.4f}'.format(val_loss)) 228 | print('') 229 | scheduler.step(val_loss) 230 | loss_list.append(val_loss) 231 | 232 | # Check if best model. 233 | if val_loss < best_loss: 234 | best_loss = val_loss 235 | best_model = deepcopy(surrogate) 236 | best_epoch = epoch 237 | if verbose: 238 | print('New best epoch, loss = {:.4f}'.format(val_loss)) 239 | print('') 240 | elif epoch - best_epoch == lookback: 241 | if verbose: 242 | print('Stopping early') 243 | break 244 | 245 | # Clean up. 246 | for param, best_param in zip(surrogate.parameters(), 247 | best_model.parameters()): 248 | param.data = best_param.data 249 | self.loss_list = loss_list 250 | self.surrogate.eval() 251 | 252 | def train_original_model(self, 253 | train_data, 254 | val_data, 255 | original_model, 256 | batch_size, 257 | max_epochs, 258 | loss_fn, 259 | validation_samples=1, 260 | validation_batch_size=None, 261 | lr=1e-3, 262 | min_lr=1e-5, 263 | lr_factor=0.5, 264 | lookback=5, 265 | training_seed=None, 266 | validation_seed=None, 267 | bar=False, 268 | verbose=False): 269 | ''' 270 | Train surrogate model with labels provided by the original model. 271 | Args: 272 | train_data: training data with inputs only (np.ndarray, torch.Tensor, 273 | torch.utils.data.Dataset). 274 | val_data: validation data with inputs only (np.ndarray, torch.Tensor, 275 | torch.utils.data.Dataset). 276 | original_model: original predictive model (e.g., torch.nn.Module). 277 | batch_size: minibatch size. 278 | max_epochs: maximum training epochs. 279 | loss_fn: loss function (e.g., fastshap.KLDivLoss). 280 | validation_samples: number of samples per validation example. 281 | validation_batch_size: validation minibatch size. 282 | lr: initial learning rate. 283 | min_lr: minimum learning rate. 284 | lr_factor: learning rate decrease factor. 285 | lookback: lookback window for early stopping. 286 | training_seed: random seed for training. 287 | validation_seed: random seed for generating validation data. 288 | verbose: verbosity. 289 | ''' 290 | # Set up train dataset. 291 | if isinstance(train_data, np.ndarray): 292 | train_data = torch.tensor(train_data, dtype=torch.float32) 293 | 294 | if isinstance(train_data, torch.Tensor): 295 | train_set = TensorDataset(train_data) 296 | elif isinstance(train_data, Dataset): 297 | train_set = train_data 298 | else: 299 | raise ValueError('train_data must be either tensor or a ' 300 | 'PyTorch Dataset') 301 | 302 | # Set up train data loader. 303 | random_sampler = RandomSampler( 304 | train_set, replacement=True, 305 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size) 306 | batch_sampler = BatchSampler( 307 | random_sampler, batch_size=batch_size, drop_last=True) 308 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler) 309 | 310 | # Set up validation dataset. 311 | sampler = UniformSampler(self.num_players) 312 | if validation_seed is not None: 313 | torch.manual_seed(validation_seed) 314 | S_val = sampler.sample(len(val_data) * validation_samples) 315 | if validation_batch_size is None: 316 | validation_batch_size = batch_size 317 | 318 | if isinstance(val_data, np.ndarray): 319 | val_data = torch.tensor(val_data, dtype=torch.float32) 320 | 321 | if isinstance(val_data, torch.Tensor): 322 | # Generate validation labels. 323 | y_val = generate_labels(TensorDataset(val_data), original_model, 324 | validation_batch_size) 325 | y_val_repeat = y_val.repeat( 326 | validation_samples, *[1 for _ in y_val.shape[1:]]) 327 | 328 | # Create dataset. 329 | val_data_repeat = val_data.repeat(validation_samples, 1) 330 | val_set = TensorDataset(val_data_repeat, y_val_repeat, S_val) 331 | elif isinstance(val_data, Dataset): 332 | # Generate validation labels. 333 | y_val = generate_labels(val_data, original_model, 334 | validation_batch_size) 335 | y_val_repeat = y_val.repeat( 336 | validation_samples, *[1 for _ in y_val.shape[1:]]) 337 | 338 | # Create dataset. 339 | val_set = DatasetRepeat( 340 | [val_data, TensorDataset(y_val_repeat, S_val)]) 341 | else: 342 | raise ValueError('val_data must be either tuple of tensors or a ' 343 | 'PyTorch Dataset') 344 | 345 | val_loader = DataLoader(val_set, batch_size=validation_batch_size) 346 | 347 | # Setup for training. 348 | surrogate = self.surrogate 349 | device = next(surrogate.parameters()).device 350 | optimizer = optim.Adam(surrogate.parameters(), lr=lr) 351 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 352 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr, 353 | verbose=verbose) 354 | best_loss = validate(self, loss_fn, val_loader).item() 355 | best_epoch = 0 356 | best_model = deepcopy(surrogate) 357 | loss_list = [best_loss] 358 | if training_seed is not None: 359 | torch.manual_seed(training_seed) 360 | 361 | for epoch in range(max_epochs): 362 | # Batch iterable. 363 | if bar: 364 | batch_iter = tqdm(train_loader, desc='Training epoch') 365 | else: 366 | batch_iter = train_loader 367 | 368 | for (x,) in batch_iter: 369 | # Prepare data. 370 | x = x.to(device) 371 | 372 | # Get original model prediction. 373 | with torch.no_grad(): 374 | y = original_model(x) 375 | 376 | # Generate subsets. 377 | S = sampler.sample(batch_size).to(device=device) 378 | 379 | # Make predictions. 380 | pred = self.__call__(x, S) 381 | loss = loss_fn(pred, y) 382 | 383 | # Optimizer step. 384 | loss.backward() 385 | optimizer.step() 386 | surrogate.zero_grad() 387 | 388 | # Evaluate validation loss. 389 | self.surrogate.eval() 390 | val_loss = validate(self, loss_fn, val_loader).item() 391 | self.surrogate.train() 392 | 393 | # Print progress. 394 | if verbose: 395 | print('----- Epoch = {} -----'.format(epoch + 1)) 396 | print('Val loss = {:.4f}'.format(val_loss)) 397 | print('') 398 | scheduler.step(val_loss) 399 | loss_list.append(val_loss) 400 | 401 | # Check if best model. 402 | if val_loss < best_loss: 403 | best_loss = val_loss 404 | best_model = deepcopy(surrogate) 405 | best_epoch = epoch 406 | if verbose: 407 | print('New best epoch, loss = {:.4f}'.format(val_loss)) 408 | print('') 409 | elif epoch - best_epoch == lookback: 410 | if verbose: 411 | print('Stopping early') 412 | break 413 | 414 | # Clean up. 415 | for param, best_param in zip(surrogate.parameters(), 416 | best_model.parameters()): 417 | param.data = best_param.data 418 | self.loss_list = loss_list 419 | self.surrogate.eval() 420 | 421 | def __call__(self, x, S): 422 | ''' 423 | Evaluate surrogate model. 424 | Args: 425 | x: input examples. 426 | S: coalitions. 427 | ''' 428 | if self.groups_matrix is not None: 429 | S = torch.mm(S, self.groups_matrix) 430 | try: 431 | surr_value=self.surrogate((x, S)) 432 | except: 433 | surr_value=self.surrogate((x, S.long())) 434 | return surr_value 435 | 436 | -------------------------------------------------------------------------------- /weightedSHAP/third_party/image_surrogate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/image_surrogate.py 3 | ''' 4 | import torch 5 | import torch.optim as optim 6 | import numpy as np 7 | from torch.utils.data import Dataset, TensorDataset, DataLoader 8 | from torch.utils.data import RandomSampler, BatchSampler 9 | from .image_imputers import ImageImputer 10 | from .utils import UniformSampler, DatasetRepeat 11 | from tqdm.auto import tqdm 12 | from copy import deepcopy 13 | 14 | def validate(surrogate, loss_fn, data_loader): 15 | ''' 16 | Calculate mean validation loss. 17 | Args: 18 | loss_fn: loss function. 19 | data_loader: data loader. 20 | ''' 21 | with torch.no_grad(): 22 | # Setup. 23 | device = next(surrogate.surrogate.parameters()).device 24 | mean_loss = 0 25 | N = 0 26 | 27 | for x, y, S in data_loader: 28 | x = x.to(device) 29 | y = y.to(device) 30 | S = S.to(device) 31 | pred = surrogate(x, S) 32 | loss = loss_fn(pred, y) 33 | N += len(x) 34 | mean_loss += len(x) * (loss - mean_loss) / N 35 | 36 | return mean_loss 37 | 38 | 39 | def generate_labels(dataset, model, batch_size, num_workers): 40 | ''' 41 | Generate prediction labels for a set of inputs. 42 | Args: 43 | dataset: dataset object. 44 | model: predictive model. 45 | batch_size: minibatch size. 46 | num_workers: number of worker threads. 47 | ''' 48 | with torch.no_grad(): 49 | # Setup. 50 | preds = [] 51 | device = next(model.parameters()).device 52 | loader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, 53 | num_workers=num_workers) 54 | 55 | for (x,) in loader: 56 | pred = model(x.to(device)).cpu() 57 | preds.append(pred) 58 | 59 | return torch.cat(preds) 60 | 61 | 62 | class ImageSurrogate(ImageImputer): 63 | ''' 64 | Wrapper around image surrogate model. 65 | Args: 66 | surrogate: surrogate model (torch.nn.Module). 67 | width: image width. 68 | height: image height. 69 | superpixel_size: superpixel width/height (int). 70 | ''' 71 | 72 | def __init__(self, surrogate, width, height, superpixel_size): 73 | # Initialize for coalition resizing, number of players. 74 | super().__init__(width, height, superpixel_size) 75 | 76 | # Store surrogate model. 77 | self.surrogate = surrogate 78 | 79 | def train(self, 80 | train_data, 81 | val_data, 82 | batch_size, 83 | max_epochs, 84 | loss_fn, 85 | validation_samples=1, 86 | validation_batch_size=None, 87 | lr=1e-3, 88 | min_lr=1e-5, 89 | lr_factor=0.5, 90 | lookback=5, 91 | training_seed=None, 92 | validation_seed=None, 93 | num_workers=0, 94 | bar=False, 95 | verbose=False): 96 | ''' 97 | Train surrogate model. 98 | Args: 99 | train_data: training data with inputs and the original model's 100 | predictions (np.ndarray tuple, torch.Tensor tuple, 101 | torch.utils.data.Dataset). 102 | val_data: validation data with inputs and the original model's 103 | predictions (np.ndarray tuple, torch.Tensor tuple, 104 | torch.utils.data.Dataset). 105 | batch_size: minibatch size. 106 | max_epochs: max number of training epochs. 107 | loss_fn: loss function (e.g., fastshap.KLDivLoss) 108 | validation_samples: number of samples per validation example. 109 | validation_batch_size: validation minibatch size. 110 | lr: initial learning rate. 111 | min_lr: minimum learning rate. 112 | lr_factor: learning rate decrease factor. 113 | lookback: lookback window for early stopping. 114 | training_seed: random seed for training. 115 | validation_seed: random seed for generating validation data. 116 | num_workers: number of worker threads in data loader. 117 | bar: whether to show progress bar. 118 | verbose: verbosity. 119 | ''' 120 | # Set up train dataset. 121 | if isinstance(train_data, tuple): 122 | x_train, y_train = train_data 123 | if isinstance(x_train, np.ndarray): 124 | x_train = torch.tensor(x_train, dtype=torch.float32) 125 | y_train = torch.tensor(y_train, dtype=torch.float32) 126 | train_set = TensorDataset(x_train, y_train) 127 | elif isinstance(train_data, Dataset): 128 | train_set = train_data 129 | else: 130 | raise ValueError('train_data must be either tuple of tensors or a ' 131 | 'PyTorch Dataset') 132 | 133 | # Set up train data loader. 134 | random_sampler = RandomSampler( 135 | train_set, replacement=True, 136 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size) 137 | batch_sampler = BatchSampler( 138 | random_sampler, batch_size=batch_size, drop_last=True) 139 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler, 140 | pin_memory=True, num_workers=num_workers) 141 | 142 | # Set up validation dataset. 143 | sampler = UniformSampler(self.num_players) 144 | if validation_seed is not None: 145 | torch.manual_seed(validation_seed) 146 | S_val = sampler.sample(len(val_data) * validation_samples) 147 | 148 | if isinstance(val_data, tuple): 149 | x_val, y_val = val_data 150 | if isinstance(x_val, np.ndarray): 151 | x_val = torch.tensor(x_val, dtype=torch.float32) 152 | y_val = torch.tensor(y_val, dtype=torch.float32) 153 | x_val_repeat = x_val.repeat(validation_samples, 1, 1, 1) 154 | y_val_repeat = y_val.repeat(validation_samples, 1) 155 | val_set = TensorDataset(x_val_repeat, y_val_repeat, S_val) 156 | elif isinstance(val_data, Dataset): 157 | val_set = DatasetRepeat([val_data, TensorDataset(S_val)]) 158 | else: 159 | raise ValueError('val_data must be either tuple of tensors or a ' 160 | 'PyTorch Dataset') 161 | 162 | if validation_batch_size is None: 163 | validation_batch_size = batch_size 164 | val_loader = DataLoader(val_set, batch_size=validation_batch_size, 165 | pin_memory=True, num_workers=num_workers) 166 | 167 | # Setup for training. 168 | surrogate = self.surrogate 169 | device = next(surrogate.parameters()).device 170 | optimizer = optim.Adam(surrogate.parameters(), lr=lr) 171 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 172 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr, 173 | verbose=verbose) 174 | best_loss = validate(self, loss_fn, val_loader).item() 175 | best_epoch = 0 176 | best_model = deepcopy(surrogate) 177 | loss_list = [best_loss] 178 | if training_seed is not None: 179 | torch.manual_seed(training_seed) 180 | 181 | for epoch in range(max_epochs): 182 | # Batch iterable. 183 | if bar: 184 | batch_iter = tqdm(train_loader, desc='Training epoch') 185 | else: 186 | batch_iter = train_loader 187 | 188 | for x, y in batch_iter: 189 | # Prepare data. 190 | x = x.to(device) 191 | y = y.to(device) 192 | 193 | # Generate subsets. 194 | S = sampler.sample(batch_size).to(device=device) 195 | 196 | # Make predictions. 197 | pred = self.__call__(x, S) 198 | loss = loss_fn(pred, y) 199 | 200 | # Optimizer step. 201 | loss.backward() 202 | optimizer.step() 203 | surrogate.zero_grad() 204 | 205 | # Evaluate validation loss. 206 | self.surrogate.eval() 207 | val_loss = validate(self, loss_fn, val_loader).item() 208 | self.surrogate.train() 209 | 210 | # Print progress. 211 | if verbose: 212 | print('----- Epoch = {} -----'.format(epoch + 1)) 213 | print('Val loss = {:.4f}'.format(val_loss)) 214 | print('') 215 | scheduler.step(val_loss) 216 | loss_list.append(val_loss) 217 | 218 | # Check if best model. 219 | if val_loss < best_loss: 220 | best_loss = val_loss 221 | best_model = deepcopy(surrogate) 222 | best_epoch = epoch 223 | if verbose: 224 | print('New best epoch, loss = {:.4f}'.format(val_loss)) 225 | print('') 226 | elif epoch - best_epoch == lookback: 227 | if verbose: 228 | print('Stopping early') 229 | break 230 | 231 | # Clean up. 232 | for param, best_param in zip(surrogate.parameters(), 233 | best_model.parameters()): 234 | param.data = best_param.data 235 | self.loss_list = loss_list 236 | self.surrogate.eval() 237 | 238 | def train_original_model(self, 239 | train_data, 240 | val_data, 241 | original_model, 242 | batch_size, 243 | max_epochs, 244 | loss_fn, 245 | validation_samples=1, 246 | validation_batch_size=None, 247 | lr=1e-3, 248 | min_lr=1e-5, 249 | lr_factor=0.5, 250 | lookback=5, 251 | training_seed=None, 252 | validation_seed=None, 253 | num_workers=0, 254 | bar=False, 255 | verbose=False): 256 | ''' 257 | Train surrogate model with labels provided by the original model. This 258 | approach is designed for when data augmentations make the data loader 259 | non-deterministic, and labels (the original model's predictions) cannot 260 | be generated prior to training. 261 | Args: 262 | train_data: training data with inputs only (np.ndarray, torch.Tensor, 263 | torch.utils.data.Dataset). 264 | val_data: validation data with inputs only (np.ndarray, torch.Tensor, 265 | torch.utils.data.Dataset). 266 | original_model: original predictive model (e.g., torch.nn.Module). 267 | batch_size: minibatch size. 268 | max_epochs: max number of training epochs. 269 | loss_fn: loss function (e.g., fastshap.KLDivLoss) 270 | validation_samples: number of samples per validation example. 271 | validation_batch_size: validation minibatch size. 272 | lr: initial learning rate. 273 | min_lr: minimum learning rate. 274 | lr_factor: learning rate decrease factor. 275 | lookback: lookback window for early stopping. 276 | training_seed: random seed for training. 277 | validation_seed: random seed for generating validation data. 278 | num_workers: number of worker threads in data loader. 279 | bar: whether to show progress bar. 280 | verbose: verbosity. 281 | ''' 282 | # Set up train dataset. 283 | if isinstance(train_data, np.ndarray): 284 | train_data = torch.tensor(train_data, dtype=torch.float32) 285 | 286 | if isinstance(train_data, torch.Tensor): 287 | train_set = TensorDataset(train_data) 288 | elif isinstance(train_data, Dataset): 289 | train_set = train_data 290 | else: 291 | raise ValueError('train_data must be either tensor or a ' 292 | 'PyTorch Dataset') 293 | 294 | # Set up train data loader. 295 | random_sampler = RandomSampler( 296 | train_set, replacement=True, 297 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size) 298 | batch_sampler = BatchSampler( 299 | random_sampler, batch_size=batch_size, drop_last=True) 300 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler, 301 | pin_memory=True, num_workers=num_workers) 302 | 303 | # Set up validation dataset. 304 | sampler = UniformSampler(self.num_players) 305 | if validation_seed is not None: 306 | torch.manual_seed(validation_seed) 307 | S_val = sampler.sample(len(val_data) * validation_samples) 308 | if validation_batch_size is None: 309 | validation_batch_size = batch_size 310 | 311 | if isinstance(val_data, np.ndarray): 312 | val_data = torch.tensor(val_data, dtype=torch.float32) 313 | 314 | if isinstance(val_data, torch.Tensor): 315 | # Generate validation labels. 316 | y_val = generate_labels(TensorDataset(val_data), original_model, 317 | validation_batch_size, num_workers) 318 | y_val_repeat = y_val.repeat( 319 | validation_samples, *[1 for _ in y_val.shape[1:]]) 320 | 321 | # Create dataset. 322 | val_data_repeat = val_data.repeat(validation_samples, 1, 1, 1) 323 | val_set = TensorDataset(val_data_repeat, y_val_repeat, S_val) 324 | elif isinstance(val_data, Dataset): 325 | # Generate validation labels. 326 | y_val = generate_labels(val_data, original_model, 327 | validation_batch_size, num_workers) 328 | y_val_repeat = y_val.repeat( 329 | validation_samples, *[1 for _ in y_val.shape[1:]]) 330 | 331 | # Create dataset. 332 | val_set = DatasetRepeat( 333 | [val_data, TensorDataset(y_val_repeat, S_val)]) 334 | else: 335 | raise ValueError('val_data must be either tuple of tensors or a ' 336 | 'PyTorch Dataset') 337 | 338 | val_loader = DataLoader(val_set, batch_size=validation_batch_size, 339 | pin_memory=True, num_workers=num_workers) 340 | 341 | # Setup for training. 342 | surrogate = self.surrogate 343 | device = next(surrogate.parameters()).device 344 | optimizer = optim.Adam(surrogate.parameters(), lr=lr) 345 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 346 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr, 347 | verbose=verbose) 348 | best_loss = validate(self, loss_fn, val_loader).item() 349 | best_epoch = 0 350 | best_model = deepcopy(surrogate) 351 | loss_list = [best_loss] 352 | if training_seed is not None: 353 | torch.manual_seed(training_seed) 354 | 355 | for epoch in range(max_epochs): 356 | # Batch iterable. 357 | if bar: 358 | batch_iter = tqdm(train_loader, desc='Training epoch') 359 | else: 360 | batch_iter = train_loader 361 | 362 | for (x,) in batch_iter: 363 | # Prepare data. 364 | x = x.to(device) 365 | 366 | # Get original model prediction. 367 | with torch.no_grad(): 368 | y = original_model(x) 369 | 370 | # Generate subsets. 371 | S = sampler.sample(batch_size).to(device=device) 372 | 373 | # Make predictions. 374 | pred = self.__call__(x, S) 375 | loss = loss_fn(pred, y) 376 | 377 | # Optimizer step. 378 | loss.backward() 379 | optimizer.step() 380 | surrogate.zero_grad() 381 | 382 | # Evaluate validation loss. 383 | self.surrogate.eval() 384 | val_loss = validate(self, loss_fn, val_loader).item() 385 | self.surrogate.train() 386 | 387 | # Print progress. 388 | if verbose: 389 | print('----- Epoch = {} -----'.format(epoch + 1)) 390 | print('Val loss = {:.4f}'.format(val_loss)) 391 | print('') 392 | scheduler.step(val_loss) 393 | loss_list.append(val_loss) 394 | 395 | # Check if best model. 396 | if val_loss < best_loss: 397 | best_loss = val_loss 398 | best_model = deepcopy(surrogate) 399 | best_epoch = epoch 400 | if verbose: 401 | print('New best epoch, loss = {:.4f}'.format(val_loss)) 402 | print('') 403 | elif epoch - best_epoch == lookback: 404 | if verbose: 405 | print('Stopping early') 406 | break 407 | 408 | # Clean up. 409 | for param, best_param in zip(surrogate.parameters(), 410 | best_model.parameters()): 411 | param.data = best_param.data 412 | self.loss_list = loss_list 413 | self.surrogate.eval() 414 | 415 | def __call__(self, x, S): 416 | ''' 417 | Evaluate surrogate model. 418 | Args: 419 | x: input examples. 420 | S: coalitions. 421 | ''' 422 | S = self.resize(S) 423 | return self.surrogate((x, S)) 424 | 425 | 426 | --------------------------------------------------------------------------------