├── .gitignore
├── .ipynb_checkpoints
├── DeiT_example-checkpoint.ipynb
└── example-checkpoint.ipynb
├── BERT_explainability.ipynb
├── BERT_explainability
└── modules
│ ├── BERT
│ ├── BERT.py
│ ├── BERT_cls_lrp.py
│ ├── BERT_orig_lrp.py
│ ├── BertForSequenceClassification.py
│ └── ExplanationGenerator.py
│ ├── __init__.py
│ ├── layers_lrp.py
│ └── layers_ours.py
├── BERT_params
├── boolq.json
├── boolq_baas.json
├── boolq_bert.json
├── boolq_soft.json
├── cose_bert.json
├── cose_multiclass.json
├── esnli_bert.json
├── evidence_inference.json
├── evidence_inference_bert.json
├── evidence_inference_soft.json
├── fever.json
├── fever_baas.json
├── fever_bert.json
├── fever_soft.json
├── movies.json
├── movies_baas.json
├── movies_bert.json
├── movies_soft.json
├── multirc.json
├── multirc_baas.json
├── multirc_bert.json
└── multirc_soft.json
├── BERT_rationale_benchmark
├── __init__.py
├── metrics.py
├── models
│ ├── model_utils.py
│ ├── pipeline
│ │ ├── __init__.py
│ │ ├── bert_pipeline.py
│ │ ├── pipeline_train.py
│ │ └── pipeline_utils.py
│ └── sequence_taggers.py
└── utils.py
├── DeiT.PNG
├── DeiT_example.ipynb
├── LICENSE
├── README.md
├── Transformer_explainability.ipynb
├── baselines
└── ViT
│ ├── ViT_LRP.py
│ ├── ViT_explanation_generator.py
│ ├── ViT_new.py
│ ├── ViT_orig_LRP.py
│ ├── generate_visualizations.py
│ ├── helpers.py
│ ├── imagenet_seg_eval.py
│ ├── layer_helpers.py
│ ├── misc_functions.py
│ ├── pertubation_eval_from_hdf5.py
│ └── weight_init.py
├── data
├── Imagenet.py
├── VOC.py
├── __init__.py
├── imagenet.py
├── imagenet_utils.py
└── transforms.py
├── dataset
└── expl_hdf5.py
├── example.PNG
├── example.ipynb
├── method-page-001.jpg
├── modules
├── __init__.py
├── layers_lrp.py
└── layers_ours.py
├── new_work.jpg
├── requirements.txt
├── samples
├── CLS2IDX.py
├── catdog.png
├── dogbird.png
├── dogcat2.png
├── el1.png
├── el2.png
├── el3.png
├── el4.png
└── el5.png
└── utils
├── __init__.py
├── confusionmatrix.py
├── iou.py
├── metric.py
├── metrices.py
├── parallel.py
├── render.py
├── saver.py
└── summaries.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | all_good_vis/
3 | __pycache__
4 | *.tar
5 | .idea
6 | run/
7 | baselines/ViT/experiments/
8 | baselines/ViT/visualizations/
9 | bert_models/
10 | data/movies/
11 |
12 |
--------------------------------------------------------------------------------
/BERT_explainability/modules/BERT/BERT_cls_lrp.py:
--------------------------------------------------------------------------------
1 | from transformers import BertPreTrainedModel
2 | from transformers.utils import logging
3 | from BERT_explainability.modules.layers_lrp import *
4 | from BERT_explainability.modules.BERT.BERT_orig_lrp import BertModel
5 | from torch.nn import CrossEntropyLoss, MSELoss
6 | import torch.nn as nn
7 | from typing import List, Any
8 | import torch
9 | from BERT_rationale_benchmark.models.model_utils import PaddedSequence
10 |
11 |
12 | class BertForSequenceClassification(BertPreTrainedModel):
13 | def __init__(self, config):
14 | super().__init__(config)
15 | self.num_labels = config.num_labels
16 |
17 | self.bert = BertModel(config)
18 | self.dropout = Dropout(config.hidden_dropout_prob)
19 | self.classifier = Linear(config.hidden_size, config.num_labels)
20 |
21 | self.init_weights()
22 |
23 | def forward(
24 | self,
25 | input_ids=None,
26 | attention_mask=None,
27 | token_type_ids=None,
28 | position_ids=None,
29 | head_mask=None,
30 | inputs_embeds=None,
31 | labels=None,
32 | output_attentions=None,
33 | output_hidden_states=None,
34 | return_dict=None,
35 | ):
36 | r"""
37 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
38 | Labels for computing the sequence classification/regression loss.
39 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
40 | If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
41 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
42 | """
43 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
44 |
45 | outputs = self.bert(
46 | input_ids,
47 | attention_mask=attention_mask,
48 | token_type_ids=token_type_ids,
49 | position_ids=position_ids,
50 | head_mask=head_mask,
51 | inputs_embeds=inputs_embeds,
52 | output_attentions=output_attentions,
53 | output_hidden_states=output_hidden_states,
54 | return_dict=return_dict,
55 | )
56 |
57 | pooled_output = outputs[1]
58 |
59 | pooled_output = self.dropout(pooled_output)
60 | logits = self.classifier(pooled_output)
61 |
62 | loss = None
63 | if labels is not None:
64 | if self.num_labels == 1:
65 | # We are doing regression
66 | loss_fct = MSELoss()
67 | loss = loss_fct(logits.view(-1), labels.view(-1))
68 | else:
69 | loss_fct = CrossEntropyLoss()
70 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
71 |
72 | if not return_dict:
73 | output = (logits,) + outputs[2:]
74 | return ((loss,) + output) if loss is not None else output
75 |
76 | return SequenceClassifierOutput(
77 | loss=loss,
78 | logits=logits,
79 | hidden_states=outputs.hidden_states,
80 | attentions=outputs.attentions,
81 | )
82 |
83 | def relprop(self, cam=None, **kwargs):
84 | cam = self.classifier.relprop(cam, **kwargs)
85 | cam = self.dropout.relprop(cam, **kwargs)
86 | cam = self.bert.relprop(cam, **kwargs)
87 | return cam
88 |
89 |
90 | # this is the actual classifier we will be using
91 | class BertClassifier(nn.Module):
92 | """Thin wrapper around BertForSequenceClassification"""
93 |
94 | def __init__(self,
95 | bert_dir: str,
96 | pad_token_id: int,
97 | cls_token_id: int,
98 | sep_token_id: int,
99 | num_labels: int,
100 | max_length: int = 512,
101 | use_half_precision=True):
102 | super(BertClassifier, self).__init__()
103 | bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
104 | if use_half_precision:
105 | import apex
106 | bert = bert.half()
107 | self.bert = bert
108 | self.pad_token_id = pad_token_id
109 | self.cls_token_id = cls_token_id
110 | self.sep_token_id = sep_token_id
111 | self.max_length = max_length
112 |
113 | def forward(self,
114 | query: List[torch.tensor],
115 | docids: List[Any],
116 | document_batch: List[torch.tensor]):
117 | assert len(query) == len(document_batch)
118 | print(query)
119 | # note about device management:
120 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
121 | # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
122 | target_device = next(self.parameters()).device
123 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
124 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
125 | input_tensors = []
126 | position_ids = []
127 | for q, d in zip(query, document_batch):
128 | if len(q) + len(d) + 2 > self.max_length:
129 | d = d[:(self.max_length - len(q) - 2)]
130 | input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
131 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
132 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
133 | device=target_device)
134 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
135 | (classes,) = self.bert(bert_input.data,
136 | attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
137 | position_ids=positions.data)
138 | assert torch.all(classes == classes) # for nans
139 |
140 | print(input_tensors[0])
141 | print(self.relprop()[0])
142 |
143 | return classes
144 |
145 | def relprop(self, cam=None, **kwargs):
146 | return self.bert.relprop(cam, **kwargs)
147 |
148 |
149 | if __name__ == '__main__':
150 | from transformers import BertTokenizer
151 | import os
152 |
153 | class Config:
154 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
155 | hidden_dropout_prob):
156 | self.hidden_size = hidden_size
157 | self.num_attention_heads = num_attention_heads
158 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
159 | self.num_labels = num_labels
160 | self.hidden_dropout_prob = hidden_dropout_prob
161 |
162 |
163 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
164 | x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
165 | add_special_tokens=True,
166 | max_length=512,
167 | return_token_type_ids=False,
168 | return_attention_mask=True,
169 | pad_to_max_length=True,
170 | return_tensors='pt',
171 | truncation=True)
172 |
173 | print(x['input_ids'])
174 |
175 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
176 | model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
177 | model.load_state_dict(torch.load(model_save_file))
178 |
179 | # x = torch.randint(100, (2, 20))
180 | # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
181 | # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
182 | # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
183 | # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
184 | # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
185 | # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
186 | # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
187 | # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
188 | # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
189 | # 102, 101, 1012, 102]])
190 | # x.requires_grad_()
191 |
192 | model.eval()
193 |
194 | y = model(x['input_ids'], x['attention_mask'])
195 | print(y)
196 |
197 | cam, _ = model.relprop()
198 |
199 | #print(cam.shape)
200 |
201 | cam = cam.sum(-1)
202 | #print(cam)
203 |
--------------------------------------------------------------------------------
/BERT_explainability/modules/BERT/BertForSequenceClassification.py:
--------------------------------------------------------------------------------
1 | from transformers import BertPreTrainedModel
2 | from transformers.utils import logging
3 | from BERT_explainability.modules.layers_ours import *
4 | from BERT_explainability.modules.BERT.BERT import BertModel
5 | from torch.nn import CrossEntropyLoss, MSELoss
6 | import torch.nn as nn
7 | from typing import List, Any
8 | import torch
9 | from BERT_rationale_benchmark.models.model_utils import PaddedSequence
10 |
11 |
12 | class BertForSequenceClassification(BertPreTrainedModel):
13 | def __init__(self, config):
14 | super().__init__(config)
15 | self.num_labels = config.num_labels
16 |
17 | self.bert = BertModel(config)
18 | self.dropout = Dropout(config.hidden_dropout_prob)
19 | self.classifier = Linear(config.hidden_size, config.num_labels)
20 |
21 | self.init_weights()
22 |
23 | def forward(
24 | self,
25 | input_ids=None,
26 | attention_mask=None,
27 | token_type_ids=None,
28 | position_ids=None,
29 | head_mask=None,
30 | inputs_embeds=None,
31 | labels=None,
32 | output_attentions=None,
33 | output_hidden_states=None,
34 | return_dict=None,
35 | ):
36 | r"""
37 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
38 | Labels for computing the sequence classification/regression loss.
39 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
40 | If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
41 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
42 | """
43 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
44 |
45 | outputs = self.bert(
46 | input_ids,
47 | attention_mask=attention_mask,
48 | token_type_ids=token_type_ids,
49 | position_ids=position_ids,
50 | head_mask=head_mask,
51 | inputs_embeds=inputs_embeds,
52 | output_attentions=output_attentions,
53 | output_hidden_states=output_hidden_states,
54 | return_dict=return_dict,
55 | )
56 |
57 | pooled_output = outputs[1]
58 |
59 | pooled_output = self.dropout(pooled_output)
60 | logits = self.classifier(pooled_output)
61 |
62 | loss = None
63 | if labels is not None:
64 | if self.num_labels == 1:
65 | # We are doing regression
66 | loss_fct = MSELoss()
67 | loss = loss_fct(logits.view(-1), labels.view(-1))
68 | else:
69 | loss_fct = CrossEntropyLoss()
70 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
71 |
72 | if not return_dict:
73 | output = (logits,) + outputs[2:]
74 | return ((loss,) + output) if loss is not None else output
75 |
76 | return SequenceClassifierOutput(
77 | loss=loss,
78 | logits=logits,
79 | hidden_states=outputs.hidden_states,
80 | attentions=outputs.attentions,
81 | )
82 |
83 | def relprop(self, cam=None, **kwargs):
84 | cam = self.classifier.relprop(cam, **kwargs)
85 | cam = self.dropout.relprop(cam, **kwargs)
86 | cam = self.bert.relprop(cam, **kwargs)
87 | # print("conservation: ", cam.sum())
88 | return cam
89 |
90 |
91 | # this is the actual classifier we will be using
92 | class BertClassifier(nn.Module):
93 | """Thin wrapper around BertForSequenceClassification"""
94 |
95 | def __init__(self,
96 | bert_dir: str,
97 | pad_token_id: int,
98 | cls_token_id: int,
99 | sep_token_id: int,
100 | num_labels: int,
101 | max_length: int = 512,
102 | use_half_precision=True):
103 | super(BertClassifier, self).__init__()
104 | bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
105 | if use_half_precision:
106 | import apex
107 | bert = bert.half()
108 | self.bert = bert
109 | self.pad_token_id = pad_token_id
110 | self.cls_token_id = cls_token_id
111 | self.sep_token_id = sep_token_id
112 | self.max_length = max_length
113 |
114 | def forward(self,
115 | query: List[torch.tensor],
116 | docids: List[Any],
117 | document_batch: List[torch.tensor]):
118 | assert len(query) == len(document_batch)
119 | print(query)
120 | # note about device management:
121 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
122 | # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
123 | target_device = next(self.parameters()).device
124 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
125 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
126 | input_tensors = []
127 | position_ids = []
128 | for q, d in zip(query, document_batch):
129 | if len(q) + len(d) + 2 > self.max_length:
130 | d = d[:(self.max_length - len(q) - 2)]
131 | input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
132 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
133 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
134 | device=target_device)
135 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
136 | (classes,) = self.bert(bert_input.data,
137 | attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
138 | position_ids=positions.data)
139 | assert torch.all(classes == classes) # for nans
140 |
141 | print(input_tensors[0])
142 | print(self.relprop()[0])
143 |
144 | return classes
145 |
146 | def relprop(self, cam=None, **kwargs):
147 | return self.bert.relprop(cam, **kwargs)
148 |
149 |
150 | if __name__ == '__main__':
151 | from transformers import BertTokenizer
152 | import os
153 |
154 | class Config:
155 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
156 | hidden_dropout_prob):
157 | self.hidden_size = hidden_size
158 | self.num_attention_heads = num_attention_heads
159 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
160 | self.num_labels = num_labels
161 | self.hidden_dropout_prob = hidden_dropout_prob
162 |
163 |
164 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
165 | x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
166 | add_special_tokens=True,
167 | max_length=512,
168 | return_token_type_ids=False,
169 | return_attention_mask=True,
170 | pad_to_max_length=True,
171 | return_tensors='pt',
172 | truncation=True)
173 |
174 | print(x['input_ids'])
175 |
176 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
177 | model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
178 | model.load_state_dict(torch.load(model_save_file))
179 |
180 | # x = torch.randint(100, (2, 20))
181 | # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
182 | # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
183 | # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
184 | # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
185 | # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
186 | # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
187 | # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
188 | # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
189 | # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
190 | # 102, 101, 1012, 102]])
191 | # x.requires_grad_()
192 |
193 | model.eval()
194 |
195 | y = model(x['input_ids'], x['attention_mask'])
196 | print(y)
197 |
198 | cam, _ = model.relprop()
199 |
200 | #print(cam.shape)
201 |
202 | cam = cam.sum(-1)
203 | #print(cam)
204 |
--------------------------------------------------------------------------------
/BERT_explainability/modules/BERT/ExplanationGenerator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import glob
5 |
6 | # compute rollout between attention layers
7 | def compute_rollout_attention(all_layer_matrices, start_layer=0):
8 | # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9 | num_tokens = all_layer_matrices[0].shape[1]
10 | batch_size = all_layer_matrices[0].shape[0]
11 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13 | matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14 | for i in range(len(all_layer_matrices))]
15 | joint_attention = matrices_aug[start_layer]
16 | for i in range(start_layer+1, len(matrices_aug)):
17 | joint_attention = matrices_aug[i].bmm(joint_attention)
18 | return joint_attention
19 |
20 | class Generator:
21 | def __init__(self, model):
22 | self.model = model
23 | self.model.eval()
24 |
25 | def forward(self, input_ids, attention_mask):
26 | return self.model(input_ids, attention_mask)
27 |
28 | def generate_LRP(self, input_ids, attention_mask,
29 | index=None, start_layer=11):
30 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
31 | kwargs = {"alpha": 1}
32 |
33 | if index == None:
34 | index = np.argmax(output.cpu().data.numpy(), axis=-1)
35 |
36 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
37 | one_hot[0, index] = 1
38 | one_hot_vector = one_hot
39 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
40 | one_hot = torch.sum(one_hot.cuda() * output)
41 |
42 | self.model.zero_grad()
43 | one_hot.backward(retain_graph=True)
44 |
45 | self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
46 |
47 | cams = []
48 | blocks = self.model.bert.encoder.layer
49 | for blk in blocks:
50 | grad = blk.attention.self.get_attn_gradients()
51 | cam = blk.attention.self.get_attn_cam()
52 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
53 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
54 | cam = grad * cam
55 | cam = cam.clamp(min=0).mean(dim=0)
56 | cams.append(cam.unsqueeze(0))
57 | rollout = compute_rollout_attention(cams, start_layer=start_layer)
58 | rollout[:, 0, 0] = rollout[:, 0].min()
59 | return rollout[:, 0]
60 |
61 |
62 | def generate_LRP_last_layer(self, input_ids, attention_mask,
63 | index=None):
64 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
65 | kwargs = {"alpha": 1}
66 | if index == None:
67 | index = np.argmax(output.cpu().data.numpy(), axis=-1)
68 |
69 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
70 | one_hot[0, index] = 1
71 | one_hot_vector = one_hot
72 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
73 | one_hot = torch.sum(one_hot.cuda() * output)
74 |
75 | self.model.zero_grad()
76 | one_hot.backward(retain_graph=True)
77 |
78 | self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
79 |
80 | cam = self.model.bert.encoder.layer[-1].attention.self.get_attn_cam()[0]
81 | cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
82 | cam[:, 0, 0] = 0
83 | return cam[:, 0]
84 |
85 | def generate_full_lrp(self, input_ids, attention_mask,
86 | index=None):
87 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
88 | kwargs = {"alpha": 1}
89 |
90 | if index == None:
91 | index = np.argmax(output.cpu().data.numpy(), axis=-1)
92 |
93 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
94 | one_hot[0, index] = 1
95 | one_hot_vector = one_hot
96 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
97 | one_hot = torch.sum(one_hot.cuda() * output)
98 |
99 | self.model.zero_grad()
100 | one_hot.backward(retain_graph=True)
101 |
102 | cam = self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
103 | cam = cam.sum(dim=2)
104 | cam[:, 0] = 0
105 | return cam
106 |
107 | def generate_attn_last_layer(self, input_ids, attention_mask,
108 | index=None):
109 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
110 | cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()[0]
111 | cam = cam.mean(dim=0).unsqueeze(0)
112 | cam[:, 0, 0] = 0
113 | return cam[:, 0]
114 |
115 | def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None):
116 | self.model.zero_grad()
117 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
118 | blocks = self.model.bert.encoder.layer
119 | all_layer_attentions = []
120 | for blk in blocks:
121 | attn_heads = blk.attention.self.get_attn()
122 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
123 | all_layer_attentions.append(avg_heads)
124 | rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
125 | rollout[:, 0, 0] = 0
126 | return rollout[:, 0]
127 |
128 | def generate_attn_gradcam(self, input_ids, attention_mask, index=None):
129 | output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
130 | kwargs = {"alpha": 1}
131 |
132 | if index == None:
133 | index = np.argmax(output.cpu().data.numpy(), axis=-1)
134 |
135 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
136 | one_hot[0, index] = 1
137 | one_hot_vector = one_hot
138 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
139 | one_hot = torch.sum(one_hot.cuda() * output)
140 |
141 | self.model.zero_grad()
142 | one_hot.backward(retain_graph=True)
143 |
144 | self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
145 |
146 | cam = self.model.bert.encoder.layer[-1].attention.self.get_attn()
147 | grad = self.model.bert.encoder.layer[-1].attention.self.get_attn_gradients()
148 |
149 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
150 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
151 | grad = grad.mean(dim=[1, 2], keepdim=True)
152 | cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
153 | cam = (cam - cam.min()) / (cam.max() - cam.min())
154 | cam[:, 0, 0] = 0
155 | return cam[:, 0]
156 |
157 |
--------------------------------------------------------------------------------
/BERT_explainability/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/BERT_explainability/modules/__init__.py
--------------------------------------------------------------------------------
/BERT_explainability/modules/layers_lrp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7 | 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul']
8 |
9 |
10 | def safe_divide(a, b):
11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12 | den = den + den.eq(0).type(den.type()) * 1e-9
13 | return a / den * b.ne(0).type(b.type())
14 |
15 |
16 | def forward_hook(self, input, output):
17 | if type(input[0]) in (list, tuple):
18 | self.X = []
19 | for i in input[0]:
20 | x = i.detach()
21 | x.requires_grad = True
22 | self.X.append(x)
23 | else:
24 | self.X = input[0].detach()
25 | self.X.requires_grad = True
26 |
27 | self.Y = output
28 |
29 |
30 | def backward_hook(self, grad_input, grad_output):
31 | self.grad_input = grad_input
32 | self.grad_output = grad_output
33 |
34 |
35 | class RelProp(nn.Module):
36 | def __init__(self):
37 | super(RelProp, self).__init__()
38 | # if not self.training:
39 | self.register_forward_hook(forward_hook)
40 |
41 | def gradprop(self, Z, X, S):
42 | C = torch.autograd.grad(Z, X, S, retain_graph=True)
43 | return C
44 |
45 | def relprop(self, R, alpha):
46 | return R
47 |
48 |
49 | class RelPropSimple(RelProp):
50 | def relprop(self, R, alpha):
51 | Z = self.forward(self.X)
52 | S = safe_divide(R, Z)
53 | C = self.gradprop(Z, self.X, S)
54 |
55 | if torch.is_tensor(self.X) == False:
56 | outputs = []
57 | outputs.append(self.X[0] * C[0])
58 | outputs.append(self.X[1] * C[1])
59 | else:
60 | outputs = self.X * (C[0])
61 | return outputs
62 |
63 | class AddEye(RelPropSimple):
64 | # input of shape B, C, seq_len, seq_len
65 | def forward(self, input):
66 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67 |
68 | class ReLU(nn.ReLU, RelProp):
69 | pass
70 |
71 | class Tanh(nn.Tanh, RelProp):
72 | pass
73 |
74 | class GELU(nn.GELU, RelProp):
75 | pass
76 |
77 | class Softmax(nn.Softmax, RelProp):
78 | pass
79 |
80 | class LayerNorm(nn.LayerNorm, RelProp):
81 | pass
82 |
83 | class Dropout(nn.Dropout, RelProp):
84 | pass
85 |
86 |
87 | class MaxPool2d(nn.MaxPool2d, RelPropSimple):
88 | pass
89 |
90 | class LayerNorm(nn.LayerNorm, RelProp):
91 | pass
92 |
93 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
94 | pass
95 |
96 | class MatMul(RelPropSimple):
97 | def forward(self, inputs):
98 | return torch.matmul(*inputs)
99 |
100 | class Mul(RelPropSimple):
101 | def forward(self, inputs):
102 | return torch.mul(*inputs)
103 |
104 | class AvgPool2d(nn.AvgPool2d, RelPropSimple):
105 | pass
106 |
107 |
108 | class Add(RelPropSimple):
109 | def forward(self, inputs):
110 | return torch.add(*inputs)
111 |
112 | class einsum(RelPropSimple):
113 | def __init__(self, equation):
114 | super().__init__()
115 | self.equation = equation
116 | def forward(self, *operands):
117 | return torch.einsum(self.equation, *operands)
118 |
119 | class IndexSelect(RelProp):
120 | def forward(self, inputs, dim, indices):
121 | self.__setattr__('dim', dim)
122 | self.__setattr__('indices', indices)
123 |
124 | return torch.index_select(inputs, dim, indices)
125 |
126 | def relprop(self, R, alpha):
127 | Z = self.forward(self.X, self.dim, self.indices)
128 | S = safe_divide(R, Z)
129 | C = self.gradprop(Z, self.X, S)
130 |
131 | if torch.is_tensor(self.X) == False:
132 | outputs = []
133 | outputs.append(self.X[0] * C[0])
134 | outputs.append(self.X[1] * C[1])
135 | else:
136 | outputs = self.X * (C[0])
137 | return outputs
138 |
139 |
140 |
141 | class Clone(RelProp):
142 | def forward(self, input, num):
143 | self.__setattr__('num', num)
144 | outputs = []
145 | for _ in range(num):
146 | outputs.append(input)
147 |
148 | return outputs
149 |
150 | def relprop(self, R, alpha):
151 | Z = []
152 | for _ in range(self.num):
153 | Z.append(self.X)
154 | S = [safe_divide(r, z) for r, z in zip(R, Z)]
155 | C = self.gradprop(Z, self.X, S)[0]
156 |
157 | R = self.X * C
158 |
159 | return R
160 |
161 | class Cat(RelProp):
162 | def forward(self, inputs, dim):
163 | self.__setattr__('dim', dim)
164 | return torch.cat(inputs, dim)
165 |
166 | def relprop(self, R, alpha):
167 | Z = self.forward(self.X, self.dim)
168 | S = safe_divide(R, Z)
169 | C = self.gradprop(Z, self.X, S)
170 |
171 | outputs = []
172 | for x, c in zip(self.X, C):
173 | outputs.append(x * c)
174 |
175 | return outputs
176 |
177 | class Sequential(nn.Sequential):
178 | def relprop(self, R, alpha):
179 | for m in reversed(self._modules.values()):
180 | R = m.relprop(R, alpha)
181 | return R
182 |
183 | class BatchNorm2d(nn.BatchNorm2d, RelProp):
184 | def relprop(self, R, alpha):
185 | X = self.X
186 | beta = 1 - alpha
187 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
188 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
189 | Z = X * weight + 1e-9
190 | S = R / Z
191 | Ca = S * weight
192 | R = self.X * (Ca)
193 | return R
194 |
195 |
196 | class Linear(nn.Linear, RelProp):
197 | def relprop(self, R, alpha):
198 | beta = alpha - 1
199 | pw = torch.clamp(self.weight, min=0)
200 | nw = torch.clamp(self.weight, max=0)
201 | px = torch.clamp(self.X, min=0)
202 | nx = torch.clamp(self.X, max=0)
203 |
204 | def f(w1, w2, x1, x2):
205 | Z1 = F.linear(x1, w1)
206 | Z2 = F.linear(x2, w2)
207 | S1 = safe_divide(R, Z1)
208 | S2 = safe_divide(R, Z2)
209 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
210 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
211 |
212 | return C1 + C2
213 |
214 | activator_relevances = f(pw, nw, px, nx)
215 | inhibitor_relevances = f(nw, pw, px, nx)
216 |
217 | R = alpha * activator_relevances - beta * inhibitor_relevances
218 |
219 | return R
220 |
221 | class Conv2d(nn.Conv2d, RelProp):
222 | def gradprop2(self, DY, weight):
223 | Z = self.forward(self.X)
224 |
225 | output_padding = self.X.size()[2] - (
226 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
227 |
228 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
229 |
230 | def relprop(self, R, alpha):
231 | if self.X.shape[1] == 3:
232 | pw = torch.clamp(self.weight, min=0)
233 | nw = torch.clamp(self.weight, max=0)
234 | X = self.X
235 | L = self.X * 0 + \
236 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
237 | keepdim=True)[0]
238 | H = self.X * 0 + \
239 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
240 | keepdim=True)[0]
241 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
242 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
243 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
244 |
245 | S = R / Za
246 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
247 | R = C
248 | else:
249 | beta = alpha - 1
250 | pw = torch.clamp(self.weight, min=0)
251 | nw = torch.clamp(self.weight, max=0)
252 | px = torch.clamp(self.X, min=0)
253 | nx = torch.clamp(self.X, max=0)
254 |
255 | def f(w1, w2, x1, x2):
256 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
257 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
258 | S1 = safe_divide(R, Z1)
259 | S2 = safe_divide(R, Z2)
260 | C1 = x1 * self.gradprop(Z1, x1, S1)[0]
261 | C2 = x2 * self.gradprop(Z2, x2, S2)[0]
262 | return C1 + C2
263 |
264 | activator_relevances = f(pw, nw, px, nx)
265 | inhibitor_relevances = f(nw, pw, px, nx)
266 |
267 | R = alpha * activator_relevances - beta * inhibitor_relevances
268 | return R
--------------------------------------------------------------------------------
/BERT_explainability/modules/layers_ours.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7 | 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul']
8 |
9 |
10 | def safe_divide(a, b):
11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12 | den = den + den.eq(0).type(den.type()) * 1e-9
13 | return a / den * b.ne(0).type(b.type())
14 |
15 |
16 | def forward_hook(self, input, output):
17 | if type(input[0]) in (list, tuple):
18 | self.X = []
19 | for i in input[0]:
20 | x = i.detach()
21 | x.requires_grad = True
22 | self.X.append(x)
23 | else:
24 | self.X = input[0].detach()
25 | self.X.requires_grad = True
26 |
27 | self.Y = output
28 |
29 |
30 | def backward_hook(self, grad_input, grad_output):
31 | self.grad_input = grad_input
32 | self.grad_output = grad_output
33 |
34 |
35 | class RelProp(nn.Module):
36 | def __init__(self):
37 | super(RelProp, self).__init__()
38 | # if not self.training:
39 | self.register_forward_hook(forward_hook)
40 |
41 | def gradprop(self, Z, X, S):
42 | C = torch.autograd.grad(Z, X, S, retain_graph=True)
43 | return C
44 |
45 | def relprop(self, R, alpha):
46 | return R
47 |
48 |
49 | class RelPropSimple(RelProp):
50 | def relprop(self, R, alpha):
51 | Z = self.forward(self.X)
52 | S = safe_divide(R, Z)
53 | C = self.gradprop(Z, self.X, S)
54 |
55 | if torch.is_tensor(self.X) == False:
56 | outputs = []
57 | outputs.append(self.X[0] * C[0])
58 | outputs.append(self.X[1] * C[1])
59 | else:
60 | outputs = self.X * (C[0])
61 | return outputs
62 |
63 | class AddEye(RelPropSimple):
64 | # input of shape B, C, seq_len, seq_len
65 | def forward(self, input):
66 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67 |
68 | class ReLU(nn.ReLU, RelProp):
69 | pass
70 |
71 | class GELU(nn.GELU, RelProp):
72 | pass
73 |
74 | class Softmax(nn.Softmax, RelProp):
75 | pass
76 |
77 | class Mul(RelPropSimple):
78 | def forward(self, inputs):
79 | return torch.mul(*inputs)
80 |
81 | class Tanh(nn.Tanh, RelProp):
82 | pass
83 | class LayerNorm(nn.LayerNorm, RelProp):
84 | pass
85 |
86 | class Dropout(nn.Dropout, RelProp):
87 | pass
88 |
89 | class MatMul(RelPropSimple):
90 | def forward(self, inputs):
91 | return torch.matmul(*inputs)
92 |
93 | class MaxPool2d(nn.MaxPool2d, RelPropSimple):
94 | pass
95 |
96 | class LayerNorm(nn.LayerNorm, RelProp):
97 | pass
98 |
99 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
100 | pass
101 |
102 |
103 | class AvgPool2d(nn.AvgPool2d, RelPropSimple):
104 | pass
105 |
106 |
107 | class Add(RelPropSimple):
108 | def forward(self, inputs):
109 | return torch.add(*inputs)
110 |
111 | def relprop(self, R, alpha):
112 | Z = self.forward(self.X)
113 | S = safe_divide(R, Z)
114 | C = self.gradprop(Z, self.X, S)
115 |
116 | a = self.X[0] * C[0]
117 | b = self.X[1] * C[1]
118 |
119 | a_sum = a.sum()
120 | b_sum = b.sum()
121 |
122 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
123 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
124 |
125 | a = a * safe_divide(a_fact, a.sum())
126 | b = b * safe_divide(b_fact, b.sum())
127 |
128 | outputs = [a, b]
129 |
130 | return outputs
131 |
132 | class einsum(RelPropSimple):
133 | def __init__(self, equation):
134 | super().__init__()
135 | self.equation = equation
136 | def forward(self, *operands):
137 | return torch.einsum(self.equation, *operands)
138 |
139 | class IndexSelect(RelProp):
140 | def forward(self, inputs, dim, indices):
141 | self.__setattr__('dim', dim)
142 | self.__setattr__('indices', indices)
143 |
144 | return torch.index_select(inputs, dim, indices)
145 |
146 | def relprop(self, R, alpha):
147 | Z = self.forward(self.X, self.dim, self.indices)
148 | S = safe_divide(R, Z)
149 | C = self.gradprop(Z, self.X, S)
150 |
151 | if torch.is_tensor(self.X) == False:
152 | outputs = []
153 | outputs.append(self.X[0] * C[0])
154 | outputs.append(self.X[1] * C[1])
155 | else:
156 | outputs = self.X * (C[0])
157 | return outputs
158 |
159 |
160 |
161 | class Clone(RelProp):
162 | def forward(self, input, num):
163 | self.__setattr__('num', num)
164 | outputs = []
165 | for _ in range(num):
166 | outputs.append(input)
167 |
168 | return outputs
169 |
170 | def relprop(self, R, alpha):
171 | Z = []
172 | for _ in range(self.num):
173 | Z.append(self.X)
174 | S = [safe_divide(r, z) for r, z in zip(R, Z)]
175 | C = self.gradprop(Z, self.X, S)[0]
176 |
177 | R = self.X * C
178 |
179 | return R
180 |
181 |
182 | class Cat(RelProp):
183 | def forward(self, inputs, dim):
184 | self.__setattr__('dim', dim)
185 | return torch.cat(inputs, dim)
186 |
187 | def relprop(self, R, alpha):
188 | Z = self.forward(self.X, self.dim)
189 | S = safe_divide(R, Z)
190 | C = self.gradprop(Z, self.X, S)
191 |
192 | outputs = []
193 | for x, c in zip(self.X, C):
194 | outputs.append(x * c)
195 |
196 | return outputs
197 |
198 |
199 | class Sequential(nn.Sequential):
200 | def relprop(self, R, alpha):
201 | for m in reversed(self._modules.values()):
202 | R = m.relprop(R, alpha)
203 | return R
204 |
205 |
206 | class BatchNorm2d(nn.BatchNorm2d, RelProp):
207 | def relprop(self, R, alpha):
208 | X = self.X
209 | beta = 1 - alpha
210 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
211 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
212 | Z = X * weight + 1e-9
213 | S = R / Z
214 | Ca = S * weight
215 | R = self.X * (Ca)
216 | return R
217 |
218 |
219 | class Linear(nn.Linear, RelProp):
220 | def relprop(self, R, alpha):
221 | beta = alpha - 1
222 | pw = torch.clamp(self.weight, min=0)
223 | nw = torch.clamp(self.weight, max=0)
224 | px = torch.clamp(self.X, min=0)
225 | nx = torch.clamp(self.X, max=0)
226 |
227 | def f(w1, w2, x1, x2):
228 | Z1 = F.linear(x1, w1)
229 | Z2 = F.linear(x2, w2)
230 | S1 = safe_divide(R, Z1 + Z2)
231 | S2 = safe_divide(R, Z1 + Z2)
232 | C1 = x1 * self.gradprop(Z1, x1, S1)[0]
233 | C2 = x2 * self.gradprop(Z2, x2, S2)[0]
234 |
235 | return C1 + C2
236 |
237 | activator_relevances = f(pw, nw, px, nx)
238 | inhibitor_relevances = f(nw, pw, px, nx)
239 |
240 | R = alpha * activator_relevances - beta * inhibitor_relevances
241 |
242 | return R
243 |
244 |
245 | class Conv2d(nn.Conv2d, RelProp):
246 | def gradprop2(self, DY, weight):
247 | Z = self.forward(self.X)
248 |
249 | output_padding = self.X.size()[2] - (
250 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
251 |
252 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
253 |
254 | def relprop(self, R, alpha):
255 | if self.X.shape[1] == 3:
256 | pw = torch.clamp(self.weight, min=0)
257 | nw = torch.clamp(self.weight, max=0)
258 | X = self.X
259 | L = self.X * 0 + \
260 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
261 | keepdim=True)[0]
262 | H = self.X * 0 + \
263 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
264 | keepdim=True)[0]
265 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
266 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
267 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
268 |
269 | S = R / Za
270 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
271 | R = C
272 | else:
273 | beta = alpha - 1
274 | pw = torch.clamp(self.weight, min=0)
275 | nw = torch.clamp(self.weight, max=0)
276 | px = torch.clamp(self.X, min=0)
277 | nx = torch.clamp(self.X, max=0)
278 |
279 | def f(w1, w2, x1, x2):
280 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
281 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
282 | S1 = safe_divide(R, Z1)
283 | S2 = safe_divide(R, Z2)
284 | C1 = x1 * self.gradprop(Z1, x1, S1)[0]
285 | C2 = x2 * self.gradprop(Z2, x2, S2)[0]
286 | return C1 + C2
287 |
288 | activator_relevances = f(pw, nw, px, nx)
289 | inhibitor_relevances = f(nw, pw, px, nx)
290 |
291 | R = alpha * activator_relevances - beta * inhibitor_relevances
292 | return R
--------------------------------------------------------------------------------
/BERT_params/boolq.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.05
5 | },
6 | "evidence_identifier": {
7 | "mlp_size": 128,
8 | "dropout": 0.2,
9 | "batch_size": 768,
10 | "epochs": 50,
11 | "patience": 10,
12 | "lr": 1e-3,
13 | "sampling_method": "random",
14 | "sampling_ratio": 1.0
15 | },
16 | "evidence_classifier": {
17 | "classes": [ "False", "True" ],
18 | "mlp_size": 128,
19 | "dropout": 0.2,
20 | "batch_size": 768,
21 | "epochs": 50,
22 | "patience": 10,
23 | "lr": 1e-3,
24 | "sampling_method": "everything"
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/BERT_params/boolq_baas.json:
--------------------------------------------------------------------------------
1 | {
2 | "start_server": 0,
3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4 | "max_length": 512,
5 | "pooling_strategy": "CLS_TOKEN",
6 | "evidence_identifier": {
7 | "batch_size": 64,
8 | "epochs": 3,
9 | "patience": 10,
10 | "lr": 1e-3,
11 | "max_grad_norm": 1.0,
12 | "sampling_method": "random",
13 | "sampling_ratio": 1.0
14 | },
15 | "evidence_classifier": {
16 | "classes": [ "False", "True" ],
17 | "batch_size": 64,
18 | "epochs": 3,
19 | "patience": 10,
20 | "lr": 1e-3,
21 | "max_grad_norm": 1.0,
22 | "sampling_method": "everything"
23 | }
24 | }
25 |
26 |
27 |
--------------------------------------------------------------------------------
/BERT_params/boolq_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 1,
6 | "use_evidence_token_identifier": 0,
7 | "evidence_identifier": {
8 | "batch_size": 10,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 50,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1,
14 | "sampling_method": "random",
15 | "sampling_ratio": 1,
16 | "use_half_precision": 0
17 | },
18 | "evidence_classifier": {
19 | "classes": [
20 | "False",
21 | "True"
22 | ],
23 | "batch_size": 10,
24 | "warmup_steps": 50,
25 | "epochs": 10,
26 | "patience": 10,
27 | "lr": 1e-05,
28 | "max_grad_norm": 1,
29 | "sampling_method": "everything",
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/BERT_params/boolq_soft.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.2
5 | },
6 | "classifier": {
7 | "classes": [ "False", "True" ],
8 | "has_query": 1,
9 | "hidden_size": 32,
10 | "mlp_size": 128,
11 | "dropout": 0.2,
12 | "batch_size": 16,
13 | "epochs": 50,
14 | "attention_epochs": 50,
15 | "patience": 10,
16 | "lr": 1e-3,
17 | "dropout": 0.2,
18 | "k_fraction": 0.07,
19 | "threshold": 0.1
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/BERT_params/cose_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 0,
6 | "use_evidence_token_identifier": 1,
7 | "evidence_token_identifier": {
8 | "batch_size": 32,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 10,
12 | "lr": 1e-05,
13 | "max_grad_norm": 0.5,
14 | "sampling_method": "everything",
15 | "use_half_precision": 0,
16 | "cose_data_hack": 1
17 | },
18 | "evidence_classifier": {
19 | "classes": [ "false", "true"],
20 | "batch_size": 32,
21 | "warmup_steps": 10,
22 | "epochs": 10,
23 | "patience": 10,
24 | "lr": 1e-05,
25 | "max_grad_norm": 0.5,
26 | "sampling_method": "everything",
27 | "use_half_precision": 0,
28 | "cose_data_hack": 1
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/BERT_params/cose_multiclass.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 1,
6 | "use_evidence_token_identifier": 0,
7 | "evidence_identifier": {
8 | "batch_size": 32,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 50,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1,
14 | "sampling_method": "random",
15 | "sampling_ratio": 1,
16 | "use_half_precision": 0
17 | },
18 | "evidence_classifier": {
19 | "classes": [
20 | "A",
21 | "B",
22 | "C",
23 | "D",
24 | "E"
25 | ],
26 | "batch_size": 10,
27 | "warmup_steps": 50,
28 | "epochs": 10,
29 | "patience": 10,
30 | "lr": 1e-05,
31 | "max_grad_norm": 1,
32 | "sampling_method": "everything",
33 | "use_half_precision": 0
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/BERT_params/esnli_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 0,
6 | "use_evidence_token_identifier": 1,
7 | "evidence_token_identifier": {
8 | "batch_size": 32,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 10,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1,
14 | "sampling_method": "everything",
15 | "use_half_precision": 0
16 | },
17 | "evidence_classifier": {
18 | "classes": [ "contradiction", "neutral", "entailment" ],
19 | "batch_size": 32,
20 | "warmup_steps": 10,
21 | "epochs": 10,
22 | "patience": 10,
23 | "lr": 1e-05,
24 | "max_grad_norm": 1,
25 | "sampling_method": "everything",
26 | "use_half_precision": 0
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/BERT_params/evidence_inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/PubMed-w2v.bin",
4 | "dropout": 0.05
5 | },
6 | "evidence_identifier": {
7 | "mlp_size": 128,
8 | "dropout": 0.05,
9 | "batch_size": 768,
10 | "epochs": 50,
11 | "patience": 10,
12 | "lr": 1e-3,
13 | "sampling_method": "random",
14 | "sampling_ratio": 1.0
15 | },
16 | "evidence_classifier": {
17 | "classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
18 | "mlp_size": 128,
19 | "dropout": 0.05,
20 | "batch_size": 768,
21 | "epochs": 50,
22 | "patience": 10,
23 | "lr": 1e-3,
24 | "sampling_method": "everything"
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/BERT_params/evidence_inference_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "allenai/scibert_scivocab_uncased",
4 | "bert_dir": "allenai/scibert_scivocab_uncased",
5 | "use_evidence_sentence_identifier": 1,
6 | "use_evidence_token_identifier": 0,
7 | "evidence_identifier": {
8 | "batch_size": 10,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 10,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1,
14 | "sampling_method": "random",
15 | "use_half_precision": 0,
16 | "sampling_ratio": 1
17 | },
18 | "evidence_classifier": {
19 | "classes": [
20 | "significantly decreased",
21 | "no significant difference",
22 | "significantly increased"
23 | ],
24 | "batch_size": 10,
25 | "warmup_steps": 10,
26 | "epochs": 10,
27 | "patience": 10,
28 | "lr": 1e-05,
29 | "max_grad_norm": 1,
30 | "sampling_method": "everything",
31 | "use_half_precision": 0
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/BERT_params/evidence_inference_soft.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/PubMed-w2v.bin",
4 | "dropout": 0.2
5 | },
6 | "classifier": {
7 | "classes": [ "significantly decreased", "no significant difference", "significantly increased" ],
8 | "use_token_selection": 1,
9 | "has_query": 1,
10 | "hidden_size": 32,
11 | "mlp_size": 128,
12 | "dropout": 0.2,
13 | "batch_size": 16,
14 | "epochs": 50,
15 | "attention_epochs": 0,
16 | "patience": 10,
17 | "lr": 1e-3,
18 | "dropout": 0.2,
19 | "k_fraction": 0.013,
20 | "threshold": 0.1
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/BERT_params/fever.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.05
5 | },
6 | "evidence_identifier": {
7 | "mlp_size": 128,
8 | "dropout": 0.05,
9 | "batch_size": 768,
10 | "epochs": 50,
11 | "patience": 10,
12 | "lr": 1e-3,
13 | "sampling_method": "random",
14 | "sampling_ratio": 1.0
15 | },
16 | "evidence_classifier": {
17 | "classes": [ "SUPPORTS", "REFUTES" ],
18 | "mlp_size": 128,
19 | "dropout": 0.05,
20 | "batch_size": 768,
21 | "epochs": 50,
22 | "patience": 10,
23 | "lr": 1e-5,
24 | "sampling_method": "everything"
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/BERT_params/fever_baas.json:
--------------------------------------------------------------------------------
1 | {
2 | "start_server": 0,
3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4 | "max_length": 512,
5 | "pooling_strategy": "CLS_TOKEN",
6 | "evidence_identifier": {
7 | "batch_size": 64,
8 | "epochs": 3,
9 | "patience": 10,
10 | "lr": 1e-3,
11 | "max_grad_norm": 1.0,
12 | "sampling_method": "random",
13 | "sampling_ratio": 1.0
14 | },
15 | "evidence_classifier": {
16 | "classes": [ "SUPPORTS", "REFUTES" ],
17 | "batch_size": 64,
18 | "epochs": 3,
19 | "patience": 10,
20 | "lr": 1e-3,
21 | "max_grad_norm": 1.0,
22 | "sampling_method": "everything"
23 | }
24 | }
25 |
26 |
--------------------------------------------------------------------------------
/BERT_params/fever_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 1,
6 | "use_evidence_token_identifier": 0,
7 | "evidence_identifier": {
8 | "batch_size": 16,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 10,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1.0,
14 | "sampling_method": "random",
15 | "sampling_ratio": 1.0,
16 | "use_half_precision": 0
17 | },
18 | "evidence_classifier": {
19 | "classes": [
20 | "SUPPORTS",
21 | "REFUTES"
22 | ],
23 | "batch_size": 10,
24 | "warmup_steps": 10,
25 | "epochs": 10,
26 | "patience": 10,
27 | "lr": 1e-05,
28 | "max_grad_norm": 1.0,
29 | "sampling_method": "everything",
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/BERT_params/fever_soft.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.2
5 | },
6 | "classifier": {
7 | "classes": [ "SUPPORTS", "REFUTES" ],
8 | "has_query": 1,
9 | "hidden_size": 32,
10 | "mlp_size": 128,
11 | "dropout": 0.2,
12 | "batch_size": 128,
13 | "epochs": 50,
14 | "attention_epochs": 50,
15 | "patience": 10,
16 | "lr": 1e-3,
17 | "dropout": 0.2,
18 | "k_fraction": 0.07,
19 | "threshold": 0.1
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/BERT_params/movies.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.05
5 | },
6 | "evidence_identifier": {
7 | "mlp_size": 128,
8 | "dropout": 0.05,
9 | "batch_size": 768,
10 | "epochs": 50,
11 | "patience": 10,
12 | "lr": 1e-4,
13 | "sampling_method": "random",
14 | "sampling_ratio": 1.0
15 | },
16 | "evidence_classifier": {
17 | "classes": [ "NEG", "POS" ],
18 | "mlp_size": 128,
19 | "dropout": 0.05,
20 | "batch_size": 768,
21 | "epochs": 50,
22 | "patience": 10,
23 | "lr": 1e-3,
24 | "sampling_method": "everything"
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/BERT_params/movies_baas.json:
--------------------------------------------------------------------------------
1 | {
2 | "start_server": 0,
3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4 | "max_length": 512,
5 | "pooling_strategy": "CLS_TOKEN",
6 | "evidence_identifier": {
7 | "batch_size": 64,
8 | "epochs": 3,
9 | "patience": 10,
10 | "lr": 1e-3,
11 | "max_grad_norm": 1.0,
12 | "sampling_method": "random",
13 | "sampling_ratio": 1.0
14 | },
15 | "evidence_classifier": {
16 | "classes": [ "NEG", "POS" ],
17 | "batch_size": 64,
18 | "epochs": 3,
19 | "patience": 10,
20 | "lr": 1e-3,
21 | "max_grad_norm": 1.0,
22 | "sampling_method": "everything"
23 | }
24 | }
25 |
26 |
27 |
--------------------------------------------------------------------------------
/BERT_params/movies_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 1,
6 | "use_evidence_token_identifier": 0,
7 | "evidence_identifier": {
8 | "batch_size": 16,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 50,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1,
14 | "sampling_method": "random",
15 | "sampling_ratio": 1,
16 | "use_half_precision": 0
17 | },
18 | "evidence_classifier": {
19 | "classes": [
20 | "NEG",
21 | "POS"
22 | ],
23 | "batch_size": 10,
24 | "warmup_steps": 50,
25 | "epochs": 10,
26 | "patience": 10,
27 | "lr": 1e-05,
28 | "max_grad_norm": 1,
29 | "sampling_method": "everything",
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/BERT_params/movies_soft.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.2
5 | },
6 | "classifier": {
7 | "classes": [ "NEG", "POS" ],
8 | "has_query": 0,
9 | "hidden_size": 32,
10 | "mlp_size": 128,
11 | "dropout": 0.2,
12 | "batch_size": 16,
13 | "epochs": 50,
14 | "attention_epochs": 50,
15 | "patience": 10,
16 | "lr": 1e-3,
17 | "dropout": 0.2,
18 | "k_fraction": 0.07,
19 | "threshold": 0.1
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/BERT_params/multirc.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.05
5 | },
6 | "evidence_identifier": {
7 | "mlp_size": 128,
8 | "dropout": 0.05,
9 | "batch_size": 768,
10 | "epochs": 50,
11 | "patience": 10,
12 | "lr": 1e-3,
13 | "sampling_method": "random",
14 | "sampling_ratio": 1.0
15 | },
16 | "evidence_classifier": {
17 | "classes": [ "False", "True" ],
18 | "mlp_size": 128,
19 | "dropout": 0.05,
20 | "batch_size": 768,
21 | "epochs": 50,
22 | "patience": 10,
23 | "lr": 1e-3,
24 | "sampling_method": "everything"
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/BERT_params/multirc_baas.json:
--------------------------------------------------------------------------------
1 | {
2 | "start_server": 0,
3 | "bert_dir": "model_components/uncased_L-12_H-768_A-12/",
4 | "max_length": 512,
5 | "pooling_strategy": "CLS_TOKEN",
6 | "evidence_identifier": {
7 | "batch_size": 64,
8 | "epochs": 3,
9 | "patience": 10,
10 | "lr": 1e-3,
11 | "max_grad_norm": 1.0,
12 | "sampling_method": "random",
13 | "sampling_ratio": 1.0
14 | },
15 | "evidence_classifier": {
16 | "classes": [ "False", "True" ],
17 | "batch_size": 64,
18 | "epochs": 3,
19 | "patience": 10,
20 | "lr": 1e-3,
21 | "max_grad_norm": 1.0,
22 | "sampling_method": "everything"
23 | }
24 | }
25 |
26 |
27 |
--------------------------------------------------------------------------------
/BERT_params/multirc_bert.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "bert_vocab": "bert-base-uncased",
4 | "bert_dir": "bert-base-uncased",
5 | "use_evidence_sentence_identifier": 1,
6 | "use_evidence_token_identifier": 0,
7 | "evidence_identifier": {
8 | "batch_size": 32,
9 | "epochs": 10,
10 | "patience": 10,
11 | "warmup_steps": 50,
12 | "lr": 1e-05,
13 | "max_grad_norm": 1,
14 | "sampling_method": "random",
15 | "sampling_ratio": 1,
16 | "use_half_precision": 0
17 | },
18 | "evidence_classifier": {
19 | "classes": [
20 | "False",
21 | "True"
22 | ],
23 | "batch_size": 32,
24 | "warmup_steps": 50,
25 | "epochs": 10,
26 | "patience": 10,
27 | "lr": 1e-05,
28 | "max_grad_norm": 1,
29 | "sampling_method": "everything",
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/BERT_params/multirc_soft.json:
--------------------------------------------------------------------------------
1 | {
2 | "embeddings": {
3 | "embedding_file": "model_components/glove.6B.200d.txt",
4 | "dropout": 0.2
5 | },
6 | "classifier": {
7 | "classes": [ "False", "True" ],
8 | "has_query": 1,
9 | "hidden_size": 32,
10 | "mlp_size": 128,
11 | "dropout": 0.2,
12 | "batch_size": 16,
13 | "epochs": 50,
14 | "attention_epochs": 50,
15 | "patience": 10,
16 | "lr": 1e-3,
17 | "dropout": 0.2,
18 | "k_fraction": 0.07,
19 | "threshold": 0.1
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/BERT_rationale_benchmark/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/BERT_rationale_benchmark/__init__.py
--------------------------------------------------------------------------------
/BERT_rationale_benchmark/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, List, Set
3 |
4 | import numpy as np
5 | from gensim.models import KeyedVectors
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence
10 |
11 |
12 | @dataclass(eq=True, frozen=True)
13 | class PaddedSequence:
14 | """A utility class for padding variable length sequences mean for RNN input
15 | This class is in the style of PackedSequence from the PyTorch RNN Utils,
16 | but is somewhat more manual in approach. It provides the ability to generate masks
17 | for outputs of the same input dimensions.
18 | The constructor should never be called directly and should only be called via
19 | the autopad classmethod.
20 |
21 | We'd love to delete this, but we pad_sequence, pack_padded_sequence, and
22 | pad_packed_sequence all require shuffling around tuples of information, and some
23 | convenience methods using these are nice to have.
24 | """
25 |
26 | data: torch.Tensor
27 | batch_sizes: torch.Tensor
28 | batch_first: bool = False
29 |
30 | @classmethod
31 | def autopad(cls, data, batch_first: bool = False, padding_value=0, device=None) -> 'PaddedSequence':
32 | # handle tensors of size 0 (single item)
33 | data_ = []
34 | for d in data:
35 | if len(d.size()) == 0:
36 | d = d.unsqueeze(0)
37 | data_.append(d)
38 | padded = pad_sequence(data_, batch_first=batch_first, padding_value=padding_value)
39 | if batch_first:
40 | batch_lengths = torch.LongTensor([len(x) for x in data_])
41 | if any([x == 0 for x in batch_lengths]):
42 | raise ValueError(
43 | "Found a 0 length batch element, this can't possibly be right: {}".format(batch_lengths))
44 | else:
45 | # TODO actually test this codepath
46 | batch_lengths = torch.LongTensor([len(x) for x in data])
47 | return PaddedSequence(padded, batch_lengths, batch_first).to(device=device)
48 |
49 | def pack_other(self, data: torch.Tensor):
50 | return pack_padded_sequence(data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False)
51 |
52 | @classmethod
53 | def from_packed_sequence(cls, ps: PackedSequence, batch_first: bool, padding_value=0) -> 'PaddedSequence':
54 | padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value)
55 | return PaddedSequence(padded, batch_sizes, batch_first)
56 |
57 | def cuda(self) -> 'PaddedSequence':
58 | return PaddedSequence(self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first)
59 |
60 | def to(self, dtype=None, device=None, copy=False, non_blocking=False) -> 'PaddedSequence':
61 | # TODO make to() support all of the torch.Tensor to() variants
62 | return PaddedSequence(
63 | self.data.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking),
64 | self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking),
65 | batch_first=self.batch_first)
66 |
67 | def mask(self, on=int(0), off=int(0), device='cpu', size=None, dtype=None) -> torch.Tensor:
68 | if size is None:
69 | size = self.data.size()
70 | out_tensor = torch.zeros(*size, dtype=dtype)
71 | # TODO this can be done more efficiently
72 | out_tensor.fill_(off)
73 | # note to self: these are probably less efficient than explicilty populating the off values instead of the on values.
74 | if self.batch_first:
75 | for i, bl in enumerate(self.batch_sizes):
76 | out_tensor[i, :bl] = on
77 | else:
78 | for i, bl in enumerate(self.batch_sizes):
79 | out_tensor[:bl, i] = on
80 | return out_tensor.to(device)
81 |
82 | def unpad(self, other: torch.Tensor) -> List[torch.Tensor]:
83 | out = []
84 | for o, bl in zip(other, self.batch_sizes):
85 | out.append(o[:bl])
86 | return out
87 |
88 | def flip(self) -> 'PaddedSequence':
89 | return PaddedSequence(self.data.transpose(0, 1), not self.batch_first, self.padding_value)
90 |
91 |
92 | def extract_embeddings(vocab: Set[str], embedding_file: str, unk_token: str = 'UNK', pad_token: str = 'PAD') -> (
93 | nn.Embedding, Dict[str, int], List[str]):
94 | vocab = vocab | set([unk_token, pad_token])
95 | if embedding_file.endswith('.bin'):
96 | WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True)
97 |
98 | word_to_vector = dict()
99 | WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()])
100 |
101 | if unk_token not in WVs:
102 | mean_vector = np.mean(WV_matrix, axis=0)
103 | word_to_vector[unk_token] = mean_vector
104 | if pad_token not in WVs:
105 | word_to_vector[pad_token] = np.zeros(WVs.vector_size)
106 |
107 | for v in vocab:
108 | if v in WVs:
109 | word_to_vector[v] = WVs[v]
110 |
111 | interner = dict()
112 | deinterner = list()
113 | vectors = []
114 | count = 0
115 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})):
116 | vector = word_to_vector[word]
117 | vectors.append(np.array(vector))
118 | interner[word] = count
119 | deinterner.append(word)
120 | count += 1
121 | vectors = torch.FloatTensor(np.array(vectors))
122 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token])
123 | embedding.weight.requires_grad = False
124 | return embedding, interner, deinterner
125 | elif embedding_file.endswith('.txt'):
126 | word_to_vector = dict()
127 | vector = []
128 | with open(embedding_file, 'r') as inf:
129 | for line in inf:
130 | contents = line.strip().split()
131 | word = contents[0]
132 | vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0)
133 | word_to_vector[word] = vector
134 | embed_size = vector.size()
135 | if unk_token not in word_to_vector:
136 | mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0)
137 | word_to_vector[unk_token] = mean_vector.unsqueeze(0)
138 | if pad_token not in word_to_vector:
139 | word_to_vector[pad_token] = torch.zeros(embed_size)
140 | interner = dict()
141 | deinterner = list()
142 | vectors = []
143 | count = 0
144 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})):
145 | vector = word_to_vector[word]
146 | vectors.append(vector)
147 | interner[word] = count
148 | deinterner.append(word)
149 | count += 1
150 | vectors = torch.cat(vectors, dim=0)
151 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token])
152 | embedding.weight.requires_grad = False
153 | return embedding, interner, deinterner
154 | else:
155 | raise ValueError("Unable to open embeddings file {}".format(embedding_file))
156 |
--------------------------------------------------------------------------------
/BERT_rationale_benchmark/models/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/BERT_rationale_benchmark/models/pipeline/__init__.py
--------------------------------------------------------------------------------
/BERT_rationale_benchmark/models/pipeline/pipeline_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import random
5 | import os
6 |
7 | from itertools import chain
8 | from typing import Set
9 |
10 | import numpy as np
11 | import torch
12 |
13 | from rationale_benchmark.utils import (
14 | write_jsonl,
15 | load_datasets,
16 | load_documents,
17 | intern_documents,
18 | intern_annotations
19 | )
20 | from rationale_benchmark.models.mlp import (
21 | AttentiveClassifier,
22 | BahadanauAttention,
23 | RNNEncoder,
24 | WordEmbedder
25 | )
26 | from rationale_benchmark.models.model_utils import extract_embeddings
27 | from rationale_benchmark.models.pipeline.evidence_identifier import train_evidence_identifier
28 | from rationale_benchmark.models.pipeline.evidence_classifier import train_evidence_classifier
29 | from rationale_benchmark.models.pipeline.pipeline_utils import decode
30 |
31 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s')
32 | # let's make this more or less deterministic (not resistant to restarts)
33 | random.seed(12345)
34 | np.random.seed(67890)
35 | torch.manual_seed(10111213)
36 | torch.backends.cudnn.deterministic = True
37 | torch.backends.cudnn.benchmark = False
38 |
39 |
40 | def initialize_models(params: dict, vocab: Set[str], batch_first: bool, unk_token='UNK'):
41 | # TODO this is obviously asking for some sort of dependency injection. implement if it saves me time.
42 | if 'embedding_file' in params['embeddings']:
43 | embeddings, word_interner, de_interner = extract_embeddings(vocab, params['embeddings']['embedding_file'], unk_token=unk_token)
44 | if torch.cuda.is_available():
45 | embeddings = embeddings.cuda()
46 | else:
47 | raise ValueError("No 'embedding_file' found in params!")
48 | word_embedder = WordEmbedder(embeddings, params['embeddings']['dropout'])
49 | query_encoder = RNNEncoder(word_embedder,
50 | batch_first=batch_first,
51 | condition=False,
52 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension))
53 | document_encoder = RNNEncoder(word_embedder,
54 | batch_first=batch_first,
55 | condition=True,
56 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension,
57 | query_size=query_encoder.output_dimension))
58 | evidence_identifier = AttentiveClassifier(document_encoder,
59 | query_encoder,
60 | 2,
61 | params['evidence_identifier']['mlp_size'],
62 | params['evidence_identifier']['dropout'])
63 | query_encoder = RNNEncoder(word_embedder,
64 | batch_first=batch_first,
65 | condition=False,
66 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension))
67 | document_encoder = RNNEncoder(word_embedder,
68 | batch_first=batch_first,
69 | condition=True,
70 | attention_mechanism=BahadanauAttention(word_embedder.output_dimension,
71 | query_size=query_encoder.output_dimension))
72 | evidence_classes = dict((y,x) for (x,y) in enumerate(params['evidence_classifier']['classes']))
73 | evidence_classifier = AttentiveClassifier(document_encoder,
74 | query_encoder,
75 | len(evidence_classes),
76 | params['evidence_classifier']['mlp_size'],
77 | params['evidence_classifier']['dropout'])
78 | return evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes
79 |
80 |
81 | def main():
82 | parser = argparse.ArgumentParser(description="""Trains a pipeline model.
83 |
84 | Step 1 is evidence identification, that is identify if a given sentence is evidence or not
85 | Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task (e.g. sentiment or significance).
86 |
87 | These models should be separated into two separate steps, but at the moment:
88 | * prep data (load, intern documents, load json)
89 | * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives
90 | * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values.
91 | * train evidence identification
92 | * convert data for evidence classification - take all rationales + decisions and use this as input
93 | * train evidence classification
94 | * decode first the evidence, then run classification for each split
95 |
96 | """, formatter_class=argparse.RawTextHelpFormatter)
97 | parser.add_argument('--data_dir', dest='data_dir', required=True,
98 | help='Which directory contains a {train,val,test}.jsonl file?')
99 | parser.add_argument('--output_dir', dest='output_dir', required=True,
100 | help='Where shall we write intermediate models + final data to?')
101 | parser.add_argument('--model_params', dest='model_params', required=True,
102 | help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.')
103 | args = parser.parse_args()
104 | BATCH_FIRST = True
105 |
106 | with open(args.model_params, 'r') as fp:
107 | logging.debug(f'Loading model parameters from {args.model_params}')
108 | model_params = json.load(fp)
109 | train, val, test = load_datasets(args.data_dir)
110 | docids = set(e.docid for e in chain.from_iterable(chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test)))))
111 | documents = load_documents(args.data_dir, docids)
112 | document_vocab = set(chain.from_iterable(chain.from_iterable(documents.values())))
113 | annotation_vocab = set(chain.from_iterable(e.query.split() for e in chain(train, val, test)))
114 | logging.debug(f'Loaded {len(documents)} documents with {len(document_vocab)} unique words')
115 | # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important
116 | vocab = document_vocab | annotation_vocab
117 | unk_token = 'UNK'
118 | evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes = \
119 | initialize_models(model_params, vocab, batch_first=BATCH_FIRST, unk_token=unk_token)
120 | logging.debug(f'Including annotations, we have {len(vocab)} total words in the data, with embeddings for {len(word_interner)}')
121 | interned_documents = intern_documents(documents, word_interner, unk_token)
122 | interned_train = intern_annotations(train, word_interner, unk_token)
123 | interned_val = intern_annotations(val, word_interner, unk_token)
124 | interned_test = intern_annotations(test, word_interner, unk_token)
125 | assert BATCH_FIRST # for correctness of the split dimension for DataParallel
126 | evidence_identifier, evidence_ident_results = train_evidence_identifier(evidence_identifier.cuda(),
127 | args.output_dir, interned_train,
128 | interned_val,
129 | interned_documents,
130 | model_params,
131 | tensorize_model_inputs=True)
132 | evidence_classifier, evidence_class_results = train_evidence_classifier(evidence_classifier.cuda(),
133 | args.output_dir,
134 | interned_train,
135 | interned_val,
136 | interned_documents,
137 | model_params,
138 | class_interner=evidence_classes,
139 | tensorize_model_inputs=True)
140 | pipeline_batch_size = min([model_params['evidence_classifier']['batch_size'],
141 | model_params['evidence_identifier']['batch_size']])
142 | pipeline_results, train_decoded, val_decoded, test_decoded = decode(evidence_identifier,
143 | evidence_classifier,
144 | interned_train,
145 | interned_val,
146 | interned_test,
147 | interned_documents,
148 | evidence_classes,
149 | pipeline_batch_size,
150 | tensorize_model_inputs=True)
151 | write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl'))
152 | write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl'))
153 | write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl'))
154 | with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \
155 | open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output:
156 | ident_output.write(json.dumps(evidence_ident_results))
157 | class_output.write(json.dumps(evidence_class_results))
158 | for k, v in pipeline_results.items():
159 | if type(v) is dict:
160 | for k1, v1 in v.items():
161 | logging.info(f'Pipeline results for {k}, {k1}={v1}')
162 | else:
163 | logging.info(f'Pipeline results {k}\t={v}')
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
168 |
--------------------------------------------------------------------------------
/BERT_rationale_benchmark/models/sequence_taggers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from typing import List, Tuple, Any
5 |
6 | from transformers import BertModel
7 |
8 | from rationale_benchmark.models.model_utils import PaddedSequence
9 |
10 |
11 | class BertTagger(nn.Module):
12 | def __init__(self,
13 | bert_dir: str,
14 | pad_token_id: int,
15 | cls_token_id: int,
16 | sep_token_id: int,
17 | max_length: int=512,
18 | use_half_precision=True):
19 | super(BertTagger, self).__init__()
20 | self.sep_token_id = sep_token_id
21 | self.cls_token_id = cls_token_id
22 | self.pad_token_id = pad_token_id
23 | self.max_length = max_length
24 | bert = BertModel.from_pretrained(bert_dir)
25 | if use_half_precision:
26 | import apex
27 | bert = bert.half()
28 | self.bert = bert
29 | self.relevance_tagger = nn.Sequential(
30 | nn.Linear(self.bert.config.hidden_size, 1),
31 | nn.Sigmoid()
32 | )
33 |
34 | def forward(self,
35 | query: List[torch.tensor],
36 | docids: List[Any],
37 | document_batch: List[torch.tensor],
38 | aggregate_spans: List[Tuple[int, int]]):
39 | assert len(query) == len(document_batch)
40 | # note about device management: since distributed training is enabled, the inputs to this module can be on
41 | # *any* device (preferably cpu, since we wrap and unwrap the module) we want to keep these params on the
42 | # input device (assuming CPU) for as long as possible for cheap memory access
43 | target_device = next(self.parameters()).device
44 | #cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
45 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
46 | input_tensors = []
47 | query_lengths = []
48 | for q, d in zip(query, document_batch):
49 | if len(q) + len(d) + 1 > self.max_length:
50 | d = d[:(self.max_length - len(q) - 1)]
51 | input_tensors.append(torch.cat([q, sep_token, d]))
52 | query_lengths.append(q.size()[0])
53 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device)
54 | outputs = self.bert(bert_input.data, attention_mask=bert_input.mask(on=0.0, off=float('-inf'), dtype=torch.float, device=target_device))
55 | hidden = outputs[0]
56 | classes = self.relevance_tagger(hidden)
57 | ret = []
58 | for ql, cls, doc in zip(query_lengths, classes, document_batch):
59 | start = ql + 1
60 | end = start + len(doc)
61 | ret.append(cls[ql + 1:end])
62 | return PaddedSequence.autopad(ret, batch_first=True, padding_value=0, device=target_device).data.squeeze(dim=-1)
63 |
--------------------------------------------------------------------------------
/BERT_rationale_benchmark/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from dataclasses import dataclass, asdict, is_dataclass
5 | from itertools import chain
6 | from typing import Dict, List, Set, Tuple, Union, FrozenSet
7 |
8 |
9 | @dataclass(eq=True, frozen=True)
10 | class Evidence:
11 | """
12 | (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience.
13 | Args:
14 | text: Some representation of the evidence text
15 | docid: Some identifier for the document
16 | start_token: The canonical start token, inclusive
17 | end_token: The canonical end token, exclusive
18 | start_sentence: Best guess start sentence, inclusive
19 | end_sentence: Best guess end sentence, exclusive
20 | """
21 | text: Union[str, Tuple[int], Tuple[str]]
22 | docid: str
23 | start_token: int = -1
24 | end_token: int = -1
25 | start_sentence: int = -1
26 | end_sentence: int = -1
27 |
28 |
29 | @dataclass(eq=True, frozen=True)
30 | class Annotation:
31 | """
32 | Args:
33 | annotation_id: unique ID for this annotation element
34 | query: some representation of a query string
35 | evidences: a set of "evidence groups".
36 | Each evidence group is:
37 | * sufficient to respond to the query (or justify an answer)
38 | * composed of one or more Evidences
39 | * may have multiple documents in it (depending on the dataset)
40 | - e-snli has multiple documents
41 | - other datasets do not
42 | classification: str
43 | query_type: Optional str, additional information about the query
44 | docids: a set of docids in which one may find evidence.
45 | """
46 | annotation_id: str
47 | query: Union[str, Tuple[int]]
48 | evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]]
49 | classification: str
50 | query_type: str = None
51 | docids: Set[str] = None
52 |
53 | def all_evidences(self) -> Tuple[Evidence]:
54 | return tuple(list(chain.from_iterable(self.evidences)))
55 |
56 |
57 | def annotations_to_jsonl(annotations, output_file):
58 | with open(output_file, 'w') as of:
59 | for ann in sorted(annotations, key=lambda x: x.annotation_id):
60 | as_json = _annotation_to_dict(ann)
61 | as_str = json.dumps(as_json, sort_keys=True)
62 | of.write(as_str)
63 | of.write('\n')
64 |
65 |
66 | def _annotation_to_dict(dc):
67 | # convenience method
68 | if is_dataclass(dc):
69 | d = asdict(dc)
70 | ret = dict()
71 | for k, v in d.items():
72 | ret[k] = _annotation_to_dict(v)
73 | return ret
74 | elif isinstance(dc, dict):
75 | ret = dict()
76 | for k, v in dc.items():
77 | k = _annotation_to_dict(k)
78 | v = _annotation_to_dict(v)
79 | ret[k] = v
80 | return ret
81 | elif isinstance(dc, str):
82 | return dc
83 | elif isinstance(dc, (set, frozenset, list, tuple)):
84 | ret = []
85 | for x in dc:
86 | ret.append(_annotation_to_dict(x))
87 | return tuple(ret)
88 | else:
89 | return dc
90 |
91 |
92 | def load_jsonl(fp: str) -> List[dict]:
93 | ret = []
94 | with open(fp, 'r') as inf:
95 | for line in inf:
96 | content = json.loads(line)
97 | ret.append(content)
98 | return ret
99 |
100 |
101 | def write_jsonl(jsonl, output_file):
102 | with open(output_file, 'w') as of:
103 | for js in jsonl:
104 | as_str = json.dumps(js, sort_keys=True)
105 | of.write(as_str)
106 | of.write('\n')
107 |
108 |
109 | def annotations_from_jsonl(fp: str) -> List[Annotation]:
110 | ret = []
111 | with open(fp, 'r') as inf:
112 | for line in inf:
113 | content = json.loads(line)
114 | ev_groups = []
115 | for ev_group in content['evidences']:
116 | ev_group = tuple([Evidence(**ev) for ev in ev_group])
117 | ev_groups.append(ev_group)
118 | content['evidences'] = frozenset(ev_groups)
119 | ret.append(Annotation(**content))
120 | return ret
121 |
122 |
123 | def load_datasets(data_dir: str) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]:
124 | """Loads a training, validation, and test dataset
125 |
126 | Each dataset is assumed to have been serialized by annotations_to_jsonl,
127 | that is it is a list of json-serialized Annotation instances.
128 | """
129 | train_data = annotations_from_jsonl(os.path.join(data_dir, 'train.jsonl'))
130 | val_data = annotations_from_jsonl(os.path.join(data_dir, 'val.jsonl'))
131 | test_data = annotations_from_jsonl(os.path.join(data_dir, 'test.jsonl'))
132 | return train_data, val_data, test_data
133 |
134 |
135 | def load_documents(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]:
136 | """Loads a subset of available documents from disk.
137 |
138 | Each document is assumed to be serialized as newline ('\n') separated sentences.
139 | Each sentence is assumed to be space (' ') joined tokens.
140 | """
141 | if os.path.exists(os.path.join(data_dir, 'docs.jsonl')):
142 | assert not os.path.exists(os.path.join(data_dir, 'docs'))
143 | return load_documents_from_file(data_dir, docids)
144 |
145 | docs_dir = os.path.join(data_dir, 'docs')
146 | res = dict()
147 | if docids is None:
148 | docids = sorted(os.listdir(docs_dir))
149 | else:
150 | docids = sorted(set(str(d) for d in docids))
151 | for d in docids:
152 | with open(os.path.join(docs_dir, d), 'r') as inf:
153 | res[d] = inf.read()
154 | return res
155 |
156 |
157 | def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]:
158 | """Loads a subset of available documents from disk.
159 |
160 | Returns a tokenized version of the document.
161 | """
162 | unflattened_docs = load_documents(data_dir, docids)
163 | flattened_docs = dict()
164 | for doc, unflattened in unflattened_docs.items():
165 | flattened_docs[doc] = list(chain.from_iterable(unflattened))
166 | return flattened_docs
167 |
168 |
169 | def intern_documents(documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str):
170 | """
171 | Replaces every word with its index in an embeddings file.
172 |
173 | If a word is not found, uses the unk_token instead
174 | """
175 | ret = dict()
176 | unk = word_interner[unk_token]
177 | for docid, sentences in documents.items():
178 | ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences]
179 | return ret
180 |
181 |
182 | def intern_annotations(annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str):
183 | ret = []
184 | for ann in annotations:
185 | ev_groups = []
186 | for ev_group in ann.evidences:
187 | evs = []
188 | for ev in ev_group:
189 | evs.append(Evidence(
190 | text=tuple([word_interner.get(t, word_interner[unk_token]) for t in ev.text.split()]),
191 | docid=ev.docid,
192 | start_token=ev.start_token,
193 | end_token=ev.end_token,
194 | start_sentence=ev.start_sentence,
195 | end_sentence=ev.end_sentence))
196 | ev_groups.append(tuple(evs))
197 | ret.append(Annotation(annotation_id=ann.annotation_id,
198 | query=tuple([word_interner.get(t, word_interner[unk_token]) for t in ann.query.split()]),
199 | evidences=frozenset(ev_groups),
200 | classification=ann.classification,
201 | query_type=ann.query_type))
202 | return ret
203 |
204 |
205 | def load_documents_from_file(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]:
206 | """Loads a subset of available documents from 'docs.jsonl' file on disk.
207 |
208 | Each document is assumed to be serialized as newline ('\n') separated sentences.
209 | Each sentence is assumed to be space (' ') joined tokens.
210 | """
211 | docs_file = os.path.join(data_dir, 'docs.jsonl')
212 | documents = load_jsonl(docs_file)
213 | documents = {doc['docid']: doc['document'] for doc in documents}
214 | # res = dict()
215 | # if docids is None:
216 | # docids = sorted(list(documents.keys()))
217 | # else:
218 | # docids = sorted(set(str(d) for d in docids))
219 | # for d in docids:
220 | # lines = documents[d].split('\n')
221 | # tokenized = [line.strip().split(' ') for line in lines]
222 | # res[d] = tokenized
223 | return documents
224 |
--------------------------------------------------------------------------------
/DeiT.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/DeiT.PNG
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Hila Chefer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch Implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838) [CVPR 2021]
2 |
3 | #### Check out our new advancements- [Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers](https://github.com/hila-chefer/Transformer-MM-Explainability)!
4 | Faster, more general, and can be applied to *any* type of attention!
5 | Among the features:
6 | * We remove LRP for a simple and quick solution, and prove that the great results from our first paper still hold!
7 | * We expand our work to *any* type of Transformer- not just self-attention based encoders, but also co-attention encoders and encoder-decoders!
8 | * We show that VQA models can actually understand both image and text and make connections!
9 | * We use a DETR object detector and create segmentation masks from our explanations!
10 | * We provide a colab notebook with all the examples. You can very easily add images and questions of your own!
11 |
12 |
13 |
14 |
15 |
16 | ---
17 | ## ViT explainability notebook:
18 | [](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb)
19 |
20 | ## BERT explainability notebook:
21 | [](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb)
22 | ---
23 |
24 | ## Updates
25 | April 5 2021: Check out this new [post](https://analyticsindiamag.com/compute-relevancy-of-transformer-networks-via-novel-interpretable-transformer/) about our paper! A great resource for understanding the main concepts behind our work.
26 |
27 | March 15 2021: [A Colab notebook for BERT for sentiment analysis added!](https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb)
28 |
29 | Feb 28 2021: Our paper was accepted to CVPR 2021!
30 |
31 | Feb 17 2021: [A Colab notebook with all examples added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/Transformer_explainability.ipynb)
32 |
33 | Jan 5 2021: [A Jupyter notebook for DeiT added!](https://github.com/hila-chefer/Transformer-Explainability/blob/main/DeiT_example.ipynb)
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | ## Introduction
42 | Official implementation of [Transformer Interpretability Beyond Attention Visualization](https://arxiv.org/abs/2012.09838).
43 |
44 | We introduce a novel method which allows to visualize classifications made by a Transformer based model for both vision and NLP tasks.
45 | Our method also allows to visualize explanations per class.
46 |
47 |
48 |
49 |
50 | Method consists of 3 phases:
51 |
52 | 1. Calculating relevance for each attention matrix using our novel formulation of LRP.
53 |
54 | 2. Backpropagation of gradients for each attention matrix w.r.t. the visualized class. Gradients are used to average attention heads.
55 |
56 | 3. Layer aggregation with rollout.
57 |
58 | Please notice our [Jupyter notebook](https://github.com/hila-chefer/Transformer-Explainability/blob/main/example.ipynb) where you can run the two class specific examples from the paper.
59 |
60 |
61 | 
62 |
63 | To add another input image, simply add the image to the [samples folder](https://github.com/hila-chefer/Transformer-Explainability/tree/main/samples), and use the `generate_visualization` function for your selected class of interest (using the `class_index={class_idx}`), not specifying the index will visualize the top class.
64 |
65 | ## Credits
66 | ViT implementation is based on:
67 | - https://github.com/rwightman/pytorch-image-models
68 | - https://github.com/lucidrains/vit-pytorch
69 | - pretrained weights from: https://github.com/google-research/vision_transformer
70 |
71 | BERT implementation is taken from the huggingface Transformers library:
72 | https://huggingface.co/transformers/
73 |
74 | ERASER benchmark code adapted from the ERASER GitHub implementation: https://github.com/jayded/eraserbenchmark
75 |
76 | Text visualizations in supplementary were created using TAHV heatmap generator for text: https://github.com/jiesutd/Text-Attention-Heatmap-Visualization
77 |
78 | ## Reproducing results on ViT
79 |
80 | ### Section A. Segmentation Results
81 |
82 | Example:
83 | ```
84 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/imagenet_seg_eval.py --method transformer_attribution --imagenet-seg-path /path/to/gtsegs_ijcv.mat
85 |
86 | ```
87 | [Link to download dataset](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat).
88 |
89 | In the exmaple above we run a segmentation test with our method. Notice you can choose which method you wish to run using the `--method` argument.
90 | You must provide a path to imagenet segmentation data in `--imagenet-seg-path`.
91 |
92 | ### Section B. Perturbation Results
93 |
94 | Example:
95 | ```
96 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/generate_visualizations.py --method transformer_attribution --imagenet-validation-path /path/to/imagenet_validation_directory
97 | ```
98 |
99 | Notice that you can choose to visualize by target or top class by using the `--vis-cls` argument.
100 |
101 | Now to run the perturbation test run the following command:
102 | ```
103 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/pertubation_eval_from_hdf5.py --method transformer_attribution
104 | ```
105 |
106 | Notice that you can use the `--neg` argument to run either positive or negative perturbation.
107 |
108 | ## Reproducing results on BERT
109 |
110 | 1. Download the pretrained weights:
111 |
112 | - Download `classifier.zip` from https://drive.google.com/file/d/1kGMTr69UWWe70i-o2_JfjmWDQjT66xwQ/view?usp=sharing
113 | - mkdir -p `./bert_models/movies`
114 | - unzip classifier.zip -d ./bert_models/movies/
115 |
116 | 2. Download the dataset pkl file:
117 |
118 | - Download `preprocessed.pkl` from https://drive.google.com/file/d/1-gfbTj6D87KIm_u1QMHGLKSL3e93hxBH/view?usp=sharing
119 | - mv preprocessed.pkl ./bert_models/movies
120 |
121 | 3. Download the dataset:
122 |
123 | - Download `movies.zip` from https://drive.google.com/file/d/11faFLGkc0hkw3wrGTYJBr1nIvkRb189F/view?usp=sharing
124 | - unzip movies.zip -d ./data/
125 |
126 | 4. Now you can run the model.
127 |
128 | Example:
129 | ```
130 | CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies/ --output_dir bert_models/movies/ --model_params BERT_params/movies_bert.json
131 | ```
132 | To control which algorithm to use for explanations change the `method` variable in `BERT_rationale_benchmark/models/pipeline/bert_pipeline.py` (Defaults to 'transformer_attribution' which is our method).
133 | Running this command will create a directory for the method in `bert_models/movies/`.
134 |
135 | In order to run f1 test with k, run the following command:
136 | ```
137 | PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/metrics.py --data_dir data/movies/ --split test --results bert_models/movies//identifier_results_k.json
138 | ```
139 |
140 | Also, in the method directory there will be created `.tex` files containing the explanations extracted for each example. This corresponds to our visualizations in the supplementary.
141 |
142 | ## Citing our paper
143 | If you make use of our work, please cite our paper:
144 | ```
145 | @InProceedings{Chefer_2021_CVPR,
146 | author = {Chefer, Hila and Gur, Shir and Wolf, Lior},
147 | title = {Transformer Interpretability Beyond Attention Visualization},
148 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
149 | month = {June},
150 | year = {2021},
151 | pages = {782-791}
152 | }
153 | ```
154 |
--------------------------------------------------------------------------------
/baselines/ViT/ViT_explanation_generator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import numpy as np
4 | from numpy import *
5 |
6 | # compute rollout between attention layers
7 | def compute_rollout_attention(all_layer_matrices, start_layer=0):
8 | # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9 | num_tokens = all_layer_matrices[0].shape[1]
10 | batch_size = all_layer_matrices[0].shape[0]
11 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13 | matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14 | for i in range(len(all_layer_matrices))]
15 | joint_attention = matrices_aug[start_layer]
16 | for i in range(start_layer+1, len(matrices_aug)):
17 | joint_attention = matrices_aug[i].bmm(joint_attention)
18 | return joint_attention
19 |
20 | class LRP:
21 | def __init__(self, model):
22 | self.model = model
23 | self.model.eval()
24 |
25 | def generate_LRP(self, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0):
26 | output = self.model(input)
27 | kwargs = {"alpha": 1}
28 | if index == None:
29 | index = np.argmax(output.cpu().data.numpy(), axis=-1)
30 |
31 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
32 | one_hot[0, index] = 1
33 | one_hot_vector = one_hot
34 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
35 | one_hot = torch.sum(one_hot.cuda() * output)
36 |
37 | self.model.zero_grad()
38 | one_hot.backward(retain_graph=True)
39 |
40 | return self.model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation,
41 | start_layer=start_layer, **kwargs)
42 |
43 |
44 |
45 | class Baselines:
46 | def __init__(self, model):
47 | self.model = model
48 | self.model.eval()
49 |
50 | def generate_cam_attn(self, input, index=None):
51 | output = self.model(input.cuda(), register_hook=True)
52 | if index == None:
53 | index = np.argmax(output.cpu().data.numpy())
54 |
55 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
56 | one_hot[0][index] = 1
57 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
58 | one_hot = torch.sum(one_hot.cuda() * output)
59 |
60 | self.model.zero_grad()
61 | one_hot.backward(retain_graph=True)
62 | #################### attn
63 | grad = self.model.blocks[-1].attn.get_attn_gradients()
64 | cam = self.model.blocks[-1].attn.get_attention_map()
65 | cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
66 | grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
67 | grad = grad.mean(dim=[1, 2], keepdim=True)
68 | cam = (cam * grad).mean(0).clamp(min=0)
69 | cam = (cam - cam.min()) / (cam.max() - cam.min())
70 |
71 | return cam
72 | #################### attn
73 |
74 | def generate_rollout(self, input, start_layer=0):
75 | self.model(input)
76 | blocks = self.model.blocks
77 | all_layer_attentions = []
78 | for blk in blocks:
79 | attn_heads = blk.attn.get_attention_map()
80 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
81 | all_layer_attentions.append(avg_heads)
82 | rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
83 | return rollout[:,0, 1:]
84 |
--------------------------------------------------------------------------------
/baselines/ViT/ViT_new.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from functools import partial
7 | from einops import rearrange
8 |
9 | from baselines.ViT.helpers import load_pretrained
10 | from baselines.ViT.weight_init import trunc_normal_
11 | from baselines.ViT.layer_helpers import to_2tuple
12 |
13 |
14 | def _cfg(url='', **kwargs):
15 | return {
16 | 'url': url,
17 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18 | 'crop_pct': .9, 'interpolation': 'bicubic',
19 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20 | **kwargs
21 | }
22 |
23 |
24 | default_cfgs = {
25 | # patch models
26 | 'vit_small_patch16_224': _cfg(
27 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28 | ),
29 | 'vit_base_patch16_224': _cfg(
30 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32 | ),
33 | 'vit_large_patch16_224': _cfg(
34 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36 | }
37 |
38 | class Mlp(nn.Module):
39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40 | super().__init__()
41 | out_features = out_features or in_features
42 | hidden_features = hidden_features or in_features
43 | self.fc1 = nn.Linear(in_features, hidden_features)
44 | self.act = act_layer()
45 | self.fc2 = nn.Linear(hidden_features, out_features)
46 | self.drop = nn.Dropout(drop)
47 |
48 | def forward(self, x):
49 | x = self.fc1(x)
50 | x = self.act(x)
51 | x = self.drop(x)
52 | x = self.fc2(x)
53 | x = self.drop(x)
54 | return x
55 |
56 |
57 | class Attention(nn.Module):
58 | def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
59 | super().__init__()
60 | self.num_heads = num_heads
61 | head_dim = dim // num_heads
62 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
63 | self.scale = head_dim ** -0.5
64 |
65 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66 | self.attn_drop = nn.Dropout(attn_drop)
67 | self.proj = nn.Linear(dim, dim)
68 | self.proj_drop = nn.Dropout(proj_drop)
69 |
70 | self.attn_gradients = None
71 | self.attention_map = None
72 |
73 | def save_attn_gradients(self, attn_gradients):
74 | self.attn_gradients = attn_gradients
75 |
76 | def get_attn_gradients(self):
77 | return self.attn_gradients
78 |
79 | def save_attention_map(self, attention_map):
80 | self.attention_map = attention_map
81 |
82 | def get_attention_map(self):
83 | return self.attention_map
84 |
85 | def forward(self, x, register_hook=False):
86 | b, n, _, h = *x.shape, self.num_heads
87 |
88 | # self.save_output(x)
89 | # x.register_hook(self.save_output_grad)
90 |
91 | qkv = self.qkv(x)
92 | q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
93 |
94 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
95 |
96 | attn = dots.softmax(dim=-1)
97 | attn = self.attn_drop(attn)
98 |
99 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
100 |
101 | self.save_attention_map(attn)
102 | if register_hook:
103 | attn.register_hook(self.save_attn_gradients)
104 |
105 | out = rearrange(out, 'b h n d -> b n (h d)')
106 | out = self.proj(out)
107 | out = self.proj_drop(out)
108 | return out
109 |
110 |
111 | class Block(nn.Module):
112 |
113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
114 | super().__init__()
115 | self.norm1 = norm_layer(dim)
116 | self.attn = Attention(
117 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
118 | self.norm2 = norm_layer(dim)
119 | mlp_hidden_dim = int(dim * mlp_ratio)
120 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
121 |
122 | def forward(self, x, register_hook=False):
123 | x = x + self.attn(self.norm1(x), register_hook=register_hook)
124 | x = x + self.mlp(self.norm2(x))
125 | return x
126 |
127 |
128 | class PatchEmbed(nn.Module):
129 | """ Image to Patch Embedding
130 | """
131 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
132 | super().__init__()
133 | img_size = to_2tuple(img_size)
134 | patch_size = to_2tuple(patch_size)
135 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
136 | self.img_size = img_size
137 | self.patch_size = patch_size
138 | self.num_patches = num_patches
139 |
140 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
141 |
142 | def forward(self, x):
143 | B, C, H, W = x.shape
144 | # FIXME look at relaxing size constraints
145 | assert H == self.img_size[0] and W == self.img_size[1], \
146 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
147 | x = self.proj(x).flatten(2).transpose(1, 2)
148 | return x
149 |
150 | class VisionTransformer(nn.Module):
151 | """ Vision Transformer
152 | """
153 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
154 | num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm):
155 | super().__init__()
156 | self.num_classes = num_classes
157 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
158 | self.patch_embed = PatchEmbed(
159 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
160 | num_patches = self.patch_embed.num_patches
161 |
162 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
163 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
164 | self.pos_drop = nn.Dropout(p=drop_rate)
165 |
166 | self.blocks = nn.ModuleList([
167 | Block(
168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
169 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
170 | for i in range(depth)])
171 | self.norm = norm_layer(embed_dim)
172 |
173 | # Classifier head
174 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175 |
176 | trunc_normal_(self.pos_embed, std=.02)
177 | trunc_normal_(self.cls_token, std=.02)
178 | self.apply(self._init_weights)
179 |
180 | def _init_weights(self, m):
181 | if isinstance(m, nn.Linear):
182 | trunc_normal_(m.weight, std=.02)
183 | if isinstance(m, nn.Linear) and m.bias is not None:
184 | nn.init.constant_(m.bias, 0)
185 | elif isinstance(m, nn.LayerNorm):
186 | nn.init.constant_(m.bias, 0)
187 | nn.init.constant_(m.weight, 1.0)
188 |
189 | @torch.jit.ignore
190 | def no_weight_decay(self):
191 | return {'pos_embed', 'cls_token'}
192 |
193 | def forward(self, x, register_hook=False):
194 | B = x.shape[0]
195 | x = self.patch_embed(x)
196 |
197 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
198 | x = torch.cat((cls_tokens, x), dim=1)
199 | x = x + self.pos_embed
200 | x = self.pos_drop(x)
201 |
202 | for blk in self.blocks:
203 | x = blk(x, register_hook=register_hook)
204 |
205 | x = self.norm(x)
206 | x = x[:, 0]
207 | x = self.head(x)
208 | return x
209 |
210 |
211 | def _conv_filter(state_dict, patch_size=16):
212 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
213 | out_dict = {}
214 | for k, v in state_dict.items():
215 | if 'patch_embed.proj.weight' in k:
216 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
217 | out_dict[k] = v
218 | return out_dict
219 |
220 |
221 | def vit_base_patch16_224(pretrained=False, **kwargs):
222 | model = VisionTransformer(
223 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
224 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
225 | model.default_cfg = default_cfgs['vit_base_patch16_224']
226 | if pretrained:
227 | load_pretrained(
228 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
229 | return model
230 |
231 | def vit_large_patch16_224(pretrained=False, **kwargs):
232 | model = VisionTransformer(
233 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
234 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
235 | model.default_cfg = default_cfgs['vit_large_patch16_224']
236 | if pretrained:
237 | load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
238 | return model
239 |
--------------------------------------------------------------------------------
/baselines/ViT/generate_visualizations.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import h5py
4 |
5 | import argparse
6 |
7 | # Import saliency methods and models
8 | from misc_functions import *
9 |
10 | from ViT_explanation_generator import Baselines, LRP
11 | from ViT_new import vit_base_patch16_224
12 | from ViT_LRP import vit_base_patch16_224 as vit_LRP
13 | from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
14 |
15 | from torchvision.datasets import ImageNet
16 |
17 |
18 | def normalize(tensor,
19 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
20 | dtype = tensor.dtype
21 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
22 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
23 | tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
24 | return tensor
25 |
26 |
27 | def compute_saliency_and_save(args):
28 | first = True
29 | with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f:
30 | data_cam = f.create_dataset('vis',
31 | (1, 1, 224, 224),
32 | maxshape=(None, 1, 224, 224),
33 | dtype=np.float32,
34 | compression="gzip")
35 | data_image = f.create_dataset('image',
36 | (1, 3, 224, 224),
37 | maxshape=(None, 3, 224, 224),
38 | dtype=np.float32,
39 | compression="gzip")
40 | data_target = f.create_dataset('target',
41 | (1,),
42 | maxshape=(None,),
43 | dtype=np.int32,
44 | compression="gzip")
45 | for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
46 | if first:
47 | first = False
48 | data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
49 | data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
50 | data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
51 | else:
52 | data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
53 | data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
54 | data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
55 |
56 | # Add data
57 | data_image[-data.shape[0]:] = data.data.cpu().numpy()
58 | data_target[-data.shape[0]:] = target.data.cpu().numpy()
59 |
60 | target = target.to(device)
61 |
62 | data = normalize(data)
63 | data = data.to(device)
64 | data.requires_grad_()
65 |
66 | index = None
67 | if args.vis_class == 'target':
68 | index = target
69 |
70 | if args.method == 'rollout':
71 | Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14)
72 | # Res = Res - Res.mean()
73 |
74 | elif args.method == 'lrp':
75 | Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14)
76 | # Res = Res - Res.mean()
77 |
78 | elif args.method == 'transformer_attribution':
79 | Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14)
80 | # Res = Res - Res.mean()
81 |
82 | elif args.method == 'full_lrp':
83 | Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224)
84 | # Res = Res - Res.mean()
85 |
86 | elif args.method == 'lrp_last_layer':
87 | Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \
88 | .reshape(data.shape[0], 1, 14, 14)
89 | # Res = Res - Res.mean()
90 |
91 | elif args.method == 'attn_last_layer':
92 | Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \
93 | .reshape(data.shape[0], 1, 14, 14)
94 |
95 | elif args.method == 'attn_gradcam':
96 | Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14)
97 |
98 | if args.method != 'full_lrp' and args.method != 'input_grads':
99 | Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
100 | Res = (Res - Res.min()) / (Res.max() - Res.min())
101 |
102 | data_cam[-data.shape[0]:] = Res.data.cpu().numpy()
103 |
104 |
105 | if __name__ == "__main__":
106 | parser = argparse.ArgumentParser(description='Train a segmentation')
107 | parser.add_argument('--batch-size', type=int,
108 | default=1,
109 | help='')
110 | parser.add_argument('--method', type=str,
111 | default='grad_rollout',
112 | choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
113 | 'attn_last_layer', 'attn_gradcam'],
114 | help='')
115 | parser.add_argument('--lmd', type=float,
116 | default=10,
117 | help='')
118 | parser.add_argument('--vis-class', type=str,
119 | default='top',
120 | choices=['top', 'target', 'index'],
121 | help='')
122 | parser.add_argument('--class-id', type=int,
123 | default=0,
124 | help='')
125 | parser.add_argument('--cls-agn', action='store_true',
126 | default=False,
127 | help='')
128 | parser.add_argument('--no-ia', action='store_true',
129 | default=False,
130 | help='')
131 | parser.add_argument('--no-fx', action='store_true',
132 | default=False,
133 | help='')
134 | parser.add_argument('--no-fgx', action='store_true',
135 | default=False,
136 | help='')
137 | parser.add_argument('--no-m', action='store_true',
138 | default=False,
139 | help='')
140 | parser.add_argument('--no-reg', action='store_true',
141 | default=False,
142 | help='')
143 | parser.add_argument('--is-ablation', type=bool,
144 | default=False,
145 | help='')
146 | parser.add_argument('--imagenet-validation-path', type=str,
147 | required=True,
148 | help='')
149 | args = parser.parse_args()
150 |
151 | # PATH variables
152 | PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
153 | os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)
154 |
155 | try:
156 | os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method,
157 | args.vis_class)))
158 | except OSError:
159 | pass
160 |
161 |
162 | os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True)
163 | if args.vis_class == 'index':
164 | os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
165 | args.vis_class,
166 | args.class_id)), exist_ok=True)
167 | args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
168 | args.vis_class,
169 | args.class_id))
170 | else:
171 | ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
172 | os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
173 | args.vis_class, ablation_fold)), exist_ok=True)
174 | args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
175 | args.vis_class, ablation_fold))
176 |
177 | cuda = torch.cuda.is_available()
178 | device = torch.device("cuda" if cuda else "cpu")
179 |
180 | # Model
181 | model = vit_base_patch16_224(pretrained=True).cuda()
182 | baselines = Baselines(model)
183 |
184 | # LRP
185 | model_LRP = vit_LRP(pretrained=True).cuda()
186 | model_LRP.eval()
187 | lrp = LRP(model_LRP)
188 |
189 | # orig LRP
190 | model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
191 | model_orig_LRP.eval()
192 | orig_lrp = LRP(model_orig_LRP)
193 |
194 | # Dataset loader for sample images
195 | transform = transforms.Compose([
196 | transforms.Resize((224, 224)),
197 | transforms.ToTensor(),
198 | ])
199 |
200 | imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform)
201 | sample_loader = torch.utils.data.DataLoader(
202 | imagenet_ds,
203 | batch_size=args.batch_size,
204 | shuffle=False,
205 | num_workers=4
206 | )
207 |
208 | compute_saliency_and_save(args)
209 |
--------------------------------------------------------------------------------
/baselines/ViT/helpers.py:
--------------------------------------------------------------------------------
1 | """ Model creation / weight loading / state_dict helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import logging
6 | import os
7 | import math
8 | from collections import OrderedDict
9 | from copy import deepcopy
10 | from typing import Callable
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.utils.model_zoo as model_zoo
15 |
16 | _logger = logging.getLogger(__name__)
17 |
18 |
19 | def load_state_dict(checkpoint_path, use_ema=False):
20 | if checkpoint_path and os.path.isfile(checkpoint_path):
21 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
22 | state_dict_key = 'state_dict'
23 | if isinstance(checkpoint, dict):
24 | if use_ema and 'state_dict_ema' in checkpoint:
25 | state_dict_key = 'state_dict_ema'
26 | if state_dict_key and state_dict_key in checkpoint:
27 | new_state_dict = OrderedDict()
28 | for k, v in checkpoint[state_dict_key].items():
29 | # strip `module.` prefix
30 | name = k[7:] if k.startswith('module') else k
31 | new_state_dict[name] = v
32 | state_dict = new_state_dict
33 | else:
34 | state_dict = checkpoint
35 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
36 | return state_dict
37 | else:
38 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
39 | raise FileNotFoundError()
40 |
41 |
42 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
43 | state_dict = load_state_dict(checkpoint_path, use_ema)
44 | model.load_state_dict(state_dict, strict=strict)
45 |
46 |
47 | def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
48 | resume_epoch = None
49 | if os.path.isfile(checkpoint_path):
50 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
51 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52 | if log_info:
53 | _logger.info('Restoring model state from checkpoint...')
54 | new_state_dict = OrderedDict()
55 | for k, v in checkpoint['state_dict'].items():
56 | name = k[7:] if k.startswith('module') else k
57 | new_state_dict[name] = v
58 | model.load_state_dict(new_state_dict)
59 |
60 | if optimizer is not None and 'optimizer' in checkpoint:
61 | if log_info:
62 | _logger.info('Restoring optimizer state from checkpoint...')
63 | optimizer.load_state_dict(checkpoint['optimizer'])
64 |
65 | if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
66 | if log_info:
67 | _logger.info('Restoring AMP loss scaler state from checkpoint...')
68 | loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
69 |
70 | if 'epoch' in checkpoint:
71 | resume_epoch = checkpoint['epoch']
72 | if 'version' in checkpoint and checkpoint['version'] > 1:
73 | resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
74 |
75 | if log_info:
76 | _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
77 | else:
78 | model.load_state_dict(checkpoint)
79 | if log_info:
80 | _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
81 | return resume_epoch
82 | else:
83 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
84 | raise FileNotFoundError()
85 |
86 |
87 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
88 | if cfg is None:
89 | cfg = getattr(model, 'default_cfg')
90 | if cfg is None or 'url' not in cfg or not cfg['url']:
91 | _logger.warning("Pretrained model URL is invalid, using random initialization.")
92 | return
93 |
94 | state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
95 |
96 | if filter_fn is not None:
97 | state_dict = filter_fn(state_dict)
98 |
99 | if in_chans == 1:
100 | conv1_name = cfg['first_conv']
101 | _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
102 | conv1_weight = state_dict[conv1_name + '.weight']
103 | # Some weights are in torch.half, ensure it's float for sum on CPU
104 | conv1_type = conv1_weight.dtype
105 | conv1_weight = conv1_weight.float()
106 | O, I, J, K = conv1_weight.shape
107 | if I > 3:
108 | assert conv1_weight.shape[1] % 3 == 0
109 | # For models with space2depth stems
110 | conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
111 | conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
112 | else:
113 | conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
114 | conv1_weight = conv1_weight.to(conv1_type)
115 | state_dict[conv1_name + '.weight'] = conv1_weight
116 | elif in_chans != 3:
117 | conv1_name = cfg['first_conv']
118 | conv1_weight = state_dict[conv1_name + '.weight']
119 | conv1_type = conv1_weight.dtype
120 | conv1_weight = conv1_weight.float()
121 | O, I, J, K = conv1_weight.shape
122 | if I != 3:
123 | _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
124 | del state_dict[conv1_name + '.weight']
125 | strict = False
126 | else:
127 | # NOTE this strategy should be better than random init, but there could be other combinations of
128 | # the original RGB input layer weights that'd work better for specific cases.
129 | _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
130 | repeat = int(math.ceil(in_chans / 3))
131 | conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
132 | conv1_weight *= (3 / float(in_chans))
133 | conv1_weight = conv1_weight.to(conv1_type)
134 | state_dict[conv1_name + '.weight'] = conv1_weight
135 |
136 | classifier_name = cfg['classifier']
137 | if num_classes == 1000 and cfg['num_classes'] == 1001:
138 | # special case for imagenet trained models with extra background class in pretrained weights
139 | classifier_weight = state_dict[classifier_name + '.weight']
140 | state_dict[classifier_name + '.weight'] = classifier_weight[1:]
141 | classifier_bias = state_dict[classifier_name + '.bias']
142 | state_dict[classifier_name + '.bias'] = classifier_bias[1:]
143 | elif num_classes != cfg['num_classes']:
144 | # completely discard fully connected for all other differences between pretrained and created model
145 | del state_dict[classifier_name + '.weight']
146 | del state_dict[classifier_name + '.bias']
147 | strict = False
148 |
149 | model.load_state_dict(state_dict, strict=strict)
150 |
151 |
152 | def extract_layer(model, layer):
153 | layer = layer.split('.')
154 | module = model
155 | if hasattr(model, 'module') and layer[0] != 'module':
156 | module = model.module
157 | if not hasattr(model, 'module') and layer[0] == 'module':
158 | layer = layer[1:]
159 | for l in layer:
160 | if hasattr(module, l):
161 | if not l.isdigit():
162 | module = getattr(module, l)
163 | else:
164 | module = module[int(l)]
165 | else:
166 | return module
167 | return module
168 |
169 |
170 | def set_layer(model, layer, val):
171 | layer = layer.split('.')
172 | module = model
173 | if hasattr(model, 'module') and layer[0] != 'module':
174 | module = model.module
175 | lst_index = 0
176 | module2 = module
177 | for l in layer:
178 | if hasattr(module2, l):
179 | if not l.isdigit():
180 | module2 = getattr(module2, l)
181 | else:
182 | module2 = module2[int(l)]
183 | lst_index += 1
184 | lst_index -= 1
185 | for l in layer[:lst_index]:
186 | if not l.isdigit():
187 | module = getattr(module, l)
188 | else:
189 | module = module[int(l)]
190 | l = layer[lst_index]
191 | setattr(module, l, val)
192 |
193 |
194 | def adapt_model_from_string(parent_module, model_string):
195 | separator = '***'
196 | state_dict = {}
197 | lst_shape = model_string.split(separator)
198 | for k in lst_shape:
199 | k = k.split(':')
200 | key = k[0]
201 | shape = k[1][1:-1].split(',')
202 | if shape[0] != '':
203 | state_dict[key] = [int(i) for i in shape]
204 |
205 | new_module = deepcopy(parent_module)
206 | for n, m in parent_module.named_modules():
207 | old_module = extract_layer(parent_module, n)
208 | if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
209 | if isinstance(old_module, Conv2dSame):
210 | conv = Conv2dSame
211 | else:
212 | conv = nn.Conv2d
213 | s = state_dict[n + '.weight']
214 | in_channels = s[1]
215 | out_channels = s[0]
216 | g = 1
217 | if old_module.groups > 1:
218 | in_channels = out_channels
219 | g = in_channels
220 | new_conv = conv(
221 | in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
222 | bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
223 | groups=g, stride=old_module.stride)
224 | set_layer(new_module, n, new_conv)
225 | if isinstance(old_module, nn.BatchNorm2d):
226 | new_bn = nn.BatchNorm2d(
227 | num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
228 | affine=old_module.affine, track_running_stats=True)
229 | set_layer(new_module, n, new_bn)
230 | if isinstance(old_module, nn.Linear):
231 | # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
232 | num_features = state_dict[n + '.weight'][1]
233 | new_fc = nn.Linear(
234 | in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
235 | set_layer(new_module, n, new_fc)
236 | if hasattr(new_module, 'num_features'):
237 | new_module.num_features = num_features
238 | new_module.eval()
239 | parent_module.eval()
240 |
241 | return new_module
242 |
243 |
244 | def adapt_model_from_file(parent_module, model_variant):
245 | adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
246 | with open(adapt_file, 'r') as f:
247 | return adapt_model_from_string(parent_module, f.read().strip())
248 |
249 |
250 | def build_model_with_cfg(
251 | model_cls: Callable,
252 | variant: str,
253 | pretrained: bool,
254 | default_cfg: dict,
255 | model_cfg: dict = None,
256 | feature_cfg: dict = None,
257 | pretrained_strict: bool = True,
258 | pretrained_filter_fn: Callable = None,
259 | **kwargs):
260 | pruned = kwargs.pop('pruned', False)
261 | features = False
262 | feature_cfg = feature_cfg or {}
263 |
264 | if kwargs.pop('features_only', False):
265 | features = True
266 | feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
267 | if 'out_indices' in kwargs:
268 | feature_cfg['out_indices'] = kwargs.pop('out_indices')
269 |
270 | model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
271 | model.default_cfg = deepcopy(default_cfg)
272 |
273 | if pruned:
274 | model = adapt_model_from_file(model, variant)
275 |
276 | if pretrained:
277 | load_pretrained(
278 | model,
279 | num_classes=kwargs.get('num_classes', 0),
280 | in_chans=kwargs.get('in_chans', 3),
281 | filter_fn=pretrained_filter_fn, strict=pretrained_strict)
282 |
283 | if features:
284 | feature_cls = FeatureListNet
285 | if 'feature_cls' in feature_cfg:
286 | feature_cls = feature_cfg.pop('feature_cls')
287 | if isinstance(feature_cls, str):
288 | feature_cls = feature_cls.lower()
289 | if 'hook' in feature_cls:
290 | feature_cls = FeatureHookNet
291 | else:
292 | assert False, f'Unknown feature class {feature_cls}'
293 | model = feature_cls(model, **feature_cfg)
294 |
295 | return model
--------------------------------------------------------------------------------
/baselines/ViT/layer_helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from itertools import repeat
5 | import collections.abc
6 |
7 |
8 | # From PyTorch internals
9 | def _ntuple(n):
10 | def parse(x):
11 | if isinstance(x, collections.abc.Iterable):
12 | return x
13 | return tuple(repeat(x, n))
14 | return parse
15 |
16 |
17 | to_1tuple = _ntuple(1)
18 | to_2tuple = _ntuple(2)
19 | to_3tuple = _ntuple(3)
20 | to_4tuple = _ntuple(4)
21 | to_ntuple = _ntuple
22 |
--------------------------------------------------------------------------------
/baselines/ViT/misc_functions.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/
3 | # Written by Suraj Srinivas
4 | #
5 |
6 | """ Misc helper functions """
7 |
8 | import cv2
9 | import numpy as np
10 | import subprocess
11 |
12 | import torch
13 | import torchvision.transforms as transforms
14 |
15 |
16 | class NormalizeInverse(transforms.Normalize):
17 | # Undo normalization on images
18 |
19 | def __init__(self, mean, std):
20 | mean = torch.as_tensor(mean)
21 | std = torch.as_tensor(std)
22 | std_inv = 1 / (std + 1e-7)
23 | mean_inv = -mean * std_inv
24 | super(NormalizeInverse, self).__init__(mean=mean_inv, std=std_inv)
25 |
26 | def __call__(self, tensor):
27 | return super(NormalizeInverse, self).__call__(tensor.clone())
28 |
29 |
30 | def create_folder(folder_name):
31 | try:
32 | subprocess.call(['mkdir', '-p', folder_name])
33 | except OSError:
34 | None
35 |
36 |
37 | def save_saliency_map(image, saliency_map, filename):
38 | """
39 | Save saliency map on image.
40 |
41 | Args:
42 | image: Tensor of size (3,H,W)
43 | saliency_map: Tensor of size (1,H,W)
44 | filename: string with complete path and file extension
45 |
46 | """
47 |
48 | image = image.data.cpu().numpy()
49 | saliency_map = saliency_map.data.cpu().numpy()
50 |
51 | saliency_map = saliency_map - saliency_map.min()
52 | saliency_map = saliency_map / saliency_map.max()
53 | saliency_map = saliency_map.clip(0, 1)
54 |
55 | saliency_map = np.uint8(saliency_map * 255).transpose(1, 2, 0)
56 | saliency_map = cv2.resize(saliency_map, (224, 224))
57 |
58 | image = np.uint8(image * 255).transpose(1, 2, 0)
59 | image = cv2.resize(image, (224, 224))
60 |
61 | # Apply JET colormap
62 | color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
63 |
64 | # Combine image with heatmap
65 | img_with_heatmap = np.float32(color_heatmap) + np.float32(image)
66 | img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
67 |
68 | cv2.imwrite(filename, np.uint8(255 * img_with_heatmap))
69 |
--------------------------------------------------------------------------------
/baselines/ViT/pertubation_eval_from_hdf5.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import os
4 | from tqdm import tqdm
5 | import numpy as np
6 | import argparse
7 |
8 | # Import saliency methods and models
9 | from ViT_explanation_generator import Baselines
10 | from ViT_new import vit_base_patch16_224
11 | # from models.vgg import vgg19
12 | import glob
13 |
14 | from dataset.expl_hdf5 import ImagenetResults
15 |
16 |
17 | def normalize(tensor,
18 | mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
19 | dtype = tensor.dtype
20 | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
21 | std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
22 | tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
23 | return tensor
24 |
25 |
26 | def eval(args):
27 | num_samples = 0
28 | num_correct_model = np.zeros((len(imagenet_ds,)))
29 | dissimilarity_model = np.zeros((len(imagenet_ds,)))
30 | model_index = 0
31 |
32 | if args.scale == 'per':
33 | base_size = 224 * 224
34 | perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
35 | elif args.scale == '100':
36 | base_size = 100
37 | perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45]
38 | else:
39 | raise Exception('scale not valid')
40 |
41 | num_correct_pertub = np.zeros((9, len(imagenet_ds)))
42 | dissimilarity_pertub = np.zeros((9, len(imagenet_ds)))
43 | logit_diff_pertub = np.zeros((9, len(imagenet_ds)))
44 | prob_diff_pertub = np.zeros((9, len(imagenet_ds)))
45 | perturb_index = 0
46 |
47 | for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)):
48 | # Update the number of samples
49 | num_samples += len(data)
50 |
51 | data = data.to(device)
52 | vis = vis.to(device)
53 | target = target.to(device)
54 | norm_data = normalize(data.clone())
55 |
56 | # Compute model accuracy
57 | pred = model(norm_data)
58 | pred_probabilities = torch.softmax(pred, dim=1)
59 | pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1)
60 | pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
61 | pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1)
62 | tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy()
63 | num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred
64 |
65 | probs = torch.softmax(pred, dim=1)
66 | target_probs = torch.gather(probs, 1, target[:, None])[:, 0]
67 | second_probs = probs.data.topk(2, dim=1)[0][:, 1]
68 | temp = torch.log(target_probs / second_probs).data.cpu().numpy()
69 | dissimilarity_model[model_index:model_index+len(temp)] = temp
70 |
71 | if args.wrong:
72 | wid = np.argwhere(tgt_pred == 0).flatten()
73 | if len(wid) == 0:
74 | continue
75 | wid = torch.from_numpy(wid).to(vis.device)
76 | vis = vis.index_select(0, wid)
77 | data = data.index_select(0, wid)
78 | target = target.index_select(0, wid)
79 |
80 | # Save original shape
81 | org_shape = data.shape
82 |
83 | if args.neg:
84 | vis = -vis
85 |
86 | vis = vis.reshape(org_shape[0], -1)
87 |
88 | for i in range(len(perturbation_steps)):
89 | _data = data.clone()
90 |
91 | _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1)
92 | idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1)
93 | _data = _data.reshape(org_shape[0], org_shape[1], -1)
94 | _data = _data.scatter_(-1, idx, 0)
95 | _data = _data.reshape(*org_shape)
96 |
97 | _norm_data = normalize(_data)
98 |
99 | out = model(_norm_data)
100 |
101 | pred_probabilities = torch.softmax(out, dim=1)
102 | pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
103 | diff = (pred_prob - pred_org_prob).data.cpu().numpy()
104 | prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
105 |
106 | pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1)
107 | diff = (pred_logit - pred_org_logit).data.cpu().numpy()
108 | logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
109 |
110 | target_class = out.data.max(1, keepdim=True)[1].squeeze(1)
111 | temp = (target == target_class).type(target.type()).data.cpu().numpy()
112 | num_correct_pertub[i, perturb_index:perturb_index+len(temp)] = temp
113 |
114 | probs_pertub = torch.softmax(out, dim=1)
115 | target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0]
116 | second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1]
117 | temp = torch.log(target_probs / second_probs).data.cpu().numpy()
118 | dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp
119 |
120 | model_index += len(target)
121 | perturb_index += len(target)
122 |
123 | np.save(os.path.join(args.experiment_dir, 'model_hits.npy'), num_correct_model)
124 | np.save(os.path.join(args.experiment_dir, 'model_dissimilarities.npy'), dissimilarity_model)
125 | np.save(os.path.join(args.experiment_dir, 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index])
126 | np.save(os.path.join(args.experiment_dir, 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index])
127 | np.save(os.path.join(args.experiment_dir, 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index])
128 | np.save(os.path.join(args.experiment_dir, 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index])
129 |
130 | print(np.mean(num_correct_model), np.std(num_correct_model))
131 | print(np.mean(dissimilarity_model), np.std(dissimilarity_model))
132 | print(perturbation_steps)
133 | print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1))
134 | print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1))
135 |
136 |
137 | if __name__ == "__main__":
138 | parser = argparse.ArgumentParser(description='Train a segmentation')
139 | parser.add_argument('--batch-size', type=int,
140 | default=16,
141 | help='')
142 | parser.add_argument('--neg', type=bool,
143 | default=True,
144 | help='')
145 | parser.add_argument('--value', action='store_true',
146 | default=False,
147 | help='')
148 | parser.add_argument('--scale', type=str,
149 | default='per',
150 | choices=['per', '100'],
151 | help='')
152 | parser.add_argument('--method', type=str,
153 | default='grad_rollout',
154 | choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer',
155 | 'lrp_second_layer', 'gradcam',
156 | 'attn_last_layer', 'attn_gradcam', 'input_grads'],
157 | help='')
158 | parser.add_argument('--vis-class', type=str,
159 | default='top',
160 | choices=['top', 'target', 'index'],
161 | help='')
162 | parser.add_argument('--wrong', action='store_true',
163 | default=False,
164 | help='')
165 | parser.add_argument('--class-id', type=int,
166 | default=0,
167 | help='')
168 | parser.add_argument('--is-ablation', type=bool,
169 | default=False,
170 | help='')
171 | args = parser.parse_args()
172 |
173 | torch.multiprocessing.set_start_method('spawn')
174 |
175 | # PATH variables
176 | PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
177 | dataset = PATH + 'dataset/'
178 | os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True)
179 | os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True)
180 |
181 | exp_name = args.method
182 | exp_name += '_neg' if args.neg else '_pos'
183 | print(exp_name)
184 |
185 | if args.vis_class == 'index':
186 | args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name,
187 | args.vis_class,
188 | args.class_id))
189 | else:
190 | ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
191 | args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name,
192 | args.vis_class, ablation_fold))
193 | # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name,
194 | # args.vis_class))
195 |
196 | if args.wrong:
197 | args.runs_dir += '_wrong'
198 |
199 | experiments = sorted(glob.glob(os.path.join(args.runs_dir, 'experiment_*')))
200 | experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0
201 | args.experiment_dir = os.path.join(args.runs_dir, 'experiment_{}'.format(str(experiment_id)))
202 | os.makedirs(args.experiment_dir, exist_ok=True)
203 |
204 | cuda = torch.cuda.is_available()
205 | device = torch.device("cuda" if cuda else "cpu")
206 |
207 | if args.vis_class == 'index':
208 | vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(args.method,
209 | args.vis_class,
210 | args.class_id))
211 | else:
212 | ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
213 | vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}'.format(args.method,
214 | args.vis_class, ablation_fold))
215 | # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method,
216 | # args.vis_class))
217 |
218 | # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method))
219 | imagenet_ds = ImagenetResults(vis_method_dir)
220 |
221 | # Model
222 | model = vit_base_patch16_224(pretrained=True).cuda()
223 | model.eval()
224 |
225 | save_path = PATH + 'results/'
226 |
227 | sample_loader = torch.utils.data.DataLoader(
228 | imagenet_ds,
229 | batch_size=args.batch_size,
230 | num_workers=2,
231 | shuffle=False)
232 |
233 | eval(args)
234 |
--------------------------------------------------------------------------------
/baselines/ViT/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 |
5 |
6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9 | def norm_cdf(x):
10 | # Computes standard normal cumulative distribution function
11 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
12 |
13 | if (mean < a - 2 * std) or (mean > b + 2 * std):
14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15 | "The distribution of values may be incorrect.",
16 | stacklevel=2)
17 |
18 | with torch.no_grad():
19 | # Values are generated by using a truncated uniform distribution and
20 | # then using the inverse CDF for the normal distribution.
21 | # Get upper and lower cdf values
22 | l = norm_cdf((a - mean) / std)
23 | u = norm_cdf((b - mean) / std)
24 |
25 | # Uniformly fill tensor with values from [l, u], then translate to
26 | # [2l-1, 2u-1].
27 | tensor.uniform_(2 * l - 1, 2 * u - 1)
28 |
29 | # Use inverse cdf transform for normal distribution to get truncated
30 | # standard normal
31 | tensor.erfinv_()
32 |
33 | # Transform to proper mean, std
34 | tensor.mul_(std * math.sqrt(2.))
35 | tensor.add_(mean)
36 |
37 | # Clamp to ensure it's in the proper range
38 | tensor.clamp_(min=a, max=b)
39 | return tensor
40 |
41 |
42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43 | # type: (Tensor, float, float, float, float) -> Tensor
44 | r"""Fills the input Tensor with values drawn from a truncated
45 | normal distribution. The values are effectively drawn from the
46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
47 | with values outside :math:`[a, b]` redrawn until they are within
48 | the bounds. The method used for generating the random values works
49 | best when :math:`a \leq \text{mean} \leq b`.
50 | Args:
51 | tensor: an n-dimensional `torch.Tensor`
52 | mean: the mean of the normal distribution
53 | std: the standard deviation of the normal distribution
54 | a: the minimum cutoff value
55 | b: the maximum cutoff value
56 | Examples:
57 | >>> w = torch.empty(3, 5)
58 | >>> nn.init.trunc_normal_(w)
59 | """
60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
--------------------------------------------------------------------------------
/data/Imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 | import numpy as np
5 | import cv2
6 |
7 | from torchvision.datasets import ImageNet
8 |
9 | from PIL import Image, ImageFilter
10 | import h5py
11 | from glob import glob
12 |
13 |
14 | class ImageNet_blur(ImageNet):
15 | def __getitem__(self, index):
16 | """
17 | Args:
18 | index (int): Index
19 |
20 | Returns:
21 | tuple: (sample, target) where target is class_index of the target class.
22 | """
23 | path, target = self.samples[index]
24 | sample = self.loader(path)
25 |
26 | gauss_blur = ImageFilter.GaussianBlur(11)
27 | median_blur = ImageFilter.MedianFilter(11)
28 |
29 | blurred_img1 = sample.filter(gauss_blur)
30 | blurred_img2 = sample.filter(median_blur)
31 | blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5)
32 |
33 | if self.transform is not None:
34 | sample = self.transform(sample)
35 | blurred_img = self.transform(blurred_img)
36 | if self.target_transform is not None:
37 | target = self.target_transform(target)
38 |
39 | return (sample, blurred_img), target
40 |
41 |
42 | class Imagenet_Segmentation(data.Dataset):
43 | CLASSES = 2
44 |
45 | def __init__(self,
46 | path,
47 | transform=None,
48 | target_transform=None):
49 | self.path = path
50 | self.transform = transform
51 | self.target_transform = target_transform
52 | # self.h5py = h5py.File(path, 'r+')
53 | self.h5py = None
54 | tmp = h5py.File(path, 'r')
55 | self.data_length = len(tmp['/value/img'])
56 | tmp.close()
57 | del tmp
58 |
59 | def __getitem__(self, index):
60 |
61 | if self.h5py is None:
62 | self.h5py = h5py.File(self.path, 'r')
63 |
64 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
65 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
66 |
67 | img = Image.fromarray(img).convert('RGB')
68 | target = Image.fromarray(target)
69 |
70 | if self.transform is not None:
71 | img = self.transform(img)
72 |
73 | if self.target_transform is not None:
74 | target = np.array(self.target_transform(target)).astype('int32')
75 | target = torch.from_numpy(target).long()
76 |
77 | return img, target
78 |
79 | def __len__(self):
80 | # return len(self.h5py['/value/img'])
81 | return self.data_length
82 |
83 |
84 | class Imagenet_Segmentation_Blur(data.Dataset):
85 | CLASSES = 2
86 |
87 | def __init__(self,
88 | path,
89 | transform=None,
90 | target_transform=None):
91 | self.path = path
92 | self.transform = transform
93 | self.target_transform = target_transform
94 | # self.h5py = h5py.File(path, 'r+')
95 | self.h5py = None
96 | tmp = h5py.File(path, 'r')
97 | self.data_length = len(tmp['/value/img'])
98 | tmp.close()
99 | del tmp
100 |
101 | def __getitem__(self, index):
102 |
103 | if self.h5py is None:
104 | self.h5py = h5py.File(self.path, 'r')
105 |
106 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
107 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
108 |
109 | img = Image.fromarray(img).convert('RGB')
110 | target = Image.fromarray(target)
111 |
112 | gauss_blur = ImageFilter.GaussianBlur(11)
113 | median_blur = ImageFilter.MedianFilter(11)
114 |
115 | blurred_img1 = img.filter(gauss_blur)
116 | blurred_img2 = img.filter(median_blur)
117 | blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5)
118 |
119 | # blurred_img1 = cv2.GaussianBlur(img, (11, 11), 5)
120 | # blurred_img2 = np.float32(cv2.medianBlur(img, 11))
121 | # blurred_img = (blurred_img1 + blurred_img2) / 2
122 |
123 | if self.transform is not None:
124 | img = self.transform(img)
125 | blurred_img = self.transform(blurred_img)
126 |
127 | if self.target_transform is not None:
128 | target = np.array(self.target_transform(target)).astype('int32')
129 | target = torch.from_numpy(target).long()
130 |
131 | return (img, blurred_img), target
132 |
133 | def __len__(self):
134 | # return len(self.h5py['/value/img'])
135 | return self.data_length
136 |
137 |
138 | class Imagenet_Segmentation_eval_dir(data.Dataset):
139 | CLASSES = 2
140 |
141 | def __init__(self,
142 | path,
143 | eval_path,
144 | transform=None,
145 | target_transform=None):
146 | self.transform = transform
147 | self.target_transform = target_transform
148 | self.h5py = h5py.File(path, 'r+')
149 |
150 | # 500 each file
151 | self.results = glob(os.path.join(eval_path, '*.npy'))
152 |
153 | def __getitem__(self, index):
154 |
155 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
156 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
157 | res = np.load(self.results[index])
158 |
159 | img = Image.fromarray(img).convert('RGB')
160 | target = Image.fromarray(target)
161 |
162 | if self.transform is not None:
163 | img = self.transform(img)
164 |
165 | if self.target_transform is not None:
166 | target = np.array(self.target_transform(target)).astype('int32')
167 | target = torch.from_numpy(target).long()
168 |
169 | return img, target
170 |
171 | def __len__(self):
172 | return len(self.h5py['/value/img'])
173 |
174 |
175 | if __name__ == '__main__':
176 | import torchvision.transforms as transforms
177 | from tqdm import tqdm
178 | from imageio import imsave
179 | import scipy.io as sio
180 |
181 | # meta = sio.loadmat('/home/shirgur/ext/Data/Datasets/temp/ILSVRC2012_devkit_t12/data/meta.mat', squeeze_me=True)['synsets']
182 |
183 | # Data
184 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
185 | std=[0.229, 0.224, 0.225])
186 | test_img_trans = transforms.Compose([
187 | transforms.Resize((224, 224)),
188 | transforms.ToTensor(),
189 | normalize,
190 | ])
191 | test_lbl_trans = transforms.Compose([
192 | transforms.Resize((224, 224), Image.NEAREST),
193 | ])
194 |
195 | ds = Imagenet_Segmentation('/home/shirgur/ext/Data/Datasets/imagenet-seg/other/gtsegs_ijcv.mat',
196 | transform=test_img_trans, target_transform=test_lbl_trans)
197 |
198 | for i, (img, tgt) in enumerate(tqdm(ds)):
199 | tgt = (tgt.numpy() * 255).astype(np.uint8)
200 | imsave('/home/shirgur/ext/Code/C2S/run/imagenet/gt/{}.png'.format(i), tgt)
201 |
202 | print('here')
203 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/data/__init__.py
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as data
4 | import numpy as np
5 |
6 | from PIL import Image
7 | import h5py
8 |
9 | __all__ = ['ImagenetResults']
10 |
11 |
12 | class Imagenet_Segmentation(data.Dataset):
13 | CLASSES = 2
14 |
15 | def __init__(self,
16 | path,
17 | transform=None,
18 | target_transform=None):
19 | self.path = path
20 | self.transform = transform
21 | self.target_transform = target_transform
22 | self.h5py = None
23 | tmp = h5py.File(path, 'r')
24 | self.data_length = len(tmp['/value/img'])
25 | tmp.close()
26 | del tmp
27 |
28 | def __getitem__(self, index):
29 |
30 | if self.h5py is None:
31 | self.h5py = h5py.File(self.path, 'r')
32 |
33 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
34 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
35 |
36 | img = Image.fromarray(img).convert('RGB')
37 | target = Image.fromarray(target)
38 |
39 | if self.transform is not None:
40 | img = self.transform(img)
41 |
42 | if self.target_transform is not None:
43 | target = np.array(self.target_transform(target)).astype('int32')
44 | target = torch.from_numpy(target).long()
45 |
46 | return img, target
47 |
48 | def __len__(self):
49 | return self.data_length
50 |
51 |
52 | class ImagenetResults(data.Dataset):
53 | def __init__(self, path):
54 | super(ImagenetResults, self).__init__()
55 |
56 | self.path = os.path.join(path, 'results.hdf5')
57 | self.data = None
58 |
59 | print('Reading dataset length...')
60 | with h5py.File(self.path, 'r') as f:
61 | self.data_length = len(f['/image'])
62 |
63 | def __len__(self):
64 | return self.data_length
65 |
66 | def __getitem__(self, item):
67 | if self.data is None:
68 | self.data = h5py.File(self.path, 'r')
69 |
70 | image = torch.tensor(self.data['image'][item])
71 | vis = torch.tensor(self.data['vis'][item])
72 | target = torch.tensor(self.data['target'][item]).long()
73 |
74 | return image, vis, target
75 |
--------------------------------------------------------------------------------
/dataset/expl_hdf5.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import h5py
4 | import os
5 |
6 |
7 | class ImagenetResults(Dataset):
8 | def __init__(self, path):
9 | super(ImagenetResults, self).__init__()
10 |
11 | self.path = os.path.join(path, 'results.hdf5')
12 | self.data = None
13 |
14 | print('Reading dataset length...')
15 | with h5py.File(self.path , 'r') as f:
16 | # tmp = h5py.File(self.path , 'r')
17 | self.data_length = len(f['/image'])
18 |
19 | def __len__(self):
20 | return self.data_length
21 |
22 | def __getitem__(self, item):
23 | if self.data is None:
24 | self.data = h5py.File(self.path, 'r')
25 |
26 | image = torch.tensor(self.data['image'][item])
27 | vis = torch.tensor(self.data['vis'][item])
28 | target = torch.tensor(self.data['target'][item]).long()
29 |
30 | return image, vis, target
31 |
32 |
33 | if __name__ == '__main__':
34 | from utils import render
35 | import imageio
36 | import numpy as np
37 |
38 | ds = ImagenetResults('../visualizations/fullgrad')
39 | sample_loader = torch.utils.data.DataLoader(
40 | ds,
41 | batch_size=5,
42 | shuffle=False)
43 |
44 | iterator = iter(sample_loader)
45 | image, vis, target = next(iterator)
46 |
47 | maps = (render.hm_to_rgb(vis[0].data.cpu().numpy(), scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
48 |
49 | # imageio.imsave('../delete_hm.jpg', maps)
50 |
51 | print(len(ds))
--------------------------------------------------------------------------------
/example.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/example.PNG
--------------------------------------------------------------------------------
/method-page-001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/method-page-001.jpg
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/modules/__init__.py
--------------------------------------------------------------------------------
/modules/layers_lrp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7 | 'LayerNorm', 'AddEye']
8 |
9 |
10 | def safe_divide(a, b):
11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12 | den = den + den.eq(0).type(den.type()) * 1e-9
13 | return a / den * b.ne(0).type(b.type())
14 |
15 |
16 | def forward_hook(self, input, output):
17 | if type(input[0]) in (list, tuple):
18 | self.X = []
19 | for i in input[0]:
20 | x = i.detach()
21 | x.requires_grad = True
22 | self.X.append(x)
23 | else:
24 | self.X = input[0].detach()
25 | self.X.requires_grad = True
26 |
27 | self.Y = output
28 |
29 |
30 | def backward_hook(self, grad_input, grad_output):
31 | self.grad_input = grad_input
32 | self.grad_output = grad_output
33 |
34 |
35 | class RelProp(nn.Module):
36 | def __init__(self):
37 | super(RelProp, self).__init__()
38 | # if not self.training:
39 | self.register_forward_hook(forward_hook)
40 |
41 | def gradprop(self, Z, X, S):
42 | C = torch.autograd.grad(Z, X, S, retain_graph=True)
43 | return C
44 |
45 | def relprop(self, R, alpha):
46 | return R
47 |
48 |
49 | class RelPropSimple(RelProp):
50 | def relprop(self, R, alpha):
51 | Z = self.forward(self.X)
52 | S = safe_divide(R, Z)
53 | C = self.gradprop(Z, self.X, S)
54 |
55 | if torch.is_tensor(self.X) == False:
56 | outputs = []
57 | outputs.append(self.X[0] * C[0])
58 | outputs.append(self.X[1] * C[1])
59 | else:
60 | outputs = self.X * (C[0])
61 | return outputs
62 |
63 | class AddEye(RelPropSimple):
64 | # input of shape B, C, seq_len, seq_len
65 | def forward(self, input):
66 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67 |
68 | class ReLU(nn.ReLU, RelProp):
69 | pass
70 |
71 | class GELU(nn.GELU, RelProp):
72 | pass
73 |
74 | class Softmax(nn.Softmax, RelProp):
75 | pass
76 |
77 | class LayerNorm(nn.LayerNorm, RelProp):
78 | pass
79 |
80 | class Dropout(nn.Dropout, RelProp):
81 | pass
82 |
83 |
84 | class MaxPool2d(nn.MaxPool2d, RelPropSimple):
85 | pass
86 |
87 | class LayerNorm(nn.LayerNorm, RelProp):
88 | pass
89 |
90 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
91 | pass
92 |
93 |
94 | class AvgPool2d(nn.AvgPool2d, RelPropSimple):
95 | pass
96 |
97 |
98 | class Add(RelPropSimple):
99 | def forward(self, inputs):
100 | return torch.add(*inputs)
101 |
102 | class einsum(RelPropSimple):
103 | def __init__(self, equation):
104 | super().__init__()
105 | self.equation = equation
106 | def forward(self, *operands):
107 | return torch.einsum(self.equation, *operands)
108 |
109 | class IndexSelect(RelProp):
110 | def forward(self, inputs, dim, indices):
111 | self.__setattr__('dim', dim)
112 | self.__setattr__('indices', indices)
113 |
114 | return torch.index_select(inputs, dim, indices)
115 |
116 | def relprop(self, R, alpha):
117 | Z = self.forward(self.X, self.dim, self.indices)
118 | S = safe_divide(R, Z)
119 | C = self.gradprop(Z, self.X, S)
120 |
121 | if torch.is_tensor(self.X) == False:
122 | outputs = []
123 | outputs.append(self.X[0] * C[0])
124 | outputs.append(self.X[1] * C[1])
125 | else:
126 | outputs = self.X * (C[0])
127 | return outputs
128 |
129 |
130 |
131 | class Clone(RelProp):
132 | def forward(self, input, num):
133 | self.__setattr__('num', num)
134 | outputs = []
135 | for _ in range(num):
136 | outputs.append(input)
137 |
138 | return outputs
139 |
140 | def relprop(self, R, alpha):
141 | Z = []
142 | for _ in range(self.num):
143 | Z.append(self.X)
144 | S = [safe_divide(r, z) for r, z in zip(R, Z)]
145 | C = self.gradprop(Z, self.X, S)[0]
146 |
147 | R = self.X * C
148 |
149 | return R
150 |
151 | class Cat(RelProp):
152 | def forward(self, inputs, dim):
153 | self.__setattr__('dim', dim)
154 | return torch.cat(inputs, dim)
155 |
156 | def relprop(self, R, alpha):
157 | Z = self.forward(self.X, self.dim)
158 | S = safe_divide(R, Z)
159 | C = self.gradprop(Z, self.X, S)
160 |
161 | outputs = []
162 | for x, c in zip(self.X, C):
163 | outputs.append(x * c)
164 |
165 | return outputs
166 |
167 |
168 | class Sequential(nn.Sequential):
169 | def relprop(self, R, alpha):
170 | for m in reversed(self._modules.values()):
171 | R = m.relprop(R, alpha)
172 | return R
173 |
174 |
175 | class BatchNorm2d(nn.BatchNorm2d, RelProp):
176 | def relprop(self, R, alpha):
177 | X = self.X
178 | beta = 1 - alpha
179 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
180 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
181 | Z = X * weight + 1e-9
182 | S = R / Z
183 | Ca = S * weight
184 | R = self.X * (Ca)
185 | return R
186 |
187 |
188 | class Linear(nn.Linear, RelProp):
189 | def relprop(self, R, alpha):
190 | beta = alpha - 1
191 | pw = torch.clamp(self.weight, min=0)
192 | nw = torch.clamp(self.weight, max=0)
193 | px = torch.clamp(self.X, min=0)
194 | nx = torch.clamp(self.X, max=0)
195 |
196 | def f(w1, w2, x1, x2):
197 | Z1 = F.linear(x1, w1)
198 | Z2 = F.linear(x2, w2)
199 | S1 = safe_divide(R, Z1)
200 | S2 = safe_divide(R, Z2)
201 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
202 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
203 |
204 | return C1 + C2
205 |
206 | activator_relevances = f(pw, nw, px, nx)
207 | inhibitor_relevances = f(nw, pw, px, nx)
208 |
209 | R = alpha * activator_relevances - beta * inhibitor_relevances
210 |
211 | return R
212 |
213 |
214 | class Conv2d(nn.Conv2d, RelProp):
215 | def gradprop2(self, DY, weight):
216 | Z = self.forward(self.X)
217 |
218 | output_padding = self.X.size()[2] - (
219 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
220 |
221 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
222 |
223 | def relprop(self, R, alpha):
224 | if self.X.shape[1] == 3:
225 | pw = torch.clamp(self.weight, min=0)
226 | nw = torch.clamp(self.weight, max=0)
227 | X = self.X
228 | L = self.X * 0 + \
229 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
230 | keepdim=True)[0]
231 | H = self.X * 0 + \
232 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
233 | keepdim=True)[0]
234 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
235 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
236 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
237 |
238 | S = R / Za
239 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
240 | R = C
241 | else:
242 | beta = alpha - 1
243 | pw = torch.clamp(self.weight, min=0)
244 | nw = torch.clamp(self.weight, max=0)
245 | px = torch.clamp(self.X, min=0)
246 | nx = torch.clamp(self.X, max=0)
247 |
248 | def f(w1, w2, x1, x2):
249 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
250 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
251 | S1 = safe_divide(R, Z1)
252 | S2 = safe_divide(R, Z2)
253 | C1 = x1 * self.gradprop(Z1, x1, S1)[0]
254 | C2 = x2 * self.gradprop(Z2, x2, S2)[0]
255 | return C1 + C2
256 |
257 | activator_relevances = f(pw, nw, px, nx)
258 | inhibitor_relevances = f(nw, pw, px, nx)
259 |
260 | R = alpha * activator_relevances - beta * inhibitor_relevances
261 | return R
--------------------------------------------------------------------------------
/modules/layers_ours.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7 | 'LayerNorm', 'AddEye']
8 |
9 |
10 | def safe_divide(a, b):
11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12 | den = den + den.eq(0).type(den.type()) * 1e-9
13 | return a / den * b.ne(0).type(b.type())
14 |
15 |
16 | def forward_hook(self, input, output):
17 | if type(input[0]) in (list, tuple):
18 | self.X = []
19 | for i in input[0]:
20 | x = i.detach()
21 | x.requires_grad = True
22 | self.X.append(x)
23 | else:
24 | self.X = input[0].detach()
25 | self.X.requires_grad = True
26 |
27 | self.Y = output
28 |
29 |
30 | def backward_hook(self, grad_input, grad_output):
31 | self.grad_input = grad_input
32 | self.grad_output = grad_output
33 |
34 |
35 | class RelProp(nn.Module):
36 | def __init__(self):
37 | super(RelProp, self).__init__()
38 | # if not self.training:
39 | self.register_forward_hook(forward_hook)
40 |
41 | def gradprop(self, Z, X, S):
42 | C = torch.autograd.grad(Z, X, S, retain_graph=True)
43 | return C
44 |
45 | def relprop(self, R, alpha):
46 | return R
47 |
48 | class RelPropSimple(RelProp):
49 | def relprop(self, R, alpha):
50 | Z = self.forward(self.X)
51 | S = safe_divide(R, Z)
52 | C = self.gradprop(Z, self.X, S)
53 |
54 | if torch.is_tensor(self.X) == False:
55 | outputs = []
56 | outputs.append(self.X[0] * C[0])
57 | outputs.append(self.X[1] * C[1])
58 | else:
59 | outputs = self.X * (C[0])
60 | return outputs
61 |
62 | class AddEye(RelPropSimple):
63 | # input of shape B, C, seq_len, seq_len
64 | def forward(self, input):
65 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
66 |
67 | class ReLU(nn.ReLU, RelProp):
68 | pass
69 |
70 | class GELU(nn.GELU, RelProp):
71 | pass
72 |
73 | class Softmax(nn.Softmax, RelProp):
74 | pass
75 |
76 | class LayerNorm(nn.LayerNorm, RelProp):
77 | pass
78 |
79 | class Dropout(nn.Dropout, RelProp):
80 | pass
81 |
82 |
83 | class MaxPool2d(nn.MaxPool2d, RelPropSimple):
84 | pass
85 |
86 | class LayerNorm(nn.LayerNorm, RelProp):
87 | pass
88 |
89 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
90 | pass
91 |
92 |
93 | class AvgPool2d(nn.AvgPool2d, RelPropSimple):
94 | pass
95 |
96 |
97 | class Add(RelPropSimple):
98 | def forward(self, inputs):
99 | return torch.add(*inputs)
100 |
101 | def relprop(self, R, alpha):
102 | Z = self.forward(self.X)
103 | S = safe_divide(R, Z)
104 | C = self.gradprop(Z, self.X, S)
105 |
106 | a = self.X[0] * C[0]
107 | b = self.X[1] * C[1]
108 |
109 | a_sum = a.sum()
110 | b_sum = b.sum()
111 |
112 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
113 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
114 |
115 | a = a * safe_divide(a_fact, a.sum())
116 | b = b * safe_divide(b_fact, b.sum())
117 |
118 | outputs = [a, b]
119 |
120 | return outputs
121 |
122 | class einsum(RelPropSimple):
123 | def __init__(self, equation):
124 | super().__init__()
125 | self.equation = equation
126 | def forward(self, *operands):
127 | return torch.einsum(self.equation, *operands)
128 |
129 | class IndexSelect(RelProp):
130 | def forward(self, inputs, dim, indices):
131 | self.__setattr__('dim', dim)
132 | self.__setattr__('indices', indices)
133 |
134 | return torch.index_select(inputs, dim, indices)
135 |
136 | def relprop(self, R, alpha):
137 | Z = self.forward(self.X, self.dim, self.indices)
138 | S = safe_divide(R, Z)
139 | C = self.gradprop(Z, self.X, S)
140 |
141 | if torch.is_tensor(self.X) == False:
142 | outputs = []
143 | outputs.append(self.X[0] * C[0])
144 | outputs.append(self.X[1] * C[1])
145 | else:
146 | outputs = self.X * (C[0])
147 | return outputs
148 |
149 |
150 |
151 | class Clone(RelProp):
152 | def forward(self, input, num):
153 | self.__setattr__('num', num)
154 | outputs = []
155 | for _ in range(num):
156 | outputs.append(input)
157 |
158 | return outputs
159 |
160 | def relprop(self, R, alpha):
161 | Z = []
162 | for _ in range(self.num):
163 | Z.append(self.X)
164 | S = [safe_divide(r, z) for r, z in zip(R, Z)]
165 | C = self.gradprop(Z, self.X, S)[0]
166 |
167 | R = self.X * C
168 |
169 | return R
170 |
171 | class Cat(RelProp):
172 | def forward(self, inputs, dim):
173 | self.__setattr__('dim', dim)
174 | return torch.cat(inputs, dim)
175 |
176 | def relprop(self, R, alpha):
177 | Z = self.forward(self.X, self.dim)
178 | S = safe_divide(R, Z)
179 | C = self.gradprop(Z, self.X, S)
180 |
181 | outputs = []
182 | for x, c in zip(self.X, C):
183 | outputs.append(x * c)
184 |
185 | return outputs
186 |
187 |
188 | class Sequential(nn.Sequential):
189 | def relprop(self, R, alpha):
190 | for m in reversed(self._modules.values()):
191 | R = m.relprop(R, alpha)
192 | return R
193 |
194 | class BatchNorm2d(nn.BatchNorm2d, RelProp):
195 | def relprop(self, R, alpha):
196 | X = self.X
197 | beta = 1 - alpha
198 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
199 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
200 | Z = X * weight + 1e-9
201 | S = R / Z
202 | Ca = S * weight
203 | R = self.X * (Ca)
204 | return R
205 |
206 |
207 | class Linear(nn.Linear, RelProp):
208 | def relprop(self, R, alpha):
209 | beta = alpha - 1
210 | pw = torch.clamp(self.weight, min=0)
211 | nw = torch.clamp(self.weight, max=0)
212 | px = torch.clamp(self.X, min=0)
213 | nx = torch.clamp(self.X, max=0)
214 |
215 | def f(w1, w2, x1, x2):
216 | Z1 = F.linear(x1, w1)
217 | Z2 = F.linear(x2, w2)
218 | S1 = safe_divide(R, Z1 + Z2)
219 | S2 = safe_divide(R, Z1 + Z2)
220 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
221 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
222 |
223 | return C1 + C2
224 |
225 | activator_relevances = f(pw, nw, px, nx)
226 | inhibitor_relevances = f(nw, pw, px, nx)
227 |
228 | R = alpha * activator_relevances - beta * inhibitor_relevances
229 |
230 | return R
231 |
232 |
233 | class Conv2d(nn.Conv2d, RelProp):
234 | def gradprop2(self, DY, weight):
235 | Z = self.forward(self.X)
236 |
237 | output_padding = self.X.size()[2] - (
238 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
239 |
240 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
241 |
242 | def relprop(self, R, alpha):
243 | if self.X.shape[1] == 3:
244 | pw = torch.clamp(self.weight, min=0)
245 | nw = torch.clamp(self.weight, max=0)
246 | X = self.X
247 | L = self.X * 0 + \
248 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
249 | keepdim=True)[0]
250 | H = self.X * 0 + \
251 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
252 | keepdim=True)[0]
253 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
254 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
255 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
256 |
257 | S = R / Za
258 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
259 | R = C
260 | else:
261 | beta = alpha - 1
262 | pw = torch.clamp(self.weight, min=0)
263 | nw = torch.clamp(self.weight, max=0)
264 | px = torch.clamp(self.X, min=0)
265 | nx = torch.clamp(self.X, max=0)
266 |
267 | def f(w1, w2, x1, x2):
268 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
269 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
270 | S1 = safe_divide(R, Z1)
271 | S2 = safe_divide(R, Z2)
272 | C1 = x1 * self.gradprop(Z1, x1, S1)[0]
273 | C2 = x2 * self.gradprop(Z2, x2, S2)[0]
274 | return C1 + C2
275 |
276 | activator_relevances = f(pw, nw, px, nx)
277 | inhibitor_relevances = f(nw, pw, px, nx)
278 |
279 | R = alpha * activator_relevances - beta * inhibitor_relevances
280 | return R
281 |
--------------------------------------------------------------------------------
/new_work.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/new_work.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow>=8.1.1
2 | einops == 0.3.0
3 | h5py == 2.8.0
4 | imageio == 2.9.0
5 | matplotlib == 3.3.2
6 | opencv_python
7 | scikit_image == 0.17.2
8 | scipy == 1.5.2
9 | sklearn
10 | torch == 1.7.0
11 | torchvision == 0.8.1
12 | tqdm == 4.51.0
13 | transformers == 3.5.1
14 | utils == 1.0.1
15 | Pygments>=2.7.4
16 |
--------------------------------------------------------------------------------
/samples/catdog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/catdog.png
--------------------------------------------------------------------------------
/samples/dogbird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/dogbird.png
--------------------------------------------------------------------------------
/samples/dogcat2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/dogcat2.png
--------------------------------------------------------------------------------
/samples/el1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el1.png
--------------------------------------------------------------------------------
/samples/el2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el2.png
--------------------------------------------------------------------------------
/samples/el3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el3.png
--------------------------------------------------------------------------------
/samples/el4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el4.png
--------------------------------------------------------------------------------
/samples/el5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/samples/el5.png
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hila-chefer/Transformer-Explainability/c3e578f76b954e8528afeaaee26de3f07e3fe559/utils/__init__.py
--------------------------------------------------------------------------------
/utils/confusionmatrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from . import metric
4 |
5 |
6 | class ConfusionMatrix(metric.Metric):
7 | """Constructs a confusion matrix for a multi-class classification problems.
8 | Does not support multi-label, multi-class problems.
9 | Keyword arguments:
10 | - num_classes (int): number of classes in the classification problem.
11 | - normalized (boolean, optional): Determines whether or not the confusion
12 | matrix is normalized or not. Default: False.
13 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
14 | """
15 |
16 | def __init__(self, num_classes, normalized=False):
17 | super().__init__()
18 |
19 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
20 | self.normalized = normalized
21 | self.num_classes = num_classes
22 | self.reset()
23 |
24 | def reset(self):
25 | self.conf.fill(0)
26 |
27 | def add(self, predicted, target):
28 | """Computes the confusion matrix
29 | The shape of the confusion matrix is K x K, where K is the number
30 | of classes.
31 | Keyword arguments:
32 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
33 | predicted scores obtained from the model for N examples and K classes,
34 | or an N-tensor/array of integer values between 0 and K-1.
35 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
36 | ground-truth classes for N examples and K classes, or an N-tensor/array
37 | of integer values between 0 and K-1.
38 | """
39 | # If target and/or predicted are tensors, convert them to numpy arrays
40 | if torch.is_tensor(predicted):
41 | predicted = predicted.cpu().numpy()
42 | if torch.is_tensor(target):
43 | target = target.cpu().numpy()
44 |
45 | assert predicted.shape[0] == target.shape[0], \
46 | 'number of targets and predicted outputs do not match'
47 |
48 | if np.ndim(predicted) != 1:
49 | assert predicted.shape[1] == self.num_classes, \
50 | 'number of predictions does not match size of confusion matrix'
51 | predicted = np.argmax(predicted, 1)
52 | else:
53 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
54 | 'predicted values are not between 0 and k-1'
55 |
56 | if np.ndim(target) != 1:
57 | assert target.shape[1] == self.num_classes, \
58 | 'Onehot target does not match size of confusion matrix'
59 | assert (target >= 0).all() and (target <= 1).all(), \
60 | 'in one-hot encoding, target values should be 0 or 1'
61 | assert (target.sum(1) == 1).all(), \
62 | 'multi-label setting is not supported'
63 | target = np.argmax(target, 1)
64 | else:
65 | assert (target.max() < self.num_classes) and (target.min() >= 0), \
66 | 'target values are not between 0 and k-1'
67 |
68 | # hack for bincounting 2 arrays together
69 | x = predicted + self.num_classes * target
70 | bincount_2d = np.bincount(
71 | x.astype(np.int32), minlength=self.num_classes**2)
72 | assert bincount_2d.size == self.num_classes**2
73 | conf = bincount_2d.reshape((self.num_classes, self.num_classes))
74 |
75 | self.conf += conf
76 |
77 | def value(self):
78 | """
79 | Returns:
80 | Confustion matrix of K rows and K columns, where rows corresponds
81 | to ground-truth targets and columns corresponds to predicted
82 | targets.
83 | """
84 | if self.normalized:
85 | conf = self.conf.astype(np.float32)
86 | return conf / conf.sum(1).clip(min=1e-12)[:, None]
87 | else:
88 | return self.conf
--------------------------------------------------------------------------------
/utils/iou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from . import metric
4 | from .confusionmatrix import ConfusionMatrix
5 |
6 |
7 | class IoU(metric.Metric):
8 | """Computes the intersection over union (IoU) per class and corresponding
9 | mean (mIoU).
10 |
11 | Intersection over union (IoU) is a common evaluation metric for semantic
12 | segmentation. The predictions are first accumulated in a confusion matrix
13 | and the IoU is computed from it as follows:
14 |
15 | IoU = true_positive / (true_positive + false_positive + false_negative).
16 |
17 | Keyword arguments:
18 | - num_classes (int): number of classes in the classification problem
19 | - normalized (boolean, optional): Determines whether or not the confusion
20 | matrix is normalized or not. Default: False.
21 | - ignore_index (int or iterable, optional): Index of the classes to ignore
22 | when computing the IoU. Can be an int, or any iterable of ints.
23 | """
24 |
25 | def __init__(self, num_classes, normalized=False, ignore_index=None):
26 | super().__init__()
27 | self.conf_metric = ConfusionMatrix(num_classes, normalized)
28 |
29 | if ignore_index is None:
30 | self.ignore_index = None
31 | elif isinstance(ignore_index, int):
32 | self.ignore_index = (ignore_index,)
33 | else:
34 | try:
35 | self.ignore_index = tuple(ignore_index)
36 | except TypeError:
37 | raise ValueError("'ignore_index' must be an int or iterable")
38 |
39 | def reset(self):
40 | self.conf_metric.reset()
41 |
42 | def add(self, predicted, target):
43 | """Adds the predicted and target pair to the IoU metric.
44 |
45 | Keyword arguments:
46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of
47 | predicted scores obtained from the model for N examples and K classes,
48 | or (N, H, W) tensor of integer values between 0 and K-1.
49 | - target (Tensor): Can be a (N, K, H, W) tensor of
50 | target scores for N examples and K classes, or (N, H, W) tensor of
51 | integer values between 0 and K-1.
52 |
53 | """
54 | # Dimensions check
55 | assert predicted.size(0) == target.size(0), \
56 | 'number of targets and predicted outputs do not match'
57 | assert predicted.dim() == 3 or predicted.dim() == 4, \
58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)"
59 | assert target.dim() == 3 or target.dim() == 4, \
60 | "targets must be of dimension (N, H, W) or (N, K, H, W)"
61 |
62 | # If the tensor is in categorical format convert it to integer format
63 | if predicted.dim() == 4:
64 | _, predicted = predicted.max(1)
65 | if target.dim() == 4:
66 | _, target = target.max(1)
67 |
68 | self.conf_metric.add(predicted.view(-1), target.view(-1))
69 |
70 | def value(self):
71 | """Computes the IoU and mean IoU.
72 |
73 | The mean computation ignores NaN elements of the IoU array.
74 |
75 | Returns:
76 | Tuple: (IoU, mIoU). The first output is the per class IoU,
77 | for K classes it's numpy.ndarray with K elements. The second output,
78 | is the mean IoU.
79 | """
80 | conf_matrix = self.conf_metric.value()
81 | if self.ignore_index is not None:
82 | for index in self.ignore_index:
83 | conf_matrix[:, self.ignore_index] = 0
84 | conf_matrix[self.ignore_index, :] = 0
85 | true_positive = np.diag(conf_matrix)
86 | false_positive = np.sum(conf_matrix, 0) - true_positive
87 | false_negative = np.sum(conf_matrix, 1) - true_positive
88 |
89 | # Just in case we get a division by 0, ignore/hide the error
90 | with np.errstate(divide='ignore', invalid='ignore'):
91 | iou = true_positive / (true_positive + false_positive + false_negative)
92 |
93 | return iou, np.nanmean(iou)
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | class Metric(object):
2 | """Base class for all metrics.
3 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
4 | """
5 | def reset(self):
6 | pass
7 |
8 | def add(self):
9 | pass
10 |
11 | def value(self):
12 | pass
--------------------------------------------------------------------------------
/utils/metrices.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.metrics import f1_score, average_precision_score
4 | from sklearn.metrics import precision_recall_curve, roc_curve
5 |
6 | SMOOTH = 1e-6
7 | __all__ = ['get_f1_scores', 'get_ap_scores', 'batch_pix_accuracy', 'batch_intersection_union', 'get_iou', 'get_pr',
8 | 'get_roc', 'get_ap_multiclass']
9 |
10 |
11 | def get_iou(outputs: torch.Tensor, labels: torch.Tensor):
12 | # You can comment out this line if you are passing tensors of equal shape
13 | # But if you are passing output from UNet or something it will most probably
14 | # be with the BATCH x 1 x H x W shape
15 | outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
16 | labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
17 |
18 | intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
19 | union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0
20 |
21 | iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
22 |
23 | return iou.cpu().numpy()
24 |
25 |
26 | def get_f1_scores(predict, target, ignore_index=-1):
27 | # Tensor process
28 | batch_size = predict.shape[0]
29 | predict = predict.data.cpu().numpy().reshape(-1)
30 | target = target.data.cpu().numpy().reshape(-1)
31 | pb = predict[target != ignore_index].reshape(batch_size, -1)
32 | tb = target[target != ignore_index].reshape(batch_size, -1)
33 |
34 | total = []
35 | for p, t in zip(pb, tb):
36 | total.append(np.nan_to_num(f1_score(t, p)))
37 |
38 | return total
39 |
40 |
41 | def get_roc(predict, target, ignore_index=-1):
42 | target_expand = target.unsqueeze(1).expand_as(predict)
43 | target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
44 | # Tensor process
45 | x = torch.zeros_like(target_expand)
46 | t = target.unsqueeze(1).clamp(min=0)
47 | target_1hot = x.scatter_(1, t, 1)
48 | batch_size = predict.shape[0]
49 | predict = predict.data.cpu().numpy().reshape(-1)
50 | target = target_1hot.data.cpu().numpy().reshape(-1)
51 | pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
52 | tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
53 |
54 | total = []
55 | for p, t in zip(pb, tb):
56 | total.append(roc_curve(t, p))
57 |
58 | return total
59 |
60 |
61 | def get_pr(predict, target, ignore_index=-1):
62 | target_expand = target.unsqueeze(1).expand_as(predict)
63 | target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
64 | # Tensor process
65 | x = torch.zeros_like(target_expand)
66 | t = target.unsqueeze(1).clamp(min=0)
67 | target_1hot = x.scatter_(1, t, 1)
68 | batch_size = predict.shape[0]
69 | predict = predict.data.cpu().numpy().reshape(-1)
70 | target = target_1hot.data.cpu().numpy().reshape(-1)
71 | pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
72 | tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
73 |
74 | total = []
75 | for p, t in zip(pb, tb):
76 | total.append(precision_recall_curve(t, p))
77 |
78 | return total
79 |
80 |
81 | def get_ap_scores(predict, target, ignore_index=-1):
82 | total = []
83 | for pred, tgt in zip(predict, target):
84 | target_expand = tgt.unsqueeze(0).expand_as(pred)
85 | target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
86 |
87 | # Tensor process
88 | x = torch.zeros_like(target_expand)
89 | t = tgt.unsqueeze(0).clamp(min=0).long()
90 | target_1hot = x.scatter_(0, t, 1)
91 | predict_flat = pred.data.cpu().numpy().reshape(-1)
92 | target_flat = target_1hot.data.cpu().numpy().reshape(-1)
93 |
94 | p = predict_flat[target_expand_numpy != ignore_index]
95 | t = target_flat[target_expand_numpy != ignore_index]
96 |
97 | total.append(np.nan_to_num(average_precision_score(t, p)))
98 |
99 | return total
100 |
101 |
102 | def get_ap_multiclass(predict, target):
103 | total = []
104 | for pred, tgt in zip(predict, target):
105 | predict_flat = pred.data.cpu().numpy().reshape(-1)
106 | target_flat = tgt.data.cpu().numpy().reshape(-1)
107 |
108 | total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat)))
109 |
110 | return total
111 |
112 |
113 | def batch_precision_recall(predict, target, thr=0.5):
114 | """Batch Precision Recall
115 | Args:
116 | predict: input 4D tensor
117 | target: label 4D tensor
118 | """
119 | # _, predict = torch.max(predict, 1)
120 |
121 | predict = predict > thr
122 | predict = predict.data.cpu().numpy() + 1
123 | target = target.data.cpu().numpy() + 1
124 |
125 | tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
126 | fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
127 | fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
128 |
129 | precision = float(np.nan_to_num(tp / (tp + fp)))
130 | recall = float(np.nan_to_num(tp / (tp + fn)))
131 |
132 | return precision, recall
133 |
134 |
135 | def batch_pix_accuracy(predict, target):
136 | """Batch Pixel Accuracy
137 | Args:
138 | predict: input 3D tensor
139 | target: label 3D tensor
140 | """
141 |
142 | # for thr in np.linspace(0, 1, slices):
143 |
144 | _, predict = torch.max(predict, 0)
145 | predict = predict.cpu().numpy() + 1
146 | target = target.cpu().numpy() + 1
147 | pixel_labeled = np.sum(target > 0)
148 | pixel_correct = np.sum((predict == target) * (target > 0))
149 | assert pixel_correct <= pixel_labeled, \
150 | "Correct area should be smaller than Labeled"
151 | return pixel_correct, pixel_labeled
152 |
153 |
154 | def batch_intersection_union(predict, target, nclass):
155 | """Batch Intersection of Union
156 | Args:
157 | predict: input 3D tensor
158 | target: label 3D tensor
159 | nclass: number of categories (int)
160 | """
161 | _, predict = torch.max(predict, 0)
162 | mini = 1
163 | maxi = nclass
164 | nbins = nclass
165 | predict = predict.cpu().numpy() + 1
166 | target = target.cpu().numpy() + 1
167 |
168 | predict = predict * (target > 0).astype(predict.dtype)
169 | intersection = predict * (predict == target)
170 | # areas of intersection and union
171 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
172 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
173 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
174 | area_union = area_pred + area_lab - area_inter
175 | assert (area_inter <= area_union).all(), \
176 | "Intersection area should be smaller than Union area"
177 | return area_inter, area_union
178 |
179 |
180 | # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
181 | def pixel_accuracy(im_pred, im_lab):
182 | im_pred = np.asarray(im_pred)
183 | im_lab = np.asarray(im_lab)
184 |
185 | # Remove classes from unlabeled pixels in gt image.
186 | # We should not penalize detections in unlabeled portions of the image.
187 | pixel_labeled = np.sum(im_lab > 0)
188 | pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
189 | # pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
190 | return pixel_correct, pixel_labeled
191 |
192 |
193 | def intersection_and_union(im_pred, im_lab, num_class):
194 | im_pred = np.asarray(im_pred)
195 | im_lab = np.asarray(im_lab)
196 | # Remove classes from unlabeled pixels in gt image.
197 | im_pred = im_pred * (im_lab > 0)
198 | # Compute area intersection:
199 | intersection = im_pred * (im_pred == im_lab)
200 | area_inter, _ = np.histogram(intersection, bins=num_class - 1,
201 | range=(1, num_class - 1))
202 | # Compute area union:
203 | area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
204 | range=(1, num_class - 1))
205 | area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
206 | range=(1, num_class - 1))
207 | area_union = area_pred + area_lab - area_inter
208 | return area_inter, area_union
209 |
--------------------------------------------------------------------------------
/utils/parallel.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang
3 | ## ECE Department, Rutgers University
4 | ## Email: zhang.hang@rutgers.edu
5 | ## Copyright (c) 2017
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 |
11 | """Encoding Data Parallel"""
12 | import threading
13 | import functools
14 | import torch
15 | from torch.autograd import Variable, Function
16 | import torch.cuda.comm as comm
17 | from torch.nn.parallel.data_parallel import DataParallel
18 | from torch.nn.parallel.parallel_apply import get_a_var
19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
20 |
21 | torch_ver = torch.__version__[:3]
22 |
23 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
24 | 'patch_replication_callback']
25 |
26 | def allreduce(*inputs):
27 | """Cross GPU all reduce autograd operation for calculate mean and
28 | variance in SyncBN.
29 | """
30 | return AllReduce.apply(*inputs)
31 |
32 | class AllReduce(Function):
33 | @staticmethod
34 | def forward(ctx, num_inputs, *inputs):
35 | ctx.num_inputs = num_inputs
36 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
37 | inputs = [inputs[i:i + num_inputs]
38 | for i in range(0, len(inputs), num_inputs)]
39 | # sort before reduce sum
40 | inputs = sorted(inputs, key=lambda i: i[0].get_device())
41 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
42 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
43 | return tuple([t for tensors in outputs for t in tensors])
44 |
45 | @staticmethod
46 | def backward(ctx, *inputs):
47 | inputs = [i.data for i in inputs]
48 | inputs = [inputs[i:i + ctx.num_inputs]
49 | for i in range(0, len(inputs), ctx.num_inputs)]
50 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
51 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
52 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
53 |
54 |
55 | class Reduce(Function):
56 | @staticmethod
57 | def forward(ctx, *inputs):
58 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
59 | inputs = sorted(inputs, key=lambda i: i.get_device())
60 | return comm.reduce_add(inputs)
61 |
62 | @staticmethod
63 | def backward(ctx, gradOutput):
64 | return Broadcast.apply(ctx.target_gpus, gradOutput)
65 |
66 |
67 | class DataParallelModel(DataParallel):
68 | """Implements data parallelism at the module level.
69 |
70 | This container parallelizes the application of the given module by
71 | splitting the input across the specified devices by chunking in the
72 | batch dimension.
73 | In the forward pass, the module is replicated on each device,
74 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
75 | Note that the outputs are not gathered, please use compatible
76 | :class:`encoding.parallel.DataParallelCriterion`.
77 |
78 | The batch size should be larger than the number of GPUs used. It should
79 | also be an integer multiple of the number of GPUs so that each chunk is
80 | the same size (so that each GPU processes the same number of samples).
81 |
82 | Args:
83 | module: module to be parallelized
84 | device_ids: CUDA devices (default: all devices)
85 |
86 | Reference:
87 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
88 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
89 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
90 |
91 | Example::
92 |
93 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
94 | >>> y = net(x)
95 | """
96 | def gather(self, outputs, output_device):
97 | return outputs
98 |
99 | def replicate(self, module, device_ids):
100 | modules = super(DataParallelModel, self).replicate(module, device_ids)
101 | execute_replication_callbacks(modules)
102 | return modules
103 |
104 |
105 | class DataParallelCriterion(DataParallel):
106 | """
107 | Calculate loss in multiple-GPUs, which balance the memory usage for
108 | Semantic Segmentation.
109 |
110 | The targets are splitted across the specified devices by chunking in
111 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
112 |
113 | Reference:
114 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
115 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
116 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
117 |
118 | Example::
119 |
120 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
121 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
122 | >>> y = net(x)
123 | >>> loss = criterion(y, target)
124 | """
125 | def forward(self, inputs, *targets, **kwargs):
126 | # input should be already scatterd
127 | # scattering the targets instead
128 | if not self.device_ids:
129 | return self.module(inputs, *targets, **kwargs)
130 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
131 | if len(self.device_ids) == 1:
132 | return self.module(inputs, *targets[0], **kwargs[0])
133 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
134 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
135 | return Reduce.apply(*outputs) / len(outputs)
136 | #return self.gather(outputs, self.output_device).mean()
137 |
138 |
139 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
140 | assert len(modules) == len(inputs)
141 | assert len(targets) == len(inputs)
142 | if kwargs_tup:
143 | assert len(modules) == len(kwargs_tup)
144 | else:
145 | kwargs_tup = ({},) * len(modules)
146 | if devices is not None:
147 | assert len(modules) == len(devices)
148 | else:
149 | devices = [None] * len(modules)
150 |
151 | lock = threading.Lock()
152 | results = {}
153 | if torch_ver != "0.3":
154 | grad_enabled = torch.is_grad_enabled()
155 |
156 | def _worker(i, module, input, target, kwargs, device=None):
157 | if torch_ver != "0.3":
158 | torch.set_grad_enabled(grad_enabled)
159 | if device is None:
160 | device = get_a_var(input).get_device()
161 | try:
162 | with torch.cuda.device(device):
163 | # this also avoids accidental slicing of `input` if it is a Tensor
164 | if not isinstance(input, (list, tuple)):
165 | input = (input,)
166 | if type(input) != type(target):
167 | if isinstance(target, tuple):
168 | input = tuple(input)
169 | elif isinstance(target, list):
170 | input = list(input)
171 | else:
172 | raise Exception("Types problem")
173 |
174 | output = module(*(input + target), **kwargs)
175 | with lock:
176 | results[i] = output
177 | except Exception as e:
178 | with lock:
179 | results[i] = e
180 |
181 | if len(modules) > 1:
182 | threads = [threading.Thread(target=_worker,
183 | args=(i, module, input, target,
184 | kwargs, device),)
185 | for i, (module, input, target, kwargs, device) in
186 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
187 |
188 | for thread in threads:
189 | thread.start()
190 | for thread in threads:
191 | thread.join()
192 | else:
193 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
194 |
195 | outputs = []
196 | for i in range(len(inputs)):
197 | output = results[i]
198 | if isinstance(output, Exception):
199 | raise output
200 | outputs.append(output)
201 | return outputs
202 |
203 |
204 | ###########################################################################
205 | # Adapted from Synchronized-BatchNorm-PyTorch.
206 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
207 | #
208 | class CallbackContext(object):
209 | pass
210 |
211 |
212 | def execute_replication_callbacks(modules):
213 | """
214 | Execute an replication callback `__data_parallel_replicate__` on each module created
215 | by original replication.
216 |
217 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
218 |
219 | Note that, as all modules are isomorphism, we assign each sub-module with a context
220 | (shared among multiple copies of this module on different devices).
221 | Through this context, different copies can share some information.
222 |
223 | We guarantee that the callback on the master copy (the first copy) will be called ahead
224 | of calling the callback of any slave copies.
225 | """
226 | master_copy = modules[0]
227 | nr_modules = len(list(master_copy.modules()))
228 | ctxs = [CallbackContext() for _ in range(nr_modules)]
229 |
230 | for i, module in enumerate(modules):
231 | for j, m in enumerate(module.modules()):
232 | if hasattr(m, '__data_parallel_replicate__'):
233 | m.__data_parallel_replicate__(ctxs[j], i)
234 |
235 |
236 | def patch_replication_callback(data_parallel):
237 | """
238 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
239 | Useful when you have customized `DataParallel` implementation.
240 |
241 | Examples:
242 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
243 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
244 | > patch_replication_callback(sync_bn)
245 | # this is equivalent to
246 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
247 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
248 | """
249 |
250 | assert isinstance(data_parallel, DataParallel)
251 |
252 | old_replicate = data_parallel.replicate
253 |
254 | @functools.wraps(old_replicate)
255 | def new_replicate(module, device_ids):
256 | modules = old_replicate(module, device_ids)
257 | execute_replication_callbacks(modules)
258 | return modules
259 |
260 | data_parallel.replicate = new_replicate
261 |
--------------------------------------------------------------------------------
/utils/render.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.cm
3 | import skimage.io
4 | import skimage.feature
5 | import skimage.filters
6 |
7 |
8 | def vec2im(V, shape=()):
9 | '''
10 | Transform an array V into a specified shape - or if no shape is given assume a square output format.
11 |
12 | Parameters
13 | ----------
14 |
15 | V : numpy.ndarray
16 | an array either representing a matrix or vector to be reshaped into an two-dimensional image
17 |
18 | shape : tuple or list
19 | optional. containing the shape information for the output array if not given, the output is assumed to be square
20 |
21 | Returns
22 | -------
23 |
24 | W : numpy.ndarray
25 | with W.shape = shape or W.shape = [np.sqrt(V.size)]*2
26 |
27 | '''
28 |
29 | if len(shape) < 2:
30 | shape = [np.sqrt(V.size)] * 2
31 | shape = map(int, shape)
32 | return np.reshape(V, shape)
33 |
34 |
35 | def enlarge_image(img, scaling=3):
36 | '''
37 | Enlarges a given input matrix by replicating each pixel value scaling times in horizontal and vertical direction.
38 |
39 | Parameters
40 | ----------
41 |
42 | img : numpy.ndarray
43 | array of shape [H x W] OR [H x W x D]
44 |
45 | scaling : int
46 | positive integer value > 0
47 |
48 | Returns
49 | -------
50 |
51 | out : numpy.ndarray
52 | two-dimensional array of shape [scaling*H x scaling*W]
53 | OR
54 | three-dimensional array of shape [scaling*H x scaling*W x D]
55 | depending on the dimensionality of the input
56 | '''
57 |
58 | if scaling < 1 or not isinstance(scaling, int):
59 | print('scaling factor needs to be an int >= 1')
60 |
61 | if len(img.shape) == 2:
62 | H, W = img.shape
63 |
64 | out = np.zeros((scaling * H, scaling * W))
65 | for h in range(H):
66 | fh = scaling * h
67 | for w in range(W):
68 | fw = scaling * w
69 | out[fh:fh + scaling, fw:fw + scaling] = img[h, w]
70 |
71 | elif len(img.shape) == 3:
72 | H, W, D = img.shape
73 |
74 | out = np.zeros((scaling * H, scaling * W, D))
75 | for h in range(H):
76 | fh = scaling * h
77 | for w in range(W):
78 | fw = scaling * w
79 | out[fh:fh + scaling, fw:fw + scaling, :] = img[h, w, :]
80 |
81 | return out
82 |
83 |
84 | def repaint_corner_pixels(rgbimg, scaling=3):
85 | '''
86 | DEPRECATED/OBSOLETE.
87 |
88 | Recolors the top left and bottom right pixel (groups) with the average rgb value of its three neighboring pixel (groups).
89 | The recoloring visually masks the opposing pixel values which are a product of stabilizing the scaling.
90 | Assumes those image ares will pretty much never show evidence.
91 |
92 | Parameters
93 | ----------
94 |
95 | rgbimg : numpy.ndarray
96 | array of shape [H x W x 3]
97 |
98 | scaling : int
99 | positive integer value > 0
100 |
101 | Returns
102 | -------
103 |
104 | rgbimg : numpy.ndarray
105 | three-dimensional array of shape [scaling*H x scaling*W x 3]
106 | '''
107 |
108 | # top left corner.
109 | rgbimg[0:scaling, 0:scaling, :] = (rgbimg[0, scaling, :] + rgbimg[scaling, 0, :] + rgbimg[scaling, scaling,
110 | :]) / 3.0
111 | # bottom right corner
112 | rgbimg[-scaling:, -scaling:, :] = (rgbimg[-1, -1 - scaling, :] + rgbimg[-1 - scaling, -1, :] + rgbimg[-1 - scaling,
113 | -1 - scaling,
114 | :]) / 3.0
115 | return rgbimg
116 |
117 |
118 | def digit_to_rgb(X, scaling=3, shape=(), cmap='binary'):
119 | '''
120 | Takes as input an intensity array and produces a rgb image due to some color map
121 |
122 | Parameters
123 | ----------
124 |
125 | X : numpy.ndarray
126 | intensity matrix as array of shape [M x N]
127 |
128 | scaling : int
129 | optional. positive integer value > 0
130 |
131 | shape: tuple or list of its , length = 2
132 | optional. if not given, X is reshaped to be square.
133 |
134 | cmap : str
135 | name of color map of choice. default is 'binary'
136 |
137 | Returns
138 | -------
139 |
140 | image : numpy.ndarray
141 | three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
142 | '''
143 |
144 | # create color map object from name string
145 | cmap = eval('matplotlib.cm.{}'.format(cmap))
146 |
147 | image = enlarge_image(vec2im(X, shape), scaling) # enlarge
148 | image = cmap(image.flatten())[..., 0:3].reshape([image.shape[0], image.shape[1], 3]) # colorize, reshape
149 |
150 | return image
151 |
152 |
153 | def hm_to_rgb(R, X=None, scaling=3, shape=(), sigma=2, cmap='bwr', normalize=True):
154 | '''
155 | Takes as input an intensity array and produces a rgb image for the represented heatmap.
156 | optionally draws the outline of another input on top of it.
157 |
158 | Parameters
159 | ----------
160 |
161 | R : numpy.ndarray
162 | the heatmap to be visualized, shaped [M x N]
163 |
164 | X : numpy.ndarray
165 | optional. some input, usually the data point for which the heatmap R is for, which shall serve
166 | as a template for a black outline to be drawn on top of the image
167 | shaped [M x N]
168 |
169 | scaling: int
170 | factor, on how to enlarge the heatmap (to control resolution and as a inverse way to control outline thickness)
171 | after reshaping it using shape.
172 |
173 | shape: tuple or list, length = 2
174 | optional. if not given, X is reshaped to be square.
175 |
176 | sigma : double
177 | optional. sigma-parameter for the canny algorithm used for edge detection. the found edges are drawn as outlines.
178 |
179 | cmap : str
180 | optional. color map of choice
181 |
182 | normalize : bool
183 | optional. whether to normalize the heatmap to [-1 1] prior to colorization or not.
184 |
185 | Returns
186 | -------
187 |
188 | rgbimg : numpy.ndarray
189 | three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
190 | '''
191 |
192 | # create color map object from name string
193 | cmap = eval('matplotlib.cm.{}'.format(cmap))
194 |
195 | if normalize:
196 | R = R / np.max(np.abs(R)) # normalize to [-1,1] wrt to max relevance magnitude
197 | R = (R + 1.) / 2. # shift/normalize to [0,1] for color mapping
198 |
199 | R = enlarge_image(R, scaling)
200 | rgb = cmap(R.flatten())[..., 0:3].reshape([R.shape[0], R.shape[1], 3])
201 | # rgb = repaint_corner_pixels(rgb, scaling) #obsolete due to directly calling the color map with [0,1]-normalized inputs
202 |
203 | if not X is None: # compute the outline of the input
204 | # X = enlarge_image(vec2im(X,shape), scaling)
205 | xdims = X.shape
206 | Rdims = R.shape
207 |
208 | # if not np.all(xdims == Rdims):
209 | # print 'transformed heatmap and data dimension mismatch. data dimensions differ?'
210 | # print 'R.shape = ',Rdims, 'X.shape = ', xdims
211 | # print 'skipping drawing of outline\n'
212 | # else:
213 | # #edges = skimage.filters.canny(X, sigma=sigma)
214 | # edges = skimage.feature.canny(X, sigma=sigma)
215 | # edges = np.invert(np.dstack([edges]*3))*1.0
216 | # rgb *= edges # set outline pixels to black color
217 |
218 | return rgb
219 |
220 |
221 | def save_image(rgb_images, path, gap=2):
222 | '''
223 | Takes as input a list of rgb images, places them next to each other with a gap and writes out the result.
224 |
225 | Parameters
226 | ----------
227 |
228 | rgb_images : list , tuple, collection. such stuff
229 | each item in the collection is expected to be an rgb image of dimensions [H x _ x 3]
230 | where the width is variable
231 |
232 | path : str
233 | the output path of the assembled image
234 |
235 | gap : int
236 | optional. sets the width of a black area of pixels realized as an image shaped [H x gap x 3] in between the input images
237 |
238 | Returns
239 | -------
240 |
241 | image : numpy.ndarray
242 | the assembled image as written out to path
243 | '''
244 |
245 | sz = []
246 | image = []
247 | for i in range(len(rgb_images)):
248 | if not sz:
249 | sz = rgb_images[i].shape
250 | image = rgb_images[i]
251 | gap = np.zeros((sz[0], gap, sz[2]))
252 | continue
253 | if not sz[0] == rgb_images[i].shape[0] and sz[1] == rgb_images[i].shape[2]:
254 | print('image', i, 'differs in size. unable to perform horizontal alignment')
255 | print('expected: Hx_xD = {0}x_x{1}'.format(sz[0], sz[1]))
256 | print('got : Hx_xD = {0}x_x{1}'.format(rgb_images[i].shape[0], rgb_images[i].shape[1]))
257 | print('skipping image\n')
258 | else:
259 | image = np.hstack((image, gap, rgb_images[i]))
260 |
261 | image *= 255
262 | image = image.astype(np.uint8)
263 |
264 | print('saving image to ', path)
265 | skimage.io.imsave(path, image)
266 | return image
267 |
--------------------------------------------------------------------------------
/utils/saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | import glob
5 |
6 |
7 | class Saver(object):
8 |
9 | def __init__(self, args):
10 | self.args = args
11 | self.directory = os.path.join('run', args.train_dataset, args.checkname)
12 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*')))
13 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0
14 |
15 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)))
16 | if not os.path.exists(self.experiment_dir):
17 | os.makedirs(self.experiment_dir)
18 |
19 | def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
20 | """Saves checkpoint to disk"""
21 | filename = os.path.join(self.experiment_dir, filename)
22 | torch.save(state, filename)
23 |
24 | def save_experiment_config(self):
25 | logfile = os.path.join(self.experiment_dir, 'parameters.txt')
26 | log_file = open(logfile, 'w')
27 | p = OrderedDict()
28 | p['train_dataset'] = self.args.train_dataset
29 | p['lr'] = self.args.lr
30 | p['epoch'] = self.args.epochs
31 |
32 | for key, val in p.items():
33 | log_file.write(key + ':' + str(val) + '\n')
34 | log_file.close()
35 |
--------------------------------------------------------------------------------
/utils/summaries.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.tensorboard import SummaryWriter
3 |
4 |
5 | class TensorboardSummary(object):
6 | def __init__(self, directory):
7 | self.directory = directory
8 | self.writer = SummaryWriter(log_dir=os.path.join(self.directory))
9 |
10 | def add_scalar(self, *args):
11 | self.writer.add_scalar(*args)
--------------------------------------------------------------------------------