├── .gitignore ├── LICENSE ├── README.md └── effnetv2.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Duo Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **[NEW!]** Check out our latest work [involution](https://github.com/d-li14/involution) accepted to CVPR'21 that introduces a new neural operator, other than convolution and self-attention. 2 | 3 | --- 4 | 5 | # PyTorch implementation of EfficientNet V2 6 | 7 | Reproduction of EfficientNet V2 architecture as described in [EfficientNetV2: Smaller Models and Faster Training](https://arxiv.org/abs/2104.00298) by Mingxing Tan, Quoc V. Le with the [PyTorch](pytorch.org) framework. 8 | 9 | ## Models 10 | 11 | | Architecture | # Parameters | FLOPs | Top-1 Acc. (%) | 12 | | ----------------- | ------------ | ------ | -------------------------- | 13 | | EfficientNetV2-S | 22.10M | 8.42G @ 384 | | 14 | | EfficientNetV2-M | 55.30M | 24.74G @ 480 | | 15 | | EfficientNetV2-L | 119.36M | 56.13G @ 480 | | 16 | | EfficientNetV2-XL | 208.96M | 93.41G @ 512 | | 17 | 18 | Stay tuned for ImageNet pre-trained weights. 19 | 20 | ## Acknowledgement 21 | 22 | The implementation is heavily borrowed from [HBONet](https://github.com/d-li14/HBONet) or [MobileNetV2](https://github.com/d-li14/mobilenetv2.pytorch), please kindly consider citing the following 23 | 24 | ``` 25 | @InProceedings{Li_2019_ICCV, 26 | author = {Li, Duo and Zhou, Aojun and Yao, Anbang}, 27 | title = {HBONet: Harmonious Bottleneck on Two Orthogonal Dimensions}, 28 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 29 | month = {Oct}, 30 | year = {2019} 31 | } 32 | ``` 33 | ``` 34 | @InProceedings{Sandler_2018_CVPR, 35 | author = {Sandler, Mark and Howard, Andrew and Zhu, Menglong and Zhmoginov, Andrey and Chen, Liang-Chieh}, 36 | title = {MobileNetV2: Inverted Residuals and Linear Bottlenecks}, 37 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 38 | month = {June}, 39 | year = {2018} 40 | } 41 | ``` 42 | 43 | The official [TensorFlow implementation](https://github.com/google/automl/tree/master/efficientnetv2) by [@mingxingtan](https://github.com/mingxingtan). 44 | -------------------------------------------------------------------------------- /effnetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a EfficientNetV2 Model as defined in: 3 | Mingxing Tan, Quoc V. Le. (2021). 4 | EfficientNetV2: Smaller Models and Faster Training 5 | arXiv preprint arXiv:2104.00298. 6 | import from https://github.com/d-li14/mobilenetv2.pytorch 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import math 12 | 13 | __all__ = ['effnetv2_s', 'effnetv2_m', 'effnetv2_l', 'effnetv2_xl'] 14 | 15 | 16 | def _make_divisible(v, divisor, min_value=None): 17 | """ 18 | This function is taken from the original tf repo. 19 | It ensures that all layers have a channel number that is divisible by 8 20 | It can be seen here: 21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 22 | :param v: 23 | :param divisor: 24 | :param min_value: 25 | :return: 26 | """ 27 | if min_value is None: 28 | min_value = divisor 29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 30 | # Make sure that round down does not go down by more than 10%. 31 | if new_v < 0.9 * v: 32 | new_v += divisor 33 | return new_v 34 | 35 | 36 | # SiLU (Swish) activation function 37 | if hasattr(nn, 'SiLU'): 38 | SiLU = nn.SiLU 39 | else: 40 | # For compatibility with old PyTorch versions 41 | class SiLU(nn.Module): 42 | def forward(self, x): 43 | return x * torch.sigmoid(x) 44 | 45 | 46 | class SELayer(nn.Module): 47 | def __init__(self, inp, oup, reduction=4): 48 | super(SELayer, self).__init__() 49 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 50 | self.fc = nn.Sequential( 51 | nn.Linear(oup, _make_divisible(inp // reduction, 8)), 52 | SiLU(), 53 | nn.Linear(_make_divisible(inp // reduction, 8), oup), 54 | nn.Sigmoid() 55 | ) 56 | 57 | def forward(self, x): 58 | b, c, _, _ = x.size() 59 | y = self.avg_pool(x).view(b, c) 60 | y = self.fc(y).view(b, c, 1, 1) 61 | return x * y 62 | 63 | 64 | def conv_3x3_bn(inp, oup, stride): 65 | return nn.Sequential( 66 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 67 | nn.BatchNorm2d(oup), 68 | SiLU() 69 | ) 70 | 71 | 72 | def conv_1x1_bn(inp, oup): 73 | return nn.Sequential( 74 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 75 | nn.BatchNorm2d(oup), 76 | SiLU() 77 | ) 78 | 79 | 80 | class MBConv(nn.Module): 81 | def __init__(self, inp, oup, stride, expand_ratio, use_se): 82 | super(MBConv, self).__init__() 83 | assert stride in [1, 2] 84 | 85 | hidden_dim = round(inp * expand_ratio) 86 | self.identity = stride == 1 and inp == oup 87 | if use_se: 88 | self.conv = nn.Sequential( 89 | # pw 90 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 91 | nn.BatchNorm2d(hidden_dim), 92 | SiLU(), 93 | # dw 94 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 95 | nn.BatchNorm2d(hidden_dim), 96 | SiLU(), 97 | SELayer(inp, hidden_dim), 98 | # pw-linear 99 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 100 | nn.BatchNorm2d(oup), 101 | ) 102 | else: 103 | self.conv = nn.Sequential( 104 | # fused 105 | nn.Conv2d(inp, hidden_dim, 3, stride, 1, bias=False), 106 | nn.BatchNorm2d(hidden_dim), 107 | SiLU(), 108 | # pw-linear 109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 110 | nn.BatchNorm2d(oup), 111 | ) 112 | 113 | 114 | def forward(self, x): 115 | if self.identity: 116 | return x + self.conv(x) 117 | else: 118 | return self.conv(x) 119 | 120 | 121 | class EffNetV2(nn.Module): 122 | def __init__(self, cfgs, num_classes=1000, width_mult=1.): 123 | super(EffNetV2, self).__init__() 124 | self.cfgs = cfgs 125 | 126 | # building first layer 127 | input_channel = _make_divisible(24 * width_mult, 8) 128 | layers = [conv_3x3_bn(3, input_channel, 2)] 129 | # building inverted residual blocks 130 | block = MBConv 131 | for t, c, n, s, use_se in self.cfgs: 132 | output_channel = _make_divisible(c * width_mult, 8) 133 | for i in range(n): 134 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, use_se)) 135 | input_channel = output_channel 136 | self.features = nn.Sequential(*layers) 137 | # building last several layers 138 | output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792 139 | self.conv = conv_1x1_bn(input_channel, output_channel) 140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 141 | self.classifier = nn.Linear(output_channel, num_classes) 142 | 143 | self._initialize_weights() 144 | 145 | def forward(self, x): 146 | x = self.features(x) 147 | x = self.conv(x) 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.classifier(x) 151 | return x 152 | 153 | def _initialize_weights(self): 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 157 | m.weight.data.normal_(0, math.sqrt(2. / n)) 158 | if m.bias is not None: 159 | m.bias.data.zero_() 160 | elif isinstance(m, nn.BatchNorm2d): 161 | m.weight.data.fill_(1) 162 | m.bias.data.zero_() 163 | elif isinstance(m, nn.Linear): 164 | m.weight.data.normal_(0, 0.001) 165 | m.bias.data.zero_() 166 | 167 | 168 | def effnetv2_s(**kwargs): 169 | """ 170 | Constructs a EfficientNetV2-S model 171 | """ 172 | cfgs = [ 173 | # t, c, n, s, SE 174 | [1, 24, 2, 1, 0], 175 | [4, 48, 4, 2, 0], 176 | [4, 64, 4, 2, 0], 177 | [4, 128, 6, 2, 1], 178 | [6, 160, 9, 1, 1], 179 | [6, 256, 15, 2, 1], 180 | ] 181 | return EffNetV2(cfgs, **kwargs) 182 | 183 | 184 | def effnetv2_m(**kwargs): 185 | """ 186 | Constructs a EfficientNetV2-M model 187 | """ 188 | cfgs = [ 189 | # t, c, n, s, SE 190 | [1, 24, 3, 1, 0], 191 | [4, 48, 5, 2, 0], 192 | [4, 80, 5, 2, 0], 193 | [4, 160, 7, 2, 1], 194 | [6, 176, 14, 1, 1], 195 | [6, 304, 18, 2, 1], 196 | [6, 512, 5, 1, 1], 197 | ] 198 | return EffNetV2(cfgs, **kwargs) 199 | 200 | 201 | def effnetv2_l(**kwargs): 202 | """ 203 | Constructs a EfficientNetV2-L model 204 | """ 205 | cfgs = [ 206 | # t, c, n, s, SE 207 | [1, 32, 4, 1, 0], 208 | [4, 64, 7, 2, 0], 209 | [4, 96, 7, 2, 0], 210 | [4, 192, 10, 2, 1], 211 | [6, 224, 19, 1, 1], 212 | [6, 384, 25, 2, 1], 213 | [6, 640, 7, 1, 1], 214 | ] 215 | return EffNetV2(cfgs, **kwargs) 216 | 217 | 218 | def effnetv2_xl(**kwargs): 219 | """ 220 | Constructs a EfficientNetV2-XL model 221 | """ 222 | cfgs = [ 223 | # t, c, n, s, SE 224 | [1, 32, 4, 1, 0], 225 | [4, 64, 8, 2, 0], 226 | [4, 96, 8, 2, 0], 227 | [4, 192, 16, 2, 1], 228 | [6, 256, 24, 1, 1], 229 | [6, 512, 32, 2, 1], 230 | [6, 640, 8, 1, 1], 231 | ] 232 | return EffNetV2(cfgs, **kwargs) 233 | --------------------------------------------------------------------------------