├── .gitignore
├── fig
├── 2d-weightedshap.png
├── mnist-weightedshap.png
└── inclusion-weightedshap.png
├── notebook
├── fraud_dataset.pkl
└── fraud_example.pickle
├── weightedSHAP
├── third_party
│ ├── __init__.py
│ ├── behavior.py
│ ├── image_imputers.py
│ ├── resnet.py
│ ├── utils.py
│ ├── removal.py
│ ├── surrogate.py
│ └── image_surrogate.py
├── __init__.py
├── weightedSHAPEngine.py
├── utils.py
├── data.py
└── train.py
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | .DS_Store
4 | .ipynb_checkpoints
5 | /*.ipynb
6 | *.pt
7 | clf_path
8 |
--------------------------------------------------------------------------------
/fig/2d-weightedshap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/fig/2d-weightedshap.png
--------------------------------------------------------------------------------
/fig/mnist-weightedshap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/fig/mnist-weightedshap.png
--------------------------------------------------------------------------------
/notebook/fraud_dataset.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/notebook/fraud_dataset.pkl
--------------------------------------------------------------------------------
/fig/inclusion-weightedshap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/fig/inclusion-weightedshap.png
--------------------------------------------------------------------------------
/notebook/fraud_example.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ykwon0407/WeightedSHAP/HEAD/notebook/fraud_example.pickle
--------------------------------------------------------------------------------
/weightedSHAP/third_party/__init__.py:
--------------------------------------------------------------------------------
1 | from . import behavior, image_imputers, image_surrogate, removal, resnet, surrogate, utils
--------------------------------------------------------------------------------
/weightedSHAP/__init__.py:
--------------------------------------------------------------------------------
1 | from . import data, third_party, train, utils, weightedSHAPEngine
2 |
3 | from .data import load_data
4 | from .train import create_model_to_explain, generate_coalition_function
5 | from .weightedSHAPEngine import compute_attributions
--------------------------------------------------------------------------------
/weightedSHAP/third_party/behavior.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://github.com/iancovert/removal-explanations/blob/main/rexplain/behavior.py
3 | '''
4 | import numpy as np
5 |
6 | class PredictionGame:
7 | '''
8 | Cooperative game for an individual example's prediction.
9 |
10 | Args:
11 | extension: model extension (see removal.py).
12 | sample: numpy array representing a single model input.
13 | '''
14 |
15 | def __init__(self, extension, sample, superpixel_size=1):
16 | # Add batch dimension to sample.
17 | if sample.ndim == 1:
18 | sample = sample[np.newaxis]
19 | # elif sample.shape[0] != 1:
20 | # raise ValueError('sample must have shape (ndim,) or (1,ndim)')
21 |
22 | self.extension = extension
23 | self.sample = sample
24 | self.players = np.prod(sample.shape)//(superpixel_size**2)//sample.shape[0] # sample.shape[1]
25 |
26 | # Caching.
27 | self.sample_repeat = sample
28 |
29 | def __call__(self, S):
30 | # Return scalar if single subset.
31 | single_eval = (S.ndim == 1)
32 | if single_eval:
33 | S = S[np.newaxis]
34 | input_data = self.sample
35 | else:
36 | # Try to use caching for repeated data.
37 | if len(S) != len(self.sample_repeat):
38 | self.sample_repeat = self.sample.repeat(len(S), 0)
39 | input_data = self.sample_repeat
40 |
41 | # Evaluate.
42 | output = self.extension(input_data, S)
43 | if single_eval:
44 | output = output[0]
45 | return output
46 |
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WeightedSHAP: analyzing and improving Shapley based feature attributions
2 |
3 | This repository provides an implementation of the paper *[WeightedSHAP: analyzing and improving Shapley based feature attributions](https://arxiv.org/abs/2209.13429)* accepted at [NeurIPS 2022](https://nips.cc/Conferences/2022). We show the suboptimality of SHAP and propose **a new feature attribution method called WeightedSHAP**. WeightedSHAP is a generalization of SHAP and is more effective to capture influential features.
4 |
5 | ### Quick start
6 |
7 | We provide an easy-to-follow [Jupyter notebook](notebook/Example_fraud_inclusion_AUC.ipynb), which introduces how to compute the WeightedSHAP on the Fraud dataset.
8 |
9 | ### Key results
10 |
11 |
12 |
13 |
14 |
15 | → Illustrations of the suboptimality of Shapley-based feature attributions (SHAP) when $d=2$. ***Shapley value fails to assign large attributions to more influential features*** on grey area.
16 |
17 |
18 |
19 |
20 |
21 | → Illustrations of the prediction recovery error curve and the Inclusion AUC curve as a function of the number of features added. ***WeightedSHAP effectively assigns larger values for more influential features*** and recovers the original prediction $\hat{f}(x)$ significantly faster than other state-of-the-art methods.
22 |
23 |
24 |
25 |
26 |
27 | → ***WeightedSHAP can identify more interpretable features***. In particular, SHAP fails to capture the last stroke of digit nine, which is a crucially important stroke to differentiate from the digit zero.
28 |
29 |
30 | ### References
31 |
32 | This repository highly depends on the following two repositories.
33 |
34 | - Covert, I., Lundberg, S. M., & Lee, S. I. (2021). Explaining by Removing: A Unified Framework for Model Explanation. J. Mach. Learn. Res., 22, 209-1. [[GitHub]](https://github.com/iancovert/removal-explanations)
35 |
36 | - Jethani, N., Sudarshan, M., Covert, I. C., Lee, S. I., & Ranganath, R. (2021, September). FastSHAP: Real-Time Shapley Value Estimation. In International Conference on Learning Representations. [[GitHub]](https://github.com/iancovert/fastshap/tree/main/fastshap)
37 |
38 | ### Authors
39 |
40 | - Yongchan Kwon (yk3012 (at) columbia (dot) edu)
41 |
42 | - James Zou (jamesz (at) stanford (dot) edu)
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/weightedSHAP/third_party/image_imputers.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/image_imputers.py
3 | '''
4 | import torch.nn as nn
5 |
6 | class ImageImputer:
7 | '''
8 | Image imputer base class.
9 | Args:
10 | width: image width.
11 | height: image height.
12 | superpixel_size: superpixel width/height (int).
13 | '''
14 |
15 | def __init__(self, width, height, superpixel_size=1):
16 | # Verify arguments.
17 | assert width % superpixel_size == 0
18 | assert height % superpixel_size == 0
19 |
20 | # Set up superpixel upsampling.
21 | self.width = width
22 | self.height = height
23 | self.supsize = superpixel_size
24 | if superpixel_size == 1:
25 | self.upsample = nn.Identity()
26 | else:
27 | self.upsample = nn.Upsample(
28 | scale_factor=superpixel_size, mode='nearest')
29 |
30 | # Set up number of players.
31 | self.small_width = width // superpixel_size
32 | self.small_height = height // superpixel_size
33 | self.num_players = self.small_width * self.small_height
34 |
35 | def __call__(self, x, S):
36 | '''
37 | Evaluate with subset of features.
38 | Args:
39 | x: input examples.
40 | S: coalitions.
41 | '''
42 | raise NotImplementedError
43 |
44 | def resize(self, S):
45 | '''
46 | Resize coalition variable S into grid shape.
47 | Args:
48 | S: coalitions.
49 | '''
50 | if len(S.shape) == 2:
51 | S = S.reshape(S.shape[0], self.small_height,
52 | self.small_width).unsqueeze(1)
53 | return self.upsample(S)
54 |
55 |
56 | class BaselineImageImputer(ImageImputer):
57 | '''
58 | Evaluate image model while replacing features with baseline values.
59 | Args:
60 | model: predictive model.
61 | baseline: baseline value(s).
62 | width: image width.
63 | height: image height.
64 | superpixel_size: superpixel width/height (int).
65 | link: link function (e.g., nn.Softmax).
66 | '''
67 |
68 | def __init__(self, model, baseline, width, height, superpixel_size,
69 | link=None):
70 | super().__init__(width, height, superpixel_size)
71 | self.model = model
72 | self.baseline = baseline
73 |
74 | # Set up link.
75 | if link is None:
76 | self.link = nn.Identity()
77 | elif isinstance(link, nn.Module):
78 | self.link = link
79 | else:
80 | raise ValueError('unsupported link function: {}'.format(link))
81 |
82 | def __call__(self, x, S):
83 | '''
84 | Evaluate model using baseline values.
85 | '''
86 | S = self.resize(S)
87 | x_baseline = S * x + (1 - S) * self.baseline
88 | return self.link(self.model(x_baseline))
89 |
90 |
--------------------------------------------------------------------------------
/weightedSHAP/third_party/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | ResNet in PyTorch.
3 | This implementation is based on kuangliu's code
4 | https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
5 | and the PyTorch reference implementation
6 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
7 | '''
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, in_planes, planes, stride=1):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
20 | stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion*planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion*planes,
27 | kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1):
43 | super(Bottleneck, self).__init__()
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
47 | stride=stride, padding=1, bias=False)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.conv3 = nn.Conv2d(planes, self.expansion *
50 | planes, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
52 |
53 | self.shortcut = nn.Sequential()
54 | if stride != 1 or in_planes != self.expansion*planes:
55 | self.shortcut = nn.Sequential(
56 | nn.Conv2d(in_planes, self.expansion*planes,
57 | kernel_size=1, stride=stride, bias=False),
58 | nn.BatchNorm2d(self.expansion*planes)
59 | )
60 |
61 | def forward(self, x):
62 | out = F.relu(self.bn1(self.conv1(x)))
63 | out = F.relu(self.bn2(self.conv2(out)))
64 | out = self.bn3(self.conv3(out))
65 | out += self.shortcut(x)
66 | out = F.relu(out)
67 | return out
68 |
69 |
70 | class ResNet(nn.Module):
71 | def __init__(self, block, num_blocks, num_classes, in_channels):
72 | super(ResNet, self).__init__()
73 | self.in_planes = 64
74 |
75 | # Input conv.
76 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
77 | stride=1, padding=1, bias=False)
78 | self.bn1 = nn.BatchNorm2d(64)
79 |
80 | # Residual blocks.
81 | channels = 64
82 | stride = 1
83 | blocks = []
84 | for num in num_blocks:
85 | blocks.append(self._make_layer(block, channels, num, stride=stride))
86 | channels *= 2
87 | stride = 2
88 | self.layers = nn.ModuleList(blocks)
89 |
90 | # Output layer.
91 | self.num_classes = num_classes
92 | if num_classes is not None:
93 | self.linear = nn.Linear(512*block.expansion, num_classes)
94 |
95 | def _make_layer(self, block, planes, num_blocks, stride):
96 | strides = [stride] + [1]*(num_blocks-1)
97 | layers = []
98 | for stride in strides:
99 | layers.append(block(self.in_planes, planes, stride))
100 | self.in_planes = planes * block.expansion
101 | return nn.Sequential(*layers)
102 |
103 | def forward(self, x):
104 | # Input conv.
105 | out = F.relu(self.bn1(self.conv1(x)))
106 |
107 | # Residual blocks.
108 | for layer in self.layers:
109 | out = layer(out)
110 |
111 | # Output layer.
112 | if self.num_classes is not None:
113 | out = F.avg_pool2d(out, 4)
114 | out = out.view(out.size(0), -1)
115 | out = self.linear(out)
116 |
117 | return out
118 |
119 |
120 | def ResNet18(num_classes, in_channels=3):
121 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, in_channels)
122 |
123 |
124 | def ResNet34(num_classes, in_channels=3):
125 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, in_channels)
126 |
127 |
128 | def ResNet50(num_classes, in_channels=3):
129 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, in_channels)
130 |
--------------------------------------------------------------------------------
/weightedSHAP/weightedSHAPEngine.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 |
4 | # custom modules
5 | from weightedSHAP import data, utils
6 | from weightedSHAP.third_party import behavior
7 |
8 | semivalue_list=[(-1,'LOO-First'), (1,32), (1,16), (1, 8), (1, 4), (1, 2),
9 | (1,1), (2,1), (4,1), (8, 1), (16, 1), (32,1), (-1,'LOO-Last')]
10 | attribution_list=['LOO-First', 'Beta(32,1)', 'Beta(16,1)', 'Beta(8,1)',
11 | 'Beta(4,1)', 'Beta(2,1)', 'Beta(1,1)', 'Beta(1,2)',
12 | 'Beta(1,4)', 'Beta(1,8)', 'Beta(1,16)', 'Beta(1,32)', 'LOO-Last']
13 |
14 | def MarginalContributionValue(game, thresh=1.005, batch_size=1, n_check_period=100):
15 | '''Calculate feature attributions using the marginal contributions.'''
16 |
17 | # index of the added feature, cardinality of set
18 | arange = np.arange(batch_size)
19 | output = game(np.zeros(game.players, dtype=bool))
20 | MC_size=[game.players, game.players] + list(output.shape)
21 | MC_mat=np.zeros(MC_size)
22 | MC_count=np.zeros((game.players, game.players))
23 |
24 | converged = False
25 | n_iter = 0
26 | marginal_contribs=np.zeros([0, np.prod([game.players] + list(output.shape))])
27 | while not converged:
28 | for _ in range(n_check_period):
29 | # Sample permutations.
30 | permutations = np.tile(np.arange(game.players), (batch_size, 1))
31 | for row in permutations:
32 | np.random.shuffle(row)
33 | S = np.zeros((batch_size, game.players), dtype=bool)
34 |
35 | # Unroll permutations.
36 | prev_value = game(S)
37 | marginal_contribs_tmp = np.zeros(([batch_size, game.players] + list(output.shape)))
38 | for j in range(game.players):
39 | '''
40 | Marginal contribution estimates with respect to j samples
41 | j = 0 means LOO-First
42 | '''
43 | S[arange, permutations[:, j]] = 1
44 | next_value = game(S)
45 | MC_mat[permutations[:, j], j] += (next_value - prev_value)
46 | MC_count[permutations[:, j], j] += 1
47 | marginal_contribs_tmp[arange, permutations[:, j]] = (next_value - prev_value)
48 |
49 | # update
50 | prev_value = next_value
51 | marginal_contribs=np.concatenate([marginal_contribs, marginal_contribs_tmp.reshape(batch_size,-1)], axis=0)
52 |
53 | if (n_iter+1) == 100:
54 | converged=True
55 | elif (n_iter+1) >= 2:
56 | if utils.check_convergence(marginal_contribs) < thresh:
57 | print(f'Therehosld: {int(0.999*(game.players*n_check_period))}')
58 | converged=True
59 | else:
60 | pass
61 |
62 | n_iter += 1
63 |
64 | print(f'We have seen {((n_iter+1)*n_check_period*batch_size)} random subsets for each feature.')
65 | if len(MC_mat.shape) != 2:
66 | # classification case
67 | MC_count=np.repeat(MC_count, MC_mat.shape[-1], axis=-1).reshape(MC_mat.shape)
68 |
69 | return MC_mat, MC_count
70 |
71 |
72 | def compute_attributions(problem, ML_model,
73 | model_to_explain, conditional_extension,
74 | X_train, y_train, X_val, y_val, X_test, y_test, n_max=100):
75 | '''
76 | Compute attribution values and evaluate its performance
77 | '''
78 | pred_list, pred_masking = [], []
79 | cond_pred_keep_absolute, cond_pred_remove_absolute=[], []
80 | value_list=[]
81 | n_max=min(n_max, len(X_test))
82 | for ind in tqdm(range(n_max)):
83 | # Store original prediction
84 | original_pred=utils.compute_predict(model_to_explain, X_test[ind,:].reshape(1,-1), problem, ML_model)
85 | pred_list.append(original_pred)
86 |
87 | # Estimate marginal contributions
88 | conditional_game=behavior.PredictionGame(conditional_extension, X_test[ind, :])
89 | MC_conditional_mat, MC_conditional_count=MarginalContributionValue(conditional_game)
90 | MC_est=np.array(MC_conditional_mat/(MC_conditional_count+1e-16))
91 |
92 | # Optimize weight for WeightedSHAP (By default, AUP is used)
93 | attribution_dict_all=utils.compute_semivalue_from_MC(MC_est, semivalue_list)
94 | cond_pred_keep_absolute_list=utils.compute_cond_pred_list(attribution_dict_all, conditional_game)
95 | AUP_list=np.sum(np.abs(np.array(cond_pred_keep_absolute_list)- original_pred), axis=1)
96 | WeightedSHAP_index=np.argmin(AUP_list)
97 | value_list.append(attribution_dict_all[attribution_list[WeightedSHAP_index]])
98 |
99 | '''
100 | Evaluation
101 | '''
102 | # Conditional prediction from most important to least important (keep absolte)
103 | cond_pred_keep_absolute_list=utils.compute_cond_pred_list(attribution_dict_all, conditional_game)
104 | cond_pred_keep_absolute.append(cond_pred_keep_absolute_list)
105 |
106 | exp_dict=dict()
107 | exp_dict['value_list']=value_list
108 | exp_dict['true_list']=np.array(y_test)[:n_max]
109 | exp_dict['pred_list']=np.array(pred_list)
110 | exp_dict['input_list']=np.array(X_test)[:n_max]
111 | exp_dict['cond_pred_keep_absolute']=np.array(cond_pred_keep_absolute)
112 |
113 | return exp_dict
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/weightedSHAP/third_party/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/utils.py
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | import itertools
8 | from torch.utils.data import Dataset
9 |
10 | class MaskLayer1d(nn.Module):
11 | '''
12 | Masking for 1d inputs.
13 |
14 | Args:
15 | append: whether to append the mask along channels dim.
16 | value: replacement value for held out features.
17 | '''
18 | def __init__(self, append=True, value=0):
19 | super().__init__()
20 | self.append = append
21 | self.value = value
22 |
23 | def forward(self, input_tuple):
24 | x, S = input_tuple
25 | x = x * S + self.value * (1 - S)
26 | if self.append:
27 | x = torch.cat((x, S), dim=1)
28 | return x
29 |
30 | class MaskLayer2d(nn.Module):
31 | '''
32 | Masking for 2d inputs.
33 |
34 | Args:
35 | append: whether to append the mask along channels dim.
36 | value: replacement value for held out features.
37 | '''
38 | def __init__(self, append=True, value=0):
39 | super().__init__()
40 | self.append = append
41 | self.value = value
42 |
43 | def forward(self, input_tuple):
44 | '''
45 | Apply mask to input.
46 |
47 | Args:
48 | input_tuple: tuple of input x and mask S.
49 | '''
50 | x, S = input_tuple
51 | x = x * S + self.value * (1 - S)
52 | if self.append:
53 | x = torch.cat((x, S), dim=1)
54 | return x
55 |
56 | class KLDivLoss(nn.Module):
57 | '''
58 | KL divergence loss that applies log softmax operation to predictions.
59 | Args:
60 | reduction: how to reduce loss value (e.g., 'batchmean').
61 | log_target: whether the target is expected as a log probabilities (or as
62 | probabilities).
63 | '''
64 |
65 | def __init__(self, reduction='batchmean', log_target=False):
66 | super().__init__()
67 | self.kld = nn.KLDivLoss(reduction=reduction, log_target=log_target)
68 |
69 | def forward(self, pred, target):
70 | '''
71 | Evaluate loss.
72 | Args:
73 | pred:
74 | target:
75 | '''
76 | return self.kld(pred.log_softmax(dim=1), target)
77 |
78 | class MSELoss(nn.Module):
79 | '''
80 | MSE loss.
81 | Args:
82 | reduction: how to reduce loss value (e.g., 'batchmean').
83 | log_target: whether the target is expected as a log probabilities (or as
84 | probabilities).
85 | '''
86 |
87 | def __init__(self, reduction='mean'):
88 | super().__init__()
89 | self.mseloss = nn.MSELoss(reduction=reduction)
90 |
91 | def forward(self, pred, target):
92 | '''
93 | Evaluate loss.
94 | Args:
95 | pred:
96 | target:
97 | '''
98 | return self.mseloss(pred, target)
99 |
100 | class UniformSampler:
101 | '''
102 | For sampling player subsets with cardinality chosen uniformly at random.
103 | Args:
104 | num_players: number of players.
105 | '''
106 |
107 | def __init__(self, num_players):
108 | self.num_players = num_players
109 |
110 | def sample(self, batch_size):
111 | '''
112 | Generate sample.
113 | Args:
114 | batch_size:
115 | '''
116 | S = torch.ones(batch_size, self.num_players, dtype=torch.float32)
117 | num_included = (torch.rand(batch_size) * (self.num_players + 1)).int()
118 | # TODO ideally avoid for loops
119 | # TODO ideally pass buffer to assign samples in place
120 | for i in range(batch_size):
121 | S[i, num_included[i]:] = 0
122 | S[i] = S[i, torch.randperm(self.num_players)]
123 |
124 | return S
125 |
126 | class DatasetRepeat(Dataset):
127 | '''
128 | A wrapper around multiple datasets that allows repeated elements when the
129 | dataset sizes don't match. The number of elements is the maximum dataset
130 | size, and all datasets must be broadcastable to the same size.
131 | Args:
132 | datasets: list of dataset objects.
133 | '''
134 |
135 | def __init__(self, datasets):
136 | # Get maximum number of elements.
137 | assert np.all([isinstance(dset, Dataset) for dset in datasets])
138 | items = [len(dset) for dset in datasets]
139 | num_items = np.max(items)
140 |
141 | # Ensure all datasets align.
142 | # assert np.all([num_items % num == 0 for num in items])
143 | self.dsets = datasets
144 | self.num_items = num_items
145 | self.items = items
146 |
147 | def __getitem__(self, index):
148 | assert 0 <= index < self.num_items
149 | return_items = [dset[index % num] for dset, num in
150 | zip(self.dsets, self.items)]
151 | return tuple(itertools.chain(*return_items))
152 |
153 | def __len__(self):
154 | return self.num_items
155 |
156 | class DatasetInputOnly(Dataset):
157 | '''
158 | A wrapper around a dataset object to ensure that only the first element is
159 | returned.
160 | Args:
161 | dataset: dataset object.
162 | '''
163 |
164 | def __init__(self, dataset):
165 | assert isinstance(dataset, Dataset)
166 | self.dataset = dataset
167 |
168 | def __getitem__(self, index):
169 | return (self.dataset[index][0],)
170 |
171 | def __len__(self):
172 | return len(self.dataset)
173 |
174 |
--------------------------------------------------------------------------------
/weightedSHAP/utils.py:
--------------------------------------------------------------------------------
1 | import os, sys, inspect, pickle
2 | import numpy as np
3 | from weightedSHAP import train
4 |
5 | def crossentropyloss(pred, target):
6 | '''Cross entropy loss that does not average across samples.'''
7 | if pred.ndim == 1:
8 | pred = pred[:, np.newaxis]
9 | pred = np.concatenate((1 - pred, pred), axis=1)
10 |
11 | if pred.shape == target.shape:
12 | # Soft cross entropy loss.
13 | pred = np.clip(pred, a_min=1e-12, a_max=1-1e-12)
14 | return - np.sum(np.log(pred) * target, axis=1)
15 | else:
16 | # Standard cross entropy loss.
17 | return - np.log(pred[np.arange(len(pred)), target])
18 |
19 | def mseloss(pred, target):
20 | '''MSE loss that does not average across samples.'''
21 | return np.sum((pred - target) ** 2, axis=1)
22 |
23 | def beta_constant(a, b):
24 | '''
25 | the second argument (b; beta) should be integer in this function
26 | '''
27 | beta_fct_value=1/a
28 | for i in range(1,b):
29 | beta_fct_value=beta_fct_value*(i/(a+i))
30 | return beta_fct_value
31 |
32 | def compute_weight_list(m, alpha=1, beta=1):
33 | '''
34 | Given a prior distribution (beta distribution (alpha,beta))
35 | beta_constant(j+1, m-j) = j! (m-j-1)! / (m-1)! / m # which is exactly the Shapley weights.
36 |
37 | # weight_list[n] is a weight when baseline model uses 'n' samples (w^{(n)}(j)*binom{n-1}{j} in the paper).
38 | '''
39 | weight_list=np.zeros(m)
40 | normalizing_constant=1/beta_constant(alpha, beta)
41 | for j in np.arange(m):
42 | # when the cardinality of random sets is j
43 | weight_list[j]=beta_constant(j+alpha, m-j+beta-1)/beta_constant(j+1, m-j)
44 | weight_list[j]=normalizing_constant*weight_list[j] # we need this '/m' but omit for stability # normalizing
45 | return weight_list/np.sum(weight_list)
46 |
47 | def compute_semivalue_from_MC(marginal_contrib, semivalue_list):
48 | '''
49 | With the marginal contribution values, it computes semivalues
50 |
51 | '''
52 | semivalue_dict={}
53 | n_elements=marginal_contrib.shape[0]
54 | for weight in semivalue_list:
55 | alpha, beta=weight
56 | if alpha > 0:
57 | model_name=f'Beta({beta},{alpha})'
58 | weight_list=compute_weight_list(m=n_elements, alpha=alpha, beta=beta)
59 | else:
60 | if beta == 'LOO-First':
61 | model_name='LOO-First'
62 | weight_list=np.zeros(n_elements)
63 | weight_list[0]=1
64 | elif beta == 'LOO-Last':
65 | model_name='LOO-Last'
66 | weight_list=np.zeros(n_elements)
67 | weight_list[-1]=1
68 |
69 | if len(marginal_contrib.shape) == 2:
70 | semivalue_tmp=np.einsum('ij,j->i', marginal_contrib, weight_list)
71 | else:
72 | # classification case
73 | semivalue_tmp=np.einsum('ijk,j->ik', marginal_contrib, weight_list)
74 | semivalue_dict[model_name]=semivalue_tmp
75 | return semivalue_dict
76 |
77 | def check_convergence(mem, n_require=100):
78 | """
79 | Compute Gelman-Rubin statistic
80 | Ref. https://arxiv.org/pdf/1812.09384.pdf (p.7, Eq.4)
81 | """
82 | if len(mem) < n_require:
83 | return 100
84 | n_chains=10
85 | (N,n_to_be_valued)=mem.shape
86 | if (N % n_chains) == 0:
87 | n_MC_sample=N//n_chains
88 | offset=0
89 | else:
90 | n_MC_sample=N//n_chains
91 | offset=(N%n_chains)
92 | mem=mem[offset:]
93 | percent=25
94 | while True:
95 | IQR_contstant=np.percentile(mem.reshape(-1), 50+percent) - np.percentile(mem.reshape(-1), 50-percent)
96 | if IQR_contstant == 0:
97 | percent=(50+percent)//2
98 | if percent >= 49:
99 | assert False, 'CHECK!!! IQR is zero!!!'
100 | else:
101 | break
102 |
103 | mem_tmp=mem.reshape(n_chains, n_MC_sample, n_to_be_valued)
104 | GR_list=[]
105 | for j in range(n_to_be_valued):
106 | mem_tmp_j_original=mem_tmp[:,:,j].T # now we have (n_MC_sample, n_chains)
107 | mem_tmp_j=mem_tmp_j_original/IQR_contstant
108 | mem_tmp_j_mean=np.mean(mem_tmp_j, axis=0)
109 | s_term=np.sum((mem_tmp_j-mem_tmp_j_mean)**2)/(n_chains*(n_MC_sample-1)) # + 1e-16 this could lead to wrong estimator
110 | if s_term == 0:
111 | continue
112 | mu_hat_j=np.mean(mem_tmp_j)
113 | B_term=n_MC_sample*np.sum((mem_tmp_j_mean-mu_hat_j)**2)/(n_chains-1)
114 |
115 | GR_stat=np.sqrt((n_MC_sample-1)/n_MC_sample + B_term/(s_term*n_MC_sample))
116 | GR_list.append(GR_stat)
117 | GR_stat=np.max(GR_list)
118 | print(f'Total number of random sets: {len(mem)}, GR_stat: {GR_stat}', flush=True)
119 | return GR_stat
120 |
121 | def compute_cond_pred_list(attribution_dict, game, more_important_first=True):
122 | n_features=game.players
123 | n_max_features=n_features # min(n_features, 200)
124 |
125 | cond_pred_list=[]
126 | for method in attribution_dict.keys():
127 | cond_pred_list_tmp=[]
128 | if more_important_first is True:
129 | # more important to less important (large to zero)
130 | sorted_index=np.argsort(np.abs(attribution_dict[method]))[::-1]
131 | else:
132 | # less important to more important (zero to large)
133 | sorted_index=np.argsort(np.abs(attribution_dict[method]))
134 |
135 | for n_top in range(n_max_features+1):
136 | top_index=sorted_index[:n_top]
137 | S=np.zeros(n_features, dtype=bool)
138 | S[top_index]=True
139 |
140 | # prediction recovery error
141 | cond_pred_list_tmp.append(game(S))
142 | cond_pred_list.append(cond_pred_list_tmp)
143 |
144 | return cond_pred_list
145 |
146 | def compute_pred_maksing_list(attribution_dict, model_to_explain, x, problem, ML_model, more_important_first=True):
147 | n_features=x.shape[1]
148 | n_max_features=n_features # min(n_features, 200)
149 |
150 | pred_masking_list=[]
151 | for method in attribution_dict.keys():
152 | pred_masking_list_tmp=[]
153 | if more_important_first is True:
154 | # more important to less important (large to zero)
155 | sorted_index=np.argsort(np.abs(attribution_dict[method]))[::-1]
156 | else:
157 | # less important to more important (zero to large)
158 | sorted_index=np.argsort(np.abs(attribution_dict[method]))
159 |
160 | for n_top in range(n_max_features+1):
161 | top_index=sorted_index[:n_top]
162 | curr_x=np.zeros((1,n_features)) # Input matrix is standardized
163 | curr_x[0, top_index] = x[0, top_index]
164 |
165 | # prediction recovery error
166 | curr_pred=compute_predict(model_to_explain, curr_x, problem, ML_model)
167 | pred_masking_list_tmp.append(curr_pred)
168 | pred_masking_list.append(pred_masking_list_tmp)
169 |
170 | return pred_masking_list
171 |
172 | def compute_predict(model_to_explain, x, problem, ML_model):
173 | if (ML_model == 'linear') and (problem == 'classification'):
174 | return float(model_to_explain.predict_proba(x)[:,1])
175 | else:
176 | return float(model_to_explain.predict(x))
177 |
178 |
--------------------------------------------------------------------------------
/weightedSHAP/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pickle
4 | from sklearn.model_selection import train_test_split
5 | from sklearn.datasets import load_boston
6 | from sklearn.preprocessing import StandardScaler
7 |
8 | def load_data(problem, dataset, dir_path, random_factor=2, random_seed=2022):
9 | '''
10 | Load a dataset
11 | We split datasets as Train:Val:Est:Test=7:1:1:1
12 | Train: to train a model
13 | Val: to optimize hyperparameters
14 | Est: to estimate coalition functions
15 | Test: to evaluate performance
16 | '''
17 | print('-'*30)
18 | print('Load a dataset')
19 | print('-'*30)
20 |
21 | if problem=='regression':
22 | (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)=load_regression_dataset(dataset=dataset, dir_path=dir_path, rid=random_seed)
23 | elif problem=='classification':
24 | (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)=load_classification_dataset(dataset=dataset, dir_path=dir_path, rid=random_seed)
25 | else:
26 | raise NotImplementedError('Check problem')
27 |
28 | if random_factor != 0:
29 | # We add noisy features to the original dataset.
30 | print('-'*30)
31 | print('Before adding noise')
32 | print(f'Shape of X_train, X_val, X_est, X_test: {X_train.shape}, {X_val.shape}, {X_est.shape}, {X_test.shape}')
33 | print('-'*30)
34 | dim_noise=int(X_train.shape[1]*random_factor)
35 | X_train=extend_dataset(X_train, dim_noise, verbose=True)
36 | X_val=extend_dataset(X_val, dim_noise)
37 | X_est=extend_dataset(X_est, dim_noise)
38 | X_test=extend_dataset(X_test, dim_noise)
39 | print('After adding noise')
40 | print(f'Shape of X_train, X_val, X_est, X_test: {X_train.shape}, {X_val.shape}, {X_est.shape}, {X_test.shape}')
41 | print('-'*30)
42 | else:
43 | # We use the original dataset.
44 | print('-'*30)
45 | print(f'Shape of X_train, X_val, X_est, X_test: {X_train.shape}, {X_val.shape}, {X_est.shape}, {X_test.shape}')
46 | print('-'*30)
47 |
48 | return (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)
49 |
50 | def load_regression_dataset(dataset='abalone', dir_path='dir_path', rid=1):
51 | '''
52 | This function loads regression datasets.
53 | dir_path: path to regression datasets.
54 |
55 | You may need to download datasets first. Make sure to store in 'dir_path'.
56 | The datasets are avaiable at the following links.
57 | abalone: https://archive.ics.uci.edu/ml/machine-learning-databases/abalone/
58 | airfoil: https://archive.ics.uci.edu/ml/machine-learning-databases/00291/
59 | whitewine: https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/
60 | '''
61 | np.random.seed(rid)
62 |
63 | if dataset == 'boston':
64 | print('-'*50)
65 | print('Boston')
66 | print('-'*50)
67 | data=load_boston()
68 | X, y=data['data'], data['target']
69 | elif dataset == 'abalone':
70 | print('-'*50)
71 | print('Abalone')
72 | print('-'*50)
73 | raw_data = pd.read_csv(dir_path+"/abalone.data", header=None)
74 | raw_data.dropna(inplace=True)
75 | X, y = pd.get_dummies(raw_data.iloc[:,:-1],drop_first=True).values, raw_data.iloc[:,-1].values
76 | elif dataset == 'whitewine':
77 | print('-'*50)
78 | print('whitewine')
79 | print('-'*50)
80 | raw_data = pd.read_csv(dir_path+"/winequality-white.csv",delimiter=";")
81 | raw_data.dropna(inplace=True)
82 | X, y = raw_data.values[:,:-1], raw_data.values[:,-1]
83 | elif dataset == 'airfoil':
84 | print('-'*50)
85 | print('airfoil')
86 | print('-'*50)
87 | raw_data = pd.read_csv(dir_path+"/airfoil_self_noise.dat", sep='\t', names=['X1','X2,','X3','X4','X5','Y'])
88 | X, y = raw_data.values[:,:-1], raw_data.values[:,-1]
89 | else:
90 | raise NotImplementedError(f'Check {dataset}')
91 |
92 | X = standardize_data(X)
93 | X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.1)
94 | X_train, X_val, y_train, y_val=train_test_split(X_train, y_train, test_size=float(1/9))
95 | X_train, X_est, y_train, y_est=train_test_split(X_train, y_train, test_size=float(1/8))
96 |
97 | return (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)
98 |
99 | def load_classification_dataset(dataset='gaussian', dir_path='dir_path', rid=1):
100 | '''
101 | This function loads classification datasets.
102 | dir_path: path to classification datasets.
103 | '''
104 | np.random.seed(rid)
105 |
106 | if dataset == 'gaussian':
107 | print('-'*50)
108 | print('Gaussian')
109 | print('-'*50)
110 | n, input_dim, rho=10000, 10, 0.25
111 | U_cov=np.diag((1-rho)*np.ones(input_dim))+rho
112 | U_mean=np.zeros(input_dim)
113 | X=np.random.multivariate_normal(U_mean, U_cov, n)
114 |
115 | beta_true=(np.linspace(input_dim,(41*input_dim/50),input_dim)/input_dim).reshape(input_dim,1)
116 | p_true=np.exp(X.dot(beta_true))/(1.+np.exp(X.dot(beta_true)))
117 | y=np.random.binomial(n=1, p=p_true).reshape(-1)
118 | elif dataset == 'fraud':
119 | print('-'*50)
120 | print('Fraud Detection')
121 | print('-'*50)
122 | data_dict=pickle.load(open(f'{dir_path}/fraud_dataset.pkl', 'rb'))
123 | data, target = data_dict['X_num'], data_dict['y']
124 | target = (target == 1) + 0.0
125 | target = target.astype(np.int32)
126 | X, y=make_balance_sample(data, target)
127 | else:
128 | raise NotImplementedError(f'Check {dataset}')
129 |
130 | X = standardize_data(X)
131 | X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.1)
132 | X_train, X_val, y_train, y_val=train_test_split(X_train, y_train, test_size=float(1/9))
133 | X_train, X_est, y_train, y_est=train_test_split(X_train, y_train, test_size=float(1/8))
134 |
135 | return (X_train, y_train), (X_val, y_val), (X_est, y_est), (X_test, y_test)
136 |
137 |
138 | '''
139 | Data utils
140 | '''
141 |
142 | def extend_dataset(X, d_to_add, verbose=False):
143 | n, d_prev = X.shape
144 | rho=(np.sum((X.T.dot(X)/n)[:d_prev,:d_prev])-d_prev)/(d_prev*(d_prev-1))
145 |
146 | if -1/(4*(d_prev-1)+1e-16) > rho:
147 | if verbose is True:
148 | # if rho is too small, the sigma_square defined below can be negative
149 | print(f'Initial rho: {rho:.4f}')
150 | rho=max(-1/(4*(d_prev-1)+1e-16), rho)
151 | print(f'After fixing rho: {rho:.4f}')
152 | else:
153 | if verbose is True:
154 | print(f'Rho: {rho:.4f}')
155 |
156 | for _ in range(d_to_add):
157 | sigma_square=1-(rho**2)*(d_prev)/(1+rho*(d_prev-1)+1e-16)
158 | new_X = (rho/(1+rho*(d_prev-1)+1e-16))*X.dot(np.ones((d_prev,1)))
159 | new_X += np.random.normal(size=(n,1))*np.sqrt(sigma_square)
160 | X = np.concatenate((X, new_X), axis=1)
161 | d_prev += 1
162 | return X
163 |
164 | def make_balance_sample(data, target):
165 | p = np.mean(target)
166 | minor_class=1 if p < 0.5 else 0
167 |
168 | index_minor_class = np.where(target == minor_class)[0]
169 | n_minor_class=len(index_minor_class)
170 | n_major_class=len(target)-n_minor_class
171 | new_minor=np.random.choice(index_minor_class, size=n_major_class-n_minor_class, replace=True)
172 |
173 | data=np.concatenate([data, data[new_minor]])
174 | target=np.concatenate([target, target[new_minor]])
175 | return data, target
176 |
177 | def standardize_data(X):
178 | ss=StandardScaler()
179 | ss.fit(X)
180 | try:
181 | X = ss.transform(X.values)
182 | except:
183 | X = ss.transform(X)
184 | return X
185 |
186 |
187 |
188 |
--------------------------------------------------------------------------------
/weightedSHAP/third_party/removal.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://github.com/iancovert/removal-explanations/blob/main/rexplain/removal.py
3 | '''
4 | import numpy as np
5 |
6 | class MarginalExtension:
7 | '''Extend a model by marginalizing out removed features using their
8 | marginal distribution.'''
9 | def __init__(self, data, model):
10 | self.model = model
11 | self.data = data
12 | self.data_repeat = data
13 | self.samples = len(data)
14 | # self.x_addr = None
15 | # self.x_repeat = None
16 |
17 | def __call__(self, x, S):
18 | # Prepare x and S.
19 | n = len(x)
20 | x = x.repeat(self.samples, 0)
21 | S = S.repeat(self.samples, 0)
22 |
23 | # Prepare samples.
24 | if len(self.data_repeat) != self.samples * n:
25 | self.data_repeat = np.tile(self.data, (n, 1))
26 |
27 | # Replace specified indices.
28 | x_ = x.copy()
29 | x_[~S] = self.data_repeat[~S]
30 |
31 | # Make predictions.
32 | pred = self.model(x_)
33 | pred = pred.reshape(-1, self.samples, *pred.shape[1:])
34 | return np.mean(pred, axis=1)
35 |
36 |
37 | class ConditionalSupervisedExtension:
38 | '''Extend a model using a supervised surrogate model.'''
39 | def __init__(self, surrogate):
40 | self.surrogate = surrogate
41 |
42 | def __call__(self, x, S):
43 | return self.surrogate(x, S)
44 |
45 | class DefaultExtension:
46 | '''Extend a model by replacing removed features with default values.'''
47 | def __init__(self, values, model):
48 | self.model = model
49 | if values.ndim == 1:
50 | values = values[np.newaxis]
51 | elif values[0] != 1:
52 | raise ValueError('values shape must be (dim,) or (1, dim)')
53 | self.values = values
54 | self.values_repeat = values
55 |
56 | def __call__(self, x, S):
57 | # Prepare x.
58 | if len(x) != len(self.values_repeat):
59 | self.values_repeat = self.values.repeat(len(x), 0)
60 |
61 | # Replace specified indices.
62 | x_ = x.copy()
63 | x_[~S] = self.values_repeat[~S]
64 |
65 | # Make predictions.
66 | return self.model(x_)
67 |
68 | class MarginalExtensionApprox:
69 | '''Extend a model by marginalizing out removed features using their
70 | marginal distribution.'''
71 | def __init__(self, data_mean, model, grad_array):
72 | self.model = model
73 | self.data_mean = data_mean
74 | self.grad=grad_array
75 |
76 | def __call__(self, x, S):
77 | # Prepare samples.
78 | n=len(x)
79 | if len(self.data_mean) != n:
80 | self.data_repeat = np.tile(self.data_mean, (n, 1))
81 |
82 | # Replace specified indices.
83 | x_ = x.copy()
84 | x_[~S] = self.data_repeat[~S]
85 |
86 | # Make predictions.
87 | pred = self.model(x)
88 | pred += (x_-x).dot(self.grad)
89 | pred = pred.reshape(-1, 1, *pred.shape[1:])
90 | return np.mean(pred, axis=1)
91 |
92 | class UniformExtension:
93 | '''Extend a model by marginalizing out removed features using a
94 | uniform distribution.'''
95 | def __init__(self, values, categorical_inds, samples, model):
96 | self.model = model
97 | self.values = values
98 | self.categorical_inds = categorical_inds
99 | self.samples = samples
100 |
101 | def __call__(self, x, S):
102 | # Prepare x and S.
103 | n = len(x)
104 | x = x.repeat(self.samples, 0)
105 | S = S.repeat(self.samples, 0)
106 |
107 | # Prepare samples.
108 | samples = np.zeros((n * self.samples, x.shape[1]))
109 | for i in range(x.shape[1]):
110 | if i in self.categorical_inds:
111 | inds = np.random.choice(
112 | len(self.values[i]), n * self.samples)
113 | samples[:, i] = self.values[i][inds]
114 | else:
115 | samples[:, i] = np.random.uniform(
116 | low=self.values[i][0], high=self.values[i][1],
117 | size=n * self.samples)
118 |
119 | # Replace specified indices.
120 | x_ = x.copy()
121 | x_[~S] = samples[~S]
122 |
123 | # Make predictions.
124 | pred = self.model(x_)
125 | pred = pred.reshape(-1, self.samples, *pred.shape[1:])
126 | return np.mean(pred, axis=1)
127 |
128 |
129 | class UniformContinuousExtension:
130 | '''
131 | Extend a model by marginalizing out removed features using a
132 | uniform distribution. Specific to sets of continuous features.
133 |
134 | TODO: should we have caching here for repeating x?
135 |
136 | '''
137 | def __init__(self, min_vals, max_vals, samples, model):
138 | self.model = model
139 | self.min = min_vals
140 | self.max = max_vals
141 | self.samples = samples
142 |
143 | def __call__(self, x, S):
144 | # Prepare x and S.
145 | x = x.repeat(self.samples, 0)
146 | S = S.repeat(self.samples, 0)
147 |
148 | # Prepare samples.
149 | u = np.random.uniform(size=x.shape)
150 | samples = u * self.min + (1 - u) * self.max
151 |
152 | # Replace specified indices.
153 | x_ = x.copy()
154 | x_[~S] = samples[~S]
155 |
156 | # Make predictions.
157 | pred = self.model(x_)
158 | pred = pred.reshape(-1, self.samples, *pred.shape[1:])
159 | return np.mean(pred, axis=1)
160 |
161 |
162 | class ProductMarginalExtension:
163 | '''Extend a model by marginalizing out removed features the
164 | product of their marginal distributions.'''
165 | def __init__(self, data, samples, model):
166 | self.model = model
167 | self.data = data
168 | self.data_repeat = data
169 | self.samples = samples
170 |
171 | def __call__(self, x, S):
172 | # Prepare x and S.
173 | n = len(x)
174 | x = x.repeat(self.samples, 0)
175 | S = S.repeat(self.samples, 0)
176 |
177 | # Prepare samples.
178 | samples = np.zeros((n * self.samples, x.shape[1]))
179 | for i in range(x.shape[1]):
180 | inds = np.random.choice(len(self.data), n * self.samples)
181 | samples[:, i] = self.data[inds, i]
182 |
183 | # Replace specified indices.
184 | x_ = x.copy()
185 | x_[~S] = samples[~S]
186 |
187 | # Make predictions.
188 | pred = self.model(x_)
189 | pred = pred.reshape(-1, self.samples, *pred.shape[1:])
190 | return np.mean(pred, axis=1)
191 |
192 |
193 | class SeparateModelExtension:
194 | '''Extend a model using separate models for each subset of features.'''
195 | def __init__(self, model_dict):
196 | self.model_dict = model_dict
197 |
198 | def __call__(self, x, S):
199 | output = []
200 | for i in range(len(S)):
201 | # Extract model.
202 | row = S[i]
203 | model = self.model_dict[str(row)]
204 |
205 | # Make prediction.
206 | output.append(model(x[i:i+1, row]))
207 |
208 | return np.concatenate(output, axis=0)
209 |
210 |
211 | class ConditionalExtension:
212 | '''Extend a model by marginalizing out removed features using a model of
213 | their conditional distribution.'''
214 | def __init__(self, conditional_model, samples, model):
215 | self.model = model
216 | self.conditional_model = conditional_model
217 | self.samples = samples
218 | self.x_addr = None
219 | self.x_repeat = None
220 |
221 | def __call__(self, x, S):
222 | # Prepare x.
223 | if self.x_addr != id(x):
224 | self.x_addr = id(x)
225 | self.x_repeat = x.repeat(self.samples, 0)
226 | x = self.x_repeat
227 |
228 | # Prepare samples.
229 | S = S.repeat(self.samples, 0)
230 | samples = self.conditional_model(x, S)
231 |
232 | # Replace specified indices.
233 | x_ = x.copy()
234 | x_[~S] = samples[~S]
235 |
236 | # Make predictions.
237 | pred = self.model(x_)
238 | pred = pred.reshape(-1, self.samples, *pred.shape[1:])
239 | return np.mean(pred, axis=1)
240 |
241 |
242 |
243 |
--------------------------------------------------------------------------------
/weightedSHAP/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from time import time
3 | import torch
4 | import torch.nn as nn
5 | from scipy.stats import ttest_ind
6 |
7 | from weightedSHAP.third_party import removal, surrogate
8 | from weightedSHAP.third_party.utils import MaskLayer1d, KLDivLoss, MSELoss
9 |
10 | def create_boosting_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model):
11 | print('Train a model to explain: Boosting')
12 | import lightgbm as lgb
13 | if problem == 'classification':
14 | params = {
15 | "learning_rate": 0.005,
16 | "objective": "binary",
17 | "metric": ["binary_logloss", "binary_error"],
18 | "num_threads": 1,
19 | "verbose": -1,
20 | "num_leaves":15,
21 | "bagging_fraction":0.5,
22 | "feature_fraction":0.5,
23 | "lambda_l2": 1e-3
24 | }
25 | else:
26 | params = {
27 | "learning_rate": 0.005,
28 | "objective": "mean_squared_error",
29 | "metric": "mean_squared_error",
30 | "num_threads": 1,
31 | "verbose": -1,
32 | "num_leaves":15,
33 | "bagging_fraction":0.5,
34 | "feature_fraction":0.5,
35 | "lambda_l2": 1e-3
36 | }
37 |
38 | d_train = lgb.Dataset(X_train, label=y_train)
39 | d_val = lgb.Dataset(X_val, label=y_val)
40 |
41 | callbacks=[lgb.early_stopping(25)]
42 | model_to_explain=lgb.train(params, d_train,
43 | num_boost_round=1000,
44 | valid_sets=[d_val],
45 | callbacks=callbacks)
46 |
47 | return model_to_explain
48 |
49 | def create_linear_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model):
50 | print('Train a model to explain: Linear')
51 | from sklearn.linear_model import LinearRegression, LogisticRegression
52 | if problem == 'classification':
53 | model_to_explain=LogisticRegression()
54 | model_to_explain.fit(X_train, y_train)
55 | else:
56 | model_to_explain=LinearRegression()
57 | model_to_explain.fit(X_train, y_train)
58 |
59 | return model_to_explain
60 |
61 | def create_MLP_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model):
62 | print('Train a model to explain: MLP')
63 | device = torch.device('cpu')
64 | num_features=X_train.shape[1]
65 | n_output=2 if problem=='classification' else 1
66 | model_to_explain = nn.Sequential(
67 | nn.Linear(num_features, 128),
68 | nn.ReLU(inplace=True),
69 | nn.Linear(128, 128),
70 | nn.ReLU(inplace=True),
71 | nn.Linear(128, n_output)).to(device)
72 |
73 | # training part
74 | return model_to_explain
75 |
76 | def check_overfitting(X_train, y_train, X_val, y_val, model_to_explain, problem, ML_model):
77 | '''
78 | Check overfitting
79 | '''
80 | if problem == 'classification':
81 | tmp_err=lambda y1, y2: ((y1 > 0.5) != y2) + 0.0
82 | else:
83 | tmp_err=lambda y1, y2: (y1-y2)**2
84 |
85 | if (ML_model == 'linear') and (problem == 'classification'):
86 | tr_pred_error=tmp_err(model_to_explain.predict_proba(X_train)[:,1], y_train)
87 | val_pred_error=tmp_err(model_to_explain.predict_proba(X_val)[:,1], y_val)
88 | elif ML_model == 'MLP':
89 | # not used
90 | y_train_pred=model_to_explain(torch.from_numpy(X_train.astype(np.float32))).detach().numpy().reshape(-1)
91 | y_val_pred=model_to_explain(torch.from_numpy(X_val.astype(np.float32))).detach().numpy().reshape(-1)
92 | tr_pred_error=tmp_err(y_train_pred, y_train)
93 | val_pred_error=tmp_err(y_val_pred, y_val)
94 | else:
95 | tr_pred_error=tmp_err(model_to_explain.predict(X_train), y_train)
96 | val_pred_error=tmp_err(model_to_explain.predict(X_val), y_val)
97 |
98 | p_value=ttest_ind(tr_pred_error, val_pred_error)[1]
99 | overfitting_check='Not overfitted' if p_value > 0.01 else 'Overfitted'
100 |
101 | tr_err, val_err = np.mean(tr_pred_error), np.mean(val_pred_error)
102 | print(f'Overfitting? / P-value: {overfitting_check} / {p_value:.4f}')
103 | print(f'Tr error, Val error: {tr_err:.3f}, {val_err:.3f}')
104 | return tr_err, val_err
105 |
106 | def create_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model):
107 | print('-'*30)
108 | print('Train a model')
109 | start_time=time()
110 | if ML_model=='linear':
111 | model_to_explain=create_linear_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model)
112 | elif ML_model=='boosting':
113 | model_to_explain=create_boosting_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model)
114 | elif ML_model=='MLP':
115 | model_to_explain=create_MLP_model_to_explain(X_train, y_train, X_val, y_val, problem, ML_model)
116 | else:
117 | raise ValueError(f'Check ML_model: {ML_model}')
118 |
119 | elapsed_time_train=time()-start_time
120 | print(f'Elapsed time for training a model to explain: {elapsed_time_train:.2f} seconds')
121 | print('-'*30)
122 | return model_to_explain # , tr_err, val_err
123 |
124 |
125 | def create_surrogate_model(model_to_explain, X_train, X_est, problem='classification', ML_model='linear', verbose=False):
126 | start_time=time()
127 | # [Step 1] Create surrogate model
128 | device = torch.device('cpu')
129 | num_features=X_train.shape[1]
130 | n_output=2 if problem=='classification' else 1
131 | surrogate_model = nn.Sequential(
132 | MaskLayer1d(value=0, append=True),
133 | nn.Linear(2 * num_features, 128),
134 | nn.ELU(inplace=True),
135 | nn.Linear(128, 128),
136 | nn.ELU(inplace=True),
137 | nn.Linear(128, n_output)).to(device)
138 |
139 | # Set up surrogate object
140 | surrogate_object = surrogate.Surrogate(surrogate_model, num_features)
141 |
142 | if problem=='classification':
143 | loss_fn=KLDivLoss()
144 | # predict_proba and predict
145 | def original_model(x):
146 | if ML_model == 'linear':
147 | pred = (model_to_explain.predict_proba(x.cpu().numpy())[:,1]).reshape(-1)
148 | else:
149 | pred = model_to_explain.predict(x.cpu().numpy())
150 |
151 | pred = np.stack([1 - pred, pred]).T
152 | return torch.tensor(pred, dtype=torch.float32, device=x.device)
153 | else:
154 | loss_fn=MSELoss()
155 | def original_model(x):
156 | pred = model_to_explain.predict(x.cpu().numpy()).reshape(-1, 1)
157 | return torch.tensor(pred, dtype=torch.float32, device=x.device)
158 |
159 | # Train
160 | surrogate_object.train_original_model(X_train,
161 | X_est,
162 | original_model,
163 | batch_size=64,
164 | max_epochs=100,
165 | loss_fn=loss_fn,
166 | validation_samples=128,
167 | validation_batch_size=(2**12),
168 | lookback=10,
169 | verbose=verbose)
170 | elapsed_time=time()-start_time
171 | print(f'Elapsed time for training a surrogate model: {elapsed_time:.2f} seconds')
172 | return surrogate_model
173 |
174 |
175 | def generate_coalition_function(model_to_explain, X_train, X_est,
176 | problem='classification', ML_model='linear', verbose=False):
177 | surrogate_model=create_surrogate_model(model_to_explain, X_train, X_est, problem, ML_model, verbose)
178 |
179 | device=torch.device('cpu')
180 | if problem == 'classification':
181 | model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),
182 | torch.tensor(S, dtype=torch.float32, device=device))).softmax(dim=-1).cpu().data.numpy().reshape(x.shape[0],-1)[:,1]
183 | elif problem == 'regression':
184 | model_condi_wrapper = lambda x, S: surrogate_model((torch.tensor(x, dtype=torch.float32, device=device),
185 | torch.tensor(S, dtype=torch.float32, device=device))).cpu().data.numpy().reshape(x.shape[0],-1)[:,0]
186 | else:
187 | raise ValueError(f'Check problem: {problem}')
188 | conditional_extension=removal.ConditionalSupervisedExtension(model_condi_wrapper)
189 |
190 | return conditional_extension
191 |
192 |
--------------------------------------------------------------------------------
/weightedSHAP/third_party/surrogate.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/surrogate.py
3 | '''
4 | import torch
5 | import torch.optim as optim
6 | import numpy as np
7 | from torch.utils.data import Dataset, TensorDataset, DataLoader
8 | from torch.utils.data import RandomSampler, BatchSampler
9 | from copy import deepcopy
10 | from tqdm.auto import tqdm
11 | from .utils import UniformSampler, DatasetRepeat
12 |
13 | def validate(surrogate, loss_fn, data_loader):
14 | '''
15 | Calculate mean validation loss.
16 | Args:
17 | loss_fn: loss function.
18 | data_loader: data loader.
19 | '''
20 | with torch.no_grad():
21 | # Setup.
22 | device = next(surrogate.surrogate.parameters()).device
23 | mean_loss = 0
24 | N = 0
25 |
26 | for x, y, S in data_loader:
27 | x = x.to(device)
28 | y = y.to(device)
29 | S = S.to(device)
30 | try:
31 | pred = surrogate(x, S)
32 | except:
33 | pred = surrogate(x, S.long())
34 | loss = loss_fn(pred, y)
35 | N += len(x)
36 | mean_loss += len(x) * (loss - mean_loss) / N
37 |
38 | return mean_loss
39 |
40 |
41 | def generate_labels(dataset, model, batch_size):
42 | '''
43 | Generate prediction labels for a set of inputs.
44 | Args:
45 | dataset: dataset object.
46 | model: predictive model.
47 | batch_size: minibatch size.
48 | '''
49 | with torch.no_grad():
50 | # Setup.
51 | preds = []
52 | if isinstance(model, torch.nn.Module):
53 | device = next(model.parameters()).device
54 | else:
55 | device = torch.device('cpu')
56 | loader = DataLoader(dataset, batch_size=batch_size)
57 |
58 | for (x,) in loader:
59 | pred = model(x.to(device)).cpu()
60 | preds.append(pred)
61 |
62 | return torch.cat(preds)
63 |
64 |
65 | class Surrogate:
66 | '''
67 | Wrapper around surrogate model.
68 | Args:
69 | surrogate: surrogate model.
70 | num_features: number of features.
71 | groups: (optional) feature groups, represented by a list of lists.
72 | '''
73 |
74 | def __init__(self, surrogate, num_features, groups=None):
75 | # Store surrogate model.
76 | self.surrogate = surrogate
77 |
78 | # Store feature groups.
79 | if groups is None:
80 | self.num_players = num_features
81 | self.groups_matrix = None
82 | else:
83 | # Verify groups.
84 | inds_list = []
85 | for group in groups:
86 | inds_list += list(group)
87 | assert np.all(np.sort(inds_list) == np.arange(num_features))
88 |
89 | # Map groups to features.
90 | self.num_players = len(groups)
91 | device = next(surrogate.parameters()).device
92 | self.groups_matrix = torch.zeros(
93 | len(groups), num_features, dtype=torch.float32, device=device)
94 | for i, group in enumerate(groups):
95 | self.groups_matrix[i, group] = 1
96 |
97 | def train(self,
98 | train_data,
99 | val_data,
100 | batch_size,
101 | max_epochs,
102 | loss_fn,
103 | validation_samples=1,
104 | validation_batch_size=None,
105 | lr=1e-3,
106 | min_lr=1e-5,
107 | lr_factor=0.5,
108 | lookback=5,
109 | training_seed=None,
110 | validation_seed=None,
111 | bar=False,
112 | verbose=False):
113 | '''
114 | Train surrogate model.
115 | Args:
116 | train_data: training data with inputs and the original model's
117 | predictions (np.ndarray tuple, torch.Tensor tuple,
118 | torch.utils.data.Dataset).
119 | val_data: validation data with inputs and the original model's
120 | predictions (np.ndarray tuple, torch.Tensor tuple,
121 | torch.utils.data.Dataset).
122 | batch_size: minibatch size.
123 | max_epochs: maximum training epochs.
124 | loss_fn: loss function (e.g., fastshap.KLDivLoss).
125 | validation_samples: number of samples per validation example.
126 | validation_batch_size: validation minibatch size.
127 | lr: initial learning rate.
128 | min_lr: minimum learning rate.
129 | lr_factor: learning rate decrease factor.
130 | lookback: lookback window for early stopping.
131 | training_seed: random seed for training.
132 | validation_seed: random seed for generating validation data.
133 | verbose: verbosity.
134 | '''
135 | # Set up train dataset.
136 | if isinstance(train_data, tuple):
137 | x_train, y_train = train_data
138 | if isinstance(x_train, np.ndarray):
139 | x_train = torch.tensor(x_train, dtype=torch.float32)
140 | y_train = torch.tensor(y_train, dtype=torch.float32)
141 | train_set = TensorDataset(x_train, y_train)
142 | elif isinstance(train_data, Dataset):
143 | train_set = train_data
144 | else:
145 | raise ValueError('train_data must be either tuple of tensors or a '
146 | 'PyTorch Dataset')
147 |
148 | # Set up train data loader.
149 | random_sampler = RandomSampler(
150 | train_set, replacement=True,
151 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
152 | batch_sampler = BatchSampler(
153 | random_sampler, batch_size=batch_size, drop_last=True)
154 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler)
155 |
156 | # Set up validation dataset.
157 | sampler = UniformSampler(self.num_players)
158 | if validation_seed is not None:
159 | torch.manual_seed(validation_seed)
160 | S_val = sampler.sample(len(val_data) * validation_samples)
161 |
162 | if isinstance(val_data, tuple):
163 | x_val, y_val = val_data
164 | if isinstance(x_val, np.ndarray):
165 | x_val = torch.tensor(x_val, dtype=torch.float32)
166 | y_val = torch.tensor(y_val, dtype=torch.float32)
167 | x_val_repeat = x_val.repeat(validation_samples, 1)
168 | y_val_repeat = y_val.repeat(validation_samples, 1)
169 | val_set = TensorDataset(
170 | x_val_repeat, y_val_repeat, S_val)
171 | elif isinstance(val_data, Dataset):
172 | val_set = DatasetRepeat([val_data, TensorDataset(S_val)])
173 | else:
174 | raise ValueError('val_data must be either tuple of tensors or a '
175 | 'PyTorch Dataset')
176 |
177 | if validation_batch_size is None:
178 | validation_batch_size = batch_size
179 | val_loader = DataLoader(val_set, batch_size=validation_batch_size)
180 |
181 | # Setup for training.
182 | surrogate = self.surrogate
183 | device = next(surrogate.parameters()).device
184 | optimizer = optim.Adam(surrogate.parameters(), lr=lr)
185 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
186 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,
187 | verbose=verbose)
188 | best_loss = validate(self, loss_fn, val_loader).item()
189 | best_epoch = 0
190 | best_model = deepcopy(surrogate)
191 | loss_list = [best_loss]
192 | if training_seed is not None:
193 | torch.manual_seed(training_seed)
194 |
195 | for epoch in range(max_epochs):
196 | # Batch iterable.
197 | if bar:
198 | batch_iter = tqdm(train_loader, desc='Training epoch')
199 | else:
200 | batch_iter = train_loader
201 |
202 | for x, y in batch_iter:
203 | # Prepare data.
204 | x = x.to(device)
205 | y = y.to(device)
206 |
207 | # Generate subsets.
208 | S = sampler.sample(batch_size).to(device=device)
209 |
210 | # Make predictions.
211 | pred = self.__call__(x, S)
212 | loss = loss_fn(pred, y)
213 |
214 | # Optimizer step.
215 | loss.backward()
216 | optimizer.step()
217 | surrogate.zero_grad()
218 |
219 | # Evaluate validation loss.
220 | self.surrogate.eval()
221 | val_loss = validate(self, loss_fn, val_loader).item()
222 | self.surrogate.train()
223 |
224 | # Print progress.
225 | if verbose:
226 | print('----- Epoch = {} -----'.format(epoch + 1))
227 | print('Val loss = {:.4f}'.format(val_loss))
228 | print('')
229 | scheduler.step(val_loss)
230 | loss_list.append(val_loss)
231 |
232 | # Check if best model.
233 | if val_loss < best_loss:
234 | best_loss = val_loss
235 | best_model = deepcopy(surrogate)
236 | best_epoch = epoch
237 | if verbose:
238 | print('New best epoch, loss = {:.4f}'.format(val_loss))
239 | print('')
240 | elif epoch - best_epoch == lookback:
241 | if verbose:
242 | print('Stopping early')
243 | break
244 |
245 | # Clean up.
246 | for param, best_param in zip(surrogate.parameters(),
247 | best_model.parameters()):
248 | param.data = best_param.data
249 | self.loss_list = loss_list
250 | self.surrogate.eval()
251 |
252 | def train_original_model(self,
253 | train_data,
254 | val_data,
255 | original_model,
256 | batch_size,
257 | max_epochs,
258 | loss_fn,
259 | validation_samples=1,
260 | validation_batch_size=None,
261 | lr=1e-3,
262 | min_lr=1e-5,
263 | lr_factor=0.5,
264 | lookback=5,
265 | training_seed=None,
266 | validation_seed=None,
267 | bar=False,
268 | verbose=False):
269 | '''
270 | Train surrogate model with labels provided by the original model.
271 | Args:
272 | train_data: training data with inputs only (np.ndarray, torch.Tensor,
273 | torch.utils.data.Dataset).
274 | val_data: validation data with inputs only (np.ndarray, torch.Tensor,
275 | torch.utils.data.Dataset).
276 | original_model: original predictive model (e.g., torch.nn.Module).
277 | batch_size: minibatch size.
278 | max_epochs: maximum training epochs.
279 | loss_fn: loss function (e.g., fastshap.KLDivLoss).
280 | validation_samples: number of samples per validation example.
281 | validation_batch_size: validation minibatch size.
282 | lr: initial learning rate.
283 | min_lr: minimum learning rate.
284 | lr_factor: learning rate decrease factor.
285 | lookback: lookback window for early stopping.
286 | training_seed: random seed for training.
287 | validation_seed: random seed for generating validation data.
288 | verbose: verbosity.
289 | '''
290 | # Set up train dataset.
291 | if isinstance(train_data, np.ndarray):
292 | train_data = torch.tensor(train_data, dtype=torch.float32)
293 |
294 | if isinstance(train_data, torch.Tensor):
295 | train_set = TensorDataset(train_data)
296 | elif isinstance(train_data, Dataset):
297 | train_set = train_data
298 | else:
299 | raise ValueError('train_data must be either tensor or a '
300 | 'PyTorch Dataset')
301 |
302 | # Set up train data loader.
303 | random_sampler = RandomSampler(
304 | train_set, replacement=True,
305 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
306 | batch_sampler = BatchSampler(
307 | random_sampler, batch_size=batch_size, drop_last=True)
308 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler)
309 |
310 | # Set up validation dataset.
311 | sampler = UniformSampler(self.num_players)
312 | if validation_seed is not None:
313 | torch.manual_seed(validation_seed)
314 | S_val = sampler.sample(len(val_data) * validation_samples)
315 | if validation_batch_size is None:
316 | validation_batch_size = batch_size
317 |
318 | if isinstance(val_data, np.ndarray):
319 | val_data = torch.tensor(val_data, dtype=torch.float32)
320 |
321 | if isinstance(val_data, torch.Tensor):
322 | # Generate validation labels.
323 | y_val = generate_labels(TensorDataset(val_data), original_model,
324 | validation_batch_size)
325 | y_val_repeat = y_val.repeat(
326 | validation_samples, *[1 for _ in y_val.shape[1:]])
327 |
328 | # Create dataset.
329 | val_data_repeat = val_data.repeat(validation_samples, 1)
330 | val_set = TensorDataset(val_data_repeat, y_val_repeat, S_val)
331 | elif isinstance(val_data, Dataset):
332 | # Generate validation labels.
333 | y_val = generate_labels(val_data, original_model,
334 | validation_batch_size)
335 | y_val_repeat = y_val.repeat(
336 | validation_samples, *[1 for _ in y_val.shape[1:]])
337 |
338 | # Create dataset.
339 | val_set = DatasetRepeat(
340 | [val_data, TensorDataset(y_val_repeat, S_val)])
341 | else:
342 | raise ValueError('val_data must be either tuple of tensors or a '
343 | 'PyTorch Dataset')
344 |
345 | val_loader = DataLoader(val_set, batch_size=validation_batch_size)
346 |
347 | # Setup for training.
348 | surrogate = self.surrogate
349 | device = next(surrogate.parameters()).device
350 | optimizer = optim.Adam(surrogate.parameters(), lr=lr)
351 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
352 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,
353 | verbose=verbose)
354 | best_loss = validate(self, loss_fn, val_loader).item()
355 | best_epoch = 0
356 | best_model = deepcopy(surrogate)
357 | loss_list = [best_loss]
358 | if training_seed is not None:
359 | torch.manual_seed(training_seed)
360 |
361 | for epoch in range(max_epochs):
362 | # Batch iterable.
363 | if bar:
364 | batch_iter = tqdm(train_loader, desc='Training epoch')
365 | else:
366 | batch_iter = train_loader
367 |
368 | for (x,) in batch_iter:
369 | # Prepare data.
370 | x = x.to(device)
371 |
372 | # Get original model prediction.
373 | with torch.no_grad():
374 | y = original_model(x)
375 |
376 | # Generate subsets.
377 | S = sampler.sample(batch_size).to(device=device)
378 |
379 | # Make predictions.
380 | pred = self.__call__(x, S)
381 | loss = loss_fn(pred, y)
382 |
383 | # Optimizer step.
384 | loss.backward()
385 | optimizer.step()
386 | surrogate.zero_grad()
387 |
388 | # Evaluate validation loss.
389 | self.surrogate.eval()
390 | val_loss = validate(self, loss_fn, val_loader).item()
391 | self.surrogate.train()
392 |
393 | # Print progress.
394 | if verbose:
395 | print('----- Epoch = {} -----'.format(epoch + 1))
396 | print('Val loss = {:.4f}'.format(val_loss))
397 | print('')
398 | scheduler.step(val_loss)
399 | loss_list.append(val_loss)
400 |
401 | # Check if best model.
402 | if val_loss < best_loss:
403 | best_loss = val_loss
404 | best_model = deepcopy(surrogate)
405 | best_epoch = epoch
406 | if verbose:
407 | print('New best epoch, loss = {:.4f}'.format(val_loss))
408 | print('')
409 | elif epoch - best_epoch == lookback:
410 | if verbose:
411 | print('Stopping early')
412 | break
413 |
414 | # Clean up.
415 | for param, best_param in zip(surrogate.parameters(),
416 | best_model.parameters()):
417 | param.data = best_param.data
418 | self.loss_list = loss_list
419 | self.surrogate.eval()
420 |
421 | def __call__(self, x, S):
422 | '''
423 | Evaluate surrogate model.
424 | Args:
425 | x: input examples.
426 | S: coalitions.
427 | '''
428 | if self.groups_matrix is not None:
429 | S = torch.mm(S, self.groups_matrix)
430 | try:
431 | surr_value=self.surrogate((x, S))
432 | except:
433 | surr_value=self.surrogate((x, S.long()))
434 | return surr_value
435 |
436 |
--------------------------------------------------------------------------------
/weightedSHAP/third_party/image_surrogate.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://github.com/iancovert/fastshap/blob/main/fastshap/image_surrogate.py
3 | '''
4 | import torch
5 | import torch.optim as optim
6 | import numpy as np
7 | from torch.utils.data import Dataset, TensorDataset, DataLoader
8 | from torch.utils.data import RandomSampler, BatchSampler
9 | from .image_imputers import ImageImputer
10 | from .utils import UniformSampler, DatasetRepeat
11 | from tqdm.auto import tqdm
12 | from copy import deepcopy
13 |
14 | def validate(surrogate, loss_fn, data_loader):
15 | '''
16 | Calculate mean validation loss.
17 | Args:
18 | loss_fn: loss function.
19 | data_loader: data loader.
20 | '''
21 | with torch.no_grad():
22 | # Setup.
23 | device = next(surrogate.surrogate.parameters()).device
24 | mean_loss = 0
25 | N = 0
26 |
27 | for x, y, S in data_loader:
28 | x = x.to(device)
29 | y = y.to(device)
30 | S = S.to(device)
31 | pred = surrogate(x, S)
32 | loss = loss_fn(pred, y)
33 | N += len(x)
34 | mean_loss += len(x) * (loss - mean_loss) / N
35 |
36 | return mean_loss
37 |
38 |
39 | def generate_labels(dataset, model, batch_size, num_workers):
40 | '''
41 | Generate prediction labels for a set of inputs.
42 | Args:
43 | dataset: dataset object.
44 | model: predictive model.
45 | batch_size: minibatch size.
46 | num_workers: number of worker threads.
47 | '''
48 | with torch.no_grad():
49 | # Setup.
50 | preds = []
51 | device = next(model.parameters()).device
52 | loader = DataLoader(dataset, batch_size=batch_size, pin_memory=True,
53 | num_workers=num_workers)
54 |
55 | for (x,) in loader:
56 | pred = model(x.to(device)).cpu()
57 | preds.append(pred)
58 |
59 | return torch.cat(preds)
60 |
61 |
62 | class ImageSurrogate(ImageImputer):
63 | '''
64 | Wrapper around image surrogate model.
65 | Args:
66 | surrogate: surrogate model (torch.nn.Module).
67 | width: image width.
68 | height: image height.
69 | superpixel_size: superpixel width/height (int).
70 | '''
71 |
72 | def __init__(self, surrogate, width, height, superpixel_size):
73 | # Initialize for coalition resizing, number of players.
74 | super().__init__(width, height, superpixel_size)
75 |
76 | # Store surrogate model.
77 | self.surrogate = surrogate
78 |
79 | def train(self,
80 | train_data,
81 | val_data,
82 | batch_size,
83 | max_epochs,
84 | loss_fn,
85 | validation_samples=1,
86 | validation_batch_size=None,
87 | lr=1e-3,
88 | min_lr=1e-5,
89 | lr_factor=0.5,
90 | lookback=5,
91 | training_seed=None,
92 | validation_seed=None,
93 | num_workers=0,
94 | bar=False,
95 | verbose=False):
96 | '''
97 | Train surrogate model.
98 | Args:
99 | train_data: training data with inputs and the original model's
100 | predictions (np.ndarray tuple, torch.Tensor tuple,
101 | torch.utils.data.Dataset).
102 | val_data: validation data with inputs and the original model's
103 | predictions (np.ndarray tuple, torch.Tensor tuple,
104 | torch.utils.data.Dataset).
105 | batch_size: minibatch size.
106 | max_epochs: max number of training epochs.
107 | loss_fn: loss function (e.g., fastshap.KLDivLoss)
108 | validation_samples: number of samples per validation example.
109 | validation_batch_size: validation minibatch size.
110 | lr: initial learning rate.
111 | min_lr: minimum learning rate.
112 | lr_factor: learning rate decrease factor.
113 | lookback: lookback window for early stopping.
114 | training_seed: random seed for training.
115 | validation_seed: random seed for generating validation data.
116 | num_workers: number of worker threads in data loader.
117 | bar: whether to show progress bar.
118 | verbose: verbosity.
119 | '''
120 | # Set up train dataset.
121 | if isinstance(train_data, tuple):
122 | x_train, y_train = train_data
123 | if isinstance(x_train, np.ndarray):
124 | x_train = torch.tensor(x_train, dtype=torch.float32)
125 | y_train = torch.tensor(y_train, dtype=torch.float32)
126 | train_set = TensorDataset(x_train, y_train)
127 | elif isinstance(train_data, Dataset):
128 | train_set = train_data
129 | else:
130 | raise ValueError('train_data must be either tuple of tensors or a '
131 | 'PyTorch Dataset')
132 |
133 | # Set up train data loader.
134 | random_sampler = RandomSampler(
135 | train_set, replacement=True,
136 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
137 | batch_sampler = BatchSampler(
138 | random_sampler, batch_size=batch_size, drop_last=True)
139 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler,
140 | pin_memory=True, num_workers=num_workers)
141 |
142 | # Set up validation dataset.
143 | sampler = UniformSampler(self.num_players)
144 | if validation_seed is not None:
145 | torch.manual_seed(validation_seed)
146 | S_val = sampler.sample(len(val_data) * validation_samples)
147 |
148 | if isinstance(val_data, tuple):
149 | x_val, y_val = val_data
150 | if isinstance(x_val, np.ndarray):
151 | x_val = torch.tensor(x_val, dtype=torch.float32)
152 | y_val = torch.tensor(y_val, dtype=torch.float32)
153 | x_val_repeat = x_val.repeat(validation_samples, 1, 1, 1)
154 | y_val_repeat = y_val.repeat(validation_samples, 1)
155 | val_set = TensorDataset(x_val_repeat, y_val_repeat, S_val)
156 | elif isinstance(val_data, Dataset):
157 | val_set = DatasetRepeat([val_data, TensorDataset(S_val)])
158 | else:
159 | raise ValueError('val_data must be either tuple of tensors or a '
160 | 'PyTorch Dataset')
161 |
162 | if validation_batch_size is None:
163 | validation_batch_size = batch_size
164 | val_loader = DataLoader(val_set, batch_size=validation_batch_size,
165 | pin_memory=True, num_workers=num_workers)
166 |
167 | # Setup for training.
168 | surrogate = self.surrogate
169 | device = next(surrogate.parameters()).device
170 | optimizer = optim.Adam(surrogate.parameters(), lr=lr)
171 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
172 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,
173 | verbose=verbose)
174 | best_loss = validate(self, loss_fn, val_loader).item()
175 | best_epoch = 0
176 | best_model = deepcopy(surrogate)
177 | loss_list = [best_loss]
178 | if training_seed is not None:
179 | torch.manual_seed(training_seed)
180 |
181 | for epoch in range(max_epochs):
182 | # Batch iterable.
183 | if bar:
184 | batch_iter = tqdm(train_loader, desc='Training epoch')
185 | else:
186 | batch_iter = train_loader
187 |
188 | for x, y in batch_iter:
189 | # Prepare data.
190 | x = x.to(device)
191 | y = y.to(device)
192 |
193 | # Generate subsets.
194 | S = sampler.sample(batch_size).to(device=device)
195 |
196 | # Make predictions.
197 | pred = self.__call__(x, S)
198 | loss = loss_fn(pred, y)
199 |
200 | # Optimizer step.
201 | loss.backward()
202 | optimizer.step()
203 | surrogate.zero_grad()
204 |
205 | # Evaluate validation loss.
206 | self.surrogate.eval()
207 | val_loss = validate(self, loss_fn, val_loader).item()
208 | self.surrogate.train()
209 |
210 | # Print progress.
211 | if verbose:
212 | print('----- Epoch = {} -----'.format(epoch + 1))
213 | print('Val loss = {:.4f}'.format(val_loss))
214 | print('')
215 | scheduler.step(val_loss)
216 | loss_list.append(val_loss)
217 |
218 | # Check if best model.
219 | if val_loss < best_loss:
220 | best_loss = val_loss
221 | best_model = deepcopy(surrogate)
222 | best_epoch = epoch
223 | if verbose:
224 | print('New best epoch, loss = {:.4f}'.format(val_loss))
225 | print('')
226 | elif epoch - best_epoch == lookback:
227 | if verbose:
228 | print('Stopping early')
229 | break
230 |
231 | # Clean up.
232 | for param, best_param in zip(surrogate.parameters(),
233 | best_model.parameters()):
234 | param.data = best_param.data
235 | self.loss_list = loss_list
236 | self.surrogate.eval()
237 |
238 | def train_original_model(self,
239 | train_data,
240 | val_data,
241 | original_model,
242 | batch_size,
243 | max_epochs,
244 | loss_fn,
245 | validation_samples=1,
246 | validation_batch_size=None,
247 | lr=1e-3,
248 | min_lr=1e-5,
249 | lr_factor=0.5,
250 | lookback=5,
251 | training_seed=None,
252 | validation_seed=None,
253 | num_workers=0,
254 | bar=False,
255 | verbose=False):
256 | '''
257 | Train surrogate model with labels provided by the original model. This
258 | approach is designed for when data augmentations make the data loader
259 | non-deterministic, and labels (the original model's predictions) cannot
260 | be generated prior to training.
261 | Args:
262 | train_data: training data with inputs only (np.ndarray, torch.Tensor,
263 | torch.utils.data.Dataset).
264 | val_data: validation data with inputs only (np.ndarray, torch.Tensor,
265 | torch.utils.data.Dataset).
266 | original_model: original predictive model (e.g., torch.nn.Module).
267 | batch_size: minibatch size.
268 | max_epochs: max number of training epochs.
269 | loss_fn: loss function (e.g., fastshap.KLDivLoss)
270 | validation_samples: number of samples per validation example.
271 | validation_batch_size: validation minibatch size.
272 | lr: initial learning rate.
273 | min_lr: minimum learning rate.
274 | lr_factor: learning rate decrease factor.
275 | lookback: lookback window for early stopping.
276 | training_seed: random seed for training.
277 | validation_seed: random seed for generating validation data.
278 | num_workers: number of worker threads in data loader.
279 | bar: whether to show progress bar.
280 | verbose: verbosity.
281 | '''
282 | # Set up train dataset.
283 | if isinstance(train_data, np.ndarray):
284 | train_data = torch.tensor(train_data, dtype=torch.float32)
285 |
286 | if isinstance(train_data, torch.Tensor):
287 | train_set = TensorDataset(train_data)
288 | elif isinstance(train_data, Dataset):
289 | train_set = train_data
290 | else:
291 | raise ValueError('train_data must be either tensor or a '
292 | 'PyTorch Dataset')
293 |
294 | # Set up train data loader.
295 | random_sampler = RandomSampler(
296 | train_set, replacement=True,
297 | num_samples=int(np.ceil(len(train_set) / batch_size))*batch_size)
298 | batch_sampler = BatchSampler(
299 | random_sampler, batch_size=batch_size, drop_last=True)
300 | train_loader = DataLoader(train_set, batch_sampler=batch_sampler,
301 | pin_memory=True, num_workers=num_workers)
302 |
303 | # Set up validation dataset.
304 | sampler = UniformSampler(self.num_players)
305 | if validation_seed is not None:
306 | torch.manual_seed(validation_seed)
307 | S_val = sampler.sample(len(val_data) * validation_samples)
308 | if validation_batch_size is None:
309 | validation_batch_size = batch_size
310 |
311 | if isinstance(val_data, np.ndarray):
312 | val_data = torch.tensor(val_data, dtype=torch.float32)
313 |
314 | if isinstance(val_data, torch.Tensor):
315 | # Generate validation labels.
316 | y_val = generate_labels(TensorDataset(val_data), original_model,
317 | validation_batch_size, num_workers)
318 | y_val_repeat = y_val.repeat(
319 | validation_samples, *[1 for _ in y_val.shape[1:]])
320 |
321 | # Create dataset.
322 | val_data_repeat = val_data.repeat(validation_samples, 1, 1, 1)
323 | val_set = TensorDataset(val_data_repeat, y_val_repeat, S_val)
324 | elif isinstance(val_data, Dataset):
325 | # Generate validation labels.
326 | y_val = generate_labels(val_data, original_model,
327 | validation_batch_size, num_workers)
328 | y_val_repeat = y_val.repeat(
329 | validation_samples, *[1 for _ in y_val.shape[1:]])
330 |
331 | # Create dataset.
332 | val_set = DatasetRepeat(
333 | [val_data, TensorDataset(y_val_repeat, S_val)])
334 | else:
335 | raise ValueError('val_data must be either tuple of tensors or a '
336 | 'PyTorch Dataset')
337 |
338 | val_loader = DataLoader(val_set, batch_size=validation_batch_size,
339 | pin_memory=True, num_workers=num_workers)
340 |
341 | # Setup for training.
342 | surrogate = self.surrogate
343 | device = next(surrogate.parameters()).device
344 | optimizer = optim.Adam(surrogate.parameters(), lr=lr)
345 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
346 | optimizer, factor=lr_factor, patience=lookback // 2, min_lr=min_lr,
347 | verbose=verbose)
348 | best_loss = validate(self, loss_fn, val_loader).item()
349 | best_epoch = 0
350 | best_model = deepcopy(surrogate)
351 | loss_list = [best_loss]
352 | if training_seed is not None:
353 | torch.manual_seed(training_seed)
354 |
355 | for epoch in range(max_epochs):
356 | # Batch iterable.
357 | if bar:
358 | batch_iter = tqdm(train_loader, desc='Training epoch')
359 | else:
360 | batch_iter = train_loader
361 |
362 | for (x,) in batch_iter:
363 | # Prepare data.
364 | x = x.to(device)
365 |
366 | # Get original model prediction.
367 | with torch.no_grad():
368 | y = original_model(x)
369 |
370 | # Generate subsets.
371 | S = sampler.sample(batch_size).to(device=device)
372 |
373 | # Make predictions.
374 | pred = self.__call__(x, S)
375 | loss = loss_fn(pred, y)
376 |
377 | # Optimizer step.
378 | loss.backward()
379 | optimizer.step()
380 | surrogate.zero_grad()
381 |
382 | # Evaluate validation loss.
383 | self.surrogate.eval()
384 | val_loss = validate(self, loss_fn, val_loader).item()
385 | self.surrogate.train()
386 |
387 | # Print progress.
388 | if verbose:
389 | print('----- Epoch = {} -----'.format(epoch + 1))
390 | print('Val loss = {:.4f}'.format(val_loss))
391 | print('')
392 | scheduler.step(val_loss)
393 | loss_list.append(val_loss)
394 |
395 | # Check if best model.
396 | if val_loss < best_loss:
397 | best_loss = val_loss
398 | best_model = deepcopy(surrogate)
399 | best_epoch = epoch
400 | if verbose:
401 | print('New best epoch, loss = {:.4f}'.format(val_loss))
402 | print('')
403 | elif epoch - best_epoch == lookback:
404 | if verbose:
405 | print('Stopping early')
406 | break
407 |
408 | # Clean up.
409 | for param, best_param in zip(surrogate.parameters(),
410 | best_model.parameters()):
411 | param.data = best_param.data
412 | self.loss_list = loss_list
413 | self.surrogate.eval()
414 |
415 | def __call__(self, x, S):
416 | '''
417 | Evaluate surrogate model.
418 | Args:
419 | x: input examples.
420 | S: coalitions.
421 | '''
422 | S = self.resize(S)
423 | return self.surrogate((x, S))
424 |
425 |
426 |
--------------------------------------------------------------------------------