├── LICENSE ├── Readme.md ├── examples ├── both.png ├── both_attention_rollout_0.000_mean.png ├── both_attention_rollout_0.000_min.png ├── both_attention_rollout_0.900_max.png ├── both_attention_rollout_0.950_max.png ├── both_attention_rollout_0.990_max.png ├── both_discard_ratio.gif ├── both_grad_rollout_243_0.900_max.png ├── both_grad_rollout_282_0.900_max.png ├── dogbird.png ├── dogbird_attention_rollout_0.000_mean.png ├── dogbird_attention_rollout_0.900_max.png ├── dogbird_grad_rollout_161_0.900_max.png ├── dogbird_grad_rollout_87_0.900_max.png ├── grad_rollout_161_0.000_max.png ├── grad_rollout_161_0.500_max.png ├── grad_rollout_161_0.900_max.png ├── grad_rollout_87_0.000_max.png ├── input.png ├── plane.png ├── plane2.png ├── plane2_attention_rollout_0.000_mean.png ├── plane2_attention_rollout_0.900_max.png ├── plane_attention_rollout_0.000_max.png ├── plane_attention_rollout_0.000_mean.png ├── plane_attention_rollout_0.000_min.png ├── plane_attention_rollout_0.900_max.png ├── plane_attention_rollout_0.900_min.png └── plane_discard_ratio.gif ├── vit_explain.py ├── vit_grad_rollout.py └── vit_rollout.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jacob Gildenblat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Explainability for Vision Transformers (in PyTorch) 2 | 3 | This repository implements methods for explainability in Vision Transformers. 4 | 5 | See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability 6 | 7 | 8 | ## Currently implemented: 9 | - Attention Rollout. 10 | - Gradient Attention Rollout for class specific explainability. 11 | *This is our attempt to further build upon and improve Attention Rollout.* 12 | 13 | - TBD Attention flow is work in progress. 14 | 15 | Includes some tweaks and tricks to get it working: 16 | - Different Attention Head fusion methods, 17 | - Removing the lowest attentions. 18 | 19 | 20 | ## Usage 21 | 22 | - From code 23 | ``` python 24 | from vit_grad_rollout import VITAttentionGradRollout 25 | 26 | model = torch.hub.load('facebookresearch/deit:main', 27 | 'deit_tiny_patch16_224', pretrained=True) 28 | grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max') 29 | mask = grad_rollout(input_tensor, category_index=243) 30 | 31 | ``` 32 | 33 | - From the command line: 34 | 35 | ``` 36 | python vit_explain.py --image_path --head_fusion --discard_ratio --category_index 37 | ``` 38 | If category_index isn't specified, Attention Rollout will be used, 39 | otherwise Gradient Attention Rollout will be used. 40 | 41 | Notice that by default, this uses the 'Tiny' model from [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) 42 | hosted on torch hub. 43 | 44 | ## Where did the Transformer pay attention to in this image? 45 | 46 | | Image | Vanilla Attention Rollout | With discard_ratio+max fusion | 47 | | -------------------------|-------------------------|------------------------- | 48 | | ![](examples/both.png) | ![](examples/both_attention_rollout_0.000_mean.png) | ![](examples/both_attention_rollout_0.990_max.png) 49 | ![](examples/plane.png) | ![](examples/plane_attention_rollout_0.000_mean.png) | ![](examples/plane_attention_rollout_0.900_max.png) | 50 | ![](examples/dogbird.png) | ![](examples/dogbird_attention_rollout_0.000_mean.png) | ![](examples/dogbird_attention_rollout_0.900_max.png) | 51 | ![](examples/plane2.png) | ![](examples/plane2_attention_rollout_0.000_mean.png) | ![](examples/plane2_attention_rollout_0.900_max.png) | 52 | 53 | ## Gradient Attention Rollout for class specific explainability 54 | 55 | The Attention that flows in the transformer passes along information belonging to different classes. 56 | Gradient roll out lets us see what locations the network paid attention too, 57 | but it tells us nothing about if it ended up using those locations for the final classification. 58 | 59 | We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories). 60 | 61 | 62 | ### Where does the Transformer see a Dog (category 243), and a Cat (category 282)? 63 | ![](examples/both_grad_rollout_243_0.900_max.png) ![](examples/both_grad_rollout_282_0.900_max.png) 64 | 65 | ### Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87): 66 | ![](examples/dogbird_grad_rollout_161_0.900_max.png) ![](examples/dogbird_grad_rollout_87_0.900_max.png) 67 | 68 | 69 | ## Tricks and Tweaks to get this working 70 | 71 | ### Filtering the lowest attentions in every layer 72 | 73 | `--discard_ratio ` 74 | 75 | Removes noise by keeping the strongest attentions. 76 | 77 | Results for dIfferent values: 78 | 79 | ![](examples/both_discard_ratio.gif) ![](examples/plane_discard_ratio.gif) 80 | 81 | ### Different Attention Head Fusions 82 | 83 | The Attention Rollout method suggests taking the average attention accross the attention heads, 84 | 85 | but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better. 86 | 87 | ` --head_fusion ` 88 | 89 | | Image | Mean Fusion | Min Fusion | 90 | | -------------------------|-------------------------|------------------------- | 91 | ![](examples/both.png) | ![](examples/both_attention_rollout_0.000_mean.png) | ![](examples/both_attention_rollout_0.000_min.png) 92 | 93 | ## References 94 | - [Quantifying Attention Flow in Transformers](https://arxiv.org/abs/2005.00928) 95 | - [timm: a great collection of models in PyTorch](https://github.com/rwightman/pytorch-image-models) 96 | and especially [the vision transformer implementation](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) 97 | 98 | - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 99 | - Credit for https://github.com/jeonsworld/ViT-pytorch for being a good starting point. 100 | 101 | ## Requirements 102 | `pip install timm` 103 | -------------------------------------------------------------------------------- /examples/both.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both.png -------------------------------------------------------------------------------- /examples/both_attention_rollout_0.000_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_attention_rollout_0.000_mean.png -------------------------------------------------------------------------------- /examples/both_attention_rollout_0.000_min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_attention_rollout_0.000_min.png -------------------------------------------------------------------------------- /examples/both_attention_rollout_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_attention_rollout_0.900_max.png -------------------------------------------------------------------------------- /examples/both_attention_rollout_0.950_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_attention_rollout_0.950_max.png -------------------------------------------------------------------------------- /examples/both_attention_rollout_0.990_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_attention_rollout_0.990_max.png -------------------------------------------------------------------------------- /examples/both_discard_ratio.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_discard_ratio.gif -------------------------------------------------------------------------------- /examples/both_grad_rollout_243_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_grad_rollout_243_0.900_max.png -------------------------------------------------------------------------------- /examples/both_grad_rollout_282_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/both_grad_rollout_282_0.900_max.png -------------------------------------------------------------------------------- /examples/dogbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/dogbird.png -------------------------------------------------------------------------------- /examples/dogbird_attention_rollout_0.000_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/dogbird_attention_rollout_0.000_mean.png -------------------------------------------------------------------------------- /examples/dogbird_attention_rollout_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/dogbird_attention_rollout_0.900_max.png -------------------------------------------------------------------------------- /examples/dogbird_grad_rollout_161_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/dogbird_grad_rollout_161_0.900_max.png -------------------------------------------------------------------------------- /examples/dogbird_grad_rollout_87_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/dogbird_grad_rollout_87_0.900_max.png -------------------------------------------------------------------------------- /examples/grad_rollout_161_0.000_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/grad_rollout_161_0.000_max.png -------------------------------------------------------------------------------- /examples/grad_rollout_161_0.500_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/grad_rollout_161_0.500_max.png -------------------------------------------------------------------------------- /examples/grad_rollout_161_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/grad_rollout_161_0.900_max.png -------------------------------------------------------------------------------- /examples/grad_rollout_87_0.000_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/grad_rollout_87_0.000_max.png -------------------------------------------------------------------------------- /examples/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/input.png -------------------------------------------------------------------------------- /examples/plane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane.png -------------------------------------------------------------------------------- /examples/plane2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane2.png -------------------------------------------------------------------------------- /examples/plane2_attention_rollout_0.000_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane2_attention_rollout_0.000_mean.png -------------------------------------------------------------------------------- /examples/plane2_attention_rollout_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane2_attention_rollout_0.900_max.png -------------------------------------------------------------------------------- /examples/plane_attention_rollout_0.000_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane_attention_rollout_0.000_max.png -------------------------------------------------------------------------------- /examples/plane_attention_rollout_0.000_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane_attention_rollout_0.000_mean.png -------------------------------------------------------------------------------- /examples/plane_attention_rollout_0.000_min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane_attention_rollout_0.000_min.png -------------------------------------------------------------------------------- /examples/plane_attention_rollout_0.900_max.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane_attention_rollout_0.900_max.png -------------------------------------------------------------------------------- /examples/plane_attention_rollout_0.900_min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane_attention_rollout_0.900_min.png -------------------------------------------------------------------------------- /examples/plane_discard_ratio.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/vit-explain/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/examples/plane_discard_ratio.gif -------------------------------------------------------------------------------- /vit_explain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import torch 4 | from PIL import Image 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | from vit_rollout import VITAttentionRollout 10 | from vit_grad_rollout import VITAttentionGradRollout 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--use_cuda', action='store_true', default=False, 15 | help='Use NVIDIA GPU acceleration') 16 | parser.add_argument('--image_path', type=str, default='./examples/both.png', 17 | help='Input image path') 18 | parser.add_argument('--head_fusion', type=str, default='max', 19 | help='How to fuse the attention heads for attention rollout. \ 20 | Can be mean/max/min') 21 | parser.add_argument('--discard_ratio', type=float, default=0.9, 22 | help='How many of the lowest 14x14 attention paths should we discard') 23 | parser.add_argument('--category_index', type=int, default=None, 24 | help='The category index for gradient rollout') 25 | args = parser.parse_args() 26 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 27 | if args.use_cuda: 28 | print("Using GPU") 29 | else: 30 | print("Using CPU") 31 | 32 | return args 33 | 34 | def show_mask_on_image(img, mask): 35 | img = np.float32(img) / 255 36 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 37 | heatmap = np.float32(heatmap) / 255 38 | cam = heatmap + np.float32(img) 39 | cam = cam / np.max(cam) 40 | return np.uint8(255 * cam) 41 | 42 | if __name__ == '__main__': 43 | args = get_args() 44 | model = torch.hub.load('facebookresearch/deit:main', 45 | 'deit_tiny_patch16_224', pretrained=True) 46 | model.eval() 47 | 48 | if args.use_cuda: 49 | model = model.cuda() 50 | 51 | transform = transforms.Compose([ 52 | transforms.Resize((224, 224)), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 55 | ]) 56 | img = Image.open(args.image_path) 57 | img = img.resize((224, 224)) 58 | input_tensor = transform(img).unsqueeze(0) 59 | if args.use_cuda: 60 | input_tensor = input_tensor.cuda() 61 | 62 | if args.category_index is None: 63 | print("Doing Attention Rollout") 64 | attention_rollout = VITAttentionRollout(model, head_fusion=args.head_fusion, 65 | discard_ratio=args.discard_ratio) 66 | mask = attention_rollout(input_tensor) 67 | name = "attention_rollout_{:.3f}_{}.png".format(args.discard_ratio, args.head_fusion) 68 | else: 69 | print("Doing Gradient Attention Rollout") 70 | grad_rollout = VITAttentionGradRollout(model, discard_ratio=args.discard_ratio) 71 | mask = grad_rollout(input_tensor, args.category_index) 72 | name = "grad_rollout_{}_{:.3f}_{}.png".format(args.category_index, 73 | args.discard_ratio, args.head_fusion) 74 | 75 | 76 | np_img = np.array(img)[:, :, ::-1] 77 | mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0])) 78 | mask = show_mask_on_image(np_img, mask) 79 | cv2.imshow("Input Image", np_img) 80 | cv2.imshow(name, mask) 81 | cv2.imwrite("input.png", np_img) 82 | cv2.imwrite(name, mask) 83 | cv2.waitKey(-1) -------------------------------------------------------------------------------- /vit_grad_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | def grad_rollout(attentions, gradients, discard_ratio): 10 | result = torch.eye(attentions[0].size(-1)) 11 | with torch.no_grad(): 12 | for attention, grad in zip(attentions, gradients): 13 | weights = grad 14 | attention_heads_fused = (attention*weights).mean(axis=1) 15 | attention_heads_fused[attention_heads_fused < 0] = 0 16 | 17 | # Drop the lowest attentions, but 18 | # don't drop the class token 19 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 20 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 21 | #indices = indices[indices != 0] 22 | flat[0, indices] = 0 23 | 24 | I = torch.eye(attention_heads_fused.size(-1)) 25 | a = (attention_heads_fused + 1.0*I)/2 26 | a = a / a.sum(dim=-1) 27 | result = torch.matmul(a, result) 28 | 29 | # Look at the total attention between the class token, 30 | # and the image patches 31 | mask = result[0, 0 , 1 :] 32 | # In case of 224x224 image, this brings us from 196 to 14 33 | width = int(mask.size(-1)**0.5) 34 | mask = mask.reshape(width, width).numpy() 35 | mask = mask / np.max(mask) 36 | return mask 37 | 38 | class VITAttentionGradRollout: 39 | def __init__(self, model, attention_layer_name='attn_drop', 40 | discard_ratio=0.9): 41 | self.model = model 42 | self.discard_ratio = discard_ratio 43 | for name, module in self.model.named_modules(): 44 | if attention_layer_name in name: 45 | module.register_forward_hook(self.get_attention) 46 | module.register_backward_hook(self.get_attention_gradient) 47 | 48 | self.attentions = [] 49 | self.attention_gradients = [] 50 | 51 | def get_attention(self, module, input, output): 52 | self.attentions.append(output.cpu()) 53 | 54 | def get_attention_gradient(self, module, grad_input, grad_output): 55 | self.attention_gradients.append(grad_input[0].cpu()) 56 | 57 | def __call__(self, input_tensor, category_index): 58 | self.model.zero_grad() 59 | output = self.model(input_tensor) 60 | category_mask = torch.zeros(output.size()) 61 | category_mask[:, category_index] = 1 62 | loss = (output*category_mask).sum() 63 | loss.backward() 64 | 65 | return grad_rollout(self.attentions, self.attention_gradients, 66 | self.discard_ratio) -------------------------------------------------------------------------------- /vit_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | def rollout(attentions, discard_ratio, head_fusion): 10 | result = torch.eye(attentions[0].size(-1)) 11 | with torch.no_grad(): 12 | for attention in attentions: 13 | if head_fusion == "mean": 14 | attention_heads_fused = attention.mean(axis=1) 15 | elif head_fusion == "max": 16 | attention_heads_fused = attention.max(axis=1)[0] 17 | elif head_fusion == "min": 18 | attention_heads_fused = attention.min(axis=1)[0] 19 | else: 20 | raise "Attention head fusion type Not supported" 21 | 22 | # Drop the lowest attentions, but 23 | # don't drop the class token 24 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 25 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 26 | indices = indices[indices != 0] 27 | flat[0, indices] = 0 28 | 29 | I = torch.eye(attention_heads_fused.size(-1)) 30 | a = (attention_heads_fused + 1.0*I)/2 31 | a = a / a.sum(dim=-1) 32 | 33 | result = torch.matmul(a, result) 34 | 35 | # Look at the total attention between the class token, 36 | # and the image patches 37 | mask = result[0, 0 , 1 :] 38 | # In case of 224x224 image, this brings us from 196 to 14 39 | width = int(mask.size(-1)**0.5) 40 | mask = mask.reshape(width, width).numpy() 41 | mask = mask / np.max(mask) 42 | return mask 43 | 44 | class VITAttentionRollout: 45 | def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean", 46 | discard_ratio=0.9): 47 | self.model = model 48 | self.head_fusion = head_fusion 49 | self.discard_ratio = discard_ratio 50 | for name, module in self.model.named_modules(): 51 | if attention_layer_name in name: 52 | module.register_forward_hook(self.get_attention) 53 | 54 | self.attentions = [] 55 | 56 | def get_attention(self, module, input, output): 57 | self.attentions.append(output.cpu()) 58 | 59 | def __call__(self, input_tensor): 60 | self.attentions = [] 61 | with torch.no_grad(): 62 | output = self.model(input_tensor) 63 | 64 | return rollout(self.attentions, self.discard_ratio, self.head_fusion) --------------------------------------------------------------------------------