├── .gitignore
├── LICENSE
├── README.md
├── dgconv.py
├── figs
├── Dynamic_Conv.png
└── ablation.png
└── g_resnext.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Duo LI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dgconv.pytorch
2 | PyTorch implementation of Dynamic Grouping Convolution and Groupable ConvNet in [Differentiable Learning-to-Group Channels via Groupable Convolutional Neural Networks](https://arxiv.org/abs/1908.05867).
3 |
4 |

5 |
6 | * *Kronecker Product* is utilized to construct the sparse matrix efficiently and regularly.
7 | * Discrete optimization is solved with the *Straight-Through Estimator* trick.
8 | * Automatically learn the number of groups in an end-to-end differentiable fashion.
9 |
10 | ## ResNeXt-50 on ImageNet
11 |
12 | **DGConv** is used as a drop-in replacement of depthwise separable convolution in the original ResNeXt to build G-ResNeXt-50/101 network architectures. Here are some results of their performance comparison.
13 |
14 | | Architecture | LR decay strategy | Top-1 / Top-5 Accuracy |
15 | | ------------------------------------------------------------ | ------------------- | ---------------------- |
16 | | [ResNeXt-50 (32x4d)](https://drive.google.com/open?id=1zVQm-aoJV6GRi-mCds7B8HVcsI8Jbjim) | cosine (120 epochs) | 78.198 / 93.916 |
17 | | [G-ResNeXt](https://drive.google.com/open?id=1elM-FVacE-Pkin_hCiW24oKCogHEaYRn) | cosine (120 epochs) | 78.592 / 94.106 |
18 |
19 | 
20 |
21 | ## Citation
22 |
23 | ```bibtex
24 | @InProceedings{Zhang_2019_ICCV,
25 | author = {Zhang, Zhaoyang and Li, Jingyu and Shao, Wenqi and Peng, Zhanglin and Zhang, Ruimao and Wang, Xiaogang and Luo, Ping},
26 | title = {Differentiable Learning-to-Group Channels via Groupable Convolutional Neural Networks},
27 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
28 | month = {Oct},
29 | year = {2019}
30 | }
31 | ```
32 |
33 |
--------------------------------------------------------------------------------
/dgconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | from torch.nn import functional as F
5 | import math
6 |
7 |
8 | def aggregate(gate, D, I, K, sort=False):
9 | if sort:
10 | _, ind = gate.sort(descending=True)
11 | gate = gate[:, ind[0, :]]
12 |
13 | U = [(gate[0, i] * D + gate[1, i] * I) for i in range(K)]
14 | while len(U) != 1:
15 | temp = []
16 | for i in range(0, len(U) - 1, 2):
17 | temp.append(kronecker_product(U[i], U[i + 1]))
18 | if len(U) % 2 != 0:
19 | temp.append(U[-1])
20 | del U
21 | U = temp
22 |
23 | return U[0], gate
24 |
25 |
26 | def kronecker_product(mat1, mat2):
27 | return torch.ger(mat1.view(-1), mat2.view(-1)).reshape(*(mat1.size() + mat2.size())).permute(
28 | [0, 2, 1, 3]).reshape(mat1.size(0) * mat2.size(0), mat1.size(1) * mat2.size(1))
29 |
30 |
31 | class DGConv2d(nn.Module):
32 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, sort=False):
33 | super(DGConv2d, self).__init__()
34 | self.register_buffer('D', torch.eye(2))
35 | self.register_buffer('I', torch.ones(2, 2))
36 | self.K = int(math.log2(in_channels))
37 | eps = 1e-8
38 | gate_init = [eps * random.choice([-1, 1]) for _ in range(self.K)]
39 | self.register_parameter('gate', nn.Parameter(torch.Tensor(gate_init)))
40 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
41 | self.in_channels = in_channels
42 | self.out_channels = out_channels
43 | self.sort = sort
44 |
45 | def forward(self, x):
46 | setattr(self.gate, 'org', self.gate.data.clone())
47 | self.gate.data = ((self.gate.org - 0).sign() + 1) / 2.
48 | U_regularizer = 2 ** (self.K + torch.sum(self.gate))
49 | gate = torch.stack((1 - self.gate, self.gate))
50 | self.gate.data = self.gate.org # Straight-Through Estimator
51 | U, gate = aggregate(gate, self.D, self.I, self.K, sort=self.sort)
52 | masked_weight = self.conv.weight * U.view(self.out_channels, self.in_channels, 1, 1)
53 | x = F.conv2d(x, masked_weight, self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation)
54 | return x, U_regularizer
55 |
56 |
--------------------------------------------------------------------------------
/figs/Dynamic_Conv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-li14/dgconv.pytorch/3b997d6cb04632f3d242e582488e19cdcd52374d/figs/Dynamic_Conv.png
--------------------------------------------------------------------------------
/figs/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-li14/dgconv.pytorch/3b997d6cb04632f3d242e582488e19cdcd52374d/figs/ablation.png
--------------------------------------------------------------------------------
/g_resnext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from dgconv import DGConv2d
4 |
5 |
6 | __all__ = ['G_ResNet', 'g_resnext50', 'g_resnext101']
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
10 | """3x3 convolution with padding"""
11 | return DGConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=dilation, bias=False, dilation=dilation)
13 |
14 |
15 | def conv1x1(in_planes, out_planes, stride=1):
16 | """1x1 convolution"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
18 |
19 |
20 | class Bottleneck(nn.Module):
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
24 | base_width=64, dilation=1, norm_layer=None):
25 | super(Bottleneck, self).__init__()
26 | if norm_layer is None:
27 | norm_layer = nn.BatchNorm2d
28 | width = int(planes * (base_width / 64.)) * groups
29 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
30 | self.conv1 = conv1x1(inplanes, width)
31 | self.bn1 = norm_layer(width)
32 | self.conv2 = conv3x3(width, width, stride, dilation)
33 | self.bn2 = norm_layer(width)
34 | self.conv3 = conv1x1(width, planes * self.expansion)
35 | self.bn3 = norm_layer(planes * self.expansion)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | U_regularizer_sum = 0
42 | if isinstance(x, tuple):
43 | x, U_regularizer_sum = x[0], x[1]
44 | identity = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out, U_regularizer = self.conv2(out)
51 | out = self.bn2(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv3(out)
55 | out = self.bn3(out)
56 |
57 | if self.downsample is not None:
58 | identity = self.downsample(x)
59 |
60 | out += identity
61 | out = self.relu(out)
62 |
63 | return out, U_regularizer_sum + U_regularizer
64 |
65 |
66 | class G_ResNet(nn.Module):
67 |
68 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
69 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
70 | norm_layer=None):
71 | super(G_ResNet, self).__init__()
72 | if norm_layer is None:
73 | norm_layer = nn.BatchNorm2d
74 | self._norm_layer = norm_layer
75 |
76 | self.inplanes = 64
77 | self.dilation = 1
78 | if replace_stride_with_dilation is None:
79 | # each element in the tuple indicates if we should replace
80 | # the 2x2 stride with a dilated convolution instead
81 | replace_stride_with_dilation = [False, False, False]
82 | if len(replace_stride_with_dilation) != 3:
83 | raise ValueError("replace_stride_with_dilation should be None "
84 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
85 | self.groups = groups
86 | self.base_width = width_per_group
87 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
88 | bias=False)
89 | self.bn1 = norm_layer(self.inplanes)
90 | self.relu = nn.ReLU(inplace=True)
91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
92 | self.layer1 = self._make_layer(block, 64, layers[0])
93 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
94 | dilate=replace_stride_with_dilation[0])
95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
96 | dilate=replace_stride_with_dilation[1])
97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
98 | dilate=replace_stride_with_dilation[2])
99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
100 | self.fc = nn.Linear(512 * block.expansion, num_classes)
101 |
102 | for m in self.modules():
103 | if isinstance(m, nn.Conv2d):
104 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106 | nn.init.constant_(m.weight, 1)
107 | nn.init.constant_(m.bias, 0)
108 |
109 | # Zero-initialize the last BN in each residual branch,
110 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
111 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
112 | if zero_init_residual:
113 | for m in self.modules():
114 | if isinstance(m, Bottleneck):
115 | nn.init.constant_(m.bn3.weight, 0)
116 | elif isinstance(m, BasicBlock):
117 | nn.init.constant_(m.bn2.weight, 0)
118 |
119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120 | norm_layer = self._norm_layer
121 | downsample = None
122 | previous_dilation = self.dilation
123 | if dilate:
124 | self.dilation *= stride
125 | stride = 1
126 | if stride != 1 or self.inplanes != planes * block.expansion:
127 | downsample = nn.Sequential(
128 | conv1x1(self.inplanes, planes * block.expansion, stride),
129 | norm_layer(planes * block.expansion),
130 | )
131 |
132 | layers = []
133 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
134 | self.base_width, previous_dilation, norm_layer))
135 | self.inplanes = planes * block.expansion
136 | for _ in range(1, blocks):
137 | layers.append(block(self.inplanes, planes, groups=self.groups,
138 | base_width=self.base_width, dilation=self.dilation,
139 | norm_layer=norm_layer))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.relu(x)
147 | x = self.maxpool(x)
148 |
149 | x, layer1_sum = self.layer1(x)
150 | x, layer2_sum = self.layer2(x)
151 | x, layer3_sum = self.layer3(x)
152 | x, layer4_sum = self.layer4(x)
153 |
154 | x = self.avgpool(x)
155 | x = torch.flatten(x, 1)
156 | x = self.fc(x)
157 |
158 | return x, layer1_sum + layer2_sum + layer3_sum + layer4_sum
159 |
160 |
161 | def g_resnext50(**kwargs):
162 | r"""ResNeXt-50 32x4d model from
163 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
164 | """
165 | kwargs['groups'] = 32
166 | kwargs['width_per_group'] = 4
167 | return G_ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
168 |
169 |
170 | def g_resnext101(**kwargs):
171 | r"""ResNeXt-101 32x8d model from
172 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
173 | """
174 | kwargs['groups'] = 32
175 | kwargs['width_per_group'] = 8
176 | return G_ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
177 |
178 |
--------------------------------------------------------------------------------