├── .gitignore ├── .ipynb_checkpoints ├── DeiT_example-checkpoint.ipynb └── example-checkpoint.ipynb ├── BERT_explainability.ipynb ├── BERT_explainability └── modules │ ├── BERT │ ├── BERT.py │ ├── BERT_cls_lrp.py │ ├── BERT_orig_lrp.py │ ├── BertForSequenceClassification.py │ └── ExplanationGenerator.py │ ├── __init__.py │ ├── layers_lrp.py │ └── layers_ours.py ├── BERT_params ├── boolq.json ├── boolq_baas.json ├── boolq_bert.json ├── boolq_soft.json ├── cose_bert.json ├── cose_multiclass.json ├── esnli_bert.json ├── evidence_inference.json ├── evidence_inference_bert.json ├── evidence_inference_soft.json ├── fever.json ├── fever_baas.json ├── fever_bert.json ├── fever_soft.json ├── movies.json ├── movies_baas.json ├── movies_bert.json ├── movies_soft.json ├── multirc.json ├── multirc_baas.json ├── multirc_bert.json └── multirc_soft.json ├── BERT_rationale_benchmark ├── __init__.py ├── metrics.py ├── models │ ├── model_utils.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── bert_pipeline.py │ │ ├── pipeline_train.py │ │ └── pipeline_utils.py │ └── sequence_taggers.py └── utils.py ├── DeiT.PNG ├── DeiT_example.ipynb ├── LICENSE ├── README.md ├── Transformer_explainability.ipynb ├── baselines └── ViT │ ├── ViT_LRP.py │ ├── ViT_explanation_generator.py │ ├── ViT_new.py │ ├── ViT_orig_LRP.py │ ├── generate_visualizations.py │ ├── helpers.py │ ├── imagenet_seg_eval.py │ ├── layer_helpers.py │ ├── misc_functions.py │ ├── pertubation_eval_from_hdf5.py │ └── weight_init.py ├── data ├── Imagenet.py ├── VOC.py ├── __init__.py ├── imagenet.py ├── imagenet_utils.py └── transforms.py ├── dataset └── expl_hdf5.py ├── example.PNG ├── example.ipynb ├── method-page-001.jpg ├── modules ├── __init__.py ├── layers_lrp.py └── layers_ours.py ├── new_work.jpg ├── requirements.txt ├── samples ├── CLS2IDX.py ├── catdog.png ├── dogbird.png ├── dogcat2.png ├── el1.png ├── el2.png ├── el3.png ├── el4.png └── el5.png └── utils ├── __init__.py ├── confusionmatrix.py ├── iou.py ├── metric.py ├── metrices.py ├── parallel.py ├── render.py ├── saver.py └── summaries.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | all_good_vis/ 3 | __pycache__ 4 | *.tar 5 | .idea 6 | run/ 7 | baselines/ViT/experiments/ 8 | baselines/ViT/visualizations/ 9 | bert_models/ 10 | data/movies/ 11 | 12 | -------------------------------------------------------------------------------- /BERT_explainability/modules/BERT/BERT_cls_lrp.py: -------------------------------------------------------------------------------- 1 | from transformers import BertPreTrainedModel 2 | from transformers.utils import logging 3 | from BERT_explainability.modules.layers_lrp import * 4 | from BERT_explainability.modules.BERT.BERT_orig_lrp import BertModel 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | import torch.nn as nn 7 | from typing import List, Any 8 | import torch 9 | from BERT_rationale_benchmark.models.model_utils import PaddedSequence 10 | 11 | 12 | class BertForSequenceClassification(BertPreTrainedModel): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self.num_labels = config.num_labels 16 | 17 | self.bert = BertModel(config) 18 | self.dropout = Dropout(config.hidden_dropout_prob) 19 | self.classifier = Linear(config.hidden_size, config.num_labels) 20 | 21 | self.init_weights() 22 | 23 | def forward( 24 | self, 25 | input_ids=None, 26 | attention_mask=None, 27 | token_type_ids=None, 28 | position_ids=None, 29 | head_mask=None, 30 | inputs_embeds=None, 31 | labels=None, 32 | output_attentions=None, 33 | output_hidden_states=None, 34 | return_dict=None, 35 | ): 36 | r""" 37 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 38 | Labels for computing the sequence classification/regression loss. 39 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. 40 | If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 41 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 42 | """ 43 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 44 | 45 | outputs = self.bert( 46 | input_ids, 47 | attention_mask=attention_mask, 48 | token_type_ids=token_type_ids, 49 | position_ids=position_ids, 50 | head_mask=head_mask, 51 | inputs_embeds=inputs_embeds, 52 | output_attentions=output_attentions, 53 | output_hidden_states=output_hidden_states, 54 | return_dict=return_dict, 55 | ) 56 | 57 | pooled_output = outputs[1] 58 | 59 | pooled_output = self.dropout(pooled_output) 60 | logits = self.classifier(pooled_output) 61 | 62 | loss = None 63 | if labels is not None: 64 | if self.num_labels == 1: 65 | # We are doing regression 66 | loss_fct = MSELoss() 67 | loss = loss_fct(logits.view(-1), labels.view(-1)) 68 | else: 69 | loss_fct = CrossEntropyLoss() 70 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 71 | 72 | if not return_dict: 73 | output = (logits,) + outputs[2:] 74 | return ((loss,) + output) if loss is not None else output 75 | 76 | return SequenceClassifierOutput( 77 | loss=loss, 78 | logits=logits, 79 | hidden_states=outputs.hidden_states, 80 | attentions=outputs.attentions, 81 | ) 82 | 83 | def relprop(self, cam=None, **kwargs): 84 | cam = self.classifier.relprop(cam, **kwargs) 85 | cam = self.dropout.relprop(cam, **kwargs) 86 | cam = self.bert.relprop(cam, **kwargs) 87 | return cam 88 | 89 | 90 | # this is the actual classifier we will be using 91 | class BertClassifier(nn.Module): 92 | """Thin wrapper around BertForSequenceClassification""" 93 | 94 | def __init__(self, 95 | bert_dir: str, 96 | pad_token_id: int, 97 | cls_token_id: int, 98 | sep_token_id: int, 99 | num_labels: int, 100 | max_length: int = 512, 101 | use_half_precision=True): 102 | super(BertClassifier, self).__init__() 103 | bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels) 104 | if use_half_precision: 105 | import apex 106 | bert = bert.half() 107 | self.bert = bert 108 | self.pad_token_id = pad_token_id 109 | self.cls_token_id = cls_token_id 110 | self.sep_token_id = sep_token_id 111 | self.max_length = max_length 112 | 113 | def forward(self, 114 | query: List[torch.tensor], 115 | docids: List[Any], 116 | document_batch: List[torch.tensor]): 117 | assert len(query) == len(document_batch) 118 | print(query) 119 | # note about device management: 120 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) 121 | # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access 122 | target_device = next(self.parameters()).device 123 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) 124 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) 125 | input_tensors = [] 126 | position_ids = [] 127 | for q, d in zip(query, document_batch): 128 | if len(q) + len(d) + 2 > self.max_length: 129 | d = d[:(self.max_length - len(q) - 2)] 130 | input_tensors.append(torch.cat([cls_token, q, sep_token, d])) 131 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) 132 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, 133 | device=target_device) 134 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) 135 | (classes,) = self.bert(bert_input.data, 136 | attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device), 137 | position_ids=positions.data) 138 | assert torch.all(classes == classes) # for nans 139 | 140 | print(input_tensors[0]) 141 | print(self.relprop()[0]) 142 | 143 | return classes 144 | 145 | def relprop(self, cam=None, **kwargs): 146 | return self.bert.relprop(cam, **kwargs) 147 | 148 | 149 | if __name__ == '__main__': 150 | from transformers import BertTokenizer 151 | import os 152 | 153 | class Config: 154 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels, 155 | hidden_dropout_prob): 156 | self.hidden_size = hidden_size 157 | self.num_attention_heads = num_attention_heads 158 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 159 | self.num_labels = num_labels 160 | self.hidden_dropout_prob = hidden_dropout_prob 161 | 162 | 163 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 164 | x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]", 165 | add_special_tokens=True, 166 | max_length=512, 167 | return_token_type_ids=False, 168 | return_attention_mask=True, 169 | pad_to_max_length=True, 170 | return_tensors='pt', 171 | truncation=True) 172 | 173 | print(x['input_ids']) 174 | 175 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) 176 | model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt') 177 | model.load_state_dict(torch.load(model_save_file)) 178 | 179 | # x = torch.randint(100, (2, 20)) 180 | # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102, 181 | # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101, 182 | # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005, 183 | # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102, 184 | # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101, 185 | # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010, 186 | # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102, 187 | # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101, 188 | # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054, 189 | # 102, 101, 1012, 102]]) 190 | # x.requires_grad_() 191 | 192 | model.eval() 193 | 194 | y = model(x['input_ids'], x['attention_mask']) 195 | print(y) 196 | 197 | cam, _ = model.relprop() 198 | 199 | #print(cam.shape) 200 | 201 | cam = cam.sum(-1) 202 | #print(cam) 203 | -------------------------------------------------------------------------------- /BERT_explainability/modules/BERT/BertForSequenceClassification.py: -------------------------------------------------------------------------------- 1 | from transformers import BertPreTrainedModel 2 | from transformers.utils import logging 3 | from BERT_explainability.modules.layers_ours import * 4 | from BERT_explainability.modules.BERT.BERT import BertModel 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | import torch.nn as nn 7 | from typing import List, Any 8 | import torch 9 | from BERT_rationale_benchmark.models.model_utils import PaddedSequence 10 | 11 | 12 | class BertForSequenceClassification(BertPreTrainedModel): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self.num_labels = config.num_labels 16 | 17 | self.bert = BertModel(config) 18 | self.dropout = Dropout(config.hidden_dropout_prob) 19 | self.classifier = Linear(config.hidden_size, config.num_labels) 20 | 21 | self.init_weights() 22 | 23 | def forward( 24 | self, 25 | input_ids=None, 26 | attention_mask=None, 27 | token_type_ids=None, 28 | position_ids=None, 29 | head_mask=None, 30 | inputs_embeds=None, 31 | labels=None, 32 | output_attentions=None, 33 | output_hidden_states=None, 34 | return_dict=None, 35 | ): 36 | r""" 37 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 38 | Labels for computing the sequence classification/regression loss. 39 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. 40 | If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 41 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 42 | """ 43 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 44 | 45 | outputs = self.bert( 46 | input_ids, 47 | attention_mask=attention_mask, 48 | token_type_ids=token_type_ids, 49 | position_ids=position_ids, 50 | head_mask=head_mask, 51 | inputs_embeds=inputs_embeds, 52 | output_attentions=output_attentions, 53 | output_hidden_states=output_hidden_states, 54 | return_dict=return_dict, 55 | ) 56 | 57 | pooled_output = outputs[1] 58 | 59 | pooled_output = self.dropout(pooled_output) 60 | logits = self.classifier(pooled_output) 61 | 62 | loss = None 63 | if labels is not None: 64 | if self.num_labels == 1: 65 | # We are doing regression 66 | loss_fct = MSELoss() 67 | loss = loss_fct(logits.view(-1), labels.view(-1)) 68 | else: 69 | loss_fct = CrossEntropyLoss() 70 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 71 | 72 | if not return_dict: 73 | output = (logits,) + outputs[2:] 74 | return ((loss,) + output) if loss is not None else output 75 | 76 | return SequenceClassifierOutput( 77 | loss=loss, 78 | logits=logits, 79 | hidden_states=outputs.hidden_states, 80 | attentions=outputs.attentions, 81 | ) 82 | 83 | def relprop(self, cam=None, **kwargs): 84 | cam = self.classifier.relprop(cam, **kwargs) 85 | cam = self.dropout.relprop(cam, **kwargs) 86 | cam = self.bert.relprop(cam, **kwargs) 87 | # print("conservation: ", cam.sum()) 88 | return cam 89 | 90 | 91 | # this is the actual classifier we will be using 92 | class BertClassifier(nn.Module): 93 | """Thin wrapper around BertForSequenceClassification""" 94 | 95 | def __init__(self, 96 | bert_dir: str, 97 | pad_token_id: int, 98 | cls_token_id: int, 99 | sep_token_id: int, 100 | num_labels: int, 101 | max_length: int = 512, 102 | use_half_precision=True): 103 | super(BertClassifier, self).__init__() 104 | bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels) 105 | if use_half_precision: 106 | import apex 107 | bert = bert.half() 108 | self.bert = bert 109 | self.pad_token_id = pad_token_id 110 | self.cls_token_id = cls_token_id 111 | self.sep_token_id = sep_token_id 112 | self.max_length = max_length 113 | 114 | def forward(self, 115 | query: List[torch.tensor], 116 | docids: List[Any], 117 | document_batch: List[torch.tensor]): 118 | assert len(query) == len(document_batch) 119 | print(query) 120 | # note about device management: 121 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) 122 | # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access 123 | target_device = next(self.parameters()).device 124 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) 125 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) 126 | input_tensors = [] 127 | position_ids = [] 128 | for q, d in zip(query, document_batch): 129 | if len(q) + len(d) + 2 > self.max_length: 130 | d = d[:(self.max_length - len(q) - 2)] 131 | input_tensors.append(torch.cat([cls_token, q, sep_token, d])) 132 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) 133 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, 134 | device=target_device) 135 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) 136 | (classes,) = self.bert(bert_input.data, 137 | attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device), 138 | position_ids=positions.data) 139 | assert torch.all(classes == classes) # for nans 140 | 141 | print(input_tensors[0]) 142 | print(self.relprop()[0]) 143 | 144 | return classes 145 | 146 | def relprop(self, cam=None, **kwargs): 147 | return self.bert.relprop(cam, **kwargs) 148 | 149 | 150 | if __name__ == '__main__': 151 | from transformers import BertTokenizer 152 | import os 153 | 154 | class Config: 155 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels, 156 | hidden_dropout_prob): 157 | self.hidden_size = hidden_size 158 | self.num_attention_heads = num_attention_heads 159 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 160 | self.num_labels = num_labels 161 | self.hidden_dropout_prob = hidden_dropout_prob 162 | 163 | 164 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 165 | x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]", 166 | add_special_tokens=True, 167 | max_length=512, 168 | return_token_type_ids=False, 169 | return_attention_mask=True, 170 | pad_to_max_length=True, 171 | return_tensors='pt', 172 | truncation=True) 173 | 174 | print(x['input_ids']) 175 | 176 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) 177 | model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt') 178 | model.load_state_dict(torch.load(model_save_file)) 179 | 180 | # x = torch.randint(100, (2, 20)) 181 | # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102, 182 | # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101, 183 | # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005, 184 | # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102, 185 | # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101, 186 | # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010, 187 | # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102, 188 | # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101, 189 | # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054, 190 | # 102, 101, 1012, 102]]) 191 | # x.requires_grad_() 192 | 193 | model.eval() 194 | 195 | y = model(x['input_ids'], x['attention_mask']) 196 | print(y) 197 | 198 | cam, _ = model.relprop() 199 | 200 | #print(cam.shape) 201 | 202 | cam = cam.sum(-1) 203 | #print(cam) 204 | -------------------------------------------------------------------------------- /BERT_explainability/modules/BERT/ExplanationGenerator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import glob 5 | 6 | # compute rollout between attention layers 7 | def compute_rollout_attention(all_layer_matrices, start_layer=0): 8 | # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow 9 | num_tokens = all_layer_matrices[0].shape[1] 10 | batch_size = all_layer_matrices[0].shape[0] 11 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) 12 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] 13 | matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) 14 | for i in range(len(all_layer_matrices))] 15 | joint_attention = matrices_aug[start_layer] 16 | for i in range(start_layer+1, len(matrices_aug)): 17 | joint_attention = matrices_aug[i].bmm(joint_attention) 18 | return joint_attention 19 | 20 | class Generator: 21 | def __init__(self, model): 22 | self.model = model 23 | self.model.eval() 24 | 25 | def forward(self, input_ids, attention_mask): 26 | return self.model(input_ids, attention_mask) 27 | 28 | def generate_LRP(self, input_ids, attention_mask, 29 | index=None, start_layer=11): 30 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] 31 | kwargs = {"alpha": 1} 32 | 33 | if index == None: 34 | index = np.argmax(output.cpu().data.numpy(), axis=-1) 35 | 36 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 37 | one_hot[0, index] = 1 38 | one_hot_vector = one_hot 39 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 40 | one_hot = torch.sum(one_hot.cuda() * output) 41 | 42 | self.model.zero_grad() 43 | one_hot.backward(retain_graph=True) 44 | 45 | self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs) 46 | 47 | cams = [] 48 | blocks = self.model.bert.encoder.layer 49 | for blk in blocks: 50 | grad = blk.attention.self.get_attn_gradients() 51 | cam = blk.attention.self.get_attn_cam() 52 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) 53 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) 54 | cam = grad * cam 55 | cam = cam.clamp(min=0).mean(dim=0) 56 | cams.append(cam.unsqueeze(0)) 57 | rollout = compute_rollout_attention(cams, start_layer=start_layer) 58 | rollout[:, 0, 0] = rollout[:, 0].min() 59 | return rollout[:, 0] 60 | 61 | 62 | def generate_LRP_last_layer(self, input_ids, attention_mask, 63 | index=None): 64 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] 65 | kwargs = {"alpha": 1} 66 | if index == None: 67 | index = np.argmax(output.cpu().data.numpy(), axis=-1) 68 | 69 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 70 | one_hot[0, index] = 1 71 | one_hot_vector = one_hot 72 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 73 | one_hot = torch.sum(one_hot.cuda() * output) 74 | 75 | self.model.zero_grad() 76 | one_hot.backward(retain_graph=True) 77 | 78 | self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs) 79 | 80 | cam = self.model.bert.encoder.layer[-1].attention.self.get_attn_cam()[0] 81 | cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0) 82 | cam[:, 0, 0] = 0 83 | return cam[:, 0] 84 | 85 | def generate_full_lrp(self, input_ids, attention_mask, 86 | index=None): 87 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] 88 | kwargs = {"alpha": 1} 89 | 90 | if index == None: 91 | index = np.argmax(output.cpu().data.numpy(), axis=-1) 92 | 93 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 94 | one_hot[0, index] = 1 95 | one_hot_vector = one_hot 96 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 97 | one_hot = torch.sum(one_hot.cuda() * output) 98 | 99 | self.model.zero_grad() 100 | one_hot.backward(retain_graph=True) 101 | 102 | cam = self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs) 103 | cam = cam.sum(dim=2) 104 | cam[:, 0] = 0 105 | return cam 106 | 107 | def generate_attn_last_layer(self, input_ids, attention_mask, 108 | index=None): 109 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] 110 | cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()[0] 111 | cam = cam.mean(dim=0).unsqueeze(0) 112 | cam[:, 0, 0] = 0 113 | return cam[:, 0] 114 | 115 | def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None): 116 | self.model.zero_grad() 117 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] 118 | blocks = self.model.bert.encoder.layer 119 | all_layer_attentions = [] 120 | for blk in blocks: 121 | attn_heads = blk.attention.self.get_attn() 122 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() 123 | all_layer_attentions.append(avg_heads) 124 | rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer) 125 | rollout[:, 0, 0] = 0 126 | return rollout[:, 0] 127 | 128 | def generate_attn_gradcam(self, input_ids, attention_mask, index=None): 129 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] 130 | kwargs = {"alpha": 1} 131 | 132 | if index == None: 133 | index = np.argmax(output.cpu().data.numpy(), axis=-1) 134 | 135 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 136 | one_hot[0, index] = 1 137 | one_hot_vector = one_hot 138 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 139 | one_hot = torch.sum(one_hot.cuda() * output) 140 | 141 | self.model.zero_grad() 142 | one_hot.backward(retain_graph=True) 143 | 144 | self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs) 145 | 146 | cam = self.model.bert.encoder.layer[-1].attention.self.get_attn() 147 | grad = self.model.bert.encoder.layer[-1].attention.self.get_attn_gradients() 148 | 149 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) 150 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) 151 | grad = grad.mean(dim=[1, 2], keepdim=True) 152 | cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0) 153 | cam = (cam - cam.min()) / (cam.max() - cam.min()) 154 | cam[:, 0, 0] = 0 155 | return cam[:, 0] 156 | 157 | -------------------------------------------------------------------------------- /BERT_explainability/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/BERT_explainability/modules/__init__.py -------------------------------------------------------------------------------- /BERT_explainability/modules/layers_lrp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', 6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', 7 | 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul'] 8 | 9 | 10 | def safe_divide(a, b): 11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9) 12 | den = den + den.eq(0).type(den.type()) * 1e-9 13 | return a / den * b.ne(0).type(b.type()) 14 | 15 | 16 | def forward_hook(self, input, output): 17 | if type(input[0]) in (list, tuple): 18 | self.X = [] 19 | for i in input[0]: 20 | x = i.detach() 21 | x.requires_grad = True 22 | self.X.append(x) 23 | else: 24 | self.X = input[0].detach() 25 | self.X.requires_grad = True 26 | 27 | self.Y = output 28 | 29 | 30 | def backward_hook(self, grad_input, grad_output): 31 | self.grad_input = grad_input 32 | self.grad_output = grad_output 33 | 34 | 35 | class RelProp(nn.Module): 36 | def __init__(self): 37 | super(RelProp, self).__init__() 38 | # if not self.training: 39 | self.register_forward_hook(forward_hook) 40 | 41 | def gradprop(self, Z, X, S): 42 | C = torch.autograd.grad(Z, X, S, retain_graph=True) 43 | return C 44 | 45 | def relprop(self, R, alpha): 46 | return R 47 | 48 | 49 | class RelPropSimple(RelProp): 50 | def relprop(self, R, alpha): 51 | Z = self.forward(self.X) 52 | S = safe_divide(R, Z) 53 | C = self.gradprop(Z, self.X, S) 54 | 55 | if torch.is_tensor(self.X) == False: 56 | outputs = [] 57 | outputs.append(self.X[0] * C[0]) 58 | outputs.append(self.X[1] * C[1]) 59 | else: 60 | outputs = self.X * (C[0]) 61 | return outputs 62 | 63 | class AddEye(RelPropSimple): 64 | # input of shape B, C, seq_len, seq_len 65 | def forward(self, input): 66 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) 67 | 68 | class ReLU(nn.ReLU, RelProp): 69 | pass 70 | 71 | class Tanh(nn.Tanh, RelProp): 72 | pass 73 | 74 | class GELU(nn.GELU, RelProp): 75 | pass 76 | 77 | class Softmax(nn.Softmax, RelProp): 78 | pass 79 | 80 | class LayerNorm(nn.LayerNorm, RelProp): 81 | pass 82 | 83 | class Dropout(nn.Dropout, RelProp): 84 | pass 85 | 86 | 87 | class MaxPool2d(nn.MaxPool2d, RelPropSimple): 88 | pass 89 | 90 | class LayerNorm(nn.LayerNorm, RelProp): 91 | pass 92 | 93 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): 94 | pass 95 | 96 | class MatMul(RelPropSimple): 97 | def forward(self, inputs): 98 | return torch.matmul(*inputs) 99 | 100 | class Mul(RelPropSimple): 101 | def forward(self, inputs): 102 | return torch.mul(*inputs) 103 | 104 | class AvgPool2d(nn.AvgPool2d, RelPropSimple): 105 | pass 106 | 107 | 108 | class Add(RelPropSimple): 109 | def forward(self, inputs): 110 | return torch.add(*inputs) 111 | 112 | class einsum(RelPropSimple): 113 | def __init__(self, equation): 114 | super().__init__() 115 | self.equation = equation 116 | def forward(self, *operands): 117 | return torch.einsum(self.equation, *operands) 118 | 119 | class IndexSelect(RelProp): 120 | def forward(self, inputs, dim, indices): 121 | self.__setattr__('dim', dim) 122 | self.__setattr__('indices', indices) 123 | 124 | return torch.index_select(inputs, dim, indices) 125 | 126 | def relprop(self, R, alpha): 127 | Z = self.forward(self.X, self.dim, self.indices) 128 | S = safe_divide(R, Z) 129 | C = self.gradprop(Z, self.X, S) 130 | 131 | if torch.is_tensor(self.X) == False: 132 | outputs = [] 133 | outputs.append(self.X[0] * C[0]) 134 | outputs.append(self.X[1] * C[1]) 135 | else: 136 | outputs = self.X * (C[0]) 137 | return outputs 138 | 139 | 140 | 141 | class Clone(RelProp): 142 | def forward(self, input, num): 143 | self.__setattr__('num', num) 144 | outputs = [] 145 | for _ in range(num): 146 | outputs.append(input) 147 | 148 | return outputs 149 | 150 | def relprop(self, R, alpha): 151 | Z = [] 152 | for _ in range(self.num): 153 | Z.append(self.X) 154 | S = [safe_divide(r, z) for r, z in zip(R, Z)] 155 | C = self.gradprop(Z, self.X, S)[0] 156 | 157 | R = self.X * C 158 | 159 | return R 160 | 161 | class Cat(RelProp): 162 | def forward(self, inputs, dim): 163 | self.__setattr__('dim', dim) 164 | return torch.cat(inputs, dim) 165 | 166 | def relprop(self, R, alpha): 167 | Z = self.forward(self.X, self.dim) 168 | S = safe_divide(R, Z) 169 | C = self.gradprop(Z, self.X, S) 170 | 171 | outputs = [] 172 | for x, c in zip(self.X, C): 173 | outputs.append(x * c) 174 | 175 | return outputs 176 | 177 | class Sequential(nn.Sequential): 178 | def relprop(self, R, alpha): 179 | for m in reversed(self._modules.values()): 180 | R = m.relprop(R, alpha) 181 | return R 182 | 183 | class BatchNorm2d(nn.BatchNorm2d, RelProp): 184 | def relprop(self, R, alpha): 185 | X = self.X 186 | beta = 1 - alpha 187 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( 188 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) 189 | Z = X * weight + 1e-9 190 | S = R / Z 191 | Ca = S * weight 192 | R = self.X * (Ca) 193 | return R 194 | 195 | 196 | class Linear(nn.Linear, RelProp): 197 | def relprop(self, R, alpha): 198 | beta = alpha - 1 199 | pw = torch.clamp(self.weight, min=0) 200 | nw = torch.clamp(self.weight, max=0) 201 | px = torch.clamp(self.X, min=0) 202 | nx = torch.clamp(self.X, max=0) 203 | 204 | def f(w1, w2, x1, x2): 205 | Z1 = F.linear(x1, w1) 206 | Z2 = F.linear(x2, w2) 207 | S1 = safe_divide(R, Z1) 208 | S2 = safe_divide(R, Z2) 209 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0] 210 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0] 211 | 212 | return C1 + C2 213 | 214 | activator_relevances = f(pw, nw, px, nx) 215 | inhibitor_relevances = f(nw, pw, px, nx) 216 | 217 | R = alpha * activator_relevances - beta * inhibitor_relevances 218 | 219 | return R 220 | 221 | class Conv2d(nn.Conv2d, RelProp): 222 | def gradprop2(self, DY, weight): 223 | Z = self.forward(self.X) 224 | 225 | output_padding = self.X.size()[2] - ( 226 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) 227 | 228 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) 229 | 230 | def relprop(self, R, alpha): 231 | if self.X.shape[1] == 3: 232 | pw = torch.clamp(self.weight, min=0) 233 | nw = torch.clamp(self.weight, max=0) 234 | X = self.X 235 | L = self.X * 0 + \ 236 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 237 | keepdim=True)[0] 238 | H = self.X * 0 + \ 239 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 240 | keepdim=True)[0] 241 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ 242 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ 243 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 244 | 245 | S = R / Za 246 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) 247 | R = C 248 | else: 249 | beta = alpha - 1 250 | pw = torch.clamp(self.weight, min=0) 251 | nw = torch.clamp(self.weight, max=0) 252 | px = torch.clamp(self.X, min=0) 253 | nx = torch.clamp(self.X, max=0) 254 | 255 | def f(w1, w2, x1, x2): 256 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) 257 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) 258 | S1 = safe_divide(R, Z1) 259 | S2 = safe_divide(R, Z2) 260 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 261 | C2 = x2 * self.gradprop(Z2, x2, S2)[0] 262 | return C1 + C2 263 | 264 | activator_relevances = f(pw, nw, px, nx) 265 | inhibitor_relevances = f(nw, pw, px, nx) 266 | 267 | R = alpha * activator_relevances - beta * inhibitor_relevances 268 | return R -------------------------------------------------------------------------------- /BERT_explainability/modules/layers_ours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', 6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', 7 | 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul'] 8 | 9 | 10 | def safe_divide(a, b): 11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9) 12 | den = den + den.eq(0).type(den.type()) * 1e-9 13 | return a / den * b.ne(0).type(b.type()) 14 | 15 | 16 | def forward_hook(self, input, output): 17 | if type(input[0]) in (list, tuple): 18 | self.X = [] 19 | for i in input[0]: 20 | x = i.detach() 21 | x.requires_grad = True 22 | self.X.append(x) 23 | else: 24 | self.X = input[0].detach() 25 | self.X.requires_grad = True 26 | 27 | self.Y = output 28 | 29 | 30 | def backward_hook(self, grad_input, grad_output): 31 | self.grad_input = grad_input 32 | self.grad_output = grad_output 33 | 34 | 35 | class RelProp(nn.Module): 36 | def __init__(self): 37 | super(RelProp, self).__init__() 38 | # if not self.training: 39 | self.register_forward_hook(forward_hook) 40 | 41 | def gradprop(self, Z, X, S): 42 | C = torch.autograd.grad(Z, X, S, retain_graph=True) 43 | return C 44 | 45 | def relprop(self, R, alpha): 46 | return R 47 | 48 | 49 | class RelPropSimple(RelProp): 50 | def relprop(self, R, alpha): 51 | Z = self.forward(self.X) 52 | S = safe_divide(R, Z) 53 | C = self.gradprop(Z, self.X, S) 54 | 55 | if torch.is_tensor(self.X) == False: 56 | outputs = [] 57 | outputs.append(self.X[0] * C[0]) 58 | outputs.append(self.X[1] * C[1]) 59 | else: 60 | outputs = self.X * (C[0]) 61 | return outputs 62 | 63 | class AddEye(RelPropSimple): 64 | # input of shape B, C, seq_len, seq_len 65 | def forward(self, input): 66 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) 67 | 68 | class ReLU(nn.ReLU, RelProp): 69 | pass 70 | 71 | class GELU(nn.GELU, RelProp): 72 | pass 73 | 74 | class Softmax(nn.Softmax, RelProp): 75 | pass 76 | 77 | class Mul(RelPropSimple): 78 | def forward(self, inputs): 79 | return torch.mul(*inputs) 80 | 81 | class Tanh(nn.Tanh, RelProp): 82 | pass 83 | class LayerNorm(nn.LayerNorm, RelProp): 84 | pass 85 | 86 | class Dropout(nn.Dropout, RelProp): 87 | pass 88 | 89 | class MatMul(RelPropSimple): 90 | def forward(self, inputs): 91 | return torch.matmul(*inputs) 92 | 93 | class MaxPool2d(nn.MaxPool2d, RelPropSimple): 94 | pass 95 | 96 | class LayerNorm(nn.LayerNorm, RelProp): 97 | pass 98 | 99 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): 100 | pass 101 | 102 | 103 | class AvgPool2d(nn.AvgPool2d, RelPropSimple): 104 | pass 105 | 106 | 107 | class Add(RelPropSimple): 108 | def forward(self, inputs): 109 | return torch.add(*inputs) 110 | 111 | def relprop(self, R, alpha): 112 | Z = self.forward(self.X) 113 | S = safe_divide(R, Z) 114 | C = self.gradprop(Z, self.X, S) 115 | 116 | a = self.X[0] * C[0] 117 | b = self.X[1] * C[1] 118 | 119 | a_sum = a.sum() 120 | b_sum = b.sum() 121 | 122 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 123 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 124 | 125 | a = a * safe_divide(a_fact, a.sum()) 126 | b = b * safe_divide(b_fact, b.sum()) 127 | 128 | outputs = [a, b] 129 | 130 | return outputs 131 | 132 | class einsum(RelPropSimple): 133 | def __init__(self, equation): 134 | super().__init__() 135 | self.equation = equation 136 | def forward(self, *operands): 137 | return torch.einsum(self.equation, *operands) 138 | 139 | class IndexSelect(RelProp): 140 | def forward(self, inputs, dim, indices): 141 | self.__setattr__('dim', dim) 142 | self.__setattr__('indices', indices) 143 | 144 | return torch.index_select(inputs, dim, indices) 145 | 146 | def relprop(self, R, alpha): 147 | Z = self.forward(self.X, self.dim, self.indices) 148 | S = safe_divide(R, Z) 149 | C = self.gradprop(Z, self.X, S) 150 | 151 | if torch.is_tensor(self.X) == False: 152 | outputs = [] 153 | outputs.append(self.X[0] * C[0]) 154 | outputs.append(self.X[1] * C[1]) 155 | else: 156 | outputs = self.X * (C[0]) 157 | return outputs 158 | 159 | 160 | 161 | class Clone(RelProp): 162 | def forward(self, input, num): 163 | self.__setattr__('num', num) 164 | outputs = [] 165 | for _ in range(num): 166 | outputs.append(input) 167 | 168 | return outputs 169 | 170 | def relprop(self, R, alpha): 171 | Z = [] 172 | for _ in range(self.num): 173 | Z.append(self.X) 174 | S = [safe_divide(r, z) for r, z in zip(R, Z)] 175 | C = self.gradprop(Z, self.X, S)[0] 176 | 177 | R = self.X * C 178 | 179 | return R 180 | 181 | 182 | class Cat(RelProp): 183 | def forward(self, inputs, dim): 184 | self.__setattr__('dim', dim) 185 | return torch.cat(inputs, dim) 186 | 187 | def relprop(self, R, alpha): 188 | Z = self.forward(self.X, self.dim) 189 | S = safe_divide(R, Z) 190 | C = self.gradprop(Z, self.X, S) 191 | 192 | outputs = [] 193 | for x, c in zip(self.X, C): 194 | outputs.append(x * c) 195 | 196 | return outputs 197 | 198 | 199 | class Sequential(nn.Sequential): 200 | def relprop(self, R, alpha): 201 | for m in reversed(self._modules.values()): 202 | R = m.relprop(R, alpha) 203 | return R 204 | 205 | 206 | class BatchNorm2d(nn.BatchNorm2d, RelProp): 207 | def relprop(self, R, alpha): 208 | X = self.X 209 | beta = 1 - alpha 210 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( 211 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) 212 | Z = X * weight + 1e-9 213 | S = R / Z 214 | Ca = S * weight 215 | R = self.X * (Ca) 216 | return R 217 | 218 | 219 | class Linear(nn.Linear, RelProp): 220 | def relprop(self, R, alpha): 221 | beta = alpha - 1 222 | pw = torch.clamp(self.weight, min=0) 223 | nw = torch.clamp(self.weight, max=0) 224 | px = torch.clamp(self.X, min=0) 225 | nx = torch.clamp(self.X, max=0) 226 | 227 | def f(w1, w2, x1, x2): 228 | Z1 = F.linear(x1, w1) 229 | Z2 = F.linear(x2, w2) 230 | S1 = safe_divide(R, Z1 + Z2) 231 | S2 = safe_divide(R, Z1 + Z2) 232 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 233 | C2 = x2 * self.gradprop(Z2, x2, S2)[0] 234 | 235 | return C1 + C2 236 | 237 | activator_relevances = f(pw, nw, px, nx) 238 | inhibitor_relevances = f(nw, pw, px, nx) 239 | 240 | R = alpha * activator_relevances - beta * inhibitor_relevances 241 | 242 | return R 243 | 244 | 245 | class Conv2d(nn.Conv2d, RelProp): 246 | def gradprop2(self, DY, weight): 247 | Z = self.forward(self.X) 248 | 249 | output_padding = self.X.size()[2] - ( 250 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) 251 | 252 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) 253 | 254 | def relprop(self, R, alpha): 255 | if self.X.shape[1] == 3: 256 | pw = torch.clamp(self.weight, min=0) 257 | nw = torch.clamp(self.weight, max=0) 258 | X = self.X 259 | L = self.X * 0 + \ 260 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 261 | keepdim=True)[0] 262 | H = self.X * 0 + \ 263 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 264 | keepdim=True)[0] 265 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ 266 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ 267 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 268 | 269 | S = R / Za 270 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) 271 | R = C 272 | else: 273 | beta = alpha - 1 274 | pw = torch.clamp(self.weight, min=0) 275 | nw = torch.clamp(self.weight, max=0) 276 | px = torch.clamp(self.X, min=0) 277 | nx = torch.clamp(self.X, max=0) 278 | 279 | def f(w1, w2, x1, x2): 280 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) 281 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) 282 | S1 = safe_divide(R, Z1) 283 | S2 = safe_divide(R, Z2) 284 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 285 | C2 = x2 * self.gradprop(Z2, x2, S2)[0] 286 | return C1 + C2 287 | 288 | activator_relevances = f(pw, nw, px, nx) 289 | inhibitor_relevances = f(nw, pw, px, nx) 290 | 291 | R = alpha * activator_relevances - beta * inhibitor_relevances 292 | return R -------------------------------------------------------------------------------- /BERT_params/boolq.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.05 5 | }, 6 | "evidence_identifier": { 7 | "mlp_size": 128, 8 | "dropout": 0.2, 9 | "batch_size": 768, 10 | "epochs": 50, 11 | "patience": 10, 12 | "lr": 1e-3, 13 | "sampling_method": "random", 14 | "sampling_ratio": 1.0 15 | }, 16 | "evidence_classifier": { 17 | "classes": [ "False", "True" ], 18 | "mlp_size": 128, 19 | "dropout": 0.2, 20 | "batch_size": 768, 21 | "epochs": 50, 22 | "patience": 10, 23 | "lr": 1e-3, 24 | "sampling_method": "everything" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /BERT_params/boolq_baas.json: -------------------------------------------------------------------------------- 1 | { 2 | "start_server": 0, 3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/", 4 | "max_length": 512, 5 | "pooling_strategy": "CLS_TOKEN", 6 | "evidence_identifier": { 7 | "batch_size": 64, 8 | "epochs": 3, 9 | "patience": 10, 10 | "lr": 1e-3, 11 | "max_grad_norm": 1.0, 12 | "sampling_method": "random", 13 | "sampling_ratio": 1.0 14 | }, 15 | "evidence_classifier": { 16 | "classes": [ "False", "True" ], 17 | "batch_size": 64, 18 | "epochs": 3, 19 | "patience": 10, 20 | "lr": 1e-3, 21 | "max_grad_norm": 1.0, 22 | "sampling_method": "everything" 23 | } 24 | } 25 | 26 | 27 | -------------------------------------------------------------------------------- /BERT_params/boolq_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 1, 6 | "use_evidence_token_identifier": 0, 7 | "evidence_identifier": { 8 | "batch_size": 10, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 50, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1, 14 | "sampling_method": "random", 15 | "sampling_ratio": 1, 16 | "use_half_precision": 0 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ 20 | "False", 21 | "True" 22 | ], 23 | "batch_size": 10, 24 | "warmup_steps": 50, 25 | "epochs": 10, 26 | "patience": 10, 27 | "lr": 1e-05, 28 | "max_grad_norm": 1, 29 | "sampling_method": "everything", 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /BERT_params/boolq_soft.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.2 5 | }, 6 | "classifier": { 7 | "classes": [ "False", "True" ], 8 | "has_query": 1, 9 | "hidden_size": 32, 10 | "mlp_size": 128, 11 | "dropout": 0.2, 12 | "batch_size": 16, 13 | "epochs": 50, 14 | "attention_epochs": 50, 15 | "patience": 10, 16 | "lr": 1e-3, 17 | "dropout": 0.2, 18 | "k_fraction": 0.07, 19 | "threshold": 0.1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /BERT_params/cose_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 0, 6 | "use_evidence_token_identifier": 1, 7 | "evidence_token_identifier": { 8 | "batch_size": 32, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 10, 12 | "lr": 1e-05, 13 | "max_grad_norm": 0.5, 14 | "sampling_method": "everything", 15 | "use_half_precision": 0, 16 | "cose_data_hack": 1 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ "false", "true"], 20 | "batch_size": 32, 21 | "warmup_steps": 10, 22 | "epochs": 10, 23 | "patience": 10, 24 | "lr": 1e-05, 25 | "max_grad_norm": 0.5, 26 | "sampling_method": "everything", 27 | "use_half_precision": 0, 28 | "cose_data_hack": 1 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /BERT_params/cose_multiclass.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 1, 6 | "use_evidence_token_identifier": 0, 7 | "evidence_identifier": { 8 | "batch_size": 32, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 50, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1, 14 | "sampling_method": "random", 15 | "sampling_ratio": 1, 16 | "use_half_precision": 0 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ 20 | "A", 21 | "B", 22 | "C", 23 | "D", 24 | "E" 25 | ], 26 | "batch_size": 10, 27 | "warmup_steps": 50, 28 | "epochs": 10, 29 | "patience": 10, 30 | "lr": 1e-05, 31 | "max_grad_norm": 1, 32 | "sampling_method": "everything", 33 | "use_half_precision": 0 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /BERT_params/esnli_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 0, 6 | "use_evidence_token_identifier": 1, 7 | "evidence_token_identifier": { 8 | "batch_size": 32, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 10, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1, 14 | "sampling_method": "everything", 15 | "use_half_precision": 0 16 | }, 17 | "evidence_classifier": { 18 | "classes": [ "contradiction", "neutral", "entailment" ], 19 | "batch_size": 32, 20 | "warmup_steps": 10, 21 | "epochs": 10, 22 | "patience": 10, 23 | "lr": 1e-05, 24 | "max_grad_norm": 1, 25 | "sampling_method": "everything", 26 | "use_half_precision": 0 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /BERT_params/evidence_inference.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/PubMed-w2v.bin", 4 | "dropout": 0.05 5 | }, 6 | "evidence_identifier": { 7 | "mlp_size": 128, 8 | "dropout": 0.05, 9 | "batch_size": 768, 10 | "epochs": 50, 11 | "patience": 10, 12 | "lr": 1e-3, 13 | "sampling_method": "random", 14 | "sampling_ratio": 1.0 15 | }, 16 | "evidence_classifier": { 17 | "classes": [ "significantly decreased", "no significant difference", "significantly increased" ], 18 | "mlp_size": 128, 19 | "dropout": 0.05, 20 | "batch_size": 768, 21 | "epochs": 50, 22 | "patience": 10, 23 | "lr": 1e-3, 24 | "sampling_method": "everything" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /BERT_params/evidence_inference_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "allenai/scibert_scivocab_uncased", 4 | "bert_dir": "allenai/scibert_scivocab_uncased", 5 | "use_evidence_sentence_identifier": 1, 6 | "use_evidence_token_identifier": 0, 7 | "evidence_identifier": { 8 | "batch_size": 10, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 10, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1, 14 | "sampling_method": "random", 15 | "use_half_precision": 0, 16 | "sampling_ratio": 1 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ 20 | "significantly decreased", 21 | "no significant difference", 22 | "significantly increased" 23 | ], 24 | "batch_size": 10, 25 | "warmup_steps": 10, 26 | "epochs": 10, 27 | "patience": 10, 28 | "lr": 1e-05, 29 | "max_grad_norm": 1, 30 | "sampling_method": "everything", 31 | "use_half_precision": 0 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /BERT_params/evidence_inference_soft.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/PubMed-w2v.bin", 4 | "dropout": 0.2 5 | }, 6 | "classifier": { 7 | "classes": [ "significantly decreased", "no significant difference", "significantly increased" ], 8 | "use_token_selection": 1, 9 | "has_query": 1, 10 | "hidden_size": 32, 11 | "mlp_size": 128, 12 | "dropout": 0.2, 13 | "batch_size": 16, 14 | "epochs": 50, 15 | "attention_epochs": 0, 16 | "patience": 10, 17 | "lr": 1e-3, 18 | "dropout": 0.2, 19 | "k_fraction": 0.013, 20 | "threshold": 0.1 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /BERT_params/fever.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.05 5 | }, 6 | "evidence_identifier": { 7 | "mlp_size": 128, 8 | "dropout": 0.05, 9 | "batch_size": 768, 10 | "epochs": 50, 11 | "patience": 10, 12 | "lr": 1e-3, 13 | "sampling_method": "random", 14 | "sampling_ratio": 1.0 15 | }, 16 | "evidence_classifier": { 17 | "classes": [ "SUPPORTS", "REFUTES" ], 18 | "mlp_size": 128, 19 | "dropout": 0.05, 20 | "batch_size": 768, 21 | "epochs": 50, 22 | "patience": 10, 23 | "lr": 1e-5, 24 | "sampling_method": "everything" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /BERT_params/fever_baas.json: -------------------------------------------------------------------------------- 1 | { 2 | "start_server": 0, 3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/", 4 | "max_length": 512, 5 | "pooling_strategy": "CLS_TOKEN", 6 | "evidence_identifier": { 7 | "batch_size": 64, 8 | "epochs": 3, 9 | "patience": 10, 10 | "lr": 1e-3, 11 | "max_grad_norm": 1.0, 12 | "sampling_method": "random", 13 | "sampling_ratio": 1.0 14 | }, 15 | "evidence_classifier": { 16 | "classes": [ "SUPPORTS", "REFUTES" ], 17 | "batch_size": 64, 18 | "epochs": 3, 19 | "patience": 10, 20 | "lr": 1e-3, 21 | "max_grad_norm": 1.0, 22 | "sampling_method": "everything" 23 | } 24 | } 25 | 26 | -------------------------------------------------------------------------------- /BERT_params/fever_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 1, 6 | "use_evidence_token_identifier": 0, 7 | "evidence_identifier": { 8 | "batch_size": 16, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 10, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1.0, 14 | "sampling_method": "random", 15 | "sampling_ratio": 1.0, 16 | "use_half_precision": 0 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ 20 | "SUPPORTS", 21 | "REFUTES" 22 | ], 23 | "batch_size": 10, 24 | "warmup_steps": 10, 25 | "epochs": 10, 26 | "patience": 10, 27 | "lr": 1e-05, 28 | "max_grad_norm": 1.0, 29 | "sampling_method": "everything", 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /BERT_params/fever_soft.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.2 5 | }, 6 | "classifier": { 7 | "classes": [ "SUPPORTS", "REFUTES" ], 8 | "has_query": 1, 9 | "hidden_size": 32, 10 | "mlp_size": 128, 11 | "dropout": 0.2, 12 | "batch_size": 128, 13 | "epochs": 50, 14 | "attention_epochs": 50, 15 | "patience": 10, 16 | "lr": 1e-3, 17 | "dropout": 0.2, 18 | "k_fraction": 0.07, 19 | "threshold": 0.1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /BERT_params/movies.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.05 5 | }, 6 | "evidence_identifier": { 7 | "mlp_size": 128, 8 | "dropout": 0.05, 9 | "batch_size": 768, 10 | "epochs": 50, 11 | "patience": 10, 12 | "lr": 1e-4, 13 | "sampling_method": "random", 14 | "sampling_ratio": 1.0 15 | }, 16 | "evidence_classifier": { 17 | "classes": [ "NEG", "POS" ], 18 | "mlp_size": 128, 19 | "dropout": 0.05, 20 | "batch_size": 768, 21 | "epochs": 50, 22 | "patience": 10, 23 | "lr": 1e-3, 24 | "sampling_method": "everything" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /BERT_params/movies_baas.json: -------------------------------------------------------------------------------- 1 | { 2 | "start_server": 0, 3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/", 4 | "max_length": 512, 5 | "pooling_strategy": "CLS_TOKEN", 6 | "evidence_identifier": { 7 | "batch_size": 64, 8 | "epochs": 3, 9 | "patience": 10, 10 | "lr": 1e-3, 11 | "max_grad_norm": 1.0, 12 | "sampling_method": "random", 13 | "sampling_ratio": 1.0 14 | }, 15 | "evidence_classifier": { 16 | "classes": [ "NEG", "POS" ], 17 | "batch_size": 64, 18 | "epochs": 3, 19 | "patience": 10, 20 | "lr": 1e-3, 21 | "max_grad_norm": 1.0, 22 | "sampling_method": "everything" 23 | } 24 | } 25 | 26 | 27 | -------------------------------------------------------------------------------- /BERT_params/movies_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 1, 6 | "use_evidence_token_identifier": 0, 7 | "evidence_identifier": { 8 | "batch_size": 16, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 50, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1, 14 | "sampling_method": "random", 15 | "sampling_ratio": 1, 16 | "use_half_precision": 0 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ 20 | "NEG", 21 | "POS" 22 | ], 23 | "batch_size": 10, 24 | "warmup_steps": 50, 25 | "epochs": 10, 26 | "patience": 10, 27 | "lr": 1e-05, 28 | "max_grad_norm": 1, 29 | "sampling_method": "everything", 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /BERT_params/movies_soft.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.2 5 | }, 6 | "classifier": { 7 | "classes": [ "NEG", "POS" ], 8 | "has_query": 0, 9 | "hidden_size": 32, 10 | "mlp_size": 128, 11 | "dropout": 0.2, 12 | "batch_size": 16, 13 | "epochs": 50, 14 | "attention_epochs": 50, 15 | "patience": 10, 16 | "lr": 1e-3, 17 | "dropout": 0.2, 18 | "k_fraction": 0.07, 19 | "threshold": 0.1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /BERT_params/multirc.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.05 5 | }, 6 | "evidence_identifier": { 7 | "mlp_size": 128, 8 | "dropout": 0.05, 9 | "batch_size": 768, 10 | "epochs": 50, 11 | "patience": 10, 12 | "lr": 1e-3, 13 | "sampling_method": "random", 14 | "sampling_ratio": 1.0 15 | }, 16 | "evidence_classifier": { 17 | "classes": [ "False", "True" ], 18 | "mlp_size": 128, 19 | "dropout": 0.05, 20 | "batch_size": 768, 21 | "epochs": 50, 22 | "patience": 10, 23 | "lr": 1e-3, 24 | "sampling_method": "everything" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /BERT_params/multirc_baas.json: -------------------------------------------------------------------------------- 1 | { 2 | "start_server": 0, 3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/", 4 | "max_length": 512, 5 | "pooling_strategy": "CLS_TOKEN", 6 | "evidence_identifier": { 7 | "batch_size": 64, 8 | "epochs": 3, 9 | "patience": 10, 10 | "lr": 1e-3, 11 | "max_grad_norm": 1.0, 12 | "sampling_method": "random", 13 | "sampling_ratio": 1.0 14 | }, 15 | "evidence_classifier": { 16 | "classes": [ "False", "True" ], 17 | "batch_size": 64, 18 | "epochs": 3, 19 | "patience": 10, 20 | "lr": 1e-3, 21 | "max_grad_norm": 1.0, 22 | "sampling_method": "everything" 23 | } 24 | } 25 | 26 | 27 | -------------------------------------------------------------------------------- /BERT_params/multirc_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "bert_vocab": "bert-base-uncased", 4 | "bert_dir": "bert-base-uncased", 5 | "use_evidence_sentence_identifier": 1, 6 | "use_evidence_token_identifier": 0, 7 | "evidence_identifier": { 8 | "batch_size": 32, 9 | "epochs": 10, 10 | "patience": 10, 11 | "warmup_steps": 50, 12 | "lr": 1e-05, 13 | "max_grad_norm": 1, 14 | "sampling_method": "random", 15 | "sampling_ratio": 1, 16 | "use_half_precision": 0 17 | }, 18 | "evidence_classifier": { 19 | "classes": [ 20 | "False", 21 | "True" 22 | ], 23 | "batch_size": 32, 24 | "warmup_steps": 50, 25 | "epochs": 10, 26 | "patience": 10, 27 | "lr": 1e-05, 28 | "max_grad_norm": 1, 29 | "sampling_method": "everything", 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /BERT_params/multirc_soft.json: -------------------------------------------------------------------------------- 1 | { 2 | "embeddings": { 3 | "embedding_file": "model_components/glove.6B.200d.txt", 4 | "dropout": 0.2 5 | }, 6 | "classifier": { 7 | "classes": [ "False", "True" ], 8 | "has_query": 1, 9 | "hidden_size": 32, 10 | "mlp_size": 128, 11 | "dropout": 0.2, 12 | "batch_size": 16, 13 | "epochs": 50, 14 | "attention_epochs": 50, 15 | "patience": 10, 16 | "lr": 1e-3, 17 | "dropout": 0.2, 18 | "k_fraction": 0.07, 19 | "threshold": 0.1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /BERT_rationale_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/BERT_rationale_benchmark/__init__.py -------------------------------------------------------------------------------- /BERT_rationale_benchmark/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Set 3 | 4 | import numpy as np 5 | from gensim.models import KeyedVectors 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence 10 | 11 | 12 | @dataclass(eq=True, frozen=True) 13 | class PaddedSequence: 14 | """A utility class for padding variable length sequences mean for RNN input 15 | This class is in the style of PackedSequence from the PyTorch RNN Utils, 16 | but is somewhat more manual in approach. It provides the ability to generate masks 17 | for outputs of the same input dimensions. 18 | The constructor should never be called directly and should only be called via 19 | the autopad classmethod. 20 | 21 | We'd love to delete this, but we pad_sequence, pack_padded_sequence, and 22 | pad_packed_sequence all require shuffling around tuples of information, and some 23 | convenience methods using these are nice to have. 24 | """ 25 | 26 | data: torch.Tensor 27 | batch_sizes: torch.Tensor 28 | batch_first: bool = False 29 | 30 | @classmethod 31 | def autopad(cls, data, batch_first: bool = False, padding_value=0, device=None) -> 'PaddedSequence': 32 | # handle tensors of size 0 (single item) 33 | data_ = [] 34 | for d in data: 35 | if len(d.size()) == 0: 36 | d = d.unsqueeze(0) 37 | data_.append(d) 38 | padded = pad_sequence(data_, batch_first=batch_first, padding_value=padding_value) 39 | if batch_first: 40 | batch_lengths = torch.LongTensor([len(x) for x in data_]) 41 | if any([x == 0 for x in batch_lengths]): 42 | raise ValueError( 43 | "Found a 0 length batch element, this can't possibly be right: {}".format(batch_lengths)) 44 | else: 45 | # TODO actually test this codepath 46 | batch_lengths = torch.LongTensor([len(x) for x in data]) 47 | return PaddedSequence(padded, batch_lengths, batch_first).to(device=device) 48 | 49 | def pack_other(self, data: torch.Tensor): 50 | return pack_padded_sequence(data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False) 51 | 52 | @classmethod 53 | def from_packed_sequence(cls, ps: PackedSequence, batch_first: bool, padding_value=0) -> 'PaddedSequence': 54 | padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value) 55 | return PaddedSequence(padded, batch_sizes, batch_first) 56 | 57 | def cuda(self) -> 'PaddedSequence': 58 | return PaddedSequence(self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first) 59 | 60 | def to(self, dtype=None, device=None, copy=False, non_blocking=False) -> 'PaddedSequence': 61 | # TODO make to() support all of the torch.Tensor to() variants 62 | return PaddedSequence( 63 | self.data.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking), 64 | self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking), 65 | batch_first=self.batch_first) 66 | 67 | def mask(self, on=int(0), off=int(0), device='cpu', size=None, dtype=None) -> torch.Tensor: 68 | if size is None: 69 | size = self.data.size() 70 | out_tensor = torch.zeros(*size, dtype=dtype) 71 | # TODO this can be done more efficiently 72 | out_tensor.fill_(off) 73 | # note to self: these are probably less efficient than explicilty populating the off values instead of the on values. 74 | if self.batch_first: 75 | for i, bl in enumerate(self.batch_sizes): 76 | out_tensor[i, :bl] = on 77 | else: 78 | for i, bl in enumerate(self.batch_sizes): 79 | out_tensor[:bl, i] = on 80 | return out_tensor.to(device) 81 | 82 | def unpad(self, other: torch.Tensor) -> List[torch.Tensor]: 83 | out = [] 84 | for o, bl in zip(other, self.batch_sizes): 85 | out.append(o[:bl]) 86 | return out 87 | 88 | def flip(self) -> 'PaddedSequence': 89 | return PaddedSequence(self.data.transpose(0, 1), not self.batch_first, self.padding_value) 90 | 91 | 92 | def extract_embeddings(vocab: Set[str], embedding_file: str, unk_token: str = 'UNK', pad_token: str = 'PAD') -> ( 93 | nn.Embedding, Dict[str, int], List[str]): 94 | vocab = vocab | set([unk_token, pad_token]) 95 | if embedding_file.endswith('.bin'): 96 | WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True) 97 | 98 | word_to_vector = dict() 99 | WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()]) 100 | 101 | if unk_token not in WVs: 102 | mean_vector = np.mean(WV_matrix, axis=0) 103 | word_to_vector[unk_token] = mean_vector 104 | if pad_token not in WVs: 105 | word_to_vector[pad_token] = np.zeros(WVs.vector_size) 106 | 107 | for v in vocab: 108 | if v in WVs: 109 | word_to_vector[v] = WVs[v] 110 | 111 | interner = dict() 112 | deinterner = list() 113 | vectors = [] 114 | count = 0 115 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})): 116 | vector = word_to_vector[word] 117 | vectors.append(np.array(vector)) 118 | interner[word] = count 119 | deinterner.append(word) 120 | count += 1 121 | vectors = torch.FloatTensor(np.array(vectors)) 122 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token]) 123 | embedding.weight.requires_grad = False 124 | return embedding, interner, deinterner 125 | elif embedding_file.endswith('.txt'): 126 | word_to_vector = dict() 127 | vector = [] 128 | with open(embedding_file, 'r') as inf: 129 | for line in inf: 130 | contents = line.strip().split() 131 | word = contents[0] 132 | vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0) 133 | word_to_vector[word] = vector 134 | embed_size = vector.size() 135 | if unk_token not in word_to_vector: 136 | mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0) 137 | word_to_vector[unk_token] = mean_vector.unsqueeze(0) 138 | if pad_token not in word_to_vector: 139 | word_to_vector[pad_token] = torch.zeros(embed_size) 140 | interner = dict() 141 | deinterner = list() 142 | vectors = [] 143 | count = 0 144 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})): 145 | vector = word_to_vector[word] 146 | vectors.append(vector) 147 | interner[word] = count 148 | deinterner.append(word) 149 | count += 1 150 | vectors = torch.cat(vectors, dim=0) 151 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token]) 152 | embedding.weight.requires_grad = False 153 | return embedding, interner, deinterner 154 | else: 155 | raise ValueError("Unable to open embeddings file {}".format(embedding_file)) 156 | -------------------------------------------------------------------------------- /BERT_rationale_benchmark/models/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/BERT_rationale_benchmark/models/pipeline/__init__.py -------------------------------------------------------------------------------- /BERT_rationale_benchmark/models/pipeline/pipeline_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import random 5 | import os 6 | 7 | from itertools import chain 8 | from typing import Set 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from rationale_benchmark.utils import ( 14 | write_jsonl, 15 | load_datasets, 16 | load_documents, 17 | intern_documents, 18 | intern_annotations 19 | ) 20 | from rationale_benchmark.models.mlp import ( 21 | AttentiveClassifier, 22 | BahadanauAttention, 23 | RNNEncoder, 24 | WordEmbedder 25 | ) 26 | from rationale_benchmark.models.model_utils import extract_embeddings 27 | from rationale_benchmark.models.pipeline.evidence_identifier import train_evidence_identifier 28 | from rationale_benchmark.models.pipeline.evidence_classifier import train_evidence_classifier 29 | from rationale_benchmark.models.pipeline.pipeline_utils import decode 30 | 31 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') 32 | # let's make this more or less deterministic (not resistant to restarts) 33 | random.seed(12345) 34 | np.random.seed(67890) 35 | torch.manual_seed(10111213) 36 | torch.backends.cudnn.deterministic = True 37 | torch.backends.cudnn.benchmark = False 38 | 39 | 40 | def initialize_models(params: dict, vocab: Set[str], batch_first: bool, unk_token='UNK'): 41 | # TODO this is obviously asking for some sort of dependency injection. implement if it saves me time. 42 | if 'embedding_file' in params['embeddings']: 43 | embeddings, word_interner, de_interner = extract_embeddings(vocab, params['embeddings']['embedding_file'], unk_token=unk_token) 44 | if torch.cuda.is_available(): 45 | embeddings = embeddings.cuda() 46 | else: 47 | raise ValueError("No 'embedding_file' found in params!") 48 | word_embedder = WordEmbedder(embeddings, params['embeddings']['dropout']) 49 | query_encoder = RNNEncoder(word_embedder, 50 | batch_first=batch_first, 51 | condition=False, 52 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension)) 53 | document_encoder = RNNEncoder(word_embedder, 54 | batch_first=batch_first, 55 | condition=True, 56 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension, 57 | query_size=query_encoder.output_dimension)) 58 | evidence_identifier = AttentiveClassifier(document_encoder, 59 | query_encoder, 60 | 2, 61 | params['evidence_identifier']['mlp_size'], 62 | params['evidence_identifier']['dropout']) 63 | query_encoder = RNNEncoder(word_embedder, 64 | batch_first=batch_first, 65 | condition=False, 66 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension)) 67 | document_encoder = RNNEncoder(word_embedder, 68 | batch_first=batch_first, 69 | condition=True, 70 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension, 71 | query_size=query_encoder.output_dimension)) 72 | evidence_classes = dict((y,x) for (x,y) in enumerate(params['evidence_classifier']['classes'])) 73 | evidence_classifier = AttentiveClassifier(document_encoder, 74 | query_encoder, 75 | len(evidence_classes), 76 | params['evidence_classifier']['mlp_size'], 77 | params['evidence_classifier']['dropout']) 78 | return evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser(description="""Trains a pipeline model. 83 | 84 | Step 1 is evidence identification, that is identify if a given sentence is evidence or not 85 | Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task (e.g. sentiment or significance). 86 | 87 | These models should be separated into two separate steps, but at the moment: 88 | * prep data (load, intern documents, load json) 89 | * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives 90 | * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values. 91 | * train evidence identification 92 | * convert data for evidence classification - take all rationales + decisions and use this as input 93 | * train evidence classification 94 | * decode first the evidence, then run classification for each split 95 | 96 | """, formatter_class=argparse.RawTextHelpFormatter) 97 | parser.add_argument('--data_dir', dest='data_dir', required=True, 98 | help='Which directory contains a {train,val,test}.jsonl file?') 99 | parser.add_argument('--output_dir', dest='output_dir', required=True, 100 | help='Where shall we write intermediate models + final data to?') 101 | parser.add_argument('--model_params', dest='model_params', required=True, 102 | help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.') 103 | args = parser.parse_args() 104 | BATCH_FIRST = True 105 | 106 | with open(args.model_params, 'r') as fp: 107 | logging.debug(f'Loading model parameters from {args.model_params}') 108 | model_params = json.load(fp) 109 | train, val, test = load_datasets(args.data_dir) 110 | docids = set(e.docid for e in chain.from_iterable(chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test))))) 111 | documents = load_documents(args.data_dir, docids) 112 | document_vocab = set(chain.from_iterable(chain.from_iterable(documents.values()))) 113 | annotation_vocab = set(chain.from_iterable(e.query.split() for e in chain(train, val, test))) 114 | logging.debug(f'Loaded {len(documents)} documents with {len(document_vocab)} unique words') 115 | # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important 116 | vocab = document_vocab | annotation_vocab 117 | unk_token = 'UNK' 118 | evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes = \ 119 | initialize_models(model_params, vocab, batch_first=BATCH_FIRST, unk_token=unk_token) 120 | logging.debug(f'Including annotations, we have {len(vocab)} total words in the data, with embeddings for {len(word_interner)}') 121 | interned_documents = intern_documents(documents, word_interner, unk_token) 122 | interned_train = intern_annotations(train, word_interner, unk_token) 123 | interned_val = intern_annotations(val, word_interner, unk_token) 124 | interned_test = intern_annotations(test, word_interner, unk_token) 125 | assert BATCH_FIRST # for correctness of the split dimension for DataParallel 126 | evidence_identifier, evidence_ident_results = train_evidence_identifier(evidence_identifier.cuda(), 127 | args.output_dir, interned_train, 128 | interned_val, 129 | interned_documents, 130 | model_params, 131 | tensorize_model_inputs=True) 132 | evidence_classifier, evidence_class_results = train_evidence_classifier(evidence_classifier.cuda(), 133 | args.output_dir, 134 | interned_train, 135 | interned_val, 136 | interned_documents, 137 | model_params, 138 | class_interner=evidence_classes, 139 | tensorize_model_inputs=True) 140 | pipeline_batch_size = min([model_params['evidence_classifier']['batch_size'], 141 | model_params['evidence_identifier']['batch_size']]) 142 | pipeline_results, train_decoded, val_decoded, test_decoded = decode(evidence_identifier, 143 | evidence_classifier, 144 | interned_train, 145 | interned_val, 146 | interned_test, 147 | interned_documents, 148 | evidence_classes, 149 | pipeline_batch_size, 150 | tensorize_model_inputs=True) 151 | write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl')) 152 | write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl')) 153 | write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl')) 154 | with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \ 155 | open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output: 156 | ident_output.write(json.dumps(evidence_ident_results)) 157 | class_output.write(json.dumps(evidence_class_results)) 158 | for k, v in pipeline_results.items(): 159 | if type(v) is dict: 160 | for k1, v1 in v.items(): 161 | logging.info(f'Pipeline results for {k}, {k1}={v1}') 162 | else: 163 | logging.info(f'Pipeline results {k}\t={v}') 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /BERT_rationale_benchmark/models/sequence_taggers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import List, Tuple, Any 5 | 6 | from transformers import BertModel 7 | 8 | from rationale_benchmark.models.model_utils import PaddedSequence 9 | 10 | 11 | class BertTagger(nn.Module): 12 | def __init__(self, 13 | bert_dir: str, 14 | pad_token_id: int, 15 | cls_token_id: int, 16 | sep_token_id: int, 17 | max_length: int=512, 18 | use_half_precision=True): 19 | super(BertTagger, self).__init__() 20 | self.sep_token_id = sep_token_id 21 | self.cls_token_id = cls_token_id 22 | self.pad_token_id = pad_token_id 23 | self.max_length = max_length 24 | bert = BertModel.from_pretrained(bert_dir) 25 | if use_half_precision: 26 | import apex 27 | bert = bert.half() 28 | self.bert = bert 29 | self.relevance_tagger = nn.Sequential( 30 | nn.Linear(self.bert.config.hidden_size, 1), 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, 35 | query: List[torch.tensor], 36 | docids: List[Any], 37 | document_batch: List[torch.tensor], 38 | aggregate_spans: List[Tuple[int, int]]): 39 | assert len(query) == len(document_batch) 40 | # note about device management: since distributed training is enabled, the inputs to this module can be on 41 | # *any* device (preferably cpu, since we wrap and unwrap the module) we want to keep these params on the 42 | # input device (assuming CPU) for as long as possible for cheap memory access 43 | target_device = next(self.parameters()).device 44 | #cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) 45 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) 46 | input_tensors = [] 47 | query_lengths = [] 48 | for q, d in zip(query, document_batch): 49 | if len(q) + len(d) + 1 > self.max_length: 50 | d = d[:(self.max_length - len(q) - 1)] 51 | input_tensors.append(torch.cat([q, sep_token, d])) 52 | query_lengths.append(q.size()[0]) 53 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) 54 | outputs = self.bert(bert_input.data, attention_mask=bert_input.mask(on=0.0, off=float('-inf'), dtype=torch.float, device=target_device)) 55 | hidden = outputs[0] 56 | classes = self.relevance_tagger(hidden) 57 | ret = [] 58 | for ql, cls, doc in zip(query_lengths, classes, document_batch): 59 | start = ql + 1 60 | end = start + len(doc) 61 | ret.append(cls[ql + 1:end]) 62 | return PaddedSequence.autopad(ret, batch_first=True, padding_value=0, device=target_device).data.squeeze(dim=-1) 63 | -------------------------------------------------------------------------------- /BERT_rationale_benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from dataclasses import dataclass, asdict, is_dataclass 5 | from itertools import chain 6 | from typing import Dict, List, Set, Tuple, Union, FrozenSet 7 | 8 | 9 | @dataclass(eq=True, frozen=True) 10 | class Evidence: 11 | """ 12 | (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience. 13 | Args: 14 | text: Some representation of the evidence text 15 | docid: Some identifier for the document 16 | start_token: The canonical start token, inclusive 17 | end_token: The canonical end token, exclusive 18 | start_sentence: Best guess start sentence, inclusive 19 | end_sentence: Best guess end sentence, exclusive 20 | """ 21 | text: Union[str, Tuple[int], Tuple[str]] 22 | docid: str 23 | start_token: int = -1 24 | end_token: int = -1 25 | start_sentence: int = -1 26 | end_sentence: int = -1 27 | 28 | 29 | @dataclass(eq=True, frozen=True) 30 | class Annotation: 31 | """ 32 | Args: 33 | annotation_id: unique ID for this annotation element 34 | query: some representation of a query string 35 | evidences: a set of "evidence groups". 36 | Each evidence group is: 37 | * sufficient to respond to the query (or justify an answer) 38 | * composed of one or more Evidences 39 | * may have multiple documents in it (depending on the dataset) 40 | - e-snli has multiple documents 41 | - other datasets do not 42 | classification: str 43 | query_type: Optional str, additional information about the query 44 | docids: a set of docids in which one may find evidence. 45 | """ 46 | annotation_id: str 47 | query: Union[str, Tuple[int]] 48 | evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]] 49 | classification: str 50 | query_type: str = None 51 | docids: Set[str] = None 52 | 53 | def all_evidences(self) -> Tuple[Evidence]: 54 | return tuple(list(chain.from_iterable(self.evidences))) 55 | 56 | 57 | def annotations_to_jsonl(annotations, output_file): 58 | with open(output_file, 'w') as of: 59 | for ann in sorted(annotations, key=lambda x: x.annotation_id): 60 | as_json = _annotation_to_dict(ann) 61 | as_str = json.dumps(as_json, sort_keys=True) 62 | of.write(as_str) 63 | of.write('\n') 64 | 65 | 66 | def _annotation_to_dict(dc): 67 | # convenience method 68 | if is_dataclass(dc): 69 | d = asdict(dc) 70 | ret = dict() 71 | for k, v in d.items(): 72 | ret[k] = _annotation_to_dict(v) 73 | return ret 74 | elif isinstance(dc, dict): 75 | ret = dict() 76 | for k, v in dc.items(): 77 | k = _annotation_to_dict(k) 78 | v = _annotation_to_dict(v) 79 | ret[k] = v 80 | return ret 81 | elif isinstance(dc, str): 82 | return dc 83 | elif isinstance(dc, (set, frozenset, list, tuple)): 84 | ret = [] 85 | for x in dc: 86 | ret.append(_annotation_to_dict(x)) 87 | return tuple(ret) 88 | else: 89 | return dc 90 | 91 | 92 | def load_jsonl(fp: str) -> List[dict]: 93 | ret = [] 94 | with open(fp, 'r') as inf: 95 | for line in inf: 96 | content = json.loads(line) 97 | ret.append(content) 98 | return ret 99 | 100 | 101 | def write_jsonl(jsonl, output_file): 102 | with open(output_file, 'w') as of: 103 | for js in jsonl: 104 | as_str = json.dumps(js, sort_keys=True) 105 | of.write(as_str) 106 | of.write('\n') 107 | 108 | 109 | def annotations_from_jsonl(fp: str) -> List[Annotation]: 110 | ret = [] 111 | with open(fp, 'r') as inf: 112 | for line in inf: 113 | content = json.loads(line) 114 | ev_groups = [] 115 | for ev_group in content['evidences']: 116 | ev_group = tuple([Evidence(**ev) for ev in ev_group]) 117 | ev_groups.append(ev_group) 118 | content['evidences'] = frozenset(ev_groups) 119 | ret.append(Annotation(**content)) 120 | return ret 121 | 122 | 123 | def load_datasets(data_dir: str) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]: 124 | """Loads a training, validation, and test dataset 125 | 126 | Each dataset is assumed to have been serialized by annotations_to_jsonl, 127 | that is it is a list of json-serialized Annotation instances. 128 | """ 129 | train_data = annotations_from_jsonl(os.path.join(data_dir, 'train.jsonl')) 130 | val_data = annotations_from_jsonl(os.path.join(data_dir, 'val.jsonl')) 131 | test_data = annotations_from_jsonl(os.path.join(data_dir, 'test.jsonl')) 132 | return train_data, val_data, test_data 133 | 134 | 135 | def load_documents(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: 136 | """Loads a subset of available documents from disk. 137 | 138 | Each document is assumed to be serialized as newline ('\n') separated sentences. 139 | Each sentence is assumed to be space (' ') joined tokens. 140 | """ 141 | if os.path.exists(os.path.join(data_dir, 'docs.jsonl')): 142 | assert not os.path.exists(os.path.join(data_dir, 'docs')) 143 | return load_documents_from_file(data_dir, docids) 144 | 145 | docs_dir = os.path.join(data_dir, 'docs') 146 | res = dict() 147 | if docids is None: 148 | docids = sorted(os.listdir(docs_dir)) 149 | else: 150 | docids = sorted(set(str(d) for d in docids)) 151 | for d in docids: 152 | with open(os.path.join(docs_dir, d), 'r') as inf: 153 | res[d] = inf.read() 154 | return res 155 | 156 | 157 | def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]: 158 | """Loads a subset of available documents from disk. 159 | 160 | Returns a tokenized version of the document. 161 | """ 162 | unflattened_docs = load_documents(data_dir, docids) 163 | flattened_docs = dict() 164 | for doc, unflattened in unflattened_docs.items(): 165 | flattened_docs[doc] = list(chain.from_iterable(unflattened)) 166 | return flattened_docs 167 | 168 | 169 | def intern_documents(documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str): 170 | """ 171 | Replaces every word with its index in an embeddings file. 172 | 173 | If a word is not found, uses the unk_token instead 174 | """ 175 | ret = dict() 176 | unk = word_interner[unk_token] 177 | for docid, sentences in documents.items(): 178 | ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences] 179 | return ret 180 | 181 | 182 | def intern_annotations(annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str): 183 | ret = [] 184 | for ann in annotations: 185 | ev_groups = [] 186 | for ev_group in ann.evidences: 187 | evs = [] 188 | for ev in ev_group: 189 | evs.append(Evidence( 190 | text=tuple([word_interner.get(t, word_interner[unk_token]) for t in ev.text.split()]), 191 | docid=ev.docid, 192 | start_token=ev.start_token, 193 | end_token=ev.end_token, 194 | start_sentence=ev.start_sentence, 195 | end_sentence=ev.end_sentence)) 196 | ev_groups.append(tuple(evs)) 197 | ret.append(Annotation(annotation_id=ann.annotation_id, 198 | query=tuple([word_interner.get(t, word_interner[unk_token]) for t in ann.query.split()]), 199 | evidences=frozenset(ev_groups), 200 | classification=ann.classification, 201 | query_type=ann.query_type)) 202 | return ret 203 | 204 | 205 | def load_documents_from_file(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: 206 | """Loads a subset of available documents from 'docs.jsonl' file on disk. 207 | 208 | Each document is assumed to be serialized as newline ('\n') separated sentences. 209 | Each sentence is assumed to be space (' ') joined tokens. 210 | """ 211 | docs_file = os.path.join(data_dir, 'docs.jsonl') 212 | documents = load_jsonl(docs_file) 213 | documents = {doc['docid']: doc['document'] for doc in documents} 214 | # res = dict() 215 | # if docids is None: 216 | # docids = sorted(list(documents.keys())) 217 | # else: 218 | # docids = sorted(set(str(d) for d in docids)) 219 | # for d in docids: 220 | # lines = documents[d].split('\n') 221 | # tokenized = [line.strip().split(' ') for line in lines] 222 | # res[d] = tokenized 223 | return documents 224 | -------------------------------------------------------------------------------- /DeiT.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/DeiT.PNG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hila Chefer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838) [CVPR 2021] 2 | 3 | #### Check out our new advancements- [Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers](https://github.com/hila-chefer/Transformer-MM-Explainability)! 4 | Faster, more general, and can be applied to *any* type of attention! 5 | Among the features: 6 | * We remove LRP for a simple and quick solution, and prove that the great results from our first paper still hold! 7 | * We expand our work to *any* type of Transformer- not just self-attention based encoders, but also co-attention encoders and encoder-decoders! 8 | * We show that VQA models can actually understand both image and text and make connections! 9 | * We use a DETR object detector and create segmentation masks from our explanations! 10 | * We provide a colab notebook with all the examples. You can very easily add images and questions of your own! 11 | 12 |

13 | 14 |

15 | 16 | --- 17 | ## ViT explainability notebook: 18 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb) 19 | 20 | ## BERT explainability notebook: 21 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb) 22 | --- 23 | 24 | ## Updates 25 | April 5 2021: Check out this new [post](https://analyticsindiamag.com/compute-relevancy-of-transformer-networks-via-novel-interpretable-transformer/) about our paper! A great resource for understanding the main concepts behind our work. 26 | 27 | March 15 2021: [A Colab notebook for BERT for sentiment analysis added!](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb) 28 | 29 | Feb 28 2021: Our paper was accepted to CVPR 2021! 30 | 31 | Feb 17 2021: [A Colab notebook with all examples added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb) 32 | 33 | Jan 5 2021: [A Jupyter notebook for DeiT added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/DeiT_example.ipynb) 34 | 35 | 36 |

37 | 38 |

39 | 40 | 41 | ## Introduction 42 | Official implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838). 43 | 44 | We introduce a novel method which allows to visualize classifications made by a Transformer based model for both vision and NLP tasks. 45 | Our method also allows to visualize explanations per class. 46 | 47 |

48 | 49 |

50 | Method consists of 3 phases: 51 | 52 | 1. Calculating relevance for each attention matrix using our novel formulation of LRP. 53 | 54 | 2. Backpropagation of gradients for each attention matrix w.r.t. the visualized class. Gradients are used to average attention heads. 55 | 56 | 3. Layer aggregation with rollout. 57 | 58 | Please notice our [Jupyter notebook](https://github.com/hila-chefer/Transformer-Explainability/blob/main/example.ipynb) where you can run the two class specific examples from the paper. 59 | 60 | 61 | ![alt text](https://github.com/hila-chefer/Transformer-Explainability/blob/main/example.PNG) 62 | 63 | To add another input image, simply add the image to the [samples folder](https://github.com/hila-chefer/Transformer-Explainability/tree/main/samples), and use the `generate_visualization` function for your selected class of interest (using the `class_index={class_idx}`), not specifying the index will visualize the top class. 64 | 65 | ## Credits 66 | ViT implementation is based on: 67 | - https://github.com/rwightman/pytorch-image-models 68 | - https://github.com/lucidrains/vit-pytorch 69 | - pretrained weights from: https://github.com/google-research/vision_transformer 70 | 71 | BERT implementation is taken from the huggingface Transformers library: 72 | https://huggingface.co/transformers/ 73 | 74 | ERASER benchmark code adapted from the ERASER GitHub implementation: https://github.com/jayded/eraserbenchmark 75 | 76 | Text visualizations in supplementary were created using TAHV heatmap generator for text: https://github.com/jiesutd/Text-Attention-Heatmap-Visualization 77 | 78 | ## Reproducing results on ViT 79 | 80 | ### Section A. Segmentation Results 81 | 82 | Example: 83 | ``` 84 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/imagenet_seg_eval.py --method transformer_attribution --imagenet-seg-path /path/to/gtsegs_ijcv.mat 85 | 86 | ``` 87 | [Link to download dataset](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat). 88 | 89 | In the exmaple above we run a segmentation test with our method. Notice you can choose which method you wish to run using the `--method` argument. 90 | You must provide a path to imagenet segmentation data in `--imagenet-seg-path`. 91 | 92 | ### Section B. Perturbation Results 93 | 94 | Example: 95 | ``` 96 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/generate_visualizations.py --method transformer_attribution --imagenet-validation-path /path/to/imagenet_validation_directory 97 | ``` 98 | 99 | Notice that you can choose to visualize by target or top class by using the `--vis-cls` argument. 100 | 101 | Now to run the perturbation test run the following command: 102 | ``` 103 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/pertubation_eval_from_hdf5.py --method transformer_attribution 104 | ``` 105 | 106 | Notice that you can use the `--neg` argument to run either positive or negative perturbation. 107 | 108 | ## Reproducing results on BERT 109 | 110 | 1. Download the pretrained weights: 111 | 112 | - Download `classifier.zip` from https://drive.google.com/file/d/1kGMTr69UWWe70i-o2_JfjmWDQjT66xwQ/view?usp=sharing 113 | - mkdir -p `./bert_models/movies` 114 | - unzip classifier.zip -d ./bert_models/movies/ 115 | 116 | 2. Download the dataset pkl file: 117 | 118 | - Download `preprocessed.pkl` from https://drive.google.com/file/d/1-gfbTj6D87KIm_u1QMHGLKSL3e93hxBH/view?usp=sharing 119 | - mv preprocessed.pkl ./bert_models/movies 120 | 121 | 3. Download the dataset: 122 | 123 | - Download `movies.zip` from https://drive.google.com/file/d/11faFLGkc0hkw3wrGTYJBr1nIvkRb189F/view?usp=sharing 124 | - unzip movies.zip -d ./data/ 125 | 126 | 4. Now you can run the model. 127 | 128 | Example: 129 | ``` 130 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies/ --output_dir bert_models/movies/ --model_params BERT_params/movies_bert.json 131 | ``` 132 | To control which algorithm to use for explanations change the `method` variable in `BERT_rationale_benchmark/models/pipeline/bert_pipeline.py` (Defaults to 'transformer_attribution' which is our method). 133 | Running this command will create a directory for the method in `bert_models/movies/`. 134 | 135 | In order to run f1 test with k, run the following command: 136 | ``` 137 | PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/metrics.py --data_dir data/movies/ --split test --results bert_models/movies//identifier_results_k.json 138 | ``` 139 | 140 | Also, in the method directory there will be created `.tex` files containing the explanations extracted for each example. This corresponds to our visualizations in the supplementary. 141 | 142 | ## Citing our paper 143 | If you make use of our work, please cite our paper: 144 | ``` 145 | @InProceedings{Chefer_2021_CVPR, 146 | author = {Chefer, Hila and Gur, Shir and Wolf, Lior}, 147 | title = {Transformer Interpretability Beyond Attention Visualization}, 148 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 149 | month = {June}, 150 | year = {2021}, 151 | pages = {782-791} 152 | } 153 | ``` 154 | -------------------------------------------------------------------------------- /baselines/ViT/ViT_explanation_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | from numpy import * 5 | 6 | # compute rollout between attention layers 7 | def compute_rollout_attention(all_layer_matrices, start_layer=0): 8 | # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow 9 | num_tokens = all_layer_matrices[0].shape[1] 10 | batch_size = all_layer_matrices[0].shape[0] 11 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) 12 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] 13 | matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) 14 | for i in range(len(all_layer_matrices))] 15 | joint_attention = matrices_aug[start_layer] 16 | for i in range(start_layer+1, len(matrices_aug)): 17 | joint_attention = matrices_aug[i].bmm(joint_attention) 18 | return joint_attention 19 | 20 | class LRP: 21 | def __init__(self, model): 22 | self.model = model 23 | self.model.eval() 24 | 25 | def generate_LRP(self, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0): 26 | output = self.model(input) 27 | kwargs = {"alpha": 1} 28 | if index == None: 29 | index = np.argmax(output.cpu().data.numpy(), axis=-1) 30 | 31 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 32 | one_hot[0, index] = 1 33 | one_hot_vector = one_hot 34 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 35 | one_hot = torch.sum(one_hot.cuda() * output) 36 | 37 | self.model.zero_grad() 38 | one_hot.backward(retain_graph=True) 39 | 40 | return self.model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation, 41 | start_layer=start_layer, **kwargs) 42 | 43 | 44 | 45 | class Baselines: 46 | def __init__(self, model): 47 | self.model = model 48 | self.model.eval() 49 | 50 | def generate_cam_attn(self, input, index=None): 51 | output = self.model(input.cuda(), register_hook=True) 52 | if index == None: 53 | index = np.argmax(output.cpu().data.numpy()) 54 | 55 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 56 | one_hot[0][index] = 1 57 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 58 | one_hot = torch.sum(one_hot.cuda() * output) 59 | 60 | self.model.zero_grad() 61 | one_hot.backward(retain_graph=True) 62 | #################### attn 63 | grad = self.model.blocks[-1].attn.get_attn_gradients() 64 | cam = self.model.blocks[-1].attn.get_attention_map() 65 | cam = cam[0, :, 0, 1:].reshape(-1, 14, 14) 66 | grad = grad[0, :, 0, 1:].reshape(-1, 14, 14) 67 | grad = grad.mean(dim=[1, 2], keepdim=True) 68 | cam = (cam * grad).mean(0).clamp(min=0) 69 | cam = (cam - cam.min()) / (cam.max() - cam.min()) 70 | 71 | return cam 72 | #################### attn 73 | 74 | def generate_rollout(self, input, start_layer=0): 75 | self.model(input) 76 | blocks = self.model.blocks 77 | all_layer_attentions = [] 78 | for blk in blocks: 79 | attn_heads = blk.attn.get_attention_map() 80 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() 81 | all_layer_attentions.append(avg_heads) 82 | rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer) 83 | return rollout[:,0, 1:] 84 | -------------------------------------------------------------------------------- /baselines/ViT/ViT_new.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | from einops import rearrange 8 | 9 | from baselines.ViT.helpers import load_pretrained 10 | from baselines.ViT.weight_init import trunc_normal_ 11 | from baselines.ViT.layer_helpers import to_2tuple 12 | 13 | 14 | def _cfg(url='', **kwargs): 15 | return { 16 | 'url': url, 17 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 18 | 'crop_pct': .9, 'interpolation': 'bicubic', 19 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 20 | **kwargs 21 | } 22 | 23 | 24 | default_cfgs = { 25 | # patch models 26 | 'vit_small_patch16_224': _cfg( 27 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 28 | ), 29 | 'vit_base_patch16_224': _cfg( 30 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 31 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 32 | ), 33 | 'vit_large_patch16_224': _cfg( 34 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 35 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 36 | } 37 | 38 | class Mlp(nn.Module): 39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 40 | super().__init__() 41 | out_features = out_features or in_features 42 | hidden_features = hidden_features or in_features 43 | self.fc1 = nn.Linear(in_features, hidden_features) 44 | self.act = act_layer() 45 | self.fc2 = nn.Linear(hidden_features, out_features) 46 | self.drop = nn.Dropout(drop) 47 | 48 | def forward(self, x): 49 | x = self.fc1(x) 50 | x = self.act(x) 51 | x = self.drop(x) 52 | x = self.fc2(x) 53 | x = self.drop(x) 54 | return x 55 | 56 | 57 | class Attention(nn.Module): 58 | def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.): 59 | super().__init__() 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 63 | self.scale = head_dim ** -0.5 64 | 65 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 66 | self.attn_drop = nn.Dropout(attn_drop) 67 | self.proj = nn.Linear(dim, dim) 68 | self.proj_drop = nn.Dropout(proj_drop) 69 | 70 | self.attn_gradients = None 71 | self.attention_map = None 72 | 73 | def save_attn_gradients(self, attn_gradients): 74 | self.attn_gradients = attn_gradients 75 | 76 | def get_attn_gradients(self): 77 | return self.attn_gradients 78 | 79 | def save_attention_map(self, attention_map): 80 | self.attention_map = attention_map 81 | 82 | def get_attention_map(self): 83 | return self.attention_map 84 | 85 | def forward(self, x, register_hook=False): 86 | b, n, _, h = *x.shape, self.num_heads 87 | 88 | # self.save_output(x) 89 | # x.register_hook(self.save_output_grad) 90 | 91 | qkv = self.qkv(x) 92 | q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) 93 | 94 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 95 | 96 | attn = dots.softmax(dim=-1) 97 | attn = self.attn_drop(attn) 98 | 99 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 100 | 101 | self.save_attention_map(attn) 102 | if register_hook: 103 | attn.register_hook(self.save_attn_gradients) 104 | 105 | out = rearrange(out, 'b h n d -> b n (h d)') 106 | out = self.proj(out) 107 | out = self.proj_drop(out) 108 | return out 109 | 110 | 111 | class Block(nn.Module): 112 | 113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 114 | super().__init__() 115 | self.norm1 = norm_layer(dim) 116 | self.attn = Attention( 117 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 118 | self.norm2 = norm_layer(dim) 119 | mlp_hidden_dim = int(dim * mlp_ratio) 120 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 121 | 122 | def forward(self, x, register_hook=False): 123 | x = x + self.attn(self.norm1(x), register_hook=register_hook) 124 | x = x + self.mlp(self.norm2(x)) 125 | return x 126 | 127 | 128 | class PatchEmbed(nn.Module): 129 | """ Image to Patch Embedding 130 | """ 131 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 132 | super().__init__() 133 | img_size = to_2tuple(img_size) 134 | patch_size = to_2tuple(patch_size) 135 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 136 | self.img_size = img_size 137 | self.patch_size = patch_size 138 | self.num_patches = num_patches 139 | 140 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 141 | 142 | def forward(self, x): 143 | B, C, H, W = x.shape 144 | # FIXME look at relaxing size constraints 145 | assert H == self.img_size[0] and W == self.img_size[1], \ 146 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 147 | x = self.proj(x).flatten(2).transpose(1, 2) 148 | return x 149 | 150 | class VisionTransformer(nn.Module): 151 | """ Vision Transformer 152 | """ 153 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 154 | num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm): 155 | super().__init__() 156 | self.num_classes = num_classes 157 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 158 | self.patch_embed = PatchEmbed( 159 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 160 | num_patches = self.patch_embed.num_patches 161 | 162 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 163 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 164 | self.pos_drop = nn.Dropout(p=drop_rate) 165 | 166 | self.blocks = nn.ModuleList([ 167 | Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 169 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 170 | for i in range(depth)]) 171 | self.norm = norm_layer(embed_dim) 172 | 173 | # Classifier head 174 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 175 | 176 | trunc_normal_(self.pos_embed, std=.02) 177 | trunc_normal_(self.cls_token, std=.02) 178 | self.apply(self._init_weights) 179 | 180 | def _init_weights(self, m): 181 | if isinstance(m, nn.Linear): 182 | trunc_normal_(m.weight, std=.02) 183 | if isinstance(m, nn.Linear) and m.bias is not None: 184 | nn.init.constant_(m.bias, 0) 185 | elif isinstance(m, nn.LayerNorm): 186 | nn.init.constant_(m.bias, 0) 187 | nn.init.constant_(m.weight, 1.0) 188 | 189 | @torch.jit.ignore 190 | def no_weight_decay(self): 191 | return {'pos_embed', 'cls_token'} 192 | 193 | def forward(self, x, register_hook=False): 194 | B = x.shape[0] 195 | x = self.patch_embed(x) 196 | 197 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 198 | x = torch.cat((cls_tokens, x), dim=1) 199 | x = x + self.pos_embed 200 | x = self.pos_drop(x) 201 | 202 | for blk in self.blocks: 203 | x = blk(x, register_hook=register_hook) 204 | 205 | x = self.norm(x) 206 | x = x[:, 0] 207 | x = self.head(x) 208 | return x 209 | 210 | 211 | def _conv_filter(state_dict, patch_size=16): 212 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 213 | out_dict = {} 214 | for k, v in state_dict.items(): 215 | if 'patch_embed.proj.weight' in k: 216 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 217 | out_dict[k] = v 218 | return out_dict 219 | 220 | 221 | def vit_base_patch16_224(pretrained=False, **kwargs): 222 | model = VisionTransformer( 223 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 224 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 225 | model.default_cfg = default_cfgs['vit_base_patch16_224'] 226 | if pretrained: 227 | load_pretrained( 228 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) 229 | return model 230 | 231 | def vit_large_patch16_224(pretrained=False, **kwargs): 232 | model = VisionTransformer( 233 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 234 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 235 | model.default_cfg = default_cfgs['vit_large_patch16_224'] 236 | if pretrained: 237 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 238 | return model 239 | -------------------------------------------------------------------------------- /baselines/ViT/generate_visualizations.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import h5py 4 | 5 | import argparse 6 | 7 | # Import saliency methods and models 8 | from misc_functions import * 9 | 10 | from ViT_explanation_generator import Baselines, LRP 11 | from ViT_new import vit_base_patch16_224 12 | from ViT_LRP import vit_base_patch16_224 as vit_LRP 13 | from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP 14 | 15 | from torchvision.datasets import ImageNet 16 | 17 | 18 | def normalize(tensor, 19 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 20 | dtype = tensor.dtype 21 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) 22 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device) 23 | tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 24 | return tensor 25 | 26 | 27 | def compute_saliency_and_save(args): 28 | first = True 29 | with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f: 30 | data_cam = f.create_dataset('vis', 31 | (1, 1, 224, 224), 32 | maxshape=(None, 1, 224, 224), 33 | dtype=np.float32, 34 | compression="gzip") 35 | data_image = f.create_dataset('image', 36 | (1, 3, 224, 224), 37 | maxshape=(None, 3, 224, 224), 38 | dtype=np.float32, 39 | compression="gzip") 40 | data_target = f.create_dataset('target', 41 | (1,), 42 | maxshape=(None,), 43 | dtype=np.int32, 44 | compression="gzip") 45 | for batch_idx, (data, target) in enumerate(tqdm(sample_loader)): 46 | if first: 47 | first = False 48 | data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0) 49 | data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0) 50 | data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0) 51 | else: 52 | data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0) 53 | data_image.resize(data_image.shape[0] + data.shape[0], axis=0) 54 | data_target.resize(data_target.shape[0] + data.shape[0], axis=0) 55 | 56 | # Add data 57 | data_image[-data.shape[0]:] = data.data.cpu().numpy() 58 | data_target[-data.shape[0]:] = target.data.cpu().numpy() 59 | 60 | target = target.to(device) 61 | 62 | data = normalize(data) 63 | data = data.to(device) 64 | data.requires_grad_() 65 | 66 | index = None 67 | if args.vis_class == 'target': 68 | index = target 69 | 70 | if args.method == 'rollout': 71 | Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14) 72 | # Res = Res - Res.mean() 73 | 74 | elif args.method == 'lrp': 75 | Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14) 76 | # Res = Res - Res.mean() 77 | 78 | elif args.method == 'transformer_attribution': 79 | Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14) 80 | # Res = Res - Res.mean() 81 | 82 | elif args.method == 'full_lrp': 83 | Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224) 84 | # Res = Res - Res.mean() 85 | 86 | elif args.method == 'lrp_last_layer': 87 | Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \ 88 | .reshape(data.shape[0], 1, 14, 14) 89 | # Res = Res - Res.mean() 90 | 91 | elif args.method == 'attn_last_layer': 92 | Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \ 93 | .reshape(data.shape[0], 1, 14, 14) 94 | 95 | elif args.method == 'attn_gradcam': 96 | Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14) 97 | 98 | if args.method != 'full_lrp' and args.method != 'input_grads': 99 | Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda() 100 | Res = (Res - Res.min()) / (Res.max() - Res.min()) 101 | 102 | data_cam[-data.shape[0]:] = Res.data.cpu().numpy() 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser(description='Train a segmentation') 107 | parser.add_argument('--batch-size', type=int, 108 | default=1, 109 | help='') 110 | parser.add_argument('--method', type=str, 111 | default='grad_rollout', 112 | choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer', 113 | 'attn_last_layer', 'attn_gradcam'], 114 | help='') 115 | parser.add_argument('--lmd', type=float, 116 | default=10, 117 | help='') 118 | parser.add_argument('--vis-class', type=str, 119 | default='top', 120 | choices=['top', 'target', 'index'], 121 | help='') 122 | parser.add_argument('--class-id', type=int, 123 | default=0, 124 | help='') 125 | parser.add_argument('--cls-agn', action='store_true', 126 | default=False, 127 | help='') 128 | parser.add_argument('--no-ia', action='store_true', 129 | default=False, 130 | help='') 131 | parser.add_argument('--no-fx', action='store_true', 132 | default=False, 133 | help='') 134 | parser.add_argument('--no-fgx', action='store_true', 135 | default=False, 136 | help='') 137 | parser.add_argument('--no-m', action='store_true', 138 | default=False, 139 | help='') 140 | parser.add_argument('--no-reg', action='store_true', 141 | default=False, 142 | help='') 143 | parser.add_argument('--is-ablation', type=bool, 144 | default=False, 145 | help='') 146 | parser.add_argument('--imagenet-validation-path', type=str, 147 | required=True, 148 | help='') 149 | args = parser.parse_args() 150 | 151 | # PATH variables 152 | PATH = os.path.dirname(os.path.abspath(__file__)) + '/' 153 | os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True) 154 | 155 | try: 156 | os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method, 157 | args.vis_class))) 158 | except OSError: 159 | pass 160 | 161 | 162 | os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True) 163 | if args.vis_class == 'index': 164 | os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method, 165 | args.vis_class, 166 | args.class_id)), exist_ok=True) 167 | args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method, 168 | args.vis_class, 169 | args.class_id)) 170 | else: 171 | ablation_fold = 'ablation' if args.is_ablation else 'not_ablation' 172 | os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method, 173 | args.vis_class, ablation_fold)), exist_ok=True) 174 | args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method, 175 | args.vis_class, ablation_fold)) 176 | 177 | cuda = torch.cuda.is_available() 178 | device = torch.device("cuda" if cuda else "cpu") 179 | 180 | # Model 181 | model = vit_base_patch16_224(pretrained=True).cuda() 182 | baselines = Baselines(model) 183 | 184 | # LRP 185 | model_LRP = vit_LRP(pretrained=True).cuda() 186 | model_LRP.eval() 187 | lrp = LRP(model_LRP) 188 | 189 | # orig LRP 190 | model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() 191 | model_orig_LRP.eval() 192 | orig_lrp = LRP(model_orig_LRP) 193 | 194 | # Dataset loader for sample images 195 | transform = transforms.Compose([ 196 | transforms.Resize((224, 224)), 197 | transforms.ToTensor(), 198 | ]) 199 | 200 | imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform) 201 | sample_loader = torch.utils.data.DataLoader( 202 | imagenet_ds, 203 | batch_size=args.batch_size, 204 | shuffle=False, 205 | num_workers=4 206 | ) 207 | 208 | compute_saliency_and_save(args) 209 | -------------------------------------------------------------------------------- /baselines/ViT/helpers.py: -------------------------------------------------------------------------------- 1 | """ Model creation / weight loading / state_dict helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import os 7 | import math 8 | from collections import OrderedDict 9 | from copy import deepcopy 10 | from typing import Callable 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | def load_state_dict(checkpoint_path, use_ema=False): 20 | if checkpoint_path and os.path.isfile(checkpoint_path): 21 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 22 | state_dict_key = 'state_dict' 23 | if isinstance(checkpoint, dict): 24 | if use_ema and 'state_dict_ema' in checkpoint: 25 | state_dict_key = 'state_dict_ema' 26 | if state_dict_key and state_dict_key in checkpoint: 27 | new_state_dict = OrderedDict() 28 | for k, v in checkpoint[state_dict_key].items(): 29 | # strip `module.` prefix 30 | name = k[7:] if k.startswith('module') else k 31 | new_state_dict[name] = v 32 | state_dict = new_state_dict 33 | else: 34 | state_dict = checkpoint 35 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 36 | return state_dict 37 | else: 38 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 39 | raise FileNotFoundError() 40 | 41 | 42 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): 43 | state_dict = load_state_dict(checkpoint_path, use_ema) 44 | model.load_state_dict(state_dict, strict=strict) 45 | 46 | 47 | def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): 48 | resume_epoch = None 49 | if os.path.isfile(checkpoint_path): 50 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 51 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 52 | if log_info: 53 | _logger.info('Restoring model state from checkpoint...') 54 | new_state_dict = OrderedDict() 55 | for k, v in checkpoint['state_dict'].items(): 56 | name = k[7:] if k.startswith('module') else k 57 | new_state_dict[name] = v 58 | model.load_state_dict(new_state_dict) 59 | 60 | if optimizer is not None and 'optimizer' in checkpoint: 61 | if log_info: 62 | _logger.info('Restoring optimizer state from checkpoint...') 63 | optimizer.load_state_dict(checkpoint['optimizer']) 64 | 65 | if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: 66 | if log_info: 67 | _logger.info('Restoring AMP loss scaler state from checkpoint...') 68 | loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) 69 | 70 | if 'epoch' in checkpoint: 71 | resume_epoch = checkpoint['epoch'] 72 | if 'version' in checkpoint and checkpoint['version'] > 1: 73 | resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save 74 | 75 | if log_info: 76 | _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) 77 | else: 78 | model.load_state_dict(checkpoint) 79 | if log_info: 80 | _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) 81 | return resume_epoch 82 | else: 83 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 84 | raise FileNotFoundError() 85 | 86 | 87 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): 88 | if cfg is None: 89 | cfg = getattr(model, 'default_cfg') 90 | if cfg is None or 'url' not in cfg or not cfg['url']: 91 | _logger.warning("Pretrained model URL is invalid, using random initialization.") 92 | return 93 | 94 | state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') 95 | 96 | if filter_fn is not None: 97 | state_dict = filter_fn(state_dict) 98 | 99 | if in_chans == 1: 100 | conv1_name = cfg['first_conv'] 101 | _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) 102 | conv1_weight = state_dict[conv1_name + '.weight'] 103 | # Some weights are in torch.half, ensure it's float for sum on CPU 104 | conv1_type = conv1_weight.dtype 105 | conv1_weight = conv1_weight.float() 106 | O, I, J, K = conv1_weight.shape 107 | if I > 3: 108 | assert conv1_weight.shape[1] % 3 == 0 109 | # For models with space2depth stems 110 | conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) 111 | conv1_weight = conv1_weight.sum(dim=2, keepdim=False) 112 | else: 113 | conv1_weight = conv1_weight.sum(dim=1, keepdim=True) 114 | conv1_weight = conv1_weight.to(conv1_type) 115 | state_dict[conv1_name + '.weight'] = conv1_weight 116 | elif in_chans != 3: 117 | conv1_name = cfg['first_conv'] 118 | conv1_weight = state_dict[conv1_name + '.weight'] 119 | conv1_type = conv1_weight.dtype 120 | conv1_weight = conv1_weight.float() 121 | O, I, J, K = conv1_weight.shape 122 | if I != 3: 123 | _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) 124 | del state_dict[conv1_name + '.weight'] 125 | strict = False 126 | else: 127 | # NOTE this strategy should be better than random init, but there could be other combinations of 128 | # the original RGB input layer weights that'd work better for specific cases. 129 | _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) 130 | repeat = int(math.ceil(in_chans / 3)) 131 | conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 132 | conv1_weight *= (3 / float(in_chans)) 133 | conv1_weight = conv1_weight.to(conv1_type) 134 | state_dict[conv1_name + '.weight'] = conv1_weight 135 | 136 | classifier_name = cfg['classifier'] 137 | if num_classes == 1000 and cfg['num_classes'] == 1001: 138 | # special case for imagenet trained models with extra background class in pretrained weights 139 | classifier_weight = state_dict[classifier_name + '.weight'] 140 | state_dict[classifier_name + '.weight'] = classifier_weight[1:] 141 | classifier_bias = state_dict[classifier_name + '.bias'] 142 | state_dict[classifier_name + '.bias'] = classifier_bias[1:] 143 | elif num_classes != cfg['num_classes']: 144 | # completely discard fully connected for all other differences between pretrained and created model 145 | del state_dict[classifier_name + '.weight'] 146 | del state_dict[classifier_name + '.bias'] 147 | strict = False 148 | 149 | model.load_state_dict(state_dict, strict=strict) 150 | 151 | 152 | def extract_layer(model, layer): 153 | layer = layer.split('.') 154 | module = model 155 | if hasattr(model, 'module') and layer[0] != 'module': 156 | module = model.module 157 | if not hasattr(model, 'module') and layer[0] == 'module': 158 | layer = layer[1:] 159 | for l in layer: 160 | if hasattr(module, l): 161 | if not l.isdigit(): 162 | module = getattr(module, l) 163 | else: 164 | module = module[int(l)] 165 | else: 166 | return module 167 | return module 168 | 169 | 170 | def set_layer(model, layer, val): 171 | layer = layer.split('.') 172 | module = model 173 | if hasattr(model, 'module') and layer[0] != 'module': 174 | module = model.module 175 | lst_index = 0 176 | module2 = module 177 | for l in layer: 178 | if hasattr(module2, l): 179 | if not l.isdigit(): 180 | module2 = getattr(module2, l) 181 | else: 182 | module2 = module2[int(l)] 183 | lst_index += 1 184 | lst_index -= 1 185 | for l in layer[:lst_index]: 186 | if not l.isdigit(): 187 | module = getattr(module, l) 188 | else: 189 | module = module[int(l)] 190 | l = layer[lst_index] 191 | setattr(module, l, val) 192 | 193 | 194 | def adapt_model_from_string(parent_module, model_string): 195 | separator = '***' 196 | state_dict = {} 197 | lst_shape = model_string.split(separator) 198 | for k in lst_shape: 199 | k = k.split(':') 200 | key = k[0] 201 | shape = k[1][1:-1].split(',') 202 | if shape[0] != '': 203 | state_dict[key] = [int(i) for i in shape] 204 | 205 | new_module = deepcopy(parent_module) 206 | for n, m in parent_module.named_modules(): 207 | old_module = extract_layer(parent_module, n) 208 | if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): 209 | if isinstance(old_module, Conv2dSame): 210 | conv = Conv2dSame 211 | else: 212 | conv = nn.Conv2d 213 | s = state_dict[n + '.weight'] 214 | in_channels = s[1] 215 | out_channels = s[0] 216 | g = 1 217 | if old_module.groups > 1: 218 | in_channels = out_channels 219 | g = in_channels 220 | new_conv = conv( 221 | in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, 222 | bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, 223 | groups=g, stride=old_module.stride) 224 | set_layer(new_module, n, new_conv) 225 | if isinstance(old_module, nn.BatchNorm2d): 226 | new_bn = nn.BatchNorm2d( 227 | num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, 228 | affine=old_module.affine, track_running_stats=True) 229 | set_layer(new_module, n, new_bn) 230 | if isinstance(old_module, nn.Linear): 231 | # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? 232 | num_features = state_dict[n + '.weight'][1] 233 | new_fc = nn.Linear( 234 | in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) 235 | set_layer(new_module, n, new_fc) 236 | if hasattr(new_module, 'num_features'): 237 | new_module.num_features = num_features 238 | new_module.eval() 239 | parent_module.eval() 240 | 241 | return new_module 242 | 243 | 244 | def adapt_model_from_file(parent_module, model_variant): 245 | adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') 246 | with open(adapt_file, 'r') as f: 247 | return adapt_model_from_string(parent_module, f.read().strip()) 248 | 249 | 250 | def build_model_with_cfg( 251 | model_cls: Callable, 252 | variant: str, 253 | pretrained: bool, 254 | default_cfg: dict, 255 | model_cfg: dict = None, 256 | feature_cfg: dict = None, 257 | pretrained_strict: bool = True, 258 | pretrained_filter_fn: Callable = None, 259 | **kwargs): 260 | pruned = kwargs.pop('pruned', False) 261 | features = False 262 | feature_cfg = feature_cfg or {} 263 | 264 | if kwargs.pop('features_only', False): 265 | features = True 266 | feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) 267 | if 'out_indices' in kwargs: 268 | feature_cfg['out_indices'] = kwargs.pop('out_indices') 269 | 270 | model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) 271 | model.default_cfg = deepcopy(default_cfg) 272 | 273 | if pruned: 274 | model = adapt_model_from_file(model, variant) 275 | 276 | if pretrained: 277 | load_pretrained( 278 | model, 279 | num_classes=kwargs.get('num_classes', 0), 280 | in_chans=kwargs.get('in_chans', 3), 281 | filter_fn=pretrained_filter_fn, strict=pretrained_strict) 282 | 283 | if features: 284 | feature_cls = FeatureListNet 285 | if 'feature_cls' in feature_cfg: 286 | feature_cls = feature_cfg.pop('feature_cls') 287 | if isinstance(feature_cls, str): 288 | feature_cls = feature_cls.lower() 289 | if 'hook' in feature_cls: 290 | feature_cls = FeatureHookNet 291 | else: 292 | assert False, f'Unknown feature class {feature_cls}' 293 | model = feature_cls(model, **feature_cfg) 294 | 295 | return model -------------------------------------------------------------------------------- /baselines/ViT/layer_helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from itertools import repeat 5 | import collections.abc 6 | 7 | 8 | # From PyTorch internals 9 | def _ntuple(n): 10 | def parse(x): 11 | if isinstance(x, collections.abc.Iterable): 12 | return x 13 | return tuple(repeat(x, n)) 14 | return parse 15 | 16 | 17 | to_1tuple = _ntuple(1) 18 | to_2tuple = _ntuple(2) 19 | to_3tuple = _ntuple(3) 20 | to_4tuple = _ntuple(4) 21 | to_ntuple = _ntuple 22 | -------------------------------------------------------------------------------- /baselines/ViT/misc_functions.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ 3 | # Written by Suraj Srinivas 4 | # 5 | 6 | """ Misc helper functions """ 7 | 8 | import cv2 9 | import numpy as np 10 | import subprocess 11 | 12 | import torch 13 | import torchvision.transforms as transforms 14 | 15 | 16 | class NormalizeInverse(transforms.Normalize): 17 | # Undo normalization on images 18 | 19 | def __init__(self, mean, std): 20 | mean = torch.as_tensor(mean) 21 | std = torch.as_tensor(std) 22 | std_inv = 1 / (std + 1e-7) 23 | mean_inv = -mean * std_inv 24 | super(NormalizeInverse, self).__init__(mean=mean_inv, std=std_inv) 25 | 26 | def __call__(self, tensor): 27 | return super(NormalizeInverse, self).__call__(tensor.clone()) 28 | 29 | 30 | def create_folder(folder_name): 31 | try: 32 | subprocess.call(['mkdir', '-p', folder_name]) 33 | except OSError: 34 | None 35 | 36 | 37 | def save_saliency_map(image, saliency_map, filename): 38 | """ 39 | Save saliency map on image. 40 | 41 | Args: 42 | image: Tensor of size (3,H,W) 43 | saliency_map: Tensor of size (1,H,W) 44 | filename: string with complete path and file extension 45 | 46 | """ 47 | 48 | image = image.data.cpu().numpy() 49 | saliency_map = saliency_map.data.cpu().numpy() 50 | 51 | saliency_map = saliency_map - saliency_map.min() 52 | saliency_map = saliency_map / saliency_map.max() 53 | saliency_map = saliency_map.clip(0, 1) 54 | 55 | saliency_map = np.uint8(saliency_map * 255).transpose(1, 2, 0) 56 | saliency_map = cv2.resize(saliency_map, (224, 224)) 57 | 58 | image = np.uint8(image * 255).transpose(1, 2, 0) 59 | image = cv2.resize(image, (224, 224)) 60 | 61 | # Apply JET colormap 62 | color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET) 63 | 64 | # Combine image with heatmap 65 | img_with_heatmap = np.float32(color_heatmap) + np.float32(image) 66 | img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap) 67 | 68 | cv2.imwrite(filename, np.uint8(255 * img_with_heatmap)) 69 | -------------------------------------------------------------------------------- /baselines/ViT/pertubation_eval_from_hdf5.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | from tqdm import tqdm 5 | import numpy as np 6 | import argparse 7 | 8 | # Import saliency methods and models 9 | from ViT_explanation_generator import Baselines 10 | from ViT_new import vit_base_patch16_224 11 | # from models.vgg import vgg19 12 | import glob 13 | 14 | from dataset.expl_hdf5 import ImagenetResults 15 | 16 | 17 | def normalize(tensor, 18 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 19 | dtype = tensor.dtype 20 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) 21 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device) 22 | tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 23 | return tensor 24 | 25 | 26 | def eval(args): 27 | num_samples = 0 28 | num_correct_model = np.zeros((len(imagenet_ds,))) 29 | dissimilarity_model = np.zeros((len(imagenet_ds,))) 30 | model_index = 0 31 | 32 | if args.scale == 'per': 33 | base_size = 224 * 224 34 | perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 35 | elif args.scale == '100': 36 | base_size = 100 37 | perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45] 38 | else: 39 | raise Exception('scale not valid') 40 | 41 | num_correct_pertub = np.zeros((9, len(imagenet_ds))) 42 | dissimilarity_pertub = np.zeros((9, len(imagenet_ds))) 43 | logit_diff_pertub = np.zeros((9, len(imagenet_ds))) 44 | prob_diff_pertub = np.zeros((9, len(imagenet_ds))) 45 | perturb_index = 0 46 | 47 | for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)): 48 | # Update the number of samples 49 | num_samples += len(data) 50 | 51 | data = data.to(device) 52 | vis = vis.to(device) 53 | target = target.to(device) 54 | norm_data = normalize(data.clone()) 55 | 56 | # Compute model accuracy 57 | pred = model(norm_data) 58 | pred_probabilities = torch.softmax(pred, dim=1) 59 | pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1) 60 | pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) 61 | pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1) 62 | tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy() 63 | num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred 64 | 65 | probs = torch.softmax(pred, dim=1) 66 | target_probs = torch.gather(probs, 1, target[:, None])[:, 0] 67 | second_probs = probs.data.topk(2, dim=1)[0][:, 1] 68 | temp = torch.log(target_probs / second_probs).data.cpu().numpy() 69 | dissimilarity_model[model_index:model_index+len(temp)] = temp 70 | 71 | if args.wrong: 72 | wid = np.argwhere(tgt_pred == 0).flatten() 73 | if len(wid) == 0: 74 | continue 75 | wid = torch.from_numpy(wid).to(vis.device) 76 | vis = vis.index_select(0, wid) 77 | data = data.index_select(0, wid) 78 | target = target.index_select(0, wid) 79 | 80 | # Save original shape 81 | org_shape = data.shape 82 | 83 | if args.neg: 84 | vis = -vis 85 | 86 | vis = vis.reshape(org_shape[0], -1) 87 | 88 | for i in range(len(perturbation_steps)): 89 | _data = data.clone() 90 | 91 | _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1) 92 | idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1) 93 | _data = _data.reshape(org_shape[0], org_shape[1], -1) 94 | _data = _data.scatter_(-1, idx, 0) 95 | _data = _data.reshape(*org_shape) 96 | 97 | _norm_data = normalize(_data) 98 | 99 | out = model(_norm_data) 100 | 101 | pred_probabilities = torch.softmax(out, dim=1) 102 | pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) 103 | diff = (pred_prob - pred_org_prob).data.cpu().numpy() 104 | prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff 105 | 106 | pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1) 107 | diff = (pred_logit - pred_org_logit).data.cpu().numpy() 108 | logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff 109 | 110 | target_class = out.data.max(1, keepdim=True)[1].squeeze(1) 111 | temp = (target == target_class).type(target.type()).data.cpu().numpy() 112 | num_correct_pertub[i, perturb_index:perturb_index+len(temp)] = temp 113 | 114 | probs_pertub = torch.softmax(out, dim=1) 115 | target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0] 116 | second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1] 117 | temp = torch.log(target_probs / second_probs).data.cpu().numpy() 118 | dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp 119 | 120 | model_index += len(target) 121 | perturb_index += len(target) 122 | 123 | np.save(os.path.join(args.experiment_dir, 'model_hits.npy'), num_correct_model) 124 | np.save(os.path.join(args.experiment_dir, 'model_dissimilarities.npy'), dissimilarity_model) 125 | np.save(os.path.join(args.experiment_dir, 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index]) 126 | np.save(os.path.join(args.experiment_dir, 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index]) 127 | np.save(os.path.join(args.experiment_dir, 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index]) 128 | np.save(os.path.join(args.experiment_dir, 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index]) 129 | 130 | print(np.mean(num_correct_model), np.std(num_correct_model)) 131 | print(np.mean(dissimilarity_model), np.std(dissimilarity_model)) 132 | print(perturbation_steps) 133 | print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1)) 134 | print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1)) 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser(description='Train a segmentation') 139 | parser.add_argument('--batch-size', type=int, 140 | default=16, 141 | help='') 142 | parser.add_argument('--neg', type=bool, 143 | default=True, 144 | help='') 145 | parser.add_argument('--value', action='store_true', 146 | default=False, 147 | help='') 148 | parser.add_argument('--scale', type=str, 149 | default='per', 150 | choices=['per', '100'], 151 | help='') 152 | parser.add_argument('--method', type=str, 153 | default='grad_rollout', 154 | choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer', 155 | 'lrp_second_layer', 'gradcam', 156 | 'attn_last_layer', 'attn_gradcam', 'input_grads'], 157 | help='') 158 | parser.add_argument('--vis-class', type=str, 159 | default='top', 160 | choices=['top', 'target', 'index'], 161 | help='') 162 | parser.add_argument('--wrong', action='store_true', 163 | default=False, 164 | help='') 165 | parser.add_argument('--class-id', type=int, 166 | default=0, 167 | help='') 168 | parser.add_argument('--is-ablation', type=bool, 169 | default=False, 170 | help='') 171 | args = parser.parse_args() 172 | 173 | torch.multiprocessing.set_start_method('spawn') 174 | 175 | # PATH variables 176 | PATH = os.path.dirname(os.path.abspath(__file__)) + '/' 177 | dataset = PATH + 'dataset/' 178 | os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True) 179 | os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True) 180 | 181 | exp_name = args.method 182 | exp_name += '_neg' if args.neg else '_pos' 183 | print(exp_name) 184 | 185 | if args.vis_class == 'index': 186 | args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name, 187 | args.vis_class, 188 | args.class_id)) 189 | else: 190 | ablation_fold = 'ablation' if args.is_ablation else 'not_ablation' 191 | args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name, 192 | args.vis_class, ablation_fold)) 193 | # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name, 194 | # args.vis_class)) 195 | 196 | if args.wrong: 197 | args.runs_dir += '_wrong' 198 | 199 | experiments = sorted(glob.glob(os.path.join(args.runs_dir, 'experiment_*'))) 200 | experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0 201 | args.experiment_dir = os.path.join(args.runs_dir, 'experiment_{}'.format(str(experiment_id))) 202 | os.makedirs(args.experiment_dir, exist_ok=True) 203 | 204 | cuda = torch.cuda.is_available() 205 | device = torch.device("cuda" if cuda else "cpu") 206 | 207 | if args.vis_class == 'index': 208 | vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(args.method, 209 | args.vis_class, 210 | args.class_id)) 211 | else: 212 | ablation_fold = 'ablation' if args.is_ablation else 'not_ablation' 213 | vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}'.format(args.method, 214 | args.vis_class, ablation_fold)) 215 | # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method, 216 | # args.vis_class)) 217 | 218 | # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method)) 219 | imagenet_ds = ImagenetResults(vis_method_dir) 220 | 221 | # Model 222 | model = vit_base_patch16_224(pretrained=True).cuda() 223 | model.eval() 224 | 225 | save_path = PATH + 'results/' 226 | 227 | sample_loader = torch.utils.data.DataLoader( 228 | imagenet_ds, 229 | batch_size=args.batch_size, 230 | num_workers=2, 231 | shuffle=False) 232 | 233 | eval(args) 234 | -------------------------------------------------------------------------------- /baselines/ViT/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /data/Imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | import cv2 6 | 7 | from torchvision.datasets import ImageNet 8 | 9 | from PIL import Image, ImageFilter 10 | import h5py 11 | from glob import glob 12 | 13 | 14 | class ImageNet_blur(ImageNet): 15 | def __getitem__(self, index): 16 | """ 17 | Args: 18 | index (int): Index 19 | 20 | Returns: 21 | tuple: (sample, target) where target is class_index of the target class. 22 | """ 23 | path, target = self.samples[index] 24 | sample = self.loader(path) 25 | 26 | gauss_blur = ImageFilter.GaussianBlur(11) 27 | median_blur = ImageFilter.MedianFilter(11) 28 | 29 | blurred_img1 = sample.filter(gauss_blur) 30 | blurred_img2 = sample.filter(median_blur) 31 | blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5) 32 | 33 | if self.transform is not None: 34 | sample = self.transform(sample) 35 | blurred_img = self.transform(blurred_img) 36 | if self.target_transform is not None: 37 | target = self.target_transform(target) 38 | 39 | return (sample, blurred_img), target 40 | 41 | 42 | class Imagenet_Segmentation(data.Dataset): 43 | CLASSES = 2 44 | 45 | def __init__(self, 46 | path, 47 | transform=None, 48 | target_transform=None): 49 | self.path = path 50 | self.transform = transform 51 | self.target_transform = target_transform 52 | # self.h5py = h5py.File(path, 'r+') 53 | self.h5py = None 54 | tmp = h5py.File(path, 'r') 55 | self.data_length = len(tmp['/value/img']) 56 | tmp.close() 57 | del tmp 58 | 59 | def __getitem__(self, index): 60 | 61 | if self.h5py is None: 62 | self.h5py = h5py.File(self.path, 'r') 63 | 64 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) 65 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) 66 | 67 | img = Image.fromarray(img).convert('RGB') 68 | target = Image.fromarray(target) 69 | 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | 73 | if self.target_transform is not None: 74 | target = np.array(self.target_transform(target)).astype('int32') 75 | target = torch.from_numpy(target).long() 76 | 77 | return img, target 78 | 79 | def __len__(self): 80 | # return len(self.h5py['/value/img']) 81 | return self.data_length 82 | 83 | 84 | class Imagenet_Segmentation_Blur(data.Dataset): 85 | CLASSES = 2 86 | 87 | def __init__(self, 88 | path, 89 | transform=None, 90 | target_transform=None): 91 | self.path = path 92 | self.transform = transform 93 | self.target_transform = target_transform 94 | # self.h5py = h5py.File(path, 'r+') 95 | self.h5py = None 96 | tmp = h5py.File(path, 'r') 97 | self.data_length = len(tmp['/value/img']) 98 | tmp.close() 99 | del tmp 100 | 101 | def __getitem__(self, index): 102 | 103 | if self.h5py is None: 104 | self.h5py = h5py.File(self.path, 'r') 105 | 106 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) 107 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) 108 | 109 | img = Image.fromarray(img).convert('RGB') 110 | target = Image.fromarray(target) 111 | 112 | gauss_blur = ImageFilter.GaussianBlur(11) 113 | median_blur = ImageFilter.MedianFilter(11) 114 | 115 | blurred_img1 = img.filter(gauss_blur) 116 | blurred_img2 = img.filter(median_blur) 117 | blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5) 118 | 119 | # blurred_img1 = cv2.GaussianBlur(img, (11, 11), 5) 120 | # blurred_img2 = np.float32(cv2.medianBlur(img, 11)) 121 | # blurred_img = (blurred_img1 + blurred_img2) / 2 122 | 123 | if self.transform is not None: 124 | img = self.transform(img) 125 | blurred_img = self.transform(blurred_img) 126 | 127 | if self.target_transform is not None: 128 | target = np.array(self.target_transform(target)).astype('int32') 129 | target = torch.from_numpy(target).long() 130 | 131 | return (img, blurred_img), target 132 | 133 | def __len__(self): 134 | # return len(self.h5py['/value/img']) 135 | return self.data_length 136 | 137 | 138 | class Imagenet_Segmentation_eval_dir(data.Dataset): 139 | CLASSES = 2 140 | 141 | def __init__(self, 142 | path, 143 | eval_path, 144 | transform=None, 145 | target_transform=None): 146 | self.transform = transform 147 | self.target_transform = target_transform 148 | self.h5py = h5py.File(path, 'r+') 149 | 150 | # 500 each file 151 | self.results = glob(os.path.join(eval_path, '*.npy')) 152 | 153 | def __getitem__(self, index): 154 | 155 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) 156 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) 157 | res = np.load(self.results[index]) 158 | 159 | img = Image.fromarray(img).convert('RGB') 160 | target = Image.fromarray(target) 161 | 162 | if self.transform is not None: 163 | img = self.transform(img) 164 | 165 | if self.target_transform is not None: 166 | target = np.array(self.target_transform(target)).astype('int32') 167 | target = torch.from_numpy(target).long() 168 | 169 | return img, target 170 | 171 | def __len__(self): 172 | return len(self.h5py['/value/img']) 173 | 174 | 175 | if __name__ == '__main__': 176 | import torchvision.transforms as transforms 177 | from tqdm import tqdm 178 | from imageio import imsave 179 | import scipy.io as sio 180 | 181 | # meta = sio.loadmat('/home/shirgur/ext/Data/Datasets/temp/ILSVRC2012_devkit_t12/data/meta.mat', squeeze_me=True)['synsets'] 182 | 183 | # Data 184 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 185 | std=[0.229, 0.224, 0.225]) 186 | test_img_trans = transforms.Compose([ 187 | transforms.Resize((224, 224)), 188 | transforms.ToTensor(), 189 | normalize, 190 | ]) 191 | test_lbl_trans = transforms.Compose([ 192 | transforms.Resize((224, 224), Image.NEAREST), 193 | ]) 194 | 195 | ds = Imagenet_Segmentation('/home/shirgur/ext/Data/Datasets/imagenet-seg/other/gtsegs_ijcv.mat', 196 | transform=test_img_trans, target_transform=test_lbl_trans) 197 | 198 | for i, (img, tgt) in enumerate(tqdm(ds)): 199 | tgt = (tgt.numpy() * 255).astype(np.uint8) 200 | imsave('/home/shirgur/ext/Code/C2S/run/imagenet/gt/{}.png'.format(i), tgt) 201 | 202 | print('here') 203 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/data/__init__.py -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | 6 | from PIL import Image 7 | import h5py 8 | 9 | __all__ = ['ImagenetResults'] 10 | 11 | 12 | class Imagenet_Segmentation(data.Dataset): 13 | CLASSES = 2 14 | 15 | def __init__(self, 16 | path, 17 | transform=None, 18 | target_transform=None): 19 | self.path = path 20 | self.transform = transform 21 | self.target_transform = target_transform 22 | self.h5py = None 23 | tmp = h5py.File(path, 'r') 24 | self.data_length = len(tmp['/value/img']) 25 | tmp.close() 26 | del tmp 27 | 28 | def __getitem__(self, index): 29 | 30 | if self.h5py is None: 31 | self.h5py = h5py.File(self.path, 'r') 32 | 33 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) 34 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) 35 | 36 | img = Image.fromarray(img).convert('RGB') 37 | target = Image.fromarray(target) 38 | 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | 42 | if self.target_transform is not None: 43 | target = np.array(self.target_transform(target)).astype('int32') 44 | target = torch.from_numpy(target).long() 45 | 46 | return img, target 47 | 48 | def __len__(self): 49 | return self.data_length 50 | 51 | 52 | class ImagenetResults(data.Dataset): 53 | def __init__(self, path): 54 | super(ImagenetResults, self).__init__() 55 | 56 | self.path = os.path.join(path, 'results.hdf5') 57 | self.data = None 58 | 59 | print('Reading dataset length...') 60 | with h5py.File(self.path, 'r') as f: 61 | self.data_length = len(f['/image']) 62 | 63 | def __len__(self): 64 | return self.data_length 65 | 66 | def __getitem__(self, item): 67 | if self.data is None: 68 | self.data = h5py.File(self.path, 'r') 69 | 70 | image = torch.tensor(self.data['image'][item]) 71 | vis = torch.tensor(self.data['vis'][item]) 72 | target = torch.tensor(self.data['target'][item]).long() 73 | 74 | return image, vis, target 75 | -------------------------------------------------------------------------------- /dataset/expl_hdf5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | import os 5 | 6 | 7 | class ImagenetResults(Dataset): 8 | def __init__(self, path): 9 | super(ImagenetResults, self).__init__() 10 | 11 | self.path = os.path.join(path, 'results.hdf5') 12 | self.data = None 13 | 14 | print('Reading dataset length...') 15 | with h5py.File(self.path , 'r') as f: 16 | # tmp = h5py.File(self.path , 'r') 17 | self.data_length = len(f['/image']) 18 | 19 | def __len__(self): 20 | return self.data_length 21 | 22 | def __getitem__(self, item): 23 | if self.data is None: 24 | self.data = h5py.File(self.path, 'r') 25 | 26 | image = torch.tensor(self.data['image'][item]) 27 | vis = torch.tensor(self.data['vis'][item]) 28 | target = torch.tensor(self.data['target'][item]).long() 29 | 30 | return image, vis, target 31 | 32 | 33 | if __name__ == '__main__': 34 | from utils import render 35 | import imageio 36 | import numpy as np 37 | 38 | ds = ImagenetResults('../visualizations/fullgrad') 39 | sample_loader = torch.utils.data.DataLoader( 40 | ds, 41 | batch_size=5, 42 | shuffle=False) 43 | 44 | iterator = iter(sample_loader) 45 | image, vis, target = next(iterator) 46 | 47 | maps = (render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8) 48 | 49 | # imageio.imsave('../delete_hm.jpg', maps) 50 | 51 | print(len(ds)) -------------------------------------------------------------------------------- /example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/example.PNG -------------------------------------------------------------------------------- /method-page-001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/method-page-001.jpg -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/modules/__init__.py -------------------------------------------------------------------------------- /modules/layers_lrp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', 6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', 7 | 'LayerNorm', 'AddEye'] 8 | 9 | 10 | def safe_divide(a, b): 11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9) 12 | den = den + den.eq(0).type(den.type()) * 1e-9 13 | return a / den * b.ne(0).type(b.type()) 14 | 15 | 16 | def forward_hook(self, input, output): 17 | if type(input[0]) in (list, tuple): 18 | self.X = [] 19 | for i in input[0]: 20 | x = i.detach() 21 | x.requires_grad = True 22 | self.X.append(x) 23 | else: 24 | self.X = input[0].detach() 25 | self.X.requires_grad = True 26 | 27 | self.Y = output 28 | 29 | 30 | def backward_hook(self, grad_input, grad_output): 31 | self.grad_input = grad_input 32 | self.grad_output = grad_output 33 | 34 | 35 | class RelProp(nn.Module): 36 | def __init__(self): 37 | super(RelProp, self).__init__() 38 | # if not self.training: 39 | self.register_forward_hook(forward_hook) 40 | 41 | def gradprop(self, Z, X, S): 42 | C = torch.autograd.grad(Z, X, S, retain_graph=True) 43 | return C 44 | 45 | def relprop(self, R, alpha): 46 | return R 47 | 48 | 49 | class RelPropSimple(RelProp): 50 | def relprop(self, R, alpha): 51 | Z = self.forward(self.X) 52 | S = safe_divide(R, Z) 53 | C = self.gradprop(Z, self.X, S) 54 | 55 | if torch.is_tensor(self.X) == False: 56 | outputs = [] 57 | outputs.append(self.X[0] * C[0]) 58 | outputs.append(self.X[1] * C[1]) 59 | else: 60 | outputs = self.X * (C[0]) 61 | return outputs 62 | 63 | class AddEye(RelPropSimple): 64 | # input of shape B, C, seq_len, seq_len 65 | def forward(self, input): 66 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) 67 | 68 | class ReLU(nn.ReLU, RelProp): 69 | pass 70 | 71 | class GELU(nn.GELU, RelProp): 72 | pass 73 | 74 | class Softmax(nn.Softmax, RelProp): 75 | pass 76 | 77 | class LayerNorm(nn.LayerNorm, RelProp): 78 | pass 79 | 80 | class Dropout(nn.Dropout, RelProp): 81 | pass 82 | 83 | 84 | class MaxPool2d(nn.MaxPool2d, RelPropSimple): 85 | pass 86 | 87 | class LayerNorm(nn.LayerNorm, RelProp): 88 | pass 89 | 90 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): 91 | pass 92 | 93 | 94 | class AvgPool2d(nn.AvgPool2d, RelPropSimple): 95 | pass 96 | 97 | 98 | class Add(RelPropSimple): 99 | def forward(self, inputs): 100 | return torch.add(*inputs) 101 | 102 | class einsum(RelPropSimple): 103 | def __init__(self, equation): 104 | super().__init__() 105 | self.equation = equation 106 | def forward(self, *operands): 107 | return torch.einsum(self.equation, *operands) 108 | 109 | class IndexSelect(RelProp): 110 | def forward(self, inputs, dim, indices): 111 | self.__setattr__('dim', dim) 112 | self.__setattr__('indices', indices) 113 | 114 | return torch.index_select(inputs, dim, indices) 115 | 116 | def relprop(self, R, alpha): 117 | Z = self.forward(self.X, self.dim, self.indices) 118 | S = safe_divide(R, Z) 119 | C = self.gradprop(Z, self.X, S) 120 | 121 | if torch.is_tensor(self.X) == False: 122 | outputs = [] 123 | outputs.append(self.X[0] * C[0]) 124 | outputs.append(self.X[1] * C[1]) 125 | else: 126 | outputs = self.X * (C[0]) 127 | return outputs 128 | 129 | 130 | 131 | class Clone(RelProp): 132 | def forward(self, input, num): 133 | self.__setattr__('num', num) 134 | outputs = [] 135 | for _ in range(num): 136 | outputs.append(input) 137 | 138 | return outputs 139 | 140 | def relprop(self, R, alpha): 141 | Z = [] 142 | for _ in range(self.num): 143 | Z.append(self.X) 144 | S = [safe_divide(r, z) for r, z in zip(R, Z)] 145 | C = self.gradprop(Z, self.X, S)[0] 146 | 147 | R = self.X * C 148 | 149 | return R 150 | 151 | class Cat(RelProp): 152 | def forward(self, inputs, dim): 153 | self.__setattr__('dim', dim) 154 | return torch.cat(inputs, dim) 155 | 156 | def relprop(self, R, alpha): 157 | Z = self.forward(self.X, self.dim) 158 | S = safe_divide(R, Z) 159 | C = self.gradprop(Z, self.X, S) 160 | 161 | outputs = [] 162 | for x, c in zip(self.X, C): 163 | outputs.append(x * c) 164 | 165 | return outputs 166 | 167 | 168 | class Sequential(nn.Sequential): 169 | def relprop(self, R, alpha): 170 | for m in reversed(self._modules.values()): 171 | R = m.relprop(R, alpha) 172 | return R 173 | 174 | 175 | class BatchNorm2d(nn.BatchNorm2d, RelProp): 176 | def relprop(self, R, alpha): 177 | X = self.X 178 | beta = 1 - alpha 179 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( 180 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) 181 | Z = X * weight + 1e-9 182 | S = R / Z 183 | Ca = S * weight 184 | R = self.X * (Ca) 185 | return R 186 | 187 | 188 | class Linear(nn.Linear, RelProp): 189 | def relprop(self, R, alpha): 190 | beta = alpha - 1 191 | pw = torch.clamp(self.weight, min=0) 192 | nw = torch.clamp(self.weight, max=0) 193 | px = torch.clamp(self.X, min=0) 194 | nx = torch.clamp(self.X, max=0) 195 | 196 | def f(w1, w2, x1, x2): 197 | Z1 = F.linear(x1, w1) 198 | Z2 = F.linear(x2, w2) 199 | S1 = safe_divide(R, Z1) 200 | S2 = safe_divide(R, Z2) 201 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0] 202 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0] 203 | 204 | return C1 + C2 205 | 206 | activator_relevances = f(pw, nw, px, nx) 207 | inhibitor_relevances = f(nw, pw, px, nx) 208 | 209 | R = alpha * activator_relevances - beta * inhibitor_relevances 210 | 211 | return R 212 | 213 | 214 | class Conv2d(nn.Conv2d, RelProp): 215 | def gradprop2(self, DY, weight): 216 | Z = self.forward(self.X) 217 | 218 | output_padding = self.X.size()[2] - ( 219 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) 220 | 221 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) 222 | 223 | def relprop(self, R, alpha): 224 | if self.X.shape[1] == 3: 225 | pw = torch.clamp(self.weight, min=0) 226 | nw = torch.clamp(self.weight, max=0) 227 | X = self.X 228 | L = self.X * 0 + \ 229 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 230 | keepdim=True)[0] 231 | H = self.X * 0 + \ 232 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 233 | keepdim=True)[0] 234 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ 235 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ 236 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 237 | 238 | S = R / Za 239 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) 240 | R = C 241 | else: 242 | beta = alpha - 1 243 | pw = torch.clamp(self.weight, min=0) 244 | nw = torch.clamp(self.weight, max=0) 245 | px = torch.clamp(self.X, min=0) 246 | nx = torch.clamp(self.X, max=0) 247 | 248 | def f(w1, w2, x1, x2): 249 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) 250 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) 251 | S1 = safe_divide(R, Z1) 252 | S2 = safe_divide(R, Z2) 253 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 254 | C2 = x2 * self.gradprop(Z2, x2, S2)[0] 255 | return C1 + C2 256 | 257 | activator_relevances = f(pw, nw, px, nx) 258 | inhibitor_relevances = f(nw, pw, px, nx) 259 | 260 | R = alpha * activator_relevances - beta * inhibitor_relevances 261 | return R -------------------------------------------------------------------------------- /modules/layers_ours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', 6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', 7 | 'LayerNorm', 'AddEye'] 8 | 9 | 10 | def safe_divide(a, b): 11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9) 12 | den = den + den.eq(0).type(den.type()) * 1e-9 13 | return a / den * b.ne(0).type(b.type()) 14 | 15 | 16 | def forward_hook(self, input, output): 17 | if type(input[0]) in (list, tuple): 18 | self.X = [] 19 | for i in input[0]: 20 | x = i.detach() 21 | x.requires_grad = True 22 | self.X.append(x) 23 | else: 24 | self.X = input[0].detach() 25 | self.X.requires_grad = True 26 | 27 | self.Y = output 28 | 29 | 30 | def backward_hook(self, grad_input, grad_output): 31 | self.grad_input = grad_input 32 | self.grad_output = grad_output 33 | 34 | 35 | class RelProp(nn.Module): 36 | def __init__(self): 37 | super(RelProp, self).__init__() 38 | # if not self.training: 39 | self.register_forward_hook(forward_hook) 40 | 41 | def gradprop(self, Z, X, S): 42 | C = torch.autograd.grad(Z, X, S, retain_graph=True) 43 | return C 44 | 45 | def relprop(self, R, alpha): 46 | return R 47 | 48 | class RelPropSimple(RelProp): 49 | def relprop(self, R, alpha): 50 | Z = self.forward(self.X) 51 | S = safe_divide(R, Z) 52 | C = self.gradprop(Z, self.X, S) 53 | 54 | if torch.is_tensor(self.X) == False: 55 | outputs = [] 56 | outputs.append(self.X[0] * C[0]) 57 | outputs.append(self.X[1] * C[1]) 58 | else: 59 | outputs = self.X * (C[0]) 60 | return outputs 61 | 62 | class AddEye(RelPropSimple): 63 | # input of shape B, C, seq_len, seq_len 64 | def forward(self, input): 65 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) 66 | 67 | class ReLU(nn.ReLU, RelProp): 68 | pass 69 | 70 | class GELU(nn.GELU, RelProp): 71 | pass 72 | 73 | class Softmax(nn.Softmax, RelProp): 74 | pass 75 | 76 | class LayerNorm(nn.LayerNorm, RelProp): 77 | pass 78 | 79 | class Dropout(nn.Dropout, RelProp): 80 | pass 81 | 82 | 83 | class MaxPool2d(nn.MaxPool2d, RelPropSimple): 84 | pass 85 | 86 | class LayerNorm(nn.LayerNorm, RelProp): 87 | pass 88 | 89 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): 90 | pass 91 | 92 | 93 | class AvgPool2d(nn.AvgPool2d, RelPropSimple): 94 | pass 95 | 96 | 97 | class Add(RelPropSimple): 98 | def forward(self, inputs): 99 | return torch.add(*inputs) 100 | 101 | def relprop(self, R, alpha): 102 | Z = self.forward(self.X) 103 | S = safe_divide(R, Z) 104 | C = self.gradprop(Z, self.X, S) 105 | 106 | a = self.X[0] * C[0] 107 | b = self.X[1] * C[1] 108 | 109 | a_sum = a.sum() 110 | b_sum = b.sum() 111 | 112 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 113 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 114 | 115 | a = a * safe_divide(a_fact, a.sum()) 116 | b = b * safe_divide(b_fact, b.sum()) 117 | 118 | outputs = [a, b] 119 | 120 | return outputs 121 | 122 | class einsum(RelPropSimple): 123 | def __init__(self, equation): 124 | super().__init__() 125 | self.equation = equation 126 | def forward(self, *operands): 127 | return torch.einsum(self.equation, *operands) 128 | 129 | class IndexSelect(RelProp): 130 | def forward(self, inputs, dim, indices): 131 | self.__setattr__('dim', dim) 132 | self.__setattr__('indices', indices) 133 | 134 | return torch.index_select(inputs, dim, indices) 135 | 136 | def relprop(self, R, alpha): 137 | Z = self.forward(self.X, self.dim, self.indices) 138 | S = safe_divide(R, Z) 139 | C = self.gradprop(Z, self.X, S) 140 | 141 | if torch.is_tensor(self.X) == False: 142 | outputs = [] 143 | outputs.append(self.X[0] * C[0]) 144 | outputs.append(self.X[1] * C[1]) 145 | else: 146 | outputs = self.X * (C[0]) 147 | return outputs 148 | 149 | 150 | 151 | class Clone(RelProp): 152 | def forward(self, input, num): 153 | self.__setattr__('num', num) 154 | outputs = [] 155 | for _ in range(num): 156 | outputs.append(input) 157 | 158 | return outputs 159 | 160 | def relprop(self, R, alpha): 161 | Z = [] 162 | for _ in range(self.num): 163 | Z.append(self.X) 164 | S = [safe_divide(r, z) for r, z in zip(R, Z)] 165 | C = self.gradprop(Z, self.X, S)[0] 166 | 167 | R = self.X * C 168 | 169 | return R 170 | 171 | class Cat(RelProp): 172 | def forward(self, inputs, dim): 173 | self.__setattr__('dim', dim) 174 | return torch.cat(inputs, dim) 175 | 176 | def relprop(self, R, alpha): 177 | Z = self.forward(self.X, self.dim) 178 | S = safe_divide(R, Z) 179 | C = self.gradprop(Z, self.X, S) 180 | 181 | outputs = [] 182 | for x, c in zip(self.X, C): 183 | outputs.append(x * c) 184 | 185 | return outputs 186 | 187 | 188 | class Sequential(nn.Sequential): 189 | def relprop(self, R, alpha): 190 | for m in reversed(self._modules.values()): 191 | R = m.relprop(R, alpha) 192 | return R 193 | 194 | class BatchNorm2d(nn.BatchNorm2d, RelProp): 195 | def relprop(self, R, alpha): 196 | X = self.X 197 | beta = 1 - alpha 198 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( 199 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) 200 | Z = X * weight + 1e-9 201 | S = R / Z 202 | Ca = S * weight 203 | R = self.X * (Ca) 204 | return R 205 | 206 | 207 | class Linear(nn.Linear, RelProp): 208 | def relprop(self, R, alpha): 209 | beta = alpha - 1 210 | pw = torch.clamp(self.weight, min=0) 211 | nw = torch.clamp(self.weight, max=0) 212 | px = torch.clamp(self.X, min=0) 213 | nx = torch.clamp(self.X, max=0) 214 | 215 | def f(w1, w2, x1, x2): 216 | Z1 = F.linear(x1, w1) 217 | Z2 = F.linear(x2, w2) 218 | S1 = safe_divide(R, Z1 + Z2) 219 | S2 = safe_divide(R, Z1 + Z2) 220 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0] 221 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0] 222 | 223 | return C1 + C2 224 | 225 | activator_relevances = f(pw, nw, px, nx) 226 | inhibitor_relevances = f(nw, pw, px, nx) 227 | 228 | R = alpha * activator_relevances - beta * inhibitor_relevances 229 | 230 | return R 231 | 232 | 233 | class Conv2d(nn.Conv2d, RelProp): 234 | def gradprop2(self, DY, weight): 235 | Z = self.forward(self.X) 236 | 237 | output_padding = self.X.size()[2] - ( 238 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) 239 | 240 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) 241 | 242 | def relprop(self, R, alpha): 243 | if self.X.shape[1] == 3: 244 | pw = torch.clamp(self.weight, min=0) 245 | nw = torch.clamp(self.weight, max=0) 246 | X = self.X 247 | L = self.X * 0 + \ 248 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 249 | keepdim=True)[0] 250 | H = self.X * 0 + \ 251 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 252 | keepdim=True)[0] 253 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ 254 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ 255 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 256 | 257 | S = R / Za 258 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) 259 | R = C 260 | else: 261 | beta = alpha - 1 262 | pw = torch.clamp(self.weight, min=0) 263 | nw = torch.clamp(self.weight, max=0) 264 | px = torch.clamp(self.X, min=0) 265 | nx = torch.clamp(self.X, max=0) 266 | 267 | def f(w1, w2, x1, x2): 268 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) 269 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) 270 | S1 = safe_divide(R, Z1) 271 | S2 = safe_divide(R, Z2) 272 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 273 | C2 = x2 * self.gradprop(Z2, x2, S2)[0] 274 | return C1 + C2 275 | 276 | activator_relevances = f(pw, nw, px, nx) 277 | inhibitor_relevances = f(nw, pw, px, nx) 278 | 279 | R = alpha * activator_relevances - beta * inhibitor_relevances 280 | return R 281 | -------------------------------------------------------------------------------- /new_work.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/new_work.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=8.1.1 2 | einops == 0.3.0 3 | h5py == 2.8.0 4 | imageio == 2.9.0 5 | matplotlib == 3.3.2 6 | opencv_python 7 | scikit_image == 0.17.2 8 | scipy == 1.5.2 9 | sklearn 10 | torch == 1.7.0 11 | torchvision == 0.8.1 12 | tqdm == 4.51.0 13 | transformers == 3.5.1 14 | utils == 1.0.1 15 | Pygments>=2.7.4 16 | -------------------------------------------------------------------------------- /samples/catdog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/catdog.png -------------------------------------------------------------------------------- /samples/dogbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/dogbird.png -------------------------------------------------------------------------------- /samples/dogcat2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/dogcat2.png -------------------------------------------------------------------------------- /samples/el1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el1.png -------------------------------------------------------------------------------- /samples/el2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el2.png -------------------------------------------------------------------------------- /samples/el3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el3.png -------------------------------------------------------------------------------- /samples/el4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el4.png -------------------------------------------------------------------------------- /samples/el5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el5.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/utils/__init__.py -------------------------------------------------------------------------------- /utils/confusionmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from . import metric 4 | 5 | 6 | class ConfusionMatrix(metric.Metric): 7 | """Constructs a confusion matrix for a multi-class classification problems. 8 | Does not support multi-label, multi-class problems. 9 | Keyword arguments: 10 | - num_classes (int): number of classes in the classification problem. 11 | - normalized (boolean, optional): Determines whether or not the confusion 12 | matrix is normalized or not. Default: False. 13 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py 14 | """ 15 | 16 | def __init__(self, num_classes, normalized=False): 17 | super().__init__() 18 | 19 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) 20 | self.normalized = normalized 21 | self.num_classes = num_classes 22 | self.reset() 23 | 24 | def reset(self): 25 | self.conf.fill(0) 26 | 27 | def add(self, predicted, target): 28 | """Computes the confusion matrix 29 | The shape of the confusion matrix is K x K, where K is the number 30 | of classes. 31 | Keyword arguments: 32 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of 33 | predicted scores obtained from the model for N examples and K classes, 34 | or an N-tensor/array of integer values between 0 and K-1. 35 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of 36 | ground-truth classes for N examples and K classes, or an N-tensor/array 37 | of integer values between 0 and K-1. 38 | """ 39 | # If target and/or predicted are tensors, convert them to numpy arrays 40 | if torch.is_tensor(predicted): 41 | predicted = predicted.cpu().numpy() 42 | if torch.is_tensor(target): 43 | target = target.cpu().numpy() 44 | 45 | assert predicted.shape[0] == target.shape[0], \ 46 | 'number of targets and predicted outputs do not match' 47 | 48 | if np.ndim(predicted) != 1: 49 | assert predicted.shape[1] == self.num_classes, \ 50 | 'number of predictions does not match size of confusion matrix' 51 | predicted = np.argmax(predicted, 1) 52 | else: 53 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ 54 | 'predicted values are not between 0 and k-1' 55 | 56 | if np.ndim(target) != 1: 57 | assert target.shape[1] == self.num_classes, \ 58 | 'Onehot target does not match size of confusion matrix' 59 | assert (target >= 0).all() and (target <= 1).all(), \ 60 | 'in one-hot encoding, target values should be 0 or 1' 61 | assert (target.sum(1) == 1).all(), \ 62 | 'multi-label setting is not supported' 63 | target = np.argmax(target, 1) 64 | else: 65 | assert (target.max() < self.num_classes) and (target.min() >= 0), \ 66 | 'target values are not between 0 and k-1' 67 | 68 | # hack for bincounting 2 arrays together 69 | x = predicted + self.num_classes * target 70 | bincount_2d = np.bincount( 71 | x.astype(np.int32), minlength=self.num_classes**2) 72 | assert bincount_2d.size == self.num_classes**2 73 | conf = bincount_2d.reshape((self.num_classes, self.num_classes)) 74 | 75 | self.conf += conf 76 | 77 | def value(self): 78 | """ 79 | Returns: 80 | Confustion matrix of K rows and K columns, where rows corresponds 81 | to ground-truth targets and columns corresponds to predicted 82 | targets. 83 | """ 84 | if self.normalized: 85 | conf = self.conf.astype(np.float32) 86 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 87 | else: 88 | return self.conf -------------------------------------------------------------------------------- /utils/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from . import metric 4 | from .confusionmatrix import ConfusionMatrix 5 | 6 | 7 | class IoU(metric.Metric): 8 | """Computes the intersection over union (IoU) per class and corresponding 9 | mean (mIoU). 10 | 11 | Intersection over union (IoU) is a common evaluation metric for semantic 12 | segmentation. The predictions are first accumulated in a confusion matrix 13 | and the IoU is computed from it as follows: 14 | 15 | IoU = true_positive / (true_positive + false_positive + false_negative). 16 | 17 | Keyword arguments: 18 | - num_classes (int): number of classes in the classification problem 19 | - normalized (boolean, optional): Determines whether or not the confusion 20 | matrix is normalized or not. Default: False. 21 | - ignore_index (int or iterable, optional): Index of the classes to ignore 22 | when computing the IoU. Can be an int, or any iterable of ints. 23 | """ 24 | 25 | def __init__(self, num_classes, normalized=False, ignore_index=None): 26 | super().__init__() 27 | self.conf_metric = ConfusionMatrix(num_classes, normalized) 28 | 29 | if ignore_index is None: 30 | self.ignore_index = None 31 | elif isinstance(ignore_index, int): 32 | self.ignore_index = (ignore_index,) 33 | else: 34 | try: 35 | self.ignore_index = tuple(ignore_index) 36 | except TypeError: 37 | raise ValueError("'ignore_index' must be an int or iterable") 38 | 39 | def reset(self): 40 | self.conf_metric.reset() 41 | 42 | def add(self, predicted, target): 43 | """Adds the predicted and target pair to the IoU metric. 44 | 45 | Keyword arguments: 46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of 47 | predicted scores obtained from the model for N examples and K classes, 48 | or (N, H, W) tensor of integer values between 0 and K-1. 49 | - target (Tensor): Can be a (N, K, H, W) tensor of 50 | target scores for N examples and K classes, or (N, H, W) tensor of 51 | integer values between 0 and K-1. 52 | 53 | """ 54 | # Dimensions check 55 | assert predicted.size(0) == target.size(0), \ 56 | 'number of targets and predicted outputs do not match' 57 | assert predicted.dim() == 3 or predicted.dim() == 4, \ 58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)" 59 | assert target.dim() == 3 or target.dim() == 4, \ 60 | "targets must be of dimension (N, H, W) or (N, K, H, W)" 61 | 62 | # If the tensor is in categorical format convert it to integer format 63 | if predicted.dim() == 4: 64 | _, predicted = predicted.max(1) 65 | if target.dim() == 4: 66 | _, target = target.max(1) 67 | 68 | self.conf_metric.add(predicted.view(-1), target.view(-1)) 69 | 70 | def value(self): 71 | """Computes the IoU and mean IoU. 72 | 73 | The mean computation ignores NaN elements of the IoU array. 74 | 75 | Returns: 76 | Tuple: (IoU, mIoU). The first output is the per class IoU, 77 | for K classes it's numpy.ndarray with K elements. The second output, 78 | is the mean IoU. 79 | """ 80 | conf_matrix = self.conf_metric.value() 81 | if self.ignore_index is not None: 82 | for index in self.ignore_index: 83 | conf_matrix[:, self.ignore_index] = 0 84 | conf_matrix[self.ignore_index, :] = 0 85 | true_positive = np.diag(conf_matrix) 86 | false_positive = np.sum(conf_matrix, 0) - true_positive 87 | false_negative = np.sum(conf_matrix, 1) - true_positive 88 | 89 | # Just in case we get a division by 0, ignore/hide the error 90 | with np.errstate(divide='ignore', invalid='ignore'): 91 | iou = true_positive / (true_positive + false_positive + false_negative) 92 | 93 | return iou, np.nanmean(iou) -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | class Metric(object): 2 | """Base class for all metrics. 3 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py 4 | """ 5 | def reset(self): 6 | pass 7 | 8 | def add(self): 9 | pass 10 | 11 | def value(self): 12 | pass -------------------------------------------------------------------------------- /utils/metrices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import f1_score, average_precision_score 4 | from sklearn.metrics import precision_recall_curve, roc_curve 5 | 6 | SMOOTH = 1e-6 7 | __all__ = ['get_f1_scores', 'get_ap_scores', 'batch_pix_accuracy', 'batch_intersection_union', 'get_iou', 'get_pr', 8 | 'get_roc', 'get_ap_multiclass'] 9 | 10 | 11 | def get_iou(outputs: torch.Tensor, labels: torch.Tensor): 12 | # You can comment out this line if you are passing tensors of equal shape 13 | # But if you are passing output from UNet or something it will most probably 14 | # be with the BATCH x 1 x H x W shape 15 | outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W 16 | labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W 17 | 18 | intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 19 | union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0 20 | 21 | iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 22 | 23 | return iou.cpu().numpy() 24 | 25 | 26 | def get_f1_scores(predict, target, ignore_index=-1): 27 | # Tensor process 28 | batch_size = predict.shape[0] 29 | predict = predict.data.cpu().numpy().reshape(-1) 30 | target = target.data.cpu().numpy().reshape(-1) 31 | pb = predict[target != ignore_index].reshape(batch_size, -1) 32 | tb = target[target != ignore_index].reshape(batch_size, -1) 33 | 34 | total = [] 35 | for p, t in zip(pb, tb): 36 | total.append(np.nan_to_num(f1_score(t, p))) 37 | 38 | return total 39 | 40 | 41 | def get_roc(predict, target, ignore_index=-1): 42 | target_expand = target.unsqueeze(1).expand_as(predict) 43 | target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1) 44 | # Tensor process 45 | x = torch.zeros_like(target_expand) 46 | t = target.unsqueeze(1).clamp(min=0) 47 | target_1hot = x.scatter_(1, t, 1) 48 | batch_size = predict.shape[0] 49 | predict = predict.data.cpu().numpy().reshape(-1) 50 | target = target_1hot.data.cpu().numpy().reshape(-1) 51 | pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1) 52 | tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1) 53 | 54 | total = [] 55 | for p, t in zip(pb, tb): 56 | total.append(roc_curve(t, p)) 57 | 58 | return total 59 | 60 | 61 | def get_pr(predict, target, ignore_index=-1): 62 | target_expand = target.unsqueeze(1).expand_as(predict) 63 | target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1) 64 | # Tensor process 65 | x = torch.zeros_like(target_expand) 66 | t = target.unsqueeze(1).clamp(min=0) 67 | target_1hot = x.scatter_(1, t, 1) 68 | batch_size = predict.shape[0] 69 | predict = predict.data.cpu().numpy().reshape(-1) 70 | target = target_1hot.data.cpu().numpy().reshape(-1) 71 | pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1) 72 | tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1) 73 | 74 | total = [] 75 | for p, t in zip(pb, tb): 76 | total.append(precision_recall_curve(t, p)) 77 | 78 | return total 79 | 80 | 81 | def get_ap_scores(predict, target, ignore_index=-1): 82 | total = [] 83 | for pred, tgt in zip(predict, target): 84 | target_expand = tgt.unsqueeze(0).expand_as(pred) 85 | target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1) 86 | 87 | # Tensor process 88 | x = torch.zeros_like(target_expand) 89 | t = tgt.unsqueeze(0).clamp(min=0).long() 90 | target_1hot = x.scatter_(0, t, 1) 91 | predict_flat = pred.data.cpu().numpy().reshape(-1) 92 | target_flat = target_1hot.data.cpu().numpy().reshape(-1) 93 | 94 | p = predict_flat[target_expand_numpy != ignore_index] 95 | t = target_flat[target_expand_numpy != ignore_index] 96 | 97 | total.append(np.nan_to_num(average_precision_score(t, p))) 98 | 99 | return total 100 | 101 | 102 | def get_ap_multiclass(predict, target): 103 | total = [] 104 | for pred, tgt in zip(predict, target): 105 | predict_flat = pred.data.cpu().numpy().reshape(-1) 106 | target_flat = tgt.data.cpu().numpy().reshape(-1) 107 | 108 | total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat))) 109 | 110 | return total 111 | 112 | 113 | def batch_precision_recall(predict, target, thr=0.5): 114 | """Batch Precision Recall 115 | Args: 116 | predict: input 4D tensor 117 | target: label 4D tensor 118 | """ 119 | # _, predict = torch.max(predict, 1) 120 | 121 | predict = predict > thr 122 | predict = predict.data.cpu().numpy() + 1 123 | target = target.data.cpu().numpy() + 1 124 | 125 | tp = np.sum(((predict == 2) * (target == 2)) * (target > 0)) 126 | fp = np.sum(((predict == 2) * (target == 1)) * (target > 0)) 127 | fn = np.sum(((predict == 1) * (target == 2)) * (target > 0)) 128 | 129 | precision = float(np.nan_to_num(tp / (tp + fp))) 130 | recall = float(np.nan_to_num(tp / (tp + fn))) 131 | 132 | return precision, recall 133 | 134 | 135 | def batch_pix_accuracy(predict, target): 136 | """Batch Pixel Accuracy 137 | Args: 138 | predict: input 3D tensor 139 | target: label 3D tensor 140 | """ 141 | 142 | # for thr in np.linspace(0, 1, slices): 143 | 144 | _, predict = torch.max(predict, 0) 145 | predict = predict.cpu().numpy() + 1 146 | target = target.cpu().numpy() + 1 147 | pixel_labeled = np.sum(target > 0) 148 | pixel_correct = np.sum((predict == target) * (target > 0)) 149 | assert pixel_correct <= pixel_labeled, \ 150 | "Correct area should be smaller than Labeled" 151 | return pixel_correct, pixel_labeled 152 | 153 | 154 | def batch_intersection_union(predict, target, nclass): 155 | """Batch Intersection of Union 156 | Args: 157 | predict: input 3D tensor 158 | target: label 3D tensor 159 | nclass: number of categories (int) 160 | """ 161 | _, predict = torch.max(predict, 0) 162 | mini = 1 163 | maxi = nclass 164 | nbins = nclass 165 | predict = predict.cpu().numpy() + 1 166 | target = target.cpu().numpy() + 1 167 | 168 | predict = predict * (target > 0).astype(predict.dtype) 169 | intersection = predict * (predict == target) 170 | # areas of intersection and union 171 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 172 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 173 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 174 | area_union = area_pred + area_lab - area_inter 175 | assert (area_inter <= area_union).all(), \ 176 | "Intersection area should be smaller than Union area" 177 | return area_inter, area_union 178 | 179 | 180 | # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py 181 | def pixel_accuracy(im_pred, im_lab): 182 | im_pred = np.asarray(im_pred) 183 | im_lab = np.asarray(im_lab) 184 | 185 | # Remove classes from unlabeled pixels in gt image. 186 | # We should not penalize detections in unlabeled portions of the image. 187 | pixel_labeled = np.sum(im_lab > 0) 188 | pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0)) 189 | # pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 190 | return pixel_correct, pixel_labeled 191 | 192 | 193 | def intersection_and_union(im_pred, im_lab, num_class): 194 | im_pred = np.asarray(im_pred) 195 | im_lab = np.asarray(im_lab) 196 | # Remove classes from unlabeled pixels in gt image. 197 | im_pred = im_pred * (im_lab > 0) 198 | # Compute area intersection: 199 | intersection = im_pred * (im_pred == im_lab) 200 | area_inter, _ = np.histogram(intersection, bins=num_class - 1, 201 | range=(1, num_class - 1)) 202 | # Compute area union: 203 | area_pred, _ = np.histogram(im_pred, bins=num_class - 1, 204 | range=(1, num_class - 1)) 205 | area_lab, _ = np.histogram(im_lab, bins=num_class - 1, 206 | range=(1, num_class - 1)) 207 | area_union = area_pred + area_lab - area_inter 208 | return area_inter, area_union 209 | -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Encoding Data Parallel""" 12 | import threading 13 | import functools 14 | import torch 15 | from torch.autograd import Variable, Function 16 | import torch.cuda.comm as comm 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.nn.parallel.parallel_apply import get_a_var 19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 20 | 21 | torch_ver = torch.__version__[:3] 22 | 23 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 24 | 'patch_replication_callback'] 25 | 26 | def allreduce(*inputs): 27 | """Cross GPU all reduce autograd operation for calculate mean and 28 | variance in SyncBN. 29 | """ 30 | return AllReduce.apply(*inputs) 31 | 32 | class AllReduce(Function): 33 | @staticmethod 34 | def forward(ctx, num_inputs, *inputs): 35 | ctx.num_inputs = num_inputs 36 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 37 | inputs = [inputs[i:i + num_inputs] 38 | for i in range(0, len(inputs), num_inputs)] 39 | # sort before reduce sum 40 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 41 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 42 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 43 | return tuple([t for tensors in outputs for t in tensors]) 44 | 45 | @staticmethod 46 | def backward(ctx, *inputs): 47 | inputs = [i.data for i in inputs] 48 | inputs = [inputs[i:i + ctx.num_inputs] 49 | for i in range(0, len(inputs), ctx.num_inputs)] 50 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 51 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 52 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 53 | 54 | 55 | class Reduce(Function): 56 | @staticmethod 57 | def forward(ctx, *inputs): 58 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 59 | inputs = sorted(inputs, key=lambda i: i.get_device()) 60 | return comm.reduce_add(inputs) 61 | 62 | @staticmethod 63 | def backward(ctx, gradOutput): 64 | return Broadcast.apply(ctx.target_gpus, gradOutput) 65 | 66 | 67 | class DataParallelModel(DataParallel): 68 | """Implements data parallelism at the module level. 69 | 70 | This container parallelizes the application of the given module by 71 | splitting the input across the specified devices by chunking in the 72 | batch dimension. 73 | In the forward pass, the module is replicated on each device, 74 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. 75 | Note that the outputs are not gathered, please use compatible 76 | :class:`encoding.parallel.DataParallelCriterion`. 77 | 78 | The batch size should be larger than the number of GPUs used. It should 79 | also be an integer multiple of the number of GPUs so that each chunk is 80 | the same size (so that each GPU processes the same number of samples). 81 | 82 | Args: 83 | module: module to be parallelized 84 | device_ids: CUDA devices (default: all devices) 85 | 86 | Reference: 87 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 88 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 89 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 90 | 91 | Example:: 92 | 93 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 94 | >>> y = net(x) 95 | """ 96 | def gather(self, outputs, output_device): 97 | return outputs 98 | 99 | def replicate(self, module, device_ids): 100 | modules = super(DataParallelModel, self).replicate(module, device_ids) 101 | execute_replication_callbacks(modules) 102 | return modules 103 | 104 | 105 | class DataParallelCriterion(DataParallel): 106 | """ 107 | Calculate loss in multiple-GPUs, which balance the memory usage for 108 | Semantic Segmentation. 109 | 110 | The targets are splitted across the specified devices by chunking in 111 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 112 | 113 | Reference: 114 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 115 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 116 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 117 | 118 | Example:: 119 | 120 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 121 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 122 | >>> y = net(x) 123 | >>> loss = criterion(y, target) 124 | """ 125 | def forward(self, inputs, *targets, **kwargs): 126 | # input should be already scatterd 127 | # scattering the targets instead 128 | if not self.device_ids: 129 | return self.module(inputs, *targets, **kwargs) 130 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 131 | if len(self.device_ids) == 1: 132 | return self.module(inputs, *targets[0], **kwargs[0]) 133 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 134 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 135 | return Reduce.apply(*outputs) / len(outputs) 136 | #return self.gather(outputs, self.output_device).mean() 137 | 138 | 139 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 140 | assert len(modules) == len(inputs) 141 | assert len(targets) == len(inputs) 142 | if kwargs_tup: 143 | assert len(modules) == len(kwargs_tup) 144 | else: 145 | kwargs_tup = ({},) * len(modules) 146 | if devices is not None: 147 | assert len(modules) == len(devices) 148 | else: 149 | devices = [None] * len(modules) 150 | 151 | lock = threading.Lock() 152 | results = {} 153 | if torch_ver != "0.3": 154 | grad_enabled = torch.is_grad_enabled() 155 | 156 | def _worker(i, module, input, target, kwargs, device=None): 157 | if torch_ver != "0.3": 158 | torch.set_grad_enabled(grad_enabled) 159 | if device is None: 160 | device = get_a_var(input).get_device() 161 | try: 162 | with torch.cuda.device(device): 163 | # this also avoids accidental slicing of `input` if it is a Tensor 164 | if not isinstance(input, (list, tuple)): 165 | input = (input,) 166 | if type(input) != type(target): 167 | if isinstance(target, tuple): 168 | input = tuple(input) 169 | elif isinstance(target, list): 170 | input = list(input) 171 | else: 172 | raise Exception("Types problem") 173 | 174 | output = module(*(input + target), **kwargs) 175 | with lock: 176 | results[i] = output 177 | except Exception as e: 178 | with lock: 179 | results[i] = e 180 | 181 | if len(modules) > 1: 182 | threads = [threading.Thread(target=_worker, 183 | args=(i, module, input, target, 184 | kwargs, device),) 185 | for i, (module, input, target, kwargs, device) in 186 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 187 | 188 | for thread in threads: 189 | thread.start() 190 | for thread in threads: 191 | thread.join() 192 | else: 193 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 194 | 195 | outputs = [] 196 | for i in range(len(inputs)): 197 | output = results[i] 198 | if isinstance(output, Exception): 199 | raise output 200 | outputs.append(output) 201 | return outputs 202 | 203 | 204 | ########################################################################### 205 | # Adapted from Synchronized-BatchNorm-PyTorch. 206 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 207 | # 208 | class CallbackContext(object): 209 | pass 210 | 211 | 212 | def execute_replication_callbacks(modules): 213 | """ 214 | Execute an replication callback `__data_parallel_replicate__` on each module created 215 | by original replication. 216 | 217 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 218 | 219 | Note that, as all modules are isomorphism, we assign each sub-module with a context 220 | (shared among multiple copies of this module on different devices). 221 | Through this context, different copies can share some information. 222 | 223 | We guarantee that the callback on the master copy (the first copy) will be called ahead 224 | of calling the callback of any slave copies. 225 | """ 226 | master_copy = modules[0] 227 | nr_modules = len(list(master_copy.modules())) 228 | ctxs = [CallbackContext() for _ in range(nr_modules)] 229 | 230 | for i, module in enumerate(modules): 231 | for j, m in enumerate(module.modules()): 232 | if hasattr(m, '__data_parallel_replicate__'): 233 | m.__data_parallel_replicate__(ctxs[j], i) 234 | 235 | 236 | def patch_replication_callback(data_parallel): 237 | """ 238 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 239 | Useful when you have customized `DataParallel` implementation. 240 | 241 | Examples: 242 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 243 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 244 | > patch_replication_callback(sync_bn) 245 | # this is equivalent to 246 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 247 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 248 | """ 249 | 250 | assert isinstance(data_parallel, DataParallel) 251 | 252 | old_replicate = data_parallel.replicate 253 | 254 | @functools.wraps(old_replicate) 255 | def new_replicate(module, device_ids): 256 | modules = old_replicate(module, device_ids) 257 | execute_replication_callbacks(modules) 258 | return modules 259 | 260 | data_parallel.replicate = new_replicate 261 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.cm 3 | import skimage.io 4 | import skimage.feature 5 | import skimage.filters 6 | 7 | 8 | def vec2im(V, shape=()): 9 | ''' 10 | Transform an array V into a specified shape - or if no shape is given assume a square output format. 11 | 12 | Parameters 13 | ---------- 14 | 15 | V : numpy.ndarray 16 | an array either representing a matrix or vector to be reshaped into an two-dimensional image 17 | 18 | shape : tuple or list 19 | optional. containing the shape information for the output array if not given, the output is assumed to be square 20 | 21 | Returns 22 | ------- 23 | 24 | W : numpy.ndarray 25 | with W.shape = shape or W.shape = [np.sqrt(V.size)]*2 26 | 27 | ''' 28 | 29 | if len(shape) < 2: 30 | shape = [np.sqrt(V.size)] * 2 31 | shape = map(int, shape) 32 | return np.reshape(V, shape) 33 | 34 | 35 | def enlarge_image(img, scaling=3): 36 | ''' 37 | Enlarges a given input matrix by replicating each pixel value scaling times in horizontal and vertical direction. 38 | 39 | Parameters 40 | ---------- 41 | 42 | img : numpy.ndarray 43 | array of shape [H x W] OR [H x W x D] 44 | 45 | scaling : int 46 | positive integer value > 0 47 | 48 | Returns 49 | ------- 50 | 51 | out : numpy.ndarray 52 | two-dimensional array of shape [scaling*H x scaling*W] 53 | OR 54 | three-dimensional array of shape [scaling*H x scaling*W x D] 55 | depending on the dimensionality of the input 56 | ''' 57 | 58 | if scaling < 1 or not isinstance(scaling, int): 59 | print('scaling factor needs to be an int >= 1') 60 | 61 | if len(img.shape) == 2: 62 | H, W = img.shape 63 | 64 | out = np.zeros((scaling * H, scaling * W)) 65 | for h in range(H): 66 | fh = scaling * h 67 | for w in range(W): 68 | fw = scaling * w 69 | out[fh:fh + scaling, fw:fw + scaling] = img[h, w] 70 | 71 | elif len(img.shape) == 3: 72 | H, W, D = img.shape 73 | 74 | out = np.zeros((scaling * H, scaling * W, D)) 75 | for h in range(H): 76 | fh = scaling * h 77 | for w in range(W): 78 | fw = scaling * w 79 | out[fh:fh + scaling, fw:fw + scaling, :] = img[h, w, :] 80 | 81 | return out 82 | 83 | 84 | def repaint_corner_pixels(rgbimg, scaling=3): 85 | ''' 86 | DEPRECATED/OBSOLETE. 87 | 88 | Recolors the top left and bottom right pixel (groups) with the average rgb value of its three neighboring pixel (groups). 89 | The recoloring visually masks the opposing pixel values which are a product of stabilizing the scaling. 90 | Assumes those image ares will pretty much never show evidence. 91 | 92 | Parameters 93 | ---------- 94 | 95 | rgbimg : numpy.ndarray 96 | array of shape [H x W x 3] 97 | 98 | scaling : int 99 | positive integer value > 0 100 | 101 | Returns 102 | ------- 103 | 104 | rgbimg : numpy.ndarray 105 | three-dimensional array of shape [scaling*H x scaling*W x 3] 106 | ''' 107 | 108 | # top left corner. 109 | rgbimg[0:scaling, 0:scaling, :] = (rgbimg[0, scaling, :] + rgbimg[scaling, 0, :] + rgbimg[scaling, scaling, 110 | :]) / 3.0 111 | # bottom right corner 112 | rgbimg[-scaling:, -scaling:, :] = (rgbimg[-1, -1 - scaling, :] + rgbimg[-1 - scaling, -1, :] + rgbimg[-1 - scaling, 113 | -1 - scaling, 114 | :]) / 3.0 115 | return rgbimg 116 | 117 | 118 | def digit_to_rgb(X, scaling=3, shape=(), cmap='binary'): 119 | ''' 120 | Takes as input an intensity array and produces a rgb image due to some color map 121 | 122 | Parameters 123 | ---------- 124 | 125 | X : numpy.ndarray 126 | intensity matrix as array of shape [M x N] 127 | 128 | scaling : int 129 | optional. positive integer value > 0 130 | 131 | shape: tuple or list of its , length = 2 132 | optional. if not given, X is reshaped to be square. 133 | 134 | cmap : str 135 | name of color map of choice. default is 'binary' 136 | 137 | Returns 138 | ------- 139 | 140 | image : numpy.ndarray 141 | three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N 142 | ''' 143 | 144 | # create color map object from name string 145 | cmap = eval('matplotlib.cm.{}'.format(cmap)) 146 | 147 | image = enlarge_image(vec2im(X, shape), scaling) # enlarge 148 | image = cmap(image.flatten())[..., 0:3].reshape([image.shape[0], image.shape[1], 3]) # colorize, reshape 149 | 150 | return image 151 | 152 | 153 | def hm_to_rgb(R, X=None, scaling=3, shape=(), sigma=2, cmap='bwr', normalize=True): 154 | ''' 155 | Takes as input an intensity array and produces a rgb image for the represented heatmap. 156 | optionally draws the outline of another input on top of it. 157 | 158 | Parameters 159 | ---------- 160 | 161 | R : numpy.ndarray 162 | the heatmap to be visualized, shaped [M x N] 163 | 164 | X : numpy.ndarray 165 | optional. some input, usually the data point for which the heatmap R is for, which shall serve 166 | as a template for a black outline to be drawn on top of the image 167 | shaped [M x N] 168 | 169 | scaling: int 170 | factor, on how to enlarge the heatmap (to control resolution and as a inverse way to control outline thickness) 171 | after reshaping it using shape. 172 | 173 | shape: tuple or list, length = 2 174 | optional. if not given, X is reshaped to be square. 175 | 176 | sigma : double 177 | optional. sigma-parameter for the canny algorithm used for edge detection. the found edges are drawn as outlines. 178 | 179 | cmap : str 180 | optional. color map of choice 181 | 182 | normalize : bool 183 | optional. whether to normalize the heatmap to [-1 1] prior to colorization or not. 184 | 185 | Returns 186 | ------- 187 | 188 | rgbimg : numpy.ndarray 189 | three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N 190 | ''' 191 | 192 | # create color map object from name string 193 | cmap = eval('matplotlib.cm.{}'.format(cmap)) 194 | 195 | if normalize: 196 | R = R / np.max(np.abs(R)) # normalize to [-1,1] wrt to max relevance magnitude 197 | R = (R + 1.) / 2. # shift/normalize to [0,1] for color mapping 198 | 199 | R = enlarge_image(R, scaling) 200 | rgb = cmap(R.flatten())[..., 0:3].reshape([R.shape[0], R.shape[1], 3]) 201 | # rgb = repaint_corner_pixels(rgb, scaling) #obsolete due to directly calling the color map with [0,1]-normalized inputs 202 | 203 | if not X is None: # compute the outline of the input 204 | # X = enlarge_image(vec2im(X,shape), scaling) 205 | xdims = X.shape 206 | Rdims = R.shape 207 | 208 | # if not np.all(xdims == Rdims): 209 | # print 'transformed heatmap and data dimension mismatch. data dimensions differ?' 210 | # print 'R.shape = ',Rdims, 'X.shape = ', xdims 211 | # print 'skipping drawing of outline\n' 212 | # else: 213 | # #edges = skimage.filters.canny(X, sigma=sigma) 214 | # edges = skimage.feature.canny(X, sigma=sigma) 215 | # edges = np.invert(np.dstack([edges]*3))*1.0 216 | # rgb *= edges # set outline pixels to black color 217 | 218 | return rgb 219 | 220 | 221 | def save_image(rgb_images, path, gap=2): 222 | ''' 223 | Takes as input a list of rgb images, places them next to each other with a gap and writes out the result. 224 | 225 | Parameters 226 | ---------- 227 | 228 | rgb_images : list , tuple, collection. such stuff 229 | each item in the collection is expected to be an rgb image of dimensions [H x _ x 3] 230 | where the width is variable 231 | 232 | path : str 233 | the output path of the assembled image 234 | 235 | gap : int 236 | optional. sets the width of a black area of pixels realized as an image shaped [H x gap x 3] in between the input images 237 | 238 | Returns 239 | ------- 240 | 241 | image : numpy.ndarray 242 | the assembled image as written out to path 243 | ''' 244 | 245 | sz = [] 246 | image = [] 247 | for i in range(len(rgb_images)): 248 | if not sz: 249 | sz = rgb_images[i].shape 250 | image = rgb_images[i] 251 | gap = np.zeros((sz[0], gap, sz[2])) 252 | continue 253 | if not sz[0] == rgb_images[i].shape[0] and sz[1] == rgb_images[i].shape[2]: 254 | print('image', i, 'differs in size. unable to perform horizontal alignment') 255 | print('expected: Hx_xD = {0}x_x{1}'.format(sz[0], sz[1])) 256 | print('got : Hx_xD = {0}x_x{1}'.format(rgb_images[i].shape[0], rgb_images[i].shape[1])) 257 | print('skipping image\n') 258 | else: 259 | image = np.hstack((image, gap, rgb_images[i])) 260 | 261 | image *= 255 262 | image = image.astype(np.uint8) 263 | 264 | print('saving image to ', path) 265 | skimage.io.imsave(path, image) 266 | return image 267 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | import glob 5 | 6 | 7 | class Saver(object): 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | self.directory = os.path.join('run', args.train_dataset, args.checkname) 12 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 13 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 14 | 15 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 16 | if not os.path.exists(self.experiment_dir): 17 | os.makedirs(self.experiment_dir) 18 | 19 | def save_checkpoint(self, state, filename='checkpoint.pth.tar'): 20 | """Saves checkpoint to disk""" 21 | filename = os.path.join(self.experiment_dir, filename) 22 | torch.save(state, filename) 23 | 24 | def save_experiment_config(self): 25 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 26 | log_file = open(logfile, 'w') 27 | p = OrderedDict() 28 | p['train_dataset'] = self.args.train_dataset 29 | p['lr'] = self.args.lr 30 | p['epoch'] = self.args.epochs 31 | 32 | for key, val in p.items(): 33 | log_file.write(key + ':' + str(val) + '\n') 34 | log_file.close() 35 | -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | 5 | class TensorboardSummary(object): 6 | def __init__(self, directory): 7 | self.directory = directory 8 | self.writer = SummaryWriter(log_dir=os.path.join(self.directory)) 9 | 10 | def add_scalar(self, *args): 11 | self.writer.add_scalar(*args) --------------------------------------------------------------------------------