├── imgs ├── .DS_Store ├── Figure_1.png ├── task_aware.png ├── DyHead_Block.png ├── scale_attention.png └── spatial_attention.png ├── CODE_OF_CONDUCT.md ├── torch ├── DyHead.py ├── concat_fpn_output.py └── attention_layers.py ├── LICENSE └── README.md /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/.DS_Store -------------------------------------------------------------------------------- /imgs/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/Figure_1.png -------------------------------------------------------------------------------- /imgs/task_aware.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/task_aware.png -------------------------------------------------------------------------------- /imgs/DyHead_Block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/DyHead_Block.png -------------------------------------------------------------------------------- /imgs/scale_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/scale_attention.png -------------------------------------------------------------------------------- /imgs/spatial_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/spatial_attention.png -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /torch/DyHead.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from attention_layers import Scale_Aware_Layer, Spatial_Aware_Layer, Task_Aware_Layer 3 | from collections import OrderedDict 4 | 5 | class DyHead_Block(nn.Module): 6 | def __init__(self, L, S, C): 7 | super(DyHead_Block, self).__init__() 8 | # Saving all dimension sizes of F 9 | self.L_size = L 10 | self.S_size = S 11 | self.C_size = C 12 | 13 | # Inititalizing all attention layers 14 | self.scale_attention = Scale_Aware_Layer(s_size=self.S_size) 15 | self.spatial_attention = Spatial_Aware_Layer(L_size=self.L_size) 16 | self.task_attention = Task_Aware_Layer(num_channels=self.C_size) 17 | 18 | def forward(self, F_tensor): 19 | scale_output = self.scale_attention(F_tensor) 20 | spacial_output = self.spatial_attention(scale_output) 21 | task_output = self.task_attention(spacial_output) 22 | 23 | return task_output 24 | 25 | def DyHead(num_blocks, L, S, C): 26 | blocks = [('Block_{}'.format(i+1),DyHead_Block(L, S, C)) for i in range(num_blocks)] 27 | 28 | return nn.Sequential(OrderedDict(blocks)) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /torch/concat_fpn_output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class concat_feature_maps(nn.Module): 7 | def __init__(self): 8 | super(concat_feature_maps, self).__init__() 9 | 10 | def forward(self, fpn_output): 11 | # Calculating median height to upsample or desample each fpn levels 12 | heights = [] 13 | level_tensors = [] 14 | for key, values in fpn_output.items(): 15 | if key != 'pool': 16 | heights.append(values.shape[2]) 17 | level_tensors.append(values) 18 | median_height = int(np.median(heights)) 19 | 20 | # Upsample and Desampling tensors to median height and width 21 | for i in range(len(level_tensors)): 22 | level = level_tensors[i] 23 | # If level height is greater than median, then downsample with interpolate 24 | if level.shape[2] > median_height: 25 | level = F.interpolate(input=level, size=(median_height, median_height),mode='nearest') 26 | # If level height is less than median, then upsample 27 | else: 28 | level = F.interpolate(input=level, size=(median_height, median_height), mode='nearest') 29 | level_tensors[i] = level 30 | 31 | # Concating all levels with dimensions (batch_size, levels, C, H, W) 32 | concat_levels = torch.stack(level_tensors, dim=1) 33 | 34 | # Reshaping tensor from (batch_size, levels, C, H, W) to (batch_size, levels, HxW=S, C) 35 | concat_levels = concat_levels.flatten(start_dim=3).transpose(dim0=2, dim1=3) 36 | return concat_levels -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # Dynamic Head: Unifying Object Detection Heads with Attentions 5 |
6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dynamic-head-unifying-object-detection-heads/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=dynamic-head-unifying-object-detection-heads) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dynamic-head-unifying-object-detection-heads/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=dynamic-head-unifying-object-detection-heads) 9 | 10 | https://user-images.githubusercontent.com/1438231/122347136-9282e900-cefe-11eb-8b36-ebe08736ec97.mp4 11 | 12 | 13 | This is the official implementation of CVPR 2021 paper "Dynamic Head: Unifying Object Detection Heads with Attentions". 14 | 15 | _"In this paper, we present a novel dynamic head framework to unify object detection heads with attentions. 16 | By coherently combining multiple self-attention mechanisms between feature levels for scale-awareness, among spatial locations for spatial-awareness, and within output channels for task-awareness, the proposed approach significantly improves the representation ability of object detection heads without any computational overhead."_ 17 | 18 | 19 | >[**Dynamic Head: Unifying Object Detection Heads With Attentions**](https://arxiv.org/pdf/2106.08322.pdf) 20 | > 21 | >[Xiyang Dai](https://scholar.google.com/citations?user=QC8RwcoAAAAJ&hl=en), [Yinpeng Chen](https://scholar.google.com/citations?user=V_VpLksAAAAJ&hl=en), [Bin Xiao](https://scholar.google.com/citations?user=t5HZdzoAAAAJ&hl=en), [Dongdong Chen](https://scholar.google.com/citations?user=sYKpKqEAAAAJ&hl=zh-CN), [Mengchen Liu](https://scholar.google.com/citations?user=cOPQtYgAAAAJ&hl=zh-CN), [Lu Yuan](https://scholar.google.com/citations?user=k9TsUVsAAAAJ&hl=en), [Lei Zhang](https://scholar.google.com/citations?user=fIlGZToAAAAJ&hl=en) 22 | 23 | 24 | 25 | ### Model Zoo 26 | 27 | Code and Model are under internal review and will release soon. Stay tuned! 28 | 29 | 30 | ### Citation 31 | 32 | ```BibTeX 33 | @InProceedings{Dai_2021_CVPR, 34 | author = {Dai, Xiyang and Chen, Yinpeng and Xiao, Bin and Chen, Dongdong and Liu, Mengchen and Yuan, Lu and Zhang, Lei}, 35 | title = {Dynamic Head: Unifying Object Detection Heads With Attentions}, 36 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 37 | month = {June}, 38 | year = {2021}, 39 | pages = {7373-7382} 40 | } 41 | ``` 42 | 43 | 44 | 45 | ### Contributing 46 | 47 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 48 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 49 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 50 | 51 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 52 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 53 | provided by the bot. You will only need to do this once across all repos using our CLA. 54 | 55 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 56 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 57 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 58 | 59 | ------ 60 | # My Notes 61 | 62 | Hi there, I am a recent undergrad graduate and am currently looking for ML positions. I always wanted to learn how to implement code from a paper, and I was happy to implement the DyHead attachment that can be used by others. 63 | 64 | All the code I wrote uses PyTorch, these are all the modules: 65 | 1. [`concat_fpn_output.py`](./torch/concat_fpn_output.py) - This takes the output of the FPN and concatenates all the levels to the median height and width of all the levels via upsampling or downsampling. 66 | 2. [`attention_layers.py`](./torch/attention_layers.py) - This contains all the classes for the three attention mechanisms. 67 | - Big Thanks to user Github [Islanna](https://github.com/Islanna/), she implemented code from the [Dynamic ReLU Paper](https://arxiv.org/pdf/2003.10027.pdf). The Task-aware Attention layer uses the same technique from Dynamic-ReLU-A that constructs a dynamic ReLU funtion that are both spatial and channel shared. I used her code as a way to understand how to implement it and I used the same techniques but made the code simpler for my own learning process, but all credits to her. This is her repository: https://github.com/Islanna/DynamicReLU. 68 | 69 | 4. [`DyHead.py`](./torch/DyHead.py) - This contains the classes to construct a single DyHead block or the entire DyHead. 70 | 71 | The [`DyHead_Example.ipynb`](./torch/DyHead_Example.ipynb) notebook demonstrates how all the classes above work, I would encourage to have a look. 72 | 73 | The code used is not the most efficient, but the code is well documented and easily understandable. However, I am sure changes to make it more efficient is not a problem. 74 | 75 | ## Future Additions: 76 | The code does not contruct a full Object Detection model with a DyHead. This is the case because I currently need to change my focus on to just find a new position but also I was confused about the implementation of ROI Pooling on the tensor *F* since dimensions do not contain the spacial dimensions since it was reshaped to be LxSxC not LxHxWxC. I would like to hear more about how this is implemented. 77 | 78 | So in the future when I have more time and a better understanding, I would like to implement both one-stage and two-stage detectors using PyTorch's Built-in FasterRCNN modules to easily adapt the inclusion of DyHead for detection purposes. 79 | -------------------------------------------------------------------------------- /torch/attention_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.ops import DeformConv2d 5 | 6 | 7 | class Scale_Aware_Layer(nn.Module): 8 | # Constructor 9 | def __init__(self, s_size): 10 | super(Scale_Aware_Layer, self).__init__() 11 | 12 | # Average Pooling 13 | self.avg_layer = nn.AvgPool2d(kernel_size=3, stride=1, padding=1) 14 | 15 | #1x1 Conv layer 16 | self.conv = nn.Conv2d(in_channels=s_size, out_channels=1, kernel_size=1) 17 | 18 | # Hard Sigmoid 19 | self.hard_sigmoid = nn.Hardsigmoid() 20 | 21 | # ReLU function 22 | self.relu = nn.ReLU() 23 | 24 | def forward(self, F): 25 | 26 | # Transposing input from (batch_size, L, S, C) to (batch_size, S, L, C) so we can use convolutional layer over the level dimension L 27 | x = F.transpose(dim0=2, dim1=1) 28 | 29 | # Passing tensor through avg pool layer 30 | x = self.avg_layer(x) 31 | 32 | # Passing tensor through Conv layer 33 | x = self.conv(x) 34 | 35 | # Reshaping Tensor from (batch_size, 1, L, C) to (batch_size, L, 1, C) to then be multiplied to F 36 | x = x.transpose(dim0=1, dim1=2) 37 | 38 | # Passing conv output to relu 39 | x = self.relu(x) 40 | 41 | # Passing tensor to hard sigmoid function 42 | pi_L = self.hard_sigmoid(x) 43 | 44 | # pi_L: (batch_size, L, 1, C) 45 | # F: (batch_size, L, S, C) 46 | return pi_L * F 47 | 48 | class Spatial_Aware_Layer(nn.Module): 49 | # Constructor 50 | def __init__(self, L_size, kernel_height=3, kernel_width=3, padding=1, stride=1, dilation=1, groups=1): 51 | super(Spatial_Aware_Layer, self).__init__() 52 | 53 | self.in_channels = L_size 54 | self.out_channels = L_size 55 | 56 | self.kernel_size = (kernel_height, kernel_width) 57 | self.padding = padding 58 | self.stride = stride 59 | self.dilation = dilation 60 | self.K = kernel_height * kernel_width 61 | self.groups = groups 62 | 63 | # 3x3 Convolution with 3K out_channel output as described in Deform Conv2 paper 64 | self.offset_and_mask_conv = nn.Conv2d(in_channels=self.in_channels, 65 | out_channels=3*self.K, #3K depth 66 | kernel_size=self.kernel_size, 67 | stride=self.stride, 68 | padding=self.padding, 69 | dilation=dilation) 70 | 71 | self.deform_conv = DeformConv2d(in_channels=self.in_channels, 72 | out_channels=self.out_channels, 73 | kernel_size=self.kernel_size, 74 | stride=self.stride, 75 | padding=self.padding, 76 | dilation=self.dilation, 77 | groups=self.groups) 78 | def forward(self, F): 79 | # Generating offesets and masks (or modulators) for convolution operation 80 | offsets_and_masks = self.offset_and_mask_conv(F) 81 | 82 | # Separating offsets and masks as described in Deform Conv v2 paper 83 | offset = offsets_and_masks[:, :2*self.K, :, :] # First 2K channels 84 | mask = torch.sigmoid(offsets_and_masks[:, 2*self.K:, : , :]) # Last 1K channels and passing it through sigmoid 85 | 86 | # Passing offsets, masks, and F into deform conv layer 87 | spacial_output = self.deform_conv(F, offset, mask) 88 | return spacial_output 89 | 90 | # DyReLUA technique from Dynamic ReLU paper 91 | class DyReLUA(nn.Module): 92 | def __init__(self, channels, reduction=8, k=2, lambdas=None, init_values=None): 93 | super(DyReLUA, self).__init__() 94 | 95 | self.fc1 = nn.Linear(channels, channels // reduction) 96 | self.fc2 = nn.Linear(channels//reduction, 2*k) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.sigmoid = nn.Sigmoid() 99 | 100 | # Defining lambdas in form of [La1, La2, Lb1, Lb2] 101 | if lambdas is not None: 102 | self.lambdas = lambdas 103 | else: 104 | # Default lambdas from DyReLU paper 105 | self.lambdas = torch.tensor([1.0, 1.0, 0.5, 0.5], dtype=torch.float) 106 | 107 | # Defining Initializing values in form of [alpha1, alpha2, Beta1, Beta2] 108 | if lambdas is not None: 109 | self.init_values = init_values 110 | else: 111 | # Default initializing values of DyReLU paper 112 | self.init_values = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float) 113 | 114 | def forward(self, F_tensor): 115 | 116 | # Global Averaging F 117 | kernel_size = F_tensor.shape[2:] # Getting HxW of F 118 | gap_output = F.avg_pool2d(F_tensor, kernel_size) 119 | 120 | # Flattening gap_output from (batch_size, C, 1, 1) to (batch_size, C) 121 | gap_output = gap_output.flatten(start_dim=1) 122 | 123 | # Passing Global Average output through Fully-Connected Layers 124 | x = self.relu(self.fc1(gap_output)) 125 | x = self.fc2(x) 126 | 127 | # Normalization between (-1, 1) 128 | residuals = 2 * self.sigmoid(x) - 1 129 | 130 | # Getting values of theta, and separating alphas and betas 131 | theta = self.init_values + self.lambdas * residuals # Contains[alpha1(x), alpha2(x), Beta1(x), Beta2(x)] 132 | alphas = theta[0, :2] 133 | betas = theta[0, 2:] 134 | 135 | # Performing maximum on both piecewise functions 136 | output = torch.maximum((alphas[0] * F_tensor + betas[0]), (alphas[1] * F_tensor + betas[1])) 137 | 138 | return output 139 | 140 | class Task_Aware_Layer(nn.Module): 141 | # Defining constructor 142 | def __init__(self, num_channels): 143 | super(Task_Aware_Layer, self).__init__() 144 | 145 | # DyReLUA relu 146 | self.dynamic_relu = DyReLUA(num_channels) 147 | 148 | def forward(self, F_tensor): 149 | # Permutating F from (batch_size, L, S, C) to (batch_size, C, L, S) so we can reduce the dimensions over LxS 150 | F_tensor = F_tensor.permute(0, 3, 1, 2) 151 | 152 | output = self.dynamic_relu(F_tensor) 153 | 154 | # Reversing the permutation 155 | output = output.permute(0, 2, 3, 1) 156 | 157 | return output --------------------------------------------------------------------------------