├── LICENSE
├── README.md
├── rexnetv1.py
└── rexnetv1_lite.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020-present NAVER Corp.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #### (NOTICE) All the ReXNet-lite's model files have been updated!
2 | #### (NOTICE) Our paper has been accepted at CVPR 2021!! The paper has been updated at [arxiv](https://arxiv.org/pdf/2007.00992.pdf)!
3 |
4 | ## Rethinking Channel Dimensions for Efficient Model Design
5 |
6 | **Dongyoon Han, Sangdoo Yun, Byeongho Heo, and YoungJoon Yoo** | [Paper](https://arxiv.org/abs/2007.00992) | [Pretrained Models](#pretrained)
7 |
8 | NAVER AI Lab
9 |
10 | ## Abstract
11 |
12 | Designing an efficient model within the limited computational cost is challenging. We argue the accuracy of a lightweight model has been further limited by the design convention: a stage-wise configuration of the channel dimensions, which looks like a piecewise linear function of the network stage. In this paper, we study an effective channel dimension configuration towards better performance than the convention. To this end, we empirically study how to design a single layer properly by analyzing the rank of the output feature. We then investigate the channel configuration of a model by searching network architectures concerning the channel configuration under the computational cost restriction. Based on the investigation, we propose a simple yet effective channel configuration that can be parameterized by the layer index. As a result, our proposed model following the channel parameterization achieves remarkable performance on ImageNet classification and transfer learning tasks including COCO object detection, COCO instance segmentation, and fine-grained classifications.
13 |
14 | ## Model performance
15 | - We first illustrate our models' top-acc. vs. computational costs graphs compared with EfficientNets
16 |
17 |
18 |
19 |
20 |
21 | ### Performance comparison
22 | #### ReXNets vs EfficientNets
23 | - The CPU latencies are tested on Xeon E5-2630_v4 with a single image and the GPU latencies are measured on a V100 GPU with **the batchsize of 64**.
24 | - EfficientNets' scores are taken form [arxiv v3 of the paper](https://arxiv.org/pdf/1905.11946v3.pdf).
25 |
26 | Model | Input Res. | Top-1 acc. | Top-5 acc. | FLOPs/params. | CPU Lat./ GPU Lat.
27 | :--: |:--:|:--:|:--:|:--:|:--:|
28 | **ReXNet_0.9** | 224x224 | 77.2 | 93.5 | 0.35B/4.1M | 45ms/20ms
29 | |||||
30 | EfficientNet-B0 | 224x224 | 77.3 | 93.5 | 0.39B/5.3M | 47ms/23ms
31 | **ReXNet_1.0** | 224x224 | 77.9 | 93.9 | 0.40B/4.8M | 47ms/21ms
32 | |||||
33 | EfficientNet-B1 | 240x240 | 79.2 | 94.5 | 0.70B/7.8M | 70ms/37ms
34 | **ReXNet_1.3** | 224x224 | 79.5 | 94.7| 0.66B/7.6M | 55ms/28ms
35 | |||||
36 | EfficientNet-B2 | 260x260 | 80.3 | 95.0 | 1.0B/9.2M | 77ms/48ms
37 | **ReXNet_1.5** | 224x224 | 80.3 | 95.2| 0.88B/9.7M | 59ms/31ms
38 | |||||
39 | EfficientNet-B3 | 300x300 | 81.7 | 95.6 | 1.8B/12M | 100ms/78ms
40 | **ReXNet_2.0** | 224x224 | 81.6 | 95.7 | 1.8B/19M | 69ms/40ms
41 |
42 | #### ReXNet-lites vs. EfficientNet-lites
43 | - ReXNet-lites do not use SE-net an SiLU activations aiming to faster training and inference speed.
44 | - We compare ReXNet-lites with [EfficientNet-lites](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite).
45 | - Here the GPU latencies are measured on two M40 GPUs, we will update the number run on a V100 GPU soon.
46 |
47 | Model | Input Res. | Top-1 acc. | Top-5 acc. | FLOPs/params | CPU Lat./ GPU Lat.
48 | :--: |:--:|:--:|:--:|:--:|:--:|
49 | EfficientNet-lite0 | 224x224 | 75.1 | - | 0.41B/4.7M | 30ms/49ms
50 | **ReXNet-lite_1.0** | 224x224 | 76.2 | 92.8 | 0.41B/4.7M | 31ms/49ms
51 | |||||
52 | EfficientNet-lite1 | 240x240 | 76.7 | - | 0.63B/5.4M | 44ms/73ms
53 | **ReXNet-lite_1.3** | 224x224 | 77.8 | 93.8 | 0.65B/6.8M | 36ms/61ms
54 | |||||
55 | EfficientNet-lite2 | 260x260 | 77.6 | - | 0.90B/ 6.1M | 48ms/93ms
56 | **ReXNet-lite_1.5** | 224x224 | 78.6 | 94.2| 0.84B/8.3M| 39ms/68ms
57 | |||||
58 | EfficientNet-lite3 | 280x280| 79.8 | - | 1.4B/ 8.2M | 60ms/131ms
59 | **ReXNet-lite_2.0** | 224x224 | 80.2 | 95.0 | 1.5B/13M | 49ms/90ms
60 |
61 | ## ImageNet-1k Pretrained models
62 |
ImageNet classification results
63 |
64 | - Please refer the following pretrained models. Top-1 and top-5 accuraies are reported with the computational costs.
65 | - Note that all the models are trained and evaluated with 224x224 image size.
66 |
67 | Model | Input Res. | Top-1 acc. | Top-5 acc. | FLOPs/params |
68 | :--: |:--:|:--:|:--:|:--:
69 | [ReXNet_1.0](https://drive.google.com/file/d/1xeIJ3wb83uOowU008ykYj6wDX2dsncA9/view?usp=sharing) | 224x224 | 77.9 | 93.9 | 0.40B/4.8M |
70 | [ReXNet_1.3](https://drive.google.com/file/d/1x2ziK9Oyv66Y9NsxJxXsdjzpQF2uSJj0/view?usp=sharing) | 224x224 | 79.5 | 94.7 | 0.66B/7.6M |
71 | [ReXNet_1.5](https://drive.google.com/file/d/1TOBGsbDhTHWBgqcRnyKIR0tHsJTOPUIG/view?usp=sharing) | 224x224 | 80.3 | 95.2 | 0.88B/9.7M |
72 | [ReXNet_2.0](https://drive.google.com/file/d/1R1aOTKIe1Mvck86NanqcjWnlR8DY-Z4C/view?usp=sharing) | 224x224 | 81.6 | 95.7 | 1.5B/16M |
73 | [ReXNet_3.0](https://drive.google.com/file/d/1iXAsr8gs3pRz0QyHKomdj5SGVzPWbIs2/view?usp=sharing) | 224x224 | 82.8 | 96.2 | 3.4B/34M |
74 | ||||
75 | [ReXNet-lite_1.0](https://drive.google.com/file/d/1d9G4pLwZwkoDR2TRPCQlxiWiuC7R-Oqf/view?usp=sharing) | 224x224 | 76.2 | 92.8 | 0.41B/4.7M |
76 | [ReXNet-lite_1.3](https://drive.google.com/file/d/1NsbsdI8qAHG6HdMxmySXcrl9NdEx3s0L/view?usp=sharing) | 224x224 | 77.8 | 93.8 | 0.65B/6.8M |
77 | [ReXNet-lite_1.5](https://drive.google.com/file/d/12QzIh9A-U0PBGaLNOIr4gX2MoZEBnRjk/view?usp=sharing) | 224x224 | 78.6 | 94.2 | 0.84B/8.3M|
78 | [ReXNet-lite_2.0](https://drive.google.com/file/d/1pGdG9HWnqSAu1FajmaMJMK5JyOJaiFyW/view?usp=sharing) | 224x224 | 80.2 | 95.0 | 1.5B/13M |
79 |
80 | ### Finetuning results
81 | #### COCO Object detection
82 | - The following results are trained with **Faster RCNN with FPN**:
83 |
84 | | Backbone |Img. Size| B_AP (%) | B_AP_0.5 (%) | B_AP_0.75 (%) | Params. |FLOPs | Eval. set|
85 | |:----:|:----:|:----:|:----:|:----:|:---:|:---:|:---:|
86 | | FBNet-C-FPN | 1200x800 | 35.1 | 57.4 | 37.2 | 21.4M | 119.0B | val2017 |
87 | | EfficientNetB0-FPN | 1200x800 | 38.0 | 60.1 | 40.4 | 21.0M | 123.0B | val2017|
88 | | ReXNet_0.9-FPN | 1200x800 | 38.0 | **60.6** | 40.8 | 20.1M | 123.0B | val2017|
89 | | ReXNet_1.0-FPN | 1200x800 | **38.5** | **60.6** | **41.5** | 20.7M | 124.1B | val2017|
90 | |||||||||
91 | | ResNet50-FPN | 1200x800 | 37.6| 58.2| 40.9 | 41.8M | 202.2B | val2017|
92 | | ResNeXt-101-FPN | 1200x800 | 40.3 | 62.1 | 44.1 | 60.4M | 272.4B | val2017|
93 | | ReXNet_2.2-FPN | 1200x800| **41.5** | **64.0** | **44.9** | 33.0M | 153.8B | val2017|
94 |
95 |
96 | #### COCO instance segmentation
97 | - The following results are trained with **Mask RCNN with FPN**, S_AP and B_AP denote segmentation AP and box AP, respectively:
98 |
99 | | Backbone |Img. Size| S_AP (%) | S_AP_0.5 (%) | S_AP_0.75 (%) | B_AP (%) | B_AP_0.5 (%) | B_AP_0.75 (%) | Params. |FLOPs | Eval. set|
100 | |:----:|:----:|:----:|:----:|:----:|:---:|:---:|:---:|:---:|:---:|:---:|
101 | | EfficientNetB0_FPN | 1200x800 | 34.8 | 56.8 | 36.6 | 38.4 | 60.2 | 40.8 | 23.7M | 123.0B | val2017|
102 | | ReXNet_0.9-FPN | 1200x800 | **35.2** | **57.4**| **37.1** |**38.7** |**60.8**|**41.6**| 22.8M | 123.0B | val2017|
103 | | ReXNet_1.0-FPN | 1200x800 | 35.4 | 57.7 | 37.4 | 38.9 |61.1 | 42.1 | 23.3M | 124.1B | val2017|
104 | ||||||||||||
105 | | ResNet50-FPN | 1200x800 | 34.6 | 55.9 | 36.8 |38.5 |59.0|41.6| 44.2M | 207B | val2017|
106 | | ReXNet_2.2-FPN | 1200x800 | **37.8** | **61.0** | **40.2** | **42.0** | **64.5** | **45.6**| 35.6M | 153.8B | val2017|
107 |
108 | ## Getting Started
109 | ### Requirements
110 | - Python3
111 | - PyTorch (> 1.0)
112 | - Torchvision (> 0.2)
113 | - NumPy
114 |
115 | ### Using the pretrained models
116 | - [timm>=0.3.0](https://github.com/rwightman/pytorch-image-models) provides the wonderful wrap-up of ours models thanks to [Ross Wightman](https://github.com/rwightman). Otherwise, the models can be loaded as follows:
117 | - To use ReXNet on a GPU:
118 | ```python
119 | import torch
120 | import rexnetv1
121 |
122 | model = rexnetv1.ReXNetV1(width_mult=1.0).cuda()
123 | model.load_state_dict(torch.load('./rexnetv1_1.0.pth'))
124 | model.eval()
125 | print(model(torch.randn(1, 3, 224, 224).cuda()))
126 | ```
127 |
128 | - To use ReXNet-lite on a CPU:
129 | ```python
130 | import torch
131 | import rexnetv1_lite
132 |
133 | model = rexnetv1_lite.ReXNetV1_lite(multiplier=1.0)
134 | model.load_state_dict(torch.load('./rexnet_lite_1.0.pth', map_location=torch.device('cpu')))
135 | model.eval()
136 | print(model(torch.randn(1, 3, 224, 224)))
137 |
138 | ```
139 |
140 | ### Training own ReXNet
141 |
142 | ReXNet can be trained with any PyTorch training codes including [ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet) with the model file and proper arguments. Since the provided model file is not complicated, we simply convert the model to train a ReXNet in other frameworks like MXNet. For MXNet, we recommend [MXnet-gluoncv](https://gluon-cv.mxnet.io/model_zoo/classification.html) as a training code.
143 |
144 | Using PyTorch, we trained ReXNets with one of the popular imagenet classification code, [Ross Wightman](https://github.com/rwightman)'s [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) for more efficient training. After including ReXNet's model file into the training code, one can train ReXNet-1.0x with the following command line:
145 |
146 | ./distributed_train.sh 4 /imagenet/ --model rexnetv1 --rex-width-mult 1.0 --opt sgd --amp \
147 | --lr 0.5 --weight-decay 1e-5 \
148 | --batch-size 128 --epochs 400 --sched cosine \
149 | --remode pixel --reprob 0.2 --drop 0.2 --aa rand-m9-mstd0.5
150 |
151 | Using droppath or MixUP may need to train a bigger model.
152 |
153 | ## License
154 |
155 | This project is distributed under [MIT license](LICENSE).
156 |
157 |
158 | ## How to cite
159 |
160 | ```
161 | @misc{han2021rethinking,
162 | title={Rethinking Channel Dimensions for Efficient Model Design},
163 | author={Dongyoon Han and Sangdoo Yun and Byeongho Heo and YoungJoon Yoo},
164 | year={2021},
165 | eprint={2007.00992},
166 | archivePrefix={arXiv},
167 | primaryClass={cs.CV}
168 | }
169 | ```
170 |
--------------------------------------------------------------------------------
/rexnetv1.py:
--------------------------------------------------------------------------------
1 | """
2 | ReXNet
3 | Copyright (c) 2020-present NAVER Corp.
4 | MIT license
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | from math import ceil
10 |
11 | # Memory-efficient Siwsh using torch.jit.script borrowed from the code in (https://twitter.com/jeremyphoward/status/1188251041835315200)
12 | # Currently use memory-efficient SiLU as default:
13 | USE_MEMORY_EFFICIENT_SiLU = True
14 |
15 | if USE_MEMORY_EFFICIENT_SiLU:
16 | @torch.jit.script
17 | def silu_fwd(x):
18 | return x.mul(torch.sigmoid(x))
19 |
20 |
21 | @torch.jit.script
22 | def silu_bwd(x, grad_output):
23 | x_sigmoid = torch.sigmoid(x)
24 | return grad_output * (x_sigmoid * (1. + x * (1. - x_sigmoid)))
25 |
26 |
27 | class SiLUJitImplementation(torch.autograd.Function):
28 | @staticmethod
29 | def forward(ctx, x):
30 | ctx.save_for_backward(x)
31 | return silu_fwd(x)
32 |
33 | @staticmethod
34 | def backward(ctx, grad_output):
35 | x = ctx.saved_tensors[0]
36 | return silu_bwd(x, grad_output)
37 |
38 |
39 | def silu(x, inplace=False):
40 | return SiLUJitImplementation.apply(x)
41 |
42 | else:
43 | def silu(x, inplace=False):
44 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
45 |
46 |
47 | class SiLU(nn.Module):
48 | def __init__(self, inplace=True):
49 | super(SiLU, self).__init__()
50 | self.inplace = inplace
51 |
52 | def forward(self, x):
53 | return silu(x, self.inplace)
54 |
55 |
56 | def ConvBNAct(out, in_channels, channels, kernel=1, stride=1, pad=0,
57 | num_group=1, active=True, relu6=False):
58 | out.append(nn.Conv2d(in_channels, channels, kernel,
59 | stride, pad, groups=num_group, bias=False))
60 | out.append(nn.BatchNorm2d(channels))
61 | if active:
62 | out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))
63 |
64 |
65 | def ConvBNSiLU(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1):
66 | out.append(nn.Conv2d(in_channels, channels, kernel,
67 | stride, pad, groups=num_group, bias=False))
68 | out.append(nn.BatchNorm2d(channels))
69 | out.append(SiLU(inplace=True))
70 |
71 |
72 | class SE(nn.Module):
73 | def __init__(self, in_channels, channels, se_ratio=12):
74 | super(SE, self).__init__()
75 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
76 | self.fc = nn.Sequential(
77 | nn.Conv2d(in_channels, channels // se_ratio, kernel_size=1, padding=0),
78 | nn.BatchNorm2d(channels // se_ratio),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0),
81 | nn.Sigmoid()
82 | )
83 |
84 | def forward(self, x):
85 | y = self.avg_pool(x)
86 | y = self.fc(y)
87 | return x * y
88 |
89 |
90 | class LinearBottleneck(nn.Module):
91 | def __init__(self, in_channels, channels, t, stride, use_se=True, se_ratio=12,
92 | **kwargs):
93 | super(LinearBottleneck, self).__init__(**kwargs)
94 | self.use_shortcut = stride == 1 and in_channels <= channels
95 | self.in_channels = in_channels
96 | self.out_channels = channels
97 |
98 | out = []
99 | if t != 1:
100 | dw_channels = in_channels * t
101 | ConvBNSiLU(out, in_channels=in_channels, channels=dw_channels)
102 | else:
103 | dw_channels = in_channels
104 |
105 | ConvBNAct(out, in_channels=dw_channels, channels=dw_channels, kernel=3, stride=stride, pad=1,
106 | num_group=dw_channels, active=False)
107 |
108 | if use_se:
109 | out.append(SE(dw_channels, dw_channels, se_ratio))
110 |
111 | out.append(nn.ReLU6())
112 | ConvBNAct(out, in_channels=dw_channels, channels=channels, active=False, relu6=True)
113 | self.out = nn.Sequential(*out)
114 |
115 | def forward(self, x):
116 | out = self.out(x)
117 | if self.use_shortcut:
118 | out[:, 0:self.in_channels] += x
119 |
120 | return out
121 |
122 |
123 | class ReXNetV1(nn.Module):
124 | def __init__(self, input_ch=16, final_ch=180, width_mult=1.0, depth_mult=1.0, classes=1000,
125 | use_se=True,
126 | se_ratio=12,
127 | dropout_ratio=0.2,
128 | bn_momentum=0.9):
129 | super(ReXNetV1, self).__init__()
130 |
131 | layers = [1, 2, 2, 3, 3, 5]
132 | strides = [1, 2, 2, 2, 1, 2]
133 | use_ses = [False, False, True, True, True, True]
134 |
135 | layers = [ceil(element * depth_mult) for element in layers]
136 | strides = sum([[element] + [1] * (layers[idx] - 1)
137 | for idx, element in enumerate(strides)], [])
138 | if use_se:
139 | use_ses = sum([[element] * layers[idx] for idx, element in enumerate(use_ses)], [])
140 | else:
141 | use_ses = [False] * sum(layers[:])
142 | ts = [1] * layers[0] + [6] * sum(layers[1:])
143 |
144 | self.depth = sum(layers[:]) * 3
145 | stem_channel = 32 / width_mult if width_mult < 1.0 else 32
146 | inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch
147 |
148 | features = []
149 | in_channels_group = []
150 | channels_group = []
151 |
152 | # The following channel configuration is a simple instance to make each layer become an expand layer.
153 | for i in range(self.depth // 3):
154 | if i == 0:
155 | in_channels_group.append(int(round(stem_channel * width_mult)))
156 | channels_group.append(int(round(inplanes * width_mult)))
157 | else:
158 | in_channels_group.append(int(round(inplanes * width_mult)))
159 | inplanes += final_ch / (self.depth // 3 * 1.0)
160 | channels_group.append(int(round(inplanes * width_mult)))
161 |
162 | ConvBNSiLU(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1)
163 |
164 | for block_idx, (in_c, c, t, s, se) in enumerate(zip(in_channels_group, channels_group, ts, strides, use_ses)):
165 | features.append(LinearBottleneck(in_channels=in_c,
166 | channels=c,
167 | t=t,
168 | stride=s,
169 | use_se=se, se_ratio=se_ratio))
170 |
171 | pen_channels = int(1280 * width_mult)
172 | ConvBNSiLU(features, c, pen_channels)
173 |
174 | features.append(nn.AdaptiveAvgPool2d(1))
175 | self.features = nn.Sequential(*features)
176 | self.output = nn.Sequential(
177 | nn.Dropout(dropout_ratio),
178 | nn.Conv2d(pen_channels, classes, 1, bias=True))
179 |
180 | def extract_features(self, x):
181 | return self.features[:-1](x)
182 |
183 | def forward(self, x):
184 | x = self.features(x)
185 | x = self.output(x).flatten(1)
186 | return x
187 |
188 |
189 | if __name__ == '__main__':
190 | model = ReXNetV1(width_mult=1.0)
191 | out = model(torch.randn(2, 3, 224, 224))
192 | loss = out.sum()
193 | loss.backward()
194 | print('Checked a single forward/backward iteration')
195 |
--------------------------------------------------------------------------------
/rexnetv1_lite.py:
--------------------------------------------------------------------------------
1 | """
2 | ReXNet_lite
3 | Copyright (c) 2021-present NAVER Corp.
4 | MIT license
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | from math import ceil
10 |
11 |
12 | def _make_divisible(channel_size, divisor=None, min_value=None):
13 | """
14 | This function is taken from the original tf repo.
15 | It ensures that all layers have a channel number that is divisible by 8
16 | It can be seen here:
17 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
18 | """
19 | if not divisor:
20 | return channel_size
21 |
22 | if min_value is None:
23 | min_value = divisor
24 | new_channel_size = max(min_value, int(channel_size + divisor / 2) // divisor * divisor)
25 | # Make sure that round down does not go down by more than 10%.
26 | if new_channel_size < 0.9 * channel_size:
27 | new_channel_size += divisor
28 | return new_channel_size
29 |
30 |
31 | def _add_conv(out, in_channels, channels, kernel=1, stride=1, pad=0,
32 | num_group=1, active=True, relu6=True, bn_momentum=0.1, bn_eps=1e-5):
33 | out.append(nn.Conv2d(in_channels, channels, kernel, stride, pad, groups=num_group, bias=False))
34 | out.append(nn.BatchNorm2d(channels, momentum=bn_momentum, eps=bn_eps))
35 | if active:
36 | out.append(nn.ReLU6(inplace=True) if relu6 else nn.ReLU(inplace=True))
37 |
38 |
39 | class LinearBottleneck(nn.Module):
40 | def __init__(self, in_channels, channels, t, kernel_size=3, stride=1,
41 | bn_momentum=0.1, bn_eps=1e-5,
42 | **kwargs):
43 | super(LinearBottleneck, self).__init__(**kwargs)
44 | self.conv_shortcut = None
45 | self.use_shortcut = stride == 1 and in_channels <= channels
46 | self.in_channels = in_channels
47 | self.out_channels = channels
48 | out = []
49 | if t != 1:
50 | dw_channels = in_channels * t
51 | _add_conv(out, in_channels=in_channels, channels=dw_channels, bn_momentum=bn_momentum,
52 | bn_eps=bn_eps)
53 | else:
54 | dw_channels = in_channels
55 |
56 | _add_conv(out, in_channels=dw_channels, channels=dw_channels * 1, kernel=kernel_size, stride=stride,
57 | pad=(kernel_size // 2),
58 | num_group=dw_channels, bn_momentum=bn_momentum, bn_eps=bn_eps)
59 |
60 | _add_conv(out, in_channels=dw_channels, channels=channels, active=False, bn_momentum=bn_momentum,
61 | bn_eps=bn_eps)
62 |
63 | self.out = nn.Sequential(*out)
64 |
65 | def forward(self, x):
66 | out = self.out(x)
67 |
68 | if self.use_shortcut:
69 | out[:, 0:self.in_channels] += x
70 | return out
71 |
72 |
73 | class ReXNetV1_lite(nn.Module):
74 | def __init__(self, fix_head_stem=False, divisible_value=8,
75 | input_ch=16, final_ch=164, multiplier=1.0, classes=1000,
76 | dropout_ratio=0.2,
77 | bn_momentum=0.1,
78 | bn_eps=1e-5, kernel_conf='333333'):
79 | super(ReXNetV1_lite, self).__init__()
80 |
81 | layers = [1, 2, 2, 3, 3, 5]
82 | strides = [1, 2, 2, 2, 1, 2]
83 | kernel_sizes = [int(element) for element in kernel_conf]
84 |
85 | strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], [])
86 | ts = [1] * layers[0] + [6] * sum(layers[1:])
87 | kernel_sizes = sum([[element] * layers[idx] for idx, element in enumerate(kernel_sizes)], [])
88 | self.num_convblocks = sum(layers[:])
89 |
90 | features = []
91 | inplanes = input_ch / multiplier if multiplier < 1.0 else input_ch
92 | first_channel = 32 / multiplier if multiplier < 1.0 or fix_head_stem else 32
93 | first_channel = _make_divisible(int(round(first_channel * multiplier)), divisible_value)
94 |
95 | in_channels_group = []
96 | channels_group = []
97 |
98 | _add_conv(features, 3, first_channel, kernel=3, stride=2, pad=1,
99 | bn_momentum=bn_momentum, bn_eps=bn_eps)
100 |
101 | for i in range(self.num_convblocks):
102 | inplanes_divisible = _make_divisible(int(round(inplanes * multiplier)), divisible_value)
103 | if i == 0:
104 | in_channels_group.append(first_channel)
105 | channels_group.append(inplanes_divisible)
106 | else:
107 | in_channels_group.append(inplanes_divisible)
108 | inplanes += final_ch / (self.num_convblocks - 1 * 1.0)
109 | inplanes_divisible = _make_divisible(int(round(inplanes * multiplier)), divisible_value)
110 | channels_group.append(inplanes_divisible)
111 |
112 | for block_idx, (in_c, c, t, k, s) in enumerate(
113 | zip(in_channels_group, channels_group, ts, kernel_sizes, strides)):
114 | features.append(LinearBottleneck(in_channels=in_c,
115 | channels=c,
116 | t=t,
117 | kernel_size=k,
118 | stride=s,
119 | bn_momentum=bn_momentum,
120 | bn_eps=bn_eps))
121 |
122 | pen_channels = int(1280 * multiplier) if multiplier > 1 and not fix_head_stem else 1280
123 | _add_conv(features, c, pen_channels, bn_momentum=bn_momentum, bn_eps=bn_eps)
124 |
125 | self.features = nn.Sequential(*features)
126 | self.avgpool = nn.AdaptiveAvgPool2d(1)
127 |
128 | self.output = nn.Sequential(
129 | nn.Conv2d(pen_channels, 1024, 1, bias=True),
130 | nn.BatchNorm2d(1024, momentum=bn_momentum, eps=bn_eps),
131 | nn.ReLU6(inplace=True),
132 | nn.Dropout(dropout_ratio),
133 | nn.Conv2d(1024, classes, 1, bias=True))
134 |
135 | def forward(self, x):
136 | x = self.features(x)
137 | x = self.avgpool(x)
138 | x = self.output(x).flatten(1)
139 | return x
140 |
141 | if __name__ == '__main__':
142 | model = ReXNetV1_lite(multiplier=1.0)
143 | out = model(torch.randn(2, 3, 224, 224))
144 | loss = out.sum()
145 | loss.backward()
146 | print('Checked a single forward/backward iteration')
147 |
148 |
--------------------------------------------------------------------------------