├── .gitignore
├── LICENSE
├── README.md
├── benchmark
├── __init__.py
├── compute_flops.py
├── compute_madd.py
├── compute_memory.py
├── compute_speed.py
├── model_hook.py
├── reporter.py
├── stat_tree.py
└── statistics.py
├── datasets.py
├── losses.py
├── lr_scheduler.py
├── networks.py
├── profile_example.py
├── test.py
├── test_example.sh
├── train.py
├── train_example.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | checkpoints*/
132 | log.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Aber Hu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ImageNet-training
2 |
3 | Pytorch ImageNet training codes with various tricks, lr schedulers, distributed training, mixed precision training, DALI dataloader etc. We hope this repo can help ImageNet experiments in NAS researches.
4 |
5 | ## Train
6 | ```
7 | CUDA_VISIBLE_DEVICES=0 python -u train.py --train_root /path/to/imagenet/train_set --val_root /path/to/imagenet/val_set --train_list /path/to/imagenet/train_list --val_list /path/to/imagenet/val_list
8 | ```
9 |
10 | Please refer to [train_example.sh](https://github.com/AberHu/ImageNet-training/blob/master/train_example.sh) for more details.
11 |
12 | ## Test
13 | ```
14 | CUDA_VISIBLE_DEVICES=0 python -u test.py --val_root /path/to/imagenet/val_set --val_list /path/to/imagenet/val_list --weights /path/to/pretrained_weights
15 | ```
16 |
17 | Please refer to [test_example.sh](https://github.com/AberHu/ImageNet-training/blob/master/test_example.sh) for more details.
18 |
19 | ## Model Profiling
20 | Please refer to [profile_example.py](https://github.com/AberHu/ImageNet-training/blob/master/profile_example.py) for more details.
21 |
22 | ## Tested on
23 | Python == 3.7.6
24 | pytorch == 1.5.1
25 | torchvision == 0.6.1
26 | nvidia.dali == 0.22.0
27 | cuDNN == 7.6.5
28 | apex from [this link](https://github.com/NVIDIA/apex.git)
29 |
30 | ## License
31 | This repo is released under the MIT license. Please see the [LICENSE](https://github.com/AberHu/ImageNet-training/blob/master/LICENSE) file for more information.
32 |
--------------------------------------------------------------------------------
/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | from .compute_speed import compute_speed
2 | from .compute_memory import compute_memory
3 | from .compute_madd import compute_madd
4 | from .compute_flops import compute_flops
5 | from .stat_tree import StatTree, StatNode
6 | from .model_hook import ModelHook
7 | from .statistics import ModelStat, stat
8 | from .reporter import report_format
9 |
10 | __all__ = [ 'StatTree', 'StatNode', 'ModelHook', 'ModelStat', 'stat', 'report_format'
11 | 'compute_speed', 'compute_memory', 'compute_madd', 'compute_flops']
--------------------------------------------------------------------------------
/benchmark/compute_flops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import sys
5 | sys.path.append('..')
6 | from networks import HSwish, HSigmoid, Swish, Sigmoid
7 |
8 |
9 | def compute_flops(module, inp, out):
10 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.LeakyReLU)):
11 | return compute_ReLU_flops(module, inp, out), 'Activation'
12 | elif isinstance(module, nn.ELU):
13 | return compute_ELU_flops(module, inp, out), 'Activation'
14 | elif isinstance(module, Sigmoid):
15 | return compute_Sigmoid_flops(module, inp, out), 'Activation'
16 | elif isinstance(module, HSigmoid):
17 | return compute_HSigmoid_flops(module, inp, out), 'Activation'
18 | elif isinstance(module, Swish):
19 | return compute_Swish_flops(module, inp, out), 'Activation'
20 | elif isinstance(module, HSwish):
21 | return compute_HSwish_flops(module, inp, out), 'Activation'
22 | elif isinstance(module, nn.Conv2d):
23 | return compute_Conv2d_flops(module, inp, out), 'Conv2d'
24 | elif isinstance(module, nn.ConvTranspose2d):
25 | return compute_ConvTranspose2d_flops(module, inp, out), 'ConvTranspose2d'
26 | elif isinstance(module, nn.BatchNorm2d):
27 | return compute_BatchNorm2d_flops(module, inp, out), 'BatchNorm2d'
28 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
29 | return compute_Pool2d_flops(module, inp, out), 'Pool2d'
30 | elif isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)):
31 | return compute_AdaptivePool2d_flops(module, inp, out), 'Pool2d'
32 | elif isinstance(module, nn.Linear):
33 | return compute_Linear_flops(module, inp, out), 'Linear'
34 | else:
35 | print("[Flops]: {} is not supported!".format(type(module).__name__))
36 | return 0, -1
37 | pass
38 |
39 |
40 | def compute_ReLU_flops(module, inp, out):
41 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.LeakyReLU))
42 |
43 | batch_size = inp.size()[0]
44 | active_elements_count = batch_size
45 |
46 | for s in inp.size()[1:]:
47 | active_elements_count *= s
48 |
49 | return active_elements_count
50 |
51 |
52 | def compute_ELU_flops(module, inp, out):
53 | assert isinstance(module, nn.ELU)
54 |
55 | batch_size = inp.size()[0]
56 | active_elements_count = batch_size
57 |
58 | for s in inp.size()[1:]:
59 | active_elements_count *= s
60 | active_elements_count *= 3
61 |
62 | return active_elements_count
63 |
64 |
65 | def compute_Sigmoid_flops(module, inp, out):
66 | assert isinstance(module, Sigmoid)
67 |
68 | batch_size = inp.size()[0]
69 | active_elements_count = batch_size
70 |
71 | for s in inp.size()[1:]:
72 | active_elements_count *= s
73 | active_elements_count *= 4
74 |
75 | return active_elements_count
76 |
77 |
78 | def compute_HSigmoid_flops(module, inp, out):
79 | assert isinstance(module, HSigmoid)
80 |
81 | batch_size = inp.size()[0]
82 | active_elements_count = batch_size
83 |
84 | for s in inp.size()[1:]:
85 | active_elements_count *= s
86 | active_elements_count *= (2 + 1)
87 |
88 | return active_elements_count
89 |
90 |
91 | def compute_Swish_flops(module, inp, out):
92 | assert isinstance(module, Swish)
93 |
94 | batch_size = inp.size()[0]
95 | active_elements_count = batch_size
96 |
97 | for s in inp.size()[1:]:
98 | active_elements_count *= s
99 | active_elements_count *= (1 + 4)
100 |
101 | return active_elements_count
102 |
103 |
104 | def compute_HSwish_flops(module, inp, out):
105 | assert isinstance(module, HSwish)
106 |
107 | batch_size = inp.size()[0]
108 | active_elements_count = batch_size
109 |
110 | for s in inp.size()[1:]:
111 | active_elements_count *= s
112 | active_elements_count *= (1 + 3)
113 |
114 | return active_elements_count
115 |
116 |
117 | def compute_Conv2d_flops(module, inp, out):
118 | # Can have multiple inputs, getting the first one
119 | assert isinstance(module, nn.Conv2d)
120 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
121 |
122 | batch_size = inp.size()[0]
123 | in_c = inp.size()[1]
124 | k_h, k_w = module.kernel_size
125 | out_c, out_h, out_w = out.size()[1:]
126 | groups = module.groups
127 |
128 | conv_per_position_flops = k_h * k_w * in_c * out_c // groups
129 | active_elements_count = batch_size * out_h * out_w
130 | total_conv_flops = conv_per_position_flops * active_elements_count
131 |
132 | bias_flops = 0
133 | if module.bias is not None:
134 | bias_flops = out_c * active_elements_count
135 |
136 | total_flops = total_conv_flops + bias_flops
137 | return total_flops
138 |
139 |
140 | def compute_ConvTranspose2d_flops(module, inp, out):
141 | # Can have multiple inputs, getting the first one
142 | assert isinstance(module, nn.ConvTranspose2d)
143 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
144 |
145 | batch_size = inp.size()[0]
146 | in_c = inp.size()[1]
147 | k_h, k_w = module.kernel_size
148 | out_c, out_h, out_w = out.size()[1:]
149 | groups = module.groups
150 |
151 | conv_per_position_flops = k_h * k_w * in_c * out_c // groups
152 | active_elements_count = batch_size * out_h * out_w
153 | total_conv_flops = conv_per_position_flops * active_elements_count
154 |
155 | bias_flops = 0
156 | if module.bias is not None:
157 | bias_flops = out_c * active_elements_count
158 |
159 | total_flops = total_conv_flops + bias_flops
160 | return total_flops
161 |
162 |
163 | def compute_BatchNorm2d_flops(module, inp, out):
164 | assert isinstance(module, nn.BatchNorm2d)
165 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
166 |
167 | bn_flops = np.prod(inp.shape)
168 | if module.affine:
169 | bn_flops *= 2
170 |
171 | return bn_flops
172 |
173 |
174 | def compute_Pool2d_flops(module, inp, out):
175 | assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
176 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
177 |
178 | if isinstance(module.kernel_size, (tuple, list)):
179 | k_h, k_w = module.kernel_size
180 | else:
181 | k_h, k_w = module.kernel_size, module.kernel_size
182 | out_c, out_h, out_w = out.size()[1:]
183 | batch_size = inp.size()[0]
184 |
185 | pool_flops = batch_size * out_c * out_h * out_w * k_h * k_w
186 |
187 | return pool_flops
188 |
189 |
190 | def compute_AdaptivePool2d_flops(module, inp, out):
191 | assert isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d))
192 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
193 |
194 | inp_c, inp_h, inp_w = inp.size()[1:]
195 | out_c, out_h, out_w = out.size()[1:]
196 | k_h = int(round(inp_h / out_h))
197 | k_w = int(round(inp_w / out_w))
198 | batch_size = inp.size()[0]
199 |
200 | adaptive_pool_flops = batch_size * out_c * out_h * out_w * k_h * k_w
201 |
202 | return np.prod(inp.shape)
203 |
204 |
205 | def compute_Linear_flops(module, inp, out):
206 | assert isinstance(module, nn.Linear)
207 | assert len(inp.size()) == 2 and len(out.size()) == 2
208 |
209 | batch_size = inp.size()[0]
210 | num_in_features = inp.size()[1]
211 | num_out_features = out.size()[1]
212 |
213 | total_fc_flops = batch_size * num_in_features * num_out_features
214 |
215 | bias_flops = 0
216 | if module.bias is not None:
217 | bias_flops = batch_size * num_out_features
218 |
219 | total_flops = total_fc_flops + bias_flops
220 | return total_flops
221 |
222 |
--------------------------------------------------------------------------------
/benchmark/compute_madd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import sys
5 | sys.path.append('..')
6 | from networks import HSwish, HSigmoid, Swish, Sigmoid
7 |
8 | def compute_madd(module, inp, out):
9 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.LeakyReLU, nn.PReLU)):
10 | return compute_ReLU_madd(module, inp, out)
11 | elif isinstance(module, nn.ELU):
12 | return compute_ELU_madd(module, inp, out)
13 | elif isinstance(module, Sigmoid):
14 | return compute_Sigmoid_madd(module, inp, out)
15 | elif isinstance(module, HSigmoid):
16 | return compute_HSigmoid_madd(module, inp, out)
17 | elif isinstance(module, Swish):
18 | return compute_Swish_madd(module, inp, out)
19 | elif isinstance(module, HSwish):
20 | return compute_HSwish_madd(module, inp, out)
21 | elif isinstance(module, nn.Conv2d):
22 | return compute_Conv2d_madd(module, inp, out)
23 | elif isinstance(module, nn.ConvTranspose2d):
24 | return compute_ConvTranspose2d_madd(module, inp, out)
25 | elif isinstance(module, nn.BatchNorm2d):
26 | return compute_BatchNorm2d_madd(module, inp, out)
27 | elif isinstance(module, nn.Linear):
28 | return compute_Linear_madd(module, inp, out)
29 | elif isinstance(module, nn.MaxPool2d):
30 | return compute_MaxPool2d_madd(module, inp, out)
31 | elif isinstance(module, nn.AdaptiveMaxPool2d):
32 | return compute_AdaptiveMaxPool2d_madd(module, inp, out)
33 | elif isinstance(module, nn.AvgPool2d):
34 | return compute_AvgPool2d_madd(module, inp, out)
35 | elif isinstance(module, nn.AdaptiveAvgPool2d):
36 | return compute_AdaptiveAvgPool2d_madd(module, inp, out)
37 | else:
38 | print("[MAdd]: {} is not supported!".format(type(module).__name__))
39 | return 0
40 |
41 |
42 | def compute_ReLU_madd(module, inp, out):
43 | assert isinstance(module, (nn.ReLU, nn.ReLU6))
44 |
45 | count = 1
46 | for i in inp.size()[1:]:
47 | count *= i
48 |
49 | return count
50 |
51 |
52 | def compute_ELU_madd(module, inp, out):
53 | assert isinstance(module, nn.ELU)
54 |
55 | count = 1
56 | for i in inp.size()[1:]:
57 | count *= i
58 | total_mul = count + count
59 | total_add = count
60 |
61 | return total_mul + total_add
62 |
63 |
64 | def compute_Sigmoid_madd(module, inp, out):
65 | assert isinstance(module, Sigmoid)
66 |
67 | count = 1
68 | for i in inp.size()[1:]:
69 | count *= i
70 | total_mul = count + count + count
71 | total_add = count
72 |
73 | return total_mul + total_add
74 |
75 |
76 | def compute_HSigmoid_madd(module, inp, out):
77 | assert isinstance(module, HSigmoid)
78 |
79 | count = 1
80 | for i in inp.size()[1:]:
81 | count *= i
82 | total_mul = count + (count)
83 | total_add = count
84 |
85 | return total_mul + total_add
86 |
87 |
88 | def compute_Swish_madd(module, inp, out):
89 | assert isinstance(module, Swish)
90 |
91 | count = 1
92 | for i in inp.size()[1:]:
93 | count *= i
94 | total_mul = count + (count + count + count)
95 | total_add = 0 + (count)
96 |
97 | return total_mul + total_add
98 |
99 |
100 | def compute_HSwish_madd(module, inp, out):
101 | assert isinstance(module, HSwish)
102 |
103 | count = 1
104 | for i in inp.size()[1:]:
105 | count *= i
106 | total_mul = count + (count + count)
107 | total_add = 0 + (count)
108 |
109 | return total_mul + total_add
110 |
111 |
112 | def compute_Conv2d_madd(module, inp, out):
113 | assert isinstance(module, nn.Conv2d)
114 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
115 |
116 | in_c = inp.size()[1]
117 | k_h, k_w = module.kernel_size
118 | out_c, out_h, out_w = out.size()[1:]
119 | groups = module.groups
120 |
121 | # ops per output element
122 | kernel_mul = k_h * k_w * (in_c // groups)
123 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
124 |
125 | kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups)
126 | kernel_add_group = kernel_add * out_h * out_w * (out_c // groups)
127 |
128 | total_mul = kernel_mul_group * groups
129 | total_add = kernel_add_group * groups
130 |
131 | return total_mul + total_add
132 |
133 |
134 | def compute_ConvTranspose2d_madd(module, inp, out):
135 | assert isinstance(module, nn.ConvTranspose2d)
136 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
137 |
138 | in_c, in_h, in_w = inp.size()[1:]
139 | k_h, k_w = module.kernel_size
140 | out_c, out_h, out_w = out.size()[1:]
141 | groups = module.groups
142 |
143 | kernel_mul = k_h * k_w * (in_c // groups)
144 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
145 |
146 | kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups)
147 | kernel_add_group = kernel_add * in_h * in_w * (out_c // groups)
148 |
149 | total_mul = kernel_mul_group * groups
150 | total_add = kernel_add_group * groups
151 |
152 | return total_mul + total_add
153 |
154 |
155 | def compute_BatchNorm2d_madd(module, inp, out):
156 | assert isinstance(module, nn.BatchNorm2d)
157 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
158 |
159 | in_c, in_h, in_w = inp.size()[1:]
160 |
161 | # 1. sub mean
162 | # 2. div standard deviation
163 | # 3. mul alpha
164 | # 4. add beta
165 | return 4 * in_c * in_h * in_w
166 |
167 |
168 | def compute_Linear_madd(module, inp, out):
169 | assert isinstance(module, nn.Linear)
170 | assert len(inp.size()) == 2 and len(out.size()) == 2
171 |
172 | num_in_features = inp.size()[1]
173 | num_out_features = out.size()[1]
174 |
175 | mul = num_in_features
176 | add = num_in_features - 1 + (0 if module.bias is None else 1)
177 |
178 | return num_out_features * (mul + add)
179 |
180 |
181 | def compute_MaxPool2d_madd(module, inp, out):
182 | assert isinstance(module, nn.MaxPool2d)
183 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
184 |
185 | if isinstance(module.kernel_size, (tuple, list)):
186 | k_h, k_w = module.kernel_size
187 | else:
188 | k_h, k_w = module.kernel_size, module.kernel_size
189 | out_c, out_h, out_w = out.size()[1:]
190 |
191 | return (k_h * k_w - 1) * out_h * out_w * out_c
192 |
193 |
194 | def compute_AdaptiveMaxPool2d_madd(module, inp, out):
195 | assert isinstance(module, nn.AdaptiveMaxPool2d)
196 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
197 |
198 | in_c, in_h, in_w = inp.size()[1:]
199 | out_c, out_h, out_w = out.size()[1:]
200 | k_h = int(round(in_h / out_h))
201 | k_w = int(round(in_w / out_w))
202 |
203 | return (k_h * k_w - 1) * out_h * out_w * out_c
204 |
205 |
206 | def compute_AvgPool2d_madd(module, inp, out):
207 | assert isinstance(module, nn.AvgPool2d)
208 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
209 |
210 | if isinstance(module.kernel_size, (tuple, list)):
211 | k_h, k_w = module.kernel_size
212 | else:
213 | k_h, k_w = module.kernel_size, module.kernel_size
214 | out_c, out_h, out_w = out.size()[1:]
215 |
216 | kernel_add = k_h * k_w - 1
217 | kernel_avg = 1
218 |
219 | return (kernel_add + kernel_avg) * out_h * out_w * out_c
220 |
221 |
222 | def compute_AdaptiveAvgPool2d_madd(module, inp, out):
223 | assert isinstance(module, nn.AdaptiveAvgPool2d)
224 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
225 |
226 | in_c, in_h, in_w = inp.size()[1:]
227 | out_c, out_h, out_w = out.size()[1:]
228 | k_h = int(round(in_h / out_h))
229 | k_w = int(round(in_w / out_w))
230 |
231 | kernel_add = k_h * k_w - 1
232 | kernel_avg = 1
233 |
234 | return (kernel_add + kernel_avg) * out_h * out_w * out_c
235 |
--------------------------------------------------------------------------------
/benchmark/compute_memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import sys
5 | sys.path.append('..')
6 | from networks import HSwish, HSigmoid, Swish, Sigmoid
7 |
8 | def compute_memory(module, inp, out):
9 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)):
10 | return compute_ReLU_memory(module, inp, out)
11 | elif isinstance(module, nn.PReLU):
12 | return compute_PReLU_memory(module, inp, out)
13 | elif isinstance(module, (Sigmoid, HSigmoid)):
14 | return compute_Sigmoid_memory(module, inp, out)
15 | elif isinstance(module, (Swish, HSwish)):
16 | return compute_Swish_memory(module, inp, out)
17 | elif isinstance(module, nn.Conv2d):
18 | return compute_Conv2d_memory(module, inp, out)
19 | elif isinstance(module, nn.ConvTranspose2d):
20 | return compute_ConvTranspose2d_memory(module, inp, out)
21 | elif isinstance(module, nn.BatchNorm2d):
22 | return compute_BatchNorm2d_memory(module, inp, out)
23 | elif isinstance(module, nn.Linear):
24 | return compute_Linear_memory(module, inp, out)
25 | elif isinstance(module, (
26 | nn.AvgPool2d, nn.MaxPool2d, nn.AdaptiveAvgPool2d,
27 | nn.AdaptiveMaxPool2d)):
28 | return compute_Pool2d_memory(module, inp, out)
29 | else:
30 | print("[Memory]: {} is not supported!".format(type(module).__name__))
31 | return 0, 0
32 | pass
33 |
34 |
35 | def num_params(module):
36 | return sum(p.numel() for p in module.parameters() if p.requires_grad) # why conditioned if p.requires_grad ???
37 |
38 |
39 | def compute_ReLU_memory(module, inp, out):
40 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU))
41 | batch_size = inp.size()[0]
42 | mread = batch_size * inp.size()[1:].numel()
43 | mwrite = batch_size * inp.size()[1:].numel()
44 |
45 | return (mread, mwrite)
46 |
47 |
48 | def compute_PReLU_memory(module, inp, out):
49 | assert isinstance(module, (nn.PReLU))
50 | batch_size = inp.size()[0]
51 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
52 | mwrite = batch_size * inp.size()[1:].numel()
53 |
54 | return (mread, mwrite)
55 |
56 |
57 | def compute_Sigmoid_memory(module, inp, out):
58 | assert isinstance(module, (Sigmoid, HSigmoid))
59 | batch_size = inp.size()[0]
60 | mread = batch_size * inp.size()[1:].numel()
61 | mwrite = batch_size * inp.size()[1:].numel()
62 |
63 | return (mread, mwrite)
64 |
65 |
66 | def compute_Swish_memory(module, inp, out):
67 | assert isinstance(module, (Swish, HSwish))
68 | batch_size = inp.size()[0]
69 | mread = batch_size * (inp.size()[1:].numel() + inp.size()[1:].numel())
70 | mwrite = batch_size * inp.size()[1:].numel()
71 |
72 | return (mread, mwrite)
73 |
74 |
75 | def compute_Conv2d_memory(module, inp, out):
76 | assert isinstance(module, nn.Conv2d)
77 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
78 |
79 | batch_size = inp.size()[0]
80 | in_c = inp.size()[1]
81 | out_c, out_h, out_w = out.size()[1:]
82 |
83 | # This includes weighs with bias if the module contains it.
84 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
85 | mwrite = batch_size * out_c * out_h * out_w
86 | return (mread, mwrite)
87 |
88 |
89 | def compute_ConvTranspose2d_memory(module, inp, out):
90 | assert isinstance(module, nn.ConvTranspose2d)
91 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
92 |
93 | batch_size = inp.size()[0]
94 | in_c = inp.size()[1]
95 | out_c, out_h, out_w = out.size()[1:]
96 |
97 | # This includes weighs with bias if the module contains it.
98 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
99 | mwrite = batch_size * out_c * out_h * out_w
100 | return (mread, mwrite)
101 |
102 |
103 | def compute_BatchNorm2d_memory(module, inp, out):
104 | assert isinstance(module, nn.BatchNorm2d)
105 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
106 | batch_size, in_c, in_h, in_w = inp.size()
107 |
108 | mread = batch_size * (inp.size()[1:].numel() + 2 * in_c)
109 | mwrite = inp.size().numel()
110 | return (mread, mwrite)
111 |
112 |
113 | def compute_Linear_memory(module, inp, out):
114 | assert isinstance(module, nn.Linear)
115 | assert len(inp.size()) == 2 and len(out.size()) == 2
116 | batch_size = inp.size()[0]
117 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
118 | mwrite = out.size().numel()
119 |
120 | return (mread, mwrite)
121 |
122 |
123 | def compute_Pool2d_memory(module, inp, out):
124 | assert isinstance(module, (
125 | nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d))
126 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
127 | batch_size = inp.size()[0]
128 | mread = batch_size * inp.size()[1:].numel()
129 | mwrite = batch_size * out.size()[1:].numel()
130 | return (mread, mwrite)
131 |
--------------------------------------------------------------------------------
/benchmark/compute_speed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 |
7 |
8 | def compute_speed(model, input_size, device='cuda:0', iteration=1000):
9 | assert isinstance(input_size, (list, tuple))
10 | assert len(input_size) == 4
11 | os.environ['OMP_NUM_THREADS'] = '1'
12 | os.environ['MKL_NUM_THREADS'] = '1'
13 |
14 | device = torch.device(device)
15 | if 'cuda' in str(device):
16 | cudnn.enabled = True
17 | cudnn.benchmark = True
18 | torch.cuda.set_device(device)
19 |
20 | model = model.to(device)
21 | model.eval()
22 |
23 | x = torch.randn(*input_size, device=device)
24 | x.to(device)
25 |
26 | # warmup for 100 iterations
27 | for _ in range(100):
28 | model(x)
29 |
30 | print('=============Speed Testing=============')
31 | print('Device: {}'.format(str(device)))
32 | if 'cuda' in str(device):
33 | torch.cuda.synchronize() # wait for cuda to finish (cuda is asynchronous!)
34 | torch.cuda.synchronize()
35 | t_start = time.time()
36 | for _ in range(iteration):
37 | model(x)
38 | if 'cuda' in str(device):
39 | torch.cuda.synchronize() # wait for cuda to finish (cuda is asynchronous!)
40 | torch.cuda.synchronize()
41 | elapsed_time = time.time() - t_start
42 | print('Elapsed time: [%.2fs / %diter]' % (elapsed_time, iteration))
43 | print('Speed Time: %.2fms/iter FPS: %.2f' % (
44 | elapsed_time / iteration * 1000, iteration * input_size[0] / elapsed_time))
45 |
46 |
47 |
--------------------------------------------------------------------------------
/benchmark/model_hook.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import OrderedDict
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .compute_madd import compute_madd
8 | from .compute_flops import compute_flops
9 | from .compute_memory import compute_memory
10 |
11 |
12 | class ModelHook(object):
13 | def __init__(self, model, input_size):
14 | assert isinstance(model, nn.Module)
15 | assert isinstance(input_size, (list, tuple))
16 |
17 | self._model = model
18 | self._input_size = input_size
19 | self._origin_call = dict() # sub module call hook
20 |
21 | self._hook_model()
22 | x = torch.rand(*self._input_size) # add module duration time
23 | self._model.eval()
24 | self._model(x)
25 |
26 | @staticmethod
27 | def _register_buffer(module):
28 | assert isinstance(module, nn.Module)
29 |
30 | if len(list(module.children())) > 0:
31 | return
32 |
33 | module.register_buffer('input_shape', torch.zeros(3).int())
34 | module.register_buffer('output_shape', torch.zeros(3).int())
35 | module.register_buffer('parameter_quantity', torch.zeros(1).int())
36 | module.register_buffer('inference_memory', torch.zeros(1).long())
37 | module.register_buffer('MAdd', torch.zeros(1).long())
38 | module.register_buffer('duration', torch.zeros(1).float())
39 | module.register_buffer('ConvFlops', torch.zeros(1).long())
40 | module.register_buffer('Flops', torch.zeros(1).long())
41 | module.register_buffer('MemRead', torch.zeros(1).long())
42 | module.register_buffer('MemWrite', torch.zeros(1).long())
43 |
44 | def _sub_module_call_hook(self):
45 | def wrap_call(module, *input, **kwargs):
46 | assert module.__class__ in self._origin_call
47 |
48 | # Itemsize for memory
49 | itemsize = input[0].detach().numpy().itemsize
50 |
51 | # !!!!!! added by Aber Hu
52 | # Duration is not accurate, since it only runs 1 time, no warmup, no mulit runs.
53 | start = time.time()
54 | output = self._origin_call[module.__class__](module, *input, **kwargs)
55 | end = time.time()
56 | module.duration = torch.from_numpy(
57 | np.array([end - start], dtype=np.float32))
58 |
59 | module.input_shape = torch.from_numpy(
60 | np.array(input[0].size()[1:], dtype=np.int32))
61 | module.output_shape = torch.from_numpy(
62 | np.array(output.size()[1:], dtype=np.int32))
63 |
64 | parameter_quantity = 0
65 | # iterate through parameters and count num params
66 | for name, p in module._parameters.items():
67 | parameter_quantity += (0 if p is None else torch.numel(p))
68 | module.parameter_quantity = torch.from_numpy(
69 | np.array([parameter_quantity], dtype=np.long))
70 |
71 | inference_memory = 1
72 | for s in output.size()[1:]:
73 | inference_memory *= s
74 | # memory += parameters_number # exclude parameter memory
75 | # shown as MB unit
76 | inference_memory = inference_memory * itemsize / (1024 ** 2)
77 | module.inference_memory = torch.from_numpy(
78 | np.array([inference_memory], dtype=np.float32))
79 |
80 | if len(input) == 1:
81 | madd = compute_madd(module, input[0], output)
82 | conv_flops = 0
83 | flops, type = compute_flops(module, input[0], output)
84 | if type == 'Conv2d':
85 | conv_flops = flops
86 | memread, memwrite = compute_memory(module, input[0], output)
87 | elif len(input) > 1:
88 | madd = compute_madd(module, input, output)
89 | conv_flops = 0
90 | flops, type = compute_flops(module, input, output)
91 | if type == 'Conv2d':
92 | conv_flops = flops
93 | memread, memwrite = compute_memory(module, input, output)
94 | else: # error
95 | madd = 0
96 | flops = 0
97 | conv_flops = 0
98 | memread, memwrite = [0, 0]
99 | module.MAdd = torch.from_numpy(
100 | np.array([madd], dtype=np.int64))
101 | module.Flops = torch.from_numpy(
102 | np.array([flops], dtype=np.int64))
103 | module.ConvFlops = torch.from_numpy(
104 | np.array([conv_flops], dtype=np.int64))
105 | module.MemRead = torch.from_numpy(
106 | np.array([memread], dtype=np.int64)*itemsize)
107 | module.MemWrite = torch.from_numpy(
108 | np.array([memwrite], dtype=np.int64)*itemsize)
109 |
110 | return output
111 |
112 | for module in self._model.modules():
113 | if len(list(module.children())) == 0 and module.__class__ not in self._origin_call:
114 | self._origin_call[module.__class__] = module.__class__.__call__
115 | module.__class__.__call__ = wrap_call
116 |
117 | def _sub_module_call_unhook(self):
118 | for module in self._model.modules():
119 | if len(list(module.children())) == 0 and module.__class__ in self._origin_call:
120 | module.__class__.__call__ = self._origin_call[module.__class__]
121 |
122 | def _hook_model(self):
123 | self._model.apply(self._register_buffer)
124 | self._sub_module_call_hook()
125 |
126 | def _unhook_model(self):
127 | self._sub_module_call_unhook()
128 |
129 | @staticmethod
130 | def _retrieve_leaf_modules(model):
131 | leaf_modules = []
132 | for name, m in model.named_modules():
133 | if len(list(m.children())) == 0:
134 | leaf_modules.append((name, m))
135 | return leaf_modules
136 |
137 | def retrieve_leaf_modules(self):
138 | return OrderedDict(self._retrieve_leaf_modules(self._model))
139 |
--------------------------------------------------------------------------------
/benchmark/reporter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | pd.set_option('display.width', 1000)
4 | pd.set_option('display.max_rows', 10000)
5 | pd.set_option('display.max_columns', 10000)
6 |
7 |
8 | def round_value(value, binary=False):
9 | divisor = 1024. if binary else 1000.
10 |
11 | if value // divisor ** 4 > 0:
12 | return str(round(value / divisor ** 4, 2)) + 'T'
13 | elif value // divisor ** 3 > 0:
14 | return str(round(value / divisor ** 3, 2)) + 'G'
15 | elif value // divisor ** 2 > 0:
16 | return str(round(value / divisor ** 2, 2)) + 'M'
17 | elif value // divisor > 0:
18 | return str(round(value / divisor, 2)) + 'K'
19 | return str(value)
20 |
21 |
22 | def report_format(collected_nodes, brief_report=False):
23 | data = list()
24 | for node in collected_nodes:
25 | name = node.name
26 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
27 | *[e for e in node.input_shape])
28 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
29 | *[e for e in node.output_shape])
30 | parameter_quantity = node.parameter_quantity
31 | inference_memory = node.inference_memory
32 | MAdd = node.MAdd
33 | Flops = node.Flops
34 | ConvFlops = node.ConvFlops
35 | mread = node.MemRead
36 | mwrite = node.MemWrite
37 | duration = node.duration
38 | data.append([name, input_shape, output_shape, parameter_quantity,
39 | inference_memory, MAdd, duration, Flops, ConvFlops, mread, mwrite])
40 | df = pd.DataFrame(data)
41 | df.columns = ['module name', 'input shape', 'output shape',
42 | 'params', 'memory(MB)', 'MAdd', 'duration', 'Flops',
43 | 'ConvFlops', 'MemRead(B)', 'MemWrite(B)']
44 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
45 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
46 | total_parameters_quantity = df['params'].sum()
47 | total_memory = df['memory(MB)'].sum()
48 | total_operation_quantity = df['MAdd'].sum()
49 | total_flops = df['Flops'].sum()
50 | total_conv_flops = df['ConvFlops'].sum()
51 | total_duration = df['duration[%]'].sum()
52 | total_mread = df['MemRead(B)'].sum()
53 | total_mwrite = df['MemWrite(B)'].sum()
54 | total_memrw = df['MemR+W(B)'].sum()
55 | del df['duration']
56 |
57 | # Add Total row
58 | total_df = pd.Series([total_parameters_quantity, total_memory,
59 | total_operation_quantity, total_flops,
60 | total_conv_flops, total_duration,
61 | total_mread, total_mwrite, total_memrw],
62 | index=['params', 'memory(MB)', 'MAdd', 'Flops',
63 | 'ConvFlops', 'duration[%]',
64 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
65 | name='total')
66 | df = df.append(total_df)
67 |
68 | df = df.fillna(' ')
69 | df['params'] = df['params'].apply(lambda x: '{:,}'.format(x))
70 | df['memory(MB)'] = df['memory(MB)'].apply(lambda x: '{:,.2f}'.format(x))
71 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
72 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
73 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))
74 | df['ConvFlops'] = df['ConvFlops'].apply(lambda x: '{:,}'.format(x))
75 | df['MemRead(B)'] = df['MemRead(B)'].apply(lambda x: '{:,}'.format(x))
76 | df['MemWrite(B)'] = df['MemWrite(B)'].apply(lambda x: '{:,}'.format(x))
77 | df['MemR+W(B)'] = df['MemR+W(B)'].apply(lambda x: '{:,}'.format(x))
78 |
79 | if not brief_report:
80 | summary = str(df) + '\n'
81 | summary += "=" * len(str(df).split('\n')[0])
82 | summary += '\n'
83 | summary += "Total params: {}\n".format(round_value(total_parameters_quantity))
84 |
85 | summary += "-" * len(str(df).split('\n')[0])
86 | summary += '\n'
87 | summary += "Total memory: {:.2f}MB\n".format(total_memory)
88 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity))
89 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops))
90 | summary += "Total Flops(Conv Only): {}Flops\n".format(round_value(total_conv_flops))
91 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True))
92 | else:
93 | summary = "Total params: {}\n".format(round_value(total_parameters_quantity))
94 | summary += "Total memory: {:.2f}MB\n".format(total_memory)
95 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity))
96 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops))
97 | summary += "Total Flops(Conv Only): {}Flops\n".format(round_value(total_conv_flops))
98 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True))
99 |
100 | return summary
101 |
--------------------------------------------------------------------------------
/benchmark/stat_tree.py:
--------------------------------------------------------------------------------
1 | import queue
2 |
3 |
4 | class StatTree(object):
5 | def __init__(self, root_node):
6 | assert isinstance(root_node, StatNode)
7 |
8 | self.root_node = root_node
9 |
10 | def get_same_level_max_node_depth(self, query_node):
11 | if query_node.name == self.root_node.name:
12 | return 0
13 | same_level_depth = max([child.depth for child in query_node.parent.children])
14 | return same_level_depth
15 |
16 | def update_stat_nodes_granularity(self):
17 | q = queue.Queue()
18 | q.put(self.root_node)
19 | while not q.empty():
20 | node = q.get()
21 | node.granularity = self.get_same_level_max_node_depth(node)
22 | for child in node.children:
23 | q.put(child)
24 |
25 | def get_collected_stat_nodes(self, query_granularity):
26 | self.update_stat_nodes_granularity()
27 |
28 | collected_nodes = []
29 | stack = list()
30 | stack.append(self.root_node)
31 | while len(stack) > 0:
32 | node = stack.pop()
33 | for child in reversed(node.children):
34 | stack.append(child)
35 | if node.depth == query_granularity:
36 | collected_nodes.append(node)
37 | if node.depth < query_granularity <= node.granularity:
38 | collected_nodes.append(node)
39 | return collected_nodes
40 |
41 |
42 | class StatNode(object):
43 | def __init__(self, name=str(), parent=None):
44 | self._name = name
45 | self._input_shape = None
46 | self._output_shape = None
47 | self._parameter_quantity = 0
48 | self._inference_memory = 0
49 | self._MAdd = 0
50 | self._MemRead = 0
51 | self._MemWrite = 0
52 | self._Flops = 0
53 | self._ConvFlops = 0
54 | self._duration = 0
55 | self._duration_percent = 0
56 |
57 | self._granularity = 1
58 | self._depth = 1
59 | self.parent = parent
60 | self.children = list()
61 |
62 | @property
63 | def name(self):
64 | return self._name
65 |
66 | @name.setter
67 | def name(self, name):
68 | self._name = name
69 |
70 | @property
71 | def granularity(self):
72 | return self._granularity
73 |
74 | @granularity.setter
75 | def granularity(self, g):
76 | self._granularity = g
77 |
78 | @property
79 | def depth(self):
80 | d = self._depth
81 | if len(self.children) > 0:
82 | d += max([child.depth for child in self.children])
83 | return d
84 |
85 | @property
86 | def input_shape(self):
87 | if len(self.children) == 0: # leaf
88 | return self._input_shape
89 | else:
90 | return self.children[0].input_shape
91 |
92 | @input_shape.setter
93 | def input_shape(self, input_shape):
94 | assert isinstance(input_shape, (list, tuple))
95 | self._input_shape = input_shape
96 |
97 | @property
98 | def output_shape(self):
99 | if len(self.children) == 0: # leaf
100 | return self._output_shape
101 | else:
102 | return self.children[-1].output_shape
103 |
104 | @output_shape.setter
105 | def output_shape(self, output_shape):
106 | assert isinstance(output_shape, (list, tuple))
107 | self._output_shape = output_shape
108 |
109 | @property
110 | def parameter_quantity(self):
111 | total_parameter_quantity = self._parameter_quantity
112 | for child in self.children:
113 | total_parameter_quantity += child.parameter_quantity
114 | return total_parameter_quantity
115 |
116 | @parameter_quantity.setter
117 | def parameter_quantity(self, parameter_quantity):
118 | assert parameter_quantity >= 0
119 | self._parameter_quantity = parameter_quantity
120 |
121 | @property
122 | def inference_memory(self):
123 | total_inference_memory = self._inference_memory
124 | for child in self.children:
125 | total_inference_memory += child.inference_memory
126 | return total_inference_memory
127 |
128 | @inference_memory.setter
129 | def inference_memory(self, inference_memory):
130 | self._inference_memory = inference_memory
131 |
132 | @property
133 | def MAdd(self):
134 | total_MAdd = self._MAdd
135 | for child in self.children:
136 | total_MAdd += child.MAdd
137 | return total_MAdd
138 |
139 | @MAdd.setter
140 | def MAdd(self, MAdd):
141 | self._MAdd = MAdd
142 |
143 | @property
144 | def Flops(self):
145 | total_Flops = self._Flops
146 | for child in self.children:
147 | total_Flops += child.Flops
148 | return total_Flops
149 |
150 | @Flops.setter
151 | def Flops(self, Flops):
152 | self._Flops = Flops
153 |
154 | @property
155 | def ConvFlops(self):
156 | total_ConvFlops = self._ConvFlops
157 | for child in self.children:
158 | total_ConvFlops += child.ConvFlops
159 | return total_ConvFlops
160 |
161 | @ConvFlops.setter
162 | def ConvFlops(self, ConvFlops):
163 | self._ConvFlops = ConvFlops
164 |
165 | @property
166 | def MemRead(self):
167 | total_MemRead = self._MemRead
168 | for child in self.children:
169 | total_MemRead += child.MemRead
170 | return total_MemRead
171 |
172 | @MemRead.setter
173 | def MemRead(self, MemRead):
174 | self._MemRead = MemRead
175 |
176 | @property
177 | def MemWrite(self):
178 | total_MemWrite = self._MemWrite
179 | for child in self.children:
180 | total_MemWrite += child.MemWrite
181 | return total_MemWrite
182 |
183 | @MemWrite.setter
184 | def MemWrite(self, MemWrite):
185 | self._MemWrite = MemWrite
186 |
187 | @property
188 | def duration(self):
189 | total_duration = self._duration
190 | for child in self.children:
191 | total_duration += child.duration
192 | return total_duration
193 |
194 | @duration.setter
195 | def duration(self, duration):
196 | self._duration = duration
197 |
198 | def find_child_index(self, child_name):
199 | assert isinstance(child_name, str)
200 |
201 | index = -1
202 | for i in range(len(self.children)):
203 | if child_name == self.children[i].name:
204 | index = i
205 | return index
206 |
207 | def add_child(self, node):
208 | assert isinstance(node, StatNode)
209 |
210 | if self.find_child_index(node.name) == -1: # not exist
211 | self.children.append(node)
212 |
--------------------------------------------------------------------------------
/benchmark/statistics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 |
5 | from .model_hook import ModelHook
6 | from .stat_tree import StatTree, StatNode
7 | from .reporter import report_format
8 |
9 |
10 | def get_parent_node(root_node, stat_node_name):
11 | assert isinstance(root_node, StatNode)
12 |
13 | node = root_node
14 | names = stat_node_name.split('.')
15 | for i in range(len(names) - 1):
16 | node_name = '.'.join(names[0:i+1])
17 | child_index = node.find_child_index(node_name)
18 | assert child_index != -1
19 | node = node.children[child_index]
20 | return node
21 |
22 |
23 | def convert_leaf_modules_to_stat_tree(leaf_modules):
24 | assert isinstance(leaf_modules, OrderedDict)
25 |
26 | create_index = 1
27 | root_node = StatNode(name='root', parent=None)
28 | for leaf_module_name, leaf_module in leaf_modules.items():
29 | names = leaf_module_name.split('.')
30 | for i in range(len(names)):
31 | create_index += 1
32 | stat_node_name = '.'.join(names[0:i+1])
33 | parent_node = get_parent_node(root_node, stat_node_name)
34 | node = StatNode(name=stat_node_name, parent=parent_node)
35 | parent_node.add_child(node)
36 | if i == len(names) - 1: # leaf module itself
37 | node.input_shape = leaf_module.input_shape.numpy().tolist()
38 | node.output_shape = leaf_module.output_shape.numpy().tolist()
39 | node.parameter_quantity = leaf_module.parameter_quantity.numpy()[0]
40 | node.inference_memory = leaf_module.inference_memory.numpy()[0]
41 | node.MAdd = leaf_module.MAdd.numpy()[0]
42 | node.Flops = leaf_module.Flops.numpy()[0]
43 | node.ConvFlops = leaf_module.ConvFlops.numpy()[0]
44 | node.duration = leaf_module.duration.numpy()[0]
45 | node.MemRead = leaf_module.MemRead.numpy()[0]
46 | node.MemWrite = leaf_module.MemWrite.numpy()[0]
47 | return StatTree(root_node)
48 |
49 |
50 | class ModelStat(object):
51 | def __init__(self, model, input_size, query_granularity=1, brief_report=False):
52 | assert isinstance(model, nn.Module)
53 | assert isinstance(input_size, (tuple, list)) and len(input_size) == 4
54 | self.model_hook = ModelHook(model, input_size)
55 | self.leaf_modules = self.model_hook.retrieve_leaf_modules()
56 | self.stat_tree = convert_leaf_modules_to_stat_tree(self.leaf_modules)
57 | self._brief_report = brief_report
58 |
59 | if 1 <= query_granularity <= self.stat_tree.root_node.depth:
60 | self._query_granularity = query_granularity
61 | else:
62 | self._query_granularity = self.stat_tree.root_node.depth
63 |
64 | def show_report(self):
65 | collected_nodes = self.stat_tree.get_collected_stat_nodes(self._query_granularity)
66 | report = report_format(collected_nodes, self._brief_report)
67 | print(report)
68 |
69 | def unhook_model(self):
70 | self.model_hook._unhook_model()
71 |
72 | @property
73 | def query_granularity(self):
74 | return self._query_granularity
75 |
76 | @query_granularity.setter
77 | def query_granularity(self, query_granularity):
78 | if 1 <= query_granularity <= self.stat_tree.root_node.depth:
79 | self._query_granularity = query_granularity
80 | else:
81 | self._query_granularity = self.stat_tree.root_node.depth
82 |
83 | @property
84 | def brief_report(self):
85 | return self._brief_report
86 |
87 | @brief_report.setter
88 | def brief_report(self, brief_report):
89 | self._brief_report = brief_report
90 |
91 |
92 | def stat(model, input_size, query_granularity=1, brief_report=False):
93 | ms = ModelStat(model, input_size, query_granularity, brief_report)
94 | ms.show_report()
95 | ms.unhook_model()
96 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from PIL import Image
4 | import torch.utils.data as data
5 | import torchvision.transforms as transforms
6 | import nvidia.dali.pipeline as pipeline
7 | import nvidia.dali.ops as ops
8 | import nvidia.dali.types as types
9 | IMAGENET_MEAN = [0.485, 0.456, 0.406]
10 | IMAGENET_STD = [0.229, 0.224, 0.225]
11 |
12 |
13 | # If get UserWarning: Corrupt EXIF data, use cv2_loader or ignore warnings
14 | def pil_loader(path):
15 | img = Image.open(path).convert('RGB')
16 | return img
17 |
18 | def cv2_loader(path):
19 | img = cv2.imread(path, cv2.IMREAD_COLOR)
20 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
21 | img = Image.fromarray(img)
22 | return img
23 |
24 | default_loader = pil_loader
25 |
26 |
27 | def default_list_reader(list_path):
28 | img_list = []
29 | with open(list_path, 'r') as f:
30 | for line in f.readlines():
31 | img_path, label = line.strip().split(' ')
32 | img_list.append((img_path, int(label)))
33 |
34 | return img_list
35 |
36 |
37 | class ImageList(data.Dataset):
38 | def __init__(self, root, list_path, transform=None, list_reader=default_list_reader, loader=default_loader):
39 | self.root = root
40 | self.img_list = list_reader(list_path)
41 | self.transform = transform
42 | self.loader = loader
43 |
44 | def __getitem__(self, index):
45 | img_path, target = self.img_list[index]
46 | img = self.loader(os.path.join(self.root, img_path))
47 |
48 | if self.transform:
49 | img = self.transform(img)
50 |
51 | return img, target
52 |
53 | def __len__(self):
54 | return len(self.img_list)
55 |
56 |
57 | def get_train_transform(coji=False):
58 | transform_list = [
59 | transforms.RandomResizedCrop(224),
60 | transforms.RandomHorizontalFlip(0.5),
61 | ]
62 | if coji:
63 | transform_list += [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),]
64 | transform_list += [
65 | transforms.ToTensor(),
66 | transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
67 | ]
68 | train_transform = transforms.Compose(transform_list)
69 |
70 | return train_transform
71 |
72 |
73 | def get_val_transform():
74 | transform_list = [
75 | transforms.Resize(256),
76 | transforms.CenterCrop(224),
77 | transforms.ToTensor(),
78 | transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
79 | ]
80 | val_transform = transforms.Compose(transform_list)
81 |
82 | return val_transform
83 |
84 |
85 | class HybridTrainPipe(pipeline.Pipeline):
86 | def __init__(self, batch_size, num_threads, device_id, root, list_path,
87 | crop, shard_id, num_shards, coji=False, dali_cpu=False):
88 | super(HybridTrainPipe, self).__init__(batch_size,
89 | num_threads,
90 | device_id,
91 | seed=12 + device_id)
92 | self.read = ops.FileReader(file_root=root,
93 | file_list=list_path,
94 | shard_id=shard_id,
95 | num_shards=num_shards,
96 | random_shuffle=True,
97 | initial_fill=1024)
98 | # Let user decide which pipeline works
99 | dali_device = 'cpu' if dali_cpu else 'gpu'
100 | decoder_device = 'cpu' if dali_cpu else 'mixed'
101 | # This padding sets the size of the internal nvJPEG buffers to be able to handle all images
102 | # from full-sized ImageNet without additional reallocations
103 | device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
104 | host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
105 | self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
106 | device_memory_padding=device_memory_padding,
107 | host_memory_padding=host_memory_padding,
108 | random_aspect_ratio=[0.75, 1.33333333],
109 | random_area=[0.08, 1.0],
110 | num_attempts=100)
111 | self.resize = ops.Resize(device=dali_device,
112 | resize_x=crop,
113 | resize_y=crop,
114 | interp_type=types.INTERP_TRIANGULAR)
115 | self.cmnp = ops.CropMirrorNormalize(device=dali_device,
116 | output_dtype=types.FLOAT,
117 | output_layout=types.NCHW,
118 | crop=(crop, crop),
119 | image_type=types.RGB,
120 | mean=[x*255 for x in IMAGENET_MEAN],
121 | std=[x*255 for x in IMAGENET_STD])
122 | self.coin = ops.CoinFlip(probability=0.5)
123 |
124 | self.coji = coji
125 | if self.coji:
126 | self.twist = ops.ColorTwist(device=dali_device)
127 | self.brightness_rng = ops.Uniform(range=[1.0-0.4, 1.0+0.4])
128 | self.contrast_rng = ops.Uniform(range=[1.0-0.4, 1.0+0.4])
129 | self.saturation_rng = ops.Uniform(range=[1.0-0.4, 1.0+0.4])
130 |
131 | def define_graph(self):
132 | rng = self.coin()
133 | imgs, targets = self.read(name="Reader")
134 | imgs = self.decode(imgs)
135 | imgs = self.resize(imgs)
136 | if self.coji:
137 | brightness = self.brightness_rng()
138 | contrast = self.contrast_rng()
139 | saturation = self.saturation_rng()
140 | imgs = self.twist(imgs, brightness=brightness, contrast=contrast, saturation=saturation)
141 | imgs = self.cmnp(imgs, mirror=rng)
142 | return [imgs, targets]
143 |
144 |
145 | class HybridValPipe(pipeline.Pipeline):
146 | def __init__(self, batch_size, num_threads, device_id, root, list_path,
147 | size, crop, shard_id, num_shards, dali_cpu=False):
148 | super(HybridValPipe, self).__init__(batch_size,
149 | num_threads,
150 | device_id,
151 | seed=12 + device_id)
152 | self.read = ops.FileReader(file_root=root,
153 | file_list=list_path,
154 | shard_id=shard_id,
155 | num_shards=num_shards,
156 | random_shuffle=False)
157 | # Let user decide which pipeline works
158 | dali_device = 'cpu' if dali_cpu else 'gpu'
159 | decoder_device = 'cpu' if dali_cpu else 'mixed'
160 | self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
161 | self.resize = ops.Resize(device=dali_device,
162 | resize_shorter=size,
163 | interp_type=types.INTERP_TRIANGULAR)
164 | self.cmnp = ops.CropMirrorNormalize(device=dali_device,
165 | output_dtype=types.FLOAT,
166 | output_layout=types.NCHW,
167 | crop=(crop, crop),
168 | image_type=types.RGB,
169 | mean=[x*255 for x in IMAGENET_MEAN],
170 | std=[x*255 for x in IMAGENET_STD])
171 |
172 | def define_graph(self):
173 | imgs, targets = self.read(name="Reader")
174 | imgs = self.decode(imgs)
175 | imgs = self.resize(imgs)
176 | imgs = self.cmnp(imgs)
177 | return [imgs, targets]
178 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CrossEntropyLabelSmooth(nn.Module):
6 | def __init__(self, num_classes, epsilon):
7 | super(CrossEntropyLabelSmooth, self).__init__()
8 | self.num_classes = num_classes
9 | self.epsilon = epsilon
10 | self.logsoftmax = nn.LogSoftmax(dim=1)
11 |
12 | def forward(self, xs, targets):
13 | log_probs = self.logsoftmax(xs)
14 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
15 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
16 | loss = (-targets * log_probs).mean(0).sum()
17 |
18 | return loss
19 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import LambdaLR
3 | import warnings
4 |
5 |
6 | class LambdaLRWithMin(LambdaLR):
7 | def __init__(self, optimizer, lr_lambda, eta_min=0, last_epoch=-1):
8 | self.eta_min = eta_min
9 | super(LambdaLRWithMin, self).__init__(optimizer, lr_lambda, last_epoch)
10 |
11 | def get_lr(self):
12 | if not self._get_lr_called_within_step:
13 | warnings.warn("To get the last learning rate computed by the scheduler, "
14 | "please use `get_last_lr()`.")
15 |
16 | return [base_lr * lmbda(self.last_epoch) + self.eta_min * (1.0 - lmbda(self.last_epoch))
17 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
18 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 |
6 |
7 | class HSwish(nn.Module):
8 | def __init__(self, inplace=True):
9 | super(HSwish, self).__init__()
10 | self.inplace = inplace
11 |
12 | def forward(self, x):
13 | out = x * F.relu6(x + 3, inplace=self.inplace) / 6
14 | return out
15 |
16 | class HSigmoid(nn.Module):
17 | def __init__(self, inplace=True):
18 | super(HSigmoid, self).__init__()
19 | self.inplace = inplace
20 |
21 | def forward(self, x):
22 | out = F.relu6(x + 3, inplace=self.inplace) / 6
23 | return out
24 |
25 | class Swish(nn.Module):
26 | def __init__(self):
27 | super(Swish, self).__init__()
28 |
29 | def forward(self, x):
30 | out = x * F.sigmoid(x)
31 | return out
32 |
33 | Sigmoid = nn.Sigmoid
34 |
35 |
36 | hswish = HSwish
37 | hsigmoid = HSigmoid
38 | swish = Swish
39 | sigmoid = Sigmoid
40 | relu = nn.ReLU
41 | relu6 = nn.ReLU6
42 |
43 |
44 | class SEModule(nn.Module):
45 | def __init__(self, in_channels, reduction=4):
46 | super(SEModule, self).__init__()
47 | self.se = nn.Sequential(
48 | nn.AdaptiveAvgPool2d(1),
49 | nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1, stride=1, padding=0, bias=True),
50 | nn.BatchNorm2d(in_channels//reduction),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(in_channels//reduction, in_channels, kernel_size=1, stride=1, padding=0, bias=True),
53 | nn.BatchNorm2d(in_channels),
54 | hsigmoid(inplace=True)
55 | )
56 |
57 | def forward(self, x):
58 | return x * self.se(x)
59 |
60 |
61 | class MBInvertedResBlock(nn.Module):
62 | def __init__(self, in_channels, mid_channels, out_channels, kernel_size=3, stride=1, act_func=relu, with_se=False):
63 | super(MBInvertedResBlock, self).__init__()
64 | self.has_residual = (in_channels == out_channels) and (stride == 1)
65 | self.se = SEModule(mid_channels) if with_se else None
66 |
67 | if mid_channels > in_channels:
68 | self.inverted_bottleneck = nn.Sequential(
69 | nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, bias=False),
70 | nn.BatchNorm2d(mid_channels),
71 | act_func(inplace=True)
72 | )
73 | else:
74 | self.inverted_bottleneck = None
75 | self.depth_conv = nn.Sequential(
76 | nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride,
77 | padding=kernel_size//2, groups=mid_channels, bias=False),
78 | nn.BatchNorm2d(mid_channels),
79 | act_func(inplace=True)
80 | )
81 | self.point_linear = nn.Sequential(
82 | nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
83 | nn.BatchNorm2d(out_channels)
84 | )
85 |
86 | def forward(self, x):
87 | res = x
88 |
89 | if self.inverted_bottleneck is not None:
90 | out = self.inverted_bottleneck(x)
91 | else:
92 | out = x
93 |
94 | out = self.depth_conv(out)
95 | if self.se is not None:
96 | out = self.se(out)
97 | out = self.point_linear(out)
98 |
99 | if self.has_residual:
100 | out += res
101 |
102 | return out
103 |
104 |
105 | class MobileNetV3_Large(nn.Module):
106 | def __init__(self, num_classes=1000, dropout_rate=0.0, zero_init_last_bn=False):
107 | super(MobileNetV3_Large, self).__init__()
108 | self.dropout_rate = dropout_rate
109 | self.zero_init_last_bn = zero_init_last_bn
110 |
111 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
112 | self.bn1 = nn.BatchNorm2d(16)
113 | self.hs1 = hswish(inplace=True)
114 | self.bneck = nn.Sequential(
115 | MBInvertedResBlock(16, 16, 16, 3, 1, relu, False),
116 | MBInvertedResBlock(16, 64, 24, 3, 2, relu, False),
117 | MBInvertedResBlock(24, 72, 24, 3, 1, relu, False),
118 | MBInvertedResBlock(24, 72, 40, 5, 2, relu, True),
119 | MBInvertedResBlock(40, 120, 40, 5, 1, relu, True),
120 | MBInvertedResBlock(40, 120, 40, 5, 1, relu, True),
121 | MBInvertedResBlock(40, 240, 80, 3, 2, hswish, False),
122 | MBInvertedResBlock(80, 200, 80, 3, 1, hswish, False),
123 | MBInvertedResBlock(80, 184, 80, 3, 1, hswish, False),
124 | MBInvertedResBlock(80, 184, 80, 3, 1, hswish, False),
125 | MBInvertedResBlock(80, 480, 112, 3, 1, hswish, True),
126 | MBInvertedResBlock(112, 672, 112, 3, 1, hswish, True),
127 | MBInvertedResBlock(112, 672, 160, 5, 2, hswish, True),
128 | MBInvertedResBlock(160, 960, 160, 5, 1, hswish, True),
129 | MBInvertedResBlock(160, 960, 160, 5, 1, hswish, True),
130 | )
131 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
132 | self.bn2 = nn.BatchNorm2d(960)
133 | self.hs2 = hswish(inplace=True)
134 | self.avgpool = nn.AdaptiveAvgPool2d(1)
135 | self.conv3 = nn.Conv2d(960, 1280, kernel_size=1, stride=1, padding=0, bias=True)
136 | self.hs3 = hswish()
137 | self.classifier = nn.Linear(1280, num_classes)
138 |
139 | self._initialization()
140 | # self._set_bn_param(0.1, 0.001)
141 |
142 | def forward(self, x):
143 | out = self.conv1(x)
144 | out = self.bn1(out)
145 | out = self.hs1(out)
146 |
147 | out = self.bneck(out)
148 | out = self.conv2(out)
149 | out = self.bn2(out)
150 | out = self.hs2(out)
151 |
152 | out = self.avgpool(out)
153 | out = self.conv3(out)
154 | out = self.hs3(out)
155 | out = out.view(out.size(0), -1)
156 | if self.dropout_rate > 0.0:
157 | out = F.dropout(out, p=self.dropout_rate, training=self.training)
158 | out = self.classifier(out)
159 |
160 | return out
161 |
162 | def _initialization(self):
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | init.kaiming_normal_(m.weight, mode='fan_out')
166 | if m.bias is not None:
167 | init.constant_(m.bias, 0)
168 | elif isinstance(m, nn.BatchNorm2d):
169 | init.constant_(m.weight, 1)
170 | init.constant_(m.bias, 0)
171 | elif isinstance(m, nn.Linear):
172 | init.normal_(m.weight, std=0.001)
173 | if m.bias is not None:
174 | init.constant_(m.bias, 0)
175 |
176 | if self.zero_init_last_bn:
177 | for mname, m in self.named_modules():
178 | if isinstance(m, MBInvertedResBlock):
179 | if m.has_residual:
180 | init.constant_(m.point_linear[1].weight, 0)
181 |
182 | # def _set_bn_param(self, bn_momentum, bn_eps):
183 | # for m in self.modules():
184 | # if isinstance(m, nn.BatchNorm2d):
185 | # m.momentum = bn_momentum
186 | # m.eps = bn_eps
187 |
188 |
189 | class MobileNetV3_Small(nn.Module):
190 | def __init__(self, num_classes=1000, dropout_rate=0.0, zero_init_last_bn=False):
191 | super(MobileNetV3_Small, self).__init__()
192 | self.dropout_rate = dropout_rate
193 | self.zero_init_last_bn = zero_init_last_bn
194 |
195 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
196 | self.bn1 = nn.BatchNorm2d(16)
197 | self.hs1 = hswish(inplace=True)
198 | self.bneck = nn.Sequential(
199 | MBInvertedResBlock(16, 16, 16, 3, 2, relu, True),
200 | MBInvertedResBlock(16, 72, 24, 3, 2, relu, False),
201 | MBInvertedResBlock(24, 88, 24, 3, 1, relu, False),
202 | MBInvertedResBlock(24, 96, 40, 5, 2, hswish, True),
203 | MBInvertedResBlock(40, 240, 40, 5, 1, hswish, True),
204 | MBInvertedResBlock(40, 240, 40, 5, 1, hswish, True),
205 | MBInvertedResBlock(40, 120, 48, 5, 1, hswish, True),
206 | MBInvertedResBlock(48, 144, 48, 5, 1, hswish, True),
207 | MBInvertedResBlock(48, 288, 96, 5, 2, hswish, True),
208 | MBInvertedResBlock(96, 576, 96, 5, 1, hswish, True),
209 | MBInvertedResBlock(96, 576, 96, 5, 1, hswish, True),
210 | )
211 | self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
212 | self.bn2 = nn.BatchNorm2d(576)
213 | self.hs2 = hswish(inplace=True)
214 | self.avgpool = nn.AdaptiveAvgPool2d(1)
215 | self.conv3 = nn.Conv2d(576, 1280, kernel_size=1, stride=1, padding=0, bias=True)
216 | self.hs3 = hswish()
217 | self.classifier = nn.Linear(1280, num_classes)
218 |
219 | self._initialization()
220 | # self._set_bn_param(0.1, 0.001)
221 |
222 | def forward(self, x):
223 | out = self.conv1(x)
224 | out = self.bn1(out)
225 | out = self.hs1(out)
226 |
227 | out = self.bneck(out)
228 | out = self.conv2(out)
229 | out = self.bn2(out)
230 | out = self.hs2(out)
231 |
232 | out = self.avgpool(out)
233 | out = self.conv3(out)
234 | out = self.hs3(out)
235 | out = out.view(out.size(0), -1)
236 | if self.dropout_rate > 0.0:
237 | out = F.dropout(out, p=self.dropout_rate, training=self.training)
238 | out = self.classifier(out)
239 |
240 | return out
241 |
242 | def _initialization(self):
243 | for m in self.modules():
244 | if isinstance(m, nn.Conv2d):
245 | init.kaiming_normal_(m.weight, mode='fan_out')
246 | if m.bias is not None:
247 | init.constant_(m.bias, 0)
248 | elif isinstance(m, nn.BatchNorm2d):
249 | init.constant_(m.weight, 1)
250 | init.constant_(m.bias, 0)
251 | elif isinstance(m, nn.Linear):
252 | init.normal_(m.weight, std=0.001)
253 | if m.bias is not None:
254 | init.constant_(m.bias, 0)
255 |
256 | if self.zero_init_last_bn:
257 | for mname, m in self.named_modules():
258 | if isinstance(m, MBInvertedResBlock):
259 | if m.has_residual:
260 | init.constant_(m.point_linear[1].weight, 0)
261 |
262 | # def _set_bn_param(self, bn_momentum, bn_eps):
263 | # for m in self.modules():
264 | # if isinstance(m, nn.BatchNorm2d):
265 | # m.momentum = bn_momentum
266 | # m.eps = bn_eps
267 |
--------------------------------------------------------------------------------
/profile_example.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from benchmark import ModelStat, stat
4 | from benchmark import compute_speed
5 |
6 | sys.path.append('..')
7 | from networks import MobileNetV3_Small, MobileNetV3_Large
8 |
9 |
10 | model = MobileNetV3_Small()
11 |
12 | # query_granularity can be any int value, usually:
13 | # query_granularity=1 reports every leaf node
14 | # query_granularity=-1 only reports the root node
15 | stat(model, (1, 3, 224, 224), query_granularity=1, brief_report=False)
16 | stat(model, (1, 3, 224, 224), query_granularity=-1, brief_report=False)
17 |
18 | # brief_report=True only reports the summation
19 | stat(model, (1, 3, 224, 224), query_granularity=1, brief_report=True)
20 |
21 |
22 | # can also initialize ModelStat, set the query_granularity and then show_report
23 | ms = ModelStat(model, (1, 3, 224, 224), query_granularity=1, brief_report=False)
24 |
25 | ms.query_granularity = -1
26 | ms.show_report()
27 | ms.query_granularity = 1
28 | ms.show_report()
29 |
30 | ms.unhook_model()
31 |
32 | # measure latency
33 | compute_speed(model, (32, 3, 224, 224), 'cuda:0', 1000)
34 | compute_speed(model, (1, 3, 224, 224), 'cuda:0', 1000)
35 | compute_speed(model, (1, 3, 224, 224), 'cpu', 1000)
36 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | import warnings
5 | warnings.filterwarnings('ignore')
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.distributed as dist
10 | import torch.backends.cudnn as cudnn
11 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator
12 | from apex.parallel import DistributedDataParallel as DDP
13 |
14 | from utils import AverageMeter, accuracy
15 | from datasets import ImageList, pil_loader, cv2_loader
16 | from datasets import get_val_transform, HybridValPipe
17 | from networks import MobileNetV3_Large, MobileNetV3_Small
18 |
19 |
20 | parser = argparse.ArgumentParser(
21 | description="Basic Pytorch ImageNet Example. Testing.",
22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23 |
24 | # various paths
25 | parser.add_argument('--val_root', type=str, required=True, help='root path to validating images')
26 | parser.add_argument('--val_list', type=str, required=True, help='validating image list')
27 | parser.add_argument('--weights', type=str, required=True, help='checkpoint for testing')
28 |
29 | # testing hyper-parameters
30 | parser.add_argument('--workers', type=int, default=8, help='number of workers to load dataset (global)')
31 | parser.add_argument('--batch_size', type=int, default=512, help='batch size (global)')
32 | parser.add_argument('--model', type=str, default='MobileNetV3_Large', help='type of model',
33 | choices=['MobileNetV3_Large', 'MobileNetV3_Small'])
34 | parser.add_argument('--num_classes', type=int, default=1000, help='class number of testing set')
35 | parser.add_argument('--trans_mode', type=str, default='tv', help='mode of image transformation (tv/dali)')
36 | parser.add_argument('--dali_cpu', action='store_true', default=False, help='runs CPU based DALI pipeline')
37 | parser.add_argument('--ema', action='store_true', default=False, help='whether to use EMA')
38 |
39 | # amp and DDP hyper-parameters
40 | parser.add_argument('--local_rank', type=int, default=0)
41 | parser.add_argument('--channels_last', type=str, default='False')
42 |
43 |
44 | args, unparsed = parser.parse_known_args()
45 | args.channels_last = eval(args.channels_last)
46 |
47 | if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'):
48 | if args.channels_last:
49 | memory_format = torch.channels_last
50 | else:
51 | memory_format = torch.contiguous_format
52 | else:
53 | memory_format = None
54 |
55 |
56 | def main():
57 | cudnn.enabled=True
58 | cudnn.benchmark = True
59 | args.distributed = False
60 | if 'WORLD_SIZE' in os.environ:
61 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
62 | args.gpu = 0
63 | args.world_size = 1
64 | if args.distributed:
65 | args.gpu = args.local_rank
66 | torch.cuda.set_device(args.gpu)
67 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
68 | args.world_size = torch.distributed.get_world_size()
69 |
70 | # create model
71 | if args.model == 'MobileNetV3_Large':
72 | model = MobileNetV3_Large(args.num_classes, 0.0, False)
73 | elif args.model == 'MobileNetV3_Small':
74 | model = MobileNetV3_Small(args.num_classes, 0.0, False)
75 | else:
76 | raise Exception('invalid type of model')
77 | model = model.cuda().to(memory_format=memory_format) if memory_format is not None else model.cuda()
78 |
79 | # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
80 | # This must be done AFTER the call to amp.initialize.
81 | if args.distributed:
82 | # By default, apex.parallel.DistributedDataParallel overlaps communication with
83 | # computation in the backward pass.
84 | # delay_allreduce delays all communication to the end of the backward pass.
85 | model = DDP(model, delay_allreduce=True)
86 | else:
87 | model = nn.DataParallel(model)
88 |
89 | # define transform and initialize dataloader
90 | batch_size = args.batch_size // args.world_size
91 | workers = args.workers // args.world_size
92 | if args.trans_mode == 'tv':
93 | val_transform = get_val_transform()
94 | val_dataset = ImageList(root=args.val_root,
95 | list_path=args.val_list,
96 | transform=val_transform)
97 | val_sampler = None
98 | if args.distributed:
99 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
100 | val_loader = torch.utils.data.DataLoader(
101 | val_dataset, batch_size=batch_size, num_workers=workers,
102 | pin_memory=True, sampler=val_sampler, shuffle=False)
103 | elif args.trans_mode == 'dali':
104 | pipe = HybridValPipe(batch_size=batch_size,
105 | num_threads=workers,
106 | device_id=args.local_rank,
107 | root=args.val_root,
108 | list_path=args.val_list,
109 | size=256,
110 | crop=224,
111 | shard_id=args.local_rank,
112 | num_shards=args.world_size,
113 | dali_cpu=args.dali_cpu)
114 | pipe.build()
115 | val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/args.world_size))
116 | else:
117 | raise Exception('invalid image transformation mode')
118 |
119 | # restart from weights
120 | if args.weights and os.path.isfile(args.weights):
121 | if args.local_rank == 0:
122 | print('loading weights from {}'.format(args.weights))
123 | checkpoint = torch.load(args.weights, map_location=lambda storage,loc: storage.cuda(args.gpu))
124 | if args.ema:
125 | model.load_state_dict(checkpoint['ema'])
126 | else:
127 | model.load_state_dict(checkpoint['model'])
128 |
129 | val_acc_top1, val_acc_top5 = validate(val_loader, model)
130 | if args.local_rank == 0:
131 | print('Val_acc_top1: {:.2f}'.format(val_acc_top1))
132 | print('Val_acc_top5: {:.2f}'.format(val_acc_top5))
133 |
134 |
135 | def validate(val_loader, model):
136 | top1 = AverageMeter()
137 | top5 = AverageMeter()
138 |
139 | model.eval()
140 |
141 | for data in tqdm(val_loader):
142 | if args.trans_mode == 'tv':
143 | x = data[0].cuda(non_blocking=True)
144 | target = data[1].cuda(non_blocking=True)
145 | elif args.trans_mode == 'dali':
146 | x = data[0]['data'].cuda(non_blocking=True)
147 | target = data[0]['label'].squeeze().cuda(non_blocking=True).long()
148 |
149 | with torch.no_grad():
150 | logits = model(x)
151 |
152 | prec1, prec5 = accuracy(logits, target, topk=(1, 5))
153 | if args.distributed:
154 | prec1 = reduce_tensor(prec1)
155 | prec5 = reduce_tensor(prec5)
156 | top1.update(prec1.item(), x.size(0))
157 | top5.update(prec5.item(), x.size(0))
158 |
159 | return top1.avg, top5.avg
160 |
161 |
162 | def reduce_tensor(tensor):
163 | rt = tensor.clone()
164 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
165 | rt /= args.world_size
166 | return rt
167 |
168 |
169 | if __name__ == '__main__':
170 | main()
171 |
--------------------------------------------------------------------------------
/test_example.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python -u test.py \
2 | --val_root "Your ImageNet Val Set Path" \
3 | --val_list "ImageNet Val List" \
4 | --weights "Pretrained Weights" \
5 | --model 'MobileNetV3_Large' \
6 | --trans_mode 'tv' \
7 | --ema
8 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import glob
5 | import logging
6 | import argparse
7 | import warnings
8 | warnings.filterwarnings('ignore')
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.distributed as dist
14 | import torch.utils.data.distributed
15 | import torch.backends.cudnn as cudnn
16 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator
17 | from apex.parallel import DistributedDataParallel as DDP
18 | from apex import amp, parallel
19 |
20 | from utils import AverageMeter, EMA, accuracy, set_seed
21 | from utils import create_exp_dir, save_checkpoint, get_params
22 | from losses import CrossEntropyLabelSmooth
23 | from datasets import ImageList, pil_loader, cv2_loader
24 | from datasets import get_train_transform, get_val_transform
25 | from datasets import HybridTrainPipe, HybridValPipe
26 | from networks import MobileNetV3_Large, MobileNetV3_Small
27 | from lr_scheduler import LambdaLRWithMin
28 |
29 |
30 | parser = argparse.ArgumentParser(
31 | description="Basic Pytorch ImageNet Example. There is no tricks such as mixup/autoaug/dropblock/droppath etc.",
32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
33 |
34 | # various paths
35 | parser.add_argument('--train_root', type=str, required=True, help='root path to training images')
36 | parser.add_argument('--train_list', type=str, required=True, help='training image list')
37 | parser.add_argument('--val_root', type=str, required=True, help='root path to validating images')
38 | parser.add_argument('--val_list', type=str, required=True, help='validating image list')
39 | parser.add_argument('--save', type=str, default='./checkpoints/', help='model and log saving path')
40 | parser.add_argument('--snapshot', type=str, default='', help='checkpoint for reset')
41 |
42 | # training hyper-parameters
43 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
44 | parser.add_argument('--workers', type=int, default=8, help='number of workers to load dataset (global)')
45 | parser.add_argument('--epochs', type=int, default=250, help='number of total training epochs')
46 | parser.add_argument('--warmup_epochs', type=int, default=5, help='number of warmup epochs')
47 | parser.add_argument('--batch_size', type=int, default=512, help='batch size (global)')
48 | parser.add_argument('--lr', type=float, default=0.2, help='initial learning rate')
49 | parser.add_argument('--lr_min', type=float, default=0.0, help='minimum learning rate')
50 | parser.add_argument('--lr_scheduler', type=str, default='cosine_epoch', help='type of lr scheduler',
51 | choices=['linear_epoch', 'linear_batch', 'cosine_epoch', 'cosine_batch', 'step_epoch', 'step_batch'])
52 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
53 | parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay (wd)')
54 | parser.add_argument('--no_wd_bias_bn', action='store_true', default=False, help='whether to remove wd on bias and bn')
55 | parser.add_argument('--model', type=str, default='MobileNetV3_Large', help='type of model',
56 | choices=['MobileNetV3_Large', 'MobileNetV3_Small'])
57 | parser.add_argument('--num_classes', type=int, default=1000, help='class number of training set')
58 | parser.add_argument('--dropout_rate', type=float, default=0.0, help='dropout rate')
59 | parser.add_argument('--zero_init_last_bn', action='store_true', default=False, help='zero initialize the last bn in each block')
60 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
61 | parser.add_argument('--trans_mode', type=str, default='tv', help='mode of image transformation (tv/dali)')
62 | parser.add_argument('--color_jitter', action='store_true', default=False, help='apply color augmentation or not')
63 | parser.add_argument('--dali_cpu', action='store_true', default=False, help='runs CPU based DALI pipeline')
64 | parser.add_argument('--ema_decay', type=float, default=0.0, help='whether to use EMA')
65 |
66 | # amp and DDP hyper-parameters
67 | parser.add_argument('--local_rank', type=int, default=0)
68 | parser.add_argument('--sync_bn', action='store_true', help='enabling apex sync BN')
69 | parser.add_argument('--opt_level', type=str, default=None)
70 | parser.add_argument('--keep_batchnorm_fp32', type=str, default=None)
71 | parser.add_argument('--loss_scale', type=str, default=None)
72 | parser.add_argument('--channels_last', type=str, default='False')
73 |
74 | # others
75 | parser.add_argument('--seed', type=int, default=2, help='random seed')
76 | parser.add_argument('--note', type=str, default='try', help='note for this run')
77 |
78 |
79 | args, unparsed = parser.parse_known_args()
80 | args.channels_last = eval(args.channels_last)
81 |
82 | args.save = os.path.join(args.save, '{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), args.note))
83 | if args.local_rank == 0:
84 | create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh'))
85 |
86 | log_format = '%(asctime)s %(message)s'
87 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
88 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
89 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
90 | fh.setFormatter(logging.Formatter(log_format))
91 | logging.getLogger().addHandler(fh)
92 |
93 | if hasattr(torch, 'channels_last') and hasattr(torch, 'contiguous_format'):
94 | if args.channels_last:
95 | memory_format = torch.channels_last
96 | else:
97 | memory_format = torch.contiguous_format
98 | else:
99 | memory_format = None
100 |
101 |
102 | def main():
103 | set_seed(args.seed)
104 | cudnn.enabled=True
105 | cudnn.benchmark = True
106 | args.distributed = False
107 | if 'WORLD_SIZE' in os.environ:
108 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
109 | args.gpu = 0
110 | args.world_size = 1
111 | if args.distributed:
112 | set_seed(args.local_rank)
113 | args.gpu = args.local_rank
114 | torch.cuda.set_device(args.gpu)
115 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
116 | args.world_size = torch.distributed.get_world_size()
117 | if args.local_rank == 0:
118 | logging.info("args = {}".format(args))
119 | logging.info("unparsed_args = {}".format(unparsed))
120 | logging.info("distributed = {}".format(args.distributed))
121 | logging.info("sync_bn = {}".format(args.sync_bn))
122 | logging.info("opt_level = {}".format(args.opt_level))
123 | logging.info("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32))
124 | logging.info("loss_scale = {}".format(args.loss_scale))
125 | logging.info("CUDNN VERSION: {}".format(torch.backends.cudnn.version()))
126 |
127 | # create model
128 | if args.model == 'MobileNetV3_Large':
129 | model = MobileNetV3_Large(args.num_classes, args.dropout_rate, args.zero_init_last_bn)
130 | elif args.model == 'MobileNetV3_Small':
131 | model = MobileNetV3_Small(args.num_classes, args.dropout_rate, args.zero_init_last_bn)
132 | else:
133 | raise Exception('invalid type of model')
134 | if args.sync_bn:
135 | if args.local_rank == 0: logging.info("using apex synced BN")
136 | model = parallel.convert_syncbn_model(model)
137 | model = model.cuda().to(memory_format=memory_format) if memory_format is not None else model.cuda()
138 |
139 | # define criterion and optimizer
140 | if args.label_smooth > 0.0:
141 | criterion = CrossEntropyLabelSmooth(args.num_classes, args.label_smooth)
142 | else:
143 | criterion = nn.CrossEntropyLoss()
144 | criterion = criterion.cuda()
145 |
146 | params = get_params(model) if args.no_wd_bias_bn else model.parameters()
147 | optimizer = torch.optim.SGD(params, args.lr,
148 | momentum=args.momentum,
149 | weight_decay=args.weight_decay)
150 | # Initialize Amp
151 | if args.opt_level is not None:
152 | model, optimizer = amp.initialize(model, optimizer,
153 | opt_level=args.opt_level,
154 | keep_batchnorm_fp32=args.keep_batchnorm_fp32,
155 | loss_scale=args.loss_scale)
156 |
157 | # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
158 | # This must be done AFTER the call to amp.initialize.
159 | if args.distributed:
160 | # By default, apex.parallel.DistributedDataParallel overlaps communication with
161 | # computation in the backward pass.
162 | # delay_allreduce delays all communication to the end of the backward pass.
163 | model = DDP(model, delay_allreduce=True)
164 | else:
165 | model = nn.DataParallel(model)
166 |
167 | # exponential moving average
168 | if args.ema_decay > 0.0:
169 | ema = EMA(model, args.ema_decay)
170 | ema.register()
171 | else:
172 | ema = None
173 |
174 | # define transform and initialize dataloader
175 | batch_size = args.batch_size // args.world_size
176 | workers = args.workers // args.world_size
177 | if args.trans_mode == 'tv':
178 | train_transform = get_train_transform(args.color_jitter)
179 | val_transform = get_val_transform()
180 | train_dataset = ImageList(root=args.train_root,
181 | list_path=args.train_list,
182 | transform=train_transform)
183 | val_dataset = ImageList(root=args.val_root,
184 | list_path=args.val_list,
185 | transform=val_transform)
186 | train_sampler = None
187 | val_sampler = None
188 | if args.distributed:
189 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
190 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
191 | train_loader = torch.utils.data.DataLoader(
192 | train_dataset, batch_size=batch_size, num_workers=workers,
193 | pin_memory=True, sampler=train_sampler, shuffle=(train_sampler is None))
194 | val_loader = torch.utils.data.DataLoader(
195 | val_dataset, batch_size=batch_size, num_workers=workers,
196 | pin_memory=True, sampler=val_sampler, shuffle=False)
197 | args.batches_per_epoch = len(train_loader)
198 | elif args.trans_mode == 'dali':
199 | pipe = HybridTrainPipe(batch_size=batch_size,
200 | num_threads=workers,
201 | device_id=args.local_rank,
202 | root=args.train_root,
203 | list_path=args.train_list,
204 | crop=224,
205 | shard_id=args.local_rank,
206 | num_shards=args.world_size,
207 | coji=args.color_jitter,
208 | dali_cpu=args.dali_cpu)
209 | pipe.build()
210 | train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/args.world_size))
211 | args.batches_per_epoch = train_loader._size // train_loader.batch_size
212 | args.batches_per_epoch += (train_loader._size % train_loader.batch_size) != 0
213 |
214 | pipe = HybridValPipe(batch_size=batch_size,
215 | num_threads=workers,
216 | device_id=args.local_rank,
217 | root=args.val_root,
218 | list_path=args.val_list,
219 | size=256,
220 | crop=224,
221 | shard_id=args.local_rank,
222 | num_shards=args.world_size,
223 | dali_cpu=args.dali_cpu)
224 | pipe.build()
225 | val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size("Reader")/args.world_size))
226 | else:
227 | raise Exception('invalid image transformation mode')
228 |
229 | # define learning rate scheduler
230 | scheduler = get_lr_scheduler(optimizer)
231 |
232 | best_acc_top1 = 0
233 | best_acc_top5 = 0
234 | start_epoch = 0
235 |
236 | # restart from snapshot
237 | if args.snapshot and os.path.isfile(args.snapshot):
238 | if args.local_rank == 0:
239 | logging.info('loading snapshot from {}'.format(args.snapshot))
240 | checkpoint = torch.load(args.snapshot, map_location=lambda storage,loc: storage.cuda(args.gpu))
241 | start_epoch = checkpoint['epoch']
242 | best_acc_top1 = checkpoint['best_acc_top1']
243 | best_acc_top5 = checkpoint['best_acc_top5']
244 | model.load_state_dict(checkpoint['model'])
245 | optimizer.load_state_dict(checkpoint['optimizer'])
246 | if checkpoint['ema'] is not None:
247 | ema.load_state_dict(checkpoint['ema'])
248 | if args.opt_level is not None:
249 | amp.load_state_dict(checkpoint['amp'])
250 | scheduler = get_lr_scheduler(optimizer)
251 | for epoch in range(start_epoch):
252 | if epoch < args.warmup_epochs:
253 | adjust_learning_rate(optimizer, scheduler, epoch, -1)
254 | warmup_lr = get_last_lr(optimizer)
255 | if args.local_rank == 0:
256 | logging.info('Epoch: %d, Warming-up lr: %e', epoch, warmup_lr)
257 | else:
258 | current_lr = get_last_lr(optimizer)
259 | if args.local_rank == 0:
260 | logging.info('Epoch: %d lr %e', epoch, current_lr)
261 |
262 | if epoch < args.warmup_epochs:
263 | for param_group in optimizer.param_groups:
264 | param_group['lr'] = args.lr
265 | else:
266 | if args.lr_scheduler in ['linear_epoch', 'cosine_epoch', 'step_epoch']:
267 | adjust_learning_rate(optimizer, scheduler, epoch, -1)
268 | if args.lr_scheduler in ['linear_batch', 'cosine_batch', 'step_batch']:
269 | for batch_idx in range(args.batches_per_epoch):
270 | adjust_learning_rate(optimizer, scheduler, epoch, batch_idx)
271 |
272 | # the main loop
273 | for epoch in range(start_epoch, args.epochs):
274 | if epoch < args.warmup_epochs:
275 | adjust_learning_rate(optimizer, scheduler, epoch, -1)
276 | warmup_lr = get_last_lr(optimizer)
277 | if args.local_rank == 0:
278 | logging.info('Epoch: %d, Warming-up lr: %e', epoch, warmup_lr)
279 | else:
280 | current_lr = get_last_lr(optimizer)
281 | if args.local_rank == 0:
282 | logging.info('Epoch: %d lr %e', epoch, current_lr)
283 |
284 | if args.distributed and args.trans_mode == 'tv':
285 | train_sampler.set_epoch(epoch)
286 |
287 | epoch_start = time.time()
288 | train_acc, train_obj = train(train_loader, model, ema, criterion, optimizer, scheduler, epoch)
289 | if args.local_rank == 0:
290 | logging.info('Train_acc: %f', train_acc)
291 |
292 | val_acc_top1, val_acc_top5, val_obj = validate(val_loader, model, criterion)
293 | if args.local_rank == 0:
294 | logging.info('Val_acc_top1: %f', val_acc_top1)
295 | logging.info('Val_acc_top5: %f', val_acc_top5)
296 | logging.info('Epoch time: %ds.', time.time() - epoch_start)
297 |
298 | if args.local_rank == 0:
299 | is_best = False
300 | if val_acc_top1 > best_acc_top1:
301 | best_acc_top1 = val_acc_top1
302 | best_acc_top5 = val_acc_top5
303 | is_best = True
304 | save_checkpoint({
305 | 'epoch': epoch + 1,
306 | 'model': model.state_dict(),
307 | 'ema': ema.state_dict() if ema is not None else None,
308 | 'best_acc_top1': best_acc_top1,
309 | 'best_acc_top5': best_acc_top5,
310 | 'optimizer' : optimizer.state_dict(),
311 | 'amp': amp.state_dict() if args.opt_level is not None else None,
312 | }, is_best, args.save)
313 |
314 | if epoch < args.warmup_epochs:
315 | for param_group in optimizer.param_groups:
316 | param_group['lr'] = args.lr
317 | else:
318 | adjust_learning_rate(optimizer, scheduler, epoch, -1)
319 |
320 | if args.trans_mode == 'dali':
321 | train_loader.reset()
322 | val_loader.reset()
323 |
324 |
325 | def train(train_loader, model, ema, criterion, optimizer, scheduler, epoch):
326 | objs = AverageMeter()
327 | top1 = AverageMeter()
328 | top5 = AverageMeter()
329 | batch_time = AverageMeter()
330 | data_time = AverageMeter()
331 | model.train()
332 |
333 | end = time.time()
334 | for batch_idx, data in enumerate(train_loader):
335 | data_time.update(time.time() - end)
336 | if args.trans_mode == 'tv':
337 | x = data[0].cuda(non_blocking=True)
338 | target = data[1].cuda(non_blocking=True)
339 | elif args.trans_mode == 'dali':
340 | x = data[0]['data'].cuda(non_blocking=True)
341 | target = data[0]['label'].squeeze().cuda(non_blocking=True).long()
342 |
343 | # forward
344 | batch_start = time.time()
345 | logits = model(x)
346 | loss = criterion(logits, target)
347 |
348 | # backward
349 | optimizer.zero_grad()
350 | if args.opt_level is not None:
351 | with amp.scale_loss(loss, optimizer) as scaled_loss:
352 | scaled_loss.backward()
353 | else:
354 | loss.backward()
355 | optimizer.step()
356 | if ema is not None: ema.update()
357 | batch_time.update(time.time() - batch_start)
358 |
359 | if batch_idx % args.print_freq == 0:
360 | # For better performance, don't accumulate these metrics every iteration,
361 | # since they may incur an allreduce and some host<->device syncs.
362 | prec1, prec5 = accuracy(logits, target, topk=(1, 5))
363 | if args.distributed:
364 | reduced_loss = reduce_tensor(loss.data)
365 | prec1 = reduce_tensor(prec1)
366 | prec5 = reduce_tensor(prec5)
367 | else:
368 | reduced_loss = loss.data
369 | objs.update(reduced_loss.item(), x.size(0))
370 | top1.update(prec1.item(), x.size(0))
371 | top5.update(prec5.item(), x.size(0))
372 | torch.cuda.synchronize()
373 |
374 | duration = 0 if batch_idx == 0 else time.time() - duration_start
375 | duration_start = time.time()
376 | if args.local_rank == 0:
377 | logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs DTime: %.4fs',
378 | batch_idx, objs.avg, top1.avg, top5.avg, duration, batch_time.avg, data_time.avg)
379 |
380 | adjust_learning_rate(optimizer, scheduler, epoch, batch_idx)
381 | end = time.time()
382 |
383 | return top1.avg, objs.avg
384 |
385 |
386 | def validate(val_loader, model, criterion):
387 | objs = AverageMeter()
388 | top1 = AverageMeter()
389 | top5 = AverageMeter()
390 |
391 | model.eval()
392 |
393 | for batch_idx, data in enumerate(val_loader):
394 | if args.trans_mode == 'tv':
395 | x = data[0].cuda(non_blocking=True)
396 | target = data[1].cuda(non_blocking=True)
397 | elif args.trans_mode == 'dali':
398 | x = data[0]['data'].cuda(non_blocking=True)
399 | target = data[0]['label'].squeeze().cuda(non_blocking=True).long()
400 |
401 | with torch.no_grad():
402 | logits = model(x)
403 | loss = criterion(logits, target)
404 |
405 | prec1, prec5 = accuracy(logits, target, topk=(1, 5))
406 | if args.distributed:
407 | reduced_loss = reduce_tensor(loss.data)
408 | prec1 = reduce_tensor(prec1)
409 | prec5 = reduce_tensor(prec5)
410 | else:
411 | reduced_loss = loss.data
412 | objs.update(reduced_loss.item(), x.size(0))
413 | top1.update(prec1.item(), x.size(0))
414 | top5.update(prec5.item(), x.size(0))
415 |
416 | if args.local_rank == 0 and batch_idx % args.print_freq == 0:
417 | duration = 0 if batch_idx == 0 else time.time() - duration_start
418 | duration_start = time.time()
419 | logging.info('VALIDATE Step: %03d Objs: %e R1: %f R5: %f Duration: %ds', batch_idx, objs.avg, top1.avg, top5.avg, duration)
420 |
421 | return top1.avg, top5.avg, objs.avg
422 |
423 |
424 | def reduce_tensor(tensor):
425 | rt = tensor.clone()
426 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
427 | rt /= args.world_size
428 | return rt
429 |
430 |
431 | def get_lr_scheduler(optimizer):
432 | if args.lr_scheduler == 'linear_epoch':
433 | total_steps = args.epochs - args.warmup_epochs
434 | lambda_func = lambda step: max(1.0-step/float(total_steps), 0)
435 | scheduler = LambdaLRWithMin(optimizer, lambda_func, args.lr_min)
436 | elif args.lr_scheduler == 'linear_batch':
437 | total_steps = (args.epochs - args.warmup_epochs) * args.batches_per_epoch
438 | lambda_func = lambda step: max(1.0-step/float(total_steps), 0)
439 | scheduler = LambdaLRWithMin(optimizer, lambda_func, args.lr_min)
440 | elif args.lr_scheduler == 'cosine_epoch':
441 | total_steps = args.epochs - args.warmup_epochs
442 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(total_steps), args.lr_min)
443 | elif args.lr_scheduler == 'cosine_batch':
444 | total_steps = (args.epochs - args.warmup_epochs) * args.batches_per_epoch
445 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(total_steps), args.lr_min)
446 | elif args.lr_scheduler == 'step_epoch':
447 | assert args.lr_min > 0.0, 'the minimum lr must be larger than 0 for "step" lr_scheduler'
448 | total_steps = args.epochs - args.warmup_epochs
449 | gamma = (args.lr_min / args.lr) ** (1.0 / total_steps)
450 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma)
451 | elif args.lr_scheduler == 'step_batch':
452 | assert args.lr_min > 0.0, 'the minimum lr must be larger than 0 for "step" lr_scheduler'
453 | total_steps = (args.epochs - args.warmup_epochs) * args.batches_per_epoch
454 | gamma = (args.lr_min / args.lr) ** (1.0 / total_steps)
455 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma)
456 | else:
457 | raise Exception('invalid type fo lr scheduler')
458 |
459 | return scheduler
460 |
461 |
462 | def get_last_lr(optimizer):
463 | last_lrs = [param_group['lr'] for param_group in optimizer.param_groups]
464 | return last_lrs[0]
465 |
466 |
467 | def adjust_learning_rate(optimizer, scheduler, epoch, batch_idx):
468 | '''
469 | batch_idx = -1: adjusts lr per epoch
470 | batch_idx >= 0: adjusts lr per batch
471 | '''
472 | if args.lr_scheduler in ['linear_epoch', 'cosine_epoch', 'step_epoch']:
473 | if epoch < args.warmup_epochs:
474 | if batch_idx == -1:
475 | warmup_lr = float(epoch + 1) / (args.warmup_epochs + 1) * args.lr
476 | for param_group in optimizer.param_groups:
477 | param_group['lr'] = warmup_lr
478 | else:
479 | if batch_idx == -1:
480 | scheduler.step()
481 |
482 | if args.lr_scheduler in ['linear_batch', 'cosine_batch', 'step_batch']:
483 | if epoch < args.warmup_epochs:
484 | batch_idx = epoch * args.batches_per_epoch + batch_idx
485 | total_batches = args.warmup_epochs * args.batches_per_epoch
486 | warmup_lr = float(batch_idx + 2) / (total_batches + 1) * args.lr
487 | for param_group in optimizer.param_groups:
488 | param_group['lr'] = warmup_lr
489 | else:
490 | if batch_idx >= 0:
491 | scheduler.step()
492 |
493 |
494 | if __name__ == '__main__':
495 | main()
496 |
--------------------------------------------------------------------------------
/train_example.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1 python -u -m torch.distributed.launch --nproc_per_node=2 train.py \
2 | --train_root "Your ImageNet Train Set Path" \
3 | --val_root "Your ImageNet Val Set Path" \
4 | --train_list "ImageNet Train List" \
5 | --val_list "ImageNet Val List" \
6 | --save './checkpoints/' \
7 | --workers 16 \
8 | --epochs 250 \
9 | --warmup_epochs 5 \
10 | --batch_size 512 \
11 | --lr 0.2 \
12 | --lr_min 0.0 \
13 | --lr_scheduler 'cosine_epoch' \
14 | --momentum 0.9 \
15 | --weight_decay 3e-5 \
16 | --no_wd_bias_bn \
17 | --model 'MobileNetV3_Large' \
18 | --num_classes 1000 \
19 | --dropout_rate 0.2 \
20 | --label_smooth 0.1 \
21 | --trans_mode 'tv' \
22 | --color_jitter \
23 | --ema_decay 0.9999 \
24 | --opt_level 'O1' \
25 | --note 'try'
26 |
27 |
28 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import torch
4 | import random
5 | import numpy as np
6 |
7 |
8 | def set_seed(seed):
9 | random.seed(seed)
10 | np.random.seed(seed)
11 | torch.manual_seed(seed)
12 | torch.cuda.manual_seed(seed)
13 | torch.cuda.manual_seed_all(seed)
14 |
15 |
16 | class AverageMeter(object):
17 | """
18 | Computes and stores the average and current value
19 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
20 | """
21 | def __init__(self):
22 | self.val = 0
23 | self.avg = 0
24 | self.sum = 0
25 | self.count = 0
26 |
27 | def reset(self):
28 | self.val = 0
29 | self.avg = 0
30 | self.sum = 0
31 | self.count = 0
32 |
33 | def update(self, val, n=1):
34 | self.val = val
35 | self.sum += val * n
36 | self.count += n
37 | self.avg = self.sum / self.count
38 |
39 |
40 | def accuracy(output, target, topk=(1,)):
41 | """ Computes the precision@k for the specified values of k """
42 | maxk = max(topk)
43 | batch_size = target.size(0)
44 |
45 | _, pred = output.topk(maxk, 1, True, True)
46 | pred = pred.t()
47 | correct = pred.eq(target.view(1, -1).expand_as(pred))
48 |
49 | res = []
50 | for k in topk:
51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
52 | res.append(correct_k.mul_(100.0 / batch_size))
53 | return res
54 |
55 |
56 | def save_checkpoint(state, is_best, save):
57 | filename = os.path.join(save, 'checkpoint.pth.tar')
58 | torch.save(state, filename)
59 | if is_best:
60 | best_filename = os.path.join(save, 'model_best.pth.tar')
61 | shutil.copyfile(filename, best_filename)
62 |
63 |
64 | def create_exp_dir(path, scripts_to_save=None):
65 | if not os.path.exists(path):
66 | os.makedirs(path)
67 | print('Experiment dir : {}'.format(path))
68 |
69 | if scripts_to_save is not None:
70 | os.makedirs(os.path.join(path, 'scripts'))
71 | for script in scripts_to_save:
72 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
73 | shutil.copyfile(script, dst_file)
74 |
75 |
76 | def get_params(model):
77 | params_no_weight_decay = []
78 | params_weight_decay = []
79 | for pname, p in model.named_parameters():
80 | if pname.find('weight') >= 0 and len(p.size()) > 1:
81 | # print('include ', pname, p.size())
82 | params_weight_decay.append(p)
83 | else:
84 | # print('not include ', pname, p.size())
85 | params_no_weight_decay.append(p)
86 | assert len(list(model.parameters())) == len(params_weight_decay) + len(params_no_weight_decay)
87 | params = [dict(params=params_weight_decay), dict(params=params_no_weight_decay, weight_decay=0.)]
88 | return params
89 |
90 |
91 | class EMA():
92 | def __init__(self, model, decay):
93 | self.model = model
94 | self.decay = decay
95 | self.shadow = {}
96 |
97 | def register(self):
98 | for name, state in self.model.state_dict().items():
99 | self.shadow[name] = state.clone()
100 |
101 | def update(self):
102 | for name, state in self.model.state_dict().items():
103 | assert name in self.shadow
104 | new_average = (1.0 - self.decay) * state + self.decay * self.shadow[name]
105 | self.shadow[name] = new_average.clone()
106 | del new_average
107 |
108 | def state_dict(self):
109 | return self.shadow
110 |
111 | def load_state_dict(self, state_dict):
112 | for name, state in state_dict.items():
113 | self.shadow[name] = state.clone()
--------------------------------------------------------------------------------