├── README.md ├── global_local_attention_module_pytorch ├── __init__.py ├── glam.py ├── global_channel_attention.py ├── global_spatial_attention.py ├── local_channel_attention.py └── local_spatial_attention.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Global-Local-Attention-Module-Pytorch 2 | ### [This implementation is moved to this [repo](https://github.com/LinkAnJarad/Torch-Modules-Compilation)] 3 | 4 | 5 | 6 | Unoffical Implementation of the [Global Local Attention Module (GLAM)](https://arxiv.org/pdf/2107.08000.pdf) in PyTorch 7 | 8 | ![image](https://user-images.githubusercontent.com/79294502/192976117-67fa4a17-eec0-4dda-987d-3c1fc2ffe554.png) 9 | 10 | ## [Paper](https://arxiv.org/pdf/2107.08000.pdf) 11 | 12 | Song, C. H., Han, H. J., & Avrithis, Y. (2022). All the attention you need: Global-local, spatial-channel attention for image retrieval. In *Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision* (pp. 2754-2763). 13 | 14 | ## Usage 15 | 16 | ```python 17 | from global_local_attention_module_pytorch import GLAM 18 | 19 | feature_maps = torch.randn(16, 32, 8, 8) # shape (batch_size, num_channels, height, width) 20 | glam = GLAM(in_channels=32, num_reduced_channels=16, feature_map_size=8, kernel_size=5) 21 | 22 | glam(feature_maps) # shape (16, 32, 8, 8), same as input 23 | ``` 24 | 25 | ## Arguments 26 | 27 | * `in_channels (int)`: number of channels of the input feature map 28 | * `num_reduced_channels (int)`: number of channels that the local and global spatial attention modules will reduce the input feature map. Refer to figures 3 and 5 in the paper. 29 | * `feaure_map_size (int)`: height/width of the feature map. *The height/width of the input feature maps must be at least 7, due to the 7x7 convolution (3x3 dilated conv) in the module.* 30 | * `kernel_size (int)`: scope of the inter-channel attention 31 | -------------------------------------------------------------------------------- /global_local_attention_module_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .glam import GLAM 2 | from .local_channel_attention import LocalChannelAttention 3 | from .global_channel_attention import GlobalChannelAttention 4 | from .local_spatial_attention import LocalSpatialAttention 5 | from .global_spatial_attention import GlobalSpatialAttention 6 | -------------------------------------------------------------------------------- /global_local_attention_module_pytorch/glam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .local_channel_attention import LocalChannelAttention 4 | from .global_channel_attention import GlobalChannelAttention 5 | from .local_spatial_attention import LocalSpatialAttention 6 | from .global_spatial_attention import GlobalSpatialAttention 7 | 8 | 9 | class GLAM(nn.Module): 10 | 11 | def __init__(self, in_channels, num_reduced_channels, feature_map_size, kernel_size): 12 | ''' 13 | Song, C. H., Han, H. J., & Avrithis, Y. (2022). All the attention you need: Global-local, spatial-channel attention for image retrieval. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 2754-2763). 14 | 15 | Args: 16 | in_channels (int): number of channels of the input feature map 17 | num_reduced_channels (int): number of channels that the local and global spatial attention modules will reduce the input feature map. Refer to figures 3 and 5 in the paper. 18 | feaure_map_size (int): height/width of the feature map 19 | kernel_size (int): scope of the inter-channel attention 20 | ''' 21 | 22 | super().__init__() 23 | 24 | self.local_channel_att = LocalChannelAttention(feature_map_size, kernel_size) 25 | self.local_spatial_att = LocalSpatialAttention(in_channels, num_reduced_channels) 26 | self.global_channel_att = GlobalChannelAttention(feature_map_size, kernel_size) 27 | self.global_spatial_att = GlobalSpatialAttention(in_channels, num_reduced_channels) 28 | 29 | self.fusion_weights = nn.Parameter(torch.Tensor([0.333, 0.333, 0.333])) # equal intial weights 30 | 31 | def forward(self, x): 32 | local_channel_att = self.local_channel_att(x) # local channel 33 | local_att = self.local_spatial_att(x, local_channel_att) # local spatial 34 | global_channel_att = self.global_channel_att(x) # global channel 35 | global_att = self.global_spatial_att(x, global_channel_att) # global spatial 36 | 37 | local_att = local_att.unsqueeze(1) # unsqueeze to prepare for concat 38 | global_att = global_att.unsqueeze(1) # unsqueeze to prepare for concat 39 | x = x.unsqueeze(1) # unsqueeze to prepare for concat 40 | 41 | all_feature_maps = torch.cat((local_att, x, global_att), dim=1) 42 | weights = self.fusion_weights.softmax(-1).reshape(1, 3, 1, 1, 1) 43 | fused_feature_maps = (all_feature_maps * weights).sum(1) 44 | 45 | return fused_feature_maps 46 | -------------------------------------------------------------------------------- /global_local_attention_module_pytorch/global_channel_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class GlobalChannelAttention(nn.Module): 5 | def __init__(self, feature_map_size, kernel_size): 6 | super().__init__() 7 | assert (kernel_size%2 == 1), "Kernel size must be odd" 8 | 9 | self.conv_q = nn.Conv1d(1, 1, kernel_size, 1, padding=(kernel_size-1)//2) 10 | self.conv_k = nn.Conv1d(1, 1, kernel_size, 1, padding=(kernel_size-1)//2) 11 | self.GAP = nn.AvgPool2d(feature_map_size) 12 | 13 | def forward(self, x): 14 | N, C, H, W = x.shape 15 | 16 | query = key = self.GAP(x).reshape(N, 1, C) 17 | query = self.conv_q(query).sigmoid() 18 | key = self.conv_q(key).sigmoid().permute(0, 2, 1) 19 | query_key = torch.bmm(key, query).reshape(N, -1) 20 | query_key = query_key.softmax(-1).reshape(N, C, C) 21 | 22 | value = x.permute(0, 2, 3, 1).reshape(N, -1, C) 23 | att = torch.bmm(value, query_key).permute(0, 2, 1) 24 | att = att.reshape(N, C, H, W) 25 | return x * att 26 | -------------------------------------------------------------------------------- /global_local_attention_module_pytorch/global_spatial_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class GlobalSpatialAttention(nn.Module): 5 | def __init__(self, in_channels, num_reduced_channels): 6 | super().__init__() 7 | 8 | self.conv1x1_q = nn.Conv2d(in_channels, num_reduced_channels, 1, 1) 9 | self.conv1x1_k = nn.Conv2d(in_channels, num_reduced_channels, 1, 1) 10 | self.conv1x1_v = nn.Conv2d(in_channels, num_reduced_channels, 1, 1) 11 | self.conv1x1_att = nn.Conv2d(num_reduced_channels, in_channels, 1, 1) 12 | 13 | def forward(self, feature_maps, global_channel_output): 14 | query = self.conv1x1_q(feature_maps) 15 | N, C, H, W = query.shape 16 | query = query.reshape(N, C, -1) 17 | key = self.conv1x1_k(feature_maps).reshape(N, C, -1) 18 | 19 | query_key = torch.bmm(key.permute(0, 2, 1), query) 20 | query_key = query_key.reshape(N, -1).softmax(-1) 21 | query_key = query_key.reshape(N, int(H*W), int(H*W)) 22 | value = self.conv1x1_v(feature_maps).reshape(N, C, -1) 23 | att = torch.bmm(value, query_key).reshape(N, C, H, W) 24 | att = self.conv1x1_att(att) 25 | 26 | return (global_channel_output * att) + global_channel_output 27 | -------------------------------------------------------------------------------- /global_local_attention_module_pytorch/local_channel_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LocalChannelAttention(nn.Module): 5 | def __init__(self, feature_map_size, kernel_size): 6 | super().__init__() 7 | assert (kernel_size%2 == 1), "Kernel size must be odd" 8 | 9 | self.conv = nn.Conv1d(1, 1, kernel_size, 1, padding=(kernel_size-1)//2) 10 | self.GAP = nn.AvgPool2d(feature_map_size) 11 | 12 | def forward(self, x): 13 | N, C, H, W = x.shape 14 | att = self.GAP(x).reshape(N, 1, C) 15 | att = self.conv(att).sigmoid() 16 | att = att.reshape(N, C, 1, 1) 17 | return (x * att) + x 18 | -------------------------------------------------------------------------------- /global_local_attention_module_pytorch/local_spatial_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class LocalSpatialAttention(nn.Module): 5 | def __init__(self, in_channels, num_reduced_channels): 6 | super().__init__() 7 | 8 | self.conv1x1_1 = nn.Conv2d(in_channels, num_reduced_channels, 1, 1) 9 | self.conv1x1_2 = nn.Conv2d(int(num_reduced_channels*4), 1, 1, 1) 10 | 11 | self.dilated_conv3x3 = nn.Conv2d(num_reduced_channels, num_reduced_channels, 3, 1, padding=1) 12 | self.dilated_conv5x5 = nn.Conv2d(num_reduced_channels, num_reduced_channels, 3, 1, padding=2, dilation=2) 13 | self.dilated_conv7x7 = nn.Conv2d(num_reduced_channels, num_reduced_channels, 3, 1, padding=3, dilation=3) 14 | 15 | def forward(self, feature_maps, local_channel_output): 16 | att = self.conv1x1_1(feature_maps) 17 | d1 = self.dilated_conv3x3(att) 18 | d2 = self.dilated_conv5x5(att) 19 | d3 = self.dilated_conv7x7(att) 20 | att = torch.cat((att, d1, d2, d3), dim=1) 21 | att = self.conv1x1_2(att) 22 | return (local_channel_output * att) + local_channel_output 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | VERSION = '0.0.1' 4 | DESCRIPTION = 'Unoffical Implementation of the Global Local Attention Module (GLAM) in PyTorch' 5 | 6 | setup( 7 | name="global_local_attention_module_pytorch", 8 | version=VERSION, 9 | author="Link An Jarad", 10 | description=DESCRIPTION, 11 | url="https://github.com/LinkAnJarad/global_local_attention_module_pytorch", 12 | packages=find_packages(), 13 | install_requires=['torch'] 14 | ) 15 | --------------------------------------------------------------------------------