├── .gitignore ├── LICENSE ├── README.md ├── benchmark ├── __init__.py ├── compute_flops.py ├── compute_madd.py ├── compute_memory.py ├── compute_speed.py ├── model_hook.py ├── reporter.py ├── stat_tree.py └── statistics.py ├── datasets.py ├── losses.py ├── lr_scheduler.py ├── networks.py ├── profile_example.py ├── test.py ├── test_example.sh ├── train.py ├── train_example.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | checkpoints*/ 132 | log.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Aber Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImageNet-training 2 | 3 | Pytorch ImageNet training codes with various tricks, lr schedulers, distributed training, mixed precision training, DALI dataloader etc. We hope this repo can help ImageNet experiments in NAS researches. 4 | 5 | ## Train 6 | ``` 7 | CUDA_VISIBLE_DEVICES=0 python -u train.py --train_root /path/to/imagenet/train_set --val_root /path/to/imagenet/val_set --train_list /path/to/imagenet/train_list --val_list /path/to/imagenet/val_list 8 | ``` 9 | 10 | Please refer to [train_example.sh](https://github.com/AberHu/ImageNet-training/blob/master/train_example.sh) for more details. 11 | 12 | ## Test 13 | ``` 14 | CUDA_VISIBLE_DEVICES=0 python -u test.py --val_root /path/to/imagenet/val_set --val_list /path/to/imagenet/val_list --weights /path/to/pretrained_weights 15 | ``` 16 | 17 | Please refer to [test_example.sh](https://github.com/AberHu/ImageNet-training/blob/master/test_example.sh) for more details. 18 | 19 | ## Model Profiling 20 | Please refer to [profile_example.py](https://github.com/AberHu/ImageNet-training/blob/master/profile_example.py) for more details. 21 | 22 | ## Tested on 23 | Python == 3.7.6
24 | pytorch == 1.5.1
25 | torchvision == 0.6.1
26 | nvidia.dali == 0.22.0
27 | cuDNN == 7.6.5
28 | apex from [this link](https://github.com/NVIDIA/apex.git) 29 | 30 | ## License 31 | This repo is released under the MIT license. Please see the [LICENSE](https://github.com/AberHu/ImageNet-training/blob/master/LICENSE) file for more information. 32 | -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from .compute_speed import compute_speed 2 | from .compute_memory import compute_memory 3 | from .compute_madd import compute_madd 4 | from .compute_flops import compute_flops 5 | from .stat_tree import StatTree, StatNode 6 | from .model_hook import ModelHook 7 | from .statistics import ModelStat, stat 8 | from .reporter import report_format 9 | 10 | __all__ = [ 'StatTree', 'StatNode', 'ModelHook', 'ModelStat', 'stat', 'report_format' 11 | 'compute_speed', 'compute_memory', 'compute_madd', 'compute_flops'] -------------------------------------------------------------------------------- /benchmark/compute_flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | sys.path.append('..') 6 | from networks import HSwish, HSigmoid, Swish, Sigmoid 7 | 8 | 9 | def compute_flops(module, inp, out): 10 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.LeakyReLU)): 11 | return compute_ReLU_flops(module, inp, out), 'Activation' 12 | elif isinstance(module, nn.ELU): 13 | return compute_ELU_flops(module, inp, out), 'Activation' 14 | elif isinstance(module, Sigmoid): 15 | return compute_Sigmoid_flops(module, inp, out), 'Activation' 16 | elif isinstance(module, HSigmoid): 17 | return compute_HSigmoid_flops(module, inp, out), 'Activation' 18 | elif isinstance(module, Swish): 19 | return compute_Swish_flops(module, inp, out), 'Activation' 20 | elif isinstance(module, HSwish): 21 | return compute_HSwish_flops(module, inp, out), 'Activation' 22 | elif isinstance(module, nn.Conv2d): 23 | return compute_Conv2d_flops(module, inp, out), 'Conv2d' 24 | elif isinstance(module, nn.ConvTranspose2d): 25 | return compute_ConvTranspose2d_flops(module, inp, out), 'ConvTranspose2d' 26 | elif isinstance(module, nn.BatchNorm2d): 27 | return compute_BatchNorm2d_flops(module, inp, out), 'BatchNorm2d' 28 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)): 29 | return compute_Pool2d_flops(module, inp, out), 'Pool2d' 30 | elif isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)): 31 | return compute_AdaptivePool2d_flops(module, inp, out), 'Pool2d' 32 | elif isinstance(module, nn.Linear): 33 | return compute_Linear_flops(module, inp, out), 'Linear' 34 | else: 35 | print("[Flops]: {} is not supported!".format(type(module).__name__)) 36 | return 0, -1 37 | pass 38 | 39 | 40 | def compute_ReLU_flops(module, inp, out): 41 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.LeakyReLU)) 42 | 43 | batch_size = inp.size()[0] 44 | active_elements_count = batch_size 45 | 46 | for s in inp.size()[1:]: 47 | active_elements_count *= s 48 | 49 | return active_elements_count 50 | 51 | 52 | def compute_ELU_flops(module, inp, out): 53 | assert isinstance(module, nn.ELU) 54 | 55 | batch_size = inp.size()[0] 56 | active_elements_count = batch_size 57 | 58 | for s in inp.size()[1:]: 59 | active_elements_count *= s 60 | active_elements_count *= 3 61 | 62 | return active_elements_count 63 | 64 | 65 | def compute_Sigmoid_flops(module, inp, out): 66 | assert isinstance(module, Sigmoid) 67 | 68 | batch_size = inp.size()[0] 69 | active_elements_count = batch_size 70 | 71 | for s in inp.size()[1:]: 72 | active_elements_count *= s 73 | active_elements_count *= 4 74 | 75 | return active_elements_count 76 | 77 | 78 | def compute_HSigmoid_flops(module, inp, out): 79 | assert isinstance(module, HSigmoid) 80 | 81 | batch_size = inp.size()[0] 82 | active_elements_count = batch_size 83 | 84 | for s in inp.size()[1:]: 85 | active_elements_count *= s 86 | active_elements_count *= (2 + 1) 87 | 88 | return active_elements_count 89 | 90 | 91 | def compute_Swish_flops(module, inp, out): 92 | assert isinstance(module, Swish) 93 | 94 | batch_size = inp.size()[0] 95 | active_elements_count = batch_size 96 | 97 | for s in inp.size()[1:]: 98 | active_elements_count *= s 99 | active_elements_count *= (1 + 4) 100 | 101 | return active_elements_count 102 | 103 | 104 | def compute_HSwish_flops(module, inp, out): 105 | assert isinstance(module, HSwish) 106 | 107 | batch_size = inp.size()[0] 108 | active_elements_count = batch_size 109 | 110 | for s in inp.size()[1:]: 111 | active_elements_count *= s 112 | active_elements_count *= (1 + 3) 113 | 114 | return active_elements_count 115 | 116 | 117 | def compute_Conv2d_flops(module, inp, out): 118 | # Can have multiple inputs, getting the first one 119 | assert isinstance(module, nn.Conv2d) 120 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 121 | 122 | batch_size = inp.size()[0] 123 | in_c = inp.size()[1] 124 | k_h, k_w = module.kernel_size 125 | out_c, out_h, out_w = out.size()[1:] 126 | groups = module.groups 127 | 128 | conv_per_position_flops = k_h * k_w * in_c * out_c // groups 129 | active_elements_count = batch_size * out_h * out_w 130 | total_conv_flops = conv_per_position_flops * active_elements_count 131 | 132 | bias_flops = 0 133 | if module.bias is not None: 134 | bias_flops = out_c * active_elements_count 135 | 136 | total_flops = total_conv_flops + bias_flops 137 | return total_flops 138 | 139 | 140 | def compute_ConvTranspose2d_flops(module, inp, out): 141 | # Can have multiple inputs, getting the first one 142 | assert isinstance(module, nn.ConvTranspose2d) 143 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 144 | 145 | batch_size = inp.size()[0] 146 | in_c = inp.size()[1] 147 | k_h, k_w = module.kernel_size 148 | out_c, out_h, out_w = out.size()[1:] 149 | groups = module.groups 150 | 151 | conv_per_position_flops = k_h * k_w * in_c * out_c // groups 152 | active_elements_count = batch_size * out_h * out_w 153 | total_conv_flops = conv_per_position_flops * active_elements_count 154 | 155 | bias_flops = 0 156 | if module.bias is not None: 157 | bias_flops = out_c * active_elements_count 158 | 159 | total_flops = total_conv_flops + bias_flops 160 | return total_flops 161 | 162 | 163 | def compute_BatchNorm2d_flops(module, inp, out): 164 | assert isinstance(module, nn.BatchNorm2d) 165 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 166 | 167 | bn_flops = np.prod(inp.shape) 168 | if module.affine: 169 | bn_flops *= 2 170 | 171 | return bn_flops 172 | 173 | 174 | def compute_Pool2d_flops(module, inp, out): 175 | assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)) 176 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 177 | 178 | if isinstance(module.kernel_size, (tuple, list)): 179 | k_h, k_w = module.kernel_size 180 | else: 181 | k_h, k_w = module.kernel_size, module.kernel_size 182 | out_c, out_h, out_w = out.size()[1:] 183 | batch_size = inp.size()[0] 184 | 185 | pool_flops = batch_size * out_c * out_h * out_w * k_h * k_w 186 | 187 | return pool_flops 188 | 189 | 190 | def compute_AdaptivePool2d_flops(module, inp, out): 191 | assert isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)) 192 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 193 | 194 | inp_c, inp_h, inp_w = inp.size()[1:] 195 | out_c, out_h, out_w = out.size()[1:] 196 | k_h = int(round(inp_h / out_h)) 197 | k_w = int(round(inp_w / out_w)) 198 | batch_size = inp.size()[0] 199 | 200 | adaptive_pool_flops = batch_size * out_c * out_h * out_w * k_h * k_w 201 | 202 | return np.prod(inp.shape) 203 | 204 | 205 | def compute_Linear_flops(module, inp, out): 206 | assert isinstance(module, nn.Linear) 207 | assert len(inp.size()) == 2 and len(out.size()) == 2 208 | 209 | batch_size = inp.size()[0] 210 | num_in_features = inp.size()[1] 211 | num_out_features = out.size()[1] 212 | 213 | total_fc_flops = batch_size * num_in_features * num_out_features 214 | 215 | bias_flops = 0 216 | if module.bias is not None: 217 | bias_flops = batch_size * num_out_features 218 | 219 | total_flops = total_fc_flops + bias_flops 220 | return total_flops 221 | 222 | -------------------------------------------------------------------------------- /benchmark/compute_madd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | sys.path.append('..') 6 | from networks import HSwish, HSigmoid, Swish, Sigmoid 7 | 8 | def compute_madd(module, inp, out): 9 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.LeakyReLU, nn.PReLU)): 10 | return compute_ReLU_madd(module, inp, out) 11 | elif isinstance(module, nn.ELU): 12 | return compute_ELU_madd(module, inp, out) 13 | elif isinstance(module, Sigmoid): 14 | return compute_Sigmoid_madd(module, inp, out) 15 | elif isinstance(module, HSigmoid): 16 | return compute_HSigmoid_madd(module, inp, out) 17 | elif isinstance(module, Swish): 18 | return compute_Swish_madd(module, inp, out) 19 | elif isinstance(module, HSwish): 20 | return compute_HSwish_madd(module, inp, out) 21 | elif isinstance(module, nn.Conv2d): 22 | return compute_Conv2d_madd(module, inp, out) 23 | elif isinstance(module, nn.ConvTranspose2d): 24 | return compute_ConvTranspose2d_madd(module, inp, out) 25 | elif isinstance(module, nn.BatchNorm2d): 26 | return compute_BatchNorm2d_madd(module, inp, out) 27 | elif isinstance(module, nn.Linear): 28 | return compute_Linear_madd(module, inp, out) 29 | elif isinstance(module, nn.MaxPool2d): 30 | return compute_MaxPool2d_madd(module, inp, out) 31 | elif isinstance(module, nn.AdaptiveMaxPool2d): 32 | return compute_AdaptiveMaxPool2d_madd(module, inp, out) 33 | elif isinstance(module, nn.AvgPool2d): 34 | return compute_AvgPool2d_madd(module, inp, out) 35 | elif isinstance(module, nn.AdaptiveAvgPool2d): 36 | return compute_AdaptiveAvgPool2d_madd(module, inp, out) 37 | else: 38 | print("[MAdd]: {} is not supported!".format(type(module).__name__)) 39 | return 0 40 | 41 | 42 | def compute_ReLU_madd(module, inp, out): 43 | assert isinstance(module, (nn.ReLU, nn.ReLU6)) 44 | 45 | count = 1 46 | for i in inp.size()[1:]: 47 | count *= i 48 | 49 | return count 50 | 51 | 52 | def compute_ELU_madd(module, inp, out): 53 | assert isinstance(module, nn.ELU) 54 | 55 | count = 1 56 | for i in inp.size()[1:]: 57 | count *= i 58 | total_mul = count + count 59 | total_add = count 60 | 61 | return total_mul + total_add 62 | 63 | 64 | def compute_Sigmoid_madd(module, inp, out): 65 | assert isinstance(module, Sigmoid) 66 | 67 | count = 1 68 | for i in inp.size()[1:]: 69 | count *= i 70 | total_mul = count + count + count 71 | total_add = count 72 | 73 | return total_mul + total_add 74 | 75 | 76 | def compute_HSigmoid_madd(module, inp, out): 77 | assert isinstance(module, HSigmoid) 78 | 79 | count = 1 80 | for i in inp.size()[1:]: 81 | count *= i 82 | total_mul = count + (count) 83 | total_add = count 84 | 85 | return total_mul + total_add 86 | 87 | 88 | def compute_Swish_madd(module, inp, out): 89 | assert isinstance(module, Swish) 90 | 91 | count = 1 92 | for i in inp.size()[1:]: 93 | count *= i 94 | total_mul = count + (count + count + count) 95 | total_add = 0 + (count) 96 | 97 | return total_mul + total_add 98 | 99 | 100 | def compute_HSwish_madd(module, inp, out): 101 | assert isinstance(module, HSwish) 102 | 103 | count = 1 104 | for i in inp.size()[1:]: 105 | count *= i 106 | total_mul = count + (count + count) 107 | total_add = 0 + (count) 108 | 109 | return total_mul + total_add 110 | 111 | 112 | def compute_Conv2d_madd(module, inp, out): 113 | assert isinstance(module, nn.Conv2d) 114 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 115 | 116 | in_c = inp.size()[1] 117 | k_h, k_w = module.kernel_size 118 | out_c, out_h, out_w = out.size()[1:] 119 | groups = module.groups 120 | 121 | # ops per output element 122 | kernel_mul = k_h * k_w * (in_c // groups) 123 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1) 124 | 125 | kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups) 126 | kernel_add_group = kernel_add * out_h * out_w * (out_c // groups) 127 | 128 | total_mul = kernel_mul_group * groups 129 | total_add = kernel_add_group * groups 130 | 131 | return total_mul + total_add 132 | 133 | 134 | def compute_ConvTranspose2d_madd(module, inp, out): 135 | assert isinstance(module, nn.ConvTranspose2d) 136 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 137 | 138 | in_c, in_h, in_w = inp.size()[1:] 139 | k_h, k_w = module.kernel_size 140 | out_c, out_h, out_w = out.size()[1:] 141 | groups = module.groups 142 | 143 | kernel_mul = k_h * k_w * (in_c // groups) 144 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1) 145 | 146 | kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups) 147 | kernel_add_group = kernel_add * in_h * in_w * (out_c // groups) 148 | 149 | total_mul = kernel_mul_group * groups 150 | total_add = kernel_add_group * groups 151 | 152 | return total_mul + total_add 153 | 154 | 155 | def compute_BatchNorm2d_madd(module, inp, out): 156 | assert isinstance(module, nn.BatchNorm2d) 157 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 158 | 159 | in_c, in_h, in_w = inp.size()[1:] 160 | 161 | # 1. sub mean 162 | # 2. div standard deviation 163 | # 3. mul alpha 164 | # 4. add beta 165 | return 4 * in_c * in_h * in_w 166 | 167 | 168 | def compute_Linear_madd(module, inp, out): 169 | assert isinstance(module, nn.Linear) 170 | assert len(inp.size()) == 2 and len(out.size()) == 2 171 | 172 | num_in_features = inp.size()[1] 173 | num_out_features = out.size()[1] 174 | 175 | mul = num_in_features 176 | add = num_in_features - 1 + (0 if module.bias is None else 1) 177 | 178 | return num_out_features * (mul + add) 179 | 180 | 181 | def compute_MaxPool2d_madd(module, inp, out): 182 | assert isinstance(module, nn.MaxPool2d) 183 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 184 | 185 | if isinstance(module.kernel_size, (tuple, list)): 186 | k_h, k_w = module.kernel_size 187 | else: 188 | k_h, k_w = module.kernel_size, module.kernel_size 189 | out_c, out_h, out_w = out.size()[1:] 190 | 191 | return (k_h * k_w - 1) * out_h * out_w * out_c 192 | 193 | 194 | def compute_AdaptiveMaxPool2d_madd(module, inp, out): 195 | assert isinstance(module, nn.AdaptiveMaxPool2d) 196 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 197 | 198 | in_c, in_h, in_w = inp.size()[1:] 199 | out_c, out_h, out_w = out.size()[1:] 200 | k_h = int(round(in_h / out_h)) 201 | k_w = int(round(in_w / out_w)) 202 | 203 | return (k_h * k_w - 1) * out_h * out_w * out_c 204 | 205 | 206 | def compute_AvgPool2d_madd(module, inp, out): 207 | assert isinstance(module, nn.AvgPool2d) 208 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 209 | 210 | if isinstance(module.kernel_size, (tuple, list)): 211 | k_h, k_w = module.kernel_size 212 | else: 213 | k_h, k_w = module.kernel_size, module.kernel_size 214 | out_c, out_h, out_w = out.size()[1:] 215 | 216 | kernel_add = k_h * k_w - 1 217 | kernel_avg = 1 218 | 219 | return (kernel_add + kernel_avg) * out_h * out_w * out_c 220 | 221 | 222 | def compute_AdaptiveAvgPool2d_madd(module, inp, out): 223 | assert isinstance(module, nn.AdaptiveAvgPool2d) 224 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 225 | 226 | in_c, in_h, in_w = inp.size()[1:] 227 | out_c, out_h, out_w = out.size()[1:] 228 | k_h = int(round(in_h / out_h)) 229 | k_w = int(round(in_w / out_w)) 230 | 231 | kernel_add = k_h * k_w - 1 232 | kernel_avg = 1 233 | 234 | return (kernel_add + kernel_avg) * out_h * out_w * out_c 235 | -------------------------------------------------------------------------------- /benchmark/compute_memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import sys 5 | sys.path.append('..') 6 | from networks import HSwish, HSigmoid, Swish, Sigmoid 7 | 8 | def compute_memory(module, inp, out): 9 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)): 10 | return compute_ReLU_memory(module, inp, out) 11 | elif isinstance(module, nn.PReLU): 12 | return compute_PReLU_memory(module, inp, out) 13 | elif isinstance(module, (Sigmoid, HSigmoid)): 14 | return compute_Sigmoid_memory(module, inp, out) 15 | elif isinstance(module, (Swish, HSwish)): 16 | return compute_Swish_memory(module, inp, out) 17 | elif isinstance(module, nn.Conv2d): 18 | return compute_Conv2d_memory(module, inp, out) 19 | elif isinstance(module, nn.ConvTranspose2d): 20 | return compute_ConvTranspose2d_memory(module, inp, out) 21 | elif isinstance(module, nn.BatchNorm2d): 22 | return compute_BatchNorm2d_memory(module, inp, out) 23 | elif isinstance(module, nn.Linear): 24 | return compute_Linear_memory(module, inp, out) 25 | elif isinstance(module, ( 26 | nn.AvgPool2d, nn.MaxPool2d, nn.AdaptiveAvgPool2d, 27 | nn.AdaptiveMaxPool2d)): 28 | return compute_Pool2d_memory(module, inp, out) 29 | else: 30 | print("[Memory]: {} is not supported!".format(type(module).__name__)) 31 | return 0, 0 32 | pass 33 | 34 | 35 | def num_params(module): 36 | return sum(p.numel() for p in module.parameters() if p.requires_grad) # why conditioned if p.requires_grad ??? 37 | 38 | 39 | def compute_ReLU_memory(module, inp, out): 40 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)) 41 | batch_size = inp.size()[0] 42 | mread = batch_size * inp.size()[1:].numel() 43 | mwrite = batch_size * inp.size()[1:].numel() 44 | 45 | return (mread, mwrite) 46 | 47 | 48 | def compute_PReLU_memory(module, inp, out): 49 | assert isinstance(module, (nn.PReLU)) 50 | batch_size = inp.size()[0] 51 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 52 | mwrite = batch_size * inp.size()[1:].numel() 53 | 54 | return (mread, mwrite) 55 | 56 | 57 | def compute_Sigmoid_memory(module, inp, out): 58 | assert isinstance(module, (Sigmoid, HSigmoid)) 59 | batch_size = inp.size()[0] 60 | mread = batch_size * inp.size()[1:].numel() 61 | mwrite = batch_size * inp.size()[1:].numel() 62 | 63 | return (mread, mwrite) 64 | 65 | 66 | def compute_Swish_memory(module, inp, out): 67 | assert isinstance(module, (Swish, HSwish)) 68 | batch_size = inp.size()[0] 69 | mread = batch_size * (inp.size()[1:].numel() + inp.size()[1:].numel()) 70 | mwrite = batch_size * inp.size()[1:].numel() 71 | 72 | return (mread, mwrite) 73 | 74 | 75 | def compute_Conv2d_memory(module, inp, out): 76 | assert isinstance(module, nn.Conv2d) 77 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 78 | 79 | batch_size = inp.size()[0] 80 | in_c = inp.size()[1] 81 | out_c, out_h, out_w = out.size()[1:] 82 | 83 | # This includes weighs with bias if the module contains it. 84 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 85 | mwrite = batch_size * out_c * out_h * out_w 86 | return (mread, mwrite) 87 | 88 | 89 | def compute_ConvTranspose2d_memory(module, inp, out): 90 | assert isinstance(module, nn.ConvTranspose2d) 91 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 92 | 93 | batch_size = inp.size()[0] 94 | in_c = inp.size()[1] 95 | out_c, out_h, out_w = out.size()[1:] 96 | 97 | # This includes weighs with bias if the module contains it. 98 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 99 | mwrite = batch_size * out_c * out_h * out_w 100 | return (mread, mwrite) 101 | 102 | 103 | def compute_BatchNorm2d_memory(module, inp, out): 104 | assert isinstance(module, nn.BatchNorm2d) 105 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 106 | batch_size, in_c, in_h, in_w = inp.size() 107 | 108 | mread = batch_size * (inp.size()[1:].numel() + 2 * in_c) 109 | mwrite = inp.size().numel() 110 | return (mread, mwrite) 111 | 112 | 113 | def compute_Linear_memory(module, inp, out): 114 | assert isinstance(module, nn.Linear) 115 | assert len(inp.size()) == 2 and len(out.size()) == 2 116 | batch_size = inp.size()[0] 117 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 118 | mwrite = out.size().numel() 119 | 120 | return (mread, mwrite) 121 | 122 | 123 | def compute_Pool2d_memory(module, inp, out): 124 | assert isinstance(module, ( 125 | nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)) 126 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 127 | batch_size = inp.size()[0] 128 | mread = batch_size * inp.size()[1:].numel() 129 | mwrite = batch_size * out.size()[1:].numel() 130 | return (mread, mwrite) 131 | -------------------------------------------------------------------------------- /benchmark/compute_speed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | 7 | 8 | def compute_speed(model, input_size, device='cuda:0', iteration=1000): 9 | assert isinstance(input_size, (list, tuple)) 10 | assert len(input_size) == 4 11 | os.environ['OMP_NUM_THREADS'] = '1' 12 | os.environ['MKL_NUM_THREADS'] = '1' 13 | 14 | device = torch.device(device) 15 | if 'cuda' in str(device): 16 | cudnn.enabled = True 17 | cudnn.benchmark = True 18 | torch.cuda.set_device(device) 19 | 20 | model = model.to(device) 21 | model.eval() 22 | 23 | x = torch.randn(*input_size, device=device) 24 | x.to(device) 25 | 26 | # warmup for 100 iterations 27 | for _ in range(100): 28 | model(x) 29 | 30 | print('=============Speed Testing=============') 31 | print('Device: {}'.format(str(device))) 32 | if 'cuda' in str(device): 33 | torch.cuda.synchronize() # wait for cuda to finish (cuda is asynchronous!) 34 | torch.cuda.synchronize() 35 | t_start = time.time() 36 | for _ in range(iteration): 37 | model(x) 38 | if 'cuda' in str(device): 39 | torch.cuda.synchronize() # wait for cuda to finish (cuda is asynchronous!) 40 | torch.cuda.synchronize() 41 | elapsed_time = time.time() - t_start 42 | print('Elapsed time: [%.2fs / %diter]' % (elapsed_time, iteration)) 43 | print('Speed Time: %.2fms/iter FPS: %.2f' % ( 44 | elapsed_time / iteration * 1000, iteration * input_size[0] / elapsed_time)) 45 | 46 | 47 | -------------------------------------------------------------------------------- /benchmark/model_hook.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import OrderedDict 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .compute_madd import compute_madd 8 | from .compute_flops import compute_flops 9 | from .compute_memory import compute_memory 10 | 11 | 12 | class ModelHook(object): 13 | def __init__(self, model, input_size): 14 | assert isinstance(model, nn.Module) 15 | assert isinstance(input_size, (list, tuple)) 16 | 17 | self._model = model 18 | self._input_size = input_size 19 | self._origin_call = dict() # sub module call hook 20 | 21 | self._hook_model() 22 | x = torch.rand(*self._input_size) # add module duration time 23 | self._model.eval() 24 | self._model(x) 25 | 26 | @staticmethod 27 | def _register_buffer(module): 28 | assert isinstance(module, nn.Module) 29 | 30 | if len(list(module.children())) > 0: 31 | return 32 | 33 | module.register_buffer('input_shape', torch.zeros(3).int()) 34 | module.register_buffer('output_shape', torch.zeros(3).int()) 35 | module.register_buffer('parameter_quantity', torch.zeros(1).int()) 36 | module.register_buffer('inference_memory', torch.zeros(1).long()) 37 | module.register_buffer('MAdd', torch.zeros(1).long()) 38 | module.register_buffer('duration', torch.zeros(1).float()) 39 | module.register_buffer('ConvFlops', torch.zeros(1).long()) 40 | module.register_buffer('Flops', torch.zeros(1).long()) 41 | module.register_buffer('MemRead', torch.zeros(1).long()) 42 | module.register_buffer('MemWrite', torch.zeros(1).long()) 43 | 44 | def _sub_module_call_hook(self): 45 | def wrap_call(module, *input, **kwargs): 46 | assert module.__class__ in self._origin_call 47 | 48 | # Itemsize for memory 49 | itemsize = input[0].detach().numpy().itemsize 50 | 51 | # !!!!!! added by Aber Hu 52 | # Duration is not accurate, since it only runs 1 time, no warmup, no mulit runs. 53 | start = time.time() 54 | output = self._origin_call[module.__class__](module, *input, **kwargs) 55 | end = time.time() 56 | module.duration = torch.from_numpy( 57 | np.array([end - start], dtype=np.float32)) 58 | 59 | module.input_shape = torch.from_numpy( 60 | np.array(input[0].size()[1:], dtype=np.int32)) 61 | module.output_shape = torch.from_numpy( 62 | np.array(output.size()[1:], dtype=np.int32)) 63 | 64 | parameter_quantity = 0 65 | # iterate through parameters and count num params 66 | for name, p in module._parameters.items(): 67 | parameter_quantity += (0 if p is None else torch.numel(p)) 68 | module.parameter_quantity = torch.from_numpy( 69 | np.array([parameter_quantity], dtype=np.long)) 70 | 71 | inference_memory = 1 72 | for s in output.size()[1:]: 73 | inference_memory *= s 74 | # memory += parameters_number # exclude parameter memory 75 | # shown as MB unit 76 | inference_memory = inference_memory * itemsize / (1024 ** 2) 77 | module.inference_memory = torch.from_numpy( 78 | np.array([inference_memory], dtype=np.float32)) 79 | 80 | if len(input) == 1: 81 | madd = compute_madd(module, input[0], output) 82 | conv_flops = 0 83 | flops, type = compute_flops(module, input[0], output) 84 | if type == 'Conv2d': 85 | conv_flops = flops 86 | memread, memwrite = compute_memory(module, input[0], output) 87 | elif len(input) > 1: 88 | madd = compute_madd(module, input, output) 89 | conv_flops = 0 90 | flops, type = compute_flops(module, input, output) 91 | if type == 'Conv2d': 92 | conv_flops = flops 93 | memread, memwrite = compute_memory(module, input, output) 94 | else: # error 95 | madd = 0 96 | flops = 0 97 | conv_flops = 0 98 | memread, memwrite = [0, 0] 99 | module.MAdd = torch.from_numpy( 100 | np.array([madd], dtype=np.int64)) 101 | module.Flops = torch.from_numpy( 102 | np.array([flops], dtype=np.int64)) 103 | module.ConvFlops = torch.from_numpy( 104 | np.array([conv_flops], dtype=np.int64)) 105 | module.MemRead = torch.from_numpy( 106 | np.array([memread], dtype=np.int64)*itemsize) 107 | module.MemWrite = torch.from_numpy( 108 | np.array([memwrite], dtype=np.int64)*itemsize) 109 | 110 | return output 111 | 112 | for module in self._model.modules(): 113 | if len(list(module.children())) == 0 and module.__class__ not in self._origin_call: 114 | self._origin_call[module.__class__] = module.__class__.__call__ 115 | module.__class__.__call__ = wrap_call 116 | 117 | def _sub_module_call_unhook(self): 118 | for module in self._model.modules(): 119 | if len(list(module.children())) == 0 and module.__class__ in self._origin_call: 120 | module.__class__.__call__ = self._origin_call[module.__class__] 121 | 122 | def _hook_model(self): 123 | self._model.apply(self._register_buffer) 124 | self._sub_module_call_hook() 125 | 126 | def _unhook_model(self): 127 | self._sub_module_call_unhook() 128 | 129 | @staticmethod 130 | def _retrieve_leaf_modules(model): 131 | leaf_modules = [] 132 | for name, m in model.named_modules(): 133 | if len(list(m.children())) == 0: 134 | leaf_modules.append((name, m)) 135 | return leaf_modules 136 | 137 | def retrieve_leaf_modules(self): 138 | return OrderedDict(self._retrieve_leaf_modules(self._model)) 139 | -------------------------------------------------------------------------------- /benchmark/reporter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | pd.set_option('display.width', 1000) 4 | pd.set_option('display.max_rows', 10000) 5 | pd.set_option('display.max_columns', 10000) 6 | 7 | 8 | def round_value(value, binary=False): 9 | divisor = 1024. if binary else 1000. 10 | 11 | if value // divisor ** 4 > 0: 12 | return str(round(value / divisor ** 4, 2)) + 'T' 13 | elif value // divisor ** 3 > 0: 14 | return str(round(value / divisor ** 3, 2)) + 'G' 15 | elif value // divisor ** 2 > 0: 16 | return str(round(value / divisor ** 2, 2)) + 'M' 17 | elif value // divisor > 0: 18 | return str(round(value / divisor, 2)) + 'K' 19 | return str(value) 20 | 21 | 22 | def report_format(collected_nodes, brief_report=False): 23 | data = list() 24 | for node in collected_nodes: 25 | name = node.name 26 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format( 27 | *[e for e in node.input_shape]) 28 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format( 29 | *[e for e in node.output_shape]) 30 | parameter_quantity = node.parameter_quantity 31 | inference_memory = node.inference_memory 32 | MAdd = node.MAdd 33 | Flops = node.Flops 34 | ConvFlops = node.ConvFlops 35 | mread = node.MemRead 36 | mwrite = node.MemWrite 37 | duration = node.duration 38 | data.append([name, input_shape, output_shape, parameter_quantity, 39 | inference_memory, MAdd, duration, Flops, ConvFlops, mread, mwrite]) 40 | df = pd.DataFrame(data) 41 | df.columns = ['module name', 'input shape', 'output shape', 42 | 'params', 'memory(MB)', 'MAdd', 'duration', 'Flops', 43 | 'ConvFlops', 'MemRead(B)', 'MemWrite(B)'] 44 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7) 45 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)'] 46 | total_parameters_quantity = df['params'].sum() 47 | total_memory = df['memory(MB)'].sum() 48 | total_operation_quantity = df['MAdd'].sum() 49 | total_flops = df['Flops'].sum() 50 | total_conv_flops = df['ConvFlops'].sum() 51 | total_duration = df['duration[%]'].sum() 52 | total_mread = df['MemRead(B)'].sum() 53 | total_mwrite = df['MemWrite(B)'].sum() 54 | total_memrw = df['MemR+W(B)'].sum() 55 | del df['duration'] 56 | 57 | # Add Total row 58 | total_df = pd.Series([total_parameters_quantity, total_memory, 59 | total_operation_quantity, total_flops, 60 | total_conv_flops, total_duration, 61 | total_mread, total_mwrite, total_memrw], 62 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 63 | 'ConvFlops', 'duration[%]', 64 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'], 65 | name='total') 66 | df = df.append(total_df) 67 | 68 | df = df.fillna(' ') 69 | df['params'] = df['params'].apply(lambda x: '{:,}'.format(x)) 70 | df['memory(MB)'] = df['memory(MB)'].apply(lambda x: '{:,.2f}'.format(x)) 71 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x)) 72 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x)) 73 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x)) 74 | df['ConvFlops'] = df['ConvFlops'].apply(lambda x: '{:,}'.format(x)) 75 | df['MemRead(B)'] = df['MemRead(B)'].apply(lambda x: '{:,}'.format(x)) 76 | df['MemWrite(B)'] = df['MemWrite(B)'].apply(lambda x: '{:,}'.format(x)) 77 | df['MemR+W(B)'] = df['MemR+W(B)'].apply(lambda x: '{:,}'.format(x)) 78 | 79 | if not brief_report: 80 | summary = str(df) + '\n' 81 | summary += "=" * len(str(df).split('\n')[0]) 82 | summary += '\n' 83 | summary += "Total params: {}\n".format(round_value(total_parameters_quantity)) 84 | 85 | summary += "-" * len(str(df).split('\n')[0]) 86 | summary += '\n' 87 | summary += "Total memory: {:.2f}MB\n".format(total_memory) 88 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity)) 89 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops)) 90 | summary += "Total Flops(Conv Only): {}Flops\n".format(round_value(total_conv_flops)) 91 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True)) 92 | else: 93 | summary = "Total params: {}\n".format(round_value(total_parameters_quantity)) 94 | summary += "Total memory: {:.2f}MB\n".format(total_memory) 95 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity)) 96 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops)) 97 | summary += "Total Flops(Conv Only): {}Flops\n".format(round_value(total_conv_flops)) 98 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True)) 99 | 100 | return summary 101 | -------------------------------------------------------------------------------- /benchmark/stat_tree.py: -------------------------------------------------------------------------------- 1 | import queue 2 | 3 | 4 | class StatTree(object): 5 | def __init__(self, root_node): 6 | assert isinstance(root_node, StatNode) 7 | 8 | self.root_node = root_node 9 | 10 | def get_same_level_max_node_depth(self, query_node): 11 | if query_node.name == self.root_node.name: 12 | return 0 13 | same_level_depth = max([child.depth for child in query_node.parent.children]) 14 | return same_level_depth 15 | 16 | def update_stat_nodes_granularity(self): 17 | q = queue.Queue() 18 | q.put(self.root_node) 19 | while not q.empty(): 20 | node = q.get() 21 | node.granularity = self.get_same_level_max_node_depth(node) 22 | for child in node.children: 23 | q.put(child) 24 | 25 | def get_collected_stat_nodes(self, query_granularity): 26 | self.update_stat_nodes_granularity() 27 | 28 | collected_nodes = [] 29 | stack = list() 30 | stack.append(self.root_node) 31 | while len(stack) > 0: 32 | node = stack.pop() 33 | for child in reversed(node.children): 34 | stack.append(child) 35 | if node.depth == query_granularity: 36 | collected_nodes.append(node) 37 | if node.depth < query_granularity <= node.granularity: 38 | collected_nodes.append(node) 39 | return collected_nodes 40 | 41 | 42 | class StatNode(object): 43 | def __init__(self, name=str(), parent=None): 44 | self._name = name 45 | self._input_shape = None 46 | self._output_shape = None 47 | self._parameter_quantity = 0 48 | self._inference_memory = 0 49 | self._MAdd = 0 50 | self._MemRead = 0 51 | self._MemWrite = 0 52 | self._Flops = 0 53 | self._ConvFlops = 0 54 | self._duration = 0 55 | self._duration_percent = 0 56 | 57 | self._granularity = 1 58 | self._depth = 1 59 | self.parent = parent 60 | self.children = list() 61 | 62 | @property 63 | def name(self): 64 | return self._name 65 | 66 | @name.setter 67 | def name(self, name): 68 | self._name = name 69 | 70 | @property 71 | def granularity(self): 72 | return self._granularity 73 | 74 | @granularity.setter 75 | def granularity(self, g): 76 | self._granularity = g 77 | 78 | @property 79 | def depth(self): 80 | d = self._depth 81 | if len(self.children) > 0: 82 | d += max([child.depth for child in self.children]) 83 | return d 84 | 85 | @property 86 | def input_shape(self): 87 | if len(self.children) == 0: # leaf 88 | return self._input_shape 89 | else: 90 | return self.children[0].input_shape 91 | 92 | @input_shape.setter 93 | def input_shape(self, input_shape): 94 | assert isinstance(input_shape, (list, tuple)) 95 | self._input_shape = input_shape 96 | 97 | @property 98 | def output_shape(self): 99 | if len(self.children) == 0: # leaf 100 | return self._output_shape 101 | else: 102 | return self.children[-1].output_shape 103 | 104 | @output_shape.setter 105 | def output_shape(self, output_shape): 106 | assert isinstance(output_shape, (list, tuple)) 107 | self._output_shape = output_shape 108 | 109 | @property 110 | def parameter_quantity(self): 111 | total_parameter_quantity = self._parameter_quantity 112 | for child in self.children: 113 | total_parameter_quantity += child.parameter_quantity 114 | return total_parameter_quantity 115 | 116 | @parameter_quantity.setter 117 | def parameter_quantity(self, parameter_quantity): 118 | assert parameter_quantity >= 0 119 | self._parameter_quantity = parameter_quantity 120 | 121 | @property 122 | def inference_memory(self): 123 | total_inference_memory = self._inference_memory 124 | for child in self.children: 125 | total_inference_memory += child.inference_memory 126 | return total_inference_memory 127 | 128 | @inference_memory.setter 129 | def inference_memory(self, inference_memory): 130 | self._inference_memory = inference_memory 131 | 132 | @property 133 | def MAdd(self): 134 | total_MAdd = self._MAdd 135 | for child in self.children: 136 | total_MAdd += child.MAdd 137 | return total_MAdd 138 | 139 | @MAdd.setter 140 | def MAdd(self, MAdd): 141 | self._MAdd = MAdd 142 | 143 | @property 144 | def Flops(self): 145 | total_Flops = self._Flops 146 | for child in self.children: 147 | total_Flops += child.Flops 148 | return total_Flops 149 | 150 | @Flops.setter 151 | def Flops(self, Flops): 152 | self._Flops = Flops 153 | 154 | @property 155 | def ConvFlops(self): 156 | total_ConvFlops = self._ConvFlops 157 | for child in self.children: 158 | total_ConvFlops += child.ConvFlops 159 | return total_ConvFlops 160 | 161 | @ConvFlops.setter 162 | def ConvFlops(self, ConvFlops): 163 | self._ConvFlops = ConvFlops 164 | 165 | @property 166 | def MemRead(self): 167 | total_MemRead = self._MemRead 168 | for child in self.children: 169 | total_MemRead += child.MemRead 170 | return total_MemRead 171 | 172 | @MemRead.setter 173 | def MemRead(self, MemRead): 174 | self._MemRead = MemRead 175 | 176 | @property 177 | def MemWrite(self): 178 | total_MemWrite = self._MemWrite 179 | for child in self.children: 180 | total_MemWrite += child.MemWrite 181 | return total_MemWrite 182 | 183 | @MemWrite.setter 184 | def MemWrite(self, MemWrite): 185 | self._MemWrite = MemWrite 186 | 187 | @property 188 | def duration(self): 189 | total_duration = self._duration 190 | for child in self.children: 191 | total_duration += child.duration 192 | return total_duration 193 | 194 | @duration.setter 195 | def duration(self, duration): 196 | self._duration = duration 197 | 198 | def find_child_index(self, child_name): 199 | assert isinstance(child_name, str) 200 | 201 | index = -1 202 | for i in range(len(self.children)): 203 | if child_name == self.children[i].name: 204 | index = i 205 | return index 206 | 207 | def add_child(self, node): 208 | assert isinstance(node, StatNode) 209 | 210 | if self.find_child_index(node.name) == -1: # not exist 211 | self.children.append(node) 212 | -------------------------------------------------------------------------------- /benchmark/statistics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | from .model_hook import ModelHook 6 | from .stat_tree import StatTree, StatNode 7 | from .reporter import report_format 8 | 9 | 10 | def get_parent_node(root_node, stat_node_name): 11 | assert isinstance(root_node, StatNode) 12 | 13 | node = root_node 14 | names = stat_node_name.split('.') 15 | for i in range(len(names) - 1): 16 | node_name = '.'.join(names[0:i+1]) 17 | child_index = node.find_child_index(node_name) 18 | assert child_index != -1 19 | node = node.children[child_index] 20 | return node 21 | 22 | 23 | def convert_leaf_modules_to_stat_tree(leaf_modules): 24 | assert isinstance(leaf_modules, OrderedDict) 25 | 26 | create_index = 1 27 | root_node = StatNode(name='root', parent=None) 28 | for leaf_module_name, leaf_module in leaf_modules.items(): 29 | names = leaf_module_name.split('.') 30 | for i in range(len(names)): 31 | create_index += 1 32 | stat_node_name = '.'.join(names[0:i+1]) 33 | parent_node = get_parent_node(root_node, stat_node_name) 34 | node = StatNode(name=stat_node_name, parent=parent_node) 35 | parent_node.add_child(node) 36 | if i == len(names) - 1: # leaf module itself 37 | node.input_shape = leaf_module.input_shape.numpy().tolist() 38 | node.output_shape = leaf_module.output_shape.numpy().tolist() 39 | node.parameter_quantity = leaf_module.parameter_quantity.numpy()[0] 40 | node.inference_memory = leaf_module.inference_memory.numpy()[0] 41 | node.MAdd = leaf_module.MAdd.numpy()[0] 42 | node.Flops = leaf_module.Flops.numpy()[0] 43 | node.ConvFlops = leaf_module.ConvFlops.numpy()[0] 44 | node.duration = leaf_module.duration.numpy()[0] 45 | node.MemRead = leaf_module.MemRead.numpy()[0] 46 | node.MemWrite = leaf_module.MemWrite.numpy()[0] 47 | return StatTree(root_node) 48 | 49 | 50 | class ModelStat(object): 51 | def __init__(self, model, input_size, query_granularity=1, brief_report=False): 52 | assert isinstance(model, nn.Module) 53 | assert isinstance(input_size, (tuple, list)) and len(input_size) == 4 54 | self.model_hook = ModelHook(model, input_size) 55 | self.leaf_modules = self.model_hook.retrieve_leaf_modules() 56 | self.stat_tree = convert_leaf_modules_to_stat_tree(self.leaf_modules) 57 | self._brief_report = brief_report 58 | 59 | if 1 <= query_granularity <= self.stat_tree.root_node.depth: 60 | self._query_granularity = query_granularity 61 | else: 62 | self._query_granularity = self.stat_tree.root_node.depth 63 | 64 | def show_report(self): 65 | collected_nodes = self.stat_tree.get_collected_stat_nodes(self._query_granularity) 66 | report = report_format(collected_nodes, self._brief_report) 67 | print(report) 68 | 69 | def unhook_model(self): 70 | self.model_hook._unhook_model() 71 | 72 | @property 73 | def query_granularity(self): 74 | return self._query_granularity 75 | 76 | @query_granularity.setter 77 | def query_granularity(self, query_granularity): 78 | if 1 <= query_granularity <= self.stat_tree.root_node.depth: 79 | self._query_granularity = query_granularity 80 | else: 81 | self._query_granularity = self.stat_tree.root_node.depth 82 | 83 | @property 84 | def brief_report(self): 85 | return self._brief_report 86 | 87 | @brief_report.setter 88 | def brief_report(self, brief_report): 89 | self._brief_report = brief_report 90 | 91 | 92 | def stat(model, input_size, query_granularity=1, brief_report=False): 93 | ms = ModelStat(model, input_size, query_granularity, brief_report) 94 | ms.show_report() 95 | ms.unhook_model() 96 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image 4 | import torch.utils.data as data 5 | import torchvision.transforms as transforms 6 | import nvidia.dali.pipeline as pipeline 7 | import nvidia.dali.ops as ops 8 | import nvidia.dali.types as types 9 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 10 | IMAGENET_STD = [0.229, 0.224, 0.225] 11 | 12 | 13 | # If get UserWarning: Corrupt EXIF data, use cv2_loader or ignore warnings 14 | def pil_loader(path): 15 | img = Image.open(path).convert('RGB') 16 | return img 17 | 18 | def cv2_loader(path): 19 | img = cv2.imread(path, cv2.IMREAD_COLOR) 20 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 21 | img = Image.fromarray(img) 22 | return img 23 | 24 | default_loader = pil_loader 25 | 26 | 27 | def default_list_reader(list_path): 28 | img_list = [] 29 | with open(list_path, 'r') as f: 30 | for line in f.readlines(): 31 | img_path, label = line.strip().split(' ') 32 | img_list.append((img_path, int(label))) 33 | 34 | return img_list 35 | 36 | 37 | class ImageList(data.Dataset): 38 | def __init__(self, root, list_path, transform=None, list_reader=default_list_reader, loader=default_loader): 39 | self.root = root 40 | self.img_list = list_reader(list_path) 41 | self.transform = transform 42 | self.loader = loader 43 | 44 | def __getitem__(self, index): 45 | img_path, target = self.img_list[index] 46 | img = self.loader(os.path.join(self.root, img_path)) 47 | 48 | if self.transform: 49 | img = self.transform(img) 50 | 51 | return img, target 52 | 53 | def __len__(self): 54 | return len(self.img_list) 55 | 56 | 57 | def get_train_transform(coji=False): 58 | transform_list = [ 59 | transforms.RandomResizedCrop(224), 60 | transforms.RandomHorizontalFlip(0.5), 61 | ] 62 | if coji: 63 | transform_list += [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),] 64 | transform_list += [ 65 | transforms.ToTensor(), 66 | transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), 67 | ] 68 | train_transform = transforms.Compose(transform_list) 69 | 70 | return train_transform 71 | 72 | 73 | def get_val_transform(): 74 | transform_list = [ 75 | transforms.Resize(256), 76 | transforms.CenterCrop(224), 77 | transforms.ToTensor(), 78 | transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), 79 | ] 80 | val_transform = transforms.Compose(transform_list) 81 | 82 | return val_transform 83 | 84 | 85 | class HybridTrainPipe(pipeline.Pipeline): 86 | def __init__(self, batch_size, num_threads, device_id, root, list_path, 87 | crop, shard_id, num_shards, coji=False, dali_cpu=False): 88 | super(HybridTrainPipe, self).__init__(batch_size, 89 | num_threads, 90 | device_id, 91 | seed=12 + device_id) 92 | self.read = ops.FileReader(file_root=root, 93 | file_list=list_path, 94 | shard_id=shard_id, 95 | num_shards=num_shards, 96 | random_shuffle=True, 97 | initial_fill=1024) 98 | # Let user decide which pipeline works 99 | dali_device = 'cpu' if dali_cpu else 'gpu' 100 | decoder_device = 'cpu' if dali_cpu else 'mixed' 101 | # This padding sets the size of the internal nvJPEG buffers to be able to handle all images 102 | # from full-sized ImageNet without additional reallocations 103 | device_memory_padding = 211025920 if decoder_device == 'mixed' else 0 104 | host_memory_padding = 140544512 if decoder_device == 'mixed' else 0 105 | self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB, 106 | device_memory_padding=device_memory_padding, 107 | host_memory_padding=host_memory_padding, 108 | random_aspect_ratio=[0.75, 1.33333333], 109 | random_area=[0.08, 1.0], 110 | num_attempts=100) 111 | self.resize = ops.Resize(device=dali_device, 112 | resize_x=crop, 113 | resize_y=crop, 114 | interp_type=types.INTERP_TRIANGULAR) 115 | self.cmnp = ops.CropMirrorNormalize(device=dali_device, 116 | output_dtype=types.FLOAT, 117 | output_layout=types.NCHW, 118 | crop=(crop, crop), 119 | image_type=types.RGB, 120 | mean=[x*255 for x in IMAGENET_MEAN], 121 | std=[x*255 for x in IMAGENET_STD]) 122 | self.coin = ops.CoinFlip(probability=0.5) 123 | 124 | self.coji = coji 125 | if self.coji: 126 | self.twist = ops.ColorTwist(device=dali_device) 127 | self.brightness_rng = ops.Uniform(range=[1.0-0.4, 1.0+0.4]) 128 | self.contrast_rng = ops.Uniform(range=[1.0-0.4, 1.0+0.4]) 129 | self.saturation_rng = ops.Uniform(range=[1.0-0.4, 1.0+0.4]) 130 | 131 | def define_graph(self): 132 | rng = self.coin() 133 | imgs, targets = self.read(name="Reader") 134 | imgs = self.decode(imgs) 135 | imgs = self.resize(imgs) 136 | if self.coji: 137 | brightness = self.brightness_rng() 138 | contrast = self.contrast_rng() 139 | saturation = self.saturation_rng() 140 | imgs = self.twist(imgs, brightness=brightness, contrast=contrast, saturation=saturation) 141 | imgs = self.cmnp(imgs, mirror=rng) 142 | return [imgs, targets] 143 | 144 | 145 | class HybridValPipe(pipeline.Pipeline): 146 | def __init__(self, batch_size, num_threads, device_id, root, list_path, 147 | size, crop, shard_id, num_shards, dali_cpu=False): 148 | super(HybridValPipe, self).__init__(batch_size, 149 | num_threads, 150 | device_id, 151 | seed=12 + device_id) 152 | self.read = ops.FileReader(file_root=root, 153 | file_list=list_path, 154 | shard_id=shard_id, 155 | num_shards=num_shards, 156 | random_shuffle=False) 157 | # Let user decide which pipeline works 158 | dali_device = 'cpu' if dali_cpu else 'gpu' 159 | decoder_device = 'cpu' if dali_cpu else 'mixed' 160 | self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB) 161 | self.resize = ops.Resize(device=dali_device, 162 | resize_shorter=size, 163 | interp_type=types.INTERP_TRIANGULAR) 164 | self.cmnp = ops.CropMirrorNormalize(device=dali_device, 165 | output_dtype=types.FLOAT, 166 | output_layout=types.NCHW, 167 | crop=(crop, crop), 168 | image_type=types.RGB, 169 | mean=[x*255 for x in IMAGENET_MEAN], 170 | std=[x*255 for x in IMAGENET_STD]) 171 | 172 | def define_graph(self): 173 | imgs, targets = self.read(name="Reader") 174 | imgs = self.decode(imgs) 175 | imgs = self.resize(imgs) 176 | imgs = self.cmnp(imgs) 177 | return [imgs, targets] 178 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CrossEntropyLabelSmooth(nn.Module): 6 | def __init__(self, num_classes, epsilon): 7 | super(CrossEntropyLabelSmooth, self).__init__() 8 | self.num_classes = num_classes 9 | self.epsilon = epsilon 10 | self.logsoftmax = nn.LogSoftmax(dim=1) 11 | 12 | def forward(self, xs, targets): 13 | log_probs = self.logsoftmax(xs) 14 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 15 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 16 | loss = (-targets * log_probs).mean(0).sum() 17 | 18 | return loss 19 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import LambdaLR 3 | import warnings 4 | 5 | 6 | class LambdaLRWithMin(LambdaLR): 7 | def __init__(self, optimizer, lr_lambda, eta_min=0, last_epoch=-1): 8 | self.eta_min = eta_min 9 | super(LambdaLRWithMin, self).__init__(optimizer, lr_lambda, last_epoch) 10 | 11 | def get_lr(self): 12 | if not self._get_lr_called_within_step: 13 | warnings.warn("To get the last learning rate computed by the scheduler, " 14 | "please use `get_last_lr()`.") 15 | 16 | return [base_lr * lmbda(self.last_epoch) + self.eta_min * (1.0 - lmbda(self.last_epoch)) 17 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 18 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | 7 | class HSwish(nn.Module): 8 | def __init__(self, inplace=True): 9 | super(HSwish, self).__init__() 10 | self.inplace = inplace 11 | 12 | def forward(self, x): 13 | out = x * F.relu6(x + 3, inplace=self.inplace) / 6 14 | return out 15 | 16 | class HSigmoid(nn.Module): 17 | def __init__(self, inplace=True): 18 | super(HSigmoid, self).__init__() 19 | self.inplace = inplace 20 | 21 | def forward(self, x): 22 | out = F.relu6(x + 3, inplace=self.inplace) / 6 23 | return out 24 | 25 | class Swish(nn.Module): 26 | def __init__(self): 27 | super(Swish, self).__init__() 28 | 29 | def forward(self, x): 30 | out = x * F.sigmoid(x) 31 | return out 32 | 33 | Sigmoid = nn.Sigmoid 34 | 35 | 36 | hswish = HSwish 37 | hsigmoid = HSigmoid 38 | swish = Swish 39 | sigmoid = Sigmoid 40 | relu = nn.ReLU 41 | relu6 = nn.ReLU6 42 | 43 | 44 | class SEModule(nn.Module): 45 | def __init__(self, in_channels, reduction=4): 46 | super(SEModule, self).__init__() 47 | self.se = nn.Sequential( 48 | nn.AdaptiveAvgPool2d(1), 49 | nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1, stride=1, padding=0, bias=True), 50 | nn.BatchNorm2d(in_channels//reduction), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(in_channels//reduction, in_channels, kernel_size=1, stride=1, padding=0, bias=True), 53 | nn.BatchNorm2d(in_channels), 54 | hsigmoid(inplace=True) 55 | ) 56 | 57 | def forward(self, x): 58 | return x * self.se(x) 59 | 60 | 61 | class MBInvertedResBlock(nn.Module): 62 | def __init__(self, in_channels, mid_channels, out_channels, kernel_size=3, stride=1, act_func=relu, with_se=False): 63 | super(MBInvertedResBlock, self).__init__() 64 | self.has_residual = (in_channels == out_channels) and (stride == 1) 65 | self.se = SEModule(mid_channels) if with_se else None 66 | 67 | if mid_channels > in_channels: 68 | self.inverted_bottleneck = nn.Sequential( 69 | nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, bias=False), 70 | nn.BatchNorm2d(mid_channels), 71 | act_func(inplace=True) 72 | ) 73 | else: 74 | self.inverted_bottleneck = None 75 | self.depth_conv = nn.Sequential( 76 | nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride, 77 | padding=kernel_size//2, groups=mid_channels, bias=False), 78 | nn.BatchNorm2d(mid_channels), 79 | act_func(inplace=True) 80 | ) 81 | self.point_linear = nn.Sequential( 82 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), 83 | nn.BatchNorm2d(out_channels) 84 | ) 85 | 86 | def forward(self, x): 87 | res = x 88 | 89 | if self.inverted_bottleneck is not None: 90 | out = self.inverted_bottleneck(x) 91 | else: 92 | out = x 93 | 94 | out = self.depth_conv(out) 95 | if self.se is not None: 96 | out = self.se(out) 97 | out = self.point_linear(out) 98 | 99 | if self.has_residual: 100 | out += res 101 | 102 | return out 103 | 104 | 105 | class MobileNetV3_Large(nn.Module): 106 | def __init__(self, num_classes=1000, dropout_rate=0.0, zero_init_last_bn=False): 107 | super(MobileNetV3_Large, self).__init__() 108 | self.dropout_rate = dropout_rate 109 | self.zero_init_last_bn = zero_init_last_bn 110 | 111 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 112 | self.bn1 = nn.BatchNorm2d(16) 113 | self.hs1 = hswish(inplace=True) 114 | self.bneck = nn.Sequential( 115 | MBInvertedResBlock(16, 16, 16, 3, 1, relu, False), 116 | MBInvertedResBlock(16, 64, 24, 3, 2, relu, False), 117 | MBInvertedResBlock(24, 72, 24, 3, 1, relu, False), 118 | MBInvertedResBlock(24, 72, 40, 5, 2, relu, True), 119 | MBInvertedResBlock(40, 120, 40, 5, 1, relu, True), 120 | MBInvertedResBlock(40, 120, 40, 5, 1, relu, True), 121 | MBInvertedResBlock(40, 240, 80, 3, 2, hswish, False), 122 | MBInvertedResBlock(80, 200, 80, 3, 1, hswish, False), 123 | MBInvertedResBlock(80, 184, 80, 3, 1, hswish, False), 124 | MBInvertedResBlock(80, 184, 80, 3, 1, hswish, False), 125 | MBInvertedResBlock(80, 480, 112, 3, 1, hswish, True), 126 | MBInvertedResBlock(112, 672, 112, 3, 1, hswish, True), 127 | MBInvertedResBlock(112, 672, 160, 5, 2, hswish, True), 128 | MBInvertedResBlock(160, 960, 160, 5, 1, hswish, True), 129 | MBInvertedResBlock(160, 960, 160, 5, 1, hswish, True), 130 | ) 131 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False) 132 | self.bn2 = nn.BatchNorm2d(960) 133 | self.hs2 = hswish(inplace=True) 134 | self.avgpool = nn.AdaptiveAvgPool2d(1) 135 | self.conv3 = nn.Conv2d(960, 1280, kernel_size=1, stride=1, padding=0, bias=True) 136 | self.hs3 = hswish() 137 | self.classifier = nn.Linear(1280, num_classes) 138 | 139 | self._initialization() 140 | # self._set_bn_param(0.1, 0.001) 141 | 142 | def forward(self, x): 143 | out = self.conv1(x) 144 | out = self.bn1(out) 145 | out = self.hs1(out) 146 | 147 | out = self.bneck(out) 148 | out = self.conv2(out) 149 | out = self.bn2(out) 150 | out = self.hs2(out) 151 | 152 | out = self.avgpool(out) 153 | out = self.conv3(out) 154 | out = self.hs3(out) 155 | out = out.view(out.size(0), -1) 156 | if self.dropout_rate > 0.0: 157 | out = F.dropout(out, p=self.dropout_rate, training=self.training) 158 | out = self.classifier(out) 159 | 160 | return out 161 | 162 | def _initialization(self): 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | init.kaiming_normal_(m.weight, mode='fan_out') 166 | if m.bias is not None: 167 | init.constant_(m.bias, 0) 168 | elif isinstance(m, nn.BatchNorm2d): 169 | init.constant_(m.weight, 1) 170 | init.constant_(m.bias, 0) 171 | elif isinstance(m, nn.Linear): 172 | init.normal_(m.weight, std=0.001) 173 | if m.bias is not None: 174 | init.constant_(m.bias, 0) 175 | 176 | if self.zero_init_last_bn: 177 | for mname, m in self.named_modules(): 178 | if isinstance(m, MBInvertedResBlock): 179 | if m.has_residual: 180 | init.constant_(m.point_linear[1].weight, 0) 181 | 182 | # def _set_bn_param(self, bn_momentum, bn_eps): 183 | # for m in self.modules(): 184 | # if isinstance(m, nn.BatchNorm2d): 185 | # m.momentum = bn_momentum 186 | # m.eps = bn_eps 187 | 188 | 189 | class MobileNetV3_Small(nn.Module): 190 | def __init__(self, num_classes=1000, dropout_rate=0.0, zero_init_last_bn=False): 191 | super(MobileNetV3_Small, self).__init__() 192 | self.dropout_rate = dropout_rate 193 | self.zero_init_last_bn = zero_init_last_bn 194 | 195 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 196 | self.bn1 = nn.BatchNorm2d(16) 197 | self.hs1 = hswish(inplace=True) 198 | self.bneck = nn.Sequential( 199 | MBInvertedResBlock(16, 16, 16, 3, 2, relu, True), 200 | MBInvertedResBlock(16, 72, 24, 3, 2, relu, False), 201 | MBInvertedResBlock(24, 88, 24, 3, 1, relu, False), 202 | MBInvertedResBlock(24, 96, 40, 5, 2, hswish, True), 203 | MBInvertedResBlock(40, 240, 40, 5, 1, hswish, True), 204 | MBInvertedResBlock(40, 240, 40, 5, 1, hswish, True), 205 | MBInvertedResBlock(40, 120, 48, 5, 1, hswish, True), 206 | MBInvertedResBlock(48, 144, 48, 5, 1, hswish, True), 207 | MBInvertedResBlock(48, 288, 96, 5, 2, hswish, True), 208 | MBInvertedResBlock(96, 576, 96, 5, 1, hswish, True), 209 | MBInvertedResBlock(96, 576, 96, 5, 1, hswish, True), 210 | ) 211 | self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False) 212 | self.bn2 = nn.BatchNorm2d(576) 213 | self.hs2 = hswish(inplace=True) 214 | self.avgpool = nn.AdaptiveAvgPool2d(1) 215 | self.conv3 = nn.Conv2d(576, 1280, kernel_size=1, stride=1, padding=0, bias=True) 216 | self.hs3 = hswish() 217 | self.classifier = nn.Linear(1280, num_classes) 218 | 219 | self._initialization() 220 | # self._set_bn_param(0.1, 0.001) 221 | 222 | def forward(self, x): 223 | out = self.conv1(x) 224 | out = self.bn1(out) 225 | out = self.hs1(out) 226 | 227 | out = self.bneck(out) 228 | out = self.conv2(out) 229 | out = self.bn2(out) 230 | out = self.hs2(out) 231 | 232 | out = self.avgpool(out) 233 | out = self.conv3(out) 234 | out = self.hs3(out) 235 | out = out.view(out.size(0), -1) 236 | if self.dropout_rate > 0.0: 237 | out = F.dropout(out, p=self.dropout_rate, training=self.training) 238 | out = self.classifier(out) 239 | 240 | return out 241 | 242 | def _initialization(self): 243 | for m in self.modules(): 244 | if isinstance(m, nn.Conv2d): 245 | init.kaiming_normal_(m.weight, mode='fan_out') 246 | if m.bias is not None: 247 | init.constant_(m.bias, 0) 248 | elif isinstance(m, nn.BatchNorm2d): 249 | init.constant_(m.weight, 1) 250 | init.constant_(m.bias, 0) 251 | elif isinstance(m, nn.Linear): 252 | init.normal_(m.weight, std=0.001) 253 | if m.bias is not None: 254 | init.constant_(m.bias, 0) 255 | 256 | if self.zero_init_last_bn: 257 | for mname, m in self.named_modules(): 258 | if isinstance(m, MBInvertedResBlock): 259 | if m.has_residual: 260 | init.constant_(m.point_linear[1].weight, 0) 261 | 262 | # def _set_bn_param(self, bn_momentum, bn_eps): 263 | # for m in self.modules(): 264 | # if isinstance(m, nn.BatchNorm2d): 265 | # m.momentum = bn_momentum 266 | # m.eps = bn_eps 267 | -------------------------------------------------------------------------------- /profile_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from benchmark import ModelStat, stat 4 | from benchmark import compute_speed 5 | 6 | sys.path.append('..') 7 | from networks import MobileNetV3_Small, MobileNetV3_Large 8 | 9 | 10 | model = MobileNetV3_Small() 11 | 12 | # query_granularity can be any int value, usually: 13 | # query_granularity=1 reports every leaf node 14 | # query_granularity=-1 only reports the root node 15 | stat(model, (1, 3, 224, 224), query_granularity=1, brief_report=False) 16 | stat(model, (1, 3, 224, 224), query_granularity=-1, brief_report=False) 17 | 18 | # brief_report=True only reports the summation 19 | stat(model, (1, 3, 224, 224), query_granularity=1, brief_report=True) 20 | 21 | 22 | # can also initialize ModelStat, set the query_granularity and then show_report 23 | ms = ModelStat(model, (1, 3, 224, 224), query_granularity=1, brief_report=False) 24 | 25 | ms.query_granularity = -1 26 | ms.show_report() 27 | ms.query_granularity = 1 28 | ms.show_report() 29 | 30 | ms.unhook_model() 31 | 32 | # measure latency 33 | compute_speed(model, (32, 3, 224, 224), 'cuda:0', 1000) 34 | compute_speed(model, (1, 3, 224, 224), 'cuda:0', 1000) 35 | compute_speed(model, (1, 3, 224, 224), 'cpu', 1000) 36 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import warnings 5 | warnings.filterwarnings('ignore') 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | import torch.backends.cudnn as cudnn 11 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 12 | from apex.parallel import DistributedDataParallel as DDP 13 | 14 | from utils import AverageMeter, accuracy 15 | from datasets import ImageList, pil_loader, cv2_loader 16 | from datasets import get_val_transform, HybridValPipe 17 | from networks import MobileNetV3_Large, MobileNetV3_Small 18 | 19 | 20 | parser = argparse.ArgumentParser( 21 | description="Basic Pytorch ImageNet Example. Testing.", 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | 24 | # various paths 25 | parser.add_argument('--val_root', type=str, required=True, help='root path to validating images') 26 | parser.add_argument('--val_list', type=str, required=True, help='validating image list') 27 | parser.add_argument('--weights', type=str, required=True, help='checkpoint for testing') 28 | 29 | # testing hyper-parameters 30 | parser.add_argument('--workers', type=int, default=8, help='number of workers to load dataset (global)') 31 | parser.add_argument('--batch_size', type=int, default=512, help='batch size (global)') 32 | parser.add_argument('--model', type=str, default='MobileNetV3_Large', help='type of model', 33 | choices=['MobileNetV3_Large', 'MobileNetV3_Small']) 34 | parser.add_argument('--num_classes', type=int, default=1000, help='class number of testing set') 35 | parser.add_argument('--trans_mode', type=str, default='tv', help='mode of image transformation (tv/dali)') 36 | parser.add_argument('--dali_cpu', action='store_true', default=False, help='runs CPU based DALI pipeline') 37 | parser.add_argument('--ema', action='store_true', default=False, help='whether to use EMA') 38 | 39 | # amp and DDP hyper-parameters 40 | parser.add_argument('--local_rank', type=int, default=0) 41 | parser.add_argument('--channels_last', type=str, default='False') 42 | 43 | 44 | args, unparsed = parser.parse_known_args() 45 | args.channels_last = eval(args.channels_last) 46 | 47 | if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'): 48 | if args.channels_last: 49 | memory_format = torch.channels_last 50 | else: 51 | memory_format = torch.contiguous_format 52 | else: 53 | memory_format = None 54 | 55 | 56 | def main(): 57 | cudnn.enabled=True 58 | cudnn.benchmark = True 59 | args.distributed = False 60 | if 'WORLD_SIZE' in os.environ: 61 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 62 | args.gpu = 0 63 | args.world_size = 1 64 | if args.distributed: 65 | args.gpu = args.local_rank 66 | torch.cuda.set_device(args.gpu) 67 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 68 | args.world_size = torch.distributed.get_world_size() 69 | 70 | # create model 71 | if args.model == 'MobileNetV3_Large': 72 | model = MobileNetV3_Large(args.num_classes, 0.0, False) 73 | elif args.model == 'MobileNetV3_Small': 74 | model = MobileNetV3_Small(args.num_classes, 0.0, False) 75 | else: 76 | raise Exception('invalid type of model') 77 | model = model.cuda().to(memory_format=memory_format) if memory_format is not None else model.cuda() 78 | 79 | # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. 80 | # This must be done AFTER the call to amp.initialize. 81 | if args.distributed: 82 | # By default, apex.parallel.DistributedDataParallel overlaps communication with 83 | # computation in the backward pass. 84 | # delay_allreduce delays all communication to the end of the backward pass. 85 | model = DDP(model, delay_allreduce=True) 86 | else: 87 | model = nn.DataParallel(model) 88 | 89 | # define transform and initialize dataloader 90 | batch_size = args.batch_size // args.world_size 91 | workers = args.workers // args.world_size 92 | if args.trans_mode == 'tv': 93 | val_transform = get_val_transform() 94 | val_dataset = ImageList(root=args.val_root, 95 | list_path=args.val_list, 96 | transform=val_transform) 97 | val_sampler = None 98 | if args.distributed: 99 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) 100 | val_loader = torch.utils.data.DataLoader( 101 | val_dataset, batch_size=batch_size, num_workers=workers, 102 | pin_memory=True, sampler=val_sampler, shuffle=False) 103 | elif args.trans_mode == 'dali': 104 | pipe = HybridValPipe(batch_size=batch_size, 105 | num_threads=workers, 106 | device_id=args.local_rank, 107 | root=args.val_root, 108 | list_path=args.val_list, 109 | size=256, 110 | crop=224, 111 | shard_id=args.local_rank, 112 | num_shards=args.world_size, 113 | dali_cpu=args.dali_cpu) 114 | pipe.build() 115 | val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/args.world_size)) 116 | else: 117 | raise Exception('invalid image transformation mode') 118 | 119 | # restart from weights 120 | if args.weights and os.path.isfile(args.weights): 121 | if args.local_rank == 0: 122 | print('loading weights from {}'.format(args.weights)) 123 | checkpoint = torch.load(args.weights, map_location=lambda storage,loc: storage.cuda(args.gpu)) 124 | if args.ema: 125 | model.load_state_dict(checkpoint['ema']) 126 | else: 127 | model.load_state_dict(checkpoint['model']) 128 | 129 | val_acc_top1, val_acc_top5 = validate(val_loader, model) 130 | if args.local_rank == 0: 131 | print('Val_acc_top1: {:.2f}'.format(val_acc_top1)) 132 | print('Val_acc_top5: {:.2f}'.format(val_acc_top5)) 133 | 134 | 135 | def validate(val_loader, model): 136 | top1 = AverageMeter() 137 | top5 = AverageMeter() 138 | 139 | model.eval() 140 | 141 | for data in tqdm(val_loader): 142 | if args.trans_mode == 'tv': 143 | x = data[0].cuda(non_blocking=True) 144 | target = data[1].cuda(non_blocking=True) 145 | elif args.trans_mode == 'dali': 146 | x = data[0]['data'].cuda(non_blocking=True) 147 | target = data[0]['label'].squeeze().cuda(non_blocking=True).long() 148 | 149 | with torch.no_grad(): 150 | logits = model(x) 151 | 152 | prec1, prec5 = accuracy(logits, target, topk=(1, 5)) 153 | if args.distributed: 154 | prec1 = reduce_tensor(prec1) 155 | prec5 = reduce_tensor(prec5) 156 | top1.update(prec1.item(), x.size(0)) 157 | top5.update(prec5.item(), x.size(0)) 158 | 159 | return top1.avg, top5.avg 160 | 161 | 162 | def reduce_tensor(tensor): 163 | rt = tensor.clone() 164 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 165 | rt /= args.world_size 166 | return rt 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /test_example.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python -u test.py \ 2 | --val_root "Your ImageNet Val Set Path" \ 3 | --val_list "ImageNet Val List" \ 4 | --weights "Pretrained Weights" \ 5 | --model 'MobileNetV3_Large' \ 6 | --trans_mode 'tv' \ 7 | --ema 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import logging 6 | import argparse 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.distributed as dist 14 | import torch.utils.data.distributed 15 | import torch.backends.cudnn as cudnn 16 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 17 | from apex.parallel import DistributedDataParallel as DDP 18 | from apex import amp, parallel 19 | 20 | from utils import AverageMeter, EMA, accuracy, set_seed 21 | from utils import create_exp_dir, save_checkpoint, get_params 22 | from losses import CrossEntropyLabelSmooth 23 | from datasets import ImageList, pil_loader, cv2_loader 24 | from datasets import get_train_transform, get_val_transform 25 | from datasets import HybridTrainPipe, HybridValPipe 26 | from networks import MobileNetV3_Large, MobileNetV3_Small 27 | from lr_scheduler import LambdaLRWithMin 28 | 29 | 30 | parser = argparse.ArgumentParser( 31 | description="Basic Pytorch ImageNet Example. There is no tricks such as mixup/autoaug/dropblock/droppath etc.", 32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | 34 | # various paths 35 | parser.add_argument('--train_root', type=str, required=True, help='root path to training images') 36 | parser.add_argument('--train_list', type=str, required=True, help='training image list') 37 | parser.add_argument('--val_root', type=str, required=True, help='root path to validating images') 38 | parser.add_argument('--val_list', type=str, required=True, help='validating image list') 39 | parser.add_argument('--save', type=str, default='./checkpoints/', help='model and log saving path') 40 | parser.add_argument('--snapshot', type=str, default='', help='checkpoint for reset') 41 | 42 | # training hyper-parameters 43 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 44 | parser.add_argument('--workers', type=int, default=8, help='number of workers to load dataset (global)') 45 | parser.add_argument('--epochs', type=int, default=250, help='number of total training epochs') 46 | parser.add_argument('--warmup_epochs', type=int, default=5, help='number of warmup epochs') 47 | parser.add_argument('--batch_size', type=int, default=512, help='batch size (global)') 48 | parser.add_argument('--lr', type=float, default=0.2, help='initial learning rate') 49 | parser.add_argument('--lr_min', type=float, default=0.0, help='minimum learning rate') 50 | parser.add_argument('--lr_scheduler', type=str, default='cosine_epoch', help='type of lr scheduler', 51 | choices=['linear_epoch', 'linear_batch', 'cosine_epoch', 'cosine_batch', 'step_epoch', 'step_batch']) 52 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 53 | parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay (wd)') 54 | parser.add_argument('--no_wd_bias_bn', action='store_true', default=False, help='whether to remove wd on bias and bn') 55 | parser.add_argument('--model', type=str, default='MobileNetV3_Large', help='type of model', 56 | choices=['MobileNetV3_Large', 'MobileNetV3_Small']) 57 | parser.add_argument('--num_classes', type=int, default=1000, help='class number of training set') 58 | parser.add_argument('--dropout_rate', type=float, default=0.0, help='dropout rate') 59 | parser.add_argument('--zero_init_last_bn', action='store_true', default=False, help='zero initialize the last bn in each block') 60 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') 61 | parser.add_argument('--trans_mode', type=str, default='tv', help='mode of image transformation (tv/dali)') 62 | parser.add_argument('--color_jitter', action='store_true', default=False, help='apply color augmentation or not') 63 | parser.add_argument('--dali_cpu', action='store_true', default=False, help='runs CPU based DALI pipeline') 64 | parser.add_argument('--ema_decay', type=float, default=0.0, help='whether to use EMA') 65 | 66 | # amp and DDP hyper-parameters 67 | parser.add_argument('--local_rank', type=int, default=0) 68 | parser.add_argument('--sync_bn', action='store_true', help='enabling apex sync BN') 69 | parser.add_argument('--opt_level', type=str, default=None) 70 | parser.add_argument('--keep_batchnorm_fp32', type=str, default=None) 71 | parser.add_argument('--loss_scale', type=str, default=None) 72 | parser.add_argument('--channels_last', type=str, default='False') 73 | 74 | # others 75 | parser.add_argument('--seed', type=int, default=2, help='random seed') 76 | parser.add_argument('--note', type=str, default='try', help='note for this run') 77 | 78 | 79 | args, unparsed = parser.parse_known_args() 80 | args.channels_last = eval(args.channels_last) 81 | 82 | args.save = os.path.join(args.save, '{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), args.note)) 83 | if args.local_rank == 0: 84 | create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')) 85 | 86 | log_format = '%(asctime)s %(message)s' 87 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 88 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 89 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 90 | fh.setFormatter(logging.Formatter(log_format)) 91 | logging.getLogger().addHandler(fh) 92 | 93 | if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'): 94 | if args.channels_last: 95 | memory_format = torch.channels_last 96 | else: 97 | memory_format = torch.contiguous_format 98 | else: 99 | memory_format = None 100 | 101 | 102 | def main(): 103 | set_seed(args.seed) 104 | cudnn.enabled=True 105 | cudnn.benchmark = True 106 | args.distributed = False 107 | if 'WORLD_SIZE' in os.environ: 108 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 109 | args.gpu = 0 110 | args.world_size = 1 111 | if args.distributed: 112 | set_seed(args.local_rank) 113 | args.gpu = args.local_rank 114 | torch.cuda.set_device(args.gpu) 115 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 116 | args.world_size = torch.distributed.get_world_size() 117 | if args.local_rank == 0: 118 | logging.info("args = {}".format(args)) 119 | logging.info("unparsed_args = {}".format(unparsed)) 120 | logging.info("distributed = {}".format(args.distributed)) 121 | logging.info("sync_bn = {}".format(args.sync_bn)) 122 | logging.info("opt_level = {}".format(args.opt_level)) 123 | logging.info("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32)) 124 | logging.info("loss_scale = {}".format(args.loss_scale)) 125 | logging.info("CUDNN VERSION: {}".format(torch.backends.cudnn.version())) 126 | 127 | # create model 128 | if args.model == 'MobileNetV3_Large': 129 | model = MobileNetV3_Large(args.num_classes, args.dropout_rate, args.zero_init_last_bn) 130 | elif args.model == 'MobileNetV3_Small': 131 | model = MobileNetV3_Small(args.num_classes, args.dropout_rate, args.zero_init_last_bn) 132 | else: 133 | raise Exception('invalid type of model') 134 | if args.sync_bn: 135 | if args.local_rank == 0: logging.info("using apex synced BN") 136 | model = parallel.convert_syncbn_model(model) 137 | model = model.cuda().to(memory_format=memory_format) if memory_format is not None else model.cuda() 138 | 139 | # define criterion and optimizer 140 | if args.label_smooth > 0.0: 141 | criterion = CrossEntropyLabelSmooth(args.num_classes, args.label_smooth) 142 | else: 143 | criterion = nn.CrossEntropyLoss() 144 | criterion = criterion.cuda() 145 | 146 | params = get_params(model) if args.no_wd_bias_bn else model.parameters() 147 | optimizer = torch.optim.SGD(params, args.lr, 148 | momentum=args.momentum, 149 | weight_decay=args.weight_decay) 150 | # Initialize Amp 151 | if args.opt_level is not None: 152 | model, optimizer = amp.initialize(model, optimizer, 153 | opt_level=args.opt_level, 154 | keep_batchnorm_fp32=args.keep_batchnorm_fp32, 155 | loss_scale=args.loss_scale) 156 | 157 | # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. 158 | # This must be done AFTER the call to amp.initialize. 159 | if args.distributed: 160 | # By default, apex.parallel.DistributedDataParallel overlaps communication with 161 | # computation in the backward pass. 162 | # delay_allreduce delays all communication to the end of the backward pass. 163 | model = DDP(model, delay_allreduce=True) 164 | else: 165 | model = nn.DataParallel(model) 166 | 167 | # exponential moving average 168 | if args.ema_decay > 0.0: 169 | ema = EMA(model, args.ema_decay) 170 | ema.register() 171 | else: 172 | ema = None 173 | 174 | # define transform and initialize dataloader 175 | batch_size = args.batch_size // args.world_size 176 | workers = args.workers // args.world_size 177 | if args.trans_mode == 'tv': 178 | train_transform = get_train_transform(args.color_jitter) 179 | val_transform = get_val_transform() 180 | train_dataset = ImageList(root=args.train_root, 181 | list_path=args.train_list, 182 | transform=train_transform) 183 | val_dataset = ImageList(root=args.val_root, 184 | list_path=args.val_list, 185 | transform=val_transform) 186 | train_sampler = None 187 | val_sampler = None 188 | if args.distributed: 189 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) 190 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) 191 | train_loader = torch.utils.data.DataLoader( 192 | train_dataset, batch_size=batch_size, num_workers=workers, 193 | pin_memory=True, sampler=train_sampler, shuffle=(train_sampler is None)) 194 | val_loader = torch.utils.data.DataLoader( 195 | val_dataset, batch_size=batch_size, num_workers=workers, 196 | pin_memory=True, sampler=val_sampler, shuffle=False) 197 | args.batches_per_epoch = len(train_loader) 198 | elif args.trans_mode == 'dali': 199 | pipe = HybridTrainPipe(batch_size=batch_size, 200 | num_threads=workers, 201 | device_id=args.local_rank, 202 | root=args.train_root, 203 | list_path=args.train_list, 204 | crop=224, 205 | shard_id=args.local_rank, 206 | num_shards=args.world_size, 207 | coji=args.color_jitter, 208 | dali_cpu=args.dali_cpu) 209 | pipe.build() 210 | train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/args.world_size)) 211 | args.batches_per_epoch = train_loader._size // train_loader.batch_size 212 | args.batches_per_epoch += (train_loader._size % train_loader.batch_size) != 0 213 | 214 | pipe = HybridValPipe(batch_size=batch_size, 215 | num_threads=workers, 216 | device_id=args.local_rank, 217 | root=args.val_root, 218 | list_path=args.val_list, 219 | size=256, 220 | crop=224, 221 | shard_id=args.local_rank, 222 | num_shards=args.world_size, 223 | dali_cpu=args.dali_cpu) 224 | pipe.build() 225 | val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/args.world_size)) 226 | else: 227 | raise Exception('invalid image transformation mode') 228 | 229 | # define learning rate scheduler 230 | scheduler = get_lr_scheduler(optimizer) 231 | 232 | best_acc_top1 = 0 233 | best_acc_top5 = 0 234 | start_epoch = 0 235 | 236 | # restart from snapshot 237 | if args.snapshot and os.path.isfile(args.snapshot): 238 | if args.local_rank == 0: 239 | logging.info('loading snapshot from {}'.format(args.snapshot)) 240 | checkpoint = torch.load(args.snapshot, map_location=lambda storage,loc: storage.cuda(args.gpu)) 241 | start_epoch = checkpoint['epoch'] 242 | best_acc_top1 = checkpoint['best_acc_top1'] 243 | best_acc_top5 = checkpoint['best_acc_top5'] 244 | model.load_state_dict(checkpoint['model']) 245 | optimizer.load_state_dict(checkpoint['optimizer']) 246 | if checkpoint['ema'] is not None: 247 | ema.load_state_dict(checkpoint['ema']) 248 | if args.opt_level is not None: 249 | amp.load_state_dict(checkpoint['amp']) 250 | scheduler = get_lr_scheduler(optimizer) 251 | for epoch in range(start_epoch): 252 | if epoch < args.warmup_epochs: 253 | adjust_learning_rate(optimizer, scheduler, epoch, -1) 254 | warmup_lr = get_last_lr(optimizer) 255 | if args.local_rank == 0: 256 | logging.info('Epoch: %d, Warming-up lr: %e', epoch, warmup_lr) 257 | else: 258 | current_lr = get_last_lr(optimizer) 259 | if args.local_rank == 0: 260 | logging.info('Epoch: %d lr %e', epoch, current_lr) 261 | 262 | if epoch < args.warmup_epochs: 263 | for param_group in optimizer.param_groups: 264 | param_group['lr'] = args.lr 265 | else: 266 | if args.lr_scheduler in ['linear_epoch', 'cosine_epoch', 'step_epoch']: 267 | adjust_learning_rate(optimizer, scheduler, epoch, -1) 268 | if args.lr_scheduler in ['linear_batch', 'cosine_batch', 'step_batch']: 269 | for batch_idx in range(args.batches_per_epoch): 270 | adjust_learning_rate(optimizer, scheduler, epoch, batch_idx) 271 | 272 | # the main loop 273 | for epoch in range(start_epoch, args.epochs): 274 | if epoch < args.warmup_epochs: 275 | adjust_learning_rate(optimizer, scheduler, epoch, -1) 276 | warmup_lr = get_last_lr(optimizer) 277 | if args.local_rank == 0: 278 | logging.info('Epoch: %d, Warming-up lr: %e', epoch, warmup_lr) 279 | else: 280 | current_lr = get_last_lr(optimizer) 281 | if args.local_rank == 0: 282 | logging.info('Epoch: %d lr %e', epoch, current_lr) 283 | 284 | if args.distributed and args.trans_mode == 'tv': 285 | train_sampler.set_epoch(epoch) 286 | 287 | epoch_start = time.time() 288 | train_acc, train_obj = train(train_loader, model, ema, criterion, optimizer, scheduler, epoch) 289 | if args.local_rank == 0: 290 | logging.info('Train_acc: %f', train_acc) 291 | 292 | val_acc_top1, val_acc_top5, val_obj = validate(val_loader, model, criterion) 293 | if args.local_rank == 0: 294 | logging.info('Val_acc_top1: %f', val_acc_top1) 295 | logging.info('Val_acc_top5: %f', val_acc_top5) 296 | logging.info('Epoch time: %ds.', time.time() - epoch_start) 297 | 298 | if args.local_rank == 0: 299 | is_best = False 300 | if val_acc_top1 > best_acc_top1: 301 | best_acc_top1 = val_acc_top1 302 | best_acc_top5 = val_acc_top5 303 | is_best = True 304 | save_checkpoint({ 305 | 'epoch': epoch + 1, 306 | 'model': model.state_dict(), 307 | 'ema': ema.state_dict() if ema is not None else None, 308 | 'best_acc_top1': best_acc_top1, 309 | 'best_acc_top5': best_acc_top5, 310 | 'optimizer' : optimizer.state_dict(), 311 | 'amp': amp.state_dict() if args.opt_level is not None else None, 312 | }, is_best, args.save) 313 | 314 | if epoch < args.warmup_epochs: 315 | for param_group in optimizer.param_groups: 316 | param_group['lr'] = args.lr 317 | else: 318 | adjust_learning_rate(optimizer, scheduler, epoch, -1) 319 | 320 | if args.trans_mode == 'dali': 321 | train_loader.reset() 322 | val_loader.reset() 323 | 324 | 325 | def train(train_loader, model, ema, criterion, optimizer, scheduler, epoch): 326 | objs = AverageMeter() 327 | top1 = AverageMeter() 328 | top5 = AverageMeter() 329 | batch_time = AverageMeter() 330 | data_time = AverageMeter() 331 | model.train() 332 | 333 | end = time.time() 334 | for batch_idx, data in enumerate(train_loader): 335 | data_time.update(time.time() - end) 336 | if args.trans_mode == 'tv': 337 | x = data[0].cuda(non_blocking=True) 338 | target = data[1].cuda(non_blocking=True) 339 | elif args.trans_mode == 'dali': 340 | x = data[0]['data'].cuda(non_blocking=True) 341 | target = data[0]['label'].squeeze().cuda(non_blocking=True).long() 342 | 343 | # forward 344 | batch_start = time.time() 345 | logits = model(x) 346 | loss = criterion(logits, target) 347 | 348 | # backward 349 | optimizer.zero_grad() 350 | if args.opt_level is not None: 351 | with amp.scale_loss(loss, optimizer) as scaled_loss: 352 | scaled_loss.backward() 353 | else: 354 | loss.backward() 355 | optimizer.step() 356 | if ema is not None: ema.update() 357 | batch_time.update(time.time() - batch_start) 358 | 359 | if batch_idx % args.print_freq == 0: 360 | # For better performance, don't accumulate these metrics every iteration, 361 | # since they may incur an allreduce and some host<->device syncs. 362 | prec1, prec5 = accuracy(logits, target, topk=(1, 5)) 363 | if args.distributed: 364 | reduced_loss = reduce_tensor(loss.data) 365 | prec1 = reduce_tensor(prec1) 366 | prec5 = reduce_tensor(prec5) 367 | else: 368 | reduced_loss = loss.data 369 | objs.update(reduced_loss.item(), x.size(0)) 370 | top1.update(prec1.item(), x.size(0)) 371 | top5.update(prec5.item(), x.size(0)) 372 | torch.cuda.synchronize() 373 | 374 | duration = 0 if batch_idx == 0 else time.time() - duration_start 375 | duration_start = time.time() 376 | if args.local_rank == 0: 377 | logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs DTime: %.4fs', 378 | batch_idx, objs.avg, top1.avg, top5.avg, duration, batch_time.avg, data_time.avg) 379 | 380 | adjust_learning_rate(optimizer, scheduler, epoch, batch_idx) 381 | end = time.time() 382 | 383 | return top1.avg, objs.avg 384 | 385 | 386 | def validate(val_loader, model, criterion): 387 | objs = AverageMeter() 388 | top1 = AverageMeter() 389 | top5 = AverageMeter() 390 | 391 | model.eval() 392 | 393 | for batch_idx, data in enumerate(val_loader): 394 | if args.trans_mode == 'tv': 395 | x = data[0].cuda(non_blocking=True) 396 | target = data[1].cuda(non_blocking=True) 397 | elif args.trans_mode == 'dali': 398 | x = data[0]['data'].cuda(non_blocking=True) 399 | target = data[0]['label'].squeeze().cuda(non_blocking=True).long() 400 | 401 | with torch.no_grad(): 402 | logits = model(x) 403 | loss = criterion(logits, target) 404 | 405 | prec1, prec5 = accuracy(logits, target, topk=(1, 5)) 406 | if args.distributed: 407 | reduced_loss = reduce_tensor(loss.data) 408 | prec1 = reduce_tensor(prec1) 409 | prec5 = reduce_tensor(prec5) 410 | else: 411 | reduced_loss = loss.data 412 | objs.update(reduced_loss.item(), x.size(0)) 413 | top1.update(prec1.item(), x.size(0)) 414 | top5.update(prec5.item(), x.size(0)) 415 | 416 | if args.local_rank == 0 and batch_idx % args.print_freq == 0: 417 | duration = 0 if batch_idx == 0 else time.time() - duration_start 418 | duration_start = time.time() 419 | logging.info('VALIDATE Step: %03d Objs: %e R1: %f R5: %f Duration: %ds', batch_idx, objs.avg, top1.avg, top5.avg, duration) 420 | 421 | return top1.avg, top5.avg, objs.avg 422 | 423 | 424 | def reduce_tensor(tensor): 425 | rt = tensor.clone() 426 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 427 | rt /= args.world_size 428 | return rt 429 | 430 | 431 | def get_lr_scheduler(optimizer): 432 | if args.lr_scheduler == 'linear_epoch': 433 | total_steps = args.epochs - args.warmup_epochs 434 | lambda_func = lambda step: max(1.0-step/float(total_steps), 0) 435 | scheduler = LambdaLRWithMin(optimizer, lambda_func, args.lr_min) 436 | elif args.lr_scheduler == 'linear_batch': 437 | total_steps = (args.epochs - args.warmup_epochs) * args.batches_per_epoch 438 | lambda_func = lambda step: max(1.0-step/float(total_steps), 0) 439 | scheduler = LambdaLRWithMin(optimizer, lambda_func, args.lr_min) 440 | elif args.lr_scheduler == 'cosine_epoch': 441 | total_steps = args.epochs - args.warmup_epochs 442 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(total_steps), args.lr_min) 443 | elif args.lr_scheduler == 'cosine_batch': 444 | total_steps = (args.epochs - args.warmup_epochs) * args.batches_per_epoch 445 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(total_steps), args.lr_min) 446 | elif args.lr_scheduler == 'step_epoch': 447 | assert args.lr_min > 0.0, 'the minimum lr must be larger than 0 for "step" lr_scheduler' 448 | total_steps = args.epochs - args.warmup_epochs 449 | gamma = (args.lr_min / args.lr) ** (1.0 / total_steps) 450 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma) 451 | elif args.lr_scheduler == 'step_batch': 452 | assert args.lr_min > 0.0, 'the minimum lr must be larger than 0 for "step" lr_scheduler' 453 | total_steps = (args.epochs - args.warmup_epochs) * args.batches_per_epoch 454 | gamma = (args.lr_min / args.lr) ** (1.0 / total_steps) 455 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma) 456 | else: 457 | raise Exception('invalid type fo lr scheduler') 458 | 459 | return scheduler 460 | 461 | 462 | def get_last_lr(optimizer): 463 | last_lrs = [param_group['lr'] for param_group in optimizer.param_groups] 464 | return last_lrs[0] 465 | 466 | 467 | def adjust_learning_rate(optimizer, scheduler, epoch, batch_idx): 468 | ''' 469 | batch_idx = -1: adjusts lr per epoch 470 | batch_idx >= 0: adjusts lr per batch 471 | ''' 472 | if args.lr_scheduler in ['linear_epoch', 'cosine_epoch', 'step_epoch']: 473 | if epoch < args.warmup_epochs: 474 | if batch_idx == -1: 475 | warmup_lr = float(epoch + 1) / (args.warmup_epochs + 1) * args.lr 476 | for param_group in optimizer.param_groups: 477 | param_group['lr'] = warmup_lr 478 | else: 479 | if batch_idx == -1: 480 | scheduler.step() 481 | 482 | if args.lr_scheduler in ['linear_batch', 'cosine_batch', 'step_batch']: 483 | if epoch < args.warmup_epochs: 484 | batch_idx = epoch * args.batches_per_epoch + batch_idx 485 | total_batches = args.warmup_epochs * args.batches_per_epoch 486 | warmup_lr = float(batch_idx + 2) / (total_batches + 1) * args.lr 487 | for param_group in optimizer.param_groups: 488 | param_group['lr'] = warmup_lr 489 | else: 490 | if batch_idx >= 0: 491 | scheduler.step() 492 | 493 | 494 | if __name__ == '__main__': 495 | main() 496 | -------------------------------------------------------------------------------- /train_example.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python -u -m torch.distributed.launch --nproc_per_node=2 train.py \ 2 | --train_root "Your ImageNet Train Set Path" \ 3 | --val_root "Your ImageNet Val Set Path" \ 4 | --train_list "ImageNet Train List" \ 5 | --val_list "ImageNet Val List" \ 6 | --save './checkpoints/' \ 7 | --workers 16 \ 8 | --epochs 250 \ 9 | --warmup_epochs 5 \ 10 | --batch_size 512 \ 11 | --lr 0.2 \ 12 | --lr_min 0.0 \ 13 | --lr_scheduler 'cosine_epoch' \ 14 | --momentum 0.9 \ 15 | --weight_decay 3e-5 \ 16 | --no_wd_bias_bn \ 17 | --model 'MobileNetV3_Large' \ 18 | --num_classes 1000 \ 19 | --dropout_rate 0.2 \ 20 | --label_smooth 0.1 \ 21 | --trans_mode 'tv' \ 22 | --color_jitter \ 23 | --ema_decay 0.9999 \ 24 | --opt_level 'O1' \ 25 | --note 'try' 26 | 27 | 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import random 5 | import numpy as np 6 | 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | 15 | 16 | class AverageMeter(object): 17 | """ 18 | Computes and stores the average and current value 19 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 20 | """ 21 | def __init__(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / self.count 38 | 39 | 40 | def accuracy(output, target, topk=(1,)): 41 | """ Computes the precision@k for the specified values of k """ 42 | maxk = max(topk) 43 | batch_size = target.size(0) 44 | 45 | _, pred = output.topk(maxk, 1, True, True) 46 | pred = pred.t() 47 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 48 | 49 | res = [] 50 | for k in topk: 51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 52 | res.append(correct_k.mul_(100.0 / batch_size)) 53 | return res 54 | 55 | 56 | def save_checkpoint(state, is_best, save): 57 | filename = os.path.join(save, 'checkpoint.pth.tar') 58 | torch.save(state, filename) 59 | if is_best: 60 | best_filename = os.path.join(save, 'model_best.pth.tar') 61 | shutil.copyfile(filename, best_filename) 62 | 63 | 64 | def create_exp_dir(path, scripts_to_save=None): 65 | if not os.path.exists(path): 66 | os.makedirs(path) 67 | print('Experiment dir : {}'.format(path)) 68 | 69 | if scripts_to_save is not None: 70 | os.makedirs(os.path.join(path, 'scripts')) 71 | for script in scripts_to_save: 72 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 73 | shutil.copyfile(script, dst_file) 74 | 75 | 76 | def get_params(model): 77 | params_no_weight_decay = [] 78 | params_weight_decay = [] 79 | for pname, p in model.named_parameters(): 80 | if pname.find('weight') >= 0 and len(p.size()) > 1: 81 | # print('include ', pname, p.size()) 82 | params_weight_decay.append(p) 83 | else: 84 | # print('not include ', pname, p.size()) 85 | params_no_weight_decay.append(p) 86 | assert len(list(model.parameters())) == len(params_weight_decay) + len(params_no_weight_decay) 87 | params = [dict(params=params_weight_decay), dict(params=params_no_weight_decay, weight_decay=0.)] 88 | return params 89 | 90 | 91 | class EMA(): 92 | def __init__(self, model, decay): 93 | self.model = model 94 | self.decay = decay 95 | self.shadow = {} 96 | 97 | def register(self): 98 | for name, state in self.model.state_dict().items(): 99 | self.shadow[name] = state.clone() 100 | 101 | def update(self): 102 | for name, state in self.model.state_dict().items(): 103 | assert name in self.shadow 104 | new_average = (1.0 - self.decay) * state + self.decay * self.shadow[name] 105 | self.shadow[name] = new_average.clone() 106 | del new_average 107 | 108 | def state_dict(self): 109 | return self.shadow 110 | 111 | def load_state_dict(self, state_dict): 112 | for name, state in state_dict.items(): 113 | self.shadow[name] = state.clone() --------------------------------------------------------------------------------