├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py └── train_cfg │ ├── imagenet_train_base_cfg.py │ ├── imagenet_train_ghostnet_cfg.py │ ├── imagenet_train_mbv2_cfg.py │ ├── imagenet_train_mbv3_cfg.py │ └── imagenet_train_res_cfg.py ├── dataset ├── __init__.py ├── im2rec.py ├── imagenet_data_dali.py ├── mk_img_list.py └── torchvision_extension.py ├── imgs └── dali-speed-comp.png ├── model_zoo ├── __init__.py ├── ghostnet.py └── mobilenetv3.py ├── models ├── __init__.py ├── model_derived.py └── operations.py ├── run_apis ├── __init__.py ├── train_dist.py └── trainer.py ├── scripts └── train_dist.sh └── tools ├── __init__.py ├── env.py ├── lr_scheduler.py ├── multadds_count.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dali-pytorch 2 | This is a PyTorch toolkit for accelerating ImageNet training based on the distributed mode with [NVIDIA DALI](https://github.com/NVIDIA/DALI) equipped. Though with extremely high training speed, the toolkit achieves similar or higher performance than that reported in the original papers. 3 | 4 | ## Main Requirements (recommended) 5 | 6 | * python >= 3.7.0 7 | * pytorch >= 1.2.0 8 | * CUDA >= 10.0 9 | * nvidia-dali-cuda100 >= 0.23.0 10 | * mmcv >= 1.0.5 11 | * mxnet >= 1.6.0 (only used for preparing the dataset) 12 | 13 | ## Train 14 | 1. As shown in the [official evaluation](https://developer.nvidia.com/DALI), the MXNet data format shows notable speed advantages. We use the MXNet `.rec` dataset format for training. Please prepare the ImageNet dataset as `.rec` following ["Create a Dataset Using RecordIO"](https://mxnet.apache.org/api/faq/recordio). 15 | ![speed_comp](./imgs/dali-speed-comp.png) 16 | 17 | 2. Train your model with the following script. It is recommended to run the training code on 8 GPUs with a total batch size of 128x8. Taken MobileNetV2 training as an example, 18 | ``` 19 | sh ./scripts/train_dist.sh 8 ./imagenet-rec imagenet_train_mbv2_cfg.py 20 | ``` 21 | 22 | 23 | ## Results 24 | We evaluate our toolkit on several popular networks as follows, which achieves similar or higher performance than the original papers reported ones. All the experiments are performed on 8 TITAN-XP GPUs. 25 | 26 | Our pre-trained models and corresponding training logs can be downloaded at [DALI_MODEL_ZOO](https://drive.google.com/drive/folders/1-KHHupxYmZClxlnmrZ5X5ndObFcp98Lv?usp=sharing). 27 | 28 | | Model | Reported Top-1(%) | Top-1(%) | Epochs | Time w/dali | 29 | |----------------------|-------|---------|---------|---------| 30 | | ResNet18 | 69.76* | 72.15 | 240 | 16h | 31 | | MobileNetV2 | 72.0 | 72.94 | 240 | 1d4.5h | 32 | | MobileNetV3 | 75.2 | 75.07 | 400 | 2d1.5h | 33 | | GhostNet 1.0x | 73.9 | 73.97 | 400 | 1d21.5h | 34 | 35 | *\* denotes the result reported in the torchvision model zoo.*
36 | *"d": day; "h": hour.* 37 | 38 | ---------------------------------- 39 | 40 | Comparisons between training w/ and w/o dali: 41 | 42 | | Model | Method | Top-1(%) | Time | 43 | |----------------------|-------|---------|---------| 44 | | ResNet18 | w/ dali | 72.15 | **16h** | 45 | | | w/o dali | 72.01 | 47.5h | 46 | 47 | *The w\o dali setting is based on the lmdb data format from the codebase of [DenseNAS](https://github.com/JaminFong/DenseNAS).* 48 | 49 | ## Acknowledgement 50 | This project is in part supported by [Horizon Robotics](https://en.horizon.ai). 51 | 52 | The code of some models is from [ghostnet](https://github.com/huawei-noah/ghostnet/blob/master/pytorch/ghostnet.py) and [mobilenetv3](https://github.com/d-li14/mobilenetv3.pytorch). Thanks for the contribution of these repositories. 53 | 54 | ## Citation 55 | This repository is partially based on [DenseNAS](https://github.com/JaminFong/DenseNAS). If you find this toolkit helpful in your research, please cite 56 | ``` 57 | @inproceedings{fang2019densely, 58 | title={Densely connected search space for more flexible neural architecture search}, 59 | author={Fang, Jiemin and Sun, Yuzhu and Zhang, Qian and Li, Yuan and Liu, Wenyu and Wang, Xinggang}, 60 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 61 | year={2020} 62 | } 63 | ``` -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from . import * -------------------------------------------------------------------------------- /configs/train_cfg/imagenet_train_base_cfg.py: -------------------------------------------------------------------------------- 1 | model_zoo=dict( 2 | use_model_zoo = False, 3 | model_name = '', 4 | cfg=dict( 5 | init_mode='he_fout', 6 | bn_momentum=0.1, 7 | bn_eps=1e-3 8 | ) 9 | ) 10 | 11 | net_type='mbv2' # mbv2 / res 12 | net_config="""""" 13 | 14 | train_params=dict( 15 | epochs=240, 16 | use_seed=True, 17 | seed=0 18 | ) 19 | 20 | optim=dict( 21 | init_lr=0.5, 22 | min_lr=1e-5, 23 | lr_schedule='cosine', # cosine poly 24 | momentum=0.9, 25 | weight_decay=4e-5, 26 | bn_wo_wd=False, 27 | label_smooth=True, 28 | smooth_alpha=0.1, 29 | use_grad_clip=False, 30 | grad_clip=10, 31 | if_resume=False, 32 | resume=dict( 33 | load_path='', 34 | load_epoch=191, 35 | ), 36 | use_warm_up=False, 37 | warm_up=dict( 38 | epoch=5, 39 | init_lr=1e-5, 40 | target_lr=0.5, 41 | ), 42 | use_multi_stage=False, 43 | multi_stage=dict( 44 | stage_epochs=330 45 | ), 46 | cosine=dict( 47 | use_restart=False, 48 | restart=dict( 49 | lr_period=[10, 20, 40, 80, 160, 320], 50 | lr_step=[0, 10, 30, 70, 150, 310], 51 | ) 52 | ) 53 | ) 54 | 55 | data=dict( 56 | use_dali=True, 57 | num_threads=4, 58 | resize_batch=16, # 32 in default 59 | batch_size=128, 60 | dataset='imagenet', #imagenet 61 | train_pref='train', 62 | val_pref='val', 63 | num_examples=1281167, 64 | input_size=(3,224,224), 65 | # type_interp='INTERP_TRIANGULAR' # INTERP_TRIANGULAR INTERP_LINEAR 66 | min_filter='INTERP_TRIANGULAR', # INTERP_TRIANGULAR INTERP_LINEAR 67 | mag_filter='INTERP_LANCZOS3', # INTERP_TRIANGULAR INTERP_LINEAR 68 | random_sized=dict( 69 | min_scale=0.08 70 | ), 71 | mean=[0.485, 0.456, 0.406], 72 | std=[0.229, 0.224, 0.225], 73 | color=False, 74 | val_shuffle=False 75 | ) 76 | -------------------------------------------------------------------------------- /configs/train_cfg/imagenet_train_ghostnet_cfg.py: -------------------------------------------------------------------------------- 1 | _base_ = './imagenet_train_base_cfg.py' 2 | 3 | model_zoo=dict( 4 | use_model_zoo = True, 5 | model_name = 'ghostnet', 6 | ) 7 | 8 | train_params=dict( 9 | epochs=400, 10 | ) 11 | 12 | optim=dict( 13 | weight_decay=4e-5, 14 | bn_wo_wd=True, 15 | if_resume=False, 16 | resume=dict( 17 | load_path='', 18 | load_epoch=191, 19 | ), 20 | use_warm_up=True, 21 | warm_up=dict( 22 | epoch=5, 23 | init_lr=1e-5, 24 | target_lr=0.5, 25 | ), 26 | ) 27 | 28 | data=dict( 29 | num_threads=4, 30 | resize_batch=16, # 32 in default 31 | batch_size=128, 32 | color=True, 33 | ) -------------------------------------------------------------------------------- /configs/train_cfg/imagenet_train_mbv2_cfg.py: -------------------------------------------------------------------------------- 1 | _base_ = './imagenet_train_base_cfg.py' 2 | 3 | model_zoo=dict( 4 | use_model_zoo = False, 5 | model_name = 'mobilenet_v2', 6 | cfg = None 7 | ) 8 | 9 | net_type='mbv2' # mbv2 / res 10 | net_config="""[[32, 16], 'mbconv_k3_t1', [], 0, 1]| 11 | [[16, 24], 'mbconv_k3_t6', ['mbconv_k3_t6'], 1, 2]| 12 | [[24, 32], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6'], 2, 2]| 13 | [[32, 64], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6', 'mbconv_k3_t6'], 3, 2]| 14 | [[64, 96], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6'], 2, 1]| 15 | [[96, 160], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6'], 2, 2]| 16 | [[160, 320], 'mbconv_k3_t6', [], 3, 1]| 17 | [[320, 1280], 'conv1_1']""" 18 | 19 | train_params=dict( 20 | epochs=240, 21 | ) 22 | 23 | optim=dict( 24 | weight_decay=4e-5, 25 | bn_wo_wd=False, 26 | if_resume=False, 27 | resume=dict( 28 | load_path='', 29 | load_epoch=191, 30 | ), 31 | use_warm_up=False, 32 | warm_up=dict( 33 | epoch=5, 34 | init_lr=1e-5, 35 | target_lr=0.5, 36 | ), 37 | use_se=False 38 | ) 39 | 40 | data=dict( 41 | num_threads=4, 42 | resize_batch=16, # 32 in default 43 | batch_size=128, 44 | color=False, 45 | ) -------------------------------------------------------------------------------- /configs/train_cfg/imagenet_train_mbv3_cfg.py: -------------------------------------------------------------------------------- 1 | _base_ = './imagenet_train_base_cfg.py' 2 | 3 | model_zoo=dict( 4 | use_model_zoo = True, 5 | model_name = 'mobilenetv3_large', 6 | ) 7 | 8 | train_params=dict( 9 | epochs=400, 10 | ) 11 | 12 | optim=dict( 13 | weight_decay=4e-5, 14 | bn_wo_wd=True, 15 | if_resume=False, 16 | resume=dict( 17 | load_path='', 18 | load_epoch=191, 19 | ), 20 | use_warm_up=True, 21 | warm_up=dict( 22 | epoch=5, 23 | init_lr=1e-5, 24 | target_lr=0.5, 25 | ), 26 | ) 27 | 28 | data=dict( 29 | num_threads=4, 30 | resize_batch=16, # 32 in default 31 | batch_size=128, 32 | color=True, 33 | ) -------------------------------------------------------------------------------- /configs/train_cfg/imagenet_train_res_cfg.py: -------------------------------------------------------------------------------- 1 | _base_ = './imagenet_train_base_cfg.py' 2 | 3 | model_zoo=dict( 4 | use_model_zoo = True, 5 | model_name = 'resnet18', 6 | cfg = None 7 | ) 8 | 9 | train_params=dict( 10 | epochs=240, 11 | ) 12 | 13 | optim=dict( 14 | weight_decay=4e-5, 15 | bn_wo_wd=False, 16 | if_resume=False, 17 | resume=dict( 18 | load_path='', 19 | load_epoch=191, 20 | ), 21 | use_warm_up=False, 22 | warm_up=dict( 23 | epoch=5, 24 | init_lr=1e-5, 25 | target_lr=0.5, 26 | ), 27 | ) 28 | 29 | data=dict( 30 | num_threads=4, 31 | resize_batch=16, # 32 in default 32 | batch_size=128, 33 | color=False, 34 | ) -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/im2rec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from __future__ import print_function 21 | import os 22 | import sys 23 | 24 | curr_path = os.path.abspath(os.path.dirname(__file__)) 25 | sys.path.append(os.path.join(curr_path, "../python")) 26 | import mxnet as mx 27 | import random 28 | import argparse 29 | import cv2 30 | import time 31 | import traceback 32 | 33 | try: 34 | import multiprocessing 35 | except ImportError: 36 | multiprocessing = None 37 | 38 | def list_image(root, recursive, exts): 39 | """Traverses the root of directory that contains images and 40 | generates image list iterator. 41 | Parameters 42 | ---------- 43 | root: string 44 | recursive: bool 45 | exts: string 46 | Returns 47 | ------- 48 | image iterator that contains all the image under the specified path 49 | """ 50 | 51 | i = 0 52 | if recursive: 53 | cat = {} 54 | for path, dirs, files in os.walk(root, followlinks=True): 55 | dirs.sort() 56 | files.sort() 57 | for fname in files: 58 | fpath = os.path.join(path, fname) 59 | suffix = os.path.splitext(fname)[1].lower() 60 | if os.path.isfile(fpath) and (suffix in exts): 61 | if path not in cat: 62 | cat[path] = len(cat) 63 | yield (i, os.path.relpath(fpath, root), cat[path]) 64 | i += 1 65 | for k, v in sorted(cat.items(), key=lambda x: x[1]): 66 | print(os.path.relpath(k, root), v) 67 | else: 68 | for fname in sorted(os.listdir(root)): 69 | fpath = os.path.join(root, fname) 70 | suffix = os.path.splitext(fname)[1].lower() 71 | if os.path.isfile(fpath) and (suffix in exts): 72 | yield (i, os.path.relpath(fpath, root), 0) 73 | i += 1 74 | 75 | def write_list(path_out, image_list): 76 | """Hepler function to write image list into the file. 77 | The format is as below, 78 | integer_image_index \t float_label_index \t path_to_image 79 | Note that the blank between number and tab is only used for readability. 80 | Parameters 81 | ---------- 82 | path_out: string 83 | image_list: list 84 | """ 85 | with open(path_out, 'w') as fout: 86 | for i, item in enumerate(image_list): 87 | line = '%d\t' % item[0] 88 | for j in item[2:]: 89 | line += '%f\t' % j 90 | line += '%s\n' % item[1] 91 | fout.write(line) 92 | 93 | def make_list(args): 94 | """Generates .lst file. 95 | Parameters 96 | ---------- 97 | args: object that contains all the arguments 98 | """ 99 | image_list = list_image(args.root, args.recursive, args.exts) 100 | image_list = list(image_list) 101 | if args.shuffle is True: 102 | random.seed(100) 103 | random.shuffle(image_list) 104 | N = len(image_list) 105 | chunk_size = (N + args.chunks - 1) // args.chunks 106 | for i in range(args.chunks): 107 | chunk = image_list[i * chunk_size:(i + 1) * chunk_size] 108 | if args.chunks > 1: 109 | str_chunk = '_%d' % i 110 | else: 111 | str_chunk = '' 112 | sep = int(chunk_size * args.train_ratio) 113 | sep_test = int(chunk_size * args.test_ratio) 114 | if args.train_ratio == 1.0: 115 | write_list(args.prefix + str_chunk + '.lst', chunk) 116 | else: 117 | if args.test_ratio: 118 | write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) 119 | if args.train_ratio + args.test_ratio < 1.0: 120 | write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:]) 121 | write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep]) 122 | 123 | def read_list(path_in): 124 | """Reads the .lst file and generates corresponding iterator. 125 | Parameters 126 | ---------- 127 | path_in: string 128 | Returns 129 | ------- 130 | item iterator that contains information in .lst file 131 | """ 132 | with open(path_in) as fin: 133 | while True: 134 | line = fin.readline() 135 | if not line: 136 | break 137 | line = [i.strip() for i in line.strip().split('\t')] 138 | line_len = len(line) 139 | # check the data format of .lst file 140 | if line_len < 3: 141 | print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line)) 142 | continue 143 | try: 144 | item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] 145 | except Exception as e: 146 | print('Parsing lst met error for %s, detail: %s' % (line, e)) 147 | continue 148 | yield item 149 | 150 | def image_encode(args, i, item, q_out): 151 | """Reads, preprocesses, packs the image and put it back in output queue. 152 | Parameters 153 | ---------- 154 | args: object 155 | i: int 156 | item: list 157 | q_out: queue 158 | """ 159 | fullpath = os.path.join(args.root, item[1]) 160 | 161 | if len(item) > 3 and args.pack_label: 162 | header = mx.recordio.IRHeader(0, item[2:], item[0], 0) 163 | else: 164 | header = mx.recordio.IRHeader(0, item[2], item[0], 0) 165 | 166 | if args.pass_through: 167 | try: 168 | with open(fullpath, 'rb') as fin: 169 | img = fin.read() 170 | s = mx.recordio.pack(header, img) 171 | q_out.put((i, s, item)) 172 | except Exception as e: 173 | traceback.print_exc() 174 | print('pack_img error:', item[1], e) 175 | q_out.put((i, None, item)) 176 | return 177 | 178 | try: 179 | img = cv2.imread(fullpath, args.color) 180 | except: 181 | traceback.print_exc() 182 | print('imread error trying to load file: %s ' % fullpath) 183 | q_out.put((i, None, item)) 184 | return 185 | if img is None: 186 | print('imread read blank (None) image for file: %s' % fullpath) 187 | q_out.put((i, None, item)) 188 | return 189 | if args.center_crop: 190 | if img.shape[0] > img.shape[1]: 191 | margin = (img.shape[0] - img.shape[1]) // 2 192 | img = img[margin:margin + img.shape[1], :] 193 | else: 194 | margin = (img.shape[1] - img.shape[0]) // 2 195 | img = img[:, margin:margin + img.shape[0]] 196 | if args.resize: 197 | if img.shape[0] > img.shape[1]: 198 | newsize = (args.resize, img.shape[0] * args.resize // img.shape[1]) 199 | else: 200 | newsize = (img.shape[1] * args.resize // img.shape[0], args.resize) 201 | img = cv2.resize(img, newsize) 202 | 203 | try: 204 | s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) 205 | q_out.put((i, s, item)) 206 | except Exception as e: 207 | traceback.print_exc() 208 | print('pack_img error on file: %s' % fullpath, e) 209 | q_out.put((i, None, item)) 210 | return 211 | 212 | def read_worker(args, q_in, q_out): 213 | """Function that will be spawned to fetch the image 214 | from the input queue and put it back to output queue. 215 | Parameters 216 | ---------- 217 | args: object 218 | q_in: queue 219 | q_out: queue 220 | """ 221 | while True: 222 | deq = q_in.get() 223 | if deq is None: 224 | break 225 | i, item = deq 226 | image_encode(args, i, item, q_out) 227 | 228 | def write_worker(q_out, fname, working_dir): 229 | """Function that will be spawned to fetch processed image 230 | from the output queue and write to the .rec file. 231 | Parameters 232 | ---------- 233 | q_out: queue 234 | fname: string 235 | working_dir: string 236 | """ 237 | pre_time = time.time() 238 | count = 0 239 | fname = os.path.basename(fname) 240 | fname_rec = os.path.splitext(fname)[0] + '.rec' 241 | fname_idx = os.path.splitext(fname)[0] + '.idx' 242 | record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), 243 | os.path.join(working_dir, fname_rec), 'w') 244 | buf = {} 245 | more = True 246 | while more: 247 | deq = q_out.get() 248 | if deq is not None: 249 | i, s, item = deq 250 | buf[i] = (s, item) 251 | else: 252 | more = False 253 | while count in buf: 254 | s, item = buf[count] 255 | del buf[count] 256 | if s is not None: 257 | record.write_idx(item[0], s) 258 | 259 | if count % 1000 == 0: 260 | cur_time = time.time() 261 | print('time:', cur_time - pre_time, ' count:', count) 262 | pre_time = cur_time 263 | count += 1 264 | 265 | def parse_args(): 266 | """Defines all arguments. 267 | Returns 268 | ------- 269 | args object that contains all the params 270 | """ 271 | parser = argparse.ArgumentParser( 272 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 273 | description='Create an image list or \ 274 | make a record database by reading from an image list') 275 | parser.add_argument('prefix', help='prefix of input/output lst and rec files.') 276 | parser.add_argument('root', help='path to folder containing images.') 277 | 278 | cgroup = parser.add_argument_group('Options for creating image lists') 279 | cgroup.add_argument('--list', action='store_true', 280 | help='If this is set im2rec will create image list(s) by traversing root folder\ 281 | and output to .lst.\ 282 | Otherwise im2rec will read .lst and create a database at .rec') 283 | cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'], 284 | help='list of acceptable image extensions.') 285 | cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.') 286 | cgroup.add_argument('--train-ratio', type=float, default=1.0, 287 | help='Ratio of images to use for training.') 288 | cgroup.add_argument('--test-ratio', type=float, default=0, 289 | help='Ratio of images to use for testing.') 290 | cgroup.add_argument('--recursive', action='store_true', 291 | help='If true recursively walk through subdirs and assign an unique label\ 292 | to images in each folder. Otherwise only include images in the root folder\ 293 | and give them label 0.') 294 | cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false', 295 | help='If this is passed, \ 296 | im2rec will not randomize the image order in .lst') 297 | rgroup = parser.add_argument_group('Options for creating database') 298 | rgroup.add_argument('--pass-through', action='store_true', 299 | help='whether to skip transformation and save image as is') 300 | rgroup.add_argument('--resize', type=int, default=0, 301 | help='resize the shorter edge of image to the newsize, original images will\ 302 | be packed by default.') 303 | rgroup.add_argument('--center-crop', action='store_true', 304 | help='specify whether to crop the center image to make it rectangular.') 305 | rgroup.add_argument('--quality', type=int, default=95, 306 | help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9') 307 | rgroup.add_argument('--num-thread', type=int, default=1, 308 | help='number of thread to use for encoding. order of images will be different\ 309 | from the input list if >1. the input list will be modified to match the\ 310 | resulting order.') 311 | rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1], 312 | help='specify the color mode of the loaded image.\ 313 | 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ 314 | 0: Loads image in grayscale mode.\ 315 | -1:Loads image as such including alpha channel.') 316 | rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], 317 | help='specify the encoding of the images.') 318 | rgroup.add_argument('--pack-label', action='store_true', 319 | help='Whether to also pack multi dimensional label in the record file') 320 | args = parser.parse_args() 321 | args.prefix = os.path.abspath(args.prefix) 322 | args.root = os.path.abspath(args.root) 323 | return args 324 | 325 | if __name__ == '__main__': 326 | args = parse_args() 327 | # if the '--list' is used, it generates .lst file 328 | if args.list: 329 | make_list(args) 330 | # otherwise read .lst file to generates .rec file 331 | else: 332 | if os.path.isdir(args.prefix): 333 | working_dir = args.prefix 334 | else: 335 | working_dir = os.path.dirname(args.prefix) 336 | files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir) 337 | if os.path.isfile(os.path.join(working_dir, fname))] 338 | count = 0 339 | for fname in files: 340 | if fname.startswith(args.prefix) and fname.endswith('.lst'): 341 | print('Creating .rec file from', fname, 'in', working_dir) 342 | count += 1 343 | image_list = read_list(fname) 344 | if args.shuffle: 345 | image_list = list(image_list) 346 | random.seed(100) 347 | random.shuffle(image_list) 348 | # -- write_record -- # 349 | if args.num_thread > 1 and multiprocessing is not None: 350 | q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)] 351 | q_out = multiprocessing.Queue(1024) 352 | # define the process 353 | read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \ 354 | for i in range(args.num_thread)] 355 | # process images with num_thread process 356 | for p in read_process: 357 | p.start() 358 | # only use one process to write .rec to avoid race-condtion 359 | write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir)) 360 | write_process.start() 361 | # put the image list into input queue 362 | for i, item in enumerate(image_list): 363 | q_in[i % len(q_in)].put((i, item)) 364 | for q in q_in: 365 | q.put(None) 366 | for p in read_process: 367 | p.join() 368 | 369 | q_out.put(None) 370 | write_process.join() 371 | else: 372 | print('multiprocessing not available, fall back to single threaded encoding') 373 | try: 374 | import Queue as queue 375 | except ImportError: 376 | import queue 377 | q_out = queue.Queue() 378 | fname = os.path.basename(fname) 379 | fname_rec = os.path.splitext(fname)[0] + '.rec' 380 | fname_idx = os.path.splitext(fname)[0] + '.idx' 381 | record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), 382 | os.path.join(working_dir, fname_rec), 'w') 383 | cnt = 0 384 | pre_time = time.time() 385 | for i, item in enumerate(image_list): 386 | image_encode(args, i, item, q_out) 387 | if q_out.empty(): 388 | continue 389 | _, s, _ = q_out.get() 390 | record.write_idx(item[0], s) 391 | if cnt % 1000 == 0: 392 | cur_time = time.time() 393 | print('time:', cur_time - pre_time, ' count:', cnt) 394 | pre_time = cur_time 395 | cnt += 1 396 | if not count: 397 | print('Did not find and list file with prefix %s'%args.prefix) -------------------------------------------------------------------------------- /dataset/imagenet_data_dali.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | try: 3 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 4 | from nvidia.dali.pipeline import Pipeline 5 | import nvidia.dali.ops as ops 6 | import nvidia.dali.types as types 7 | except ImportError: 8 | raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.") 9 | 10 | class HybridTrainPipe(Pipeline): 11 | def __init__(self, batch_size, world_size, device_id, data_dir, 12 | data_cfg): 13 | super(HybridTrainPipe, self).__init__(batch_size, 14 | data_cfg.num_threads, 15 | device_id, 16 | seed=12 + device_id) 17 | self.input = ops.MXNetReader(path = osp.join(data_dir, data_cfg.train_pref+".rec"), 18 | index_path = osp.join(data_dir, data_cfg.train_pref+".idx"), 19 | random_shuffle = True, 20 | shard_id = device_id, 21 | num_shards = world_size) 22 | #let user decide which pipeline works him bets for RN version he runs 23 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 24 | self.rrc = ops.RandomResizedCrop(device='gpu', 25 | size=data_cfg.input_size[1], 26 | random_area=[data_cfg.random_sized.min_scale, 1.0], 27 | # interp_type=getattr(types, data_cfg.type_interp), 28 | min_filter=getattr(types, data_cfg.min_filter), 29 | mag_filter=getattr(types, data_cfg.mag_filter), 30 | minibatch_size=data_cfg.resize_batch) 31 | if data_cfg.color: 32 | self.color = ops.ColorTwist(device="gpu") 33 | self.bright = ops.Uniform(range=[0.6, 1.4]) 34 | self.cont = ops.Uniform(range=[0.6, 1.4]) 35 | self.sat = ops.Uniform(range=[0.6, 1.4]) 36 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 37 | output_layout=types.NCHW, 38 | mean=[0.485 * 255,0.456 * 255,0.406 * 255], 39 | std=[0.229 * 255,0.224 * 255,0.225 * 255]) 40 | self.coin = ops.CoinFlip(probability=0.5) 41 | 42 | def define_graph(self): 43 | rng = self.coin() 44 | self.jpegs, self.labels = self.input(name="Reader") 45 | images = self.decode(self.jpegs) 46 | images = self.rrc(images) 47 | if hasattr(self, 'color'): 48 | images = self.color(images, 49 | brightness=self.bright(), 50 | contrast=self.cont(), 51 | saturation=self.sat()) 52 | output = self.cmnp(images.gpu(), mirror=rng) 53 | return [output, self.labels] 54 | 55 | class HybridValPipe(Pipeline): 56 | def __init__(self, batch_size, world_size, device_id, data_dir, data_cfg): 57 | super(HybridValPipe, self).__init__(batch_size, 58 | data_cfg.num_threads, 59 | device_id, 60 | seed=12 + device_id) 61 | self.input = ops.MXNetReader(path = osp.join(data_dir, data_cfg.val_pref+".rec"), 62 | index_path = osp.join(data_dir, data_cfg.val_pref+".idx"), 63 | random_shuffle = data_cfg.val_shuffle, 64 | shard_id = device_id, 65 | num_shards = world_size) 66 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 67 | self.res = ops.Resize(device="gpu", 68 | resize_shorter=256, 69 | # interp_type=getattr(types, data_cfg.type_interp), 70 | min_filter=getattr(types, data_cfg.min_filter), 71 | mag_filter=getattr(types, data_cfg.mag_filter), 72 | minibatch_size=data_cfg.resize_batch) 73 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 74 | output_layout=types.NCHW, 75 | crop=(data_cfg.input_size[1],)*2, 76 | mean=[0.485 * 255,0.456 * 255,0.406 * 255], 77 | std=[0.229 * 255,0.224 * 255,0.225 * 255]) 78 | 79 | def define_graph(self): 80 | self.jpegs, self.labels = self.input(name="Reader") 81 | images = self.decode(self.jpegs) 82 | images = self.res(images) 83 | output = self.cmnp(images) 84 | return [output, self.labels] 85 | 86 | class DALIClassificationIterator_(DALIClassificationIterator): 87 | def next_batch(self): 88 | try: 89 | data = self.__next__() 90 | input = data[0]["data"] 91 | target = data[0]["label"].squeeze().cuda().long() 92 | except StopIteration: 93 | input, target = None, None 94 | return input, target 95 | 96 | 97 | def get_train_loader(batch_size, world_size, device_id, 98 | data_dir, data_cfg): 99 | pipe = HybridTrainPipe(batch_size=batch_size, 100 | world_size=world_size, 101 | device_id=device_id, 102 | data_dir=data_dir, 103 | data_cfg=data_cfg,) 104 | pipe.build() 105 | train_loader = DALIClassificationIterator_(pipe, reader_name="Reader", 106 | fill_last_batch=False, auto_reset=True) 107 | return train_loader 108 | 109 | 110 | def get_val_loader(batch_size, world_size, device_id, 111 | data_dir, data_cfg): 112 | pipe = HybridValPipe(batch_size=batch_size, 113 | world_size=world_size, 114 | device_id=device_id, 115 | data_dir=data_dir, 116 | data_cfg=data_cfg) 117 | pipe.build() 118 | val_loader = DALIClassificationIterator_(pipe, reader_name="Reader", 119 | fill_last_batch=False, auto_reset=True) 120 | return val_loader 121 | 122 | def get_data_loader(data_cfg, world_size, device_id, data_dir): 123 | train_loader = get_train_loader(data_cfg.batch_size, 124 | world_size, device_id, 125 | data_dir, data_cfg) 126 | val_loader = get_val_loader(data_cfg.batch_size, 127 | world_size, device_id, 128 | data_dir, data_cfg) 129 | return train_loader, val_loader -------------------------------------------------------------------------------- /dataset/mk_img_list.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | def get_list(data_path, output_path): 5 | for split in os.listdir(data_path): 6 | split_path = os.path.join(data_path, split) 7 | if not os.path.isdir(split_path): 8 | continue 9 | f = open(os.path.join(output_path, split + '_datalist'), 'a+') 10 | for sub in os.listdir(split_path): 11 | sub_path = os.path.join(split_path, sub) 12 | if not os.path.isdir(sub_path): 13 | continue 14 | for image in os.listdir(sub_path): 15 | image_name = sub + '/' + image 16 | f.writelines(image_name + '\n') 17 | f.close() 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser("Params") 22 | parser.add_argument('--image_path', type=str, default='', help='the path of the images') 23 | parser.add_argument('--output_path', type=str, default='', help='the output path of the lmdb file') 24 | args = parser.parse_args() 25 | 26 | get_list(args.image_path, args.output_path) 27 | -------------------------------------------------------------------------------- /dataset/torchvision_extension.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | from torchvision.transforms import functional as F 4 | 5 | #In this file some more transformations (apart from the ones defined in torchvision.transform) 6 | #are added. Particularly helpful to train imagenet, and in the style of the transforms 7 | #used by fb.resnet https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua 8 | 9 | #This file is taken from a proposed pull request on the torchvision github project. 10 | #At the moment this pull request has not been accepted yet, that is why I report it here. 11 | #Link to the pull request: https://github.com/pytorch/vision/pull/27/files 12 | 13 | class Lighting(object): 14 | 15 | """Lighting noise(AlexNet - style PCA - based noise)""" 16 | 17 | def __init__(self, alphastd, eigval, eigvec): 18 | self.alphastd = alphastd 19 | self.eigval = eigval 20 | self.eigvec = eigvec 21 | 22 | def __call__(self, img): 23 | # img is supposed go be a torch tensor 24 | 25 | if self.alphastd == 0: 26 | return img 27 | 28 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 29 | rgb = self.eigvec.type_as(img).clone()\ 30 | .mul(alpha.view(1, 3).expand(3, 3))\ 31 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 32 | .sum(1).squeeze() 33 | 34 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 35 | 36 | 37 | class RandomScale(object): 38 | 39 | """ResNet style data augmentation""" 40 | 41 | def __init__(self, minSize, maxSize): 42 | self.minSize = minSize 43 | self.maxSize = maxSize 44 | 45 | def __call__(self, img): 46 | 47 | targetSz = int(round(random.uniform(self.minSize, self.maxSize))) 48 | 49 | return F.resize(img, targetSz) 50 | 51 | -------------------------------------------------------------------------------- /imgs/dali-speed-comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/imgs/dali-speed-comp.png -------------------------------------------------------------------------------- /model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .ghostnet import ghostnet 2 | from .mobilenetv3 import mobilenetv3_large 3 | from .mobilenetv3 import mobilenetv3_small 4 | 5 | from torchvision.models import mobilenet_v2 6 | from torchvision.models import resnet18, resnet34, resnet50 -------------------------------------------------------------------------------- /model_zoo/ghostnet.py: -------------------------------------------------------------------------------- 1 | # 2020.06.09-Changed for building GhostNet 2 | # Huawei Technologies Co., Ltd. 3 | """ 4 | From https://github.com/huawei-noah/ghostnet/blob/master/pytorch/ghostnet.py 5 | --------------------------------------------- 6 | Creates a GhostNet Model as defined in: 7 | GhostNet: More Features from Cheap Operations By Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu. 8 | https://arxiv.org/abs/1911.11907 9 | Modified from https://github.com/d-li14/mobilenetv3.pytorch and https://github.com/rwightman/pytorch-image-models 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import math 15 | 16 | 17 | __all__ = ['ghost_net'] 18 | 19 | 20 | def _make_divisible(v, divisor, min_value=None): 21 | """ 22 | This function is taken from the original tf repo. 23 | It ensures that all layers have a channel number that is divisible by 8 24 | It can be seen here: 25 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 26 | """ 27 | if min_value is None: 28 | min_value = divisor 29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 30 | # Make sure that round down does not go down by more than 10%. 31 | if new_v < 0.9 * v: 32 | new_v += divisor 33 | return new_v 34 | 35 | 36 | def hard_sigmoid(x, inplace: bool = False): 37 | if inplace: 38 | return x.add_(3.).clamp_(0., 6.).div_(6.) 39 | else: 40 | return F.relu6(x + 3.) / 6. 41 | 42 | 43 | class SqueezeExcite(nn.Module): 44 | def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, 45 | act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): 46 | super(SqueezeExcite, self).__init__() 47 | self.gate_fn = gate_fn 48 | reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) 49 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 50 | self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) 51 | self.act1 = act_layer(inplace=True) 52 | self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) 53 | 54 | def forward(self, x): 55 | x_se = self.avg_pool(x) 56 | x_se = self.conv_reduce(x_se) 57 | x_se = self.act1(x_se) 58 | x_se = self.conv_expand(x_se) 59 | x = x * self.gate_fn(x_se) 60 | return x 61 | 62 | 63 | class ConvBnAct(nn.Module): 64 | def __init__(self, in_chs, out_chs, kernel_size, 65 | stride=1, act_layer=nn.ReLU): 66 | super(ConvBnAct, self).__init__() 67 | self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) 68 | self.bn1 = nn.BatchNorm2d(out_chs) 69 | self.act1 = act_layer(inplace=True) 70 | 71 | def forward(self, x): 72 | x = self.conv(x) 73 | x = self.bn1(x) 74 | x = self.act1(x) 75 | return x 76 | 77 | 78 | class GhostModule(nn.Module): 79 | def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 80 | super(GhostModule, self).__init__() 81 | self.oup = oup 82 | init_channels = math.ceil(oup / ratio) 83 | new_channels = init_channels*(ratio-1) 84 | 85 | self.primary_conv = nn.Sequential( 86 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 87 | nn.BatchNorm2d(init_channels), 88 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 89 | ) 90 | 91 | self.cheap_operation = nn.Sequential( 92 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 93 | nn.BatchNorm2d(new_channels), 94 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 95 | ) 96 | 97 | def forward(self, x): 98 | x1 = self.primary_conv(x) 99 | x2 = self.cheap_operation(x1) 100 | out = torch.cat([x1,x2], dim=1) 101 | return out[:,:self.oup,:,:] 102 | 103 | 104 | class GhostBottleneck(nn.Module): 105 | """ Ghost bottleneck w/ optional SE""" 106 | 107 | def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, 108 | stride=1, act_layer=nn.ReLU, se_ratio=0.): 109 | super(GhostBottleneck, self).__init__() 110 | has_se = se_ratio is not None and se_ratio > 0. 111 | self.stride = stride 112 | 113 | # Point-wise expansion 114 | self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) 115 | 116 | # Depth-wise convolution 117 | if self.stride > 1: 118 | self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, 119 | padding=(dw_kernel_size-1)//2, 120 | groups=mid_chs, bias=False) 121 | self.bn_dw = nn.BatchNorm2d(mid_chs) 122 | 123 | # Squeeze-and-excitation 124 | if has_se: 125 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) 126 | else: 127 | self.se = None 128 | 129 | # Point-wise linear projection 130 | self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) 131 | 132 | # shortcut 133 | if (in_chs == out_chs and self.stride == 1): 134 | self.shortcut = nn.Sequential() 135 | else: 136 | self.shortcut = nn.Sequential( 137 | nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, 138 | padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), 139 | nn.BatchNorm2d(in_chs), 140 | nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), 141 | nn.BatchNorm2d(out_chs), 142 | ) 143 | 144 | 145 | def forward(self, x): 146 | residual = x 147 | 148 | # 1st ghost bottleneck 149 | x = self.ghost1(x) 150 | 151 | # Depth-wise convolution 152 | if self.stride > 1: 153 | x = self.conv_dw(x) 154 | x = self.bn_dw(x) 155 | 156 | # Squeeze-and-excitation 157 | if self.se is not None: 158 | x = self.se(x) 159 | 160 | # 2nd ghost bottleneck 161 | x = self.ghost2(x) 162 | 163 | x += self.shortcut(residual) 164 | return x 165 | 166 | 167 | class GhostNet(nn.Module): 168 | def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, 169 | init_mode='he_fout', bn_momentum=0.1, bn_eps=1e-3): 170 | super(GhostNet, self).__init__() 171 | # setting of inverted residual blocks 172 | self.cfgs = cfgs 173 | self.dropout = dropout 174 | 175 | # building first layer 176 | output_channel = _make_divisible(16 * width, 4) 177 | self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) 178 | self.bn1 = nn.BatchNorm2d(output_channel) 179 | self.act1 = nn.ReLU(inplace=True) 180 | input_channel = output_channel 181 | 182 | # building inverted residual blocks 183 | stages = [] 184 | block = GhostBottleneck 185 | for cfg in self.cfgs: 186 | layers = [] 187 | for k, exp_size, c, se_ratio, s in cfg: 188 | output_channel = _make_divisible(c * width, 4) 189 | hidden_channel = _make_divisible(exp_size * width, 4) 190 | layers.append(block(input_channel, hidden_channel, output_channel, k, s, 191 | se_ratio=se_ratio)) 192 | input_channel = output_channel 193 | stages.append(nn.Sequential(*layers)) 194 | 195 | output_channel = _make_divisible(exp_size * width, 4) 196 | stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) 197 | input_channel = output_channel 198 | 199 | self.blocks = nn.Sequential(*stages) 200 | 201 | # building last several layers 202 | output_channel = 1280 203 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 204 | self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) 205 | self.act2 = nn.ReLU(inplace=True) 206 | self.classifier = nn.Linear(output_channel, num_classes) 207 | self._initialize_weights() 208 | self.set_bn_param(bn_momentum, bn_eps) 209 | 210 | def forward(self, x): 211 | x = self.conv_stem(x) 212 | x = self.bn1(x) 213 | x = self.act1(x) 214 | x = self.blocks(x) 215 | x = self.global_pool(x) 216 | x = self.conv_head(x) 217 | x = self.act2(x) 218 | x = x.view(x.size(0), -1) 219 | if self.dropout > 0.: 220 | x = F.dropout(x, p=self.dropout, training=self.training) 221 | x = self.classifier(x) 222 | return x 223 | 224 | def _initialize_weights(self): 225 | for m in self.modules(): 226 | if isinstance(m, nn.Conv2d): 227 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 228 | elif isinstance(m, nn.BatchNorm2d): 229 | m.weight.data.fill_(1) 230 | m.bias.data.zero_() 231 | 232 | def set_bn_param(self, bn_momentum, bn_eps): 233 | for m in self.modules(): 234 | if isinstance(m, nn.BatchNorm2d): 235 | m.momentum = bn_momentum 236 | m.eps = bn_eps 237 | return 238 | 239 | 240 | def ghostnet(**kwargs): 241 | """ 242 | Constructs a GhostNet model 243 | """ 244 | cfgs = [ 245 | # k, t, c, SE, s 246 | # stage1 247 | [[3, 16, 16, 0, 1]], 248 | # stage2 249 | [[3, 48, 24, 0, 2]], 250 | [[3, 72, 24, 0, 1]], 251 | # stage3 252 | [[5, 72, 40, 0.25, 2]], 253 | [[5, 120, 40, 0.25, 1]], 254 | # stage4 255 | [[3, 240, 80, 0, 2]], 256 | [[3, 200, 80, 0, 1], 257 | [3, 184, 80, 0, 1], 258 | [3, 184, 80, 0, 1], 259 | [3, 480, 112, 0.25, 1], 260 | [3, 672, 112, 0.25, 1] 261 | ], 262 | # stage5 263 | [[5, 672, 160, 0.25, 2]], 264 | [[5, 960, 160, 0, 1], 265 | [5, 960, 160, 0.25, 1], 266 | [5, 960, 160, 0, 1], 267 | [5, 960, 160, 0.25, 1] 268 | ] 269 | ] 270 | return GhostNet(cfgs, **kwargs) 271 | 272 | 273 | if __name__=='__main__': 274 | model = ghostnet() 275 | model.eval() 276 | print(model) 277 | input = torch.randn(32,3,320,256) 278 | y = model(input) 279 | print(y.size()) -------------------------------------------------------------------------------- /model_zoo/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/d-li14/mobilenetv3.pytorch 3 | ---------------------------------------------------- 4 | Creates a MobileNetV3 Model as defined in: 5 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 6 | Searching for MobileNetV3 7 | arXiv preprint arXiv:1905.02244. 8 | """ 9 | 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small'] 15 | 16 | def _make_divisible(v, divisor, min_value=None): 17 | """ 18 | This function is taken from the original tf repo. 19 | It ensures that all layers have a channel number that is divisible by 8 20 | It can be seen here: 21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 22 | :param v: 23 | :param divisor: 24 | :param min_value: 25 | :return: 26 | """ 27 | if min_value is None: 28 | min_value = divisor 29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 30 | # Make sure that round down does not go down by more than 10%. 31 | if new_v < 0.9 * v: 32 | new_v += divisor 33 | return new_v 34 | 35 | 36 | class h_sigmoid(nn.Module): 37 | def __init__(self, inplace=True): 38 | super(h_sigmoid, self).__init__() 39 | self.relu = nn.ReLU6(inplace=inplace) 40 | 41 | def forward(self, x): 42 | return self.relu(x + 3) / 6 43 | 44 | 45 | class h_swish(nn.Module): 46 | def __init__(self, inplace=True): 47 | super(h_swish, self).__init__() 48 | self.sigmoid = h_sigmoid(inplace=inplace) 49 | 50 | def forward(self, x): 51 | return x * self.sigmoid(x) 52 | 53 | 54 | class SELayer(nn.Module): 55 | def __init__(self, channel, reduction=4): 56 | super(SELayer, self).__init__() 57 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 58 | self.fc = nn.Sequential( 59 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 60 | nn.ReLU(inplace=True), 61 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 62 | h_sigmoid() 63 | ) 64 | 65 | def forward(self, x): 66 | b, c, _, _ = x.size() 67 | y = self.avg_pool(x).view(b, c) 68 | y = self.fc(y).view(b, c, 1, 1) 69 | return x * y 70 | 71 | 72 | def conv_3x3_bn(inp, oup, stride): 73 | return nn.Sequential( 74 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 75 | nn.BatchNorm2d(oup), 76 | h_swish() 77 | ) 78 | 79 | 80 | def conv_1x1_bn(inp, oup): 81 | return nn.Sequential( 82 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 83 | nn.BatchNorm2d(oup), 84 | h_swish() 85 | ) 86 | 87 | 88 | class InvertedResidual(nn.Module): 89 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 90 | super(InvertedResidual, self).__init__() 91 | assert stride in [1, 2] 92 | 93 | self.identity = stride == 1 and inp == oup 94 | 95 | if inp == hidden_dim: 96 | self.conv = nn.Sequential( 97 | # dw 98 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 99 | nn.BatchNorm2d(hidden_dim), 100 | h_swish() if use_hs else nn.ReLU(inplace=True), 101 | # Squeeze-and-Excite 102 | SELayer(hidden_dim) if use_se else nn.Identity(), 103 | # pw-linear 104 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 105 | nn.BatchNorm2d(oup), 106 | ) 107 | else: 108 | self.conv = nn.Sequential( 109 | # pw 110 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 111 | nn.BatchNorm2d(hidden_dim), 112 | h_swish() if use_hs else nn.ReLU(inplace=True), 113 | # dw 114 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 115 | nn.BatchNorm2d(hidden_dim), 116 | h_swish() if use_hs else nn.ReLU(inplace=True), 117 | # Squeeze-and-Excite 118 | SELayer(hidden_dim) if use_se else nn.Identity(), 119 | # pw-linear 120 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 121 | nn.BatchNorm2d(oup), 122 | ) 123 | 124 | def forward(self, x): 125 | if self.identity: 126 | return x + self.conv(x) 127 | else: 128 | return self.conv(x) 129 | 130 | 131 | class MobileNetV3(nn.Module): 132 | def __init__(self, cfgs, mode, num_classes=1000, width_mult=1., 133 | init_mode='he_fout', bn_momentum=0.1, bn_eps=1e-3): 134 | super(MobileNetV3, self).__init__() 135 | # setting of inverted residual blocks 136 | self.cfgs = cfgs 137 | assert mode in ['large', 'small'] 138 | 139 | # building first layer 140 | input_channel = _make_divisible(16 * width_mult, 8) 141 | layers = [conv_3x3_bn(3, input_channel, 2)] 142 | # building inverted residual blocks 143 | block = InvertedResidual 144 | for k, t, c, use_se, use_hs, s in self.cfgs: 145 | output_channel = _make_divisible(c * width_mult, 8) 146 | exp_size = _make_divisible(input_channel * t, 8) 147 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 148 | input_channel = output_channel 149 | self.features = nn.Sequential(*layers) 150 | # building last several layers 151 | self.conv = conv_1x1_bn(input_channel, exp_size) 152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | output_channel = {'large': 1280, 'small': 1024} 154 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode] 155 | self.classifier = nn.Sequential( 156 | nn.Linear(exp_size, output_channel), 157 | h_swish(), 158 | nn.Dropout(0.2), 159 | nn.Linear(output_channel, num_classes), 160 | ) 161 | 162 | self._initialize_weights() 163 | self.set_bn_param(bn_momentum, bn_eps) 164 | 165 | def forward(self, x): 166 | x = self.features(x) 167 | x = self.conv(x) 168 | x = self.avgpool(x) 169 | x = x.view(x.size(0), -1) 170 | x = self.classifier(x) 171 | return x 172 | 173 | def _initialize_weights(self): 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 177 | if m.bias is not None: 178 | m.bias.data.zero_() 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | elif isinstance(m, nn.Linear): 183 | m.weight.data.normal_(0, 0.01) 184 | m.bias.data.zero_() 185 | 186 | def set_bn_param(self, bn_momentum, bn_eps): 187 | for m in self.modules(): 188 | if isinstance(m, nn.BatchNorm2d): 189 | m.momentum = bn_momentum 190 | m.eps = bn_eps 191 | return 192 | 193 | 194 | def mobilenetv3_large(**kwargs): 195 | """ 196 | Constructs a MobileNetV3-Large model 197 | """ 198 | cfgs = [ 199 | # k, t, c, SE, HS, s 200 | [3, 1, 16, 0, 0, 1], 201 | [3, 4, 24, 0, 0, 2], 202 | [3, 3, 24, 0, 0, 1], 203 | [5, 3, 40, 1, 0, 2], 204 | [5, 3, 40, 1, 0, 1], 205 | [5, 3, 40, 1, 0, 1], 206 | [3, 6, 80, 0, 1, 2], 207 | [3, 2.5, 80, 0, 1, 1], 208 | [3, 2.3, 80, 0, 1, 1], 209 | [3, 2.3, 80, 0, 1, 1], 210 | [3, 6, 112, 1, 1, 1], 211 | [3, 6, 112, 1, 1, 1], 212 | [5, 6, 160, 1, 1, 2], 213 | [5, 6, 160, 1, 1, 1], 214 | [5, 6, 160, 1, 1, 1] 215 | ] 216 | return MobileNetV3(cfgs, mode='large', **kwargs) 217 | 218 | 219 | def mobilenetv3_small(**kwargs): 220 | """ 221 | Constructs a MobileNetV3-Small model 222 | """ 223 | cfgs = [ 224 | # k, t, c, SE, HS, s 225 | [3, 1, 16, 1, 0, 2], 226 | [3, 4.5, 24, 0, 0, 2], 227 | [3, 3.67, 24, 0, 0, 1], 228 | [5, 4, 40, 1, 1, 2], 229 | [5, 6, 40, 1, 1, 1], 230 | [5, 6, 40, 1, 1, 1], 231 | [5, 3, 48, 1, 1, 1], 232 | [5, 3, 48, 1, 1, 1], 233 | [5, 6, 96, 1, 1, 2], 234 | [5, 6, 96, 1, 1, 1], 235 | [5, 6, 96, 1, 1, 1], 236 | ] 237 | 238 | return MobileNetV3(cfgs, mode='small', **kwargs) 239 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/models/__init__.py -------------------------------------------------------------------------------- /models/model_derived.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .operations import OPS 7 | from tools.utils import parse_net_config 8 | 9 | 10 | class Block(nn.Module): 11 | 12 | def __init__(self, in_ch, block_ch, head_op, stack_ops, stride, use_se=False): 13 | super(Block, self).__init__() 14 | self.head_layer = OPS[head_op](in_ch, block_ch, stride, use_se=use_se, 15 | affine=True, track_running_stats=True) 16 | 17 | modules = [] 18 | for stack_op in stack_ops: 19 | modules.append(OPS[stack_op](block_ch, block_ch, 1, use_se=use_se, 20 | affine=True, track_running_stats=True)) 21 | self.stack_layers = nn.Sequential(*modules) 22 | 23 | def forward(self, x): 24 | x = self.head_layer(x) 25 | x = self.stack_layers(x) 26 | return x 27 | 28 | 29 | class Conv1_1_Block(nn.Module): 30 | 31 | def __init__(self, in_ch, block_ch): 32 | super(Conv1_1_Block, self).__init__() 33 | self.conv1_1 = nn.Sequential( 34 | nn.Conv2d(in_channels=in_ch, out_channels=block_ch, 35 | kernel_size=1, stride=1, padding=0, bias=False), 36 | nn.BatchNorm2d(block_ch), 37 | nn.ReLU6(inplace=True) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.conv1_1(x) 42 | 43 | 44 | class MBV2_Net(nn.Module): 45 | def __init__(self, net_config, config=None): 46 | """ 47 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride] 48 | """ 49 | super(MBV2_Net, self).__init__() 50 | self.config = config 51 | self.net_config = parse_net_config(net_config) 52 | self.in_chs = self.net_config[0][0][0] 53 | self._num_classes = 1000 54 | use_se = False 55 | if hasattr(config.optim, 'use_se'): 56 | use_se = config.optim.use_se 57 | 58 | self.input_block = nn.Sequential( 59 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3, 60 | stride=2, padding=1, bias=False), 61 | nn.BatchNorm2d(self.in_chs), 62 | nn.ReLU6(inplace=True) 63 | ) 64 | self.blocks = nn.ModuleList() 65 | for cfg in self.net_config: 66 | if cfg[1] == 'conv1_1': 67 | continue 68 | self.blocks.append(Block(cfg[0][0], cfg[0][1], 69 | cfg[1], cfg[2], cfg[-1], use_se)) 70 | 71 | if self.net_config[-1][1] == 'conv1_1': 72 | block_last_dim = self.net_config[-1][0][0] 73 | last_dim = self.net_config[-1][0][1] 74 | else: 75 | block_last_dim = self.net_config[-1][0][1] 76 | self.conv1_1_block = Conv1_1_Block(block_last_dim, last_dim) 77 | 78 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 79 | self.classifier = nn.Linear(last_dim, self._num_classes) 80 | 81 | self.init_model() 82 | self.set_bn_param(0.1, 0.001) 83 | 84 | 85 | def forward(self,x): 86 | block_data = self.input_block(x) 87 | for i, block in enumerate(self.blocks): 88 | block_data = block(block_data) 89 | block_data = self.conv1_1_block(block_data) 90 | 91 | out = self.global_pooling(block_data) 92 | logits = self.classifier(out.view(out.size(0),-1)) 93 | 94 | return logits 95 | 96 | def init_model(self, model_init='he_fout', init_div_groups=True): 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | if model_init == 'he_fout': 100 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | if init_div_groups: 103 | n /= m.groups 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | elif model_init == 'he_fin': 106 | # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 108 | if init_div_groups: 109 | n /= m.groups 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | else: 112 | raise NotImplementedError 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.BatchNorm1d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | 123 | def set_bn_param(self, bn_momentum, bn_eps): 124 | for m in self.modules(): 125 | if isinstance(m, nn.BatchNorm2d): 126 | m.momentum = bn_momentum 127 | m.eps = bn_eps 128 | return 129 | 130 | 131 | class RES_Net(nn.Module): 132 | def __init__(self, net_config, config=None): 133 | """ 134 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride] 135 | """ 136 | super(RES_Net, self).__init__() 137 | self.config = config 138 | self.net_config = parse_net_config(net_config) 139 | self.in_chs = self.net_config[0][0][0] 140 | self._num_classes = 1000 141 | 142 | self.input_block = nn.Sequential( 143 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3, 144 | stride=2, padding=1, bias=False), 145 | nn.BatchNorm2d(self.in_chs), 146 | nn.ReLU6(inplace=True), 147 | ) 148 | self.blocks = nn.ModuleList() 149 | for config in self.net_config: 150 | self.blocks.append(Block(config[0][0], config[0][1], 151 | config[1], config[2], config[-1], use_se=False)) 152 | 153 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 154 | if self.net_config[-1][1] == 'bottle_neck': 155 | last_dim = self.net_config[-1][0][-1] * 4 156 | else: 157 | last_dim = self.net_config[-1][0][1] 158 | self.classifier = nn.Linear(last_dim, self._num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | if m.affine==True: 165 | nn.init.constant_(m.weight, 1) 166 | nn.init.constant_(m.bias, 0) 167 | 168 | def forward(self, x): 169 | block_data = self.input_block(x) 170 | for i, block in enumerate(self.blocks): 171 | block_data = block(block_data) 172 | 173 | out = self.global_pooling(block_data) 174 | out = torch.flatten(out, 1) 175 | logits = self.classifier(out) 176 | return logits 177 | 178 | 179 | class PLAIN_Net(MBV2_Net): 180 | def __init__(self, *args, **kwargs): 181 | super(PLAIN_Net, self).__init__(*args, **kwargs) 182 | -------------------------------------------------------------------------------- /models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'mbconv_k3_t1': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=1, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 6 | 'mbconv_k3_t3': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=3, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 7 | 'mbconv_k3_t6': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=6, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 8 | 'mbconv_k5_t1': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=1, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 9 | 'mbconv_k5_t3': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=3, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 10 | 'mbconv_k5_t6': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=6, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 11 | 'mbconv_k7_t1': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=1, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 12 | 'mbconv_k7_t3': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=3, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 13 | 'mbconv_k7_t6': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=6, use_se=use_se, affine=affine, track_running_stats=track_running_stats), 14 | 'basic_block': lambda C_in, C_out, stride, use_se, affine, track_running_stats: BasicBlock(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats), 15 | 'bottle_neck': lambda C_in, C_out, stride, use_se, affine, track_running_stats: Bottleneck(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats), 16 | 'skip_connect': lambda C_in, C_out, stride, use_se, affine, track_running_stats: Skip(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats), 17 | } 18 | 19 | 20 | class MBConv(nn.Module): 21 | def __init__(self, C_in, C_out, kernel_size, stride, padding, t=3, affine=True, 22 | track_running_stats=True, use_se=False): 23 | super(MBConv, self).__init__() 24 | if t > 1: 25 | C_hidden = C_in*t 26 | self._expand_conv = nn.Sequential( 27 | nn.Conv2d(C_in, C_hidden, kernel_size=1, stride=1, padding=0, groups=1, bias=False), 28 | nn.BatchNorm2d(C_hidden, affine=affine, track_running_stats=track_running_stats), 29 | nn.ReLU6(inplace=True)) 30 | 31 | self._depthwise_conv = nn.Sequential( 32 | nn.Conv2d(C_hidden, C_hidden, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_hidden, bias=False), 33 | nn.BatchNorm2d(C_hidden, affine=affine, track_running_stats=track_running_stats), 34 | nn.ReLU6(inplace=True)) 35 | 36 | self._project_conv = nn.Sequential( 37 | nn.Conv2d(C_hidden, C_out, kernel_size=1, stride=1, padding=0, groups=1, bias=False), 38 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)) 39 | else: 40 | C_hidden = C_in 41 | self._expand_conv = None 42 | 43 | self._depthwise_conv = nn.Sequential( 44 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 45 | nn.BatchNorm2d(C_in, affine=affine, track_running_stats=track_running_stats), 46 | nn.ReLU6(inplace=True)) 47 | 48 | self._project_conv = nn.Sequential( 49 | nn.Conv2d(C_in, C_out, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(C_out)) 51 | 52 | self.se = SELayer(C_hidden) if use_se else None 53 | 54 | def forward(self, x): 55 | input_data = x 56 | if self._expand_conv is not None: 57 | x = self._expand_conv(x) 58 | x = self._depthwise_conv(x) 59 | if self.se is not None: 60 | x = self.se(x) 61 | out_data = self._project_conv(x) 62 | 63 | if out_data.shape == input_data.shape: 64 | return out_data + input_data 65 | else: 66 | return out_data 67 | 68 | 69 | class h_sigmoid(nn.Module): 70 | def __init__(self, inplace=True): 71 | super(h_sigmoid, self).__init__() 72 | self.relu = nn.ReLU6(inplace=inplace) 73 | 74 | def forward(self, x): 75 | return self.relu(x + 3) / 6 76 | 77 | def _make_divisible(v, divisor, min_value=None): 78 | """ 79 | This function is taken from the original tf repo. 80 | It ensures that all layers have a channel number that is divisible by 8 81 | It can be seen here: 82 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 83 | :param v: 84 | :param divisor: 85 | :param min_value: 86 | :return: 87 | """ 88 | if min_value is None: 89 | min_value = divisor 90 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 91 | # Make sure that round down does not go down by more than 10%. 92 | if new_v < 0.9 * v: 93 | new_v += divisor 94 | return new_v 95 | 96 | class SELayer(nn.Module): 97 | def __init__(self, channel, reduction=4): 98 | super(SELayer, self).__init__() 99 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 100 | self.fc = nn.Sequential( 101 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 102 | nn.ReLU(inplace=True), 103 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 104 | h_sigmoid() 105 | ) 106 | 107 | def forward(self, x): 108 | b, c, _, _ = x.size() 109 | y = self.avg_pool(x).view(b, c) 110 | y = self.fc(y).view(b, c, 1, 1) 111 | return x * y 112 | 113 | 114 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 115 | """3x3 convolution with padding""" 116 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 117 | padding=dilation, groups=groups, bias=False, dilation=dilation) 118 | 119 | def conv1x1(in_planes, out_planes, stride=1): 120 | """1x1 convolution""" 121 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 122 | 123 | class BasicBlock(nn.Module): 124 | def __init__(self, inplanes, planes, stride=1, groups=1, 125 | base_width=64, dilation=1, norm_layer=None, 126 | affine=True, track_running_stats=True): 127 | super(BasicBlock, self).__init__() 128 | if norm_layer is None: 129 | norm_layer = nn.BatchNorm2d 130 | if groups != 1 or base_width != 64: 131 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 132 | if dilation > 1: 133 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 134 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 135 | self.conv1 = conv3x3(inplanes, planes, stride) 136 | self.bn1 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.conv2 = conv3x3(planes, planes) 139 | self.bn2 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats) 140 | self.downsample = None 141 | if stride != 1 or inplanes != planes: 142 | self.downsample = nn.Sequential( 143 | conv1x1(inplanes, planes, stride), 144 | norm_layer(planes, affine=affine, track_running_stats=track_running_stats), 145 | ) 146 | 147 | def forward(self, x): 148 | identity = x 149 | out = self.conv1(x) 150 | out = self.bn1(out) 151 | out = self.relu(out) 152 | out = self.conv2(out) 153 | out = self.bn2(out) 154 | 155 | if self.downsample is not None: 156 | identity = self.downsample(x) 157 | out += identity 158 | out = self.relu(out) 159 | 160 | return out 161 | 162 | 163 | class Bottleneck(nn.Module): 164 | def __init__(self, inplanes, planes, stride=1, affine=True, track_running_stats=True): 165 | super(Bottleneck, self).__init__() 166 | if inplanes != 32: 167 | inplanes *= 4 168 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 169 | self.bn1 = nn.BatchNorm2d(planes) 170 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 171 | padding=1, bias=False) 172 | self.bn2 = nn.BatchNorm2d(planes) 173 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 174 | self.bn3 = nn.BatchNorm2d(planes * 4) 175 | self.relu = nn.ReLU(inplace=True) 176 | self.stride = stride 177 | self.downsample = None 178 | if stride != 1 or inplanes != planes*4: 179 | self.downsample = nn.Sequential( 180 | conv1x1(inplanes, planes * 4, stride), 181 | nn.BatchNorm2d(planes * 4, affine=affine, track_running_stats=track_running_stats), 182 | ) 183 | 184 | def forward(self, x): 185 | residual = x 186 | 187 | out = self.conv1(x) 188 | out = self.bn1(out) 189 | out = self.relu(out) 190 | 191 | out = self.conv2(out) 192 | out = self.bn2(out) 193 | out = self.relu(out) 194 | 195 | out = self.conv3(out) 196 | out = self.bn3(out) 197 | 198 | if self.downsample is not None: 199 | residual = self.downsample(x) 200 | 201 | out += residual 202 | out = self.relu(out) 203 | 204 | return out 205 | 206 | 207 | class Skip(nn.Module): 208 | def __init__(self, C_in, C_out, stride, affine=True, track_running_stats=True): 209 | super(Skip, self).__init__() 210 | if C_in!=C_out: 211 | skip_conv = nn.Sequential( 212 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=stride, padding=0, groups=1, bias=False), 213 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)) 214 | stride = 1 215 | self.op=Identity(stride) 216 | 217 | if C_in!=C_out: 218 | self.op=nn.Sequential(skip_conv, self.op) 219 | 220 | def forward(self,x): 221 | return self.op(x) 222 | 223 | class Identity(nn.Module): 224 | def __init__(self, stride): 225 | super(Identity, self).__init__() 226 | self.stride = stride 227 | 228 | def forward(self, x): 229 | if self.stride == 1: 230 | return x 231 | else: 232 | return x[:, :, ::self.stride, ::self.stride] 233 | -------------------------------------------------------------------------------- /run_apis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/run_apis/__init__.py -------------------------------------------------------------------------------- /run_apis/train_dist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import logging 4 | import os 5 | import sys; sys.path.append(os.path.join(sys.path[0], '..')) 6 | import time 7 | 8 | import model_zoo 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.nn as nn 14 | from dataset import imagenet_data_dali 15 | from mmcv import Config 16 | from models import model_derived 17 | from tensorboardX import SummaryWriter 18 | from tools import env, utils 19 | from tools.lr_scheduler import get_lr_scheduler 20 | from tools.multadds_count import comp_multadds 21 | 22 | from trainer import Trainer 23 | 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser("Train_Params") 28 | parser.add_argument('--report_freq', type=float, default=500, help='report frequency') 29 | parser.add_argument('--data_path', type=str, default='../data', help='location of the data corpus') 30 | parser.add_argument('--load_path', type=str, default='./model_path', help='model loading path') 31 | parser.add_argument('--save', type=str, default='../', help='experiment name') 32 | parser.add_argument('--tb_path', type=str, default='', help='tensorboard output path') 33 | parser.add_argument('--meas_lat', type=ast.literal_eval, default='False', help='whether to measure the latency of the model') 34 | parser.add_argument('--job_name', type=str, default='', help='job_name') 35 | parser.add_argument('--port', type=int, default=23333, help='dist port') 36 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') 37 | parser.add_argument('--local_rank', default=0, type=int) 38 | parser.add_argument('--evaluation', type=ast.literal_eval, default='False', help='first evaluation') 39 | parser.add_argument('--config', type=str, default='', help='the file of the config') 40 | args = parser.parse_args() 41 | 42 | config = Config.fromfile(os.path.join('configs/train_cfg', args.config)) 43 | if config.net_config: 44 | net_config = config.pop('net_config') 45 | 46 | # init distributed env first, since logger depends on the dist info. 47 | if args.launcher == 'none': 48 | distributed = False 49 | raise EnvironmentError 50 | else: 51 | distributed = True 52 | os.environ['MASTER_ADDR'] = 'localhost' 53 | os.environ['MASTER_PORT'] = '%d' % args.port 54 | env.init_dist(args.launcher) 55 | local_rank = dist.get_rank() 56 | world_size = dist.get_world_size() 57 | 58 | if args.job_name != '': 59 | args.job_name = time.strftime("%Y%m%d-%H%M%S-") + args.job_name 60 | args.save = os.path.join(args.save, args.job_name) 61 | if local_rank == 0: 62 | utils.create_exp_dir(args.save) 63 | os.system('cp -r ./* '+args.save) 64 | else: 65 | args.save = os.path.join(args.save, 'output') 66 | if local_rank == 0: 67 | utils.create_exp_dir(args.save) 68 | 69 | if args.tb_path == '': 70 | args.tb_path = args.save 71 | 72 | env.get_root_logger(log_dir=args.save) 73 | cudnn.benchmark = True 74 | cudnn.enabled = True 75 | 76 | if config.train_params.use_seed: 77 | utils.set_seed(config.train_params.seed) 78 | 79 | logging.info("args = %s", args) 80 | logging.info('Training with config:') 81 | logging.info(config.pretty_text) 82 | writer = SummaryWriter(args.tb_path) 83 | 84 | if config.model_zoo.use_model_zoo: 85 | model = getattr(model_zoo, config.model_zoo.model_name)( 86 | **config.model_zoo.cfg if config.model_zoo.cfg else {}) 87 | else: 88 | if os.path.isfile(os.path.join(args.load_path, 'net_config')): 89 | net_config, config.net_type = utils.load_net_config( 90 | os.path.join(args.load_path, 'net_config')) 91 | derivedNetwork = getattr(model_derived, '%s_Net' % config.net_type.upper()) 92 | model = derivedNetwork(net_config, config=config) 93 | 94 | model.eval() 95 | if hasattr(model, 'net_config'): 96 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.net_config))) 97 | if args.meas_lat: 98 | latency_cpu = utils.latency_measure(model, (3, 224, 224), 1, 2000, mode='cpu') 99 | logging.info('latency_cpu (batch 1): %.2fms' % latency_cpu) 100 | latency_gpu = utils.latency_measure(model, (3, 224, 224), 32, 1000, mode='gpu') 101 | logging.info('latency_gpu (batch 32): %.2fms' % latency_gpu) 102 | 103 | params = utils.count_parameters_in_MB(model) 104 | logging.info("Params = %.2fMB" % params) 105 | mult_adds = comp_multadds(model, input_size=config.data.input_size) 106 | logging.info("Mult-Adds = %.2fMB" % mult_adds) 107 | 108 | model.cuda(local_rank) 109 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) 110 | 111 | # whether to resume from a checkpoint 112 | if config.optim.if_resume: 113 | utils.load_model(model, config.optim.resume.load_path, distributed) 114 | start_epoch = config.optim.resume.load_epoch + 1 115 | else: 116 | start_epoch = 0 117 | 118 | if config.optim.label_smooth: 119 | criterion = utils.cross_entropy_with_label_smoothing 120 | else: 121 | criterion = nn.CrossEntropyLoss() 122 | criterion = criterion.cuda() 123 | 124 | if config.optim.bn_wo_wd: 125 | parameters = utils.add_weight_decay(model, config.optim.weight_decay) 126 | else: 127 | parameters = model.parameters() 128 | optimizer = torch.optim.SGD( 129 | parameters, 130 | config.optim.init_lr, 131 | momentum=config.optim.momentum, 132 | weight_decay=config.optim.weight_decay 133 | ) 134 | 135 | train_loader, val_loader = imagenet_data_dali.get_data_loader( 136 | config.data, world_size, local_rank, args.data_path 137 | ) 138 | 139 | scheduler = get_lr_scheduler(config, optimizer, train_loader._size) 140 | scheduler.last_step = start_epoch * (train_loader._size // config.data.batch_size + 1)-1 141 | 142 | trainer = Trainer(train_loader, val_loader, optimizer, criterion, 143 | scheduler, config, args.report_freq, distributed) 144 | 145 | best_epoch = [0, 0, 0] # [epoch, acc_top1, acc_top5] 146 | if args.evaluation: 147 | with torch.no_grad(): 148 | val_acc_top1, val_acc_top5, batch_time, data_time = trainer.infer(model, start_epoch-1) 149 | if val_acc_top1 > best_epoch[1]: 150 | best_epoch = [start_epoch-1, val_acc_top1, val_acc_top5] 151 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2]) 152 | 153 | for epoch in range(start_epoch, config.train_params.epochs): 154 | train_acc_top1, train_acc_top5, train_obj, batch_time, data_time = trainer.train(model, epoch) 155 | 156 | with torch.no_grad(): 157 | val_acc_top1, val_acc_top5, batch_time, data_time = trainer.infer(model, epoch) 158 | if val_acc_top1 > best_epoch[1]: 159 | best_epoch = [epoch, val_acc_top1, val_acc_top5] 160 | if local_rank==0: 161 | utils.save(model, os.path.join(args.save, 'weights.pt')) 162 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2]) 163 | 164 | if local_rank == 0: 165 | writer.add_scalar('train_acc_top1', train_acc_top1, epoch) 166 | writer.add_scalar('train_loss', train_obj, epoch) 167 | writer.add_scalar('val_acc_top1', val_acc_top1, epoch) 168 | 169 | if hasattr(model.module, 'net_config'): 170 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.module.net_config))) 171 | logging.info("Params = %.2fMB" % params) 172 | logging.info("Mult-Adds = %.2fMB" % mult_adds) 173 | -------------------------------------------------------------------------------- /run_apis/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | from tools import env, utils 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, train_data, val_data, optimizer=None, criterion=None, 11 | scheduler=None, config=None, report_freq=None, distributed=False): 12 | self.train_data = train_data 13 | self.val_data = val_data 14 | self.optimizer = optimizer 15 | self.criterion = criterion 16 | self.scheduler = scheduler 17 | self.config = config 18 | self.report_freq = report_freq 19 | self.distributed = distributed 20 | 21 | def train(self, model, epoch): 22 | objs = utils.AverageMeter() 23 | top1 = utils.AverageMeter() 24 | top5 = utils.AverageMeter() 25 | data_time = utils.AverageMeter() 26 | batch_time = utils.AverageMeter() 27 | model.train() 28 | 29 | start = time.time() 30 | try: 31 | data = next(self.train_data) 32 | input = data[0]["data"] 33 | target = data[0]["label"].squeeze().cuda().long() 34 | except StopIteration: 35 | input, target = None, None 36 | step = 0 37 | while input is not None: 38 | data_t = time.time() - start 39 | self.scheduler.step() 40 | n = input.size(0) 41 | if step==0: 42 | logging.info('epoch %d lr %e', epoch, self.optimizer.param_groups[0]['lr']) 43 | self.optimizer.zero_grad() 44 | 45 | logits= model(input) 46 | if self.config.optim.label_smooth: 47 | loss = self.criterion(logits, target, self.config.optim.smooth_alpha) 48 | else: 49 | loss = self.criterion(logits, target) 50 | 51 | loss.backward() 52 | if self.config.optim.use_grad_clip: 53 | nn.utils.clip_grad_norm_(model.parameters(), self.config.optim.grad_clip) 54 | self.optimizer.step() 55 | 56 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 57 | 58 | if self.distributed: 59 | loss = env.reduce_tensor(loss.data) 60 | prec1 = env.reduce_tensor(prec1) 61 | prec5 = env.reduce_tensor(prec5) 62 | torch.cuda.synchronize() 63 | 64 | batch_t = time.time() - start 65 | start = time.time() 66 | objs.update(loss.item(), n) 67 | top1.update(prec1.item(), n) 68 | top5.update(prec5.item(), n) 69 | data_time.update(data_t) 70 | batch_time.update(batch_t) 71 | if step!=0 and step % self.report_freq == 0: 72 | logging.info( 73 | 'Train epoch %03d step %03d | loss %.4f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', 74 | epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg) 75 | 76 | try: 77 | data = next(self.train_data) 78 | input = data[0]["data"] 79 | target = data[0]["label"].squeeze().cuda().long() 80 | except StopIteration: 81 | input, target = None, None 82 | step += 1 83 | logging.info('EPOCH%d Train_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', 84 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg) 85 | 86 | return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg 87 | 88 | 89 | def infer(self, model, epoch=0): 90 | top1 = utils.AverageMeter() 91 | top5 = utils.AverageMeter() 92 | data_time = utils.AverageMeter() 93 | batch_time = utils.AverageMeter() 94 | model.eval() 95 | 96 | start = time.time() 97 | try: 98 | data = next(self.val_data) 99 | input = data[0]["data"] 100 | target = data[0]["label"].squeeze().cuda().long() 101 | except StopIteration: 102 | input, target = None, None 103 | step = 0 104 | while input is not None: 105 | step += 1 106 | data_t = time.time() - start 107 | n = input.size(0) 108 | logits = model(input) 109 | 110 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 111 | if self.distributed: 112 | prec1 = env.reduce_tensor(prec1) 113 | prec5 = env.reduce_tensor(prec5) 114 | torch.cuda.synchronize() 115 | batch_t = time.time() - start 116 | start = time.time() 117 | top1.update(prec1.item(), n) 118 | top5.update(prec5.item(), n) 119 | data_time.update(data_t) 120 | batch_time.update(batch_t) 121 | 122 | if step % self.report_freq == 0: 123 | logging.info( 124 | 'Val epoch %03d step %03d | top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f', 125 | epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg) 126 | try: 127 | data = next(self.val_data) 128 | input = data[0]["data"] 129 | target = data[0]["label"].squeeze().cuda().long() 130 | except StopIteration: 131 | input, target = None, None 132 | 133 | logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f', 134 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg) 135 | return top1.avg, top5.avg, batch_time.avg, data_time.avg 136 | 137 | -------------------------------------------------------------------------------- /scripts/train_dist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -m torch.distributed.launch --nproc_per_node=$1 run_apis/train_dist.py \ 4 | --launcher pytorch \ 5 | --report_freq 400 \ 6 | --data_path $2 \ 7 | --port 23333 \ 8 | --config $3 9 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/tools/__init__.py -------------------------------------------------------------------------------- /tools/env.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | 11 | 12 | def init_dist(launcher, backend='nccl', **kwargs): 13 | # if mp.get_start_method(allow_none=True) is None: 14 | # mp.set_start_method('spawn') 15 | if launcher == 'pytorch': 16 | _init_dist_pytorch(backend, **kwargs) 17 | elif launcher == 'mpi': 18 | _init_dist_mpi(backend, **kwargs) 19 | elif launcher == 'slurm': 20 | _init_dist_slurm(backend, **kwargs) 21 | else: 22 | raise ValueError('Invalid launcher type: {}'.format(launcher)) 23 | 24 | 25 | def _init_dist_pytorch(backend, **kwargs): 26 | rank = int(os.environ['RANK']) 27 | num_gpus = torch.cuda.device_count() 28 | torch.cuda.set_device(rank % num_gpus) 29 | dist.init_process_group(backend=backend, **kwargs) 30 | 31 | 32 | def _init_dist_mpi(backend, **kwargs): 33 | raise NotImplementedError 34 | 35 | 36 | def _init_dist_slurm(backend, **kwargs): 37 | raise NotImplementedError 38 | 39 | 40 | def set_random_seed(seed): 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed_all(seed) 45 | 46 | 47 | def get_root_logger(log_dir=None, log_level=logging.INFO): 48 | logger = logging.getLogger() 49 | if not logger.hasHandlers(): 50 | logging.basicConfig( 51 | format='%(asctime)s - %(message)s', 52 | level=log_level, 53 | datefmt='%m/%d %I:%M:%S %p') 54 | rank = dist.get_rank() 55 | if rank != 0: 56 | logger.setLevel('ERROR') 57 | 58 | if log_dir and rank == 0: 59 | filename = 'train_dist.log' 60 | log_file = osp.join(log_dir, filename) 61 | _add_file_handler(logger, log_file, level=log_level) 62 | return logger 63 | 64 | def _add_file_handler(logger, 65 | filename=None, 66 | mode='w', 67 | level=logging.INFO): 68 | file_handler = logging.FileHandler(filename, mode) 69 | file_handler.setFormatter( 70 | logging.Formatter('%(asctime)s - %(message)s')) 71 | file_handler.setLevel(level) 72 | logger.addHandler(file_handler) 73 | return logger 74 | 75 | def reduce_tensor(tensor): 76 | rt = tensor.clone() 77 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 78 | rt /= dist.get_world_size() 79 | return rt -------------------------------------------------------------------------------- /tools/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import CosineAnnealingLR 3 | from torch.optim.optimizer import Optimizer 4 | import math 5 | 6 | class CosineRestartAnnealingLR(object): 7 | # decay as step 8 | # T_max refers to the max update step 9 | 10 | def __init__(self, optimizer, T_max, lr_period, lr_step, eta_min=0, last_step=-1, 11 | use_warmup=False, warmup_mode='linear', warmup_steps=0, warmup_startlr=0, 12 | warmup_targetlr=0, use_restart=False): 13 | 14 | self.use_warmup = use_warmup 15 | self.warmup_mode = warmup_mode 16 | self.warmup_steps = warmup_steps 17 | self.warmup_startlr = warmup_startlr 18 | self.warmup_targetlr = warmup_targetlr 19 | self.use_restart = use_restart 20 | self.T_max = T_max 21 | self.eta_min = eta_min 22 | 23 | if self.use_restart == False: 24 | self.lr_period = [self.T_max - self.warmup_steps] 25 | self.lr_step = [self.warmup_steps] 26 | else: 27 | self.lr_period = lr_period 28 | self.lr_step = lr_step 29 | 30 | self.last_step = last_step 31 | self.cycle_length = self.lr_period[0] 32 | self.cur = 0 33 | 34 | if not isinstance(optimizer, Optimizer): 35 | raise TypeError('{} is not an Optimizer'.format( 36 | type(optimizer).__name__)) 37 | self.optimizer = optimizer 38 | if last_step == -1: 39 | for group in optimizer.param_groups: 40 | group.setdefault('initial_lr', group['lr']) 41 | else: 42 | for i, group in enumerate(optimizer.param_groups): 43 | if 'initial_lr' not in group: 44 | raise KeyError("param 'initial_lr' is not specified " 45 | "in param_groups[{}] when resuming an optimizer".format(i)) 46 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 47 | 48 | 49 | def step(self, step=None): 50 | 51 | if step is not None: 52 | self.last_step = step 53 | else: 54 | self.last_step += 1 55 | 56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 57 | param_group['lr'] = lr 58 | 59 | 60 | def get_lr(self): 61 | 62 | lrs = [] 63 | for base_lr in self.base_lrs: 64 | if self.use_warmup and self.last_step < self.warmup_steps: 65 | if self.warmup_mode == 'constant': 66 | lrs.append(self.warmup_startlr) 67 | elif self.warmup_mode =='linear': 68 | cur_lr = self.warmup_startlr + \ 69 | float(self.warmup_targetlr-self.warmup_startlr)/self.warmup_steps*self.last_step 70 | lrs.append(cur_lr) 71 | else: 72 | raise NotImplementedError 73 | 74 | else: 75 | if (self.last_step) in self.lr_step: 76 | self.cycle_length = self.lr_period[self.lr_step.index(self.last_step)] 77 | self.cur = self.last_step 78 | 79 | peri_iter = self.last_step-self.cur 80 | 81 | if peri_iter <= self.cycle_length: 82 | unit_cycle = (1 + math.cos(peri_iter * math.pi / self.cycle_length)) / 2 83 | adjusted_cycle = unit_cycle * (base_lr - self.eta_min) + self.eta_min 84 | lrs.append(adjusted_cycle) 85 | else: 86 | lrs.append(self.eta_min) 87 | 88 | return lrs 89 | 90 | 91 | def display_lr_curve(self, total_steps): 92 | lrs = [] 93 | for _ in range(total_steps): 94 | self.step() 95 | lrs.append(self.get_lr()[0]) 96 | import matplotlib.pyplot as plt 97 | plt.plot(lrs) 98 | plt.show() 99 | 100 | 101 | def get_lr_scheduler(config, optimizer, num_examples=None): 102 | 103 | if num_examples is None: 104 | num_examples = config.data.num_examples 105 | epoch_steps = math.ceil(num_examples // config.data.batch_size) 106 | 107 | if config.optim.use_multi_stage: 108 | max_steps = epoch_steps * config.optim.multi_stage.stage_epochs 109 | else: 110 | max_steps = epoch_steps * config.train_params.epochs 111 | 112 | period_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_period] 113 | step_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_step] 114 | 115 | init_lr = config.optim.init_lr 116 | 117 | use_warmup = config.optim.use_warm_up 118 | if use_warmup: 119 | warmup_steps = config.optim.warm_up.epoch * epoch_steps 120 | warmup_startlr = config.optim.warm_up.init_lr 121 | warmup_targetlr = config.optim.warm_up.target_lr 122 | else: 123 | warmup_steps = 0 124 | warmup_startlr = init_lr 125 | warmup_targetlr = init_lr 126 | 127 | if config.optim.lr_schedule == 'cosine': 128 | scheduler = CosineRestartAnnealingLR(optimizer, 129 | float(max_steps), 130 | period_steps, 131 | step_steps, 132 | eta_min=config.optim.min_lr, 133 | use_warmup=use_warmup, 134 | warmup_steps=warmup_steps, 135 | warmup_startlr=warmup_startlr, 136 | warmup_targetlr=warmup_targetlr, 137 | use_restart=config.optim.cosine.use_restart) 138 | # scheduler = CosineAnnealingLR(optimizer, config.train_params.epochs, config.optim.min_lr) 139 | elif config.optim.lr_schedule == 'poly': 140 | raise NotImplementedError 141 | else: 142 | raise NotImplementedError 143 | 144 | return scheduler 145 | 146 | -------------------------------------------------------------------------------- /tools/multadds_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # Original implementation: 3 | # https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py 4 | 5 | # ---- Public functions 6 | 7 | def comp_multadds(model, input_size=(3,224,224)): 8 | input_size = (1,) + tuple(input_size) 9 | model = model.cuda() 10 | input_data = torch.randn(input_size).cuda() 11 | model = add_flops_counting_methods(model) 12 | model.start_flops_count() 13 | with torch.no_grad(): 14 | _ = model(input_data) 15 | 16 | mult_adds = model.compute_average_flops_cost() / 1e6 17 | return mult_adds 18 | 19 | 20 | def comp_multadds_fw(model, input_data, use_gpu=True): 21 | model = add_flops_counting_methods(model) 22 | if use_gpu: 23 | model = model.cuda() 24 | model.start_flops_count() 25 | with torch.no_grad(): 26 | output_data = model(input_data) 27 | 28 | mult_adds = model.compute_average_flops_cost() / 1e6 29 | return mult_adds, output_data 30 | 31 | 32 | def add_flops_counting_methods(net_main_module): 33 | """Adds flops counting functions to an existing model. After that 34 | the flops count should be activated and the model should be run on an input 35 | image. 36 | Example: 37 | fcn = add_flops_counting_methods(fcn) 38 | fcn = fcn.cuda().train() 39 | fcn.start_flops_count() 40 | _ = fcn(batch) 41 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch 42 | Important: dividing by 2 only works for resnet models -- see below for the details 43 | of flops computation. 44 | Attention: we are counting multiply-add as two flops in this work, because in 45 | most resnet models convolutions are bias-free (BN layers act as bias there) 46 | and it makes sense to count muliply and add as separate flops therefore. 47 | This is why in the above example we divide by 2 in order to be consistent with 48 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual 49 | Networks" by Figurnov et al multiply-add was counted as two flops. 50 | This module computes the average flops which is necessary for dynamic networks which 51 | have different number of executed layers. For static networks it is enough to run the network 52 | once and get statistics (above example). 53 | Implementation: 54 | The module works by adding batch_count to the main module which tracks the sum 55 | of all batch sizes that were run through the network. 56 | Also each convolutional layer of the network tracks the overall number of flops 57 | performed. 58 | The parameters are updated with the help of registered hook-functions which 59 | are being called each time the respective layer is executed. 60 | Parameters 61 | ---------- 62 | net_main_module : torch.nn.Module 63 | Main module containing network 64 | Returns 65 | ------- 66 | net_main_module : torch.nn.Module 67 | Updated main module with new methods/attributes that are used 68 | to compute flops. 69 | """ 70 | 71 | # adding additional methods to the existing module object, 72 | # this is done this way so that each function has access to self object 73 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 74 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 75 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 76 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 77 | 78 | net_main_module.reset_flops_count() 79 | 80 | # Adding varialbles necessary for masked flops computation 81 | net_main_module.apply(add_flops_mask_variable_or_reset) 82 | 83 | return net_main_module 84 | 85 | 86 | def compute_average_flops_cost(self): 87 | """ 88 | A method that will be available after add_flops_counting_methods() is called 89 | on a desired net object. 90 | Returns current mean flops consumption per image. 91 | """ 92 | 93 | batches_count = self.__batch_counter__ 94 | 95 | flops_sum = 0 96 | 97 | for module in self.modules(): 98 | 99 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 100 | flops_sum += module.__flops__ 101 | 102 | 103 | return flops_sum / batches_count 104 | 105 | 106 | def start_flops_count(self): 107 | """ 108 | A method that will be available after add_flops_counting_methods() is called 109 | on a desired net object. 110 | Activates the computation of mean flops consumption per image. 111 | Call it before you run the network. 112 | """ 113 | 114 | add_batch_counter_hook_function(self) 115 | 116 | self.apply(add_flops_counter_hook_function) 117 | 118 | 119 | def stop_flops_count(self): 120 | """ 121 | A method that will be available after add_flops_counting_methods() is called 122 | on a desired net object. 123 | Stops computing the mean flops consumption per image. 124 | Call whenever you want to pause the computation. 125 | """ 126 | 127 | remove_batch_counter_hook_function(self) 128 | 129 | self.apply(remove_flops_counter_hook_function) 130 | 131 | 132 | def reset_flops_count(self): 133 | """ 134 | A method that will be available after add_flops_counting_methods() is called 135 | on a desired net object. 136 | Resets statistics computed so far. 137 | """ 138 | 139 | add_batch_counter_variables_or_reset(self) 140 | 141 | self.apply(add_flops_counter_variable_or_reset) 142 | 143 | 144 | def add_flops_mask(module, mask): 145 | def add_flops_mask_func(module): 146 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 147 | module.__mask__ = mask 148 | 149 | module.apply(add_flops_mask_func) 150 | 151 | 152 | def remove_flops_mask(module): 153 | module.apply(add_flops_mask_variable_or_reset) 154 | 155 | 156 | # ---- Internal functions 157 | 158 | 159 | def conv_flops_counter_hook(conv_module, input, output): 160 | # Can have multiple inputs, getting the first one 161 | input = input[0] 162 | 163 | batch_size = input.shape[0] 164 | output_height, output_width = output.shape[2:] 165 | 166 | kernel_height, kernel_width = conv_module.kernel_size 167 | in_channels = conv_module.in_channels 168 | out_channels = conv_module.out_channels 169 | 170 | conv_per_position_flops = (kernel_height * kernel_width * in_channels * out_channels) / conv_module.groups 171 | 172 | active_elements_count = batch_size * output_height * output_width 173 | 174 | if conv_module.__mask__ is not None: 175 | # (b, 1, h, w) 176 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 177 | active_elements_count = flops_mask.sum() 178 | 179 | overall_conv_flops = conv_per_position_flops * active_elements_count 180 | 181 | bias_flops = 0 182 | 183 | if conv_module.bias is not None: 184 | bias_flops = out_channels * active_elements_count 185 | 186 | overall_flops = overall_conv_flops + bias_flops 187 | 188 | conv_module.__flops__ += overall_flops 189 | 190 | 191 | def linear_flops_counter_hook(linear_module, input, output): 192 | 193 | input = input[0] 194 | batch_size = input.shape[0] 195 | overall_flops = linear_module.in_features * linear_module.out_features * batch_size 196 | 197 | # bias_flops = 0 198 | 199 | # if conv_module.bias is not None: 200 | # bias_flops = out_channels * active_elements_count 201 | 202 | # overall_flops = overall_conv_flops + bias_flops 203 | 204 | linear_module.__flops__ += overall_flops 205 | 206 | 207 | def batch_counter_hook(module, input, output): 208 | # Can have multiple inputs, getting the first one 209 | input = input[0] 210 | 211 | batch_size = input.shape[0] 212 | 213 | module.__batch_counter__ += batch_size 214 | 215 | 216 | def add_batch_counter_variables_or_reset(module): 217 | module.__batch_counter__ = 0 218 | 219 | 220 | def add_batch_counter_hook_function(module): 221 | if hasattr(module, '__batch_counter_handle__'): 222 | return 223 | 224 | handle = module.register_forward_hook(batch_counter_hook) 225 | module.__batch_counter_handle__ = handle 226 | 227 | 228 | def remove_batch_counter_hook_function(module): 229 | if hasattr(module, '__batch_counter_handle__'): 230 | module.__batch_counter_handle__.remove() 231 | 232 | del module.__batch_counter_handle__ 233 | 234 | 235 | def add_flops_counter_variable_or_reset(module): 236 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 237 | module.__flops__ = 0 238 | 239 | 240 | def add_flops_counter_hook_function(module): 241 | if isinstance(module, torch.nn.Conv2d): 242 | if hasattr(module, '__flops_handle__'): 243 | return 244 | 245 | handle = module.register_forward_hook(conv_flops_counter_hook) 246 | module.__flops_handle__ = handle 247 | elif isinstance(module, torch.nn.Linear): 248 | 249 | if hasattr(module, '__flops_handle__'): 250 | return 251 | 252 | handle = module.register_forward_hook(linear_flops_counter_hook) 253 | module.__flops_handle__ = handle 254 | 255 | 256 | def remove_flops_counter_hook_function(module): 257 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 258 | 259 | if hasattr(module, '__flops_handle__'): 260 | module.__flops_handle__.remove() 261 | 262 | del module.__flops_handle__ 263 | 264 | 265 | # --- Masked flops counting 266 | 267 | 268 | # Also being run in the initialization 269 | def add_flops_mask_variable_or_reset(module): 270 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 271 | module.__mask__ = None 272 | 273 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.distributed as dist 11 | 12 | 13 | class AverageMeter(object): 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.avg = 0 19 | self.sum = 0 20 | self.cnt = 0 21 | 22 | def update(self, val, n=1): 23 | self.cur = val 24 | self.sum += val * n 25 | self.cnt += n 26 | self.avg = self.sum / self.cnt 27 | 28 | 29 | def accuracy(output, target, topk=(1, 5)): 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | 37 | res = [] 38 | for k in topk: 39 | correct_k = correct[:k].view(-1).float().sum(0) 40 | res.append(correct_k.mul_(100.0/batch_size)) 41 | return res 42 | 43 | 44 | def count_parameters_in_MB(model): 45 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "aux" not in name)/1e6 46 | 47 | 48 | def save_checkpoint(state, is_best, save): 49 | filename = os.path.join(save, 'checkpoint.pth.tar') 50 | torch.save(state, filename) 51 | if is_best: 52 | best_filename = os.path.join(save, 'model_best.pth.tar') 53 | shutil.copyfile(filename, best_filename) 54 | 55 | 56 | def save(model, model_path): 57 | torch.save(model.state_dict(), model_path) 58 | 59 | 60 | def load_net_config(path): 61 | with open(path, 'r') as f: 62 | net_config = '' 63 | while True: 64 | line = f.readline().strip() 65 | if 'net_type' in line: 66 | net_type = line.split(': ')[-1] 67 | break 68 | else: 69 | net_config += line 70 | return net_config, net_type 71 | 72 | 73 | def load_model(model, model_path, distributed=False): 74 | logging.info('Start loading the model from ' + model_path) 75 | if 'http' in model_path: 76 | model_addr = model_path 77 | model_path = model_path.split('/')[-1] 78 | if (not distributed) or (distributed and dist.get_rank()==0): 79 | if os.path.isfile(model_path): 80 | os.system('rm ' + model_path) 81 | os.system('wget -q ' + model_addr) 82 | if distributed: 83 | dist.barrier() 84 | model.load_state_dict(torch.load(model_path, 85 | map_location='cuda:{}'.format(dist.get_rank()) if distributed else None)) 86 | logging.info('Loading the model finished!') 87 | 88 | 89 | def create_exp_dir(path): 90 | if not os.path.exists(path): 91 | os.mkdir(path) 92 | 93 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.): 94 | """ 95 | Label smoothing implementation. 96 | This function is taken from https://github.com/MIT-HAN-LAB/ProxylessNAS/blob/master/proxyless_nas/utils.py 97 | """ 98 | 99 | logsoftmax = nn.LogSoftmax(dim=1).cuda() 100 | n_classes = pred.size(1) 101 | # convert to one-hot 102 | target = torch.unsqueeze(target, 1) 103 | soft_target = torch.zeros_like(pred) 104 | soft_target.scatter_(1, target, 1) 105 | # label smoothing 106 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 107 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 108 | 109 | 110 | def parse_net_config(net_config): 111 | str_configs = net_config.split('|') 112 | return [eval(str_config) for str_config in str_configs] 113 | 114 | 115 | def set_seed(seed): 116 | np.random.seed(seed) 117 | torch.manual_seed(seed) 118 | torch.cuda.manual_seed(seed) 119 | 120 | 121 | def set_logging(save_path, log_name='log.txt'): 122 | log_format = '%(asctime)s %(message)s' 123 | date_format = '%m/%d %H:%M:%S' 124 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 125 | format=log_format, datefmt=date_format) 126 | fh = logging.FileHandler(os.path.join(save_path, log_name)) 127 | fh.setFormatter(logging.Formatter(log_format, date_format)) 128 | logging.getLogger().addHandler(fh) 129 | 130 | 131 | def create_save_dir(save_path, job_name): 132 | if job_name != '': 133 | job_name = time.strftime("%Y%m%d-%H%M%S-") + job_name 134 | save_path = os.path.join(save_path, job_name) 135 | create_exp_dir(save_path) 136 | os.system('cp -r ./* '+save_path) 137 | save_path = os.path.join(save_path, 'output') 138 | create_exp_dir(save_path) 139 | else: 140 | save_path = os.path.join(save_path, 'output') 141 | create_exp_dir(save_path) 142 | return save_path, job_name 143 | 144 | 145 | def latency_measure(module, input_size, batch_size, meas_times, mode='gpu'): 146 | assert mode in ['gpu', 'cpu'] 147 | 148 | latency = [] 149 | module.eval() 150 | input_size = (batch_size,) + tuple(input_size) 151 | input_data = torch.randn(input_size) 152 | if mode=='gpu': 153 | input_data = input_data.cuda() 154 | module.cuda() 155 | 156 | for i in range(meas_times): 157 | with torch.no_grad(): 158 | start = time.time() 159 | _ = module(input_data) 160 | torch.cuda.synchronize() 161 | if i >= 100: 162 | latency.append(time.time() - start) 163 | # print(np.mean(latency) * 1e3, 'ms') 164 | return np.mean(latency) * 1e3 165 | 166 | 167 | def latency_measure_fw(module, input_data, meas_times): 168 | latency = [] 169 | module.eval() 170 | 171 | for i in range(meas_times): 172 | with torch.no_grad(): 173 | start = time.time() 174 | output_data = module(input_data) 175 | torch.cuda.synchronize() 176 | if i >= 100: 177 | latency.append(time.time() - start) 178 | # print(np.mean(latency) * 1e3, 'ms') 179 | return np.mean(latency) * 1e3, output_data 180 | 181 | 182 | def record_topk(k, rec_list, data, comp_attr, check_attr): 183 | def get_insert_idx(orig_list, data, comp_attr): 184 | start = 0 185 | end = len(orig_list) 186 | while start < end: 187 | mid = (start + end) // 2 188 | if data[comp_attr] < orig_list[mid][comp_attr]: 189 | start = mid + 1 190 | else: 191 | end = mid 192 | return start 193 | 194 | if_insert = False 195 | insert_idx = get_insert_idx(rec_list, data, comp_attr) 196 | if insert_idx < k: 197 | rec_list.insert(insert_idx, data) 198 | if_insert = True 199 | while len(rec_list) > k: 200 | rec_list.pop() 201 | return if_insert 202 | 203 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 204 | decay = [] 205 | no_decay = [] 206 | for name, param in model.named_parameters(): 207 | if not param.requires_grad: 208 | continue # frozen weights 209 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 210 | no_decay.append(param) 211 | else: 212 | decay.append(param) 213 | return [ 214 | {'params': no_decay, 'weight_decay': 0.}, 215 | {'params': decay, 'weight_decay': weight_decay}] 216 | --------------------------------------------------------------------------------