├── imgs
├── .DS_Store
├── Figure_1.png
├── task_aware.png
├── DyHead_Block.png
├── scale_attention.png
└── spatial_attention.png
├── CODE_OF_CONDUCT.md
├── torch
├── DyHead.py
├── concat_fpn_output.py
└── attention_layers.py
├── LICENSE
└── README.md
/imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/.DS_Store
--------------------------------------------------------------------------------
/imgs/Figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/Figure_1.png
--------------------------------------------------------------------------------
/imgs/task_aware.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/task_aware.png
--------------------------------------------------------------------------------
/imgs/DyHead_Block.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/DyHead_Block.png
--------------------------------------------------------------------------------
/imgs/scale_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/scale_attention.png
--------------------------------------------------------------------------------
/imgs/spatial_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Coldestadam/DynamicHead/HEAD/imgs/spatial_attention.png
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/torch/DyHead.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from attention_layers import Scale_Aware_Layer, Spatial_Aware_Layer, Task_Aware_Layer
3 | from collections import OrderedDict
4 |
5 | class DyHead_Block(nn.Module):
6 | def __init__(self, L, S, C):
7 | super(DyHead_Block, self).__init__()
8 | # Saving all dimension sizes of F
9 | self.L_size = L
10 | self.S_size = S
11 | self.C_size = C
12 |
13 | # Inititalizing all attention layers
14 | self.scale_attention = Scale_Aware_Layer(s_size=self.S_size)
15 | self.spatial_attention = Spatial_Aware_Layer(L_size=self.L_size)
16 | self.task_attention = Task_Aware_Layer(num_channels=self.C_size)
17 |
18 | def forward(self, F_tensor):
19 | scale_output = self.scale_attention(F_tensor)
20 | spacial_output = self.spatial_attention(scale_output)
21 | task_output = self.task_attention(spacial_output)
22 |
23 | return task_output
24 |
25 | def DyHead(num_blocks, L, S, C):
26 | blocks = [('Block_{}'.format(i+1),DyHead_Block(L, S, C)) for i in range(num_blocks)]
27 |
28 | return nn.Sequential(OrderedDict(blocks))
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/torch/concat_fpn_output.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class concat_feature_maps(nn.Module):
7 | def __init__(self):
8 | super(concat_feature_maps, self).__init__()
9 |
10 | def forward(self, fpn_output):
11 | # Calculating median height to upsample or desample each fpn levels
12 | heights = []
13 | level_tensors = []
14 | for key, values in fpn_output.items():
15 | if key != 'pool':
16 | heights.append(values.shape[2])
17 | level_tensors.append(values)
18 | median_height = int(np.median(heights))
19 |
20 | # Upsample and Desampling tensors to median height and width
21 | for i in range(len(level_tensors)):
22 | level = level_tensors[i]
23 | # If level height is greater than median, then downsample with interpolate
24 | if level.shape[2] > median_height:
25 | level = F.interpolate(input=level, size=(median_height, median_height),mode='nearest')
26 | # If level height is less than median, then upsample
27 | else:
28 | level = F.interpolate(input=level, size=(median_height, median_height), mode='nearest')
29 | level_tensors[i] = level
30 |
31 | # Concating all levels with dimensions (batch_size, levels, C, H, W)
32 | concat_levels = torch.stack(level_tensors, dim=1)
33 |
34 | # Reshaping tensor from (batch_size, levels, C, H, W) to (batch_size, levels, HxW=S, C)
35 | concat_levels = concat_levels.flatten(start_dim=3).transpose(dim0=2, dim1=3)
36 | return concat_levels
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Dynamic Head: Unifying Object Detection Heads with Attentions
5 |
6 |
7 | [](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=dynamic-head-unifying-object-detection-heads)
8 | [](https://paperswithcode.com/sota/object-detection-on-coco?p=dynamic-head-unifying-object-detection-heads)
9 |
10 | https://user-images.githubusercontent.com/1438231/122347136-9282e900-cefe-11eb-8b36-ebe08736ec97.mp4
11 |
12 |
13 | This is the official implementation of CVPR 2021 paper "Dynamic Head: Unifying Object Detection Heads with Attentions".
14 |
15 | _"In this paper, we present a novel dynamic head framework to unify object detection heads with attentions.
16 | By coherently combining multiple self-attention mechanisms between feature levels for scale-awareness, among spatial locations for spatial-awareness, and within output channels for task-awareness, the proposed approach significantly improves the representation ability of object detection heads without any computational overhead."_
17 |
18 |
19 | >[**Dynamic Head: Unifying Object Detection Heads With Attentions**](https://arxiv.org/pdf/2106.08322.pdf)
20 | >
21 | >[Xiyang Dai](https://scholar.google.com/citations?user=QC8RwcoAAAAJ&hl=en), [Yinpeng Chen](https://scholar.google.com/citations?user=V_VpLksAAAAJ&hl=en), [Bin Xiao](https://scholar.google.com/citations?user=t5HZdzoAAAAJ&hl=en), [Dongdong Chen](https://scholar.google.com/citations?user=sYKpKqEAAAAJ&hl=zh-CN), [Mengchen Liu](https://scholar.google.com/citations?user=cOPQtYgAAAAJ&hl=zh-CN), [Lu Yuan](https://scholar.google.com/citations?user=k9TsUVsAAAAJ&hl=en), [Lei Zhang](https://scholar.google.com/citations?user=fIlGZToAAAAJ&hl=en)
22 |
23 |
24 |
25 | ### Model Zoo
26 |
27 | Code and Model are under internal review and will release soon. Stay tuned!
28 |
29 |
30 | ### Citation
31 |
32 | ```BibTeX
33 | @InProceedings{Dai_2021_CVPR,
34 | author = {Dai, Xiyang and Chen, Yinpeng and Xiao, Bin and Chen, Dongdong and Liu, Mengchen and Yuan, Lu and Zhang, Lei},
35 | title = {Dynamic Head: Unifying Object Detection Heads With Attentions},
36 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
37 | month = {June},
38 | year = {2021},
39 | pages = {7373-7382}
40 | }
41 | ```
42 |
43 |
44 |
45 | ### Contributing
46 |
47 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
48 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
49 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
50 |
51 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
52 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
53 | provided by the bot. You will only need to do this once across all repos using our CLA.
54 |
55 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
56 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
57 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
58 |
59 | ------
60 | # My Notes
61 |
62 | Hi there, I am a recent undergrad graduate and am currently looking for ML positions. I always wanted to learn how to implement code from a paper, and I was happy to implement the DyHead attachment that can be used by others.
63 |
64 | All the code I wrote uses PyTorch, these are all the modules:
65 | 1. [`concat_fpn_output.py`](./torch/concat_fpn_output.py) - This takes the output of the FPN and concatenates all the levels to the median height and width of all the levels via upsampling or downsampling.
66 | 2. [`attention_layers.py`](./torch/attention_layers.py) - This contains all the classes for the three attention mechanisms.
67 | - Big Thanks to user Github [Islanna](https://github.com/Islanna/), she implemented code from the [Dynamic ReLU Paper](https://arxiv.org/pdf/2003.10027.pdf). The Task-aware Attention layer uses the same technique from Dynamic-ReLU-A that constructs a dynamic ReLU funtion that are both spatial and channel shared. I used her code as a way to understand how to implement it and I used the same techniques but made the code simpler for my own learning process, but all credits to her. This is her repository: https://github.com/Islanna/DynamicReLU.
68 |
69 | 4. [`DyHead.py`](./torch/DyHead.py) - This contains the classes to construct a single DyHead block or the entire DyHead.
70 |
71 | The [`DyHead_Example.ipynb`](./torch/DyHead_Example.ipynb) notebook demonstrates how all the classes above work, I would encourage to have a look.
72 |
73 | The code used is not the most efficient, but the code is well documented and easily understandable. However, I am sure changes to make it more efficient is not a problem.
74 |
75 | ## Future Additions:
76 | The code does not contruct a full Object Detection model with a DyHead. This is the case because I currently need to change my focus on to just find a new position but also I was confused about the implementation of ROI Pooling on the tensor *F* since dimensions do not contain the spacial dimensions since it was reshaped to be LxSxC not LxHxWxC. I would like to hear more about how this is implemented.
77 |
78 | So in the future when I have more time and a better understanding, I would like to implement both one-stage and two-stage detectors using PyTorch's Built-in FasterRCNN modules to easily adapt the inclusion of DyHead for detection purposes.
79 |
--------------------------------------------------------------------------------
/torch/attention_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.ops import DeformConv2d
5 |
6 |
7 | class Scale_Aware_Layer(nn.Module):
8 | # Constructor
9 | def __init__(self, s_size):
10 | super(Scale_Aware_Layer, self).__init__()
11 |
12 | # Average Pooling
13 | self.avg_layer = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)
14 |
15 | #1x1 Conv layer
16 | self.conv = nn.Conv2d(in_channels=s_size, out_channels=1, kernel_size=1)
17 |
18 | # Hard Sigmoid
19 | self.hard_sigmoid = nn.Hardsigmoid()
20 |
21 | # ReLU function
22 | self.relu = nn.ReLU()
23 |
24 | def forward(self, F):
25 |
26 | # Transposing input from (batch_size, L, S, C) to (batch_size, S, L, C) so we can use convolutional layer over the level dimension L
27 | x = F.transpose(dim0=2, dim1=1)
28 |
29 | # Passing tensor through avg pool layer
30 | x = self.avg_layer(x)
31 |
32 | # Passing tensor through Conv layer
33 | x = self.conv(x)
34 |
35 | # Reshaping Tensor from (batch_size, 1, L, C) to (batch_size, L, 1, C) to then be multiplied to F
36 | x = x.transpose(dim0=1, dim1=2)
37 |
38 | # Passing conv output to relu
39 | x = self.relu(x)
40 |
41 | # Passing tensor to hard sigmoid function
42 | pi_L = self.hard_sigmoid(x)
43 |
44 | # pi_L: (batch_size, L, 1, C)
45 | # F: (batch_size, L, S, C)
46 | return pi_L * F
47 |
48 | class Spatial_Aware_Layer(nn.Module):
49 | # Constructor
50 | def __init__(self, L_size, kernel_height=3, kernel_width=3, padding=1, stride=1, dilation=1, groups=1):
51 | super(Spatial_Aware_Layer, self).__init__()
52 |
53 | self.in_channels = L_size
54 | self.out_channels = L_size
55 |
56 | self.kernel_size = (kernel_height, kernel_width)
57 | self.padding = padding
58 | self.stride = stride
59 | self.dilation = dilation
60 | self.K = kernel_height * kernel_width
61 | self.groups = groups
62 |
63 | # 3x3 Convolution with 3K out_channel output as described in Deform Conv2 paper
64 | self.offset_and_mask_conv = nn.Conv2d(in_channels=self.in_channels,
65 | out_channels=3*self.K, #3K depth
66 | kernel_size=self.kernel_size,
67 | stride=self.stride,
68 | padding=self.padding,
69 | dilation=dilation)
70 |
71 | self.deform_conv = DeformConv2d(in_channels=self.in_channels,
72 | out_channels=self.out_channels,
73 | kernel_size=self.kernel_size,
74 | stride=self.stride,
75 | padding=self.padding,
76 | dilation=self.dilation,
77 | groups=self.groups)
78 | def forward(self, F):
79 | # Generating offesets and masks (or modulators) for convolution operation
80 | offsets_and_masks = self.offset_and_mask_conv(F)
81 |
82 | # Separating offsets and masks as described in Deform Conv v2 paper
83 | offset = offsets_and_masks[:, :2*self.K, :, :] # First 2K channels
84 | mask = torch.sigmoid(offsets_and_masks[:, 2*self.K:, : , :]) # Last 1K channels and passing it through sigmoid
85 |
86 | # Passing offsets, masks, and F into deform conv layer
87 | spacial_output = self.deform_conv(F, offset, mask)
88 | return spacial_output
89 |
90 | # DyReLUA technique from Dynamic ReLU paper
91 | class DyReLUA(nn.Module):
92 | def __init__(self, channels, reduction=8, k=2, lambdas=None, init_values=None):
93 | super(DyReLUA, self).__init__()
94 |
95 | self.fc1 = nn.Linear(channels, channels // reduction)
96 | self.fc2 = nn.Linear(channels//reduction, 2*k)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.sigmoid = nn.Sigmoid()
99 |
100 | # Defining lambdas in form of [La1, La2, Lb1, Lb2]
101 | if lambdas is not None:
102 | self.lambdas = lambdas
103 | else:
104 | # Default lambdas from DyReLU paper
105 | self.lambdas = torch.tensor([1.0, 1.0, 0.5, 0.5], dtype=torch.float)
106 |
107 | # Defining Initializing values in form of [alpha1, alpha2, Beta1, Beta2]
108 | if lambdas is not None:
109 | self.init_values = init_values
110 | else:
111 | # Default initializing values of DyReLU paper
112 | self.init_values = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float)
113 |
114 | def forward(self, F_tensor):
115 |
116 | # Global Averaging F
117 | kernel_size = F_tensor.shape[2:] # Getting HxW of F
118 | gap_output = F.avg_pool2d(F_tensor, kernel_size)
119 |
120 | # Flattening gap_output from (batch_size, C, 1, 1) to (batch_size, C)
121 | gap_output = gap_output.flatten(start_dim=1)
122 |
123 | # Passing Global Average output through Fully-Connected Layers
124 | x = self.relu(self.fc1(gap_output))
125 | x = self.fc2(x)
126 |
127 | # Normalization between (-1, 1)
128 | residuals = 2 * self.sigmoid(x) - 1
129 |
130 | # Getting values of theta, and separating alphas and betas
131 | theta = self.init_values + self.lambdas * residuals # Contains[alpha1(x), alpha2(x), Beta1(x), Beta2(x)]
132 | alphas = theta[0, :2]
133 | betas = theta[0, 2:]
134 |
135 | # Performing maximum on both piecewise functions
136 | output = torch.maximum((alphas[0] * F_tensor + betas[0]), (alphas[1] * F_tensor + betas[1]))
137 |
138 | return output
139 |
140 | class Task_Aware_Layer(nn.Module):
141 | # Defining constructor
142 | def __init__(self, num_channels):
143 | super(Task_Aware_Layer, self).__init__()
144 |
145 | # DyReLUA relu
146 | self.dynamic_relu = DyReLUA(num_channels)
147 |
148 | def forward(self, F_tensor):
149 | # Permutating F from (batch_size, L, S, C) to (batch_size, C, L, S) so we can reduce the dimensions over LxS
150 | F_tensor = F_tensor.permute(0, 3, 1, 2)
151 |
152 | output = self.dynamic_relu(F_tensor)
153 |
154 | # Reversing the permutation
155 | output = output.permute(0, 2, 3, 1)
156 |
157 | return output
--------------------------------------------------------------------------------