├── MnasNet.py └── README.md /MnasNet.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | 6 | 7 | def Conv_3x3(inp, oup, stride): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 10 | nn.BatchNorm2d(oup), 11 | nn.ReLU6(inplace=True) 12 | ) 13 | 14 | 15 | def Conv_1x1(inp, oup): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU6(inplace=True) 20 | ) 21 | 22 | def SepConv_3x3(inp, oup): #input=32, output=16 23 | return nn.Sequential( 24 | # dw 25 | nn.Conv2d(inp, inp , 3, 1, 1, groups=inp, bias=False), 26 | nn.BatchNorm2d(inp), 27 | nn.ReLU6(inplace=True), 28 | # pw-linear 29 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 30 | nn.BatchNorm2d(oup), 31 | ) 32 | 33 | 34 | class InvertedResidual(nn.Module): 35 | def __init__(self, inp, oup, stride, expand_ratio, kernel): 36 | super(InvertedResidual, self).__init__() 37 | self.stride = stride 38 | assert stride in [1, 2] 39 | 40 | self.use_res_connect = self.stride == 1 and inp == oup 41 | 42 | self.conv = nn.Sequential( 43 | # pw 44 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 45 | nn.BatchNorm2d(inp * expand_ratio), 46 | nn.ReLU6(inplace=True), 47 | # dw 48 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, kernel, stride, kernel // 2, groups=inp * expand_ratio, bias=False), 49 | nn.BatchNorm2d(inp * expand_ratio), 50 | nn.ReLU6(inplace=True), 51 | # pw-linear 52 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 53 | nn.BatchNorm2d(oup), 54 | ) 55 | 56 | def forward(self, x): 57 | if self.use_res_connect: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class MnasNet(nn.Module): 64 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 65 | super(MnasNet, self).__init__() 66 | 67 | # setting of inverted residual blocks 68 | self.interverted_residual_setting = [ 69 | # t, c, n, s, k 70 | [3, 24, 3, 2, 3], # -> 56x56 71 | [3, 40, 3, 2, 5], # -> 28x28 72 | [6, 80, 3, 2, 5], # -> 14x14 73 | [6, 96, 2, 1, 3], # -> 14x14 74 | [6, 192, 4, 2, 5], # -> 7x7 75 | [6, 320, 1, 1, 3], # -> 7x7 76 | ] 77 | 78 | assert input_size % 32 == 0 79 | input_channel = int(32 * width_mult) 80 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 81 | 82 | # building first two layer 83 | self.features = [Conv_3x3(3, input_channel, 2), SepConv_3x3(input_channel, 16)] 84 | input_channel = 16 85 | 86 | # building inverted residual blocks (MBConv) 87 | for t, c, n, s, k in self.interverted_residual_setting: 88 | output_channel = int(c * width_mult) 89 | for i in range(n): 90 | if i == 0: 91 | self.features.append(InvertedResidual(input_channel, output_channel, s, t, k)) 92 | else: 93 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t, k)) 94 | input_channel = output_channel 95 | 96 | # building last several layers 97 | self.features.append(Conv_1x1(input_channel, self.last_channel)) 98 | self.features.append(nn.AdaptiveAvgPool2d(1)) 99 | 100 | # make it nn.Sequential 101 | self.features = nn.Sequential(*self.features) 102 | 103 | # building classifier 104 | self.classifier = nn.Sequential( 105 | nn.Dropout(), 106 | nn.Linear(self.last_channel, n_class), 107 | ) 108 | 109 | self._initialize_weights() 110 | 111 | def forward(self, x): 112 | x = self.features(x) 113 | x = x.view(-1, self.last_channel) 114 | x = self.classifier(x) 115 | return x 116 | 117 | def _initialize_weights(self): 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | if m.bias is not None: 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | elif isinstance(m, nn.Linear): 128 | n = m.weight.size(1) 129 | m.weight.data.normal_(0, 0.01) 130 | m.bias.data.zero_() 131 | 132 | 133 | if __name__ == '__main__': 134 | net = MnasNet() 135 | x_image = Variable(torch.randn(1, 3, 224, 224)) 136 | y = net(x_image) 137 | # print(y) 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | A PyTorch implementation of Mnasnet searched architecture: [MnasNet: Platform-Aware Neural Architecture Search for Mobile](https://arxiv.org/abs/1807.11626). On the ImageNet classification task, the model achieves 74.0% top-1 accuracy with 76ms latency on a Pixel phone, which is 1.5× faster than MobileNetV2 3 | 4 | # MnasNet Architecture 5 | ![Alt text](https://i.imgur.com/ryyU8cP.png) 6 | 7 | 8 | # Disclaimer 9 | Codes modified from [mobilenet-v2](https://github.com/tonylins/pytorch-mobilenet-v2) 10 | --------------------------------------------------------------------------------