├── doc └── sand_glass.png ├── README.md └── mobilenext.py /doc/sand_glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RangiLyu/mobilenext/HEAD/doc/sand_glass.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileNeXt: Rethinking Bottleneck Structure for Efficient Mobile Network Design 2 | 3 | ## Introduction 4 | This is a non-official PyTorch implementation of MobileNeXt model from this paper [Rethinking Bottleneck Structure for Efficient Mobile Network Design](https://arxiv.org/abs/2007.02269) 5 | 6 | ## Details 7 | 8 | ### Architecture 9 | The following is the architecture details of MobileNeXt 10 | 11 | | No. | t | Out-Dim | s | b |Inp-Dim |Operater | 12 | | :---- | :---: | :------: | :---: | :------: | :------: |:------: | 13 | | 1 | - |112 × 112 × 32 | 2 | 1 | 224 × 224 × 3| conv2d 3x3| 14 | | 2 | 2 |56 × 56 × 96 | 2 | 1 | 112 × 112 × 32| sandglass block| 15 | | 3 | 6 |56 × 56 × 144 | 1 | 1 | 56 × 56 × 96 | sandglass block| 16 | | 4 | 6 |28 × 28 × 192 | 2 | 3 | 56 × 56 × 144| sandglass block| 17 | | 5 | 6 |14 × 14 × 288 | 2 | 3 | 28 × 28 × 192| sandglass block| 18 | | 6 | 6 |14 × 14 × 384 | 1 | 4 | 14 × 14 × 288| sandglass block| 19 | | 7 | 6 |7 × 7 × 576 | 2 | 4 | 14 × 14 × 384| sandglass block| 20 | | 8 | 6 |7 × 7 × 960 | 1 | 2 | 7 × 7 × 576 | sandglass block| 21 | | 9 | 6 |7 × 7 × 1280 | 1 | 1 | 7 × 7 × 960 | sandglass block| 22 | | 10 | - |1 × 1 × 1280 | - | 1 | 7 × 7 × 1280| avgpool 7x7| 23 | | 11 | - |k | - | 1 | 1 × 1 × 1280| conv2d 1x1| 24 | 25 | ### Sand Glass Module 26 | ![sandglass_image](doc/sand_glass.png) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /mobilenext.py: -------------------------------------------------------------------------------- 1 | """ 2 | non-official PyTorch implementation of MobileNeXt from paper: 3 | Rethinking Bottleneck Structure for Efficient Mobile Network Design 4 | https://arxiv.org/abs/2007.02269 5 | 6 | modified from mobilenetv2 torchvision implementation 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py 8 | 9 | """ 10 | 11 | import math 12 | import torch 13 | from torch import nn 14 | 15 | 16 | __all__ = ['MobileNeXt', 'mobilenext'] 17 | 18 | 19 | 20 | def _make_divisible(v, divisor, min_value=None): 21 | """ 22 | This function is taken from the original tf repo. 23 | It ensures that all layers have a channel number that is divisible by 8 24 | It can be seen here: 25 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 26 | :param v: 27 | :param divisor: 28 | :param min_value: 29 | :return: 30 | """ 31 | if min_value is None: 32 | min_value = divisor 33 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 34 | # Make sure that round down does not go down by more than 10%. 35 | if new_v < 0.9 * v: 36 | new_v += divisor 37 | return new_v 38 | 39 | 40 | class ConvBNReLU(nn.Sequential): 41 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 42 | padding = (kernel_size - 1) // 2 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | super(ConvBNReLU, self).__init__( 46 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 47 | norm_layer(out_planes), 48 | nn.ReLU6(inplace=True) 49 | ) 50 | 51 | 52 | class SandGlass(nn.Module): 53 | def __init__(self, inp, oup, stride, expand_ratio, identity_tensor_multiplier=1.0, norm_layer=None, keep_3x3=False): 54 | super(SandGlass, self).__init__() 55 | self.stride = stride 56 | assert stride in [1, 2] 57 | self.use_identity = False if identity_tensor_multiplier==1.0 else True 58 | self.identity_tensor_channels = int(round(inp*identity_tensor_multiplier)) 59 | 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | 63 | hidden_dim = inp // expand_ratio 64 | if hidden_dim < oup /6.: 65 | hidden_dim = math.ceil(oup / 6.) 66 | hidden_dim = _make_divisible(hidden_dim, 16) 67 | 68 | self.use_res_connect = self.stride == 1 and inp == oup 69 | 70 | layers = [] 71 | # dw 72 | if expand_ratio == 2 or inp==oup or keep_3x3: 73 | layers.append(ConvBNReLU(inp, inp, kernel_size=3, stride=1, groups=inp, norm_layer=norm_layer)) 74 | if expand_ratio != 1: 75 | # pw-linear 76 | layers.extend([ 77 | nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=1, padding=0, groups=1, bias=False), 78 | norm_layer(hidden_dim), 79 | ]) 80 | layers.extend([ 81 | # pw 82 | ConvBNReLU(hidden_dim, oup, kernel_size=1, stride=1, groups=1, norm_layer=norm_layer), 83 | ]) 84 | if expand_ratio == 2 or inp==oup or keep_3x3 or stride==2: 85 | layers.extend([ 86 | # dw-linear 87 | nn.Conv2d(oup, oup, kernel_size=3, stride=stride, groups=oup, padding=1, bias=False), 88 | norm_layer(oup), 89 | ]) 90 | self.conv = nn.Sequential(*layers) 91 | 92 | def forward(self, x): 93 | out = self.conv(x) 94 | if self.use_res_connect: 95 | if self.use_identity: 96 | identity_tensor= x[:,:self.identity_tensor_channels,:,:] + out[:,:self.identity_tensor_channels,:,:] 97 | out = torch.cat([identity_tensor, out[:,self.identity_tensor_channels:,:,:]], dim=1) 98 | # out[:,:self.identity_tensor_channels,:,:] += x[:,:self.identity_tensor_channels,:,:] 99 | else: 100 | out = x + out 101 | return out 102 | else: 103 | return out 104 | 105 | 106 | class MobileNeXt(nn.Module): 107 | def __init__(self, 108 | num_classes=1000, 109 | width_mult=1.0, 110 | identity_tensor_multiplier=1.0, 111 | sand_glass_setting=None, 112 | round_nearest=8, 113 | block=None, 114 | norm_layer=None): 115 | """ 116 | MobileNeXt main class 117 | Args: 118 | num_classes (int): Number of classes 119 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 120 | identity_tensor_multiplier(float): Identity tensor multiplier - reduce the number of element-wise additions in each block 121 | sand_glass_setting: Network structure 122 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 123 | Set to 1 to turn off rounding 124 | block: Module specifying inverted residual building block for mobilenet 125 | norm_layer: Module specifying the normalization layer to use 126 | """ 127 | super(MobileNeXt, self).__init__() 128 | 129 | if block is None: 130 | block = SandGlass 131 | 132 | if norm_layer is None: 133 | norm_layer = nn.BatchNorm2d 134 | 135 | input_channel = 32 136 | last_channel = 1280 137 | 138 | # building first layer 139 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 140 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 141 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 142 | 143 | if sand_glass_setting is None: 144 | sand_glass_setting = [ 145 | # t, c, b, s 146 | [2, 96, 1, 2], 147 | [6, 144, 1, 1], 148 | [6, 192, 3, 2], 149 | [6, 288, 3, 2], 150 | [6, 384, 4, 1], 151 | [6, 576, 4, 2], 152 | [6, 960, 2, 1], 153 | [6, self.last_channel / width_mult, 1, 1], 154 | ] 155 | 156 | # only check the first element, assuming user knows t,c,n,s are required 157 | if len(sand_glass_setting) == 0 or len(sand_glass_setting[0]) != 4: 158 | raise ValueError("sand_glass_setting should be non-empty " 159 | "or a 4-element list, got {}".format(sand_glass_setting)) 160 | 161 | # building sand glass blocks 162 | for t, c, b, s in sand_glass_setting: 163 | output_channel = _make_divisible(c * width_mult, round_nearest) 164 | for i in range(b): 165 | stride = s if i == 0 else 1 166 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, 167 | identity_tensor_multiplier=identity_tensor_multiplier, norm_layer=norm_layer, keep_3x3=(b==1 and s==1 and i==0))) 168 | input_channel = output_channel 169 | 170 | # make it nn.Sequential 171 | self.features = nn.Sequential(*features) 172 | 173 | # building classifier 174 | self.classifier = nn.Sequential( 175 | nn.Dropout(0.2), 176 | nn.Linear(self.last_channel, num_classes) 177 | ) 178 | 179 | 180 | # weight initialization 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 184 | if m.bias is not None: 185 | nn.init.zeros_(m.bias) 186 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 187 | nn.init.ones_(m.weight) 188 | nn.init.zeros_(m.bias) 189 | elif isinstance(m, nn.Linear): 190 | nn.init.normal_(m.weight, 0, 0.01) 191 | nn.init.zeros_(m.bias) 192 | 193 | def _forward_impl(self, x): 194 | # This exists since TorchScript doesn't support inheritance, so the superclass method 195 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 196 | x = self.features(x) 197 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 198 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 199 | x = self.classifier(x) 200 | return x 201 | 202 | def forward(self, x): 203 | return self._forward_impl(x) 204 | 205 | 206 | if __name__ == "__main__": 207 | model = MobileNeXt(num_classes=1000, width_mult=1.0, identity_tensor_multiplier=1.0) 208 | print(model) 209 | 210 | test_data = torch.rand(1, 3, 224, 224) 211 | test_outputs = model(test_data) 212 | print(test_outputs.size()) 213 | --------------------------------------------------------------------------------