├── RCAV_fig2.png ├── README.md ├── TFMNIST.py ├── TFMNIST_weights.pt ├── inception_mixup.py ├── main.py ├── rcav.py ├── rcav_env.yml ├── rcav_utils.py ├── requirements.txt ├── textures ├── dots1_256.png ├── dots2_256.png ├── spiral1_256.png ├── spiral2_256.png ├── spiral3_256.png ├── stripes1_256.png ├── stripes2_256.png ├── zigzag1_256.png ├── zigzag2_256.png └── zigzag3_256.png ├── train-images-idx3-ubyte ├── train-labels-idx1-ubyte └── utils.py /RCAV_fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/RCAV_fig2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Semantic Interpretability: Revisiting Concept Activation Vectors (ICML WHI 2020) 2 | 3 | This repository contains the official pytorch implementation of RCAV and the accompanying TFMNIST and Biased-CAMELYON16 datasets. 4 | 5 | Robust Concept Activation Vectors (RCAV) quantifies the effects of semantic concepts on individual model predictions and on model behavior as a whole. By generalizing previous work on concept activation vectors to account for model non-linearity, and by introducing stricter hypothesis testing, we show that RCAV yields interpretations which are both more accurate at the image level and robust at the dataset level. RCAV, like saliency methods, supports the fine-grained interpretation of individual predictions. 6 | 7 | The TFMNIST and B-CAMELYON16 datasets may be used as benchmarks for semantic interpretability methods. 8 | 9 | ### Run main.py to reproduce the results shown in Figure 2 of the paper. Note that main.py also accepts command line arguments, e.g. if you wish to retrain the model instead of loading the trained weights. 10 | 11 | # Usage: 12 | 13 | main.py will save your reproduced copy of figure 2 to the file RCAV_fig2.png 14 | 15 | rcav.py and rcav_utils.py contains the code for running RCAV on any model. 16 | 17 | Note use of rcav.py on another model requires adding latent augmentation functionality as is done in lines 150, 156, etc. of inception_mixup.py. 18 | 19 | # Requirements: 20 | Please FIRST download model weights from https://zenodo.org/record/3889104 and put the file in the same directory as main.py 21 | 22 | Requirements for these scripts may be installed by pip or conda using the requirements.txt or rcav_env.yml files. 23 | 24 | # Datasets: 25 | TFMNIST.py contains the code for creating the TFMNIST dataset note that the split is not the same as used for model training. 26 | 27 | The B-CAMELYON16 dataset described in the paper will be made available shortly. 28 | 29 | The unnaugmented data for CAMELYON16 is available at http://gigadb.org/dataset/view/id/100439/ 30 | 31 | # TODO: 32 | Upload B-CAMELYON16 33 | 34 | Upload split used for TFMNIST 35 | 36 | Link RCAV to arxiv posting 37 | -------------------------------------------------------------------------------- /TFMNIST.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import struct 3 | import os 4 | import cv2 5 | import sys 6 | import numpy as np 7 | from PIL import Image 8 | import random 9 | from tqdm import tqdm 10 | 11 | random.seed(1) 12 | 13 | PATH_TO_TEXTURES = "textures" 14 | THRESHOLD = 10 15 | 16 | DELTAS = ( 17 | (-1, 0), 18 | (-1, 1), 19 | (-1, -1), 20 | (1, 0), 21 | (1, 1), 22 | (1, -1), 23 | (0, 1), 24 | (0, -1) 25 | ) 26 | SIZE = 28 27 | 28 | def binary_threshold(x): 29 | if x < THRESHOLD: 30 | return 0 31 | return 255 32 | 33 | def read_idx(filename): 34 | with open(filename, 'rb') as f: 35 | _, _, dims = struct.unpack('>HBB', f.read(4)) 36 | shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims)) 37 | return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape) 38 | 39 | def in_bounds(coords): 40 | return coords[0] >= 0 and coords[0] < SIZE and coords[1] >= 0 and coords[1] < SIZE 41 | 42 | def is_boundary_pixel(coords, bitmap): 43 | x = coords[0] 44 | y = coords[1] 45 | for deltaX, deltaY in DELTAS: 46 | neighbors = (x + deltaX, y + deltaY) 47 | if in_bounds(neighbors): 48 | if bitmap[neighbors[1]][neighbors[0]] == 0: 49 | return True 50 | else: 51 | return True 52 | return False 53 | 54 | def inside_boundary(coord, boundary_set): 55 | #hotspot 56 | x0 = coord[0] 57 | y0 = coord[1] 58 | values = [False, False, False, False] 59 | for iv, delta in enumerate(((1, 0), (0, 1), (-1, 0), (0, -1))): 60 | dx = delta[0] 61 | dy = delta[1] 62 | x = x0 63 | y = y0 64 | while True: 65 | if in_bounds((x, y)): 66 | if boundary_set[y][x] == 1: 67 | values[iv] = True 68 | break 69 | x += dx 70 | y += dy 71 | else: 72 | break 73 | return values[0] and values[1] and values[2] and values[3] 74 | 75 | def fill_bitmap(arr): 76 | #given a numpy array, return a filled version 77 | height = arr.shape[0] 78 | width = arr.shape[1] 79 | binary_set = np.array(arr) 80 | for y in range(height): 81 | for x in range(width): 82 | binary_set[y][x] = binary_threshold(binary_set[y][x]) 83 | 84 | #run DFS to find disjoint sets 85 | visited = np.zeros((height, width)) 86 | def explore(coords, _set): 87 | visited[coords[1]][coords[0]] = 1 88 | _set.add(coords) 89 | color = binary_set[coords[1]][coords[0]] 90 | for deltaX, deltaY in DELTAS: 91 | neighbors = (coords[0] + deltaX, coords[1] + deltaY) 92 | if in_bounds(neighbors) \ 93 | and color == binary_set[neighbors[1]][neighbors[0]] \ 94 | and visited[neighbors[1]][neighbors[0]] == 0: 95 | explore(neighbors, _set) 96 | 97 | disjoint_sets = [] 98 | for y in range(height): 99 | flag = False 100 | for x in range(width): 101 | if visited[y][x] == 0: 102 | _set = set() 103 | explore((x, y), _set) 104 | disjoint_sets.append(_set) 105 | test_item = next(iter(_set)) 106 | if len(_set) > 392 and binary_set[test_item[1]][test_item[0]] == 255: 107 | flag = True 108 | break 109 | if flag: 110 | break 111 | 112 | def key(_set): 113 | coord = next(iter(_set)) 114 | if binary_set[coord[1]][coord[0]] == 255: 115 | return len(_set) 116 | return 0 117 | 118 | max_set = max(disjoint_sets, key=key) 119 | plt_arr = np.zeros((height, width)) 120 | result = np.zeros((height, width)) 121 | 122 | #find outline of max set 123 | boundary_set = np.zeros((height, width)) 124 | for coord in max_set: 125 | plt_arr[coord[1]][coord[0]] = 200 126 | result[coord[1]][coord[0]] = 255 127 | if is_boundary_pixel(coord, binary_set): #hotspot 128 | boundary_set[coord[1]][coord[0]] = 1 129 | plt_arr[coord[1]][coord[0]] = 255 130 | 131 | #for each node not in max set, if inside boundary, then flip it on 132 | #to check if in boundary, extend rays in 4 directions. If all 4 hit a boundary, then 133 | # it is inside. 134 | for _set in disjoint_sets: 135 | if _set != max_set: 136 | for coord in _set: 137 | plt_arr[coord[1]][coord[0]] = 100 138 | if inside_boundary(coord, boundary_set): 139 | plt_arr[coord[1]][coord[0]] = 200 140 | result[coord[1]][coord[0]] = 255 141 | return result 142 | 143 | def load_texture(filename): 144 | print(filename) 145 | im = np.array(Image.open(filename)) 146 | return im 147 | 148 | def scale_texture(texture, scale=1): 149 | if scale > 1 or scale < 0: 150 | raise Exception("invalid scale: must be in [0, 1]") 151 | size_new = int(32 + ((256-32)*(1-scale))) 152 | res = cv2.resize(texture, dsize=(size_new, size_new), interpolation=cv2.INTER_CUBIC) 153 | #crop to center 154 | resize = (size_new-32)//2 155 | return res[resize:resize+32, resize:resize+32] 156 | 157 | 158 | def add_to_filled(filled, texture, offset=(0,0)): 159 | result = np.zeros(filled.shape) 160 | ry = 0 161 | rx = 0 162 | for y in range(offset[0], filled.shape[0]+offset[0]): 163 | rx = 0 164 | for x in range(offset[1], filled.shape[1]+offset[1]): 165 | if (filled[ry][rx] == 255): 166 | result[ry][rx] = texture[y][x] 167 | rx += 1 168 | ry += 1 169 | return result 170 | 171 | def show_bitmap(arr1): 172 | # plot the sample 173 | plt.imshow(arr1, cmap='gray') 174 | plt.show() 175 | 176 | def show_bitmap_sbs(arr1, arr2): 177 | # plot the sample 178 | plt.imshow([np.concatenate([arr1[i], arr2[i]]) for i in range(arr1.shape[0])], cmap='gray') 179 | plt.show() 180 | 181 | def add_noise(arr, intensity=0): 182 | NOISE_LEVEL = 30 183 | CEIL = np.ones(arr.shape)*255 184 | FLOOR = np.zeros(arr.shape) 185 | noise = np.random.randint(-NOISE_LEVEL,NOISE_LEVEL,size=arr.shape) 186 | result = arr+noise 187 | result = np.minimum(result, CEIL) 188 | result = np.maximum(result, FLOOR).astype(np.uint8) 189 | return result 190 | 191 | class TexturedFMNIST(): 192 | def __init__(self, texture_dir=PATH_TO_TEXTURES, fmnist_dir='.'): 193 | self.imgs = read_idx(os.path.join(fmnist_dir, "train-images-idx3-ubyte")) 194 | self.labels = read_idx(os.path.join(fmnist_dir, "train-labels-idx1-ubyte")) 195 | 196 | 197 | self.textures = [load_texture(os.path.join(texture_dir, i)) for i in os.listdir(texture_dir)] 198 | new_textures = [] 199 | for text in self.textures: 200 | new_textures.append(text+(text<30).astype(np.uint8)*30) 201 | self.textures = new_textures 202 | self.train_inds = range(0,50000) 203 | self.val_inds = range(50000,60000) 204 | print("TexturedFMNIST initialized") 205 | 206 | def build_class(self, class_num, train, num_samples=None, offset=None, return_inds=True, get_meta=True, texture_choices=[], alpha=0, texture_rescale=True, texture_aug=True, aug_intensity=0.5): 207 | ''' 208 | texture_rescale: either bool or float oor string? controls how to rescale textures when applying 209 | texture_aug: whether to add some noise to texture scaling and orientation before application 210 | ''' 211 | # Loop over self.get_textured_sample randomly sampling noise for texture application if texture_aug is True 212 | result = [] 213 | inds = [] 214 | meta_list = [] 215 | if train: to_build_inds = self.train_inds 216 | else: to_build_inds = self.val_inds 217 | 218 | if train==False: 219 | offset = (0,0) 220 | texture_aug = False 221 | for index in tqdm(to_build_inds, desc='Building TFMNIST class'): 222 | if self.labels[index] == class_num: 223 | i0, label, meta = self.get_textured_sample(index, texture_choices=texture_choices, alpha=alpha, texture_rescale=texture_rescale, texture_aug=texture_aug, aug_intensity=aug_intensity) 224 | result.append(i0) 225 | if return_inds: inds.append(index) 226 | if get_meta: meta_list.append(meta) 227 | if not(num_samples is None) and len(result)>=num_samples: break 228 | if not return_inds: 229 | return result 230 | else: 231 | return (result, inds, meta_list) 232 | 233 | def get_textured_sample(self, img_index, offset=None, texture_choices=[], randomize_textures=True, alpha=0, texture_rescale=True, texture_aug=True, aug_intensity=0.5): 234 | 235 | if texture_aug: 236 | if offset is None: offset = (random.randint(0, 4), random.randint(0, 4)) 237 | noise = aug_intensity 238 | else: 239 | offset = (0,0) 240 | noise = 0 241 | 242 | if randomize_textures: 243 | t1_ind = random.choice(texture_choices) 244 | t2_ind = random.choice(texture_choices) 245 | t1 = self.textures[t1_ind] 246 | t2 = self.textures[t2_ind] 247 | else: 248 | t1_ind = texture_choices[0] 249 | t2_ind = texture_choices[1] 250 | t1 = self.textures[t1_ind] 251 | t2 = self.textures[t2_ind] 252 | 253 | if texture_rescale: 254 | t1 = scale_texture(t1, random.random()) 255 | t2 = scale_texture(t2, random.random()) 256 | else: 257 | t1 = scale_texture(t1, 1) 258 | t2 = scale_texture(t2, 1) 259 | 260 | texture = None 261 | if alpha: 262 | texture = self.interpolate_textures([t1, t2], alpha) #add_noise(, noise) 263 | else: 264 | texture = t1 265 | if noise>0: 266 | texture = add_noise(texture, noise) 267 | i1 = add_to_filled(fill_bitmap(self.imgs[img_index]), texture, offset) 268 | meta = {'offset':offset, 'textures':(t1_ind,t2_ind)} 269 | return i1, self.labels[img_index], meta 270 | 271 | def interpolate_textures(self, texture_list, alpha): 272 | result = alpha*texture_list[0]+(1-alpha)*texture_list[1] 273 | CEIL = np.ones(texture_list[0].shape)*255 274 | FLOOR = np.zeros(texture_list[0].shape) 275 | result = np.minimum(result, CEIL) 276 | result = np.maximum(result, FLOOR).astype(np.uint8) 277 | return result 278 | 279 | def test(): 280 | tf = TexturedFMNIST() 281 | tf.build_class(0, tf.textures, False, False, False) -------------------------------------------------------------------------------- /TFMNIST_weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/TFMNIST_weights.pt -------------------------------------------------------------------------------- /inception_mixup.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.jit.annotations import Optional 7 | from torch import Tensor 8 | from torch.hub import load_state_dict_from_url 9 | import numpy as np 10 | 11 | 12 | __all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs'] 13 | 14 | 15 | model_urls = { 16 | # Inception v3 ported from TensorFlow 17 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 18 | } 19 | 20 | # InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) 21 | # InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]} 22 | 23 | # # Script annotations failed with _GoogleNetOutputs = namedtuple ... 24 | # # _InceptionOutputs set here for backwards compat 25 | # _InceptionOutputs = InceptionOutputs 26 | 27 | def mixup_data(x, y, alpha): 28 | 29 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 30 | if alpha > 0.: 31 | lam = np.random.beta(alpha, alpha) 32 | else: 33 | lam = 1. 34 | batch_size = x.size()[0] 35 | index = torch.randperm(batch_size).cuda() 36 | mixed_x = lam * x + (1 - lam) * x[index,:] 37 | y_a, y_b = y, y[index] 38 | return mixed_x, y_a, y_b, lam 39 | 40 | def inception_v3(pretrained=False, progress=True, **kwargs): 41 | r"""Inception v3 model architecture from 42 | `"Rethinking the Inception Architecture for Computer Vision" `_. 43 | 44 | .. note:: 45 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 46 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 47 | 48 | Args: 49 | pretrained (bool): If True, returns a model pre-trained on ImageNet 50 | progress (bool): If True, displays a progress bar of the download to stderr 51 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 52 | Default: *True* 53 | transform_input (bool): If True, preprocesses the input according to the method with which it 54 | was trained on ImageNet. Default: *False* 55 | """ 56 | if pretrained: 57 | if 'transform_input' not in kwargs: 58 | kwargs['transform_input'] = True 59 | if 'aux_logits' in kwargs: 60 | original_aux_logits = kwargs['aux_logits'] 61 | kwargs['aux_logits'] = True 62 | else: 63 | original_aux_logits = True 64 | model = Inception3(**kwargs) 65 | state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], 66 | progress=progress) 67 | model.load_state_dict(state_dict) 68 | if not original_aux_logits: 69 | model.aux_logits = False 70 | del model.AuxLogits 71 | return model 72 | 73 | return Inception3(**kwargs) 74 | 75 | 76 | class Inception3(nn.Module): 77 | 78 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, 79 | inception_blocks=None, init_weights=True, 80 | manifold_mix=None, mixup_alpha=None, latent_aug=None): 81 | super(Inception3, self).__init__() 82 | if inception_blocks is None: 83 | inception_blocks = [ 84 | BasicConv2d, InceptionA, InceptionB, InceptionC, 85 | InceptionD, InceptionE, InceptionAux 86 | ] 87 | assert len(inception_blocks) == 7 88 | conv_block = inception_blocks[0] 89 | inception_a = inception_blocks[1] 90 | inception_b = inception_blocks[2] 91 | inception_c = inception_blocks[3] 92 | inception_d = inception_blocks[4] 93 | inception_e = inception_blocks[5] 94 | inception_aux = inception_blocks[6] 95 | 96 | self.manifold_mix = manifold_mix 97 | self.mixup_alpha = mixup_alpha 98 | self.aux_logits = aux_logits 99 | self.transform_input = transform_input 100 | self.latent_aug = latent_aug 101 | 102 | self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) 103 | self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) 104 | self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) 105 | self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) 106 | self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) 107 | self.Mixed_5b = inception_a(192, pool_features=32) 108 | self.Mixed_5c = inception_a(256, pool_features=64) 109 | self.Mixed_5d = inception_a(288, pool_features=64) 110 | self.Mixed_6a = inception_b(288) 111 | self.Mixed_6b = inception_c(768, channels_7x7=128) 112 | self.Mixed_6c = inception_c(768, channels_7x7=160) 113 | self.Mixed_6d = inception_c(768, channels_7x7=160) 114 | self.Mixed_6e = inception_c(768, channels_7x7=192) 115 | if aux_logits: 116 | self.AuxLogits = inception_aux(768, num_classes) 117 | self.Mixed_7a = inception_d(768) 118 | self.Mixed_7b = inception_e(1280) 119 | self.Mixed_7c = inception_e(2048) 120 | self.fc = nn.Linear(2048, num_classes) 121 | if init_weights: 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 124 | import scipy.stats as stats 125 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 126 | X = stats.truncnorm(-2, 2, scale=stddev) 127 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 128 | values = values.view(m.weight.size()) 129 | with torch.no_grad(): 130 | m.weight.copy_(values) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | nn.init.constant_(m.weight, 1) 133 | nn.init.constant_(m.bias, 0) 134 | 135 | def _transform_input(self, x): 136 | if self.transform_input: 137 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 138 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 139 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 140 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 141 | return x 142 | 143 | def _forward(self, x, target=None, aug=None): 144 | # N x 3 x 299 x 299 145 | y_a = None 146 | if self.manifold_mix=='input' and not (target is None): 147 | x, y_a, y_b, lam = mixup_data(x, target, self.mixup_alpha) 148 | assert(not y_a is None) 149 | x = self.Conv2d_1a_3x3(x) 150 | if self.latent_aug=='Conv2d_1a_3x3' and not(aug is None): 151 | x = x+aug 152 | # N x 32 x 149 x 149 153 | x = self.Conv2d_2a_3x3(x) 154 | # N x 32 x 147 x 147 155 | x = self.Conv2d_2b_3x3(x) 156 | if self.latent_aug=='Conv2d_2b_3x3' and not(aug is None): 157 | x = x+aug 158 | # N x 64 x 147 x 147 159 | x = F.max_pool2d(x, kernel_size=3, stride=2) 160 | # N x 64 x 73 x 73 161 | x = self.Conv2d_3b_1x1(x) 162 | if self.latent_aug=='Conv2d_3b_1x1' and not(aug is None): 163 | x = x+aug 164 | # N x 80 x 73 x 73 165 | x = self.Conv2d_4a_3x3(x) 166 | # N x 192 x 71 x 71 167 | x = F.max_pool2d(x, kernel_size=3, stride=2) 168 | # N x 192 x 35 x 35 169 | x = self.Mixed_5b(x) 170 | if self.latent_aug=='Mixed_5b' and not(aug is None): 171 | x = x+aug 172 | # N x 256 x 35 x 35 173 | x = self.Mixed_5c(x) 174 | if self.latent_aug=='Mixed_5c' and not(aug is None): 175 | x = x+aug 176 | # N x 288 x 35 x 35 177 | x = self.Mixed_5d(x) 178 | if self.latent_aug=='Mixed_5d' and not(aug is None): 179 | x = x+aug 180 | # N x 288 x 35 x 35 181 | x = self.Mixed_6a(x) 182 | if self.latent_aug=='Mixed_6a' and not(aug is None): 183 | x = x+aug 184 | # N x 768 x 17 x 17 185 | x = self.Mixed_6b(x) 186 | if self.latent_aug=='Mixed_6b' and not(aug is None): 187 | x = x+aug 188 | # N x 768 x 17 x 17 189 | x = self.Mixed_6c(x) 190 | if self.latent_aug=='Mixed_6c' and not(aug is None): 191 | x = x+aug 192 | # N x 768 x 17 x 17 193 | x = self.Mixed_6d(x) 194 | if self.latent_aug=='Mixed_6d' and not(aug is None): 195 | x = x+aug 196 | # N x 768 x 17 x 17 197 | x = self.Mixed_6e(x) 198 | if self.latent_aug=='Mixed_6e' and not(aug is None): 199 | x = x+aug 200 | # N x 768 x 17 x 17 201 | aux_defined = self.training and self.aux_logits 202 | if aux_defined: 203 | aux = self.AuxLogits(x) 204 | else: 205 | aux = None 206 | # N x 768 x 17 x 17 207 | x = self.Mixed_7a(x) 208 | if self.latent_aug=='Mixed_7a' and not(aug is None): 209 | x = x+aug 210 | # N x 1280 x 8 x 8 211 | x = self.Mixed_7b(x) 212 | if self.latent_aug=='Mixed_7b' and not(aug is None): 213 | x = x+aug 214 | # N x 2048 x 8 x 8 215 | x = self.Mixed_7c(x) 216 | if self.latent_aug=='Mixed_7c' and not(aug is None): 217 | x = x+aug 218 | # N x 2048 x 8 x 8 219 | # Adaptive average pooling 220 | x = F.adaptive_avg_pool2d(x, (1, 1)) 221 | # N x 2048 x 1 x 1 222 | x = F.dropout(x, training=self.training) 223 | # N x 2048 x 1 x 1 224 | x = torch.flatten(x, 1) 225 | # N x 2048 226 | x = self.fc(x) 227 | # N x 1000 (num_classes) 228 | if not (target is None): return x, (y_a, y_b, lam) 229 | else: return x 230 | 231 | @torch.jit.unused 232 | def eager_outputs(self, x, aux): 233 | # type: (Tensor, Optional[Tensor]) -> InceptionOutputs 234 | if self.training and self.aux_logits: 235 | return InceptionOutputs(x, aux) 236 | else: 237 | return x 238 | 239 | def forward(self, x, target=None, aug=None): 240 | x = self._transform_input(x,) 241 | x = self._forward(x, target=target, aug=aug) 242 | return x 243 | 244 | 245 | class InceptionA(nn.Module): 246 | 247 | def __init__(self, in_channels, pool_features, conv_block=None): 248 | super(InceptionA, self).__init__() 249 | if conv_block is None: 250 | conv_block = BasicConv2d 251 | self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) 252 | 253 | self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) 254 | self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) 255 | 256 | self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) 257 | self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) 258 | self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) 259 | 260 | self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) 261 | 262 | def _forward(self, x): 263 | branch1x1 = self.branch1x1(x) 264 | 265 | branch5x5 = self.branch5x5_1(x) 266 | branch5x5 = self.branch5x5_2(branch5x5) 267 | 268 | branch3x3dbl = self.branch3x3dbl_1(x) 269 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 270 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 271 | 272 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 276 | return outputs 277 | 278 | def forward(self, x): 279 | outputs = self._forward(x) 280 | return torch.cat(outputs, 1) 281 | 282 | 283 | class InceptionB(nn.Module): 284 | 285 | def __init__(self, in_channels, conv_block=None): 286 | super(InceptionB, self).__init__() 287 | if conv_block is None: 288 | conv_block = BasicConv2d 289 | self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) 290 | 291 | self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) 292 | self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) 293 | self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) 294 | 295 | def _forward(self, x): 296 | branch3x3 = self.branch3x3(x) 297 | 298 | branch3x3dbl = self.branch3x3dbl_1(x) 299 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 300 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 301 | 302 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 303 | 304 | outputs = [branch3x3, branch3x3dbl, branch_pool] 305 | return outputs 306 | 307 | def forward(self, x): 308 | outputs = self._forward(x) 309 | return torch.cat(outputs, 1) 310 | 311 | 312 | class InceptionC(nn.Module): 313 | 314 | def __init__(self, in_channels, channels_7x7, conv_block=None): 315 | super(InceptionC, self).__init__() 316 | if conv_block is None: 317 | conv_block = BasicConv2d 318 | self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) 319 | 320 | c7 = channels_7x7 321 | self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) 322 | self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 323 | self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 324 | 325 | self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) 326 | self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 327 | self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 328 | self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 329 | self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 330 | 331 | self.branch_pool = conv_block(in_channels, 192, kernel_size=1) 332 | 333 | def _forward(self, x): 334 | branch1x1 = self.branch1x1(x) 335 | 336 | branch7x7 = self.branch7x7_1(x) 337 | branch7x7 = self.branch7x7_2(branch7x7) 338 | branch7x7 = self.branch7x7_3(branch7x7) 339 | 340 | branch7x7dbl = self.branch7x7dbl_1(x) 341 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 342 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 343 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 344 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 345 | 346 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 347 | branch_pool = self.branch_pool(branch_pool) 348 | 349 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 350 | return outputs 351 | 352 | def forward(self, x): 353 | outputs = self._forward(x) 354 | return torch.cat(outputs, 1) 355 | 356 | 357 | class InceptionD(nn.Module): 358 | 359 | def __init__(self, in_channels, conv_block=None): 360 | super(InceptionD, self).__init__() 361 | if conv_block is None: 362 | conv_block = BasicConv2d 363 | self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) 364 | self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) 365 | 366 | self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) 367 | self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) 368 | self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) 369 | self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) 370 | 371 | def _forward(self, x): 372 | branch3x3 = self.branch3x3_1(x) 373 | branch3x3 = self.branch3x3_2(branch3x3) 374 | 375 | branch7x7x3 = self.branch7x7x3_1(x) 376 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 377 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 378 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 379 | 380 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 381 | outputs = [branch3x3, branch7x7x3, branch_pool] 382 | return outputs 383 | 384 | def forward(self, x): 385 | outputs = self._forward(x) 386 | return torch.cat(outputs, 1) 387 | 388 | 389 | class InceptionE(nn.Module): 390 | 391 | def __init__(self, in_channels, conv_block=None): 392 | super(InceptionE, self).__init__() 393 | if conv_block is None: 394 | conv_block = BasicConv2d 395 | self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) 396 | 397 | self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) 398 | self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) 399 | self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) 400 | 401 | self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) 402 | self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) 403 | self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) 404 | self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) 405 | 406 | self.branch_pool = conv_block(in_channels, 192, kernel_size=1) 407 | 408 | def _forward(self, x): 409 | branch1x1 = self.branch1x1(x) 410 | 411 | branch3x3 = self.branch3x3_1(x) 412 | branch3x3 = [ 413 | self.branch3x3_2a(branch3x3), 414 | self.branch3x3_2b(branch3x3), 415 | ] 416 | branch3x3 = torch.cat(branch3x3, 1) 417 | 418 | branch3x3dbl = self.branch3x3dbl_1(x) 419 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 420 | branch3x3dbl = [ 421 | self.branch3x3dbl_3a(branch3x3dbl), 422 | self.branch3x3dbl_3b(branch3x3dbl), 423 | ] 424 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 425 | 426 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 427 | branch_pool = self.branch_pool(branch_pool) 428 | 429 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 430 | return outputs 431 | 432 | def forward(self, x): 433 | outputs = self._forward(x) 434 | return torch.cat(outputs, 1) 435 | 436 | 437 | class InceptionAux(nn.Module): 438 | 439 | def __init__(self, in_channels, num_classes, conv_block=None): 440 | super(InceptionAux, self).__init__() 441 | if conv_block is None: 442 | conv_block = BasicConv2d 443 | self.conv0 = conv_block(in_channels, 128, kernel_size=1) 444 | self.conv1 = conv_block(128, 768, kernel_size=5) 445 | self.conv1.stddev = 0.01 446 | self.fc = nn.Linear(768, num_classes) 447 | self.fc.stddev = 0.001 448 | 449 | def forward(self, x): 450 | # N x 768 x 17 x 17 451 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 452 | # N x 768 x 5 x 5 453 | x = self.conv0(x) 454 | # N x 128 x 5 x 5 455 | x = self.conv1(x) 456 | # N x 768 x 1 x 1 457 | # Adaptive average pooling 458 | x = F.adaptive_avg_pool2d(x, (1, 1)) 459 | # N x 768 x 1 x 1 460 | x = torch.flatten(x, 1) 461 | # N x 768 462 | x = self.fc(x) 463 | # N x 1000 464 | return x 465 | 466 | 467 | class BasicConv2d(nn.Module): 468 | 469 | def __init__(self, in_channels, out_channels, **kwargs): 470 | super(BasicConv2d, self).__init__() 471 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 472 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 473 | 474 | def forward(self, x): 475 | x = self.conv(x) 476 | x = self.bn(x) 477 | return F.relu(x, inplace=True) 478 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Import 2 | import copy as cp 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | import PIL 9 | from scipy.special import softmax 10 | import scipy.stats as stats 11 | from scipy.spatial.distance import cosine 12 | import random 13 | import math 14 | import sklearn 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torchvision 20 | import torchvision.transforms as transforms 21 | import torch.utils.data as torchUtils 22 | import torch.optim as optim 23 | from torchvision import models 24 | 25 | import utils 26 | import inception_mixup 27 | import rcav 28 | import rcav_utils 29 | import TFMNIST 30 | import struct 31 | from tqdm import tqdm 32 | import argparse 33 | 34 | if __name__ == '__main__': 35 | ''' 36 | Note that all hyper parameters are set according to the experiment shown in paper figure 2. A number of additional hyperparameters for FMNIST and RCAV not set by the argparser can be found labelled by '#HYPERPARAMETER' in the code below. 37 | ''' 38 | parser = argparse.ArgumentParser(description='Run RCAV on TFMNIST to reproduce figure 2 of paper') 39 | parser.add_argument('-t', '--train', type=bool, default=False, help='Whether to train or load trained model') 40 | parser.add_argument('-b', '--batch_size', type=int, default=64, help='Batch size if training') 41 | parser.add_argument('-n', '--n_permutations', type=int, default=500, help='Number of permutations drawn to approximate RCAV p-value') 42 | parser.add_argument('-e', '--epochs', type=int, default=50, help='Number of epochs to train for') 43 | 44 | kwargs = parser.parse_args() 45 | train=kwargs.train 46 | load_saved_model=(not train) 47 | batch_size=kwargs.batch_size 48 | n_perm = kwargs.n_permutations 49 | epochs = kwargs.epochs 50 | 51 | 52 | fmnist_dir = './' 53 | texture_dir = './textures' 54 | save_dir=None 55 | counterfactual_aug=True 56 | 57 | #################### Construct TFMNIST ############################### 58 | class_textures = {0:[1,6,4], 6:[2,5,9]} #HYPERPARAMETER FOR TFMNIST 59 | for i in range(10): 60 | if not i in [0,6]: class_textures[i]=[0,3,7,8] 61 | 62 | if train: 63 | train_ims, train_inds, train_meta = [], [], [] 64 | texturizer = TFMNIST.TexturedFMNIST(texture_dir = texture_dir, fmnist_dir=fmnist_dir) 65 | for cl in range(10): 66 | textures = class_textures[cl] 67 | data = texturizer.build_class(cl, train=True, texture_choices=textures, alpha=False, texture_rescale=False, texture_aug=True, aug_intensity=0.5) 68 | train_ims.append(data[0]) 69 | train_inds.append(data[1]) 70 | train_meta.append(data[2]) 71 | 72 | train_im_list, train_label_list,train_ind_list, train_meta_list = [], [], [], [] 73 | for i,images in enumerate(train_ims): 74 | train_im_list.extend(images) 75 | train_label_list.extend(len(images)*[i]) 76 | train_meta_list.extend(train_meta[i]) 77 | train_ind_list.extend(train_inds[i]) 78 | 79 | train_im_array = np.stack(train_im_list).astype(np.uint8) 80 | train_label_array = np.array(train_label_list) 81 | train_ind_array = np.array(train_ind_list) 82 | train_texture_array = np.array([traindict['textures'][0] for traindict in train_meta_list]) 83 | 84 | shuffle_inds_train = list(range(len(train_im_array))) 85 | random.shuffle(shuffle_inds_train) 86 | train_im_array = train_im_array[shuffle_inds_train] 87 | train_label_array = train_label_array[shuffle_inds_train] 88 | train_ind_array = train_ind_array[shuffle_inds_train] 89 | train_texture_array = train_texture_array[shuffle_inds_train] 90 | train_array_dict = {'X':train_im_array ,'Y':train_label_array ,'meta_inds':train_ind_array, 'meta_texts':train_texture_array} 91 | 92 | train_transforms = transforms.Compose([transforms.ToPILImage(), transforms.Resize((224,224)), transforms.RandomAffine(50, translate=(0.05,0.05), scale=(0.9,1.1), shear=(-10,10)), transforms.ToTensor()]) 93 | train_set = utils.im_dataset(train_im_array, train_label_array, transform = train_transforms) 94 | train_loader = torchUtils.DataLoader(train_set, batch_size=batch_size, shuffle=True, sampler=None, 95 | batch_sampler=None, num_workers=1, collate_fn=None, 96 | pin_memory=False, drop_last=False, timeout=0, 97 | worker_init_fn=None) 98 | 99 | val_ims, val_inds, val_meta = [], [], [] 100 | texturizer = TFMNIST.TexturedFMNIST(texture_dir = texture_dir, fmnist_dir=fmnist_dir) 101 | for cl in range(10): 102 | textures = class_textures[cl] 103 | data = texturizer.build_class(cl, train=False, texture_choices=textures, alpha=False, texture_rescale=False, texture_aug=True, aug_intensity=0.5) 104 | val_ims.append(data[0]) 105 | val_inds.append(data[1]) 106 | val_meta.append(data[2]) 107 | 108 | val_im_list, val_label_list,val_ind_list, val_meta_list = [], [], [], [] 109 | for i,images in enumerate(val_ims): 110 | val_im_list.extend(images) 111 | val_label_list.extend(len(images)*[i]) 112 | val_meta_list.extend(val_meta[i]) 113 | val_ind_list.extend(val_inds[i]) 114 | 115 | val_im_array = np.stack(val_im_list).astype(np.uint8) 116 | val_label_array = np.array(val_label_list) 117 | val_ind_array = np.array(val_ind_list) 118 | val_texture_array = np.array([valdict['textures'][0] for valdict in val_meta_list]) 119 | 120 | shuffle_inds_val = list(range(len(val_im_array))) 121 | random.shuffle(shuffle_inds_val) 122 | val_im_array = val_im_array[shuffle_inds_val] 123 | val_label_array = val_label_array[shuffle_inds_val] 124 | val_ind_array = val_ind_array[shuffle_inds_val] 125 | val_texture_array = val_texture_array[shuffle_inds_val] 126 | 127 | val_array_dict = {'X':val_im_array ,'Y':val_label_array ,'meta_inds':val_ind_array, 'meta_texts':val_texture_array} 128 | val_transforms = transforms.Compose([transforms.ToPILImage(), transforms.Resize((224,224)), transforms.ToTensor()]) 129 | val_set = utils.im_dataset(val_im_array, val_label_array, transform = val_transforms) 130 | 131 | val_loader = torchUtils.DataLoader(val_set, batch_size=batch_size, shuffle=False, sampler=None, 132 | batch_sampler=None, num_workers=1, collate_fn=None, 133 | pin_memory=False, drop_last=False, timeout=0, 134 | worker_init_fn=None) 135 | 136 | #################### Load or train model ############################### 137 | print('Loading model') 138 | 139 | #HYPERPARAMETERS MODEL TRAINING 140 | lr_factor = 0.2 141 | lr_patience = 10 142 | decay = 0.98 143 | stop_patience = 15 144 | 145 | model = inception_mixup.inception_v3(pretrained=True, aux_logits=False, manifold_mix='input', mixup_alpha=0.2, transform_input=False) 146 | model.fc = nn.Linear(2048, 10) #FMNIST has 10 classes 147 | model.Conv2d_1a_3x3 = utils.BasicConv2d(1, 32, kernel_size=3, stride=2) #FMNIST has only one channel 148 | model = model.cuda() 149 | 150 | BCE = nn.BCELoss() 151 | softmax = nn.Softmax(dim=1) 152 | criterion = lambda x,y: BCE(softmax(x),y) 153 | optimizer = optim.Adam(model.parameters(), lr=0.001) 154 | plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=lr_factor, patience=lr_patience, verbose=True, threshold=-0.01) 155 | decay_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay, last_epoch=-1) 156 | schedulers = {'decay':[decay_scheduler,None], 'plateau':[plateau_scheduler,'acc'],} 157 | 158 | runner = utils.RunNet(model, criterion, optimizer, 10, 159 | schedulers=schedulers, save_dir=save_dir, mixup=True) 160 | if load_saved_model: model.load_state_dict(torch.load('TFMNIST_weights.pt')) 161 | if train: 162 | for epoch in tqdm(range(epochs)): 163 | print('Epoch ', epoch+1) 164 | runner.do_epoch(train_loader, train=True) 165 | print('train loss', runner.losses) 166 | runner.do_epoch(val_loader, train=False) 167 | runner.get_results('val') 168 | runner.schedule() 169 | print('val set performance', [(metric, runner.results['val'][metric][-1]) for metric in runner.results['val']]) 170 | print() 171 | 172 | #################### Compute ground truth softmax differences ############################### 173 | print('Computing ground truth concept sensitivity') 174 | if counterfactual_aug: 175 | #HYPERPARAMETERS RCAV 176 | benchmark_classes = [0] 177 | benchmark_class = benchmark_classes[0] 178 | target_class = 6 179 | textures_to_interp = [2,5,9] 180 | val_interp_list, val_baseline_list = [], [] 181 | texturizer = TFMNIST.TexturedFMNIST(texture_dir = texture_dir, fmnist_dir=fmnist_dir) 182 | 183 | for i,ind in enumerate(val_ind_array): 184 | if val_label_array[i] in benchmark_classes: 185 | baseline_texture = val_texture_array[i] 186 | interp_texture = random.choice(textures_to_interp) 187 | val_interp_list.append(texturizer.get_textured_sample(ind, offset=(0,0), texture_choices=[baseline_texture, interp_texture], 188 | randomize_textures=False, alpha=0.90, texture_rescale=False, texture_aug=False, aug_intensity=0)[0].astype(np.uint8)) 189 | val_baseline_list.append(texturizer.get_textured_sample(ind, offset=(0,0), texture_choices=[baseline_texture, interp_texture], 190 | randomize_textures=False, alpha=False, texture_rescale=False, texture_aug=False, aug_intensity=0)[0].astype(np.uint8)) 191 | 192 | labels = len(val_interp_list)*[benchmark_class] 193 | val_transforms = transforms.Compose([transforms.ToPILImage(), transforms.Resize((224,224)), transforms.ToTensor()]) 194 | baseline_val_set = utils.im_dataset(val_baseline_list, labels, transform = val_transforms) 195 | interp_val_set = utils.im_dataset(val_interp_list, labels, transform = val_transforms) 196 | 197 | baseline_loader = torchUtils.DataLoader(baseline_val_set, batch_size=batch_size, shuffle=False, sampler=None, 198 | batch_sampler=None, num_workers=0, collate_fn=None, 199 | pin_memory=False, drop_last=False, timeout=0, 200 | worker_init_fn=None) 201 | interp_loader = torchUtils.DataLoader(interp_val_set, batch_size=batch_size, shuffle=False, sampler=None, 202 | batch_sampler=None, num_workers=0, collate_fn=None, 203 | pin_memory=False, drop_last=False, timeout=0, 204 | worker_init_fn=None) 205 | 206 | runner = utils.RunNet(model, None, None, 2, 207 | schedulers=dict(), save_dir=None) 208 | runner.do_epoch(interp_loader, False) 209 | runner.get_results('interp',format_only=True) 210 | interp_preds = cp.copy(runner.preds) 211 | interp_labels = cp.copy(runner.label_list) 212 | 213 | runner = utils.RunNet(model, None, None, 2, 214 | schedulers=dict(), save_dir=None) 215 | runner.do_epoch(baseline_loader, False) 216 | runner.get_results('baseline',format_only=True) 217 | baseline_preds = cp.copy(runner.preds) 218 | baseline_labels = cp.copy(runner.label_list) 219 | 220 | #################### Construct CAV Concept Set ############################### 221 | print('Building CAV concept set') 222 | pos_cav_ims, neg_cav_ims = [], [] 223 | texturizer = TFMNIST.TexturedFMNIST(texture_dir = texture_dir, fmnist_dir=fmnist_dir) 224 | #HYPERPARAMETERS RCAV 225 | pos_texture = [1,6,4] 226 | neg_texture = [2,5,9] 227 | classes = 10 228 | samples = 50 229 | label_to_cav_class = {0:'Spiral', 1:'Zigzag'} 230 | 231 | for cl in range(classes): 232 | data = texturizer.build_class(cl, num_samples=samples, train=False, texture_choices=pos_texture, alpha=False, texture_rescale=False, texture_aug=False, aug_intensity=0.5) 233 | pos_cav_ims.extend(data[0]) 234 | 235 | for cl in range(classes): 236 | data = texturizer.build_class(cl, num_samples=samples, train=False, texture_choices=neg_texture, alpha=False, texture_rescale=False, texture_aug=False, aug_intensity=0.5) 237 | neg_cav_ims.extend(data[0]) 238 | 239 | random.shuffle(pos_cav_ims) 240 | random.shuffle(neg_cav_ims) 241 | all_cav_ims = pos_cav_ims[:250]+neg_cav_ims[:250] 242 | all_cav_labels = 250*[1]+250*[0] 243 | all_cav_ims = np.stack(all_cav_ims).astype(np.uint8) 244 | all_cav_labels = np.array(all_cav_labels) 245 | 246 | cav_set = utils.im_dataset(all_cav_ims, all_cav_labels, transform = val_transforms) 247 | cav_loader = torchUtils.DataLoader(cav_set, batch_size=batch_size, shuffle=False, sampler=None, 248 | batch_sampler=None, num_workers=0, collate_fn=None, 249 | pin_memory=False, drop_last=False, timeout=0, 250 | worker_init_fn=None) 251 | 252 | #################### Run RCAV to get sensitivity scores ############################### 253 | print('Running RCAV') 254 | #HYPERPARAMETERS RCAV 255 | step_size = 10 256 | RCAV = True 257 | layers_to_test = ['Mixed_6a'] 258 | concepts_to_test = [0] 259 | num_tests = 1 260 | ground_truth_diffs = np.array(interp_preds[:,target_class]-baseline_preds[:,target_class]) 261 | 262 | model.eval() 263 | for layer in layers_to_test: 264 | print(layer) 265 | test = rcav.RCAV(model, layer, cav_set, baseline_val_set, baseline_loader, all_cav_labels, 266 | num_classes=10, class_nums=[benchmark_class], target_class_num=target_class, multiple_tests_num=num_tests, TCAV=False) 267 | for pos_class in concepts_to_test: 268 | cav_score, significance = test.run(pos_class, n_random=n_perm, null_hypothesis='permutation', step_size=step_size, early_stop=False) 269 | pval = test.benchmark_sample_correlation(ground_truth_diffs) 270 | print('Image-level correlation has tau={0:.4f}, with un-adjusted p={1:.4f}'.format(test.trained_tau[0], pval)) #Note this raw p-value is only accurate down to p=1/(n_perm+1) 271 | fig = plt.figure(figsize=(14,7)) 272 | plt.ylim(-0.001,0.012) 273 | p = sns.regplot(x=ground_truth_diffs, y=test.cav_diffs, label='RCAV', scatter_kws={'alpha':0.5}) 274 | p = p.get_figure() 275 | p.savefig('RCAV_fig2.png') -------------------------------------------------------------------------------- /rcav.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import linear_model 3 | from sklearn.model_selection import train_test_split 4 | import rcav_utils 5 | import scipy.stats as stats 6 | import scipy.spatial as spatial 7 | import random 8 | import torch 9 | from torch.nn import Softmax 10 | from scipy.special import softmax 11 | import copy as cp 12 | import pickle 13 | from tqdm import tqdm 14 | import math 15 | 16 | class RCAV(): 17 | def __init__(self, model, layer_name, train_dataset, val_dataset, val_loader, concept_labels, num_classes, class_nums, target_class_num, multiple_tests_num=1, TCAV=False, **logistic_regression_kwargs): 18 | ''' 19 | Args: 20 | model (pytorch model): trained model 21 | layer_name (string): layer choice 22 | train_dataset (torch.utils.data.Dataset): concept set consisting of (image, concept_label) pairs 23 | val_dataset (torch.utils.data.Dataset): dataset consisting of (image, label) pairs i.e. a subset of the val set corrresponding the train set on which the model was trained. 24 | val_loader (torch.utils.data.DataLoader): loader corresponding to val_dataset. Note that sampling order must be consistent across passes i.e. SHUFFLE=FALSE. 25 | concept_labels (np array): array of concept labels in the same order as used in train_dataset 26 | num_classes (int): number of classes 27 | class_nums (list of int): labels (classes) over which TCAV will be calculated 28 | target_class_num (int): label (class) number for which TCAV concept sensitivity will be calculated 29 | multiple_tests_num (int): Number of hypothesis tests carried out in this experiment. Used to define early stopping criterion for non-significance. Returned p-values are non-adjusted. 30 | TCAV (bool): Whether to calculate concept sensitivity using TCAV method or RCAV method. Default is RCAV. 31 | ''' 32 | self.model = model 33 | self.layer_name = layer_name 34 | self.train_dataset = train_dataset 35 | self.val_dataset = val_dataset 36 | self.val_loader = val_loader 37 | self.concept_labels = concept_labels 38 | self.num_classes = num_classes 39 | self.class_nums = class_nums 40 | self.class_num = class_nums[0] 41 | self.target_class_num = target_class_num 42 | self.multiple_tests_num = multiple_tests_num 43 | self.TCAV = TCAV 44 | self.logistic_regression_kwargs = logistic_regression_kwargs 45 | self.step_size = None 46 | 47 | self._reset_fields(size_change=True) 48 | self.grads = None 49 | self.concept_acts = None 50 | 51 | self.random_cav_scores = [] 52 | self.random_cav_accs = [] 53 | self.random_cavs = [] 54 | 55 | self.significance_dict = dict() 56 | self.model.eval() 57 | 58 | def _reset_fields(self, size_change): 59 | ''' 60 | Resets fields between runs. Used if testing the same layer and target class, but on a different concept or with a different step size. 61 | ''' 62 | self.sample_cav = None 63 | self.cav_score = None 64 | self.cav_scores = [] 65 | self.cavs = [] 66 | self.cav_accs = [] 67 | if size_change: 68 | self.random_cav_scores = [] 69 | self.random_latent_aug_preds = [] 70 | 71 | def _get_acts_subset(self, pos_concept_num, sample_size=None, random_concept=True, bootstrap=True, randomized_concept_labels=None): 72 | ''' 73 | Gets activations for either the subset excluding pos_concept_num (if random_concept=True), otherwise the subset of only pos_concept_num. 74 | 75 | Args: 76 | pos_concept_num: int, the concept number for this run of TCAV 77 | sample_size: int, the bootstrap sample size to draw 78 | random_concept: Bool, whether to sample from pos_concept or random concepts 79 | bootstrap: Bool, whether to use the whole sample or resample by boostrapping 80 | randomized_concept_labels: list of int, an alternative to self.concept_labels to use for generating null distribution 81 | ''' 82 | if sample_size is None: 83 | if random_concept: sample_size = sum(self.concept_labels!=pos_concept_num) 84 | else: sample_size = sum(self.concept_labels==pos_concept_num) 85 | 86 | if randomized_concept_labels: weight = [0 if random_concept^(label!=pos_concept_num) else 1 for label in randomized_concept_labels] 87 | else: weight = [0 if random_concept^(label!=pos_concept_num) else 1 for label in self.concept_labels] 88 | tot = sum(weight) 89 | weight = [indicator/tot for indicator in weight] 90 | act_inds = np.random.choice(np.arange(len(self.concept_labels)), size=sample_size, replace=bootstrap, p=weight) 91 | acts_subset = [self.concept_acts[ind] for ind in act_inds] 92 | if random_concept: labelled_acts_subset = [[act[0], 0] for act in acts_subset] 93 | else: labelled_acts_subset = [[act[0], 1] for act in acts_subset] 94 | return labelled_acts_subset 95 | 96 | def _get_cav(self, acts,): 97 | cav = CAV() 98 | cav.train(acts, **self.logistic_regression_kwargs) 99 | return cav 100 | 101 | def _get_TCAV_score_significance(self, null_hypothesis, early_stop=False): 102 | ''' 103 | Computes TCAV sensitivity scores 104 | ''' 105 | if early_stop: raise NotImplementedError() 106 | if self.grads is None: self.grads = rcav_utils.get_grads(self.model, self.layer_name, self.val_dataset, self.class_nums, self.target_class_num, self.num_classes) 107 | self.cav_scores = [sum([1 if np.dot(grad,cav.vec)>0 else 0 for grad in self.grads])/len(self.grads) for cav in self.cavs] 108 | self.cav_score = self.cav_scores[0] 109 | if null_hypothesis=='ttest_onesamp': 110 | self.significance = stats.ttest_1samp(self.cav_scores, 0.5) 111 | self.significance = (self.significance[0], min(1,self.significance[1])) 112 | self.significance_dict[self.pos_concept_num] = cp.deepcopy((self.cav_score, self.significance)) 113 | return self.cav_score, self.significance 114 | elif null_hypothesis=='permutation' or null_hypothesis=='gaussian_null': 115 | if null_hypothesis=='permutation': 116 | for i in tqdm(range(self.n_random), desc='Sampling Null Distribution'): 117 | randomized_concept_labels = [act[1] for act in self.concept_acts] 118 | random.shuffle(randomized_concept_labels) 119 | pos_acts = self._get_acts_subset(self.pos_concept_num, sample_size=self.sample_size, randomized_concept_labels=randomized_concept_labels, bootstrap=False, random_concept=False) 120 | rand_acts = self._get_acts_subset(self.pos_concept_num, sample_size=self.sample_size, randomized_concept_labels=randomized_concept_labels, bootstrap=False) 121 | new_cav = self._get_cav(rand_acts+pos_acts) 122 | self.random_cavs.append(new_cav) 123 | self.random_cav_accs.append(new_cav.acc) 124 | self.random_cav_scores = [sum([1 if np.dot(grad,cav.vec)>0 else 0 for grad in self.grads])/len(self.grads) for cav in self.random_cavs] 125 | cav_from_50 = [np.abs(score-0.5) for score in self.cav_scores][0] 126 | random_cav_from_50 = [np.abs(score-0.5) for score in self.random_cav_scores] 127 | self.significance = (cav_from_50, min(1,sum([int(random_score>cav_from_50) for random_score in random_cav_from_50])/len(random_cav_from_50))) 128 | self.significance_dict[self.pos_concept_num] = cp.deepcopy((self.cav_score, self.significance)) 129 | return self.cav_score, self.significance 130 | else: 131 | raise ValueError('null_hypothesis must be in ["ttest_onesamp","permutation","gaussian_null"]') 132 | 133 | def _get_RCAV_score_significance(self, null_hypothesis, early_stop=True): 134 | ''' 135 | Computes RCAV sensitivity scores 136 | ''' 137 | if null_hypothesis=='ttest_onesamp': 138 | assert early_stop==False 139 | 140 | # Compute RCAV sensitivity scores for the actual CAV 141 | pair_softmax = Softmax(dim=-1) 142 | self.baseline_preds, self.cav_latent_aug_preds = [], [] 143 | with torch.no_grad(): 144 | self.model.latent_aug = None 145 | for batch in self.val_loader: 146 | if type(batch)==dict: inputs = batch['image'].cuda() 147 | else: inputs = batch[0].cuda() 148 | self.baseline_preds.append(self.model(inputs, aug=None).cpu()) 149 | self.baseline_preds = softmax(np.vstack(torch.cat(self.baseline_preds).numpy()), axis=1) 150 | with torch.no_grad(): 151 | aug_tensor = torch.Tensor(self.sample_cav.vec.reshape(self.acts_dimensions)) 152 | aug_tensor = self.step_size*aug_tensor/torch.norm(aug_tensor) 153 | self.model.latent_aug = self.layer_name 154 | for batch in self.val_loader: 155 | if type(batch)==dict: inputs = batch['image'].cuda() 156 | else: inputs = batch[0].cuda() 157 | aug_batch = torch.cat(inputs.shape[0]*[aug_tensor]).cuda() 158 | self.cav_latent_aug_preds.append(self.model(inputs, aug=aug_batch).cpu()) 159 | self.cav_latent_aug_preds = softmax(np.vstack(torch.cat(self.cav_latent_aug_preds).numpy()), axis=1) 160 | 161 | self.cav_scores = [sum([1 if self.cav_latent_aug_preds[p][self.target_class_num]>=baseline_pred[self.target_class_num] else 0 for p,baseline_pred in enumerate(self.baseline_preds)])/len(self.baseline_preds) 162 | for cav in self.cavs] 163 | self.cav_score = self.cav_scores[0] 164 | 165 | # Compute RCAV sensitivity scores for null set CAVs 166 | null_threshold = math.ceil(0.05*self.n_random/self.multiple_tests_num) 167 | cav_from_50 = [np.abs(score-0.5) for score in self.cav_scores][0] 168 | self.model.latent_aug = self.layer_name 169 | null_count = 0 170 | randomized_concept_labels = [act[1] for act in self.concept_acts] 171 | for i in tqdm(range(self.n_random), desc='Sampling Null Distribution'): 172 | if len(self.random_cav_scores) < i+1 or null_hypothesis=='ttest_onesamp': 173 | if null_hypothesis=='permutation' and len(self.random_cavs)baseline_pred[self.target_class_num] else 0 for p,baseline_pred in 194 | enumerate(self.baseline_preds)])/len(self.baseline_preds)) 195 | if early_stop: 196 | null_check = np.abs(self.random_cav_scores[i]-0.5)>=cav_from_50 197 | if null_check: null_count = null_count+1 198 | if null_count>=null_threshold: 199 | random_cav_from_50 = [np.abs(score-0.5) for score in self.random_cav_scores] 200 | self.significance = (cav_from_50, min(1,sum([int(random_score>=cav_from_50) for random_score in random_cav_from_50])/len(random_cav_from_50))) 201 | self.significance_dict[self.pos_concept_num] = cp.deepcopy((self.cav_score, self.significance)) 202 | return self.cav_score, self.significance 203 | 204 | if null_hypothesis in ['permutation', 'gaussian_null']: 205 | random_cav_from_50 = [np.abs(score-0.5) for score in self.random_cav_scores] 206 | self.significance = (cav_from_50, min(1,sum([int(random_score>=cav_from_50) for random_score in random_cav_from_50])/len(random_cav_from_50))) 207 | self.significance_dict[self.pos_concept_num] = cp.deepcopy((self.cav_score, self.significance)) 208 | return self.cav_score, self.significance 209 | elif null_hypothesis=='ttest_onesamp': 210 | self.significance = stats.ttest_1samp(self.random_cav_scores, 0.5) 211 | self.significance_dict[self.pos_concept_num] = cp.deepcopy((self.cav_score, self.significance)) 212 | return self.cav_score, self.significance 213 | else: 214 | raise ValueError('null_hypothesis must be in ["ttest_onesamp","permutation", "gaussian_null"]') 215 | 216 | def benchmark_sample_correlation(self, ground_truth_score_delta, hypothesis_test=stats.kendalltau): 217 | ''' 218 | Given ground truth for image-level concept sensitivity, computes performance metrics 219 | Args: 220 | ground_truth_score_delta: (np array) ground truth softmax differences 221 | hypothesis_test: function to use for hypothesis test statistic. this function will _not_ be used for the p-value, only the statistic. 222 | Returns: 223 | Hypothesis test p-value 224 | ''' 225 | if not self.TCAV: 226 | self.cav_diffs = [self.cav_latent_aug_preds[p][self.target_class_num]-baseline_pred[self.target_class_num] for p,baseline_pred in enumerate(self.baseline_preds)] 227 | self.trained_tau = hypothesis_test(ground_truth_score_delta, self.cav_diffs) 228 | self.random_tau = [] 229 | for aug_preds in self.random_latent_aug_preds: 230 | diffs = [aug_preds[p][self.target_class_num]-baseline_pred[self.target_class_num] for p,baseline_pred in enumerate(self.baseline_preds)] 231 | self.random_tau.append(hypothesis_test(ground_truth_score_delta, diffs)) 232 | else: 233 | self.cav_diffs = np.array([spatial.distance.cosine(self.sample_cav.vec,grad) for grad in self.grads]) 234 | self.trained_tau = hypothesis_test(ground_truth_score_delta, self.cav_diffs) 235 | self.random_tau = [] 236 | for cav in self.random_cavs: 237 | rand_cav_diffs = np.array([spatial.distance.cosine(cav.vec,grad) for grad in self.grads]) 238 | self.random_tau.append(hypothesis_test(ground_truth_score_delta, rand_cav_diffs)) 239 | 240 | return sum([1 for rt in self.random_tau if self.trained_tau[0]<=rt[0]])/len(self.random_tau) 241 | 242 | def save(self, save_dir, save_prefix): 243 | ''' 244 | Note that saving clears all high memory usage fields. 245 | ''' 246 | self.cavs = [self.cavs[0]] 247 | self.val_loader = None 248 | self.val_dataset = None 249 | self.grads = None 250 | self.model = None 251 | self.concept_acts = [] 252 | self.random_cavs = [] 253 | self.train_dataset=None 254 | save_loc = save_dir+save_prefix+'_{0}.pkl'.format(self.layer_name) 255 | with open(save_loc, 'wb') as f: pickle.dump(self.__dict__,f) 256 | 257 | 258 | def run(self, pos_concept_num, sample_size=None, n_random=500, step_size=10, null_hypothesis='permutation', early_stop=True): 259 | ''' 260 | Builds CAVs and calculates scores using the subset of val_dataset corresponding to class_num. If doing multiple tests on the same layer you can use run() without re-initializing the class instance. 261 | Args: 262 | sample_size: int or None, size of bootstrap and permutation set null CAV training sets. If None defaults to the size of the dataset. 263 | n_random: int, number of boostrap or permutation set null samples to use 264 | step_size: float, step size for RCAV 265 | null_hypothesis: str in ['ttest_onesamp','permutation','gaussian_null'], defines the hypothesis test used 266 | early_stop: bool, whether to early stop when non-significance level is reached for gaussian null or permutation test. 267 | Returns: 268 | sensitivity score, (test statistic, p-value) 269 | ''' 270 | self._reset_fields(size_change=(self.step_size!=step_size)) 271 | self.n_random = n_random 272 | self.step_size = step_size 273 | self.pos_concept_num = pos_concept_num 274 | self.sample_size = sample_size 275 | if self.concept_acts is None: self.concept_acts, self.acts_dimensions = rcav_utils.get_acts(self.model, self.layer_name, self.train_dataset, self.concept_labels) 276 | 277 | # First get cav score on given samples 278 | pos_acts = self._get_acts_subset(pos_concept_num, sample_size=sample_size, random_concept=False, bootstrap=False) 279 | rand_acts = self._get_acts_subset(pos_concept_num, sample_size=sample_size, bootstrap=False) 280 | self.sample_cav = self._get_cav(rand_acts+pos_acts) 281 | self.cavs.append(self.sample_cav) 282 | self.cav_accs.append(self.sample_cav.acc) 283 | 284 | if null_hypothesis in ['ttest_onesamp']: 285 | for _ in tqdm(range(n_random), desc='Training Random CAVs'): 286 | pos_acts = self._get_acts_subset(pos_concept_num, sample_size=sample_size, random_concept=False) 287 | rand_acts = self._get_acts_subset(pos_concept_num, sample_size=sample_size) 288 | self.cavs.append(self._get_cav(rand_acts+pos_acts)) 289 | self.cav_accs.append(self.cavs[-1].acc) 290 | 291 | if null_hypothesis=='gaussian_null': 292 | self.random_cavs = [CAV() for i in range(n_random)] 293 | for cav in self.random_cavs: cav.vec = np.random.normal(size=self.sample_cav.vec.shape) 294 | 295 | if self.cav_accs[0] < 0.8: print('Warning: the CAV accuracy of {0} is low, so TCAV results may not be meaningful'.format(self.cav_accs[0])) 296 | if self.TCAV: return self._get_TCAV_score_significance(null_hypothesis, early_stop=early_stop) 297 | else: return self._get_RCAV_score_significance(null_hypothesis, early_stop=early_stop) 298 | 299 | 300 | class CAV(): 301 | def __init__(self,): 302 | ''' 303 | ''' 304 | self.vec = None 305 | self.acc = None 306 | self.class_balance = None 307 | 308 | def train(self, acts, **logistic_regression_kwargs): 309 | ''' 310 | Args: 311 | acts: list of [act,concept] pairs. 312 | logistic_regression_kwargs: hyper_parameters for linear model 313 | ''' 314 | # convert acts into arrays for training 315 | X,Y = [],[] 316 | for act,concept in acts: 317 | X.append(act) 318 | Y.append(concept) 319 | X,Y = np.array(X),np.array(Y) 320 | 321 | X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33, stratify=Y) 322 | if 'alpha' not in logistic_regression_kwargs: lm = linear_model.SGDClassifier(alpha=.01, max_iter=1000, tol=1e-3, **logistic_regression_kwargs) 323 | else: lm = linear_model.SGDClassifier(max_iter=1000, tol=1e-3, **logistic_regression_kwargs) 324 | lm.fit(X_train, Y_train) 325 | self.acc = lm.score(X_test,Y_test) 326 | self.vec = lm.coef_[0] 327 | self.class_balance = np.unique(Y,return_counts=True)[1][1]/len(Y) -------------------------------------------------------------------------------- /rcav_env.yml: -------------------------------------------------------------------------------- 1 | name: rcav_env 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - cudatoolkit 9 | - cudnn 10 | - numpy 11 | - numpy-base 12 | - opencv 13 | - pandas 14 | - pandoc 15 | - pillow 16 | - pip 17 | - pylint 18 | - python>3.7.0 19 | - python-dateutil 20 | - pytorch>1.3.0 21 | - scikit-learn 22 | - scipy 23 | - seaborn 24 | - torchvision 25 | - matplotlib 26 | - pip 27 | - tqdm -------------------------------------------------------------------------------- /rcav_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ActList(): 4 | ''' 5 | Doubles as gradient recording when used for backward hooks 6 | Args: 7 | forward: Bool, whether to forward or backward hook 8 | Returns: 9 | ''' 10 | def __init__(self, forward=True): 11 | self.acts = [] 12 | self.forward = forward 13 | self.dimensions = None 14 | 15 | def record_output(self, module, input, output): 16 | ''' 17 | Doubles as gradient recording when used for backward hook. Appends activations into self.acts. 18 | ''' 19 | if self.forward and self.dimensions is None: self.dimensions = output.data.cpu().numpy().shape 20 | if self.forward: self.acts.append(output.data.cpu().reshape(-1).numpy()) 21 | else: self.acts.append(output[0].cpu().reshape(-1).numpy()) 22 | 23 | 24 | def get_acts(model, layer_name, dataset, concept_labels,): 25 | ''' 26 | returns list of [act,concept_name] pairs 27 | ''' 28 | act_log = ActList() 29 | hook = None 30 | for name, mod in model.named_modules(): 31 | if name==layer_name: hook = mod.register_forward_hook(act_log.record_output) 32 | if hook is None: 33 | raise NameError(layer_name+'Not found') 34 | model.eval() 35 | 36 | sample = dataset[0] 37 | if type(sample)==dict: 38 | batch_shape = [1]+list(dataset[0]['image'].shape) 39 | else: 40 | batch_shape = [1]+list(dataset[0][0].shape) 41 | for ind in range(len(dataset)): 42 | with torch.no_grad(): 43 | im = dataset.__getitem__(ind) 44 | if type(im)==dict: 45 | im = im['image'].view(*batch_shape).cuda() 46 | else: 47 | im = im[0].view(*batch_shape).cuda() 48 | model(im) 49 | 50 | acts = zip(act_log.acts, concept_labels) 51 | acts = list(map(list,acts)) 52 | 53 | hook.remove() 54 | return acts, act_log.dimensions 55 | 56 | 57 | def get_grads(model, layer_name, dataset, true_class_nums, grad_class_num, num_classes): 58 | ''' 59 | Returns list of gradients for all samples in dataset of class class_num ordered by index 60 | ''' 61 | grad_log = ActList(forward=False) 62 | for name, mod in model.named_modules(): 63 | if name==layer_name: hook = mod.register_backward_hook(grad_log.record_output) 64 | model.eval() 65 | 66 | target = torch.zeros([num_classes]) 67 | target[grad_class_num]=1 68 | target = target.cuda() 69 | sample = dataset[0] 70 | if type(sample)==dict: 71 | batch_shape = [1]+list(dataset[0]['image'].shape) 72 | else: 73 | batch_shape = [1]+list(dataset[0][0].shape) 74 | 75 | for ind in range(len(dataset)): 76 | model.zero_grad() 77 | item = dataset.__getitem__(ind) 78 | if type(item)==dict: 79 | if not item['labels'] in true_class_nums: continue 80 | im = item['image'].view(*batch_shape).cuda() 81 | else: 82 | if not item[1] in true_class_nums: continue 83 | im = item[0].view(*batch_shape).cuda() 84 | out = model(im) 85 | masked_out = torch.sum(out*target) 86 | masked_out.backward() 87 | 88 | grads = grad_log.acts 89 | hook.remove() 90 | return grads -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | pandas 4 | Pillow 5 | pytorch 6 | scikit-learn 7 | scipy 8 | seaborn 9 | torchvision 10 | matplotlib 11 | tqdm -------------------------------------------------------------------------------- /textures/dots1_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/dots1_256.png -------------------------------------------------------------------------------- /textures/dots2_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/dots2_256.png -------------------------------------------------------------------------------- /textures/spiral1_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/spiral1_256.png -------------------------------------------------------------------------------- /textures/spiral2_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/spiral2_256.png -------------------------------------------------------------------------------- /textures/spiral3_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/spiral3_256.png -------------------------------------------------------------------------------- /textures/stripes1_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/stripes1_256.png -------------------------------------------------------------------------------- /textures/stripes2_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/stripes2_256.png -------------------------------------------------------------------------------- /textures/zigzag1_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/zigzag1_256.png -------------------------------------------------------------------------------- /textures/zigzag2_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/zigzag2_256.png -------------------------------------------------------------------------------- /textures/zigzag3_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/textures/zigzag3_256.png -------------------------------------------------------------------------------- /train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/train-images-idx3-ubyte -------------------------------------------------------------------------------- /train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keiserlab/rcav/2ac533e531a59607495a292fe6e598c8f7e34b57/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import torch.utils.data as utils 7 | import torch.optim as optim 8 | from torchvision import models 9 | import numpy as np 10 | import pickle 11 | from scipy.special import softmax 12 | import sklearn 13 | import copy as cp 14 | 15 | 16 | class BasicConv2d(nn.Module): 17 | def __init__(self, in_channels, out_channels, **kwargs): 18 | super(BasicConv2d, self).__init__() 19 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 20 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | x = self.bn(x) 25 | return F.relu(x, inplace=True) 26 | 27 | 28 | class im_dataset(utils.Dataset): 29 | def __init__(self, X, Y, transform=None): 30 | assert len(X)==len(Y) 31 | self.X = X 32 | self.Y = Y 33 | self.transform = transform 34 | 35 | def __getitem__(self, index): 36 | img = self.X[index] 37 | labels = self.Y[index] 38 | 39 | if self.transform: return self.transform(img), labels 40 | else: return img, labels 41 | 42 | def __len__(self): 43 | return len(self.X) 44 | 45 | class RunNet(): 46 | def __init__(self, model, criterion, optimizer, n_classes, 47 | schedulers=dict(), save_dir=None, mixup=False): 48 | ''' 49 | schedulers: dict, of the form 'name':[object, metric] where metric may be None 50 | ''' 51 | self.model = model 52 | self.criterion = criterion 53 | self.optimizer = optimizer 54 | self.n_classes = n_classes 55 | 56 | self.schedulers = schedulers 57 | self.results = dict() 58 | self.save_dir = save_dir 59 | self.stop = False # used to flag early stopping i.e. if 'early_stop' in self.schedulers.keys(): 60 | 61 | self.mixup = mixup 62 | 63 | def do_batch(self, inputs, labels, train, other=None): 64 | ''' 65 | ''' 66 | inputs, labels = inputs.cuda(), labels.cuda() 67 | if not self.mixup or not train: 68 | outputs = self.model(inputs) 69 | return outputs, labels 70 | else: 71 | outputs, mix_info = self.model(inputs, target=labels) 72 | y_a, y_b, lam = mix_info 73 | if len(y_a.shape)<2: 74 | zerosa = torch.zeros(len(y_a),self.n_classes).cuda() 75 | zerosa = zerosa.scatter(1,y_a.reshape(-1,1),1) 76 | zerosb = torch.zeros(len(y_b),self.n_classes).cuda() 77 | zerosb = zerosb.scatter(1,y_b.reshape(-1,1),1) 78 | mixed_labels = lam*zerosa+(1-lam)*zerosb 79 | else: 80 | mixed_labels = lam*y_a+(1-lam)*y_b 81 | return outputs, mixed_labels 82 | 83 | def do_epoch(self, loader, train): 84 | ''' 85 | train: bool, whether or not to keep track of and apply gradients 86 | ''' 87 | self.preds, self.label_list, self.losses = [], [], [] 88 | # Loop through loader 89 | if train: 90 | self.model.train() 91 | for i, data in enumerate(loader): 92 | self.optimizer.zero_grad() 93 | if type(data)==dict: 94 | outputs, labels = self.do_batch(data['image'], data['labels'], train) 95 | else: 96 | outputs, labels = self.do_batch(data[0], data[1], train) 97 | if len(labels.shape)==1: loss = self.criterion(outputs, labels) 98 | else: loss = self.criterion(outputs, labels) #Note this labels shape is non-standard 99 | loss.backward() 100 | self.losses.append(loss.item()) 101 | self.optimizer.step() 102 | self.preds.append(outputs.detach()) 103 | self.label_list.append(labels) 104 | self.losses = np.mean(self.losses) 105 | else: 106 | self.model.eval() 107 | with torch.no_grad(): 108 | for i, data in enumerate(loader): 109 | if type(data)==dict: 110 | outputs, labels = self.do_batch(data['image'], data['labels'], train) 111 | else: 112 | outputs, labels = self.do_batch(data[0], data[1], train) 113 | self.preds.append(outputs) 114 | self.label_list.append(labels) 115 | 116 | def get_results(self, name, format_only=False): 117 | ''' 118 | name: str, one of 'train', 'val', 'test', or a variant thereof 119 | ''' 120 | if not name in self.results.keys(): self.results[name] = dict() 121 | self.label_list = np.vstack(torch.cat(self.label_list).cpu().numpy()) 122 | self.preds = softmax(np.vstack(torch.cat(self.preds).cpu().numpy()), axis=1) 123 | if format_only: return 124 | if len(self.label_list.shape)==1: 125 | epoch_results = get_performance_metrics(num_classes=self.n_classes,preds=self.preds,label_list=self.label_list) 126 | else: 127 | epoch_results = get_performance_metrics(num_classes=self.n_classes,preds=self.preds,label_list=self.label_list[:,0]) 128 | if not name in self.results: self.results[name] = dict() 129 | for result in epoch_results.keys(): 130 | if not result in self.results[name]: self.results[name][result] = [] 131 | self.results[name][result].append(epoch_results[result]) 132 | 133 | def schedule(self): 134 | # Apply schedulers 135 | if 'decay' in self.schedulers.keys(): 136 | self.schedulers['decay'][0].step() 137 | if 'plateau' in self.schedulers.keys(): 138 | self.schedulers['plateau'][0].step(self.results['val'][self.schedulers['plateau'][1]][-1]) 139 | if 'early_stop' in self.schedulers.keys(): 140 | self.stop = self.schedulers['early_stop'][0](self.results['val'][self.schedulers['early_stop'][1]][-1], self.model) 141 | 142 | def save(self, fold): 143 | torch.save(self.model.state_dict(), self.save_dir+'/'+'fold'+str(fold)) 144 | with open(self.save_dir+'/'+"results"+str(fold)+".pkl", "wb") as file: 145 | pickle.dump(self.results,file) 146 | 147 | 148 | def get_performance_metrics(num_classes, preds, label_list, metrics=['acc','auprc','auroc','log_loss'], rounding=4): 149 | ''' 150 | num_classes: integer 151 | metrics: list of strings. subset of ['acc','auprc','auroc','log_loss'] 152 | ''' 153 | 154 | results = dict() 155 | if 'auprc' in metrics or 'auroc' in metrics: 156 | if num_classes>2: 157 | relevant_classes = list(np.unique(label_list)) 158 | Y_byclass = [np.array([[0,1] if row==l else [1,0] for row in label_list]) for l in range(num_classes)] 159 | pred_byclass = [cp.deepcopy(preds) for i in range(num_classes)] 160 | for j in range(num_classes): 161 | other = np.sum(pred_byclass[j][:,[i for i in range(num_classes) if i!=j]],axis=1) 162 | pred_byclass[j][:,1] = pred_byclass[j][:,j] 163 | pred_byclass[j][:,0] = other 164 | pred_byclass[j] = pred_byclass[j][:,[0,1]] 165 | 166 | for metric in metrics: 167 | if metric=='acc': results[metric] = round(sklearn.metrics.accuracy_score(label_list,np.argmax(preds,axis=1)),rounding) 168 | elif metric=='auprc': 169 | if num_classes>2: results[metric] = round(np.mean([sklearn.metrics.average_precision_score(Y_byclass[cl],pred_byclass[cl]) for cl in range(num_classes) if cl in relevant_classes]),rounding) 170 | else: results[metric] = round(sklearn.metrics.average_precision_score(label_list,preds[:,1]),rounding) 171 | elif metric=='auroc': 172 | if num_classes>2: results[metric] = round(np.mean([sklearn.metrics.roc_auc_score(Y_byclass[cl],pred_byclass[cl]) for cl in range(num_classes) if cl in relevant_classes]),rounding) 173 | else: results[metric] = round(sklearn.metrics.roc_auc_score(label_list,preds[:,1]),rounding) 174 | elif metric=='log_loss': 175 | sklearn.metrics.log_loss(label_list,preds,labels=np.array(list(range(num_classes)))) 176 | return results --------------------------------------------------------------------------------