├── assets ├── mixnet_results.JPG └── mixnet_architecture.JPG ├── README.md ├── LICENSE └── mixnet.py /assets/mixnet_results.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romulus0914/MixNet-PyTorch/HEAD/assets/mixnet_results.JPG -------------------------------------------------------------------------------- /assets/mixnet_architecture.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romulus0914/MixNet-PyTorch/HEAD/assets/mixnet_architecture.JPG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixNet-Pytorch 2 | A PyTorch implementation of MixNet 3 | 4 | # Overview 5 | A PyTorch implementation of MixNet architecture: [MixNet: Mixed Depthwise Convolutional Kernels](https://arxiv.org/pdf/1907.09595.pdf). 6 | Based on MobileNetV2, found by Neural Architecture Search, replacing depthwise convolution to the proposed mixed depthwise convolution (**MDConv**). 7 | Results: More accurate than previous models including MobileNetV2 (ImageNet top-1 accuracy +4.2%), ShuffleNetV2 (+3.5%), MnasNet (+1.3%), ProxylessNAS (+2.2%), and FBNet (+2.0%). 8 | MixNet-L achieves a new state-of-the-art 78.9% ImageNet top-1 accuracy under typical mobile settings (<600M FLOPS). 9 | 10 | # MixNet Architecture 11 | ![Architecture](./assets/mixnet_architecture.JPG) 12 | 13 | # MixNet Results 14 | ![Results](./assets/mixnet_results.JPG) 15 | 16 | # Disclaimer 17 | Slightly modified from [MobileNetV3-PyTorch](https://github.com/AnjieZheng/MobileNetV3-PyTorch) by Anjie Zheng. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Romulus Hong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mixnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import math 7 | 8 | class Swish(nn.Module): 9 | def __init__(self): 10 | super(Swish, self).__init__() 11 | 12 | self.sigmoid = nn.Sigmoid() 13 | 14 | def forward(self, x): 15 | return x * self.sigmoid(x) 16 | 17 | NON_LINEARITY = { 18 | 'ReLU': nn.ReLU(inplace=True), 19 | 'Swish': Swish(), 20 | } 21 | 22 | def _RoundChannels(c, divisor=8, min_value=None): 23 | if min_value is None: 24 | min_value = divisor 25 | new_c = max(min_value, int(c + divisor / 2) // divisor * divisor) 26 | if new_c < 0.9 * c: 27 | new_c += divisor 28 | return new_c 29 | 30 | def _SplitChannels(channels, num_groups): 31 | split_channels = [channels//num_groups for _ in range(num_groups)] 32 | split_channels[0] += channels - sum(split_channels) 33 | return split_channels 34 | 35 | def Conv3x3Bn(in_channels, out_channels, stride, non_linear='ReLU'): 36 | return nn.Sequential( 37 | nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False), 38 | nn.BatchNorm2d(out_channels), 39 | NON_LINEARITY[non_linear] 40 | ) 41 | 42 | def Conv1x1Bn(in_channels, out_channels, non_linear='ReLU'): 43 | return nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), 45 | nn.BatchNorm2d(out_channels), 46 | NON_LINEARITY[non_linear] 47 | ) 48 | 49 | class SqueezeAndExcite(nn.Module): 50 | def __init__(self, channels, squeeze_channels, se_ratio): 51 | super(SqueezeAndExcite, self).__init__() 52 | 53 | squeeze_channels = squeeze_channels * se_ratio 54 | if not squeeze_channels.is_integer(): 55 | raise ValueError('channels must be divisible by 1/ratio') 56 | 57 | squeeze_channels = int(squeeze_channels) 58 | self.se_reduce = nn.Conv2d(channels, squeeze_channels, 1, 1, 0, bias=True) 59 | self.non_linear1 = NON_LINEARITY['Swish'] 60 | self.se_expand = nn.Conv2d(squeeze_channels, channels, 1, 1, 0, bias=True) 61 | self.non_linear2 = nn.Sigmoid() 62 | 63 | def forward(self, x): 64 | y = torch.mean(x, (2, 3), keepdim=True) 65 | y = self.non_linear1(self.se_reduce(y)) 66 | y = self.non_linear2(self.se_expand(y)) 67 | y = x * y 68 | 69 | return y 70 | 71 | class GroupedConv2d(nn.Module): 72 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 73 | super(GroupedConv2d, self).__init__() 74 | 75 | self.num_groups = len(kernel_size) 76 | self.split_in_channels = _SplitChannels(in_channels, self.num_groups) 77 | self.split_out_channels = _SplitChannels(out_channels, self.num_groups) 78 | 79 | self.grouped_conv = nn.ModuleList() 80 | for i in range(self.num_groups): 81 | self.grouped_conv.append(nn.Conv2d( 82 | self.split_in_channels[i], 83 | self.split_out_channels[i], 84 | kernel_size[i], 85 | stride=stride, 86 | padding=padding, 87 | bias=False 88 | )) 89 | 90 | def forward(self, x): 91 | if self.num_groups == 1: 92 | return self.grouped_conv[0](x) 93 | 94 | x_split = torch.split(x, self.split_in_channels, dim=1) 95 | x = [conv(t) for conv, t in zip(self.grouped_conv, x_split)] 96 | x = torch.cat(x, dim=1) 97 | 98 | return x 99 | 100 | class MDConv(nn.Module): 101 | def __init__(self, channels, kernel_size, stride): 102 | super(MDConv, self).__init__() 103 | 104 | self.num_groups = len(kernel_size) 105 | self.split_channels = _SplitChannels(channels, self.num_groups) 106 | 107 | self.mixed_depthwise_conv = nn.ModuleList() 108 | for i in range(self.num_groups): 109 | self.mixed_depthwise_conv.append(nn.Conv2d( 110 | self.split_channels[i], 111 | self.split_channels[i], 112 | kernel_size[i], 113 | stride=stride, 114 | padding=kernel_size[i]//2, 115 | groups=self.split_channels[i], 116 | bias=False 117 | )) 118 | 119 | def forward(self, x): 120 | if self.num_groups == 1: 121 | return self.mixed_depthwise_conv[0](x) 122 | 123 | x_split = torch.split(x, self.split_channels, dim=1) 124 | x = [conv(t) for conv, t in zip(self.mixed_depthwise_conv, x_split)] 125 | x = torch.cat(x, dim=1) 126 | 127 | return x 128 | 129 | class MixNetBlock(nn.Module): 130 | def __init__( 131 | self, 132 | in_channels, 133 | out_channels, 134 | kernel_size=[3], 135 | expand_ksize=[1], 136 | project_ksize=[1], 137 | stride=1, 138 | expand_ratio=1, 139 | non_linear='ReLU', 140 | se_ratio=0.0 141 | ): 142 | 143 | super(MixNetBlock, self).__init__() 144 | 145 | expand = (expand_ratio != 1) 146 | expand_channels = in_channels * expand_ratio 147 | se = (se_ratio != 0.0) 148 | self.residual_connection = (stride == 1 and in_channels == out_channels) 149 | 150 | conv = [] 151 | 152 | if expand: 153 | # expansion phase 154 | pw_expansion = nn.Sequential( 155 | GroupedConv2d(in_channels, expand_channels, expand_ksize), 156 | nn.BatchNorm2d(expand_channels), 157 | NON_LINEARITY[non_linear] 158 | ) 159 | conv.append(pw_expansion) 160 | 161 | # depthwise convolution phase 162 | dw = nn.Sequential( 163 | MDConv(expand_channels, kernel_size, stride), 164 | nn.BatchNorm2d(expand_channels), 165 | NON_LINEARITY[non_linear] 166 | ) 167 | conv.append(dw) 168 | 169 | if se: 170 | # squeeze and excite 171 | squeeze_excite = SqueezeAndExcite(expand_channels, in_channels, se_ratio) 172 | conv.append(squeeze_excite) 173 | 174 | # projection phase 175 | pw_projection = nn.Sequential( 176 | GroupedConv2d(expand_channels, out_channels, project_ksize), 177 | nn.BatchNorm2d(out_channels) 178 | ) 179 | conv.append(pw_projection) 180 | 181 | self.conv = nn.Sequential(*conv) 182 | 183 | def forward(self, x): 184 | if self.residual_connection: 185 | return x + self.conv(x) 186 | else: 187 | return self.conv(x) 188 | 189 | class MixNet(nn.Module): 190 | # [in_channels, out_channels, kernel_size, expand_ksize, project_ksize, stride, expand_ratio, non_linear, se_ratio] 191 | mixnet_s = [(16, 16, [3], [1], [1], 1, 1, 'ReLU', 0.0), 192 | (16, 24, [3], [1, 1], [1, 1], 2, 6, 'ReLU', 0.0), 193 | (24, 24, [3], [1, 1], [1, 1], 1, 3, 'ReLU', 0.0), 194 | (24, 40, [3, 5, 7], [1], [1], 2, 6, 'Swish', 0.5), 195 | (40, 40, [3, 5], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 196 | (40, 40, [3, 5], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 197 | (40, 40, [3, 5], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 198 | (40, 80, [3, 5, 7], [1], [1, 1], 2, 6, 'Swish', 0.25), 199 | (80, 80, [3, 5], [1], [1, 1], 1, 6, 'Swish', 0.25), 200 | (80, 80, [3, 5], [1], [1, 1], 1, 6, 'Swish', 0.25), 201 | (80, 120, [3, 5, 7], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 202 | (120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, 'Swish', 0.5), 203 | (120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, 'Swish', 0.5), 204 | (120, 200, [3, 5, 7, 9, 11], [1], [1], 2, 6, 'Swish', 0.5), 205 | (200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, 'Swish', 0.5), 206 | (200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, 'Swish', 0.5)] 207 | 208 | mixnet_m = [(24, 24, [3], [1], [1], 1, 1, 'ReLU', 0.0), 209 | (24, 32, [3, 5, 7], [1, 1], [1, 1], 2, 6, 'ReLU', 0.0), 210 | (32, 32, [3], [1, 1], [1, 1], 1, 3, 'ReLU', 0.0), 211 | (32, 40, [3, 5, 7, 9], [1], [1], 2, 6, 'Swish', 0.5), 212 | (40, 40, [3, 5], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 213 | (40, 40, [3, 5], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 214 | (40, 40, [3, 5], [1, 1], [1, 1], 1, 6, 'Swish', 0.5), 215 | (40, 80, [3, 5, 7], [1], [1], 2, 6, 'Swish', 0.25), 216 | (80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, 'Swish', 0.25), 217 | (80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, 'Swish', 0.25), 218 | (80, 80, [3, 5, 7, 9], [1, 1], [1, 1], 1, 6, 'Swish', 0.25), 219 | (80, 120, [3], [1], [1], 1, 6, 'Swish', 0.5), 220 | (120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, 'Swish', 0.5), 221 | (120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, 'Swish', 0.5), 222 | (120, 120, [3, 5, 7, 9], [1, 1], [1, 1], 1, 3, 'Swish', 0.5), 223 | (120, 200, [3, 5, 7, 9], [1], [1], 2, 6, 'Swish', 0.5), 224 | (200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, 'Swish', 0.5), 225 | (200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, 'Swish', 0.5), 226 | (200, 200, [3, 5, 7, 9], [1], [1, 1], 1, 6, 'Swish', 0.5)] 227 | 228 | def __init__(self, net_type='mixnet_s', input_size=224, num_classes=1000, stem_channels=16, feature_size=1536, depth_multiplier=1.0): 229 | super(MixNet, self).__init__() 230 | 231 | if net_type == 'mixnet_s': 232 | config = self.mixnet_s 233 | stem_channels = 16 234 | dropout_rate = 0.2 235 | elif net_type == 'mixnet_m': 236 | config = self.mixnet_m 237 | stem_channels = 24 238 | dropout_rate = 0.25 239 | elif net_type == 'mixnet_l': 240 | config = self.mixnet_m 241 | stem_channels = 24 242 | depth_multiplier *= 1.3 243 | dropout_rate = 0.25 244 | else: 245 | raise TypeError('Unsupported MixNet type') 246 | 247 | assert input_size % 32 == 0 248 | 249 | # depth multiplier 250 | if depth_multiplier != 1.0: 251 | stem_channels = _RoundChannels(stem_channels*depth_multiplier) 252 | 253 | for i, conf in enumerate(config): 254 | conf_ls = list(conf) 255 | conf_ls[0] = _RoundChannels(conf_ls[0]*depth_multiplier) 256 | conf_ls[1] = _RoundChannels(conf_ls[1]*depth_multiplier) 257 | config[i] = tuple(conf_ls) 258 | 259 | # stem convolution 260 | self.stem_conv = Conv3x3Bn(3, stem_channels, 2) 261 | 262 | # building MixNet blocks 263 | layers = [] 264 | for in_channels, out_channels, kernel_size, expand_ksize, project_ksize, stride, expand_ratio, non_linear, se_ratio in config: 265 | layers.append(MixNetBlock( 266 | in_channels, 267 | out_channels, 268 | kernel_size=kernel_size, 269 | expand_ksize=expand_ksize, 270 | project_ksize=project_ksize, 271 | stride=stride, 272 | expand_ratio=expand_ratio, 273 | non_linear=non_linear, 274 | se_ratio=se_ratio 275 | )) 276 | self.layers = nn.Sequential(*layers) 277 | 278 | # last several layers 279 | self.head_conv = Conv1x1Bn(config[-1][1], feature_size) 280 | 281 | self.avgpool = nn.AvgPool2d(input_size//32, stride=1) 282 | self.dropout = nn.Dropout(dropout_rate) 283 | self.classifier = nn.Linear(feature_size, num_classes) 284 | 285 | self._initialize_weights() 286 | 287 | def forward(self, x): 288 | x = self.stem_conv(x) 289 | x = self.layers(x) 290 | x = self.head_conv(x) 291 | x = self.avgpool(x) 292 | x = x.view(x.size(0), -1) 293 | x = self.dropout(x) 294 | x = self.classifier(x) 295 | 296 | return x 297 | 298 | def _initialize_weights(self): 299 | for m in self.modules(): 300 | if isinstance(m, nn.Conv2d): 301 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 302 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 303 | if m.bias is not None: 304 | m.bias.data.zero_() 305 | elif isinstance(m, nn.BatchNorm2d): 306 | m.weight.data.fill_(1) 307 | m.bias.data.zero_() 308 | elif isinstance(m, nn.Linear): 309 | n = m.weight.size(1) 310 | m.weight.data.normal_(0, 0.01) 311 | m.bias.data.zero_() 312 | 313 | 314 | if __name__ == '__main__': 315 | net = MixNet() 316 | x_image = Variable(torch.randn(1, 3, 224, 224)) 317 | y = net(x_image) 318 | --------------------------------------------------------------------------------