├── .gitignore ├── LICENSE ├── README.md ├── efficient_attention.py ├── illustration.png └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shen Zhuoran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Attention 2 | 3 | An implementation of the [efficient attention](https://arxiv.org/abs/1812.01243) module. 4 | 5 | ## Description 6 | 7 | ![](illustration.png) 8 | 9 | Efficient attention is an attention mechanism that substantially optimizes the memory and computational efficiency while retaining **exactly** the same expressive power as the conventional dot-product attention. The illustration above compares the two types of attention. The efficient attention module is a drop-in replacement for the non-local module ([Wang et al., 2018](https://arxiv.org/abs/1711.07971)), while it: 10 | 11 | - uses less resources to achieve the same accuracy; 12 | - achieves higher accuracy with the same resource constraints (by allowing more insertions); and 13 | - is applicable in domains and models where the non-local module is not (due to resource constraints). 14 | 15 | ## Resources 16 | 17 | YouTube: 18 | - Presentation: https://youtu.be/_wnjhTM04NM 19 | 20 | bilibili (for users in Mainland China): 21 | - Presentation: https://www.bilibili.com/video/BV1tK4y1f7Rm 22 | - Presentation in Chinese: https://www.bilibili.com/video/bv1Gt4y1Y7E3 23 | 24 | ## Implementation details 25 | 26 | This repository implements the efficient attention module with softmax normalization, output reprojection, and residual connection. 27 | 28 | ## Features not in the paper 29 | 30 | This repository implements additionally implements the multi-head mechanism which was not in the paper. To learn more about the mechanism, refer to [Vaswani et al.](https://arxiv.org/abs/1706.03762) 31 | 32 | ## Citation 33 | 34 | The [paper](https://arxiv.org/abs/1812.01243) will appear at WACV 2021. If you use, compare with, or refer to this work, please cite 35 | 36 | ```bibtex 37 | @inproceedings{shen2021efficient, 38 | author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li}, 39 | title = {Efficient Attention: Attention with Linear Complexities}, 40 | booktitle = {WACV}, 41 | year = {2021}, 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /efficient_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as f 4 | 5 | 6 | class EfficientAttention(nn.Module): 7 | 8 | def __init__(self, in_channels, key_channels, head_count, value_channels): 9 | super().__init__() 10 | self.in_channels = in_channels 11 | self.key_channels = key_channels 12 | self.head_count = head_count 13 | self.value_channels = value_channels 14 | 15 | self.keys = nn.Conv2d(in_channels, key_channels, 1) 16 | self.queries = nn.Conv2d(in_channels, key_channels, 1) 17 | self.values = nn.Conv2d(in_channels, value_channels, 1) 18 | self.reprojection = nn.Conv2d(value_channels, in_channels, 1) 19 | 20 | def forward(self, input_): 21 | n, _, h, w = input_.size() 22 | keys = self.keys(input_).reshape((n, self.key_channels, h * w)) 23 | queries = self.queries(input_).reshape(n, self.key_channels, h * w) 24 | values = self.values(input_).reshape((n, self.value_channels, h * w)) 25 | head_key_channels = self.key_channels // self.head_count 26 | head_value_channels = self.value_channels // self.head_count 27 | 28 | attended_values = [] 29 | for i in range(self.head_count): 30 | key = f.softmax(keys[ 31 | :, 32 | i * head_key_channels: (i + 1) * head_key_channels, 33 | : 34 | ], dim=2) 35 | query = f.softmax(queries[ 36 | :, 37 | i * head_key_channels: (i + 1) * head_key_channels, 38 | : 39 | ], dim=1) 40 | value = values[ 41 | :, 42 | i * head_value_channels: (i + 1) * head_value_channels, 43 | : 44 | ] 45 | context = key @ value.transpose(1, 2) 46 | attended_value = ( 47 | context.transpose(1, 2) @ query 48 | ).reshape(n, head_value_channels, h, w) 49 | attended_values.append(attended_value) 50 | 51 | aggregated_values = torch.cat(attended_values, dim=1) 52 | reprojected_value = self.reprojection(aggregated_values) 53 | attention = reprojected_value + input_ 54 | 55 | return attention 56 | -------------------------------------------------------------------------------- /illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmsflash/efficient-attention/46a5f9eaf09470affb0ab30932b7748cc3c871ef/illustration.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from efficient_attention import EfficientAttention 3 | 4 | 5 | x = torch.tensor([[[[1, 1], [1, 1]]]], dtype=torch.float32) 6 | print(EfficientAttention(1, 2, 2, 2)(x)) 7 | --------------------------------------------------------------------------------