├── .gitignore ├── LICENSE ├── README.md ├── dgconv.py ├── figs ├── Dynamic_Conv.png └── ablation.png └── g_resnext.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Duo LI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dgconv.pytorch 2 | PyTorch implementation of Dynamic Grouping Convolution and Groupable ConvNet in [Differentiable Learning-to-Group Channels via Groupable Convolutional Neural Networks](https://arxiv.org/abs/1908.05867). 3 | 4 |

5 | 6 | * *Kronecker Product* is utilized to construct the sparse matrix efficiently and regularly. 7 | * Discrete optimization is solved with the *Straight-Through Estimator* trick. 8 | * Automatically learn the number of groups in an end-to-end differentiable fashion. 9 | 10 | ## ResNeXt-50 on ImageNet 11 | 12 | **DGConv** is used as a drop-in replacement of depthwise separable convolution in the original ResNeXt to build G-ResNeXt-50/101 network architectures. Here are some results of their performance comparison. 13 | 14 | | Architecture | LR decay strategy | Top-1 / Top-5 Accuracy | 15 | | ------------------------------------------------------------ | ------------------- | ---------------------- | 16 | | [ResNeXt-50 (32x4d)](https://drive.google.com/open?id=1zVQm-aoJV6GRi-mCds7B8HVcsI8Jbjim) | cosine (120 epochs) | 78.198 / 93.916 | 17 | | [G-ResNeXt](https://drive.google.com/open?id=1elM-FVacE-Pkin_hCiW24oKCogHEaYRn) | cosine (120 epochs) | 78.592 / 94.106 | 18 | 19 |

20 | 21 | ## Citation 22 | 23 | ```bibtex 24 | @InProceedings{Zhang_2019_ICCV, 25 | author = {Zhang, Zhaoyang and Li, Jingyu and Shao, Wenqi and Peng, Zhanglin and Zhang, Ruimao and Wang, Xiaogang and Luo, Ping}, 26 | title = {Differentiable Learning-to-Group Channels via Groupable Convolutional Neural Networks}, 27 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 28 | month = {Oct}, 29 | year = {2019} 30 | } 31 | ``` 32 | 33 | -------------------------------------------------------------------------------- /dgconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from torch.nn import functional as F 5 | import math 6 | 7 | 8 | def aggregate(gate, D, I, K, sort=False): 9 | if sort: 10 | _, ind = gate.sort(descending=True) 11 | gate = gate[:, ind[0, :]] 12 | 13 | U = [(gate[0, i] * D + gate[1, i] * I) for i in range(K)] 14 | while len(U) != 1: 15 | temp = [] 16 | for i in range(0, len(U) - 1, 2): 17 | temp.append(kronecker_product(U[i], U[i + 1])) 18 | if len(U) % 2 != 0: 19 | temp.append(U[-1]) 20 | del U 21 | U = temp 22 | 23 | return U[0], gate 24 | 25 | 26 | def kronecker_product(mat1, mat2): 27 | return torch.ger(mat1.view(-1), mat2.view(-1)).reshape(*(mat1.size() + mat2.size())).permute( 28 | [0, 2, 1, 3]).reshape(mat1.size(0) * mat2.size(0), mat1.size(1) * mat2.size(1)) 29 | 30 | 31 | class DGConv2d(nn.Module): 32 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, sort=False): 33 | super(DGConv2d, self).__init__() 34 | self.register_buffer('D', torch.eye(2)) 35 | self.register_buffer('I', torch.ones(2, 2)) 36 | self.K = int(math.log2(in_channels)) 37 | eps = 1e-8 38 | gate_init = [eps * random.choice([-1, 1]) for _ in range(self.K)] 39 | self.register_parameter('gate', nn.Parameter(torch.Tensor(gate_init))) 40 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) 41 | self.in_channels = in_channels 42 | self.out_channels = out_channels 43 | self.sort = sort 44 | 45 | def forward(self, x): 46 | setattr(self.gate, 'org', self.gate.data.clone()) 47 | self.gate.data = ((self.gate.org - 0).sign() + 1) / 2. 48 | U_regularizer = 2 ** (self.K + torch.sum(self.gate)) 49 | gate = torch.stack((1 - self.gate, self.gate)) 50 | self.gate.data = self.gate.org # Straight-Through Estimator 51 | U, gate = aggregate(gate, self.D, self.I, self.K, sort=self.sort) 52 | masked_weight = self.conv.weight * U.view(self.out_channels, self.in_channels, 1, 1) 53 | x = F.conv2d(x, masked_weight, self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation) 54 | return x, U_regularizer 55 | 56 | -------------------------------------------------------------------------------- /figs/Dynamic_Conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-li14/dgconv.pytorch/3b997d6cb04632f3d242e582488e19cdcd52374d/figs/Dynamic_Conv.png -------------------------------------------------------------------------------- /figs/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-li14/dgconv.pytorch/3b997d6cb04632f3d242e582488e19cdcd52374d/figs/ablation.png -------------------------------------------------------------------------------- /g_resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dgconv import DGConv2d 4 | 5 | 6 | __all__ = ['G_ResNet', 'g_resnext50', 'g_resnext101'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 10 | """3x3 convolution with padding""" 11 | return DGConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=dilation, bias=False, dilation=dilation) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 24 | base_width=64, dilation=1, norm_layer=None): 25 | super(Bottleneck, self).__init__() 26 | if norm_layer is None: 27 | norm_layer = nn.BatchNorm2d 28 | width = int(planes * (base_width / 64.)) * groups 29 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 30 | self.conv1 = conv1x1(inplanes, width) 31 | self.bn1 = norm_layer(width) 32 | self.conv2 = conv3x3(width, width, stride, dilation) 33 | self.bn2 = norm_layer(width) 34 | self.conv3 = conv1x1(width, planes * self.expansion) 35 | self.bn3 = norm_layer(planes * self.expansion) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | U_regularizer_sum = 0 42 | if isinstance(x, tuple): 43 | x, U_regularizer_sum = x[0], x[1] 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out, U_regularizer = self.conv2(out) 51 | out = self.bn2(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv3(out) 55 | out = self.bn3(out) 56 | 57 | if self.downsample is not None: 58 | identity = self.downsample(x) 59 | 60 | out += identity 61 | out = self.relu(out) 62 | 63 | return out, U_regularizer_sum + U_regularizer 64 | 65 | 66 | class G_ResNet(nn.Module): 67 | 68 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 70 | norm_layer=None): 71 | super(G_ResNet, self).__init__() 72 | if norm_layer is None: 73 | norm_layer = nn.BatchNorm2d 74 | self._norm_layer = norm_layer 75 | 76 | self.inplanes = 64 77 | self.dilation = 1 78 | if replace_stride_with_dilation is None: 79 | # each element in the tuple indicates if we should replace 80 | # the 2x2 stride with a dilated convolution instead 81 | replace_stride_with_dilation = [False, False, False] 82 | if len(replace_stride_with_dilation) != 3: 83 | raise ValueError("replace_stride_with_dilation should be None " 84 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 85 | self.groups = groups 86 | self.base_width = width_per_group 87 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 88 | bias=False) 89 | self.bn1 = norm_layer(self.inplanes) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 92 | self.layer1 = self._make_layer(block, 64, layers[0]) 93 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 94 | dilate=replace_stride_with_dilation[0]) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 96 | dilate=replace_stride_with_dilation[1]) 97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 98 | dilate=replace_stride_with_dilation[2]) 99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 100 | self.fc = nn.Linear(512 * block.expansion, num_classes) 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | # Zero-initialize the last BN in each residual branch, 110 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 111 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 112 | if zero_init_residual: 113 | for m in self.modules(): 114 | if isinstance(m, Bottleneck): 115 | nn.init.constant_(m.bn3.weight, 0) 116 | elif isinstance(m, BasicBlock): 117 | nn.init.constant_(m.bn2.weight, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 120 | norm_layer = self._norm_layer 121 | downsample = None 122 | previous_dilation = self.dilation 123 | if dilate: 124 | self.dilation *= stride 125 | stride = 1 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | conv1x1(self.inplanes, planes * block.expansion, stride), 129 | norm_layer(planes * block.expansion), 130 | ) 131 | 132 | layers = [] 133 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 134 | self.base_width, previous_dilation, norm_layer)) 135 | self.inplanes = planes * block.expansion 136 | for _ in range(1, blocks): 137 | layers.append(block(self.inplanes, planes, groups=self.groups, 138 | base_width=self.base_width, dilation=self.dilation, 139 | norm_layer=norm_layer)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x, layer1_sum = self.layer1(x) 150 | x, layer2_sum = self.layer2(x) 151 | x, layer3_sum = self.layer3(x) 152 | x, layer4_sum = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = torch.flatten(x, 1) 156 | x = self.fc(x) 157 | 158 | return x, layer1_sum + layer2_sum + layer3_sum + layer4_sum 159 | 160 | 161 | def g_resnext50(**kwargs): 162 | r"""ResNeXt-50 32x4d model from 163 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 164 | """ 165 | kwargs['groups'] = 32 166 | kwargs['width_per_group'] = 4 167 | return G_ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 168 | 169 | 170 | def g_resnext101(**kwargs): 171 | r"""ResNeXt-101 32x8d model from 172 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 173 | """ 174 | kwargs['groups'] = 32 175 | kwargs['width_per_group'] = 8 176 | return G_ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 177 | 178 | --------------------------------------------------------------------------------