├── .gitignore ├── BENCHMARK.md ├── LICENSE ├── README.md ├── caffe2_benchmark.py ├── caffe2_validate.py ├── data ├── __init__.py ├── dataset.py ├── loader.py ├── tf_preprocessing.py └── transforms.py ├── geffnet ├── __init__.py ├── activations │ ├── __init__.py │ ├── activations.py │ ├── activations_jit.py │ └── activations_me.py ├── config.py ├── conv2d_layers.py ├── efficientnet_builder.py ├── gen_efficientnet.py ├── helpers.py ├── mobilenetv3.py ├── model_factory.py └── version.py ├── hubconf.py ├── onnx_export.py ├── onnx_optimize.py ├── onnx_to_caffe.py ├── onnx_validate.py ├── requirements.txt ├── setup.py ├── utils.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # pytorch stuff 104 | *.pth 105 | *.onnx 106 | *.pb 107 | 108 | trained_models/ 109 | .fuse_hidden* 110 | -------------------------------------------------------------------------------- /BENCHMARK.md: -------------------------------------------------------------------------------- 1 | # Model Performance Benchmarks 2 | 3 | All benchmarks run as per: 4 | 5 | ``` 6 | python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx 7 | python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx 8 | python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3 9 | python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt 10 | python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb 11 | python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb 12 | ``` 13 | 14 | ## EfficientNet-B0 15 | 16 | ### Unoptimized 17 | ``` 18 | Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897 19 | Time per operator type: 20 | 29.7378 ms. 60.5145%. Conv 21 | 12.1785 ms. 24.7824%. Sigmoid 22 | 3.62811 ms. 7.38297%. SpatialBN 23 | 2.98444 ms. 6.07314%. Mul 24 | 0.326902 ms. 0.665225%. AveragePool 25 | 0.197317 ms. 0.401528%. FC 26 | 0.0852877 ms. 0.173555%. Add 27 | 0.0032607 ms. 0.00663532%. Squeeze 28 | 49.1416 ms in Total 29 | FLOP per operator type: 30 | 0.76907 GFLOP. 95.2696%. Conv 31 | 0.0269508 GFLOP. 3.33857%. SpatialBN 32 | 0.00846444 GFLOP. 1.04855%. Mul 33 | 0.002561 GFLOP. 0.317248%. FC 34 | 0.000210112 GFLOP. 0.0260279%. Add 35 | 0.807256 GFLOP in Total 36 | Feature Memory Read per operator type: 37 | 58.5253 MB. 43.0891%. Mul 38 | 43.2015 MB. 31.807%. Conv 39 | 27.2869 MB. 20.0899%. SpatialBN 40 | 5.12912 MB. 3.77631%. FC 41 | 1.6809 MB. 1.23756%. Add 42 | 135.824 MB in Total 43 | Feature Memory Written per operator type: 44 | 33.8578 MB. 38.1965%. Mul 45 | 26.9881 MB. 30.4465%. Conv 46 | 26.9508 MB. 30.4044%. SpatialBN 47 | 0.840448 MB. 0.948147%. Add 48 | 0.004 MB. 0.00451258%. FC 49 | 88.6412 MB in Total 50 | Parameter Memory per operator type: 51 | 15.8248 MB. 74.9391%. Conv 52 | 5.124 MB. 24.265%. FC 53 | 0.168064 MB. 0.795877%. SpatialBN 54 | 0 MB. 0%. Add 55 | 0 MB. 0%. Mul 56 | 21.1168 MB in Total 57 | ``` 58 | ### Optimized 59 | ``` 60 | Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996 61 | Time per operator type: 62 | 29.776 ms. 65.002%. Conv 63 | 12.2803 ms. 26.8084%. Sigmoid 64 | 3.15073 ms. 6.87815%. Mul 65 | 0.328651 ms. 0.717456%. AveragePool 66 | 0.186237 ms. 0.406563%. FC 67 | 0.0832429 ms. 0.181722%. Add 68 | 0.0026184 ms. 0.00571606%. Squeeze 69 | 45.8078 ms in Total 70 | FLOP per operator type: 71 | 0.76907 GFLOP. 98.5601%. Conv 72 | 0.00846444 GFLOP. 1.08476%. Mul 73 | 0.002561 GFLOP. 0.328205%. FC 74 | 0.000210112 GFLOP. 0.0269269%. Add 75 | 0.780305 GFLOP in Total 76 | Feature Memory Read per operator type: 77 | 58.5253 MB. 53.8803%. Mul 78 | 43.2855 MB. 39.8501%. Conv 79 | 5.12912 MB. 4.72204%. FC 80 | 1.6809 MB. 1.54749%. Add 81 | 108.621 MB in Total 82 | Feature Memory Written per operator type: 83 | 33.8578 MB. 54.8834%. Mul 84 | 26.9881 MB. 43.7477%. Conv 85 | 0.840448 MB. 1.36237%. Add 86 | 0.004 MB. 0.00648399%. FC 87 | 61.6904 MB in Total 88 | Parameter Memory per operator type: 89 | 15.8248 MB. 75.5403%. Conv 90 | 5.124 MB. 24.4597%. FC 91 | 0 MB. 0%. Add 92 | 0 MB. 0%. Mul 93 | 20.9488 MB in Total 94 | ``` 95 | 96 | ## EfficientNet-B1 97 | ### Optimized 98 | ``` 99 | Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256 100 | Time per operator type: 101 | 45.7915 ms. 66.3206%. Conv 102 | 17.8718 ms. 25.8841%. Sigmoid 103 | 4.44132 ms. 6.43244%. Mul 104 | 0.51001 ms. 0.738658%. AveragePool 105 | 0.233283 ms. 0.337868%. Add 106 | 0.194986 ms. 0.282402%. FC 107 | 0.00268255 ms. 0.00388519%. Squeeze 108 | 69.0456 ms in Total 109 | FLOP per operator type: 110 | 1.37105 GFLOP. 98.7673%. Conv 111 | 0.0138759 GFLOP. 0.99959%. Mul 112 | 0.002561 GFLOP. 0.184489%. FC 113 | 0.000674432 GFLOP. 0.0485847%. Add 114 | 1.38816 GFLOP in Total 115 | Feature Memory Read per operator type: 116 | 94.624 MB. 54.0789%. Mul 117 | 69.8255 MB. 39.9062%. Conv 118 | 5.39546 MB. 3.08357%. Add 119 | 5.12912 MB. 2.93136%. FC 120 | 174.974 MB in Total 121 | Feature Memory Written per operator type: 122 | 55.5035 MB. 54.555%. Mul 123 | 43.5333 MB. 42.7894%. Conv 124 | 2.69773 MB. 2.65163%. Add 125 | 0.004 MB. 0.00393165%. FC 126 | 101.739 MB in Total 127 | Parameter Memory per operator type: 128 | 25.7479 MB. 83.4024%. Conv 129 | 5.124 MB. 16.5976%. FC 130 | 0 MB. 0%. Add 131 | 0 MB. 0%. Mul 132 | 30.8719 MB in Total 133 | ``` 134 | 135 | ## EfficientNet-B2 136 | ### Optimized 137 | ``` 138 | Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366 139 | Time per operator type: 140 | 61.4627 ms. 67.5845%. Conv 141 | 22.7458 ms. 25.0113%. Sigmoid 142 | 5.59931 ms. 6.15701%. Mul 143 | 0.642567 ms. 0.706568%. AveragePool 144 | 0.272795 ms. 0.299965%. Add 145 | 0.216178 ms. 0.237709%. FC 146 | 0.00268895 ms. 0.00295677%. Squeeze 147 | 90.942 ms in Total 148 | FLOP per operator type: 149 | 1.98431 GFLOP. 98.9343%. Conv 150 | 0.0177039 GFLOP. 0.882686%. Mul 151 | 0.002817 GFLOP. 0.140451%. FC 152 | 0.000853984 GFLOP. 0.0425782%. Add 153 | 2.00568 GFLOP in Total 154 | Feature Memory Read per operator type: 155 | 120.609 MB. 54.9637%. Mul 156 | 86.3512 MB. 39.3519%. Conv 157 | 6.83187 MB. 3.11341%. Add 158 | 5.64163 MB. 2.571%. FC 159 | 219.433 MB in Total 160 | Feature Memory Written per operator type: 161 | 70.8155 MB. 54.6573%. Mul 162 | 55.3273 MB. 42.7031%. Conv 163 | 3.41594 MB. 2.63651%. Add 164 | 0.004 MB. 0.00308731%. FC 165 | 129.563 MB in Total 166 | Parameter Memory per operator type: 167 | 30.4721 MB. 84.3913%. Conv 168 | 5.636 MB. 15.6087%. FC 169 | 0 MB. 0%. Add 170 | 0 MB. 0%. Mul 171 | 36.1081 MB in Total 172 | ``` 173 | 174 | ## MixNet-M 175 | ### Optimized 176 | ``` 177 | Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448 178 | Time per operator type: 179 | 48.1139 ms. 75.2052%. Conv 180 | 7.1341 ms. 11.1511%. Sigmoid 181 | 2.63706 ms. 4.12189%. SpatialBN 182 | 1.73186 ms. 2.70701%. Mul 183 | 1.38707 ms. 2.16809%. Split 184 | 1.29322 ms. 2.02139%. Concat 185 | 1.00093 ms. 1.56452%. Relu 186 | 0.235309 ms. 0.367803%. Add 187 | 0.221579 ms. 0.346343%. FC 188 | 0.219315 ms. 0.342803%. AveragePool 189 | 0.00250145 ms. 0.00390993%. Squeeze 190 | 63.9768 ms in Total 191 | FLOP per operator type: 192 | 0.675273 GFLOP. 95.5827%. Conv 193 | 0.0221072 GFLOP. 3.12921%. SpatialBN 194 | 0.00538445 GFLOP. 0.762152%. Mul 195 | 0.003073 GFLOP. 0.434973%. FC 196 | 0.000642488 GFLOP. 0.0909421%. Add 197 | 0 GFLOP. 0%. Concat 198 | 0 GFLOP. 0%. Relu 199 | 0.70648 GFLOP in Total 200 | Feature Memory Read per operator type: 201 | 46.8424 MB. 30.502%. Conv 202 | 36.8626 MB. 24.0036%. Mul 203 | 22.3152 MB. 14.5309%. SpatialBN 204 | 22.1074 MB. 14.3955%. Concat 205 | 14.1496 MB. 9.21372%. Relu 206 | 6.15414 MB. 4.00735%. FC 207 | 5.1399 MB. 3.34692%. Add 208 | 153.571 MB in Total 209 | Feature Memory Written per operator type: 210 | 32.7672 MB. 28.4331%. Conv 211 | 22.1072 MB. 19.1831%. Concat 212 | 22.1072 MB. 19.1831%. SpatialBN 213 | 21.5378 MB. 18.689%. Mul 214 | 14.1496 MB. 12.2781%. Relu 215 | 2.56995 MB. 2.23003%. Add 216 | 0.004 MB. 0.00347092%. FC 217 | 115.243 MB in Total 218 | Parameter Memory per operator type: 219 | 13.7059 MB. 68.674%. Conv 220 | 6.148 MB. 30.8049%. FC 221 | 0.104 MB. 0.521097%. SpatialBN 222 | 0 MB. 0%. Add 223 | 0 MB. 0%. Concat 224 | 0 MB. 0%. Mul 225 | 0 MB. 0%. Relu 226 | 19.9579 MB in Total 227 | ``` 228 | 229 | ## TF MobileNet-V3 Large 1.0 230 | 231 | ### Optimized 232 | ``` 233 | Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525 234 | Time per operator type: 235 | 17.437 ms. 80.0087%. Conv 236 | 1.27662 ms. 5.8577%. Add 237 | 1.12759 ms. 5.17387%. Div 238 | 0.701155 ms. 3.21721%. Mul 239 | 0.562654 ms. 2.58171%. Relu 240 | 0.431144 ms. 1.97828%. Clip 241 | 0.156902 ms. 0.719936%. FC 242 | 0.0996858 ms. 0.457402%. AveragePool 243 | 0.00112455 ms. 0.00515993%. Flatten 244 | 21.7939 ms in Total 245 | FLOP per operator type: 246 | 0.43062 GFLOP. 98.1484%. Conv 247 | 0.002561 GFLOP. 0.583713%. FC 248 | 0.00210867 GFLOP. 0.480616%. Mul 249 | 0.00193868 GFLOP. 0.441871%. Add 250 | 0.00151532 GFLOP. 0.345377%. Div 251 | 0 GFLOP. 0%. Relu 252 | 0.438743 GFLOP in Total 253 | Feature Memory Read per operator type: 254 | 34.7967 MB. 43.9391%. Conv 255 | 14.496 MB. 18.3046%. Mul 256 | 9.44828 MB. 11.9307%. Add 257 | 9.26157 MB. 11.6949%. Relu 258 | 6.0614 MB. 7.65395%. Div 259 | 5.12912 MB. 6.47673%. FC 260 | 79.193 MB in Total 261 | Feature Memory Written per operator type: 262 | 17.6247 MB. 35.8656%. Conv 263 | 9.26157 MB. 18.847%. Relu 264 | 8.43469 MB. 17.1643%. Mul 265 | 7.75472 MB. 15.7806%. Add 266 | 6.06128 MB. 12.3345%. Div 267 | 0.004 MB. 0.00813985%. FC 268 | 49.1409 MB in Total 269 | Parameter Memory per operator type: 270 | 16.6851 MB. 76.5052%. Conv 271 | 5.124 MB. 23.4948%. FC 272 | 0 MB. 0%. Add 273 | 0 MB. 0%. Div 274 | 0 MB. 0%. Mul 275 | 0 MB. 0%. Relu 276 | 21.8091 MB in Total 277 | ``` 278 | 279 | ## MobileNet-V3 (RW) 280 | 281 | ### Unoptimized 282 | ``` 283 | Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712 284 | Time per operator type: 285 | 15.9266 ms. 69.2624%. Conv 286 | 2.36551 ms. 10.2873%. SpatialBN 287 | 1.39102 ms. 6.04936%. Add 288 | 1.30327 ms. 5.66773%. Div 289 | 0.737014 ms. 3.20517%. Mul 290 | 0.639697 ms. 2.78195%. Relu 291 | 0.375681 ms. 1.63378%. Clip 292 | 0.153126 ms. 0.665921%. FC 293 | 0.0993787 ms. 0.432184%. AveragePool 294 | 0.0032632 ms. 0.0141912%. Squeeze 295 | 22.9946 ms in Total 296 | FLOP per operator type: 297 | 0.430616 GFLOP. 94.4041%. Conv 298 | 0.0175992 GFLOP. 3.85829%. SpatialBN 299 | 0.002561 GFLOP. 0.561449%. FC 300 | 0.00210961 GFLOP. 0.46249%. Mul 301 | 0.00173891 GFLOP. 0.381223%. Add 302 | 0.00151626 GFLOP. 0.33241%. Div 303 | 0 GFLOP. 0%. Relu 304 | 0.456141 GFLOP in Total 305 | Feature Memory Read per operator type: 306 | 34.7354 MB. 36.4363%. Conv 307 | 17.7944 MB. 18.6658%. SpatialBN 308 | 14.5035 MB. 15.2137%. Mul 309 | 9.25778 MB. 9.71113%. Relu 310 | 7.84641 MB. 8.23064%. Add 311 | 6.06516 MB. 6.36216%. Div 312 | 5.12912 MB. 5.38029%. FC 313 | 95.3317 MB in Total 314 | Feature Memory Written per operator type: 315 | 17.6246 MB. 26.7264%. Conv 316 | 17.5992 MB. 26.6878%. SpatialBN 317 | 9.25778 MB. 14.0387%. Relu 318 | 8.43843 MB. 12.7962%. Mul 319 | 6.95565 MB. 10.5477%. Add 320 | 6.06502 MB. 9.19713%. Div 321 | 0.004 MB. 0.00606568%. FC 322 | 65.9447 MB in Total 323 | Parameter Memory per operator type: 324 | 16.6778 MB. 76.1564%. Conv 325 | 5.124 MB. 23.3979%. FC 326 | 0.0976 MB. 0.445674%. SpatialBN 327 | 0 MB. 0%. Add 328 | 0 MB. 0%. Div 329 | 0 MB. 0%. Mul 330 | 0 MB. 0%. Relu 331 | 21.8994 MB in Total 332 | 333 | ``` 334 | ### Optimized 335 | 336 | ``` 337 | Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527 338 | Time per operator type: 339 | 17.146 ms. 78.8965%. Conv 340 | 1.38453 ms. 6.37084%. Add 341 | 1.30991 ms. 6.02749%. Div 342 | 0.685417 ms. 3.15391%. Mul 343 | 0.532589 ms. 2.45068%. Relu 344 | 0.418263 ms. 1.92461%. Clip 345 | 0.15128 ms. 0.696106%. FC 346 | 0.102065 ms. 0.469648%. AveragePool 347 | 0.0022143 ms. 0.010189%. Squeeze 348 | 21.7323 ms in Total 349 | FLOP per operator type: 350 | 0.430616 GFLOP. 98.1927%. Conv 351 | 0.002561 GFLOP. 0.583981%. FC 352 | 0.00210961 GFLOP. 0.481051%. Mul 353 | 0.00173891 GFLOP. 0.396522%. Add 354 | 0.00151626 GFLOP. 0.34575%. Div 355 | 0 GFLOP. 0%. Relu 356 | 0.438542 GFLOP in Total 357 | Feature Memory Read per operator type: 358 | 34.7842 MB. 44.833%. Conv 359 | 14.5035 MB. 18.6934%. Mul 360 | 9.25778 MB. 11.9323%. Relu 361 | 7.84641 MB. 10.1132%. Add 362 | 6.06516 MB. 7.81733%. Div 363 | 5.12912 MB. 6.61087%. FC 364 | 77.5861 MB in Total 365 | Feature Memory Written per operator type: 366 | 17.6246 MB. 36.4556%. Conv 367 | 9.25778 MB. 19.1492%. Relu 368 | 8.43843 MB. 17.4544%. Mul 369 | 6.95565 MB. 14.3874%. Add 370 | 6.06502 MB. 12.5452%. Div 371 | 0.004 MB. 0.00827378%. FC 372 | 48.3455 MB in Total 373 | Parameter Memory per operator type: 374 | 16.6778 MB. 76.4973%. Conv 375 | 5.124 MB. 23.5027%. FC 376 | 0 MB. 0%. Add 377 | 0 MB. 0%. Div 378 | 0 MB. 0%. Mul 379 | 0 MB. 0%. Relu 380 | 21.8018 MB in Total 381 | 382 | ``` 383 | 384 | ## MnasNet-A1 385 | 386 | ### Unoptimized 387 | ``` 388 | Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345 389 | Time per operator type: 390 | 24.4656 ms. 79.0905%. Conv 391 | 4.14958 ms. 13.4144%. SpatialBN 392 | 1.60598 ms. 5.19169%. Relu 393 | 0.295219 ms. 0.95436%. Mul 394 | 0.187609 ms. 0.606486%. FC 395 | 0.120556 ms. 0.389724%. AveragePool 396 | 0.09036 ms. 0.292109%. Add 397 | 0.015727 ms. 0.050841%. Sigmoid 398 | 0.00306205 ms. 0.00989875%. Squeeze 399 | 30.9337 ms in Total 400 | FLOP per operator type: 401 | 0.620598 GFLOP. 95.6434%. Conv 402 | 0.0248873 GFLOP. 3.8355%. SpatialBN 403 | 0.002561 GFLOP. 0.394688%. FC 404 | 0.000597408 GFLOP. 0.0920695%. Mul 405 | 0.000222656 GFLOP. 0.0343146%. Add 406 | 0 GFLOP. 0%. Relu 407 | 0.648867 GFLOP in Total 408 | Feature Memory Read per operator type: 409 | 35.5457 MB. 38.4109%. Conv 410 | 25.1552 MB. 27.1829%. SpatialBN 411 | 22.5235 MB. 24.339%. Relu 412 | 5.12912 MB. 5.54256%. FC 413 | 2.40586 MB. 2.59978%. Mul 414 | 1.78125 MB. 1.92483%. Add 415 | 92.5406 MB in Total 416 | Feature Memory Written per operator type: 417 | 24.9042 MB. 32.9424%. Conv 418 | 24.8873 MB. 32.92%. SpatialBN 419 | 22.5235 MB. 29.7932%. Relu 420 | 2.38963 MB. 3.16092%. Mul 421 | 0.890624 MB. 1.17809%. Add 422 | 0.004 MB. 0.00529106%. FC 423 | 75.5993 MB in Total 424 | Parameter Memory per operator type: 425 | 10.2732 MB. 66.1459%. Conv 426 | 5.124 MB. 32.9917%. FC 427 | 0.133952 MB. 0.86247%. SpatialBN 428 | 0 MB. 0%. Add 429 | 0 MB. 0%. Mul 430 | 0 MB. 0%. Relu 431 | 15.5312 MB in Total 432 | ``` 433 | 434 | ### Optimized 435 | ``` 436 | Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597 437 | Time per operator type: 438 | 22.0547 ms. 91.1375%. Conv 439 | 1.49096 ms. 6.16116%. Relu 440 | 0.253417 ms. 1.0472%. Mul 441 | 0.18506 ms. 0.76473%. FC 442 | 0.112942 ms. 0.466717%. AveragePool 443 | 0.086769 ms. 0.358559%. Add 444 | 0.0127889 ms. 0.0528479%. Sigmoid 445 | 0.0027346 ms. 0.0113003%. Squeeze 446 | 24.1994 ms in Total 447 | FLOP per operator type: 448 | 0.620598 GFLOP. 99.4581%. Conv 449 | 0.002561 GFLOP. 0.41043%. FC 450 | 0.000597408 GFLOP. 0.0957417%. Mul 451 | 0.000222656 GFLOP. 0.0356832%. Add 452 | 0 GFLOP. 0%. Relu 453 | 0.623979 GFLOP in Total 454 | Feature Memory Read per operator type: 455 | 35.6127 MB. 52.7968%. Conv 456 | 22.5235 MB. 33.3917%. Relu 457 | 5.12912 MB. 7.60406%. FC 458 | 2.40586 MB. 3.56675%. Mul 459 | 1.78125 MB. 2.64075%. Add 460 | 67.4524 MB in Total 461 | Feature Memory Written per operator type: 462 | 24.9042 MB. 49.1092%. Conv 463 | 22.5235 MB. 44.4145%. Relu 464 | 2.38963 MB. 4.71216%. Mul 465 | 0.890624 MB. 1.75624%. Add 466 | 0.004 MB. 0.00788768%. FC 467 | 50.712 MB in Total 468 | Parameter Memory per operator type: 469 | 10.2732 MB. 66.7213%. Conv 470 | 5.124 MB. 33.2787%. FC 471 | 0 MB. 0%. Add 472 | 0 MB. 0%. Mul 473 | 0 MB. 0%. Relu 474 | 15.3972 MB in Total 475 | ``` 476 | ## MnasNet-B1 477 | 478 | ### Unoptimized 479 | ``` 480 | Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322 481 | Time per operator type: 482 | 29.1121 ms. 83.3081%. Conv 483 | 4.14959 ms. 11.8746%. SpatialBN 484 | 1.35823 ms. 3.88675%. Relu 485 | 0.186188 ms. 0.532802%. FC 486 | 0.116244 ms. 0.332647%. Add 487 | 0.018641 ms. 0.0533437%. AveragePool 488 | 0.0040904 ms. 0.0117052%. Squeeze 489 | 34.9451 ms in Total 490 | FLOP per operator type: 491 | 0.626272 GFLOP. 96.2088%. Conv 492 | 0.0218266 GFLOP. 3.35303%. SpatialBN 493 | 0.002561 GFLOP. 0.393424%. FC 494 | 0.000291648 GFLOP. 0.0448034%. Add 495 | 0 GFLOP. 0%. Relu 496 | 0.650951 GFLOP in Total 497 | Feature Memory Read per operator type: 498 | 34.4354 MB. 41.3788%. Conv 499 | 22.1299 MB. 26.5921%. SpatialBN 500 | 19.1923 MB. 23.0622%. Relu 501 | 5.12912 MB. 6.16333%. FC 502 | 2.33318 MB. 2.80364%. Add 503 | 83.2199 MB in Total 504 | Feature Memory Written per operator type: 505 | 21.8266 MB. 34.0955%. Conv 506 | 21.8266 MB. 34.0955%. SpatialBN 507 | 19.1923 MB. 29.9805%. Relu 508 | 1.16659 MB. 1.82234%. Add 509 | 0.004 MB. 0.00624844%. FC 510 | 64.016 MB in Total 511 | Parameter Memory per operator type: 512 | 12.2576 MB. 69.9104%. Conv 513 | 5.124 MB. 29.2245%. FC 514 | 0.15168 MB. 0.865099%. SpatialBN 515 | 0 MB. 0%. Add 516 | 0 MB. 0%. Relu 517 | 17.5332 MB in Total 518 | ``` 519 | 520 | ### Optimized 521 | ``` 522 | Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426 523 | Time per operator type: 524 | 24.9888 ms. 94.0962%. Conv 525 | 1.26147 ms. 4.75011%. Relu 526 | 0.176234 ms. 0.663619%. FC 527 | 0.113309 ms. 0.426672%. Add 528 | 0.0138708 ms. 0.0522311%. AveragePool 529 | 0.00295685 ms. 0.0111341%. Squeeze 530 | 26.5566 ms in Total 531 | FLOP per operator type: 532 | 0.626272 GFLOP. 99.5466%. Conv 533 | 0.002561 GFLOP. 0.407074%. FC 534 | 0.000291648 GFLOP. 0.0463578%. Add 535 | 0 GFLOP. 0%. Relu 536 | 0.629124 GFLOP in Total 537 | Feature Memory Read per operator type: 538 | 34.5112 MB. 56.4224%. Conv 539 | 19.1923 MB. 31.3775%. Relu 540 | 5.12912 MB. 8.3856%. FC 541 | 2.33318 MB. 3.81452%. Add 542 | 61.1658 MB in Total 543 | Feature Memory Written per operator type: 544 | 21.8266 MB. 51.7346%. Conv 545 | 19.1923 MB. 45.4908%. Relu 546 | 1.16659 MB. 2.76513%. Add 547 | 0.004 MB. 0.00948104%. FC 548 | 42.1895 MB in Total 549 | Parameter Memory per operator type: 550 | 12.2576 MB. 70.5205%. Conv 551 | 5.124 MB. 29.4795%. FC 552 | 0 MB. 0%. Add 553 | 0 MB. 0%. Relu 554 | 17.3816 MB in Total 555 | ``` 556 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Ross Wightman 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (Generic) EfficientNets for PyTorch 2 | 3 | ```diff 4 | -- **NOTE** This repo is not being maintained -- 5 | ``` 6 | Please use [`timm`](https://github.com/huggingface/pytorch-image-models) instead. It includes all of these model definitions (compatible weights) and much much more. 7 | 8 | A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search. 9 | 10 | All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py)) 11 | 12 | ## What's New 13 | 14 | ### Aug 19, 2020 15 | * Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1) 16 | * Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1) 17 | * Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX 18 | * ONNX runtime based validation script added 19 | * activations (mostly) brought in sync with `timm` equivalents 20 | 21 | 22 | ### April 5, 2020 23 | * Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite 24 | * 3.5M param MobileNet-V2 100 @ 73% 25 | * 4.5M param MobileNet-V2 110d @ 75% 26 | * 6.1M param MobileNet-V2 140 @ 76.5% 27 | * 5.8M param MobileNet-V2 120d @ 77.3% 28 | 29 | ### March 23, 2020 30 | * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) 31 | * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1 32 | * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior 33 | 34 | ### Feb 12, 2020 35 | * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) 36 | * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization. 37 | * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) 38 | 39 | ### Jan 22, 2020 40 | * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models) 41 | * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict 42 | * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues 43 | 44 | ### Nov 22, 2019 45 | * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different 46 | preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights. 47 | 48 | ### Nov 15, 2019 49 | * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights 50 | * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine 51 | 52 | ### Oct 30, 2019 53 | * Many of the models will now work with torch.jit.script, MixNet being the biggest exception 54 | * Improved interface for enabling torchscript or ONNX export compatible modes (via config) 55 | * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn 56 | * Activation factory to select best version of activation by name or override one globally 57 | * Add pretrained checkpoint load helper that handles input conv and classifier changes 58 | 59 | ### Oct 27, 2019 60 | * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv 61 | * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet 62 | * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base 63 | * Switch activations and global pooling to modules 64 | * Add memory-efficient Swish/Mish impl 65 | * Add as_sequential() method to all models and allow as an argument in entrypoint fns 66 | * Move MobileNetV3 into own file since it has a different head 67 | * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here 68 | 69 | ## Models 70 | 71 | Implemented models include: 72 | * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252) 73 | * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) 74 | * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946) 75 | * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) 76 | * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971) 77 | * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) 78 | * MixNet (https://arxiv.org/abs/1907.09595) 79 | * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) 80 | * MobileNet-V3 (https://arxiv.org/abs/1905.02244) 81 | * FBNet-C (https://arxiv.org/abs/1812.03443) 82 | * Single-Path NAS (https://arxiv.org/abs/1904.02877) 83 | 84 | I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code. 85 | 86 | ## Pretrained 87 | 88 | I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models 89 | 90 | 91 | |Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop | 92 | |---|---|---|---|---|---|---|---| 93 | | efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 | 94 | | efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 | 95 | | mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 | 96 | | efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 | 97 | | mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 | 98 | | efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 | 99 | | mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 | 100 | | efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 | 101 | | efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 | 102 | | efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 | 103 | | mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 | 104 | | mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 | 105 | | mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 | 106 | | mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 | 107 | | mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 | 108 | | mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 | 109 | | efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 | 110 | | mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 | 111 | | fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 | 112 | | mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 | 113 | | mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 | 114 | | spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 | 115 | | mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 | 116 | 117 | 118 | More pretrained models to come... 119 | 120 | 121 | ## Ported Weights 122 | 123 | The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args. 124 | 125 | **IMPORTANT:** 126 | * Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std. 127 | * Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl. 128 | 129 | To run validation for tf_efficientnet_b5: 130 | `python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic` 131 | 132 | To run validation w/ TF preprocessing for tf_efficientnet_b5: 133 | `python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing` 134 | 135 | To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: 136 | `python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5` 137 | 138 | |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop | 139 | |---|---|---|---|---|---|---| 140 | | tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A | 141 | | tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 | 142 | | tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 | 143 | | tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A | 144 | | tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A | 145 | | tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A | 146 | | tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A | 147 | | tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A | 148 | | tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A | 149 | | tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A | 150 | | tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A | 151 | | tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A | 152 | | tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 | 153 | | tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 | 154 | | tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A | 155 | | tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 | 156 | | tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A | 157 | | tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 | 158 | | tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A | 159 | | tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 | 160 | | tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 | 161 | | tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A | 162 | | tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A | 163 | | tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 | 164 | | tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A | 165 | | tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 | 166 | | tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A | 167 | | tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 | 168 | | tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A | 169 | | tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 | 170 | | tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A | 171 | | tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 | 172 | | tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 | 173 | | tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A | 174 | | tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A | 175 | | tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 | 176 | | tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | 177 | | tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 | 178 | | tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 | 179 | | tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | 180 | | tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 | 181 | | tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A | 182 | | tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A | 183 | | tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 | 184 | | tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 | 185 | | tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A | 186 | | tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A | 187 | | tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 | 188 | | tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A | 189 | | tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 | 190 | | tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 | 191 | | tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A | 192 | | tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A | 193 | | tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 | 194 | | tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 | 195 | | tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 | 196 | | tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A | 197 | | tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A | 198 | | tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A | 199 | | tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A | 200 | | tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 | 201 | | tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 | 202 | | tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 | 203 | | tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 | 204 | | tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 | 205 | | tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 | 206 | | tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 | 207 | | tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 | 208 | | tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A | 209 | | tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A | 210 | | tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 | 211 | | tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A | 212 | | tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A | 213 | | tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A | 214 | | tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 | 215 | | tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A | 216 | | tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 | 217 | | tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 | 218 | | tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A | 219 | | tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 | 220 | | tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A | 221 | | tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A | 222 | | tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 | 223 | | tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 | 224 | | tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A | 225 | | tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 | 226 | | tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A | 227 | | tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 | 228 | | tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A | 229 | | tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 | 230 | | tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A | 231 | | tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 | 232 | | tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A | 233 | | tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 | 234 | | tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A | 235 | | tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 | 236 | 237 | 238 | *tfp models validated with `tf-preprocessing` pipeline 239 | 240 | Google tf and tflite weights ported from official Tensorflow repositories 241 | * https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet 242 | * https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet 243 | * https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet 244 | 245 | ## Usage 246 | 247 | ### Environment 248 | 249 | All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x. 250 | 251 | Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself. 252 | 253 | PyTorch versions 1.4, 1.5, 1.6 have been tested with this code. 254 | 255 | I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: 256 | ``` 257 | conda create -n torch-env 258 | conda activate torch-env 259 | conda install -c pytorch pytorch torchvision cudatoolkit=10.2 260 | ``` 261 | 262 | ### PyTorch Hub 263 | 264 | Models can be accessed via the PyTorch Hub API 265 | 266 | ``` 267 | >>> torch.hub.list('rwightman/gen-efficientnet-pytorch') 268 | ['efficientnet_b0', ...] 269 | >>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True) 270 | >>> model.eval() 271 | >>> output = model(torch.randn(1,3,224,224)) 272 | ``` 273 | 274 | ### Pip 275 | This package can be installed via pip. 276 | 277 | Install (after conda env/install): 278 | ``` 279 | pip install geffnet 280 | ``` 281 | 282 | Eval use: 283 | ``` 284 | >>> import geffnet 285 | >>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True) 286 | >>> m.eval() 287 | ``` 288 | 289 | Train use: 290 | ``` 291 | >>> import geffnet 292 | >>> # models can also be created by using the entrypoint directly 293 | >>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2) 294 | >>> m.train() 295 | ``` 296 | 297 | Create in a nn.Sequential container, for fast.ai, etc: 298 | ``` 299 | >>> import geffnet 300 | >>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True) 301 | ``` 302 | 303 | ### Exporting 304 | 305 | Scripts are included to 306 | * export models to ONNX (`onnx_export.py`) 307 | * optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg) 308 | * validate with ONNX runtime (`onnx_validate.py`) 309 | * convert ONNX model to Caffe2 (`onnx_to_caffe.py`) 310 | * validate in Caffe2 (`caffe2_validate.py`) 311 | * benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`) 312 | 313 | As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation: 314 | ``` 315 | python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx 316 | python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx 317 | ``` 318 | 319 | These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible 320 | export now requires additional args mentioned in the export script (not needed in earlier versions). 321 | 322 | #### Export Notes 323 | 1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script. 324 | 2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working. 325 | 3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization. 326 | 3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here. 327 | 328 | 329 | -------------------------------------------------------------------------------- /caffe2_benchmark.py: -------------------------------------------------------------------------------- 1 | """ Caffe2 validation script 2 | 3 | This script runs Caffe2 benchmark on exported ONNX model. 4 | It is a useful tool for reporting model FLOPS. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | import argparse 9 | from caffe2.python import core, workspace, model_helper 10 | from caffe2.proto import caffe2_pb2 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') 14 | parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', 15 | help='caffe2 model pb name prefix') 16 | parser.add_argument('--c2-init', default='', type=str, metavar='PATH', 17 | help='caffe2 model init .pb') 18 | parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', 19 | help='caffe2 model predict .pb') 20 | parser.add_argument('-b', '--batch-size', default=1, type=int, 21 | metavar='N', help='mini-batch size (default: 1)') 22 | parser.add_argument('--img-size', default=224, type=int, 23 | metavar='N', help='Input image dimension, uses model default if empty') 24 | 25 | 26 | def main(): 27 | args = parser.parse_args() 28 | args.gpu_id = 0 29 | if args.c2_prefix: 30 | args.c2_init = args.c2_prefix + '.init.pb' 31 | args.c2_predict = args.c2_prefix + '.predict.pb' 32 | 33 | model = model_helper.ModelHelper(name="le_net", init_params=False) 34 | 35 | # Bring in the init net from init_net.pb 36 | init_net_proto = caffe2_pb2.NetDef() 37 | with open(args.c2_init, "rb") as f: 38 | init_net_proto.ParseFromString(f.read()) 39 | model.param_init_net = core.Net(init_net_proto) 40 | 41 | # bring in the predict net from predict_net.pb 42 | predict_net_proto = caffe2_pb2.NetDef() 43 | with open(args.c2_predict, "rb") as f: 44 | predict_net_proto.ParseFromString(f.read()) 45 | model.net = core.Net(predict_net_proto) 46 | 47 | # CUDA performance not impressive 48 | #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) 49 | #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 50 | #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 51 | 52 | input_blob = model.net.external_inputs[0] 53 | model.param_init_net.GaussianFill( 54 | [], 55 | input_blob.GetUnscopedName(), 56 | shape=(args.batch_size, 3, args.img_size, args.img_size), 57 | mean=0.0, 58 | std=1.0) 59 | workspace.RunNetOnce(model.param_init_net) 60 | workspace.CreateNet(model.net, overwrite=True) 61 | workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /caffe2_validate.py: -------------------------------------------------------------------------------- 1 | """ Caffe2 validation script 2 | 3 | This script is created to verify exported ONNX models running in Caffe2 4 | It utilizes the same PyTorch dataloader/processing pipeline for a 5 | fair comparison against the originals. 6 | 7 | Copyright 2020 Ross Wightman 8 | """ 9 | import argparse 10 | import numpy as np 11 | from caffe2.python import core, workspace, model_helper 12 | from caffe2.proto import caffe2_pb2 13 | from data import create_loader, resolve_data_config, Dataset 14 | from utils import AverageMeter 15 | import time 16 | 17 | parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') 18 | parser.add_argument('data', metavar='DIR', 19 | help='path to dataset') 20 | parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', 21 | help='caffe2 model pb name prefix') 22 | parser.add_argument('--c2-init', default='', type=str, metavar='PATH', 23 | help='caffe2 model init .pb') 24 | parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', 25 | help='caffe2 model predict .pb') 26 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 27 | help='number of data loading workers (default: 2)') 28 | parser.add_argument('-b', '--batch-size', default=256, type=int, 29 | metavar='N', help='mini-batch size (default: 256)') 30 | parser.add_argument('--img-size', default=None, type=int, 31 | metavar='N', help='Input image dimension, uses model default if empty') 32 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 33 | help='Override mean pixel value of dataset') 34 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 35 | help='Override std deviation of of dataset') 36 | parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', 37 | help='Override default crop pct of 0.875') 38 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 39 | help='Image resize interpolation type (overrides model)') 40 | parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', 41 | help='use tensorflow mnasnet preporcessing') 42 | parser.add_argument('--print-freq', '-p', default=10, type=int, 43 | metavar='N', help='print frequency (default: 10)') 44 | 45 | 46 | def main(): 47 | args = parser.parse_args() 48 | args.gpu_id = 0 49 | if args.c2_prefix: 50 | args.c2_init = args.c2_prefix + '.init.pb' 51 | args.c2_predict = args.c2_prefix + '.predict.pb' 52 | 53 | model = model_helper.ModelHelper(name="validation_net", init_params=False) 54 | 55 | # Bring in the init net from init_net.pb 56 | init_net_proto = caffe2_pb2.NetDef() 57 | with open(args.c2_init, "rb") as f: 58 | init_net_proto.ParseFromString(f.read()) 59 | model.param_init_net = core.Net(init_net_proto) 60 | 61 | # bring in the predict net from predict_net.pb 62 | predict_net_proto = caffe2_pb2.NetDef() 63 | with open(args.c2_predict, "rb") as f: 64 | predict_net_proto.ParseFromString(f.read()) 65 | model.net = core.Net(predict_net_proto) 66 | 67 | data_config = resolve_data_config(None, args) 68 | loader = create_loader( 69 | Dataset(args.data, load_bytes=args.tf_preprocessing), 70 | input_size=data_config['input_size'], 71 | batch_size=args.batch_size, 72 | use_prefetcher=False, 73 | interpolation=data_config['interpolation'], 74 | mean=data_config['mean'], 75 | std=data_config['std'], 76 | num_workers=args.workers, 77 | crop_pct=data_config['crop_pct'], 78 | tensorflow_preprocessing=args.tf_preprocessing) 79 | 80 | # this is so obvious, wonderful interface 81 | input_blob = model.net.external_inputs[0] 82 | output_blob = model.net.external_outputs[0] 83 | 84 | if True: 85 | device_opts = None 86 | else: 87 | # CUDA is crashing, no idea why, awesome error message, give it a try for kicks 88 | device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) 89 | model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 90 | model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 91 | 92 | model.param_init_net.GaussianFill( 93 | [], input_blob.GetUnscopedName(), 94 | shape=(1,) + data_config['input_size'], mean=0.0, std=1.0) 95 | workspace.RunNetOnce(model.param_init_net) 96 | workspace.CreateNet(model.net, overwrite=True) 97 | 98 | batch_time = AverageMeter() 99 | top1 = AverageMeter() 100 | top5 = AverageMeter() 101 | end = time.time() 102 | for i, (input, target) in enumerate(loader): 103 | # run the net and return prediction 104 | caffe2_in = input.data.numpy() 105 | workspace.FeedBlob(input_blob, caffe2_in, device_opts) 106 | workspace.RunNet(model.net, num_iter=1) 107 | output = workspace.FetchBlob(output_blob) 108 | 109 | # measure accuracy and record loss 110 | prec1, prec5 = accuracy_np(output.data, target.numpy()) 111 | top1.update(prec1.item(), input.size(0)) 112 | top5.update(prec5.item(), input.size(0)) 113 | 114 | # measure elapsed time 115 | batch_time.update(time.time() - end) 116 | end = time.time() 117 | 118 | if i % args.print_freq == 0: 119 | print('Test: [{0}/{1}]\t' 120 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' 121 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 122 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 123 | i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, 124 | ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) 125 | 126 | print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( 127 | top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) 128 | 129 | 130 | def accuracy_np(output, target): 131 | max_indices = np.argsort(output, axis=1)[:, ::-1] 132 | top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() 133 | top1 = 100 * np.equal(max_indices[:, 0], target).mean() 134 | return top1, top5 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .transforms import * 3 | from .loader import create_loader 4 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | """ Quick n simple image folder dataset 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import torch.utils.data as data 6 | 7 | import os 8 | import re 9 | import torch 10 | from PIL import Image 11 | 12 | 13 | IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] 14 | 15 | 16 | def natural_key(string_): 17 | """See http://www.codinghorror.com/blog/archives/001018.html""" 18 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 19 | 20 | 21 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 22 | if class_to_idx is None: 23 | class_to_idx = dict() 24 | build_class_idx = True 25 | else: 26 | build_class_idx = False 27 | labels = [] 28 | filenames = [] 29 | for root, subdirs, files in os.walk(folder, topdown=False): 30 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 31 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 32 | if build_class_idx and not subdirs: 33 | class_to_idx[label] = None 34 | for f in files: 35 | base, ext = os.path.splitext(f) 36 | if ext.lower() in types: 37 | filenames.append(os.path.join(root, f)) 38 | labels.append(label) 39 | if build_class_idx: 40 | classes = sorted(class_to_idx.keys(), key=natural_key) 41 | for idx, c in enumerate(classes): 42 | class_to_idx[c] = idx 43 | images_and_targets = zip(filenames, [class_to_idx[l] for l in labels]) 44 | if sort: 45 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 46 | if build_class_idx: 47 | return images_and_targets, classes, class_to_idx 48 | else: 49 | return images_and_targets 50 | 51 | 52 | class Dataset(data.Dataset): 53 | 54 | def __init__( 55 | self, 56 | root, 57 | transform=None, 58 | load_bytes=False): 59 | 60 | imgs, _, _ = find_images_and_targets(root) 61 | if len(imgs) == 0: 62 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 63 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 64 | self.root = root 65 | self.imgs = imgs 66 | self.transform = transform 67 | self.load_bytes = load_bytes 68 | 69 | def __getitem__(self, index): 70 | path, target = self.imgs[index] 71 | img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | if target is None: 75 | target = torch.zeros(1).long() 76 | return img, target 77 | 78 | def __len__(self): 79 | return len(self.imgs) 80 | 81 | def filenames(self, indices=[], basename=False): 82 | if indices: 83 | if basename: 84 | return [os.path.basename(self.imgs[i][0]) for i in indices] 85 | else: 86 | return [self.imgs[i][0] for i in indices] 87 | else: 88 | if basename: 89 | return [os.path.basename(x[0]) for x in self.imgs] 90 | else: 91 | return [x[0] for x in self.imgs] 92 | -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | """ Fast Collate, CUDA Prefetcher 2 | 3 | Prefetcher and Fast Collate inspired by NVIDIA APEX example at 4 | https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import torch 9 | import torch.utils.data 10 | from .transforms import * 11 | 12 | 13 | def fast_collate(batch): 14 | targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) 15 | batch_size = len(targets) 16 | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 17 | for i in range(batch_size): 18 | tensor[i] += torch.from_numpy(batch[i][0]) 19 | 20 | return tensor, targets 21 | 22 | 23 | class PrefetchLoader: 24 | 25 | def __init__(self, 26 | loader, 27 | mean=IMAGENET_DEFAULT_MEAN, 28 | std=IMAGENET_DEFAULT_STD): 29 | self.loader = loader 30 | self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) 31 | self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) 32 | 33 | def __iter__(self): 34 | stream = torch.cuda.Stream() 35 | first = True 36 | 37 | for next_input, next_target in self.loader: 38 | with torch.cuda.stream(stream): 39 | next_input = next_input.cuda(non_blocking=True) 40 | next_target = next_target.cuda(non_blocking=True) 41 | next_input = next_input.float().sub_(self.mean).div_(self.std) 42 | 43 | if not first: 44 | yield input, target 45 | else: 46 | first = False 47 | 48 | torch.cuda.current_stream().wait_stream(stream) 49 | input = next_input 50 | target = next_target 51 | 52 | yield input, target 53 | 54 | def __len__(self): 55 | return len(self.loader) 56 | 57 | @property 58 | def sampler(self): 59 | return self.loader.sampler 60 | 61 | 62 | def create_loader( 63 | dataset, 64 | input_size, 65 | batch_size, 66 | is_training=False, 67 | use_prefetcher=True, 68 | interpolation='bilinear', 69 | mean=IMAGENET_DEFAULT_MEAN, 70 | std=IMAGENET_DEFAULT_STD, 71 | num_workers=1, 72 | crop_pct=None, 73 | tensorflow_preprocessing=False 74 | ): 75 | if isinstance(input_size, tuple): 76 | img_size = input_size[-2:] 77 | else: 78 | img_size = input_size 79 | 80 | if tensorflow_preprocessing and use_prefetcher: 81 | from data.tf_preprocessing import TfPreprocessTransform 82 | transform = TfPreprocessTransform( 83 | is_training=is_training, size=img_size, interpolation=interpolation) 84 | else: 85 | transform = transforms_imagenet_eval( 86 | img_size, 87 | interpolation=interpolation, 88 | use_prefetcher=use_prefetcher, 89 | mean=mean, 90 | std=std, 91 | crop_pct=crop_pct) 92 | 93 | dataset.transform = transform 94 | 95 | loader = torch.utils.data.DataLoader( 96 | dataset, 97 | batch_size=batch_size, 98 | shuffle=False, 99 | num_workers=num_workers, 100 | collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate, 101 | ) 102 | if use_prefetcher: 103 | loader = PrefetchLoader( 104 | loader, 105 | mean=mean, 106 | std=std) 107 | 108 | return loader 109 | -------------------------------------------------------------------------------- /data/tf_preprocessing.py: -------------------------------------------------------------------------------- 1 | """ Tensorflow Preprocessing Adapter 2 | 3 | Allows use of Tensorflow preprocessing pipeline in PyTorch Transform 4 | 5 | Copyright of original Tensorflow code below. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | # ============================================================================== 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | import numpy as np 29 | 30 | IMAGE_SIZE = 224 31 | CROP_PADDING = 32 32 | 33 | 34 | def distorted_bounding_box_crop(image_bytes, 35 | bbox, 36 | min_object_covered=0.1, 37 | aspect_ratio_range=(0.75, 1.33), 38 | area_range=(0.05, 1.0), 39 | max_attempts=100, 40 | scope=None): 41 | """Generates cropped_image using one of the bboxes randomly distorted. 42 | 43 | See `tf.image.sample_distorted_bounding_box` for more documentation. 44 | 45 | Args: 46 | image_bytes: `Tensor` of binary image data. 47 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 48 | where each coordinate is [0, 1) and the coordinates are arranged 49 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 50 | image. 51 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 52 | area of the image must contain at least this fraction of any bounding 53 | box supplied. 54 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 55 | image must have an aspect ratio = width / height within this range. 56 | area_range: An optional list of `float`s. The cropped area of the image 57 | must contain a fraction of the supplied image within in this range. 58 | max_attempts: An optional `int`. Number of attempts at generating a cropped 59 | region of the image of the specified constraints. After `max_attempts` 60 | failures, return the entire image. 61 | scope: Optional `str` for name scope. 62 | Returns: 63 | cropped image `Tensor` 64 | """ 65 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): 66 | shape = tf.image.extract_jpeg_shape(image_bytes) 67 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 68 | shape, 69 | bounding_boxes=bbox, 70 | min_object_covered=min_object_covered, 71 | aspect_ratio_range=aspect_ratio_range, 72 | area_range=area_range, 73 | max_attempts=max_attempts, 74 | use_image_if_no_bounding_boxes=True) 75 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 76 | 77 | # Crop the image to the specified bounding box. 78 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 79 | target_height, target_width, _ = tf.unstack(bbox_size) 80 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 81 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 82 | 83 | return image 84 | 85 | 86 | def _at_least_x_are_equal(a, b, x): 87 | """At least `x` of `a` and `b` `Tensors` are equal.""" 88 | match = tf.equal(a, b) 89 | match = tf.cast(match, tf.int32) 90 | return tf.greater_equal(tf.reduce_sum(match), x) 91 | 92 | 93 | def _decode_and_random_crop(image_bytes, image_size, resize_method): 94 | """Make a random crop of image_size.""" 95 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 96 | image = distorted_bounding_box_crop( 97 | image_bytes, 98 | bbox, 99 | min_object_covered=0.1, 100 | aspect_ratio_range=(3. / 4, 4. / 3.), 101 | area_range=(0.08, 1.0), 102 | max_attempts=10, 103 | scope=None) 104 | original_shape = tf.image.extract_jpeg_shape(image_bytes) 105 | bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) 106 | 107 | image = tf.cond( 108 | bad, 109 | lambda: _decode_and_center_crop(image_bytes, image_size), 110 | lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0]) 111 | 112 | return image 113 | 114 | 115 | def _decode_and_center_crop(image_bytes, image_size, resize_method): 116 | """Crops to center of image with padding then scales image_size.""" 117 | shape = tf.image.extract_jpeg_shape(image_bytes) 118 | image_height = shape[0] 119 | image_width = shape[1] 120 | 121 | padded_center_crop_size = tf.cast( 122 | ((image_size / (image_size + CROP_PADDING)) * 123 | tf.cast(tf.minimum(image_height, image_width), tf.float32)), 124 | tf.int32) 125 | 126 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 127 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 128 | crop_window = tf.stack([offset_height, offset_width, 129 | padded_center_crop_size, padded_center_crop_size]) 130 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 131 | image = tf.image.resize([image], [image_size, image_size], resize_method)[0] 132 | 133 | return image 134 | 135 | 136 | def _flip(image): 137 | """Random horizontal image flip.""" 138 | image = tf.image.random_flip_left_right(image) 139 | return image 140 | 141 | 142 | def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): 143 | """Preprocesses the given image for evaluation. 144 | 145 | Args: 146 | image_bytes: `Tensor` representing an image binary of arbitrary size. 147 | use_bfloat16: `bool` for whether to use bfloat16. 148 | image_size: image size. 149 | interpolation: image interpolation method 150 | 151 | Returns: 152 | A preprocessed image `Tensor`. 153 | """ 154 | resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR 155 | image = _decode_and_random_crop(image_bytes, image_size, resize_method) 156 | image = _flip(image) 157 | image = tf.reshape(image, [image_size, image_size, 3]) 158 | image = tf.image.convert_image_dtype( 159 | image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) 160 | return image 161 | 162 | 163 | def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): 164 | """Preprocesses the given image for evaluation. 165 | 166 | Args: 167 | image_bytes: `Tensor` representing an image binary of arbitrary size. 168 | use_bfloat16: `bool` for whether to use bfloat16. 169 | image_size: image size. 170 | interpolation: image interpolation method 171 | 172 | Returns: 173 | A preprocessed image `Tensor`. 174 | """ 175 | resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR 176 | image = _decode_and_center_crop(image_bytes, image_size, resize_method) 177 | image = tf.reshape(image, [image_size, image_size, 3]) 178 | image = tf.image.convert_image_dtype( 179 | image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) 180 | return image 181 | 182 | 183 | def preprocess_image(image_bytes, 184 | is_training=False, 185 | use_bfloat16=False, 186 | image_size=IMAGE_SIZE, 187 | interpolation='bicubic'): 188 | """Preprocesses the given image. 189 | 190 | Args: 191 | image_bytes: `Tensor` representing an image binary of arbitrary size. 192 | is_training: `bool` for whether the preprocessing is for training. 193 | use_bfloat16: `bool` for whether to use bfloat16. 194 | image_size: image size. 195 | interpolation: image interpolation method 196 | 197 | Returns: 198 | A preprocessed image `Tensor` with value range of [0, 255]. 199 | """ 200 | if is_training: 201 | return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation) 202 | else: 203 | return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation) 204 | 205 | 206 | class TfPreprocessTransform: 207 | 208 | def __init__(self, is_training=False, size=224, interpolation='bicubic'): 209 | self.is_training = is_training 210 | self.size = size[0] if isinstance(size, tuple) else size 211 | self.interpolation = interpolation 212 | self._image_bytes = None 213 | self.process_image = self._build_tf_graph() 214 | self.sess = None 215 | 216 | def _build_tf_graph(self): 217 | with tf.device('/cpu:0'): 218 | self._image_bytes = tf.placeholder( 219 | shape=[], 220 | dtype=tf.string, 221 | ) 222 | img = preprocess_image( 223 | self._image_bytes, self.is_training, False, self.size, self.interpolation) 224 | return img 225 | 226 | def __call__(self, image_bytes): 227 | if self.sess is None: 228 | self.sess = tf.Session() 229 | img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes}) 230 | img = img.round().clip(0, 255).astype(np.uint8) 231 | if img.ndim < 3: 232 | img = np.expand_dims(img, axis=-1) 233 | img = np.rollaxis(img, 2) # HWC to CHW 234 | return img 235 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from PIL import Image 4 | import math 5 | import numpy as np 6 | 7 | DEFAULT_CROP_PCT = 0.875 8 | 9 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 10 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 11 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 12 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 13 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 14 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 15 | 16 | 17 | def resolve_data_config(model, args, default_cfg={}, verbose=True): 18 | new_config = {} 19 | default_cfg = default_cfg 20 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 21 | default_cfg = model.default_cfg 22 | 23 | # Resolve input/image size 24 | # FIXME grayscale/chans arg to use different # channels? 25 | in_chans = 3 26 | input_size = (in_chans, 224, 224) 27 | if args.img_size is not None: 28 | # FIXME support passing img_size as tuple, non-square 29 | assert isinstance(args.img_size, int) 30 | input_size = (in_chans, args.img_size, args.img_size) 31 | elif 'input_size' in default_cfg: 32 | input_size = default_cfg['input_size'] 33 | new_config['input_size'] = input_size 34 | 35 | # resolve interpolation method 36 | new_config['interpolation'] = 'bicubic' 37 | if args.interpolation: 38 | new_config['interpolation'] = args.interpolation 39 | elif 'interpolation' in default_cfg: 40 | new_config['interpolation'] = default_cfg['interpolation'] 41 | 42 | # resolve dataset + model mean for normalization 43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 44 | if args.mean is not None: 45 | mean = tuple(args.mean) 46 | if len(mean) == 1: 47 | mean = tuple(list(mean) * in_chans) 48 | else: 49 | assert len(mean) == in_chans 50 | new_config['mean'] = mean 51 | elif 'mean' in default_cfg: 52 | new_config['mean'] = default_cfg['mean'] 53 | 54 | # resolve dataset + model std deviation for normalization 55 | new_config['std'] = IMAGENET_DEFAULT_STD 56 | if args.std is not None: 57 | std = tuple(args.std) 58 | if len(std) == 1: 59 | std = tuple(list(std) * in_chans) 60 | else: 61 | assert len(std) == in_chans 62 | new_config['std'] = std 63 | elif 'std' in default_cfg: 64 | new_config['std'] = default_cfg['std'] 65 | 66 | # resolve default crop percentage 67 | new_config['crop_pct'] = DEFAULT_CROP_PCT 68 | if args.crop_pct is not None: 69 | new_config['crop_pct'] = args.crop_pct 70 | elif 'crop_pct' in default_cfg: 71 | new_config['crop_pct'] = default_cfg['crop_pct'] 72 | 73 | if verbose: 74 | print('Data processing configuration for current model + dataset:') 75 | for n, v in new_config.items(): 76 | print('\t%s: %s' % (n, str(v))) 77 | 78 | return new_config 79 | 80 | 81 | class ToNumpy: 82 | 83 | def __call__(self, pil_img): 84 | np_img = np.array(pil_img, dtype=np.uint8) 85 | if np_img.ndim < 3: 86 | np_img = np.expand_dims(np_img, axis=-1) 87 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 88 | return np_img 89 | 90 | 91 | class ToTensor: 92 | 93 | def __init__(self, dtype=torch.float32): 94 | self.dtype = dtype 95 | 96 | def __call__(self, pil_img): 97 | np_img = np.array(pil_img, dtype=np.uint8) 98 | if np_img.ndim < 3: 99 | np_img = np.expand_dims(np_img, axis=-1) 100 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 101 | return torch.from_numpy(np_img).to(dtype=self.dtype) 102 | 103 | 104 | def _pil_interp(method): 105 | if method == 'bicubic': 106 | return Image.BICUBIC 107 | elif method == 'lanczos': 108 | return Image.LANCZOS 109 | elif method == 'hamming': 110 | return Image.HAMMING 111 | else: 112 | # default bilinear, do we want to allow nearest? 113 | return Image.BILINEAR 114 | 115 | 116 | def transforms_imagenet_eval( 117 | img_size=224, 118 | crop_pct=None, 119 | interpolation='bilinear', 120 | use_prefetcher=False, 121 | mean=IMAGENET_DEFAULT_MEAN, 122 | std=IMAGENET_DEFAULT_STD): 123 | crop_pct = crop_pct or DEFAULT_CROP_PCT 124 | 125 | if isinstance(img_size, tuple): 126 | assert len(img_size) == 2 127 | if img_size[-1] == img_size[-2]: 128 | # fall-back to older behaviour so Resize scales to shortest edge if target is square 129 | scale_size = int(math.floor(img_size[0] / crop_pct)) 130 | else: 131 | scale_size = tuple([int(x / crop_pct) for x in img_size]) 132 | else: 133 | scale_size = int(math.floor(img_size / crop_pct)) 134 | 135 | tfl = [ 136 | transforms.Resize(scale_size, _pil_interp(interpolation)), 137 | transforms.CenterCrop(img_size), 138 | ] 139 | if use_prefetcher: 140 | # prefetcher and collate will handle tensor conversion and norm 141 | tfl += [ToNumpy()] 142 | else: 143 | tfl += [ 144 | transforms.ToTensor(), 145 | transforms.Normalize( 146 | mean=torch.tensor(mean), 147 | std=torch.tensor(std)) 148 | ] 149 | 150 | return transforms.Compose(tfl) 151 | -------------------------------------------------------------------------------- /geffnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_efficientnet import * 2 | from .mobilenetv3 import * 3 | from .model_factory import create_model 4 | from .config import is_exportable, is_scriptable, set_exportable, set_scriptable 5 | from .activations import * -------------------------------------------------------------------------------- /geffnet/activations/__init__.py: -------------------------------------------------------------------------------- 1 | from geffnet import config 2 | from geffnet.activations.activations_me import * 3 | from geffnet.activations.activations_jit import * 4 | from geffnet.activations.activations import * 5 | import torch 6 | 7 | _has_silu = 'silu' in dir(torch.nn.functional) 8 | 9 | _ACT_FN_DEFAULT = dict( 10 | silu=F.silu if _has_silu else swish, 11 | swish=F.silu if _has_silu else swish, 12 | mish=mish, 13 | relu=F.relu, 14 | relu6=F.relu6, 15 | sigmoid=sigmoid, 16 | tanh=tanh, 17 | hard_sigmoid=hard_sigmoid, 18 | hard_swish=hard_swish, 19 | ) 20 | 21 | _ACT_FN_JIT = dict( 22 | silu=F.silu if _has_silu else swish_jit, 23 | swish=F.silu if _has_silu else swish_jit, 24 | mish=mish_jit, 25 | ) 26 | 27 | _ACT_FN_ME = dict( 28 | silu=F.silu if _has_silu else swish_me, 29 | swish=F.silu if _has_silu else swish_me, 30 | mish=mish_me, 31 | hard_swish=hard_swish_me, 32 | hard_sigmoid_jit=hard_sigmoid_me, 33 | ) 34 | 35 | _ACT_LAYER_DEFAULT = dict( 36 | silu=nn.SiLU if _has_silu else Swish, 37 | swish=nn.SiLU if _has_silu else Swish, 38 | mish=Mish, 39 | relu=nn.ReLU, 40 | relu6=nn.ReLU6, 41 | sigmoid=Sigmoid, 42 | tanh=Tanh, 43 | hard_sigmoid=HardSigmoid, 44 | hard_swish=HardSwish, 45 | ) 46 | 47 | _ACT_LAYER_JIT = dict( 48 | silu=nn.SiLU if _has_silu else SwishJit, 49 | swish=nn.SiLU if _has_silu else SwishJit, 50 | mish=MishJit, 51 | ) 52 | 53 | _ACT_LAYER_ME = dict( 54 | silu=nn.SiLU if _has_silu else SwishMe, 55 | swish=nn.SiLU if _has_silu else SwishMe, 56 | mish=MishMe, 57 | hard_swish=HardSwishMe, 58 | hard_sigmoid=HardSigmoidMe 59 | ) 60 | 61 | _OVERRIDE_FN = dict() 62 | _OVERRIDE_LAYER = dict() 63 | 64 | 65 | def add_override_act_fn(name, fn): 66 | global _OVERRIDE_FN 67 | _OVERRIDE_FN[name] = fn 68 | 69 | 70 | def update_override_act_fn(overrides): 71 | assert isinstance(overrides, dict) 72 | global _OVERRIDE_FN 73 | _OVERRIDE_FN.update(overrides) 74 | 75 | 76 | def clear_override_act_fn(): 77 | global _OVERRIDE_FN 78 | _OVERRIDE_FN = dict() 79 | 80 | 81 | def add_override_act_layer(name, fn): 82 | _OVERRIDE_LAYER[name] = fn 83 | 84 | 85 | def update_override_act_layer(overrides): 86 | assert isinstance(overrides, dict) 87 | global _OVERRIDE_LAYER 88 | _OVERRIDE_LAYER.update(overrides) 89 | 90 | 91 | def clear_override_act_layer(): 92 | global _OVERRIDE_LAYER 93 | _OVERRIDE_LAYER = dict() 94 | 95 | 96 | def get_act_fn(name='relu'): 97 | """ Activation Function Factory 98 | Fetching activation fns by name with this function allows export or torch script friendly 99 | functions to be returned dynamically based on current config. 100 | """ 101 | if name in _OVERRIDE_FN: 102 | return _OVERRIDE_FN[name] 103 | use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) 104 | if use_me and name in _ACT_FN_ME: 105 | # If not exporting or scripting the model, first look for a memory optimized version 106 | # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin 107 | return _ACT_FN_ME[name] 108 | if config.is_exportable() and name in ('silu', 'swish'): 109 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 110 | return swish 111 | use_jit = not (config.is_exportable() or config.is_no_jit()) 112 | # NOTE: export tracing should work with jit scripted components, but I keep running into issues 113 | if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting 114 | return _ACT_FN_JIT[name] 115 | return _ACT_FN_DEFAULT[name] 116 | 117 | 118 | def get_act_layer(name='relu'): 119 | """ Activation Layer Factory 120 | Fetching activation layers by name with this function allows export or torch script friendly 121 | functions to be returned dynamically based on current config. 122 | """ 123 | if name in _OVERRIDE_LAYER: 124 | return _OVERRIDE_LAYER[name] 125 | use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) 126 | if use_me and name in _ACT_LAYER_ME: 127 | return _ACT_LAYER_ME[name] 128 | if config.is_exportable() and name in ('silu', 'swish'): 129 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 130 | return Swish 131 | use_jit = not (config.is_exportable() or config.is_no_jit()) 132 | # NOTE: export tracing should work with jit scripted components, but I keep running into issues 133 | if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting 134 | return _ACT_LAYER_JIT[name] 135 | return _ACT_LAYER_DEFAULT[name] 136 | 137 | 138 | -------------------------------------------------------------------------------- /geffnet/activations/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | 11 | 12 | def swish(x, inplace: bool = False): 13 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 14 | and also as Swish (https://arxiv.org/abs/1710.05941). 15 | 16 | TODO Rename to SiLU with addition to PyTorch 17 | """ 18 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 19 | 20 | 21 | class Swish(nn.Module): 22 | def __init__(self, inplace: bool = False): 23 | super(Swish, self).__init__() 24 | self.inplace = inplace 25 | 26 | def forward(self, x): 27 | return swish(x, self.inplace) 28 | 29 | 30 | def mish(x, inplace: bool = False): 31 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | def __init__(self, inplace: bool = False): 38 | super(Mish, self).__init__() 39 | self.inplace = inplace 40 | 41 | def forward(self, x): 42 | return mish(x, self.inplace) 43 | 44 | 45 | def sigmoid(x, inplace: bool = False): 46 | return x.sigmoid_() if inplace else x.sigmoid() 47 | 48 | 49 | # PyTorch has this, but not with a consistent inplace argmument interface 50 | class Sigmoid(nn.Module): 51 | def __init__(self, inplace: bool = False): 52 | super(Sigmoid, self).__init__() 53 | self.inplace = inplace 54 | 55 | def forward(self, x): 56 | return x.sigmoid_() if self.inplace else x.sigmoid() 57 | 58 | 59 | def tanh(x, inplace: bool = False): 60 | return x.tanh_() if inplace else x.tanh() 61 | 62 | 63 | # PyTorch has this, but not with a consistent inplace argmument interface 64 | class Tanh(nn.Module): 65 | def __init__(self, inplace: bool = False): 66 | super(Tanh, self).__init__() 67 | self.inplace = inplace 68 | 69 | def forward(self, x): 70 | return x.tanh_() if self.inplace else x.tanh() 71 | 72 | 73 | def hard_swish(x, inplace: bool = False): 74 | inner = F.relu6(x + 3.).div_(6.) 75 | return x.mul_(inner) if inplace else x.mul(inner) 76 | 77 | 78 | class HardSwish(nn.Module): 79 | def __init__(self, inplace: bool = False): 80 | super(HardSwish, self).__init__() 81 | self.inplace = inplace 82 | 83 | def forward(self, x): 84 | return hard_swish(x, self.inplace) 85 | 86 | 87 | def hard_sigmoid(x, inplace: bool = False): 88 | if inplace: 89 | return x.add_(3.).clamp_(0., 6.).div_(6.) 90 | else: 91 | return F.relu6(x + 3.) / 6. 92 | 93 | 94 | class HardSigmoid(nn.Module): 95 | def __init__(self, inplace: bool = False): 96 | super(HardSigmoid, self).__init__() 97 | self.inplace = inplace 98 | 99 | def forward(self, x): 100 | return hard_sigmoid(x, self.inplace) 101 | 102 | 103 | -------------------------------------------------------------------------------- /geffnet/activations/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations (jit) 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', 18 | 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit(x, inplace: bool = False): 23 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 24 | and also as Swish (https://arxiv.org/abs/1710.05941). 25 | 26 | TODO Rename to SiLU with addition to PyTorch 27 | """ 28 | return x.mul(x.sigmoid()) 29 | 30 | 31 | @torch.jit.script 32 | def mish_jit(x, _inplace: bool = False): 33 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 34 | """ 35 | return x.mul(F.softplus(x).tanh()) 36 | 37 | 38 | class SwishJit(nn.Module): 39 | def __init__(self, inplace: bool = False): 40 | super(SwishJit, self).__init__() 41 | 42 | def forward(self, x): 43 | return swish_jit(x) 44 | 45 | 46 | class MishJit(nn.Module): 47 | def __init__(self, inplace: bool = False): 48 | super(MishJit, self).__init__() 49 | 50 | def forward(self, x): 51 | return mish_jit(x) 52 | 53 | 54 | @torch.jit.script 55 | def hard_sigmoid_jit(x, inplace: bool = False): 56 | # return F.relu6(x + 3.) / 6. 57 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 58 | 59 | 60 | class HardSigmoidJit(nn.Module): 61 | def __init__(self, inplace: bool = False): 62 | super(HardSigmoidJit, self).__init__() 63 | 64 | def forward(self, x): 65 | return hard_sigmoid_jit(x) 66 | 67 | 68 | @torch.jit.script 69 | def hard_swish_jit(x, inplace: bool = False): 70 | # return x * (F.relu6(x + 3.) / 6) 71 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 72 | 73 | 74 | class HardSwishJit(nn.Module): 75 | def __init__(self, inplace: bool = False): 76 | super(HardSwishJit, self).__init__() 77 | 78 | def forward(self, x): 79 | return hard_swish_jit(x) 80 | -------------------------------------------------------------------------------- /geffnet/activations/activations_me.py: -------------------------------------------------------------------------------- 1 | """ Activations (memory-efficient w/ custom autograd) 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | These activations are not compatible with jit scripting or ONNX export of the model, please use either 7 | the JIT or basic versions of the activations. 8 | 9 | Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | 17 | __all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', 18 | 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit_fwd(x): 23 | return x.mul(torch.sigmoid(x)) 24 | 25 | 26 | @torch.jit.script 27 | def swish_jit_bwd(x, grad_output): 28 | x_sigmoid = torch.sigmoid(x) 29 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 30 | 31 | 32 | class SwishJitAutoFn(torch.autograd.Function): 33 | """ torch.jit.script optimised Swish w/ memory-efficient checkpoint 34 | Inspired by conversation btw Jeremy Howard & Adam Pazske 35 | https://twitter.com/jeremyphoward/status/1188251041835315200 36 | 37 | Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 38 | and also as Swish (https://arxiv.org/abs/1710.05941). 39 | 40 | TODO Rename to SiLU with addition to PyTorch 41 | """ 42 | 43 | @staticmethod 44 | def forward(ctx, x): 45 | ctx.save_for_backward(x) 46 | return swish_jit_fwd(x) 47 | 48 | @staticmethod 49 | def backward(ctx, grad_output): 50 | x = ctx.saved_tensors[0] 51 | return swish_jit_bwd(x, grad_output) 52 | 53 | 54 | def swish_me(x, inplace=False): 55 | return SwishJitAutoFn.apply(x) 56 | 57 | 58 | class SwishMe(nn.Module): 59 | def __init__(self, inplace: bool = False): 60 | super(SwishMe, self).__init__() 61 | 62 | def forward(self, x): 63 | return SwishJitAutoFn.apply(x) 64 | 65 | 66 | @torch.jit.script 67 | def mish_jit_fwd(x): 68 | return x.mul(torch.tanh(F.softplus(x))) 69 | 70 | 71 | @torch.jit.script 72 | def mish_jit_bwd(x, grad_output): 73 | x_sigmoid = torch.sigmoid(x) 74 | x_tanh_sp = F.softplus(x).tanh() 75 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 76 | 77 | 78 | class MishJitAutoFn(torch.autograd.Function): 79 | """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 80 | A memory efficient, jit scripted variant of Mish 81 | """ 82 | @staticmethod 83 | def forward(ctx, x): 84 | ctx.save_for_backward(x) 85 | return mish_jit_fwd(x) 86 | 87 | @staticmethod 88 | def backward(ctx, grad_output): 89 | x = ctx.saved_tensors[0] 90 | return mish_jit_bwd(x, grad_output) 91 | 92 | 93 | def mish_me(x, inplace=False): 94 | return MishJitAutoFn.apply(x) 95 | 96 | 97 | class MishMe(nn.Module): 98 | def __init__(self, inplace: bool = False): 99 | super(MishMe, self).__init__() 100 | 101 | def forward(self, x): 102 | return MishJitAutoFn.apply(x) 103 | 104 | 105 | @torch.jit.script 106 | def hard_sigmoid_jit_fwd(x, inplace: bool = False): 107 | return (x + 3).clamp(min=0, max=6).div(6.) 108 | 109 | 110 | @torch.jit.script 111 | def hard_sigmoid_jit_bwd(x, grad_output): 112 | m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. 113 | return grad_output * m 114 | 115 | 116 | class HardSigmoidJitAutoFn(torch.autograd.Function): 117 | @staticmethod 118 | def forward(ctx, x): 119 | ctx.save_for_backward(x) 120 | return hard_sigmoid_jit_fwd(x) 121 | 122 | @staticmethod 123 | def backward(ctx, grad_output): 124 | x = ctx.saved_tensors[0] 125 | return hard_sigmoid_jit_bwd(x, grad_output) 126 | 127 | 128 | def hard_sigmoid_me(x, inplace: bool = False): 129 | return HardSigmoidJitAutoFn.apply(x) 130 | 131 | 132 | class HardSigmoidMe(nn.Module): 133 | def __init__(self, inplace: bool = False): 134 | super(HardSigmoidMe, self).__init__() 135 | 136 | def forward(self, x): 137 | return HardSigmoidJitAutoFn.apply(x) 138 | 139 | 140 | @torch.jit.script 141 | def hard_swish_jit_fwd(x): 142 | return x * (x + 3).clamp(min=0, max=6).div(6.) 143 | 144 | 145 | @torch.jit.script 146 | def hard_swish_jit_bwd(x, grad_output): 147 | m = torch.ones_like(x) * (x >= 3.) 148 | m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) 149 | return grad_output * m 150 | 151 | 152 | class HardSwishJitAutoFn(torch.autograd.Function): 153 | """A memory efficient, jit-scripted HardSwish activation""" 154 | @staticmethod 155 | def forward(ctx, x): 156 | ctx.save_for_backward(x) 157 | return hard_swish_jit_fwd(x) 158 | 159 | @staticmethod 160 | def backward(ctx, grad_output): 161 | x = ctx.saved_tensors[0] 162 | return hard_swish_jit_bwd(x, grad_output) 163 | 164 | 165 | def hard_swish_me(x, inplace=False): 166 | return HardSwishJitAutoFn.apply(x) 167 | 168 | 169 | class HardSwishMe(nn.Module): 170 | def __init__(self, inplace: bool = False): 171 | super(HardSwishMe, self).__init__() 172 | 173 | def forward(self, x): 174 | return HardSwishJitAutoFn.apply(x) 175 | -------------------------------------------------------------------------------- /geffnet/config.py: -------------------------------------------------------------------------------- 1 | """ Global layer config state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | 117 | 118 | def layer_config_kwargs(kwargs): 119 | """ Consume config kwargs and return contextmgr obj """ 120 | return set_layer_config( 121 | scriptable=kwargs.pop('scriptable', None), 122 | exportable=kwargs.pop('exportable', None), 123 | no_jit=kwargs.pop('no_jit', None)) 124 | -------------------------------------------------------------------------------- /geffnet/conv2d_layers.py: -------------------------------------------------------------------------------- 1 | """ Conv2D w/ SAME padding, CondConv, MixedConv 2 | 3 | A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and 4 | MobileNetV3 models that maintain weight compatibility with original Tensorflow models. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | import collections.abc 9 | import math 10 | from functools import partial 11 | from itertools import repeat 12 | from typing import Tuple, Optional 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .config import * 20 | 21 | 22 | # From PyTorch internals 23 | def _ntuple(n): 24 | def parse(x): 25 | if isinstance(x, collections.abc.Iterable): 26 | return x 27 | return tuple(repeat(x, n)) 28 | return parse 29 | 30 | 31 | _single = _ntuple(1) 32 | _pair = _ntuple(2) 33 | _triple = _ntuple(3) 34 | _quadruple = _ntuple(4) 35 | 36 | 37 | def _is_static_pad(kernel_size, stride=1, dilation=1, **_): 38 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 39 | 40 | 41 | def _get_padding(kernel_size, stride=1, dilation=1, **_): 42 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 43 | return padding 44 | 45 | 46 | def _calc_same_pad(i: int, k: int, s: int, d: int): 47 | return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) 48 | 49 | 50 | def _same_pad_arg(input_size, kernel_size, stride, dilation): 51 | ih, iw = input_size 52 | kh, kw = kernel_size 53 | pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) 54 | pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) 55 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 56 | 57 | 58 | def _split_channels(num_chan, num_groups): 59 | split = [num_chan // num_groups for _ in range(num_groups)] 60 | split[0] += num_chan - sum(split) 61 | return split 62 | 63 | 64 | def conv2d_same( 65 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 66 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 67 | ih, iw = x.size()[-2:] 68 | kh, kw = weight.size()[-2:] 69 | pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) 70 | pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) 71 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 72 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 73 | 74 | 75 | class Conv2dSame(nn.Conv2d): 76 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 77 | """ 78 | 79 | # pylint: disable=unused-argument 80 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 81 | padding=0, dilation=1, groups=1, bias=True): 82 | super(Conv2dSame, self).__init__( 83 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 84 | 85 | def forward(self, x): 86 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 87 | 88 | 89 | class Conv2dSameExport(nn.Conv2d): 90 | """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions 91 | 92 | NOTE: This does not currently work with torch.jit.script 93 | """ 94 | 95 | # pylint: disable=unused-argument 96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 97 | padding=0, dilation=1, groups=1, bias=True): 98 | super(Conv2dSameExport, self).__init__( 99 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 100 | self.pad = None 101 | self.pad_input_size = (0, 0) 102 | 103 | def forward(self, x): 104 | input_size = x.size()[-2:] 105 | if self.pad is None: 106 | pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) 107 | self.pad = nn.ZeroPad2d(pad_arg) 108 | self.pad_input_size = input_size 109 | 110 | if self.pad is not None: 111 | x = self.pad(x) 112 | return F.conv2d( 113 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 114 | 115 | 116 | def get_padding_value(padding, kernel_size, **kwargs): 117 | dynamic = False 118 | if isinstance(padding, str): 119 | # for any string padding, the padding will be calculated for you, one of three ways 120 | padding = padding.lower() 121 | if padding == 'same': 122 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 123 | if _is_static_pad(kernel_size, **kwargs): 124 | # static case, no extra overhead 125 | padding = _get_padding(kernel_size, **kwargs) 126 | else: 127 | # dynamic padding 128 | padding = 0 129 | dynamic = True 130 | elif padding == 'valid': 131 | # 'VALID' padding, same as padding=0 132 | padding = 0 133 | else: 134 | # Default to PyTorch style 'same'-ish symmetric padding 135 | padding = _get_padding(kernel_size, **kwargs) 136 | return padding, dynamic 137 | 138 | 139 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 140 | padding = kwargs.pop('padding', '') 141 | kwargs.setdefault('bias', False) 142 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 143 | if is_dynamic: 144 | if is_exportable(): 145 | assert not is_scriptable() 146 | return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) 147 | else: 148 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 149 | else: 150 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 151 | 152 | 153 | class MixedConv2d(nn.ModuleDict): 154 | """ Mixed Grouped Convolution 155 | Based on MDConv and GroupedConv in MixNet impl: 156 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 157 | """ 158 | 159 | def __init__(self, in_channels, out_channels, kernel_size=3, 160 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 161 | super(MixedConv2d, self).__init__() 162 | 163 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 164 | num_groups = len(kernel_size) 165 | in_splits = _split_channels(in_channels, num_groups) 166 | out_splits = _split_channels(out_channels, num_groups) 167 | self.in_channels = sum(in_splits) 168 | self.out_channels = sum(out_splits) 169 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 170 | conv_groups = out_ch if depthwise else 1 171 | self.add_module( 172 | str(idx), 173 | create_conv2d_pad( 174 | in_ch, out_ch, k, stride=stride, 175 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 176 | ) 177 | self.splits = in_splits 178 | 179 | def forward(self, x): 180 | x_split = torch.split(x, self.splits, 1) 181 | x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] 182 | x = torch.cat(x_out, 1) 183 | return x 184 | 185 | 186 | def get_condconv_initializer(initializer, num_experts, expert_shape): 187 | def condconv_initializer(weight): 188 | """CondConv initializer function.""" 189 | num_params = np.prod(expert_shape) 190 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or 191 | weight.shape[1] != num_params): 192 | raise (ValueError( 193 | 'CondConv variables must have shape [num_experts, num_params]')) 194 | for i in range(num_experts): 195 | initializer(weight[i].view(expert_shape)) 196 | return condconv_initializer 197 | 198 | 199 | class CondConv2d(nn.Module): 200 | """ Conditional Convolution 201 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 202 | 203 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 204 | https://github.com/pytorch/pytorch/issues/17983 205 | """ 206 | __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] 207 | 208 | def __init__(self, in_channels, out_channels, kernel_size=3, 209 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 210 | super(CondConv2d, self).__init__() 211 | 212 | self.in_channels = in_channels 213 | self.out_channels = out_channels 214 | self.kernel_size = _pair(kernel_size) 215 | self.stride = _pair(stride) 216 | padding_val, is_padding_dynamic = get_padding_value( 217 | padding, kernel_size, stride=stride, dilation=dilation) 218 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 219 | self.padding = _pair(padding_val) 220 | self.dilation = _pair(dilation) 221 | self.groups = groups 222 | self.num_experts = num_experts 223 | 224 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 225 | weight_num_param = 1 226 | for wd in self.weight_shape: 227 | weight_num_param *= wd 228 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 229 | 230 | if bias: 231 | self.bias_shape = (self.out_channels,) 232 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 233 | else: 234 | self.register_parameter('bias', None) 235 | 236 | self.reset_parameters() 237 | 238 | def reset_parameters(self): 239 | init_weight = get_condconv_initializer( 240 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) 241 | init_weight(self.weight) 242 | if self.bias is not None: 243 | fan_in = np.prod(self.weight_shape[1:]) 244 | bound = 1 / math.sqrt(fan_in) 245 | init_bias = get_condconv_initializer( 246 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) 247 | init_bias(self.bias) 248 | 249 | def forward(self, x, routing_weights): 250 | B, C, H, W = x.shape 251 | weight = torch.matmul(routing_weights, self.weight) 252 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 253 | weight = weight.view(new_weight_shape) 254 | bias = None 255 | if self.bias is not None: 256 | bias = torch.matmul(routing_weights, self.bias) 257 | bias = bias.view(B * self.out_channels) 258 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 259 | x = x.view(1, B * C, H, W) 260 | if self.dynamic_padding: 261 | out = conv2d_same( 262 | x, weight, bias, stride=self.stride, padding=self.padding, 263 | dilation=self.dilation, groups=self.groups * B) 264 | else: 265 | out = F.conv2d( 266 | x, weight, bias, stride=self.stride, padding=self.padding, 267 | dilation=self.dilation, groups=self.groups * B) 268 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 269 | 270 | # Literal port (from TF definition) 271 | # x = torch.split(x, 1, 0) 272 | # weight = torch.split(weight, 1, 0) 273 | # if self.bias is not None: 274 | # bias = torch.matmul(routing_weights, self.bias) 275 | # bias = torch.split(bias, 1, 0) 276 | # else: 277 | # bias = [None] * B 278 | # out = [] 279 | # for xi, wi, bi in zip(x, weight, bias): 280 | # wi = wi.view(*self.weight_shape) 281 | # if bi is not None: 282 | # bi = bi.view(*self.bias_shape) 283 | # out.append(self.conv_fn( 284 | # xi, wi, bi, stride=self.stride, padding=self.padding, 285 | # dilation=self.dilation, groups=self.groups)) 286 | # out = torch.cat(out, 0) 287 | return out 288 | 289 | 290 | def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): 291 | assert 'groups' not in kwargs # only use 'depthwise' bool arg 292 | if isinstance(kernel_size, list): 293 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 294 | # We're going to use only lists for defining the MixedConv2d kernel groups, 295 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 296 | m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) 297 | else: 298 | depthwise = kwargs.pop('depthwise', False) 299 | groups = out_chs if depthwise else 1 300 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 301 | m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) 302 | else: 303 | m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) 304 | return m 305 | -------------------------------------------------------------------------------- /geffnet/efficientnet_builder.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet / MobileNetV3 Blocks and Builder 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | from copy import deepcopy 7 | 8 | from .conv2d_layers import * 9 | from geffnet.activations import * 10 | 11 | __all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', 12 | 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 13 | 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', 14 | 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' 15 | ] 16 | 17 | # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per 18 | # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) 19 | # NOTE: momentum varies btw .99 and .9997 depending on source 20 | # .99 in official TF TPU impl 21 | # .9997 (/w .999 in search space) for paper 22 | # 23 | # PyTorch defaults are momentum = .1, eps = 1e-5 24 | # 25 | BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 26 | BN_EPS_TF_DEFAULT = 1e-3 27 | _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) 28 | 29 | 30 | def get_bn_args_tf(): 31 | return _BN_ARGS_TF.copy() 32 | 33 | 34 | def resolve_bn_args(kwargs): 35 | bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} 36 | bn_momentum = kwargs.pop('bn_momentum', None) 37 | if bn_momentum is not None: 38 | bn_args['momentum'] = bn_momentum 39 | bn_eps = kwargs.pop('bn_eps', None) 40 | if bn_eps is not None: 41 | bn_args['eps'] = bn_eps 42 | return bn_args 43 | 44 | 45 | _SE_ARGS_DEFAULT = dict( 46 | gate_fn=sigmoid, 47 | act_layer=None, # None == use containing block's activation layer 48 | reduce_mid=False, 49 | divisor=1) 50 | 51 | 52 | def resolve_se_args(kwargs, in_chs, act_layer=None): 53 | se_kwargs = kwargs.copy() if kwargs is not None else {} 54 | # fill in args that aren't specified with the defaults 55 | for k, v in _SE_ARGS_DEFAULT.items(): 56 | se_kwargs.setdefault(k, v) 57 | # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch 58 | if not se_kwargs.pop('reduce_mid'): 59 | se_kwargs['reduced_base_chs'] = in_chs 60 | # act_layer override, if it remains None, the containing block's act_layer will be used 61 | if se_kwargs['act_layer'] is None: 62 | assert act_layer is not None 63 | se_kwargs['act_layer'] = act_layer 64 | return se_kwargs 65 | 66 | 67 | def resolve_act_layer(kwargs, default='relu'): 68 | act_layer = kwargs.pop('act_layer', default) 69 | if isinstance(act_layer, str): 70 | act_layer = get_act_layer(act_layer) 71 | return act_layer 72 | 73 | 74 | def make_divisible(v: int, divisor: int = 8, min_value: int = None): 75 | min_value = min_value or divisor 76 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 77 | if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. 78 | new_v += divisor 79 | return new_v 80 | 81 | 82 | def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): 83 | """Round number of filters based on depth multiplier.""" 84 | if not multiplier: 85 | return channels 86 | channels *= multiplier 87 | return make_divisible(channels, divisor, channel_min) 88 | 89 | 90 | def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): 91 | """Apply drop connect.""" 92 | if not training: 93 | return inputs 94 | 95 | keep_prob = 1 - drop_connect_rate 96 | random_tensor = keep_prob + torch.rand( 97 | (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) 98 | random_tensor.floor_() # binarize 99 | output = inputs.div(keep_prob) * random_tensor 100 | return output 101 | 102 | 103 | class SqueezeExcite(nn.Module): 104 | 105 | def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): 106 | super(SqueezeExcite, self).__init__() 107 | reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) 108 | self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) 109 | self.act1 = act_layer(inplace=True) 110 | self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) 111 | self.gate_fn = gate_fn 112 | 113 | def forward(self, x): 114 | x_se = x.mean((2, 3), keepdim=True) 115 | x_se = self.conv_reduce(x_se) 116 | x_se = self.act1(x_se) 117 | x_se = self.conv_expand(x_se) 118 | x = x * self.gate_fn(x_se) 119 | return x 120 | 121 | 122 | class ConvBnAct(nn.Module): 123 | def __init__(self, in_chs, out_chs, kernel_size, 124 | stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): 125 | super(ConvBnAct, self).__init__() 126 | assert stride in [1, 2] 127 | norm_kwargs = norm_kwargs or {} 128 | self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) 129 | self.bn1 = norm_layer(out_chs, **norm_kwargs) 130 | self.act1 = act_layer(inplace=True) 131 | 132 | def forward(self, x): 133 | x = self.conv(x) 134 | x = self.bn1(x) 135 | x = self.act1(x) 136 | return x 137 | 138 | 139 | class DepthwiseSeparableConv(nn.Module): 140 | """ DepthwiseSeparable block 141 | Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion 142 | factor of 1.0. This is an alternative to having a IR with optional first pw conv. 143 | """ 144 | def __init__(self, in_chs, out_chs, dw_kernel_size=3, 145 | stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, 146 | pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, 147 | norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): 148 | super(DepthwiseSeparableConv, self).__init__() 149 | assert stride in [1, 2] 150 | norm_kwargs = norm_kwargs or {} 151 | self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip 152 | self.drop_connect_rate = drop_connect_rate 153 | 154 | self.conv_dw = select_conv2d( 155 | in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) 156 | self.bn1 = norm_layer(in_chs, **norm_kwargs) 157 | self.act1 = act_layer(inplace=True) 158 | 159 | # Squeeze-and-excitation 160 | if se_ratio is not None and se_ratio > 0.: 161 | se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) 162 | self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) 163 | else: 164 | self.se = nn.Identity() 165 | 166 | self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) 167 | self.bn2 = norm_layer(out_chs, **norm_kwargs) 168 | self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() 169 | 170 | def forward(self, x): 171 | residual = x 172 | 173 | x = self.conv_dw(x) 174 | x = self.bn1(x) 175 | x = self.act1(x) 176 | 177 | x = self.se(x) 178 | 179 | x = self.conv_pw(x) 180 | x = self.bn2(x) 181 | x = self.act2(x) 182 | 183 | if self.has_residual: 184 | if self.drop_connect_rate > 0.: 185 | x = drop_connect(x, self.training, self.drop_connect_rate) 186 | x += residual 187 | return x 188 | 189 | 190 | class InvertedResidual(nn.Module): 191 | """ Inverted residual block w/ optional SE""" 192 | 193 | def __init__(self, in_chs, out_chs, dw_kernel_size=3, 194 | stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, 195 | exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, 196 | se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, 197 | conv_kwargs=None, drop_connect_rate=0.): 198 | super(InvertedResidual, self).__init__() 199 | norm_kwargs = norm_kwargs or {} 200 | conv_kwargs = conv_kwargs or {} 201 | mid_chs: int = make_divisible(in_chs * exp_ratio) 202 | self.has_residual = (in_chs == out_chs and stride == 1) and not noskip 203 | self.drop_connect_rate = drop_connect_rate 204 | 205 | # Point-wise expansion 206 | self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) 207 | self.bn1 = norm_layer(mid_chs, **norm_kwargs) 208 | self.act1 = act_layer(inplace=True) 209 | 210 | # Depth-wise convolution 211 | self.conv_dw = select_conv2d( 212 | mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) 213 | self.bn2 = norm_layer(mid_chs, **norm_kwargs) 214 | self.act2 = act_layer(inplace=True) 215 | 216 | # Squeeze-and-excitation 217 | if se_ratio is not None and se_ratio > 0.: 218 | se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) 219 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) 220 | else: 221 | self.se = nn.Identity() # for jit.script compat 222 | 223 | # Point-wise linear projection 224 | self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) 225 | self.bn3 = norm_layer(out_chs, **norm_kwargs) 226 | 227 | def forward(self, x): 228 | residual = x 229 | 230 | # Point-wise expansion 231 | x = self.conv_pw(x) 232 | x = self.bn1(x) 233 | x = self.act1(x) 234 | 235 | # Depth-wise convolution 236 | x = self.conv_dw(x) 237 | x = self.bn2(x) 238 | x = self.act2(x) 239 | 240 | # Squeeze-and-excitation 241 | x = self.se(x) 242 | 243 | # Point-wise linear projection 244 | x = self.conv_pwl(x) 245 | x = self.bn3(x) 246 | 247 | if self.has_residual: 248 | if self.drop_connect_rate > 0.: 249 | x = drop_connect(x, self.training, self.drop_connect_rate) 250 | x += residual 251 | return x 252 | 253 | 254 | class CondConvResidual(InvertedResidual): 255 | """ Inverted residual block w/ CondConv routing""" 256 | 257 | def __init__(self, in_chs, out_chs, dw_kernel_size=3, 258 | stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, 259 | exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, 260 | se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, 261 | num_experts=0, drop_connect_rate=0.): 262 | 263 | self.num_experts = num_experts 264 | conv_kwargs = dict(num_experts=self.num_experts) 265 | 266 | super(CondConvResidual, self).__init__( 267 | in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, 268 | act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, 269 | pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, 270 | norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, 271 | drop_connect_rate=drop_connect_rate) 272 | 273 | self.routing_fn = nn.Linear(in_chs, self.num_experts) 274 | 275 | def forward(self, x): 276 | residual = x 277 | 278 | # CondConv routing 279 | pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) 280 | routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) 281 | 282 | # Point-wise expansion 283 | x = self.conv_pw(x, routing_weights) 284 | x = self.bn1(x) 285 | x = self.act1(x) 286 | 287 | # Depth-wise convolution 288 | x = self.conv_dw(x, routing_weights) 289 | x = self.bn2(x) 290 | x = self.act2(x) 291 | 292 | # Squeeze-and-excitation 293 | x = self.se(x) 294 | 295 | # Point-wise linear projection 296 | x = self.conv_pwl(x, routing_weights) 297 | x = self.bn3(x) 298 | 299 | if self.has_residual: 300 | if self.drop_connect_rate > 0.: 301 | x = drop_connect(x, self.training, self.drop_connect_rate) 302 | x += residual 303 | return x 304 | 305 | 306 | class EdgeResidual(nn.Module): 307 | """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" 308 | 309 | def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, 310 | stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, 311 | se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): 312 | super(EdgeResidual, self).__init__() 313 | norm_kwargs = norm_kwargs or {} 314 | mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) 315 | self.has_residual = (in_chs == out_chs and stride == 1) and not noskip 316 | self.drop_connect_rate = drop_connect_rate 317 | 318 | # Expansion convolution 319 | self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) 320 | self.bn1 = norm_layer(mid_chs, **norm_kwargs) 321 | self.act1 = act_layer(inplace=True) 322 | 323 | # Squeeze-and-excitation 324 | if se_ratio is not None and se_ratio > 0.: 325 | se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) 326 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) 327 | else: 328 | self.se = nn.Identity() 329 | 330 | # Point-wise linear projection 331 | self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) 332 | self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) 333 | 334 | def forward(self, x): 335 | residual = x 336 | 337 | # Expansion convolution 338 | x = self.conv_exp(x) 339 | x = self.bn1(x) 340 | x = self.act1(x) 341 | 342 | # Squeeze-and-excitation 343 | x = self.se(x) 344 | 345 | # Point-wise linear projection 346 | x = self.conv_pwl(x) 347 | x = self.bn2(x) 348 | 349 | if self.has_residual: 350 | if self.drop_connect_rate > 0.: 351 | x = drop_connect(x, self.training, self.drop_connect_rate) 352 | x += residual 353 | 354 | return x 355 | 356 | 357 | class EfficientNetBuilder: 358 | """ Build Trunk Blocks for Efficient/Mobile Networks 359 | 360 | This ended up being somewhat of a cross between 361 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py 362 | and 363 | https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py 364 | 365 | """ 366 | 367 | def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, 368 | pad_type='', act_layer=None, se_kwargs=None, 369 | norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): 370 | self.channel_multiplier = channel_multiplier 371 | self.channel_divisor = channel_divisor 372 | self.channel_min = channel_min 373 | self.pad_type = pad_type 374 | self.act_layer = act_layer 375 | self.se_kwargs = se_kwargs 376 | self.norm_layer = norm_layer 377 | self.norm_kwargs = norm_kwargs 378 | self.drop_connect_rate = drop_connect_rate 379 | 380 | # updated during build 381 | self.in_chs = None 382 | self.block_idx = 0 383 | self.block_count = 0 384 | 385 | def _round_channels(self, chs): 386 | return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) 387 | 388 | def _make_block(self, ba): 389 | bt = ba.pop('block_type') 390 | ba['in_chs'] = self.in_chs 391 | ba['out_chs'] = self._round_channels(ba['out_chs']) 392 | if 'fake_in_chs' in ba and ba['fake_in_chs']: 393 | # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU 394 | ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) 395 | ba['norm_layer'] = self.norm_layer 396 | ba['norm_kwargs'] = self.norm_kwargs 397 | ba['pad_type'] = self.pad_type 398 | # block act fn overrides the model default 399 | ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer 400 | assert ba['act_layer'] is not None 401 | if bt == 'ir': 402 | ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count 403 | ba['se_kwargs'] = self.se_kwargs 404 | if ba.get('num_experts', 0) > 0: 405 | block = CondConvResidual(**ba) 406 | else: 407 | block = InvertedResidual(**ba) 408 | elif bt == 'ds' or bt == 'dsa': 409 | ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count 410 | ba['se_kwargs'] = self.se_kwargs 411 | block = DepthwiseSeparableConv(**ba) 412 | elif bt == 'er': 413 | ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count 414 | ba['se_kwargs'] = self.se_kwargs 415 | block = EdgeResidual(**ba) 416 | elif bt == 'cn': 417 | block = ConvBnAct(**ba) 418 | else: 419 | assert False, 'Uknkown block type (%s) while building model.' % bt 420 | self.in_chs = ba['out_chs'] # update in_chs for arg of next block 421 | return block 422 | 423 | def _make_stack(self, stack_args): 424 | blocks = [] 425 | # each stack (stage) contains a list of block arguments 426 | for i, ba in enumerate(stack_args): 427 | if i >= 1: 428 | # only the first block in any stack can have a stride > 1 429 | ba['stride'] = 1 430 | block = self._make_block(ba) 431 | blocks.append(block) 432 | self.block_idx += 1 # incr global idx (across all stacks) 433 | return nn.Sequential(*blocks) 434 | 435 | def __call__(self, in_chs, block_args): 436 | """ Build the blocks 437 | Args: 438 | in_chs: Number of input-channels passed to first block 439 | block_args: A list of lists, outer list defines stages, inner 440 | list contains strings defining block configuration(s) 441 | Return: 442 | List of block stacks (each stack wrapped in nn.Sequential) 443 | """ 444 | self.in_chs = in_chs 445 | self.block_count = sum([len(x) for x in block_args]) 446 | self.block_idx = 0 447 | blocks = [] 448 | # outer list of block_args defines the stacks ('stages' by some conventions) 449 | for stack_idx, stack in enumerate(block_args): 450 | assert isinstance(stack, list) 451 | stack = self._make_stack(stack) 452 | blocks.append(stack) 453 | return blocks 454 | 455 | 456 | def _parse_ksize(ss): 457 | if ss.isdigit(): 458 | return int(ss) 459 | else: 460 | return [int(k) for k in ss.split('.')] 461 | 462 | 463 | def _decode_block_str(block_str): 464 | """ Decode block definition string 465 | 466 | Gets a list of block arg (dicts) through a string notation of arguments. 467 | E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip 468 | 469 | All args can exist in any order with the exception of the leading string which 470 | is assumed to indicate the block type. 471 | 472 | leading string - block type ( 473 | ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) 474 | r - number of repeat blocks, 475 | k - kernel size, 476 | s - strides (1-9), 477 | e - expansion ratio, 478 | c - output channels, 479 | se - squeeze/excitation ratio 480 | n - activation fn ('re', 'r6', 'hs', or 'sw') 481 | Args: 482 | block_str: a string representation of block arguments. 483 | Returns: 484 | A list of block args (dicts) 485 | Raises: 486 | ValueError: if the string def not properly specified (TODO) 487 | """ 488 | assert isinstance(block_str, str) 489 | ops = block_str.split('_') 490 | block_type = ops[0] # take the block type off the front 491 | ops = ops[1:] 492 | options = {} 493 | noskip = False 494 | for op in ops: 495 | # string options being checked on individual basis, combine if they grow 496 | if op == 'noskip': 497 | noskip = True 498 | elif op.startswith('n'): 499 | # activation fn 500 | key = op[0] 501 | v = op[1:] 502 | if v == 're': 503 | value = get_act_layer('relu') 504 | elif v == 'r6': 505 | value = get_act_layer('relu6') 506 | elif v == 'hs': 507 | value = get_act_layer('hard_swish') 508 | elif v == 'sw': 509 | value = get_act_layer('swish') 510 | else: 511 | continue 512 | options[key] = value 513 | else: 514 | # all numeric options 515 | splits = re.split(r'(\d.*)', op) 516 | if len(splits) >= 2: 517 | key, value = splits[:2] 518 | options[key] = value 519 | 520 | # if act_layer is None, the model default (passed to model init) will be used 521 | act_layer = options['n'] if 'n' in options else None 522 | exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 523 | pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 524 | fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def 525 | 526 | num_repeat = int(options['r']) 527 | # each type of block has different valid arguments, fill accordingly 528 | if block_type == 'ir': 529 | block_args = dict( 530 | block_type=block_type, 531 | dw_kernel_size=_parse_ksize(options['k']), 532 | exp_kernel_size=exp_kernel_size, 533 | pw_kernel_size=pw_kernel_size, 534 | out_chs=int(options['c']), 535 | exp_ratio=float(options['e']), 536 | se_ratio=float(options['se']) if 'se' in options else None, 537 | stride=int(options['s']), 538 | act_layer=act_layer, 539 | noskip=noskip, 540 | ) 541 | if 'cc' in options: 542 | block_args['num_experts'] = int(options['cc']) 543 | elif block_type == 'ds' or block_type == 'dsa': 544 | block_args = dict( 545 | block_type=block_type, 546 | dw_kernel_size=_parse_ksize(options['k']), 547 | pw_kernel_size=pw_kernel_size, 548 | out_chs=int(options['c']), 549 | se_ratio=float(options['se']) if 'se' in options else None, 550 | stride=int(options['s']), 551 | act_layer=act_layer, 552 | pw_act=block_type == 'dsa', 553 | noskip=block_type == 'dsa' or noskip, 554 | ) 555 | elif block_type == 'er': 556 | block_args = dict( 557 | block_type=block_type, 558 | exp_kernel_size=_parse_ksize(options['k']), 559 | pw_kernel_size=pw_kernel_size, 560 | out_chs=int(options['c']), 561 | exp_ratio=float(options['e']), 562 | fake_in_chs=fake_in_chs, 563 | se_ratio=float(options['se']) if 'se' in options else None, 564 | stride=int(options['s']), 565 | act_layer=act_layer, 566 | noskip=noskip, 567 | ) 568 | elif block_type == 'cn': 569 | block_args = dict( 570 | block_type=block_type, 571 | kernel_size=int(options['k']), 572 | out_chs=int(options['c']), 573 | stride=int(options['s']), 574 | act_layer=act_layer, 575 | ) 576 | else: 577 | assert False, 'Unknown block type (%s)' % block_type 578 | 579 | return block_args, num_repeat 580 | 581 | 582 | def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): 583 | """ Per-stage depth scaling 584 | Scales the block repeats in each stage. This depth scaling impl maintains 585 | compatibility with the EfficientNet scaling method, while allowing sensible 586 | scaling for other models that may have multiple block arg definitions in each stage. 587 | """ 588 | 589 | # We scale the total repeat count for each stage, there may be multiple 590 | # block arg defs per stage so we need to sum. 591 | num_repeat = sum(repeats) 592 | if depth_trunc == 'round': 593 | # Truncating to int by rounding allows stages with few repeats to remain 594 | # proportionally smaller for longer. This is a good choice when stage definitions 595 | # include single repeat stages that we'd prefer to keep that way as long as possible 596 | num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) 597 | else: 598 | # The default for EfficientNet truncates repeats to int via 'ceil'. 599 | # Any multiplier > 1.0 will result in an increased depth for every stage. 600 | num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) 601 | 602 | # Proportionally distribute repeat count scaling to each block definition in the stage. 603 | # Allocation is done in reverse as it results in the first block being less likely to be scaled. 604 | # The first block makes less sense to repeat in most of the arch definitions. 605 | repeats_scaled = [] 606 | for r in repeats[::-1]: 607 | rs = max(1, round((r / num_repeat * num_repeat_scaled))) 608 | repeats_scaled.append(rs) 609 | num_repeat -= r 610 | num_repeat_scaled -= rs 611 | repeats_scaled = repeats_scaled[::-1] 612 | 613 | # Apply the calculated scaling to each block arg in the stage 614 | sa_scaled = [] 615 | for ba, rep in zip(stack_args, repeats_scaled): 616 | sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) 617 | return sa_scaled 618 | 619 | 620 | def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): 621 | arch_args = [] 622 | for stack_idx, block_strings in enumerate(arch_def): 623 | assert isinstance(block_strings, list) 624 | stack_args = [] 625 | repeats = [] 626 | for block_str in block_strings: 627 | assert isinstance(block_str, str) 628 | ba, rep = _decode_block_str(block_str) 629 | if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: 630 | ba['num_experts'] *= experts_multiplier 631 | stack_args.append(ba) 632 | repeats.append(rep) 633 | if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): 634 | arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) 635 | else: 636 | arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) 637 | return arch_args 638 | 639 | 640 | def initialize_weight_goog(m, n='', fix_group_fanout=True): 641 | # weight init as per Tensorflow Official impl 642 | # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py 643 | if isinstance(m, CondConv2d): 644 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 645 | if fix_group_fanout: 646 | fan_out //= m.groups 647 | init_weight_fn = get_condconv_initializer( 648 | lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) 649 | init_weight_fn(m.weight) 650 | if m.bias is not None: 651 | m.bias.data.zero_() 652 | elif isinstance(m, nn.Conv2d): 653 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 654 | if fix_group_fanout: 655 | fan_out //= m.groups 656 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 657 | if m.bias is not None: 658 | m.bias.data.zero_() 659 | elif isinstance(m, nn.BatchNorm2d): 660 | m.weight.data.fill_(1.0) 661 | m.bias.data.zero_() 662 | elif isinstance(m, nn.Linear): 663 | fan_out = m.weight.size(0) # fan-out 664 | fan_in = 0 665 | if 'routing_fn' in n: 666 | fan_in = m.weight.size(1) 667 | init_range = 1.0 / math.sqrt(fan_in + fan_out) 668 | m.weight.data.uniform_(-init_range, init_range) 669 | m.bias.data.zero_() 670 | 671 | 672 | def initialize_weight_default(m, n=''): 673 | if isinstance(m, CondConv2d): 674 | init_fn = get_condconv_initializer(partial( 675 | nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) 676 | init_fn(m.weight) 677 | elif isinstance(m, nn.Conv2d): 678 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 679 | elif isinstance(m, nn.BatchNorm2d): 680 | m.weight.data.fill_(1.0) 681 | m.bias.data.zero_() 682 | elif isinstance(m, nn.Linear): 683 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') 684 | -------------------------------------------------------------------------------- /geffnet/helpers.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint loading / state_dict helpers 2 | Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | import os 6 | from collections import OrderedDict 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | 13 | def load_checkpoint(model, checkpoint_path): 14 | if checkpoint_path and os.path.isfile(checkpoint_path): 15 | print("=> Loading checkpoint '{}'".format(checkpoint_path)) 16 | checkpoint = torch.load(checkpoint_path) 17 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 18 | new_state_dict = OrderedDict() 19 | for k, v in checkpoint['state_dict'].items(): 20 | if k.startswith('module'): 21 | name = k[7:] # remove `module.` 22 | else: 23 | name = k 24 | new_state_dict[name] = v 25 | model.load_state_dict(new_state_dict) 26 | else: 27 | model.load_state_dict(checkpoint) 28 | print("=> Loaded checkpoint '{}'".format(checkpoint_path)) 29 | else: 30 | print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) 31 | raise FileNotFoundError() 32 | 33 | 34 | def load_pretrained(model, url, filter_fn=None, strict=True): 35 | if not url: 36 | print("=> Warning: Pretrained model URL is empty, using random initialization.") 37 | return 38 | 39 | state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') 40 | 41 | input_conv = 'conv_stem' 42 | classifier = 'classifier' 43 | in_chans = getattr(model, input_conv).weight.shape[1] 44 | num_classes = getattr(model, classifier).weight.shape[0] 45 | 46 | input_conv_weight = input_conv + '.weight' 47 | pretrained_in_chans = state_dict[input_conv_weight].shape[1] 48 | if in_chans != pretrained_in_chans: 49 | if in_chans == 1: 50 | print('=> Converting pretrained input conv {} from {} to 1 channel'.format( 51 | input_conv_weight, pretrained_in_chans)) 52 | conv1_weight = state_dict[input_conv_weight] 53 | state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) 54 | else: 55 | print('=> Discarding pretrained input conv {} since input channel count != {}'.format( 56 | input_conv_weight, pretrained_in_chans)) 57 | del state_dict[input_conv_weight] 58 | strict = False 59 | 60 | classifier_weight = classifier + '.weight' 61 | pretrained_num_classes = state_dict[classifier_weight].shape[0] 62 | if num_classes != pretrained_num_classes: 63 | print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) 64 | del state_dict[classifier_weight] 65 | del state_dict[classifier + '.bias'] 66 | strict = False 67 | 68 | if filter_fn is not None: 69 | state_dict = filter_fn(state_dict) 70 | 71 | model.load_state_dict(state_dict, strict=strict) 72 | -------------------------------------------------------------------------------- /geffnet/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """ MobileNet-V3 2 | 3 | A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. 4 | 5 | Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .activations import get_act_fn, get_act_layer, HardSwish 13 | from .config import layer_config_kwargs 14 | from .conv2d_layers import select_conv2d 15 | from .helpers import load_pretrained 16 | from .efficientnet_builder import * 17 | 18 | __all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', 19 | 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', 20 | 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 21 | 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] 22 | 23 | model_urls = { 24 | 'mobilenetv3_rw': 25 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', 26 | 'mobilenetv3_large_075': None, 27 | 'mobilenetv3_large_100': 28 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', 29 | 'mobilenetv3_large_minimal_100': None, 30 | 'mobilenetv3_small_075': None, 31 | 'mobilenetv3_small_100': None, 32 | 'mobilenetv3_small_minimal_100': None, 33 | 'tf_mobilenetv3_large_075': 34 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', 35 | 'tf_mobilenetv3_large_100': 36 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', 37 | 'tf_mobilenetv3_large_minimal_100': 38 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', 39 | 'tf_mobilenetv3_small_075': 40 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', 41 | 'tf_mobilenetv3_small_100': 42 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', 43 | 'tf_mobilenetv3_small_minimal_100': 44 | 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', 45 | } 46 | 47 | 48 | class MobileNetV3(nn.Module): 49 | """ MobileNet-V3 50 | 51 | A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the 52 | head convolution without a final batch-norm layer before the classifier. 53 | 54 | Paper: https://arxiv.org/abs/1905.02244 55 | """ 56 | 57 | def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, 58 | channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., 59 | se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): 60 | super(MobileNetV3, self).__init__() 61 | self.drop_rate = drop_rate 62 | 63 | stem_size = round_channels(stem_size, channel_multiplier) 64 | self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) 65 | self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) 66 | self.act1 = act_layer(inplace=True) 67 | in_chs = stem_size 68 | 69 | builder = EfficientNetBuilder( 70 | channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, 71 | norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) 72 | self.blocks = nn.Sequential(*builder(in_chs, block_args)) 73 | in_chs = builder.in_chs 74 | 75 | self.global_pool = nn.AdaptiveAvgPool2d(1) 76 | self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) 77 | self.act2 = act_layer(inplace=True) 78 | self.classifier = nn.Linear(num_features, num_classes) 79 | 80 | for m in self.modules(): 81 | if weight_init == 'goog': 82 | initialize_weight_goog(m) 83 | else: 84 | initialize_weight_default(m) 85 | 86 | def as_sequential(self): 87 | layers = [self.conv_stem, self.bn1, self.act1] 88 | layers.extend(self.blocks) 89 | layers.extend([ 90 | self.global_pool, self.conv_head, self.act2, 91 | nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) 92 | return nn.Sequential(*layers) 93 | 94 | def features(self, x): 95 | x = self.conv_stem(x) 96 | x = self.bn1(x) 97 | x = self.act1(x) 98 | x = self.blocks(x) 99 | x = self.global_pool(x) 100 | x = self.conv_head(x) 101 | x = self.act2(x) 102 | return x 103 | 104 | def forward(self, x): 105 | x = self.features(x) 106 | x = x.flatten(1) 107 | if self.drop_rate > 0.: 108 | x = F.dropout(x, p=self.drop_rate, training=self.training) 109 | return self.classifier(x) 110 | 111 | 112 | def _create_model(model_kwargs, variant, pretrained=False): 113 | as_sequential = model_kwargs.pop('as_sequential', False) 114 | model = MobileNetV3(**model_kwargs) 115 | if pretrained and model_urls[variant]: 116 | load_pretrained(model, model_urls[variant]) 117 | if as_sequential: 118 | model = model.as_sequential() 119 | return model 120 | 121 | 122 | def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): 123 | """Creates a MobileNet-V3 model (RW variant). 124 | 125 | Paper: https://arxiv.org/abs/1905.02244 126 | 127 | This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the 128 | eventual Tensorflow reference impl but has a few differences: 129 | 1. This model has no bias on the head convolution 130 | 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet 131 | 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer 132 | from their parent block 133 | 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count 134 | 135 | Overall the changes are fairly minor and result in a very small parameter count difference and no 136 | top-1/5 137 | 138 | Args: 139 | channel_multiplier: multiplier to number of channels per layer. 140 | """ 141 | arch_def = [ 142 | # stage 0, 112x112 in 143 | ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu 144 | # stage 1, 112x112 in 145 | ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu 146 | # stage 2, 56x56 in 147 | ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu 148 | # stage 3, 28x28 in 149 | ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish 150 | # stage 4, 14x14in 151 | ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish 152 | # stage 5, 14x14in 153 | ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish 154 | # stage 6, 7x7 in 155 | ['cn_r1_k1_s1_c960'], # hard-swish 156 | ] 157 | with layer_config_kwargs(kwargs): 158 | model_kwargs = dict( 159 | block_args=decode_arch_def(arch_def), 160 | head_bias=False, # one of my mistakes 161 | channel_multiplier=channel_multiplier, 162 | act_layer=resolve_act_layer(kwargs, 'hard_swish'), 163 | se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), 164 | norm_kwargs=resolve_bn_args(kwargs), 165 | **kwargs, 166 | ) 167 | model = _create_model(model_kwargs, variant, pretrained) 168 | return model 169 | 170 | 171 | def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): 172 | """Creates a MobileNet-V3 large/small/minimal models. 173 | 174 | Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py 175 | Paper: https://arxiv.org/abs/1905.02244 176 | 177 | Args: 178 | channel_multiplier: multiplier to number of channels per layer. 179 | """ 180 | if 'small' in variant: 181 | num_features = 1024 182 | if 'minimal' in variant: 183 | act_layer = 'relu' 184 | arch_def = [ 185 | # stage 0, 112x112 in 186 | ['ds_r1_k3_s2_e1_c16'], 187 | # stage 1, 56x56 in 188 | ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], 189 | # stage 2, 28x28 in 190 | ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], 191 | # stage 3, 14x14 in 192 | ['ir_r2_k3_s1_e3_c48'], 193 | # stage 4, 14x14in 194 | ['ir_r3_k3_s2_e6_c96'], 195 | # stage 6, 7x7 in 196 | ['cn_r1_k1_s1_c576'], 197 | ] 198 | else: 199 | act_layer = 'hard_swish' 200 | arch_def = [ 201 | # stage 0, 112x112 in 202 | ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu 203 | # stage 1, 56x56 in 204 | ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu 205 | # stage 2, 28x28 in 206 | ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish 207 | # stage 3, 14x14 in 208 | ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish 209 | # stage 4, 14x14in 210 | ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish 211 | # stage 6, 7x7 in 212 | ['cn_r1_k1_s1_c576'], # hard-swish 213 | ] 214 | else: 215 | num_features = 1280 216 | if 'minimal' in variant: 217 | act_layer = 'relu' 218 | arch_def = [ 219 | # stage 0, 112x112 in 220 | ['ds_r1_k3_s1_e1_c16'], 221 | # stage 1, 112x112 in 222 | ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], 223 | # stage 2, 56x56 in 224 | ['ir_r3_k3_s2_e3_c40'], 225 | # stage 3, 28x28 in 226 | ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], 227 | # stage 4, 14x14in 228 | ['ir_r2_k3_s1_e6_c112'], 229 | # stage 5, 14x14in 230 | ['ir_r3_k3_s2_e6_c160'], 231 | # stage 6, 7x7 in 232 | ['cn_r1_k1_s1_c960'], 233 | ] 234 | else: 235 | act_layer = 'hard_swish' 236 | arch_def = [ 237 | # stage 0, 112x112 in 238 | ['ds_r1_k3_s1_e1_c16_nre'], # relu 239 | # stage 1, 112x112 in 240 | ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu 241 | # stage 2, 56x56 in 242 | ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu 243 | # stage 3, 28x28 in 244 | ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish 245 | # stage 4, 14x14in 246 | ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish 247 | # stage 5, 14x14in 248 | ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish 249 | # stage 6, 7x7 in 250 | ['cn_r1_k1_s1_c960'], # hard-swish 251 | ] 252 | with layer_config_kwargs(kwargs): 253 | model_kwargs = dict( 254 | block_args=decode_arch_def(arch_def), 255 | num_features=num_features, 256 | stem_size=16, 257 | channel_multiplier=channel_multiplier, 258 | act_layer=resolve_act_layer(kwargs, act_layer), 259 | se_kwargs=dict( 260 | act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), 261 | norm_kwargs=resolve_bn_args(kwargs), 262 | **kwargs, 263 | ) 264 | model = _create_model(model_kwargs, variant, pretrained) 265 | return model 266 | 267 | 268 | def mobilenetv3_rw(pretrained=False, **kwargs): 269 | """ MobileNet-V3 RW 270 | Attn: See note in gen function for this variant. 271 | """ 272 | # NOTE for train set drop_rate=0.2 273 | if pretrained: 274 | # pretrained model trained with non-default BN epsilon 275 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 276 | model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) 277 | return model 278 | 279 | 280 | def mobilenetv3_large_075(pretrained=False, **kwargs): 281 | """ MobileNet V3 Large 0.75""" 282 | # NOTE for train set drop_rate=0.2 283 | model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) 284 | return model 285 | 286 | 287 | def mobilenetv3_large_100(pretrained=False, **kwargs): 288 | """ MobileNet V3 Large 1.0 """ 289 | # NOTE for train set drop_rate=0.2 290 | model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) 291 | return model 292 | 293 | 294 | def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): 295 | """ MobileNet V3 Large (Minimalistic) 1.0 """ 296 | # NOTE for train set drop_rate=0.2 297 | model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) 298 | return model 299 | 300 | 301 | def mobilenetv3_small_075(pretrained=False, **kwargs): 302 | """ MobileNet V3 Small 0.75 """ 303 | model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) 304 | return model 305 | 306 | 307 | def mobilenetv3_small_100(pretrained=False, **kwargs): 308 | """ MobileNet V3 Small 1.0 """ 309 | model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) 310 | return model 311 | 312 | 313 | def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): 314 | """ MobileNet V3 Small (Minimalistic) 1.0 """ 315 | model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) 316 | return model 317 | 318 | 319 | def tf_mobilenetv3_large_075(pretrained=False, **kwargs): 320 | """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ 321 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 322 | kwargs['pad_type'] = 'same' 323 | model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) 324 | return model 325 | 326 | 327 | def tf_mobilenetv3_large_100(pretrained=False, **kwargs): 328 | """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ 329 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 330 | kwargs['pad_type'] = 'same' 331 | model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) 332 | return model 333 | 334 | 335 | def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): 336 | """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ 337 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 338 | kwargs['pad_type'] = 'same' 339 | model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) 340 | return model 341 | 342 | 343 | def tf_mobilenetv3_small_075(pretrained=False, **kwargs): 344 | """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ 345 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 346 | kwargs['pad_type'] = 'same' 347 | model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) 348 | return model 349 | 350 | 351 | def tf_mobilenetv3_small_100(pretrained=False, **kwargs): 352 | """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" 353 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 354 | kwargs['pad_type'] = 'same' 355 | model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) 356 | return model 357 | 358 | 359 | def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): 360 | """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ 361 | kwargs['bn_eps'] = BN_EPS_TF_DEFAULT 362 | kwargs['pad_type'] = 'same' 363 | model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) 364 | return model 365 | -------------------------------------------------------------------------------- /geffnet/model_factory.py: -------------------------------------------------------------------------------- 1 | from .config import set_layer_config 2 | from .helpers import load_checkpoint 3 | 4 | from .gen_efficientnet import * 5 | from .mobilenetv3 import * 6 | 7 | 8 | def create_model( 9 | model_name='mnasnet_100', 10 | pretrained=None, 11 | num_classes=1000, 12 | in_chans=3, 13 | checkpoint_path='', 14 | **kwargs): 15 | 16 | model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) 17 | 18 | if model_name in globals(): 19 | create_fn = globals()[model_name] 20 | model = create_fn(**model_kwargs) 21 | else: 22 | raise RuntimeError('Unknown model (%s)' % model_name) 23 | 24 | if checkpoint_path and not pretrained: 25 | load_checkpoint(model, checkpoint_path) 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /geffnet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.2' 2 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'math'] 2 | 3 | from geffnet import efficientnet_b0 4 | from geffnet import efficientnet_b1 5 | from geffnet import efficientnet_b2 6 | from geffnet import efficientnet_b3 7 | 8 | from geffnet import efficientnet_es 9 | 10 | from geffnet import efficientnet_lite0 11 | 12 | from geffnet import mixnet_s 13 | from geffnet import mixnet_m 14 | from geffnet import mixnet_l 15 | from geffnet import mixnet_xl 16 | 17 | from geffnet import mobilenetv2_100 18 | from geffnet import mobilenetv2_110d 19 | from geffnet import mobilenetv2_120d 20 | from geffnet import mobilenetv2_140 21 | 22 | from geffnet import mobilenetv3_large_100 23 | from geffnet import mobilenetv3_rw 24 | from geffnet import mnasnet_a1 25 | from geffnet import mnasnet_b1 26 | from geffnet import fbnetc_100 27 | from geffnet import spnasnet_100 28 | 29 | from geffnet import tf_efficientnet_b0 30 | from geffnet import tf_efficientnet_b1 31 | from geffnet import tf_efficientnet_b2 32 | from geffnet import tf_efficientnet_b3 33 | from geffnet import tf_efficientnet_b4 34 | from geffnet import tf_efficientnet_b5 35 | from geffnet import tf_efficientnet_b6 36 | from geffnet import tf_efficientnet_b7 37 | from geffnet import tf_efficientnet_b8 38 | 39 | from geffnet import tf_efficientnet_b0_ap 40 | from geffnet import tf_efficientnet_b1_ap 41 | from geffnet import tf_efficientnet_b2_ap 42 | from geffnet import tf_efficientnet_b3_ap 43 | from geffnet import tf_efficientnet_b4_ap 44 | from geffnet import tf_efficientnet_b5_ap 45 | from geffnet import tf_efficientnet_b6_ap 46 | from geffnet import tf_efficientnet_b7_ap 47 | from geffnet import tf_efficientnet_b8_ap 48 | 49 | from geffnet import tf_efficientnet_b0_ns 50 | from geffnet import tf_efficientnet_b1_ns 51 | from geffnet import tf_efficientnet_b2_ns 52 | from geffnet import tf_efficientnet_b3_ns 53 | from geffnet import tf_efficientnet_b4_ns 54 | from geffnet import tf_efficientnet_b5_ns 55 | from geffnet import tf_efficientnet_b6_ns 56 | from geffnet import tf_efficientnet_b7_ns 57 | from geffnet import tf_efficientnet_l2_ns_475 58 | from geffnet import tf_efficientnet_l2_ns 59 | 60 | from geffnet import tf_efficientnet_es 61 | from geffnet import tf_efficientnet_em 62 | from geffnet import tf_efficientnet_el 63 | 64 | from geffnet import tf_efficientnet_cc_b0_4e 65 | from geffnet import tf_efficientnet_cc_b0_8e 66 | from geffnet import tf_efficientnet_cc_b1_8e 67 | 68 | from geffnet import tf_efficientnet_lite0 69 | from geffnet import tf_efficientnet_lite1 70 | from geffnet import tf_efficientnet_lite2 71 | from geffnet import tf_efficientnet_lite3 72 | from geffnet import tf_efficientnet_lite4 73 | 74 | from geffnet import tf_mixnet_s 75 | from geffnet import tf_mixnet_m 76 | from geffnet import tf_mixnet_l 77 | 78 | from geffnet import tf_mobilenetv3_large_075 79 | from geffnet import tf_mobilenetv3_large_100 80 | from geffnet import tf_mobilenetv3_large_minimal_100 81 | from geffnet import tf_mobilenetv3_small_075 82 | from geffnet import tf_mobilenetv3_small_100 83 | from geffnet import tf_mobilenetv3_small_minimal_100 84 | 85 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | """ ONNX export script 2 | 3 | Export PyTorch models as ONNX graphs. 4 | 5 | This export script originally started as an adaptation of code snippets found at 6 | https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html 7 | 8 | The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph 9 | for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible 10 | with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback 11 | flags are currently required. 12 | 13 | Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for 14 | caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime. 15 | 16 | Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models. 17 | Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks. 18 | 19 | Copyright 2020 Ross Wightman 20 | """ 21 | import argparse 22 | import torch 23 | import numpy as np 24 | 25 | import onnx 26 | import geffnet 27 | 28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 29 | parser.add_argument('output', metavar='ONNX_FILE', 30 | help='output model filename') 31 | parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100', 32 | help='model architecture (default: mobilenetv3_large_100)') 33 | parser.add_argument('--opset', type=int, default=10, 34 | help='ONNX opset to use (default: 10)') 35 | parser.add_argument('--keep-init', action='store_true', default=False, 36 | help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') 37 | parser.add_argument('--aten-fallback', action='store_true', default=False, 38 | help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') 39 | parser.add_argument('--dynamic-size', action='store_true', default=False, 40 | help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') 41 | parser.add_argument('-b', '--batch-size', default=1, type=int, 42 | metavar='N', help='mini-batch size (default: 1)') 43 | parser.add_argument('--img-size', default=None, type=int, 44 | metavar='N', help='Input image dimension, uses model default if empty') 45 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 46 | help='Override mean pixel value of dataset') 47 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 48 | help='Override std deviation of of dataset') 49 | parser.add_argument('--num-classes', type=int, default=1000, 50 | help='Number classes in dataset') 51 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 52 | help='path to checkpoint (default: none)') 53 | 54 | 55 | def main(): 56 | args = parser.parse_args() 57 | 58 | args.pretrained = True 59 | if args.checkpoint: 60 | args.pretrained = False 61 | 62 | print("==> Creating PyTorch {} model".format(args.model)) 63 | # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers 64 | # for models using SAME padding 65 | model = geffnet.create_model( 66 | args.model, 67 | num_classes=args.num_classes, 68 | in_chans=3, 69 | pretrained=args.pretrained, 70 | checkpoint_path=args.checkpoint, 71 | exportable=True) 72 | 73 | model.eval() 74 | 75 | example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True) 76 | 77 | # Run model once before export trace, sets padding for models with Conv2dSameExport. This means 78 | # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for 79 | # the input img_size specified in this script. 80 | # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to 81 | # issues in the tracing of the dynamic padding or errors attempting to export the model after jit 82 | # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... 83 | model(example_input) 84 | 85 | print("==> Exporting model to ONNX format at '{}'".format(args.output)) 86 | input_names = ["input0"] 87 | output_names = ["output0"] 88 | dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} 89 | if args.dynamic_size: 90 | dynamic_axes['input0'][2] = 'height' 91 | dynamic_axes['input0'][3] = 'width' 92 | if args.aten_fallback: 93 | export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK 94 | else: 95 | export_type = torch.onnx.OperatorExportTypes.ONNX 96 | 97 | torch_out = torch.onnx._export( 98 | model, example_input, args.output, export_params=True, verbose=True, input_names=input_names, 99 | output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes, 100 | opset_version=args.opset, operator_export_type=export_type) 101 | 102 | print("==> Loading and checking exported model from '{}'".format(args.output)) 103 | onnx_model = onnx.load(args.output) 104 | onnx.checker.check_model(onnx_model) # assuming throw on error 105 | print("==> Passed") 106 | 107 | if args.keep_init and args.aten_fallback: 108 | import caffe2.python.onnx.backend as onnx_caffe2 109 | # Caffe2 loading only works properly in newer PyTorch/ONNX combos when 110 | # keep_initializers_as_inputs and aten_fallback are set to True. 111 | print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output)) 112 | caffe2_backend = onnx_caffe2.prepare(onnx_model) 113 | B = {onnx_model.graph.input[0].name: x.data.numpy()} 114 | c2_out = caffe2_backend.run(B)[0] 115 | np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5) 116 | print("==> Passed") 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /onnx_optimize.py: -------------------------------------------------------------------------------- 1 | """ ONNX optimization script 2 | 3 | Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. 4 | 5 | NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), 6 | it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). 7 | 8 | Copyright 2020 Ross Wightman 9 | """ 10 | import argparse 11 | import warnings 12 | 13 | import onnx 14 | from onnx import optimizer 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Optimize ONNX model") 18 | 19 | parser.add_argument("model", help="The ONNX model") 20 | parser.add_argument("--output", required=True, help="The optimized model output filename") 21 | 22 | 23 | def traverse_graph(graph, prefix=''): 24 | content = [] 25 | indent = prefix + ' ' 26 | graphs = [] 27 | num_nodes = 0 28 | for node in graph.node: 29 | pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) 30 | assert isinstance(gs, list) 31 | content.append(pn) 32 | graphs.extend(gs) 33 | num_nodes += 1 34 | for g in graphs: 35 | g_count, g_str = traverse_graph(g) 36 | content.append('\n' + g_str) 37 | num_nodes += g_count 38 | return num_nodes, '\n'.join(content) 39 | 40 | 41 | def main(): 42 | args = parser.parse_args() 43 | onnx_model = onnx.load(args.model) 44 | num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) 45 | 46 | # Optimizer passes to perform 47 | passes = [ 48 | #'eliminate_deadend', 49 | 'eliminate_identity', 50 | 'eliminate_nop_dropout', 51 | 'eliminate_nop_pad', 52 | 'eliminate_nop_transpose', 53 | 'eliminate_unused_initializer', 54 | 'extract_constant_to_initializer', 55 | 'fuse_add_bias_into_conv', 56 | 'fuse_bn_into_conv', 57 | 'fuse_consecutive_concats', 58 | 'fuse_consecutive_reduce_unsqueeze', 59 | 'fuse_consecutive_squeezes', 60 | 'fuse_consecutive_transposes', 61 | #'fuse_matmul_add_bias_into_gemm', 62 | 'fuse_pad_into_conv', 63 | #'fuse_transpose_into_gemm', 64 | #'lift_lexical_references', 65 | ] 66 | 67 | # Apply the optimization on the original serialized model 68 | # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing 69 | # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 70 | # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. 71 | warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." 72 | "Try onnxruntime optimization if this doesn't work.") 73 | optimized_model = optimizer.optimize(onnx_model, passes) 74 | 75 | num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) 76 | print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) 77 | print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) 78 | 79 | # Save the ONNX model 80 | onnx.save(optimized_model, args.output) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /onnx_to_caffe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import onnx 4 | from caffe2.python.onnx.backend import Caffe2Backend 5 | 6 | 7 | parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2") 8 | 9 | parser.add_argument("model", help="The ONNX model") 10 | parser.add_argument("--c2-prefix", required=True, 11 | help="The output file prefix for the caffe2 model init and predict file. ") 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | onnx_model = onnx.load(args.model) 17 | caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) 18 | caffe2_init_str = caffe2_init.SerializeToString() 19 | with open(args.c2_prefix + '.init.pb', "wb") as f: 20 | f.write(caffe2_init_str) 21 | caffe2_predict_str = caffe2_predict.SerializeToString() 22 | with open(args.c2_prefix + '.predict.pb', "wb") as f: 23 | f.write(caffe2_predict_str) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /onnx_validate.py: -------------------------------------------------------------------------------- 1 | """ ONNX-runtime validation script 2 | 3 | This script was created to verify accuracy and performance of exported ONNX 4 | models running with the onnxruntime. It utilizes the PyTorch dataloader/processing 5 | pipeline for a fair comparison against the originals. 6 | 7 | Copyright 2020 Ross Wightman 8 | """ 9 | import argparse 10 | import numpy as np 11 | import onnxruntime 12 | from data import create_loader, resolve_data_config, Dataset 13 | from utils import AverageMeter 14 | import time 15 | 16 | parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') 17 | parser.add_argument('data', metavar='DIR', 18 | help='path to dataset') 19 | parser.add_argument('--onnx-input', default='', type=str, metavar='PATH', 20 | help='path to onnx model/weights file') 21 | parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', 22 | help='path to output optimized onnx graph') 23 | parser.add_argument('--profile', action='store_true', default=False, 24 | help='Enable profiler output.') 25 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 26 | help='number of data loading workers (default: 2)') 27 | parser.add_argument('-b', '--batch-size', default=256, type=int, 28 | metavar='N', help='mini-batch size (default: 256)') 29 | parser.add_argument('--img-size', default=None, type=int, 30 | metavar='N', help='Input image dimension, uses model default if empty') 31 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 32 | help='Override mean pixel value of dataset') 33 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 34 | help='Override std deviation of of dataset') 35 | parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', 36 | help='Override default crop pct of 0.875') 37 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 38 | help='Image resize interpolation type (overrides model)') 39 | parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', 40 | help='use tensorflow mnasnet preporcessing') 41 | parser.add_argument('--print-freq', '-p', default=10, type=int, 42 | metavar='N', help='print frequency (default: 10)') 43 | 44 | 45 | def main(): 46 | args = parser.parse_args() 47 | args.gpu_id = 0 48 | 49 | # Set graph optimization level 50 | sess_options = onnxruntime.SessionOptions() 51 | sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 52 | if args.profile: 53 | sess_options.enable_profiling = True 54 | if args.onnx_output_opt: 55 | sess_options.optimized_model_filepath = args.onnx_output_opt 56 | 57 | session = onnxruntime.InferenceSession(args.onnx_input, sess_options) 58 | 59 | data_config = resolve_data_config(None, args) 60 | loader = create_loader( 61 | Dataset(args.data, load_bytes=args.tf_preprocessing), 62 | input_size=data_config['input_size'], 63 | batch_size=args.batch_size, 64 | use_prefetcher=False, 65 | interpolation=data_config['interpolation'], 66 | mean=data_config['mean'], 67 | std=data_config['std'], 68 | num_workers=args.workers, 69 | crop_pct=data_config['crop_pct'], 70 | tensorflow_preprocessing=args.tf_preprocessing) 71 | 72 | input_name = session.get_inputs()[0].name 73 | 74 | batch_time = AverageMeter() 75 | top1 = AverageMeter() 76 | top5 = AverageMeter() 77 | end = time.time() 78 | for i, (input, target) in enumerate(loader): 79 | # run the net and return prediction 80 | output = session.run([], {input_name: input.data.numpy()}) 81 | output = output[0] 82 | 83 | # measure accuracy and record loss 84 | prec1, prec5 = accuracy_np(output, target.numpy()) 85 | top1.update(prec1.item(), input.size(0)) 86 | top5.update(prec5.item(), input.size(0)) 87 | 88 | # measure elapsed time 89 | batch_time.update(time.time() - end) 90 | end = time.time() 91 | 92 | if i % args.print_freq == 0: 93 | print('Test: [{0}/{1}]\t' 94 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' 95 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 96 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 97 | i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, 98 | ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) 99 | 100 | print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( 101 | top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) 102 | 103 | 104 | def accuracy_np(output, target): 105 | max_indices = np.argsort(output, axis=1)[:, ::-1] 106 | top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() 107 | top1 = 100 * np.equal(max_indices[:, 0], target).mean() 108 | return top1, top5 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | torchvision>=0.4.0 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('geffnet/version.py').read()) 14 | setup( 15 | name='geffnet', 16 | version=__version__, 17 | description='(Generic) EfficientNets for PyTorch', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/gen-efficientnet-pytorch', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Topic :: Scientific/Engineering', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ], 41 | 42 | # Note that this is a string of words separated by whitespace, not a list. 43 | keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', 44 | packages=find_packages(exclude=['data']), 45 | install_requires=['torch >= 1.4', 'torchvision'], 46 | python_requires='>=3.6', 47 | ) 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class AverageMeter: 5 | """Computes and stores the average and current value""" 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | 22 | def accuracy(output, target, topk=(1,)): 23 | """Computes the precision@k for the specified values of k""" 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | 38 | def get_outdir(path, *paths, inc=False): 39 | outdir = os.path.join(path, *paths) 40 | if not os.path.exists(outdir): 41 | os.makedirs(outdir) 42 | elif inc: 43 | count = 1 44 | outdir_inc = outdir + '-' + str(count) 45 | while os.path.exists(outdir_inc): 46 | count = count + 1 47 | outdir_inc = outdir + '-' + str(count) 48 | assert count < 100 49 | outdir = outdir_inc 50 | os.makedirs(outdir) 51 | return outdir 52 | 53 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | from contextlib import suppress 11 | 12 | import geffnet 13 | from data import Dataset, create_loader, resolve_data_config 14 | from utils import accuracy, AverageMeter 15 | 16 | has_native_amp = False 17 | try: 18 | if getattr(torch.cuda.amp, 'autocast') is not None: 19 | has_native_amp = True 20 | except AttributeError: 21 | pass 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 26 | parser.add_argument('data', metavar='DIR', 27 | help='path to dataset') 28 | parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00', 29 | help='model architecture (default: dpn92)') 30 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 31 | help='number of data loading workers (default: 2)') 32 | parser.add_argument('-b', '--batch-size', default=256, type=int, 33 | metavar='N', help='mini-batch size (default: 256)') 34 | parser.add_argument('--img-size', default=None, type=int, 35 | metavar='N', help='Input image dimension, uses model default if empty') 36 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 37 | help='Override mean pixel value of dataset') 38 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 39 | help='Override std deviation of of dataset') 40 | parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', 41 | help='Override default crop pct of 0.875') 42 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 43 | help='Image resize interpolation type (overrides model)') 44 | parser.add_argument('--num-classes', type=int, default=1000, 45 | help='Number classes in dataset') 46 | parser.add_argument('--print-freq', '-p', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 51 | help='use pre-trained model') 52 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 53 | help='convert model torchscript for inference') 54 | parser.add_argument('--num-gpu', type=int, default=1, 55 | help='Number of GPUS to use') 56 | parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', 57 | help='use tensorflow mnasnet preporcessing') 58 | parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', 59 | help='') 60 | parser.add_argument('--channels-last', action='store_true', default=False, 61 | help='Use channels_last memory layout') 62 | parser.add_argument('--amp', action='store_true', default=False, 63 | help='Use native Torch AMP mixed precision.') 64 | 65 | 66 | def main(): 67 | args = parser.parse_args() 68 | 69 | if not args.checkpoint and not args.pretrained: 70 | args.pretrained = True 71 | 72 | amp_autocast = suppress # do nothing 73 | if args.amp: 74 | if not has_native_amp: 75 | print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.") 76 | else: 77 | amp_autocast = torch.cuda.amp.autocast 78 | 79 | # create model 80 | model = geffnet.create_model( 81 | args.model, 82 | num_classes=args.num_classes, 83 | in_chans=3, 84 | pretrained=args.pretrained, 85 | checkpoint_path=args.checkpoint, 86 | scriptable=args.torchscript) 87 | 88 | if args.channels_last: 89 | model = model.to(memory_format=torch.channels_last) 90 | 91 | if args.torchscript: 92 | torch.jit.optimized_execution(True) 93 | model = torch.jit.script(model) 94 | 95 | print('Model %s created, param count: %d' % 96 | (args.model, sum([m.numel() for m in model.parameters()]))) 97 | 98 | data_config = resolve_data_config(model, args) 99 | 100 | criterion = nn.CrossEntropyLoss() 101 | 102 | if not args.no_cuda: 103 | if args.num_gpu > 1: 104 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() 105 | else: 106 | model = model.cuda() 107 | criterion = criterion.cuda() 108 | 109 | loader = create_loader( 110 | Dataset(args.data, load_bytes=args.tf_preprocessing), 111 | input_size=data_config['input_size'], 112 | batch_size=args.batch_size, 113 | use_prefetcher=not args.no_cuda, 114 | interpolation=data_config['interpolation'], 115 | mean=data_config['mean'], 116 | std=data_config['std'], 117 | num_workers=args.workers, 118 | crop_pct=data_config['crop_pct'], 119 | tensorflow_preprocessing=args.tf_preprocessing) 120 | 121 | batch_time = AverageMeter() 122 | losses = AverageMeter() 123 | top1 = AverageMeter() 124 | top5 = AverageMeter() 125 | 126 | model.eval() 127 | end = time.time() 128 | with torch.no_grad(): 129 | for i, (input, target) in enumerate(loader): 130 | if not args.no_cuda: 131 | target = target.cuda() 132 | input = input.cuda() 133 | if args.channels_last: 134 | input = input.contiguous(memory_format=torch.channels_last) 135 | 136 | # compute output 137 | with amp_autocast(): 138 | output = model(input) 139 | loss = criterion(output, target) 140 | 141 | # measure accuracy and record loss 142 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 143 | losses.update(loss.item(), input.size(0)) 144 | top1.update(prec1.item(), input.size(0)) 145 | top5.update(prec5.item(), input.size(0)) 146 | 147 | # measure elapsed time 148 | batch_time.update(time.time() - end) 149 | end = time.time() 150 | 151 | if i % args.print_freq == 0: 152 | print('Test: [{0}/{1}]\t' 153 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' 154 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 155 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 156 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 157 | i, len(loader), batch_time=batch_time, 158 | rate_avg=input.size(0) / batch_time.avg, 159 | loss=losses, top1=top1, top5=top5)) 160 | 161 | print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( 162 | top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | --------------------------------------------------------------------------------