├── .gitignore ├── LICENSE.md ├── README.md ├── configs ├── imagenet_resnet101_prune100.json ├── imagenet_resnet101_prune40.json ├── imagenet_resnet101_prune50.json ├── imagenet_resnet101_prune50_group.json ├── imagenet_resnet101_prune55.json ├── imagenet_resnet101_prune75.json ├── imagenet_resnet50_prune100.json ├── imagenet_resnet50_prune56.json ├── imagenet_resnet50_prune72.json ├── imagenet_resnet50_prune81.json └── imagenet_resnet50_prune91.json ├── images └── resnet_result.png ├── layers └── gate_layer.py ├── logger.py ├── main.py ├── models ├── densenet_imagenet.py ├── lenet.py ├── preact_resnet.py ├── resnet.py └── vgg_bn.py ├── pruning_engine.py ├── requirements.txt └── utils ├── group_lasso_optimizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | .idea/ 3 | *.tar.gz 4 | *.zip 5 | *.pkl 6 | *.pyc 7 | *.py~ 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | *.*[cod] 12 | .DS_Store 13 | ._* 14 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 26 | 27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 28 | 29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 30 | 31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 32 | 33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 34 | 35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 36 | 37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 38 | 39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 40 | 41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 42 | 43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 44 | 45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 46 | 47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 48 | 49 | ### Section 2 – Scope. 50 | 51 | a. ___License grant.___ 52 | 53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 54 | 55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 56 | 57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 58 | 59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 60 | 61 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 62 | 63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 64 | 65 | 5. __Downstream recipients.__ 66 | 67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 68 | 69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 70 | 71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 72 | 73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 74 | 75 | b. ___Other rights.___ 76 | 77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 78 | 79 | 2. Patent and trademark rights are not licensed under this Public License. 80 | 81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 82 | 83 | ### Section 3 – License Conditions. 84 | 85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 86 | 87 | a. ___Attribution.___ 88 | 89 | 1. If You Share the Licensed Material (including in modified form), You must: 90 | 91 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 92 | 93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 94 | 95 | ii. a copyright notice; 96 | 97 | iii. a notice that refers to this Public License; 98 | 99 | iv. a notice that refers to the disclaimer of warranties; 100 | 101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 102 | 103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 104 | 105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 106 | 107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 108 | 109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 110 | 111 | b. ___ShareAlike.___ 112 | 113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 114 | 115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 116 | 117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 118 | 119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 120 | 121 | ### Section 4 – Sui Generis Database Rights. 122 | 123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 124 | 125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 126 | 127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 128 | 129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 130 | 131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 132 | 133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 134 | 135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 136 | 137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 138 | 139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 140 | 141 | ### Section 6 – Term and Termination. 142 | 143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 144 | 145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 146 | 147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 148 | 149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 150 | 151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 152 | 153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 154 | 155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 156 | 157 | ### Section 7 – Other Terms and Conditions. 158 | 159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 160 | 161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 162 | 163 | ### Section 8 – Interpretation. 164 | 165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 166 | 167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 168 | 169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 170 | 171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 172 | 173 | ``` 174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 175 | 176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/). 177 | ``` 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/nvlabs/SPADE/master/LICENSE.md) 2 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 3 | 4 | # Importance Estimation for Neural Network Pruning 5 | 6 | This repo contains required scripts to reproduce results from paper: 7 | 8 | Importance Estimation for Neural Network Pruning
9 | Pavlo Molchanov, Arun Mallya, Stephen Tyree, Iuri Frosio, Jan Kautz .
10 | In CVPR 2019. 11 | 12 | ![ResNet results](images/resnet_result.png "ResNet results") 13 | 14 | ### [License](https://raw.githubusercontent.com/nvlabs/Taylor_pruning/master/LICENSE.md) 15 | 16 | Copyright (C) 2019 NVIDIA Corporation. 17 | 18 | All rights reserved. 19 | Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**) 20 | 21 | The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com). 22 | 23 | ## Installation 24 | 25 | Clone this repo. 26 | ```bash 27 | git clone https://github.com/NVlabs/Taylor_pruning.git 28 | cd Taylor_pruning/ 29 | ``` 30 | 31 | This code requires PyTorch 1.0 and python 3+. Please install dependencies by 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | For the best reproducibility of results you will need NVIDIA DGX1 server with 8 V100. During pruning and finetuning we at most use 4 GPUs. 37 | 38 | 39 | The code was tested with python3.6 the following software versions: 40 | 41 | | Software | version | 42 | | ------------- |-------------| 43 | | cuDNN | v7500 | 44 | | Pytorch | 1.0.1.post2 | 45 | | CUDA | v10.0 | 46 | 47 | 48 | ## Preparation 49 | 50 | ### Dataset preparation 51 | 52 | Pruning examples use ImageNet 1k dataset which needs to be downloaded beforehand. 53 | Use standard instructions to setup ImageNet 1k for Pytorch, e.g. from [here](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5#geting-the-data). 54 | 55 | ### Models preparation 56 | 57 | We use pretrained models provided with Pytorch, they need to be downloaded: 58 | 59 | ``` 60 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth 61 | wget https://download.pytorch.org/models/resnet101-5d3b4d8f.pth 62 | mkdir ./models/pretrained/ 63 | mv {resnet50-19c8e357.pth,resnet101-5d3b4d8f.pth} ./models/pretrained/ 64 | ``` 65 | 66 | ## Pruning examples 67 | 68 | Results in Table 3 can be reproduced the following script running on 4 V100 GPUs: 69 | 70 | ``` 71 | python main.py --name=runs/resnet50/resnet50_prune72 --dataset=Imagenet \ 72 | --lr=0.001 --lr-decay-every=10 --momentum=0.9 --epochs=25 --batch-size=256 \ 73 | --pruning=True --seed=0 --model=resnet50 --load_model=./models/pretrained/resnet50-19c8e357.pth \ 74 | --mgpu=True --group_wd_coeff=1e-8 --wd=0.0 --tensorboard=True --pruning-method=22 \ 75 | --data=/imagenet/ --no_grad_clip=True --pruning_config=./configs/imagenet_resnet50_prune72.json 76 | ``` 77 | 78 | Note: we run finetuning for 25 epochs and do not use weight decay. Better model can be obtained by training longer, increasing learning and using weight decay. 79 | 80 | Example of config file `./configs/imagenet_resnet50_prune72.json`: 81 | ```json 82 | { 83 | "method": 22, 84 | "frequency" : 30, 85 | "prune_per_iteration" : 100, 86 | "maximum_pruning_iterations" : 32, 87 | "starting_neuron" : 0, 88 | "fixed_layer" : -1, 89 | "l2_normalization_per_layer": false, 90 | "rank_neurons_equally": false, 91 | "prune_neurons_max": 3200, 92 | "use_momentum": true, 93 | "pruning_silent": false, 94 | 95 | "pruning_threshold" : 100.0, 96 | "start_pruning_after_n_iterations" : 0, 97 | "push_group_down" : false, 98 | "do_iterative_pruning" : true, 99 | "fixed_criteria" : false, 100 | "seed" : 0, 101 | "pruning_momentum" : 0.9 102 | } 103 | 104 | ``` 105 | 106 | We provide config files for pruning to 56%, 72%, 81%, 91% of original ResNet-50 and to 40%, 50%, 55%, 75% of ResNet-101 models. Percentage means the ratio of gates to be active after pruning. 107 | 108 | ## Parameter description 109 | 110 | Pruning methods (different criteria) are encoded with integer as: 111 | 112 | | method id | name | description | comment | 113 | | ------------- |-------------| -------------| -----| 114 | | 22 | Taylor_gate | *Gate after BN* in Table 2, *Taylor FO* in Table 1, *Taylor-FO-BN* in Table 3 | **Best method**| 115 | | 0 | Taylor_weight| *Conv weight/conv/linear weight with Taylor FO* In Table 2 and Table 1 | | 116 | | 1 | Random | Random|| 117 | | 2 | Weight norm | Weight magnitude/ weight|| 118 | | 3 | Weight_abs | Not used|| 119 | | 6 | Taylor_output | *Taylor-output* as in [27]|| 120 | | 10 | OBD | Optimal Brain Damage|| 121 | | 11 | Taylor_gate_SO| Taylor SO|| 122 | | 23 | Taylor_gate_FG| uses gradient per example to compute Taylor FO, Taylor FO- FG in Table 1, Gate after BN - FG in Table 2|| 123 | | 30 | BN_weight | *BN scale* in Table 2|| 124 | | 31 | BN_Taylor | *BN scale Taylor FO* in Table 2|| 125 | 126 | 127 | 128 | ### Citation 129 | 130 | If you use this code for your research, please cite our papers. 131 | ``` 132 | @inproceedings{molchanov2019taylor, 133 | title={Importance Estimation for Neural Network Pruning}, 134 | author={Molchanov, Pavlo and Mallya, Arun and Tyree, Stephen and Frosio, Iuri and Kautz, Jan}, 135 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 136 | year={2019} 137 | } 138 | ``` 139 | 140 | ## To come 141 | 142 | - Config files for DenseNet and VGG16 pruning on ImageNet 143 | - Examples of CIFAR experiments 144 | - Getting shrinked model after pruning (for ResNet-101 only) 145 | - Oracle estimates of true importance, code to compute correlation with Oracle 146 | 147 | ## Acknowledgments 148 | This code heavily reuses Pytorch example for ImageNet training provided [here](https://github.com/pytorch/examples/tree/master/imagenet). -------------------------------------------------------------------------------- /configs/imagenet_resnet101_prune100.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 0, 5 | "maximum_pruning_iterations" : 1, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 0, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet101_prune40.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 120, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 12000, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet101_prune50.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 100, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 10000, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet101_prune50_group.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 100, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 10000, 11 | "use_momentum": true, 12 | "pruning_silent": true, 13 | "group_size": 32, 14 | "flops_regularization": 1e-12, 15 | 16 | "pruning_threshold" : 100.0, 17 | "start_pruning_after_n_iterations" : 0, 18 | "do_iterative_pruning" : true, 19 | "fixed_criteria" : false, 20 | "seed" : 0, 21 | "pruning_momentum" : 0.9 22 | } 23 | 24 | -------------------------------------------------------------------------------- /configs/imagenet_resnet101_prune55.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 90, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 9000, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet101_prune75.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 50, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 5000, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet50_prune100.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 0, 5 | "maximum_pruning_iterations" : 1, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 0, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet50_prune56.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 55, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 5500, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet50_prune72.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 32, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 3200, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet50_prune81.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 22, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 2200, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /configs/imagenet_resnet50_prune91.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": 22, 3 | "frequency" : 30, 4 | "prune_per_iteration" : 100, 5 | "maximum_pruning_iterations" : 10, 6 | "starting_neuron" : 0, 7 | "fixed_layer" : -1, 8 | "l2_normalization_per_layer": false, 9 | "rank_neurons_equally": false, 10 | "prune_neurons_max": 1000, 11 | "use_momentum": true, 12 | "pruning_silent": false, 13 | 14 | "pruning_threshold" : 100.0, 15 | "start_pruning_after_n_iterations" : 0, 16 | "push_group_down" : false, 17 | "do_iterative_pruning" : true, 18 | "fixed_criteria" : false, 19 | "seed" : 0, 20 | "pruning_momentum" : 0.9 21 | } 22 | 23 | -------------------------------------------------------------------------------- /images/resnet_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Taylor_pruning/4b70120a8903494cdb565cd8ca97509543fd2862/images/resnet_result.png -------------------------------------------------------------------------------- /layers/gate_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | ''' 11 | Gating layer for pruning 12 | ''' 13 | 14 | class GateLayer(nn.Module): 15 | def __init__(self, input_features, output_features, size_mask): 16 | super(GateLayer, self).__init__() 17 | self.input_features = input_features 18 | self.output_features = output_features 19 | self.size_mask = size_mask 20 | self.weight = nn.Parameter(torch.ones(output_features)) 21 | 22 | # for simpler way to find these layers 23 | self.do_not_update = True 24 | 25 | def forward(self, input): 26 | return input*self.weight.view(*self.size_mask) 27 | 28 | def extra_repr(self): 29 | return 'in_features={}, out_features={}'.format( 30 | self.input_features, self.output_features is not None 31 | ) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import sys 7 | 8 | #based on https://groups.google.com/forum/#!topic/comp.lang.python/0lqfVgjkc68 9 | 10 | class Logger(object): 11 | def __init__(self, filename="Default.log"): 12 | self.terminal = sys.stdout 13 | bufsize = 1 14 | self.log = open(filename, "w", buffering=bufsize) 15 | 16 | def delink(self): 17 | self.log.close() 18 | self.log = open('foo', "w") 19 | # self.write = self.writeTerminalOnly 20 | 21 | def writeTerminalOnly(self, message): 22 | self.terminal.write(message) 23 | 24 | def write(self, message): 25 | self.terminal.write(message) 26 | self.log.write(message) 27 | 28 | def flush(self): 29 | pass 30 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from __future__ import print_function 7 | import argparse 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | import torch.backends.cudnn as cudnn 12 | 13 | # ++++++for pruning 14 | import os, sys 15 | import time 16 | from utils.utils import save_checkpoint, adjust_learning_rate, AverageMeter, accuracy, load_model_pytorch, dynamic_network_change_local, get_conv_sizes, connect_gates_with_parameters_for_flops 17 | from tensorboardX import SummaryWriter 18 | 19 | from logger import Logger 20 | from models.lenet import LeNet 21 | from models.vgg_bn import slimmingvgg as vgg11_bn 22 | from models.preact_resnet import * 23 | from pruning_engine import pytorch_pruning, PruningConfigReader, prepare_pruning_list 24 | 25 | from utils.group_lasso_optimizer import group_lasso_decay 26 | 27 | import torch.distributed as dist 28 | import torch.utils.data 29 | import torch.utils.data.distributed 30 | import torch.nn.parallel 31 | 32 | import warnings 33 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 34 | 35 | import numpy as np 36 | #++++++++end 37 | 38 | # code is based on Pytorch example for imagenet 39 | # https://github.com/pytorch/examples/tree/master/imagenet 40 | 41 | 42 | def str2bool(v): 43 | # from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse/43357954#43357954 44 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 45 | return True 46 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 47 | return False 48 | else: 49 | raise argparse.ArgumentTypeError('Boolean value expected.') 50 | 51 | 52 | def train(args, model, device, train_loader, optimizer, epoch, criterion, train_writer=None, pruning_engine=None): 53 | """Train for one epoch on the training set also performs pruning""" 54 | global global_iteration 55 | batch_time = AverageMeter() 56 | losses = AverageMeter() 57 | top1 = AverageMeter() 58 | top5 = AverageMeter() 59 | loss_tracker = 0.0 60 | acc_tracker = 0.0 61 | loss_tracker_num = 0 62 | res_pruning = 0 63 | 64 | model.train() 65 | if args.fixed_network: 66 | # if network is fixed then we put it to eval mode 67 | model.eval() 68 | 69 | end = time.time() 70 | for batch_idx, (data, target) in enumerate(train_loader): 71 | data, target = data.to(device), target.to(device) 72 | # make sure that all gradients are zero 73 | for p in model.parameters(): 74 | if p.grad is not None: 75 | p.grad.detach_() 76 | p.grad.zero_() 77 | 78 | output = model(data) 79 | loss = criterion(output, target) 80 | 81 | if args.pruning: 82 | # useful for method 40 and 50 that calculate oracle 83 | pruning_engine.run_full_oracle(model, data, target, criterion, initial_loss=loss.item()) 84 | 85 | # measure accuracy and record loss 86 | losses.update(loss.item(), data.size(0)) 87 | 88 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 89 | top1.update(prec1.item(), data.size(0)) 90 | top5.update(prec5.item(), data.size(0)) 91 | acc_tracker += prec1.item() 92 | 93 | loss_tracker += loss.item() 94 | 95 | loss_tracker_num += 1 96 | 97 | if args.pruning: 98 | if pruning_engine.needs_hessian: 99 | pruning_engine.compute_hessian(loss) 100 | 101 | if not (args.pruning and args.pruning_method == 50): 102 | group_wd_optimizer.step() 103 | 104 | 105 | loss.backward() 106 | 107 | 108 | # add gradient clipping 109 | if not args.no_grad_clip: 110 | # found it useless for our experiments 111 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 112 | 113 | # step_after will calculate flops and number of parameters left 114 | # needs to be launched before the main optimizer, 115 | # otherwise weight decay will make numbers not correct 116 | if not (args.pruning and args.pruning_method == 50): 117 | if batch_idx % args.log_interval == 0: 118 | group_wd_optimizer.step_after() 119 | 120 | optimizer.step() 121 | 122 | batch_time.update(time.time() - end) 123 | end = time.time() 124 | 125 | global_iteration = global_iteration + 1 126 | 127 | if batch_idx % args.log_interval == 0: 128 | print('Epoch: [{0}][{1}/{2}]\t' 129 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 130 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 131 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 132 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 133 | epoch, batch_idx, len(train_loader), batch_time=batch_time, 134 | loss=losses, top1=top1, top5=top5)) 135 | 136 | if train_writer is not None: 137 | train_writer.add_scalar('train_loss_ave', losses.avg, global_iteration) 138 | 139 | if args.pruning: 140 | # pruning_engine.update_flops(stats=group_wd_optimizer.per_layer_per_neuron_stats) 141 | pruning_engine.do_step(loss=loss.item(), optimizer=optimizer) 142 | 143 | if args.model == "resnet20" or args.model == "resnet101" or args.dataset == "Imagenet": 144 | if (pruning_engine.maximum_pruning_iterations == pruning_engine.pruning_iterations_done) and pruning_engine.set_moment_zero: 145 | for group in optimizer.param_groups: 146 | for p in group['params']: 147 | if p.grad is None: 148 | continue 149 | param_state = optimizer.state[p] 150 | if 'momentum_buffer' in param_state: 151 | del param_state['momentum_buffer'] 152 | 153 | pruning_engine.set_moment_zero = False 154 | 155 | # if not (args.pruning and args.pruning_method == 50): 156 | # if batch_idx % args.log_interval == 0: 157 | # group_wd_optimizer.step_after() 158 | 159 | if args.tensorboard and (batch_idx % args.log_interval == 0): 160 | 161 | neurons_left = int(group_wd_optimizer.get_number_neurons(print_output=args.get_flops)) 162 | flops = int(group_wd_optimizer.get_number_flops(print_output=args.get_flops)) 163 | 164 | train_writer.add_scalar('neurons_optimizer_left', neurons_left, global_iteration) 165 | train_writer.add_scalar('neurons_optimizer_flops_left', flops, global_iteration) 166 | else: 167 | if args.get_flops: 168 | neurons_left = int(group_wd_optimizer.get_number_neurons(print_output=args.get_flops)) 169 | flops = int(group_wd_optimizer.get_number_flops(print_output=args.get_flops)) 170 | 171 | if args.limit_training_batches != -1: 172 | if args.limit_training_batches < batch_idx: 173 | # return from training step, unsafe and was not tested correctly 174 | print("return from training step, unsafe and was not tested correctly") 175 | return 0 176 | 177 | # print number of parameters left: 178 | if args.tensorboard: 179 | print('neurons_optimizer_left', neurons_left, global_iteration) 180 | 181 | 182 | def validate(args, test_loader, model, device, criterion, epoch, train_writer=None): 183 | """Perform validation on the validation set""" 184 | batch_time = AverageMeter() 185 | losses = AverageMeter() 186 | top1 = AverageMeter() 187 | top5 = AverageMeter() 188 | 189 | # switch to evaluate mode 190 | model.eval() 191 | 192 | end = time.time() 193 | with torch.no_grad(): 194 | for data_test in test_loader: 195 | data, target = data_test 196 | 197 | data = data.to(device) 198 | 199 | output = model(data) 200 | 201 | if args.get_inference_time: 202 | iterations_get_inference_time = 100 203 | start_get_inference_time = time.time() 204 | for it in range(iterations_get_inference_time): 205 | output = model(data) 206 | end_get_inference_time = time.time() 207 | print("time taken for %d iterations, per-iteration is: "%(iterations_get_inference_time), (end_get_inference_time - start_get_inference_time)*1000.0/float(iterations_get_inference_time), "ms") 208 | 209 | target = target.to(device) 210 | loss = criterion(output, target) 211 | 212 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 213 | losses.update(loss.item(), data.size(0)) 214 | top1.update(prec1.item(), data.size(0)) 215 | top5.update(prec5.item(), data.size(0)) 216 | 217 | # measure elapsed time 218 | batch_time.update(time.time() - end) 219 | end = time.time() 220 | 221 | print(' * Prec@1 {top1.avg:.3f}, Prec@5 {top5.avg:.3f}, Time {batch_time.sum:.5f}, Loss: {losses.avg:.3f}'.format(top1=top1, top5=top5,batch_time=batch_time, losses = losses) ) 222 | # log to TensorBoard 223 | if train_writer is not None: 224 | train_writer.add_scalar('val_loss', losses.avg, epoch) 225 | train_writer.add_scalar('val_acc', top1.avg, epoch) 226 | 227 | return top1.avg, losses.avg 228 | 229 | 230 | 231 | def main(): 232 | # Training settings 233 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 234 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 235 | help='input batch size for training (default: 64)') 236 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 237 | help='input batch size for testing (default: 1000)') 238 | parser.add_argument('--world_size', type=int, default=1, 239 | help='number of GPUs to use') 240 | 241 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 242 | help='number of epochs to train (default: 10)') 243 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 244 | help='learning rate (default: 0.01)') 245 | parser.add_argument('--wd', type=float, default=1e-4, 246 | help='weight decay (default: 5e-4)') 247 | parser.add_argument('--lr-decay-every', type=int, default=100, 248 | help='learning rate decay by 10 every X epochs') 249 | parser.add_argument('--lr-decay-scalar', type=float, default=0.1, 250 | help='--') 251 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 252 | help='SGD momentum (default: 0.5)') 253 | parser.add_argument('--no-cuda', action='store_true', default=False, 254 | help='disables CUDA training') 255 | parser.add_argument('--seed', type=int, default=1, metavar='S', 256 | help='random seed (default: 1)') 257 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 258 | help='how many batches to wait before logging training status') 259 | 260 | parser.add_argument('--run_test', default=False, type=str2bool, nargs='?', 261 | help='run test only') 262 | 263 | parser.add_argument('--limit_training_batches', type=int, default=-1, 264 | help='how many batches to do per training, -1 means as many as possible') 265 | 266 | parser.add_argument('--no_grad_clip', default=False, type=str2bool, nargs='?', 267 | help='turn off gradient clipping') 268 | 269 | parser.add_argument('--get_flops', default=False, type=str2bool, nargs='?', 270 | help='add hooks to compute flops') 271 | 272 | parser.add_argument('--get_inference_time', default=False, type=str2bool, nargs='?', 273 | help='runs valid multiple times and reports the result') 274 | 275 | parser.add_argument('--mgpu', default=False, type=str2bool, nargs='?', 276 | help='use data paralization via multiple GPUs') 277 | 278 | parser.add_argument('--dataset', default="MNIST", type=str, 279 | help='dataset for experiment, choice: MNIST, CIFAR10', choices= ["MNIST", "CIFAR10", "Imagenet"]) 280 | 281 | parser.add_argument('--data', metavar='DIR', default='/imagenet', help='path to imagenet dataset') 282 | 283 | parser.add_argument('--model', default="lenet3", type=str, 284 | help='model selection, choices: lenet3, vgg, mobilenetv2, resnet18', 285 | choices=["lenet3", "vgg", "mobilenetv2", "resnet18", "resnet152", "resnet50", "resnet50_noskip", 286 | "resnet20", "resnet34", "resnet101", "resnet101_noskip", "densenet201_imagenet", 287 | 'densenet121_imagenet']) 288 | 289 | parser.add_argument('--tensorboard', type=str2bool, nargs='?', 290 | help='Log progress to TensorBoard') 291 | 292 | parser.add_argument('--save_models', default=True, type=str2bool, nargs='?', 293 | help='if True, models will be saved to the local folder') 294 | 295 | 296 | # ============================PRUNING added 297 | parser.add_argument('--pruning_config', default=None, type=str, 298 | help='path to pruning configuration file, will overwrite all pruning parameters in arguments') 299 | 300 | parser.add_argument('--group_wd_coeff', type=float, default=0.0, 301 | help='group weight decay') 302 | parser.add_argument('--name', default='test', type=str, 303 | help='experiment name(folder) to store logs') 304 | 305 | parser.add_argument('--augment', default=False, type=str2bool, nargs='?', 306 | help='enable or not augmentation of training dataset, only for CIFAR, def False') 307 | 308 | parser.add_argument('--load_model', default='', type=str, 309 | help='path to model weights') 310 | 311 | parser.add_argument('--pruning', default=False, type=str2bool, nargs='?', 312 | help='enable or not pruning, def False') 313 | 314 | parser.add_argument('--pruning-threshold', '--pt', default=100.0, type=float, 315 | help='Max error perc on validation set while pruning (default: 100.0 means always prune)') 316 | 317 | parser.add_argument('--pruning-momentum', default=0.0, type=float, 318 | help='Use momentum on criteria between pruning iterations, def 0.0 means no momentum') 319 | 320 | parser.add_argument('--pruning-step', default=15, type=int, 321 | help='How often to check loss and do pruning step') 322 | 323 | parser.add_argument('--prune_per_iteration', default=10, type=int, 324 | help='How many neurons to remove at each iteration') 325 | 326 | parser.add_argument('--fixed_layer', default=-1, type=int, 327 | help='Prune only a given layer with index, use -1 to prune all') 328 | 329 | parser.add_argument('--start_pruning_after_n_iterations', default=0, type=int, 330 | help='from which iteration to start pruning') 331 | 332 | parser.add_argument('--maximum_pruning_iterations', default=1e8, type=int, 333 | help='maximum pruning iterations') 334 | 335 | parser.add_argument('--starting_neuron', default=0, type=int, 336 | help='starting position for oracle pruning') 337 | 338 | parser.add_argument('--prune_neurons_max', default=-1, type=int, 339 | help='prune_neurons_max') 340 | 341 | parser.add_argument('--pruning-method', default=0, type=int, 342 | help='pruning method to be used, see readme.md') 343 | 344 | parser.add_argument('--pruning_fixed_criteria', default=False, type=str2bool, nargs='?', 345 | help='enable or not criteria reevaluation, def False') 346 | 347 | parser.add_argument('--fixed_network', default=False, type=str2bool, nargs='?', 348 | help='fix network for oracle or criteria computation') 349 | 350 | parser.add_argument('--zero_lr_for_epochs', default=-1, type=int, 351 | help='Learning rate will be set to 0 for given number of updates') 352 | 353 | parser.add_argument('--dynamic_network', default=False, type=str2bool, nargs='?', 354 | help='Creates a new network graph from pruned model, works with ResNet-101 only') 355 | 356 | parser.add_argument('--use_test_as_train', default=False, type=str2bool, nargs='?', 357 | help='use testing dataset instead of training') 358 | 359 | parser.add_argument('--pruning_mask_from', default='', type=str, 360 | help='path to mask file precomputed') 361 | 362 | parser.add_argument('--compute_flops', default=True, type=str2bool, nargs='?', 363 | help='if True, will run dummy inference of batch 1 before training to get conv sizes') 364 | 365 | 366 | 367 | # ============================END pruning added 368 | 369 | best_prec1 = 0 370 | global global_iteration 371 | global group_wd_optimizer 372 | global_iteration = 0 373 | 374 | args = parser.parse_args() 375 | use_cuda = not args.no_cuda and torch.cuda.is_available() 376 | 377 | torch.manual_seed(args.seed) 378 | 379 | args.distributed = args.world_size > 1 380 | if args.distributed: 381 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 382 | world_size=args.world_size, rank=0) 383 | 384 | device = torch.device("cuda" if use_cuda else "cpu") 385 | 386 | if args.model == "lenet3": 387 | model = LeNet(dataset=args.dataset) 388 | elif args.model == "vgg": 389 | model = vgg11_bn(pretrained=True) 390 | elif args.model == "resnet18": 391 | model = PreActResNet18() 392 | elif (args.model == "resnet50") or (args.model == "resnet50_noskip"): 393 | if args.dataset == "CIFAR10": 394 | model = PreActResNet50(dataset=args.dataset) 395 | else: 396 | from models.resnet import resnet50 397 | skip_gate = True 398 | if "noskip" in args.model: 399 | skip_gate = False 400 | 401 | if args.pruning_method not in [22, 40]: 402 | skip_gate = False 403 | model = resnet50(skip_gate=skip_gate) 404 | elif args.model == "resnet34": 405 | if not (args.dataset == "CIFAR10"): 406 | from models.resnet import resnet34 407 | model = resnet34() 408 | elif "resnet101" in args.model: 409 | if not (args.dataset == "CIFAR10"): 410 | from models.resnet import resnet101 411 | if args.dataset == "Imagenet": 412 | classes = 1000 413 | 414 | if "noskip" in args.model: 415 | model = resnet101(num_classes=classes, skip_gate=False) 416 | else: 417 | model = resnet101(num_classes=classes) 418 | 419 | elif args.model == "resnet20": 420 | if args.dataset == "CIFAR10": 421 | NotImplementedError("resnet20 is not implemented in the current project") 422 | # from models.resnet_cifar import resnet20 423 | # model = resnet20() 424 | elif args.model == "resnet152": 425 | model = PreActResNet152() 426 | elif args.model == "densenet201_imagenet": 427 | from models.densenet_imagenet import DenseNet201 428 | model = DenseNet201(gate_types=['output_bn'], pretrained=True) 429 | elif args.model == "densenet121_imagenet": 430 | from models.densenet_imagenet import DenseNet121 431 | model = DenseNet121(gate_types=['output_bn'], pretrained=True) 432 | else: 433 | print(args.model, "model is not supported") 434 | 435 | # dataset loading section 436 | if args.dataset == "MNIST": 437 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 438 | train_loader = torch.utils.data.DataLoader( 439 | datasets.MNIST('../data', train=True, download=True, 440 | transform=transforms.Compose([ 441 | transforms.ToTensor(), 442 | transforms.Normalize((0.1307,), (0.3081,)) 443 | ])), 444 | batch_size=args.batch_size, shuffle=True, **kwargs) 445 | test_loader = torch.utils.data.DataLoader( 446 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 447 | transforms.ToTensor(), 448 | transforms.Normalize((0.1307,), (0.3081,)) 449 | ])), 450 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 451 | 452 | elif args.dataset == "CIFAR10": 453 | # Data loading code 454 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 455 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 456 | 457 | if args.augment: 458 | transform_train = transforms.Compose([ 459 | transforms.RandomCrop(32, padding=4), 460 | transforms.RandomHorizontalFlip(), 461 | transforms.ToTensor(), 462 | normalize, 463 | ]) 464 | else: 465 | transform_train = transforms.Compose([ 466 | transforms.ToTensor(), 467 | normalize, 468 | ]) 469 | 470 | transform_test = transforms.Compose([ 471 | transforms.ToTensor(), 472 | normalize 473 | ]) 474 | 475 | kwargs = {'num_workers': 8, 'pin_memory': True} 476 | train_loader = torch.utils.data.DataLoader( 477 | datasets.CIFAR10('../data', train=True, download=True, 478 | transform=transform_train), 479 | batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) 480 | 481 | test_loader = torch.utils.data.DataLoader( 482 | datasets.CIFAR10('../data', train=False, transform=transform_test), 483 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 484 | 485 | elif args.dataset == "Imagenet": 486 | traindir = os.path.join(args.data, 'train') 487 | valdir = os.path.join(args.data, 'val') 488 | 489 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 490 | std=[0.229, 0.224, 0.225]) 491 | 492 | train_dataset = datasets.ImageFolder( 493 | traindir, 494 | transforms.Compose([ 495 | transforms.RandomResizedCrop(224), 496 | transforms.RandomHorizontalFlip(), 497 | transforms.ToTensor(), 498 | normalize, 499 | ])) 500 | 501 | if args.distributed: 502 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 503 | else: 504 | train_sampler = None 505 | 506 | kwargs = {'num_workers': 16} 507 | 508 | train_loader = torch.utils.data.DataLoader( 509 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 510 | sampler=train_sampler, pin_memory=True, **kwargs) 511 | 512 | if args.use_test_as_train: 513 | train_loader = torch.utils.data.DataLoader( 514 | datasets.ImageFolder(valdir, transforms.Compose([ 515 | transforms.Resize(256), 516 | transforms.CenterCrop(224), 517 | transforms.ToTensor(), 518 | normalize, 519 | ])), 520 | batch_size=args.batch_size, shuffle=(train_sampler is None), **kwargs) 521 | 522 | 523 | test_loader = torch.utils.data.DataLoader( 524 | datasets.ImageFolder(valdir, transforms.Compose([ 525 | transforms.Resize(256), 526 | transforms.CenterCrop(224), 527 | transforms.ToTensor(), 528 | normalize, 529 | ])), 530 | batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs) 531 | 532 | ####end dataset preparation 533 | 534 | if args.dynamic_network: 535 | # attempts to load pruned model and modify it be removing pruned channels 536 | # works for resnet101 only 537 | if (len(args.load_model) > 0) and (args.dynamic_network): 538 | if os.path.isfile(args.load_model): 539 | load_model_pytorch(model, args.load_model, args.model) 540 | 541 | else: 542 | print("=> no checkpoint found at '{}'".format(args.load_model)) 543 | exit() 544 | 545 | dynamic_network_change_local(model) 546 | 547 | # save the model 548 | log_save_folder = "%s"%args.name 549 | if not os.path.exists(log_save_folder): 550 | os.makedirs(log_save_folder) 551 | 552 | if not os.path.exists("%s/models" % (log_save_folder)): 553 | os.makedirs("%s/models" % (log_save_folder)) 554 | 555 | model_save_path = "%s/models/pruned.weights"%(log_save_folder) 556 | model_state_dict = model.state_dict() 557 | if args.save_models: 558 | save_checkpoint({ 559 | 'state_dict': model_state_dict 560 | }, False, filename = model_save_path) 561 | 562 | print("model is defined") 563 | 564 | # aux function to get size of feature maps 565 | # First it adds hooks for each conv layer 566 | # Then runs inference with 1 image 567 | output_sizes = get_conv_sizes(args, model) 568 | 569 | if use_cuda and not args.mgpu: 570 | model = model.to(device) 571 | elif args.distributed: 572 | model.cuda() 573 | print("\n\n WARNING: distributed pruning was not verified and might not work correctly") 574 | model = torch.nn.parallel.DistributedDataParallel(model) 575 | elif args.mgpu: 576 | model = torch.nn.DataParallel(model).cuda() 577 | else: 578 | model = model.to(device) 579 | 580 | print("model is set to device: use_cuda {}, args.mgpu {}, agrs.distributed {}".format(use_cuda, args.mgpu, args.distributed)) 581 | 582 | weight_decay = args.wd 583 | if args.fixed_network: 584 | weight_decay = 0.0 585 | 586 | # remove updates from gate layers, because we want them to be 0 or 1 constantly 587 | if 1: 588 | parameters_for_update = [] 589 | parameters_for_update_named = [] 590 | for name, m in model.named_parameters(): 591 | if "gate" not in name: 592 | parameters_for_update.append(m) 593 | parameters_for_update_named.append((name, m)) 594 | else: 595 | print("skipping parameter", name, "shape:", m.shape) 596 | 597 | total_size_params = sum([np.prod(par.shape) for par in parameters_for_update]) 598 | print("Total number of parameters, w/o usage of bn consts: ", total_size_params) 599 | 600 | optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay) 601 | 602 | if 1: 603 | # helping optimizer to implement group lasso (with very small weight that doesn't affect training) 604 | # will be used to calculate number of remaining flops and parameters in the network 605 | group_wd_optimizer = group_lasso_decay(parameters_for_update, group_lasso_weight=args.group_wd_coeff, named_parameters=parameters_for_update_named, output_sizes=output_sizes) 606 | 607 | cudnn.benchmark = True 608 | 609 | # define objective 610 | criterion = nn.CrossEntropyLoss() 611 | 612 | ###=======================added for pruning 613 | # logging part 614 | log_save_folder = "%s"%args.name 615 | if not os.path.exists(log_save_folder): 616 | os.makedirs(log_save_folder) 617 | 618 | if not os.path.exists("%s/models" % (log_save_folder)): 619 | os.makedirs("%s/models" % (log_save_folder)) 620 | 621 | train_writer = None 622 | if args.tensorboard: 623 | try: 624 | # tensorboardX v1.6 625 | train_writer = SummaryWriter(log_dir="%s"%(log_save_folder)) 626 | except: 627 | # tensorboardX v1.7 628 | train_writer = SummaryWriter(logdir="%s"%(log_save_folder)) 629 | 630 | time_point = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) 631 | textfile = "%s/log_%s.txt" % (log_save_folder, time_point) 632 | stdout = Logger(textfile) 633 | sys.stdout = stdout 634 | print(" ".join(sys.argv)) 635 | 636 | # initializing parameters for pruning 637 | # we can add weights of different layers or we can add gates (multiplies output with 1, useful only for gradient computation) 638 | pruning_engine = None 639 | if args.pruning: 640 | pruning_settings = dict() 641 | if not (args.pruning_config is None): 642 | pruning_settings_reader = PruningConfigReader() 643 | pruning_settings_reader.read_config(args.pruning_config) 644 | pruning_settings = pruning_settings_reader.get_parameters() 645 | 646 | # overwrite parameters from config file with those from command line 647 | # needs manual entry here 648 | # user_specified = [key for key in vars(default_args).keys() if not (vars(default_args)[key]==vars(args)[key])] 649 | # argv_of_interest = ['pruning_threshold', 'pruning-momentum', 'pruning_step', 'prune_per_iteration', 650 | # 'fixed_layer', 'start_pruning_after_n_iterations', 'maximum_pruning_iterations', 651 | # 'starting_neuron', 'prune_neurons_max', 'pruning_method'] 652 | 653 | has_attribute = lambda x: any([x in a for a in sys.argv]) 654 | 655 | if has_attribute('pruning-momentum'): 656 | pruning_settings['pruning_momentum'] = vars(args)['pruning_momentum'] 657 | if has_attribute('pruning-method'): 658 | pruning_settings['method'] = vars(args)['pruning_method'] 659 | 660 | pruning_parameters_list = prepare_pruning_list(pruning_settings, model, model_name=args.model, 661 | pruning_mask_from=args.pruning_mask_from, name=args.name) 662 | print("Total pruning layers:", len(pruning_parameters_list)) 663 | 664 | folder_to_write = "%s"%log_save_folder+"/" 665 | log_folder = folder_to_write 666 | 667 | pruning_engine = pytorch_pruning(pruning_parameters_list, pruning_settings=pruning_settings, log_folder=log_folder) 668 | 669 | pruning_engine.connect_tensorboard(train_writer) 670 | pruning_engine.dataset = args.dataset 671 | pruning_engine.model = args.model 672 | pruning_engine.pruning_mask_from = args.pruning_mask_from 673 | pruning_engine.load_mask() 674 | gates_to_params = connect_gates_with_parameters_for_flops(args.model, parameters_for_update_named) 675 | pruning_engine.gates_to_params = gates_to_params 676 | 677 | ###=======================end for pruning 678 | # loading model file 679 | if (len(args.load_model) > 0) and (not args.dynamic_network): 680 | if os.path.isfile(args.load_model): 681 | load_model_pytorch(model, args.load_model, args.model) 682 | else: 683 | print("=> no checkpoint found at '{}'".format(args.load_model)) 684 | exit() 685 | 686 | if args.tensorboard and 0: 687 | if args.dataset == "CIFAR10": 688 | dummy_input = torch.rand(1, 3, 32, 32).to(device) 689 | elif args.dataset == "Imagenet": 690 | dummy_input = torch.rand(1, 3, 224, 224).to(device) 691 | 692 | train_writer.add_graph(model, dummy_input) 693 | 694 | for epoch in range(1, args.epochs + 1): 695 | if args.distributed: 696 | train_sampler.set_epoch(epoch) 697 | adjust_learning_rate(args, optimizer, epoch, args.zero_lr_for_epochs, train_writer) 698 | 699 | if not args.run_test and not args.get_inference_time: 700 | train(args, model, device, train_loader, optimizer, epoch, criterion, train_writer=train_writer, pruning_engine=pruning_engine) 701 | 702 | if args.pruning: 703 | # skip validation error calculation and model saving 704 | if pruning_engine.method == 50: continue 705 | 706 | # evaluate on validation set 707 | prec1, _ = validate(args, test_loader, model, device, criterion, epoch, train_writer=train_writer) 708 | 709 | # remember best prec@1 and save checkpoint 710 | is_best = prec1 > best_prec1 711 | best_prec1 = max(prec1, best_prec1) 712 | model_save_path = "%s/models/checkpoint.weights"%(log_save_folder) 713 | model_state_dict = model.state_dict() 714 | if args.save_models: 715 | save_checkpoint({ 716 | 'epoch': epoch + 1, 717 | 'state_dict': model_state_dict, 718 | 'best_prec1': best_prec1, 719 | }, is_best, filename=model_save_path) 720 | 721 | 722 | if __name__ == '__main__': 723 | main() 724 | -------------------------------------------------------------------------------- /models/densenet_imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | # based on https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py 7 | 8 | import re 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.model_zoo as model_zoo 13 | from collections import OrderedDict 14 | from layers.gate_layer import GateLayer 15 | 16 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 17 | 18 | 19 | model_urls = { 20 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 21 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 22 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 23 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 24 | } 25 | 26 | 27 | 28 | def DenseNet121(pretrained=False, **kwargs): 29 | r"""Densenet-121 model from 30 | `"Densely Connected Convolutional Networks" `_ 31 | 32 | Args: 33 | pretrained (bool): If True, returns a model pre-trained on ImageNet 34 | """ 35 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 36 | **kwargs) 37 | if pretrained: 38 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 39 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 40 | # They are also in the checkpoints in model_urls. This pattern is used 41 | # to find such keys. 42 | pattern = re.compile( 43 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 44 | state_dict = model_zoo.load_url(model_urls['densenet121']) 45 | for key in list(state_dict.keys()): 46 | res = pattern.match(key) 47 | if res: 48 | new_key = res.group(1) + res.group(2) 49 | state_dict[new_key] = state_dict[key] 50 | del state_dict[key] 51 | model.load_state_dict(state_dict, strict=False) 52 | return model 53 | 54 | 55 | def DenseNet169(pretrained=False, **kwargs): 56 | r"""Densenet-169 model from 57 | `"Densely Connected Convolutional Networks" `_ 58 | 59 | Args: 60 | pretrained (bool): If True, returns a model pre-trained on ImageNet 61 | """ 62 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 63 | **kwargs) 64 | if pretrained: 65 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 66 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 67 | # They are also in the checkpoints in model_urls. This pattern is used 68 | # to find such keys. 69 | pattern = re.compile( 70 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 71 | state_dict = model_zoo.load_url(model_urls['densenet169']) 72 | for key in list(state_dict.keys()): 73 | res = pattern.match(key) 74 | if res: 75 | new_key = res.group(1) + res.group(2) 76 | state_dict[new_key] = state_dict[key] 77 | del state_dict[key] 78 | model.load_state_dict(state_dict, strict=False) 79 | return model 80 | 81 | 82 | def DenseNet201(pretrained=False, **kwargs): 83 | r"""Densenet-201 model from 84 | `"Densely Connected Convolutional Networks" `_ 85 | 86 | Args: 87 | pretrained (bool): If True, returns a model pre-trained on ImageNet 88 | """ 89 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 90 | **kwargs) 91 | if pretrained: 92 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 93 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 94 | # They are also in the checkpoints in model_urls. This pattern is used 95 | # to find such keys. 96 | pattern = re.compile( 97 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 98 | state_dict = model_zoo.load_url(model_urls['densenet201'], model_dir='models/pretrained') 99 | for key in list(state_dict.keys()): 100 | res = pattern.match(key) 101 | if res: 102 | new_key = res.group(1) + res.group(2) 103 | state_dict[new_key] = state_dict[key] 104 | del state_dict[key] 105 | model.load_state_dict(state_dict, strict=False) 106 | return model 107 | 108 | 109 | def densenet161(pretrained=False, **kwargs): 110 | r"""Densenet-161 model from 111 | `"Densely Connected Convolutional Networks" `_ 112 | 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | """ 116 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 117 | **kwargs) 118 | if pretrained: 119 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 120 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 121 | # They are also in the checkpoints in model_urls. This pattern is used 122 | # to find such keys. 123 | pattern = re.compile( 124 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 125 | state_dict = model_zoo.load_url(model_urls['densenet161']) 126 | for key in list(state_dict.keys()): 127 | res = pattern.match(key) 128 | if res: 129 | new_key = res.group(1) + res.group(2) 130 | state_dict[new_key] = state_dict[key] 131 | del state_dict[key] 132 | model.load_state_dict(state_dict, strict=False) 133 | return model 134 | 135 | 136 | class _DenseLayer(nn.Sequential): 137 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, gate_types): 138 | super(_DenseLayer, self).__init__() 139 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 140 | self.add_module('relu1', nn.ReLU(inplace=True)), 141 | if 'input' in gate_types: 142 | self.add_module('gate1): (input', GateLayer(num_input_features,num_input_features,[1, -1, 1, 1])) 143 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 144 | growth_rate, kernel_size=1, stride=1, bias=False)), 145 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 146 | if 'output_bn' in gate_types: 147 | self.add_module('gate2): (output_bn', GateLayer(bn_size * growth_rate,bn_size * growth_rate,[1, -1, 1, 1])) 148 | self.add_module('relu2', nn.ReLU(inplace=True)), 149 | 150 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 151 | kernel_size=3, stride=1, padding=1, bias=False)), 152 | if 'output_conv' in gate_types: 153 | self.add_module('gate3): (output_conv', GateLayer(growth_rate,growth_rate,[1, -1, 1, 1])) 154 | self.drop_rate = drop_rate 155 | 156 | def forward(self, x): 157 | new_features = super(_DenseLayer, self).forward(x) 158 | if self.drop_rate > 0: 159 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 160 | return torch.cat([x, new_features], 1) 161 | 162 | 163 | class _DenseBlock(nn.Sequential): 164 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, gate_types): 165 | super(_DenseBlock, self).__init__() 166 | for i in range(num_layers): 167 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, gate_types) 168 | self.add_module('denselayer%d' % (i + 1), layer) 169 | 170 | 171 | class _Transition(nn.Sequential): 172 | def __init__(self, num_input_features, num_output_features, gate_types): 173 | super(_Transition, self).__init__() 174 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 175 | self.add_module('relu', nn.ReLU(inplace=True)) 176 | if 'input' in gate_types: 177 | self.add_module('gate): (input', GateLayer(num_input_features,num_input_features,[1, -1, 1, 1])) 178 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 179 | kernel_size=1, stride=1, bias=False)) 180 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 181 | if 'output_conv' in gate_types: 182 | self.add_module('gate): (output_conv', GateLayer(num_output_features,num_output_features,[1, -1, 1, 1])) 183 | 184 | 185 | class DenseNet(nn.Module): 186 | r"""Densenet-BC model class, based on 187 | `"Densely Connected Convolutional Networks" `_ 188 | 189 | Args: 190 | growth_rate (int) - how many filters to add each layer (`k` in paper) 191 | block_config (list of 4 ints) - how many layers in each pooling block 192 | num_init_features (int) - the number of filters to learn in the first convolution layer 193 | bn_size (int) - multiplicative factor for number of bottle neck layers 194 | (i.e. bn_size * k features in the bottleneck layer) 195 | drop_rate (float) - dropout rate after each dense layer 196 | num_classes (int) - number of classification classes 197 | """ 198 | 199 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 200 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, 201 | gate_types=['input','output_bn','output_conv','bottom','top']): 202 | 203 | super(DenseNet, self).__init__() 204 | 205 | # First convolution 206 | self.features = nn.Sequential(OrderedDict([ 207 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 208 | ('norm0', nn.BatchNorm2d(num_init_features)), 209 | ('relu0', nn.ReLU(inplace=True)), 210 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 211 | ])) 212 | if 'bottom' in gate_types: 213 | self.features.add_module('gate0): (bottom', GateLayer(num_init_features,num_init_features,[1, -1, 1, 1])) 214 | 215 | # Each denseblock 216 | num_features = num_init_features 217 | for i, num_layers in enumerate(block_config): 218 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 219 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, 220 | gate_types=gate_types) 221 | self.features.add_module('denseblock%d' % (i + 1), block) 222 | num_features = num_features + num_layers * growth_rate 223 | if i != len(block_config) - 1: 224 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, 225 | gate_types=gate_types) 226 | self.features.add_module('transition%d' % (i + 1), trans) 227 | num_features = num_features // 2 228 | 229 | # Final batch norm 230 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 231 | if 'top' in gate_types: 232 | self.features.add_module('gate5): (top', GateLayer(num_features,num_features,[1, -1, 1, 1])) 233 | 234 | # Linear layer 235 | self.classifier = nn.Linear(num_features, num_classes) 236 | 237 | # Official init from torch repo. 238 | for m in self.modules(): 239 | if isinstance(m, nn.Conv2d): 240 | nn.init.kaiming_normal_(m.weight) 241 | elif isinstance(m, nn.BatchNorm2d): 242 | nn.init.constant_(m.weight, 1) 243 | nn.init.constant_(m.bias, 0) 244 | elif isinstance(m, nn.Linear): 245 | nn.init.constant_(m.bias, 0) 246 | 247 | def forward(self, x): 248 | features = self.features(x) 249 | out = F.relu(features, inplace=True) 250 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 251 | out = self.classifier(out) 252 | return out 253 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | # based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py 7 | 8 | '''LeNet in PyTorch.''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from layers.gate_layer import GateLayer 13 | 14 | 15 | class LeNet(nn.Module): 16 | def __init__(self, dataset="CIFAR10"): 17 | super(LeNet, self).__init__() 18 | if dataset=="CIFAR10": 19 | nunits_input = 3 20 | nuintis_fc = 32*5*5 21 | elif dataset=="MNIST": 22 | nunits_input = 1 23 | nuintis_fc = 32*4*4 24 | self.conv1 = nn.Conv2d(nunits_input, 16, 5) 25 | self.gate1 = GateLayer(16,16,[1, -1, 1, 1]) 26 | self.conv2 = nn.Conv2d(16, 32, 5) 27 | self.gate2 = GateLayer(32,32,[1, -1, 1, 1]) 28 | self.fc1 = nn.Linear(nuintis_fc, 120) 29 | self.gate3 = GateLayer(120,120,[1, -1]) 30 | self.fc2 = nn.Linear(120, 84) 31 | self.gate4 = GateLayer(84,84,[1, -1]) 32 | self.fc3 = nn.Linear(84, 10) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.conv1(x)) 36 | out = F.max_pool2d(out, 2) 37 | out = self.gate1(out) 38 | out = F.relu(self.conv2(out)) 39 | out = F.max_pool2d(out, 2) 40 | out = self.gate2(out) 41 | out = out.view(out.size(0), -1) 42 | out = F.relu(self.fc1(out)) 43 | out = self.gate3(out) 44 | out = F.relu(self.fc2(out)) 45 | out = self.gate4(out) 46 | out = self.fc3(out) 47 | return out -------------------------------------------------------------------------------- /models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | # based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py 7 | 8 | '''Pre-activation ResNet in PyTorch. 9 | Reference: 10 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 11 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from layers.gate_layer import GateLayer 18 | 19 | 20 | def norm2d(planes, num_groups=32): 21 | if num_groups != 0: 22 | print("num_groups:{}".format(num_groups)) 23 | if num_groups > 0: 24 | return GroupNorm2D(planes, num_groups) 25 | else: 26 | return nn.BatchNorm2d(planes) 27 | 28 | 29 | class PreActBlock(nn.Module): 30 | '''Pre-activation version of the BasicBlock.''' 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1, group_norm=0): 34 | super(PreActBlock, self).__init__() 35 | self.bn1 = norm2d(in_planes, group_norm) 36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 37 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1]) 38 | self.bn2 = norm2d(planes, group_norm) 39 | 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 41 | self.gate_out = GateLayer(planes,planes,[1, -1, 1, 1]) 42 | 43 | if stride != 1 or in_planes != self.expansion*planes: 44 | self.shortcut = nn.Sequential( 45 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 46 | ) 47 | self.gate_shortcut = GateLayer(self.expansion*planes,self.expansion*planes,[1, -1, 1, 1]) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(x)) 51 | 52 | if hasattr(self, 'shortcut'): 53 | shortcut = self.shortcut(out) 54 | shortcut = self.gate_shortcut(shortcut) 55 | else: 56 | shortcut = x 57 | 58 | out = self.conv1(out) 59 | 60 | out = self.bn2(out) 61 | out = self.gate1(out) 62 | 63 | out = F.relu(out) 64 | out = self.conv2(out) 65 | out = self.gate_out(out) 66 | 67 | out = out + shortcut 68 | ##as a block here we might benefit with gate at this stage 69 | 70 | return out 71 | 72 | 73 | class PreActBottleneck(nn.Module): 74 | '''Pre-activation version of the original Bottleneck module.''' 75 | expansion = 4 76 | 77 | def __init__(self, in_planes, planes, stride=1, group_norm=0): 78 | super(PreActBottleneck, self).__init__() 79 | 80 | self.bn1 = norm2d(in_planes, group_norm) 81 | 82 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 83 | self.bn2 = norm2d(planes, group_norm) 84 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1]) 85 | 86 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 87 | self.bn3 = norm2d(planes, group_norm) 88 | self.gate2 = GateLayer(planes,planes,[1, -1, 1, 1]) 89 | 90 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 91 | self.gate3 = GateLayer(self.expansion*planes,self.expansion*planes,[1, -1, 1, 1]) 92 | 93 | if stride != 1 or in_planes != self.expansion*planes: 94 | self.shortcut = nn.Sequential( 95 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 96 | self.gate_shortcut = GateLayer(self.expansion*planes,self.expansion*planes,[1, -1, 1, 1]) 97 | 98 | def forward(self, x): 99 | out = F.relu(self.bn1(x)) 100 | input_out = out 101 | 102 | out = self.conv1(out) 103 | out = self.bn2(out) 104 | out = self.gate1(out) 105 | out = F.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn3(out) 109 | out = self.gate2(out) 110 | 111 | out = F.relu(out) 112 | 113 | out = self.conv3(out) 114 | out = self.gate3(out) 115 | 116 | if hasattr(self, 'shortcut'): 117 | shortcut = self.shortcut(input_out) 118 | shortcut = self.gate_shortcut(shortcut) 119 | else: 120 | shortcut = x 121 | 122 | out = out + shortcut 123 | return out 124 | 125 | 126 | class PreActResNet(nn.Module): 127 | def __init__(self, block, num_blocks, num_classes=10, group_norm=0, dataset="CIFAR10"): 128 | super(PreActResNet, self).__init__() 129 | 130 | self.in_planes = 64 131 | self.dataset = dataset 132 | 133 | if dataset == "CIFAR10": 134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 135 | num_classes = 10 136 | elif dataset == "Imagenet": 137 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 138 | bias=False) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | num_classes = 1000 143 | 144 | self.gate_in = GateLayer(64, 64, [1, -1, 1, 1]) 145 | 146 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, group_norm=group_norm) 147 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, group_norm=group_norm) 148 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, group_norm=group_norm) 149 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, group_norm=group_norm) 150 | 151 | if dataset == "CIFAR10": 152 | self.avgpool = nn.AvgPool2d(4, stride=1) 153 | self.linear = nn.Linear(512*block.expansion, num_classes) 154 | elif dataset == "Imagenet": 155 | self.avgpool = nn.AvgPool2d(7, stride=1) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | if isinstance(m, nn.BatchNorm2d): 164 | m.bias.data.zero_() 165 | 166 | def _make_layer(self, block, planes, num_blocks, stride, group_norm = 0): 167 | strides = [stride] + [1]*(num_blocks-1) 168 | layers = [] 169 | for stride in strides: 170 | layers.append(block(self.in_planes, planes, stride, group_norm = group_norm)) 171 | self.in_planes = planes * block.expansion 172 | return nn.Sequential(*layers) 173 | 174 | def forward(self, x): 175 | out = self.conv1(x) 176 | if self.dataset == "Imagenet": 177 | out = self.bn1(out) 178 | out = self.relu(out) 179 | out = self.maxpool(out) 180 | 181 | out = self.gate_in(out) 182 | 183 | out = self.layer1(out) 184 | out = self.layer2(out) 185 | out = self.layer3(out) 186 | out = self.layer4(out) 187 | out = self.avgpool(out) 188 | 189 | out = out.view(out.size(0), -1) 190 | if self.dataset == "CIFAR10": 191 | out = self.linear(out) 192 | elif self.dataset == "Imagenet": 193 | out = self.fc(out) 194 | 195 | return out 196 | 197 | 198 | def PreActResNet18(group_norm = 0): 199 | return PreActResNet(PreActBlock, [2,2,2,2], group_norm= group_norm) 200 | 201 | def PreActResNet34(): 202 | return PreActResNet(PreActBlock, [3,4,6,3]) 203 | 204 | def PreActResNet50(group_norm=0, dataset = "CIFAR10"): 205 | return PreActResNet(PreActBottleneck, [3,4,6,3], group_norm = group_norm, dataset = dataset) 206 | 207 | def PreActResNet101(): 208 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 209 | 210 | def PreActResNet152(): 211 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 212 | 213 | 214 | def test(): 215 | net = PreActResNet18() 216 | y = net((torch.randn(1,3,32,32))) 217 | print(y.size()) 218 | 219 | # test() 220 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | # based on https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | 8 | import torch.nn as nn 9 | import torch.utils.model_zoo as model_zoo 10 | from layers.gate_layer import GateLayer 11 | 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, gate=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = nn.BatchNorm2d(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1]) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.gate2 = GateLayer(planes,planes,[1, -1, 1, 1]) 44 | self.downsample = downsample 45 | self.stride = stride 46 | self.gate = gate 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.gate1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | out = self.gate2(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | if self.gate is not None: 67 | out = self.gate(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, downsample=None, gate=None): 76 | super(Bottleneck, self).__init__() 77 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(planes) 79 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1]) 80 | self.relu1 = nn.ReLU(inplace=True) 81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 82 | padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.gate2 = GateLayer(planes,planes,[1, -1, 1, 1]) 85 | self.relu2 = nn.ReLU(inplace=True) 86 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 87 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 88 | self.relu3 = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | self.gate = gate 92 | 93 | def forward(self, x): 94 | residual = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.gate1(out) 99 | 100 | out = self.relu1(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.gate2(out) 105 | out = self.relu2(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | out = self.relu3(out) 115 | if self.gate is not None: 116 | out = self.gate(out) 117 | 118 | return out 119 | 120 | 121 | class ResNet(nn.Module): 122 | 123 | def __init__(self, block, layers, num_classes=1000, skip_gate = True): 124 | self.inplanes = 64 125 | super(ResNet, self).__init__() 126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = nn.BatchNorm2d(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | 132 | gate = skip_gate 133 | self.gate = gate 134 | if gate: 135 | # self.gate_skip1 = GateLayer(64,64,[1, -1, 1, 1]) 136 | self.gate_skip64 = GateLayer(64*4,64*4,[1, -1, 1, 1]) 137 | self.gate_skip128 = GateLayer(128*4,128*4,[1, -1, 1, 1]) 138 | self.gate_skip256 = GateLayer(256*4,256*4,[1, -1, 1, 1]) 139 | self.gate_skip512 = GateLayer(512*4,512*4,[1, -1, 1, 1]) 140 | if block == BasicBlock: 141 | self.gate_skip64 = GateLayer(64, 64, [1, -1, 1, 1]) 142 | self.gate_skip128 = GateLayer(128, 128, [1, -1, 1, 1]) 143 | self.gate_skip256 = GateLayer(256, 256, [1, -1, 1, 1]) 144 | self.gate_skip512 = GateLayer(512, 512, [1, -1, 1, 1]) 145 | else: 146 | self.gate_skip64 = None 147 | self.gate_skip128 = None 148 | self.gate_skip256 = None 149 | self.gate_skip512 = None 150 | 151 | self.layer1 = self._make_layer(block, 64, layers[0], gate = self.gate_skip64) 152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, gate=self.gate_skip128) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, gate=self.gate_skip256) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, gate=self.gate_skip512) 155 | self.avgpool = nn.AvgPool2d(7, stride=1) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, nn.BatchNorm2d): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | def _make_layer(self, block, planes, blocks, stride=1, gate = None): 166 | downsample = None 167 | if stride != 1 or self.inplanes != planes * block.expansion: 168 | downsample = nn.Sequential( 169 | nn.Conv2d(self.inplanes, planes * block.expansion, 170 | kernel_size=1, stride=stride, bias=False), 171 | nn.BatchNorm2d(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, downsample, gate = gate)) 176 | 177 | self.inplanes = planes * block.expansion 178 | for i in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, gate = gate)) 180 | 181 | return nn.Sequential(*layers) 182 | 183 | def forward(self, x): 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | x = self.maxpool(x) 188 | 189 | # if self.gate: 190 | # x=self.gate_skip1(x) 191 | 192 | x = self.layer1(x) 193 | x = self.layer2(x) 194 | x = self.layer3(x) 195 | x = self.layer4(x) 196 | 197 | x = self.avgpool(x) 198 | x = x.view(x.size(0), -1) 199 | x = self.fc(x) 200 | 201 | return x 202 | 203 | 204 | def resnet18(pretrained=False, **kwargs): 205 | """Constructs a ResNet-18 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 213 | return model 214 | 215 | 216 | def resnet34(pretrained=False, **kwargs): 217 | """Constructs a ResNet-34 model. 218 | 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | """ 222 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 223 | if pretrained: 224 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 225 | return model 226 | 227 | 228 | def resnet50(pretrained=False, **kwargs): 229 | """Constructs a ResNet-50 model. 230 | 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | """ 234 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 235 | if pretrained: 236 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 237 | return model 238 | 239 | 240 | def resnet101(pretrained=False, **kwargs): 241 | """Constructs a ResNet-101 model. 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 247 | if pretrained: 248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 249 | return model 250 | 251 | 252 | def resnet152(pretrained=False, **kwargs): 253 | """Constructs a ResNet-152 model. 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | """ 258 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 259 | if pretrained: 260 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 261 | return model 262 | -------------------------------------------------------------------------------- /models/vgg_bn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | # based on https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 7 | 8 | import torch.nn as nn 9 | import torch.utils.model_zoo as model_zoo 10 | from layers.gate_layer import GateLayer 11 | 12 | __all__ = [ 13 | 'slimmingvgg', 14 | ] 15 | 16 | model_urls = { 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | } 19 | 20 | class LinView(nn.Module): 21 | def __init__(self): 22 | super(LinView, self).__init__() 23 | 24 | def forward(self, x): 25 | return x.view(x.size(0), -1) 26 | 27 | 28 | class VGG(nn.Module): 29 | 30 | def __init__(self, features, cfg, num_classes=1000, init_weights=True): 31 | super(VGG, self).__init__() 32 | self.features = features 33 | self.classifier = nn.Sequential( 34 | nn.Linear(cfg[0] * 7 * 7, cfg[1]), 35 | nn.BatchNorm1d(cfg[1]), 36 | nn.ReLU(True), 37 | nn.Linear(cfg[1],cfg[2]), 38 | nn.BatchNorm1d(cfg[2]), 39 | nn.ReLU(True), 40 | nn.Linear(cfg[2], num_classes) 41 | ) 42 | if init_weights: 43 | self._initialize_weights() 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = x.view(x.size(0), -1) 48 | x = self.classifier(x) 49 | return x 50 | 51 | def _initialize_weights(self): 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | nn.init.kaiming_normal(m.weight, mode='fan_out')#, nonlinearity='relu') 55 | if m.bias is not None: 56 | m.bias.data.zero_() 57 | elif isinstance(m, nn.BatchNorm2d): 58 | m.weight.data.fill_(0.5) 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.Linear): 61 | m.weight.data.normal_(0, 0.01) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.BatchNorm1d): 64 | m.weight.data.fill_(0.5) 65 | m.bias.data.zero_() 66 | 67 | def make_layers(cfg, batch_norm=False): 68 | layers = [] 69 | in_channels = 3 70 | for v in cfg: 71 | if v == 'M': 72 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 73 | else: 74 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 75 | if batch_norm: 76 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 77 | else: 78 | layers += [conv2d, nn.ReLU(inplace=True)] 79 | in_channels = v 80 | return nn.Sequential(*layers) 81 | 82 | cfg = { 83 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M', 4096, 4096] 84 | } 85 | 86 | 87 | def flatten_model(old_net): 88 | """Removes nested modules. Only works for VGG.""" 89 | from collections import OrderedDict 90 | module_list, counter, inserted_view = [], 0, False 91 | gate_counter = 0 92 | print("printing network") 93 | print(" Hard codded network in vgg_bn.py") 94 | for m_indx, module in enumerate(old_net.modules()): 95 | if not isinstance(module, (nn.Sequential, VGG)): 96 | print(m_indx, module) 97 | if isinstance(module, nn.Linear) and not inserted_view: 98 | module_list.append(('flatten', LinView())) 99 | inserted_view = True 100 | 101 | # features.0 102 | # classifier 103 | prefix = "features" 104 | 105 | if m_indx > 30: 106 | prefix = "classifier" 107 | if m_indx == 32: 108 | counter = 0 109 | 110 | # prefix = "" 111 | 112 | module_list.append((prefix + str(counter), module)) 113 | 114 | if isinstance(module, nn.BatchNorm2d): 115 | planes = module.num_features 116 | gate = GateLayer(planes, planes, [1, -1, 1, 1]) 117 | module_list.append(('gate%d' % (gate_counter), gate)) 118 | print("gate ", counter, planes) 119 | gate_counter += 1 120 | 121 | 122 | if isinstance(module, nn.BatchNorm1d): 123 | planes = module.num_features 124 | gate = GateLayer(planes, planes, [1, -1]) 125 | module_list.append(('gate%d' % (gate_counter), gate)) 126 | print("gate ", counter, planes) 127 | gate_counter += 1 128 | 129 | 130 | counter += 1 131 | new_net = nn.Sequential(OrderedDict(module_list)) 132 | return new_net 133 | 134 | 135 | def slimmingvgg(pretrained=False, config=None, **kwargs): 136 | """VGG 11-layer model (configuration "A") with batch normalization 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | """ 141 | if pretrained: 142 | kwargs['init_weights'] = False 143 | if config == None: 144 | config = cfg['A'] 145 | config2 = [config[-4],config[-2],config[-1]] 146 | model = VGG(make_layers(config[:-2], batch_norm=True), config2, **kwargs) 147 | if pretrained: 148 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']),strict=False) 149 | model = flatten_model(model) 150 | return model 151 | -------------------------------------------------------------------------------- /pruning_engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from __future__ import print_function 7 | import os 8 | import time 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.optim 14 | import torch.utils.data 15 | 16 | import numpy as np 17 | 18 | 19 | from copy import deepcopy 20 | import itertools 21 | import pickle 22 | import json 23 | 24 | METHOD_ENCODING = {0: "Taylor_weight", 1: "Random", 2: "Weight norm", 3: "Weight_abs", 25 | 6: "Taylor_output", 10: "OBD", 11: "Taylor_gate_SO", 26 | 22: "Taylor_gate", 23: "Taylor_gate_FG", 30: "BN_weight", 31: "BN_Taylor"} 27 | 28 | 29 | # Method is encoded as an integer that mapping is shown above. 30 | # Methods map to the paper as follows: 31 | # 0 - Taylor_weight - Conv weight/conv/linear weight with Taylor FO In Table 2 and Table 1 32 | # 1 - Random - Random 33 | # 2 - Weight norm - Weight magnitude/ weight 34 | # 3 - Weight_abs - Not used 35 | # 6 - Taylor_output - Taylor-output as is [27] 36 | # 10- OBD - OBD 37 | # 11- Taylor_gate_SO- Taylor SO 38 | # 22- Taylor_gate - Gate after BN in Table 2, Taylor FO in Table 1 39 | # 23- Taylor_gate_FG- uses gradient per example to compute Taylor FO, Taylor FO- FG in Table 1, Gate after BN - FG in Table 2 40 | # 30- BN_weight - BN scale in Table 2 41 | # 31- BN_Taylor - BN scale Taylor FO in Table 2 42 | 43 | 44 | class PruningConfigReader(object): 45 | def __init__(self): 46 | self.pruning_settings = {} 47 | self.config = None 48 | 49 | def read_config(self, filename): 50 | # reads .json file and sets values as pruning_settings for pruning 51 | 52 | with open(filename, "r") as f: 53 | config = json.load(f) 54 | 55 | self.config = config 56 | 57 | self.read_field_value("method", 0) 58 | self.read_field_value("frequency", 500) 59 | self.read_field_value("prune_per_iteration", 2) 60 | self.read_field_value("maximum_pruning_iterations", 10000) 61 | self.read_field_value("starting_neuron", 0) 62 | 63 | self.read_field_value("fixed_layer", -1) 64 | # self.read_field_value("use_momentum", False) 65 | 66 | self.read_field_value("pruning_threshold", 100) 67 | self.read_field_value("start_pruning_after_n_iterations", 0) 68 | # self.read_field_value("use_momentum", False) 69 | self.read_field_value("do_iterative_pruning", True) 70 | self.read_field_value("fixed_criteria", False) 71 | self.read_field_value("seed", 0) 72 | self.read_field_value("pruning_momentum", 0.9) 73 | self.read_field_value("flops_regularization", 0.0) 74 | self.read_field_value("prune_neurons_max", 1) 75 | 76 | self.read_field_value("group_size", 1) 77 | 78 | def read_field_value(self, key, default): 79 | param = default 80 | if key in self.config: 81 | param = self.config[key] 82 | 83 | self.pruning_settings[key] = param 84 | 85 | def get_parameters(self): 86 | return self.pruning_settings 87 | 88 | 89 | class pytorch_pruning(object): 90 | def __init__(self, parameters, pruning_settings=dict(), log_folder=None): 91 | def initialize_parameter(object_name, settings, key, def_value): 92 | ''' 93 | Function check if key is in the settings and sets it, otherwise puts default momentum 94 | :param object_name: reference to the object instance 95 | :param settings: dict of settings 96 | :param def_value: def value for the parameter to be putted into the field if it doesn't work 97 | :return: 98 | void 99 | ''' 100 | value = def_value 101 | if key in settings.keys(): 102 | value = settings[key] 103 | setattr(object_name, key, value) 104 | 105 | # store some statistics 106 | self.min_criteria_value = 1e6 107 | self.max_criteria_value = 0.0 108 | self.median_criteria_value = 0.0 109 | self.neuron_units = 0 110 | self.all_neuron_units = 0 111 | self.pruned_neurons = 0 112 | self.gradient_norm_final = 0.0 113 | self.flops_regularization = 0.0 #not used in the paper 114 | self.pruning_iterations_done = 0 115 | 116 | # initialize_parameter(self, pruning_settings, 'use_momentum', False) 117 | initialize_parameter(self, pruning_settings, 'pruning_momentum', 0.9) 118 | initialize_parameter(self, pruning_settings, 'flops_regularization', 0.0) 119 | self.momentum_coeff = self.pruning_momentum 120 | self.use_momentum = self.pruning_momentum > 0.0 121 | 122 | initialize_parameter(self, pruning_settings, 'prune_per_iteration', 1) 123 | initialize_parameter(self, pruning_settings, 'start_pruning_after_n_iterations', 0) 124 | initialize_parameter(self, pruning_settings, 'prune_neurons_max', 0) 125 | initialize_parameter(self, pruning_settings, 'maximum_pruning_iterations', 0) 126 | initialize_parameter(self, pruning_settings, 'pruning_silent', False) 127 | initialize_parameter(self, pruning_settings, 'l2_normalization_per_layer', False) 128 | initialize_parameter(self, pruning_settings, 'fixed_criteria', False) 129 | initialize_parameter(self, pruning_settings, 'starting_neuron', 0) 130 | initialize_parameter(self, pruning_settings, 'frequency', 30) 131 | initialize_parameter(self, pruning_settings, 'pruning_threshold', 100) 132 | initialize_parameter(self, pruning_settings, 'fixed_layer', -1) 133 | initialize_parameter(self, pruning_settings, 'combination_ID', 0) 134 | initialize_parameter(self, pruning_settings, 'seed', 0) 135 | initialize_parameter(self, pruning_settings, 'group_size', 1) 136 | 137 | initialize_parameter(self, pruning_settings, 'method', 0) 138 | 139 | # Hessian related parameters 140 | self.temp_hessian = [] # list to store Hessian 141 | self.hessian_first_time = True 142 | 143 | self.parameters = list() 144 | 145 | ##get pruning parameters 146 | for parameter in parameters: 147 | parameter_value = parameter["parameter"] 148 | self.parameters.append(parameter_value) 149 | 150 | if self.fixed_layer == -1: 151 | ##prune all layers 152 | self.prune_layers = [True for parameter in self.parameters] 153 | else: 154 | ##prune only one layer 155 | self.prune_layers = [False, ]*len(self.parameters) 156 | self.prune_layers[self.fixed_layer] = True 157 | 158 | self.iterations_done = 0 159 | 160 | self.prune_network_criteria = list() 161 | self.prune_network_accomulate = {"by_layer": list(), "averaged": list(), "averaged_cpu": list()} 162 | 163 | self.pruning_gates = list() 164 | for layer in range(len(self.parameters)): 165 | self.prune_network_criteria.append(list()) 166 | 167 | for key in self.prune_network_accomulate.keys(): 168 | self.prune_network_accomulate[key].append(list()) 169 | 170 | self.pruning_gates.append(np.ones(len(self.parameters[layer]),)) 171 | layer_now_criteria = self.prune_network_criteria[-1] 172 | for unit in range(len(self.parameters[layer])): 173 | layer_now_criteria.append(0.0) 174 | 175 | # logging setup 176 | self.log_folder = log_folder 177 | self.folder_to_write_debug = self.log_folder + '/debug/' 178 | if not os.path.exists(self.folder_to_write_debug): 179 | os.makedirs(self.folder_to_write_debug) 180 | 181 | self.method_25_first_done = True 182 | 183 | if self.method == 40 or self.method == 50 or self.method == 25: 184 | self.oracle_dict = {"layer_pruning": -1, "initial_loss": 0.0, "loss_list": list(), "neuron": list(), "iterations": 0} 185 | self.method_25_first_done = False 186 | 187 | if self.method == 25: 188 | with open("./utils/study/oracle.pickle","rb") as f: 189 | oracle_list = pickle.load(f) 190 | 191 | self.oracle_dict["loss_list"] = oracle_list 192 | 193 | self.needs_hessian = False 194 | if self.method in [10, 11]: 195 | self.needs_hessian = True 196 | 197 | # useful for storing data of the experiment 198 | self.data_logger = dict() 199 | self.data_logger["pruning_neurons"] = list() 200 | self.data_logger["pruning_accuracy"] = list() 201 | self.data_logger["pruning_loss"] = list() 202 | self.data_logger["method"] = self.method 203 | self.data_logger["prune_per_iteration"] = self.prune_per_iteration 204 | self.data_logger["combination_ID"] = list() 205 | self.data_logger["fixed_layer"] = self.fixed_layer 206 | self.data_logger["frequency"] = self.frequency 207 | self.data_logger["starting_neuron"] = self.starting_neuron 208 | self.data_logger["use_momentum"] = self.use_momentum 209 | 210 | self.data_logger["time_stamp"] = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) 211 | 212 | if hasattr(self, 'seed'): 213 | self.data_logger["seed"] = self.seed 214 | 215 | self.data_logger["filename"] = "%s/data_logger_seed_%d_%s.p"%(log_folder, self.data_logger["seed"], self.data_logger["time_stamp"]) 216 | if self.method == 50: 217 | self.data_logger["filename"] = "%s/data_logger_seed_%d_neuron_%d_%s.p"%(log_folder, self.starting_neuron, self.data_logger["seed"], self.data_logger["time_stamp"]) 218 | self.log_folder = log_folder 219 | 220 | # the rest of initializations 221 | self.pruned_neurons = self.starting_neuron 222 | 223 | self.util_loss_tracker = 0.0 224 | self.util_acc_tracker = 0.0 225 | self.util_loss_tracker_num = 0.0 226 | 227 | self.loss_tracker_exp = ExpMeter() 228 | # stores results of the pruning, 0 - unsuccessful, 1 - successful 229 | self.res_pruning = 0 230 | 231 | self.iter_step = -1 232 | 233 | self.train_writer = None 234 | 235 | self.set_moment_zero = True 236 | self.pruning_mask_from = "" 237 | 238 | def add_criteria(self): 239 | ''' 240 | This method adds criteria to global list given batch stats. 241 | ''' 242 | 243 | if self.fixed_criteria: 244 | if self.pruning_iterations_done > self.start_pruning_after_n_iterations : 245 | return 0 246 | 247 | for layer, if_prune in enumerate(self.prune_layers): 248 | if not if_prune: 249 | continue 250 | 251 | nunits = self.parameters[layer].size(0) 252 | eps = 1e-8 253 | 254 | if len(self.pruning_mask_from) > 0: 255 | # preload pruning mask 256 | self.method = -1 257 | criteria_for_layer = torch.from_numpy(self.loaded_mask_criteria[layer]).type(torch.FloatTensor).cuda(async=True) 258 | 259 | if self.method == 0: 260 | # First order Taylor expansion on the weight 261 | criteria_for_layer = (self.parameters[layer]*self.parameters[layer].grad ).data.pow(2).view(nunits,-1).sum(dim=1) 262 | elif self.method == 1: 263 | # random pruning 264 | criteria_for_layer = np.random.uniform(low=0, high=5, size=(nunits,)) 265 | elif self.method == 2: 266 | # min weight 267 | criteria_for_layer = self.parameters[layer].pow(2).view(nunits,-1).sum(dim=1).data 268 | elif self.method == 3: 269 | # weight_abs 270 | criteria_for_layer = self.parameters[layer].abs().view(nunits,-1).sum(dim=1).data 271 | elif self.method == 6: 272 | # ICLR2017 Taylor on output of the layer 273 | if 1: 274 | criteria_for_layer = self.parameters[layer].full_grad_iclr2017 275 | criteria_for_layer = criteria_for_layer / (np.linalg.norm(criteria_for_layer) + eps) 276 | elif self.method == 10: 277 | # diagonal of Hessian 278 | criteria_for_layer = (self.parameters[layer] * torch.diag(self.temp_hessian[layer])).data.view(nunits, 279 | -1).sum( 280 | dim=1) 281 | elif self.method == 11: 282 | # second order Taylor expansion for loss change in the network 283 | criteria_for_layer = (-(self.parameters[layer] * self.parameters[layer].grad).data + 0.5 * ( 284 | self.parameters[layer] * self.parameters[layer] * torch.diag( 285 | self.temp_hessian[layer])).data).pow(2) 286 | 287 | elif self.method == 22: 288 | # Taylor pruning on gate 289 | criteria_for_layer = (self.parameters[layer]*self.parameters[layer].grad).data.pow(2).view(nunits, -1).sum(dim=1) 290 | if hasattr(self, "dataset"): 291 | # fix for skip connection pruning, gradient will be accumulated instead of being averaged 292 | if self.dataset == "Imagenet": 293 | if hasattr(self, "model"): 294 | if not ("noskip" in self.model): 295 | if "resnet" in self.model: 296 | mult = 3.0 297 | if layer == 1: 298 | mult = 4.0 299 | elif layer == 2: 300 | mult = 23.0 if "resnet101" in self.model else mult 301 | mult = 6.0 if "resnet34" in self.model else mult 302 | mult = 6.0 if "resnet50" in self.model else mult 303 | 304 | criteria_for_layer /= mult 305 | 306 | elif self.method == 23: 307 | # Taylor pruning on gate with computing full gradient 308 | criteria_for_layer = (self.parameters[layer].full_grad.t()).data.pow(2).view(nunits,-1).sum(dim=1) 309 | 310 | elif self.method == 30: 311 | # batch normalization based pruning 312 | # by scale (weight) of the batchnorm 313 | criteria_for_layer = (self.parameters[layer]).data.abs().view(nunits, -1).sum(dim=1) 314 | 315 | elif self.method == 31: 316 | # Taylor FO on BN 317 | if hasattr(self.parameters[layer], "bias"): 318 | criteria_for_layer = (self.parameters[layer]*self.parameters[layer].grad + 319 | self.parameters[layer].bias*self.parameters[layer].bias.grad ).data.pow(2).view(nunits,-1).sum(dim=1) 320 | else: 321 | criteria_for_layer = ( 322 | self.parameters[layer] * self.parameters[layer].grad).data.pow(2).view(nunits, -1).sum(dim=1) 323 | 324 | elif self.method == 40: 325 | # ORACLE on the fly that reevaluates itslef every pruning step 326 | criteria_for_layer = np.asarray(self.oracle_dict["loss_list"][layer]).copy() 327 | self.oracle_dict["loss_list"][layer] = list() 328 | elif self.method == 50: 329 | # combinatorial pruning - evaluates all possibilities of removing N neurons 330 | criteria_for_layer = np.asarray(self.oracle_dict["loss_list"][layer]).copy() 331 | self.oracle_dict["loss_list"][layer] = list() 332 | else: 333 | pass 334 | 335 | if self.iterations_done == 0: 336 | self.prune_network_accomulate["by_layer"][layer] = criteria_for_layer 337 | else: 338 | self.prune_network_accomulate["by_layer"][layer] += criteria_for_layer 339 | 340 | self.iterations_done += 1 341 | 342 | @staticmethod 343 | def group_criteria(list_criteria_per_layer, group_size=1): 344 | ''' 345 | Function combine criteria per neuron into groups of size group_size. 346 | Output is a list of groups organized by layers. Length of output is a number of layers. 347 | The criterion for the group is computed as an average of member's criteria. 348 | Input: 349 | list_criteria_per_layer - list of criteria per neuron organized per layer 350 | group_size - number of neurons per group 351 | 352 | Output: 353 | groups - groups organized per layer. Each group element is a tuple of 2: (index of neurons, criterion) 354 | ''' 355 | groups = list() 356 | 357 | for layer in list_criteria_per_layer: 358 | layer_groups = list() 359 | indeces = np.argsort(layer) 360 | for group_id in range(int(np.ceil(len(layer)/group_size))): 361 | current_group = slice(group_id*group_size, min((group_id+1)*group_size, len(layer))) 362 | values = [layer[ind] for ind in indeces[current_group]] 363 | group = [indeces[current_group], sum(values)] 364 | 365 | layer_groups.append(group) 366 | groups.append(layer_groups) 367 | 368 | return groups 369 | 370 | def compute_saliency(self): 371 | ''' 372 | Method performs pruning based on precomputed criteria values. Needs to run after add_criteria() 373 | ''' 374 | def write_to_debug(what_write_name, what_write_value): 375 | # Aux function to store information in the text file 376 | with open(self.log_debug, 'a') as f: 377 | f.write("{} {}\n".format(what_write_name,what_write_value)) 378 | 379 | def nothing(what_write_name, what_write_value): 380 | pass 381 | 382 | if self.method == 50: 383 | write_to_debug = nothing 384 | 385 | # compute loss since the last pruning and decide if to prune: 386 | if self.util_loss_tracker_num > 0: 387 | validation_error = self.util_loss_tracker / self.util_loss_tracker_num 388 | validation_error_long = validation_error 389 | acc = self.util_acc_tracker / self.util_loss_tracker_num 390 | else: 391 | print("compute loss and run self.util_add_loss(loss.item()) before running this") 392 | validation_error = 0.0 393 | acc = 0.0 394 | 395 | self.util_training_loss = validation_error 396 | self.util_training_acc = acc 397 | 398 | # reset training loss tracker 399 | self.util_loss_tracker = 0.0 400 | self.util_acc_tracker = 0.0 401 | self.util_loss_tracker_num = 0 402 | 403 | if validation_error > self.pruning_threshold: 404 | ## if error is big then skip pruning 405 | print("skipping pruning", validation_error, "(%f)"%validation_error_long, self.pruning_threshold) 406 | if self.method != 4: 407 | self.res_pruning = -1 408 | return -1 409 | 410 | if self.maximum_pruning_iterations <= self.pruning_iterations_done: 411 | # if reached max number of pruning iterations -> exit 412 | self.res_pruning = -1 413 | return -1 414 | 415 | self.full_list_of_criteria = list() 416 | 417 | for layer, if_prune in enumerate(self.prune_layers): 418 | if not if_prune: 419 | continue 420 | 421 | if self.iterations_done > 0: 422 | # momentum turned to be useless and even reduces performance 423 | contribution = self.prune_network_accomulate["by_layer"][layer] / self.iterations_done 424 | if self.pruning_iterations_done == 0 or not self.use_momentum or (self.method in [4, 40, 50]): 425 | self.prune_network_accomulate["averaged"][layer] = contribution 426 | else: 427 | # use momentum to accumulate criteria over several pruning iterations: 428 | self.prune_network_accomulate["averaged"][layer] = self.momentum_coeff*self.prune_network_accomulate["averaged"][layer]+(1.0- self.momentum_coeff)*contribution 429 | 430 | current_layer = self.prune_network_accomulate["averaged"][layer] 431 | if not (self.method in [1, 4, 40, 15, 50]): 432 | current_layer = current_layer.cpu().numpy() 433 | 434 | if self.l2_normalization_per_layer: 435 | eps = 1e-8 436 | current_layer = current_layer / (np.linalg.norm(current_layer) + eps) 437 | 438 | self.prune_network_accomulate["averaged_cpu"][layer] = current_layer 439 | else: 440 | print("First do some add_criteria iterations") 441 | exit() 442 | 443 | for unit in range(len(self.parameters[layer])): 444 | criterion_now = current_layer[unit] 445 | 446 | # make sure that pruned neurons have 0 criteria 447 | self.prune_network_criteria[layer][unit] = criterion_now * self.pruning_gates[layer][unit] 448 | 449 | if self.method == 50: 450 | self.prune_network_criteria[layer][unit] = criterion_now 451 | 452 | # count number of neurons 453 | all_neuron_units, neuron_units = self._count_number_of_neurons() 454 | self.neuron_units = neuron_units 455 | self.all_neuron_units = all_neuron_units 456 | 457 | # store criteria_result into file 458 | if not self.pruning_silent: 459 | import pickle 460 | store_criteria = self.prune_network_accomulate["averaged_cpu"] 461 | pickle.dump(store_criteria, open(self.folder_to_write_debug + "criteria_%04d.pickle"%self.pruning_iterations_done, "wb")) 462 | if self.pruning_iterations_done == 0: 463 | pickle.dump(store_criteria, open(self.log_folder + "criteria_%d.pickle"%self.method, "wb")) 464 | pickle.dump(store_criteria, open(self.log_folder + "criteria_%d_final.pickle"%self.method, "wb")) 465 | 466 | if not self.fixed_criteria: 467 | self.iterations_done = 0 468 | 469 | # create groups per layer 470 | groups = self.group_criteria(self.prune_network_criteria, group_size=self.group_size) 471 | 472 | # apply flops regularization 473 | # if self.flops_regularization > 0.0: 474 | # self.apply_flops_regularization(groups, mu=self.flops_regularization) 475 | 476 | # get an array of all criteria from groups 477 | all_criteria = np.asarray([group[1] for layer in groups for group in layer]).reshape(-1) 478 | 479 | prune_neurons_now = (self.pruned_neurons + self.prune_per_iteration)//self.group_size - 1 480 | if self.prune_neurons_max != -1: 481 | prune_neurons_now = min(len(all_criteria)-1, min(prune_neurons_now, self.prune_neurons_max//self.group_size - 1)) 482 | 483 | # adaptively estimate threshold given a number of neurons to be removed 484 | threshold_now = np.sort(all_criteria)[prune_neurons_now] 485 | 486 | if self.method == 50: 487 | # combinatorial approach 488 | threshold_now = 0.5 489 | self.pruning_iterations_done = self.combination_ID 490 | self.data_logger["combination_ID"].append(self.combination_ID-1) 491 | self.combination_ID += 1 492 | self.reset_oracle_pruning() 493 | print("full_combinatorial: combination ", self.combination_ID) 494 | 495 | self.pruning_iterations_done += 1 496 | 497 | self.log_debug = self.folder_to_write_debug + 'debugOutput_pruning_%08d' % ( 498 | self.pruning_iterations_done) + '.txt' 499 | write_to_debug("method", self.method) 500 | write_to_debug("pruned_neurons", self.pruned_neurons) 501 | write_to_debug("pruning_iterations_done", self.pruning_iterations_done) 502 | write_to_debug("neuron_units", neuron_units) 503 | write_to_debug("all_neuron_units", all_neuron_units) 504 | write_to_debug("threshold_now", threshold_now) 505 | write_to_debug("groups_total", sum([len(layer) for layer in groups])) 506 | 507 | if self.pruning_iterations_done < self.start_pruning_after_n_iterations: 508 | self.res_pruning = -1 509 | return -1 510 | 511 | for layer, if_prune in enumerate(self.prune_layers): 512 | if not if_prune: 513 | continue 514 | 515 | write_to_debug("\nLayer:", layer) 516 | write_to_debug("units:", len(self.parameters[layer])) 517 | 518 | if self.prune_per_iteration == 0: 519 | continue 520 | 521 | for group in groups[layer]: 522 | if group[1] <= threshold_now: 523 | for unit in group[0]: 524 | # do actual pruning 525 | self.pruning_gates[layer][unit] *= 0.0 526 | self.parameters[layer].data[unit] *= 0.0 527 | 528 | write_to_debug("pruned_perc:", [np.nonzero(1.0-self.pruning_gates[layer])[0].size, len(self.parameters[layer])]) 529 | 530 | # count number of neurons 531 | all_neuron_units, neuron_units = self._count_number_of_neurons() 532 | 533 | self.pruned_neurons = all_neuron_units-neuron_units 534 | 535 | if self.method == 25: 536 | self.method_25_first_done = True 537 | 538 | self.threshold_now = threshold_now 539 | try: 540 | self.min_criteria_value = (all_criteria[all_criteria > 0.0]).min() 541 | self.max_criteria_value = (all_criteria[all_criteria > 0.0]).max() 542 | self.median_criteria_value = np.median(all_criteria[all_criteria > 0.0]) 543 | except: 544 | self.min_criteria_value = 0.0 545 | self.max_criteria_value = 0.0 546 | self.median_criteria_value = 0.0 547 | 548 | # set result to successful 549 | self.res_pruning = 1 550 | 551 | def _count_number_of_neurons(self): 552 | ''' 553 | Function computes number of total neurons and number of active neurons 554 | :return: 555 | all_neuron_units - number of neurons considered for pruning 556 | neuron_units - number of not pruned neurons in the model 557 | ''' 558 | all_neuron_units = 0 559 | neuron_units = 0 560 | for layer, if_prune in enumerate(self.prune_layers): 561 | if not if_prune: 562 | continue 563 | 564 | all_neuron_units += len( self.parameters[layer] ) 565 | for unit in range(len( self.parameters[layer] )): 566 | if len(self.parameters[layer].data.size()) > 1: 567 | statistics = self.parameters[layer].data[unit].abs().sum() 568 | else: 569 | statistics = self.parameters[layer].data[unit] 570 | 571 | if statistics > 0.0: 572 | neuron_units += 1 573 | 574 | return all_neuron_units, neuron_units 575 | 576 | def set_weights_oracle_pruning(self): 577 | ''' 578 | sets gates/weights to zero to evaluate pruning 579 | will reuse weights for pruning 580 | only for oracle pruning 581 | ''' 582 | 583 | for layer,if_prune in enumerate(self.prune_layers_oracle): 584 | if not if_prune: 585 | continue 586 | 587 | if self.method == 40: 588 | self.parameters[layer].data = deepcopy(torch.from_numpy(self.stored_weights).cuda()) 589 | 590 | for unit in range(len(self.parameters[layer])): 591 | if self.method == 40: 592 | self.pruning_gates[layer][unit] = 1.0 593 | 594 | if unit == self.oracle_unit: 595 | self.pruning_gates[layer][unit] *= 0.0 596 | self.parameters[layer].data[unit] *= 0.0 597 | 598 | # if 'momentum_buffer' in optimizer.state[self.parameters[layer]].keys(): 599 | # optimizer.state[self.parameters[layer]]['momentum_buffer'][unit] *= 0.0 600 | return 1 601 | 602 | def reset_oracle_pruning(self): 603 | ''' 604 | Method restores weights to original after masking for Oracle pruning 605 | :return: 606 | ''' 607 | for layer, if_prune in enumerate(self.prune_layers_oracle): 608 | if not if_prune: 609 | continue 610 | 611 | if self.method == 40 or self.method == 50: 612 | self.parameters[layer].data = deepcopy(torch.from_numpy(self.stored_weights).cuda()) 613 | 614 | for unit in range(len( self.parameters[layer])): 615 | if self.method == 40 or self.method == 50: 616 | self.pruning_gates[layer][unit] = 1.0 617 | 618 | def enforce_pruning(self): 619 | ''' 620 | Method sets parameters ang gates to 0 for pruned neurons. 621 | Helpful if optimizer will change weights from being zero (due to regularization etc.) 622 | ''' 623 | for layer, if_prune in enumerate(self.prune_layers): 624 | if not if_prune: 625 | continue 626 | 627 | for unit in range(len(self.parameters[layer])): 628 | if self.pruning_gates[layer][unit] == 0.0: 629 | self.parameters[layer].data[unit] *= 0.0 630 | 631 | def compute_hessian(self, loss): 632 | ''' 633 | Computes Hessian per layer of the loss with respect to self.parameters, currently implemented only for gates 634 | ''' 635 | 636 | if self.maximum_pruning_iterations <= self.pruning_iterations_done: 637 | # if reached max number of pruning iterations -> exit 638 | self.res_pruning = -1 639 | return -1 640 | 641 | self.temp_hessian = list() 642 | for layer_indx, parameter in enumerate(self.parameters): 643 | # print("Computing Hessian current/total layers:",layer_indx,"/",len(self.parameters)) 644 | if self.prune_layers[layer_indx]: 645 | grad_params = torch.autograd.grad(loss, parameter, create_graph=True) 646 | length_grad = len(grad_params[0]) 647 | hessian = torch.zeros(length_grad, length_grad) 648 | 649 | cnt = 0 650 | for parameter_loc in range(len(parameter)): 651 | if parameter[parameter_loc].data.cpu().numpy().sum() == 0.0: 652 | continue 653 | 654 | grad_params2 = torch.autograd.grad(grad_params[0][parameter_loc], parameter, create_graph=True) 655 | hessian[parameter_loc, :] = grad_params2[0].data 656 | 657 | else: 658 | length_grad = len(parameter) 659 | hessian = torch.zeros(length_grad, length_grad) 660 | 661 | self.temp_hessian.append(torch.FloatTensor(hessian.cpu().numpy()).cuda()) 662 | 663 | def run_full_oracle(self, model, data, target, criterion, initial_loss): 664 | ''' 665 | Runs oracle on all data by setting to 0 every neuron and running forward pass 666 | ''' 667 | 668 | # stop adding data if needed 669 | if self.maximum_pruning_iterations <= self.pruning_iterations_done: 670 | # if reached max number of pruning iterations -> exit 671 | self.res_pruning = -1 672 | return -1 673 | 674 | if self.method == 40: 675 | # for oracle let's try to do the best possible oracle by evaluating all neurons for each batch 676 | self.oracle_dict["initial_loss"] += initial_loss 677 | self.oracle_dict["iterations"] += 1 678 | 679 | # import pdb; pdb.set_trace() 680 | if hasattr(self, 'stored_pruning'): 681 | if self.stored_pruning['use_now']: 682 | # load first list of criteria 683 | print("use previous computed priors") 684 | for layer_index, layer_parameters in enumerate(self.parameters): 685 | 686 | # start list of estiamtes for the layer if it is empty 687 | if len(self.oracle_dict["loss_list"]) < layer_index + 1: 688 | self.oracle_dict["loss_list"].append(list()) 689 | 690 | if self.prune_layers[layer_index] == False: 691 | continue 692 | 693 | self.oracle_dict["loss_list"][layer_index] = self.stored_pruning['criteria_start'][layer_index] 694 | self.pruned_neurons = self.stored_pruning['neuron_start'] 695 | return 1 696 | 697 | # do first pass with precomputed values 698 | for layer_index, layer_parameters in enumerate(self.parameters): 699 | # start list of estimates for the layer if it is empty 700 | if len(self.oracle_dict["loss_list"]) < layer_index + 1: 701 | self.oracle_dict["loss_list"].append(list()) 702 | 703 | if not self.prune_layers[layer_index]: 704 | continue 705 | # copy original prune_layer variable that sets layers to be prunned 706 | self.prune_layers_oracle = [False, ]*len(self.parameters) 707 | self.prune_layers_oracle[layer_index] = True 708 | # store weights for future to recover 709 | self.stored_weights = deepcopy(self.parameters[layer_index].data.cpu().numpy()) 710 | 711 | for neurion_id, neuron in enumerate(layer_parameters): 712 | # set neuron to zero 713 | self.oracle_unit = neurion_id 714 | self.set_weights_oracle_pruning() 715 | 716 | if self.stored_weights[neurion_id].sum() == 0.0: 717 | new_loss = initial_loss 718 | else: 719 | outputs = model(data) 720 | loss = criterion(outputs, target) 721 | new_loss = loss.item() 722 | 723 | # define loss 724 | oracle_value = abs(initial_loss - new_loss) 725 | # relative loss for testing: 726 | # oracle_value = initial_loss - new_loss 727 | 728 | if len(self.oracle_dict["loss_list"][layer_index]) == 0: 729 | self.oracle_dict["loss_list"][layer_index] = [oracle_value, ] 730 | elif len(self.oracle_dict["loss_list"][layer_index]) < neurion_id+1: 731 | self.oracle_dict["loss_list"][layer_index].append(oracle_value) 732 | else: 733 | self.oracle_dict["loss_list"][layer_index][neurion_id] += oracle_value 734 | 735 | self.reset_oracle_pruning() 736 | 737 | elif self.method == 50: 738 | if self.pruning_iterations_done == 0: 739 | # store weights again 740 | self.stored_weights = deepcopy(self.parameters[self.fixed_layer].data.cpu().numpy()) 741 | 742 | self.set_next_combination() 743 | 744 | else: 745 | pass 746 | # print("Full oracle only works with the methods: {}".format(40)) 747 | 748 | def report_loss_neuron(self, training_loss, training_acc, train_writer = None, neurons_left = 0): 749 | ''' 750 | method to store stistics during pruning to the log file 751 | :param training_loss: 752 | :param training_acc: 753 | :param train_writer: 754 | :param neurons_left: 755 | :return: 756 | void 757 | ''' 758 | if train_writer is not None: 759 | train_writer.add_scalar('loss_neuron', training_loss, self.all_neuron_units-self.neuron_units) 760 | 761 | self.data_logger["pruning_neurons"].append(self.all_neuron_units-self.neuron_units) 762 | self.data_logger["pruning_loss"].append(training_loss) 763 | self.data_logger["pruning_accuracy"].append(training_acc) 764 | 765 | self.write_log_file() 766 | 767 | def write_log_file(self): 768 | with open(self.data_logger["filename"], "wb") as f: 769 | pickle.dump(self.data_logger, f) 770 | 771 | def load_mask(self): 772 | '''Method loads precomputed criteria for pruning 773 | :return: 774 | ''' 775 | if not len(self.pruning_mask_from)>0: 776 | print("pruning_engine.load_mask(): did not find mask file, will load nothing") 777 | else: 778 | if not os.path.isfile(self.pruning_mask_from): 779 | print("pruning_engine.load_mask(): file doesn't exist", self.pruning_mask_from) 780 | print("pruning_engine.load_mask(): check it, exit,", self.pruning_mask_from) 781 | exit() 782 | 783 | with open(self.pruning_mask_from, 'rb') as f: 784 | self.loaded_mask_criteria = pickle.load(f) 785 | 786 | print("pruning_engine.load_mask(): loaded criteria from", self.pruning_mask_from) 787 | 788 | def set_next_combination(self): 789 | ''' 790 | For combinatorial pruning only 791 | ''' 792 | if self.method == 50: 793 | 794 | self.oracle_dict["iterations"] += 1 795 | 796 | for layer_index, layer_parameters in enumerate(self.parameters): 797 | 798 | ##start list of estiamtes for the layer if it is empty 799 | if len(self.oracle_dict["loss_list"]) < layer_index + 1: 800 | self.oracle_dict["loss_list"].append(list()) 801 | 802 | if self.prune_layers[layer_index] == False: 803 | continue 804 | 805 | nunits = len(layer_parameters) 806 | 807 | comb_num = -1 808 | found_combination = False 809 | for it in itertools.combinations(range(nunits), self.starting_neuron): 810 | comb_num += 1 811 | if comb_num == int(self.combination_ID): 812 | found_combination = True 813 | break 814 | 815 | # import pdb; pdb.set_trace() 816 | if not found_combination: 817 | print("didn't find needed combination, exit") 818 | exit() 819 | 820 | self.prune_layers_oracle = self.prune_layers.copy() 821 | self.prune_layers_oracle = [False,]*len(self.parameters) 822 | self.prune_layers_oracle[layer_index] = True 823 | 824 | criteria_for_layer = np.ones((nunits,)) 825 | criteria_for_layer[list(it)] = 0.0 826 | 827 | if len(self.oracle_dict["loss_list"][layer_index]) == 0: 828 | self.oracle_dict["loss_list"][layer_index] = criteria_for_layer 829 | else: 830 | self.oracle_dict["loss_list"][layer_index] += criteria_for_layer 831 | 832 | def report_to_tensorboard(self, train_writer, processed_batches): 833 | ''' 834 | Log data with tensorboard 835 | ''' 836 | gradient_norm_final_before = self.gradient_norm_final 837 | train_writer.add_scalar('Neurons_left', self.neuron_units, processed_batches) 838 | train_writer.add_scalar('Criteria_min', self.min_criteria_value, self.pruning_iterations_done) 839 | train_writer.add_scalar('Criteria_max', self.max_criteria_value, self.pruning_iterations_done) 840 | train_writer.add_scalar('Criteria_median', self.median_criteria_value, self.pruning_iterations_done) 841 | train_writer.add_scalar('Gradient_norm_before', gradient_norm_final_before, self.pruning_iterations_done) 842 | train_writer.add_scalar('Pruning_threshold', self.threshold_now, self.pruning_iterations_done) 843 | 844 | def util_add_loss(self, training_loss_current, training_acc): 845 | # keeps track of current loss 846 | self.util_loss_tracker += training_loss_current 847 | self.util_acc_tracker += training_acc 848 | self.util_loss_tracker_num += 1 849 | self.loss_tracker_exp.update(training_loss_current) 850 | # self.acc_tracker_exp.update(training_acc) 851 | 852 | def do_step(self, loss=None, optimizer=None, neurons_left=0, training_acc=0.0): 853 | ''' 854 | do one step of pruning, 855 | 1) Add importance estimate 856 | 2) checks if loss is above threshold 857 | 3) performs one step of pruning if needed 858 | ''' 859 | self.iter_step += 1 860 | niter = self.iter_step 861 | 862 | # # sets pruned weights to zero 863 | # self.enforce_pruning() 864 | 865 | # stop if pruned maximum amount 866 | if self.maximum_pruning_iterations <= self.pruning_iterations_done: 867 | # exit if we pruned enough 868 | self.res_pruning = -1 869 | return -1 870 | 871 | # sets pruned weights to zero 872 | self.enforce_pruning() 873 | 874 | # compute criteria for given batch 875 | self.add_criteria() 876 | 877 | # small script to keep track of training loss since the last pruning 878 | self.util_add_loss(loss, training_acc) 879 | 880 | if ((niter-1) % self.frequency == 0) and (niter != 0) and (self.res_pruning==1): 881 | self.report_loss_neuron(self.util_training_loss, training_acc=self.util_training_acc, train_writer=self.train_writer, neurons_left=neurons_left) 882 | 883 | if niter % self.frequency == 0 and niter != 0: 884 | # do actual pruning, output: 1 - good, 0 - no pruning 885 | 886 | self.compute_saliency() 887 | self.set_momentum_zero_sgd(optimizer=optimizer) 888 | 889 | training_loss = self.util_training_loss 890 | if self.res_pruning == 1: 891 | print("Pruning: Units", self.neuron_units, "/", self.all_neuron_units, "loss", training_loss, "Zeroed", self.pruned_neurons, "criteria min:{}/max:{:2.7f}".format(self.min_criteria_value,self.max_criteria_value)) 892 | 893 | def set_momentum_zero_sgd(self, optimizer=None): 894 | ''' 895 | Method sets momentum buffer to zero for pruned neurons. Supports SGD only. 896 | :return: 897 | void 898 | ''' 899 | for layer in range(len(self.pruning_gates)): 900 | if not self.prune_layers[layer]: 901 | continue 902 | for unit in range(len(self.pruning_gates[layer])): 903 | if not self.pruning_gates[layer][unit]: 904 | continue 905 | if 'momentum_buffer' in optimizer.state[self.parameters[layer]].keys(): 906 | optimizer.state[self.parameters[layer]]['momentum_buffer'][unit] *= 0.0 907 | 908 | def connect_tensorboard(self, tensorboard): 909 | ''' 910 | Function connects tensorboard to pruning engine 911 | ''' 912 | self.tensorboard = True 913 | self.train_writer = tensorboard 914 | 915 | def update_flops(self, stats=None): 916 | ''' 917 | Function updates flops for potential regularization 918 | :param stats: a list of flops per parameter 919 | :return: 920 | ''' 921 | self.per_layer_flops = list() 922 | if len(stats["flops"]) < 1: 923 | return -1 924 | for pruning_param in self.gates_to_params: 925 | if isinstance(pruning_param, list): 926 | # parameter spans many blocks, will aggregate over them 927 | self.per_layer_flops.append(sum([stats['flops'][a] for a in pruning_param])) 928 | else: 929 | self.per_layer_flops.append(stats['flops'][pruning_param]) 930 | 931 | def apply_flops_regularization(self, groups, mu=0.1): 932 | ''' 933 | Function applieregularisation to computed importance per layer 934 | :param groups: a list of groups organized per layer 935 | :param mu: regularization coefficient 936 | :return: 937 | ''' 938 | if len(self.per_layer_flops) < 1: 939 | return -1 940 | 941 | for layer_id, layer in enumerate(groups): 942 | for group in layer: 943 | # import pdb; pdb.set_trace() 944 | total_neurons = len(group[0]) 945 | group[1] = group[1] - mu*(self.per_layer_flops[layer_id]*total_neurons) 946 | 947 | 948 | def prepare_pruning_list(pruning_settings, model, model_name, pruning_mask_from='', name=''): 949 | ''' 950 | Function returns a list of parameters from model to be considered for pruning. 951 | Depending on the pruning method and strategy different parameters are selected (conv kernels, BN parameters etc) 952 | :param pruning_settings: 953 | :param model: 954 | :return: 955 | ''' 956 | # Function creates a list of layer that will be pruned based o user selection 957 | 958 | ADD_BY_GATES = True # gates add artificially they have weight == 1 and not trained, but gradient is important. see models/lenet.py 959 | ADD_BY_WEIGHTS = ADD_BY_BN = False 960 | 961 | pruning_method = pruning_settings['method'] 962 | 963 | pruning_parameters_list = list() 964 | if ADD_BY_GATES: 965 | 966 | first_step = True 967 | prev_module = None 968 | prev_module2 = None 969 | print("network structure") 970 | for module_indx, m in enumerate(model.modules()): 971 | # print(module_indx, m) 972 | if hasattr(m, "do_not_update"): 973 | m_to_add = m 974 | 975 | if (pruning_method != 23) and (pruning_method != 6): 976 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add, 977 | "compute_criteria_from": m_to_add.weight} 978 | else: 979 | def just_hook(self, grad_input, grad_output): 980 | # getting full gradient for parameters 981 | # normal backward will provide only averaged gradient per batch 982 | # requires to store output of the layer 983 | if len(grad_output[0].shape) == 4: 984 | self.weight.full_grad = (grad_output[0] * self.output).sum(-1).sum(-1) 985 | else: 986 | self.weight.full_grad = (grad_output[0] * self.output) 987 | 988 | if pruning_method == 6: 989 | # implement ICLR2017 paper 990 | def just_hook(self, grad_input, grad_output): 991 | if len(grad_output[0].shape) == 4: 992 | self.weight.full_grad_iclr2017 = (grad_output[0] * self.output).abs().mean(-1).mean( 993 | -1).mean(0) 994 | else: 995 | self.weight.full_grad_iclr2017 = (grad_output[0] * self.output).abs().mean(0) 996 | 997 | def forward_hook(self, input, output): 998 | self.output = output 999 | 1000 | if not len(pruning_mask_from) > 0: 1001 | # in case mask is precomputed we remove hooks 1002 | m_to_add.register_forward_hook(forward_hook) 1003 | m_to_add.register_backward_hook(just_hook) 1004 | 1005 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add, 1006 | "compute_criteria_from": m_to_add.weight} 1007 | 1008 | if pruning_method in [30, 31]: 1009 | # for densenets. 1010 | # add previous layer's value for batch norm pruning 1011 | 1012 | if isinstance(prev_module, nn.BatchNorm2d): 1013 | m_to_add = prev_module 1014 | print(m_to_add, "yes") 1015 | else: 1016 | print(m_to_add, "no") 1017 | 1018 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add, 1019 | "compute_criteria_from": m_to_add.weight} 1020 | 1021 | if pruning_method in [24, ]: 1022 | # add previous layer's value for batch norm pruning 1023 | 1024 | if isinstance(prev_module, nn.Conv2d): 1025 | m_to_add = prev_module 1026 | 1027 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add, 1028 | "compute_criteria_from": m_to_add.weight} 1029 | 1030 | if pruning_method in [0, 2, 3]: 1031 | # add previous layer's value for batch norm pruning 1032 | 1033 | if isinstance(prev_module2, nn.Conv2d): 1034 | print(module_indx, prev_module2, "yes") 1035 | m_to_add = prev_module2 1036 | elif isinstance(prev_module2, nn.Linear): 1037 | print(module_indx, prev_module2, "yes") 1038 | m_to_add = prev_module2 1039 | elif isinstance(prev_module, nn.Conv2d): 1040 | print(module_indx, prev_module, "yes") 1041 | m_to_add = prev_module 1042 | else: 1043 | print(module_indx, m, "no") 1044 | 1045 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add, 1046 | "compute_criteria_from": m_to_add.weight} 1047 | 1048 | pruning_parameters_list.append(for_pruning) 1049 | prev_module2 = prev_module 1050 | prev_module = m 1051 | 1052 | if model_name == "resnet20": 1053 | # prune only even layers as in Rethinking min norm pruning 1054 | pruning_parameters_list = [d for di, d in enumerate(pruning_parameters_list) if (di % 2 == 1 and di > 0)] 1055 | 1056 | if ("prune_only_skip_connections" in name) and 1: 1057 | # will prune only skip connections (gates around them). Works with ResNets only 1058 | pruning_parameters_list = pruning_parameters_list[:4] 1059 | 1060 | return pruning_parameters_list 1061 | 1062 | class ExpMeter(object): 1063 | """Computes and stores the average and current value""" 1064 | def __init__(self, mom = 0.9): 1065 | self.reset() 1066 | self.mom = mom 1067 | 1068 | def reset(self): 1069 | self.val = 0 1070 | self.avg = 0 1071 | self.sum = 0 1072 | self.count = 0 1073 | self.exp_avg = 0 1074 | 1075 | def update(self, val, n=1): 1076 | self.val = val 1077 | self.sum += val * n 1078 | self.count += n 1079 | self.mean_avg = self.sum / self.count 1080 | self.exp_avg = self.mom*self.exp_avg + (1.0 - self.mom)*self.val 1081 | if self.count == 1: 1082 | self.exp_avg = self.val 1083 | 1084 | if __name__ == '__main__': 1085 | pass 1086 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>2.0.0 2 | numpy>=1.0.0 3 | protobuf>=3.0.0 4 | tensorboardX>=1.3 5 | torch>=1.0.1 6 | torchvision -------------------------------------------------------------------------------- /utils/group_lasso_optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from torch.optim.optimizer import Optimizer 7 | 8 | PRINT_ALL = False 9 | USE_FULL = False 10 | 11 | class group_lasso_decay(Optimizer): 12 | r"""Implements group lasso weight decay (GLWD) that pushed entire group to 0. 13 | Normal weight decay makes weight sparse, GLWD will make sparse channel. 14 | Assumes we want to decay the group related to feature maps (channels), other groups are possible but not implemented 15 | """ 16 | 17 | def __init__(self, params, group_lasso_weight=0, named_parameters=None, output_sizes=None): 18 | defaults = dict(group_lasso_weight = group_lasso_weight, total_neurons = 0) 19 | 20 | super(group_lasso_decay, self).__init__(params, defaults) 21 | 22 | self.per_layer_per_neuron_stats = {'flops': list(), 'params': list(), 'latency': list()} 23 | 24 | self.named_parameters = named_parameters 25 | 26 | self.output_sizes = None 27 | if output_sizes is not None: 28 | self.output_sizes = output_sizes 29 | 30 | def __setstate__(self, state): 31 | super(group_lasso_decay, self).__setstate__(state) 32 | 33 | def get_number_neurons(self, print_output = False): 34 | total_neurons = 0 35 | for gr_ind, group in enumerate(self.param_groups): 36 | total_neurons += group['total_neurons'].item() 37 | # if print_output: 38 | # print("Total parameters: ",total_neurons, " or ", total_neurons/1e7, " times 1e7") 39 | return total_neurons 40 | 41 | def get_number_flops(self, print_output = False): 42 | total_flops = 0 43 | for gr_ind, group in enumerate(self.param_groups): 44 | total_flops += group['total_flops'].item() 45 | 46 | total_neurons = 0 47 | for gr_ind, group in enumerate(self.param_groups): 48 | total_neurons += group['total_neurons'].item() 49 | 50 | if print_output: 51 | # print("Total flops: ",total_flops, " or ", total_flops/1e9, " times 1e9") 52 | 53 | print("Flops 1e 9/params 1e7: %3.3f & %3.3f"%(total_flops/1e9, total_neurons/1e7)) 54 | return total_flops 55 | 56 | def step(self, closure=None): 57 | """Applies GLWD regularization to weights. 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | group_lasso_weight = group['group_lasso_weight'] 68 | group['total_neurons'] = 0 69 | group['total_flops'] = 0 70 | 71 | for p in group['params']: 72 | param_state = self.state[p] 73 | 74 | if 1: 75 | 76 | weight_size = p.data.size() 77 | 78 | if ('group_lasso_coeff' not in param_state) or (param_state['group_lasso_coeff'].shape != p.data.shape): 79 | nunits = p.data.size(0) 80 | param_state['group_lasso_coeff'] = p.data.clone().view(nunits,-1).sum(dim=1)*0.0 + group_lasso_weight 81 | 82 | group_lasso_weight_local = param_state['group_lasso_coeff'] 83 | 84 | if len(weight_size) == 4: 85 | # defined for conv layers only 86 | nunits = p.data.size(0) 87 | 88 | # let's compute denominator 89 | divider = p.data.pow(2).view(nunits,-1).sum(dim=1).pow(0.5) 90 | 91 | eps = 1e-5 92 | eps2 = 1e-13 93 | # check if divider is above threshold 94 | divider_bool = divider.gt(eps).view(-1).float() 95 | 96 | group_lasso_gradient = p.data * (group_lasso_weight_local * divider_bool / 97 | (divider + eps2)).view(nunits, 1, 1, 1).repeat(1, weight_size[1],weight_size[2],weight_size[3]) 98 | 99 | # apply weight decay step: 100 | p.data.add_(-1.0, group_lasso_gradient) 101 | return loss 102 | 103 | def step_after(self, closure=None): 104 | """Computes FLOPS and number of neurons after considering zeroed out input and outputs. 105 | Channels are assumed to be pruned if their l2 norm is very small or if magnitude of gradient is very small. 106 | This function does not perform weight pruning, weights are untouched. 107 | 108 | This function also calls push_biases_down which sets corresponding biases to 0. 109 | 110 | Arguments: 111 | closure (callable, optional): A closure that reevaluates the model 112 | and returns the loss. 113 | """ 114 | loss = None 115 | if closure is not None: 116 | loss = closure() 117 | 118 | param_index = -1 119 | conv_param_index = -1 120 | for group in self.param_groups: 121 | group_lasso_weight = group['group_lasso_weight'] 122 | group['total_neurons'] = 0 123 | group['total_flops'] = 0 124 | 125 | for p in group['params']: 126 | 127 | if group_lasso_weight != 0 or 1: 128 | weight_size = p.data.size() 129 | 130 | if (len(weight_size) == 4) or (len(weight_size) == 2) or (len(weight_size) == 1): 131 | param_index += 1 132 | # defined for conv layers only 133 | nunits = p.data.size(0) 134 | # let's compute denominator 135 | divider = p.data.pow(2).view(nunits,-1).sum(dim=1).pow(0.5) 136 | 137 | eps = 1e-4 138 | # check if divider is above threshold 139 | divider_bool = divider.gt(eps).view(-1).float() 140 | 141 | if (len(weight_size) == 4) or (len(weight_size) == 2) or (len(weight_size) == 1): 142 | if not (p.grad is None): 143 | # consider gradients as well and if gradient is below spesific threshold than we claim parameter to be removed 144 | divider_grad = p.grad.data.pow(2).view(nunits, -1).sum(dim=1).pow(0.5) 145 | eps = 1e-8 146 | divider_bool_grad = divider_grad.gt(eps).view(-1).float() 147 | divider_bool = divider_bool_grad * divider_bool 148 | 149 | if (len(weight_size) == 4) or (len(weight_size) == 2): 150 | # get gradient for input: 151 | divider_grad_input = p.grad.data.pow(2).transpose(0,1).contiguous().view(p.data.size(1),-1).sum(dim=1).pow(0.5) 152 | divider_bool_grad_input = divider_grad_input.gt(eps).view(-1).float() 153 | 154 | divider_input = p.data.pow(2).transpose(0,1).contiguous().view(p.data.size(1), -1).sum(dim=1).pow(0.5) 155 | divider_bool_input = divider_input.gt(eps).view(-1).float() 156 | divider_bool_input = divider_bool_input * divider_bool_grad_input 157 | # if gradient is small then remove it out 158 | 159 | if USE_FULL: 160 | # reset to evaluate true number of flops and neurons 161 | # useful for full network only 162 | divider_bool = 0.0*divider_bool + 1.0 163 | divider_bool_input = 0.0*divider_bool_input + 1.0 164 | 165 | if len(weight_size) == 4: 166 | p.data.mul_(divider_bool.view(nunits,1, 1, 1).repeat(1,weight_size[1], weight_size[2], weight_size[3])) 167 | current_neurons = divider_bool.sum()*divider_bool_input.sum()*weight_size[2]* weight_size[3] 168 | 169 | if len(weight_size) == 2: 170 | current_neurons = divider_bool.sum()*divider_bool_input.sum() 171 | 172 | if len(weight_size) == 1: 173 | current_neurons = divider_bool.sum() 174 | # add mean and var over batches 175 | current_neurons = current_neurons + divider_bool.sum() 176 | 177 | group['total_neurons'] += current_neurons 178 | 179 | if len(weight_size) == 4: 180 | conv_param_index += 1 181 | input_channels = divider_bool_input.sum() 182 | output_channels = divider_bool.sum() 183 | 184 | if self.output_sizes is not None: 185 | output_height, output_width = self.output_sizes[conv_param_index][-2:] 186 | else: 187 | if hasattr(p, 'output_dims'): 188 | output_height, output_width = p.output_dims[-2:] 189 | else: 190 | output_height, output_width = 0, 0 191 | 192 | kernel_ops = weight_size[2] * weight_size[3] * input_channels 193 | 194 | params = output_channels * kernel_ops 195 | flops = params * output_height * output_height 196 | 197 | # add flops due to batch normalization 198 | flops = flops + output_height * output_width*3 199 | 200 | group['total_flops'] = group['total_flops'] + flops 201 | 202 | if len(weight_size) == 1: 203 | flops = len(weight_size) 204 | group['total_flops'] = group['total_flops'] + flops 205 | if len(weight_size) == 2: 206 | input_channels = divider_bool_input.sum() 207 | output_channels = divider_bool.sum() 208 | flops = input_channels * output_channels 209 | group['total_flops'] = group['total_flops'] + flops 210 | 211 | if len(self.per_layer_per_neuron_stats['flops']) <= param_index: 212 | self.per_layer_per_neuron_stats['flops'].append(flops / divider_bool.sum()) 213 | self.per_layer_per_neuron_stats['params'].append(current_neurons / divider_bool.sum()) 214 | # self.per_layer_per_neuron_stats['latency'].append(flops / output_channels) 215 | else: 216 | self.per_layer_per_neuron_stats['flops'][param_index] = flops / divider_bool.sum() 217 | self.per_layer_per_neuron_stats['params'][param_index] = current_neurons / divider_bool.sum() 218 | # self.per_layer_per_neuron_stats['latency'][param_index] = flops / output_channels 219 | 220 | self.push_biases_down(eps=1e-3) 221 | return loss 222 | 223 | def push_biases_down(self, eps=1e-3): 224 | ''' 225 | This function goes over parameters and sets according biases to zero, 226 | without this function biases will not be zero 227 | ''' 228 | # first pass 229 | list_of_names = [] 230 | for name, param in self.named_parameters: 231 | if "weight" in name: 232 | weight_size = param.data.shape 233 | if (len(weight_size) == 4) or (len(weight_size) == 2): 234 | # defined for conv layers only 235 | nunits = weight_size[0] 236 | # let's compute denominator 237 | divider = param.data.pow(2).view(nunits, -1).sum(dim=1).pow(0.5) 238 | divider_bool = divider.gt(eps).view(-1).float() 239 | list_of_names.append((name.replace("weight", "bias"), divider_bool)) 240 | 241 | # second pass 242 | for name, param in self.named_parameters: 243 | if "bias" in name: 244 | for ind in range(len(list_of_names)): 245 | if list_of_names[ind][0] == name: 246 | param.data.mul_(list_of_names[ind][1]) 247 | 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | 8 | 9 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 10 | """Saves checkpoint to disk""" 11 | try: 12 | torch.save(state, filename) 13 | if is_best: 14 | torch.save(state, filename.replace("checkpoint", "best_model")) 15 | except: 16 | print("didn't save checkpoint file") 17 | 18 | 19 | def adjust_learning_rate(args, optimizer, epoch, zero_lr_for_epochs, train_writer): 20 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs""" 21 | lr = args.lr * (args.lr_decay_scalar ** (epoch // args.lr_decay_every)) # * (0.1 ** (epoch // (2*args.lr_decay_every))) 22 | if zero_lr_for_epochs > -1: 23 | if epoch <= zero_lr_for_epochs: 24 | lr = 0.0 25 | # log to TensorBoard 26 | if args.tensorboard: 27 | train_writer.add_scalar('learning_rate', lr, epoch) 28 | print("learning rate adjusted:", lr, epoch) 29 | for param_group in optimizer.param_groups: 30 | param_group['lr'] = lr 31 | 32 | 33 | def adjust_learning_rate_fixed(args, optimizer, epoch, zero_lr_for_epochs, train_writer): 34 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs""" 35 | lr = args.lr * (args.lr_decay_scalar ** (epoch // args.lr_decay_every)) 36 | if zero_lr_for_epochs > -1: 37 | if epoch <= zero_lr_for_epochs: 38 | lr = 0.0 39 | # log to TensorBoard 40 | if epoch == args.lr_decay_every: 41 | lr_scale = args.lr_decay_scalar 42 | lr = args.lr * (args.lr_decay_scalar ** (epoch // args.lr_decay_every)) 43 | if args.tensorboard: 44 | train_writer.add_scalar('learning_rate', lr, epoch) 45 | print("learning rate adjusted:", lr, epoch) 46 | for param_group in optimizer.param_groups: 47 | param_group['lr'] = param_group['lr']*lr_scale 48 | 49 | 50 | class AverageMeter(object): 51 | """Computes and stores the average and current value""" 52 | def __init__(self): 53 | self.reset() 54 | 55 | def reset(self): 56 | self.val = 0 57 | self.avg = 0 58 | self.sum = 0 59 | self.count = 0 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | 67 | 68 | def accuracy(output, target, topk=(1,)): 69 | """Computes the precision@k for the specified values of k""" 70 | maxk = max(topk) 71 | batch_size = target.size(0) 72 | 73 | _, pred = output.topk(maxk, 1, True, True) 74 | pred = pred.t() 75 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 76 | 77 | res = [] 78 | for k in topk: 79 | correct_k = correct[:k].view(-1).float().sum(0) 80 | res.append(correct_k.mul_(100.0 / batch_size)) 81 | return res 82 | 83 | 84 | def load_model_pytorch(model, load_model, model_name): 85 | print("=> loading checkpoint '{}'".format(load_model)) 86 | checkpoint = torch.load(load_model) 87 | 88 | if 'state_dict' in checkpoint.keys(): 89 | load_from = checkpoint['state_dict'] 90 | else: 91 | load_from = checkpoint 92 | 93 | # match_dictionaries, useful if loading model without gate: 94 | if 'module.' in list(model.state_dict().keys())[0]: 95 | if 'module.' not in list(load_from.keys())[0]: 96 | from collections import OrderedDict 97 | 98 | load_from = OrderedDict([("module.{}".format(k), v) for k, v in load_from.items()]) 99 | 100 | if 'module.' not in list(model.state_dict().keys())[0]: 101 | if 'module.' in list(load_from.keys())[0]: 102 | from collections import OrderedDict 103 | 104 | load_from = OrderedDict([(k.replace("module.", ""), v) for k, v in load_from.items()]) 105 | 106 | # just for vgg 107 | if model_name == "vgg": 108 | from collections import OrderedDict 109 | 110 | load_from = OrderedDict([(k.replace("features.", "features"), v) for k, v in load_from.items()]) 111 | load_from = OrderedDict([(k.replace("classifier.", "classifier"), v) for k, v in load_from.items()]) 112 | 113 | if 1: 114 | for ind, (key, item) in enumerate(model.state_dict().items()): 115 | if ind > 10: 116 | continue 117 | print(key, model.state_dict()[key].shape) 118 | 119 | print("*********") 120 | 121 | for ind, (key, item) in enumerate(load_from.items()): 122 | if ind > 10: 123 | continue 124 | print(key, load_from[key].shape) 125 | 126 | for key, item in model.state_dict().items(): 127 | # if we add gate that is not in the saved file 128 | if key not in load_from: 129 | load_from[key] = item 130 | # if load pretrined model 131 | if load_from[key].shape != item.shape: 132 | load_from[key] = item 133 | 134 | model.load_state_dict(load_from, strict=True) 135 | 136 | 137 | epoch_from = -1 138 | if 'epoch' in checkpoint.keys(): 139 | epoch_from = checkpoint['epoch'] 140 | print("=> loaded checkpoint '{}' (epoch {})" 141 | .format(load_model, epoch_from)) 142 | 143 | 144 | def dynamic_network_change_local(model): 145 | ''' 146 | Methods attempts to modify network in place by removing pruned filters. 147 | Works with ResNet101 for now only 148 | :param model: reference to torch model to be modified 149 | :return: 150 | ''' 151 | # change network dynamically given a pruning mask 152 | 153 | # step 1: model adjustment 154 | # lets go layer by layer and get the mask if we have parameter in pruning settings: 155 | 156 | pruning_maks_input = None 157 | prev_model1 = None 158 | prev_model2 = None 159 | prev_model3 = None 160 | 161 | pruning_mask_indexes = None 162 | gate_track = -1 163 | 164 | skip_connections = list() 165 | 166 | current_skip = 0 167 | DO_SKIP = True 168 | 169 | gate_size = -1 170 | current_skip_mask_size = -1 171 | 172 | for module_indx, m in enumerate(model.modules()): 173 | pruning_mask_indexes = None 174 | if not hasattr(m, "do_not_update"): 175 | if isinstance(m, torch.nn.Conv2d): 176 | print("interm layer", gate_track, module_indx) 177 | print(m) 178 | if 1: 179 | if pruning_maks_input is not None: 180 | print("fixing interm layer", gate_track, module_indx) 181 | 182 | m.weight.data = m.weight.data[:, pruning_maks_input] 183 | 184 | pruning_maks_input = None 185 | print("weight size now", m.weight.data.shape) 186 | m.in_channels = m.weight.data.shape[1] 187 | m.out_channels = m.weight.data.shape[0] 188 | 189 | if DO_SKIP: 190 | print("doing skip connection") 191 | if m.weight.data.shape[0] == current_skip_mask_size: 192 | m.weight.data = m.weight.data[current_skip_mask] 193 | print("weight size after skip", m.weight.data.shape) 194 | 195 | if module_indx in [23, 63, 115, 395]: 196 | if DO_SKIP: 197 | print("fixing interm skip") 198 | m.weight.data = m.weight.data[:, prev_skip_mask] 199 | 200 | print("weight size now", m.weight.data.shape) 201 | m.in_channels = m.weight.data.shape[1] 202 | m.out_channels = m.weight.data.shape[0] 203 | 204 | if DO_SKIP: 205 | print("doing skip connection") 206 | if m.weight.data.shape[0] == current_skip_mask_size: 207 | m.weight.data = m.weight.data[current_skip_mask] 208 | print("weight size after skip", m.weight.data.shape) 209 | 210 | if isinstance(m, torch.nn.BatchNorm2d): 211 | print("interm layer BN: ", gate_track, module_indx) 212 | print(m) 213 | if DO_SKIP: 214 | print("doing skip connection") 215 | if m.weight.data.shape[0] == current_skip_mask_size: 216 | m.weight.data = m.weight.data[current_skip_mask] 217 | print("weight size after skip", m.weight.data.shape) 218 | 219 | # m.weight.data = m.weight.data[current_skip_mask] 220 | m.bias.data = m.bias.data[current_skip_mask] 221 | m.running_mean.data = m.running_mean.data[current_skip_mask] 222 | m.running_var.data = m.running_var.data[current_skip_mask] 223 | else: 224 | # keeping track of gates: 225 | gate_track += 1 226 | 227 | then_pass = False 228 | if gate_track < 4: 229 | # skipping skip connections 230 | then_pass = True 231 | skip_connections.append(m.weight) 232 | current_skip = -1 233 | 234 | if not then_pass: 235 | pruning_mask = m.weight 236 | if gate_size!=m.weight.shape[0]: 237 | current_skip += 1 238 | current_skip_mask_size = skip_connections[current_skip].data.shape[0] 239 | 240 | if skip_connections[current_skip].data.shape[0] != 2048: 241 | current_skip_mask = skip_connections[current_skip].data.nonzero().view(-1) 242 | else: 243 | current_skip_mask = (skip_connections[current_skip].data + 1.0).nonzero().view(-1) 244 | prev_skip_mask_size = 64 245 | prev_skip_mask = range(64) 246 | if current_skip > 0: 247 | prev_skip_mask_size = skip_connections[current_skip - 1].data.shape[0] 248 | prev_skip_mask = skip_connections[current_skip - 1].data.nonzero().view(-1) 249 | 250 | gate_size = m.weight.shape[0] 251 | 252 | if 1: 253 | print("fixing layer", gate_track, module_indx) 254 | print(m) 255 | print(pruning_mask) 256 | 257 | if 1.0 in pruning_mask: 258 | pruning_mask_indexes = pruning_mask.nonzero().view(-1) 259 | else: 260 | pruning_mask_indexes = [] 261 | m.weight.data = m.weight.data[pruning_mask_indexes] 262 | for prev_model in [prev_model1, prev_model2, prev_model3]: 263 | if isinstance(prev_model, torch.nn.Conv2d): 264 | print("prev fixing layer", prev_model, gate_track, module_indx) 265 | prev_model.weight.data = prev_model.weight.data[pruning_mask_indexes] 266 | print("weight size", prev_model.weight.data.shape) 267 | 268 | if DO_SKIP: 269 | print("doing skip connection") 270 | 271 | if prev_model.weight.data.shape[1] == current_skip_mask_size: 272 | prev_model.weight.data = prev_model.weight.data[:, current_skip_mask] 273 | print("weight size", prev_model.weight.data.shape) 274 | 275 | if module_indx in [53, 105, 385]: # add one more layer for this transition 276 | print("doing skip connection") 277 | 278 | if prev_model.weight.data.shape[1] == prev_skip_mask_size: 279 | prev_model.weight.data = prev_model.weight.data[:, prev_skip_mask] 280 | print("weight size", prev_model.weight.data.shape) 281 | 282 | if isinstance(prev_model, torch.nn.BatchNorm2d): 283 | print("prev fixing layer", prev_model, gate_track, module_indx) 284 | prev_model.weight.data = prev_model.weight.data[pruning_mask_indexes] 285 | prev_model.bias.data = prev_model.bias.data[pruning_mask_indexes] 286 | prev_model.running_mean.data = prev_model.running_mean.data[pruning_mask_indexes] 287 | prev_model.running_var.data = prev_model.running_var.data[pruning_mask_indexes] 288 | 289 | pruning_maks_input = pruning_mask_indexes 290 | 291 | prev_model3 = prev_model2 292 | prev_model2 = prev_model1 293 | prev_model1 = m 294 | 295 | if DO_SKIP: 296 | # fix gate layers 297 | gate_track = 0 298 | 299 | for module_indx, m in enumerate(model.modules()): 300 | if hasattr(m, "do_not_update"): 301 | gate_track += 1 302 | if gate_track < 4: 303 | if m.weight.shape[0] < 2048: 304 | m.weight.data = m.weight.data[m.weight.nonzero().view(-1)] 305 | 306 | print("printing conv layers") 307 | for module_indx, m in enumerate(model.modules()): 308 | if isinstance(m, torch.nn.Conv2d): 309 | print(module_indx, "->", m.weight.data.shape) 310 | 311 | print("printing bn layers") 312 | for module_indx, m in enumerate(model.modules()): 313 | if isinstance(m, torch.nn.BatchNorm2d): 314 | print(module_indx, "->", m.weight.data.shape) 315 | 316 | print("printing gate layers") 317 | for module_indx, m in enumerate(model.modules()): 318 | if hasattr(m, "do_not_update"): 319 | print(module_indx, "->", m.weight.data.shape, m.size_mask) 320 | 321 | 322 | def add_hook_for_flops(args, model): 323 | # add output dims for FLOPs computation 324 | if 1: 325 | for module_indx, m in enumerate(model.modules()): 326 | if isinstance(m, torch.nn.Conv2d): 327 | def forward_hook(self, input, output): 328 | self.weight.output_dims = output.shape 329 | 330 | m.register_forward_hook(forward_hook) 331 | 332 | 333 | def get_conv_sizes(args, model): 334 | output_sizes = None 335 | if args.compute_flops: 336 | # add hooks to compute dimensions of the output tensors for conv layers 337 | add_hook_for_flops(args, model) 338 | if 1: 339 | if args.dataset=="CIFAR10": 340 | dummy_input = torch.rand(1, 3, 32, 32) 341 | elif args.dataset=="Imagenet": 342 | dummy_input = torch.rand(1, 3, 224, 224) 343 | # run inference 344 | with torch.no_grad(): 345 | model(dummy_input) 346 | # store flops 347 | output_sizes = list() 348 | for param in model.parameters(): 349 | if hasattr(param, 'output_dims'): 350 | output_dims = param.output_dims 351 | output_sizes.append(output_dims) 352 | 353 | return output_sizes 354 | 355 | 356 | def connect_gates_with_parameters_for_flops(model_name, named_parameters): 357 | ''' 358 | Function creates a mapping between gates and parameter index to map flops 359 | :return: 360 | returns a list with mapping, each element is a gate id, entries are corresponding parameters 361 | ''' 362 | if "resnet" not in model_name: 363 | print("connect_gates_with_parameters_for_flops only supports resnet for now") 364 | return -1 365 | 366 | # for skip connections, first 4 for them: 367 | gate_to_param_map = [list() for _ in range(4)] 368 | 369 | for param_id, (name, param) in enumerate(named_parameters): 370 | if "layer" not in name: 371 | # skip because we don't prune the first layer 372 | continue 373 | 374 | if ("conv1" in name) or ("conv2" in name): 375 | gate_to_param_map.append(param_id) 376 | 377 | if "conv3" in name: 378 | # third convolution contributes to skip connection only 379 | skip_block_id = int(name[name.find("layer")+len("layer")]) 380 | gate_to_param_map[skip_block_id-1].append(param_id) 381 | 382 | return gate_to_param_map --------------------------------------------------------------------------------