├── LICENSE ├── README.md ├── rexnetv1.py └── rexnetv1_lite.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #### (NOTICE) All the ReXNet-lite's model files have been updated! 2 | #### (NOTICE) Our paper has been accepted at CVPR 2021!! The paper has been updated at [arxiv](https://arxiv.org/pdf/2007.00992.pdf)! 3 | 4 | ## Rethinking Channel Dimensions for Efficient Model Design 5 | 6 | **Dongyoon Han, Sangdoo Yun, Byeongho Heo, and YoungJoon Yoo** | [Paper](https://arxiv.org/abs/2007.00992) | [Pretrained Models](#pretrained) 7 | 8 | NAVER AI Lab 9 | 10 | ## Abstract 11 | 12 | Designing an efficient model within the limited computational cost is challenging. We argue the accuracy of a lightweight model has been further limited by the design convention: a stage-wise configuration of the channel dimensions, which looks like a piecewise linear function of the network stage. In this paper, we study an effective channel dimension configuration towards better performance than the convention. To this end, we empirically study how to design a single layer properly by analyzing the rank of the output feature. We then investigate the channel configuration of a model by searching network architectures concerning the channel configuration under the computational cost restriction. Based on the investigation, we propose a simple yet effective channel configuration that can be parameterized by the layer index. As a result, our proposed model following the channel parameterization achieves remarkable performance on ImageNet classification and transfer learning tasks including COCO object detection, COCO instance segmentation, and fine-grained classifications. 13 | 14 | ## Model performance 15 | - We first illustrate our models' top-acc. vs. computational costs graphs compared with EfficientNets 16 | 17 | 18 | 19 | 20 | 21 | ### Performance comparison 22 | #### ReXNets vs EfficientNets 23 | - The CPU latencies are tested on Xeon E5-2630_v4 with a single image and the GPU latencies are measured on a V100 GPU with **the batchsize of 64**. 24 | - EfficientNets' scores are taken form [arxiv v3 of the paper](https://arxiv.org/pdf/1905.11946v3.pdf). 25 | 26 | Model | Input Res. | Top-1 acc. | Top-5 acc. | FLOPs/params. | CPU Lat./ GPU Lat. 27 | :--: |:--:|:--:|:--:|:--:|:--:| 28 | **ReXNet_0.9** | 224x224 | 77.2 | 93.5 | 0.35B/4.1M | 45ms/20ms 29 | ||||| 30 | EfficientNet-B0 | 224x224 | 77.3 | 93.5 | 0.39B/5.3M | 47ms/23ms 31 | **ReXNet_1.0** | 224x224 | 77.9 | 93.9 | 0.40B/4.8M | 47ms/21ms 32 | ||||| 33 | EfficientNet-B1 | 240x240 | 79.2 | 94.5 | 0.70B/7.8M | 70ms/37ms 34 | **ReXNet_1.3** | 224x224 | 79.5 | 94.7| 0.66B/7.6M | 55ms/28ms 35 | ||||| 36 | EfficientNet-B2 | 260x260 | 80.3 | 95.0 | 1.0B/9.2M | 77ms/48ms 37 | **ReXNet_1.5** | 224x224 | 80.3 | 95.2| 0.88B/9.7M | 59ms/31ms 38 | ||||| 39 | EfficientNet-B3 | 300x300 | 81.7 | 95.6 | 1.8B/12M | 100ms/78ms 40 | **ReXNet_2.0** | 224x224 | 81.6 | 95.7 | 1.8B/19M | 69ms/40ms 41 | 42 | #### ReXNet-lites vs. EfficientNet-lites 43 | - ReXNet-lites do not use SE-net an SiLU activations aiming to faster training and inference speed. 44 | - We compare ReXNet-lites with [EfficientNet-lites](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite). 45 | - Here the GPU latencies are measured on two M40 GPUs, we will update the number run on a V100 GPU soon. 46 | 47 | Model | Input Res. | Top-1 acc. | Top-5 acc. | FLOPs/params | CPU Lat./ GPU Lat. 48 | :--: |:--:|:--:|:--:|:--:|:--:| 49 | EfficientNet-lite0 | 224x224 | 75.1 | - | 0.41B/4.7M | 30ms/49ms 50 | **ReXNet-lite_1.0** | 224x224 | 76.2 | 92.8 | 0.41B/4.7M | 31ms/49ms 51 | ||||| 52 | EfficientNet-lite1 | 240x240 | 76.7 | - | 0.63B/5.4M | 44ms/73ms 53 | **ReXNet-lite_1.3** | 224x224 | 77.8 | 93.8 | 0.65B/6.8M | 36ms/61ms 54 | ||||| 55 | EfficientNet-lite2 | 260x260 | 77.6 | - | 0.90B/ 6.1M | 48ms/93ms 56 | **ReXNet-lite_1.5** | 224x224 | 78.6 | 94.2| 0.84B/8.3M| 39ms/68ms 57 | ||||| 58 | EfficientNet-lite3 | 280x280| 79.8 | - | 1.4B/ 8.2M | 60ms/131ms 59 | **ReXNet-lite_2.0** | 224x224 | 80.2 | 95.0 | 1.5B/13M | 49ms/90ms 60 | 61 | ## ImageNet-1k Pretrained models 62 |

ImageNet classification results

63 | 64 | - Please refer the following pretrained models. Top-1 and top-5 accuraies are reported with the computational costs. 65 | - Note that all the models are trained and evaluated with 224x224 image size. 66 | 67 | Model | Input Res. | Top-1 acc. | Top-5 acc. | FLOPs/params | 68 | :--: |:--:|:--:|:--:|:--: 69 | [ReXNet_1.0](https://drive.google.com/file/d/1xeIJ3wb83uOowU008ykYj6wDX2dsncA9/view?usp=sharing) | 224x224 | 77.9 | 93.9 | 0.40B/4.8M | 70 | [ReXNet_1.3](https://drive.google.com/file/d/1x2ziK9Oyv66Y9NsxJxXsdjzpQF2uSJj0/view?usp=sharing) | 224x224 | 79.5 | 94.7 | 0.66B/7.6M | 71 | [ReXNet_1.5](https://drive.google.com/file/d/1TOBGsbDhTHWBgqcRnyKIR0tHsJTOPUIG/view?usp=sharing) | 224x224 | 80.3 | 95.2 | 0.88B/9.7M | 72 | [ReXNet_2.0](https://drive.google.com/file/d/1R1aOTKIe1Mvck86NanqcjWnlR8DY-Z4C/view?usp=sharing) | 224x224 | 81.6 | 95.7 | 1.5B/16M | 73 | [ReXNet_3.0](https://drive.google.com/file/d/1iXAsr8gs3pRz0QyHKomdj5SGVzPWbIs2/view?usp=sharing) | 224x224 | 82.8 | 96.2 | 3.4B/34M | 74 | |||| 75 | [ReXNet-lite_1.0](https://drive.google.com/file/d/1d9G4pLwZwkoDR2TRPCQlxiWiuC7R-Oqf/view?usp=sharing) | 224x224 | 76.2 | 92.8 | 0.41B/4.7M | 76 | [ReXNet-lite_1.3](https://drive.google.com/file/d/1NsbsdI8qAHG6HdMxmySXcrl9NdEx3s0L/view?usp=sharing) | 224x224 | 77.8 | 93.8 | 0.65B/6.8M | 77 | [ReXNet-lite_1.5](https://drive.google.com/file/d/12QzIh9A-U0PBGaLNOIr4gX2MoZEBnRjk/view?usp=sharing) | 224x224 | 78.6 | 94.2 | 0.84B/8.3M| 78 | [ReXNet-lite_2.0](https://drive.google.com/file/d/1pGdG9HWnqSAu1FajmaMJMK5JyOJaiFyW/view?usp=sharing) | 224x224 | 80.2 | 95.0 | 1.5B/13M | 79 | 80 | ### Finetuning results 81 | #### COCO Object detection 82 | - The following results are trained with **Faster RCNN with FPN**: 83 | 84 | | Backbone |Img. Size| B_AP (%) | B_AP_0.5 (%) | B_AP_0.75 (%) | Params. |FLOPs | Eval. set| 85 | |:----:|:----:|:----:|:----:|:----:|:---:|:---:|:---:| 86 | | FBNet-C-FPN | 1200x800 | 35.1 | 57.4 | 37.2 | 21.4M | 119.0B | val2017 | 87 | | EfficientNetB0-FPN | 1200x800 | 38.0 | 60.1 | 40.4 | 21.0M | 123.0B | val2017| 88 | | ReXNet_0.9-FPN | 1200x800 | 38.0 | **60.6** | 40.8 | 20.1M | 123.0B | val2017| 89 | | ReXNet_1.0-FPN | 1200x800 | **38.5** | **60.6** | **41.5** | 20.7M | 124.1B | val2017| 90 | ||||||||| 91 | | ResNet50-FPN | 1200x800 | 37.6| 58.2| 40.9 | 41.8M | 202.2B | val2017| 92 | | ResNeXt-101-FPN | 1200x800 | 40.3 | 62.1 | 44.1 | 60.4M | 272.4B | val2017| 93 | | ReXNet_2.2-FPN | 1200x800| **41.5** | **64.0** | **44.9** | 33.0M | 153.8B | val2017| 94 | 95 | 96 | #### COCO instance segmentation 97 | - The following results are trained with **Mask RCNN with FPN**, S_AP and B_AP denote segmentation AP and box AP, respectively: 98 | 99 | | Backbone |Img. Size| S_AP (%) | S_AP_0.5 (%) | S_AP_0.75 (%) | B_AP (%) | B_AP_0.5 (%) | B_AP_0.75 (%) | Params. |FLOPs | Eval. set| 100 | |:----:|:----:|:----:|:----:|:----:|:---:|:---:|:---:|:---:|:---:|:---:| 101 | | EfficientNetB0_FPN | 1200x800 | 34.8 | 56.8 | 36.6 | 38.4 | 60.2 | 40.8 | 23.7M | 123.0B | val2017| 102 | | ReXNet_0.9-FPN | 1200x800 | **35.2** | **57.4**| **37.1** |**38.7** |**60.8**|**41.6**| 22.8M | 123.0B | val2017| 103 | | ReXNet_1.0-FPN | 1200x800 | 35.4 | 57.7 | 37.4 | 38.9 |61.1 | 42.1 | 23.3M | 124.1B | val2017| 104 | |||||||||||| 105 | | ResNet50-FPN | 1200x800 | 34.6 | 55.9 | 36.8 |38.5 |59.0|41.6| 44.2M | 207B | val2017| 106 | | ReXNet_2.2-FPN | 1200x800 | **37.8** | **61.0** | **40.2** | **42.0** | **64.5** | **45.6**| 35.6M | 153.8B | val2017| 107 | 108 | ## Getting Started 109 | ### Requirements 110 | - Python3 111 | - PyTorch (> 1.0) 112 | - Torchvision (> 0.2) 113 | - NumPy 114 | 115 | ### Using the pretrained models 116 | - [timm>=0.3.0](https://github.com/rwightman/pytorch-image-models) provides the wonderful wrap-up of ours models thanks to [Ross Wightman](https://github.com/rwightman). Otherwise, the models can be loaded as follows: 117 | - To use ReXNet on a GPU: 118 | ```python 119 | import torch 120 | import rexnetv1 121 | 122 | model = rexnetv1.ReXNetV1(width_mult=1.0).cuda() 123 | model.load_state_dict(torch.load('./rexnetv1_1.0.pth')) 124 | model.eval() 125 | print(model(torch.randn(1, 3, 224, 224).cuda())) 126 | ``` 127 | 128 | - To use ReXNet-lite on a CPU: 129 | ```python 130 | import torch 131 | import rexnetv1_lite 132 | 133 | model = rexnetv1_lite.ReXNetV1_lite(multiplier=1.0) 134 | model.load_state_dict(torch.load('./rexnet_lite_1.0.pth', map_location=torch.device('cpu'))) 135 | model.eval() 136 | print(model(torch.randn(1, 3, 224, 224))) 137 | 138 | ``` 139 | 140 | ### Training own ReXNet 141 | 142 | ReXNet can be trained with any PyTorch training codes including [ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet) with the model file and proper arguments. Since the provided model file is not complicated, we simply convert the model to train a ReXNet in other frameworks like MXNet. For MXNet, we recommend [MXnet-gluoncv](https://gluon-cv.mxnet.io/model_zoo/classification.html) as a training code. 143 | 144 | Using PyTorch, we trained ReXNets with one of the popular imagenet classification code, [Ross Wightman](https://github.com/rwightman)'s [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) for more efficient training. After including ReXNet's model file into the training code, one can train ReXNet-1.0x with the following command line: 145 | 146 | ./distributed_train.sh 4 /imagenet/ --model rexnetv1 --rex-width-mult 1.0 --opt sgd --amp \ 147 | --lr 0.5 --weight-decay 1e-5 \ 148 | --batch-size 128 --epochs 400 --sched cosine \ 149 | --remode pixel --reprob 0.2 --drop 0.2 --aa rand-m9-mstd0.5 150 | 151 | Using droppath or MixUP may need to train a bigger model. 152 | 153 | ## License 154 | 155 | This project is distributed under [MIT license](LICENSE). 156 | 157 | 158 | ## How to cite 159 | 160 | ``` 161 | @misc{han2021rethinking, 162 | title={Rethinking Channel Dimensions for Efficient Model Design}, 163 | author={Dongyoon Han and Sangdoo Yun and Byeongho Heo and YoungJoon Yoo}, 164 | year={2021}, 165 | eprint={2007.00992}, 166 | archivePrefix={arXiv}, 167 | primaryClass={cs.CV} 168 | } 169 | ``` 170 | -------------------------------------------------------------------------------- /rexnetv1.py: -------------------------------------------------------------------------------- 1 | """ 2 | ReXNet 3 | Copyright (c) 2020-present NAVER Corp. 4 | MIT license 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from math import ceil 10 | 11 | # Memory-efficient Siwsh using torch.jit.script borrowed from the code in (https://twitter.com/jeremyphoward/status/1188251041835315200) 12 | # Currently use memory-efficient SiLU as default: 13 | USE_MEMORY_EFFICIENT_SiLU = True 14 | 15 | if USE_MEMORY_EFFICIENT_SiLU: 16 | @torch.jit.script 17 | def silu_fwd(x): 18 | return x.mul(torch.sigmoid(x)) 19 | 20 | 21 | @torch.jit.script 22 | def silu_bwd(x, grad_output): 23 | x_sigmoid = torch.sigmoid(x) 24 | return grad_output * (x_sigmoid * (1. + x * (1. - x_sigmoid))) 25 | 26 | 27 | class SiLUJitImplementation(torch.autograd.Function): 28 | @staticmethod 29 | def forward(ctx, x): 30 | ctx.save_for_backward(x) 31 | return silu_fwd(x) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | x = ctx.saved_tensors[0] 36 | return silu_bwd(x, grad_output) 37 | 38 | 39 | def silu(x, inplace=False): 40 | return SiLUJitImplementation.apply(x) 41 | 42 | else: 43 | def silu(x, inplace=False): 44 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 45 | 46 | 47 | class SiLU(nn.Module): 48 | def __init__(self, inplace=True): 49 | super(SiLU, self).__init__() 50 | self.inplace = inplace 51 | 52 | def forward(self, x): 53 | return silu(x, self.inplace) 54 | 55 | 56 | def ConvBNAct(out, in_channels, channels, kernel=1, stride=1, pad=0, 57 | num_group=1, active=True, relu6=False): 58 | out.append(nn.Conv2d(in_channels, channels, kernel, 59 | stride, pad, groups=num_group, bias=False)) 60 | out.append(nn.BatchNorm2d(channels)) 61 | if active: 62 | out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True)) 63 | 64 | 65 | def ConvBNSiLU(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1): 66 | out.append(nn.Conv2d(in_channels, channels, kernel, 67 | stride, pad, groups=num_group, bias=False)) 68 | out.append(nn.BatchNorm2d(channels)) 69 | out.append(SiLU(inplace=True)) 70 | 71 | 72 | class SE(nn.Module): 73 | def __init__(self, in_channels, channels, se_ratio=12): 74 | super(SE, self).__init__() 75 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 76 | self.fc = nn.Sequential( 77 | nn.Conv2d(in_channels, channels // se_ratio, kernel_size=1, padding=0), 78 | nn.BatchNorm2d(channels // se_ratio), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0), 81 | nn.Sigmoid() 82 | ) 83 | 84 | def forward(self, x): 85 | y = self.avg_pool(x) 86 | y = self.fc(y) 87 | return x * y 88 | 89 | 90 | class LinearBottleneck(nn.Module): 91 | def __init__(self, in_channels, channels, t, stride, use_se=True, se_ratio=12, 92 | **kwargs): 93 | super(LinearBottleneck, self).__init__(**kwargs) 94 | self.use_shortcut = stride == 1 and in_channels <= channels 95 | self.in_channels = in_channels 96 | self.out_channels = channels 97 | 98 | out = [] 99 | if t != 1: 100 | dw_channels = in_channels * t 101 | ConvBNSiLU(out, in_channels=in_channels, channels=dw_channels) 102 | else: 103 | dw_channels = in_channels 104 | 105 | ConvBNAct(out, in_channels=dw_channels, channels=dw_channels, kernel=3, stride=stride, pad=1, 106 | num_group=dw_channels, active=False) 107 | 108 | if use_se: 109 | out.append(SE(dw_channels, dw_channels, se_ratio)) 110 | 111 | out.append(nn.ReLU6()) 112 | ConvBNAct(out, in_channels=dw_channels, channels=channels, active=False, relu6=True) 113 | self.out = nn.Sequential(*out) 114 | 115 | def forward(self, x): 116 | out = self.out(x) 117 | if self.use_shortcut: 118 | out[:, 0:self.in_channels] += x 119 | 120 | return out 121 | 122 | 123 | class ReXNetV1(nn.Module): 124 | def __init__(self, input_ch=16, final_ch=180, width_mult=1.0, depth_mult=1.0, classes=1000, 125 | use_se=True, 126 | se_ratio=12, 127 | dropout_ratio=0.2, 128 | bn_momentum=0.9): 129 | super(ReXNetV1, self).__init__() 130 | 131 | layers = [1, 2, 2, 3, 3, 5] 132 | strides = [1, 2, 2, 2, 1, 2] 133 | use_ses = [False, False, True, True, True, True] 134 | 135 | layers = [ceil(element * depth_mult) for element in layers] 136 | strides = sum([[element] + [1] * (layers[idx] - 1) 137 | for idx, element in enumerate(strides)], []) 138 | if use_se: 139 | use_ses = sum([[element] * layers[idx] for idx, element in enumerate(use_ses)], []) 140 | else: 141 | use_ses = [False] * sum(layers[:]) 142 | ts = [1] * layers[0] + [6] * sum(layers[1:]) 143 | 144 | self.depth = sum(layers[:]) * 3 145 | stem_channel = 32 / width_mult if width_mult < 1.0 else 32 146 | inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch 147 | 148 | features = [] 149 | in_channels_group = [] 150 | channels_group = [] 151 | 152 | # The following channel configuration is a simple instance to make each layer become an expand layer. 153 | for i in range(self.depth // 3): 154 | if i == 0: 155 | in_channels_group.append(int(round(stem_channel * width_mult))) 156 | channels_group.append(int(round(inplanes * width_mult))) 157 | else: 158 | in_channels_group.append(int(round(inplanes * width_mult))) 159 | inplanes += final_ch / (self.depth // 3 * 1.0) 160 | channels_group.append(int(round(inplanes * width_mult))) 161 | 162 | ConvBNSiLU(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1) 163 | 164 | for block_idx, (in_c, c, t, s, se) in enumerate(zip(in_channels_group, channels_group, ts, strides, use_ses)): 165 | features.append(LinearBottleneck(in_channels=in_c, 166 | channels=c, 167 | t=t, 168 | stride=s, 169 | use_se=se, se_ratio=se_ratio)) 170 | 171 | pen_channels = int(1280 * width_mult) 172 | ConvBNSiLU(features, c, pen_channels) 173 | 174 | features.append(nn.AdaptiveAvgPool2d(1)) 175 | self.features = nn.Sequential(*features) 176 | self.output = nn.Sequential( 177 | nn.Dropout(dropout_ratio), 178 | nn.Conv2d(pen_channels, classes, 1, bias=True)) 179 | 180 | def extract_features(self, x): 181 | return self.features[:-1](x) 182 | 183 | def forward(self, x): 184 | x = self.features(x) 185 | x = self.output(x).flatten(1) 186 | return x 187 | 188 | 189 | if __name__ == '__main__': 190 | model = ReXNetV1(width_mult=1.0) 191 | out = model(torch.randn(2, 3, 224, 224)) 192 | loss = out.sum() 193 | loss.backward() 194 | print('Checked a single forward/backward iteration') 195 | -------------------------------------------------------------------------------- /rexnetv1_lite.py: -------------------------------------------------------------------------------- 1 | """ 2 | ReXNet_lite 3 | Copyright (c) 2021-present NAVER Corp. 4 | MIT license 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from math import ceil 10 | 11 | 12 | def _make_divisible(channel_size, divisor=None, min_value=None): 13 | """ 14 | This function is taken from the original tf repo. 15 | It ensures that all layers have a channel number that is divisible by 8 16 | It can be seen here: 17 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 18 | """ 19 | if not divisor: 20 | return channel_size 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_channel_size = max(min_value, int(channel_size + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than 10%. 26 | if new_channel_size < 0.9 * channel_size: 27 | new_channel_size += divisor 28 | return new_channel_size 29 | 30 | 31 | def _add_conv(out, in_channels, channels, kernel=1, stride=1, pad=0, 32 | num_group=1, active=True, relu6=True, bn_momentum=0.1, bn_eps=1e-5): 33 | out.append(nn.Conv2d(in_channels, channels, kernel, stride, pad, groups=num_group, bias=False)) 34 | out.append(nn.BatchNorm2d(channels, momentum=bn_momentum, eps=bn_eps)) 35 | if active: 36 | out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True)) 37 | 38 | 39 | class LinearBottleneck(nn.Module): 40 | def __init__(self, in_channels, channels, t, kernel_size=3, stride=1, 41 | bn_momentum=0.1, bn_eps=1e-5, 42 | **kwargs): 43 | super(LinearBottleneck, self).__init__(**kwargs) 44 | self.conv_shortcut = None 45 | self.use_shortcut = stride == 1 and in_channels <= channels 46 | self.in_channels = in_channels 47 | self.out_channels = channels 48 | out = [] 49 | if t != 1: 50 | dw_channels = in_channels * t 51 | _add_conv(out, in_channels=in_channels, channels=dw_channels, bn_momentum=bn_momentum, 52 | bn_eps=bn_eps) 53 | else: 54 | dw_channels = in_channels 55 | 56 | _add_conv(out, in_channels=dw_channels, channels=dw_channels * 1, kernel=kernel_size, stride=stride, 57 | pad=(kernel_size // 2), 58 | num_group=dw_channels, bn_momentum=bn_momentum, bn_eps=bn_eps) 59 | 60 | _add_conv(out, in_channels=dw_channels, channels=channels, active=False, bn_momentum=bn_momentum, 61 | bn_eps=bn_eps) 62 | 63 | self.out = nn.Sequential(*out) 64 | 65 | def forward(self, x): 66 | out = self.out(x) 67 | 68 | if self.use_shortcut: 69 | out[:, 0:self.in_channels] += x 70 | return out 71 | 72 | 73 | class ReXNetV1_lite(nn.Module): 74 | def __init__(self, fix_head_stem=False, divisible_value=8, 75 | input_ch=16, final_ch=164, multiplier=1.0, classes=1000, 76 | dropout_ratio=0.2, 77 | bn_momentum=0.1, 78 | bn_eps=1e-5, kernel_conf='333333'): 79 | super(ReXNetV1_lite, self).__init__() 80 | 81 | layers = [1, 2, 2, 3, 3, 5] 82 | strides = [1, 2, 2, 2, 1, 2] 83 | kernel_sizes = [int(element) for element in kernel_conf] 84 | 85 | strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], []) 86 | ts = [1] * layers[0] + [6] * sum(layers[1:]) 87 | kernel_sizes = sum([[element] * layers[idx] for idx, element in enumerate(kernel_sizes)], []) 88 | self.num_convblocks = sum(layers[:]) 89 | 90 | features = [] 91 | inplanes = input_ch / multiplier if multiplier < 1.0 else input_ch 92 | first_channel = 32 / multiplier if multiplier < 1.0 or fix_head_stem else 32 93 | first_channel = _make_divisible(int(round(first_channel * multiplier)), divisible_value) 94 | 95 | in_channels_group = [] 96 | channels_group = [] 97 | 98 | _add_conv(features, 3, first_channel, kernel=3, stride=2, pad=1, 99 | bn_momentum=bn_momentum, bn_eps=bn_eps) 100 | 101 | for i in range(self.num_convblocks): 102 | inplanes_divisible = _make_divisible(int(round(inplanes * multiplier)), divisible_value) 103 | if i == 0: 104 | in_channels_group.append(first_channel) 105 | channels_group.append(inplanes_divisible) 106 | else: 107 | in_channels_group.append(inplanes_divisible) 108 | inplanes += final_ch / (self.num_convblocks - 1 * 1.0) 109 | inplanes_divisible = _make_divisible(int(round(inplanes * multiplier)), divisible_value) 110 | channels_group.append(inplanes_divisible) 111 | 112 | for block_idx, (in_c, c, t, k, s) in enumerate( 113 | zip(in_channels_group, channels_group, ts, kernel_sizes, strides)): 114 | features.append(LinearBottleneck(in_channels=in_c, 115 | channels=c, 116 | t=t, 117 | kernel_size=k, 118 | stride=s, 119 | bn_momentum=bn_momentum, 120 | bn_eps=bn_eps)) 121 | 122 | pen_channels = int(1280 * multiplier) if multiplier > 1 and not fix_head_stem else 1280 123 | _add_conv(features, c, pen_channels, bn_momentum=bn_momentum, bn_eps=bn_eps) 124 | 125 | self.features = nn.Sequential(*features) 126 | self.avgpool = nn.AdaptiveAvgPool2d(1) 127 | 128 | self.output = nn.Sequential( 129 | nn.Conv2d(pen_channels, 1024, 1, bias=True), 130 | nn.BatchNorm2d(1024, momentum=bn_momentum, eps=bn_eps), 131 | nn.ReLU6(inplace=True), 132 | nn.Dropout(dropout_ratio), 133 | nn.Conv2d(1024, classes, 1, bias=True)) 134 | 135 | def forward(self, x): 136 | x = self.features(x) 137 | x = self.avgpool(x) 138 | x = self.output(x).flatten(1) 139 | return x 140 | 141 | if __name__ == '__main__': 142 | model = ReXNetV1_lite(multiplier=1.0) 143 | out = model(torch.randn(2, 3, 224, 224)) 144 | loss = out.sum() 145 | loss.backward() 146 | print('Checked a single forward/backward iteration') 147 | 148 | --------------------------------------------------------------------------------