├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── config
├── __init__.py
└── cfgs.py
├── core
├── __init__.py
├── metric.py
├── optimizer.py
├── scheduler.py
└── solver.py
├── data
├── __init__.py
└── imagenet.py
├── docs
├── flops.png
├── params.png
└── sss.png
├── symbol
├── __init__.py
├── resnet.py
└── resnext.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "incubator-mxnet-bk"]
2 | path = incubator-mxnet-bk
3 | url = https://github.com/huangzehao/incubator-mxnet-bk/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sparse-structure-selection
2 |
3 | This code is a re-implementation of the imagenet classification experiments in the paper [Data-Driven Sparse Structure Selection for Deep Neural Networks
4 | ](https://arxiv.org/abs/1707.01213) (ECCV2018).
5 |
6 |
7 |

8 |
9 |
10 | ## Citation
11 | If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry.
12 | ```
13 | @article{SSS2018
14 | author = {Zehao Huang and Naiyan Wang},
15 | title = {Data-Driven Sparse Structure Selection for Deep Neural Networks},
16 | journal = {ECCV},
17 | year = {2018}
18 | }
19 | ```
20 |
21 | ## Implementation
22 | This code is implemented by a modified [MXNet](https://github.com/huangzehao/incubator-mxnet-bk) which supports [ResNeXt-like](https://github.com/facebookresearch/ResNeXt) augmentation. (This version of MXNet does not support cudnn7)
23 |
24 | ## ImageNet data preparation
25 | Download the [ImageNet](http://image-net.org/download-images) dataset and create pass through rec (following [tornadomeet's repository](https://github.com/tornadomeet/ResNet#imagenet) but using unchange mode)
26 |
27 | ## Run
28 | - modify ```config/cfgs.py```
29 | - ```python train.py```
30 |
31 | ## Results on ImageNet-1k
32 |
33 |

34 |

35 |
36 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from cfgs import *
2 |
--------------------------------------------------------------------------------
/config/cfgs.py:
--------------------------------------------------------------------------------
1 | # mxnet version: https://github.com/huangzehao/incubator-mxnet-bk
2 | mxnet_path = 'incubator-mxnet-bk/python/'
3 | gpu_list = [0, 1, 2, 3]
4 | dataset = "imagenet"
5 | model_prefix = "resnet-50-sss-0.01"
6 | network = "resnet"
7 | depth = 50
8 | model_load_prefix = model_prefix
9 | model_load_epoch = 0
10 | retrain = False
11 | sss = True
12 | gamma = 0.01
13 |
14 | # data
15 | data_dir = 'imagenet'
16 | batch_size = 32
17 | batch_size *= len(gpu_list)
18 | kv_store = 'device'
19 |
20 | # optimizer
21 | lr = 0.1
22 | wd = 0.0001
23 | momentum = 0.9
24 | if dataset == "imagenet":
25 | lr_step = [30, 60, 90]
26 | else:
27 | lr_step = [120, 160, 240]
28 | lr_factor = 0.1
29 | begin_epoch = model_load_epoch if retrain else 0
30 | num_epoch = 100
31 | frequent = 50
32 |
33 | # network config
34 | if dataset == "imagenet":
35 | num_classes = 1000
36 | if network.startswith("res"):
37 | units_dict = {"18": [2, 2, 2, 2],
38 | "34": [3, 4, 6, 3],
39 | "50": [3, 4, 6, 3],
40 | "101": [3, 4, 23, 3],
41 | "152": [3, 8, 36, 3]}
42 | units = units_dict[str(depth)]
43 | if depth >= 50:
44 | filter_list = [64, 256, 512, 1024, 2048]
45 | bottle_neck = True
46 | else:
47 | filter_list = [64, 64, 128, 256, 512]
48 | bottle_neck = False
49 | num_stage = 4
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/core/__init__.py
--------------------------------------------------------------------------------
/core/metric.py:
--------------------------------------------------------------------------------
1 | import mxnet as mx
2 | import numpy as np
3 |
4 |
5 | class KTAccMetric(mx.metric.EvalMetric):
6 | def __init__(self):
7 | super(KTAccMetric, self).__init__('accuracy')
8 |
9 | def update(self, labels, preds):
10 | pred_cf = preds[0]
11 | label_cf = labels[0]
12 |
13 | pred_label = mx.ndarray.argmax_channel(pred_cf).asnumpy().astype('int32')
14 | label = label_cf.asnumpy().astype('int32')
15 |
16 | self.sum_metric += (pred_label.flat == label.flat).sum()
17 | self.num_inst += len(pred_label.flat)
18 |
19 |
20 | class KTTopkAccMetric(mx.metric.EvalMetric):
21 | """Calculate top k predictions accuracy"""
22 |
23 | def __init__(self, **kwargs):
24 | super(KTTopkAccMetric, self).__init__('top_k_accuracy')
25 | try:
26 | self.top_k = kwargs['top_k']
27 | except KeyError:
28 | self.top_k = 1
29 | assert (self.top_k > 1), 'Please use Accuracy if top_k is no more than 1'
30 | self.name += '_%d' % self.top_k
31 |
32 | def update(self, labels, preds):
33 |
34 | pred_label = preds[0]
35 | label = labels[0]
36 | assert (len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims'
37 | pred_label = np.argsort(pred_label.asnumpy().astype('float32'), axis=1)
38 | label = label.asnumpy().astype('int32')
39 | num_samples = pred_label.shape[0]
40 | num_dims = len(pred_label.shape)
41 | if num_dims == 1:
42 | self.sum_metric += (pred_label.flat == label.flat).sum()
43 | elif num_dims == 2:
44 | num_classes = pred_label.shape[1]
45 | top_k = min(num_classes, self.top_k)
46 | for j in range(top_k):
47 | self.sum_metric += (pred_label[:, num_classes - 1 - j].flat == label.flat).sum()
48 | self.num_inst += num_samples
49 |
--------------------------------------------------------------------------------
/core/optimizer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import mxnet as mx
4 | from math import sqrt
5 | from mxnet.optimizer import Optimizer, SGD, clip
6 | from mxnet.ndarray import NDArray, zeros
7 | from mxnet.ndarray import sgd_update, sgd_mom_update
8 |
9 |
10 | @mx.optimizer.register
11 | class APGNAG(SGD):
12 | """APG and NAG.
13 | """
14 | def __init__(self, lambda_name=None, gamma=None, **kwargs):
15 | super(APGNAG, self).__init__(**kwargs)
16 | self.lambda_name = lambda_name
17 | self.gamma = gamma
18 |
19 | def update(self, index, weight, grad, state):
20 | assert(isinstance(weight, NDArray))
21 | assert(isinstance(grad, NDArray))
22 | lr = self._get_lr(index)
23 | wd = self._get_wd(index)
24 | self._update_count(index)
25 |
26 | grad = grad * self.rescale_grad
27 | if self.clip_gradient is not None:
28 | grad = clip(grad, -self.clip_gradient, self.clip_gradient)
29 |
30 | if self.idx2name[index].startswith(self.lambda_name):
31 | # APG
32 | if state is not None:
33 | mom = state
34 | mom[:] *= self.momentum
35 | z = weight - lr * grad # equ 10
36 | z = self.soft_thresholding(z, lr * self.gamma)
37 | mom[:] = z - weight + mom # equ 11
38 | weight[:] = z + self.momentum * mom # equ 12
39 | else:
40 | assert self.momentum == 0.0
41 | # no-negative
42 | weight[:] = mx.ndarray.maximum(0.0, weight[:])
43 | if self.num_update % 1000 == 0:
44 | print self.idx2name[index], weight.asnumpy()
45 | else:
46 | if state is not None:
47 | mom = state
48 | mom[:] *= self.momentum
49 | grad += wd * weight
50 | mom[:] += grad
51 | grad[:] += self.momentum * mom
52 | weight[:] += -lr * grad
53 | else:
54 | assert self.momentum == 0.0
55 | weight[:] += -lr * (grad + wd * weight)
56 |
57 | @staticmethod
58 | def soft_thresholding(input, alpha):
59 | return mx.ndarray.sign(input) * mx.ndarray.maximum(0.0, mx.ndarray.abs(input) - alpha)
60 |
--------------------------------------------------------------------------------
/core/scheduler.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import config
3 |
4 | sys.path.insert(0, config.mxnet_path)
5 | import mxnet as mx
6 |
7 |
8 | def multi_factor_scheduler(begin_epoch, epoch_size, step, factor=0.1):
9 | step_ = [epoch_size * (x - begin_epoch) for x in step if x - begin_epoch > 0]
10 | return mx.lr_scheduler.MultiFactorScheduler(step=step_, factor=factor) if len(step_) else None
11 |
--------------------------------------------------------------------------------
/core/solver.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import sys
4 |
5 | import config
6 |
7 | sys.path.insert(0, config.mxnet_path)
8 | import mxnet as mx
9 | import numpy as np
10 | from mxnet.module import Module
11 | from mxnet import metric
12 | from mxnet.model import BatchEndParam
13 |
14 |
15 | def _as_list(obj):
16 | if isinstance(obj, list):
17 | return obj
18 | else:
19 | return [obj]
20 |
21 |
22 | class Solver(object):
23 | def __init__(self, symbol, data_names, label_names,
24 | data_shapes, label_shapes, logger=logging,
25 | context=mx.cpu(), work_load_list=None, fixed_param_names=None):
26 | self.symbol = symbol
27 | self.data_names = data_names
28 | self.label_names = label_names
29 | self.data_shapes = data_shapes
30 | self.label_shapes = label_shapes
31 | self.context = context
32 | self.work_load_list = work_load_list
33 | self.fixed_param_names = fixed_param_names
34 |
35 | if logger is None:
36 | logger = logging.getLogger()
37 | logger.setLevel(logging.INFO)
38 | self.logger = logger
39 | self.module = Module(symbol=self.symbol, data_names=self.data_names,
40 | label_names=self.label_names, logger=self.logger,
41 | context=self.context, work_load_list=self.work_load_list,
42 | fixed_param_names=self.fixed_param_names)
43 |
44 | def fit(self, train_data, eval_data=None,
45 | eval_metric='acc', validate_metric=None,
46 | work_load_list=None, epoch_end_callback=None,
47 | batch_end_callback=None, fixed_param_prefix=None,
48 | initializer=None, arg_params=None,
49 | aux_params=None, allow_missing=False,
50 | optimizer=None, optimizer_params=None,
51 | begin_epoch=0, num_epoch=None,
52 | kvstore='device'):
53 |
54 | self.module.bind(data_shapes=self.data_shapes, label_shapes=self.label_shapes, for_training=True)
55 | self.module.init_params(initializer=initializer,
56 | arg_params=arg_params,
57 | aux_params=aux_params,
58 | allow_missing=allow_missing)
59 | self.module.init_optimizer(kvstore=kvstore,
60 | optimizer=optimizer,
61 | optimizer_params=optimizer_params)
62 |
63 | if validate_metric is None:
64 | validate_metric = eval_metric
65 | if not isinstance(eval_metric, metric.EvalMetric):
66 | eval_metric = metric.create(eval_metric)
67 |
68 | # training loop
69 | for epoch in range(begin_epoch, num_epoch):
70 | tic = time.time()
71 | eval_metric.reset()
72 | nbatch = 0
73 | data_iter = iter(train_data)
74 | end_of_batch = False
75 | next_data_batch = next(data_iter)
76 | while not end_of_batch:
77 | data_batch = next_data_batch
78 | self.module.forward(data_batch, is_train=True)
79 | self.module.backward()
80 | self.module.update()
81 |
82 | try:
83 | next_data_batch = next(data_iter)
84 | except StopIteration:
85 | end_of_batch = True
86 |
87 | self.module.update_metric(eval_metric, data_batch.label)
88 |
89 | if batch_end_callback is not None:
90 | batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
91 | eval_metric=eval_metric,
92 | locals=locals())
93 | for callback in _as_list(batch_end_callback):
94 | callback(batch_end_params)
95 | nbatch += 1
96 |
97 | for name, val in eval_metric.get_name_value():
98 | self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
99 | toc = time.time()
100 | self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
101 |
102 | arg_params, aux_params = self.module.get_params()
103 | for key in arg_params:
104 | if key == 'lambda_block':
105 | self.logger.info('{}: {}'.format(key, arg_params[key].asnumpy()))
106 | if key == 'lambda_group':
107 | self.logger.info('{}: {}'.format(key, np.sum(arg_params[key].asnumpy() != 0, 1)))
108 | self.module.set_params(arg_params, aux_params)
109 |
110 | if epoch_end_callback is not None:
111 | for callback in _as_list(epoch_end_callback):
112 | callback(epoch, self.symbol, arg_params, aux_params)
113 | if eval_data:
114 | res = self.module.score(eval_data, validate_metric,
115 | score_end_callback=None,
116 | batch_end_callback=None,
117 | reset=True,
118 | epoch=epoch)
119 | for name, val in res:
120 | self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
121 |
122 | train_data.reset()
123 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from imagenet import imagenet_iterator
2 |
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import config
4 |
5 | sys.path.insert(0, config.mxnet_path)
6 | import mxnet as mx
7 |
8 |
9 | def imagenet_iterator(data_dir, batch_size, kv):
10 | train = mx.io.ImageRecordIter(
11 | path_imgrec = os.path.join(data_dir, "train.rec"),
12 | label_width = 1,
13 | data_name = 'data',
14 | label_name = 'softmax_label',
15 | data_shape = (3, 224, 224),
16 | batch_size = batch_size,
17 | pad = 0,
18 | fill_value = 127,
19 | facebook_aug = True,
20 | max_random_area = 1.0,
21 | min_random_area = 0.08,
22 | max_aspect_ratio = 4.0 / 3.0,
23 | min_aspect_ratio = 3.0 / 4.0,
24 | brightness = 0.4,
25 | contrast = 0.4,
26 | saturation = 0.4,
27 | mean_r = 123.68,
28 | mean_g = 116.28,
29 | mean_b = 103.53,
30 | std_r = 58.395,
31 | std_g = 57.12,
32 | std_b = 57.375,
33 | pca_noise = 0.1,
34 | scale = 1,
35 | inter_method = 2,
36 | rand_mirror = True,
37 | shuffle = True,
38 | shuffle_chunk_size = 4096,
39 | preprocess_threads = 20,
40 | prefetch_buffer = 16,
41 | num_parts = kv.num_workers,
42 | part_index = kv.rank)
43 |
44 | val = mx.io.ImageRecordIter(
45 | path_imgrec = os.path.join(data_dir, "val.rec"),
46 | label_width = 1,
47 | data_name = 'data',
48 | label_name = 'softmax_label',
49 | resize = 256,
50 | batch_size = batch_size,
51 | data_shape = (3, 224, 224),
52 | mean_r = 123.68,
53 | mean_g = 116.28,
54 | mean_b = 103.53,
55 | std_r = 58.395,
56 | std_g = 57.12,
57 | std_b = 57.375,
58 | scale = 1,
59 | inter_method = 2,
60 | rand_crop = False,
61 | rand_mirror = False,
62 | num_parts = kv.num_workers,
63 | part_index = kv.rank)
64 |
65 | num_examples = 1281167
66 | return train, val, num_examples
67 |
--------------------------------------------------------------------------------
/docs/flops.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/docs/flops.png
--------------------------------------------------------------------------------
/docs/params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/docs/params.png
--------------------------------------------------------------------------------
/docs/sss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/docs/sss.png
--------------------------------------------------------------------------------
/symbol/__init__.py:
--------------------------------------------------------------------------------
1 | from resnet import resnet
2 | from resnext import resnext
3 |
--------------------------------------------------------------------------------
/symbol/resnet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import config
3 |
4 | sys.path.insert(0, config.mxnet_path)
5 | import mxnet as mx
6 |
7 | eps = 1e-5
8 |
9 |
10 | def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True,
11 | bn_mom=0.9, workspace=512, memonger=False, lambda_block=None):
12 | if bottle_neck:
13 | bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn1')
14 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
15 | conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1),
16 | pad=(0, 0),
17 | no_bias=True, workspace=workspace, name=name + '_conv1')
18 | bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn2')
19 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
20 | conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter * 0.25), kernel=(3, 3), stride=stride,
21 | pad=(1, 1),
22 | no_bias=True, workspace=workspace, name=name + '_conv2')
23 | bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn3')
24 | act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
25 | conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0),
26 | no_bias=True,
27 | workspace=workspace, name=name + '_conv3')
28 | if config.sss:
29 | conv3 = mx.sym.broadcast_mul(conv3, lambda_block)
30 | if dim_match:
31 | shortcut = data
32 | else:
33 | shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True,
34 | workspace=workspace, name=name + '_sc')
35 | if memonger:
36 | shortcut._set_attr(mirror_stage='True')
37 |
38 | return conv3 + shortcut
39 | else:
40 | bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn1')
41 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
42 | conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3, 3), stride=stride, pad=(1, 1),
43 | no_bias=True, workspace=workspace, name=name + '_conv1')
44 | bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn2')
45 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
46 | conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1),
47 | no_bias=True, workspace=workspace, name=name + '_conv2')
48 | if dim_match:
49 | shortcut = data
50 | else:
51 | shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True,
52 | workspace=workspace, name=name + '_sc')
53 | if memonger:
54 | shortcut._set_attr(mirror_stage='True')
55 | return conv2 + shortcut
56 |
57 |
58 | def resnet(units, num_stage, filter_list, num_classes, data_type, bottle_neck=True,
59 | bn_mom=0.9, workspace=512, memonger=False):
60 | num_unit = len(units)
61 | assert (num_unit == num_stage)
62 |
63 | # declare lambda array
64 | if config.sss:
65 | num_block = 0
66 | for i in range(num_stage):
67 | num_block += 1
68 | for j in range(units[i] - 1):
69 | num_block += 1
70 | print 'block number:', num_block
71 | block_array = mx.symbol.Variable(name="lambda_block", shape=(num_block, ), dtype='float32', lr_mult=1, wd_mult=0, init=mx.initializer.One())
72 | block_split = mx.symbol.SliceChannel(block_array, num_outputs=num_block, axis=0, squeeze_axis=0,name="lambda_block_split")
73 |
74 | num_block = 0
75 | data = mx.sym.Variable(name='data')
76 | body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3),
77 | no_bias=True, name="conv0", workspace=workspace)
78 | body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=eps, momentum=bn_mom, name='bn0')
79 | body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
80 | body = mx.symbol.Pooling(data=body, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max')
81 |
82 | for i in range(num_stage):
83 | body = residual_unit(body, filter_list[i + 1], (1 if i == 0 else 2, 1 if i == 0 else 2), False,
84 | name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace,
85 | memonger=memonger, lambda_block=block_split[num_block])
86 | num_block += 1
87 | for j in range(units[i] - 1):
88 | body = residual_unit(body, filter_list[i + 1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2),
89 | bottle_neck=bottle_neck, workspace=workspace, memonger=memonger, lambda_block=block_split[num_block])
90 | num_block += 1
91 | bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=eps, momentum=bn_mom, name='bn1')
92 | relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
93 | pool1 = mx.symbol.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
94 | flat = mx.symbol.Flatten(data=pool1)
95 | fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
96 | cls = mx.symbol.SoftmaxOutput(data=fc1, name='softmax')
97 | return cls
98 |
--------------------------------------------------------------------------------
/symbol/resnext.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import config
3 |
4 | sys.path.insert(0, config.mxnet_path)
5 | import mxnet as mx
6 |
7 | eps = 1e-5
8 |
9 |
10 | def xresidual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True,
11 | num_group=32, bn_mom=0.9, workspace=256, memonger=False, lambda_group=None, lambda_block=None):
12 | if config.sss:
13 | lambda_group = mx.symbol.repeat(lambda_group, repeats=int(num_filter * 0.5 / num_group))
14 | lambda_group = mx.symbol.reshape(data=lambda_group, shape=(0,1,1,1))
15 | if bottle_neck:
16 | conv1 = mx.sym.Convolution(data=data, num_filter=int(num_filter * 0.5), kernel=(1, 1), stride=(1, 1),
17 | pad=(0, 0),
18 | no_bias=True, workspace=workspace, name=name + '_conv1')
19 | bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn1')
20 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
21 |
22 | conv2 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.5), num_group=num_group, kernel=(3, 3),
23 | stride=stride, pad=(1, 1),
24 | no_bias=True, workspace=workspace, name=name + '_conv2')
25 | bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn2')
26 | if config.sss:
27 | bn2 = mx.sym.transpose(data=bn2, axes=(1,0,2,3))
28 | bn2 = mx.sym.broadcast_mul(lhs=bn2,rhs=lambda_group)
29 | bn2 = mx.sym.transpose(data=bn2, axes=(1,0,2,3))
30 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
31 |
32 | conv3 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0),
33 | no_bias=True,
34 | workspace=workspace, name=name + '_conv3')
35 | bn3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn3')
36 |
37 | if dim_match:
38 | shortcut = data
39 | else:
40 | shortcut_conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1, 1), stride=stride,
41 | no_bias=True,
42 | workspace=workspace, name=name + '_sc')
43 | shortcut = mx.sym.BatchNorm(data=shortcut_conv, fix_gamma=False, eps=eps, momentum=bn_mom,
44 | name=name + '_sc_bn')
45 |
46 | if memonger:
47 | shortcut._set_attr(mirror_stage='True')
48 | if config.sss:
49 | bn3 = mx.sym.broadcast_mul(bn3, lambda_block)
50 | eltwise = bn3 + shortcut
51 | return mx.sym.Activation(data=eltwise, act_type='relu', name=name + '_relu')
52 | else:
53 | conv1 = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(3, 3), stride=stride, pad=(1, 1),
54 | no_bias=True, workspace=workspace, name=name + '_conv1')
55 | bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn1')
56 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
57 |
58 | conv2 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1),
59 | no_bias=True, workspace=workspace, name=name + '_conv2')
60 | bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn2')
61 |
62 | if dim_match:
63 | shortcut = data
64 | else:
65 | shortcut_conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1, 1), stride=stride,
66 | no_bias=True,
67 | workspace=workspace, name=name + '_sc')
68 | shortcut = mx.sym.BatchNorm(data=shortcut_conv, fix_gamma=False, eps=eps, momentum=bn_mom,
69 | name=name + '_sc_bn')
70 |
71 | if memonger:
72 | shortcut._set_attr(mirror_stage='True')
73 | eltwise = bn2 + shortcut
74 | return mx.sym.Activation(data=eltwise, act_type='relu', name=name + '_relu')
75 |
76 |
77 | def resnext(units, num_stage, filter_list, num_classes, data_type, num_group=32,
78 | bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False):
79 | num_unit = len(units)
80 | assert (num_unit == num_stage)
81 |
82 | # declare lambda array
83 | if config.sss:
84 | num_block = 0
85 | for i in range(num_stage):
86 | num_block += 1
87 | for j in range(units[i] - 1):
88 | num_block += 1
89 | print 'block number:', num_block
90 | print 'group number:', num_group
91 | block_array = mx.symbol.Variable(name="lambda_block", shape=(num_block, ), dtype='float32', lr_mult=1, wd_mult=0, init=mx.initializer.One())
92 | block_split = mx.symbol.SliceChannel(block_array, num_outputs=num_block, axis=0, squeeze_axis=0,name="lambda_block_split")
93 | group_array = mx.symbol.Variable(name="lambda_group", shape=(num_block, num_group), dtype='float32', lr_mult=1, wd_mult=0, init=mx.initializer.One())
94 | group_split = mx.symbol.SliceChannel(group_array, num_outputs=num_block, axis=0, squeeze_axis=1,name="lambda_group_split")
95 |
96 | num_block = 0
97 | data = mx.sym.Variable(name='data')
98 | body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3),
99 | no_bias=True, name="conv0", workspace=workspace)
100 | body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=eps, momentum=bn_mom, name='bn0')
101 | body = mx.sym.Activation(data=body, act_type='relu', name='relu0')
102 | body = mx.symbol.Pooling(data=body, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max')
103 |
104 | for i in range(num_stage):
105 | body = xresidual_unit(body, filter_list[i + 1], (1 if i == 0 else 2, 1 if i == 0 else 2), False,
106 | name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, num_group=num_group,
107 | bn_mom=bn_mom, workspace=workspace, memonger=memonger,
108 | lambda_group=group_split[num_block] if config.sss else None,
109 | lambda_block=block_split[num_block] if config.sss else None)
110 | num_block += 1
111 | for j in range(units[i] - 1):
112 | body = xresidual_unit(body, filter_list[i + 1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2),
113 | bottle_neck=bottle_neck, num_group=num_group, bn_mom=bn_mom,
114 | workspace=workspace, memonger=memonger,
115 | lambda_group=group_split[num_block] if config.sss else None,
116 | lambda_block=block_split[num_block] if config.sss else None)
117 | num_block += 1
118 | pool1 = mx.symbol.Pooling(data=body, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
119 | flat = mx.symbol.Flatten(data=pool1)
120 | fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
121 | cls = mx.symbol.SoftmaxOutput(data=fc1, name='softmax')
122 | return cls
123 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import logging, os
2 | import sys
3 |
4 | import config
5 |
6 | sys.path.insert(0, config.mxnet_path)
7 | import mxnet as mx
8 | from core.scheduler import multi_factor_scheduler
9 | from core.solver import Solver
10 | from core.metric import *
11 | from core.optimizer import *
12 | from data import *
13 | from symbol import *
14 |
15 |
16 | def main(config):
17 | # log file
18 | log_dir = "./log"
19 | if not os.path.exists(log_dir):
20 | os.mkdir(log_dir)
21 | logging.basicConfig(level=logging.INFO,
22 | format='%(asctime)s %(name)s %(levelname)s %(message)s',
23 | datefmt='%m-%d %H:%M',
24 | filename='{}/{}.log'.format(log_dir, config.model_prefix),
25 | filemode='a')
26 | console = logging.StreamHandler()
27 | console.setLevel(logging.INFO)
28 | formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
29 | console.setFormatter(formatter)
30 | logging.getLogger('').addHandler(console)
31 | # model folder
32 | model_dir = "./model"
33 | if not os.path.exists(model_dir):
34 | os.mkdir(model_dir)
35 |
36 | # set up environment
37 | devs = [mx.gpu(int(i)) for i in config.gpu_list]
38 | kv = mx.kvstore.create(config.kv_store)
39 |
40 | # set up iterator and symbol
41 | # iterator
42 | train, val, num_examples = imagenet_iterator(data_dir=config.data_dir,
43 | batch_size=config.batch_size,
44 | kv=kv)
45 | data_names = ('data',)
46 | label_names = ('softmax_label',)
47 | data_shapes = [('data', (config.batch_size, 3, 224, 224))]
48 | label_shapes = [('softmax_label', (config.batch_size,))]
49 |
50 | if config.network == 'resnet' or config.network == 'resnext':
51 | symbol = eval(config.network)(units=config.units,
52 | num_stage=config.num_stage,
53 | filter_list=config.filter_list,
54 | num_classes=config.num_classes,
55 | data_type=config.dataset,
56 | bottle_neck=config.bottle_neck)
57 |
58 | # train
59 | epoch_size = max(int(num_examples / config.batch_size / kv.num_workers), 1)
60 | if config.lr_step is not None:
61 | lr_scheduler = multi_factor_scheduler(config.begin_epoch, epoch_size, step=config.lr_step,
62 | factor=config.lr_factor)
63 | else:
64 | lr_scheduler = None
65 |
66 | optimizer_params = {'learning_rate': config.lr,
67 | 'lr_scheduler': lr_scheduler,
68 | 'wd': config.wd,
69 | 'momentum': config.momentum}
70 | optimizer = "nag"
71 | if config.sss:
72 | sss_optimizer_params = {'lambda_name': 'lambda',
73 | 'gamma': config.gamma}
74 | optimizer_params.update(sss_optimizer_params)
75 | optimizer = "apgnag"
76 | eval_metric = ['acc']
77 | if config.dataset == "imagenet":
78 | eval_metric.append(mx.metric.create('top_k_accuracy', top_k=5))
79 |
80 | solver = Solver(symbol=symbol,
81 | data_names=data_names,
82 | label_names=label_names,
83 | data_shapes=data_shapes,
84 | label_shapes=label_shapes,
85 | logger=logging,
86 | context=devs)
87 | epoch_end_callback = mx.callback.do_checkpoint("./model/" + config.model_prefix)
88 | batch_end_callback = mx.callback.Speedometer(config.batch_size, config.frequent)
89 | arg_params = None
90 | aux_params = None
91 | if config.retrain:
92 | _, arg_params, aux_params = mx.model.load_checkpoint("model/{}".format(config.model_load_prefix),
93 | config.model_load_epoch)
94 |
95 | if config.network.startswith('res'):
96 | initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2)
97 |
98 | solver.fit(train_data=train,
99 | eval_data=val,
100 | eval_metric=eval_metric,
101 | epoch_end_callback=epoch_end_callback,
102 | batch_end_callback=batch_end_callback,
103 | initializer=initializer,
104 | arg_params=arg_params,
105 | aux_params=aux_params,
106 | optimizer=optimizer,
107 | optimizer_params=optimizer_params,
108 | begin_epoch=config.begin_epoch,
109 | num_epoch=config.num_epoch,
110 | kvstore=kv)
111 |
112 |
113 | if __name__ == '__main__':
114 | main(config)
115 |
--------------------------------------------------------------------------------