├── .gitignore ├── LICENSE ├── acd ├── __init__.py ├── agglomeration │ ├── __init__.py │ ├── agg_1d.py │ └── agg_2d.py ├── readme.md ├── scores │ ├── __init__.py │ ├── cd.py │ ├── cd_architecture_specific.py │ ├── cd_propagate.py │ └── score_funcs.py └── util │ ├── __init__.py │ ├── conv2dnp.py │ ├── tiling_1d.py │ └── tiling_2d.py ├── citation.bib ├── docs ├── agglomeration │ ├── agg_1d.html │ ├── agg_2d.html │ └── index.html ├── build_docs.sh ├── index.html ├── intro.svg ├── scores │ ├── cd.html │ ├── cd_architecture_specific.html │ ├── cd_propagate.html │ ├── index.html │ └── score_funcs.html ├── style_docs.py └── util │ ├── conv2dnp.html │ ├── index.html │ ├── tiling_1d.html │ └── tiling_2d.html ├── dsets ├── imagenet │ ├── dset.py │ └── imnet_dict.pkl ├── mnist │ ├── dset.py │ ├── mnist.model │ ├── model.py │ └── readme.md └── sst │ ├── dset.py │ ├── model.py │ ├── readme.md │ ├── sst_vocab.pkl │ ├── state_dict.pth │ └── train.py ├── readme.md ├── reproduce_figs ├── figs │ ├── fig_2.png │ ├── fig_s2.png │ └── fig_s3.png ├── imagenet_fig3,s1,s2.ipynb ├── mnist_figs3,s4.ipynb ├── readme.md └── text_fig2.ipynb ├── setup.py ├── tests └── test_cd.py └── visualization ├── viz_1d.py └── viz_2d.py /.gitignore: -------------------------------------------------------------------------------- 1 | **.DS_STORE 2 | __pycache__ 3 | .idea 4 | **Icon* 5 | **.ipynb_checkpoints 6 | **.pyc 7 | **.swp 8 | **cache* 9 | reproduce_figs/mnist_data 10 | reproduce_figs/mnist/data 11 | **egg* 12 | build 13 | dist 14 | venv 15 | test-output.xml 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Chandan Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /acd/__init__.py: -------------------------------------------------------------------------------- 1 | '''Module for computing hierarchical interpretations of neural network predictions 2 | .. include:: ../readme.md 3 | ''' 4 | 5 | from .scores.cd import * 6 | from .scores.cd_propagate import * 7 | from .scores.score_funcs import * 8 | from .agglomeration import agg_1d, agg_2d 9 | from .util import * -------------------------------------------------------------------------------- /acd/agglomeration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/acd/agglomeration/__init__.py -------------------------------------------------------------------------------- /acd/agglomeration/agg_1d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import measure 4 | 5 | from ..scores import score_funcs 6 | from ..util import tiling_1d as tiling 7 | 8 | 9 | def agglomerate(model, batch, percentile_include, method, sweep_dim, 10 | label, num_iters=5, subtract=True, absolute=True, device='cuda'): 11 | '''Agglomerative sweep - black out selected pixels from before and resweep over the entire image 12 | 13 | Returns 14 | ------- 15 | r: dict 16 | r['comps_list'] - arrs of components with diff number for each comp 17 | r['comp_scores_list'] - dicts with score for each comp 18 | r['mask_list'] - boolean arrs of selected 19 | r['scores_list'] - arrs of scores (nan for selected) 20 | r['score_orig'] - original score 21 | ''' 22 | # get original text and score 23 | text_orig = batch.text.data.cpu().numpy() 24 | score_orig = score_funcs.get_scores_1d(batch, model, method, label, only_one=True, 25 | score_orig=None, text_orig=text_orig, subtract=subtract, device=device)[0] 26 | 27 | # get scores 28 | texts = tiling.gen_tiles(text_orig, method=method, sweep_dim=sweep_dim) 29 | texts = texts.transpose() 30 | batch.text.data = torch.LongTensor(texts).to(device) 31 | scores = score_funcs.get_scores_1d(batch, model, method, label, only_one=False, 32 | score_orig=score_orig, text_orig=text_orig, subtract=subtract, device=device) 33 | 34 | # threshold scores 35 | mask = threshold_scores(scores, percentile_include, absolute=absolute) 36 | 37 | # initialize lists 38 | scores_list = [np.copy(scores)] 39 | mask_list = [mask] 40 | comps_list = [] 41 | comp_scores_list = [{0: score_orig}] 42 | 43 | # iterate 44 | for step in range(num_iters): 45 | # find connected components for regions 46 | comps = np.copy(measure.label(mask_list[-1], background=0, connectivity=1)) 47 | 48 | # loop over components 49 | comp_scores_dict = {} 50 | for comp_num in range(1, np.max(comps) + 1): 51 | 52 | # make component tile 53 | comp_tile_bool = (comps == comp_num) 54 | comp_tile = tiling.gen_tile_from_comp(text_orig, comp_tile_bool, method) 55 | 56 | # make neighboring tiles around component 57 | border_tiles = tiling.gen_tiles_around_baseline(text_orig, comp_tile_bool, 58 | method=method, 59 | sweep_dim=sweep_dim) 60 | 61 | # predict for all tiles 62 | # format tiles into batch 63 | tiles_concat = np.hstack((comp_tile, np.squeeze(border_tiles[0]).transpose())) 64 | batch.text.data = torch.LongTensor(tiles_concat).to(device) 65 | 66 | # get scores for this component tile (index 0) and border tiles (indexes 1 onwards) 67 | scores_all = score_funcs.get_scores_1d(batch, model, method, label, only_one=False, 68 | score_orig=score_orig, text_orig=text_orig, subtract=subtract, 69 | device=device) 70 | score_comp = np.copy(scores_all[0]) 71 | scores_border_tiles = np.copy(scores_all[1:]) 72 | 73 | # store the predicted class scores 74 | comp_scores_dict[comp_num] = np.copy(score_comp) 75 | 76 | # update scores for different indexes 77 | tiles_idxs = border_tiles[1] 78 | for i, idx in enumerate(tiles_idxs): 79 | scores[idx] = scores_border_tiles[i] - score_comp 80 | 81 | # get class preds and thresholded image 82 | scores[mask_list[-1]] = np.nan 83 | mask = threshold_scores(scores, percentile_include, absolute=absolute) 84 | 85 | # add to lists 86 | scores_list.append(np.copy(scores)) 87 | mask_list.append(mask_list[-1] + mask) 88 | comps_list.append(comps) 89 | comp_scores_list.append(comp_scores_dict) 90 | 91 | if np.sum(mask) == 0: 92 | break 93 | 94 | # pad first image 95 | comps_list = [np.zeros(text_orig.size, dtype=np.int)] + comps_list 96 | 97 | return { 98 | 'comps_list': comps_list, # arrs of comps with diff number for each comp 99 | 'scores_list': scores_list, # arrs of scores (nan for selected) 100 | 'mask_list': mask_list, # boolean arrs of selected 101 | 'comp_scores_list': comp_scores_list, # dicts with score for each comp 102 | 'score_orig': score_orig # original score 103 | } 104 | 105 | 106 | def threshold_scores(scores, percentile_include, absolute): 107 | '''threshold scores at a specific percentile 108 | 109 | Returns 110 | ------- 111 | mask: np.ndarray 112 | Boolean mask which is true when scores should be kept 113 | ''' 114 | # whether to threshold based on abs value 115 | if absolute: 116 | scores = np.absolute(scores) 117 | 118 | # judgement call: last 5 always pick 2 119 | num_left = scores.size - np.sum(np.isnan(scores)) 120 | if num_left <= 5: 121 | if num_left == 5: 122 | percentile_include = 59 123 | elif num_left == 4: 124 | percentile_include = 49 125 | elif num_left == 3: 126 | percentile_include = 59 127 | elif num_left == 2: 128 | percentile_include = 49 129 | elif num_left == 1: 130 | percentile_include = 0 131 | thresh = np.nanpercentile(scores, percentile_include) 132 | mask = scores >= thresh 133 | return mask 134 | 135 | 136 | def collapse_tree(lists): 137 | '''Removes redundant joins from final hierarchy 138 | Params 139 | ------ 140 | lists: dict 141 | Dictionary of lists output by agglomerate 142 | 143 | Returns 144 | ------- 145 | lists: dicts 146 | Dictionary of lists with redundant joins removed 147 | i.e. merge whenever possible, if it doesn't skip a merge step 148 | ''' 149 | num_iters = len(lists['comps_list']) 150 | num_words = len(lists['comps_list'][0]) 151 | 152 | # need to update comp_scores_list, comps_list 153 | comps_list = [np.zeros(num_words, dtype=np.int) for i in range(num_iters)] 154 | comp_scores_list = [{0: 0} for _ in range(num_iters)] 155 | comp_levels_list = [{0: 0} for _ in range(num_iters)] # use this to determine what level to put things at 156 | 157 | # initialize first level 158 | comps_list[0] = np.arange(num_words) 159 | comp_levels_list[0] = {i: 0 for i in range(num_words)} 160 | 161 | # iterate over levels 162 | for i in range(1, num_iters): 163 | comps = lists['comps_list'][i] 164 | comps_old = lists['comps_list'][i - 1] 165 | comp_scores = lists['comp_scores_list'][i] 166 | 167 | # iterate over number of components 168 | for comp_num in range(1, np.max(comps) + 1): 169 | comp = comps == comp_num 170 | comp_size = np.sum(comp) 171 | if comp_size == 1: 172 | comp_levels_list[i][comp_num] = 0 # set level to 0 173 | else: 174 | # check for matches 175 | matches = np.unique(comps_old[comp]) 176 | num_matches = matches.size 177 | 178 | # if 0 matches, level is 1 179 | if num_matches == 0: 180 | level = 1 181 | comp_levels_list[i][comp_num] = level # set level to level 1 182 | 183 | # if 1 match, maintain level 184 | elif num_matches == 1: 185 | level = comp_levels_list[i - 1][matches[0]] 186 | 187 | 188 | # if >1 match, take highest level + 1 189 | else: 190 | level = np.max([comp_levels_list[i - 1][match] for match in matches]) + 1 191 | 192 | comp_levels_list[i][comp_num] = level 193 | new_comp_num = int(np.max(comps_list[level]) + 1) 194 | comps_list[level][comp] = new_comp_num # update comp 195 | comp_scores_list[level][new_comp_num] = comp_scores[comp_num] # update comp score 196 | 197 | # remove unnecessary iters 198 | num_iters = 0 199 | while np.sum(comps_list[num_iters] > 0) and num_iters < len(comps_list): 200 | num_iters += 1 201 | 202 | # populate lists 203 | lists['comps_list'] = comps_list[:num_iters] 204 | lists['comp_scores_list'] = comp_scores_list[:num_iters] 205 | return lists 206 | -------------------------------------------------------------------------------- /acd/agglomeration/agg_2d.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from math import ceil 3 | 4 | import numpy as np 5 | from scipy.signal import convolve2d 6 | from skimage import measure # for connected components 7 | 8 | from ..scores import score_funcs 9 | from ..util import tiling_2d as tiling 10 | 11 | # cross-shaped filter used to select a pixel when 3 of its neighbors are selected 12 | FILT = np.zeros((3, 3)) 13 | FILT[:, 1] = 1 # middle column 14 | FILT[1, :] = 1 # middle row 15 | 16 | 17 | def agglomerate(model, pred_ims, percentile_include, method, sweep_dim, 18 | im_orig, lab_num, num_iters=5, im_torch=None, model_type='mnist', device='cuda'): 19 | '''Starting from fine-grained units, generate scores for hierarchy of 2d inputs for a particular image 20 | ''' 21 | # set up shapes 22 | R = im_orig.shape[0] 23 | C = im_orig.shape[1] 24 | size_downsampled = (ceil(R / sweep_dim), ceil(C / sweep_dim)) # effectively downsampled 25 | 26 | # get scores for each starting unit 27 | tiles = tiling.gen_tiles(im_orig, fill=0, method=method, sweep_dim=sweep_dim) # masks each individual unit 28 | scores_orig_raw = score_funcs.get_scores_2d(model, method, ims=tiles, im_torch=im_torch, 29 | pred_ims=pred_ims, model_type=model_type, device=device) 30 | scores_track = np.copy(refine_scores(scores_orig_raw, lab_num)).reshape( 31 | size_downsampled) # keep track of these scores 32 | 33 | # threshold im 34 | im_thresh = threshold_scores(scores_track, percentile_include) 35 | 36 | # initialize lists 37 | scores_list = [np.copy(scores_track)] 38 | im_thresh_list = [im_thresh] 39 | comps_list = [] 40 | if not method == 'cd': 41 | comp_scores_raw_list = [{0: score_funcs.get_scores_2d(model, 'build_up', 42 | ims=np.expand_dims(im_orig, 0), # score for full image 43 | im_torch=im_torch, pred_ims=pred_ims, 44 | model_type=model_type, device=device)[0]}] 45 | else: 46 | comp_scores_raw_list = [{0: score_funcs.get_scores_2d(model, method, 47 | ims=np.expand_dims(np.ones(im_orig.transpose().shape), 0), 48 | # score for full image 49 | im_torch=im_torch, pred_ims=pred_ims, 50 | model_type=model_type, device=device)[0]}] 51 | comp_scores_raw_combined_list = [] 52 | 53 | # iterate 54 | for step in range(num_iters): 55 | # if already selected all pixels then break 56 | if np.sum(im_thresh_list[-1]) == R * C: 57 | break 58 | 59 | # find connected components for regions 60 | comps = np.copy(measure.label(im_thresh_list[-1], background=0, connectivity=2)) 61 | 62 | # establish correspondence with components from previous iteration 63 | if step > 0: 64 | comps_orig = np.copy(comps) 65 | try: 66 | comps = establish_correspondence(comps_list[-1], comps_orig) 67 | except: 68 | comps = comps_orig 69 | 70 | comp_tiles = {} # stores tiles corresponding to each tile 71 | if not method == 'cd': 72 | comps_combined_tile = np.zeros(shape=im_orig.shape) # stores all comp tiles combined 73 | else: 74 | comps_combined_tile = np.zeros(shape=(R, C)) # stores all comp tiles combined 75 | comp_surround_tiles = {} # stores tiles around comp_tiles 76 | comp_surround_idxs = {} 77 | 78 | # make tiles 79 | comp_nums = np.unique(comps) 80 | comp_nums = comp_nums[comp_nums > 0] # remove 0 81 | for comp_num in comp_nums: 82 | if comp_num > 0: 83 | # make component tile 84 | comp_tile_downsampled = (comps == comp_num) 85 | comp_tiles[comp_num] = tiling.gen_tile_from_comp(im_orig, comp_tile_downsampled, 86 | sweep_dim, method) # this is full size 87 | comp_tile_binary = tiling.gen_tile_from_comp(im_orig, comp_tile_downsampled, 88 | sweep_dim, 'cd') # this is full size 89 | # print('comps sizes', comps_combined_tile.shape, comp_tiles[comp_num].shape) 90 | comps_combined_tile += comp_tiles[comp_num] 91 | 92 | # generate tiles and corresponding idxs around component 93 | comp_surround_tiles[comp_num], comp_surround_idxs[comp_num] = \ 94 | tiling.gen_tiles_around_baseline(im_orig, comp_tile_binary, method=method, sweep_dim=sweep_dim) 95 | 96 | # predict for all tiles 97 | comp_scores_raw_dict = {} # dictionary of {comp_num: comp_score} 98 | for comp_num in comp_nums: 99 | tiles = np.concatenate((np.expand_dims(comp_tiles[comp_num], 0), # baseline tile at 0 100 | np.expand_dims(comps_combined_tile, 0), # combined tile at 1 101 | comp_surround_tiles[comp_num])) # all others afterwards 102 | scores_raw = score_funcs.get_scores_2d(model, method, ims=tiles, im_torch=im_torch, 103 | pred_ims=pred_ims, model_type=model_type) 104 | 105 | # decipher scores 106 | score_comp = np.copy(refine_scores(scores_raw, lab_num)[0]) 107 | scores_tiles = np.copy(refine_scores(scores_raw, lab_num)[2:]) 108 | 109 | # store the predicted class scores 110 | comp_scores_raw_dict[comp_num] = np.copy(scores_raw[0]) 111 | score_comps_raw_combined = np.copy(scores_raw[1]) 112 | 113 | # update pixel scores 114 | tiles_idxs = comp_surround_idxs[comp_num] 115 | for i in range(len(scores_tiles)): 116 | (r, c) = tiles_idxs[i] 117 | scores_track[r, c] = np.max(scores_tiles[i] - score_comp) # todo: subtract off previous comp / weight? 118 | 119 | # get class preds and thresholded image 120 | scores_track[im_thresh_list[-1]] = np.nan 121 | im_thresh = threshold_scores(scores_track, percentile_include) 122 | im_thresh_smoothed = smooth_im_thresh(im_thresh_list[-1], im_thresh) 123 | 124 | # add to lists 125 | scores_list.append(np.copy(scores_track)) 126 | im_thresh_list.append(im_thresh_smoothed) 127 | comps_list.append(comps) 128 | comp_scores_raw_list.append(comp_scores_raw_dict) 129 | comp_scores_raw_combined_list.append(score_comps_raw_combined) 130 | 131 | # pad first image 132 | comps_list = [np.zeros(im_orig.shape)] + comps_list 133 | 134 | lists = {'scores_list': scores_list, # float arrs of scores tracked over time (NaN for already picked) 135 | 'im_thresh_list': im_thresh_list, # boolean array of selected pixels over time 136 | 'comps_list': comps_list, # numpy arrs (each component is a different number, 0 for background) 137 | 'comp_scores_raw_list': comp_scores_raw_list, # dicts, each key is a number corresponding to a component 138 | 'comp_scores_raw_combined_list': comp_scores_raw_combined_list, 139 | # arrs representing scores for all current comps combined 140 | 'scores_orig_raw': scores_orig_raw, 141 | 'num_before_final': len(im_thresh_list)} # one arr with original scores of pixels 142 | lists = agglomerate_final(lists, model, pred_ims, percentile_include, method, sweep_dim, 143 | im_orig, lab_num, num_iters=5, im_torch=im_torch, model_type=model_type) 144 | 145 | return lists 146 | 147 | 148 | def refine_scores(scores, lab_num): 149 | '''How to convert scores to meaningful metric 150 | ''' 151 | return scores[:, lab_num] 152 | 153 | 154 | # higher scores are more likely to be picked 155 | def threshold_scores(scores, percentile_include): 156 | # pick more when more is already picked 157 | num_picked = np.sum(np.isnan(scores)) 158 | if num_picked > scores.size / 3: 159 | percentile_include -= 15 160 | 161 | thresh = np.nanpercentile(scores, percentile_include) 162 | # thresh = np.max(X) # pick only 1 pixel at a time 163 | im_thresh = np.logical_and(scores >= thresh, ~np.isnan(scores)) 164 | # scores >= thresh #np.logical_and(scores >= thresh, scores != 0) 165 | 166 | # make sure we pick something 167 | while np.sum(im_thresh) == 0: 168 | percentile_include -= 4 169 | thresh = np.nanpercentile(scores, percentile_include) 170 | # thresh = np.max(X) # pick only 1 pixel at a time 171 | im_thresh = np.logical_and(scores >= thresh, ~np.isnan(scores)) 172 | # np.logical_and(scores >= thresh, scores != 0) 173 | return im_thresh 174 | 175 | 176 | def smooth_im_thresh(im_thresh_old, im_thresh): 177 | '''Bias towards picking smoother components 178 | ''' 179 | im = im_thresh_old + im_thresh 180 | im_count_neighbors = convolve2d(im, FILT, mode='same') 181 | pixels_to_add = np.logical_and(np.logical_not(im), im_count_neighbors >= 3) 182 | return im + pixels_to_add 183 | 184 | 185 | def establish_correspondence(seg1, seg2): 186 | '''Establish correspondence between 2 segmentations of an image 187 | ''' 188 | seg_out = np.zeros(seg1.shape, dtype='int64') 189 | new_counter = 0 190 | 191 | num_segs = int(np.max(seg2)) 192 | remaining = list(range(1, 12)) # only have 10 colors though 193 | for i in range(1, num_segs + 1): 194 | seg = seg2 == i 195 | old_seg = seg1[seg] 196 | matches = np.unique(old_seg[old_seg != 0]) 197 | num_matches = matches.size 198 | 199 | # new seg 200 | if num_matches == 0: 201 | new_counter -= 1 202 | seg_out[seg] = new_counter 203 | 204 | # 1 match 205 | elif num_matches == 1: 206 | seg_out[seg] = matches[0] 207 | remaining.remove(matches[0]) 208 | 209 | # >1 matches (segs merged) 210 | else: 211 | seg_out[seg] = min(matches) 212 | remaining.remove(min(matches)) 213 | 214 | # assign new segs 215 | while new_counter < 0: 216 | seg_out[seg_out == new_counter] = min(remaining) 217 | remaining.remove(min(remaining)) 218 | new_counter += 1 219 | 220 | return seg_out # seg2 221 | 222 | 223 | def agglomerate_final(lists, model, pred_ims, percentile_include, method, sweep_dim, 224 | im_orig, lab_num, num_iters=5, im_torch=None, model_type='mnist'): 225 | '''Postprocess the final segmentation by joining the remaining segments 226 | ''' 227 | # while multiple types of blobs 228 | while (np.unique(lists['comps_list'][-1]).size > 2): 229 | # for q in range(3): 230 | comps = np.copy(lists['comps_list'][-1]) 231 | comp_scores_raw_dict = deepcopy(lists['comp_scores_raw_list'][-1]) 232 | 233 | # todo: initially merge really small blobs with nearest big blobs 234 | # if q == 0: 235 | 236 | # make tiles by combining pairs in comps 237 | comp_tiles = {} # stores tiles corresponding to each tile 238 | for comp_num in np.unique(comps): 239 | if comp_num > 0: 240 | # make component tile 241 | comp_tile_downsampled = (comps == comp_num) 242 | comp_tiles[comp_num] = tiling.gen_tile_from_comp(im_orig, comp_tile_downsampled, 243 | sweep_dim, method) # this is full size 244 | 245 | # make combined tiles 246 | comp_tiles_comb = {} 247 | for comp_num1 in np.unique(comps): 248 | for comp_num2 in np.unique(comps): 249 | if 0 < comp_num1 < comp_num2: 250 | comp_tiles_comb[(comp_num1, comp_num2)] = tiling.combine_tiles(comp_tiles[comp_num1], 251 | comp_tiles[comp_num2], method) 252 | 253 | # predict for all tiles 254 | comp_max_score_diff = -1e10 255 | comp_max_key_pair = None 256 | comp_max_scores_raw = None 257 | for key in comp_tiles_comb.keys(): 258 | # calculate scores 259 | tiles = 1.0 * np.expand_dims(comp_tiles_comb[key], 0) 260 | scores_raw = score_funcs.get_scores_2d(model, method, ims=tiles, im_torch=im_torch, 261 | pred_ims=pred_ims, model_type=model_type) 262 | 263 | # refine scores for correct class - todo this doesn't work with refine_scores 264 | score_comp = np.copy(refine_scores(scores_raw, lab_num)[0]) 265 | # score_orig = np.max(refine_scores(np.expand_dims(comp_scores_raw_dict[key[0]], 0), lab_num)[0], 266 | # refine_scores(np.expand_dims(comp_scores_raw_dict[key[1]], 0), lab_num)[0]) 267 | score_orig = max(comp_scores_raw_dict[key[0]][lab_num], comp_scores_raw_dict[key[1]][lab_num]) 268 | score_diff = score_comp - score_orig 269 | 270 | # find best score 271 | if score_diff > comp_max_score_diff: 272 | comp_max_score_diff = score_diff 273 | comp_max_key_pair = key 274 | comp_max_scores_raw = np.copy(scores_raw[0]) # store the predicted class scores 275 | 276 | # merge highest scoring blob pair 277 | comps[comps == comp_max_key_pair[1]] = comp_max_key_pair[0] 278 | 279 | # update highest scoring blob pair score 280 | comp_scores_raw_dict[comp_max_key_pair[0]] = comp_max_scores_raw 281 | comp_scores_raw_dict.pop(comp_max_key_pair[1]) 282 | 283 | # add to lists 284 | lists['comps_list'].append(comps) 285 | lists['comp_scores_raw_list'].append(comp_scores_raw_dict) 286 | lists['scores_list'].append(lists['scores_list'][-1]) 287 | lists['im_thresh_list'].append(lists['im_thresh_list'][-1]) 288 | lists['comp_scores_raw_combined_list'].append(lists['comp_scores_raw_combined_list'][-1]) 289 | 290 | return lists 291 | -------------------------------------------------------------------------------- /acd/readme.md: -------------------------------------------------------------------------------- 1 | # source for calculating ACD interpretations 2 | 3 | - [scores](scores) folder contains code for calculating different importance 4 | - `cd.py` file is the entry point for calculating CD scores 5 | - `cd_propagate.py` files contain code to calculate CD score across individual layers 6 | - `cd_architecture_specific.py` contains implementations for some specific architectures 7 | - `score_funcs.py` contains implementations of baselines and wrappers to return different scores 8 | - [agglomeration](agglomeration) folder contains code for aggregating scores to produce hierarchical interpretations 9 | - `agg_1d` is for text-like inputs and produces a sequence of 1-d components 10 | - `agg_2d` is for image-like inputs and produces a sequence of image segmentations 11 | - [util](util) scripts are used here for generating appropriately sized segments 12 | - there are a couple [tests](../tests) for some of this functionality as well 13 | 14 | *note: most of the code is separated by 1d (for 1d inputs, such as text) and 2d (for 2d inputs, such as images)* -------------------------------------------------------------------------------- /acd/scores/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/acd/scores/__init__.py -------------------------------------------------------------------------------- /acd/scores/cd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.special import expit as sigmoid 5 | from .cd_propagate import * 6 | from .cd_architecture_specific import * 7 | 8 | 9 | def cd(im_torch: torch.Tensor, model, mask=None, model_type=None, device='cuda', transform=None): 10 | '''Get contextual decomposition scores for some set of inputs for a specific image 11 | 12 | Params 13 | ------ 14 | im_torch: torch.Tensor 15 | example to interpret - usually has shape (batch_size, num_channels, height, width) 16 | model: pytorch model 17 | mask: array_like (values in {0, 1}) 18 | required unless transform is supplied 19 | array with 1s marking the locations of relevant pixels, 0s marking the background 20 | shape should match the shape of im_torch or just H x W 21 | model_type: str, optional 22 | usually should just leave this blank 23 | if this is == 'mnist', uses CD for a specific mnist model 24 | if this is == 'resnet18', uses resnet18 model 25 | device: str, optional 26 | transform: function, optional 27 | transform should be a function which transforms the original image to specify rel 28 | only used if mask is not passed 29 | 30 | Returns 31 | ------- 32 | relevant: torch.Tensor 33 | class-wise scores for relevant mask 34 | irrelevant: torch.Tensor 35 | class-wise scores for everything but the relevant mask 36 | ''' 37 | # set up model 38 | model.eval() 39 | model = model.to(device) 40 | im_torch = im_torch.to(device) 41 | 42 | # set up relevant/irrelevant based on mask 43 | if mask is not None: 44 | mask = torch.FloatTensor(mask).to(device) 45 | relevant = mask * im_torch 46 | irrelevant = (1 - mask) * im_torch 47 | elif transform is not None: 48 | relevant = transform(im_torch).to(device) 49 | if len(relevant.shape) < 4: 50 | relevant = relevant.reshape(1, 1, relevant.shape[0], relevant.shape[1]) 51 | irrelevant = im_torch - relevant 52 | else: 53 | print('mask or transform arguments required!') 54 | relevant = relevant.to(device) 55 | irrelevant = irrelevant.to(device) 56 | 57 | # deal with specific architectures which cannot be handled generically 58 | if model_type == 'mnist': 59 | return cd_propagate_mnist(relevant, irrelevant, model) 60 | elif model_type == 'resnet18': 61 | return cd_propagate_resnet(relevant, irrelevant, model) 62 | 63 | # try the generic case 64 | else: 65 | mods = list(model.modules()) 66 | relevant, irrelevant = cd_generic(mods, relevant, irrelevant) 67 | return relevant, irrelevant 68 | 69 | 70 | def cd_generic(mods, relevant, irrelevant): 71 | '''Helper function for cd which loops over modules and propagates them 72 | based on the layer name 73 | ''' 74 | for i, mod in enumerate(mods): 75 | t = str(type(mod)) 76 | if 'Conv2d' in t: 77 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mod) 78 | elif 'Linear' in t: 79 | relevant = relevant.reshape(relevant.shape[0], -1) 80 | irrelevant = irrelevant.reshape(irrelevant.shape[0], -1) 81 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mod) 82 | elif 'ReLU' in t: 83 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mod) 84 | elif 'AvgPool' in t or 'NormLayer' in t or 'Dropout' in t \ 85 | or 'ReshapeLayer' in t or ('modularize' in t and 'Transform' in t): # custom layers 86 | relevant, irrelevant = propagate_independent(relevant, irrelevant, mod) 87 | elif 'Pool' in t and not 'AvgPool' in t: 88 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, mod) 89 | elif 'BatchNorm2d' in t: 90 | relevant, irrelevant = propagate_batchnorm2d(relevant, irrelevant, mod) 91 | return relevant, irrelevant 92 | 93 | 94 | def cd_text(batch, model, start, stop, return_irrel_scores=False): 95 | '''Get contextual decomposition scores for substring of a text sequence 96 | 97 | Params 98 | ------ 99 | batch: torchtext batch 100 | really only requires that batch.text is the string input to be interpreted 101 | start: int 102 | beginning index of substring to be interpreted (inclusive) 103 | stop: int 104 | ending index of substring to be interpreted (inclusive) 105 | 106 | Returns 107 | ------- 108 | scores: torch.Tensor 109 | class-wise scores for relevant substring 110 | ''' 111 | weights = model.lstm.state_dict() 112 | 113 | # Index one = word vector (i) or hidden state (h), index two = gate 114 | W_ii, W_if, W_ig, W_io = np.split(weights['weight_ih_l0'], 4, 0) 115 | W_hi, W_hf, W_hg, W_ho = np.split(weights['weight_hh_l0'], 4, 0) 116 | b_i, b_f, b_g, b_o = np.split(weights['bias_ih_l0'].cpu().numpy() + weights['bias_hh_l0'].cpu().numpy(), 4) 117 | word_vecs = model.embed(batch.text)[:, 0].data 118 | T = word_vecs.size(0) 119 | relevant = np.zeros((T, model.hidden_dim)) 120 | irrelevant = np.zeros((T, model.hidden_dim)) 121 | relevant_h = np.zeros((T, model.hidden_dim)) 122 | irrelevant_h = np.zeros((T, model.hidden_dim)) 123 | for i in range(T): 124 | if i > 0: 125 | prev_rel_h = relevant_h[i - 1] 126 | prev_irrel_h = irrelevant_h[i - 1] 127 | else: 128 | prev_rel_h = np.zeros(model.hidden_dim) 129 | prev_irrel_h = np.zeros(model.hidden_dim) 130 | 131 | rel_i = np.dot(W_hi, prev_rel_h) 132 | rel_g = np.dot(W_hg, prev_rel_h) 133 | rel_f = np.dot(W_hf, prev_rel_h) 134 | rel_o = np.dot(W_ho, prev_rel_h) 135 | irrel_i = np.dot(W_hi, prev_irrel_h) 136 | irrel_g = np.dot(W_hg, prev_irrel_h) 137 | irrel_f = np.dot(W_hf, prev_irrel_h) 138 | irrel_o = np.dot(W_ho, prev_irrel_h) 139 | 140 | if i >= start and i <= stop: 141 | rel_i = rel_i + np.dot(W_ii, word_vecs[i]) 142 | rel_g = rel_g + np.dot(W_ig, word_vecs[i]) 143 | rel_f = rel_f + np.dot(W_if, word_vecs[i]) 144 | rel_o = rel_o + np.dot(W_io, word_vecs[i]) 145 | else: 146 | irrel_i = irrel_i + np.dot(W_ii, word_vecs[i]) 147 | irrel_g = irrel_g + np.dot(W_ig, word_vecs[i]) 148 | irrel_f = irrel_f + np.dot(W_if, word_vecs[i]) 149 | irrel_o = irrel_o + np.dot(W_io, word_vecs[i]) 150 | 151 | rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i, sigmoid) 152 | rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g, np.tanh) 153 | 154 | relevant[i] = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g 155 | irrelevant[i] = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + ( 156 | rel_contrib_i + bias_contrib_i) * irrel_contrib_g 157 | 158 | if i >= start and i <= stop: 159 | relevant[i] += bias_contrib_i * bias_contrib_g 160 | else: 161 | irrelevant[i] += bias_contrib_i * bias_contrib_g 162 | 163 | if i > 0: 164 | rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f, sigmoid) 165 | relevant[i] += (rel_contrib_f + bias_contrib_f) * relevant[i - 1] 166 | irrelevant[i] += (rel_contrib_f + irrel_contrib_f + bias_contrib_f) * irrelevant[i - 1] + irrel_contrib_f * \ 167 | relevant[i - 1] 168 | 169 | o = sigmoid(np.dot(W_io, word_vecs[i]) + np.dot(W_ho, prev_rel_h + prev_irrel_h) + b_o) 170 | rel_contrib_o, irrel_contrib_o, bias_contrib_o = propagate_three(rel_o, irrel_o, b_o, sigmoid) 171 | new_rel_h, new_irrel_h = propagate_tanh_two(relevant[i], irrelevant[i]) 172 | # relevant_h[i] = new_rel_h * (rel_contrib_o + bias_contrib_o) 173 | # irrelevant_h[i] = new_rel_h * (irrel_contrib_o) + new_irrel_h * (rel_contrib_o + irrel_contrib_o + bias_contrib_o) 174 | relevant_h[i] = o * new_rel_h 175 | irrelevant_h[i] = o * new_irrel_h 176 | 177 | W_out = model.hidden_to_label.weight.data 178 | 179 | # Sanity check: scores + irrel_scores should equal the LSTM's output minus model.hidden_to_label.bias 180 | scores = np.dot(W_out, relevant_h[T - 1]) 181 | irrel_scores = np.dot(W_out, irrelevant_h[T - 1]) 182 | 183 | if return_irrel_scores: 184 | return scores, irrel_scores 185 | 186 | return scores 187 | -------------------------------------------------------------------------------- /acd/scores/cd_architecture_specific.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .cd_propagate import * 5 | 6 | 7 | def cd_propagate_resnet(rel, irrel, model): 8 | '''Propagate a resnet architecture 9 | each BasicBlock passes its input through to its output (might need to downsample) 10 | note: the bigger resnets use BottleNeck instead of BasicBlock 11 | ''' 12 | mods = list(model.modules()) 13 | from .cd import cd_generic 14 | ''' 15 | # mods[1:5] 16 | x = self.conv1(x) 17 | x = self.bn1(x) 18 | x = self.relu(x) 19 | x = self.maxpool(x) 20 | 21 | # mods[5, 18, 34, 50] 22 | x = self.layer1(x) 23 | x = self.layer2(x) 24 | x = self.layer3(x) 25 | x = self.layer4(x) 26 | 27 | x = self.avgpool(x) 28 | x = torch.flatten(x, 1) 29 | x = self.fc(x) 30 | ''' 31 | 32 | rel, irrel = cd_generic(mods[1:5], rel, irrel) 33 | 34 | lay_nums = [5, 18, 34, 50] 35 | for lay_num in lay_nums: 36 | for basic_block in mods[lay_num]: 37 | rel, irrel = propagate_basic_block(rel, irrel, basic_block) 38 | 39 | # final things after BasicBlocks 40 | rel, irrel = cd_generic(mods[-2:], rel, irrel) 41 | return rel, irrel 42 | 43 | 44 | def cd_propagate_mnist(relevant, irrelevant, model): 45 | '''Propagate a specific mnist architecture 46 | The reason we can't automatically get this score with cd_generic is because 47 | the model.modules() is missing some things like self.maxpool, and self.Relu 48 | because the model file only defined these things in the forward method 49 | ''' 50 | mods = list(model.modules())[1:] 51 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[0]) 52 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, 53 | lambda x: F.max_pool2d(x, 2, return_indices=True)) 54 | relevant, irrelevant = propagate_relu(relevant, irrelevant, F.relu) 55 | 56 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[1]) 57 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, 58 | lambda x: F.max_pool2d(x, 2, return_indices=True)) 59 | relevant, irrelevant = propagate_relu(relevant, irrelevant, F.relu) 60 | 61 | relevant = relevant.view(-1, 320) 62 | irrelevant = irrelevant.view(-1, 320) 63 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[3]) 64 | relevant, irrelevant = propagate_relu(relevant, irrelevant, F.relu) 65 | 66 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[4]) 67 | 68 | return relevant, irrelevant 69 | 70 | 71 | def cd_track_vgg(blob, im_torch, model, model_type='vgg'): 72 | '''This implementation of cd is very long so that we can view CD at intermediate layers 73 | In reality, one should use the loop contained in the above cd function 74 | ''' 75 | # set up model 76 | model.eval() 77 | 78 | # set up blobs 79 | blob = torch.cuda.FloatTensor(blob) 80 | relevant = blob * im_torch 81 | irrelevant = (1 - blob) * im_torch 82 | 83 | mods = list(model.modules())[2:] 84 | scores = [] 85 | # (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 86 | # (1): ReLU(inplace) 87 | # (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 88 | # (3): ReLU(inplace) 89 | # (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 90 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[0]) 91 | scores.append((relevant.clone(), irrelevant.clone())) 92 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[1]) 93 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[2]) 94 | scores.append((relevant.clone(), irrelevant.clone())) 95 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[3]) 96 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, mods[4]) 97 | 98 | # (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 99 | # (6): ReLU(inplace) 100 | # (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 101 | # (8): ReLU(inplace) 102 | # (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 103 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[5]) 104 | scores.append((relevant.clone(), irrelevant.clone())) 105 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[6]) 106 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[7]) 107 | scores.append((relevant.clone(), irrelevant.clone())) 108 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[8]) 109 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, mods[9]) 110 | 111 | # (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 112 | # (11): ReLU(inplace) 113 | # (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 114 | # (13): ReLU(inplace) 115 | # (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 116 | # (15): ReLU(inplace) 117 | # (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 118 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[10]) 119 | scores.append((relevant.clone(), irrelevant.clone())) 120 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[11]) 121 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[12]) 122 | scores.append((relevant.clone(), irrelevant.clone())) 123 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[13]) 124 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[14]) 125 | scores.append((relevant.clone(), irrelevant.clone())) 126 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[15]) 127 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, mods[16]) 128 | # scores.append((relevant.clone(), irrelevant.clone())) 129 | # (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 130 | # (18): ReLU(inplace) 131 | # (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 132 | # (20): ReLU(inplace) 133 | # (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 134 | # (22): ReLU(inplace) 135 | # (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 136 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[17]) 137 | scores.append((relevant.clone(), irrelevant.clone())) 138 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[18]) 139 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[19]) 140 | scores.append((relevant.clone(), irrelevant.clone())) 141 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[20]) 142 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[21]) 143 | scores.append((relevant.clone(), irrelevant.clone())) 144 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[22]) 145 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, mods[23]) 146 | # scores.append((relevant.clone(), irrelevant.clone())) 147 | # (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 148 | # (25): ReLU(inplace) 149 | # (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 150 | # (27): ReLU(inplace) 151 | # (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 152 | # (29): ReLU(inplace) 153 | # (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 154 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[24]) 155 | scores.append((relevant.clone(), irrelevant.clone())) 156 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[25]) 157 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[26]) 158 | scores.append((relevant.clone(), irrelevant.clone())) 159 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[27]) 160 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[28]) 161 | scores.append((relevant.clone(), irrelevant.clone())) 162 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[29]) 163 | relevant, irrelevant = propagate_pooling(relevant, irrelevant, mods[30]) 164 | # scores.append((relevant.clone(), irrelevant.clone())) 165 | 166 | relevant = relevant.view(relevant.size(0), -1) 167 | irrelevant = irrelevant.view(irrelevant.size(0), -1) 168 | 169 | # (classifier): Sequential( 170 | # (0): Linear(in_features=25088, out_features=4096) 171 | # (1): ReLU(inplace) 172 | # (2): Dropout(p=0.5) 173 | # (3): Linear(in_features=4096, out_features=4096) 174 | # (4): ReLU(inplace) 175 | # (5): Dropout(p=0.5) 176 | # (6): Linear(in_features=4096, out_features=1000) 177 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[32]) 178 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[33]) 179 | relevant, irrelevant = propagate_dropout(relevant, irrelevant, mods[34]) 180 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[35]) 181 | relevant, irrelevant = propagate_relu(relevant, irrelevant, mods[36]) 182 | relevant, irrelevant = propagate_dropout(relevant, irrelevant, mods[37]) 183 | relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mods[38]) 184 | 185 | return relevant, irrelevant, scores 186 | -------------------------------------------------------------------------------- /acd/scores/cd_propagate.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.special import expit as sigmoid 6 | from torch import tanh 7 | 8 | 9 | def propagate_conv_linear(relevant, irrelevant, module): 10 | '''Propagate convolutional or linear layer 11 | Apply linear part to both pieces 12 | Split bias based on the ratio of the absolute sums 13 | ''' 14 | device = relevant.device 15 | bias = module(torch.zeros(irrelevant.size()).to(device)) 16 | rel = module(relevant) - bias 17 | irrel = module(irrelevant) - bias 18 | 19 | # elementwise proportional 20 | prop_rel = torch.abs(rel) + 1e-20 # add a small constant so we don't divide by 0 21 | prop_irrel = torch.abs(irrel) + 1e-20 # add a small constant so we don't divide by 0 22 | prop_sum = prop_rel + prop_irrel 23 | prop_rel = torch.div(prop_rel, prop_sum) 24 | prop_irrel = torch.div(prop_irrel, prop_sum) 25 | return rel + torch.mul(prop_rel, bias), irrel + torch.mul(prop_irrel, bias) 26 | 27 | 28 | def propagate_batchnorm2d(relevant, irrelevant, module): 29 | '''Propagate batchnorm2d operation 30 | ''' 31 | device = relevant.device 32 | bias = module(torch.zeros(irrelevant.size()).to(device)) 33 | rel = module(relevant) - bias 34 | irrel = module(irrelevant) - bias 35 | prop_rel = torch.abs(rel) 36 | prop_irrel = torch.abs(irrel) 37 | prop_sum = prop_rel + prop_irrel 38 | prop_rel = torch.div(prop_rel, prop_sum) 39 | prop_rel[torch.isnan(prop_rel)] = 0 40 | rel = rel + torch.mul(prop_rel, bias) 41 | irrel = module(relevant + irrelevant) - rel 42 | return rel, irrel 43 | 44 | 45 | def propagate_pooling(relevant, irrelevant, pooler): 46 | '''propagate pooling operation 47 | ''' 48 | # get both indices 49 | p = deepcopy(pooler) 50 | p.return_indices = True 51 | both, both_ind = p(relevant + irrelevant) 52 | 53 | # unpooling function 54 | def unpool(tensor, indices): 55 | '''Unpool tensor given indices for pooling 56 | ''' 57 | batch_size, in_channels, H, W = indices.shape 58 | output = torch.ones_like(indices, dtype=torch.float) 59 | for i in range(batch_size): 60 | for j in range(in_channels): 61 | output[i, j] = tensor[i, j].flatten()[indices[i, j].flatten()].reshape(H, W) 62 | return output 63 | 64 | rel, irrel = unpool(relevant, both_ind), unpool(irrelevant, both_ind) 65 | return rel, irrel 66 | 67 | 68 | def propagate_independent(relevant, irrelevant, module): 69 | '''use for things which operate independently 70 | ex. avgpool, layer_norm, dropout 71 | ''' 72 | return module(relevant), module(irrelevant) 73 | 74 | 75 | def propagate_relu(relevant, irrelevant, activation): 76 | '''propagate ReLu nonlinearity 77 | ''' 78 | swap_inplace = False 79 | try: # handles inplace 80 | if activation.inplace: 81 | swap_inplace = True 82 | activation.inplace = False 83 | except: 84 | pass 85 | rel_score = activation(relevant) 86 | irrel_score = activation(relevant + irrelevant) - activation(relevant) 87 | if swap_inplace: 88 | activation.inplace = True 89 | return rel_score, irrel_score 90 | 91 | 92 | def propagate_three(a, b, c, activation): 93 | '''Propagate a three-part nonlinearity 94 | ''' 95 | a_contrib = 0.5 * (activation(a + c) - activation(c) + activation(a + b + c) - activation(b + c)) 96 | b_contrib = 0.5 * (activation(b + c) - activation(c) + activation(a + b + c) - activation(a + c)) 97 | return a_contrib, b_contrib, activation(c) 98 | 99 | 100 | def propagate_tanh_two(a, b): 101 | '''propagate tanh nonlinearity 102 | ''' 103 | return 0.5 * (np.tanh(a) + (np.tanh(a + b) - np.tanh(b))), 0.5 * (np.tanh(b) + (np.tanh(a + b) - np.tanh(a))) 104 | 105 | 106 | def propagate_basic_block(rel, irrel, module): 107 | '''Propagate a BasicBlock (used in the ResNet architectures) 108 | This is what the forward pass of the basic block looks like 109 | identity = x 110 | 111 | out = self.conv1(x) # 1 112 | out = self.bn1(out) # 2 113 | out = self.relu(out) # 3 114 | out = self.conv2(out) # 4 115 | out = self.bn2(out) # 5 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(x) 119 | 120 | out += identity 121 | out = self.relu(out) 122 | ''' 123 | from .cd import cd_generic 124 | # for mod in module.modules(): 125 | # print('\tm', mod) 126 | rel_identity, irrel_identity = deepcopy(rel), deepcopy(irrel) 127 | rel, irrel = cd_generic(list(module.modules())[1:6], rel, irrel) 128 | 129 | if module.downsample is not None: 130 | rel_identity, irrel_identity = cd_generic(module.downsample.modules(), rel_identity, irrel_identity) 131 | 132 | rel += rel_identity 133 | irrel += irrel_identity 134 | rel, irrel = propagate_relu(rel, irrel, module.relu) 135 | 136 | return rel, irrel 137 | 138 | 139 | def propagate_lstm(x, module, start: int, stop: int, my_device=0): 140 | '''module is an lstm layer 141 | 142 | Params 143 | ------ 144 | module: lstm layer 145 | x: torch.Tensor 146 | (batch_size, seq_len, num_channels) 147 | warning: default lstm uses shape (seq_len, batch_size, num_channels) 148 | start: int 149 | start of relevant sequence 150 | stop: int 151 | end of relevant sequence 152 | 153 | Returns 154 | ------- 155 | rel, irrel: torch.Tensor 156 | (batch_size, num_channels, num_hidden_lstm) 157 | ''' 158 | 159 | # extract out weights 160 | W_ii, W_if, W_ig, W_io = torch.chunk(module.weight_ih_l0, 4, 0) 161 | W_hi, W_hf, W_hg, W_ho = torch.chunk(module.weight_hh_l0, 4, 0) 162 | b_i, b_f, b_g, b_o = torch.chunk(module.bias_ih_l0 + module.bias_hh_l0, 4) 163 | 164 | # prepare input x 165 | # x_orig = deepcopy(x) 166 | x = x.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size) 167 | seq_len = x.shape[0] 168 | batch_size = x.shape[2] 169 | output_dim = W_ho.shape[1] 170 | relevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 171 | irrelevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 172 | prev_rel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 173 | prev_irrel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 174 | for i in range(seq_len): 175 | prev_rel_h = relevant_h 176 | prev_irrel_h = irrelevant_h 177 | rel_i = torch.matmul(W_hi, prev_rel_h) 178 | rel_g = torch.matmul(W_hg, prev_rel_h) 179 | rel_f = torch.matmul(W_hf, prev_rel_h) 180 | rel_o = torch.matmul(W_ho, prev_rel_h) 181 | irrel_i = torch.matmul(W_hi, prev_irrel_h) 182 | irrel_g = torch.matmul(W_hg, prev_irrel_h) 183 | irrel_f = torch.matmul(W_hf, prev_irrel_h) 184 | irrel_o = torch.matmul(W_ho, prev_irrel_h) 185 | 186 | if i >= start and i <= stop: 187 | rel_i = rel_i + torch.matmul(W_ii, x[i]) 188 | rel_g = rel_g + torch.matmul(W_ig, x[i]) 189 | rel_f = rel_f + torch.matmul(W_if, x[i]) 190 | # rel_o = rel_o + torch.matmul(W_io, x[i]) 191 | else: 192 | irrel_i = irrel_i + torch.matmul(W_ii, x[i]) 193 | irrel_g = irrel_g + torch.matmul(W_ig, x[i]) 194 | irrel_f = irrel_f + torch.matmul(W_if, x[i]) 195 | # irrel_o = irrel_o + torch.matmul(W_io, x[i]) 196 | 197 | rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i[:, None], sigmoid) 198 | rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g[:, None], tanh) 199 | 200 | relevant = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g 201 | irrelevant = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + ( 202 | rel_contrib_i + bias_contrib_i) * irrel_contrib_g 203 | 204 | if i >= start and i < stop: 205 | relevant = relevant + bias_contrib_i * bias_contrib_g 206 | else: 207 | irrelevant = irrelevant + bias_contrib_i * bias_contrib_g 208 | 209 | if i > 0: 210 | rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f[:, None], sigmoid) 211 | relevant = relevant + (rel_contrib_f + bias_contrib_f) * prev_rel 212 | irrelevant = irrelevant + ( 213 | rel_contrib_f + irrel_contrib_f + bias_contrib_f) * prev_irrel + irrel_contrib_f * prev_rel 214 | 215 | o = sigmoid(torch.matmul(W_io, x[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o[:, None]) 216 | new_rel_h, new_irrel_h = propagate_tanh_two(relevant, irrelevant) 217 | 218 | relevant_h = o * new_rel_h 219 | irrelevant_h = o * new_irrel_h 220 | prev_rel = relevant 221 | prev_irrel = irrelevant 222 | 223 | # outputs, (h1, c1) = module(x_orig) 224 | # assert np.allclose((relevant_h + irrelevant_h).detach().numpy().flatten(), 225 | # h1.detach().numpy().flatten(), rtol=0.01) 226 | 227 | # reshape output 228 | rel_h = relevant_h.transpose(0, 1).unsqueeze(1) 229 | irrel_h = irrelevant_h.transpose(0, 1).unsqueeze(1) 230 | return rel_h, irrel_h 231 | 232 | def propagate_lstm_block(x_rel, x_irrel, module, start: int, stop: int, my_device=0): 233 | '''module is an lstm layer. This function still experimental 234 | 235 | Params 236 | ------ 237 | module: lstm layer 238 | x_rel: torch.Tensor 239 | (batch_size, seq_len, num_channels) 240 | warning: default lstm uses shape (seq_len, batch_size, num_channels) 241 | x_irrel: torch.Tensor 242 | (batch_size, seq_len, num_channels) 243 | start: int 244 | start of relevant sequence 245 | stop: int 246 | end of relevant sequence 247 | weights: torch.Tensor 248 | (seq_len) 249 | 250 | Returns 251 | ------- 252 | rel, irrel: torch.Tensor 253 | (batch_size, num_channels, num_hidden_lstm) 254 | ''' 255 | 256 | # ex_reltract out weights 257 | W_ii, W_if, W_ig, W_io = torch.chunk(module.weight_ih_l0, 4, 0) 258 | W_hi, W_hf, W_hg, W_ho = torch.chunk(module.weight_hh_l0, 4, 0) 259 | b_i, b_f, b_g, b_o = torch.chunk(module.bias_ih_l0 + module.bias_hh_l0, 4) 260 | 261 | # prepare input x 262 | # x_orig = deepcopy(x) 263 | x_rel = x_rel.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size) 264 | x_irrel = x_irrel.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size) 265 | x = x_rel + x_irrel 266 | # print('shapes', x_rel.shape, x_irrel.shape, x.shape) 267 | seq_len = x_rel.shape[0] 268 | batch_size = x_rel.shape[2] 269 | output_dim = W_ho.shape[1] 270 | relevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 271 | irrelevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 272 | prev_rel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 273 | prev_irrel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) 274 | for i in range(seq_len): 275 | prev_rel_h = relevant_h 276 | prev_irrel_h = irrelevant_h 277 | rel_i = torch.matmul(W_hi, prev_rel_h) 278 | rel_g = torch.matmul(W_hg, prev_rel_h) 279 | rel_f = torch.matmul(W_hf, prev_rel_h) 280 | rel_o = torch.matmul(W_ho, prev_rel_h) 281 | irrel_i = torch.matmul(W_hi, prev_irrel_h) 282 | irrel_g = torch.matmul(W_hg, prev_irrel_h) 283 | irrel_f = torch.matmul(W_hf, prev_irrel_h) 284 | irrel_o = torch.matmul(W_ho, prev_irrel_h) 285 | 286 | # relevant parts 287 | rel_i = rel_i + torch.matmul(W_ii, x_rel[i]) 288 | rel_g = rel_g + torch.matmul(W_ig, x_rel[i]) 289 | rel_f = rel_f + torch.matmul(W_if, x_rel[i]) 290 | # rel_o = rel_o + torch.matmul(W_io, x[i]) 291 | 292 | # irrelevant parts 293 | irrel_i = irrel_i + torch.matmul(W_ii, x_irrel[i]) 294 | irrel_g = irrel_g + torch.matmul(W_ig, x_irrel[i]) 295 | irrel_f = irrel_f + torch.matmul(W_if, x_irrel[i]) 296 | # irrel_o = irrel_o + torch.matmul(W_io, x[i]) 297 | 298 | rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i[:, None], sigmoid) 299 | rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g[:, None], tanh) 300 | 301 | relevant = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + \ 302 | bias_contrib_i * rel_contrib_g 303 | irrelevant = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + \ 304 | (rel_contrib_i + bias_contrib_i) * irrel_contrib_g 305 | 306 | # if i >= start and i < stop: 307 | relevant = relevant + bias_contrib_i * bias_contrib_g 308 | # else: 309 | irrelevant = irrelevant + bias_contrib_i * bias_contrib_g 310 | 311 | if i > 0: 312 | rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f[:, None], sigmoid) 313 | relevant = relevant + (rel_contrib_f + bias_contrib_f) * prev_rel 314 | irrelevant = irrelevant + ( 315 | rel_contrib_f + irrel_contrib_f + bias_contrib_f) * prev_irrel + irrel_contrib_f * prev_rel 316 | 317 | o = sigmoid(torch.matmul(W_io, x[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o[:, None]) 318 | new_rel_h, new_irrel_h = propagate_tanh_two(relevant, irrelevant) 319 | 320 | relevant_h = o * new_rel_h 321 | irrelevant_h = o * new_irrel_h 322 | prev_rel = relevant 323 | prev_irrel = irrelevant 324 | 325 | # outputs, (h1, c1) = module(x_orig) 326 | # assert np.allclose((relevant_h + irrelevant_h).detach().numpy().flatten(), 327 | # h1.detach().numpy().flatten(), rtol=0.01) 328 | 329 | # reshape output 330 | rel_h = relevant_h.transpose(0, 1).unsqueeze(1) 331 | irrel_h = irrelevant_h.transpose(0, 1).unsqueeze(1) 332 | return rel_h, irrel_h -------------------------------------------------------------------------------- /acd/scores/score_funcs.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .cd import cd, cd_text 9 | from ..util.conv2dnp import conv2dnp 10 | 11 | 12 | def gradient_times_input_scores(im: np.ndarray, ind: int, model, device='cuda'): 13 | ''' 14 | Params 15 | ------ 16 | im: np.ndarray 17 | Image to get scores with respect to 18 | ind: int 19 | Which class to take gradient with respect to 20 | ''' 21 | ind = torch.LongTensor([np.int(ind)]).to(device) 22 | if im.grad is not None: 23 | im.grad.data.zero_() 24 | pred = model(im) 25 | crit = nn.NLLLoss() 26 | loss = crit(pred, ind) 27 | loss.backward() 28 | res = im.grad * im 29 | return res.data.cpu().numpy()[0, 0] 30 | 31 | 32 | def ig_scores_2d(model, im_torch, num_classes=10, im_size=28, sweep_dim=1, ind=None, device='cuda'): 33 | '''Compute integrated gradients scores (2D input) 34 | ''' 35 | 36 | for p in model.parameters(): 37 | if p.grad is not None: 38 | p.grad.data.zero_() 39 | 40 | # What class to produce explanations for 41 | output = np.zeros((im_size * im_size // (sweep_dim * sweep_dim), num_classes)) 42 | 43 | if ind is None: 44 | ind = range(num_classes) 45 | for class_to_explain in ind: 46 | # _, class_to_explain = model(im_torch).max(1); class_to_explain = class_to_explain.data[0] 47 | 48 | M = 100 49 | criterion = torch.nn.L1Loss(size_average=False) 50 | mult_grid = np.array(range(M)) / (M - 1) 51 | 52 | baseline = torch.zeros(im_torch.shape).to(device) 53 | 54 | input_vecs = torch.empty((M, baseline.shape[1], baseline.shape[2], baseline.shape[3]), 55 | dtype=torch.float32, 56 | device=device, requires_grad=False) 57 | ''' 58 | input_vecs = torch.Tensor(M, baseline.size(1), 59 | baseline.size(2), baseline.size(3)).to(device) 60 | input_vecs.requires_grad = True 61 | ''' 62 | for i, prop in enumerate(mult_grid): 63 | input_vecs[i].data = baseline + (prop * (im_torch.to(device) - baseline)) 64 | input_vecs.requires_grad = True 65 | 66 | # input_vecs = input_vecs 67 | 68 | out = F.softmax(model(input_vecs))[:, class_to_explain] 69 | loss = criterion(out, torch.zeros(M).to(device)) 70 | loss.backward() 71 | 72 | imps = input_vecs.grad.mean(0).data.cpu() * (im_torch.data.cpu() - baseline.cpu()) 73 | ig_scores = imps.sum(1) 74 | 75 | # Sanity check: this should be small-ish 76 | # print((out[-1] - out[0]).data[0] - ig_scores.sum()) 77 | scores = ig_scores.cpu().numpy().reshape((1, im_size, im_size, 1)) 78 | kernel = np.ones(shape=(sweep_dim, sweep_dim, 1, 1)) 79 | scores_convd = conv2dnp(scores, kernel, stride=(sweep_dim, sweep_dim)) 80 | output[:, class_to_explain] = scores_convd.flatten() 81 | return output 82 | 83 | 84 | def ig_scores_1d(batch, model, inputs, device='cuda'): 85 | '''Compute integrated gradients scores (1D input) 86 | ''' 87 | for p in model.parameters(): 88 | if p.grad is not None: 89 | p.grad.data.zero_() 90 | M = 1000 91 | criterion = torch.nn.L1Loss(size_average=False) 92 | mult_grid = np.array(range(M)) / (M - 1) 93 | word_vecs = model.embed(batch.text).data 94 | baseline_text = copy.deepcopy(batch.text) 95 | baseline_text.data[:, :] = inputs.vocab.stoi['.'] 96 | baseline = model.embed(baseline_text).data 97 | input_vecs = torch.Tensor(baseline.size(0), M, baseline.size(2)).to(device) 98 | for i, prop in enumerate(mult_grid): 99 | input_vecs[:, i, :] = baseline + (prop * (word_vecs - baseline)).to(device) 100 | 101 | input_vecs = input_vecs 102 | 103 | hidden = (torch.zeros(1, M, model.hidden_dim).to(device), 104 | torch.zeros(1, M, model.hidden_dim).to(device)) 105 | lstm_out, hidden = model.lstm(input_vecs, hidden) 106 | logits = F.softmax(model.hidden_to_label(lstm_out[-1]))[:, 0] 107 | loss = criterion(logits, torch.zeros(M).to(device)) 108 | loss.backward() 109 | imps = input_vecs.grad.mean(1).data * (word_vecs[:, 0] - baseline[:, 0]) 110 | zero_pred = logits[0] 111 | scores = imps.sum(1) 112 | # for i in range(sent_len): 113 | # print(ig_scores[i], text_orig[i]) 114 | # Sanity check: this should be small-ish 115 | # print((logits[-1] - zero_pred) - ig_scores.sum()) 116 | return scores.cpu().numpy() 117 | 118 | 119 | def get_scores_1d(batch, model, method, label, only_one, score_orig, text_orig, subtract=False, device='cuda'): 120 | '''Return attribution scores for 1D input 121 | Params 122 | ------ 123 | method: str 124 | What type of method to use for attribution (e.g. cd, occlusion) 125 | 126 | Returns 127 | ------- 128 | scores: np.ndarray 129 | Higher scores are more important 130 | ''' 131 | # calculate scores 132 | if method == 'cd': 133 | if only_one: 134 | num_words = batch.text.data.cpu().numpy().shape[0] 135 | scores = np.expand_dims(cd_text(batch, model, start=0, stop=num_words), axis=0) 136 | else: 137 | starts, stops = tiles_to_cd(batch) 138 | batch.text.data = torch.LongTensor(text_orig).to(device) 139 | scores = np.array([cd_text(batch, model, start=starts[i], stop=stops[i]) 140 | for i in range(len(starts))]) 141 | else: 142 | scores = model(batch).data.cpu().numpy() 143 | if method == 'occlusion' and not only_one: 144 | scores = score_orig - scores 145 | 146 | # get score for other class 147 | if subtract: 148 | return scores[:, label] - scores[:, int(1 - label)] 149 | else: 150 | return scores[:, label] 151 | 152 | 153 | def get_scores_2d(model, method, ims, im_torch=None, pred_ims=None, model_type=None, device='cuda'): 154 | '''Return attribution scores for 2D input 155 | Params 156 | ------ 157 | method: str 158 | What type of method to use for attribution (e.g. cd, occlusion) 159 | ims: np.ndarray (1 x C x H x W ) 160 | Tiles to pass as masks to cd 161 | 162 | Returns 163 | ------- 164 | scores: np.ndarray 165 | Higher scores are more important 166 | ''' 167 | scores = [] 168 | if method == 'cd': 169 | for i in range(ims.shape[0]): # can use tqdm here, need to use batches 170 | scores.append(cd(im_torch, model, np.expand_dims(ims[i], 0), model_type, 171 | device=device)[0].data.cpu().numpy()) 172 | scores = np.squeeze(np.array(scores)) 173 | elif method == 'build_up': 174 | for i in range(ims.shape[0]): # can use tqdm here, need to use batches 175 | scores.append(pred_ims(model, ims[i])[0]) 176 | scores = np.squeeze(np.array(scores)) 177 | elif method == 'occlusion': 178 | for i in range(ims.shape[0]): # can use tqdm here, need to use batches 179 | scores.append(pred_ims(model, ims[i])[0]) 180 | scores = -1 * np.squeeze(np.array(scores)) 181 | if scores.ndim == 1: 182 | scores = scores.reshape(1, -1) 183 | return scores 184 | 185 | 186 | def tiles_to_cd(batch): 187 | '''Converts build up tiles into indices for cd 188 | Cd requires batch of [start, stop) with unigrams working 189 | build up tiles are of the form [0, 0, 12, 35, 0, 0] 190 | return a list of starts and indices 191 | ''' 192 | starts, stops = [], [] 193 | tiles = batch.text.data.cpu().numpy() 194 | L = tiles.shape[0] 195 | for c in range(tiles.shape[1]): 196 | text = tiles[:, c] 197 | start = 0 198 | stop = L - 1 199 | while text[start] == 0: 200 | start += 1 201 | while text[stop] == 0: 202 | stop -= 1 203 | starts.append(start) 204 | stops.append(stop) 205 | return starts, stops 206 | -------------------------------------------------------------------------------- /acd/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/acd/util/__init__.py -------------------------------------------------------------------------------- /acd/util/conv2dnp.py: -------------------------------------------------------------------------------- 1 | '''code from https://github.com/renmengye/np-conv2d 2 | ''' 3 | 4 | from __future__ import division 5 | 6 | import numpy as np 7 | 8 | 9 | def calc_pad(pad, in_siz, out_siz, stride, ksize): 10 | """Calculate padding width. 11 | 12 | Args: 13 | pad: padding method, "SAME", "VALID", or manually speicified. 14 | ksize: kernel size [I, J]. 15 | 16 | Returns: 17 | pad_: Actual padding width. 18 | """ 19 | if pad == 'SAME': 20 | return (out_siz - 1) * stride + ksize - in_siz 21 | elif pad == 'VALID': 22 | return 0 23 | else: 24 | return pad 25 | 26 | 27 | def calc_size(h, kh, pad, sh): 28 | """Calculate output image size on one dimension. 29 | 30 | Args: 31 | h: input image size. 32 | kh: kernel size. 33 | pad: padding strategy. 34 | sh: stride. 35 | 36 | Returns: 37 | s: output size. 38 | """ 39 | 40 | if pad == 'VALID': 41 | return np.ceil((h - kh + 1) / sh) 42 | elif pad == 'SAME': 43 | return np.ceil(h / sh) 44 | else: 45 | return int(np.ceil((h - kh + pad + 1) / sh)) 46 | 47 | 48 | def extract_sliding_windows_gradw(x, 49 | ksize, 50 | pad, 51 | stride, 52 | orig_size, 53 | floor_first=True): 54 | """Extracts dilated windows. 55 | 56 | Args: 57 | x: [N, H, W, C] 58 | k: [KH, KW] 59 | pad: [PH, PW] 60 | stride: [SH, SW] 61 | 62 | Returns: 63 | y: [N, H', W', KH, KW, C] 64 | """ 65 | n = x.shape[0] 66 | h = x.shape[1] 67 | w = x.shape[2] 68 | c = x.shape[3] 69 | kh = ksize[0] 70 | kw = ksize[1] 71 | sh = stride[0] 72 | sw = stride[1] 73 | 74 | h2 = orig_size[0] 75 | w2 = orig_size[1] 76 | ph = int(calc_pad(pad, h, h2, 1, ((kh - 1) * sh + 1))) 77 | pw = int(calc_pad(pad, w, w2, 1, ((kw - 1) * sw + 1))) 78 | 79 | ph2 = int(np.ceil(ph / 2)) 80 | ph3 = int(np.floor(ph / 2)) 81 | pw2 = int(np.ceil(pw / 2)) 82 | pw3 = int(np.floor(pw / 2)) 83 | if floor_first: 84 | pph = (ph3, ph2) 85 | ppw = (pw3, pw2) 86 | else: 87 | pph = (ph2, ph3) 88 | ppw = (pw2, pw3) 89 | x = np.pad( 90 | x, ((0, 0), (ph3, ph2), (pw3, pw2), (0, 0)), 91 | mode='constant', 92 | constant_values=(0.0,)) 93 | p2h = (-x.shape[1]) % sh 94 | p2w = (-x.shape[2]) % sw 95 | if p2h > 0 or p2w > 0: 96 | x = np.pad( 97 | x, ((0, 0), (0, p2h), (0, p2w), (0, 0)), 98 | mode='constant', 99 | constant_values=(0.0,)) 100 | x = x.reshape([n, int(x.shape[1] / sh), sh, int(x.shape[2] / sw), sw, c]) 101 | 102 | y = np.zeros([n, h2, w2, kh, kw, c]) 103 | for ii in range(h2): 104 | for jj in range(w2): 105 | h0 = int(np.floor(ii / sh)) 106 | w0 = int(np.floor(jj / sw)) 107 | y[:, ii, jj, :, :, :] = x[:, h0:h0 + kh, ii % sh, w0:w0 + kw, jj % 108 | sw, :] 109 | return y 110 | 111 | 112 | def extract_sliding_windows_gradx(x, 113 | ksize, 114 | pad, 115 | stride, 116 | orig_size, 117 | floor_first=False): 118 | """Extracts windows on a dilated image. 119 | 120 | Args: 121 | x: [N, H', W', C] (usually dy) 122 | k: [KH, KW] 123 | pad: [PH, PW] 124 | stride: [SH, SW] 125 | orig_size: [H, W] 126 | 127 | Returns: 128 | y: [N, H, W, KH, KW, C] 129 | """ 130 | n = x.shape[0] 131 | h = x.shape[1] 132 | w = x.shape[2] 133 | c = x.shape[3] 134 | kh = ksize[0] 135 | kw = ksize[1] 136 | ph = pad[0] 137 | pw = pad[1] 138 | sh = stride[0] 139 | sw = stride[1] 140 | h2 = orig_size[0] 141 | w2 = orig_size[1] 142 | xs = np.zeros([n, x.shape[1], sh, x.shape[2], sw, c]) 143 | xs[:, :, 0, :, 0, :] = x 144 | xss = xs.shape 145 | x = xs.reshape([xss[0], xss[1] * xss[2], xss[3] * xss[4], xss[5]]) 146 | x = x[:, :h2, :w2, :] 147 | 148 | ph2 = int(np.ceil(ph / 2)) 149 | ph3 = int(np.floor(ph / 2)) 150 | pw2 = int(np.ceil(pw / 2)) 151 | pw3 = int(np.floor(pw / 2)) 152 | if floor_first: 153 | pph = (ph3, ph2) 154 | ppw = (pw3, pw2) 155 | else: 156 | pph = (ph2, ph3) 157 | ppw = (pw2, pw3) 158 | x = np.pad( 159 | x, ((0, 0), pph, ppw, (0, 0)), 160 | mode='constant', 161 | constant_values=(0.0,)) 162 | y = np.zeros([n, h2, w2, kh, kw, c]) 163 | 164 | for ii in range(h2): 165 | for jj in range(w2): 166 | y[:, ii, jj, :, :, :] = x[:, ii:ii + kh, jj:jj + kw, :] 167 | return y 168 | 169 | 170 | def extract_sliding_windows(x, ksize, pad, stride, floor_first=True): 171 | """Converts a tensor to sliding windows. 172 | 173 | Args: 174 | x: [N, H, W, C] 175 | k: [KH, KW] 176 | pad: [PH, PW] 177 | stride: [SH, SW] 178 | 179 | Returns: 180 | y: [N, (H-KH+PH+1)/SH, (W-KW+PW+1)/SW, KH * KW, C] 181 | """ 182 | n = x.shape[0] 183 | h = x.shape[1] 184 | w = x.shape[2] 185 | c = x.shape[3] 186 | kh = ksize[0] 187 | kw = ksize[1] 188 | sh = stride[0] 189 | sw = stride[1] 190 | 191 | h2 = int(calc_size(h, kh, pad, sh)) 192 | w2 = int(calc_size(w, kw, pad, sw)) 193 | ph = int(calc_pad(pad, h, h2, sh, kh)) 194 | pw = int(calc_pad(pad, w, w2, sw, kw)) 195 | 196 | ph0 = int(np.floor(ph / 2)) 197 | ph1 = int(np.ceil(ph / 2)) 198 | pw0 = int(np.floor(pw / 2)) 199 | pw1 = int(np.ceil(pw / 2)) 200 | 201 | if floor_first: 202 | pph = (ph0, ph1) 203 | ppw = (pw0, pw1) 204 | else: 205 | pph = (ph1, ph0) 206 | ppw = (pw1, pw0) 207 | x = np.pad( 208 | x, ((0, 0), pph, ppw, (0, 0)), 209 | mode='constant', 210 | constant_values=(0.0,)) 211 | 212 | y = np.zeros([n, h2, w2, kh, kw, c]) 213 | for ii in range(h2): 214 | for jj in range(w2): 215 | xx = ii * sh 216 | yy = jj * sw 217 | y[:, ii, jj, :, :, :] = x[:, xx:xx + kh, yy:yy + kw, :] 218 | return y 219 | 220 | 221 | def conv2dnp(x, w, pad='SAME', stride=(1, 1)): 222 | """2D convolution (technically speaking, correlation). 223 | 224 | Args: 225 | x: [N, H, W, C] 226 | w: [I, J, C, K] 227 | pad: [PH, PW] 228 | stride: [SH, SW] 229 | 230 | Returns: 231 | y: [N, H', W', K] 232 | """ 233 | ksize = w.shape[:2] 234 | x = extract_sliding_windows(x, ksize, pad, stride) 235 | ws = w.shape 236 | w = w.reshape([ws[0] * ws[1] * ws[2], ws[3]]) 237 | xs = x.shape 238 | x = x.reshape([xs[0] * xs[1] * xs[2], -1]) 239 | y = x.dot(w) 240 | y = y.reshape([xs[0], xs[1], xs[2], -1]) 241 | return y 242 | 243 | 244 | def conv2d_gradw(x, dy, ksize, pad='SAME', stride=(1, 1)): 245 | """2D convolution gradient wrt. filters. 246 | 247 | Args: 248 | dy: [N, H', W', K] 249 | x: [N, H, W, C] 250 | ksize: original w ksize [I, J]. 251 | 252 | Returns: 253 | dw: [I, J, C, K] 254 | """ 255 | dy = np.transpose(dy, [1, 2, 0, 3]) 256 | x = np.transpose(x, [3, 1, 2, 0]) 257 | ksize2 = dy.shape[:2] 258 | x = extract_sliding_windows_gradw(x, ksize2, pad, stride, ksize) 259 | dys = dy.shape 260 | dy = dy.reshape([dys[0] * dys[1] * dys[2], dys[3]]) 261 | xs = x.shape 262 | x = x.reshape([xs[0] * xs[1] * xs[2], -1]) 263 | dw = x.dot(dy) 264 | dw = dw.reshape([xs[0], xs[1], xs[2], -1]) 265 | dw = np.transpose(dw, [1, 2, 0, 3]) 266 | dw = dw[:ksize[0], :ksize[1], :, :] 267 | return dw 268 | 269 | 270 | def conv2d_gradx(w, dy, xsize, pad='SAME', stride=(1, 1)): 271 | """2D convolution gradient wrt. input. 272 | 273 | Args: 274 | dy: [N, H', W', K] 275 | w: [I, J, C, K] 276 | xsize: Original image size, [H, W] 277 | 278 | Returns: 279 | dx: [N, H, W, C] 280 | """ 281 | ksize = w.shape[:2] 282 | 283 | if pad == 'SAME': 284 | dys = dy.shape[1:3] 285 | pad2h = int( 286 | calc_pad('SAME', 287 | max(dys[0], dys[0] * stride[0] - 1), xsize[0], 1, ksize[ 288 | 0])) 289 | pad2w = int( 290 | calc_pad('SAME', 291 | max(dys[0], dys[0] * stride[1] - 1), xsize[1], 1, ksize[ 292 | 1])) 293 | pad2 = (pad2h, pad2w) 294 | elif pad == 'VALID': 295 | pad2 = (int(calc_pad('SAME', 0, 0, 1, ksize[0])), 296 | int(calc_pad('SAME', 0, 0, 1, ksize[1]))) 297 | pad2 = (pad2[0] * 2, pad2[1] * 2) 298 | else: 299 | pad2 = pad 300 | w = np.transpose(w, [0, 1, 3, 2]) 301 | ksize = w.shape[:2] 302 | dx = extract_sliding_windows_gradx(dy, ksize, pad2, stride, xsize) 303 | dxs = dx.shape 304 | dx = dx.reshape([dxs[0] * dxs[1] * dxs[2], -1]) 305 | w = w[::-1, ::-1, :, :] 306 | ws = w.shape 307 | w = w.reshape([ws[0] * ws[1] * ws[2], ws[3]]) 308 | dx = dx.dot(w) 309 | return dx.reshape([dxs[0], dxs[1], dxs[2], -1]) 310 | -------------------------------------------------------------------------------- /acd/util/tiling_1d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # pytorch needs to return each input as a column 5 | def gen_tiles(text, fill=0, 6 | method='occlusion', prev_text=None, sweep_dim=1): 7 | ''' 8 | Returns 9 | ------- 10 | texts: np.ndarray 11 | batch_size x L 12 | ''' 13 | L = text.shape[0] 14 | texts = np.zeros((L - sweep_dim + 1, L), dtype=np.int) 15 | for start in range(L - sweep_dim + 1): 16 | end = start + sweep_dim 17 | if method == 'occlusion': 18 | text_new = np.copy(text).flatten() 19 | text_new[start:end] = fill 20 | elif method == 'build_up' or method == 'cd': 21 | text_new = np.zeros(L) 22 | text_new[start:end] = text[start:end] 23 | texts[start] = np.copy(text_new) 24 | return texts 25 | 26 | def gen_tile_from_comp(text_orig, comp_tile, method, fill=0): 27 | '''return tile representing component 28 | ''' 29 | if method == 'occlusion': 30 | tile_new = np.copy(text_orig).flatten() 31 | tile_new[comp_tile] = fill 32 | elif method == 'build_up' or method == 'cd': 33 | tile_new = np.zeros(text_orig.shape) 34 | tile_new[comp_tile] = text_orig[comp_tile] 35 | return tile_new 36 | 37 | 38 | 39 | def gen_tiles_around_baseline(text_orig, comp_tile, method='build_up', sweep_dim=1, fill=0): 40 | '''generate tiles around a component (varies based on method) 41 | ''' 42 | L = text_orig.shape[0] 43 | left = 0 44 | right = L - 1 45 | while not comp_tile[left]: 46 | left += 1 47 | while not comp_tile[right]: 48 | right -= 1 49 | left = max(0, left - sweep_dim) 50 | right = min(L - 1, right + sweep_dim) 51 | tiles = [] 52 | for x in [left, right]: 53 | if method == 'occlusion': 54 | tile_new = np.copy(text_orig).flatten() 55 | tile_new[comp_tile] = fill 56 | tile_new[x] = fill 57 | elif method == 'build_up' or method == 'cd': 58 | tile_new = np.zeros(text_orig.shape) 59 | tile_new[comp_tile] = text_orig[comp_tile] 60 | tile_new[x] = text_orig[x] 61 | tiles.append(tile_new) 62 | return np.array(tiles), [left, right] 63 | -------------------------------------------------------------------------------- /acd/util/tiling_2d.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import numpy as np 4 | 5 | 6 | def gen_tiles(image, fill=0, method='occlusion', prev_im=None, 7 | sweep_dim=1, num_ims=None, im_num_start=0): 8 | '''Generate all possible tilings given a granularity of sweep_dim and for a particular method 9 | ''' 10 | R = image.shape[0] 11 | C = image.shape[1] 12 | 13 | if image.ndim == 2: # mnist case 14 | if num_ims is None: # check if theres a limit on how many ims to have 15 | num_ims = ceil(R / sweep_dim) * ceil(C / sweep_dim) 16 | # print('sizes', R, C, num_ims) 17 | ims = np.empty((num_ims, R, C)) 18 | else: # imagenet case 19 | if num_ims is None: # check if theres a limit on how many ims to have 20 | num_ims = ceil(R / sweep_dim) * ceil(C / sweep_dim) 21 | if method == 'cd': 22 | ims = np.empty((num_ims, R, C)) 23 | else: 24 | ims = np.empty((num_ims, R, C, 3)) 25 | 26 | i = 0 27 | # iterate over top, left indexes 28 | for rmin in range(0, R, sweep_dim): 29 | for cmin in range(0, C, sweep_dim): 30 | if im_num_start <= i < im_num_start + num_ims: 31 | 32 | # calculate bounds of box 33 | rmax = min(rmin + sweep_dim, R) 34 | cmax = min(cmin + sweep_dim, C) 35 | 36 | # create appropriate images 37 | if method == 'occlusion': 38 | im = np.copy(image) 39 | im[rmin:rmax, cmin:cmax] = fill # image[r-1:r+1, c-1:c+1] 40 | if not prev_im is None: 41 | im[prev_im] = fill 42 | elif method == 'build_up': 43 | im = np.zeros(image.shape) 44 | im[rmin:rmax, cmin:cmax] = image[rmin:rmax, cmin:cmax] 45 | if not prev_im is None: 46 | im[prev_im] = image[prev_im] 47 | elif method == 'cd': 48 | im = np.zeros((R, C)) 49 | im[rmin:rmax, cmin:cmax] = 1 50 | if not prev_im is None: 51 | im[prev_im] = 1 52 | ims[i - im_num_start] = np.copy(im) 53 | i += 1 54 | return ims 55 | 56 | 57 | def gen_tiles_around_baseline(im_orig, comp_tile, fill=0, 58 | method='occlusion', sweep_dim=3): 59 | R = im_orig.shape[0] 60 | C = im_orig.shape[1] 61 | dim_2 = (sweep_dim // 2) # note the +1 for adjacent, but non-overlapping tiles 62 | ims, idxs = [], [] 63 | # iterate over top, left indexes 64 | for r_downsampled, rmin in enumerate(range(0, R, sweep_dim)): 65 | for c_downsampled, cmin in enumerate(range(0, C, sweep_dim)): 66 | 67 | rmax = min(rmin + sweep_dim, R) 68 | cmax = min(cmin + sweep_dim, C) 69 | 70 | # calculate bounds of new block + boundaries 71 | rminus = max(rmin - sweep_dim, 0) 72 | cminus = max(cmin - sweep_dim, 0) 73 | rplus = min(rmin + sweep_dim, R - 1) 74 | cplus = min(cmin + sweep_dim, C - 1) 75 | 76 | # new block isn't in old block 77 | if not comp_tile[rmin, cmin]: 78 | # new block borders old block 79 | if comp_tile[rminus, cmin] or comp_tile[rmin, cminus] or comp_tile[rplus, cmin] or comp_tile[ 80 | rmin, cplus]: 81 | if method == 'occlusion': 82 | im = np.copy(im_orig) # im_orig background 83 | im[rmin:rmax, cmin:cmax] = fill # black out new block 84 | im[comp_tile] = fill # black out comp_tile 85 | elif method == 'build_up': 86 | im = np.zeros(im_orig.shape) # zero background 87 | im[rmin:rmax, cmin:cmax] = im_orig[rmin:rmax, cmin:cmax] # im_orig at new block 88 | im[comp_tile] = im_orig[comp_tile] # im_orig at comp_tile 89 | elif method == 'cd': 90 | im = np.zeros((R, C)) # zero background 91 | im[rmin:rmax, cmin:cmax] = 1 # 1 at new block 92 | im[comp_tile] = 1 # 1 at comp_tile 93 | ims.append(im) 94 | idxs.append((r_downsampled, c_downsampled)) 95 | return np.array(ims), idxs 96 | 97 | 98 | def gen_tile_from_comp(im_orig, comp_tile_downsampled, sweep_dim, method, fill=0): 99 | '''generates full-size tile from comp which could be downsampled 100 | ''' 101 | R = im_orig.shape[0] 102 | C = im_orig.shape[1] 103 | if method == 'occlusion': 104 | im = np.copy(im_orig) 105 | # im[comp_tile] = fill 106 | # fill in comp_tile with fill 107 | for r in range(comp_tile_downsampled.shape[0]): 108 | for c in range(comp_tile_downsampled.shape[1]): 109 | if comp_tile_downsampled[r, c]: 110 | im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = fill 111 | 112 | elif method == 'build_up': 113 | im = np.zeros(im_orig.shape) 114 | # im[comp_tile] = im_orig[comp_tile] 115 | # fill in comp_tile with im_orig 116 | for r in range(comp_tile_downsampled.shape[0]): 117 | for c in range(comp_tile_downsampled.shape[1]): 118 | if comp_tile_downsampled[r, c]: 119 | im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = \ 120 | im_orig[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] 121 | 122 | elif method == 'cd': 123 | im = np.zeros((R, C), dtype=np.bool_) 124 | # fill in comp_tile with 1 125 | for r in range(comp_tile_downsampled.shape[0]): 126 | for c in range(comp_tile_downsampled.shape[1]): 127 | if comp_tile_downsampled[r, c]: 128 | im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = 1 129 | return im 130 | 131 | 132 | def combine_tiles(tile1, tile2, method='cd'): 133 | if not method == 'occlusion': 134 | return tile1 + tile2 135 | -------------------------------------------------------------------------------- /citation.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{ 2 | singh2019hierarchical, 3 | title={Hierarchical interpretations for neural network predictions}, 4 | author={Chandan Singh and W. James Murdoch and Bin Yu}, 5 | booktitle={International Conference on Learning Representations}, 6 | year={2019}, 7 | url={https://openreview.net/forum?id=SkEqro0ctQ}, 8 | } 9 | -------------------------------------------------------------------------------- /docs/agglomeration/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | acd.agglomeration API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 |

Module acd.agglomeration

21 |
22 |
23 |
24 |
25 |

Sub-modules

26 |
27 |
acd.agglomeration.agg_1d
28 |
29 |
30 |
31 |
acd.agglomeration.agg_2d
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | 63 |
64 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /docs/build_docs.sh: -------------------------------------------------------------------------------- 1 | cd ../acd 2 | pdoc --html . --output-dir ../docs --template-dir . 3 | cp -rf ../docs/acd/* ../docs/ 4 | rm -rf ../docs/acd 5 | cd ../docs 6 | python3 style_docs.py -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | acd API documentation 8 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 |
20 |
21 | 22 |
23 |
24 |

25 |

Hierarchical neural-net interpretations (ACD) 🧠

26 |

Produces hierarchical interpretations for a single prediction made by a pytorch neural network. Official code for Hierarchical interpretations for neural network predictions (ICLR 2019 pdf).

27 |

28 | 29 | 30 | 31 | 32 | 33 | 34 |

35 |

36 | Documentation • 37 | Demo notebooks 38 |


39 |

40 | Note: this repo is actively maintained. For any questions please file an issue. 41 |

42 |

43 |

examples/documentation

44 |
    45 |
  • installation: pip install acd (or clone and run python setup.py install)
  • 46 |
  • examples: the reproduce_figs folder has notebooks with many demos
  • 47 |
  • src: the acd folder contains the source for the method implementation
  • 48 |
  • allows for different types of interpretations by changing hyperparameters (explained in examples)
  • 49 |
  • all required data/models/code for reproducing are included in the dsets folder
  • 50 |
51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 |
Inspecting NLP sentiment modelsDetecting adversarial examplesAnalyzing imagenet models
67 |

notes on using ACD on your own data

68 |
    69 |
  • the current CD implementation often works out-of-the box, especially for networks built on common layers, such as alexnet/vgg/resnet. However, if you have custom layers or layers not accessible in net.modules(), you may need to write a custom function to iterate through some layers of your network (for examples see cd.py).
  • 70 |
  • to use baselines such build-up and occlusion, replace the pred_ims function by a function, which gets predictions from your model given a batch of examples.
  • 71 |
72 |

related work

73 |
    74 |
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • 75 |
  • TRIM (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • 76 |
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning
  • 77 |
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • 78 |
  • Baseline interpretability methods - the file scores/score_funcs.py also contains simple pytorch implementations of integrated gradients and the simple interpration technique gradient * input
  • 79 |
80 |

reference

81 |
    82 |
  • feel free to use/share this code openly
  • 83 |
  • if you find this code useful for your research, please cite the following:
  • 84 |
85 |

r 86 | @inproceedings{ 87 | singh2019hierarchical, 88 | title={Hierarchical interpretations for neural network predictions}, 89 | author={Chandan Singh and W. James Murdoch and Bin Yu}, 90 | booktitle={International Conference on Learning Representations}, 91 | year={2019}, 92 | url={<https://openreview.net/forum?id=SkEqro0ctQ},> 93 | }

94 |
95 | 96 | Expand source code 97 | 98 |
'''
 99 | .. include:: ../readme.md
100 | '''
101 | 
102 | from .scores.cd import *
103 | from .scores.cd_propagate import *
104 | from .scores.score_funcs import *
105 | from .agglomeration import agg_1d, agg_2d
106 | from .util import *
107 |
108 |
109 |
110 |

Sub-modules

111 |
112 |
acd.agglomeration
113 |
114 |
115 |
116 |
acd.scores
117 |
118 |
119 |
120 |
acd.util
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |

Functions

130 |
131 |
132 | def tanh(...) 133 |
134 |
135 |

tanh(input, out=None) -> Tensor

136 |

Returns a new tensor with the hyperbolic tangent of the elements 137 | of :attr:input.

138 |

[ \text{out}{i} = \tanh(\text{input}) ]

139 |

Args

140 |
141 |
input : Tensor
142 |
the input tensor.
143 |
out : Tensor, optional
144 |
the output tensor.
145 |
146 |

Example::

147 |
>>> a = torch.randn(4)
148 | >>> a
149 | tensor([ 0.8986, -0.7279,  1.1745,  0.2611])
150 | >>> torch.tanh(a)
151 | tensor([ 0.7156, -0.6218,  0.8257,  0.2553])
152 | 
153 |
154 |
155 |
156 |
157 |
158 |
159 | 184 |
185 | 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /docs/scores/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | acd.scores API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 |

Module acd.scores

21 |
22 |
23 |
24 |
25 |

Sub-modules

26 |
27 |
acd.scores.cd
28 |
29 |
30 |
31 |
acd.scores.cd_architecture_specific
32 |
33 |
34 |
35 |
acd.scores.cd_propagate
36 |
37 |
38 |
39 |
acd.scores.score_funcs
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 | 73 |
74 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /docs/style_docs.py: -------------------------------------------------------------------------------- 1 | # Read in the file 2 | with open('index.html', 'r') as f: 3 | data = f.read() 4 | 5 | 6 | # data = data.replace('.html">imodels.', '.html">') 7 | data = data.replace('

Module acd

', '') # remove header 8 | data = data.replace('Module for computing hierarchical interpretations of neural network predictions', '') 9 | # data = data.replace('Reference', 'Reference') 10 | 11 | # add github corner 12 | # data = data.replace('', "\n\n") 13 | data += '' 14 | 15 | 16 | data += '' 17 | 18 | # Write the file out again 19 | with open('index.html', 'w') as f: 20 | f.write(data) -------------------------------------------------------------------------------- /docs/util/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | acd.util API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 |

Module acd.util

21 |
22 |
23 |
24 |
25 |

Sub-modules

26 |
27 |
acd.util.conv2dnp
28 |
29 |

code from https://github.com/renmengye/np-conv2d

30 |
31 |
acd.util.tiling_1d
32 |
33 |
34 |
35 |
acd.util.tiling_2d
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 | 68 |
69 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/util/tiling_1d.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | acd.util.tiling_1d API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 |

Module acd.util.tiling_1d

21 |
22 |
23 |
24 | 25 | Expand source code 26 | 27 |
import numpy as np
 28 | 
 29 | 
 30 | # pytorch needs to return each input as a column
 31 | def gen_tiles(text, fill=0,
 32 |               method='occlusion', prev_text=None, sweep_dim=1):
 33 |     '''
 34 |     Returns
 35 |     -------
 36 |     texts: np.ndarray
 37 |         batch_size x L
 38 |     '''
 39 |     L = text.shape[0]
 40 |     texts = np.zeros((L - sweep_dim + 1, L), dtype=np.int)
 41 |     for start in range(L - sweep_dim + 1):
 42 |         end = start + sweep_dim
 43 |         if method == 'occlusion':
 44 |             text_new = np.copy(text).flatten()
 45 |             text_new[start:end] = fill
 46 |         elif method == 'build_up' or method == 'cd':
 47 |             text_new = np.zeros(L)
 48 |             text_new[start:end] = text[start:end]
 49 |         texts[start] = np.copy(text_new)
 50 |     return texts
 51 | 
 52 | def gen_tile_from_comp(text_orig, comp_tile, method, fill=0):
 53 |     '''return tile representing component
 54 |     '''
 55 |     if method == 'occlusion':
 56 |         tile_new = np.copy(text_orig).flatten()
 57 |         tile_new[comp_tile] = fill
 58 |     elif method == 'build_up' or method == 'cd':
 59 |         tile_new = np.zeros(text_orig.shape)
 60 |         tile_new[comp_tile] = text_orig[comp_tile]
 61 |     return tile_new
 62 | 
 63 | 
 64 | 
 65 | def gen_tiles_around_baseline(text_orig, comp_tile, method='build_up', sweep_dim=1, fill=0):
 66 |     '''generate tiles around a component (varies based on method)
 67 |     '''
 68 |     L = text_orig.shape[0]
 69 |     left = 0
 70 |     right = L - 1
 71 |     while not comp_tile[left]:
 72 |         left += 1
 73 |     while not comp_tile[right]:
 74 |         right -= 1
 75 |     left = max(0, left - sweep_dim)
 76 |     right = min(L - 1, right + sweep_dim)
 77 |     tiles = []
 78 |     for x in [left, right]:
 79 |         if method == 'occlusion':
 80 |             tile_new = np.copy(text_orig).flatten()
 81 |             tile_new[comp_tile] = fill
 82 |             tile_new[x] = fill
 83 |         elif method == 'build_up' or method == 'cd':
 84 |             tile_new = np.zeros(text_orig.shape)
 85 |             tile_new[comp_tile] = text_orig[comp_tile]
 86 |             tile_new[x] = text_orig[x]
 87 |         tiles.append(tile_new)
 88 |     return np.array(tiles), [left, right]
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |

Functions

97 |
98 |
99 | def gen_tile_from_comp(text_orig, comp_tile, method, fill=0) 100 |
101 |
102 |

return tile representing component

103 |
104 | 105 | Expand source code 106 | 107 |
def gen_tile_from_comp(text_orig, comp_tile, method, fill=0):
108 |     '''return tile representing component
109 |     '''
110 |     if method == 'occlusion':
111 |         tile_new = np.copy(text_orig).flatten()
112 |         tile_new[comp_tile] = fill
113 |     elif method == 'build_up' or method == 'cd':
114 |         tile_new = np.zeros(text_orig.shape)
115 |         tile_new[comp_tile] = text_orig[comp_tile]
116 |     return tile_new
117 |
118 |
119 |
120 | def gen_tiles(text, fill=0, method='occlusion', prev_text=None, sweep_dim=1) 121 |
122 |
123 |

Returns

124 |
125 |
texts : np.ndarray
126 |
batch_size x L
127 |
128 |
129 | 130 | Expand source code 131 | 132 |
def gen_tiles(text, fill=0,
133 |               method='occlusion', prev_text=None, sweep_dim=1):
134 |     '''
135 |     Returns
136 |     -------
137 |     texts: np.ndarray
138 |         batch_size x L
139 |     '''
140 |     L = text.shape[0]
141 |     texts = np.zeros((L - sweep_dim + 1, L), dtype=np.int)
142 |     for start in range(L - sweep_dim + 1):
143 |         end = start + sweep_dim
144 |         if method == 'occlusion':
145 |             text_new = np.copy(text).flatten()
146 |             text_new[start:end] = fill
147 |         elif method == 'build_up' or method == 'cd':
148 |             text_new = np.zeros(L)
149 |             text_new[start:end] = text[start:end]
150 |         texts[start] = np.copy(text_new)
151 |     return texts
152 |
153 |
154 |
155 | def gen_tiles_around_baseline(text_orig, comp_tile, method='build_up', sweep_dim=1, fill=0) 156 |
157 |
158 |

generate tiles around a component (varies based on method)

159 |
160 | 161 | Expand source code 162 | 163 |
def gen_tiles_around_baseline(text_orig, comp_tile, method='build_up', sweep_dim=1, fill=0):
164 |     '''generate tiles around a component (varies based on method)
165 |     '''
166 |     L = text_orig.shape[0]
167 |     left = 0
168 |     right = L - 1
169 |     while not comp_tile[left]:
170 |         left += 1
171 |     while not comp_tile[right]:
172 |         right -= 1
173 |     left = max(0, left - sweep_dim)
174 |     right = min(L - 1, right + sweep_dim)
175 |     tiles = []
176 |     for x in [left, right]:
177 |         if method == 'occlusion':
178 |             tile_new = np.copy(text_orig).flatten()
179 |             tile_new[comp_tile] = fill
180 |             tile_new[x] = fill
181 |         elif method == 'build_up' or method == 'cd':
182 |             tile_new = np.zeros(text_orig.shape)
183 |             tile_new[comp_tile] = text_orig[comp_tile]
184 |             tile_new[x] = text_orig[x]
185 |         tiles.append(tile_new)
186 |     return np.array(tiles), [left, right]
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 | 214 |
215 | 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /docs/util/tiling_2d.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | acd.util.tiling_2d API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 |
19 |
20 |

Module acd.util.tiling_2d

21 |
22 |
23 |
24 | 25 | Expand source code 26 | 27 |
from math import ceil
 28 | 
 29 | import numpy as np
 30 | 
 31 | 
 32 | def gen_tiles(image, fill=0, method='occlusion', prev_im=None,
 33 |               sweep_dim=1, num_ims=None, im_num_start=0):
 34 |     '''Generate all possible tilings given a granularity of sweep_dim and for a particular method
 35 |     '''
 36 |     R = image.shape[0]
 37 |     C = image.shape[1]
 38 | 
 39 |     if image.ndim == 2:  # mnist case
 40 |         if num_ims is None:  # check if theres a limit on how many ims to have
 41 |             num_ims = ceil(R / sweep_dim) * ceil(C / sweep_dim)
 42 |         # print('sizes', R, C, num_ims)
 43 |         ims = np.empty((num_ims, R, C))
 44 |     else:  # imagenet case
 45 |         if num_ims is None:  # check if theres a limit on how many ims to have
 46 |             num_ims = ceil(R / sweep_dim) * ceil(C / sweep_dim)
 47 |         if method == 'cd':
 48 |             ims = np.empty((num_ims, R, C))
 49 |         else:
 50 |             ims = np.empty((num_ims, R, C, 3))
 51 | 
 52 |     i = 0
 53 |     # iterate over top, left indexes
 54 |     for rmin in range(0, R, sweep_dim):
 55 |         for cmin in range(0, C, sweep_dim):
 56 |             if im_num_start <= i < im_num_start + num_ims:
 57 | 
 58 |                 # calculate bounds of box
 59 |                 rmax = min(rmin + sweep_dim, R)
 60 |                 cmax = min(cmin + sweep_dim, C)
 61 | 
 62 |                 # create appropriate images
 63 |                 if method == 'occlusion':
 64 |                     im = np.copy(image)
 65 |                     im[rmin:rmax, cmin:cmax] = fill  # image[r-1:r+1, c-1:c+1]
 66 |                     if not prev_im is None:
 67 |                         im[prev_im] = fill
 68 |                 elif method == 'build_up':
 69 |                     im = np.zeros(image.shape)
 70 |                     im[rmin:rmax, cmin:cmax] = image[rmin:rmax, cmin:cmax]
 71 |                     if not prev_im is None:
 72 |                         im[prev_im] = image[prev_im]
 73 |                 elif method == 'cd':
 74 |                     im = np.zeros((R, C))
 75 |                     im[rmin:rmax, cmin:cmax] = 1
 76 |                     if not prev_im is None:
 77 |                         im[prev_im] = 1
 78 |                 ims[i - im_num_start] = np.copy(im)
 79 |             i += 1
 80 |     return ims
 81 | 
 82 | 
 83 | def gen_tiles_around_baseline(im_orig, comp_tile, fill=0,
 84 |                               method='occlusion', sweep_dim=3):
 85 |     R = im_orig.shape[0]
 86 |     C = im_orig.shape[1]
 87 |     dim_2 = (sweep_dim // 2)  # note the +1 for adjacent, but non-overlapping tiles
 88 |     ims, idxs = [], []
 89 |     # iterate over top, left indexes
 90 |     for r_downsampled, rmin in enumerate(range(0, R, sweep_dim)):
 91 |         for c_downsampled, cmin in enumerate(range(0, C, sweep_dim)):
 92 | 
 93 |             rmax = min(rmin + sweep_dim, R)
 94 |             cmax = min(cmin + sweep_dim, C)
 95 | 
 96 |             # calculate bounds of new block + boundaries
 97 |             rminus = max(rmin - sweep_dim, 0)
 98 |             cminus = max(cmin - sweep_dim, 0)
 99 |             rplus = min(rmin + sweep_dim, R - 1)
100 |             cplus = min(cmin + sweep_dim, C - 1)
101 | 
102 |             # new block isn't in old block
103 |             if not comp_tile[rmin, cmin]:
104 |                 # new block borders old block
105 |                 if comp_tile[rminus, cmin] or comp_tile[rmin, cminus] or comp_tile[rplus, cmin] or comp_tile[
106 |                     rmin, cplus]:
107 |                     if method == 'occlusion':
108 |                         im = np.copy(im_orig)  # im_orig background
109 |                         im[rmin:rmax, cmin:cmax] = fill  # black out new block
110 |                         im[comp_tile] = fill  # black out comp_tile
111 |                     elif method == 'build_up':
112 |                         im = np.zeros(im_orig.shape)  # zero background
113 |                         im[rmin:rmax, cmin:cmax] = im_orig[rmin:rmax, cmin:cmax]  # im_orig at new block
114 |                         im[comp_tile] = im_orig[comp_tile]  # im_orig at comp_tile
115 |                     elif method == 'cd':
116 |                         im = np.zeros((R, C))  # zero background
117 |                         im[rmin:rmax, cmin:cmax] = 1  # 1 at new block
118 |                         im[comp_tile] = 1  # 1 at comp_tile
119 |                     ims.append(im)
120 |                     idxs.append((r_downsampled, c_downsampled))
121 |     return np.array(ims), idxs
122 | 
123 | 
124 | def gen_tile_from_comp(im_orig, comp_tile_downsampled, sweep_dim, method, fill=0):
125 |     '''generates full-size tile from comp which could be downsampled
126 |     '''
127 |     R = im_orig.shape[0]
128 |     C = im_orig.shape[1]
129 |     if method == 'occlusion':
130 |         im = np.copy(im_orig)
131 |         #         im[comp_tile] = fill
132 |         # fill in comp_tile with fill
133 |         for r in range(comp_tile_downsampled.shape[0]):
134 |             for c in range(comp_tile_downsampled.shape[1]):
135 |                 if comp_tile_downsampled[r, c]:
136 |                     im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = fill
137 | 
138 |     elif method == 'build_up':
139 |         im = np.zeros(im_orig.shape)
140 |         #         im[comp_tile] = im_orig[comp_tile]
141 |         # fill in comp_tile with im_orig
142 |         for r in range(comp_tile_downsampled.shape[0]):
143 |             for c in range(comp_tile_downsampled.shape[1]):
144 |                 if comp_tile_downsampled[r, c]:
145 |                     im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = \
146 |                         im_orig[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim]
147 | 
148 |     elif method == 'cd':
149 |         im = np.zeros((R, C), dtype=np.bool_)
150 |         # fill in comp_tile with 1
151 |         for r in range(comp_tile_downsampled.shape[0]):
152 |             for c in range(comp_tile_downsampled.shape[1]):
153 |                 if comp_tile_downsampled[r, c]:
154 |                     im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = 1
155 |     return im
156 | 
157 | 
158 | def combine_tiles(tile1, tile2, method='cd'):
159 |     if not method == 'occlusion':
160 |         return tile1 + tile2
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |

Functions

169 |
170 |
171 | def combine_tiles(tile1, tile2, method='cd') 172 |
173 |
174 |
175 |
176 | 177 | Expand source code 178 | 179 |
def combine_tiles(tile1, tile2, method='cd'):
180 |     if not method == 'occlusion':
181 |         return tile1 + tile2
182 |
183 |
184 |
185 | def gen_tile_from_comp(im_orig, comp_tile_downsampled, sweep_dim, method, fill=0) 186 |
187 |
188 |

generates full-size tile from comp which could be downsampled

189 |
190 | 191 | Expand source code 192 | 193 |
def gen_tile_from_comp(im_orig, comp_tile_downsampled, sweep_dim, method, fill=0):
194 |     '''generates full-size tile from comp which could be downsampled
195 |     '''
196 |     R = im_orig.shape[0]
197 |     C = im_orig.shape[1]
198 |     if method == 'occlusion':
199 |         im = np.copy(im_orig)
200 |         #         im[comp_tile] = fill
201 |         # fill in comp_tile with fill
202 |         for r in range(comp_tile_downsampled.shape[0]):
203 |             for c in range(comp_tile_downsampled.shape[1]):
204 |                 if comp_tile_downsampled[r, c]:
205 |                     im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = fill
206 | 
207 |     elif method == 'build_up':
208 |         im = np.zeros(im_orig.shape)
209 |         #         im[comp_tile] = im_orig[comp_tile]
210 |         # fill in comp_tile with im_orig
211 |         for r in range(comp_tile_downsampled.shape[0]):
212 |             for c in range(comp_tile_downsampled.shape[1]):
213 |                 if comp_tile_downsampled[r, c]:
214 |                     im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = \
215 |                         im_orig[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim]
216 | 
217 |     elif method == 'cd':
218 |         im = np.zeros((R, C), dtype=np.bool_)
219 |         # fill in comp_tile with 1
220 |         for r in range(comp_tile_downsampled.shape[0]):
221 |             for c in range(comp_tile_downsampled.shape[1]):
222 |                 if comp_tile_downsampled[r, c]:
223 |                     im[r * sweep_dim: (r + 1) * sweep_dim, c * sweep_dim: (c + 1) * sweep_dim] = 1
224 |     return im
225 |
226 |
227 |
228 | def gen_tiles(image, fill=0, method='occlusion', prev_im=None, sweep_dim=1, num_ims=None, im_num_start=0) 229 |
230 |
231 |

Generate all possible tilings given a granularity of sweep_dim and for a particular method

232 |
233 | 234 | Expand source code 235 | 236 |
def gen_tiles(image, fill=0, method='occlusion', prev_im=None,
237 |               sweep_dim=1, num_ims=None, im_num_start=0):
238 |     '''Generate all possible tilings given a granularity of sweep_dim and for a particular method
239 |     '''
240 |     R = image.shape[0]
241 |     C = image.shape[1]
242 | 
243 |     if image.ndim == 2:  # mnist case
244 |         if num_ims is None:  # check if theres a limit on how many ims to have
245 |             num_ims = ceil(R / sweep_dim) * ceil(C / sweep_dim)
246 |         # print('sizes', R, C, num_ims)
247 |         ims = np.empty((num_ims, R, C))
248 |     else:  # imagenet case
249 |         if num_ims is None:  # check if theres a limit on how many ims to have
250 |             num_ims = ceil(R / sweep_dim) * ceil(C / sweep_dim)
251 |         if method == 'cd':
252 |             ims = np.empty((num_ims, R, C))
253 |         else:
254 |             ims = np.empty((num_ims, R, C, 3))
255 | 
256 |     i = 0
257 |     # iterate over top, left indexes
258 |     for rmin in range(0, R, sweep_dim):
259 |         for cmin in range(0, C, sweep_dim):
260 |             if im_num_start <= i < im_num_start + num_ims:
261 | 
262 |                 # calculate bounds of box
263 |                 rmax = min(rmin + sweep_dim, R)
264 |                 cmax = min(cmin + sweep_dim, C)
265 | 
266 |                 # create appropriate images
267 |                 if method == 'occlusion':
268 |                     im = np.copy(image)
269 |                     im[rmin:rmax, cmin:cmax] = fill  # image[r-1:r+1, c-1:c+1]
270 |                     if not prev_im is None:
271 |                         im[prev_im] = fill
272 |                 elif method == 'build_up':
273 |                     im = np.zeros(image.shape)
274 |                     im[rmin:rmax, cmin:cmax] = image[rmin:rmax, cmin:cmax]
275 |                     if not prev_im is None:
276 |                         im[prev_im] = image[prev_im]
277 |                 elif method == 'cd':
278 |                     im = np.zeros((R, C))
279 |                     im[rmin:rmax, cmin:cmax] = 1
280 |                     if not prev_im is None:
281 |                         im[prev_im] = 1
282 |                 ims[i - im_num_start] = np.copy(im)
283 |             i += 1
284 |     return ims
285 |
286 |
287 |
288 | def gen_tiles_around_baseline(im_orig, comp_tile, fill=0, method='occlusion', sweep_dim=3) 289 |
290 |
291 |
292 |
293 | 294 | Expand source code 295 | 296 |
def gen_tiles_around_baseline(im_orig, comp_tile, fill=0,
297 |                               method='occlusion', sweep_dim=3):
298 |     R = im_orig.shape[0]
299 |     C = im_orig.shape[1]
300 |     dim_2 = (sweep_dim // 2)  # note the +1 for adjacent, but non-overlapping tiles
301 |     ims, idxs = [], []
302 |     # iterate over top, left indexes
303 |     for r_downsampled, rmin in enumerate(range(0, R, sweep_dim)):
304 |         for c_downsampled, cmin in enumerate(range(0, C, sweep_dim)):
305 | 
306 |             rmax = min(rmin + sweep_dim, R)
307 |             cmax = min(cmin + sweep_dim, C)
308 | 
309 |             # calculate bounds of new block + boundaries
310 |             rminus = max(rmin - sweep_dim, 0)
311 |             cminus = max(cmin - sweep_dim, 0)
312 |             rplus = min(rmin + sweep_dim, R - 1)
313 |             cplus = min(cmin + sweep_dim, C - 1)
314 | 
315 |             # new block isn't in old block
316 |             if not comp_tile[rmin, cmin]:
317 |                 # new block borders old block
318 |                 if comp_tile[rminus, cmin] or comp_tile[rmin, cminus] or comp_tile[rplus, cmin] or comp_tile[
319 |                     rmin, cplus]:
320 |                     if method == 'occlusion':
321 |                         im = np.copy(im_orig)  # im_orig background
322 |                         im[rmin:rmax, cmin:cmax] = fill  # black out new block
323 |                         im[comp_tile] = fill  # black out comp_tile
324 |                     elif method == 'build_up':
325 |                         im = np.zeros(im_orig.shape)  # zero background
326 |                         im[rmin:rmax, cmin:cmax] = im_orig[rmin:rmax, cmin:cmax]  # im_orig at new block
327 |                         im[comp_tile] = im_orig[comp_tile]  # im_orig at comp_tile
328 |                     elif method == 'cd':
329 |                         im = np.zeros((R, C))  # zero background
330 |                         im[rmin:rmax, cmin:cmax] = 1  # 1 at new block
331 |                         im[comp_tile] = 1  # 1 at comp_tile
332 |                     ims.append(im)
333 |                     idxs.append((r_downsampled, c_downsampled))
334 |     return np.array(ims), idxs
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 | 363 |
364 | 367 | 368 | 369 | 370 | -------------------------------------------------------------------------------- /dsets/imagenet/imnet_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/dsets/imagenet/imnet_dict.pkl -------------------------------------------------------------------------------- /dsets/mnist/dset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | import numpy as np 8 | 9 | 10 | # Training settings 11 | def get_args(): 12 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 14 | help='input batch size for training (default: 64)') 15 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 16 | help='input batch size for testing (default: 1000)') 17 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 18 | help='number of epochs to train (default: 10)') 19 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 20 | help='learning rate (default: 0.01)') 21 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 22 | help='SGD momentum (default: 0.5)') 23 | parser.add_argument('--no-cuda', action='store_true', default=False, 24 | help='disables CUDA training') 25 | parser.add_argument('--seed', type=int, default=1, metavar='S', 26 | help='random seed (default: 1)') 27 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 28 | help='how many batches to wait before logging training status') 29 | return parser.parse_args("") 30 | 31 | 32 | # load data 33 | def load_data(train_batch_size, test_batch_size, device, data_dir='data', shuffle=False): 34 | kwargs = {} #{'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {} 35 | train_loader = torch.utils.data.DataLoader( 36 | datasets.MNIST(data_dir, train=True, download=True, 37 | transform=transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.1307,), (0.3081,)) 40 | ])), 41 | batch_size=train_batch_size, shuffle=shuffle, **kwargs) 42 | test_loader = torch.utils.data.DataLoader( 43 | datasets.MNIST(data_dir, train=False, transform=transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.1307,), (0.3081,)) 46 | ])), 47 | batch_size=test_batch_size, shuffle=shuffle, **kwargs) 48 | return train_loader, test_loader 49 | 50 | 51 | def train(epoch, train_loader): 52 | model.train() 53 | for batch_idx, (data, target) in enumerate(train_loader): 54 | if args.cuda: 55 | data, target = data.cuda(), target.cuda() 56 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 57 | optimizer.zero_grad() 58 | output = model(data) 59 | loss = F.nll_loss(output, target) 60 | loss.backward() 61 | optimizer.step() 62 | if batch_idx % args.log_interval == 0: 63 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 64 | epoch, batch_idx * len(data), len(train_loader.dataset), 65 | 100. * batch_idx / len(train_loader), loss.data[0])) 66 | return model 67 | 68 | 69 | def test(model, test_loader): 70 | model.eval() 71 | test_loss = 0 72 | correct = 0 73 | for data, target in test_loader: 74 | if args.cuda: 75 | data, target = data.cuda(), target.cuda() 76 | output = model(data) 77 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 78 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 79 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 80 | 81 | test_loss /= len(test_loader.dataset) 82 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 83 | test_loss, correct, len(test_loader.dataset), 84 | 100. * correct / len(test_loader.dataset))) 85 | 86 | 87 | def get_im_and_label(num, device='cuda'): 88 | torch.manual_seed(130) 89 | _, data_loader = load_data(train_batch_size=1, test_batch_size=1, 90 | device=device, data_dir='mnist/data', 91 | shuffle=False) 92 | for i, im in enumerate(data_loader): 93 | if i == num: 94 | return im[0].to(device), im[0].numpy().squeeze(), im[1].numpy()[0] 95 | 96 | 97 | def pred_ims(model, ims, layer='softmax', device='cuda'): 98 | if len(ims.shape) == 2: 99 | ims = np.expand_dims(ims, 0) 100 | ims_torch = torch.unsqueeze(torch.Tensor(ims), 1).float().to(device) # cuda() 101 | preds = model(ims_torch) 102 | 103 | # todo - build in logit support 104 | # logits = model.logits(t) 105 | return preds.data.cpu().numpy() 106 | 107 | 108 | if __name__ == '__main__': 109 | from model import Net 110 | args = get_args() 111 | args.cuda = not args.no_cuda and torch.cuda.is_available() 112 | torch.manual_seed(args.seed) 113 | if args.cuda: 114 | torch.cuda.manual_seed(args.seed) 115 | train_loader, test_loader = load_data(args.batch_size, args.test_batch_size, args.cuda) 116 | 117 | # create model 118 | model = Net() 119 | if args.cuda: 120 | model.cuda() 121 | 122 | # train 123 | for epoch in range(1, args.epochs + 1): 124 | model = train(epoch, train_loader) 125 | test(model, test_loader) 126 | 127 | # save 128 | torch.save(model.state_dict(), 'mnist.model') 129 | # load and test 130 | # model_loaded = Net().cuda() 131 | # model_loaded.load_state_dict(torch.load('mnist.model')) 132 | # test(model_loaded, test_loader) 133 | -------------------------------------------------------------------------------- /dsets/mnist/mnist.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/dsets/mnist/mnist.model -------------------------------------------------------------------------------- /dsets/mnist/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Net(nn.Module): 7 | '''A simple conv net 8 | ''' 9 | def __init__(self): 10 | super(Net, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 12 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 13 | self.conv2_drop = nn.Dropout2d() 14 | self.fc1 = nn.Linear(320, 50) 15 | self.fc2 = nn.Linear(50, 10) 16 | 17 | def forward(self, x): 18 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 19 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 20 | x = x.view(-1, 320) 21 | x = F.relu(self.fc1(x)) 22 | x = F.dropout(x, training=self.training) 23 | x = self.fc2(x) 24 | return F.log_softmax(x, dim=1) 25 | 26 | def logits(self, x): 27 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 28 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 29 | x = x.view(-1, 320) 30 | x = F.relu(self.fc1(x)) 31 | x = F.dropout(x, training=self.training) 32 | x = self.fc2(x) 33 | return x 34 | 35 | def predicted_class(self, x): 36 | pred = self.forward(x) 37 | _, pred = pred[0].max(0) 38 | return pred.item() #data[0] 39 | -------------------------------------------------------------------------------- /dsets/mnist/readme.md: -------------------------------------------------------------------------------- 1 | - code adapted from https://github.com/pytorch/examples/tree/master/mnist -------------------------------------------------------------------------------- /dsets/sst/dset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchtext import data, datasets, vocab 4 | import random 5 | import os 6 | import pickle as pkl 7 | import sys 8 | path_to_file = os.path.dirname(__file__) 9 | 10 | # deal with different torchtext versions 11 | try: 12 | vocab._default_unk_index 13 | except AttributeError: 14 | def _default_unk_index(): 15 | return 0 16 | vocab._default_unk_index = _default_unk_index 17 | 18 | 19 | # set up data loaders 20 | def get_sst(): 21 | inputs = data.Field(lower='preserve-case') 22 | answers = data.Field(sequential=False, unk_token=None) 23 | 24 | # build with subtrees so inputs are right 25 | train_s, dev_s, test_s = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=True, 26 | filter_pred=lambda ex: ex.label != 'neutral') 27 | inputs.build_vocab(train_s, dev_s, test_s) 28 | answers.build_vocab(train_s) 29 | 30 | # rebuild without subtrees to get longer sentences 31 | train, dev, test = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=False, 32 | filter_pred=lambda ex: ex.label != 'neutral') 33 | 34 | train_iter, dev_iter, test_iter = data.BucketIterator.splits( 35 | (train, dev, test), batch_size=1, device=0) 36 | 37 | return inputs, answers, train_iter, dev_iter 38 | 39 | 40 | # get specific batches 41 | def get_batches(batch_nums, train_iterator, dev_iterator, dset='dev'): 42 | print('getting batches...') 43 | np.random.seed(13) 44 | random.seed(13) 45 | 46 | # pick data_iterator 47 | if dset == 'train': 48 | data_iterator = train_iterator 49 | elif dset == 'dev': 50 | data_iterator = dev_iterator 51 | 52 | # actually get batches 53 | num = 0 54 | batches = {} 55 | data_iterator.init_epoch() 56 | for batch_idx, batch in enumerate(data_iterator): 57 | if batch_idx == batch_nums[num]: 58 | batches[batch_idx] = batch 59 | num += 1 60 | 61 | if num == max(batch_nums): 62 | break 63 | elif num == len(batch_nums): 64 | print('found them all') 65 | break 66 | return batches 67 | 68 | def load_vocab(): 69 | return pkl.load(open(os.path.join(path_to_file, 'sst_vocab.pkl'), 'rb')) 70 | 71 | def load_model(): 72 | model = LSTMSentiment() 73 | 74 | def batch_from_str_list(s, vocab, device='cpu'): 75 | '''Put text into .text attribute of a batch 76 | ''' 77 | batch = lambda: None # placeholder which holds .text attribute 78 | nums = np.expand_dims(np.array([vocab['stoi'][x] for x in s]).transpose(), 79 | axis=1) 80 | batch.text = torch.LongTensor(nums).to(device) #cuda() 81 | return batch -------------------------------------------------------------------------------- /dsets/sst/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LSTMSentiment(nn.Module): 6 | def __init__(self, config=None, d_hidden=128, n_embed=18844, d_embed=300, d_out=2, batch_size=50): 7 | super(LSTMSentiment, self).__init__() 8 | if config is not None: 9 | self.hidden_dim = config.d_hidden 10 | self.vocab_size = config.n_embed 11 | self.emb_dim = config.d_embed 12 | self.num_out = config.d_out 13 | self.batch_size = config.batch_size 14 | else: 15 | self.hidden_dim = d_hidden 16 | self.vocab_size = n_embed 17 | self.emb_dim = d_embed 18 | self.num_out = d_out 19 | self.batch_size = batch_size 20 | self.use_gpu = True # config.use_gpu 21 | self.num_labels = 2 22 | self.embed = nn.Embedding(self.vocab_size, self.emb_dim) 23 | self.lstm = nn.LSTM(input_size=self.emb_dim, hidden_size=self.hidden_dim) 24 | self.hidden_to_label = nn.Linear(self.hidden_dim, self.num_labels) 25 | 26 | def forward(self, batch): 27 | self.hidden = (torch.zeros(1, batch.text.shape[1], self.hidden_dim), 28 | torch.zeros(1, batch.text.shape[1], self.hidden_dim)) 29 | vecs = self.embed(batch.text) 30 | lstm_out, self.hidden = self.lstm(vecs, self.hidden) 31 | logits = self.hidden_to_label(lstm_out[-1]) 32 | # log_probs = self.log_softmax(logits) 33 | # return log_probs 34 | return logits 35 | 36 | def predict(self, batch): 37 | pred = self.forward(batch) 38 | _, pred = pred[0].max(0) 39 | return pred.data[0] 40 | -------------------------------------------------------------------------------- /dsets/sst/readme.md: -------------------------------------------------------------------------------- 1 | - code adapted from https://github.com/clairett/pytorch-sentiment-classification -------------------------------------------------------------------------------- /dsets/sst/sst_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/dsets/sst/sst_vocab.pkl -------------------------------------------------------------------------------- /dsets/sst/state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/dsets/sst/state_dict.pth -------------------------------------------------------------------------------- /dsets/sst/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | 5 | import torch 6 | import torch.optim as O 7 | import torch.nn as nn 8 | from argparse import ArgumentParser 9 | 10 | from torchtext import data 11 | from torchtext import datasets 12 | 13 | from model import LSTMSentiment 14 | 15 | 16 | def get_args(): 17 | parser = ArgumentParser(description='PyTorch/torchtext SST') 18 | parser.add_argument('--epochs', type=int, default=5) 19 | parser.add_argument('--batch_size', type=int, default=50) 20 | parser.add_argument('--d_embed', type=int, default=300) 21 | parser.add_argument('--d_proj', type=int, default=300) 22 | parser.add_argument('--d_hidden', type=int, default=128) 23 | parser.add_argument('--n_layers', type=int, default=1) 24 | parser.add_argument('--log_every', type=int, default=1000) 25 | parser.add_argument('--lr', type=float, default=.001) 26 | parser.add_argument('--dev_every', type=int, default=1000) 27 | parser.add_argument('--save_every', type=int, default=1000) 28 | parser.add_argument('--dp_ratio', type=int, default=0.2) 29 | parser.add_argument('--no-bidirectional', action='store_false', dest='birnn') 30 | parser.add_argument('--preserve-case', action='store_false', dest='lower') 31 | parser.add_argument('--no-projection', action='store_false', dest='projection') 32 | parser.add_argument('--train_embed', action='store_false', dest='fix_emb') 33 | parser.add_argument('--gpu', type=int, default=0) 34 | parser.add_argument('--save_path', type=str, default='results') 35 | parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt')) 36 | parser.add_argument('--word_vectors', type=str, default='glove.6B.300d') 37 | parser.add_argument('--resume_snapshot', type=str, default='') 38 | parser.add_argument('--bad', dest='bad', action='store_true') 39 | parser.set_defaults(bad=False) 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def makedirs(name): 45 | """helper function for python 2 and 3 to call os.makedirs() 46 | avoiding an error if the directory to be created already exists""" 47 | 48 | import os, errno 49 | 50 | try: 51 | os.makedirs(name) 52 | except OSError as ex: 53 | if ex.errno == errno.EEXIST and os.path.isdir(name): 54 | # ignore existing directory 55 | pass 56 | else: 57 | # a different error happened 58 | raise 59 | 60 | 61 | args = get_args() 62 | torch.cuda.set_device(args.gpu) 63 | 64 | inputs = data.Field(lower=args.lower) 65 | answers = data.Field(sequential=False, unk_token=None) 66 | 67 | train, dev, test = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=True, 68 | filter_pred=lambda ex: ex.label != 'neutral') 69 | 70 | inputs.build_vocab(train, dev, test) 71 | if args.word_vectors: 72 | if os.path.isfile(args.vector_cache): 73 | inputs.vocab.vectors = torch.load(args.vector_cache) 74 | else: 75 | inputs.vocab.load_vectors(args.word_vectors) 76 | makedirs(os.path.dirname(args.vector_cache)) 77 | torch.save(inputs.vocab.vectors, args.vector_cache) 78 | answers.build_vocab(train) 79 | 80 | train_iter, dev_iter, test_iter = data.BucketIterator.splits( 81 | (train, dev, test), batch_size=args.batch_size, device=args.gpu) 82 | 83 | config = args 84 | config.n_embed = len(inputs.vocab) 85 | config.d_out = len(answers.vocab) 86 | config.n_cells = config.n_layers 87 | 88 | if args.bad: 89 | config.d_hidden = 10 90 | 91 | # double the number of cells for bidirectional networks 92 | if config.birnn: 93 | config.n_cells *= 2 94 | 95 | if args.resume_snapshot: 96 | model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage.cuda(args.gpu)) 97 | else: 98 | model = LSTMSentiment(config=config) 99 | if args.word_vectors: 100 | model.embed.weight.data = inputs.vocab.vectors 101 | model.cuda() 102 | 103 | criterion = nn.CrossEntropyLoss() 104 | opt = O.Adam(model.parameters()) # , lr=args.lr) 105 | # model.embed.requires_grad = False 106 | 107 | iterations = 0 108 | start = time.time() 109 | best_dev_acc = -1 110 | train_iter.repeat = False 111 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' 112 | dev_log_template = ' '.join( 113 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) 114 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(',')) 115 | makedirs(args.save_path) 116 | print(header) 117 | 118 | all_break = False 119 | for epoch in range(args.epochs): 120 | if all_break: 121 | break 122 | train_iter.init_epoch() 123 | n_correct, n_total = 0, 0 124 | for batch_idx, batch in enumerate(train_iter): 125 | 126 | # switch model to training mode, clear gradient accumulators 127 | model.train(); 128 | opt.zero_grad() 129 | 130 | iterations += 1 131 | 132 | # forward pass 133 | answer = model(batch) 134 | 135 | # calculate accuracy of predictions in the current batch 136 | n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum() 137 | n_total += batch.batch_size 138 | train_acc = 100. * n_correct / n_total 139 | 140 | # calculate loss of the network output with respect to training labels 141 | loss = criterion(answer, batch.label) 142 | 143 | # backpropagate and update optimizer learning rate 144 | loss.backward(); 145 | opt.step() 146 | 147 | # checkpoint model periodically 148 | if iterations % args.save_every == 0: 149 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') 150 | if args.bad: 151 | snapshot_prefix += '_bad' 152 | snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], 153 | iterations) 154 | torch.save(model, snapshot_path) 155 | for f in glob.glob(snapshot_prefix + '*'): 156 | if f != snapshot_path: 157 | os.remove(f) 158 | 159 | # evaluate performance on validation set periodically 160 | # if iterations % args.dev_every == 0 or (args.bad and iterations % (args.dev_every / 10) == 0): 161 | if iterations % args.dev_every == 0: 162 | 163 | # switch model to evaluation mode 164 | model.eval(); 165 | dev_iter.init_epoch() 166 | 167 | # calculate accuracy on validation set 168 | n_dev_correct, dev_loss = 0, 0 169 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 170 | answer = model(dev_batch) 171 | n_dev_correct += ( 172 | torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum() 173 | dev_loss = criterion(answer, dev_batch.label) 174 | dev_acc = 100. * n_dev_correct / len(dev) 175 | 176 | print(dev_log_template.format(time.time() - start, 177 | epoch, iterations, 1 + batch_idx, len(train_iter), 178 | 100. * (1 + batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], 179 | train_acc, dev_acc)) 180 | 181 | # update best valiation set accuracy 182 | if dev_acc > best_dev_acc: 183 | 184 | best_dev_acc = dev_acc 185 | snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') 186 | if args.bad: 187 | snapshot_prefix += '_bad' 188 | snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, 189 | dev_loss.data[0], 190 | iterations) 191 | 192 | # save model, delete previous 'best_snapshot' files 193 | torch.save(model, snapshot_path) 194 | print("Saved", snapshot_path, iterations) 195 | for f in glob.glob(snapshot_prefix + '*'): 196 | if f != snapshot_path and ((args.bad and 'bad' not in f) or (not args.bad and 'bad' in f)): 197 | os.remove(f) 198 | 199 | # If we want a bad model, quit early 200 | if False and args.bad and best_dev_acc > 0.65: 201 | all_break = True 202 | break 203 | 204 | elif iterations % args.log_every == 0: 205 | 206 | # print progress message 207 | print(log_template.format(time.time() - start, 208 | epoch, iterations, 1 + batch_idx, len(train_iter), 209 | 100. * (1 + batch_idx) / len(train_iter), loss.data[0], ' ' * 8, 210 | n_correct / n_total * 100, ' ' * 12)) 211 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

Hierarchical neural-net interpretations (ACD) 🧠

2 | 3 |

Produces hierarchical interpretations for a single prediction made by a pytorch neural network. Official code for Hierarchical interpretations for neural network predictions (ICLR 2019 pdf).

4 | 5 |

6 | 7 | 8 | 9 | 10 | 11 | 12 |

13 |

14 | Documentation • 15 | Demo notebooks 16 |

17 |

18 | Note: this repo is actively maintained. For any questions please file an issue. 19 |

20 | 21 | 22 | ![](https://csinva.io/hierarchical-dnn-interpretations/intro.svg?sanitize=True) 23 | 24 | 25 | 26 | # examples/documentation 27 | 28 | - **installation**: `pip install acd` (or clone and run `python setup.py install`) 29 | - **examples**: the [reproduce_figs](https://github.com/csinva/hierarchical-dnn-interpretations/tree/master/reproduce_figs) folder has notebooks with many demos 30 | - **src**: the [acd](acd) folder contains the source for the method implementation 31 | - allows for different types of interpretations by changing hyperparameters (explained in examples) 32 | - all required data/models/code for reproducing are included in the [dsets](dsets) folder 33 | 34 | | Inspecting NLP sentiment models | Detecting adversarial examples | Analyzing imagenet models | 35 | | ---------------------------------- | ----------------------------------- | ----------------------------------- | 36 | | ![](reproduce_figs/figs/fig_2.png) | ![](reproduce_figs/figs/fig_s3.png) | ![](reproduce_figs/figs/fig_s2.png) | 37 | 38 | 39 | # notes on using ACD on your own data 40 | - the current CD implementation often works out-of-the box, especially for networks built on common layers, such as alexnet/vgg/resnet. However, if you have custom layers or layers not accessible in `net.modules()`, you may need to write a custom function to iterate through some layers of your network (for examples see `cd.py`). 41 | - to use baselines such build-up and occlusion, replace the pred_ims function by a function, which gets predictions from your model given a batch of examples. 42 | 43 | 44 | # related work 45 | 46 | - CDEP (ICML 2020 [pdf](https://arxiv.org/abs/1909.13584), [github](https://github.com/laura-rieger/deep-explanation-penalization)) - penalizes CD / ACD scores during training to make models generalize better 47 | - TRIM (ICLR 2020 workshop [pdf](https://arxiv.org/abs/2003.01926), [github](https://github.com/csinva/transformation-importance)) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies) 48 | - PDR framework (PNAS 2019 [pdf](https://arxiv.org/abs/1901.04592)) - an overarching framewwork for guiding and framing interpretable machine learning 49 | - DAC (arXiv 2019 [pdf](https://arxiv.org/abs/1905.07631), [github](https://github.com/csinva/disentangled-attribution-curves)) - finds disentangled interpretations for random forests 50 | - Baseline interpretability methods - the file `scores/score_funcs.py` also contains simple pytorch implementations of [integrated gradients](https://arxiv.org/abs/1703.01365) and the simple interpration technique `gradient * input` 51 | 52 | # reference 53 | 54 | - feel free to use/share this code openly 55 | - if you find this code useful for your research, please cite the following: 56 | 57 | ```r 58 | @inproceedings{ 59 | singh2019hierarchical, 60 | title={Hierarchical interpretations for neural network predictions}, 61 | author={Chandan Singh and W. James Murdoch and Bin Yu}, 62 | booktitle={International Conference on Learning Representations}, 63 | year={2019}, 64 | url={https://openreview.net/forum?id=SkEqro0ctQ}, 65 | } 66 | ``` 67 | 68 | -------------------------------------------------------------------------------- /reproduce_figs/figs/fig_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/reproduce_figs/figs/fig_2.png -------------------------------------------------------------------------------- /reproduce_figs/figs/fig_s2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/reproduce_figs/figs/fig_s2.png -------------------------------------------------------------------------------- /reproduce_figs/figs/fig_s3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csinva/hierarchical-dnn-interpretations/f3a79868420a9f51c825085d62bdff16f9e1a8f3/reproduce_figs/figs/fig_s3.png -------------------------------------------------------------------------------- /reproduce_figs/imagenet_fig3,s1,s2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to /tmp/.xdg_cache_vision/torch/hub/checkpoints/vgg16-397923af.pth\n" 13 | ] 14 | }, 15 | { 16 | "data": { 17 | "application/vnd.jupyter.widget-view+json": { 18 | "model_id": "9fa39cf2b232455ab53f07256ee083a9", 19 | "version_major": 2, 20 | "version_minor": 0 21 | }, 22 | "text/plain": [ 23 | "HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))" 24 | ] 25 | }, 26 | "metadata": {}, 27 | "output_type": "display_data" 28 | }, 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "%load_ext autoreload\n", 39 | "%autoreload 2\n", 40 | "%matplotlib inline\n", 41 | "import numpy as np\n", 42 | "import sys\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import torch\n", 45 | "\n", 46 | "sys.path.append('..')\n", 47 | "sys.path.append('../visualization')\n", 48 | "\n", 49 | "import viz_2d as viz\n", 50 | "import acd\n", 51 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 52 | "\n", 53 | "# get dataset\n", 54 | "import pickle as pkl\n", 55 | "from dsets.imagenet import dset\n", 56 | "imnet_dict = pkl.load(open('../dsets/imagenet/imnet_dict.pkl', 'rb')) # contains 6 images (keys: 9, 10, 34, 20, 36, 32)\n", 57 | "\n", 58 | "# get model\n", 59 | "from torchvision import models\n", 60 | "model = models.vgg16(pretrained=True).to(device).eval()\n", 61 | "model_type='vgg' # alexnet, vgg" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "# fig 3 - recreate hockey example" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# hyperparameters\n", 78 | "num_iters = 5 # number of iterations to agglomerate before merging remaning blobs (fig uses 4)\n", 79 | "percentile_include = 95 # values above this percentile will be added at each iteration (fig uses 95)\n", 80 | "method = 'cd' # method to rank importance ('cd' works best, 'build_up' or 'occlusion' are simplest)\n", 81 | "sweep_dim = 14 # importances are calculated by blocks of sweep_dim x sweep_dim (14 yields good results for imagenet)\n", 82 | "im_torch, im_orig, lab_num_correct = imnet_dict[9] # the hockey example\n", 83 | "lab_pred = np.argmax(dset.pred_ims(model, np.copy(im_orig)))\n", 84 | "\n", 85 | "lists = acd.agg_2d.agglomerate(model, dset.pred_ims, percentile_include,\n", 86 | " method, sweep_dim, im_orig, lab_pred,\n", 87 | " num_iters=num_iters, im_torch=im_torch, model_type=model_type) \n", 88 | "\n", 89 | "\n", 90 | "# visualize\n", 91 | "plt.figure(figsize=(12, 5), facecolor='white', dpi=100)\n", 92 | "rows = 3 \n", 93 | "num_ims = len(lists['scores_list'])\n", 94 | "\n", 95 | "# original plots\n", 96 | "ind, labs = viz.visualize_original_preds(im_orig, lab_num_correct, \n", 97 | " lists['comp_scores_raw_list'], lists['scores_orig_raw'],\n", 98 | " subplot_rows=rows, dset=dset)\n", 99 | "\n", 100 | "# comp plots\n", 101 | "viz.visualize_ims_list(lists['comps_list'],\n", 102 | " title='Chosen blobs',\n", 103 | " subplot_row=1, subplot_rows=rows, colorbar=False, im_orig=im_orig, plot_overlay=True)\n", 104 | "\n", 105 | "# dict plots\n", 106 | "viz.visualize_dict_list_top(lists['comp_scores_raw_list'], method,\n", 107 | " subplot_row=2, subplot_rows=rows, ind=ind, labs=labs, use_orig_top=True)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "# fig s1 - compare different scores" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "def get_diff_scores(im_torch, im_orig, label_num, model, preds, sweep_dim):\n", 124 | " scores = []\n", 125 | "\n", 126 | " # cd\n", 127 | " method = 'cd'\n", 128 | " tiles = acd.tiling_2d.gen_tiles(im_orig, fill=0, method=method, sweep_dim=sweep_dim)\n", 129 | " scores_cd = acd.get_scores_2d(model, method=method, ims=tiles, \n", 130 | " im_torch=im_torch, model_type=model_type, device=device)\n", 131 | " scores.append(scores_cd)\n", 132 | " for method in ['occlusion', 'build_up']: # 'build_up'\n", 133 | " tiles_break = acd.tiling_2d.gen_tiles(im_orig, fill=0, method=method, sweep_dim=sweep_dim)\n", 134 | " preds_break = acd.get_scores_2d(model, method=method, ims=tiles_break, \n", 135 | " im_torch=im_torch, pred_ims=dset.pred_ims)\n", 136 | " if method == 'occlusion':\n", 137 | " preds_break += preds\n", 138 | " scores.append(np.copy(preds_break))\n", 139 | " \n", 140 | " # get integrated gradients scores\n", 141 | " scores.append(acd.ig_scores_2d(model, im_torch, num_classes=1000, \n", 142 | " im_size=224, sweep_dim=sweep_dim, ind=[label_num], device=device))\n", 143 | " \n", 144 | " return scores\n", 145 | "\n", 146 | "\n", 147 | "\n", 148 | "# pick an image + get scores\n", 149 | "im_nums = [34, 20, 36, 32] # 34: screen, 20: snake, 36: trash can, 32: crane\n", 150 | "sweep_dim = 14\n", 151 | "# sweep_dim = 56\n", 152 | "fig = plt.figure(figsize=(10, 8), facecolor='white')\n", 153 | "\n", 154 | "for x, im_num in enumerate(im_nums):\n", 155 | "\n", 156 | " im_torch, im_orig, label_num = imnet_dict[im_num] # remember torch is H x W x C\n", 157 | " print('lab', dset.lab_dict[label_num])\n", 158 | " # viz.visualize_ims_tiled(tiling.gen_tiles(im_orig, fill=np.nan))\n", 159 | " preds = dset.pred_ims(model, im_orig).flatten()\n", 160 | " ind = np.argpartition(preds, -8)[-8:] # top-scoring indexes\n", 161 | " ind = ind[np.argsort(preds[ind])][::-1] # sort the indexes\n", 162 | " scores = get_diff_scores(im_torch, im_orig, label_num, model, preds, sweep_dim)\n", 163 | "\n", 164 | " # plot raw image\n", 165 | " num_rows = len(im_nums)\n", 166 | " num_cols = len(scores) + 1\n", 167 | " plt.subplot(num_rows, num_cols, 1 + x * num_cols)\n", 168 | " plt.imshow(im_orig)\n", 169 | "# plt.axis('off')\n", 170 | " plt.gca().xaxis.set_visible(False)\n", 171 | " plt.yticks([])\n", 172 | " if x == 0:\n", 173 | " plt.title('Image', fontsize=16)\n", 174 | "\n", 175 | " if x == 0:\n", 176 | " plt.ylabel('CRT screen', fontsize=15)\n", 177 | " elif x == 1:\n", 178 | " plt.ylabel('Green mamba', fontsize=15)\n", 179 | " elif x == 2:\n", 180 | " plt.ylabel('Trash can', fontsize=15)\n", 181 | " elif x == 3:\n", 182 | " plt.ylabel('Crane', fontsize=15)\n", 183 | "\n", 184 | "\n", 185 | " # plot scores\n", 186 | " vmax = max([np.max(scores[i]) for i in range(len(scores))])\n", 187 | " vmin = min([np.min(scores[i]) for i in range(len(scores))])\n", 188 | " vabs = max(abs(vmax), abs(vmin))\n", 189 | " for i, tit in enumerate(['CD', 'Occlusion', 'Build-Up', 'IG']):\n", 190 | " plt.subplot(num_rows, num_cols, 2 + i + x * num_cols)\n", 191 | " if i == 0:\n", 192 | " plt.ylabel('pred: ' + dset.lab_dict[ind[0]][:16] + '...', fontsize=15) \n", 193 | " if x == 0:\n", 194 | " plt.title(tit, fontsize=16)\n", 195 | " p = viz.visualize_preds(scores[i], num=label_num, cbar=False) #axis_off=False, vabs=vabs)\n", 196 | " plt.xticks([])\n", 197 | " plt.yticks([])\n", 198 | "# divider = make_axes_locatable(plt.gca())\n", 199 | "# cax = divider.append_axes(\"right\", size=\"2%\", pad=0.05)\n", 200 | "# plt.colorbar(p, cax=cax)\n", 201 | " \n", 202 | "plt.tight_layout()\n", 203 | "plt.subplots_adjust(wspace=0, hspace=0)\n", 204 | "plt.show()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "# cd propagation fig\n", 212 | "Tracks CD scores layer-by-layer through VGG" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "def vgg_track(im_torch, model):\n", 222 | " mods = list(model.modules())[2:]\n", 223 | " scores = []\n", 224 | " x = im_torch.clone()\n", 225 | " for i in range(30):\n", 226 | " x = mods[i](x)\n", 227 | " if i in [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]: # all the conv 2d layers\n", 228 | " scores.append(x.clone())\n", 229 | " return scores\n", 230 | "\n", 231 | "\n", 232 | "# calculate cd score\n", 233 | "f = 16\n", 234 | "# im_num = 23\n", 235 | "im_torch, im_orig, lab_num_correct = imnet_dict[9] # the hockey example\n", 236 | "# im_torch, im_orig, label_num = dset.get_im_and_label(im_num) # remember torch is H x W x C\n", 237 | "\n", 238 | "# set up blob\n", 239 | "blob = np.zeros((3, 224, 224))\n", 240 | "blob[:, 60:80, 150:200] = 1\n", 241 | "blob[:, 60:120, 180:200] = 1\n", 242 | "\n", 243 | "rel1, irrel1, scores = acd.cd_track_vgg(blob, im_torch, model)\n", 244 | "scores = [scores[i] for i in np.arange(0, len(scores), 2)] # pick every other\n", 245 | "\n", 246 | "# calculate build_up score\n", 247 | "im_torch2, im_orig, label_num = dset.get_im_and_label(im_num) # remember torch is H x W x C\n", 248 | "im_torch2[0, :, :60, :] = 0\n", 249 | "im_torch2[0, :, :, :150] = 0\n", 250 | "im_torch2[0, :, 80:120, :180] = 0\n", 251 | "im_torch2[0, :, :, 200:] = 0\n", 252 | "im_torch2[0, :, 120:, :] = 0\n", 253 | "scores2 = vgg_track(im_torch2, model)\n", 254 | "scores2 = [scores2[i] for i in np.arange(0, len(scores2), 2)] # pick every other\n", 255 | "\n", 256 | "# calculate occlusion score\n", 257 | "im_torch, im_orig, label_num = dset.get_im_and_label(im_num) # remember torch is H x W x C\n", 258 | "\n", 259 | "im_torch[0, :, 60:80, 150:200] = 0\n", 260 | "im_torch[0, :, 60:120, 180:200] = 0\n", 261 | "\n", 262 | "scores3 = vgg_track(im_torch, model)\n", 263 | "scores3 = [scores3[i] for i in np.arange(0, len(scores3), 2)] # pick every other\n", 264 | "\n", 265 | "plt.figure(figsize=(16, 8))\n", 266 | "num_rows = 4\n", 267 | "num_cols = len(scores) + 1\n", 268 | "\n", 269 | "# show original ims\n", 270 | "# plt.subplot2grid((num_rows, num_cols), (0, 0), rowspan=num_rows)\n", 271 | "# plt.gcf().text(0.18, 0.85, 'Blob', fontsize=14)\n", 272 | "# plt.gcf().text(0.16, 0.15, 'Non-blob', fontsize=14)\n", 273 | "plt.subplot(num_rows, num_cols, 1)\n", 274 | "plt.imshow(im_orig)\n", 275 | "blob_show = np.copy(blob[0])\n", 276 | "blob_show[blob_show == 0] = np.nan\n", 277 | "plt.imshow(blob_show, alpha=0.6, cmap='Greens')\n", 278 | "plt.ylabel('CD $\\\\beta$', fontsize=f)\n", 279 | "plt.xticks([])\n", 280 | "plt.yticks([])\n", 281 | "\n", 282 | "\n", 283 | "plt.subplot(num_rows, num_cols, num_cols + 1)\n", 284 | "plt.imshow(im_orig, cmap='Greens')\n", 285 | "plt.imshow(blob_show, alpha=0.6, cmap='Greens')\n", 286 | "plt.ylabel('CD $\\\\gamma$', fontsize=f)\n", 287 | "plt.xticks([])\n", 288 | "plt.yticks([])\n", 289 | "\n", 290 | "plt.subplot(num_rows, num_cols, num_cols * 3 + 1)\n", 291 | "im_blob = np.copy(im_orig)\n", 292 | "blob_idxs = blob.astype(np.int).transpose((1, 2, 0))\n", 293 | "im_blob[blob_idxs] = 0\n", 294 | "im_blob[60:80, 150:200] = 0\n", 295 | "im_blob[60:120, 180:200] = 0\n", 296 | "plt.imshow(im_blob)\n", 297 | "plt.ylabel('Occlusion', fontsize=f)\n", 298 | "plt.xticks([])\n", 299 | "plt.yticks([])\n", 300 | "\n", 301 | "plt.subplot(num_rows, num_cols, num_cols * 2 + 1)\n", 302 | "im_blob = np.copy(im_orig)\n", 303 | "blob_idxs = blob.astype(np.int).transpose((1, 2, 0))\n", 304 | "# im_blob[blob_idxs] = 0\n", 305 | "# im_blob[60:80, 150:200] = 0\n", 306 | "# im_blob[60:120, 180:200] = 0\n", 307 | "\n", 308 | "im_blob[:60, :] = 0\n", 309 | "im_blob[:, :150] = 0\n", 310 | "im_blob[80:120, :180] = 0\n", 311 | "im_blob[:, 200:] = 0\n", 312 | "im_blob[120:, :] = 0\n", 313 | "plt.imshow(im_blob)\n", 314 | "plt.ylabel('Build up', fontsize=f)\n", 315 | "plt.xticks([])\n", 316 | "plt.yticks([])\n", 317 | "\n", 318 | "# show propagating images\n", 319 | "for i in range(len(scores)):\n", 320 | " rel1, irrel1 = scores[i]\n", 321 | " x = np.squeeze(rel1.data.cpu().numpy())\n", 322 | " x = np.sum(np.abs(x), axis=0)\n", 323 | "\n", 324 | " y = np.squeeze(irrel1.data.cpu().numpy())\n", 325 | " y = np.sum(np.abs(y), axis=0)\n", 326 | "\n", 327 | " rel2 = scores2[i]\n", 328 | " z = np.squeeze(rel2.data.cpu().numpy())\n", 329 | " z = np.sum(np.abs(z), axis=0)\n", 330 | "\n", 331 | " rel3 = scores3[i]\n", 332 | " zz = np.squeeze(rel3.data.cpu().numpy())\n", 333 | " zz = np.sum(np.abs(zz), axis=0)\n", 334 | " \n", 335 | " vmax1, vmin1 = max(np.max(x), np.max(z)), min(np.min(x), np.min(z))\n", 336 | " vmax2, vmin2 = max(np.max(y), np.max(zz)), min(np.min(y), np.min(zz))\n", 337 | "\n", 338 | " # top row\n", 339 | " plt.subplot(num_rows, num_cols, i + 2)\n", 340 | " plt.imshow(x, interpolation='None', vmin=vmin1, vmax=vmax1)\n", 341 | " plt.axis('off')\n", 342 | " plt.title('Conv ' + str(2*i+1), fontsize=f)\n", 343 | " \n", 344 | " # plot 2\n", 345 | " plt.subplot(num_rows, num_cols, num_cols + i + 2)\n", 346 | " plt.imshow(y, interpolation='None', vmin=vmin2, vmax=vmax2)\n", 347 | " plt.axis('off')\n", 348 | "\n", 349 | " # plot 3\n", 350 | " plt.subplot(num_rows, num_cols, num_cols * 2 + i + 2)\n", 351 | " plt.imshow(z, interpolation='None', cmap='viridis', vmin=vmin1, vmax=vmax1)\n", 352 | " plt.axis('off') \n", 353 | "\n", 354 | " # plot 4\n", 355 | " plt.subplot(num_rows, num_cols, num_cols * 3 + i + 2)\n", 356 | " plt.imshow(zz, interpolation='None', cmap='viridis', vmin=vmin2, vmax=vmax2)\n", 357 | " plt.axis('off')\n", 358 | "\n", 359 | "plt.subplots_adjust(hspace=0, wspace=0)\n", 360 | "plt.show()" 361 | ] 362 | } 363 | ], 364 | "metadata": { 365 | "anaconda-cloud": {}, 366 | "kernelspec": { 367 | "display_name": "Python 3", 368 | "language": "python", 369 | "name": "python3" 370 | }, 371 | "language_info": { 372 | "codemirror_mode": { 373 | "name": "ipython", 374 | "version": 3 375 | }, 376 | "file_extension": ".py", 377 | "mimetype": "text/x-python", 378 | "name": "python", 379 | "nbconvert_exporter": "python", 380 | "pygments_lexer": "ipython3", 381 | "version": "3.8.3" 382 | }, 383 | "pycharm": { 384 | "stem_cell": { 385 | "cell_type": "raw", 386 | "source": [], 387 | "metadata": { 388 | "collapsed": false 389 | } 390 | } 391 | } 392 | }, 393 | "nbformat": 4, 394 | "nbformat_minor": 4 395 | } -------------------------------------------------------------------------------- /reproduce_figs/readme.md: -------------------------------------------------------------------------------- 1 | **This folder contains notebooks to reproduce / extend the results in the paper.** 2 | 3 | The [text notebook](text_fig2.ipynb) contains code to load a pretrained model on the SST dataset. Then, you can give it different sentences and observe the hierarchical interpretations it produces. 4 | 5 | ![](figs/fig_2.png) 6 | 7 | 8 | # mnist 9 | 10 | The [mnist notebook](mnist_figs3,s4.ipynb) contains code for analyzing the mnist dataset with ACD. Running this notebook will download the MNIST dataset, if you do not already have it. 11 | 12 | - note: adversarial attacks require the `foolbox` and `randomgen` python packages (installable via pip) 13 | - 'boundary attack' is currently commented out as a result of an error in the `randomgen` package 14 | 15 | ![](figs/fig_s3.png) 16 | 17 | 18 | # imagenet 19 | The [imagenet notebook](imagenet_fig3,s1,s2.ipynb) contains code for using CD on CNN models. It comes with a pickle file containing a few imagenet images for testing out. 20 | 21 | - note: redoing the imagenet results will be very slow if using cpu instead of gpu 22 | 23 | ![](figs/fig_s2.png) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | import setuptools 3 | 4 | with open("readme.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setup( 8 | name='acd', 9 | version='0.0.2', 10 | author="Chandan Singh", 11 | description="Hierarchical interpretatations and contextual decomposition in pytorch", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/csinva/hierarchical-dnn-interpretations", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | 'scipy', 18 | 'numpy', 19 | 'scikit-image', 20 | 'torch', 21 | 'tqdm', 22 | ], 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /tests/test_cd.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import sys 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | 8 | import acd 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | def test_sst(device='cpu'): 14 | # load the model and data 15 | sys.path.append('../dsets/sst') 16 | sst_pkl = pkl.load(open('../dsets/sst/sst_vocab.pkl', 'rb')) 17 | model = torch.load('../dsets/sst/sst.model', map_location=device) 18 | model.device = device 19 | 20 | # text and label 21 | sentence = ['a', 'great', 'ensemble', 'cast', 'ca', 'n\'t', 'lift', 'this', 'heartfelt', 'enterprise', 'out', 'of', 22 | 'the', 'familiar', '.'] # note this is a real example from the dataset 23 | 24 | def batch_from_str_list(s): 25 | # form class to hold data 26 | class B: 27 | text = torch.zeros(1).to(device) 28 | 29 | batch = B() 30 | nums = np.expand_dims(np.array([sst_pkl['stoi'][x] for x in s]).transpose(), axis=1) 31 | batch.text = torch.LongTensor(nums).to(device) # cuda() 32 | return batch 33 | 34 | # prepare inputs 35 | batch = batch_from_str_list(sentence) 36 | preds = model(batch).data.cpu().numpy()[0] # predict 37 | 38 | # check that full sentence = prediction 39 | preds = preds - model.hidden_to_label.bias.detach().numpy() 40 | cd_score, irrel_scores = acd.cd_text(batch, model, start=0, stop=len(sentence), return_irrel_scores=True) 41 | assert (np.allclose(cd_score, preds, atol=1e-2)) 42 | assert (np.allclose(irrel_scores, irrel_scores * 0, atol=1e-2)) 43 | 44 | # check that rel + irrel = prediction for another subset 45 | cd_score, irrel_scores = acd.cd_text(batch, model, start=3, stop=len(sentence), return_irrel_scores=True) 46 | assert (np.allclose(cd_score + irrel_scores, preds, atol=1e-2)) 47 | 48 | 49 | def test_mnist(device='cuda'): 50 | # load the dataset 51 | sys.path.append('../dsets/mnist') 52 | import dsets.mnist.model 53 | device = 'cuda' 54 | im_torch = torch.randn(1, 1, 28, 28).to(device) 55 | 56 | # load the model 57 | model = dsets.mnist.model.Net().to(device) 58 | model.load_state_dict(torch.load('../dsets/mnist/mnist.model', map_location=device)) 59 | model = model.eval() 60 | 61 | # check that full image mask = prediction 62 | preds = model.logits(im_torch).cpu().detach().numpy() 63 | cd_score, irrel_scores = acd.cd(im_torch, model, mask=np.ones((1, 1, 28, 28)), model_type='mnist', device=device) 64 | cd_score = cd_score.cpu().detach().numpy() 65 | irrel_scores = irrel_scores.cpu().detach().numpy() 66 | assert (np.allclose(cd_score, preds, atol=1e-2)) 67 | assert (np.allclose(irrel_scores, irrel_scores * 0, atol=1e-2)) 68 | 69 | # check that rel + irrel = prediction for another subset 70 | # preds = preds - model.hidden_to_label.bias.detach().numpy() 71 | mask = np.zeros((28, 28)) 72 | mask[:14] = 1 73 | cd_score, irrel_scores = acd.cd(im_torch, model, mask=mask, model_type='mnist', device=device) 74 | cd_score = cd_score.cpu().detach().numpy() 75 | irrel_scores = irrel_scores.cpu().detach().numpy() 76 | assert (np.allclose(cd_score + irrel_scores, preds, atol=1e-2)) 77 | 78 | 79 | def test_imagenet(device='cuda', arch='vgg'): 80 | # get dataset 81 | from torchvision import models 82 | imnet_dict = pkl.load( 83 | open('../dsets/imagenet/imnet_dict.pkl', 'rb')) # contains 6 images (keys: 9, 10, 34, 20, 36, 32) 84 | 85 | # get model and image 86 | if arch == 'vgg': 87 | model = models.vgg16(pretrained=True).to(device).eval() 88 | elif arch == 'alexnet': 89 | model = models.alexnet(pretrained=True).to(device).eval() 90 | elif arch == 'resnet18': 91 | model = models.resnet18(pretrained=True).to(device).eval() 92 | im_torch = torch.randn(1, 3, 224, 224).to(device) 93 | 94 | # get predictions 95 | preds = model(im_torch).cpu().detach().numpy() 96 | 97 | # check that rel + irrel = prediction for another subset 98 | mask = np.ones((1, 3, 224, 224)) 99 | mask[:, :, :14] = 1 100 | cd_score, irrel_scores = acd.cd(im_torch, model, mask=mask, device=device, model_type=arch) 101 | cd_score = cd_score.cpu().detach().numpy() 102 | irrel_scores = irrel_scores.cpu().detach().numpy() 103 | assert (np.allclose(cd_score + irrel_scores, preds, atol=1e-2)) 104 | 105 | # check that full image mask = prediction 106 | cd_score, irrel_scores = acd.cd(im_torch, model, mask=np.ones((1, 3, 224, 224)), device=device, model_type=arch) 107 | cd_score = cd_score.cpu().detach().numpy() 108 | irrel_scores = irrel_scores.cpu().detach().numpy() 109 | # print(cd_score.flatten()[:5], irrel_scores.flatten()[:5], preds.flatten()[:5]) 110 | assert (np.allclose(cd_score, preds, atol=1e-2)) 111 | assert (np.allclose(irrel_scores, irrel_scores * 0, atol=1e-2)) 112 | 113 | 114 | if __name__ == '__main__': 115 | with torch.no_grad(): 116 | print('testing sst...') 117 | test_sst() 118 | print('testing mnist...') 119 | test_mnist() 120 | print('testing imagenet vgg...') 121 | test_imagenet(arch='vgg') 122 | print('testing imagenet alexnet...') 123 | test_imagenet(arch='alexnet') 124 | print('testing imagenet resnet18...') 125 | test_imagenet(arch='resnet18') 126 | print('all tests passed!') 127 | -------------------------------------------------------------------------------- /visualization/viz_1d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import sys 4 | from os.path import join as oj 5 | import math 6 | import matplotlib.colors as colors 7 | 8 | 9 | def visualize_scores(scores, label, text_orig, score_orig, sweep_dim=1, method='break_down'): 10 | plt.figure(figsize=(2, 10)) 11 | try: 12 | p = scores.data.cpu().numpy()[:, label] 13 | except: 14 | p = scores 15 | 16 | # plot with labels 17 | text_orig = text_orig[sweep_dim - 1:] # todo - don't do this, deal with edges better 18 | plt.barh(range(p.size), p[::-1], align='center', tick_label=text_orig[::-1]) 19 | c = "pos" if label == 0 else "neg" 20 | plt.title(method + ' class ' + c + '\n(higher is more important)') # pretty sure 1 is positive, 2 is negative 21 | 22 | 23 | # plt.show() 24 | 25 | def print_scores(lists, text_orig, num_iters): 26 | text_orig = np.array(text_orig) 27 | print('score_orig', lists['score_orig']) 28 | 29 | print(text_orig) 30 | print(lists['scores_list'][0]) 31 | 32 | # print out blobs and corresponding scores 33 | for i in range(1, num_iters): 34 | print('iter', i) 35 | comps = lists['comps_list'][i] 36 | comp_scores_list = lists['comp_scores_list'][i] 37 | 38 | # sort scores in decreasing order 39 | comps_with_scores = sorted(zip(range(1, np.max(comps) + 1), 40 | [comp_scores_list[i] for i in comp_scores_list.keys()]), 41 | key=lambda x: x[1], reverse=True) 42 | 43 | for comp_num, comp_score in comps_with_scores: 44 | print(comp_num, '\t%.3f, %s' % (comp_score, str(text_orig[comps == comp_num]))) 45 | 46 | 47 | def word_heatmap(text_orig, lists, label_pred, label, method=None, subtract=True, data=None, fontsize=9): 48 | text_orig = np.array(text_orig) 49 | num_words = text_orig.size 50 | num_iters = len(lists['comps_list']) 51 | 52 | # populate data 53 | if data is None: 54 | data = np.empty(shape=(num_iters, num_words)) 55 | data[:] = np.nan 56 | data[0, :] = lists['scores_list'][0] 57 | for i in range(1, num_iters): 58 | comps = lists['comps_list'][i] 59 | comp_scores_list = lists['comp_scores_list'][i] 60 | 61 | for comp_num in range(1, np.max(comps) + 1): 62 | idxs = comps == comp_num 63 | data[i][idxs] = comp_scores_list[comp_num] 64 | 65 | data[np.isnan(data)] = 0 # np.nanmin(data) - 0.001 66 | if num_iters == 1: 67 | plt.figure(figsize=(16, 1), dpi=300) 68 | else: 69 | plt.figure(figsize=(16, 3), dpi=300) 70 | 71 | class MidpointNormalize(colors.Normalize): 72 | def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): 73 | self.midpoint = midpoint 74 | colors.Normalize.__init__(self, vmin, vmax, clip) 75 | 76 | def __call__(self, value, clip=None): 77 | x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] 78 | return np.ma.masked_array(np.interp(value, x, y)) 79 | 80 | # cmap = plt.get_cmap('RdBu') if label_pred == 0 else plt.get_cmap('RdBu_r') 81 | 82 | cmap = plt.get_cmap('RdBu') 83 | if label_pred == 1: 84 | data *= -1 85 | # cmap = matplotlib.cm.Greys 86 | # cmap.set_bad(color='black') 87 | # cmap='viridis')#'RdBu') 88 | abs_lim = max(abs(np.nanmax(data)), abs(np.nanmin(data))) 89 | 90 | c = plt.pcolor(data, 91 | edgecolors='k', 92 | linewidths=0, 93 | norm=MidpointNormalize(vmin=abs_lim * -1, midpoint=0., vmax=abs_lim), 94 | cmap=cmap) 95 | 96 | def show_values(pc, text_orig, data, fontsize, fmt="%s", **kw): 97 | val_mean = np.nanmean(data) 98 | val_min = np.min(data) 99 | pc.update_scalarmappable() 100 | # ax = pc.get_axes() 101 | ax = pc.axes 102 | 103 | for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()): 104 | x, y = p.vertices[:-2, :].mean(0) 105 | # pick color for text 106 | if np.all(color[:3] > 0.5): # value > val_mean: #value > val_mean: # 107 | color = (0.0, 0.0, 0.0) 108 | else: 109 | color = (1.0, 1.0, 1.0) 110 | x_ind = math.floor(x) 111 | y_ind = math.floor(y) 112 | 113 | # sometimes don't display text 114 | if y_ind == 0 or data[y_ind, x_ind] != 0: # > val_min: 115 | ax.text(x, y, fmt % text_orig[x_ind], 116 | ha="center", va="center", 117 | color=color, fontsize=fontsize, **kw) 118 | 119 | show_values(c, text_orig, data, fontsize) 120 | cb = plt.colorbar(c, extend='both') # fig.colorbar(pcm, ax=ax[0], extend='both') 121 | cb.outline.set_visible(False) 122 | plt.xlim((0, num_words)) 123 | plt.ylim((0, num_iters)) 124 | plt.yticks([]) 125 | plt.plot([0, num_words], [1, 1], color='black') 126 | plt.xticks([]) 127 | 128 | cb.ax.set_title('CD score') 129 | -------------------------------------------------------------------------------- /visualization/viz_2d.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib 3 | import numpy as np 4 | from cycler import cycler 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | import math 7 | from skimage.transform import resize 8 | import random 9 | 10 | 11 | # Create an N-bin discrete colormap from the specified input map 12 | def discrete_cmap(N, base_cmap=None): 13 | base = plt.cm.get_cmap(base_cmap) 14 | nums = np.linspace(1 / N, 1, N) 15 | random.Random(10).shuffle( 16 | nums) # shuffle in place so colors aren't consecutive, 9 for imagenet figs, now set for mnist figs 17 | nums[0] = 0 18 | color_list = base(nums) 19 | cmap_name = base.name + str(N) 20 | return color_list, base.from_list(cmap_name, color_list, N) 21 | 22 | 23 | # cmap 24 | # cmap = matplotlib.cm.Greys 25 | cmap = matplotlib.cm.get_cmap('RdBu') 26 | cmap.set_bad(color='#60ff16') # bright green 27 | N_COLORS = 11 28 | cmap_comp = discrete_cmap(N_COLORS, 'jet')[1] 29 | cmap_comp.set_under(color='#ffffff') # transparent for lowest value 30 | 31 | 32 | def visualize_ims_tiled(ims_tiled): 33 | # plt.figure(figsize=(6, 30)) 34 | num_ims = 25 # len(ims_tiled) 35 | D = 5 36 | for i in range(D * (num_ims // D)): 37 | plt.subplot(D, num_ims // D, 1 + i) 38 | plt.imshow(ims_tiled[i], cmap=cmap, interpolation='None') 39 | plt.axis('off') 40 | plt.subplots_adjust(wspace=None, hspace=None) 41 | 42 | 43 | def visualize_preds(preds, num, N=28, prev_im=None, cbar=True, vabs=None, axis_off=True): 44 | N = int(math.sqrt(preds.shape[0])) 45 | preds = preds[:, num].reshape(N, N) 46 | if not prev_im is None: 47 | preds[prev_im] = np.nan 48 | 49 | ax = plt.gca() 50 | 51 | if vabs is None: 52 | vmin = np.nanmin(preds) 53 | vmax = np.nanmax(preds) 54 | vabs = max(abs(vmin), abs(vmax)) 55 | p = plt.imshow(preds, cmap=cmap, 56 | vmin=-1 * vabs, vmax=vabs, interpolation='None') 57 | if axis_off: 58 | plt.axis('off') 59 | 60 | # colorbar 61 | if cbar: 62 | divider = make_axes_locatable(ax) 63 | cax = divider.append_axes("right", size="2%", pad=0.05) 64 | plt.colorbar(p, cax=cax) 65 | 66 | return p 67 | 68 | 69 | def visualize_batch_preds(preds, prev_im=None, N=28, im_num_start=0): 70 | preds_reshaped = np.zeros(N * N) 71 | preds_reshaped[im_num_start: im_num_start + preds.size] = preds 72 | preds_reshaped = preds_reshaped.reshape(N, N) 73 | # accs_reshaped = accs[:, num].reshape(N, N) 74 | if not prev_im is None: 75 | preds_reshaped[prev_im] = np.nan 76 | plt.imshow(preds_reshaped) 77 | return preds_reshaped 78 | 79 | 80 | def visualize_ims_list(ims_list, title='', cmap_new=None, subplot_row=None, subplot_rows=3, colorbar=True, im_orig=None, 81 | plot_overlay=False, mturk=False, num_ims=None, comp_scores_raw=None, lab_num_correct=None, 82 | skip_first=False, mnist=False): 83 | im_segs = [] 84 | if subplot_row is None: 85 | plt.figure(figsize=(12, 2), facecolor='white') 86 | subplot_row = 1 87 | if num_ims is None: 88 | num_ims = len(ims_list) 89 | for i in range(num_ims): 90 | if i >= len(ims_list): 91 | break 92 | ax = plt.subplot(subplot_rows, num_ims, num_ims * subplot_row + i + 1 - mnist) 93 | if cmap_new == 'redwhiteblue': 94 | vmin = min([np.min(im[np.logical_not(np.isnan(im))]) for im in ims_list]) 95 | vmax = max([np.max(im[np.logical_not(np.isnan(im))]) for im in ims_list]) 96 | vabs = max(abs(vmin), abs(vmax)) 97 | 98 | p = plt.imshow(ims_list[i], cmap=cmap, 99 | vmin=-1 * vabs, vmax=vabs, interpolation='nearest') 100 | else: 101 | # color images 102 | if plot_overlay: 103 | if not mnist: 104 | plt.imshow(im_orig) # plot image as background 105 | # overlay component comps 106 | if i > 0 or skip_first: 107 | if mturk: 108 | 109 | # need to map this to values of comps not comp_num 110 | im_nums = np.copy(ims_list[i]).astype(np.float32) 111 | comp_to_score = comp_scores_raw[i] 112 | 113 | for r in range(im_nums.shape[0]): 114 | for c in range(im_nums.shape[1]): 115 | comp_num = int(im_nums[r, c]) 116 | if comp_num > 0: 117 | im_nums[r, c] = comp_to_score[comp_num][lab_num_correct] 118 | 119 | im = cmap(im_nums) 120 | for r in range(im.shape[0]): 121 | for c in range(im.shape[1]): 122 | if im[r, c, 1] == 0: 123 | im[r, c, 3] = 0 124 | 125 | vmin = min([comp_to_score[comp_num][lab_num_correct] 126 | for comp_to_score in comp_scores_raw[1:] 127 | for comp_num in comp_to_score.keys()]) 128 | vmax = max([comp_to_score[comp_num][lab_num_correct] 129 | for comp_to_score in comp_scores_raw[1:] 130 | for comp_num in comp_to_score.keys()]) 131 | vabs = max(abs(vmin), abs(vmax)) 132 | else: 133 | # renumber to maintain right colors 134 | # if i > 1: 135 | # im_seg = establish_correspondence(ims_list[i-1], ims_list[i]) 136 | # ims_list[i] = im_seg 137 | # else: 138 | # im_seg = ims_list[i] 139 | 140 | im_seg = ims_list[i] 141 | im = cmap_comp(im_seg) 142 | for r in range(im.shape[0]): 143 | for c in range(im.shape[1]): 144 | if im_seg[r, c] == 0: 145 | im[r, c, 3] = 0 146 | map_reshaped = resize(im, (224, 224, 4), mode='symmetric', order=0) 147 | if mturk: 148 | plt.imshow(map_reshaped, alpha=0.9, interpolation='None', vmin=-1 * vabs, vmax=vabs) 149 | else: 150 | plt.imshow(map_reshaped, alpha=0.7) 151 | # not color 152 | else: 153 | p = plt.imshow(ims_list[i], 154 | cmap=discrete_cmap(N_COLORS, # len(np.unique(ims_list[i])) + 1, 155 | 'jet')[1], vmin=0, vmax=N_COLORS, interpolation='None') 156 | # plt.imshow(ims_list[i]) 157 | if i > 0 or mturk: 158 | plt.axis('off') 159 | else: 160 | plt.axis('off') 161 | # plt.ylabel(title) 162 | # plt.yticks([]) 163 | # plt.xticks([]) 164 | 165 | # colorbar 166 | if colorbar: 167 | plt.colorbar() 168 | # ax = plt.gca() 169 | # divider = make_axes_locatable(ax) 170 | # cax = divider.append_axes("right", size="10%", pad=0.05) 171 | # plt.colorbar(p, cax=cax) 172 | 173 | plt.subplots_adjust(wspace=0, hspace=0) 174 | 175 | 176 | def visualize_dict_list(dict_list, method='break-down / build-up', 177 | subplot_row=None, subplot_rows=3, lab_num=None, bar_graph=False): 178 | # if passed lab_num, plot only lab_num 179 | if lab_num is not None: 180 | dict_list_temp = [] 181 | for d in dict_list: 182 | d_new = {} 183 | for key in d: 184 | d_new[key] = np.array(d[key][lab_num]) 185 | dict_list_temp.append(d_new) 186 | dict_list = dict_list_temp 187 | 188 | if subplot_row is None: 189 | plt.figure(figsize=(12, 2), facecolor='white') 190 | subplot_row = 1 191 | num_ims = len(dict_list) 192 | preds_orig = dict_list[0][0] 193 | 194 | # try: 195 | vmin = min([np.min(d[key]) for d in dict_list[1:] for key in d]) - 1 196 | vmax = max([np.max(d[key]) for d in dict_list[1:] for key in d]) + 1 197 | if lab_num is None: 198 | vmin = min(vmin, np.min(preds_orig)) 199 | vmax = max(vmax, np.max(preds_orig)) 200 | 201 | # plot 1st preds 202 | plt.subplot(subplot_rows, num_ims, num_ims * subplot_row + 1) 203 | # plt.plot(preds_orig, '_', color='black') 204 | 205 | if lab_num is None: 206 | plt.bar(range(preds_orig.size), preds_orig, color='black') 207 | plt.ylabel('raw score full image') 208 | else: 209 | plt.ylabel('cd blob scores') 210 | plt.ylim((vmin, vmax)) 211 | for i in range(1, num_ims): 212 | p = plt.subplot(subplot_rows, num_ims, num_ims * subplot_row + i + 1) 213 | # num_components = len(dict_list[i].keys()) 214 | p.set_prop_cycle(cycler('color', discrete_cmap(N_COLORS, 'jet')[0][1:])) 215 | 216 | if bar_graph: 217 | region_nums = sorted(dict_list[i]) 218 | vals = [dict_list[i][region_num] for region_num in region_nums] 219 | plt.bar(region_nums, vals, color=discrete_cmap(N_COLORS, 'jet')[0][1:]) 220 | 221 | plt.plot(region_nums, vals, '_', color='black') 222 | plt.ylim((vmin - 1, vmax + 1)) 223 | else: 224 | 225 | for region_num in sorted(dict_list[i]): 226 | region_arr = dict_list[i][region_num] 227 | # for class_num in range(10): 228 | # print(class_num, region_arr[class_num]) 229 | plt.plot(region_arr, '_', markeredgewidth=2.5) 230 | plt.ylim((vmin, vmax)) 231 | 232 | cur_axes = plt.gca() 233 | # if not i == 0 and not i == 1: 234 | cur_axes.yaxis.set_visible(False) 235 | if lab_num is None: 236 | cur_axes.xaxis.set_ticklabels(np.arange(0, 10, 2)) 237 | cur_axes.xaxis.set_ticks(np.arange(0, 10, 2)) 238 | cur_axes.xaxis.grid() 239 | else: 240 | cur_axes.xaxis.set_visible(False) 241 | if i == 0: 242 | plt.ylabel('raw comp scores for ' + method) 243 | plt.subplots_adjust(wspace=0, hspace=0) 244 | 245 | 246 | # except Exception as e: 247 | # print('some empty plots', e) 248 | 249 | def visualize_arr_list(arr_list, method='break-down / build-up', 250 | subplot_row=None, subplot_rows=3): 251 | if subplot_row is None: 252 | plt.figure(figsize=(12, 2), facecolor='white') 253 | subplot_row = 1 254 | num_ims = len(arr_list) + 1 255 | 256 | vmin = min([np.min(d) for d in arr_list]) 257 | vmax = max([np.max(d) for d in arr_list]) 258 | 259 | for i in range(1, num_ims): 260 | p = plt.subplot(subplot_rows, num_ims, num_ims * subplot_row + i + 1) 261 | arr = arr_list[i - 1] 262 | # plt.plot(arr, '_', markeredgewidth=0, color='black') 263 | plt.bar(np.arange(arr.size), arr, color='black') 264 | plt.ylim((vmin, vmax)) 265 | cur_axes = plt.gca() 266 | if not i == 1: 267 | cur_axes.yaxis.set_visible(False) 268 | cur_axes.xaxis.set_ticklabels(np.arange(0, 10, 2)) 269 | cur_axes.xaxis.set_ticks(np.arange(0, 10, 2)) 270 | cur_axes.xaxis.grid() 271 | if i == 0: 272 | plt.ylabel('raw combined score for ' + method) 273 | plt.subplots_adjust(wspace=0, hspace=0) 274 | 275 | 276 | def visualize_original_preds(im_orig, lab_num, comp_scores_raw_list, scores_orig_raw, 277 | subplot_rows=5, dset=None, mturk=False, tits=None): 278 | num_cols = 7 - mturk 279 | plt.subplot(subplot_rows, num_cols, 1) 280 | plt.imshow(im_orig) 281 | if not tits is None: 282 | plt.title(tits[0]) 283 | else: 284 | plt.title(dset.lab_dict[lab_num].split(',')[0]) 285 | plt.axis('off') 286 | 287 | num_top = 5 288 | preds = comp_scores_raw_list[0][0] 289 | ind = np.argpartition(preds, -num_top)[-num_top:] # top-scoring indexes 290 | ind = ind[np.argsort(preds[ind])][::-1] # sort the indexes 291 | labs = [dset.lab_dict[x][:12] for x in ind] 292 | vals = preds[ind] 293 | 294 | # plotting 295 | if not mturk: 296 | plt.subplot(subplot_rows, num_cols, 2) 297 | idxs = np.arange(num_top) 298 | plt.barh(idxs, vals, color='#2ea9e888', edgecolor='#2ea9e888', fill=True, linewidth=1) 299 | 300 | for i, (val) in enumerate(zip(idxs, vals)): 301 | lab = str(labs[i]) 302 | if 'puck' in lab: 303 | lab = 'puck' 304 | plt.text(s=str(lab), x=1, y=i, color="black", verticalalignment="center", size=10) 305 | # plt.text(s=str(pr)+"%", x=pr-5, y=i, color="w", 306 | # verticalalignment="center", horizontalalignment="left", size=18) 307 | ax = plt.gca() 308 | # ax.set_yticklabels(labs) 309 | # ax.set_yticks(np.arange(num_top)) 310 | # plt.yticks(rotation='horizontal') 311 | ax.invert_yaxis() # labels read top-to-bottom 312 | ax.get_yaxis().set_visible(False) 313 | ax.get_xaxis().set_visible(False) 314 | plt.title('prediction logits') 315 | 316 | vmin = min([np.nanmin(scores_orig_raw[:, x]) for x in ind]) # preds[:, num] 317 | vmax = max([np.nanmax(scores_orig_raw[:, x]) for x in ind]) # preds[:, num] 318 | vabs = max(abs(vmin), abs(vmax)) 319 | 320 | for i, x in enumerate(ind): 321 | if i < num_top: 322 | plt.subplot(subplot_rows, num_cols, i + 3 - mturk) 323 | if mturk: 324 | visualize_preds(scores_orig_raw, num=x, cbar=False, vabs=vabs) 325 | plt.title(dset.lab_dict[x][:14] + '...') 326 | else: 327 | visualize_preds(scores_orig_raw, num=x, cbar=False, vabs=vabs) 328 | if tits is not None: 329 | plt.title(tits[i + 2]) 330 | else: 331 | plt.title('CD (' + dset.lab_dict[x][:10] + ')') # +'\n'+ str(preds[x])) 332 | 333 | return ind, labs 334 | 335 | 336 | def visualize_dict_list_top(dict_list, method='break-down / build-up', 337 | subplot_row=None, subplot_rows=3, lab_num=None, 338 | ind=None, labs=None, num_top=5, dset=None, use_orig_top=True, 339 | num_ims=None, skip_first=False, vmin=None, vmax=None): 340 | if subplot_row is None: 341 | plt.figure(figsize=(12, 2), facecolor='white') 342 | subplot_row = 1 343 | if num_ims is None: 344 | num_ims = len(dict_list) 345 | preds_orig = dict_list[0][0] 346 | 347 | if vmin is None: 348 | vmin = min([np.min(d[key]) for d in dict_list[1:num_ims + 1] for key in d]) - 1 349 | vmax = max([np.max(d[key]) for d in dict_list[1:num_ims + 1] for key in d]) + 1 350 | 351 | for i in range(1, num_ims + skip_first): 352 | if i >= len(dict_list): 353 | break 354 | p = plt.subplot(subplot_rows, num_ims, num_ims * subplot_row + i + 1 - skip_first) 355 | # num_components = len(dict_list[i].keys()) 356 | p.set_prop_cycle(cycler('color', discrete_cmap(N_COLORS, 'jet')[0][1:])) 357 | # print('keys', dict_list[i].keys()) 358 | 359 | for region_num in range(1, max(dict_list[i].keys()) + 1): 360 | # for region_num in sorted(dict_list[i]): 361 | # print('dict_list[i]', dict_list[i]) 362 | 363 | if region_num in dict_list[i]: # check if present 364 | if use_orig_top: 365 | # print(region_num) 366 | region_arr = dict_list[i][region_num][ind] 367 | plt.plot(region_arr, '_', markeredgewidth=2) 368 | plt.xticks(np.arange(region_arr.size), labs, rotation='vertical') 369 | plt.xlim((-1, region_arr.size)) 370 | else: 371 | if region_num == 1: 372 | region_arr = dict_list[i][region_num] 373 | ind = np.argpartition(region_arr, -num_top)[-num_top:] # top-scoring indexes 374 | ind = ind[np.argsort(region_arr[ind])][::-1] # sort the indexes 375 | labs = [dset.lab_dict[x][:12] for x in ind] 376 | vals = region_arr[ind] 377 | plt.plot(vals, '_', markeredgewidth=1) 378 | plt.xticks(np.arange(ind.size), labs, rotation='vertical') 379 | plt.xlim((-1, ind.size)) 380 | plt.ylim((vmin, vmax)) 381 | else: # plot blank just to match with color cycle 382 | plt.plot(-1, 0) 383 | pass 384 | 385 | cur_axes = plt.gca() 386 | if not i == 1: 387 | cur_axes.yaxis.set_visible(False) 388 | 389 | if use_orig_top: 390 | cur_axes.xaxis.set_visible(False) 391 | # if i == 5: 392 | # plt.title('raw comp scores for ' + method) 393 | else: 394 | plt.ylabel('patch importance') 395 | plt.subplots_adjust(wspace=0, hspace=0) 396 | 397 | 398 | def visualize_top_classes(model, dset, im_orig, scores_orig_raw): 399 | preds = dset.pred_ims(model, im_orig) 400 | ind = np.argpartition(preds, -8)[-8:] # top-scoring indexes 401 | ind = ind[np.argsort(preds[ind])][::-1] # sort the indexes 402 | 403 | plt.figure(figsize=(14, 4)) 404 | for i, x in enumerate(ind): 405 | plt.subplot(1, 8, i + 1) 406 | visualize_preds(scores_orig_raw, num=x) 407 | plt.title(dset.lab_dict[x][:12] + '\n' + str(preds[x])) 408 | 409 | 410 | def visualize_original_preds_mnist(im_orig, lab_num, comp_scores_raw_list, scores_orig_raw, 411 | subplot_rows=5, dset=None, mturk=False, use_vmax=True): 412 | num_cols = 7 - mturk 413 | plt.subplot(subplot_rows, num_cols, 1) 414 | plt.imshow(im_orig, interpolation='None', cmap='gray') 415 | plt.title('Original image') 416 | plt.axis('off') 417 | 418 | num_top = 5 419 | preds = comp_scores_raw_list[0][0] 420 | ind = np.argpartition(preds, -num_top)[-num_top:] # top-scoring indexes 421 | ind = ind[np.argsort(preds[ind])][::-1] # sort the indexes 422 | labs = ind # [dset.lab_dict[x][:12] for x in ind] 423 | vals = preds[ind] 424 | 425 | # plotting 426 | if not mturk: 427 | plt.subplot(subplot_rows, num_cols, 2) 428 | idxs = np.arange(num_top) 429 | plt.barh(idxs, vals, color='#2ea9e888', edgecolor='#2ea9e888', fill=False, linewidth=1) 430 | for i, (val) in enumerate(zip(idxs, vals)): 431 | plt.text(s=str(labs[i]), x=1, y=i, color="black", verticalalignment="center", size=10) 432 | # plt.text(s=str(pr)+"%", x=pr-5, y=i, color="w", 433 | # verticalalignment="center", horizontalalignment="left", size=18) 434 | ax = plt.gca() 435 | # ax.set_yticklabels(labs) 436 | # ax.set_yticks(np.arange(num_top)) 437 | # plt.yticks(rotation='horizontal') 438 | ax.invert_yaxis() # labels read top-to-bottom 439 | ax.get_yaxis().set_visible(False) 440 | plt.title('logits') 441 | 442 | vmin = min([np.nanmin(scores_orig_raw[:, x]) for x in ind]) # preds[:, num] 443 | vmax = max([np.nanmax(scores_orig_raw[:, x]) for x in ind]) # preds[:, num] 444 | vabs = max(abs(vmin), abs(vmax)) 445 | 446 | for i, x in enumerate(ind): 447 | if i < num_top: 448 | plt.subplot(subplot_rows, num_cols, i + 3 - mturk) 449 | if mturk: 450 | if use_vmax: 451 | visualize_preds(scores_orig_raw, num=x, cbar=False, vabs=vabs) 452 | else: 453 | visualize_preds(scores_orig_raw, num=x, cbar=False) 454 | plt.title(x) 455 | else: 456 | visualize_preds(scores_orig_raw, num=x, cbar=False, vabs=vabs) 457 | plt.title(x) 458 | 459 | return ind, labs 460 | --------------------------------------------------------------------------------