├── .gitignore
├── LICENSE
├── README.md
├── configs
├── __init__.py
└── train_cfg
│ ├── imagenet_train_base_cfg.py
│ ├── imagenet_train_ghostnet_cfg.py
│ ├── imagenet_train_mbv2_cfg.py
│ ├── imagenet_train_mbv3_cfg.py
│ └── imagenet_train_res_cfg.py
├── dataset
├── __init__.py
├── im2rec.py
├── imagenet_data_dali.py
├── mk_img_list.py
└── torchvision_extension.py
├── imgs
└── dali-speed-comp.png
├── model_zoo
├── __init__.py
├── ghostnet.py
└── mobilenetv3.py
├── models
├── __init__.py
├── model_derived.py
└── operations.py
├── run_apis
├── __init__.py
├── train_dist.py
└── trainer.py
├── scripts
└── train_dist.sh
└── tools
├── __init__.py
├── env.py
├── lr_scheduler.py
├── multadds_count.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dali-pytorch
2 | This is a PyTorch toolkit for accelerating ImageNet training based on the distributed mode with [NVIDIA DALI](https://github.com/NVIDIA/DALI) equipped. Though with extremely high training speed, the toolkit achieves similar or higher performance than that reported in the original papers.
3 |
4 | ## Main Requirements (recommended)
5 |
6 | * python >= 3.7.0
7 | * pytorch >= 1.2.0
8 | * CUDA >= 10.0
9 | * nvidia-dali-cuda100 >= 0.23.0
10 | * mmcv >= 1.0.5
11 | * mxnet >= 1.6.0 (only used for preparing the dataset)
12 |
13 | ## Train
14 | 1. As shown in the [official evaluation](https://developer.nvidia.com/DALI), the MXNet data format shows notable speed advantages. We use the MXNet `.rec` dataset format for training. Please prepare the ImageNet dataset as `.rec` following ["Create a Dataset Using RecordIO"](https://mxnet.apache.org/api/faq/recordio).
15 | 
16 |
17 | 2. Train your model with the following script. It is recommended to run the training code on 8 GPUs with a total batch size of 128x8. Taken MobileNetV2 training as an example,
18 | ```
19 | sh ./scripts/train_dist.sh 8 ./imagenet-rec imagenet_train_mbv2_cfg.py
20 | ```
21 |
22 |
23 | ## Results
24 | We evaluate our toolkit on several popular networks as follows, which achieves similar or higher performance than the original papers reported ones. All the experiments are performed on 8 TITAN-XP GPUs.
25 |
26 | Our pre-trained models and corresponding training logs can be downloaded at [DALI_MODEL_ZOO](https://drive.google.com/drive/folders/1-KHHupxYmZClxlnmrZ5X5ndObFcp98Lv?usp=sharing).
27 |
28 | | Model | Reported Top-1(%) | Top-1(%) | Epochs | Time w/dali |
29 | |----------------------|-------|---------|---------|---------|
30 | | ResNet18 | 69.76* | 72.15 | 240 | 16h |
31 | | MobileNetV2 | 72.0 | 72.94 | 240 | 1d4.5h |
32 | | MobileNetV3 | 75.2 | 75.07 | 400 | 2d1.5h |
33 | | GhostNet 1.0x | 73.9 | 73.97 | 400 | 1d21.5h |
34 |
35 | *\* denotes the result reported in the torchvision model zoo.*
36 | *"d": day; "h": hour.*
37 |
38 | ----------------------------------
39 |
40 | Comparisons between training w/ and w/o dali:
41 |
42 | | Model | Method | Top-1(%) | Time |
43 | |----------------------|-------|---------|---------|
44 | | ResNet18 | w/ dali | 72.15 | **16h** |
45 | | | w/o dali | 72.01 | 47.5h |
46 |
47 | *The w\o dali setting is based on the lmdb data format from the codebase of [DenseNAS](https://github.com/JaminFong/DenseNAS).*
48 |
49 | ## Acknowledgement
50 | This project is in part supported by [Horizon Robotics](https://en.horizon.ai).
51 |
52 | The code of some models is from [ghostnet](https://github.com/huawei-noah/ghostnet/blob/master/pytorch/ghostnet.py) and [mobilenetv3](https://github.com/d-li14/mobilenetv3.pytorch). Thanks for the contribution of these repositories.
53 |
54 | ## Citation
55 | This repository is partially based on [DenseNAS](https://github.com/JaminFong/DenseNAS). If you find this toolkit helpful in your research, please cite
56 | ```
57 | @inproceedings{fang2019densely,
58 | title={Densely connected search space for more flexible neural architecture search},
59 | author={Fang, Jiemin and Sun, Yuzhu and Zhang, Qian and Li, Yuan and Liu, Wenyu and Wang, Xinggang},
60 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
61 | year={2020}
62 | }
63 | ```
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 | from . import *
--------------------------------------------------------------------------------
/configs/train_cfg/imagenet_train_base_cfg.py:
--------------------------------------------------------------------------------
1 | model_zoo=dict(
2 | use_model_zoo = False,
3 | model_name = '',
4 | cfg=dict(
5 | init_mode='he_fout',
6 | bn_momentum=0.1,
7 | bn_eps=1e-3
8 | )
9 | )
10 |
11 | net_type='mbv2' # mbv2 / res
12 | net_config=""""""
13 |
14 | train_params=dict(
15 | epochs=240,
16 | use_seed=True,
17 | seed=0
18 | )
19 |
20 | optim=dict(
21 | init_lr=0.5,
22 | min_lr=1e-5,
23 | lr_schedule='cosine', # cosine poly
24 | momentum=0.9,
25 | weight_decay=4e-5,
26 | bn_wo_wd=False,
27 | label_smooth=True,
28 | smooth_alpha=0.1,
29 | use_grad_clip=False,
30 | grad_clip=10,
31 | if_resume=False,
32 | resume=dict(
33 | load_path='',
34 | load_epoch=191,
35 | ),
36 | use_warm_up=False,
37 | warm_up=dict(
38 | epoch=5,
39 | init_lr=1e-5,
40 | target_lr=0.5,
41 | ),
42 | use_multi_stage=False,
43 | multi_stage=dict(
44 | stage_epochs=330
45 | ),
46 | cosine=dict(
47 | use_restart=False,
48 | restart=dict(
49 | lr_period=[10, 20, 40, 80, 160, 320],
50 | lr_step=[0, 10, 30, 70, 150, 310],
51 | )
52 | )
53 | )
54 |
55 | data=dict(
56 | use_dali=True,
57 | num_threads=4,
58 | resize_batch=16, # 32 in default
59 | batch_size=128,
60 | dataset='imagenet', #imagenet
61 | train_pref='train',
62 | val_pref='val',
63 | num_examples=1281167,
64 | input_size=(3,224,224),
65 | # type_interp='INTERP_TRIANGULAR' # INTERP_TRIANGULAR INTERP_LINEAR
66 | min_filter='INTERP_TRIANGULAR', # INTERP_TRIANGULAR INTERP_LINEAR
67 | mag_filter='INTERP_LANCZOS3', # INTERP_TRIANGULAR INTERP_LINEAR
68 | random_sized=dict(
69 | min_scale=0.08
70 | ),
71 | mean=[0.485, 0.456, 0.406],
72 | std=[0.229, 0.224, 0.225],
73 | color=False,
74 | val_shuffle=False
75 | )
76 |
--------------------------------------------------------------------------------
/configs/train_cfg/imagenet_train_ghostnet_cfg.py:
--------------------------------------------------------------------------------
1 | _base_ = './imagenet_train_base_cfg.py'
2 |
3 | model_zoo=dict(
4 | use_model_zoo = True,
5 | model_name = 'ghostnet',
6 | )
7 |
8 | train_params=dict(
9 | epochs=400,
10 | )
11 |
12 | optim=dict(
13 | weight_decay=4e-5,
14 | bn_wo_wd=True,
15 | if_resume=False,
16 | resume=dict(
17 | load_path='',
18 | load_epoch=191,
19 | ),
20 | use_warm_up=True,
21 | warm_up=dict(
22 | epoch=5,
23 | init_lr=1e-5,
24 | target_lr=0.5,
25 | ),
26 | )
27 |
28 | data=dict(
29 | num_threads=4,
30 | resize_batch=16, # 32 in default
31 | batch_size=128,
32 | color=True,
33 | )
--------------------------------------------------------------------------------
/configs/train_cfg/imagenet_train_mbv2_cfg.py:
--------------------------------------------------------------------------------
1 | _base_ = './imagenet_train_base_cfg.py'
2 |
3 | model_zoo=dict(
4 | use_model_zoo = False,
5 | model_name = 'mobilenet_v2',
6 | cfg = None
7 | )
8 |
9 | net_type='mbv2' # mbv2 / res
10 | net_config="""[[32, 16], 'mbconv_k3_t1', [], 0, 1]|
11 | [[16, 24], 'mbconv_k3_t6', ['mbconv_k3_t6'], 1, 2]|
12 | [[24, 32], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6'], 2, 2]|
13 | [[32, 64], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6', 'mbconv_k3_t6'], 3, 2]|
14 | [[64, 96], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6'], 2, 1]|
15 | [[96, 160], 'mbconv_k3_t6', ['mbconv_k3_t6', 'mbconv_k3_t6'], 2, 2]|
16 | [[160, 320], 'mbconv_k3_t6', [], 3, 1]|
17 | [[320, 1280], 'conv1_1']"""
18 |
19 | train_params=dict(
20 | epochs=240,
21 | )
22 |
23 | optim=dict(
24 | weight_decay=4e-5,
25 | bn_wo_wd=False,
26 | if_resume=False,
27 | resume=dict(
28 | load_path='',
29 | load_epoch=191,
30 | ),
31 | use_warm_up=False,
32 | warm_up=dict(
33 | epoch=5,
34 | init_lr=1e-5,
35 | target_lr=0.5,
36 | ),
37 | use_se=False
38 | )
39 |
40 | data=dict(
41 | num_threads=4,
42 | resize_batch=16, # 32 in default
43 | batch_size=128,
44 | color=False,
45 | )
--------------------------------------------------------------------------------
/configs/train_cfg/imagenet_train_mbv3_cfg.py:
--------------------------------------------------------------------------------
1 | _base_ = './imagenet_train_base_cfg.py'
2 |
3 | model_zoo=dict(
4 | use_model_zoo = True,
5 | model_name = 'mobilenetv3_large',
6 | )
7 |
8 | train_params=dict(
9 | epochs=400,
10 | )
11 |
12 | optim=dict(
13 | weight_decay=4e-5,
14 | bn_wo_wd=True,
15 | if_resume=False,
16 | resume=dict(
17 | load_path='',
18 | load_epoch=191,
19 | ),
20 | use_warm_up=True,
21 | warm_up=dict(
22 | epoch=5,
23 | init_lr=1e-5,
24 | target_lr=0.5,
25 | ),
26 | )
27 |
28 | data=dict(
29 | num_threads=4,
30 | resize_batch=16, # 32 in default
31 | batch_size=128,
32 | color=True,
33 | )
--------------------------------------------------------------------------------
/configs/train_cfg/imagenet_train_res_cfg.py:
--------------------------------------------------------------------------------
1 | _base_ = './imagenet_train_base_cfg.py'
2 |
3 | model_zoo=dict(
4 | use_model_zoo = True,
5 | model_name = 'resnet18',
6 | cfg = None
7 | )
8 |
9 | train_params=dict(
10 | epochs=240,
11 | )
12 |
13 | optim=dict(
14 | weight_decay=4e-5,
15 | bn_wo_wd=False,
16 | if_resume=False,
17 | resume=dict(
18 | load_path='',
19 | load_epoch=191,
20 | ),
21 | use_warm_up=False,
22 | warm_up=dict(
23 | epoch=5,
24 | init_lr=1e-5,
25 | target_lr=0.5,
26 | ),
27 | )
28 |
29 | data=dict(
30 | num_threads=4,
31 | resize_batch=16, # 32 in default
32 | batch_size=128,
33 | color=False,
34 | )
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/im2rec.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Licensed to the Apache Software Foundation (ASF) under one
4 | # or more contributor license agreements. See the NOTICE file
5 | # distributed with this work for additional information
6 | # regarding copyright ownership. The ASF licenses this file
7 | # to you under the Apache License, Version 2.0 (the
8 | # "License"); you may not use this file except in compliance
9 | # with the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing,
14 | # software distributed under the License is distributed on an
15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 | # KIND, either express or implied. See the License for the
17 | # specific language governing permissions and limitations
18 | # under the License.
19 |
20 | from __future__ import print_function
21 | import os
22 | import sys
23 |
24 | curr_path = os.path.abspath(os.path.dirname(__file__))
25 | sys.path.append(os.path.join(curr_path, "../python"))
26 | import mxnet as mx
27 | import random
28 | import argparse
29 | import cv2
30 | import time
31 | import traceback
32 |
33 | try:
34 | import multiprocessing
35 | except ImportError:
36 | multiprocessing = None
37 |
38 | def list_image(root, recursive, exts):
39 | """Traverses the root of directory that contains images and
40 | generates image list iterator.
41 | Parameters
42 | ----------
43 | root: string
44 | recursive: bool
45 | exts: string
46 | Returns
47 | -------
48 | image iterator that contains all the image under the specified path
49 | """
50 |
51 | i = 0
52 | if recursive:
53 | cat = {}
54 | for path, dirs, files in os.walk(root, followlinks=True):
55 | dirs.sort()
56 | files.sort()
57 | for fname in files:
58 | fpath = os.path.join(path, fname)
59 | suffix = os.path.splitext(fname)[1].lower()
60 | if os.path.isfile(fpath) and (suffix in exts):
61 | if path not in cat:
62 | cat[path] = len(cat)
63 | yield (i, os.path.relpath(fpath, root), cat[path])
64 | i += 1
65 | for k, v in sorted(cat.items(), key=lambda x: x[1]):
66 | print(os.path.relpath(k, root), v)
67 | else:
68 | for fname in sorted(os.listdir(root)):
69 | fpath = os.path.join(root, fname)
70 | suffix = os.path.splitext(fname)[1].lower()
71 | if os.path.isfile(fpath) and (suffix in exts):
72 | yield (i, os.path.relpath(fpath, root), 0)
73 | i += 1
74 |
75 | def write_list(path_out, image_list):
76 | """Hepler function to write image list into the file.
77 | The format is as below,
78 | integer_image_index \t float_label_index \t path_to_image
79 | Note that the blank between number and tab is only used for readability.
80 | Parameters
81 | ----------
82 | path_out: string
83 | image_list: list
84 | """
85 | with open(path_out, 'w') as fout:
86 | for i, item in enumerate(image_list):
87 | line = '%d\t' % item[0]
88 | for j in item[2:]:
89 | line += '%f\t' % j
90 | line += '%s\n' % item[1]
91 | fout.write(line)
92 |
93 | def make_list(args):
94 | """Generates .lst file.
95 | Parameters
96 | ----------
97 | args: object that contains all the arguments
98 | """
99 | image_list = list_image(args.root, args.recursive, args.exts)
100 | image_list = list(image_list)
101 | if args.shuffle is True:
102 | random.seed(100)
103 | random.shuffle(image_list)
104 | N = len(image_list)
105 | chunk_size = (N + args.chunks - 1) // args.chunks
106 | for i in range(args.chunks):
107 | chunk = image_list[i * chunk_size:(i + 1) * chunk_size]
108 | if args.chunks > 1:
109 | str_chunk = '_%d' % i
110 | else:
111 | str_chunk = ''
112 | sep = int(chunk_size * args.train_ratio)
113 | sep_test = int(chunk_size * args.test_ratio)
114 | if args.train_ratio == 1.0:
115 | write_list(args.prefix + str_chunk + '.lst', chunk)
116 | else:
117 | if args.test_ratio:
118 | write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
119 | if args.train_ratio + args.test_ratio < 1.0:
120 | write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
121 | write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])
122 |
123 | def read_list(path_in):
124 | """Reads the .lst file and generates corresponding iterator.
125 | Parameters
126 | ----------
127 | path_in: string
128 | Returns
129 | -------
130 | item iterator that contains information in .lst file
131 | """
132 | with open(path_in) as fin:
133 | while True:
134 | line = fin.readline()
135 | if not line:
136 | break
137 | line = [i.strip() for i in line.strip().split('\t')]
138 | line_len = len(line)
139 | # check the data format of .lst file
140 | if line_len < 3:
141 | print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line))
142 | continue
143 | try:
144 | item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
145 | except Exception as e:
146 | print('Parsing lst met error for %s, detail: %s' % (line, e))
147 | continue
148 | yield item
149 |
150 | def image_encode(args, i, item, q_out):
151 | """Reads, preprocesses, packs the image and put it back in output queue.
152 | Parameters
153 | ----------
154 | args: object
155 | i: int
156 | item: list
157 | q_out: queue
158 | """
159 | fullpath = os.path.join(args.root, item[1])
160 |
161 | if len(item) > 3 and args.pack_label:
162 | header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
163 | else:
164 | header = mx.recordio.IRHeader(0, item[2], item[0], 0)
165 |
166 | if args.pass_through:
167 | try:
168 | with open(fullpath, 'rb') as fin:
169 | img = fin.read()
170 | s = mx.recordio.pack(header, img)
171 | q_out.put((i, s, item))
172 | except Exception as e:
173 | traceback.print_exc()
174 | print('pack_img error:', item[1], e)
175 | q_out.put((i, None, item))
176 | return
177 |
178 | try:
179 | img = cv2.imread(fullpath, args.color)
180 | except:
181 | traceback.print_exc()
182 | print('imread error trying to load file: %s ' % fullpath)
183 | q_out.put((i, None, item))
184 | return
185 | if img is None:
186 | print('imread read blank (None) image for file: %s' % fullpath)
187 | q_out.put((i, None, item))
188 | return
189 | if args.center_crop:
190 | if img.shape[0] > img.shape[1]:
191 | margin = (img.shape[0] - img.shape[1]) // 2
192 | img = img[margin:margin + img.shape[1], :]
193 | else:
194 | margin = (img.shape[1] - img.shape[0]) // 2
195 | img = img[:, margin:margin + img.shape[0]]
196 | if args.resize:
197 | if img.shape[0] > img.shape[1]:
198 | newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
199 | else:
200 | newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
201 | img = cv2.resize(img, newsize)
202 |
203 | try:
204 | s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
205 | q_out.put((i, s, item))
206 | except Exception as e:
207 | traceback.print_exc()
208 | print('pack_img error on file: %s' % fullpath, e)
209 | q_out.put((i, None, item))
210 | return
211 |
212 | def read_worker(args, q_in, q_out):
213 | """Function that will be spawned to fetch the image
214 | from the input queue and put it back to output queue.
215 | Parameters
216 | ----------
217 | args: object
218 | q_in: queue
219 | q_out: queue
220 | """
221 | while True:
222 | deq = q_in.get()
223 | if deq is None:
224 | break
225 | i, item = deq
226 | image_encode(args, i, item, q_out)
227 |
228 | def write_worker(q_out, fname, working_dir):
229 | """Function that will be spawned to fetch processed image
230 | from the output queue and write to the .rec file.
231 | Parameters
232 | ----------
233 | q_out: queue
234 | fname: string
235 | working_dir: string
236 | """
237 | pre_time = time.time()
238 | count = 0
239 | fname = os.path.basename(fname)
240 | fname_rec = os.path.splitext(fname)[0] + '.rec'
241 | fname_idx = os.path.splitext(fname)[0] + '.idx'
242 | record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
243 | os.path.join(working_dir, fname_rec), 'w')
244 | buf = {}
245 | more = True
246 | while more:
247 | deq = q_out.get()
248 | if deq is not None:
249 | i, s, item = deq
250 | buf[i] = (s, item)
251 | else:
252 | more = False
253 | while count in buf:
254 | s, item = buf[count]
255 | del buf[count]
256 | if s is not None:
257 | record.write_idx(item[0], s)
258 |
259 | if count % 1000 == 0:
260 | cur_time = time.time()
261 | print('time:', cur_time - pre_time, ' count:', count)
262 | pre_time = cur_time
263 | count += 1
264 |
265 | def parse_args():
266 | """Defines all arguments.
267 | Returns
268 | -------
269 | args object that contains all the params
270 | """
271 | parser = argparse.ArgumentParser(
272 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
273 | description='Create an image list or \
274 | make a record database by reading from an image list')
275 | parser.add_argument('prefix', help='prefix of input/output lst and rec files.')
276 | parser.add_argument('root', help='path to folder containing images.')
277 |
278 | cgroup = parser.add_argument_group('Options for creating image lists')
279 | cgroup.add_argument('--list', action='store_true',
280 | help='If this is set im2rec will create image list(s) by traversing root folder\
281 | and output to .lst.\
282 | Otherwise im2rec will read .lst and create a database at .rec')
283 | cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],
284 | help='list of acceptable image extensions.')
285 | cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
286 | cgroup.add_argument('--train-ratio', type=float, default=1.0,
287 | help='Ratio of images to use for training.')
288 | cgroup.add_argument('--test-ratio', type=float, default=0,
289 | help='Ratio of images to use for testing.')
290 | cgroup.add_argument('--recursive', action='store_true',
291 | help='If true recursively walk through subdirs and assign an unique label\
292 | to images in each folder. Otherwise only include images in the root folder\
293 | and give them label 0.')
294 | cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false',
295 | help='If this is passed, \
296 | im2rec will not randomize the image order in .lst')
297 | rgroup = parser.add_argument_group('Options for creating database')
298 | rgroup.add_argument('--pass-through', action='store_true',
299 | help='whether to skip transformation and save image as is')
300 | rgroup.add_argument('--resize', type=int, default=0,
301 | help='resize the shorter edge of image to the newsize, original images will\
302 | be packed by default.')
303 | rgroup.add_argument('--center-crop', action='store_true',
304 | help='specify whether to crop the center image to make it rectangular.')
305 | rgroup.add_argument('--quality', type=int, default=95,
306 | help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')
307 | rgroup.add_argument('--num-thread', type=int, default=1,
308 | help='number of thread to use for encoding. order of images will be different\
309 | from the input list if >1. the input list will be modified to match the\
310 | resulting order.')
311 | rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],
312 | help='specify the color mode of the loaded image.\
313 | 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
314 | 0: Loads image in grayscale mode.\
315 | -1:Loads image as such including alpha channel.')
316 | rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
317 | help='specify the encoding of the images.')
318 | rgroup.add_argument('--pack-label', action='store_true',
319 | help='Whether to also pack multi dimensional label in the record file')
320 | args = parser.parse_args()
321 | args.prefix = os.path.abspath(args.prefix)
322 | args.root = os.path.abspath(args.root)
323 | return args
324 |
325 | if __name__ == '__main__':
326 | args = parse_args()
327 | # if the '--list' is used, it generates .lst file
328 | if args.list:
329 | make_list(args)
330 | # otherwise read .lst file to generates .rec file
331 | else:
332 | if os.path.isdir(args.prefix):
333 | working_dir = args.prefix
334 | else:
335 | working_dir = os.path.dirname(args.prefix)
336 | files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
337 | if os.path.isfile(os.path.join(working_dir, fname))]
338 | count = 0
339 | for fname in files:
340 | if fname.startswith(args.prefix) and fname.endswith('.lst'):
341 | print('Creating .rec file from', fname, 'in', working_dir)
342 | count += 1
343 | image_list = read_list(fname)
344 | if args.shuffle:
345 | image_list = list(image_list)
346 | random.seed(100)
347 | random.shuffle(image_list)
348 | # -- write_record -- #
349 | if args.num_thread > 1 and multiprocessing is not None:
350 | q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
351 | q_out = multiprocessing.Queue(1024)
352 | # define the process
353 | read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
354 | for i in range(args.num_thread)]
355 | # process images with num_thread process
356 | for p in read_process:
357 | p.start()
358 | # only use one process to write .rec to avoid race-condtion
359 | write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
360 | write_process.start()
361 | # put the image list into input queue
362 | for i, item in enumerate(image_list):
363 | q_in[i % len(q_in)].put((i, item))
364 | for q in q_in:
365 | q.put(None)
366 | for p in read_process:
367 | p.join()
368 |
369 | q_out.put(None)
370 | write_process.join()
371 | else:
372 | print('multiprocessing not available, fall back to single threaded encoding')
373 | try:
374 | import Queue as queue
375 | except ImportError:
376 | import queue
377 | q_out = queue.Queue()
378 | fname = os.path.basename(fname)
379 | fname_rec = os.path.splitext(fname)[0] + '.rec'
380 | fname_idx = os.path.splitext(fname)[0] + '.idx'
381 | record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
382 | os.path.join(working_dir, fname_rec), 'w')
383 | cnt = 0
384 | pre_time = time.time()
385 | for i, item in enumerate(image_list):
386 | image_encode(args, i, item, q_out)
387 | if q_out.empty():
388 | continue
389 | _, s, _ = q_out.get()
390 | record.write_idx(item[0], s)
391 | if cnt % 1000 == 0:
392 | cur_time = time.time()
393 | print('time:', cur_time - pre_time, ' count:', cnt)
394 | pre_time = cur_time
395 | cnt += 1
396 | if not count:
397 | print('Did not find and list file with prefix %s'%args.prefix)
--------------------------------------------------------------------------------
/dataset/imagenet_data_dali.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | try:
3 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator
4 | from nvidia.dali.pipeline import Pipeline
5 | import nvidia.dali.ops as ops
6 | import nvidia.dali.types as types
7 | except ImportError:
8 | raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.")
9 |
10 | class HybridTrainPipe(Pipeline):
11 | def __init__(self, batch_size, world_size, device_id, data_dir,
12 | data_cfg):
13 | super(HybridTrainPipe, self).__init__(batch_size,
14 | data_cfg.num_threads,
15 | device_id,
16 | seed=12 + device_id)
17 | self.input = ops.MXNetReader(path = osp.join(data_dir, data_cfg.train_pref+".rec"),
18 | index_path = osp.join(data_dir, data_cfg.train_pref+".idx"),
19 | random_shuffle = True,
20 | shard_id = device_id,
21 | num_shards = world_size)
22 | #let user decide which pipeline works him bets for RN version he runs
23 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
24 | self.rrc = ops.RandomResizedCrop(device='gpu',
25 | size=data_cfg.input_size[1],
26 | random_area=[data_cfg.random_sized.min_scale, 1.0],
27 | # interp_type=getattr(types, data_cfg.type_interp),
28 | min_filter=getattr(types, data_cfg.min_filter),
29 | mag_filter=getattr(types, data_cfg.mag_filter),
30 | minibatch_size=data_cfg.resize_batch)
31 | if data_cfg.color:
32 | self.color = ops.ColorTwist(device="gpu")
33 | self.bright = ops.Uniform(range=[0.6, 1.4])
34 | self.cont = ops.Uniform(range=[0.6, 1.4])
35 | self.sat = ops.Uniform(range=[0.6, 1.4])
36 | self.cmnp = ops.CropMirrorNormalize(device="gpu",
37 | output_layout=types.NCHW,
38 | mean=[0.485 * 255,0.456 * 255,0.406 * 255],
39 | std=[0.229 * 255,0.224 * 255,0.225 * 255])
40 | self.coin = ops.CoinFlip(probability=0.5)
41 |
42 | def define_graph(self):
43 | rng = self.coin()
44 | self.jpegs, self.labels = self.input(name="Reader")
45 | images = self.decode(self.jpegs)
46 | images = self.rrc(images)
47 | if hasattr(self, 'color'):
48 | images = self.color(images,
49 | brightness=self.bright(),
50 | contrast=self.cont(),
51 | saturation=self.sat())
52 | output = self.cmnp(images.gpu(), mirror=rng)
53 | return [output, self.labels]
54 |
55 | class HybridValPipe(Pipeline):
56 | def __init__(self, batch_size, world_size, device_id, data_dir, data_cfg):
57 | super(HybridValPipe, self).__init__(batch_size,
58 | data_cfg.num_threads,
59 | device_id,
60 | seed=12 + device_id)
61 | self.input = ops.MXNetReader(path = osp.join(data_dir, data_cfg.val_pref+".rec"),
62 | index_path = osp.join(data_dir, data_cfg.val_pref+".idx"),
63 | random_shuffle = data_cfg.val_shuffle,
64 | shard_id = device_id,
65 | num_shards = world_size)
66 | self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
67 | self.res = ops.Resize(device="gpu",
68 | resize_shorter=256,
69 | # interp_type=getattr(types, data_cfg.type_interp),
70 | min_filter=getattr(types, data_cfg.min_filter),
71 | mag_filter=getattr(types, data_cfg.mag_filter),
72 | minibatch_size=data_cfg.resize_batch)
73 | self.cmnp = ops.CropMirrorNormalize(device="gpu",
74 | output_layout=types.NCHW,
75 | crop=(data_cfg.input_size[1],)*2,
76 | mean=[0.485 * 255,0.456 * 255,0.406 * 255],
77 | std=[0.229 * 255,0.224 * 255,0.225 * 255])
78 |
79 | def define_graph(self):
80 | self.jpegs, self.labels = self.input(name="Reader")
81 | images = self.decode(self.jpegs)
82 | images = self.res(images)
83 | output = self.cmnp(images)
84 | return [output, self.labels]
85 |
86 | class DALIClassificationIterator_(DALIClassificationIterator):
87 | def next_batch(self):
88 | try:
89 | data = self.__next__()
90 | input = data[0]["data"]
91 | target = data[0]["label"].squeeze().cuda().long()
92 | except StopIteration:
93 | input, target = None, None
94 | return input, target
95 |
96 |
97 | def get_train_loader(batch_size, world_size, device_id,
98 | data_dir, data_cfg):
99 | pipe = HybridTrainPipe(batch_size=batch_size,
100 | world_size=world_size,
101 | device_id=device_id,
102 | data_dir=data_dir,
103 | data_cfg=data_cfg,)
104 | pipe.build()
105 | train_loader = DALIClassificationIterator_(pipe, reader_name="Reader",
106 | fill_last_batch=False, auto_reset=True)
107 | return train_loader
108 |
109 |
110 | def get_val_loader(batch_size, world_size, device_id,
111 | data_dir, data_cfg):
112 | pipe = HybridValPipe(batch_size=batch_size,
113 | world_size=world_size,
114 | device_id=device_id,
115 | data_dir=data_dir,
116 | data_cfg=data_cfg)
117 | pipe.build()
118 | val_loader = DALIClassificationIterator_(pipe, reader_name="Reader",
119 | fill_last_batch=False, auto_reset=True)
120 | return val_loader
121 |
122 | def get_data_loader(data_cfg, world_size, device_id, data_dir):
123 | train_loader = get_train_loader(data_cfg.batch_size,
124 | world_size, device_id,
125 | data_dir, data_cfg)
126 | val_loader = get_val_loader(data_cfg.batch_size,
127 | world_size, device_id,
128 | data_dir, data_cfg)
129 | return train_loader, val_loader
--------------------------------------------------------------------------------
/dataset/mk_img_list.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | def get_list(data_path, output_path):
5 | for split in os.listdir(data_path):
6 | split_path = os.path.join(data_path, split)
7 | if not os.path.isdir(split_path):
8 | continue
9 | f = open(os.path.join(output_path, split + '_datalist'), 'a+')
10 | for sub in os.listdir(split_path):
11 | sub_path = os.path.join(split_path, sub)
12 | if not os.path.isdir(sub_path):
13 | continue
14 | for image in os.listdir(sub_path):
15 | image_name = sub + '/' + image
16 | f.writelines(image_name + '\n')
17 | f.close()
18 |
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser("Params")
22 | parser.add_argument('--image_path', type=str, default='', help='the path of the images')
23 | parser.add_argument('--output_path', type=str, default='', help='the output path of the lmdb file')
24 | args = parser.parse_args()
25 |
26 | get_list(args.image_path, args.output_path)
27 |
--------------------------------------------------------------------------------
/dataset/torchvision_extension.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | from torchvision.transforms import functional as F
4 |
5 | #In this file some more transformations (apart from the ones defined in torchvision.transform)
6 | #are added. Particularly helpful to train imagenet, and in the style of the transforms
7 | #used by fb.resnet https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua
8 |
9 | #This file is taken from a proposed pull request on the torchvision github project.
10 | #At the moment this pull request has not been accepted yet, that is why I report it here.
11 | #Link to the pull request: https://github.com/pytorch/vision/pull/27/files
12 |
13 | class Lighting(object):
14 |
15 | """Lighting noise(AlexNet - style PCA - based noise)"""
16 |
17 | def __init__(self, alphastd, eigval, eigvec):
18 | self.alphastd = alphastd
19 | self.eigval = eigval
20 | self.eigvec = eigvec
21 |
22 | def __call__(self, img):
23 | # img is supposed go be a torch tensor
24 |
25 | if self.alphastd == 0:
26 | return img
27 |
28 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
29 | rgb = self.eigvec.type_as(img).clone()\
30 | .mul(alpha.view(1, 3).expand(3, 3))\
31 | .mul(self.eigval.view(1, 3).expand(3, 3))\
32 | .sum(1).squeeze()
33 |
34 | return img.add(rgb.view(3, 1, 1).expand_as(img))
35 |
36 |
37 | class RandomScale(object):
38 |
39 | """ResNet style data augmentation"""
40 |
41 | def __init__(self, minSize, maxSize):
42 | self.minSize = minSize
43 | self.maxSize = maxSize
44 |
45 | def __call__(self, img):
46 |
47 | targetSz = int(round(random.uniform(self.minSize, self.maxSize)))
48 |
49 | return F.resize(img, targetSz)
50 |
51 |
--------------------------------------------------------------------------------
/imgs/dali-speed-comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/imgs/dali-speed-comp.png
--------------------------------------------------------------------------------
/model_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .ghostnet import ghostnet
2 | from .mobilenetv3 import mobilenetv3_large
3 | from .mobilenetv3 import mobilenetv3_small
4 |
5 | from torchvision.models import mobilenet_v2
6 | from torchvision.models import resnet18, resnet34, resnet50
--------------------------------------------------------------------------------
/model_zoo/ghostnet.py:
--------------------------------------------------------------------------------
1 | # 2020.06.09-Changed for building GhostNet
2 | # Huawei Technologies Co., Ltd.
3 | """
4 | From https://github.com/huawei-noah/ghostnet/blob/master/pytorch/ghostnet.py
5 | ---------------------------------------------
6 | Creates a GhostNet Model as defined in:
7 | GhostNet: More Features from Cheap Operations By Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu.
8 | https://arxiv.org/abs/1911.11907
9 | Modified from https://github.com/d-li14/mobilenetv3.pytorch and https://github.com/rwightman/pytorch-image-models
10 | """
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import math
15 |
16 |
17 | __all__ = ['ghost_net']
18 |
19 |
20 | def _make_divisible(v, divisor, min_value=None):
21 | """
22 | This function is taken from the original tf repo.
23 | It ensures that all layers have a channel number that is divisible by 8
24 | It can be seen here:
25 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
26 | """
27 | if min_value is None:
28 | min_value = divisor
29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30 | # Make sure that round down does not go down by more than 10%.
31 | if new_v < 0.9 * v:
32 | new_v += divisor
33 | return new_v
34 |
35 |
36 | def hard_sigmoid(x, inplace: bool = False):
37 | if inplace:
38 | return x.add_(3.).clamp_(0., 6.).div_(6.)
39 | else:
40 | return F.relu6(x + 3.) / 6.
41 |
42 |
43 | class SqueezeExcite(nn.Module):
44 | def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
45 | act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_):
46 | super(SqueezeExcite, self).__init__()
47 | self.gate_fn = gate_fn
48 | reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
49 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
50 | self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
51 | self.act1 = act_layer(inplace=True)
52 | self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
53 |
54 | def forward(self, x):
55 | x_se = self.avg_pool(x)
56 | x_se = self.conv_reduce(x_se)
57 | x_se = self.act1(x_se)
58 | x_se = self.conv_expand(x_se)
59 | x = x * self.gate_fn(x_se)
60 | return x
61 |
62 |
63 | class ConvBnAct(nn.Module):
64 | def __init__(self, in_chs, out_chs, kernel_size,
65 | stride=1, act_layer=nn.ReLU):
66 | super(ConvBnAct, self).__init__()
67 | self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False)
68 | self.bn1 = nn.BatchNorm2d(out_chs)
69 | self.act1 = act_layer(inplace=True)
70 |
71 | def forward(self, x):
72 | x = self.conv(x)
73 | x = self.bn1(x)
74 | x = self.act1(x)
75 | return x
76 |
77 |
78 | class GhostModule(nn.Module):
79 | def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
80 | super(GhostModule, self).__init__()
81 | self.oup = oup
82 | init_channels = math.ceil(oup / ratio)
83 | new_channels = init_channels*(ratio-1)
84 |
85 | self.primary_conv = nn.Sequential(
86 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
87 | nn.BatchNorm2d(init_channels),
88 | nn.ReLU(inplace=True) if relu else nn.Sequential(),
89 | )
90 |
91 | self.cheap_operation = nn.Sequential(
92 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
93 | nn.BatchNorm2d(new_channels),
94 | nn.ReLU(inplace=True) if relu else nn.Sequential(),
95 | )
96 |
97 | def forward(self, x):
98 | x1 = self.primary_conv(x)
99 | x2 = self.cheap_operation(x1)
100 | out = torch.cat([x1,x2], dim=1)
101 | return out[:,:self.oup,:,:]
102 |
103 |
104 | class GhostBottleneck(nn.Module):
105 | """ Ghost bottleneck w/ optional SE"""
106 |
107 | def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
108 | stride=1, act_layer=nn.ReLU, se_ratio=0.):
109 | super(GhostBottleneck, self).__init__()
110 | has_se = se_ratio is not None and se_ratio > 0.
111 | self.stride = stride
112 |
113 | # Point-wise expansion
114 | self.ghost1 = GhostModule(in_chs, mid_chs, relu=True)
115 |
116 | # Depth-wise convolution
117 | if self.stride > 1:
118 | self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride,
119 | padding=(dw_kernel_size-1)//2,
120 | groups=mid_chs, bias=False)
121 | self.bn_dw = nn.BatchNorm2d(mid_chs)
122 |
123 | # Squeeze-and-excitation
124 | if has_se:
125 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio)
126 | else:
127 | self.se = None
128 |
129 | # Point-wise linear projection
130 | self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)
131 |
132 | # shortcut
133 | if (in_chs == out_chs and self.stride == 1):
134 | self.shortcut = nn.Sequential()
135 | else:
136 | self.shortcut = nn.Sequential(
137 | nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride,
138 | padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
139 | nn.BatchNorm2d(in_chs),
140 | nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
141 | nn.BatchNorm2d(out_chs),
142 | )
143 |
144 |
145 | def forward(self, x):
146 | residual = x
147 |
148 | # 1st ghost bottleneck
149 | x = self.ghost1(x)
150 |
151 | # Depth-wise convolution
152 | if self.stride > 1:
153 | x = self.conv_dw(x)
154 | x = self.bn_dw(x)
155 |
156 | # Squeeze-and-excitation
157 | if self.se is not None:
158 | x = self.se(x)
159 |
160 | # 2nd ghost bottleneck
161 | x = self.ghost2(x)
162 |
163 | x += self.shortcut(residual)
164 | return x
165 |
166 |
167 | class GhostNet(nn.Module):
168 | def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2,
169 | init_mode='he_fout', bn_momentum=0.1, bn_eps=1e-3):
170 | super(GhostNet, self).__init__()
171 | # setting of inverted residual blocks
172 | self.cfgs = cfgs
173 | self.dropout = dropout
174 |
175 | # building first layer
176 | output_channel = _make_divisible(16 * width, 4)
177 | self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False)
178 | self.bn1 = nn.BatchNorm2d(output_channel)
179 | self.act1 = nn.ReLU(inplace=True)
180 | input_channel = output_channel
181 |
182 | # building inverted residual blocks
183 | stages = []
184 | block = GhostBottleneck
185 | for cfg in self.cfgs:
186 | layers = []
187 | for k, exp_size, c, se_ratio, s in cfg:
188 | output_channel = _make_divisible(c * width, 4)
189 | hidden_channel = _make_divisible(exp_size * width, 4)
190 | layers.append(block(input_channel, hidden_channel, output_channel, k, s,
191 | se_ratio=se_ratio))
192 | input_channel = output_channel
193 | stages.append(nn.Sequential(*layers))
194 |
195 | output_channel = _make_divisible(exp_size * width, 4)
196 | stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
197 | input_channel = output_channel
198 |
199 | self.blocks = nn.Sequential(*stages)
200 |
201 | # building last several layers
202 | output_channel = 1280
203 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
204 | self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
205 | self.act2 = nn.ReLU(inplace=True)
206 | self.classifier = nn.Linear(output_channel, num_classes)
207 | self._initialize_weights()
208 | self.set_bn_param(bn_momentum, bn_eps)
209 |
210 | def forward(self, x):
211 | x = self.conv_stem(x)
212 | x = self.bn1(x)
213 | x = self.act1(x)
214 | x = self.blocks(x)
215 | x = self.global_pool(x)
216 | x = self.conv_head(x)
217 | x = self.act2(x)
218 | x = x.view(x.size(0), -1)
219 | if self.dropout > 0.:
220 | x = F.dropout(x, p=self.dropout, training=self.training)
221 | x = self.classifier(x)
222 | return x
223 |
224 | def _initialize_weights(self):
225 | for m in self.modules():
226 | if isinstance(m, nn.Conv2d):
227 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
228 | elif isinstance(m, nn.BatchNorm2d):
229 | m.weight.data.fill_(1)
230 | m.bias.data.zero_()
231 |
232 | def set_bn_param(self, bn_momentum, bn_eps):
233 | for m in self.modules():
234 | if isinstance(m, nn.BatchNorm2d):
235 | m.momentum = bn_momentum
236 | m.eps = bn_eps
237 | return
238 |
239 |
240 | def ghostnet(**kwargs):
241 | """
242 | Constructs a GhostNet model
243 | """
244 | cfgs = [
245 | # k, t, c, SE, s
246 | # stage1
247 | [[3, 16, 16, 0, 1]],
248 | # stage2
249 | [[3, 48, 24, 0, 2]],
250 | [[3, 72, 24, 0, 1]],
251 | # stage3
252 | [[5, 72, 40, 0.25, 2]],
253 | [[5, 120, 40, 0.25, 1]],
254 | # stage4
255 | [[3, 240, 80, 0, 2]],
256 | [[3, 200, 80, 0, 1],
257 | [3, 184, 80, 0, 1],
258 | [3, 184, 80, 0, 1],
259 | [3, 480, 112, 0.25, 1],
260 | [3, 672, 112, 0.25, 1]
261 | ],
262 | # stage5
263 | [[5, 672, 160, 0.25, 2]],
264 | [[5, 960, 160, 0, 1],
265 | [5, 960, 160, 0.25, 1],
266 | [5, 960, 160, 0, 1],
267 | [5, 960, 160, 0.25, 1]
268 | ]
269 | ]
270 | return GhostNet(cfgs, **kwargs)
271 |
272 |
273 | if __name__=='__main__':
274 | model = ghostnet()
275 | model.eval()
276 | print(model)
277 | input = torch.randn(32,3,320,256)
278 | y = model(input)
279 | print(y.size())
--------------------------------------------------------------------------------
/model_zoo/mobilenetv3.py:
--------------------------------------------------------------------------------
1 | """
2 | From https://github.com/d-li14/mobilenetv3.pytorch
3 | ----------------------------------------------------
4 | Creates a MobileNetV3 Model as defined in:
5 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
6 | Searching for MobileNetV3
7 | arXiv preprint arXiv:1905.02244.
8 | """
9 |
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small']
15 |
16 | def _make_divisible(v, divisor, min_value=None):
17 | """
18 | This function is taken from the original tf repo.
19 | It ensures that all layers have a channel number that is divisible by 8
20 | It can be seen here:
21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22 | :param v:
23 | :param divisor:
24 | :param min_value:
25 | :return:
26 | """
27 | if min_value is None:
28 | min_value = divisor
29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30 | # Make sure that round down does not go down by more than 10%.
31 | if new_v < 0.9 * v:
32 | new_v += divisor
33 | return new_v
34 |
35 |
36 | class h_sigmoid(nn.Module):
37 | def __init__(self, inplace=True):
38 | super(h_sigmoid, self).__init__()
39 | self.relu = nn.ReLU6(inplace=inplace)
40 |
41 | def forward(self, x):
42 | return self.relu(x + 3) / 6
43 |
44 |
45 | class h_swish(nn.Module):
46 | def __init__(self, inplace=True):
47 | super(h_swish, self).__init__()
48 | self.sigmoid = h_sigmoid(inplace=inplace)
49 |
50 | def forward(self, x):
51 | return x * self.sigmoid(x)
52 |
53 |
54 | class SELayer(nn.Module):
55 | def __init__(self, channel, reduction=4):
56 | super(SELayer, self).__init__()
57 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
58 | self.fc = nn.Sequential(
59 | nn.Linear(channel, _make_divisible(channel // reduction, 8)),
60 | nn.ReLU(inplace=True),
61 | nn.Linear(_make_divisible(channel // reduction, 8), channel),
62 | h_sigmoid()
63 | )
64 |
65 | def forward(self, x):
66 | b, c, _, _ = x.size()
67 | y = self.avg_pool(x).view(b, c)
68 | y = self.fc(y).view(b, c, 1, 1)
69 | return x * y
70 |
71 |
72 | def conv_3x3_bn(inp, oup, stride):
73 | return nn.Sequential(
74 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
75 | nn.BatchNorm2d(oup),
76 | h_swish()
77 | )
78 |
79 |
80 | def conv_1x1_bn(inp, oup):
81 | return nn.Sequential(
82 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
83 | nn.BatchNorm2d(oup),
84 | h_swish()
85 | )
86 |
87 |
88 | class InvertedResidual(nn.Module):
89 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
90 | super(InvertedResidual, self).__init__()
91 | assert stride in [1, 2]
92 |
93 | self.identity = stride == 1 and inp == oup
94 |
95 | if inp == hidden_dim:
96 | self.conv = nn.Sequential(
97 | # dw
98 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
99 | nn.BatchNorm2d(hidden_dim),
100 | h_swish() if use_hs else nn.ReLU(inplace=True),
101 | # Squeeze-and-Excite
102 | SELayer(hidden_dim) if use_se else nn.Identity(),
103 | # pw-linear
104 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
105 | nn.BatchNorm2d(oup),
106 | )
107 | else:
108 | self.conv = nn.Sequential(
109 | # pw
110 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
111 | nn.BatchNorm2d(hidden_dim),
112 | h_swish() if use_hs else nn.ReLU(inplace=True),
113 | # dw
114 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
115 | nn.BatchNorm2d(hidden_dim),
116 | h_swish() if use_hs else nn.ReLU(inplace=True),
117 | # Squeeze-and-Excite
118 | SELayer(hidden_dim) if use_se else nn.Identity(),
119 | # pw-linear
120 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
121 | nn.BatchNorm2d(oup),
122 | )
123 |
124 | def forward(self, x):
125 | if self.identity:
126 | return x + self.conv(x)
127 | else:
128 | return self.conv(x)
129 |
130 |
131 | class MobileNetV3(nn.Module):
132 | def __init__(self, cfgs, mode, num_classes=1000, width_mult=1.,
133 | init_mode='he_fout', bn_momentum=0.1, bn_eps=1e-3):
134 | super(MobileNetV3, self).__init__()
135 | # setting of inverted residual blocks
136 | self.cfgs = cfgs
137 | assert mode in ['large', 'small']
138 |
139 | # building first layer
140 | input_channel = _make_divisible(16 * width_mult, 8)
141 | layers = [conv_3x3_bn(3, input_channel, 2)]
142 | # building inverted residual blocks
143 | block = InvertedResidual
144 | for k, t, c, use_se, use_hs, s in self.cfgs:
145 | output_channel = _make_divisible(c * width_mult, 8)
146 | exp_size = _make_divisible(input_channel * t, 8)
147 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
148 | input_channel = output_channel
149 | self.features = nn.Sequential(*layers)
150 | # building last several layers
151 | self.conv = conv_1x1_bn(input_channel, exp_size)
152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
153 | output_channel = {'large': 1280, 'small': 1024}
154 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode]
155 | self.classifier = nn.Sequential(
156 | nn.Linear(exp_size, output_channel),
157 | h_swish(),
158 | nn.Dropout(0.2),
159 | nn.Linear(output_channel, num_classes),
160 | )
161 |
162 | self._initialize_weights()
163 | self.set_bn_param(bn_momentum, bn_eps)
164 |
165 | def forward(self, x):
166 | x = self.features(x)
167 | x = self.conv(x)
168 | x = self.avgpool(x)
169 | x = x.view(x.size(0), -1)
170 | x = self.classifier(x)
171 | return x
172 |
173 | def _initialize_weights(self):
174 | for m in self.modules():
175 | if isinstance(m, nn.Conv2d):
176 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
177 | if m.bias is not None:
178 | m.bias.data.zero_()
179 | elif isinstance(m, nn.BatchNorm2d):
180 | m.weight.data.fill_(1)
181 | m.bias.data.zero_()
182 | elif isinstance(m, nn.Linear):
183 | m.weight.data.normal_(0, 0.01)
184 | m.bias.data.zero_()
185 |
186 | def set_bn_param(self, bn_momentum, bn_eps):
187 | for m in self.modules():
188 | if isinstance(m, nn.BatchNorm2d):
189 | m.momentum = bn_momentum
190 | m.eps = bn_eps
191 | return
192 |
193 |
194 | def mobilenetv3_large(**kwargs):
195 | """
196 | Constructs a MobileNetV3-Large model
197 | """
198 | cfgs = [
199 | # k, t, c, SE, HS, s
200 | [3, 1, 16, 0, 0, 1],
201 | [3, 4, 24, 0, 0, 2],
202 | [3, 3, 24, 0, 0, 1],
203 | [5, 3, 40, 1, 0, 2],
204 | [5, 3, 40, 1, 0, 1],
205 | [5, 3, 40, 1, 0, 1],
206 | [3, 6, 80, 0, 1, 2],
207 | [3, 2.5, 80, 0, 1, 1],
208 | [3, 2.3, 80, 0, 1, 1],
209 | [3, 2.3, 80, 0, 1, 1],
210 | [3, 6, 112, 1, 1, 1],
211 | [3, 6, 112, 1, 1, 1],
212 | [5, 6, 160, 1, 1, 2],
213 | [5, 6, 160, 1, 1, 1],
214 | [5, 6, 160, 1, 1, 1]
215 | ]
216 | return MobileNetV3(cfgs, mode='large', **kwargs)
217 |
218 |
219 | def mobilenetv3_small(**kwargs):
220 | """
221 | Constructs a MobileNetV3-Small model
222 | """
223 | cfgs = [
224 | # k, t, c, SE, HS, s
225 | [3, 1, 16, 1, 0, 2],
226 | [3, 4.5, 24, 0, 0, 2],
227 | [3, 3.67, 24, 0, 0, 1],
228 | [5, 4, 40, 1, 1, 2],
229 | [5, 6, 40, 1, 1, 1],
230 | [5, 6, 40, 1, 1, 1],
231 | [5, 3, 48, 1, 1, 1],
232 | [5, 3, 48, 1, 1, 1],
233 | [5, 6, 96, 1, 1, 2],
234 | [5, 6, 96, 1, 1, 1],
235 | [5, 6, 96, 1, 1, 1],
236 | ]
237 |
238 | return MobileNetV3(cfgs, mode='small', **kwargs)
239 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/models/__init__.py
--------------------------------------------------------------------------------
/models/model_derived.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .operations import OPS
7 | from tools.utils import parse_net_config
8 |
9 |
10 | class Block(nn.Module):
11 |
12 | def __init__(self, in_ch, block_ch, head_op, stack_ops, stride, use_se=False):
13 | super(Block, self).__init__()
14 | self.head_layer = OPS[head_op](in_ch, block_ch, stride, use_se=use_se,
15 | affine=True, track_running_stats=True)
16 |
17 | modules = []
18 | for stack_op in stack_ops:
19 | modules.append(OPS[stack_op](block_ch, block_ch, 1, use_se=use_se,
20 | affine=True, track_running_stats=True))
21 | self.stack_layers = nn.Sequential(*modules)
22 |
23 | def forward(self, x):
24 | x = self.head_layer(x)
25 | x = self.stack_layers(x)
26 | return x
27 |
28 |
29 | class Conv1_1_Block(nn.Module):
30 |
31 | def __init__(self, in_ch, block_ch):
32 | super(Conv1_1_Block, self).__init__()
33 | self.conv1_1 = nn.Sequential(
34 | nn.Conv2d(in_channels=in_ch, out_channels=block_ch,
35 | kernel_size=1, stride=1, padding=0, bias=False),
36 | nn.BatchNorm2d(block_ch),
37 | nn.ReLU6(inplace=True)
38 | )
39 |
40 | def forward(self, x):
41 | return self.conv1_1(x)
42 |
43 |
44 | class MBV2_Net(nn.Module):
45 | def __init__(self, net_config, config=None):
46 | """
47 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride]
48 | """
49 | super(MBV2_Net, self).__init__()
50 | self.config = config
51 | self.net_config = parse_net_config(net_config)
52 | self.in_chs = self.net_config[0][0][0]
53 | self._num_classes = 1000
54 | use_se = False
55 | if hasattr(config.optim, 'use_se'):
56 | use_se = config.optim.use_se
57 |
58 | self.input_block = nn.Sequential(
59 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3,
60 | stride=2, padding=1, bias=False),
61 | nn.BatchNorm2d(self.in_chs),
62 | nn.ReLU6(inplace=True)
63 | )
64 | self.blocks = nn.ModuleList()
65 | for cfg in self.net_config:
66 | if cfg[1] == 'conv1_1':
67 | continue
68 | self.blocks.append(Block(cfg[0][0], cfg[0][1],
69 | cfg[1], cfg[2], cfg[-1], use_se))
70 |
71 | if self.net_config[-1][1] == 'conv1_1':
72 | block_last_dim = self.net_config[-1][0][0]
73 | last_dim = self.net_config[-1][0][1]
74 | else:
75 | block_last_dim = self.net_config[-1][0][1]
76 | self.conv1_1_block = Conv1_1_Block(block_last_dim, last_dim)
77 |
78 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
79 | self.classifier = nn.Linear(last_dim, self._num_classes)
80 |
81 | self.init_model()
82 | self.set_bn_param(0.1, 0.001)
83 |
84 |
85 | def forward(self,x):
86 | block_data = self.input_block(x)
87 | for i, block in enumerate(self.blocks):
88 | block_data = block(block_data)
89 | block_data = self.conv1_1_block(block_data)
90 |
91 | out = self.global_pooling(block_data)
92 | logits = self.classifier(out.view(out.size(0),-1))
93 |
94 | return logits
95 |
96 | def init_model(self, model_init='he_fout', init_div_groups=True):
97 | for m in self.modules():
98 | if isinstance(m, nn.Conv2d):
99 | if model_init == 'he_fout':
100 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
102 | if init_div_groups:
103 | n /= m.groups
104 | m.weight.data.normal_(0, math.sqrt(2. / n))
105 | elif model_init == 'he_fin':
106 | # nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
107 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
108 | if init_div_groups:
109 | n /= m.groups
110 | m.weight.data.normal_(0, math.sqrt(2. / n))
111 | else:
112 | raise NotImplementedError
113 | elif isinstance(m, nn.BatchNorm2d):
114 | m.weight.data.fill_(1)
115 | m.bias.data.zero_()
116 | elif isinstance(m, nn.Linear):
117 | if m.bias is not None:
118 | m.bias.data.zero_()
119 | elif isinstance(m, nn.BatchNorm1d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 |
123 | def set_bn_param(self, bn_momentum, bn_eps):
124 | for m in self.modules():
125 | if isinstance(m, nn.BatchNorm2d):
126 | m.momentum = bn_momentum
127 | m.eps = bn_eps
128 | return
129 |
130 |
131 | class RES_Net(nn.Module):
132 | def __init__(self, net_config, config=None):
133 | """
134 | net_config=[[in_ch, out_ch], head_op, [stack_ops], num_stack_layers, stride]
135 | """
136 | super(RES_Net, self).__init__()
137 | self.config = config
138 | self.net_config = parse_net_config(net_config)
139 | self.in_chs = self.net_config[0][0][0]
140 | self._num_classes = 1000
141 |
142 | self.input_block = nn.Sequential(
143 | nn.Conv2d(in_channels=3, out_channels=self.in_chs, kernel_size=3,
144 | stride=2, padding=1, bias=False),
145 | nn.BatchNorm2d(self.in_chs),
146 | nn.ReLU6(inplace=True),
147 | )
148 | self.blocks = nn.ModuleList()
149 | for config in self.net_config:
150 | self.blocks.append(Block(config[0][0], config[0][1],
151 | config[1], config[2], config[-1], use_se=False))
152 |
153 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
154 | if self.net_config[-1][1] == 'bottle_neck':
155 | last_dim = self.net_config[-1][0][-1] * 4
156 | else:
157 | last_dim = self.net_config[-1][0][1]
158 | self.classifier = nn.Linear(last_dim, self._num_classes)
159 |
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
164 | if m.affine==True:
165 | nn.init.constant_(m.weight, 1)
166 | nn.init.constant_(m.bias, 0)
167 |
168 | def forward(self, x):
169 | block_data = self.input_block(x)
170 | for i, block in enumerate(self.blocks):
171 | block_data = block(block_data)
172 |
173 | out = self.global_pooling(block_data)
174 | out = torch.flatten(out, 1)
175 | logits = self.classifier(out)
176 | return logits
177 |
178 |
179 | class PLAIN_Net(MBV2_Net):
180 | def __init__(self, *args, **kwargs):
181 | super(PLAIN_Net, self).__init__(*args, **kwargs)
182 |
--------------------------------------------------------------------------------
/models/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | OPS = {
5 | 'mbconv_k3_t1': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=1, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
6 | 'mbconv_k3_t3': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=3, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
7 | 'mbconv_k3_t6': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 3, stride, 1, t=6, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
8 | 'mbconv_k5_t1': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=1, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
9 | 'mbconv_k5_t3': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=3, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
10 | 'mbconv_k5_t6': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 5, stride, 2, t=6, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
11 | 'mbconv_k7_t1': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=1, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
12 | 'mbconv_k7_t3': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=3, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
13 | 'mbconv_k7_t6': lambda C_in, C_out, stride, use_se, affine, track_running_stats: MBConv(C_in, C_out, 7, stride, 3, t=6, use_se=use_se, affine=affine, track_running_stats=track_running_stats),
14 | 'basic_block': lambda C_in, C_out, stride, use_se, affine, track_running_stats: BasicBlock(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats),
15 | 'bottle_neck': lambda C_in, C_out, stride, use_se, affine, track_running_stats: Bottleneck(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats),
16 | 'skip_connect': lambda C_in, C_out, stride, use_se, affine, track_running_stats: Skip(C_in, C_out, stride, affine=affine, track_running_stats=track_running_stats),
17 | }
18 |
19 |
20 | class MBConv(nn.Module):
21 | def __init__(self, C_in, C_out, kernel_size, stride, padding, t=3, affine=True,
22 | track_running_stats=True, use_se=False):
23 | super(MBConv, self).__init__()
24 | if t > 1:
25 | C_hidden = C_in*t
26 | self._expand_conv = nn.Sequential(
27 | nn.Conv2d(C_in, C_hidden, kernel_size=1, stride=1, padding=0, groups=1, bias=False),
28 | nn.BatchNorm2d(C_hidden, affine=affine, track_running_stats=track_running_stats),
29 | nn.ReLU6(inplace=True))
30 |
31 | self._depthwise_conv = nn.Sequential(
32 | nn.Conv2d(C_hidden, C_hidden, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_hidden, bias=False),
33 | nn.BatchNorm2d(C_hidden, affine=affine, track_running_stats=track_running_stats),
34 | nn.ReLU6(inplace=True))
35 |
36 | self._project_conv = nn.Sequential(
37 | nn.Conv2d(C_hidden, C_out, kernel_size=1, stride=1, padding=0, groups=1, bias=False),
38 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats))
39 | else:
40 | C_hidden = C_in
41 | self._expand_conv = None
42 |
43 | self._depthwise_conv = nn.Sequential(
44 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
45 | nn.BatchNorm2d(C_in, affine=affine, track_running_stats=track_running_stats),
46 | nn.ReLU6(inplace=True))
47 |
48 | self._project_conv = nn.Sequential(
49 | nn.Conv2d(C_in, C_out, 1, 1, 0, bias=False),
50 | nn.BatchNorm2d(C_out))
51 |
52 | self.se = SELayer(C_hidden) if use_se else None
53 |
54 | def forward(self, x):
55 | input_data = x
56 | if self._expand_conv is not None:
57 | x = self._expand_conv(x)
58 | x = self._depthwise_conv(x)
59 | if self.se is not None:
60 | x = self.se(x)
61 | out_data = self._project_conv(x)
62 |
63 | if out_data.shape == input_data.shape:
64 | return out_data + input_data
65 | else:
66 | return out_data
67 |
68 |
69 | class h_sigmoid(nn.Module):
70 | def __init__(self, inplace=True):
71 | super(h_sigmoid, self).__init__()
72 | self.relu = nn.ReLU6(inplace=inplace)
73 |
74 | def forward(self, x):
75 | return self.relu(x + 3) / 6
76 |
77 | def _make_divisible(v, divisor, min_value=None):
78 | """
79 | This function is taken from the original tf repo.
80 | It ensures that all layers have a channel number that is divisible by 8
81 | It can be seen here:
82 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
83 | :param v:
84 | :param divisor:
85 | :param min_value:
86 | :return:
87 | """
88 | if min_value is None:
89 | min_value = divisor
90 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
91 | # Make sure that round down does not go down by more than 10%.
92 | if new_v < 0.9 * v:
93 | new_v += divisor
94 | return new_v
95 |
96 | class SELayer(nn.Module):
97 | def __init__(self, channel, reduction=4):
98 | super(SELayer, self).__init__()
99 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
100 | self.fc = nn.Sequential(
101 | nn.Linear(channel, _make_divisible(channel // reduction, 8)),
102 | nn.ReLU(inplace=True),
103 | nn.Linear(_make_divisible(channel // reduction, 8), channel),
104 | h_sigmoid()
105 | )
106 |
107 | def forward(self, x):
108 | b, c, _, _ = x.size()
109 | y = self.avg_pool(x).view(b, c)
110 | y = self.fc(y).view(b, c, 1, 1)
111 | return x * y
112 |
113 |
114 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
115 | """3x3 convolution with padding"""
116 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
117 | padding=dilation, groups=groups, bias=False, dilation=dilation)
118 |
119 | def conv1x1(in_planes, out_planes, stride=1):
120 | """1x1 convolution"""
121 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
122 |
123 | class BasicBlock(nn.Module):
124 | def __init__(self, inplanes, planes, stride=1, groups=1,
125 | base_width=64, dilation=1, norm_layer=None,
126 | affine=True, track_running_stats=True):
127 | super(BasicBlock, self).__init__()
128 | if norm_layer is None:
129 | norm_layer = nn.BatchNorm2d
130 | if groups != 1 or base_width != 64:
131 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
132 | if dilation > 1:
133 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
134 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
135 | self.conv1 = conv3x3(inplanes, planes, stride)
136 | self.bn1 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats)
137 | self.relu = nn.ReLU(inplace=True)
138 | self.conv2 = conv3x3(planes, planes)
139 | self.bn2 = norm_layer(planes, affine=affine, track_running_stats=track_running_stats)
140 | self.downsample = None
141 | if stride != 1 or inplanes != planes:
142 | self.downsample = nn.Sequential(
143 | conv1x1(inplanes, planes, stride),
144 | norm_layer(planes, affine=affine, track_running_stats=track_running_stats),
145 | )
146 |
147 | def forward(self, x):
148 | identity = x
149 | out = self.conv1(x)
150 | out = self.bn1(out)
151 | out = self.relu(out)
152 | out = self.conv2(out)
153 | out = self.bn2(out)
154 |
155 | if self.downsample is not None:
156 | identity = self.downsample(x)
157 | out += identity
158 | out = self.relu(out)
159 |
160 | return out
161 |
162 |
163 | class Bottleneck(nn.Module):
164 | def __init__(self, inplanes, planes, stride=1, affine=True, track_running_stats=True):
165 | super(Bottleneck, self).__init__()
166 | if inplanes != 32:
167 | inplanes *= 4
168 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
169 | self.bn1 = nn.BatchNorm2d(planes)
170 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
171 | padding=1, bias=False)
172 | self.bn2 = nn.BatchNorm2d(planes)
173 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
174 | self.bn3 = nn.BatchNorm2d(planes * 4)
175 | self.relu = nn.ReLU(inplace=True)
176 | self.stride = stride
177 | self.downsample = None
178 | if stride != 1 or inplanes != planes*4:
179 | self.downsample = nn.Sequential(
180 | conv1x1(inplanes, planes * 4, stride),
181 | nn.BatchNorm2d(planes * 4, affine=affine, track_running_stats=track_running_stats),
182 | )
183 |
184 | def forward(self, x):
185 | residual = x
186 |
187 | out = self.conv1(x)
188 | out = self.bn1(out)
189 | out = self.relu(out)
190 |
191 | out = self.conv2(out)
192 | out = self.bn2(out)
193 | out = self.relu(out)
194 |
195 | out = self.conv3(out)
196 | out = self.bn3(out)
197 |
198 | if self.downsample is not None:
199 | residual = self.downsample(x)
200 |
201 | out += residual
202 | out = self.relu(out)
203 |
204 | return out
205 |
206 |
207 | class Skip(nn.Module):
208 | def __init__(self, C_in, C_out, stride, affine=True, track_running_stats=True):
209 | super(Skip, self).__init__()
210 | if C_in!=C_out:
211 | skip_conv = nn.Sequential(
212 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=stride, padding=0, groups=1, bias=False),
213 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats))
214 | stride = 1
215 | self.op=Identity(stride)
216 |
217 | if C_in!=C_out:
218 | self.op=nn.Sequential(skip_conv, self.op)
219 |
220 | def forward(self,x):
221 | return self.op(x)
222 |
223 | class Identity(nn.Module):
224 | def __init__(self, stride):
225 | super(Identity, self).__init__()
226 | self.stride = stride
227 |
228 | def forward(self, x):
229 | if self.stride == 1:
230 | return x
231 | else:
232 | return x[:, :, ::self.stride, ::self.stride]
233 |
--------------------------------------------------------------------------------
/run_apis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/run_apis/__init__.py
--------------------------------------------------------------------------------
/run_apis/train_dist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import ast
3 | import logging
4 | import os
5 | import sys; sys.path.append(os.path.join(sys.path[0], '..'))
6 | import time
7 |
8 | import model_zoo
9 | import numpy as np
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | import torch.distributed as dist
13 | import torch.nn as nn
14 | from dataset import imagenet_data_dali
15 | from mmcv import Config
16 | from models import model_derived
17 | from tensorboardX import SummaryWriter
18 | from tools import env, utils
19 | from tools.lr_scheduler import get_lr_scheduler
20 | from tools.multadds_count import comp_multadds
21 |
22 | from trainer import Trainer
23 |
24 |
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser("Train_Params")
28 | parser.add_argument('--report_freq', type=float, default=500, help='report frequency')
29 | parser.add_argument('--data_path', type=str, default='../data', help='location of the data corpus')
30 | parser.add_argument('--load_path', type=str, default='./model_path', help='model loading path')
31 | parser.add_argument('--save', type=str, default='../', help='experiment name')
32 | parser.add_argument('--tb_path', type=str, default='', help='tensorboard output path')
33 | parser.add_argument('--meas_lat', type=ast.literal_eval, default='False', help='whether to measure the latency of the model')
34 | parser.add_argument('--job_name', type=str, default='', help='job_name')
35 | parser.add_argument('--port', type=int, default=23333, help='dist port')
36 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher')
37 | parser.add_argument('--local_rank', default=0, type=int)
38 | parser.add_argument('--evaluation', type=ast.literal_eval, default='False', help='first evaluation')
39 | parser.add_argument('--config', type=str, default='', help='the file of the config')
40 | args = parser.parse_args()
41 |
42 | config = Config.fromfile(os.path.join('configs/train_cfg', args.config))
43 | if config.net_config:
44 | net_config = config.pop('net_config')
45 |
46 | # init distributed env first, since logger depends on the dist info.
47 | if args.launcher == 'none':
48 | distributed = False
49 | raise EnvironmentError
50 | else:
51 | distributed = True
52 | os.environ['MASTER_ADDR'] = 'localhost'
53 | os.environ['MASTER_PORT'] = '%d' % args.port
54 | env.init_dist(args.launcher)
55 | local_rank = dist.get_rank()
56 | world_size = dist.get_world_size()
57 |
58 | if args.job_name != '':
59 | args.job_name = time.strftime("%Y%m%d-%H%M%S-") + args.job_name
60 | args.save = os.path.join(args.save, args.job_name)
61 | if local_rank == 0:
62 | utils.create_exp_dir(args.save)
63 | os.system('cp -r ./* '+args.save)
64 | else:
65 | args.save = os.path.join(args.save, 'output')
66 | if local_rank == 0:
67 | utils.create_exp_dir(args.save)
68 |
69 | if args.tb_path == '':
70 | args.tb_path = args.save
71 |
72 | env.get_root_logger(log_dir=args.save)
73 | cudnn.benchmark = True
74 | cudnn.enabled = True
75 |
76 | if config.train_params.use_seed:
77 | utils.set_seed(config.train_params.seed)
78 |
79 | logging.info("args = %s", args)
80 | logging.info('Training with config:')
81 | logging.info(config.pretty_text)
82 | writer = SummaryWriter(args.tb_path)
83 |
84 | if config.model_zoo.use_model_zoo:
85 | model = getattr(model_zoo, config.model_zoo.model_name)(
86 | **config.model_zoo.cfg if config.model_zoo.cfg else {})
87 | else:
88 | if os.path.isfile(os.path.join(args.load_path, 'net_config')):
89 | net_config, config.net_type = utils.load_net_config(
90 | os.path.join(args.load_path, 'net_config'))
91 | derivedNetwork = getattr(model_derived, '%s_Net' % config.net_type.upper())
92 | model = derivedNetwork(net_config, config=config)
93 |
94 | model.eval()
95 | if hasattr(model, 'net_config'):
96 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.net_config)))
97 | if args.meas_lat:
98 | latency_cpu = utils.latency_measure(model, (3, 224, 224), 1, 2000, mode='cpu')
99 | logging.info('latency_cpu (batch 1): %.2fms' % latency_cpu)
100 | latency_gpu = utils.latency_measure(model, (3, 224, 224), 32, 1000, mode='gpu')
101 | logging.info('latency_gpu (batch 32): %.2fms' % latency_gpu)
102 |
103 | params = utils.count_parameters_in_MB(model)
104 | logging.info("Params = %.2fMB" % params)
105 | mult_adds = comp_multadds(model, input_size=config.data.input_size)
106 | logging.info("Mult-Adds = %.2fMB" % mult_adds)
107 |
108 | model.cuda(local_rank)
109 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
110 |
111 | # whether to resume from a checkpoint
112 | if config.optim.if_resume:
113 | utils.load_model(model, config.optim.resume.load_path, distributed)
114 | start_epoch = config.optim.resume.load_epoch + 1
115 | else:
116 | start_epoch = 0
117 |
118 | if config.optim.label_smooth:
119 | criterion = utils.cross_entropy_with_label_smoothing
120 | else:
121 | criterion = nn.CrossEntropyLoss()
122 | criterion = criterion.cuda()
123 |
124 | if config.optim.bn_wo_wd:
125 | parameters = utils.add_weight_decay(model, config.optim.weight_decay)
126 | else:
127 | parameters = model.parameters()
128 | optimizer = torch.optim.SGD(
129 | parameters,
130 | config.optim.init_lr,
131 | momentum=config.optim.momentum,
132 | weight_decay=config.optim.weight_decay
133 | )
134 |
135 | train_loader, val_loader = imagenet_data_dali.get_data_loader(
136 | config.data, world_size, local_rank, args.data_path
137 | )
138 |
139 | scheduler = get_lr_scheduler(config, optimizer, train_loader._size)
140 | scheduler.last_step = start_epoch * (train_loader._size // config.data.batch_size + 1)-1
141 |
142 | trainer = Trainer(train_loader, val_loader, optimizer, criterion,
143 | scheduler, config, args.report_freq, distributed)
144 |
145 | best_epoch = [0, 0, 0] # [epoch, acc_top1, acc_top5]
146 | if args.evaluation:
147 | with torch.no_grad():
148 | val_acc_top1, val_acc_top5, batch_time, data_time = trainer.infer(model, start_epoch-1)
149 | if val_acc_top1 > best_epoch[1]:
150 | best_epoch = [start_epoch-1, val_acc_top1, val_acc_top5]
151 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2])
152 |
153 | for epoch in range(start_epoch, config.train_params.epochs):
154 | train_acc_top1, train_acc_top5, train_obj, batch_time, data_time = trainer.train(model, epoch)
155 |
156 | with torch.no_grad():
157 | val_acc_top1, val_acc_top5, batch_time, data_time = trainer.infer(model, epoch)
158 | if val_acc_top1 > best_epoch[1]:
159 | best_epoch = [epoch, val_acc_top1, val_acc_top5]
160 | if local_rank==0:
161 | utils.save(model, os.path.join(args.save, 'weights.pt'))
162 | logging.info('BEST EPOCH %d val_top1 %.2f val_top5 %.2f', best_epoch[0], best_epoch[1], best_epoch[2])
163 |
164 | if local_rank == 0:
165 | writer.add_scalar('train_acc_top1', train_acc_top1, epoch)
166 | writer.add_scalar('train_loss', train_obj, epoch)
167 | writer.add_scalar('val_acc_top1', val_acc_top1, epoch)
168 |
169 | if hasattr(model.module, 'net_config'):
170 | logging.info("Network Structure: \n" + '|\n'.join(map(str, model.module.net_config)))
171 | logging.info("Params = %.2fMB" % params)
172 | logging.info("Mult-Adds = %.2fMB" % mult_adds)
173 |
--------------------------------------------------------------------------------
/run_apis/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | from tools import env, utils
7 |
8 |
9 | class Trainer(object):
10 | def __init__(self, train_data, val_data, optimizer=None, criterion=None,
11 | scheduler=None, config=None, report_freq=None, distributed=False):
12 | self.train_data = train_data
13 | self.val_data = val_data
14 | self.optimizer = optimizer
15 | self.criterion = criterion
16 | self.scheduler = scheduler
17 | self.config = config
18 | self.report_freq = report_freq
19 | self.distributed = distributed
20 |
21 | def train(self, model, epoch):
22 | objs = utils.AverageMeter()
23 | top1 = utils.AverageMeter()
24 | top5 = utils.AverageMeter()
25 | data_time = utils.AverageMeter()
26 | batch_time = utils.AverageMeter()
27 | model.train()
28 |
29 | start = time.time()
30 | try:
31 | data = next(self.train_data)
32 | input = data[0]["data"]
33 | target = data[0]["label"].squeeze().cuda().long()
34 | except StopIteration:
35 | input, target = None, None
36 | step = 0
37 | while input is not None:
38 | data_t = time.time() - start
39 | self.scheduler.step()
40 | n = input.size(0)
41 | if step==0:
42 | logging.info('epoch %d lr %e', epoch, self.optimizer.param_groups[0]['lr'])
43 | self.optimizer.zero_grad()
44 |
45 | logits= model(input)
46 | if self.config.optim.label_smooth:
47 | loss = self.criterion(logits, target, self.config.optim.smooth_alpha)
48 | else:
49 | loss = self.criterion(logits, target)
50 |
51 | loss.backward()
52 | if self.config.optim.use_grad_clip:
53 | nn.utils.clip_grad_norm_(model.parameters(), self.config.optim.grad_clip)
54 | self.optimizer.step()
55 |
56 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
57 |
58 | if self.distributed:
59 | loss = env.reduce_tensor(loss.data)
60 | prec1 = env.reduce_tensor(prec1)
61 | prec5 = env.reduce_tensor(prec5)
62 | torch.cuda.synchronize()
63 |
64 | batch_t = time.time() - start
65 | start = time.time()
66 | objs.update(loss.item(), n)
67 | top1.update(prec1.item(), n)
68 | top5.update(prec5.item(), n)
69 | data_time.update(data_t)
70 | batch_time.update(batch_t)
71 | if step!=0 and step % self.report_freq == 0:
72 | logging.info(
73 | 'Train epoch %03d step %03d | loss %.4f top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
74 | epoch, step, objs.avg, top1.avg, top5.avg, batch_time.avg, data_time.avg)
75 |
76 | try:
77 | data = next(self.train_data)
78 | input = data[0]["data"]
79 | target = data[0]["label"].squeeze().cuda().long()
80 | except StopIteration:
81 | input, target = None, None
82 | step += 1
83 | logging.info('EPOCH%d Train_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f',
84 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)
85 |
86 | return top1.avg, top5.avg, objs.avg, batch_time.avg, data_time.avg
87 |
88 |
89 | def infer(self, model, epoch=0):
90 | top1 = utils.AverageMeter()
91 | top5 = utils.AverageMeter()
92 | data_time = utils.AverageMeter()
93 | batch_time = utils.AverageMeter()
94 | model.eval()
95 |
96 | start = time.time()
97 | try:
98 | data = next(self.val_data)
99 | input = data[0]["data"]
100 | target = data[0]["label"].squeeze().cuda().long()
101 | except StopIteration:
102 | input, target = None, None
103 | step = 0
104 | while input is not None:
105 | step += 1
106 | data_t = time.time() - start
107 | n = input.size(0)
108 | logits = model(input)
109 |
110 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
111 | if self.distributed:
112 | prec1 = env.reduce_tensor(prec1)
113 | prec5 = env.reduce_tensor(prec5)
114 | torch.cuda.synchronize()
115 | batch_t = time.time() - start
116 | start = time.time()
117 | top1.update(prec1.item(), n)
118 | top5.update(prec5.item(), n)
119 | data_time.update(data_t)
120 | batch_time.update(batch_t)
121 |
122 | if step % self.report_freq == 0:
123 | logging.info(
124 | 'Val epoch %03d step %03d | top1_acc %.2f top5_acc %.2f | batch_time %.3f data_time %.3f',
125 | epoch, step, top1.avg, top5.avg, batch_time.avg, data_time.avg)
126 | try:
127 | data = next(self.val_data)
128 | input = data[0]["data"]
129 | target = data[0]["label"].squeeze().cuda().long()
130 | except StopIteration:
131 | input, target = None, None
132 |
133 | logging.info('EPOCH%d Valid_acc top1 %.2f top5 %.2f batch_time %.3f data_time %.3f',
134 | epoch, top1.avg, top5.avg, batch_time.avg, data_time.avg)
135 | return top1.avg, top5.avg, batch_time.avg, data_time.avg
136 |
137 |
--------------------------------------------------------------------------------
/scripts/train_dist.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -m torch.distributed.launch --nproc_per_node=$1 run_apis/train_dist.py \
4 | --launcher pytorch \
5 | --report_freq 400 \
6 | --data_path $2 \
7 | --port 23333 \
8 | --config $3
9 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JaminFong/dali-pytorch/7bd5d2380d210a32d24c7309da69c8d2c5db8759/tools/__init__.py
--------------------------------------------------------------------------------
/tools/env.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import os.path as osp
4 | import random
5 |
6 | import numpy as np
7 | import torch
8 | import torch.distributed as dist
9 | import torch.multiprocessing as mp
10 |
11 |
12 | def init_dist(launcher, backend='nccl', **kwargs):
13 | # if mp.get_start_method(allow_none=True) is None:
14 | # mp.set_start_method('spawn')
15 | if launcher == 'pytorch':
16 | _init_dist_pytorch(backend, **kwargs)
17 | elif launcher == 'mpi':
18 | _init_dist_mpi(backend, **kwargs)
19 | elif launcher == 'slurm':
20 | _init_dist_slurm(backend, **kwargs)
21 | else:
22 | raise ValueError('Invalid launcher type: {}'.format(launcher))
23 |
24 |
25 | def _init_dist_pytorch(backend, **kwargs):
26 | rank = int(os.environ['RANK'])
27 | num_gpus = torch.cuda.device_count()
28 | torch.cuda.set_device(rank % num_gpus)
29 | dist.init_process_group(backend=backend, **kwargs)
30 |
31 |
32 | def _init_dist_mpi(backend, **kwargs):
33 | raise NotImplementedError
34 |
35 |
36 | def _init_dist_slurm(backend, **kwargs):
37 | raise NotImplementedError
38 |
39 |
40 | def set_random_seed(seed):
41 | random.seed(seed)
42 | np.random.seed(seed)
43 | torch.manual_seed(seed)
44 | torch.cuda.manual_seed_all(seed)
45 |
46 |
47 | def get_root_logger(log_dir=None, log_level=logging.INFO):
48 | logger = logging.getLogger()
49 | if not logger.hasHandlers():
50 | logging.basicConfig(
51 | format='%(asctime)s - %(message)s',
52 | level=log_level,
53 | datefmt='%m/%d %I:%M:%S %p')
54 | rank = dist.get_rank()
55 | if rank != 0:
56 | logger.setLevel('ERROR')
57 |
58 | if log_dir and rank == 0:
59 | filename = 'train_dist.log'
60 | log_file = osp.join(log_dir, filename)
61 | _add_file_handler(logger, log_file, level=log_level)
62 | return logger
63 |
64 | def _add_file_handler(logger,
65 | filename=None,
66 | mode='w',
67 | level=logging.INFO):
68 | file_handler = logging.FileHandler(filename, mode)
69 | file_handler.setFormatter(
70 | logging.Formatter('%(asctime)s - %(message)s'))
71 | file_handler.setLevel(level)
72 | logger.addHandler(file_handler)
73 | return logger
74 |
75 | def reduce_tensor(tensor):
76 | rt = tensor.clone()
77 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
78 | rt /= dist.get_world_size()
79 | return rt
--------------------------------------------------------------------------------
/tools/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import CosineAnnealingLR
3 | from torch.optim.optimizer import Optimizer
4 | import math
5 |
6 | class CosineRestartAnnealingLR(object):
7 | # decay as step
8 | # T_max refers to the max update step
9 |
10 | def __init__(self, optimizer, T_max, lr_period, lr_step, eta_min=0, last_step=-1,
11 | use_warmup=False, warmup_mode='linear', warmup_steps=0, warmup_startlr=0,
12 | warmup_targetlr=0, use_restart=False):
13 |
14 | self.use_warmup = use_warmup
15 | self.warmup_mode = warmup_mode
16 | self.warmup_steps = warmup_steps
17 | self.warmup_startlr = warmup_startlr
18 | self.warmup_targetlr = warmup_targetlr
19 | self.use_restart = use_restart
20 | self.T_max = T_max
21 | self.eta_min = eta_min
22 |
23 | if self.use_restart == False:
24 | self.lr_period = [self.T_max - self.warmup_steps]
25 | self.lr_step = [self.warmup_steps]
26 | else:
27 | self.lr_period = lr_period
28 | self.lr_step = lr_step
29 |
30 | self.last_step = last_step
31 | self.cycle_length = self.lr_period[0]
32 | self.cur = 0
33 |
34 | if not isinstance(optimizer, Optimizer):
35 | raise TypeError('{} is not an Optimizer'.format(
36 | type(optimizer).__name__))
37 | self.optimizer = optimizer
38 | if last_step == -1:
39 | for group in optimizer.param_groups:
40 | group.setdefault('initial_lr', group['lr'])
41 | else:
42 | for i, group in enumerate(optimizer.param_groups):
43 | if 'initial_lr' not in group:
44 | raise KeyError("param 'initial_lr' is not specified "
45 | "in param_groups[{}] when resuming an optimizer".format(i))
46 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
47 |
48 |
49 | def step(self, step=None):
50 |
51 | if step is not None:
52 | self.last_step = step
53 | else:
54 | self.last_step += 1
55 |
56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
57 | param_group['lr'] = lr
58 |
59 |
60 | def get_lr(self):
61 |
62 | lrs = []
63 | for base_lr in self.base_lrs:
64 | if self.use_warmup and self.last_step < self.warmup_steps:
65 | if self.warmup_mode == 'constant':
66 | lrs.append(self.warmup_startlr)
67 | elif self.warmup_mode =='linear':
68 | cur_lr = self.warmup_startlr + \
69 | float(self.warmup_targetlr-self.warmup_startlr)/self.warmup_steps*self.last_step
70 | lrs.append(cur_lr)
71 | else:
72 | raise NotImplementedError
73 |
74 | else:
75 | if (self.last_step) in self.lr_step:
76 | self.cycle_length = self.lr_period[self.lr_step.index(self.last_step)]
77 | self.cur = self.last_step
78 |
79 | peri_iter = self.last_step-self.cur
80 |
81 | if peri_iter <= self.cycle_length:
82 | unit_cycle = (1 + math.cos(peri_iter * math.pi / self.cycle_length)) / 2
83 | adjusted_cycle = unit_cycle * (base_lr - self.eta_min) + self.eta_min
84 | lrs.append(adjusted_cycle)
85 | else:
86 | lrs.append(self.eta_min)
87 |
88 | return lrs
89 |
90 |
91 | def display_lr_curve(self, total_steps):
92 | lrs = []
93 | for _ in range(total_steps):
94 | self.step()
95 | lrs.append(self.get_lr()[0])
96 | import matplotlib.pyplot as plt
97 | plt.plot(lrs)
98 | plt.show()
99 |
100 |
101 | def get_lr_scheduler(config, optimizer, num_examples=None):
102 |
103 | if num_examples is None:
104 | num_examples = config.data.num_examples
105 | epoch_steps = math.ceil(num_examples // config.data.batch_size)
106 |
107 | if config.optim.use_multi_stage:
108 | max_steps = epoch_steps * config.optim.multi_stage.stage_epochs
109 | else:
110 | max_steps = epoch_steps * config.train_params.epochs
111 |
112 | period_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_period]
113 | step_steps = [epoch_steps * x for x in config.optim.cosine.restart.lr_step]
114 |
115 | init_lr = config.optim.init_lr
116 |
117 | use_warmup = config.optim.use_warm_up
118 | if use_warmup:
119 | warmup_steps = config.optim.warm_up.epoch * epoch_steps
120 | warmup_startlr = config.optim.warm_up.init_lr
121 | warmup_targetlr = config.optim.warm_up.target_lr
122 | else:
123 | warmup_steps = 0
124 | warmup_startlr = init_lr
125 | warmup_targetlr = init_lr
126 |
127 | if config.optim.lr_schedule == 'cosine':
128 | scheduler = CosineRestartAnnealingLR(optimizer,
129 | float(max_steps),
130 | period_steps,
131 | step_steps,
132 | eta_min=config.optim.min_lr,
133 | use_warmup=use_warmup,
134 | warmup_steps=warmup_steps,
135 | warmup_startlr=warmup_startlr,
136 | warmup_targetlr=warmup_targetlr,
137 | use_restart=config.optim.cosine.use_restart)
138 | # scheduler = CosineAnnealingLR(optimizer, config.train_params.epochs, config.optim.min_lr)
139 | elif config.optim.lr_schedule == 'poly':
140 | raise NotImplementedError
141 | else:
142 | raise NotImplementedError
143 |
144 | return scheduler
145 |
146 |
--------------------------------------------------------------------------------
/tools/multadds_count.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # Original implementation:
3 | # https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py
4 |
5 | # ---- Public functions
6 |
7 | def comp_multadds(model, input_size=(3,224,224)):
8 | input_size = (1,) + tuple(input_size)
9 | model = model.cuda()
10 | input_data = torch.randn(input_size).cuda()
11 | model = add_flops_counting_methods(model)
12 | model.start_flops_count()
13 | with torch.no_grad():
14 | _ = model(input_data)
15 |
16 | mult_adds = model.compute_average_flops_cost() / 1e6
17 | return mult_adds
18 |
19 |
20 | def comp_multadds_fw(model, input_data, use_gpu=True):
21 | model = add_flops_counting_methods(model)
22 | if use_gpu:
23 | model = model.cuda()
24 | model.start_flops_count()
25 | with torch.no_grad():
26 | output_data = model(input_data)
27 |
28 | mult_adds = model.compute_average_flops_cost() / 1e6
29 | return mult_adds, output_data
30 |
31 |
32 | def add_flops_counting_methods(net_main_module):
33 | """Adds flops counting functions to an existing model. After that
34 | the flops count should be activated and the model should be run on an input
35 | image.
36 | Example:
37 | fcn = add_flops_counting_methods(fcn)
38 | fcn = fcn.cuda().train()
39 | fcn.start_flops_count()
40 | _ = fcn(batch)
41 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch
42 | Important: dividing by 2 only works for resnet models -- see below for the details
43 | of flops computation.
44 | Attention: we are counting multiply-add as two flops in this work, because in
45 | most resnet models convolutions are bias-free (BN layers act as bias there)
46 | and it makes sense to count muliply and add as separate flops therefore.
47 | This is why in the above example we divide by 2 in order to be consistent with
48 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual
49 | Networks" by Figurnov et al multiply-add was counted as two flops.
50 | This module computes the average flops which is necessary for dynamic networks which
51 | have different number of executed layers. For static networks it is enough to run the network
52 | once and get statistics (above example).
53 | Implementation:
54 | The module works by adding batch_count to the main module which tracks the sum
55 | of all batch sizes that were run through the network.
56 | Also each convolutional layer of the network tracks the overall number of flops
57 | performed.
58 | The parameters are updated with the help of registered hook-functions which
59 | are being called each time the respective layer is executed.
60 | Parameters
61 | ----------
62 | net_main_module : torch.nn.Module
63 | Main module containing network
64 | Returns
65 | -------
66 | net_main_module : torch.nn.Module
67 | Updated main module with new methods/attributes that are used
68 | to compute flops.
69 | """
70 |
71 | # adding additional methods to the existing module object,
72 | # this is done this way so that each function has access to self object
73 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
74 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
75 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
76 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
77 |
78 | net_main_module.reset_flops_count()
79 |
80 | # Adding varialbles necessary for masked flops computation
81 | net_main_module.apply(add_flops_mask_variable_or_reset)
82 |
83 | return net_main_module
84 |
85 |
86 | def compute_average_flops_cost(self):
87 | """
88 | A method that will be available after add_flops_counting_methods() is called
89 | on a desired net object.
90 | Returns current mean flops consumption per image.
91 | """
92 |
93 | batches_count = self.__batch_counter__
94 |
95 | flops_sum = 0
96 |
97 | for module in self.modules():
98 |
99 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
100 | flops_sum += module.__flops__
101 |
102 |
103 | return flops_sum / batches_count
104 |
105 |
106 | def start_flops_count(self):
107 | """
108 | A method that will be available after add_flops_counting_methods() is called
109 | on a desired net object.
110 | Activates the computation of mean flops consumption per image.
111 | Call it before you run the network.
112 | """
113 |
114 | add_batch_counter_hook_function(self)
115 |
116 | self.apply(add_flops_counter_hook_function)
117 |
118 |
119 | def stop_flops_count(self):
120 | """
121 | A method that will be available after add_flops_counting_methods() is called
122 | on a desired net object.
123 | Stops computing the mean flops consumption per image.
124 | Call whenever you want to pause the computation.
125 | """
126 |
127 | remove_batch_counter_hook_function(self)
128 |
129 | self.apply(remove_flops_counter_hook_function)
130 |
131 |
132 | def reset_flops_count(self):
133 | """
134 | A method that will be available after add_flops_counting_methods() is called
135 | on a desired net object.
136 | Resets statistics computed so far.
137 | """
138 |
139 | add_batch_counter_variables_or_reset(self)
140 |
141 | self.apply(add_flops_counter_variable_or_reset)
142 |
143 |
144 | def add_flops_mask(module, mask):
145 | def add_flops_mask_func(module):
146 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
147 | module.__mask__ = mask
148 |
149 | module.apply(add_flops_mask_func)
150 |
151 |
152 | def remove_flops_mask(module):
153 | module.apply(add_flops_mask_variable_or_reset)
154 |
155 |
156 | # ---- Internal functions
157 |
158 |
159 | def conv_flops_counter_hook(conv_module, input, output):
160 | # Can have multiple inputs, getting the first one
161 | input = input[0]
162 |
163 | batch_size = input.shape[0]
164 | output_height, output_width = output.shape[2:]
165 |
166 | kernel_height, kernel_width = conv_module.kernel_size
167 | in_channels = conv_module.in_channels
168 | out_channels = conv_module.out_channels
169 |
170 | conv_per_position_flops = (kernel_height * kernel_width * in_channels * out_channels) / conv_module.groups
171 |
172 | active_elements_count = batch_size * output_height * output_width
173 |
174 | if conv_module.__mask__ is not None:
175 | # (b, 1, h, w)
176 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width)
177 | active_elements_count = flops_mask.sum()
178 |
179 | overall_conv_flops = conv_per_position_flops * active_elements_count
180 |
181 | bias_flops = 0
182 |
183 | if conv_module.bias is not None:
184 | bias_flops = out_channels * active_elements_count
185 |
186 | overall_flops = overall_conv_flops + bias_flops
187 |
188 | conv_module.__flops__ += overall_flops
189 |
190 |
191 | def linear_flops_counter_hook(linear_module, input, output):
192 |
193 | input = input[0]
194 | batch_size = input.shape[0]
195 | overall_flops = linear_module.in_features * linear_module.out_features * batch_size
196 |
197 | # bias_flops = 0
198 |
199 | # if conv_module.bias is not None:
200 | # bias_flops = out_channels * active_elements_count
201 |
202 | # overall_flops = overall_conv_flops + bias_flops
203 |
204 | linear_module.__flops__ += overall_flops
205 |
206 |
207 | def batch_counter_hook(module, input, output):
208 | # Can have multiple inputs, getting the first one
209 | input = input[0]
210 |
211 | batch_size = input.shape[0]
212 |
213 | module.__batch_counter__ += batch_size
214 |
215 |
216 | def add_batch_counter_variables_or_reset(module):
217 | module.__batch_counter__ = 0
218 |
219 |
220 | def add_batch_counter_hook_function(module):
221 | if hasattr(module, '__batch_counter_handle__'):
222 | return
223 |
224 | handle = module.register_forward_hook(batch_counter_hook)
225 | module.__batch_counter_handle__ = handle
226 |
227 |
228 | def remove_batch_counter_hook_function(module):
229 | if hasattr(module, '__batch_counter_handle__'):
230 | module.__batch_counter_handle__.remove()
231 |
232 | del module.__batch_counter_handle__
233 |
234 |
235 | def add_flops_counter_variable_or_reset(module):
236 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
237 | module.__flops__ = 0
238 |
239 |
240 | def add_flops_counter_hook_function(module):
241 | if isinstance(module, torch.nn.Conv2d):
242 | if hasattr(module, '__flops_handle__'):
243 | return
244 |
245 | handle = module.register_forward_hook(conv_flops_counter_hook)
246 | module.__flops_handle__ = handle
247 | elif isinstance(module, torch.nn.Linear):
248 |
249 | if hasattr(module, '__flops_handle__'):
250 | return
251 |
252 | handle = module.register_forward_hook(linear_flops_counter_hook)
253 | module.__flops_handle__ = handle
254 |
255 |
256 | def remove_flops_counter_hook_function(module):
257 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
258 |
259 | if hasattr(module, '__flops_handle__'):
260 | module.__flops_handle__.remove()
261 |
262 | del module.__flops_handle__
263 |
264 |
265 | # --- Masked flops counting
266 |
267 |
268 | # Also being run in the initialization
269 | def add_flops_mask_variable_or_reset(module):
270 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
271 | module.__mask__ = None
272 |
273 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.distributed as dist
11 |
12 |
13 | class AverageMeter(object):
14 | def __init__(self):
15 | self.reset()
16 |
17 | def reset(self):
18 | self.avg = 0
19 | self.sum = 0
20 | self.cnt = 0
21 |
22 | def update(self, val, n=1):
23 | self.cur = val
24 | self.sum += val * n
25 | self.cnt += n
26 | self.avg = self.sum / self.cnt
27 |
28 |
29 | def accuracy(output, target, topk=(1, 5)):
30 | maxk = max(topk)
31 | batch_size = target.size(0)
32 |
33 | _, pred = output.topk(maxk, 1, True, True)
34 | pred = pred.t()
35 | correct = pred.eq(target.view(1, -1).expand_as(pred))
36 |
37 | res = []
38 | for k in topk:
39 | correct_k = correct[:k].view(-1).float().sum(0)
40 | res.append(correct_k.mul_(100.0/batch_size))
41 | return res
42 |
43 |
44 | def count_parameters_in_MB(model):
45 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "aux" not in name)/1e6
46 |
47 |
48 | def save_checkpoint(state, is_best, save):
49 | filename = os.path.join(save, 'checkpoint.pth.tar')
50 | torch.save(state, filename)
51 | if is_best:
52 | best_filename = os.path.join(save, 'model_best.pth.tar')
53 | shutil.copyfile(filename, best_filename)
54 |
55 |
56 | def save(model, model_path):
57 | torch.save(model.state_dict(), model_path)
58 |
59 |
60 | def load_net_config(path):
61 | with open(path, 'r') as f:
62 | net_config = ''
63 | while True:
64 | line = f.readline().strip()
65 | if 'net_type' in line:
66 | net_type = line.split(': ')[-1]
67 | break
68 | else:
69 | net_config += line
70 | return net_config, net_type
71 |
72 |
73 | def load_model(model, model_path, distributed=False):
74 | logging.info('Start loading the model from ' + model_path)
75 | if 'http' in model_path:
76 | model_addr = model_path
77 | model_path = model_path.split('/')[-1]
78 | if (not distributed) or (distributed and dist.get_rank()==0):
79 | if os.path.isfile(model_path):
80 | os.system('rm ' + model_path)
81 | os.system('wget -q ' + model_addr)
82 | if distributed:
83 | dist.barrier()
84 | model.load_state_dict(torch.load(model_path,
85 | map_location='cuda:{}'.format(dist.get_rank()) if distributed else None))
86 | logging.info('Loading the model finished!')
87 |
88 |
89 | def create_exp_dir(path):
90 | if not os.path.exists(path):
91 | os.mkdir(path)
92 |
93 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.):
94 | """
95 | Label smoothing implementation.
96 | This function is taken from https://github.com/MIT-HAN-LAB/ProxylessNAS/blob/master/proxyless_nas/utils.py
97 | """
98 |
99 | logsoftmax = nn.LogSoftmax(dim=1).cuda()
100 | n_classes = pred.size(1)
101 | # convert to one-hot
102 | target = torch.unsqueeze(target, 1)
103 | soft_target = torch.zeros_like(pred)
104 | soft_target.scatter_(1, target, 1)
105 | # label smoothing
106 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
107 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
108 |
109 |
110 | def parse_net_config(net_config):
111 | str_configs = net_config.split('|')
112 | return [eval(str_config) for str_config in str_configs]
113 |
114 |
115 | def set_seed(seed):
116 | np.random.seed(seed)
117 | torch.manual_seed(seed)
118 | torch.cuda.manual_seed(seed)
119 |
120 |
121 | def set_logging(save_path, log_name='log.txt'):
122 | log_format = '%(asctime)s %(message)s'
123 | date_format = '%m/%d %H:%M:%S'
124 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
125 | format=log_format, datefmt=date_format)
126 | fh = logging.FileHandler(os.path.join(save_path, log_name))
127 | fh.setFormatter(logging.Formatter(log_format, date_format))
128 | logging.getLogger().addHandler(fh)
129 |
130 |
131 | def create_save_dir(save_path, job_name):
132 | if job_name != '':
133 | job_name = time.strftime("%Y%m%d-%H%M%S-") + job_name
134 | save_path = os.path.join(save_path, job_name)
135 | create_exp_dir(save_path)
136 | os.system('cp -r ./* '+save_path)
137 | save_path = os.path.join(save_path, 'output')
138 | create_exp_dir(save_path)
139 | else:
140 | save_path = os.path.join(save_path, 'output')
141 | create_exp_dir(save_path)
142 | return save_path, job_name
143 |
144 |
145 | def latency_measure(module, input_size, batch_size, meas_times, mode='gpu'):
146 | assert mode in ['gpu', 'cpu']
147 |
148 | latency = []
149 | module.eval()
150 | input_size = (batch_size,) + tuple(input_size)
151 | input_data = torch.randn(input_size)
152 | if mode=='gpu':
153 | input_data = input_data.cuda()
154 | module.cuda()
155 |
156 | for i in range(meas_times):
157 | with torch.no_grad():
158 | start = time.time()
159 | _ = module(input_data)
160 | torch.cuda.synchronize()
161 | if i >= 100:
162 | latency.append(time.time() - start)
163 | # print(np.mean(latency) * 1e3, 'ms')
164 | return np.mean(latency) * 1e3
165 |
166 |
167 | def latency_measure_fw(module, input_data, meas_times):
168 | latency = []
169 | module.eval()
170 |
171 | for i in range(meas_times):
172 | with torch.no_grad():
173 | start = time.time()
174 | output_data = module(input_data)
175 | torch.cuda.synchronize()
176 | if i >= 100:
177 | latency.append(time.time() - start)
178 | # print(np.mean(latency) * 1e3, 'ms')
179 | return np.mean(latency) * 1e3, output_data
180 |
181 |
182 | def record_topk(k, rec_list, data, comp_attr, check_attr):
183 | def get_insert_idx(orig_list, data, comp_attr):
184 | start = 0
185 | end = len(orig_list)
186 | while start < end:
187 | mid = (start + end) // 2
188 | if data[comp_attr] < orig_list[mid][comp_attr]:
189 | start = mid + 1
190 | else:
191 | end = mid
192 | return start
193 |
194 | if_insert = False
195 | insert_idx = get_insert_idx(rec_list, data, comp_attr)
196 | if insert_idx < k:
197 | rec_list.insert(insert_idx, data)
198 | if_insert = True
199 | while len(rec_list) > k:
200 | rec_list.pop()
201 | return if_insert
202 |
203 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
204 | decay = []
205 | no_decay = []
206 | for name, param in model.named_parameters():
207 | if not param.requires_grad:
208 | continue # frozen weights
209 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
210 | no_decay.append(param)
211 | else:
212 | decay.append(param)
213 | return [
214 | {'params': no_decay, 'weight_decay': 0.},
215 | {'params': decay, 'weight_decay': weight_decay}]
216 |
--------------------------------------------------------------------------------