├── utils.py ├── model_wrapper.py ├── activations_generator.py ├── cav.py └── tcav.py /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class CelebADataset(Dataset): 7 | def __init__(self, df, source_dir, transform=None): 8 | self.images = df.to_dict()['image'] 9 | assert os.path.isdir(source_dir), 'Image source dir is absent!' 10 | self.source_dir = source_dir 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return len(self.images) 15 | 16 | def __getitem__(self, index): 17 | img_path = os.path.join(self.source_dir, self.images[index]) 18 | image = io.imread(img_path) 19 | if self.transform is not None: 20 | image = self.transform(image) 21 | return image 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /model_wrapper.py: -------------------------------------------------------------------------------- 1 | # Pytorch 2 | import copy 3 | from torch.autograd import grad 4 | 5 | 6 | class ModelWrapper(): 7 | """ Simple model wrapper to hold pytorch models plus set up the needed 8 | hooks to access the activations and grads. 9 | """ 10 | 11 | def __init__(self, model=None, bottlenecks={}): 12 | """ Initialize wrapper with model and set up the hooks to the bottlenecks. 13 | 14 | Args: 15 | model (nn.Module): Model to test 16 | bottlenecks (dict): Dictionary attaching names to the layers to 17 | hook into. Expects, at least, an input, logit and prediction. 18 | """ 19 | self.ends = None 20 | self.y_input = None 21 | self.loss = None 22 | self.bottlenecks_gradients = None 23 | self.bottlenecks_tensors = {} 24 | self.model = copy.deepcopy(model) 25 | 26 | def save_activation(name): 27 | """ Creates hooks to the activations 28 | Args: 29 | name (string): Name of the layer to hook into 30 | """ 31 | def hook(mod, inp, out): 32 | """ Saves the activation hook to dictionary 33 | """ 34 | self.bottlenecks_tensors[name] = out 35 | return hook 36 | 37 | for name, mod in self.model._modules.items(): 38 | if name in bottlenecks.keys(): 39 | mod.register_forward_hook(save_activation(bottlenecks[name])) 40 | 41 | def _make_gradient_tensors(self, y, bottleneck_name): 42 | """ 43 | Makes gradient tensor for logit y w.r.t. layer with activations 44 | 45 | Args: 46 | y (int): Index of logit (class) 47 | bottleneck_name (string): Name of layer activations 48 | 49 | Returns: 50 | (torch.tensor): Gradients of logit w.r.t. to activations 51 | 52 | """ 53 | acts = self.bottlenecks_tensors[bottleneck_name] 54 | return grad(self.ends[:, y], acts) 55 | 56 | def eval(self): 57 | """ Sets wrapped model to eval mode as is done in pytorch. 58 | """ 59 | self.model.eval() 60 | 61 | def train(self): 62 | """ Sets wrapped model to train mode as is done in pytorch. 63 | """ 64 | self.model.train() 65 | 66 | def __call__(self, x): 67 | """ Calls prediction on wrapped model pytorch. 68 | """ 69 | self.ends = self.model(x) 70 | return self.ends 71 | 72 | def get_gradient(self, y, bottleneck_name): 73 | """ Returns the gradient at a given bottle_neck. 74 | 75 | Args: 76 | y: Index of the logit layer (class) 77 | bottleneck_name: Name of the bottleneck to get gradients w.r.t. 78 | 79 | Returns: 80 | (torch.tensor): Tensor containing the gradients at layer. 81 | """ 82 | self.y_input = y 83 | return self._make_gradient_tensors(y, bottleneck_name) 84 | -------------------------------------------------------------------------------- /activations_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from utils import CelebADataset 7 | 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class ActivationsGenerator(): 12 | def __init__(self, model, source_dir, source_df, acts_dir, 13 | bottleneck_names, concepts=None, transform=None, 14 | max_examples=10): 15 | 16 | assert os.path.isfile(source_df), 'Concepts dataframe is missing!' 17 | self.concepts_df = pd.read_csv(source_df) 18 | # First column is path 19 | if concepts is None: 20 | self.concepts = list(self.concepts_df.columns)[:1] 21 | else: 22 | self.concepts = concepts 23 | self.source_dir = source_dir 24 | self.acts_dir = acts_dir 25 | self.model = model 26 | self.bottleneck_names = bottleneck_names 27 | self.max_examples = max_examples 28 | self.transform = transform 29 | 30 | def generate_acts(self, batch_size=16, num_workers=8, verbose=False): 31 | """ Generates the concept activations from the loaded dataframe and 32 | saves them to the acts_dir 33 | """ 34 | # check if acts_dir exists 35 | if not os.path.exists(self.acts_dir): 36 | os.makedirs(self.acts_dir) 37 | self.model.model.to(device) 38 | self.model.eval() 39 | for concept in self.concepts: 40 | sub_frame = self.concepts_df[self.concepts_df[concept] == 1] 41 | sub_frame.reset_index(drop=True, inplace=True) 42 | concept_dataset = CelebADataset(df=sub_frame, 43 | source_dir=self.source_dir, 44 | transform=self.transform) 45 | concept_loader = DataLoader(concept_dataset, batch_size=batch_size, 46 | num_workers=num_workers) 47 | acts = {} 48 | for idx, batch in enumerate(concept_loader): 49 | if idx == self.max_examples: 50 | break 51 | batch = batch.to(device) 52 | # need to run batch through the model to capture activations 53 | out_ = self.model(batch) 54 | for bottleneck in self.bottleneck_names: 55 | if bottleneck not in acts.keys(): 56 | acts[bottleneck] = (self.model. 57 | bottlenecks_tensors[bottleneck]. 58 | cpu().detach().numpy()) 59 | else: 60 | acts[bottleneck] = np.append( 61 | acts[bottleneck], 62 | self.model.bottlenecks_tensors[bottleneck].cpu(). 63 | detach().numpy(), axis=0) 64 | if verbose: 65 | print("[{}/{}]".format(idx, len(concept_loader))) 66 | 67 | for bottleneck in self.bottleneck_names: 68 | acts_path = os.path.join(self.acts_dir, 'acts_{}_{}' 69 | .format(concept, bottleneck)) 70 | np.save(acts_path, acts[bottleneck]) 71 | 72 | def load_activations(self): 73 | acts = {} 74 | for concept in self.concepts: 75 | if concept not in acts: 76 | acts[concept] = {} 77 | for bottleneck in self.bottleneck_names: 78 | acts_path = os.path.join(self.acts_dir, 'acts_{}_{}.npy' 79 | .format(concept, bottleneck)) 80 | acts[concept][bottleneck] = np.load(acts_path) 81 | 82 | return acts 83 | -------------------------------------------------------------------------------- /cav.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from sklearn import linear_model 4 | from sklearn import metrics 5 | from sklearn.model_selection import train_test_split 6 | 7 | 8 | class CAV(object): 9 | 10 | @staticmethod 11 | def create_training_set(concepts, bottleneck, acts): 12 | """ Creates and formats the training set of a given set of CAVs 13 | 14 | Args: 15 | concepts (list): List of concepts in CAV training dataset. 16 | bottleneck (string): Bottleneck name for the CAVs. 17 | acts (np.array): NumPy array with the activations corresponding to 18 | the concepts at given bottleneck. 19 | 20 | Returns: 21 | x (np.array): NumPy array of flattened activations. 22 | label (list): List of labels associated to the activations. 23 | labels2text (dict): Dictionary associating labels to concept names. 24 | """ 25 | x = [] 26 | labels = [] 27 | labels2text = {} 28 | # balance classes to make sure there are no issues in training 29 | min_data_points = np.min( [acts[concept][bottleneck].shape[0] 30 | for concept in concepts]) 31 | # flatten the activations as input for the classifier 32 | for idx, concept in enumerate(concepts): 33 | x.extend(acts[concept][bottleneck][:min_data_points].reshape( 34 | min_data_points, -1)) 35 | labels.extend([idx]*min_data_points) 36 | labels2text[idx] = concept 37 | x = np.array(x) 38 | labels = np.array(labels) 39 | return x, labels, labels2text 40 | 41 | def __init__(self, concepts, bottleneck, hparams, save_path=None): 42 | 43 | self.bottleneck = bottleneck 44 | self.hparams = hparams 45 | self.save_path = save_path 46 | self.concepts = concepts 47 | 48 | def train(self, acts): 49 | x, labels, labels2text = CAV.create_training_set(self.concepts, 50 | self.bottleneck, 51 | acts) 52 | 53 | if self.hparams.model_type == 'linear': 54 | lm = linear_model.SGDClassifier(alpha=self.hparams.alpha) 55 | elif self.hparams.model_type == 'logistic': 56 | lm = linear_model.LogisticRegression() 57 | else: 58 | raise ValueError('Invalid model type{}'.format( 59 | self.hparams.model_type)) 60 | 61 | self.accuracies = self.train_lm(lm, x, labels, labels2text) 62 | if len(lm.coef_) == 1: 63 | """ 64 | If there were only two labels, the concept is assigned to label 65 | 0 by 66 | default. So we flip the coef_ to reflect this. 67 | """ 68 | self.cavs = [-1 * lm.coef_[0], lm.coef_[0]] 69 | else: 70 | self.cavs = [c for c in lm.coef_] 71 | self.save_cavs() 72 | 73 | def train_lm(self, lm, x, y, labels2text): 74 | x_train, x_test, y_train, y_test, = train_test_split(x, y, 75 | test_size=0.2, 76 | stratify=y) 77 | lm.fit(x_train, y_train) 78 | y_pred = lm.predict(x_test) 79 | # get accuracy for each class 80 | num_classes = max(y) + 1 81 | acc = {} 82 | num_correct = 0 83 | for class_id in range(num_classes): 84 | # get indices of all test data that has this class. 85 | print(type(y_test)) 86 | idx = (y_test == class_id) 87 | 88 | acc[labels2text[class_id]] = metrics.accuracy_score( 89 | y_pred[idx], y_test[idx]) 90 | # overall correctness is weighted by the number of examples in 91 | # this class. 92 | num_correct += (sum(idx) * acc[labels2text[class_id]]) 93 | acc['overall'] = float(num_correct) / float(len(y_test)) 94 | return acc 95 | 96 | def save_cavs(self): 97 | """Save a dictionary of this CAV to a pickle.""" 98 | save_dict = { 99 | 'concepts': self.concepts, 100 | 'bottleneck': self.bottleneck, 101 | 'hparams': self.hparams, 102 | 'accuracies': self.accuracies, 103 | 'cavs': self.cavs, 104 | 'saved_path': self.save_path 105 | } 106 | if self.save_path is not None: 107 | with open('filename.pickle', 'wb') as pkl_file: 108 | pickle.dump(save_dict, pkl_file) 109 | else: 110 | print('save_path is None. Not saving anything') 111 | -------------------------------------------------------------------------------- /tcav.py: -------------------------------------------------------------------------------- 1 | import time 2 | from cav import CAV 3 | from cav import get_or_train_cav 4 | import numpy as np 5 | import run_params 6 | import utils 7 | 8 | 9 | class TCAV(object): 10 | """TCAV object: runs TCAV for one target and a set of concepts. 11 | The static methods (get_direction_dir_sign, compute_tcav_score, 12 | get_directional_dir) invole getting directional derivatives and 13 | calculating TCAV scores. These are static because they might be 14 | useful independently, for instance, if you are developing a new 15 | interpretability method using CAVs. 16 | """ 17 | 18 | @staticmethod 19 | def get_direction_dir_sign(mymodel, act, cav, concept, class_id): 20 | """Get the sign of directional derivative. 21 | 22 | Args: 23 | mymodel (nn.Module): a model class instance 24 | act: activations of one bottleneck to get gradient with respect to. 25 | cav: an instance of cav 26 | concept: one concept 27 | class_id: index of the class of interest (target) in logit layer. 28 | Returns: 29 | sign of the directional derivative 30 | """ 31 | # negative one to get direction that decreases the probability 32 | grad = np.reshape(mymodel.get_gradient(act, [class_id], cav.bottleneck), -1) 33 | dot_prod = np.dot(grad, cav.get_direction(concept)) 34 | return dot_prod < 0 35 | 36 | @staticmethod 37 | def compute_tcav_score(mymodel, target_class, concept, cav, 38 | class_acts): 39 | """Compute TCAV score. 40 | 41 | Args: 42 | mymodel: a model class instance 43 | target_class: one target class 44 | concept: one concept 45 | cav: an instance of cav 46 | class_acts: activations of the images in the target class. 47 | 48 | Returns: 49 | TCAV score (i.e., ratio of pictures that returns negative dot 50 | product wrt loss). 51 | """ 52 | count = 0 53 | class_id = mymodel.label_to_id(target_class) 54 | for i in range(len(class_acts)): 55 | act = np.expand_dims(class_acts[i], 0) 56 | if TCAV.get_direction_dir_sign(mymodel, act, cav, concept, 57 | class_id): 58 | count += 1 59 | return float(count) / float(len(class_acts)) 60 | 61 | def __init__(self, target, concepts, bottlenecks, activation_generator, 62 | alphas, random_counterpart, cav_dir, num_random_exp, 63 | random_concepts): 64 | """Initialze tcav class. 65 | 66 | Args: 67 | target: one target class 68 | concepts: one concept 69 | bottlenecks: the name of a bottleneck of interest. 70 | activation_generator: an ActivationGeneratorInterface instance to 71 | return activations. 72 | alphas: list of hyper parameters to run 73 | random_counterpart: the random concept to run against the concepts 74 | for statistical testing. 75 | cav_dir: the path to store CAVs 76 | num_random_exp: number of random experiments to compare against. 77 | random_concepts: A list of names of random concepts for the random 78 | experiments to draw from. Optional, if not provided, the names 79 | will be random500_{i} for i in num_random_exp. 80 | """ 81 | self.target = target 82 | self.concepts = concepts 83 | self.bottlenecks = bottlenecks 84 | self.activation_generator = activation_generator 85 | self.cav_dir = cav_dir 86 | self.alphas = alphas 87 | self.random_counterpart = random_counterpart 88 | self.mymodel = activation_generator.get_model() 89 | self.model_to_run = self.mymodel.model_name 90 | 91 | # make pairs to test. 92 | self._process_what_to_run_expand(num_random_exp=num_random_exp, 93 | random_concepts=random_concepts) 94 | # parameters 95 | self.params = self.get_params() 96 | 97 | def run(self): 98 | """Run TCAV for all parameters (concept and random), write results to html. 99 | Args: 100 | num_workers: number of workers to parallelize 101 | run_parallel: run this parallel. 102 | Returns: 103 | results: result dictionary. 104 | """ 105 | # for random exp, a machine with cpu = 30, ram = 300G, disk = 10G and 106 | # pool worker 50 seems to work. 107 | now = time.time() 108 | results = [] 109 | for param in self.params: 110 | results.append(self._run_single_set(param)) 111 | return results 112 | 113 | def _run_single_set(self, param): 114 | """Run TCAV with provided for one set of (target, concepts). 115 | Args: 116 | param: parameters to run 117 | Returns: 118 | a dictionary of results (panda frame) 119 | """ 120 | bottleneck = param.bottleneck 121 | concepts = param.concepts 122 | target_class = param.target_class 123 | activation_generator = param.activation_generator 124 | alpha = param.alpha 125 | mymodel = param.model 126 | cav_dir = param.cav_dir 127 | 128 | # Get acts 129 | acts = activation_generator.process_and_load_activations( 130 | [bottleneck], concepts + [target_class]) 131 | # Get CAVs 132 | cav_hparams = CAV.default_hparams() 133 | cav_hparams.alpha = alpha 134 | cav_instance = get_or_train_cav( 135 | concepts, bottleneck, acts, cav_dir=cav_dir, 136 | cav_hparams=cav_hparams) 137 | 138 | # clean up 139 | for c in concepts: 140 | del acts[c] 141 | 142 | # Hypo testing 143 | a_cav_key = CAV.cav_key(concepts, bottleneck, cav_hparams.model_type, 144 | cav_hparams.alpha) 145 | target_class_for_compute_tcav_score = target_class 146 | 147 | for cav_concept in concepts: 148 | if cav_concept is self.random_counterpart or 'random' not in cav_concept: 149 | i_up = self.compute_tcav_score( 150 | mymodel, target_class_for_compute_tcav_score, cav_concept, 151 | cav_instance, acts[target_class][cav_instance.bottleneck]) 152 | val_directional_dirs = self.get_directional_dir( 153 | mymodel, target_class_for_compute_tcav_score, cav_concept, 154 | cav_instance, acts[target_class][cav_instance.bottleneck]) 155 | result = {'cav_key' : a_cav_key, 'cav_concept' : cav_concept, 156 | 'target_class' : target_class, 'i_up' : i_up, 157 | 'val_directional_dirs_abs_mean' : 158 | np.mean(np.abs(val_directional_dirs)), 159 | 'val_directional_dirs_mean' : 160 | np.mean(val_directional_dirs), 161 | 'val_directional_dirs_std' : 162 | np.std(val_directional_dirs), 163 | 'note' : 'alpha_%s ' % (alpha), 164 | 'alpha' : alpha, 165 | 'bottleneck' : bottleneck} 166 | del acts 167 | return result 168 | 169 | # def _process_what_to_run_expand(self, num_random_exp=100, random_concepts=None): 170 | # """Get tuples of parameters to run TCAV with. 171 | # TCAV builds random concept to conduct statistical significance testing 172 | # againts the concept. To do this, we build many concept vectors, and many 173 | # random vectors. This function prepares runs by expanding parameters. 174 | # Args: 175 | # num_random_exp: number of random experiments to run to compare. 176 | # random_concepts: A list of names of random concepts for the random experiments 177 | # to draw from. Optional, if not provided, the names will be 178 | # random500_{i} for i in num_random_exp. 179 | # """ 180 | 181 | # target_concept_pairs = [(self.target, self.concepts)] 182 | 183 | # all_concepts_concepts, pairs_to_run_concepts = ( 184 | # utils.process_what_to_run_expand( 185 | # utils.process_what_to_run_concepts(target_concept_pairs), 186 | # self.random_counterpart, 187 | # num_random_exp=num_random_exp, 188 | # random_concepts=random_concepts)) 189 | # all_concepts_randoms, pairs_to_run_randoms = ( 190 | # utils.process_what_to_run_expand( 191 | # utils.process_what_to_run_randoms(target_concept_pairs, 192 | # self.random_counterpart), 193 | # self.random_counterpart, 194 | # num_random_exp=num_random_exp, 195 | # random_concepts=random_concepts)) 196 | # self.all_concepts = list(set(all_concepts_concepts + all_concepts_randoms)) 197 | # self.pairs_to_test = pairs_to_run_concepts + pairs_to_run_randoms 198 | 199 | def get_params(self): 200 | """Enumerate parameters for the run function. 201 | Returns: 202 | parameters 203 | """ 204 | params = [] 205 | for bottleneck in self.bottlenecks: 206 | for target_in_test, concepts_in_test in self.pairs_to_test: 207 | for alpha in self.alphas: 208 | params.append( 209 | run_params.RunParams(bottleneck, concepts_in_test, 210 | target_in_test, 211 | self.activation_generator, 212 | self.cav_dir, alpha, 213 | self.mymodel)) 214 | return params 215 | --------------------------------------------------------------------------------