├── daspTorch ├── __init__.py └── dasp.py └── README.md /daspTorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .dasp import DASP -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Approximate-Shapley-Propagation 2 | This is a Pytorch Implementation of the DASP algorithm from the paper [Explaining Deep Neural Networks with a Polynomial Time Algorithm for Shapley Value Approximation](https://icml.cc/media/Slides/icml/2019/grandball(13-09-00)-13-09-25-4776-explaining_deep.pdf) 3 | -------------------------------------------------------------------------------- /daspTorch/dasp.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from lightprobnets.contrib import adf 8 | 9 | 10 | def spaced_elements(array, num_elems=4): 11 | return [x[len(x) // 2] for x in np.array_split(np.array(array), num_elems)] 12 | 13 | 14 | class AbstractPlayerIterator(ABC): 15 | 16 | def __init__(self, inputs, random=False): 17 | self._assert_input_compatibility(inputs) 18 | self.input_shape = inputs.shape[1:] 19 | self.random = random 20 | self.n_players = self._get_number_of_players_from_shape() 21 | self.permutation = np.array(range(self.n_players), 'int32') 22 | if random is True: 23 | self.permutation = np.random.permutation(self.permutation) 24 | self.i = 0 25 | self.kn = self.n_players 26 | self.ks = spaced_elements(range(self.n_players), self.kn) 27 | 28 | def set_n_steps(self, steps): 29 | self.kn = steps 30 | self.ks = spaced_elements(range(self.n_players), self.kn) 31 | 32 | def get_number_of_players(self): 33 | return self.n_players 34 | 35 | def get_explanation_shape(self): 36 | return self.input_shape 37 | 38 | def get_coalition_size(self): 39 | return 1 40 | 41 | def get_steps_list(self): 42 | return self.ks 43 | 44 | def __iter__(self): 45 | self.i = 0 46 | return self 47 | 48 | def __next__(self): 49 | if self.i == self.n_players: 50 | raise StopIteration 51 | m = self._get_masks_for_index(self.i) 52 | self.i = self.i + 1 53 | return m 54 | 55 | @abstractmethod 56 | def _assert_input_compatibility(self, inputs): 57 | pass 58 | 59 | @abstractmethod 60 | def _get_masks_for_index(self, i): 61 | pass 62 | 63 | @abstractmethod 64 | def _get_number_of_players_from_shape(self): 65 | pass 66 | 67 | 68 | class DefaultPlayerIterator(AbstractPlayerIterator): 69 | 70 | def _assert_input_compatibility(self, inputs): 71 | assert len(inputs.shape) > 1, 'DefaultPlayerIterator requires an input with 2 or more dimensions' 72 | 73 | def _get_number_of_players_from_shape(self): 74 | return int(np.prod(self.input_shape)) 75 | 76 | def _get_masks_for_index(self, i): 77 | mask = np.zeros(self.n_players, dtype='int32') 78 | mask[self.permutation[i]] = 1 79 | return mask.reshape(self.input_shape), mask.reshape(self.input_shape) 80 | 81 | 82 | def keep_variance(x, min_variance): 83 | return x + min_variance 84 | 85 | 86 | def convert_2_lpdn(model: nn.Module, convert_weights: bool = True) -> nn.Module: 87 | """ 88 | Convert the model into a LPDN 89 | Conversion code skeleton from https://discuss.pytorch.org/t/how-can-i-replace-an-intermediate-layer-in-a-pre-trained-network/3586/7 90 | :param model: The model to convert 91 | :param convert_weights: 92 | :return: converted LPDN 93 | """ 94 | min_variance = 1e-3 95 | keep_variance_fn = lambda x: keep_variance(x, min_variance) 96 | for name, module in reversed(model._modules.items()): 97 | if len(list(module.children())) > 0: 98 | # recurse 99 | model._modules[name] = convert_2_lpdn(module, convert_weights) 100 | else: 101 | if isinstance(module, nn.Conv2d): 102 | layer_new = adf.Conv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, 103 | module.padding, module.dilation, module.groups, 104 | module.bias is not None, module.padding_mode, keep_variance_fn=keep_variance_fn) 105 | elif isinstance(module, nn.Linear): 106 | layer_new = adf.Linear(module.in_features, module.out_features, module.bias is not None, 107 | keep_variance_fn=keep_variance_fn) 108 | elif isinstance(module, nn.ReLU): 109 | layer_new = adf.ReLU(keep_variance_fn=keep_variance_fn) 110 | elif isinstance(module, nn.LeakyReLU): 111 | layer_new = adf.LeakyReLU(negative_slope=module.negative_slope, keep_variance_fn=keep_variance_fn) 112 | elif isinstance(module, nn.Dropout): 113 | layer_new = adf.Dropout(module.p, keep_variance_fn=keep_variance_fn) 114 | elif isinstance(module, nn.MaxPool2d): 115 | layer_new = adf.MaxPool2d(keep_variance_fn=keep_variance_fn) 116 | elif isinstance(module, nn.ConvTranspose2d): 117 | layer_new = adf.ConvTranspose2d(module.in_channels, module.out_channels, module.kernel_size, 118 | module.stride, module.padding, module.output_padding, module.groups, 119 | module.bias, module.dilation, keep_variance_fn=keep_variance_fn) 120 | else: 121 | raise NotImplementedError(f"Layer type {module} not supported") 122 | layer_old = module 123 | try: 124 | if convert_weights: 125 | layer_new.weight = layer_old.weight 126 | layer_new.bias = layer_old.bias 127 | except AttributeError: 128 | pass 129 | 130 | model._modules[name] = layer_new 131 | 132 | return model 133 | 134 | 135 | class DASPModel(nn.Module): 136 | def __init__(self, first_layer, lpdn_model): 137 | super(DASPModel, self).__init__() 138 | self.first_layer = ProbDenseInput(first_layer.in_features, first_layer.out_features, 139 | bias=first_layer.bias is not None) 140 | self.lpdn_model = lpdn_model 141 | self.first_layer.weight = first_layer.weight 142 | self.first_layer.bias = first_layer.bias 143 | 144 | def forward(self, inputs: torch.Tensor, mask: torch.Tensor, k: int): 145 | x1_mean, x1_var, x2_mean, x2_var = self.first_layer(inputs, mask, k) 146 | y1_mean, y1_var = self.lpdn_model(x1_mean, x1_var) 147 | y2_mean, y2_var = self.lpdn_model(x2_mean, x2_var) 148 | 149 | return torch.stack([y1_mean, y1_var], -1), torch.stack([y2_mean, y2_var], -1) 150 | 151 | 152 | class DASP(object): 153 | def __init__(self, model: nn.Module): 154 | self.model = model 155 | self._build_dasp_model() 156 | 157 | def _build_dasp_model(self): 158 | first_layer: nn.Linear = self.model.linear1 159 | lpdn_model = self._convert_to_lpdn(self.model) 160 | lpdn_model.noise_variance = 1e-3 161 | self.dasp_model = DASPModel(first_layer, lpdn_model=lpdn_model) 162 | 163 | def _convert_to_lpdn(self, model: nn.Module): 164 | return convert_2_lpdn(model, True) 165 | 166 | def __call__(self, x, steps=None): 167 | player_generator = DefaultPlayerIterator(x) 168 | player_generator.set_n_steps(steps if x.shape[1] > steps else x.shape[1]) 169 | ks = player_generator.get_steps_list() 170 | result = None 171 | tile_input = [len(ks)] + (len(x.shape) - 1) * [1] 172 | tile_mask = [len(ks) * x.shape[0]] + (len(x.shape) - 1) * [1] 173 | for i, (mask, mask_output) in enumerate(player_generator): 174 | # This line is from Keras implementation and will be updated soon 175 | # Workaround: as Keras requires the first dimension of the inputs to be the same, 176 | # we tile and repeat the input, mask and ks vector to have them aligned. 177 | y1, y2 = self.dasp_model(inputs=torch.tensor(np.tile(x, tile_input)), 178 | mask=torch.tensor(np.tile(mask, tile_mask)), 179 | k=torch.tensor(np.repeat(ks, x.shape[0]))) 180 | y1 = y1.reshape(len(ks), x.shape[0], -1, 2) 181 | y2 = y2.reshape(len(ks), x.shape[0], -1, 2) 182 | y = torch.mean(y2[..., 0] - y1[..., 0], 0) 183 | if torch.isnan(y).any(): 184 | raise RuntimeError('Result contains nans! This should not happen...') 185 | 186 | # Compute Shapley Values as mean of all coalition sizes 187 | if result is None: 188 | result = torch.zeros(y.shape + mask_output.shape) 189 | 190 | shape_mask = [1] * len(y.shape) 191 | shape_mask += list(mask_output.shape) 192 | 193 | shape_out = list(y.shape) 194 | shape_out += [1] * len(mask_output.shape) 195 | 196 | result += torch.reshape(y, shape_out) * torch.tensor(mask_output) 197 | 198 | return result 199 | 200 | 201 | if __name__ == "__main__": 202 | pass 203 | --------------------------------------------------------------------------------