├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── config ├── __init__.py └── cfgs.py ├── core ├── __init__.py ├── metric.py ├── optimizer.py ├── scheduler.py └── solver.py ├── data ├── __init__.py └── imagenet.py ├── docs ├── flops.png ├── params.png └── sss.png ├── symbol ├── __init__.py ├── resnet.py └── resnext.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "incubator-mxnet-bk"] 2 | path = incubator-mxnet-bk 3 | url = https://github.com/huangzehao/incubator-mxnet-bk/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sparse-structure-selection 2 | 3 | This code is a re-implementation of the imagenet classification experiments in the paper [Data-Driven Sparse Structure Selection for Deep Neural Networks 4 | ](https://arxiv.org/abs/1707.01213) (ECCV2018). 5 | 6 |
7 | 8 |
9 | 10 | ## Citation 11 | If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry. 12 | ``` 13 | @article{SSS2018 14 | author = {Zehao Huang and Naiyan Wang}, 15 | title = {Data-Driven Sparse Structure Selection for Deep Neural Networks}, 16 | journal = {ECCV}, 17 | year = {2018} 18 | } 19 | ``` 20 | 21 | ## Implementation 22 | This code is implemented by a modified [MXNet](https://github.com/huangzehao/incubator-mxnet-bk) which supports [ResNeXt-like](https://github.com/facebookresearch/ResNeXt) augmentation. (This version of MXNet does not support cudnn7) 23 | 24 | ## ImageNet data preparation 25 | Download the [ImageNet](http://image-net.org/download-images) dataset and create pass through rec (following [tornadomeet's repository](https://github.com/tornadomeet/ResNet#imagenet) but using unchange mode) 26 | 27 | ## Run 28 | - modify ```config/cfgs.py``` 29 | - ```python train.py``` 30 | 31 | ## Results on ImageNet-1k 32 |
33 | 34 | 35 |
36 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from cfgs import * 2 | -------------------------------------------------------------------------------- /config/cfgs.py: -------------------------------------------------------------------------------- 1 | # mxnet version: https://github.com/huangzehao/incubator-mxnet-bk 2 | mxnet_path = 'incubator-mxnet-bk/python/' 3 | gpu_list = [0, 1, 2, 3] 4 | dataset = "imagenet" 5 | model_prefix = "resnet-50-sss-0.01" 6 | network = "resnet" 7 | depth = 50 8 | model_load_prefix = model_prefix 9 | model_load_epoch = 0 10 | retrain = False 11 | sss = True 12 | gamma = 0.01 13 | 14 | # data 15 | data_dir = 'imagenet' 16 | batch_size = 32 17 | batch_size *= len(gpu_list) 18 | kv_store = 'device' 19 | 20 | # optimizer 21 | lr = 0.1 22 | wd = 0.0001 23 | momentum = 0.9 24 | if dataset == "imagenet": 25 | lr_step = [30, 60, 90] 26 | else: 27 | lr_step = [120, 160, 240] 28 | lr_factor = 0.1 29 | begin_epoch = model_load_epoch if retrain else 0 30 | num_epoch = 100 31 | frequent = 50 32 | 33 | # network config 34 | if dataset == "imagenet": 35 | num_classes = 1000 36 | if network.startswith("res"): 37 | units_dict = {"18": [2, 2, 2, 2], 38 | "34": [3, 4, 6, 3], 39 | "50": [3, 4, 6, 3], 40 | "101": [3, 4, 23, 3], 41 | "152": [3, 8, 36, 3]} 42 | units = units_dict[str(depth)] 43 | if depth >= 50: 44 | filter_list = [64, 256, 512, 1024, 2048] 45 | bottle_neck = True 46 | else: 47 | filter_list = [64, 64, 128, 256, 512] 48 | bottle_neck = False 49 | num_stage = 4 -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/core/__init__.py -------------------------------------------------------------------------------- /core/metric.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | 4 | 5 | class KTAccMetric(mx.metric.EvalMetric): 6 | def __init__(self): 7 | super(KTAccMetric, self).__init__('accuracy') 8 | 9 | def update(self, labels, preds): 10 | pred_cf = preds[0] 11 | label_cf = labels[0] 12 | 13 | pred_label = mx.ndarray.argmax_channel(pred_cf).asnumpy().astype('int32') 14 | label = label_cf.asnumpy().astype('int32') 15 | 16 | self.sum_metric += (pred_label.flat == label.flat).sum() 17 | self.num_inst += len(pred_label.flat) 18 | 19 | 20 | class KTTopkAccMetric(mx.metric.EvalMetric): 21 | """Calculate top k predictions accuracy""" 22 | 23 | def __init__(self, **kwargs): 24 | super(KTTopkAccMetric, self).__init__('top_k_accuracy') 25 | try: 26 | self.top_k = kwargs['top_k'] 27 | except KeyError: 28 | self.top_k = 1 29 | assert (self.top_k > 1), 'Please use Accuracy if top_k is no more than 1' 30 | self.name += '_%d' % self.top_k 31 | 32 | def update(self, labels, preds): 33 | 34 | pred_label = preds[0] 35 | label = labels[0] 36 | assert (len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' 37 | pred_label = np.argsort(pred_label.asnumpy().astype('float32'), axis=1) 38 | label = label.asnumpy().astype('int32') 39 | num_samples = pred_label.shape[0] 40 | num_dims = len(pred_label.shape) 41 | if num_dims == 1: 42 | self.sum_metric += (pred_label.flat == label.flat).sum() 43 | elif num_dims == 2: 44 | num_classes = pred_label.shape[1] 45 | top_k = min(num_classes, self.top_k) 46 | for j in range(top_k): 47 | self.sum_metric += (pred_label[:, num_classes - 1 - j].flat == label.flat).sum() 48 | self.num_inst += num_samples 49 | -------------------------------------------------------------------------------- /core/optimizer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import mxnet as mx 4 | from math import sqrt 5 | from mxnet.optimizer import Optimizer, SGD, clip 6 | from mxnet.ndarray import NDArray, zeros 7 | from mxnet.ndarray import sgd_update, sgd_mom_update 8 | 9 | 10 | @mx.optimizer.register 11 | class APGNAG(SGD): 12 | """APG and NAG. 13 | """ 14 | def __init__(self, lambda_name=None, gamma=None, **kwargs): 15 | super(APGNAG, self).__init__(**kwargs) 16 | self.lambda_name = lambda_name 17 | self.gamma = gamma 18 | 19 | def update(self, index, weight, grad, state): 20 | assert(isinstance(weight, NDArray)) 21 | assert(isinstance(grad, NDArray)) 22 | lr = self._get_lr(index) 23 | wd = self._get_wd(index) 24 | self._update_count(index) 25 | 26 | grad = grad * self.rescale_grad 27 | if self.clip_gradient is not None: 28 | grad = clip(grad, -self.clip_gradient, self.clip_gradient) 29 | 30 | if self.idx2name[index].startswith(self.lambda_name): 31 | # APG 32 | if state is not None: 33 | mom = state 34 | mom[:] *= self.momentum 35 | z = weight - lr * grad # equ 10 36 | z = self.soft_thresholding(z, lr * self.gamma) 37 | mom[:] = z - weight + mom # equ 11 38 | weight[:] = z + self.momentum * mom # equ 12 39 | else: 40 | assert self.momentum == 0.0 41 | # no-negative 42 | weight[:] = mx.ndarray.maximum(0.0, weight[:]) 43 | if self.num_update % 1000 == 0: 44 | print self.idx2name[index], weight.asnumpy() 45 | else: 46 | if state is not None: 47 | mom = state 48 | mom[:] *= self.momentum 49 | grad += wd * weight 50 | mom[:] += grad 51 | grad[:] += self.momentum * mom 52 | weight[:] += -lr * grad 53 | else: 54 | assert self.momentum == 0.0 55 | weight[:] += -lr * (grad + wd * weight) 56 | 57 | @staticmethod 58 | def soft_thresholding(input, alpha): 59 | return mx.ndarray.sign(input) * mx.ndarray.maximum(0.0, mx.ndarray.abs(input) - alpha) 60 | -------------------------------------------------------------------------------- /core/scheduler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import config 3 | 4 | sys.path.insert(0, config.mxnet_path) 5 | import mxnet as mx 6 | 7 | 8 | def multi_factor_scheduler(begin_epoch, epoch_size, step, factor=0.1): 9 | step_ = [epoch_size * (x - begin_epoch) for x in step if x - begin_epoch > 0] 10 | return mx.lr_scheduler.MultiFactorScheduler(step=step_, factor=factor) if len(step_) else None 11 | -------------------------------------------------------------------------------- /core/solver.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import sys 4 | 5 | import config 6 | 7 | sys.path.insert(0, config.mxnet_path) 8 | import mxnet as mx 9 | import numpy as np 10 | from mxnet.module import Module 11 | from mxnet import metric 12 | from mxnet.model import BatchEndParam 13 | 14 | 15 | def _as_list(obj): 16 | if isinstance(obj, list): 17 | return obj 18 | else: 19 | return [obj] 20 | 21 | 22 | class Solver(object): 23 | def __init__(self, symbol, data_names, label_names, 24 | data_shapes, label_shapes, logger=logging, 25 | context=mx.cpu(), work_load_list=None, fixed_param_names=None): 26 | self.symbol = symbol 27 | self.data_names = data_names 28 | self.label_names = label_names 29 | self.data_shapes = data_shapes 30 | self.label_shapes = label_shapes 31 | self.context = context 32 | self.work_load_list = work_load_list 33 | self.fixed_param_names = fixed_param_names 34 | 35 | if logger is None: 36 | logger = logging.getLogger() 37 | logger.setLevel(logging.INFO) 38 | self.logger = logger 39 | self.module = Module(symbol=self.symbol, data_names=self.data_names, 40 | label_names=self.label_names, logger=self.logger, 41 | context=self.context, work_load_list=self.work_load_list, 42 | fixed_param_names=self.fixed_param_names) 43 | 44 | def fit(self, train_data, eval_data=None, 45 | eval_metric='acc', validate_metric=None, 46 | work_load_list=None, epoch_end_callback=None, 47 | batch_end_callback=None, fixed_param_prefix=None, 48 | initializer=None, arg_params=None, 49 | aux_params=None, allow_missing=False, 50 | optimizer=None, optimizer_params=None, 51 | begin_epoch=0, num_epoch=None, 52 | kvstore='device'): 53 | 54 | self.module.bind(data_shapes=self.data_shapes, label_shapes=self.label_shapes, for_training=True) 55 | self.module.init_params(initializer=initializer, 56 | arg_params=arg_params, 57 | aux_params=aux_params, 58 | allow_missing=allow_missing) 59 | self.module.init_optimizer(kvstore=kvstore, 60 | optimizer=optimizer, 61 | optimizer_params=optimizer_params) 62 | 63 | if validate_metric is None: 64 | validate_metric = eval_metric 65 | if not isinstance(eval_metric, metric.EvalMetric): 66 | eval_metric = metric.create(eval_metric) 67 | 68 | # training loop 69 | for epoch in range(begin_epoch, num_epoch): 70 | tic = time.time() 71 | eval_metric.reset() 72 | nbatch = 0 73 | data_iter = iter(train_data) 74 | end_of_batch = False 75 | next_data_batch = next(data_iter) 76 | while not end_of_batch: 77 | data_batch = next_data_batch 78 | self.module.forward(data_batch, is_train=True) 79 | self.module.backward() 80 | self.module.update() 81 | 82 | try: 83 | next_data_batch = next(data_iter) 84 | except StopIteration: 85 | end_of_batch = True 86 | 87 | self.module.update_metric(eval_metric, data_batch.label) 88 | 89 | if batch_end_callback is not None: 90 | batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, 91 | eval_metric=eval_metric, 92 | locals=locals()) 93 | for callback in _as_list(batch_end_callback): 94 | callback(batch_end_params) 95 | nbatch += 1 96 | 97 | for name, val in eval_metric.get_name_value(): 98 | self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) 99 | toc = time.time() 100 | self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) 101 | 102 | arg_params, aux_params = self.module.get_params() 103 | for key in arg_params: 104 | if key == 'lambda_block': 105 | self.logger.info('{}: {}'.format(key, arg_params[key].asnumpy())) 106 | if key == 'lambda_group': 107 | self.logger.info('{}: {}'.format(key, np.sum(arg_params[key].asnumpy() != 0, 1))) 108 | self.module.set_params(arg_params, aux_params) 109 | 110 | if epoch_end_callback is not None: 111 | for callback in _as_list(epoch_end_callback): 112 | callback(epoch, self.symbol, arg_params, aux_params) 113 | if eval_data: 114 | res = self.module.score(eval_data, validate_metric, 115 | score_end_callback=None, 116 | batch_end_callback=None, 117 | reset=True, 118 | epoch=epoch) 119 | for name, val in res: 120 | self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) 121 | 122 | train_data.reset() 123 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from imagenet import imagenet_iterator 2 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import config 4 | 5 | sys.path.insert(0, config.mxnet_path) 6 | import mxnet as mx 7 | 8 | 9 | def imagenet_iterator(data_dir, batch_size, kv): 10 | train = mx.io.ImageRecordIter( 11 | path_imgrec = os.path.join(data_dir, "train.rec"), 12 | label_width = 1, 13 | data_name = 'data', 14 | label_name = 'softmax_label', 15 | data_shape = (3, 224, 224), 16 | batch_size = batch_size, 17 | pad = 0, 18 | fill_value = 127, 19 | facebook_aug = True, 20 | max_random_area = 1.0, 21 | min_random_area = 0.08, 22 | max_aspect_ratio = 4.0 / 3.0, 23 | min_aspect_ratio = 3.0 / 4.0, 24 | brightness = 0.4, 25 | contrast = 0.4, 26 | saturation = 0.4, 27 | mean_r = 123.68, 28 | mean_g = 116.28, 29 | mean_b = 103.53, 30 | std_r = 58.395, 31 | std_g = 57.12, 32 | std_b = 57.375, 33 | pca_noise = 0.1, 34 | scale = 1, 35 | inter_method = 2, 36 | rand_mirror = True, 37 | shuffle = True, 38 | shuffle_chunk_size = 4096, 39 | preprocess_threads = 20, 40 | prefetch_buffer = 16, 41 | num_parts = kv.num_workers, 42 | part_index = kv.rank) 43 | 44 | val = mx.io.ImageRecordIter( 45 | path_imgrec = os.path.join(data_dir, "val.rec"), 46 | label_width = 1, 47 | data_name = 'data', 48 | label_name = 'softmax_label', 49 | resize = 256, 50 | batch_size = batch_size, 51 | data_shape = (3, 224, 224), 52 | mean_r = 123.68, 53 | mean_g = 116.28, 54 | mean_b = 103.53, 55 | std_r = 58.395, 56 | std_g = 57.12, 57 | std_b = 57.375, 58 | scale = 1, 59 | inter_method = 2, 60 | rand_crop = False, 61 | rand_mirror = False, 62 | num_parts = kv.num_workers, 63 | part_index = kv.rank) 64 | 65 | num_examples = 1281167 66 | return train, val, num_examples 67 | -------------------------------------------------------------------------------- /docs/flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/docs/flops.png -------------------------------------------------------------------------------- /docs/params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/docs/params.png -------------------------------------------------------------------------------- /docs/sss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzehao/sparse-structure-selection/12ed0fd50e2995544a765690cacf0de46937f77b/docs/sss.png -------------------------------------------------------------------------------- /symbol/__init__.py: -------------------------------------------------------------------------------- 1 | from resnet import resnet 2 | from resnext import resnext 3 | -------------------------------------------------------------------------------- /symbol/resnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import config 3 | 4 | sys.path.insert(0, config.mxnet_path) 5 | import mxnet as mx 6 | 7 | eps = 1e-5 8 | 9 | 10 | def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, 11 | bn_mom=0.9, workspace=512, memonger=False, lambda_block=None): 12 | if bottle_neck: 13 | bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn1') 14 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') 15 | conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.25), kernel=(1, 1), stride=(1, 1), 16 | pad=(0, 0), 17 | no_bias=True, workspace=workspace, name=name + '_conv1') 18 | bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn2') 19 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') 20 | conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter * 0.25), kernel=(3, 3), stride=stride, 21 | pad=(1, 1), 22 | no_bias=True, workspace=workspace, name=name + '_conv2') 23 | bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn3') 24 | act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') 25 | conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), 26 | no_bias=True, 27 | workspace=workspace, name=name + '_conv3') 28 | if config.sss: 29 | conv3 = mx.sym.broadcast_mul(conv3, lambda_block) 30 | if dim_match: 31 | shortcut = data 32 | else: 33 | shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, 34 | workspace=workspace, name=name + '_sc') 35 | if memonger: 36 | shortcut._set_attr(mirror_stage='True') 37 | 38 | return conv3 + shortcut 39 | else: 40 | bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn1') 41 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') 42 | conv1 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3, 3), stride=stride, pad=(1, 1), 43 | no_bias=True, workspace=workspace, name=name + '_conv1') 44 | bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn2') 45 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') 46 | conv2 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1), 47 | no_bias=True, workspace=workspace, name=name + '_conv2') 48 | if dim_match: 49 | shortcut = data 50 | else: 51 | shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1, 1), stride=stride, no_bias=True, 52 | workspace=workspace, name=name + '_sc') 53 | if memonger: 54 | shortcut._set_attr(mirror_stage='True') 55 | return conv2 + shortcut 56 | 57 | 58 | def resnet(units, num_stage, filter_list, num_classes, data_type, bottle_neck=True, 59 | bn_mom=0.9, workspace=512, memonger=False): 60 | num_unit = len(units) 61 | assert (num_unit == num_stage) 62 | 63 | # declare lambda array 64 | if config.sss: 65 | num_block = 0 66 | for i in range(num_stage): 67 | num_block += 1 68 | for j in range(units[i] - 1): 69 | num_block += 1 70 | print 'block number:', num_block 71 | block_array = mx.symbol.Variable(name="lambda_block", shape=(num_block, ), dtype='float32', lr_mult=1, wd_mult=0, init=mx.initializer.One()) 72 | block_split = mx.symbol.SliceChannel(block_array, num_outputs=num_block, axis=0, squeeze_axis=0,name="lambda_block_split") 73 | 74 | num_block = 0 75 | data = mx.sym.Variable(name='data') 76 | body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3), 77 | no_bias=True, name="conv0", workspace=workspace) 78 | body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=eps, momentum=bn_mom, name='bn0') 79 | body = mx.sym.Activation(data=body, act_type='relu', name='relu0') 80 | body = mx.symbol.Pooling(data=body, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max') 81 | 82 | for i in range(num_stage): 83 | body = residual_unit(body, filter_list[i + 1], (1 if i == 0 else 2, 1 if i == 0 else 2), False, 84 | name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace, 85 | memonger=memonger, lambda_block=block_split[num_block]) 86 | num_block += 1 87 | for j in range(units[i] - 1): 88 | body = residual_unit(body, filter_list[i + 1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), 89 | bottle_neck=bottle_neck, workspace=workspace, memonger=memonger, lambda_block=block_split[num_block]) 90 | num_block += 1 91 | bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=eps, momentum=bn_mom, name='bn1') 92 | relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1') 93 | pool1 = mx.symbol.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') 94 | flat = mx.symbol.Flatten(data=pool1) 95 | fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') 96 | cls = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') 97 | return cls 98 | -------------------------------------------------------------------------------- /symbol/resnext.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import config 3 | 4 | sys.path.insert(0, config.mxnet_path) 5 | import mxnet as mx 6 | 7 | eps = 1e-5 8 | 9 | 10 | def xresidual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, 11 | num_group=32, bn_mom=0.9, workspace=256, memonger=False, lambda_group=None, lambda_block=None): 12 | if config.sss: 13 | lambda_group = mx.symbol.repeat(lambda_group, repeats=int(num_filter * 0.5 / num_group)) 14 | lambda_group = mx.symbol.reshape(data=lambda_group, shape=(0,1,1,1)) 15 | if bottle_neck: 16 | conv1 = mx.sym.Convolution(data=data, num_filter=int(num_filter * 0.5), kernel=(1, 1), stride=(1, 1), 17 | pad=(0, 0), 18 | no_bias=True, workspace=workspace, name=name + '_conv1') 19 | bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn1') 20 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') 21 | 22 | conv2 = mx.sym.Convolution(data=act1, num_filter=int(num_filter * 0.5), num_group=num_group, kernel=(3, 3), 23 | stride=stride, pad=(1, 1), 24 | no_bias=True, workspace=workspace, name=name + '_conv2') 25 | bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn2') 26 | if config.sss: 27 | bn2 = mx.sym.transpose(data=bn2, axes=(1,0,2,3)) 28 | bn2 = mx.sym.broadcast_mul(lhs=bn2,rhs=lambda_group) 29 | bn2 = mx.sym.transpose(data=bn2, axes=(1,0,2,3)) 30 | act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') 31 | 32 | conv3 = mx.sym.Convolution(data=act2, num_filter=num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), 33 | no_bias=True, 34 | workspace=workspace, name=name + '_conv3') 35 | bn3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=eps, momentum=bn_mom, name=name + '_bn3') 36 | 37 | if dim_match: 38 | shortcut = data 39 | else: 40 | shortcut_conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1, 1), stride=stride, 41 | no_bias=True, 42 | workspace=workspace, name=name + '_sc') 43 | shortcut = mx.sym.BatchNorm(data=shortcut_conv, fix_gamma=False, eps=eps, momentum=bn_mom, 44 | name=name + '_sc_bn') 45 | 46 | if memonger: 47 | shortcut._set_attr(mirror_stage='True') 48 | if config.sss: 49 | bn3 = mx.sym.broadcast_mul(bn3, lambda_block) 50 | eltwise = bn3 + shortcut 51 | return mx.sym.Activation(data=eltwise, act_type='relu', name=name + '_relu') 52 | else: 53 | conv1 = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(3, 3), stride=stride, pad=(1, 1), 54 | no_bias=True, workspace=workspace, name=name + '_conv1') 55 | bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn1') 56 | act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') 57 | 58 | conv2 = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(3, 3), stride=(1, 1), pad=(1, 1), 59 | no_bias=True, workspace=workspace, name=name + '_conv2') 60 | bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, momentum=bn_mom, eps=eps, name=name + '_bn2') 61 | 62 | if dim_match: 63 | shortcut = data 64 | else: 65 | shortcut_conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=(1, 1), stride=stride, 66 | no_bias=True, 67 | workspace=workspace, name=name + '_sc') 68 | shortcut = mx.sym.BatchNorm(data=shortcut_conv, fix_gamma=False, eps=eps, momentum=bn_mom, 69 | name=name + '_sc_bn') 70 | 71 | if memonger: 72 | shortcut._set_attr(mirror_stage='True') 73 | eltwise = bn2 + shortcut 74 | return mx.sym.Activation(data=eltwise, act_type='relu', name=name + '_relu') 75 | 76 | 77 | def resnext(units, num_stage, filter_list, num_classes, data_type, num_group=32, 78 | bottle_neck=True, bn_mom=0.9, workspace=256, memonger=False): 79 | num_unit = len(units) 80 | assert (num_unit == num_stage) 81 | 82 | # declare lambda array 83 | if config.sss: 84 | num_block = 0 85 | for i in range(num_stage): 86 | num_block += 1 87 | for j in range(units[i] - 1): 88 | num_block += 1 89 | print 'block number:', num_block 90 | print 'group number:', num_group 91 | block_array = mx.symbol.Variable(name="lambda_block", shape=(num_block, ), dtype='float32', lr_mult=1, wd_mult=0, init=mx.initializer.One()) 92 | block_split = mx.symbol.SliceChannel(block_array, num_outputs=num_block, axis=0, squeeze_axis=0,name="lambda_block_split") 93 | group_array = mx.symbol.Variable(name="lambda_group", shape=(num_block, num_group), dtype='float32', lr_mult=1, wd_mult=0, init=mx.initializer.One()) 94 | group_split = mx.symbol.SliceChannel(group_array, num_outputs=num_block, axis=0, squeeze_axis=1,name="lambda_group_split") 95 | 96 | num_block = 0 97 | data = mx.sym.Variable(name='data') 98 | body = mx.sym.Convolution(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2, 2), pad=(3, 3), 99 | no_bias=True, name="conv0", workspace=workspace) 100 | body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=eps, momentum=bn_mom, name='bn0') 101 | body = mx.sym.Activation(data=body, act_type='relu', name='relu0') 102 | body = mx.symbol.Pooling(data=body, kernel=(3, 3), stride=(2, 2), pad=(1, 1), pool_type='max') 103 | 104 | for i in range(num_stage): 105 | body = xresidual_unit(body, filter_list[i + 1], (1 if i == 0 else 2, 1 if i == 0 else 2), False, 106 | name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, num_group=num_group, 107 | bn_mom=bn_mom, workspace=workspace, memonger=memonger, 108 | lambda_group=group_split[num_block] if config.sss else None, 109 | lambda_block=block_split[num_block] if config.sss else None) 110 | num_block += 1 111 | for j in range(units[i] - 1): 112 | body = xresidual_unit(body, filter_list[i + 1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), 113 | bottle_neck=bottle_neck, num_group=num_group, bn_mom=bn_mom, 114 | workspace=workspace, memonger=memonger, 115 | lambda_group=group_split[num_block] if config.sss else None, 116 | lambda_block=block_split[num_block] if config.sss else None) 117 | num_block += 1 118 | pool1 = mx.symbol.Pooling(data=body, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1') 119 | flat = mx.symbol.Flatten(data=pool1) 120 | fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=num_classes, name='fc1') 121 | cls = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') 122 | return cls 123 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging, os 2 | import sys 3 | 4 | import config 5 | 6 | sys.path.insert(0, config.mxnet_path) 7 | import mxnet as mx 8 | from core.scheduler import multi_factor_scheduler 9 | from core.solver import Solver 10 | from core.metric import * 11 | from core.optimizer import * 12 | from data import * 13 | from symbol import * 14 | 15 | 16 | def main(config): 17 | # log file 18 | log_dir = "./log" 19 | if not os.path.exists(log_dir): 20 | os.mkdir(log_dir) 21 | logging.basicConfig(level=logging.INFO, 22 | format='%(asctime)s %(name)s %(levelname)s %(message)s', 23 | datefmt='%m-%d %H:%M', 24 | filename='{}/{}.log'.format(log_dir, config.model_prefix), 25 | filemode='a') 26 | console = logging.StreamHandler() 27 | console.setLevel(logging.INFO) 28 | formatter = logging.Formatter('%(name)s %(levelname)s %(message)s') 29 | console.setFormatter(formatter) 30 | logging.getLogger('').addHandler(console) 31 | # model folder 32 | model_dir = "./model" 33 | if not os.path.exists(model_dir): 34 | os.mkdir(model_dir) 35 | 36 | # set up environment 37 | devs = [mx.gpu(int(i)) for i in config.gpu_list] 38 | kv = mx.kvstore.create(config.kv_store) 39 | 40 | # set up iterator and symbol 41 | # iterator 42 | train, val, num_examples = imagenet_iterator(data_dir=config.data_dir, 43 | batch_size=config.batch_size, 44 | kv=kv) 45 | data_names = ('data',) 46 | label_names = ('softmax_label',) 47 | data_shapes = [('data', (config.batch_size, 3, 224, 224))] 48 | label_shapes = [('softmax_label', (config.batch_size,))] 49 | 50 | if config.network == 'resnet' or config.network == 'resnext': 51 | symbol = eval(config.network)(units=config.units, 52 | num_stage=config.num_stage, 53 | filter_list=config.filter_list, 54 | num_classes=config.num_classes, 55 | data_type=config.dataset, 56 | bottle_neck=config.bottle_neck) 57 | 58 | # train 59 | epoch_size = max(int(num_examples / config.batch_size / kv.num_workers), 1) 60 | if config.lr_step is not None: 61 | lr_scheduler = multi_factor_scheduler(config.begin_epoch, epoch_size, step=config.lr_step, 62 | factor=config.lr_factor) 63 | else: 64 | lr_scheduler = None 65 | 66 | optimizer_params = {'learning_rate': config.lr, 67 | 'lr_scheduler': lr_scheduler, 68 | 'wd': config.wd, 69 | 'momentum': config.momentum} 70 | optimizer = "nag" 71 | if config.sss: 72 | sss_optimizer_params = {'lambda_name': 'lambda', 73 | 'gamma': config.gamma} 74 | optimizer_params.update(sss_optimizer_params) 75 | optimizer = "apgnag" 76 | eval_metric = ['acc'] 77 | if config.dataset == "imagenet": 78 | eval_metric.append(mx.metric.create('top_k_accuracy', top_k=5)) 79 | 80 | solver = Solver(symbol=symbol, 81 | data_names=data_names, 82 | label_names=label_names, 83 | data_shapes=data_shapes, 84 | label_shapes=label_shapes, 85 | logger=logging, 86 | context=devs) 87 | epoch_end_callback = mx.callback.do_checkpoint("./model/" + config.model_prefix) 88 | batch_end_callback = mx.callback.Speedometer(config.batch_size, config.frequent) 89 | arg_params = None 90 | aux_params = None 91 | if config.retrain: 92 | _, arg_params, aux_params = mx.model.load_checkpoint("model/{}".format(config.model_load_prefix), 93 | config.model_load_epoch) 94 | 95 | if config.network.startswith('res'): 96 | initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in', magnitude=2) 97 | 98 | solver.fit(train_data=train, 99 | eval_data=val, 100 | eval_metric=eval_metric, 101 | epoch_end_callback=epoch_end_callback, 102 | batch_end_callback=batch_end_callback, 103 | initializer=initializer, 104 | arg_params=arg_params, 105 | aux_params=aux_params, 106 | optimizer=optimizer, 107 | optimizer_params=optimizer_params, 108 | begin_epoch=config.begin_epoch, 109 | num_epoch=config.num_epoch, 110 | kvstore=kv) 111 | 112 | 113 | if __name__ == '__main__': 114 | main(config) 115 | --------------------------------------------------------------------------------