├── .gitignore ├── LICENSE ├── README.md ├── distributed_train.sh ├── eval.sh ├── logs ├── vip_L7.log ├── vip_m7.log └── vip_s7.log ├── main.py ├── models ├── __init__.py └── vip.py ├── permute_mlp.png ├── transfer_learning.py ├── utils.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # 118 | .vscode/ 119 | 120 | # sublime project settings 121 | *.sublime-workspace 122 | *.sublime-project 123 | sftp-config.json 124 | run.sh 125 | 126 | 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Qibin (Andrew) Hou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition ([arxiv](https://arxiv.org/abs/2106.12368)) 2 | 3 | This is a Pytorch implementation of our paper ViP, [IEEE TPAMI 2022](https://ieeexplore.ieee.org/abstract/document/9693166/). MindSpore and Jittor code will be released soon. We present Vision Permutator, a conceptually simple and data efficient 4 | MLP-like architecture for visual recognition. We show that our Vision Permutators are formidable competitors to convolutional neural 5 | networks (CNNs) and vision transformers. 6 | 7 | We hope this work could encourage researchers to rethink the way of encoding spatial 8 | information and facilitate the development of MLP-like models. 9 | 10 | ![Compare](permute_mlp.png) 11 | 12 | Basic structure of the proposed Permute-MLP layer. The proposed Permute-MLP layer contains 13 | three branches that are responsible for encoding features along the height, width, and channel 14 | dimensions, respectively. The outputs from the three branches are then combined using element-wise addition, followed by a fully-connected layer for feature fusion. 15 | 16 | Our code is based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [Token Labeling](https://github.com/zihangJiang/TokenLabelinghttps://github.com/rwightman), [T2T-ViT](https://github.com/yitu-opensource/T2T-ViT) 17 | 18 | ### Comparison with Recent MLP-like Models 19 | 20 | | Model | Parameters | Throughput | Image resolution | Top 1 Acc. | Download | Logs | 21 | | :------------------- | :--------- | :--------- | :--------------- | :--------- | :------- | :---- | 22 | | EAMLP-14 | 30M | 711 img/s | 224 | 78.9% | | | 23 | | gMLP-S | 20M | - | 224 | 79.6% | | | 24 | | ResMLP-S24 | 30M | 715 img/s | 224 | 79.4% | | | 25 | | ViP-Small/7 (ours) | 25M | 719 img/s | 224 | 81.5% | [link](https://drive.google.com/file/d/1cX6eauDrsGsLSZnqsX7cl0oiKX8Dzv5z/view?usp=sharing) | [log](https://github.com/Andrew-Qibin/VisionPermutator/blob/main/logs/vip_s7.log) | 26 | | EAMLP-19 | 55M | 464 img/s | 224 | 79.4% | | | 27 | | Mixer-B/16 | 59M | - | 224 | 78.5% | | | 28 | | ViP-Medium/7 (ours) | 55M | 418 img/s | 224 | 82.7% | [link](https://drive.google.com/file/d/15y5WMypthpbBFdc01E3mJCZit7q0Yn8m/view?usp=sharing) | [log](https://github.com/Andrew-Qibin/VisionPermutator/blob/main/logs/vip_m7.log) | 29 | | gMLP-B | 73M | - | 224 | 81.6% | | | 30 | | ResMLP-B24 | 116M | 231 img/s | 224 | 81.0% | | | 31 | | ViP-Large/7 | 88M | 298 img/s | 224 | 83.2% | [link](https://drive.google.com/file/d/14F5IXGXmB_3jrwK33Efae-WEb5D_G85c/view?usp=sharing) | [log](https://github.com/Andrew-Qibin/VisionPermutator/blob/main/logs/vip_L7.log) | 32 | 33 | The throughput is measured on a single machine with V100 GPU (32GB) with batch size set to 32. 34 | 35 | Training ViP-Small/7 takes less than 30h on ImageNet for 300 epochs on a node with 8 A100 GPUs. 36 | 37 | ### Requirements 38 | 39 | ``` 40 | torch>=1.4.0 41 | torchvision>=0.5.0 42 | pyyaml 43 | timm==0.4.5 44 | apex if you use 'apex amp' 45 | ``` 46 | 47 | data prepare: ImageNet with the following folder structure, you can extract imagenet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). 48 | 49 | ``` 50 | │imagenet/ 51 | ├──train/ 52 | │ ├── n01440764 53 | │ │ ├── n01440764_10026.JPEG 54 | │ │ ├── n01440764_10027.JPEG 55 | │ │ ├── ...... 56 | │ ├── ...... 57 | ├──val/ 58 | │ ├── n01440764 59 | │ │ ├── ILSVRC2012_val_00000293.JPEG 60 | │ │ ├── ILSVRC2012_val_00002138.JPEG 61 | │ │ ├── ...... 62 | │ ├── ...... 63 | ``` 64 | 65 | ### Validation 66 | Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path 67 | ``` 68 | CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint 69 | ``` 70 | 71 | ### Training 72 | 73 | Command line for training on 8 GPUs (V100) 74 | ``` 75 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model vip_s7 -b 256 -j 8 --opt adamw --epochs 300 --sched cosine --apex-amp --img-size 224 --drop-path 0.1 --lr 2e-3 --weight-decay 0.05 --remode pixel --reprob 0.25 --aa rand-m9-mstd0.5-inc1 --smoothing 0.1 --mixup 0.8 --cutmix 1.0 --warmup-lr 1e-6 --warmup-epochs 20 76 | ``` 77 | 78 | 79 | ### Reference 80 | You may want to cite: 81 | ``` 82 | @article{hou2022vision, 83 | title={Vision permutator: A permutable mlp-like architecture for visual recognition}, 84 | author={Hou, Qibin and Jiang, Zihang and Yuan, Li and Cheng, Ming-Ming and Yan, Shuicheng and Feng, Jiashi}, 85 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 86 | year={2022}, 87 | publisher={IEEE} 88 | } 89 | ``` 90 | 91 | 92 | ### License 93 | This repository is released under the MIT License as found in the [LICENSE](LICENSE) file. Code in this repo is for non-commercial use only. 94 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC main.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | if [ ! $1 ]; 2 | then 3 | DATA_DIR=/path/to/imagenet/val 4 | else 5 | DATA_DIR="$1" 6 | fi 7 | if [ ! $2 ]; 8 | then 9 | MODEL_DIR=/path/to/checkpoint 10 | else 11 | MODEL_DIR="$2" 12 | fi 13 | python3 validate.py $DATA_DIR --model vip_s7 --checkpoint $MODEL_DIR/vip_s7.pth --no-test-pool --amp --img-size 224 -b 64 14 | 15 | -------------------------------------------------------------------------------- /logs/vip_L7.log: -------------------------------------------------------------------------------- 1 | epoch train_loss eval_loss eval_top1 eval_top5 2 | 0 6.913001299 6.857731318 0.4999999982 1.689999995 3 | 1 6.761281627 6.144085424 2.563999997 8.854000005 4 | 2 6.566323076 5.627922377 5.126000007 15.72399997 5 | 3 6.380618879 5.077777988 10.07400002 25.73200001 6 | 4 6.204351766 4.611636173 14.70800002 33.95399997 7 | 5 6.025766747 4.106199132 20.63799997 43.46599996 8 | 6 5.847623893 3.717910461 25.54600005 49.95000011 9 | 7 5.669392279 3.347409837 31.80600001 57.57600003 10 | 8 5.528334992 3.061046054 36.34800003 62.70599995 11 | 9 5.387733221 2.843697432 40.36200002 66.9140001 12 | 10 5.265279566 2.633993067 44.32599998 70.48200009 13 | 11 5.149167946 2.50467993 46.76200005 72.89 14 | 12 5.072068896 2.369067063 48.97600001 75.1800001 15 | 13 5.012365171 2.348375036 50.00200002 75.79800006 16 | 14 4.946704865 2.209346479 52.2639999 77.43000002 17 | 15 4.905805247 2.164400191 53.56600013 78.63800007 18 | 16 4.858599765 2.08581623 54.57199998 79.47599997 19 | 17 4.821058682 2.03694256 55.63000005 80.13399996 20 | 18 4.754349164 1.929476458 57.42200005 81.52200009 21 | 19 4.661197015 1.856204604 59.07400009 82.83800003 22 | 20 4.625498874 1.800021159 60.30800001 83.90200009 23 | 21 4.556209598 1.732825204 61.71799993 84.89400006 24 | 22 4.475701571 1.697331366 62.904 85.64600011 25 | 23 4.428713288 1.653596472 63.98799996 86.14000006 26 | 24 4.376069886 1.583566997 64.53599998 86.70800003 27 | 25 4.315484524 1.543828818 65.822 87.596 28 | 26 4.291441066 1.48466541 66.86400005 88.05200005 29 | 27 4.231677004 1.471131234 67.19599997 88.40400007 30 | 28 4.193298578 1.419725731 67.72399995 88.75199995 31 | 29 4.17669516 1.449831941 68.10999997 88.88999997 32 | 30 4.141682948 1.372756124 68.87000012 89.35800005 33 | 31 4.119020922 1.362842253 69.288 89.52999992 34 | 32 4.09656683 1.342648052 69.76600003 90.02800028 35 | 33 4.089684537 1.324026034 70.33799997 90.05800007 36 | 34 4.061986991 1.294390531 70.49000007 90.24000005 37 | 35 4.025167227 1.283889737 70.49000015 90.28800002 38 | 36 4.000577637 1.276931103 71.27000002 90.64400012 39 | 37 3.977683697 1.271825683 71.54400011 90.89799991 40 | 38 3.968775238 1.262204099 71.39400009 90.97800007 41 | 39 3.938168202 1.253261973 71.82800007 91.01399994 42 | 40 3.947777748 1.228888645 72.21200007 91.2120001 43 | 41 3.922934839 1.221194529 72.42599993 91.17600004 44 | 42 3.900672146 1.211756441 72.63199999 91.37200002 45 | 43 3.891432881 1.203859162 72.41400014 91.42400007 46 | 44 3.878281866 1.196791563 72.93999996 91.56400009 47 | 45 3.868952547 1.181129683 72.87000004 91.6680001 48 | 46 3.874262742 1.187774152 73.25999996 91.79200002 49 | 47 3.857880524 1.154163999 73.55600011 91.92999999 50 | 48 3.825644697 1.165146436 73.66400006 92.04800019 51 | 49 3.820438096 1.170657296 73.66399999 92.08000012 52 | 50 3.819287726 1.17037165 73.37400014 92.00999996 53 | 51 3.791288291 1.12613784 73.95799999 92.21400004 54 | 52 3.805008088 1.134835058 74.37000009 92.35399999 55 | 53 3.798599703 1.152775878 74.20000003 92.42200017 56 | 54 3.781377622 1.136407588 74.25199996 92.42599999 57 | 55 3.759510875 1.125298276 74.27400006 92.48200009 58 | 56 3.776419895 1.096317926 74.49400004 92.57200019 59 | 57 3.767340081 1.125865166 74.66000004 92.5220001 60 | 58 3.754183054 1.114121914 74.60200014 92.63200012 61 | 59 3.762281026 1.113626309 74.85400007 92.75800017 62 | 60 3.73144269 1.110204132 74.65599998 92.69000004 63 | 61 3.737390927 1.124995146 74.85600001 92.67800012 64 | 62 3.707579664 1.064523263 75.06999999 92.73400004 65 | 63 3.717393058 1.117569078 74.978 92.79800009 66 | 64 3.711606213 1.107795517 75.28200011 92.91400006 67 | 65 3.717213852 1.086564629 75.11399998 92.76000006 68 | 66 3.702746834 1.082565982 75.22199993 92.96800001 69 | 67 3.695250835 1.065462854 75.37800008 93.13000014 70 | 68 3.692532216 1.076244655 75.47000009 92.90800014 71 | 69 3.694451588 1.075827158 75.72399993 93.11000001 72 | 70 3.690761907 1.06159098 75.522 93.15200007 73 | 71 3.66839041 1.066271703 75.50200011 93.09800004 74 | 72 3.672531298 1.056537478 75.67599996 93.16600004 75 | 73 3.650957993 1.075173085 75.69800004 93.07200006 76 | 74 3.663316539 1.073580273 75.68199996 93.17600001 77 | 75 3.663593786 1.052213494 75.58199993 93.21800011 78 | 76 3.628347175 1.03390845 76.20800011 93.44800009 79 | 77 3.640966398 1.052992863 75.99200003 93.45799996 80 | 78 3.646266154 1.043381317 76.18599998 93.33600009 81 | 79 3.63856663 1.047625149 76.07600017 93.32400001 82 | 80 3.630207624 1.023744758 76.30999988 93.45999998 83 | 81 3.638915079 1.042934901 76.34599998 93.47799988 84 | 82 3.607032367 1.044576016 76.14800003 93.45600001 85 | 83 3.633141313 1.036485971 76.20200011 93.40400007 86 | 84 3.611221535 1.050900771 76.29400006 93.50400011 87 | 85 3.616281237 1.024137265 76.50400006 93.54400022 88 | 86 3.609388522 1.028516903 76.7759999 93.58800009 89 | 87 3.595447625 1.011578943 76.63000009 93.63200004 90 | 88 3.579950503 1.028449685 76.82800006 93.63199999 91 | 89 3.58230991 1.013791361 76.61599993 93.62600006 92 | 90 3.590382729 1.020616848 76.84000008 93.71199993 93 | 91 3.588954909 0.9983490827 76.99399998 93.89000011 94 | 92 3.586752534 0.9968194655 76.89799998 93.88200017 95 | 93 3.580000911 0.9989759994 76.73000008 93.85600017 96 | 94 3.578679323 1.005947356 76.96400016 93.91200014 97 | 95 3.550036158 1.00907395 76.74800006 93.85600009 98 | 96 3.558552878 1.015046765 77.05399995 93.97199996 99 | 97 3.556361556 1.030329271 77.10800005 93.78800006 100 | 98 3.577016217 1.016495142 77.06799993 94.05199996 101 | 99 3.548931241 0.983356522 77.26000011 94.03199998 102 | 100 3.539836015 0.9986022405 77.11400013 93.91600011 103 | 101 3.564150453 1.00415012 77.32000006 93.91800001 104 | 102 3.536392842 1.001878855 77.18400009 94.00000006 105 | 103 3.538246717 1.014467212 77.19800008 93.96800004 106 | 104 3.523329241 1.001179792 77.39399992 94.13600003 107 | 105 3.529150265 0.9773992533 77.36400013 94.07000006 108 | 106 3.509649992 0.9851403767 77.552 94.09399993 109 | 107 3.516202126 0.9971003963 77.32599995 94.17000004 110 | 108 3.530516454 0.9734222259 77.85 94.13800009 111 | 109 3.519188728 0.9796550668 77.57399998 94.16999988 112 | 110 3.497916988 0.9668980101 77.86199982 94.23999996 113 | 111 3.499300906 0.9840671982 77.89 94.22800001 114 | 112 3.511266436 0.9770733947 77.88800008 94.22199998 115 | 113 3.495403273 0.9679898705 77.68600001 94.18399996 116 | 114 3.512358086 0.992894538 77.77600001 94.37000006 117 | 115 3.498695595 0.9589503782 77.65999992 94.20399996 118 | 116 3.473399912 0.9680331292 77.74800006 94.32000001 119 | 117 3.506208028 0.9782849104 77.7519999 94.32400006 120 | 118 3.465959498 0.9809470446 77.77600006 94.22199996 121 | 119 3.454306926 0.9729351379 77.93600008 94.24800004 122 | 120 3.483699526 0.9714899434 78.06400003 94.33999993 123 | 121 3.476371356 0.9535773982 78.14400005 94.39800009 124 | 122 3.484441025 0.9761949054 78.08600013 94.42800009 125 | 123 3.445943952 0.9710110314 78.262 94.35799999 126 | 124 3.454047578 0.95120591 78.26800008 94.48600006 127 | 125 3.441857713 0.9487121837 78.26800008 94.40000004 128 | 126 3.451743586 0.9553972806 78.30799993 94.59199996 129 | 127 3.447823559 0.9553032318 78.21400001 94.43199986 130 | 128 3.436184032 0.9447731223 78.53800013 94.45600004 131 | 129 3.434735315 0.9190254749 78.58399998 94.57800003 132 | 130 3.417208638 0.9375618962 78.63400006 94.66000014 133 | 131 3.436141474 0.9513130292 78.45400013 94.66800001 134 | 132 3.406301022 0.9262450643 78.50200019 94.56799996 135 | 133 3.422926358 0.9423015359 78.78399996 94.68000006 136 | 134 3.406378508 0.9287419548 78.71800011 94.67800006 137 | 135 3.399285538 0.9205029363 78.75800023 94.85199996 138 | 136 3.393400192 0.9316855366 78.81000008 94.69600006 139 | 137 3.399185147 0.9497822449 78.87 94.78600004 140 | 138 3.391925335 0.9128809229 79.01200021 94.87000006 141 | 139 3.397752319 0.9146815884 78.98199995 94.74200001 142 | 140 3.38273532 0.9131451086 78.9600001 94.85200001 143 | 141 3.3854396 0.9181603865 79.044 94.78800006 144 | 142 3.367522717 0.928578095 79.11800005 94.9039999 145 | 143 3.377413699 0.9058879134 79.29399998 95.01600003 146 | 144 3.391084961 0.9179185202 79.0420001 94.79000003 147 | 145 3.352251904 0.9165119128 79.17400013 95.02600003 148 | 146 3.3656411 0.9264486267 79.29599998 94.96600009 149 | 147 3.346459423 0.925475686 79.31400011 95.08800003 150 | 148 3.345809289 0.9088662953 79.17800016 95.01399996 151 | 149 3.364534548 0.9240142171 79.49800013 95.03600009 152 | 150 3.337129559 0.8910718281 79.55400021 95.13400009 153 | 151 3.306680151 0.9141196345 79.6240001 94.99200006 154 | 152 3.322871804 0.9147890906 79.39399992 95.02600001 155 | 153 3.31878124 0.902114486 79.49600011 95.18600006 156 | 154 3.324711323 0.883527628 79.64800005 95.15800014 157 | 155 3.299105406 0.8955047981 79.57000011 95.08599996 158 | 156 3.317340425 0.8943205843 79.57600008 95.14200001 159 | 157 3.31820568 0.9037213893 79.7639999 95.10400006 160 | 158 3.285464134 0.8977480441 79.76600014 95.24000001 161 | 159 3.299160736 0.9142937015 79.80399998 95.13399999 162 | 160 3.281858615 0.8935115352 79.78200006 95.22599996 163 | 161 3.293683767 0.8958488972 79.67000009 95.14399998 164 | 162 3.288966775 0.8780349673 80.06400003 95.28199996 165 | 163 3.285949128 0.8947712902 80.01400016 95.19800009 166 | 164 3.276713899 0.8856335512 80.04600016 95.22200006 167 | 165 3.260910102 0.8872810849 79.9779999 95.22799996 168 | 166 3.281882593 0.8737985945 80.13800011 95.35800001 169 | 167 3.270663108 0.8740480864 80.14400003 95.42600003 170 | 168 3.263766766 0.8706478661 80.3220001 95.41599993 171 | 169 3.259786367 0.8543372419 80.21800001 95.47800001 172 | 170 3.262291346 0.8690291272 80.25200016 95.45799988 173 | 171 3.238998107 0.8638744375 80.22800003 95.43400001 174 | 172 3.233007669 0.8725711017 80.17800011 95.36999996 175 | 173 3.238121663 0.8822985034 80.13599995 95.45399998 176 | 174 3.236377239 0.8599990455 80.56400005 95.3519999 177 | 175 3.238717301 0.8410555789 80.6099999 95.54400006 178 | 176 3.222300512 0.8636243329 80.51800016 95.49199996 179 | 177 3.20611872 0.8510008482 80.62400008 95.47799993 180 | 178 3.209307517 0.8598693686 80.63199998 95.53999993 181 | 179 3.191443767 0.8585887837 80.80599998 95.55200001 182 | 180 3.189323187 0.8487512378 80.55600013 95.53000006 183 | 181 3.192659276 0.8425362645 80.75400011 95.58400001 184 | 182 3.179363506 0.8526985534 80.60400003 95.5639999 185 | 183 3.18301441 0.8375100801 80.662 95.59999988 186 | 184 3.165303758 0.8594995 80.76400016 95.60400006 187 | 185 3.178292871 0.8481940637 80.71600016 95.52999998 188 | 186 3.172219975 0.8419358005 80.95400005 95.63200003 189 | 187 3.168930956 0.8459458539 80.94400008 95.65000009 190 | 188 3.157896008 0.8394384903 81.09599995 95.81400001 191 | 189 3.154213088 0.8263042637 81.08399998 95.79199996 192 | 190 3.172243851 0.828965089 81.26000018 95.78200003 193 | 191 3.148580909 0.841667545 81.00400006 95.65199996 194 | 192 3.157555495 0.8479487026 81.162 95.58600004 195 | 193 3.132151689 0.8412623523 80.91200005 95.59199993 196 | 194 3.136727486 0.8432483573 81.12599998 95.7279999 197 | 195 3.146040337 0.8273324599 81.49399992 95.68200022 198 | 196 3.116257651 0.8292310741 81.2660001 95.71200001 199 | 197 3.110119956 0.825994346 81.49400018 95.81200003 200 | 198 3.099743775 0.8264384639 81.26600006 95.79600001 201 | 199 3.100607872 0.8214785859 81.49000019 95.8919999 202 | 200 3.111358796 0.8342787673 81.26000006 95.79000001 203 | 201 3.106865798 0.830161839 81.44600002 95.81200009 204 | 202 3.087085145 0.8274200458 81.28200013 95.84400009 205 | 203 3.101616332 0.8088009297 81.56199997 95.97600003 206 | 204 3.092041884 0.8356901847 81.524 95.8779999 207 | 205 3.067691531 0.8132908139 81.64200018 95.93200001 208 | 206 3.051748054 0.8181383617 81.71599997 95.89600001 209 | 207 3.059259721 0.8168480108 81.75600008 95.85400006 210 | 208 3.048679352 0.805407403 81.80599992 95.97000006 211 | 209 3.051528692 0.8181106576 81.69000002 95.97599996 212 | 210 3.038665329 0.8124166599 81.66400011 96.01600003 213 | 211 3.034575292 0.8130917788 81.90000007 95.98999993 214 | 212 3.027209742 0.8060897728 81.83600021 95.93599993 215 | 213 3.034819484 0.82380004 81.85600002 95.99800003 216 | 214 3.00713035 0.812400895 81.966 96.11999998 217 | 215 3.017175095 0.8133901038 82.03600005 96.08399993 218 | 216 3.022750837 0.8086925764 82.00000011 95.97400003 219 | 217 3.012305856 0.8125396531 81.98000016 95.98800011 220 | 218 3.000207118 0.8148258285 82.02400005 96.00399998 221 | 219 2.996970688 0.7964179638 82.09400018 96.07599993 222 | 220 3.001877495 0.8124238705 82.07400002 96.06800009 223 | 221 2.982086931 0.8088283077 82.03999998 96.08400006 224 | 222 2.962717124 0.8138080914 82.10599995 96.07199998 225 | 223 2.952983022 0.8052984554 82.25000008 96.12400011 226 | 224 2.973296574 0.8040889742 82.07599998 96.05399993 227 | 225 2.935394764 0.7998646548 82.14599997 96.08800001 228 | 226 2.957448329 0.8063979068 82.2560001 96.22199993 229 | 227 2.944312147 0.806794986 82.19400016 96.13000006 230 | 228 2.947030595 0.8153845561 82.08800005 95.9879999 231 | 229 2.96601132 0.7976170387 82.30400016 96.12800009 232 | 230 2.928501589 0.8040821841 82.33599997 96.13600016 233 | 231 2.913145576 0.8071104817 82.26800006 96.13800014 234 | 232 2.915285724 0.8005180933 82.38799998 96.1819999 235 | 233 2.907242026 0.8082498322 82.47000008 96.13600003 236 | 234 2.917316352 0.7949280845 82.51800005 96.1919999 237 | 235 2.894564016 0.8044481371 82.46399998 96.2779999 238 | 236 2.90357089 0.7868865868 82.59400011 96.14200006 239 | 237 2.893367427 0.7910678092 82.51999995 96.19999993 240 | 238 2.871845739 0.7949643531 82.62400005 96.33599998 241 | 239 2.861688461 0.7864959856 82.55000005 96.18600003 242 | 240 2.896303024 0.7947511209 82.58000023 96.21600001 243 | 241 2.86497329 0.7910422353 82.57400013 96.22199993 244 | 242 2.865310975 0.7945556509 82.5380001 96.24399998 245 | 243 2.863366434 0.7950671361 82.6460001 96.19800001 246 | 244 2.854070016 0.7968172623 82.62400006 96.19999998 247 | 245 2.847989508 0.7941819822 82.7440001 96.26400001 248 | 246 2.83945666 0.7977362152 82.59200002 96.21999998 249 | 247 2.832026737 0.7822571631 82.76600016 96.27600001 250 | 248 2.816397088 0.7954147977 82.64800026 96.28000003 251 | 249 2.816217218 0.7972653251 82.77400003 96.22799996 252 | 250 2.824791108 0.7949384631 82.64999998 96.29599998 253 | 251 2.823545456 0.7903238153 82.67600011 96.24600001 254 | 252 2.812996524 0.7974804704 82.71600003 96.25800001 255 | 253 2.812445027 0.796391237 82.80200016 96.23399998 256 | 254 2.803947568 0.786914629 82.71400013 96.24200006 257 | 255 2.793214032 0.7928770452 82.72200008 96.31799998 258 | 256 2.793913671 0.7915947765 82.86600018 96.35799993 259 | 257 2.768638185 0.7943645712 82.88400008 96.23400011 260 | 258 2.795628888 0.7880989613 82.88200018 96.27600009 261 | 259 2.77637262 0.7949189232 82.886 96.22600001 262 | 260 2.769878864 0.781623648 82.854 96.36799993 263 | 261 2.759905015 0.7830648042 82.9759999 96.32600009 264 | 262 2.763867208 0.794888597 82.922 96.37400009 265 | 263 2.757486463 0.7816529967 83.05000005 96.31400009 266 | 264 2.76139588 0.7776610531 83.10600005 96.34600001 267 | 265 2.754269889 0.7904224434 82.98200003 96.24400003 268 | 266 2.764200091 0.7818329756 82.95800011 96.23400001 269 | 267 2.754135898 0.7877558433 83.00799992 96.29399998 270 | 268 2.750979287 0.7811758689 82.97799998 96.34399998 271 | 269 2.726730347 0.7824476216 83.0320001 96.33000016 272 | 270 2.728987813 0.7851066397 83.12000008 96.32600003 273 | 271 2.727947882 0.7869710286 83.04000006 96.31200003 274 | 272 2.73354716 0.7826401623 82.98999998 96.26600009 275 | 273 2.720023768 0.7809933559 83.02400005 96.31999998 276 | 274 2.723090223 0.7858375484 83.01800011 96.29000001 277 | 275 2.714487995 0.7824015975 82.99800016 96.30200011 278 | 276 2.724763342 0.780421724 83.07000016 96.30000009 279 | 277 2.719096984 0.7835913246 83.18000003 96.27800001 280 | 278 2.706166182 0.7793887179 83.14800008 96.27400001 281 | 279 2.708738497 0.7816188398 83.07000011 96.26000001 282 | 280 2.709961108 0.783480232 83.09400005 96.31199993 283 | 281 2.701489329 0.7784310701 83.17400008 96.31800003 284 | 282 2.691452844 0.780681438 83.13000008 96.30200003 285 | 283 2.701883265 0.7784053961 83.03600006 96.32000003 286 | 284 2.71519053 0.7785520374 83.10800016 96.34800003 287 | 285 2.687014171 0.7781087155 83.14599995 96.34200009 288 | 286 2.696395142 0.7760199122 83.09800029 96.34200009 289 | 287 2.692885195 0.7785685784 83.10200011 96.32200009 290 | 288 2.68206346 0.7783894224 83.07200011 96.34400001 291 | 289 2.697129505 0.7761118614 83.124 96.34600006 292 | 290 2.671725699 0.7748139445 83.20399995 96.36200011 293 | 291 2.683646781 0.7776425358 83.11800023 96.34000006 294 | 292 2.683768323 0.779488044 83.08000013 96.36800009 295 | 293 2.68746531 0.7784028229 83.06600005 96.36200001 296 | 294 2.680313723 0.7764808601 83.07600013 96.37000006 297 | 295 2.682268364 0.7779367533 83.08400013 96.31200009 298 | 296 2.683300921 0.7768416742 83.05800013 96.32600009 299 | 297 2.659034457 0.7780973908 83.06399998 96.29800009 300 | 298 2.677618316 0.7755918687 83.06800008 96.28199998 301 | 299 2.693949223 0.778153858 83.12400008 96.34600001 302 | 300 2.667474576 0.7783629191 83.06400016 96.30200009 303 | 301 2.667407445 0.7755143827 83.12000016 96.39000006 304 | 302 2.680549792 0.7777911545 83.06800005 96.41400009 305 | 303 2.667167885 0.7809245298 83.0740001 96.34000009 306 | 304 2.671986699 0.7787482977 83.13400013 96.36000009 307 | 305 2.665681907 0.7767296701 83.11200016 96.33800006 308 | 306 2.675118361 0.7812005786 83.09800005 96.29800006 309 | 307 2.674164465 0.7797712204 83.09 96.34800009 310 | 308 2.686720014 0.7786456819 83.09 96.33400009 311 | 309 2.685535261 0.7806094913 83.076 96.33000001 312 | -------------------------------------------------------------------------------- /logs/vip_m7.log: -------------------------------------------------------------------------------- 1 | epoch,train_loss,eval_loss,eval_top1,eval_top5 2 | 0,6.90932880129133,6.854047560424805,0.39599999786376955,1.7059999972534179 3 | 1,6.702896015984671,5.963005331573486,3.3299999923706056,11.133999990234376 4 | 2,6.456992932728359,5.390907523193359,7.442000018920899,20.434000021972658 5 | 3,6.23755533354623,4.737257749633789,13.332000017089843,31.90200005859375 6 | 4,6.0225352219172885,4.206198528442383,19.956000020751954,41.960000043945314 7 | 5,5.804148435592651,3.7243029332733153,25.99600001220703,50.56000003662109 8 | 6,5.613142762865339,3.3396494134521486,32.04399998535156,57.703999985351565 9 | 7,5.417514324188232,2.970649938354492,37.264,63.60399993652344 10 | 8,5.27363201550075,2.7333306427001953,41.982000034179684,68.48799994140624 11 | 9,5.093750170298985,2.525843430709839,45.62599998535156,72.02600009277344 12 | 10,4.980165481567383,2.352075763015747,48.75000004150391,74.76800009765626 13 | 11,4.872818742479597,2.2027560341644286,51.57999998535156,77.1940000732422 14 | 12,4.8050375665937155,2.1824634374618532,52.47600003173828,77.90800004150391 15 | 13,4.771648270743234,2.0674364700317382,54.08,79.27600001953125 16 | 14,4.719101973942348,2.0189061892700195,56.07400005615234,80.50400001708984 17 | 15,4.677061762128558,1.9760460157775879,56.84999999755859,81.0960000390625 18 | 16,4.624162673950195,1.8888648011398315,57.492000048828125,81.84400017089844 19 | 17,4.598656790597098,1.9013685061645509,58.242000068359374,82.04000006591797 20 | 18,4.555602993283953,1.8495422927093506,58.995999990234374,82.65799993164063 21 | 19,4.478095361164638,1.7710593632507323,60.964000046386715,84.16600003417969 22 | 20,4.428852898733957,1.7178370823287963,61.39600013671875,84.69800000732423 23 | 21,4.361710923058646,1.673677162399292,63.08000003417969,85.63999998046874 24 | 22,4.2959471089499335,1.6123182100868225,64.18200008300781,86.37200002929687 25 | 23,4.249774047306606,1.5636786961364746,65.17399998291016,86.96200002685546 26 | 24,4.1908472435815,1.546017219772339,65.49799997558594,87.28400000488281 27 | 25,4.147025789533343,1.4960719562721252,66.3779999951172,87.85599994873047 28 | 26,4.12639057636261,1.453491238899231,67.12000004394531,88.29800001953124 29 | 27,4.082164934703282,1.4732326899147035,67.67400009277344,88.50199989257813 30 | 28,4.050461070878165,1.402883006439209,67.8819999975586,88.77000001953125 31 | 29,4.032754489353725,1.3888416222953797,68.64999999023438,89.02800014648437 32 | 30,4.010578036308289,1.3422764243507386,69.48000004638672,89.41399991699218 33 | 31,3.9900671584265575,1.382127080154419,69.38799999511718,89.56800004394532 34 | 32,3.945267013141087,1.3416419777488708,69.75800001708984,89.91200007080079 35 | 33,3.9570285422461375,1.3223409216308595,70.17400008789062,90.03200001464843 36 | 34,3.9207507371902466,1.2980675039482117,70.47799997070312,90.26000017578124 37 | 35,3.9023950610842024,1.28581027053833,70.73199996826172,90.18999997070313 38 | 36,3.877046755381993,1.2640974124526978,70.87399999267578,90.54800006835937 39 | 37,3.8551083973475864,1.2406789484405518,71.46400004150391,90.7480000415039 40 | 38,3.8666552816118513,1.2793489054107665,71.42199999023437,90.80799993896484 41 | 39,3.8173076936176846,1.2510529967308044,71.72000004638672,91.14800014892577 42 | 40,3.820861577987671,1.246612049369812,71.96800005126953,91.09200001953126 43 | 41,3.8092151199068343,1.220611818256378,72.24200001220703,91.14200022460938 44 | 42,3.807996835027422,1.2209717053985596,72.3039999609375,91.19200009277344 45 | 43,3.786864791597639,1.2195191497612,72.05600006347656,91.30799991210938 46 | 44,3.776286482810974,1.1800591128349305,72.82599998535156,91.59800009521484 47 | 45,3.7523889541625977,1.1763256419372559,73.25400009277344,91.64200009277344 48 | 46,3.7647524561200822,1.1948207719039916,72.71200003417968,91.47000001953126 49 | 47,3.7585670948028564,1.1500923590850831,73.23200006103515,91.8660001171875 50 | 48,3.7282666819436208,1.1795096869468689,73.0160001171875,91.80399996582031 51 | 49,3.7280845642089844,1.1735346340179444,73.24000001464844,91.75799998779297 52 | 50,3.7148827654974803,1.1504344519424439,73.2860001171875,91.69200009277344 53 | 51,3.6983796698706493,1.1542318394088744,73.38000000976562,91.89800001464843 54 | 52,3.7128142629350935,1.1372051441001891,73.98200001464843,92.0700000390625 55 | 53,3.711761474609375,1.1208502324104308,73.62000000976562,92.20600017089843 56 | 54,3.6802489246640886,1.1334939222717284,73.80600014648438,92.19399993652344 57 | 55,3.675034999847412,1.1469725336456298,73.78800001220704,92.34399999023438 58 | 56,3.6781928539276123,1.1121433991622924,74.17799995849609,92.4180000390625 59 | 57,3.660098212105887,1.1373679025650025,74.00400014160157,92.53000009277343 60 | 58,3.6428002970559255,1.0964150991439818,74.12400006591797,92.35600006835938 61 | 59,3.657994900430952,1.106780310974121,74.18400012207032,92.5460000390625 62 | 60,3.6275060176849365,1.109575862865448,74.38200000976562,92.4419999609375 63 | 61,3.6442099639347623,1.1107771522521972,74.26799998779296,92.4499999609375 64 | 62,3.628208347729274,1.0990574090576173,74.56200011230469,92.6280000390625 65 | 63,3.6321606125150407,1.1161306171035767,74.32399993164063,92.42800009277343 66 | 64,3.6263176713671004,1.1156763362693787,74.61999995605468,92.55600006835938 67 | 65,3.639007943017142,1.1101867294692993,74.74999990722657,92.64800006347656 68 | 66,3.622387102672032,1.097434368610382,74.71200003662109,92.60599998535156 69 | 67,3.6087682247161865,1.082902463645935,74.68800001220703,92.85999985839844 70 | 68,3.6143094641821727,1.1075475700569153,74.96400006103515,92.59599993652344 71 | 69,3.588876928601946,1.07731548122406,74.86200011230468,92.82400006347656 72 | 70,3.59721725327628,1.082208607635498,75.08400008300781,92.83200006591797 73 | 71,3.5702612740652904,1.0813466676330565,75.22600014160156,92.93600017089844 74 | 72,3.57477377142225,1.0802469022750854,75.17000010986328,92.92000006347656 75 | 73,3.5740824597222463,1.1001569610404969,75.20399995605469,92.79800009277344 76 | 74,3.5722209555762157,1.107181809539795,75.42400008789062,92.91600009033203 77 | 75,3.5605866398130144,1.0766574368667603,75.29400000732421,92.86000009033204 78 | 76,3.567481211253575,1.0460495560455323,75.61399990722656,93.07000006835938 79 | 77,3.5636866433279857,1.0638266690063476,75.60200003173829,93.1480001171875 80 | 78,3.5676473549434116,1.0710258788108826,75.50400000732422,93.08800006347656 81 | 79,3.562150171824864,1.050010140953064,75.55199987792969,93.15199998535157 82 | 80,3.5466181550707137,1.0560595837783813,75.48600009277344,93.0800000390625 83 | 81,3.550299882888794,1.0575691982269286,75.89000003417969,93.1960001171875 84 | 82,3.5317522627966746,1.063086628189087,75.63800005859375,93.21400006835937 85 | 83,3.5324534688677107,1.031112244129181,76.054,93.3140000366211 86 | 84,3.5365114552634105,1.0493337981033326,75.80600005859375,93.1480000390625 87 | 85,3.540239487375532,1.0531185172271729,75.8260000390625,93.2240000390625 88 | 86,3.5424277101244246,1.04272143989563,76.15600003417968,93.2040000390625 89 | 87,3.5174287046704973,1.0542635342597961,75.87999995117187,93.29800014160156 90 | 88,3.5076360191617693,1.006296633834839,76.20000005859374,93.4240001171875 91 | 89,3.4970589876174927,1.009729640712738,76.19800008789062,93.48800004150391 92 | 90,3.4958191769463673,1.0283534535980225,76.15600003417968,93.4300000390625 93 | 91,3.4994895117623463,1.0401119533348084,76.05000016601562,93.38000001464843 94 | 92,3.514454194477626,1.0369789333724975,76.15800005859376,93.45400001220703 95 | 93,3.5095438276018416,1.0178065168380737,76.46000008789062,93.43000021972657 96 | 94,3.503518189702715,1.018520607433319,76.36000002929687,93.57000005859375 97 | 95,3.489128657749721,1.0121382236480714,76.1840000366211,93.56200009277343 98 | 96,3.4821067367281233,1.0147426868247986,76.47200005859375,93.49000006591797 99 | 97,3.505060519490923,1.011701576385498,76.61799992675782,93.6179999609375 100 | 98,3.48980826990945,1.0290394330215453,76.38800002929688,93.55000008789062 101 | 99,3.472050428390503,0.9971770275497437,76.44999998046875,93.66000008789062 102 | 100,3.475861600467137,1.006145211830139,76.54799995605468,93.67599995849609 103 | 101,3.474380578313555,1.001423595046997,76.85800008300781,93.80800011230468 104 | 102,3.4643759386880055,1.0055016045379639,76.77800013183594,93.86999998535157 105 | 103,3.463895014354161,1.0105851085853577,76.61199997558593,93.69600013671875 106 | 104,3.465873956680298,1.0173472805786132,76.5080000024414,93.61800001464844 107 | 105,3.4485316957746233,0.9927911795425415,76.82800003173828,93.64600000976563 108 | 106,3.429732952799116,0.9994239200210572,76.8820000024414,93.8700000366211 109 | 107,3.4361068861825124,1.0036898034286499,76.92399995117188,93.85400011230469 110 | 108,3.4474459886550903,1.0013488720321655,76.93000006347657,93.82000006347656 111 | 109,3.425798807825361,1.0093399936103822,76.920000078125,93.75800006103516 112 | 110,3.4449102708271573,1.002862581615448,76.82599990722656,93.75399998535157 113 | 111,3.4363930565970287,0.9945665563964844,76.8660000048828,93.8219999609375 114 | 112,3.4327633551188876,0.9896794754791259,77.05600002685547,93.91400006347656 115 | 113,3.4252477203096663,0.9945056452941895,76.99400000732422,93.9940001171875 116 | 114,3.415156807218279,0.9790984543800354,77.22200016113281,93.91600000976563 117 | 115,3.4009200675146922,0.9721055029678345,77.36000000732422,93.93000009033203 118 | 116,3.4166004146848405,0.9908977197647095,77.48999997558593,94.12600005859375 119 | 117,3.413839731897627,0.9708467753410339,77.56800005859375,94.12600009033203 120 | 118,3.3895688567842757,0.9621318748855591,77.42999998535156,94.12800016601562 121 | 119,3.3959894009998868,0.989285439376831,77.52000010986328,94.11200006347656 122 | 120,3.3995001145771573,0.9688023084831238,77.40999993164063,94.14400006347657 123 | 121,3.3860636608941213,0.9830161957550049,77.58399998046875,94.0799999609375 124 | 122,3.3901074954441617,0.9809874724769593,77.53799998291015,94.0319999609375 125 | 123,3.367209025791713,0.971336136264801,77.7360000024414,94.17000008789063 126 | 124,3.37071887084416,0.9797874834442138,77.47199998046875,94.09600003417968 127 | 125,3.3855893101011003,0.9576681258583069,77.68600000732422,94.18600008789062 128 | 126,3.3691785505839755,0.9638239769744873,77.58200016113281,94.27400016601563 129 | 127,3.388491681643895,0.973994352684021,77.6920001586914,94.15000011474609 130 | 128,3.3621790749686107,0.9563417432022094,78.11400000488281,94.28600008789063 131 | 129,3.3581559487751553,0.968416593132019,77.83600018554688,94.20800006347656 132 | 130,3.360665134021214,0.9561886390113831,77.82800010986328,94.2459999609375 133 | 131,3.3736249719347273,0.9547506435012817,78.01799987548829,94.30400021484375 134 | 132,3.3275297369275774,0.9331438193130493,78.11800013671875,94.35200008789063 135 | 133,3.358065162386213,0.9460947372817993,77.79200000488281,94.17800003417969 136 | 134,3.3286341769354686,0.9399733685684204,78.11200002929688,94.37000008789063 137 | 135,3.3420989343098233,0.9497561868858337,77.99200018066406,94.40400000976562 138 | 136,3.33011840070997,0.9468776351356506,78.04200000488281,94.44800000976562 139 | 137,3.342316440173558,0.9319989150047302,78.07000010742188,94.3620001147461 140 | 138,3.3261644499642506,0.9220845895385742,78.48600002929687,94.61600006347656 141 | 139,3.31408052785056,0.9443053029251098,78.44999989746094,94.54200006347656 142 | 140,3.3290084430149625,0.910553692932129,78.36200005859375,94.57400001220704 143 | 141,3.321491684232439,0.9295531247329711,78.34200003173828,94.50600006347656 144 | 142,3.291707617895944,0.9368024371337891,78.30400000488281,94.52599990722656 145 | 143,3.3070097310202464,0.9397778823471069,78.41599997558593,94.63200006103516 146 | 144,3.323492339679173,0.9207230498886109,78.61600005615234,94.64199995605469 147 | 145,3.2997556584221974,0.9412972664833069,78.41000000244141,94.60200003417968 148 | 146,3.295029503958566,0.9245252856063843,78.53800010498047,94.60200000976562 149 | 147,3.2854332412992204,0.9361744274520875,78.61400002929688,94.69000006103515 150 | 148,3.2805512292044505,0.8999760466766358,78.74400008300782,94.7079999609375 151 | 149,3.285139186041696,0.9206268852996826,78.8680000024414,94.70600000976563 152 | 150,3.2906429767608643,0.9095589679908752,78.71800003417968,94.68600003417968 153 | 151,3.284801040376936,0.9345551473426819,78.69400000488281,94.67800013671875 154 | 152,3.2709765434265137,0.9270316875267028,78.69400008300781,94.71800005859374 155 | 153,3.234380074909755,0.9122170468902588,78.69199998535156,94.76200009277343 156 | 154,3.2663168736866544,0.9096678012275696,78.75599998046874,94.88000000976562 157 | 155,3.2503501176834106,0.8971937502479553,79.13000005615234,94.85400003662109 158 | 156,3.245058536529541,0.9038590211868286,79.03200003173828,94.85000006347656 159 | 157,3.241694348199027,0.9161998336791992,78.88000008300781,94.82400014160156 160 | 158,3.232398816517421,0.8999462019729614,79.22600008300782,94.97000008544921 161 | 159,3.237294282232012,0.8863958678436279,79.28199989990235,94.94199998291016 162 | 160,3.233746750014169,0.8991830952835083,79.30999992675781,94.93200008789063 163 | 161,3.239884240286691,0.9088663159942627,79.18800005371094,95.02000016113281 164 | 162,3.224090184484209,0.893155241394043,79.09600000488281,94.92800011230469 165 | 163,3.217820508139474,0.8919475843811036,79.28400000488281,94.98600000976562 166 | 164,3.215689471789769,0.8882910530471801,79.59600008056641,95.10600000976562 167 | 165,3.2048411539622714,0.8880675336647034,79.31200002929687,94.98400000732421 168 | 166,3.1993136405944824,0.8894321461105347,79.36400012695313,95.01200005859376 169 | 167,3.2110706908362254,0.8932610053253174,79.49199995117188,94.97600011230469 170 | 168,3.208358781678336,0.8874205178070068,79.89799995117187,95.09200000976563 171 | 169,3.206747123173305,0.8836066618156433,79.55600003173828,95.11999990722656 172 | 170,3.1907899379730225,0.8707646526527405,79.744000078125,95.21400006347656 173 | 171,3.1834730420793806,0.8908653679466247,79.73999997314453,95.03400006103516 174 | 172,3.1747579063688005,0.8915778890037537,79.64999995117188,95.06600011230469 175 | 173,3.170030117034912,0.8718975839233398,79.91600002685547,95.13600010986328 176 | 174,3.1720229046685353,0.8816528468322754,79.785999921875,95.15800008789063 177 | 175,3.171130452837263,0.8661529981994629,79.83800005126953,95.30800013671875 178 | 176,3.1641653265271867,0.888923968963623,79.91600013671875,95.11000008789063 179 | 177,3.160358820642744,0.8699007495689393,79.85199987304688,95.22599995849609 180 | 178,3.1526336329323903,0.866831381778717,79.92000004882813,95.31200003173828 181 | 179,3.1281084503446306,0.8654502025222778,80.01599992675781,95.37000008544922 182 | 180,3.1415669407163347,0.8838525032043457,79.84799997558594,95.2800000830078 183 | 181,3.1517248323985507,0.8630585159301758,79.8979999975586,95.23400003417969 184 | 182,3.1488796302250455,0.8505402045059204,80.18800002929687,95.44200013671875 185 | 183,3.1292964730943953,0.8766146088027954,80.00800013183594,95.20199998535156 186 | 184,3.110183664730617,0.8738239919471741,80.06600010253906,95.42000018554687 187 | 185,3.1210454532078336,0.8683562798690796,80.00800005371094,95.3040000341797 188 | 186,3.1250904968806674,0.8596588902854919,80.34000007568359,95.52800013671875 189 | 187,3.117073127201625,0.8575765120887756,80.218000078125,95.31600005859374 190 | 188,3.1005845240184238,0.8571908009338379,80.39800002685547,95.39200003417969 191 | 189,3.1040914910180226,0.8538576941299438,80.39000000488281,95.40399993164063 192 | 190,3.0914864369801114,0.8639280267333984,80.350000078125,95.42400016357422 193 | 191,3.086129375866481,0.8557544534683228,80.51200002929687,95.33599995605469 194 | 192,3.0807894808905467,0.8370084446334839,80.28800005859375,95.37800006347656 195 | 193,3.07554863180433,0.8622823519325257,80.448,95.39600019042969 196 | 194,3.05668682711465,0.8474378676223755,80.49799992675781,95.4460000366211 197 | 195,3.067542655127389,0.8543004457092285,80.49200005371094,95.44000006103515 198 | 196,3.066994445664542,0.856601085357666,80.43400005371093,95.52400013671875 199 | 197,3.073182071958269,0.8498832931900024,80.74000015625,95.49000003417969 200 | 198,3.056281958307539,0.8409462097549438,80.87000005371094,95.54800000976563 201 | 199,3.050658004624503,0.8633793709945679,80.83400000488281,95.46799995605468 202 | 200,3.051530735833304,0.8355366100692749,80.67200000488282,95.6039999560547 203 | 201,3.026632717677525,0.853679780883789,80.78600010253906,95.57200013671876 204 | 202,3.0374182803290233,0.8304018156051636,81.18599997558594,95.58000021484375 205 | 203,3.043479715074812,0.8450591536712646,80.74199997558594,95.5360000341797 206 | 204,3.0228357825960432,0.8336448619842529,80.9920000024414,95.64200021484375 207 | 205,3.013875433376857,0.8384150375175476,81.08000013183593,95.65000021484374 208 | 206,3.0022598675319125,0.8380760380744934,80.90599989746094,95.60600008300781 209 | 207,3.0194713047572543,0.831977225151062,81.13600000488282,95.67400019042968 210 | 208,2.9928110327039446,0.8454853986549378,81.024000078125,95.56000005859374 211 | 209,2.9901199340820312,0.833710662689209,81.10999997802735,95.78200013671875 212 | 210,2.9958644253867015,0.8277957198333741,81.17600005126953,95.66600005859375 213 | 211,2.993471009390695,0.8236359128570556,81.15200002929687,95.70000016601563 214 | 212,2.9929285730634416,0.8403115428161622,81.13800013183594,95.69799995605469 215 | 213,2.9833693163735524,0.8293285052108764,81.31000010742187,95.78200011230469 216 | 214,2.9747810874666487,0.8310634294128418,81.32399989746094,95.67600008544922 217 | 215,2.96452568258558,0.8179863739776612,81.28600002685548,95.71200013671876 218 | 216,2.9757233006613597,0.8274626565933227,81.26799989746094,95.67400011230468 219 | 217,2.9496664830616544,0.8227651625061035,81.42599997802735,95.67000005859374 220 | 218,2.942621656826564,0.815831360168457,81.533999921875,95.78399998291016 221 | 219,2.9372896466936385,0.8195946127700806,81.66599997070313,95.77200008544922 222 | 220,2.943613222667149,0.8090327165985107,81.58200005371094,95.78600011230469 223 | 221,2.937715836933681,0.8203316749954224,81.53800015625,95.80200006103516 224 | 222,2.922425525529044,0.809507099647522,81.49600012939453,95.82600019042968 225 | 223,2.903947642871312,0.8243318227386475,81.41200005371094,95.77400019042969 226 | 224,2.934469291142055,0.8180012958717346,81.5780000024414,95.79600008544922 227 | 225,2.9044903346470425,0.8277961733818054,81.66800010253907,95.7800000341797 228 | 226,2.9066080025264194,0.8131719005584717,81.75600005126954,95.85800005859375 229 | 227,2.897071055003575,0.8151784935569764,81.758000078125,95.84600013671874 230 | 228,2.896092210497175,0.816662384929657,81.672000078125,95.83800016113281 231 | 229,2.8886067867279053,0.8031469111061096,81.736000078125,95.96200000976563 232 | 230,2.8876868997301375,0.8095014908218384,81.77800017822265,95.95400008544922 233 | 231,2.8670308930533275,0.8181876503372192,81.60800005371094,95.89400013671874 234 | 232,2.8782890183585033,0.818968438835144,81.776000078125,95.89800003417969 235 | 233,2.8671919107437134,0.806286235408783,81.86400005371094,95.90400021484375 236 | 234,2.8603787081582204,0.8089845188522339,81.85200005371094,95.9380000341797 237 | 235,2.868993333407811,0.7960386561584473,81.93200015869141,95.94800008544922 238 | 236,2.8622406039919173,0.7991430662918091,82.00399992675781,95.95800003417969 239 | 237,2.8401939187731062,0.8049784683609009,82.01800005371094,95.87000006103516 240 | 238,2.8501793146133423,0.7990549187850953,81.92800005371093,95.92400003417968 241 | 239,2.8117557423455373,0.8094156219100952,82.05800005371094,95.86200021484375 242 | 240,2.864122373717172,0.8043149180984497,81.860000078125,95.85600019042968 243 | 241,2.8174088171550205,0.7992471007919312,81.960000078125,95.97200011230468 244 | 242,2.8268682105200633,0.8006832155990601,82.10400002929687,95.91400005859376 245 | 243,2.8138302053724016,0.7962913177490234,82.13600002685547,95.92200005859375 246 | 244,2.8205562148775374,0.798160097694397,82.08800002685547,95.96800013671874 247 | 245,2.8136788266045705,0.7950678706741333,82.162000078125,95.93200013671876 248 | 246,2.807972414152963,0.799595305480957,82.07200010253906,95.98800011230469 249 | 247,2.8016551562717984,0.8026992604064941,82.14200005126953,95.94200003417969 250 | 248,2.7925712210791453,0.8045618480300903,82.07399997558593,95.84400003417969 251 | 249,2.7869423968451366,0.7923246504020691,82.13400018310547,95.9759999560547 252 | 250,2.7787772757666453,0.7910479118919372,82.16000010742188,95.99600021484375 253 | 251,2.7906706162861417,0.7914137632751465,82.15800015380859,95.93200008544922 254 | 252,2.7810133014406477,0.7929864317321778,82.12000015625,96.03600005859376 255 | 253,2.7773448058537076,0.7978285437774658,82.21999995361328,95.95400011230468 256 | 254,2.76039673600878,0.793301633052826,82.27400000244141,95.98800006103515 257 | 255,2.7485550301415578,0.7903708033943176,82.23400005126953,96.05200003417968 258 | 256,2.751392960548401,0.7913374218559265,82.392,96.04600003417968 259 | 257,2.758316023009164,0.7950823604202271,82.33200005371094,96.04600013671875 260 | 258,2.7543596029281616,0.795383340511322,82.28400010253907,96.01000005859375 261 | 259,2.7355248587472096,0.7958355389404297,82.32999997558593,95.9080000341797 262 | 260,2.750061239515032,0.8003365090942383,82.350000078125,95.94200008544922 263 | 261,2.725623573575701,0.7950514992141724,82.39800015625,95.93600016113281 264 | 262,2.735450421060835,0.7940209593391419,82.36800005371094,95.97600006103515 265 | 263,2.7374914033072337,0.7982821480751038,82.40200010498047,96.04400008544921 266 | 264,2.7270736694335938,0.7936787528228759,82.44400005371094,96.01600008544922 267 | 265,2.7175720248903548,0.7923360603523254,82.48400015625,95.98400003417969 268 | 266,2.7283434527260915,0.7923417463302612,82.54200010498047,96.00200008544923 269 | 267,2.7084175518580844,0.7970557256317139,82.43800018310547,95.98400013671875 270 | 268,2.6986820016588484,0.7975159497070312,82.49600002929688,95.98600013671874 271 | 269,2.697665435927255,0.7950924592590332,82.562000078125,96.04600008544922 272 | 270,2.7052313770566667,0.7863004384803772,82.568000234375,96.05600011230469 273 | 271,2.6990305355616977,0.7904127935409546,82.59800013183593,96.02000011230469 274 | 272,2.7026418277195523,0.789643627948761,82.50400010498046,96.04000003417968 275 | 273,2.7105905498777116,0.7932098225784302,82.52200002929688,96.05000003417969 276 | 274,2.681556923048837,0.7893576400184631,82.53600002929687,96.04600013671875 277 | 275,2.6751122304371426,0.7879151227378846,82.51000005371094,96.03600008544922 278 | 276,2.6849872895649503,0.7889139362335205,82.64600002929687,96.06800010986328 279 | 277,2.6787040403911044,0.7888236251449585,82.55200008300781,95.99400008544922 280 | 278,2.6720841271536693,0.7870496005249024,82.71400013183593,96.00600008544922 281 | 279,2.676495739391872,0.7854489376831054,82.63600010498047,96.05600013671875 282 | 280,2.677263515336173,0.7890159526062012,82.60400008056641,95.99599998046875 283 | 281,2.6784958839416504,0.7886858188247681,82.56000020996093,96.04400008544921 284 | 282,2.6628720419747487,0.787669633693695,82.74600005615234,96.05800005859375 285 | 283,2.666781646864755,0.7844901875305176,82.69400010742187,96.13200010986328 286 | 284,2.6683060782296315,0.7915953833389282,82.61600008056641,96.03000008544922 287 | 285,2.6620519501822337,0.788902483997345,82.60000018554688,96.05200003417968 288 | 286,2.669293863432748,0.7923585120773315,82.66000002929688,96.11200006103516 289 | 287,2.667246903691973,0.7872966697692871,82.6840000805664,96.09200011230469 290 | 288,2.656998191561018,0.7885919734573364,82.56000008056641,96.05400019042969 291 | 289,2.6536265441349576,0.7867336483383178,82.67200018310547,96.07800016357422 292 | 290,2.6571514095578874,0.7854897700500488,82.62800008056641,96.14600011230469 293 | 291,2.6534872395651683,0.788360388622284,82.57600005371094,96.12000021484376 294 | 292,2.6571872404643466,0.7853965707969666,82.64000005615235,96.12400019042968 295 | 293,2.6569044249398366,0.7848235153579712,82.67200013183594,96.1080001123047 296 | 294,2.6470208849225725,0.785138702659607,82.67000002929687,96.06800019042969 297 | 295,2.6626992225646973,0.7841735245513916,82.73600015869141,96.07000019042968 298 | 296,2.659045491899763,0.7842865171432495,82.63400002929687,96.06800019042969 299 | 297,2.6465397902897427,0.7863325159835816,82.64000005371093,96.07800019042969 300 | 298,2.6474288361413136,0.7861149897384644,82.64600008056641,96.10400008544921 301 | 299,2.6600731100354875,0.7858068279266357,82.64600013183593,96.11200008544922 302 | 300,2.6368499483381,0.7834569989585877,82.61400018310547,96.10800003417968 303 | 301,2.656849111829485,0.7849034820556641,82.67000013183593,96.13200013671874 304 | 302,2.6492505414145335,0.783912653503418,82.59200020996094,96.07400013671875 305 | 303,2.6574468782969882,0.7871504070281983,82.63400005371093,96.06800013671875 306 | 304,2.663386957986014,0.785319022026062,82.67200015625,96.1020001123047 307 | 305,2.642723934991019,0.787382265586853,82.66600013183594,96.08800003417969 308 | 306,2.646093249320984,0.7869707857131958,82.62600013183594,96.11800011230468 309 | 307,2.6621898412704468,0.7888498119354248,82.59400005371094,96.04200019042969 310 | 308,2.6493602650506154,0.7859796387290955,82.61800013183594,96.09600003417968 311 | 309,2.6483495576041087,0.787670098953247,82.68200005371094,96.07400003417969 -------------------------------------------------------------------------------- /logs/vip_s7.log: -------------------------------------------------------------------------------- 1 | epoch,train_loss,eval_loss,eval_top1,eval_top5 2 | 0,6.910354852676392,6.871116542510986,0.2759999999332428,1.242000001373291 3 | 1,6.71982867377145,5.924816641235352,3.3279999920654295,10.678000010375976 4 | 2,6.547947951725551,5.489222848052979,6.406000028686523,18.003999963378906 5 | 3,6.291530098233904,4.778919605255127,13.257999993896485,30.625999959716797 6 | 4,6.065422262464251,4.354436435470581,18.4880000390625,39.28400002197266 7 | 5,5.855343341827393,3.8543875540161134,24.712000036621095,48.242000009765626 8 | 6,5.624480962753296,3.4115395363616945,31.09000003417969,56.33400001220703 9 | 7,5.405653442655291,3.1090993406677248,35.88799993164063,61.962000021972656 10 | 8,5.232038361685617,2.817085888595581,41.77600003417969,67.99799999511718 11 | 9,5.058237484523228,2.5883024393844605,45.716000095214845,71.65400013183594 12 | 10,4.912633725575039,2.41707734413147,48.91599996337891,74.45399997802734 13 | 11,4.791824647358486,2.2480293615722657,51.83399998291016,77.0840000732422 14 | 12,4.684575148991176,2.1050120391082765,54.60600010009765,79.26200001464844 15 | 13,4.601482255118234,2.0129325199127197,56.45199997314453,80.81000006591798 16 | 14,4.517404113497053,1.9534340384674072,57.78599991455078,81.77599993652343 17 | 15,4.447149447032383,1.8577291399765015,59.28200001464844,82.78800013671875 18 | 16,4.380970682416644,1.7844337257385254,60.38400015136719,83.83000013916016 19 | 17,4.324268647602627,1.7597603258514405,61.478000083007814,84.508000078125 20 | 18,4.284797668457031,1.7089714694595337,62.29000005859375,84.97799989990234 21 | 19,4.233554261071341,1.6634712961387634,63.12600001220703,85.83999997558594 22 | 20,4.2021098136901855,1.613315393676758,63.784000107421875,86.045999921875 23 | 21,4.191308311053684,1.611048492641449,64.45799998046876,86.63200007324218 24 | 22,4.162858963012695,1.5814465935707092,65.13199994873047,86.9299999975586 25 | 23,4.120690277644566,1.5399954890823364,65.75600013671875,87.33800005371094 26 | 24,4.089610798018319,1.5567087444877625,65.83000007324219,87.23799992431641 27 | 25,4.072297777448382,1.5134691138839722,66.51200002685547,87.77400010009765 28 | 26,4.076983400753567,1.4935909121513367,66.8679999194336,88.12999996582032 29 | 27,4.035860538482666,1.5114798937034606,66.80400010253906,87.9840001538086 30 | 28,4.024090017591204,1.4639056073379517,67.2299999975586,88.44000001708984 31 | 29,4.018065435545785,1.4688979947090148,67.17799991699219,88.48000012207031 32 | 30,3.9966084446225847,1.4603412785720826,67.93799994384766,88.79399999511719 33 | 31,3.9733629396983554,1.449040442123413,68.38199996582031,89.01800017333984 34 | 32,3.9605880805424283,1.42925826751709,68.52200009277344,89.08600007080078 35 | 33,3.975845388003758,1.3933022938346862,68.80200012695313,89.35400004638672 36 | 34,3.9497000660215105,1.3879936796379089,68.7520000415039,89.45799999023437 37 | 35,3.9356855154037476,1.3892407301712035,68.63400012451172,89.26400004882812 38 | 36,3.9234460251671925,1.4041576303100587,69.22600002197265,89.54799996582031 39 | 37,3.894150580678667,1.372424445362091,69.56199999023437,89.8439999633789 40 | 38,3.9084739855357578,1.3876727807235718,69.4959999633789,89.64000007080078 41 | 39,3.8825084481920515,1.3591691400146484,69.48199999267578,89.6180000756836 42 | 40,3.877161979675293,1.3445187977027893,69.616000078125,89.82600009521484 43 | 41,3.864820650645665,1.3553534746170044,69.86800002441406,89.8960000439453 44 | 42,3.8607681138174876,1.336970921382904,69.91400006347656,90.02400007080078 45 | 43,3.859097719192505,1.3509306648635864,70.10800001464844,89.86000001708985 46 | 44,3.835395165852138,1.324034704055786,70.38800006835937,90.23000004638672 47 | 45,3.8224423272269115,1.3121458512496949,70.09999999511719,90.24599991455078 48 | 46,3.8311296020235335,1.310946946220398,70.10000007080077,90.08999997070312 49 | 47,3.8177223035267422,1.321527822380066,70.5020001196289,90.48400006591797 50 | 48,3.814053399222238,1.2725862459373474,70.90600012207031,90.57800006835937 51 | 49,3.793250322341919,1.3034312795066834,70.98800001708985,90.57400001953125 52 | 50,3.797738415854318,1.3096038312339782,70.69000006591797,90.5100000415039 53 | 51,3.778265084539141,1.2816421472930908,70.78799991210937,90.54999996582032 54 | 52,3.7980651344571794,1.2915517928504945,71.12400014648438,90.61400006591796 55 | 53,3.782908082008362,1.2694016401672363,71.09000004150391,90.70800006591797 56 | 54,3.759143352508545,1.276143349018097,71.10000006591797,90.82000006835938 57 | 55,3.7605906554630826,1.2652636802482604,71.56400009521484,90.8680000415039 58 | 56,3.768913779939924,1.2687879308509826,71.3820000390625,90.96200001464844 59 | 57,3.7650125537599837,1.2620381854438782,71.52999996826172,91.07999996826172 60 | 58,3.744565418788365,1.2498784512901306,71.77800012207031,91.21999999267578 61 | 59,3.7503815037863597,1.2635473303604126,71.76799999267578,91.02400001708985 62 | 60,3.723265358379909,1.2620270862388612,71.59799993652344,90.98400006835938 63 | 61,3.7402224200112477,1.2400294036483765,71.82000000976562,91.0240001171875 64 | 62,3.738108124051775,1.2260919912910462,71.79600006347657,91.20000014404297 65 | 63,3.7196504729134694,1.2604127614212035,71.53199999023437,91.08199996826171 66 | 64,3.7104381322860718,1.2332775227355957,72.11200006347656,91.32800001464844 67 | 65,3.7230611188071117,1.2207095284080505,72.15600009521485,91.16800001464844 68 | 66,3.724540046283177,1.2408929682540895,72.25800009277344,91.36400016845703 69 | 67,3.7085669381277904,1.2237132588768005,72.28000012207032,91.52799999023438 70 | 68,3.712570939745222,1.2371479213142396,72.24000014404297,91.41400004394531 71 | 69,3.709454587527684,1.2106058625793457,72.57800006103515,91.40400009521484 72 | 70,3.714747667312622,1.222388146495819,72.53199993164063,91.3640001171875 73 | 71,3.6848203114100864,1.2133280154037476,72.5820001196289,91.63800011962891 74 | 72,3.6795212541307722,1.209502467842102,72.5040000415039,91.56199999023437 75 | 73,3.6712015526635304,1.194921001853943,72.37399993164063,91.41800004394531 76 | 74,3.6786709002086093,1.2171858721733093,72.64000012207032,91.45400017089844 77 | 75,3.704538515635899,1.2283701105499267,72.39199999023438,91.48200004150391 78 | 76,3.6698593412126814,1.189139610748291,72.91200000976562,91.71199999023437 79 | 77,3.664558802332197,1.1832522472381592,72.86000016357421,91.65799998779296 80 | 78,3.6730716739382063,1.1873121593856812,73.11400009033203,91.72400009521485 81 | 79,3.6714575631277904,1.19497297290802,73.00999986083984,91.6920000390625 82 | 80,3.652179683957781,1.2031186214637757,72.89400006591796,91.61000014648438 83 | 81,3.6477054187229703,1.2078253826141356,72.8820000415039,91.61400001708985 84 | 82,3.630286080496652,1.1940596000099182,72.88600003417969,91.84999999267578 85 | 83,3.6397500208445956,1.1715826042747497,73.42400006591797,91.8999999633789 86 | 84,3.6324156522750854,1.2031118625640869,73.11599993652344,91.84999999267578 87 | 85,3.6440361057009016,1.1838821751976014,73.38599998535156,91.96000001464844 88 | 86,3.651522159576416,1.1677544637489319,73.6659999609375,91.99799988769531 89 | 87,3.6378399303981235,1.1716902573776244,73.43799993408203,92.08199998779297 90 | 88,3.624962704522269,1.1783182759284974,73.38399998535156,91.94399998535157 91 | 89,3.6170146635600497,1.1617293420219421,73.76600013671874,92.04799998779296 92 | 90,3.61042514869145,1.172294651222229,73.72600011230469,92.18000014160157 93 | 91,3.6317668642316545,1.1791609832763672,73.59800004150391,92.10400004394532 94 | 92,3.6201460361480713,1.1480293505859376,73.61000000976563,92.23200001464843 95 | 93,3.6209590264729092,1.1779391680145264,73.72199998291016,92.1060001147461 96 | 94,3.613230058125087,1.1569854089546203,73.68000000488281,92.17000003662109 97 | 95,3.610241004398891,1.1557616481781006,73.82199998535157,92.39999988769532 98 | 96,3.6029657295772006,1.136378186893463,74.09800009277343,92.27199994140625 99 | 97,3.610003284045628,1.1518813787269593,73.80399998535157,92.1159999609375 100 | 98,3.602148277418954,1.1477407651901246,74.22800016357422,92.2520000390625 101 | 99,3.587556038584028,1.147612846698761,73.94999998291016,92.22599998535156 102 | 100,3.575423036302839,1.17546430147171,73.93199985839844,92.24600001464843 103 | 101,3.594241874558585,1.164679919052124,73.77200014404296,92.25400006591796 104 | 102,3.574530669621059,1.1410579643249512,74.12000006591796,92.3519999633789 105 | 103,3.577667934553964,1.1480506226921081,74.12800009033204,92.33000001708984 106 | 104,3.5631169251033237,1.1268435364151002,73.99600000488282,92.31200006591797 107 | 105,3.575550522123064,1.1371820400810242,74.20800010986328,92.33400001464844 108 | 106,3.5561102628707886,1.1347802314376831,74.03600006103515,92.38600001464843 109 | 107,3.561202202524458,1.124841847114563,74.45000000488281,92.57599993164062 110 | 108,3.5690958499908447,1.1377327275848388,74.32800005859374,92.48999998779297 111 | 109,3.5459833656038557,1.131335128288269,74.19399998535157,92.5159999609375 112 | 110,3.5482756921223233,1.1220351646995543,74.30600001220704,92.48999988769532 113 | 111,3.5477688312530518,1.1355343063354493,74.46599998779297,92.70999986083984 114 | 112,3.539861832346235,1.1296864748764037,74.57200003173828,92.67599998535157 115 | 113,3.5269309963498796,1.1203201083946228,74.82999998535156,92.66399998535157 116 | 114,3.5320667369025096,1.1132809732437134,74.79999999023437,92.76200017089843 117 | 115,3.530100005013602,1.1111866358757019,74.76600003417968,92.75200006347656 118 | 116,3.528291497911726,1.136079112148285,74.7240001123047,92.78599998779296 119 | 117,3.5343415226255144,1.1261422640800476,74.97400003662109,92.77399998535157 120 | 118,3.5179230996540616,1.0970723267936706,74.85800000488281,92.59000006591796 121 | 119,3.506786346435547,1.1209634058189393,75.1799999584961,92.69800001464844 122 | 120,3.513591698237828,1.1081656586265565,75.09600009033203,92.83000001464843 123 | 121,3.5281908001218523,1.1150537035751342,75.00599995361328,92.92800006591797 124 | 122,3.50358898299081,1.1063740638160706,75.1240000366211,92.94999983154297 125 | 123,3.4942887340273177,1.1060146494102479,75.05199993164062,92.9859999609375 126 | 124,3.4984929902212962,1.1117961889457704,75.09600003417968,92.87599998779297 127 | 125,3.503525836127145,1.100393250656128,75.00600000488281,92.94000017089844 128 | 126,3.491240450314113,1.0967469749259948,75.24199995605468,92.9420001171875 129 | 127,3.4967160565512523,1.1104211112976075,75.52800006103516,93.0040000390625 130 | 128,3.4972742795944214,1.1027438097000122,75.27600003173828,92.99400008789063 131 | 129,3.480168955666678,1.1014834995651246,75.1060000366211,92.92000011962891 132 | 130,3.451824818338667,1.1072121849250793,75.45400003417969,92.9720000390625 133 | 131,3.4820427553994313,1.080589510383606,75.4399999609375,93.19600006835938 134 | 132,3.4663478306361606,1.0852017770004272,75.4540001171875,93.00000006591797 135 | 133,3.4664641278130666,1.0898101559066773,75.31000005859374,93.10199995605468 136 | 134,3.4608843496867587,1.0899819431304931,75.51200003662109,93.17799998779297 137 | 135,3.4607612064906528,1.0701756216049194,75.54399998291015,93.16600009277344 138 | 136,3.45158212525504,1.0783992675590515,75.56799995605469,93.10799985839844 139 | 137,3.463596837861197,1.095099306163788,75.49399998779298,93.03799991210937 140 | 138,3.448858073779515,1.0766115677261352,75.85400011230469,93.21600001220703 141 | 139,3.4481924261365617,1.0719881107330322,75.84400005615234,93.31600003662109 142 | 140,3.444954037666321,1.0571806546020508,75.86799998046875,93.24400000976563 143 | 141,3.4698727130889893,1.0761893390464783,76.04199990478516,93.3980001147461 144 | 142,3.4144953659602573,1.06947179895401,75.88600011230469,93.26200006347656 145 | 143,3.4266265630722046,1.0708520328140259,75.72600006103515,93.2700001147461 146 | 144,3.448715192931039,1.064021452484131,75.97800005859375,93.22400003662109 147 | 145,3.4183421305247714,1.0608906400680542,75.95399998291016,93.24600000976562 148 | 146,3.420613612447466,1.065034186630249,76.24200000732422,93.46599993408203 149 | 147,3.416309118270874,1.0385579908561706,76.15000008789063,93.35799990722656 150 | 148,3.4180603538240706,1.0550948586654663,76.49999990478516,93.4919998852539 151 | 149,3.39834235395704,1.0568188303947448,76.16800000976562,93.45599993652344 152 | 150,3.421156185013907,1.0320460278701782,76.41200013427735,93.48200003662109 153 | 151,3.393655742917742,1.0543014291763306,76.32200013671876,93.57800000976563 154 | 152,3.402743475777762,1.026495654296875,76.7880000805664,93.4720001147461 155 | 153,3.3754907165254866,1.0559506983184814,76.39400000976562,93.7120000390625 156 | 154,3.3894948278154646,1.0255888409042357,76.47000003417969,93.61600006103515 157 | 155,3.3801317044666837,1.038236250591278,76.53000005859376,93.5659999609375 158 | 156,3.3776180233274187,1.0285412818336488,76.6440000024414,93.49800000732422 159 | 157,3.359822358403887,1.0299789972686768,76.67800016601562,93.75000001220702 160 | 158,3.3585726022720337,1.0387441951179504,76.64199993164063,93.67400001220703 161 | 159,3.3639749799455916,1.0356495218086244,76.71599995361328,93.61200001464844 162 | 160,3.3559899841036116,1.0380221781349181,76.86200005615234,93.74200006103516 163 | 161,3.374008740697588,1.0414112449455262,76.89400003173829,93.67199998535156 164 | 162,3.3787546157836914,1.0306082490730286,76.68199997802735,93.63200000976562 165 | 163,3.372366343225752,1.0280968539047242,76.75400003173829,93.76600005859375 166 | 164,3.3497800827026367,1.0359389358901978,77.06000005859374,93.85200000976562 167 | 165,3.3344835894448415,1.0007197301483155,77.01800005615235,93.84199995849609 168 | 166,3.3414117438452586,1.0291534510993958,77.11200010498047,93.92200005859375 169 | 167,3.3405251673289706,1.0049080563545227,77.19599990478515,93.96000008544922 170 | 168,3.3375538757869174,1.0196791957473754,77.09600019042969,93.88000016845703 171 | 169,3.339250292096819,1.0313093887710572,76.99800011230468,93.9060000366211 172 | 170,3.346706117902483,1.0118296506118774,77.18800005859374,93.99599990722656 173 | 171,3.329431022916521,1.0156722680473327,77.25999995117188,94.03999993164062 174 | 172,3.320967197418213,1.004298150844574,77.28400000976562,93.90599998779297 175 | 173,3.301955955369132,0.9765044348335266,77.23400012939453,94.05800008789062 176 | 174,3.3066959381103516,1.0013043937301636,77.48399998046875,94.0260001147461 177 | 175,3.306335415158953,0.9987107336425781,77.4440000341797,94.1159999560547 178 | 176,3.301999262401036,1.0165845069122315,77.22200003417969,93.96400003662109 179 | 177,3.310860804149083,0.9933593659973144,77.54000008544922,93.99400000976563 180 | 178,3.2820747068950107,0.9996557506561279,77.48000006103516,94.0519999584961 181 | 179,3.271898235593523,0.9985905562782288,77.74200005859375,94.12000006347657 182 | 180,3.260966028485979,0.9856931565856933,77.61400000732422,94.25400006103516 183 | 181,3.2826857226235524,0.983560849647522,77.74000000488282,94.14400011230468 184 | 182,3.286082250731332,0.9811837980079651,77.60000005615234,94.17999990722656 185 | 183,3.268093228340149,0.9945370296096802,77.63400003173828,94.1379998803711 186 | 184,3.2612391880580356,0.985034419631958,77.81999995605469,94.29400000976563 187 | 185,3.2700976133346558,0.976963846321106,77.93799987548829,94.42400006103516 188 | 186,3.278603434562683,0.9724673036575318,78.02400013427734,94.32599998535156 189 | 187,3.2424873283931186,0.9851917292404175,78.1199999584961,94.32200000976563 190 | 188,3.25116913659232,0.9664134226799012,78.1600000366211,94.33599998535156 191 | 189,3.2578448397772655,0.9595902676582336,78.32600013427735,94.48400008544922 192 | 190,3.235813958304269,0.9824482745170593,77.94800000732423,94.31400008789062 193 | 191,3.2409620455333163,0.9589320360755921,78.34200008544921,94.42400006103516 194 | 192,3.2311534200395857,0.9556274572372436,78.15599998535156,94.49799998535157 195 | 193,3.215212787900652,0.9710889654541016,78.206000078125,94.46000011230468 196 | 194,3.226354326520647,0.9665467096710205,78.06000005371094,94.41800008789062 197 | 195,3.204403383391244,0.955241902179718,78.47600008300782,94.48000000976562 198 | 196,3.2072790350232805,0.9491239792442322,78.38999997314453,94.55000008544921 199 | 197,3.214749421392168,0.9599419301986695,78.44200000488281,94.56200008789062 200 | 198,3.194374782698495,0.9488386449432373,78.42200010742188,94.60400006103515 201 | 199,3.180394700595311,0.9421100485229492,78.44999992675781,94.65200005859376 202 | 200,3.18903340612139,0.9776787174797058,78.49200005859375,94.52200011474609 203 | 201,3.18525184903826,0.9519289845657348,78.50400005126953,94.61400000732422 204 | 202,3.190683671406337,0.9545643902015686,78.60999995361328,94.61800006103516 205 | 203,3.1987829548971995,0.9410179819107055,78.73000002685546,94.68200011230469 206 | 204,3.1837877886635915,0.9397942705535889,78.82000012695312,94.73400011230468 207 | 205,3.1784403324127197,0.9315377732467651,78.758,94.71600016113281 208 | 206,3.154120002474104,0.9204479489707946,78.82199997802735,94.76399990478515 209 | 207,3.155407564980643,0.9336190315055847,78.86200008300781,94.79200000732422 210 | 208,3.1445415871483937,0.9249671794509887,78.71200008056641,94.73799995849609 211 | 209,3.141798666545323,0.9233237899017334,78.95000013183594,94.63600003417969 212 | 210,3.134514127458845,0.9126926391410828,78.89799992431641,94.74600011230469 213 | 211,3.149206757545471,0.9174883917427062,78.83600002685547,94.7559999560547 214 | 212,3.143160649708339,0.9237904894828797,78.93400005859375,94.76200014160156 215 | 213,3.147749968937465,0.9151992947959899,79.03000015625,94.85599993164062 216 | 214,3.1234214987073625,0.912416845111847,79.192000078125,94.95600003417968 217 | 215,3.1288598435265675,0.9083706735038757,79.08400010742187,94.98200011230469 218 | 216,3.13330910887037,0.9001986584281921,79.21600005615234,94.94800006103516 219 | 217,3.1170419795172557,0.907843272895813,79.41000008300782,94.94599995605469 220 | 218,3.109780958720616,0.897127061328888,79.46600010986329,94.99400003662109 221 | 219,3.107724360057286,0.8937215794372558,79.54200013427734,95.04000000732422 222 | 220,3.1108348880495345,0.9028960170936584,79.44200010986329,94.9520000366211 223 | 221,3.0952113015311107,0.8936521735954285,79.37600015869141,94.95000016601563 224 | 222,3.0864976133619035,0.8805189507102966,79.53400008300781,95.00999993164062 225 | 223,3.078485369682312,0.8937882517433167,79.55199997558594,94.96800006103516 226 | 224,3.086329000336783,0.8978740680313111,79.45,95.08000006103515 227 | 225,3.0571650436946323,0.8800405916404724,79.66800005859375,95.16400008544922 228 | 226,3.0648894480296542,0.8822767544555664,79.64000005859376,95.11600000732422 229 | 227,3.0691287347248624,0.8875536430358887,79.70000002929687,95.15600003417968 230 | 228,3.065608399254935,0.8871162219238281,79.81399997558594,95.13800008544922 231 | 229,3.066498262541635,0.8710383775711059,79.84799997802735,95.17200003417969 232 | 230,3.050912022590637,0.8750371379852295,79.82800010498048,95.17600011230469 233 | 231,3.046704428536551,0.8716553355979919,79.92400016113281,95.15999995605469 234 | 232,3.0496502774102345,0.8711385144615174,79.98800005859376,95.30399995605468 235 | 233,3.0478583744594028,0.8636911138343811,80.1399999560547,95.2259999560547 236 | 234,3.0329549653189525,0.8705187623786926,80.15400010986328,95.22200003417969 237 | 235,3.0256893975394115,0.8657631127738953,80.00000005615234,95.16000008789062 238 | 236,3.0303209168570384,0.8542928401947022,80.23600008300781,95.29399993164063 239 | 237,3.0169115577425276,0.8577844309997559,80.16400005859376,95.28199998535156 240 | 238,3.016129698072161,0.8544239654541016,80.29200016357422,95.28199988037109 241 | 239,3.0057530914034163,0.8579123154449463,80.2960000805664,95.3240000366211 242 | 240,3.0297729117529735,0.852004965057373,80.3680000024414,95.2479999584961 243 | 241,3.0139805930001393,0.8503977573013306,80.18000002929688,95.35200003662109 244 | 242,2.998254350253514,0.858746715259552,80.36000008544922,95.36800016601562 245 | 243,2.9913571902683804,0.853325531578064,80.43000000488281,95.30200006347657 246 | 244,2.9844330549240112,0.8470025736999511,80.45800002929687,95.31400011230468 247 | 245,2.9913331610815868,0.8380948524856567,80.54200000488281,95.41400008544922 248 | 246,2.9841435125895908,0.8404366687011718,80.6000000024414,95.38599998291015 249 | 247,2.968302386147635,0.8429741472434997,80.39400005859375,95.36800011230468 250 | 248,2.960051008633205,0.8380223566055298,80.54800018310547,95.42000008544922 251 | 249,2.9724017211369107,0.8329328897094727,80.49000005615234,95.46400003417969 252 | 250,2.9559037004198347,0.8426947736549377,80.54200008056641,95.41800006103516 253 | 251,2.9671612126486644,0.8383963717269898,80.6000000024414,95.37800006103515 254 | 252,2.9542258977890015,0.8353467890167237,80.77400000244141,95.52600006103516 255 | 253,2.9615061623709544,0.8334027414512635,80.68199995117187,95.44200000732422 256 | 254,2.9508489881243025,0.8319586064720154,80.75999995117188,95.45600003417968 257 | 255,2.9282792636326382,0.8328703302001953,80.8340000024414,95.53200005859375 258 | 256,2.9366190944399153,0.8400946211051941,80.78600013183593,95.52999990478516 259 | 257,2.9209174939564297,0.8330392538642883,80.81399992675782,95.52199993164062 260 | 258,2.9419304643358504,0.8300348459815979,80.95400013427735,95.53199995605469 261 | 259,2.9261889117104665,0.82850239944458,80.87800000244141,95.57800008544922 262 | 260,2.9221110684531078,0.821576551361084,80.81999997802734,95.57599995605469 263 | 261,2.9109401191983904,0.8324070910644531,80.88599998046875,95.5600000341797 264 | 262,2.901044641222273,0.8304000218582154,80.96600005615234,95.57600003417969 265 | 263,2.902141434805734,0.8277242088508606,80.92000005615235,95.59599993164062 266 | 264,2.9144849606922696,0.8224774540901184,80.94600015625,95.68400008544921 267 | 265,2.9092872653688704,0.8250922188377381,81.0340000805664,95.61600013916015 268 | 266,2.914496830531529,0.8196935567092896,81.11200002685547,95.64599995605468 269 | 267,2.8966016258512224,0.8289900078010559,81.10000013427734,95.63200003417968 270 | 268,2.8957821811948503,0.8225859983825684,81.16999995361329,95.66800016357422 271 | 269,2.8919225420270647,0.8262854515838624,80.9700000048828,95.71000008544922 272 | 270,2.9108768020357405,0.8263456849861145,81.21399992675781,95.64999998291016 273 | 271,2.889698658670698,0.8198468581581115,81.15400005615234,95.63799995605468 274 | 272,2.881616847855704,0.8244203399276734,81.17399995361328,95.67600016357422 275 | 273,2.8936853919710432,0.8262249835395813,81.22000005371093,95.67200016357422 276 | 274,2.8906019585473195,0.8214063010787964,81.2380000048828,95.68199990478516 277 | 275,2.865442224911281,0.8198296391677856,81.20599998046875,95.66799998291016 278 | 276,2.880365388734,0.8241642014503479,81.25799997802734,95.66800003417968 279 | 277,2.8817200490406583,0.8221864530181885,81.23600010742187,95.68000006103516 280 | 278,2.8627356631415233,0.8228190303802491,81.22000016113282,95.70600021728515 281 | 279,2.8710689885275706,0.8224459475517273,81.29200003173828,95.68600000976562 282 | 280,2.86979717867715,0.8252169724082947,81.25600018554688,95.66800008544922 283 | 281,2.855398109980992,0.8161679304122925,81.34800003173828,95.71599995605469 284 | 282,2.854904294013977,0.8144530500411987,81.35000005859375,95.76400003417969 285 | 283,2.8524268184389387,0.8198967656517029,81.29600008300781,95.68800006103515 286 | 284,2.850534898894174,0.8213496792984009,81.36400008300781,95.71600016357422 287 | 285,2.8636700255530223,0.8167939529228211,81.32200003173828,95.74600011230469 288 | 286,2.848572696958269,0.8188931160736084,81.37600003173829,95.73600016357422 289 | 287,2.869384867804391,0.8193492330932617,81.34400002929688,95.73600016357422 290 | 288,2.8438764810562134,0.8165874895858765,81.42600000488281,95.77600011230469 291 | 289,2.8604059389659335,0.8173203623771668,81.35000005615234,95.7240001123047 292 | 290,2.8472452333995273,0.8173615158843994,81.35199997802735,95.77600006103516 293 | 291,2.835127898624965,0.8174559487724304,81.44800005615234,95.78600019042969 294 | 292,2.8445494515555247,0.8202477996635437,81.41600005615234,95.74600013916016 295 | 293,2.8429692813328336,0.8218114996147156,81.38000013427734,95.74600019042968 296 | 294,2.8448794909885953,0.8210028151702881,81.38200005615235,95.78600011230469 297 | 295,2.8422097137996127,0.8206802438354492,81.44999997802735,95.74400011230469 298 | 296,2.8342733553477695,0.8187796735572815,81.41399997802735,95.76200013916015 299 | 297,2.8418258939470564,0.8203711761474609,81.45600015869141,95.76000013916016 300 | 298,2.849673339298793,0.8192286546516419,81.40200000488281,95.82000013916016 301 | 299,2.840991190501622,0.8190435117149353,81.51800002929687,95.80800008789062 302 | 300,2.8413085086005077,0.8172658764839172,81.41800018554687,95.76600013916016 303 | 301,2.8393015691212247,0.819350092830658,81.45800005859375,95.78000013916015 304 | 302,2.841710771833147,0.8177475336837768,81.49200010742187,95.76200013916015 305 | 303,2.846070579120091,0.8183813068389892,81.45000013427735,95.74400013916015 306 | 304,2.847775493349348,0.8178584407806396,81.40799989990235,95.76800008789063 307 | 305,2.8315299408776418,0.817798441696167,81.4460001586914,95.75400008789063 308 | 306,2.851865257535662,0.8194912733840942,81.46800008056641,95.72600013916015 309 | 307,2.8322412967681885,0.8179017197608948,81.46799989990234,95.76800000976563 310 | 308,2.8463452543531145,0.8190416221046448,81.45199992675781,95.75000011230469 311 | 309,2.8444330521992276,0.8170373690986633,81.49800010742187,95.76400013916016 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | ViP training and evaluating script 3 | This script is modified from pytorch-image-models by Ross Wightman (https://github.com/rwightman/pytorch-image-models/) 4 | It was started from an early version of the PyTorch ImageNet example 5 | (https://github.com/pytorch/examples/tree/master/imagenet) 6 | """ 7 | import argparse 8 | import time 9 | import yaml 10 | import os 11 | import logging 12 | from collections import OrderedDict 13 | from contextlib import suppress 14 | from datetime import datetime 15 | import models 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torchvision.utils 20 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 21 | 22 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 23 | from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model 24 | from timm.utils import * 25 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy 26 | from timm.optim import create_optimizer 27 | from timm.scheduler import create_scheduler 28 | from timm.utils import ApexScaler, NativeScaler 29 | 30 | torch.backends.cudnn.benchmark = True 31 | _logger = logging.getLogger('train') 32 | 33 | # The first arg parser parses out only the --config argument, this argument is used to 34 | # load a yaml file containing key-values that override the defaults for the main parser below 35 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 36 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 37 | help='YAML config file specifying default arguments') 38 | 39 | parser = argparse.ArgumentParser(description='ViP Training and Evaluating') 40 | 41 | # Dataset / Model parameters 42 | parser.add_argument('data', metavar='DIR', 43 | help='path to dataset') 44 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 45 | help='dataset type (default: ImageFolder/ImageTar if empty)') 46 | parser.add_argument('--train-split', metavar='NAME', default='train', 47 | help='dataset train split (default: train)') 48 | parser.add_argument('--val-split', metavar='NAME', default='validation', 49 | help='dataset validation split (default: validation)') 50 | parser.add_argument('--model', default='vip_s7', type=str, metavar='MODEL', 51 | help='Name of model to train (default: "countception"') 52 | parser.add_argument('--pretrained', action='store_true', default=False, 53 | help='Start with pretrained version of specified network (if avail)') 54 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 55 | help='Initialize model from this checkpoint (default: none)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='Resume full model and optimizer state from checkpoint (default: none)') 58 | parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH', 59 | help='path to eval checkpoint (default: none)') 60 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 61 | help='prevent resume of optimizer state when resuming model') 62 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N', 63 | help='number of label classes (default: 1000)') 64 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 65 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 66 | parser.add_argument('--img-size', type=int, default=224, metavar='N', 67 | help='Image patch size (default: None => model default)') 68 | parser.add_argument('--crop-pct', default=None, type=float, 69 | metavar='N', help='Input image center crop percent (for validation only)') 70 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 71 | help='Override mean pixel value of dataset') 72 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 73 | help='Override std deviation of of dataset') 74 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 75 | help='Image resize interpolation type (overrides model)') 76 | parser.add_argument('-b', '--batch-size', type=int, default=64, metavar='N', 77 | help='input batch size for training (default: 64)') 78 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', 79 | help='ratio of validation batch size to training batch size (default: 1)') 80 | 81 | # Optimizer parameters 82 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 83 | help='Optimizer (default: "adamw"') 84 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 85 | help='Optimizer Epsilon (default: None, use opt default)') 86 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 87 | help='Optimizer Betas (default: None, use opt default)') 88 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 89 | help='Optimizer momentum (default: 0.9)') 90 | parser.add_argument('--weight-decay', type=float, default=0.05, 91 | help='weight decay (default: 0.005 for adamw)') 92 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 93 | help='Clip gradient norm (default: None, no clipping)') 94 | 95 | # Learning rate schedule parameters 96 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 97 | help='LR scheduler (default: "cosine"') 98 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 99 | help='learning rate (default: 0.01)') 100 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 101 | help='learning rate noise on/off epoch percentages') 102 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 103 | help='learning rate noise limit percent (default: 0.67)') 104 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 105 | help='learning rate noise std-dev (default: 1.0)') 106 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 107 | help='learning rate cycle len multiplier (default: 1.0)') 108 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 109 | help='learning rate cycle limit') 110 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 111 | help='warmup learning rate (default: 0.0001)') 112 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 113 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 114 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 115 | help='number of epochs to train (default: 2)') 116 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 117 | help='manual epoch number (useful on restarts)') 118 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 119 | help='epoch interval to decay LR') 120 | parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N', 121 | help='epochs to warmup LR, if scheduler supports') 122 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 123 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 124 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 125 | help='patience epochs for Plateau LR scheduler (default: 10') 126 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 127 | help='LR decay rate (default: 0.1)') 128 | 129 | # Augmentation & regularization parameters 130 | parser.add_argument('--no-aug', action='store_true', default=False, 131 | help='Disable all training augmentation, override other train aug args') 132 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 133 | help='Random resize scale (default: 0.08 1.0)') 134 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', 135 | help='Random resize aspect ratio (default: 0.75 1.33)') 136 | parser.add_argument('--hflip', type=float, default=0.5, 137 | help='Horizontal flip training aug probability') 138 | parser.add_argument('--vflip', type=float, default=0., 139 | help='Vertical flip training aug probability') 140 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 141 | help='Color jitter factor (default: 0.4)') 142 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 143 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 144 | parser.add_argument('--aug-splits', type=int, default=0, 145 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 146 | parser.add_argument('--jsd', action='store_true', default=False, 147 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 148 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 149 | help='Random erase prob (default: 0.25)') 150 | parser.add_argument('--remode', type=str, default='pixel', 151 | help='Random erase mode (default: "const")') 152 | parser.add_argument('--recount', type=int, default=1, 153 | help='Random erase count (default: 1)') 154 | parser.add_argument('--resplit', action='store_true', default=False, 155 | help='Do not random erase first (clean) augmentation split') 156 | parser.add_argument('--mixup', type=float, default=0.8, 157 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 158 | parser.add_argument('--cutmix', type=float, default=1.0, 159 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 160 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 161 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 162 | parser.add_argument('--mixup-prob', type=float, default=1.0, 163 | help='Probability of performing mixup or cutmix when either/both is enabled') 164 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 165 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 166 | parser.add_argument('--mixup-mode', type=str, default='batch', 167 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 168 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 169 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 170 | parser.add_argument('--smoothing', type=float, default=0.1, 171 | help='Label smoothing (default: 0.1)') 172 | parser.add_argument('--train-interpolation', type=str, default='random', 173 | help='Training interpolation (random, bilinear, bicubic default: "random")') 174 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 175 | help='Dropout rate (default: 0.0)') 176 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 177 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 178 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 179 | help='Drop path rate (default: None)') 180 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 181 | help='Drop block rate (default: None)') 182 | 183 | # Batch norm parameters (only works with gen_efficientnet based models currently) 184 | parser.add_argument('--bn-tf', action='store_true', default=False, 185 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 186 | parser.add_argument('--bn-momentum', type=float, default=None, 187 | help='BatchNorm momentum override (if not None)') 188 | parser.add_argument('--bn-eps', type=float, default=None, 189 | help='BatchNorm epsilon override (if not None)') 190 | parser.add_argument('--sync-bn', action='store_true', 191 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 192 | parser.add_argument('--dist-bn', type=str, default='', 193 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 194 | parser.add_argument('--split-bn', action='store_true', 195 | help='Enable separate BN layers per augmentation split.') 196 | 197 | # Model Exponential Moving Average 198 | parser.add_argument('--model-ema', action='store_true', default=False, 199 | help='Enable tracking moving average of model weights') 200 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 201 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 202 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, 203 | help='decay factor for model weights moving average (default: 0.9998)') 204 | 205 | # Misc 206 | parser.add_argument('--seed', type=int, default=42, metavar='S', 207 | help='random seed (default: 42)') 208 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 209 | help='how many batches to wait before logging training status') 210 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 211 | help='how many batches to wait before writing recovery checkpoint') 212 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 213 | help='how many training processes to use (default: 1)') 214 | parser.add_argument('--num-gpu', type=int, default=1, 215 | help='Number of GPUS to use') 216 | parser.add_argument('--save-images', action='store_true', default=False, 217 | help='save images of input bathes every log interval for debugging') 218 | parser.add_argument('--amp', action='store_true', default=False, 219 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 220 | parser.add_argument('--apex-amp', action='store_true', default=False, 221 | help='Use NVIDIA Apex AMP mixed precision') 222 | parser.add_argument('--native-amp', action='store_true', default=False, 223 | help='Use Native Torch AMP mixed precision') 224 | parser.add_argument('--channels-last', action='store_true', default=False, 225 | help='Use channels_last memory layout') 226 | parser.add_argument('--pin-mem', action='store_true', default=False, 227 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 228 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 229 | help='disable fast prefetcher') 230 | parser.add_argument('--output', default='', type=str, metavar='PATH', 231 | help='path to output folder (default: none, current dir)') 232 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 233 | help='Best metric (default: "top1"') 234 | parser.add_argument('--tta', type=int, default=0, metavar='N', 235 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 236 | parser.add_argument("--local_rank", default=0, type=int) 237 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 238 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 239 | 240 | try: 241 | from apex import amp 242 | from apex.parallel import DistributedDataParallel as ApexDDP 243 | from apex.parallel import convert_syncbn_model 244 | 245 | has_apex = True 246 | except ImportError: 247 | has_apex = False 248 | 249 | has_native_amp = False 250 | try: 251 | if getattr(torch.cuda.amp, 'autocast') is not None: 252 | has_native_amp = True 253 | except AttributeError: 254 | pass 255 | 256 | def _parse_args(): 257 | # Do we have a config file to parse? 258 | args_config, remaining = config_parser.parse_known_args() 259 | if args_config.config: 260 | with open(args_config.config, 'r') as f: 261 | cfg = yaml.safe_load(f) 262 | parser.set_defaults(**cfg) 263 | 264 | # The main arg parser parses the rest of the args, the usual 265 | # defaults will have been overridden if config file specified. 266 | args = parser.parse_args(remaining) 267 | 268 | # Cache the args as a text string to save them in the output dir later 269 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 270 | return args, args_text 271 | 272 | 273 | def main(): 274 | setup_default_logging() 275 | args, args_text = _parse_args() 276 | 277 | args.prefetcher = not args.no_prefetcher 278 | args.distributed = False 279 | if 'WORLD_SIZE' in os.environ: 280 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 281 | if args.distributed and args.num_gpu > 1: 282 | _logger.warning( 283 | 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') 284 | args.num_gpu = 1 285 | 286 | args.device = 'cuda:0' 287 | args.world_size = 1 288 | args.rank = 0 # global rank 289 | if args.distributed: 290 | args.num_gpu = 1 291 | args.device = 'cuda:%d' % args.local_rank 292 | torch.cuda.set_device(args.local_rank) 293 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 294 | args.world_size = torch.distributed.get_world_size() 295 | args.rank = torch.distributed.get_rank() 296 | assert args.rank >= 0 297 | 298 | if args.distributed: 299 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 300 | % (args.rank, args.world_size)) 301 | else: 302 | _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) 303 | 304 | torch.manual_seed(args.seed + args.rank) 305 | 306 | model = create_model( 307 | args.model, 308 | pretrained=args.pretrained, 309 | num_classes=args.num_classes, 310 | drop_rate=args.drop, 311 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 312 | drop_path_rate=args.drop_path, 313 | drop_block_rate=args.drop_block, 314 | global_pool=args.gp, 315 | bn_tf=args.bn_tf, 316 | bn_momentum=args.bn_momentum, 317 | bn_eps=args.bn_eps, 318 | checkpoint_path=args.initial_checkpoint, 319 | img_size=args.img_size) 320 | 321 | if args.local_rank == 0: 322 | _logger.info('Model %s created, param count: %d' % 323 | (args.model, sum([m.numel() for m in model.parameters()]))) 324 | 325 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) 326 | 327 | num_aug_splits = 0 328 | if args.aug_splits > 0: 329 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 330 | num_aug_splits = args.aug_splits 331 | 332 | if args.split_bn: 333 | assert num_aug_splits > 1 or args.resplit 334 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 335 | 336 | use_amp = None 337 | if args.amp: 338 | # for backwards compat, `--amp` arg tries apex before native amp 339 | if has_apex: 340 | args.apex_amp = True 341 | elif has_native_amp: 342 | args.native_amp = True 343 | if args.apex_amp and has_apex: 344 | use_amp = 'apex' 345 | elif args.native_amp and has_native_amp: 346 | use_amp = 'native' 347 | elif args.apex_amp or args.native_amp: 348 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 349 | "Install NVIDA apex or upgrade to PyTorch 1.6") 350 | 351 | if args.num_gpu > 1: 352 | if use_amp == 'apex': 353 | _logger.warning( 354 | 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') 355 | use_amp = None 356 | model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() 357 | assert not args.channels_last, "Channels last not supported with DP, use DDP." 358 | else: 359 | model.cuda() 360 | if args.channels_last: 361 | model = model.to(memory_format=torch.channels_last) 362 | 363 | optimizer = create_optimizer(args, model) 364 | 365 | amp_autocast = suppress # do nothing 366 | loss_scaler = None 367 | if use_amp == 'apex': 368 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 369 | loss_scaler = ApexScaler() 370 | if args.local_rank == 0: 371 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 372 | elif use_amp == 'native': 373 | amp_autocast = torch.cuda.amp.autocast 374 | loss_scaler = NativeScaler() 375 | if args.local_rank == 0: 376 | _logger.info('Using native Torch AMP. Training in mixed precision.') 377 | else: 378 | if args.local_rank == 0: 379 | _logger.info('AMP not enabled. Training in float32.') 380 | 381 | # optionally resume from a checkpoint 382 | resume_epoch = None 383 | if args.resume: 384 | resume_epoch = resume_checkpoint( 385 | model, args.resume, 386 | optimizer=None if args.no_resume_opt else optimizer, 387 | loss_scaler=None if args.no_resume_opt else loss_scaler, 388 | log_info=args.local_rank == 0) 389 | 390 | model_ema = None 391 | if args.model_ema: 392 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 393 | model_ema = ModelEma( 394 | model, 395 | decay=args.model_ema_decay, 396 | device='cpu' if args.model_ema_force_cpu else '', 397 | resume=args.resume) 398 | 399 | if args.distributed: 400 | if args.sync_bn: 401 | assert not args.split_bn 402 | try: 403 | if has_apex and use_amp != 'native': 404 | # Apex SyncBN preferred unless native amp is activated 405 | model = convert_syncbn_model(model) 406 | else: 407 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 408 | if args.local_rank == 0: 409 | _logger.info( 410 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 411 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 412 | except Exception as e: 413 | _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') 414 | if has_apex and use_amp != 'native': 415 | # Apex DDP preferred unless native amp is activated 416 | if args.local_rank == 0: 417 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 418 | model = ApexDDP(model, delay_allreduce=True) 419 | else: 420 | if args.local_rank == 0: 421 | _logger.info("Using native Torch DistributedDataParallel.") 422 | model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 423 | # NOTE: EMA model does not need to be wrapped by DDP 424 | 425 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 426 | start_epoch = 0 427 | if args.start_epoch is not None: 428 | # a specified start_epoch will always override the resume epoch 429 | start_epoch = args.start_epoch 430 | elif resume_epoch is not None: 431 | start_epoch = resume_epoch 432 | if lr_scheduler is not None and start_epoch > 0: 433 | lr_scheduler.step(start_epoch) 434 | 435 | if args.local_rank == 0: 436 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 437 | 438 | dataset_train = create_dataset( 439 | args.dataset, root=args.data, split=args.train_split, is_training=True, batch_size=args.batch_size) 440 | dataset_eval = create_dataset( 441 | args.dataset, root=args.data, split=args.val_split, is_training=False, batch_size=args.batch_size) 442 | 443 | collate_fn = None 444 | mixup_fn = None 445 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 446 | if mixup_active: 447 | mixup_args = dict( 448 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 449 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 450 | label_smoothing=args.smoothing, num_classes=args.num_classes) 451 | if args.prefetcher: 452 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 453 | collate_fn = FastCollateMixup(**mixup_args) 454 | else: 455 | mixup_fn = Mixup(**mixup_args) 456 | 457 | if num_aug_splits > 1: 458 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 459 | 460 | train_interpolation = args.train_interpolation 461 | if args.no_aug or not train_interpolation: 462 | train_interpolation = data_config['interpolation'] 463 | loader_train = create_loader( 464 | dataset_train, 465 | input_size=data_config['input_size'], 466 | batch_size=args.batch_size, 467 | is_training=True, 468 | use_prefetcher=args.prefetcher, 469 | no_aug=args.no_aug, 470 | re_prob=args.reprob, 471 | re_mode=args.remode, 472 | re_count=args.recount, 473 | re_split=args.resplit, 474 | scale=args.scale, 475 | ratio=args.ratio, 476 | hflip=args.hflip, 477 | vflip=args.vflip, 478 | color_jitter=args.color_jitter, 479 | auto_augment=args.aa, 480 | num_aug_splits=num_aug_splits, 481 | interpolation=train_interpolation, 482 | mean=data_config['mean'], 483 | std=data_config['std'], 484 | num_workers=args.workers, 485 | distributed=args.distributed, 486 | collate_fn=collate_fn, 487 | pin_memory=args.pin_mem, 488 | use_multi_epochs_loader=args.use_multi_epochs_loader 489 | ) 490 | 491 | loader_eval = create_loader( 492 | dataset_eval, 493 | input_size=data_config['input_size'], 494 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 495 | is_training=False, 496 | use_prefetcher=args.prefetcher, 497 | interpolation=data_config['interpolation'], 498 | mean=data_config['mean'], 499 | std=data_config['std'], 500 | num_workers=args.workers, 501 | distributed=args.distributed, 502 | crop_pct=data_config['crop_pct'], 503 | pin_memory=args.pin_mem, 504 | ) 505 | 506 | if args.jsd: 507 | assert num_aug_splits > 1 # JSD only valid with aug splits set 508 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() 509 | elif mixup_active: 510 | # smoothing is handled with mixup target transform 511 | train_loss_fn = SoftTargetCrossEntropy().cuda() 512 | elif args.smoothing: 513 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() 514 | else: 515 | train_loss_fn = nn.CrossEntropyLoss().cuda() 516 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 517 | 518 | eval_metric = args.eval_metric 519 | best_metric = None 520 | best_epoch = None 521 | 522 | if args.eval_checkpoint: # evaluate the model 523 | load_checkpoint(model, args.eval_checkpoint, args.model_ema) 524 | val_metrics = validate(model, loader_eval, validate_loss_fn, args) 525 | print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%") 526 | return 527 | 528 | saver = None 529 | output_dir = '' 530 | if args.local_rank == 0: 531 | output_base = args.output if args.output else './output' 532 | exp_name = '-'.join([ 533 | datetime.now().strftime("%Y%m%d-%H%M%S"), 534 | args.model, 535 | str(data_config['input_size'][-1]) 536 | ]) 537 | output_dir = get_outdir(output_base, 'train', exp_name) 538 | decreasing = True if eval_metric == 'loss' else False 539 | saver = CheckpointSaver( 540 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 541 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) 542 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 543 | f.write(args_text) 544 | 545 | try: # train the model 546 | for epoch in range(start_epoch, num_epochs): 547 | if args.distributed: 548 | loader_train.sampler.set_epoch(epoch) 549 | 550 | train_metrics = train_epoch( 551 | epoch, model, loader_train, optimizer, train_loss_fn, args, 552 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, 553 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) 554 | 555 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 556 | if args.local_rank == 0: 557 | _logger.info("Distributing BatchNorm running means and vars") 558 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 559 | 560 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 561 | 562 | if model_ema is not None and not args.model_ema_force_cpu: 563 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 564 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 565 | ema_eval_metrics = validate( 566 | model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') 567 | eval_metrics = ema_eval_metrics 568 | 569 | if lr_scheduler is not None: 570 | # step LR for next epoch 571 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 572 | 573 | update_summary( 574 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), 575 | write_header=best_metric is None) 576 | 577 | if saver is not None: 578 | # save proper checkpoint with eval metric 579 | save_metric = eval_metrics[eval_metric] 580 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) 581 | 582 | except KeyboardInterrupt: 583 | pass 584 | if best_metric is not None: 585 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 586 | 587 | 588 | def train_epoch( 589 | epoch, model, loader, optimizer, loss_fn, args, 590 | lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, 591 | loss_scaler=None, model_ema=None, mixup_fn=None): 592 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 593 | if args.prefetcher and loader.mixup_enabled: 594 | loader.mixup_enabled = False 595 | elif mixup_fn is not None: 596 | mixup_fn.mixup_enabled = False 597 | 598 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 599 | batch_time_m = AverageMeter() 600 | data_time_m = AverageMeter() 601 | losses_m = AverageMeter() 602 | top1_m = AverageMeter() 603 | top5_m = AverageMeter() 604 | 605 | model.train() 606 | 607 | end = time.time() 608 | last_idx = len(loader) - 1 609 | num_updates = epoch * len(loader) 610 | for batch_idx, (input, target) in enumerate(loader): 611 | last_batch = batch_idx == last_idx 612 | data_time_m.update(time.time() - end) 613 | if not args.prefetcher: 614 | input, target = input.cuda(), target.cuda() 615 | if mixup_fn is not None: 616 | input, target = mixup_fn(input, target) 617 | if args.channels_last: 618 | input = input.contiguous(memory_format=torch.channels_last) 619 | 620 | with amp_autocast(): 621 | output = model(input) 622 | loss = loss_fn(output, target) 623 | 624 | if not args.distributed: 625 | losses_m.update(loss.item(), input.size(0)) 626 | 627 | optimizer.zero_grad() 628 | if loss_scaler is not None: 629 | loss_scaler( 630 | loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) 631 | else: 632 | loss.backward(create_graph=second_order) 633 | if args.clip_grad is not None: 634 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) 635 | optimizer.step() 636 | 637 | torch.cuda.synchronize() 638 | if model_ema is not None: 639 | model_ema.update(model) 640 | num_updates += 1 641 | 642 | batch_time_m.update(time.time() - end) 643 | if last_batch or batch_idx % args.log_interval == 0: 644 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 645 | lr = sum(lrl) / len(lrl) 646 | 647 | if args.distributed: 648 | reduced_loss = reduce_tensor(loss.data, args.world_size) 649 | losses_m.update(reduced_loss.item(), input.size(0)) 650 | 651 | if args.local_rank == 0: 652 | _logger.info( 653 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 654 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 655 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 656 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 657 | 'LR: {lr:.3e} ' 658 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 659 | epoch, 660 | batch_idx, len(loader), 661 | 100. * batch_idx / last_idx, 662 | loss=losses_m, 663 | batch_time=batch_time_m, 664 | rate=input.size(0) * args.world_size / batch_time_m.val, 665 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 666 | lr=lr, 667 | data_time=data_time_m)) 668 | 669 | if args.save_images and output_dir: 670 | torchvision.utils.save_image( 671 | input, 672 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 673 | padding=0, 674 | normalize=True) 675 | 676 | if saver is not None and args.recovery_interval and ( 677 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 678 | saver.save_recovery(epoch, batch_idx=batch_idx) 679 | 680 | if lr_scheduler is not None: 681 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 682 | 683 | end = time.time() 684 | # end for 685 | 686 | if hasattr(optimizer, 'sync_lookahead'): 687 | optimizer.sync_lookahead() 688 | 689 | return OrderedDict([('loss', losses_m.avg)]) 690 | 691 | 692 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): 693 | batch_time_m = AverageMeter() 694 | losses_m = AverageMeter() 695 | top1_m = AverageMeter() 696 | top5_m = AverageMeter() 697 | 698 | model.eval() 699 | 700 | end = time.time() 701 | last_idx = len(loader) - 1 702 | with torch.no_grad(): 703 | for batch_idx, (input, target) in enumerate(loader): 704 | last_batch = batch_idx == last_idx 705 | if not args.prefetcher: 706 | input = input.cuda() 707 | target = target.cuda() 708 | if args.channels_last: 709 | input = input.contiguous(memory_format=torch.channels_last) 710 | 711 | with amp_autocast(): 712 | output = model(input) 713 | if isinstance(output, (tuple, list)): 714 | output = output[0] 715 | 716 | # augmentation reduction 717 | reduce_factor = args.tta 718 | if reduce_factor > 1: 719 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 720 | target = target[0:target.size(0):reduce_factor] 721 | 722 | loss = loss_fn(output, target) 723 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 724 | 725 | if args.distributed: 726 | reduced_loss = reduce_tensor(loss.data, args.world_size) 727 | acc1 = reduce_tensor(acc1, args.world_size) 728 | acc5 = reduce_tensor(acc5, args.world_size) 729 | else: 730 | reduced_loss = loss.data 731 | 732 | torch.cuda.synchronize() 733 | 734 | losses_m.update(reduced_loss.item(), input.size(0)) 735 | top1_m.update(acc1.item(), output.size(0)) 736 | top5_m.update(acc5.item(), output.size(0)) 737 | 738 | batch_time_m.update(time.time() - end) 739 | end = time.time() 740 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 741 | log_name = 'Test' + log_suffix 742 | _logger.info( 743 | '{0}: [{1:>4d}/{2}] ' 744 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 745 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 746 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 747 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 748 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 749 | loss=losses_m, top1=top1_m, top5=top5_m)) 750 | 751 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 752 | 753 | return metrics 754 | 755 | 756 | if __name__ == '__main__': 757 | main() 758 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vip import * -------------------------------------------------------------------------------- /models/vip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 5 | from timm.models.layers import DropPath, trunc_normal_ 6 | from timm.models.registry import register_model 7 | 8 | def _cfg(url='', **kwargs): 9 | return { 10 | 'url': url, 11 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 12 | 'crop_pct': .96, 'interpolation': 'bicubic', 13 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', 14 | **kwargs 15 | } 16 | 17 | default_cfgs = { 18 | 'ViP_S': _cfg(crop_pct=0.9), 19 | 'ViP_M': _cfg(crop_pct=0.9), 20 | 'ViP_L': _cfg(crop_pct=0.875), 21 | } 22 | 23 | 24 | class Mlp(nn.Module): 25 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x): 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | 42 | class WeightedPermuteMLP(nn.Module): 43 | def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 44 | super().__init__() 45 | self.segment_dim = segment_dim 46 | 47 | self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias) 48 | self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias) 49 | self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias) 50 | 51 | self.reweight = Mlp(dim, dim // 4, dim *3) 52 | 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | 57 | 58 | def forward(self, x): 59 | B, H, W, C = x.shape 60 | 61 | S = C // self.segment_dim 62 | h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H*S) 63 | h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C) 64 | 65 | w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W*S) 66 | w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C) 67 | 68 | c = self.mlp_c(x) 69 | 70 | a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2) 71 | a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2) 72 | 73 | x = h * a[0] + w * a[1] + c * a[2] 74 | 75 | x = self.proj(x) 76 | x = self.proj_drop(x) 77 | 78 | return x 79 | 80 | 81 | class PermutatorBlock(nn.Module): 82 | 83 | def __init__(self, dim, segment_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 84 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip_lam=1.0, mlp_fn = WeightedPermuteMLP): 85 | super().__init__() 86 | self.norm1 = norm_layer(dim) 87 | self.attn = mlp_fn(dim, segment_dim=segment_dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop) 88 | 89 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 90 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 91 | 92 | self.norm2 = norm_layer(dim) 93 | mlp_hidden_dim = int(dim * mlp_ratio) 94 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 95 | self.skip_lam = skip_lam 96 | 97 | def forward(self, x): 98 | x = x + self.drop_path(self.attn(self.norm1(x))) / self.skip_lam 99 | x = x + self.drop_path(self.mlp(self.norm2(x))) / self.skip_lam 100 | return x 101 | 102 | class PatchEmbed(nn.Module): 103 | """ Image to Patch Embedding 104 | """ 105 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 106 | super().__init__() 107 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 108 | 109 | def forward(self, x): 110 | x = self.proj(x) # B, C, H, W 111 | return x 112 | 113 | 114 | class Downsample(nn.Module): 115 | """ Image to Patch Embedding 116 | """ 117 | def __init__(self, in_embed_dim, out_embed_dim, patch_size): 118 | super().__init__() 119 | self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) 120 | 121 | def forward(self, x): 122 | x = x.permute(0, 3, 1, 2) 123 | x = self.proj(x) # B, C, H, W 124 | x = x.permute(0, 2, 3, 1) 125 | return x 126 | 127 | def basic_blocks(dim, index, layers, segment_dim, mlp_ratio=3., qkv_bias=False, qk_scale=None, \ 128 | attn_drop=0, drop_path_rate=0., skip_lam=1.0, mlp_fn = WeightedPermuteMLP, **kwargs): 129 | blocks = [] 130 | 131 | for block_idx in range(layers[index]): 132 | block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) 133 | blocks.append(PermutatorBlock(dim, segment_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\ 134 | attn_drop=attn_drop, drop_path=block_dpr, skip_lam=skip_lam, mlp_fn = mlp_fn)) 135 | 136 | blocks = nn.Sequential(*blocks) 137 | 138 | return blocks 139 | 140 | class VisionPermutator(nn.Module): 141 | """ Vision Permutator 142 | """ 143 | def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 144 | embed_dims=None, transitions=None, segment_dim=None, mlp_ratios=None, skip_lam=1.0, 145 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 146 | norm_layer=nn.LayerNorm,mlp_fn = WeightedPermuteMLP): 147 | 148 | super().__init__() 149 | self.num_classes = num_classes 150 | 151 | self.patch_embed = PatchEmbed(img_size = img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) 152 | 153 | network = [] 154 | for i in range(len(layers)): 155 | stage = basic_blocks(embed_dims[i], i, layers, segment_dim[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 156 | qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, norm_layer=norm_layer, skip_lam=skip_lam, 157 | mlp_fn = mlp_fn) 158 | network.append(stage) 159 | if i >= len(layers) - 1: 160 | break 161 | if transitions[i] or embed_dims[i] != embed_dims[i+1]: 162 | patch_size = 2 if transitions[i] else 1 163 | network.append(Downsample(embed_dims[i], embed_dims[i+1], patch_size)) 164 | 165 | 166 | self.network = nn.ModuleList(network) 167 | 168 | self.norm = norm_layer(embed_dims[-1]) 169 | 170 | # Classifier head 171 | self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 172 | self.apply(self._init_weights) 173 | 174 | def _init_weights(self, m): 175 | if isinstance(m, nn.Linear): 176 | trunc_normal_(m.weight, std=.02) 177 | if isinstance(m, nn.Linear) and m.bias is not None: 178 | nn.init.constant_(m.bias, 0) 179 | elif isinstance(m, nn.LayerNorm): 180 | nn.init.constant_(m.bias, 0) 181 | nn.init.constant_(m.weight, 1.0) 182 | 183 | def get_classifier(self): 184 | return self.head 185 | 186 | def reset_classifier(self, num_classes, global_pool=''): 187 | self.num_classes = num_classes 188 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 189 | 190 | def forward_embeddings(self, x): 191 | x = self.patch_embed(x) 192 | # B,C,H,W-> B,H,W,C 193 | x = x.permute(0, 2, 3, 1) 194 | return x 195 | 196 | def forward_tokens(self,x): 197 | for idx, block in enumerate(self.network): 198 | x = block(x) 199 | B, H, W, C = x.shape 200 | x = x.reshape(B, -1, C) 201 | return x 202 | 203 | def forward(self, x): 204 | x = self.forward_embeddings(x) 205 | # B, H, W, C -> B, N, C 206 | x = self.forward_tokens(x) 207 | x = self.norm(x) 208 | return self.head(x.mean(1)) 209 | 210 | 211 | 212 | 213 | @register_model 214 | def vip_s14(pretrained=False, **kwargs): 215 | layers = [4, 3, 8, 3] 216 | transitions = [False, False, False, False] 217 | segment_dim = [16, 16, 16, 16] 218 | mlp_ratios = [3, 3, 3, 3] 219 | embed_dims = [384, 384, 384, 384] 220 | model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=14, transitions=transitions, 221 | segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) 222 | model.default_cfg = default_cfgs['ViP_S'] 223 | return model 224 | 225 | @register_model 226 | def vip_s7(pretrained=False, **kwargs): 227 | layers = [4, 3, 8, 3] 228 | transitions = [True, False, False, False] 229 | segment_dim = [32, 16, 16, 16] 230 | mlp_ratios = [3, 3, 3, 3] 231 | embed_dims = [192, 384, 384, 384] 232 | model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 233 | segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) 234 | model.default_cfg = default_cfgs['ViP_S'] 235 | return model 236 | 237 | @register_model 238 | def vip_m7(pretrained=False, **kwargs): 239 | # 55534632 240 | layers = [4, 3, 14, 3] 241 | transitions = [False, True, False, False] 242 | segment_dim = [32, 32, 16, 16] 243 | mlp_ratios = [3, 3, 3, 3] 244 | embed_dims = [256, 256, 512, 512] 245 | model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 246 | segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) 247 | model.default_cfg = default_cfgs['ViP_M'] 248 | return model 249 | 250 | 251 | @register_model 252 | def vip_l7(pretrained=False, **kwargs): 253 | layers = [8, 8, 16, 4] 254 | transitions = [True, False, False, False] 255 | segment_dim = [32, 16, 16, 16] 256 | mlp_ratios = [3, 3, 3, 3] 257 | embed_dims = [256, 512, 512, 512] 258 | model = VisionPermutator(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, 259 | segment_dim=segment_dim, mlp_ratios=mlp_ratios, mlp_fn=WeightedPermuteMLP, **kwargs) 260 | model.default_cfg = default_cfgs['ViP_L'] 261 | return model 262 | -------------------------------------------------------------------------------- /permute_mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/houqb/VisionPermutator/2cab2cedecfcc6aa938a341276a61fc6d6505579/permute_mlp.png -------------------------------------------------------------------------------- /transfer_learning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] YUAN Li@NUS. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | 7 | '''Tranfer pretrained vip to downstream dataset: CIFAR10/CIFAR100.''' 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import torch.backends.cudnn as cudnn 13 | 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | 17 | import os 18 | import argparse 19 | 20 | from models import * 21 | from timm.models import * 22 | from utils import progress_bar 23 | from timm.models import create_model 24 | from utils import load_for_transfer_learning 25 | 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/CIFAR100 Training') 28 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 29 | parser.add_argument('--wd', default=5e-4, type=float, help='weight decay') 30 | parser.add_argument('--min-lr', default=2e-4, type=float, help='minimal learning rate') 31 | parser.add_argument('--dataset', type=str, default='cifar10', 32 | help='cifar10 or cifar100') 33 | parser.add_argument('--b', type=int, default=128, 34 | help='batch size') 35 | parser.add_argument('--resume', '-r', action='store_true', 36 | help='resume from checkpoint') 37 | parser.add_argument('--pretrained', action='store_true', default=False, 38 | help='Start with pretrained version of specified network (if avail)') 39 | parser.add_argument('--num-classes', type=int, default=10, metavar='N', 40 | help='number of label classes (default: 1000)') 41 | parser.add_argument('--model', default='vip_s7', type=str, metavar='MODEL', 42 | help='Name of model to train (default: "countception"') 43 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 44 | help='Dropout rate (default: 0.0)') 45 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 46 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 47 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 48 | help='Drop path rate (default: None)') 49 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 50 | help='Drop block rate (default: None)') 51 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 52 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 53 | parser.add_argument('--img-size', type=int, default=224, metavar='N', 54 | help='Image patch size (default: None => model default)') 55 | parser.add_argument('--bn-tf', action='store_true', default=False, 56 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 57 | parser.add_argument('--bn-momentum', type=float, default=None, 58 | help='BatchNorm momentum override (if not None)') 59 | parser.add_argument('--bn-eps', type=float, default=None, 60 | help='BatchNorm epsilon override (if not None)') 61 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 62 | help='Initialize model from this checkpoint (default: none)') 63 | # Transfer learning 64 | parser.add_argument('--transfer-learning', default=False, 65 | help='Enable transfer learning') 66 | parser.add_argument('--transfer-model', type=str, default=None, 67 | help='Path to pretrained model for transfer learning') 68 | parser.add_argument('--transfer-ratio', type=float, default=0.01, 69 | help='lr ratio between classifier and backbone in transfer learning') 70 | 71 | args = parser.parse_args() 72 | 73 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 74 | best_acc = 0 # best test accuracy 75 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 76 | 77 | # Data 78 | print('==> Preparing data..') 79 | transform_train = transforms.Compose([ 80 | transforms.Resize(args.img_size), 81 | transforms.RandomCrop(args.img_size, padding=(args.img_size//8)), 82 | transforms.RandomHorizontalFlip(), 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 85 | ]) 86 | 87 | transform_test = transforms.Compose([ 88 | transforms.Resize(args.img_size), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 91 | ]) 92 | 93 | if args.dataset=='cifar10': 94 | args.num_classes = 10 95 | trainset = torchvision.datasets.CIFAR10( 96 | root='./data', train=True, download=True, transform=transform_train) 97 | testset = torchvision.datasets.CIFAR10( 98 | root='./data', train=False, download=True, transform=transform_test) 99 | 100 | elif args.dataset=='cifar100': 101 | args.num_classes = 100 102 | trainset = torchvision.datasets.CIFAR100( 103 | root='./data', train=True, download=True, transform=transform_train) 104 | testset = torchvision.datasets.CIFAR100( 105 | root='./data', train=False, download=True, transform=transform_test) 106 | else: 107 | print('Please use cifar10 or cifar100 dataset.') 108 | 109 | trainloader = torch.utils.data.DataLoader( 110 | trainset, batch_size=args.b, shuffle=True, num_workers=8) 111 | testloader = torch.utils.data.DataLoader( 112 | testset, batch_size=100, shuffle=False, num_workers=8) 113 | 114 | print(f'learning rate:{args.lr}, weight decay: {args.wd}') 115 | # create model 116 | print('==> Building model..') 117 | net = create_model( 118 | args.model, 119 | pretrained=args.pretrained, 120 | num_classes=args.num_classes, 121 | drop_rate=args.drop, 122 | drop_connect_rate=args.drop_connect, 123 | drop_path_rate=args.drop_path, 124 | drop_block_rate=args.drop_block, 125 | global_pool=args.gp, 126 | bn_tf=args.bn_tf, 127 | bn_momentum=args.bn_momentum, 128 | bn_eps=args.bn_eps, 129 | checkpoint_path=args.initial_checkpoint, 130 | img_size=args.img_size) 131 | 132 | if args.transfer_learning: 133 | print('transfer learning, load vip pretrained model') 134 | load_for_transfer_learning(net, args.transfer_model, use_ema=True, strict=False, num_classes=args.num_classes) 135 | 136 | net = net.to(device) 137 | if device == 'cuda': 138 | net = torch.nn.DataParallel(net) 139 | cudnn.benchmark = True 140 | 141 | if args.resume: 142 | # Load checkpoint. 143 | print('==> Resuming from checkpoint..') 144 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 145 | checkpoint = torch.load('./checkpoint/ckpt.pth') 146 | net.load_state_dict(checkpoint['net']) 147 | best_acc = checkpoint['acc'] 148 | start_epoch = checkpoint['epoch'] 149 | 150 | criterion = nn.CrossEntropyLoss() 151 | 152 | # set optimizer 153 | if args.transfer_learning: 154 | #print(net) 155 | print('set different lr for the backbone and classifier(head) of vip') 156 | parameters = [{'params': net.module.network.parameters(), 'lr': args.transfer_ratio * args.lr}, 157 | {'params': net.module.head.parameters()}] 158 | else: 159 | parameters = net.parameters() 160 | 161 | optimizer = optim.SGD(parameters, lr=args.lr, 162 | momentum=0.9, weight_decay=args.wd) 163 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=args.min_lr, T_max=60) 164 | 165 | # Training 166 | def train(epoch): 167 | print('\nEpoch: %d' % epoch) 168 | net.train() 169 | train_loss = 0 170 | correct = 0 171 | total = 0 172 | for batch_idx, (inputs, targets) in enumerate(trainloader): 173 | inputs, targets = inputs.to(device), targets.to(device) 174 | optimizer.zero_grad() 175 | outputs = net(inputs) 176 | loss = criterion(outputs, targets) 177 | loss.backward() 178 | optimizer.step() 179 | 180 | train_loss += loss.item() 181 | _, predicted = outputs.max(1) 182 | total += targets.size(0) 183 | correct += predicted.eq(targets).sum().item() 184 | 185 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 186 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 187 | 188 | def test(epoch): 189 | global best_acc 190 | net.eval() 191 | test_loss = 0 192 | correct = 0 193 | total = 0 194 | with torch.no_grad(): 195 | for batch_idx, (inputs, targets) in enumerate(testloader): 196 | inputs, targets = inputs.to(device), targets.to(device) 197 | outputs = net(inputs) 198 | loss = criterion(outputs, targets) 199 | 200 | test_loss += loss.item() 201 | _, predicted = outputs.max(1) 202 | total += targets.size(0) 203 | correct += predicted.eq(targets).sum().item() 204 | 205 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 206 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 207 | 208 | # Save checkpoint. 209 | acc = 100.*correct/total 210 | if acc > best_acc: 211 | print('Saving..') 212 | state = { 213 | 'net': net.state_dict(), 214 | 'acc': acc, 215 | 'epoch': epoch, 216 | } 217 | if not os.path.isdir(f'checkpoint_{args.dataset}_{args.model}'): 218 | os.mkdir(f'checkpoint_{args.dataset}_{args.model}') 219 | torch.save(state, f'./checkpoint_{args.dataset}_{args.model}/ckpt_{args.lr}_{args.wd}_{acc}.pth') 220 | best_acc = acc 221 | 222 | 223 | for epoch in range(start_epoch, start_epoch+60): 224 | train(epoch) 225 | test(epoch) 226 | scheduler.step() 227 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] YUAN Li@NUS. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | 7 | ''' 8 | - load_for_transfer_learning: load pretrained paramters to model in transfer learning 9 | - get_mean_and_std: calculate the mean and std value of dataset. 10 | - msr_init: net parameter initialization. 11 | - progress_bar: progress bar mimic xlua.progress. 12 | ''' 13 | import os 14 | import sys 15 | import time 16 | import torch 17 | 18 | import torch.nn as nn 19 | import torch.nn.init as init 20 | import logging 21 | import os 22 | from collections import OrderedDict 23 | 24 | _logger = logging.getLogger(__name__) 25 | 26 | 27 | def resize_pos_embed(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1) 28 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 29 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 30 | ntok_new = posemb_new.shape[1] 31 | if True: 32 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] # posemb_tok is for cls token, posemb_grid for the following tokens 33 | ntok_new -= 1 34 | else: 35 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 36 | gs_old = int(math.sqrt(len(posemb_grid))) # 14 37 | gs_new = int(math.sqrt(ntok_new)) # 24 38 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 39 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14] 40 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] 41 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim] 42 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) # [1, 24*24+1, dim] 43 | return posemb 44 | 45 | def load_state_dict(checkpoint_path, model, num_classes, use_ema=False, del_posemb=False): 46 | if checkpoint_path and os.path.isfile(checkpoint_path): 47 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 48 | state_dict_key = 'state_dict' 49 | if isinstance(checkpoint, dict): 50 | if use_ema and 'state_dict_ema' in checkpoint: 51 | state_dict_key = 'state_dict_ema' 52 | if state_dict_key and state_dict_key in checkpoint: 53 | new_state_dict = OrderedDict() 54 | for k, v in checkpoint[state_dict_key].items(): 55 | # strip `module.` prefix 56 | name = k[7:] if k.startswith('module') else k 57 | new_state_dict[name] = v 58 | state_dict = new_state_dict 59 | else: 60 | state_dict = checkpoint 61 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 62 | print(f'num classes: {num_classes}') 63 | if num_classes != 1000: 64 | # completely discard fully connected for all other differences between pretrained and created model 65 | print('delete the original class') 66 | del state_dict['head' + '.weight'] 67 | del state_dict['head' + '.bias'] 68 | 69 | if del_posemb==True: 70 | del state_dict['pos_embed'] 71 | 72 | #old_posemb = state_dict['pos_embed'] 73 | #if model.pos_embed.shape != old_posemb.shape: # need resize the position embedding by interpolate 74 | # new_posemb = resize_pos_embed(old_posemb, model.pos_embed) 75 | # state_dict['pos_embed'] = new_posemb 76 | 77 | return state_dict 78 | else: 79 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 80 | raise FileNotFoundError() 81 | 82 | 83 | 84 | def load_for_transfer_learning(model, checkpoint_path, num_classes, use_ema=False, strict=True): 85 | print(f'num classes: {num_classes}') 86 | state_dict = load_state_dict(checkpoint_path, use_ema, num_classes) 87 | model.load_state_dict(state_dict, strict=strict) 88 | 89 | 90 | def get_mean_and_std(dataset): 91 | '''Compute the mean and std value of dataset.''' 92 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 93 | mean = torch.zeros(3) 94 | std = torch.zeros(3) 95 | print('==> Computing mean and std..') 96 | for inputs, targets in dataloader: 97 | for i in range(3): 98 | mean[i] += inputs[:,i,:,:].mean() 99 | std[i] += inputs[:,i,:,:].std() 100 | mean.div_(len(dataset)) 101 | std.div_(len(dataset)) 102 | return mean, std 103 | 104 | def init_params(net): 105 | '''Init layer parameters.''' 106 | for m in net.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | init.kaiming_normal(m.weight, mode='fan_out') 109 | if m.bias: 110 | init.constant(m.bias, 0) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | init.constant(m.weight, 1) 113 | init.constant(m.bias, 0) 114 | elif isinstance(m, nn.Linear): 115 | init.normal(m.weight, std=1e-3) 116 | if m.bias: 117 | init.constant(m.bias, 0) 118 | 119 | 120 | _, term_width = os.popen('stty size', 'r').read().split() 121 | term_width = int(term_width) 122 | 123 | TOTAL_BAR_LENGTH = 65. 124 | last_time = time.time() 125 | begin_time = last_time 126 | def progress_bar(current, total, msg=None): 127 | global last_time, begin_time 128 | if current == 0: 129 | begin_time = time.time() # Reset for new bar. 130 | 131 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 132 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 133 | 134 | sys.stdout.write(' [') 135 | for i in range(cur_len): 136 | sys.stdout.write('=') 137 | sys.stdout.write('>') 138 | for i in range(rest_len): 139 | sys.stdout.write('.') 140 | sys.stdout.write(']') 141 | 142 | cur_time = time.time() 143 | step_time = cur_time - last_time 144 | last_time = cur_time 145 | tot_time = cur_time - begin_time 146 | 147 | L = [] 148 | L.append(' Step: %s' % format_time(step_time)) 149 | L.append(' | Tot: %s' % format_time(tot_time)) 150 | if msg: 151 | L.append(' | ' + msg) 152 | 153 | msg = ''.join(L) 154 | sys.stdout.write(msg) 155 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 156 | sys.stdout.write(' ') 157 | 158 | # Go back to the center of the bar. 159 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 160 | sys.stdout.write('\b') 161 | sys.stdout.write(' %d/%d ' % (current+1, total)) 162 | 163 | if current < total-1: 164 | sys.stdout.write('\r') 165 | else: 166 | sys.stdout.write('\n') 167 | sys.stdout.flush() 168 | 169 | def format_time(seconds): 170 | days = int(seconds / 3600/24) 171 | seconds = seconds - days*3600*24 172 | hours = int(seconds / 3600) 173 | seconds = seconds - hours*3600 174 | minutes = int(seconds / 60) 175 | seconds = seconds - minutes*60 176 | secondsf = int(seconds) 177 | seconds = seconds - secondsf 178 | millis = int(seconds*1000) 179 | 180 | f = '' 181 | i = 1 182 | if days > 0: 183 | f += str(days) + 'D' 184 | i += 1 185 | if hours > 0 and i <= 2: 186 | f += str(hours) + 'h' 187 | i += 1 188 | if minutes > 0 and i <= 2: 189 | f += str(minutes) + 'm' 190 | i += 1 191 | if secondsf > 0 and i <= 2: 192 | f += str(secondsf) + 's' 193 | i += 1 194 | if millis > 0 and i <= 2: 195 | f += str(millis) + 'ms' 196 | i += 1 197 | if f == '': 198 | f = '0ms' 199 | return f 200 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Validation Script 3 | Adapted from https://github.com/rwightman/pytorch-image-models 4 | The script is further extend to evaluate ViP models 5 | 6 | """ 7 | import argparse 8 | import os 9 | import csv 10 | import glob 11 | import time 12 | import logging 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | from collections import OrderedDict 17 | from contextlib import suppress 18 | 19 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models 20 | from timm.models.helpers import load_state_dict 21 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet 22 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy 23 | import models 24 | 25 | has_apex = False 26 | try: 27 | from apex import amp 28 | has_apex = True 29 | except ImportError: 30 | pass 31 | 32 | has_native_amp = False 33 | try: 34 | if getattr(torch.cuda.amp, 'autocast') is not None: 35 | has_native_amp = True 36 | except AttributeError: 37 | pass 38 | 39 | torch.backends.cudnn.benchmark = True 40 | _logger = logging.getLogger('validate') 41 | 42 | 43 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 44 | parser.add_argument('data', metavar='DIR', 45 | help='path to dataset') 46 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 47 | help='dataset type (default: ImageFolder/ImageTar if empty)') 48 | parser.add_argument('--split', metavar='NAME', default='validation', 49 | help='dataset split (default: validation)') 50 | parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', 51 | help='model architecture (default: dpn92)') 52 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 53 | help='number of data loading workers (default: 2)') 54 | parser.add_argument('-b', '--batch-size', default=256, type=int, 55 | metavar='N', help='mini-batch size (default: 256)') 56 | parser.add_argument('--img-size', default=None, type=int, 57 | metavar='N', help='Input image dimension, uses model default if empty') 58 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 59 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 60 | parser.add_argument('--crop-pct', default=None, type=float, 61 | metavar='N', help='Input image center crop pct') 62 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 63 | help='Override mean pixel value of dataset') 64 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 65 | help='Override std deviation of of dataset') 66 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 67 | help='Image resize interpolation type (overrides model)') 68 | parser.add_argument('--num-classes', type=int, default=None, 69 | help='Number classes in dataset') 70 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 71 | help='path to class to idx mapping file (default: "")') 72 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 73 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 74 | parser.add_argument('--log-freq', default=50, type=int, 75 | metavar='N', help='batch logging frequency (default: 10)') 76 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 77 | help='path to latest checkpoint (default: none)') 78 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 79 | help='use pre-trained model') 80 | parser.add_argument('--num-gpu', type=int, default=1, 81 | help='Number of GPUS to use') 82 | parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', 83 | help='disable test time pool') 84 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 85 | help='disable fast prefetcher') 86 | parser.add_argument('--pin-mem', action='store_true', default=False, 87 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 88 | parser.add_argument('--channels-last', action='store_true', default=False, 89 | help='Use channels_last memory layout') 90 | parser.add_argument('--amp', action='store_true', default=False, 91 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') 92 | parser.add_argument('--apex-amp', action='store_true', default=False, 93 | help='Use NVIDIA Apex AMP mixed precision') 94 | parser.add_argument('--native-amp', action='store_true', default=False, 95 | help='Use Native Torch AMP mixed precision') 96 | parser.add_argument('--tf-preprocessing', action='store_true', default=False, 97 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed') 98 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 99 | help='use ema version of weights if present') 100 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 101 | help='convert model torchscript for inference') 102 | parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', 103 | help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') 104 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 105 | help='Output csv file for validation results (summary)') 106 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', 107 | help='Real labels JSON file for imagenet evaluation') 108 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', 109 | help='Valid label indices txt file for validation of partial label space') 110 | 111 | 112 | def validate(args): 113 | # might as well try to validate something 114 | args.pretrained = args.pretrained or not args.checkpoint 115 | args.prefetcher = not args.no_prefetcher 116 | amp_autocast = suppress # do nothing 117 | if args.amp: 118 | if has_native_amp: 119 | args.native_amp = True 120 | elif has_apex: 121 | args.apex_amp = True 122 | else: 123 | _logger.warning("Neither APEX or Native Torch AMP is available.") 124 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." 125 | if args.native_amp: 126 | amp_autocast = torch.cuda.amp.autocast 127 | _logger.info('Validating in mixed precision with native PyTorch AMP.') 128 | elif args.apex_amp: 129 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') 130 | else: 131 | _logger.info('Validating in float32. AMP not enabled.') 132 | 133 | if args.legacy_jit: 134 | set_jit_legacy() 135 | 136 | # create model 137 | model = create_model( 138 | args.model, 139 | pretrained=args.pretrained, 140 | num_classes=args.num_classes, 141 | in_chans=3, 142 | global_pool=args.gp, 143 | scriptable=args.torchscript, 144 | img_size=args.img_size) 145 | if args.num_classes is None: 146 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 147 | args.num_classes = model.num_classes 148 | 149 | if args.checkpoint: 150 | load_checkpoint(model, args.checkpoint, args.use_ema, strict=False) 151 | 152 | param_count = sum([m.numel() for m in model.parameters()]) 153 | _logger.info('Model %s created, param count: %d' % (args.model, param_count)) 154 | 155 | data_config = resolve_data_config(vars(args), model=model, use_test_size=True) 156 | test_time_pool = False 157 | if not args.no_test_pool: 158 | model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) 159 | 160 | if args.torchscript: 161 | torch.jit.optimized_execution(True) 162 | model = torch.jit.script(model) 163 | 164 | model = model.cuda() 165 | if args.apex_amp: 166 | model = amp.initialize(model, opt_level='O1') 167 | 168 | if args.channels_last: 169 | model = model.to(memory_format=torch.channels_last) 170 | 171 | if args.num_gpu > 1: 172 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) 173 | 174 | criterion = nn.CrossEntropyLoss().cuda() 175 | 176 | dataset = create_dataset( 177 | root=args.data, name=args.dataset, split=args.split, 178 | load_bytes=args.tf_preprocessing, class_map=args.class_map) 179 | 180 | if args.valid_labels: 181 | with open(args.valid_labels, 'r') as f: 182 | valid_labels = {int(line.rstrip()) for line in f} 183 | valid_labels = [i in valid_labels for i in range(args.num_classes)] 184 | else: 185 | valid_labels = None 186 | 187 | if args.real_labels: 188 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) 189 | else: 190 | real_labels = None 191 | 192 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] 193 | loader = create_loader( 194 | dataset, 195 | input_size=data_config['input_size'], 196 | batch_size=args.batch_size, 197 | use_prefetcher=args.prefetcher, 198 | interpolation=data_config['interpolation'], 199 | mean=data_config['mean'], 200 | std=data_config['std'], 201 | num_workers=args.workers, 202 | crop_pct=crop_pct, 203 | pin_memory=args.pin_mem, 204 | tf_preprocessing=args.tf_preprocessing) 205 | 206 | batch_time = AverageMeter() 207 | losses = AverageMeter() 208 | top1 = AverageMeter() 209 | top5 = AverageMeter() 210 | 211 | model.eval() 212 | with torch.no_grad(): 213 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non 214 | input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() 215 | if args.channels_last: 216 | input = input.contiguous(memory_format=torch.channels_last) 217 | model(input) 218 | end = time.time() 219 | for batch_idx, (input, target) in enumerate(loader): 220 | if args.no_prefetcher: 221 | target = target.cuda() 222 | input = input.cuda() 223 | if args.channels_last: 224 | input = input.contiguous(memory_format=torch.channels_last) 225 | 226 | # compute output 227 | with amp_autocast(): 228 | output = model(input) 229 | if isinstance(output, (tuple, list)): 230 | output = output[0] 231 | if valid_labels is not None: 232 | output = output[:, valid_labels] 233 | loss = criterion(output, target) 234 | 235 | if real_labels is not None: 236 | real_labels.add_result(output) 237 | 238 | # measure accuracy and record loss 239 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) 240 | losses.update(loss.item(), input.size(0)) 241 | top1.update(acc1.item(), input.size(0)) 242 | top5.update(acc5.item(), input.size(0)) 243 | 244 | # measure elapsed time 245 | batch_time.update(time.time() - end) 246 | end = time.time() 247 | 248 | if batch_idx % args.log_freq == 0: 249 | _logger.info( 250 | 'Test: [{0:>4d}/{1}] ' 251 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 252 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 253 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 254 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 255 | batch_idx, len(loader), batch_time=batch_time, 256 | rate_avg=input.size(0) / batch_time.avg, 257 | loss=losses, top1=top1, top5=top5)) 258 | 259 | if real_labels is not None: 260 | # real labels mode replaces topk values at the end 261 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) 262 | else: 263 | top1a, top5a = top1.avg, top5.avg 264 | results = OrderedDict( 265 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4), 266 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4), 267 | param_count=round(param_count / 1e6, 2), 268 | img_size=data_config['input_size'][-1], 269 | cropt_pct=crop_pct, 270 | interpolation=data_config['interpolation']) 271 | 272 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( 273 | results['top1'], results['top1_err'], results['top5'], results['top5_err'])) 274 | 275 | return results 276 | 277 | 278 | def main(): 279 | setup_default_logging() 280 | args = parser.parse_args() 281 | model_cfgs = [] 282 | model_names = [] 283 | if os.path.isdir(args.checkpoint): 284 | # validate all checkpoints in a path with same model 285 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') 286 | checkpoints += glob.glob(args.checkpoint + '/*.pth') 287 | model_names = list_models(args.model) 288 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] 289 | else: 290 | if args.model == 'all': 291 | # validate all models in a list of names with pretrained checkpoints 292 | args.pretrained = True 293 | model_names = list_models(pretrained=True, exclude_filters=['*in21k']) 294 | model_cfgs = [(n, '') for n in model_names] 295 | elif not is_model(args.model): 296 | # model name doesn't exist, try as wildcard filter 297 | model_names = list_models(args.model) 298 | model_cfgs = [(n, '') for n in model_names] 299 | 300 | if len(model_cfgs): 301 | results_file = args.results_file or './results-all.csv' 302 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) 303 | results = [] 304 | try: 305 | start_batch_size = args.batch_size 306 | for m, c in model_cfgs: 307 | batch_size = start_batch_size 308 | args.model = m 309 | args.checkpoint = c 310 | result = OrderedDict(model=args.model) 311 | r = {} 312 | while not r and batch_size >= args.num_gpu: 313 | torch.cuda.empty_cache() 314 | try: 315 | args.batch_size = batch_size 316 | print('Validating with batch size: %d' % args.batch_size) 317 | r = validate(args) 318 | except RuntimeError as e: 319 | if batch_size <= args.num_gpu: 320 | print("Validation failed with no ability to reduce batch size. Exiting.") 321 | raise e 322 | batch_size = max(batch_size // 2, args.num_gpu) 323 | print("Validation failed, reducing batch size by 50%") 324 | result.update(r) 325 | if args.checkpoint: 326 | result['checkpoint'] = args.checkpoint 327 | results.append(result) 328 | except KeyboardInterrupt as e: 329 | pass 330 | results = sorted(results, key=lambda x: x['top1'], reverse=True) 331 | if len(results): 332 | write_results(results_file, results) 333 | else: 334 | validate(args) 335 | 336 | 337 | def write_results(results_file, results): 338 | with open(results_file, mode='w') as cf: 339 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 340 | dw.writeheader() 341 | for r in results: 342 | dw.writerow(r) 343 | cf.flush() 344 | 345 | 346 | if __name__ == '__main__': 347 | main() 348 | --------------------------------------------------------------------------------