├── .gitignore
├── LICENSE.md
├── README.md
├── configs
├── imagenet_resnet101_prune100.json
├── imagenet_resnet101_prune40.json
├── imagenet_resnet101_prune50.json
├── imagenet_resnet101_prune50_group.json
├── imagenet_resnet101_prune55.json
├── imagenet_resnet101_prune75.json
├── imagenet_resnet50_prune100.json
├── imagenet_resnet50_prune56.json
├── imagenet_resnet50_prune72.json
├── imagenet_resnet50_prune81.json
└── imagenet_resnet50_prune91.json
├── images
└── resnet_result.png
├── layers
└── gate_layer.py
├── logger.py
├── main.py
├── models
├── densenet_imagenet.py
├── lenet.py
├── preact_resnet.py
├── resnet.py
└── vgg_bn.py
├── pruning_engine.py
├── requirements.txt
└── utils
├── group_lasso_optimizer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | runs/
2 | .idea/
3 | *.tar.gz
4 | *.zip
5 | *.pkl
6 | *.pyc
7 | *.py~
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 | *.*[cod]
12 | .DS_Store
13 | ._*
14 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial-ShareAlike 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
26 |
27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
28 |
29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
30 |
31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
32 |
33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
34 |
35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
36 |
37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
38 |
39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
40 |
41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
42 |
43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
44 |
45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
46 |
47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
48 |
49 | ### Section 2 – Scope.
50 |
51 | a. ___License grant.___
52 |
53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
54 |
55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
56 |
57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
58 |
59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
60 |
61 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
62 |
63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
64 |
65 | 5. __Downstream recipients.__
66 |
67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
68 |
69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
70 |
71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
72 |
73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
74 |
75 | b. ___Other rights.___
76 |
77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
78 |
79 | 2. Patent and trademark rights are not licensed under this Public License.
80 |
81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
82 |
83 | ### Section 3 – License Conditions.
84 |
85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
86 |
87 | a. ___Attribution.___
88 |
89 | 1. If You Share the Licensed Material (including in modified form), You must:
90 |
91 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
92 |
93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
94 |
95 | ii. a copyright notice;
96 |
97 | iii. a notice that refers to this Public License;
98 |
99 | iv. a notice that refers to the disclaimer of warranties;
100 |
101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
102 |
103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
104 |
105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
106 |
107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
108 |
109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
110 |
111 | b. ___ShareAlike.___
112 |
113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
114 |
115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
116 |
117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
118 |
119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
120 |
121 | ### Section 4 – Sui Generis Database Rights.
122 |
123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
124 |
125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
126 |
127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
128 |
129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
130 |
131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
132 |
133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
134 |
135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
136 |
137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
138 |
139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
140 |
141 | ### Section 6 – Term and Termination.
142 |
143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
144 |
145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
146 |
147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
148 |
149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
150 |
151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
152 |
153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
154 |
155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
156 |
157 | ### Section 7 – Other Terms and Conditions.
158 |
159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
160 |
161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
162 |
163 | ### Section 8 – Interpretation.
164 |
165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
166 |
167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
168 |
169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
170 |
171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
172 |
173 | ```
174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
175 |
176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/).
177 | ```
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://raw.githubusercontent.com/nvlabs/SPADE/master/LICENSE.md)
2 | 
3 |
4 | # Importance Estimation for Neural Network Pruning
5 |
6 | This repo contains required scripts to reproduce results from paper:
7 |
8 | Importance Estimation for Neural Network Pruning
9 | Pavlo Molchanov, Arun Mallya, Stephen Tyree, Iuri Frosio, Jan Kautz .
10 | In CVPR 2019.
11 |
12 | 
13 |
14 | ### [License](https://raw.githubusercontent.com/nvlabs/Taylor_pruning/master/LICENSE.md)
15 |
16 | Copyright (C) 2019 NVIDIA Corporation.
17 |
18 | All rights reserved.
19 | Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**)
20 |
21 | The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com).
22 |
23 | ## Installation
24 |
25 | Clone this repo.
26 | ```bash
27 | git clone https://github.com/NVlabs/Taylor_pruning.git
28 | cd Taylor_pruning/
29 | ```
30 |
31 | This code requires PyTorch 1.0 and python 3+. Please install dependencies by
32 | ```bash
33 | pip install -r requirements.txt
34 | ```
35 |
36 | For the best reproducibility of results you will need NVIDIA DGX1 server with 8 V100. During pruning and finetuning we at most use 4 GPUs.
37 |
38 |
39 | The code was tested with python3.6 the following software versions:
40 |
41 | | Software | version |
42 | | ------------- |-------------|
43 | | cuDNN | v7500 |
44 | | Pytorch | 1.0.1.post2 |
45 | | CUDA | v10.0 |
46 |
47 |
48 | ## Preparation
49 |
50 | ### Dataset preparation
51 |
52 | Pruning examples use ImageNet 1k dataset which needs to be downloaded beforehand.
53 | Use standard instructions to setup ImageNet 1k for Pytorch, e.g. from [here](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5#geting-the-data).
54 |
55 | ### Models preparation
56 |
57 | We use pretrained models provided with Pytorch, they need to be downloaded:
58 |
59 | ```
60 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth
61 | wget https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
62 | mkdir ./models/pretrained/
63 | mv {resnet50-19c8e357.pth,resnet101-5d3b4d8f.pth} ./models/pretrained/
64 | ```
65 |
66 | ## Pruning examples
67 |
68 | Results in Table 3 can be reproduced the following script running on 4 V100 GPUs:
69 |
70 | ```
71 | python main.py --name=runs/resnet50/resnet50_prune72 --dataset=Imagenet \
72 | --lr=0.001 --lr-decay-every=10 --momentum=0.9 --epochs=25 --batch-size=256 \
73 | --pruning=True --seed=0 --model=resnet50 --load_model=./models/pretrained/resnet50-19c8e357.pth \
74 | --mgpu=True --group_wd_coeff=1e-8 --wd=0.0 --tensorboard=True --pruning-method=22 \
75 | --data=/imagenet/ --no_grad_clip=True --pruning_config=./configs/imagenet_resnet50_prune72.json
76 | ```
77 |
78 | Note: we run finetuning for 25 epochs and do not use weight decay. Better model can be obtained by training longer, increasing learning and using weight decay.
79 |
80 | Example of config file `./configs/imagenet_resnet50_prune72.json`:
81 | ```json
82 | {
83 | "method": 22,
84 | "frequency" : 30,
85 | "prune_per_iteration" : 100,
86 | "maximum_pruning_iterations" : 32,
87 | "starting_neuron" : 0,
88 | "fixed_layer" : -1,
89 | "l2_normalization_per_layer": false,
90 | "rank_neurons_equally": false,
91 | "prune_neurons_max": 3200,
92 | "use_momentum": true,
93 | "pruning_silent": false,
94 |
95 | "pruning_threshold" : 100.0,
96 | "start_pruning_after_n_iterations" : 0,
97 | "push_group_down" : false,
98 | "do_iterative_pruning" : true,
99 | "fixed_criteria" : false,
100 | "seed" : 0,
101 | "pruning_momentum" : 0.9
102 | }
103 |
104 | ```
105 |
106 | We provide config files for pruning to 56%, 72%, 81%, 91% of original ResNet-50 and to 40%, 50%, 55%, 75% of ResNet-101 models. Percentage means the ratio of gates to be active after pruning.
107 |
108 | ## Parameter description
109 |
110 | Pruning methods (different criteria) are encoded with integer as:
111 |
112 | | method id | name | description | comment |
113 | | ------------- |-------------| -------------| -----|
114 | | 22 | Taylor_gate | *Gate after BN* in Table 2, *Taylor FO* in Table 1, *Taylor-FO-BN* in Table 3 | **Best method**|
115 | | 0 | Taylor_weight| *Conv weight/conv/linear weight with Taylor FO* In Table 2 and Table 1 | |
116 | | 1 | Random | Random||
117 | | 2 | Weight norm | Weight magnitude/ weight||
118 | | 3 | Weight_abs | Not used||
119 | | 6 | Taylor_output | *Taylor-output* as in [27]||
120 | | 10 | OBD | Optimal Brain Damage||
121 | | 11 | Taylor_gate_SO| Taylor SO||
122 | | 23 | Taylor_gate_FG| uses gradient per example to compute Taylor FO, Taylor FO- FG in Table 1, Gate after BN - FG in Table 2||
123 | | 30 | BN_weight | *BN scale* in Table 2||
124 | | 31 | BN_Taylor | *BN scale Taylor FO* in Table 2||
125 |
126 |
127 |
128 | ### Citation
129 |
130 | If you use this code for your research, please cite our papers.
131 | ```
132 | @inproceedings{molchanov2019taylor,
133 | title={Importance Estimation for Neural Network Pruning},
134 | author={Molchanov, Pavlo and Mallya, Arun and Tyree, Stephen and Frosio, Iuri and Kautz, Jan},
135 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
136 | year={2019}
137 | }
138 | ```
139 |
140 | ## To come
141 |
142 | - Config files for DenseNet and VGG16 pruning on ImageNet
143 | - Examples of CIFAR experiments
144 | - Getting shrinked model after pruning (for ResNet-101 only)
145 | - Oracle estimates of true importance, code to compute correlation with Oracle
146 |
147 | ## Acknowledgments
148 | This code heavily reuses Pytorch example for ImageNet training provided [here](https://github.com/pytorch/examples/tree/master/imagenet).
--------------------------------------------------------------------------------
/configs/imagenet_resnet101_prune100.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 0,
5 | "maximum_pruning_iterations" : 1,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 0,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet101_prune40.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 120,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 12000,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet101_prune50.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 100,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 10000,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet101_prune50_group.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 100,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 10000,
11 | "use_momentum": true,
12 | "pruning_silent": true,
13 | "group_size": 32,
14 | "flops_regularization": 1e-12,
15 |
16 | "pruning_threshold" : 100.0,
17 | "start_pruning_after_n_iterations" : 0,
18 | "do_iterative_pruning" : true,
19 | "fixed_criteria" : false,
20 | "seed" : 0,
21 | "pruning_momentum" : 0.9
22 | }
23 |
24 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet101_prune55.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 90,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 9000,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet101_prune75.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 50,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 5000,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet50_prune100.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 0,
5 | "maximum_pruning_iterations" : 1,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 0,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet50_prune56.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 55,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 5500,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet50_prune72.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 32,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 3200,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet50_prune81.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 22,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 2200,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/configs/imagenet_resnet50_prune91.json:
--------------------------------------------------------------------------------
1 | {
2 | "method": 22,
3 | "frequency" : 30,
4 | "prune_per_iteration" : 100,
5 | "maximum_pruning_iterations" : 10,
6 | "starting_neuron" : 0,
7 | "fixed_layer" : -1,
8 | "l2_normalization_per_layer": false,
9 | "rank_neurons_equally": false,
10 | "prune_neurons_max": 1000,
11 | "use_momentum": true,
12 | "pruning_silent": false,
13 |
14 | "pruning_threshold" : 100.0,
15 | "start_pruning_after_n_iterations" : 0,
16 | "push_group_down" : false,
17 | "do_iterative_pruning" : true,
18 | "fixed_criteria" : false,
19 | "seed" : 0,
20 | "pruning_momentum" : 0.9
21 | }
22 |
23 |
--------------------------------------------------------------------------------
/images/resnet_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Taylor_pruning/4b70120a8903494cdb565cd8ca97509543fd2862/images/resnet_result.png
--------------------------------------------------------------------------------
/layers/gate_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | '''
11 | Gating layer for pruning
12 | '''
13 |
14 | class GateLayer(nn.Module):
15 | def __init__(self, input_features, output_features, size_mask):
16 | super(GateLayer, self).__init__()
17 | self.input_features = input_features
18 | self.output_features = output_features
19 | self.size_mask = size_mask
20 | self.weight = nn.Parameter(torch.ones(output_features))
21 |
22 | # for simpler way to find these layers
23 | self.do_not_update = True
24 |
25 | def forward(self, input):
26 | return input*self.weight.view(*self.size_mask)
27 |
28 | def extra_repr(self):
29 | return 'in_features={}, out_features={}'.format(
30 | self.input_features, self.output_features is not None
31 | )
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import sys
7 |
8 | #based on https://groups.google.com/forum/#!topic/comp.lang.python/0lqfVgjkc68
9 |
10 | class Logger(object):
11 | def __init__(self, filename="Default.log"):
12 | self.terminal = sys.stdout
13 | bufsize = 1
14 | self.log = open(filename, "w", buffering=bufsize)
15 |
16 | def delink(self):
17 | self.log.close()
18 | self.log = open('foo', "w")
19 | # self.write = self.writeTerminalOnly
20 |
21 | def writeTerminalOnly(self, message):
22 | self.terminal.write(message)
23 |
24 | def write(self, message):
25 | self.terminal.write(message)
26 | self.log.write(message)
27 |
28 | def flush(self):
29 | pass
30 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from __future__ import print_function
7 | import argparse
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from torchvision import datasets, transforms
11 | import torch.backends.cudnn as cudnn
12 |
13 | # ++++++for pruning
14 | import os, sys
15 | import time
16 | from utils.utils import save_checkpoint, adjust_learning_rate, AverageMeter, accuracy, load_model_pytorch, dynamic_network_change_local, get_conv_sizes, connect_gates_with_parameters_for_flops
17 | from tensorboardX import SummaryWriter
18 |
19 | from logger import Logger
20 | from models.lenet import LeNet
21 | from models.vgg_bn import slimmingvgg as vgg11_bn
22 | from models.preact_resnet import *
23 | from pruning_engine import pytorch_pruning, PruningConfigReader, prepare_pruning_list
24 |
25 | from utils.group_lasso_optimizer import group_lasso_decay
26 |
27 | import torch.distributed as dist
28 | import torch.utils.data
29 | import torch.utils.data.distributed
30 | import torch.nn.parallel
31 |
32 | import warnings
33 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
34 |
35 | import numpy as np
36 | #++++++++end
37 |
38 | # code is based on Pytorch example for imagenet
39 | # https://github.com/pytorch/examples/tree/master/imagenet
40 |
41 |
42 | def str2bool(v):
43 | # from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse/43357954#43357954
44 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
45 | return True
46 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
47 | return False
48 | else:
49 | raise argparse.ArgumentTypeError('Boolean value expected.')
50 |
51 |
52 | def train(args, model, device, train_loader, optimizer, epoch, criterion, train_writer=None, pruning_engine=None):
53 | """Train for one epoch on the training set also performs pruning"""
54 | global global_iteration
55 | batch_time = AverageMeter()
56 | losses = AverageMeter()
57 | top1 = AverageMeter()
58 | top5 = AverageMeter()
59 | loss_tracker = 0.0
60 | acc_tracker = 0.0
61 | loss_tracker_num = 0
62 | res_pruning = 0
63 |
64 | model.train()
65 | if args.fixed_network:
66 | # if network is fixed then we put it to eval mode
67 | model.eval()
68 |
69 | end = time.time()
70 | for batch_idx, (data, target) in enumerate(train_loader):
71 | data, target = data.to(device), target.to(device)
72 | # make sure that all gradients are zero
73 | for p in model.parameters():
74 | if p.grad is not None:
75 | p.grad.detach_()
76 | p.grad.zero_()
77 |
78 | output = model(data)
79 | loss = criterion(output, target)
80 |
81 | if args.pruning:
82 | # useful for method 40 and 50 that calculate oracle
83 | pruning_engine.run_full_oracle(model, data, target, criterion, initial_loss=loss.item())
84 |
85 | # measure accuracy and record loss
86 | losses.update(loss.item(), data.size(0))
87 |
88 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
89 | top1.update(prec1.item(), data.size(0))
90 | top5.update(prec5.item(), data.size(0))
91 | acc_tracker += prec1.item()
92 |
93 | loss_tracker += loss.item()
94 |
95 | loss_tracker_num += 1
96 |
97 | if args.pruning:
98 | if pruning_engine.needs_hessian:
99 | pruning_engine.compute_hessian(loss)
100 |
101 | if not (args.pruning and args.pruning_method == 50):
102 | group_wd_optimizer.step()
103 |
104 |
105 | loss.backward()
106 |
107 |
108 | # add gradient clipping
109 | if not args.no_grad_clip:
110 | # found it useless for our experiments
111 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
112 |
113 | # step_after will calculate flops and number of parameters left
114 | # needs to be launched before the main optimizer,
115 | # otherwise weight decay will make numbers not correct
116 | if not (args.pruning and args.pruning_method == 50):
117 | if batch_idx % args.log_interval == 0:
118 | group_wd_optimizer.step_after()
119 |
120 | optimizer.step()
121 |
122 | batch_time.update(time.time() - end)
123 | end = time.time()
124 |
125 | global_iteration = global_iteration + 1
126 |
127 | if batch_idx % args.log_interval == 0:
128 | print('Epoch: [{0}][{1}/{2}]\t'
129 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
130 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
131 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
132 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
133 | epoch, batch_idx, len(train_loader), batch_time=batch_time,
134 | loss=losses, top1=top1, top5=top5))
135 |
136 | if train_writer is not None:
137 | train_writer.add_scalar('train_loss_ave', losses.avg, global_iteration)
138 |
139 | if args.pruning:
140 | # pruning_engine.update_flops(stats=group_wd_optimizer.per_layer_per_neuron_stats)
141 | pruning_engine.do_step(loss=loss.item(), optimizer=optimizer)
142 |
143 | if args.model == "resnet20" or args.model == "resnet101" or args.dataset == "Imagenet":
144 | if (pruning_engine.maximum_pruning_iterations == pruning_engine.pruning_iterations_done) and pruning_engine.set_moment_zero:
145 | for group in optimizer.param_groups:
146 | for p in group['params']:
147 | if p.grad is None:
148 | continue
149 | param_state = optimizer.state[p]
150 | if 'momentum_buffer' in param_state:
151 | del param_state['momentum_buffer']
152 |
153 | pruning_engine.set_moment_zero = False
154 |
155 | # if not (args.pruning and args.pruning_method == 50):
156 | # if batch_idx % args.log_interval == 0:
157 | # group_wd_optimizer.step_after()
158 |
159 | if args.tensorboard and (batch_idx % args.log_interval == 0):
160 |
161 | neurons_left = int(group_wd_optimizer.get_number_neurons(print_output=args.get_flops))
162 | flops = int(group_wd_optimizer.get_number_flops(print_output=args.get_flops))
163 |
164 | train_writer.add_scalar('neurons_optimizer_left', neurons_left, global_iteration)
165 | train_writer.add_scalar('neurons_optimizer_flops_left', flops, global_iteration)
166 | else:
167 | if args.get_flops:
168 | neurons_left = int(group_wd_optimizer.get_number_neurons(print_output=args.get_flops))
169 | flops = int(group_wd_optimizer.get_number_flops(print_output=args.get_flops))
170 |
171 | if args.limit_training_batches != -1:
172 | if args.limit_training_batches < batch_idx:
173 | # return from training step, unsafe and was not tested correctly
174 | print("return from training step, unsafe and was not tested correctly")
175 | return 0
176 |
177 | # print number of parameters left:
178 | if args.tensorboard:
179 | print('neurons_optimizer_left', neurons_left, global_iteration)
180 |
181 |
182 | def validate(args, test_loader, model, device, criterion, epoch, train_writer=None):
183 | """Perform validation on the validation set"""
184 | batch_time = AverageMeter()
185 | losses = AverageMeter()
186 | top1 = AverageMeter()
187 | top5 = AverageMeter()
188 |
189 | # switch to evaluate mode
190 | model.eval()
191 |
192 | end = time.time()
193 | with torch.no_grad():
194 | for data_test in test_loader:
195 | data, target = data_test
196 |
197 | data = data.to(device)
198 |
199 | output = model(data)
200 |
201 | if args.get_inference_time:
202 | iterations_get_inference_time = 100
203 | start_get_inference_time = time.time()
204 | for it in range(iterations_get_inference_time):
205 | output = model(data)
206 | end_get_inference_time = time.time()
207 | print("time taken for %d iterations, per-iteration is: "%(iterations_get_inference_time), (end_get_inference_time - start_get_inference_time)*1000.0/float(iterations_get_inference_time), "ms")
208 |
209 | target = target.to(device)
210 | loss = criterion(output, target)
211 |
212 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
213 | losses.update(loss.item(), data.size(0))
214 | top1.update(prec1.item(), data.size(0))
215 | top5.update(prec5.item(), data.size(0))
216 |
217 | # measure elapsed time
218 | batch_time.update(time.time() - end)
219 | end = time.time()
220 |
221 | print(' * Prec@1 {top1.avg:.3f}, Prec@5 {top5.avg:.3f}, Time {batch_time.sum:.5f}, Loss: {losses.avg:.3f}'.format(top1=top1, top5=top5,batch_time=batch_time, losses = losses) )
222 | # log to TensorBoard
223 | if train_writer is not None:
224 | train_writer.add_scalar('val_loss', losses.avg, epoch)
225 | train_writer.add_scalar('val_acc', top1.avg, epoch)
226 |
227 | return top1.avg, losses.avg
228 |
229 |
230 |
231 | def main():
232 | # Training settings
233 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
234 | parser.add_argument('--batch-size', type=int, default=64, metavar='N',
235 | help='input batch size for training (default: 64)')
236 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
237 | help='input batch size for testing (default: 1000)')
238 | parser.add_argument('--world_size', type=int, default=1,
239 | help='number of GPUs to use')
240 |
241 | parser.add_argument('--epochs', type=int, default=10, metavar='N',
242 | help='number of epochs to train (default: 10)')
243 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
244 | help='learning rate (default: 0.01)')
245 | parser.add_argument('--wd', type=float, default=1e-4,
246 | help='weight decay (default: 5e-4)')
247 | parser.add_argument('--lr-decay-every', type=int, default=100,
248 | help='learning rate decay by 10 every X epochs')
249 | parser.add_argument('--lr-decay-scalar', type=float, default=0.1,
250 | help='--')
251 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
252 | help='SGD momentum (default: 0.5)')
253 | parser.add_argument('--no-cuda', action='store_true', default=False,
254 | help='disables CUDA training')
255 | parser.add_argument('--seed', type=int, default=1, metavar='S',
256 | help='random seed (default: 1)')
257 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
258 | help='how many batches to wait before logging training status')
259 |
260 | parser.add_argument('--run_test', default=False, type=str2bool, nargs='?',
261 | help='run test only')
262 |
263 | parser.add_argument('--limit_training_batches', type=int, default=-1,
264 | help='how many batches to do per training, -1 means as many as possible')
265 |
266 | parser.add_argument('--no_grad_clip', default=False, type=str2bool, nargs='?',
267 | help='turn off gradient clipping')
268 |
269 | parser.add_argument('--get_flops', default=False, type=str2bool, nargs='?',
270 | help='add hooks to compute flops')
271 |
272 | parser.add_argument('--get_inference_time', default=False, type=str2bool, nargs='?',
273 | help='runs valid multiple times and reports the result')
274 |
275 | parser.add_argument('--mgpu', default=False, type=str2bool, nargs='?',
276 | help='use data paralization via multiple GPUs')
277 |
278 | parser.add_argument('--dataset', default="MNIST", type=str,
279 | help='dataset for experiment, choice: MNIST, CIFAR10', choices= ["MNIST", "CIFAR10", "Imagenet"])
280 |
281 | parser.add_argument('--data', metavar='DIR', default='/imagenet', help='path to imagenet dataset')
282 |
283 | parser.add_argument('--model', default="lenet3", type=str,
284 | help='model selection, choices: lenet3, vgg, mobilenetv2, resnet18',
285 | choices=["lenet3", "vgg", "mobilenetv2", "resnet18", "resnet152", "resnet50", "resnet50_noskip",
286 | "resnet20", "resnet34", "resnet101", "resnet101_noskip", "densenet201_imagenet",
287 | 'densenet121_imagenet'])
288 |
289 | parser.add_argument('--tensorboard', type=str2bool, nargs='?',
290 | help='Log progress to TensorBoard')
291 |
292 | parser.add_argument('--save_models', default=True, type=str2bool, nargs='?',
293 | help='if True, models will be saved to the local folder')
294 |
295 |
296 | # ============================PRUNING added
297 | parser.add_argument('--pruning_config', default=None, type=str,
298 | help='path to pruning configuration file, will overwrite all pruning parameters in arguments')
299 |
300 | parser.add_argument('--group_wd_coeff', type=float, default=0.0,
301 | help='group weight decay')
302 | parser.add_argument('--name', default='test', type=str,
303 | help='experiment name(folder) to store logs')
304 |
305 | parser.add_argument('--augment', default=False, type=str2bool, nargs='?',
306 | help='enable or not augmentation of training dataset, only for CIFAR, def False')
307 |
308 | parser.add_argument('--load_model', default='', type=str,
309 | help='path to model weights')
310 |
311 | parser.add_argument('--pruning', default=False, type=str2bool, nargs='?',
312 | help='enable or not pruning, def False')
313 |
314 | parser.add_argument('--pruning-threshold', '--pt', default=100.0, type=float,
315 | help='Max error perc on validation set while pruning (default: 100.0 means always prune)')
316 |
317 | parser.add_argument('--pruning-momentum', default=0.0, type=float,
318 | help='Use momentum on criteria between pruning iterations, def 0.0 means no momentum')
319 |
320 | parser.add_argument('--pruning-step', default=15, type=int,
321 | help='How often to check loss and do pruning step')
322 |
323 | parser.add_argument('--prune_per_iteration', default=10, type=int,
324 | help='How many neurons to remove at each iteration')
325 |
326 | parser.add_argument('--fixed_layer', default=-1, type=int,
327 | help='Prune only a given layer with index, use -1 to prune all')
328 |
329 | parser.add_argument('--start_pruning_after_n_iterations', default=0, type=int,
330 | help='from which iteration to start pruning')
331 |
332 | parser.add_argument('--maximum_pruning_iterations', default=1e8, type=int,
333 | help='maximum pruning iterations')
334 |
335 | parser.add_argument('--starting_neuron', default=0, type=int,
336 | help='starting position for oracle pruning')
337 |
338 | parser.add_argument('--prune_neurons_max', default=-1, type=int,
339 | help='prune_neurons_max')
340 |
341 | parser.add_argument('--pruning-method', default=0, type=int,
342 | help='pruning method to be used, see readme.md')
343 |
344 | parser.add_argument('--pruning_fixed_criteria', default=False, type=str2bool, nargs='?',
345 | help='enable or not criteria reevaluation, def False')
346 |
347 | parser.add_argument('--fixed_network', default=False, type=str2bool, nargs='?',
348 | help='fix network for oracle or criteria computation')
349 |
350 | parser.add_argument('--zero_lr_for_epochs', default=-1, type=int,
351 | help='Learning rate will be set to 0 for given number of updates')
352 |
353 | parser.add_argument('--dynamic_network', default=False, type=str2bool, nargs='?',
354 | help='Creates a new network graph from pruned model, works with ResNet-101 only')
355 |
356 | parser.add_argument('--use_test_as_train', default=False, type=str2bool, nargs='?',
357 | help='use testing dataset instead of training')
358 |
359 | parser.add_argument('--pruning_mask_from', default='', type=str,
360 | help='path to mask file precomputed')
361 |
362 | parser.add_argument('--compute_flops', default=True, type=str2bool, nargs='?',
363 | help='if True, will run dummy inference of batch 1 before training to get conv sizes')
364 |
365 |
366 |
367 | # ============================END pruning added
368 |
369 | best_prec1 = 0
370 | global global_iteration
371 | global group_wd_optimizer
372 | global_iteration = 0
373 |
374 | args = parser.parse_args()
375 | use_cuda = not args.no_cuda and torch.cuda.is_available()
376 |
377 | torch.manual_seed(args.seed)
378 |
379 | args.distributed = args.world_size > 1
380 | if args.distributed:
381 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
382 | world_size=args.world_size, rank=0)
383 |
384 | device = torch.device("cuda" if use_cuda else "cpu")
385 |
386 | if args.model == "lenet3":
387 | model = LeNet(dataset=args.dataset)
388 | elif args.model == "vgg":
389 | model = vgg11_bn(pretrained=True)
390 | elif args.model == "resnet18":
391 | model = PreActResNet18()
392 | elif (args.model == "resnet50") or (args.model == "resnet50_noskip"):
393 | if args.dataset == "CIFAR10":
394 | model = PreActResNet50(dataset=args.dataset)
395 | else:
396 | from models.resnet import resnet50
397 | skip_gate = True
398 | if "noskip" in args.model:
399 | skip_gate = False
400 |
401 | if args.pruning_method not in [22, 40]:
402 | skip_gate = False
403 | model = resnet50(skip_gate=skip_gate)
404 | elif args.model == "resnet34":
405 | if not (args.dataset == "CIFAR10"):
406 | from models.resnet import resnet34
407 | model = resnet34()
408 | elif "resnet101" in args.model:
409 | if not (args.dataset == "CIFAR10"):
410 | from models.resnet import resnet101
411 | if args.dataset == "Imagenet":
412 | classes = 1000
413 |
414 | if "noskip" in args.model:
415 | model = resnet101(num_classes=classes, skip_gate=False)
416 | else:
417 | model = resnet101(num_classes=classes)
418 |
419 | elif args.model == "resnet20":
420 | if args.dataset == "CIFAR10":
421 | NotImplementedError("resnet20 is not implemented in the current project")
422 | # from models.resnet_cifar import resnet20
423 | # model = resnet20()
424 | elif args.model == "resnet152":
425 | model = PreActResNet152()
426 | elif args.model == "densenet201_imagenet":
427 | from models.densenet_imagenet import DenseNet201
428 | model = DenseNet201(gate_types=['output_bn'], pretrained=True)
429 | elif args.model == "densenet121_imagenet":
430 | from models.densenet_imagenet import DenseNet121
431 | model = DenseNet121(gate_types=['output_bn'], pretrained=True)
432 | else:
433 | print(args.model, "model is not supported")
434 |
435 | # dataset loading section
436 | if args.dataset == "MNIST":
437 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
438 | train_loader = torch.utils.data.DataLoader(
439 | datasets.MNIST('../data', train=True, download=True,
440 | transform=transforms.Compose([
441 | transforms.ToTensor(),
442 | transforms.Normalize((0.1307,), (0.3081,))
443 | ])),
444 | batch_size=args.batch_size, shuffle=True, **kwargs)
445 | test_loader = torch.utils.data.DataLoader(
446 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
447 | transforms.ToTensor(),
448 | transforms.Normalize((0.1307,), (0.3081,))
449 | ])),
450 | batch_size=args.test_batch_size, shuffle=True, **kwargs)
451 |
452 | elif args.dataset == "CIFAR10":
453 | # Data loading code
454 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
455 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
456 |
457 | if args.augment:
458 | transform_train = transforms.Compose([
459 | transforms.RandomCrop(32, padding=4),
460 | transforms.RandomHorizontalFlip(),
461 | transforms.ToTensor(),
462 | normalize,
463 | ])
464 | else:
465 | transform_train = transforms.Compose([
466 | transforms.ToTensor(),
467 | normalize,
468 | ])
469 |
470 | transform_test = transforms.Compose([
471 | transforms.ToTensor(),
472 | normalize
473 | ])
474 |
475 | kwargs = {'num_workers': 8, 'pin_memory': True}
476 | train_loader = torch.utils.data.DataLoader(
477 | datasets.CIFAR10('../data', train=True, download=True,
478 | transform=transform_train),
479 | batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
480 |
481 | test_loader = torch.utils.data.DataLoader(
482 | datasets.CIFAR10('../data', train=False, transform=transform_test),
483 | batch_size=args.test_batch_size, shuffle=True, **kwargs)
484 |
485 | elif args.dataset == "Imagenet":
486 | traindir = os.path.join(args.data, 'train')
487 | valdir = os.path.join(args.data, 'val')
488 |
489 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
490 | std=[0.229, 0.224, 0.225])
491 |
492 | train_dataset = datasets.ImageFolder(
493 | traindir,
494 | transforms.Compose([
495 | transforms.RandomResizedCrop(224),
496 | transforms.RandomHorizontalFlip(),
497 | transforms.ToTensor(),
498 | normalize,
499 | ]))
500 |
501 | if args.distributed:
502 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
503 | else:
504 | train_sampler = None
505 |
506 | kwargs = {'num_workers': 16}
507 |
508 | train_loader = torch.utils.data.DataLoader(
509 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
510 | sampler=train_sampler, pin_memory=True, **kwargs)
511 |
512 | if args.use_test_as_train:
513 | train_loader = torch.utils.data.DataLoader(
514 | datasets.ImageFolder(valdir, transforms.Compose([
515 | transforms.Resize(256),
516 | transforms.CenterCrop(224),
517 | transforms.ToTensor(),
518 | normalize,
519 | ])),
520 | batch_size=args.batch_size, shuffle=(train_sampler is None), **kwargs)
521 |
522 |
523 | test_loader = torch.utils.data.DataLoader(
524 | datasets.ImageFolder(valdir, transforms.Compose([
525 | transforms.Resize(256),
526 | transforms.CenterCrop(224),
527 | transforms.ToTensor(),
528 | normalize,
529 | ])),
530 | batch_size=args.batch_size, shuffle=False, pin_memory=True, **kwargs)
531 |
532 | ####end dataset preparation
533 |
534 | if args.dynamic_network:
535 | # attempts to load pruned model and modify it be removing pruned channels
536 | # works for resnet101 only
537 | if (len(args.load_model) > 0) and (args.dynamic_network):
538 | if os.path.isfile(args.load_model):
539 | load_model_pytorch(model, args.load_model, args.model)
540 |
541 | else:
542 | print("=> no checkpoint found at '{}'".format(args.load_model))
543 | exit()
544 |
545 | dynamic_network_change_local(model)
546 |
547 | # save the model
548 | log_save_folder = "%s"%args.name
549 | if not os.path.exists(log_save_folder):
550 | os.makedirs(log_save_folder)
551 |
552 | if not os.path.exists("%s/models" % (log_save_folder)):
553 | os.makedirs("%s/models" % (log_save_folder))
554 |
555 | model_save_path = "%s/models/pruned.weights"%(log_save_folder)
556 | model_state_dict = model.state_dict()
557 | if args.save_models:
558 | save_checkpoint({
559 | 'state_dict': model_state_dict
560 | }, False, filename = model_save_path)
561 |
562 | print("model is defined")
563 |
564 | # aux function to get size of feature maps
565 | # First it adds hooks for each conv layer
566 | # Then runs inference with 1 image
567 | output_sizes = get_conv_sizes(args, model)
568 |
569 | if use_cuda and not args.mgpu:
570 | model = model.to(device)
571 | elif args.distributed:
572 | model.cuda()
573 | print("\n\n WARNING: distributed pruning was not verified and might not work correctly")
574 | model = torch.nn.parallel.DistributedDataParallel(model)
575 | elif args.mgpu:
576 | model = torch.nn.DataParallel(model).cuda()
577 | else:
578 | model = model.to(device)
579 |
580 | print("model is set to device: use_cuda {}, args.mgpu {}, agrs.distributed {}".format(use_cuda, args.mgpu, args.distributed))
581 |
582 | weight_decay = args.wd
583 | if args.fixed_network:
584 | weight_decay = 0.0
585 |
586 | # remove updates from gate layers, because we want them to be 0 or 1 constantly
587 | if 1:
588 | parameters_for_update = []
589 | parameters_for_update_named = []
590 | for name, m in model.named_parameters():
591 | if "gate" not in name:
592 | parameters_for_update.append(m)
593 | parameters_for_update_named.append((name, m))
594 | else:
595 | print("skipping parameter", name, "shape:", m.shape)
596 |
597 | total_size_params = sum([np.prod(par.shape) for par in parameters_for_update])
598 | print("Total number of parameters, w/o usage of bn consts: ", total_size_params)
599 |
600 | optimizer = optim.SGD(parameters_for_update, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay)
601 |
602 | if 1:
603 | # helping optimizer to implement group lasso (with very small weight that doesn't affect training)
604 | # will be used to calculate number of remaining flops and parameters in the network
605 | group_wd_optimizer = group_lasso_decay(parameters_for_update, group_lasso_weight=args.group_wd_coeff, named_parameters=parameters_for_update_named, output_sizes=output_sizes)
606 |
607 | cudnn.benchmark = True
608 |
609 | # define objective
610 | criterion = nn.CrossEntropyLoss()
611 |
612 | ###=======================added for pruning
613 | # logging part
614 | log_save_folder = "%s"%args.name
615 | if not os.path.exists(log_save_folder):
616 | os.makedirs(log_save_folder)
617 |
618 | if not os.path.exists("%s/models" % (log_save_folder)):
619 | os.makedirs("%s/models" % (log_save_folder))
620 |
621 | train_writer = None
622 | if args.tensorboard:
623 | try:
624 | # tensorboardX v1.6
625 | train_writer = SummaryWriter(log_dir="%s"%(log_save_folder))
626 | except:
627 | # tensorboardX v1.7
628 | train_writer = SummaryWriter(logdir="%s"%(log_save_folder))
629 |
630 | time_point = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
631 | textfile = "%s/log_%s.txt" % (log_save_folder, time_point)
632 | stdout = Logger(textfile)
633 | sys.stdout = stdout
634 | print(" ".join(sys.argv))
635 |
636 | # initializing parameters for pruning
637 | # we can add weights of different layers or we can add gates (multiplies output with 1, useful only for gradient computation)
638 | pruning_engine = None
639 | if args.pruning:
640 | pruning_settings = dict()
641 | if not (args.pruning_config is None):
642 | pruning_settings_reader = PruningConfigReader()
643 | pruning_settings_reader.read_config(args.pruning_config)
644 | pruning_settings = pruning_settings_reader.get_parameters()
645 |
646 | # overwrite parameters from config file with those from command line
647 | # needs manual entry here
648 | # user_specified = [key for key in vars(default_args).keys() if not (vars(default_args)[key]==vars(args)[key])]
649 | # argv_of_interest = ['pruning_threshold', 'pruning-momentum', 'pruning_step', 'prune_per_iteration',
650 | # 'fixed_layer', 'start_pruning_after_n_iterations', 'maximum_pruning_iterations',
651 | # 'starting_neuron', 'prune_neurons_max', 'pruning_method']
652 |
653 | has_attribute = lambda x: any([x in a for a in sys.argv])
654 |
655 | if has_attribute('pruning-momentum'):
656 | pruning_settings['pruning_momentum'] = vars(args)['pruning_momentum']
657 | if has_attribute('pruning-method'):
658 | pruning_settings['method'] = vars(args)['pruning_method']
659 |
660 | pruning_parameters_list = prepare_pruning_list(pruning_settings, model, model_name=args.model,
661 | pruning_mask_from=args.pruning_mask_from, name=args.name)
662 | print("Total pruning layers:", len(pruning_parameters_list))
663 |
664 | folder_to_write = "%s"%log_save_folder+"/"
665 | log_folder = folder_to_write
666 |
667 | pruning_engine = pytorch_pruning(pruning_parameters_list, pruning_settings=pruning_settings, log_folder=log_folder)
668 |
669 | pruning_engine.connect_tensorboard(train_writer)
670 | pruning_engine.dataset = args.dataset
671 | pruning_engine.model = args.model
672 | pruning_engine.pruning_mask_from = args.pruning_mask_from
673 | pruning_engine.load_mask()
674 | gates_to_params = connect_gates_with_parameters_for_flops(args.model, parameters_for_update_named)
675 | pruning_engine.gates_to_params = gates_to_params
676 |
677 | ###=======================end for pruning
678 | # loading model file
679 | if (len(args.load_model) > 0) and (not args.dynamic_network):
680 | if os.path.isfile(args.load_model):
681 | load_model_pytorch(model, args.load_model, args.model)
682 | else:
683 | print("=> no checkpoint found at '{}'".format(args.load_model))
684 | exit()
685 |
686 | if args.tensorboard and 0:
687 | if args.dataset == "CIFAR10":
688 | dummy_input = torch.rand(1, 3, 32, 32).to(device)
689 | elif args.dataset == "Imagenet":
690 | dummy_input = torch.rand(1, 3, 224, 224).to(device)
691 |
692 | train_writer.add_graph(model, dummy_input)
693 |
694 | for epoch in range(1, args.epochs + 1):
695 | if args.distributed:
696 | train_sampler.set_epoch(epoch)
697 | adjust_learning_rate(args, optimizer, epoch, args.zero_lr_for_epochs, train_writer)
698 |
699 | if not args.run_test and not args.get_inference_time:
700 | train(args, model, device, train_loader, optimizer, epoch, criterion, train_writer=train_writer, pruning_engine=pruning_engine)
701 |
702 | if args.pruning:
703 | # skip validation error calculation and model saving
704 | if pruning_engine.method == 50: continue
705 |
706 | # evaluate on validation set
707 | prec1, _ = validate(args, test_loader, model, device, criterion, epoch, train_writer=train_writer)
708 |
709 | # remember best prec@1 and save checkpoint
710 | is_best = prec1 > best_prec1
711 | best_prec1 = max(prec1, best_prec1)
712 | model_save_path = "%s/models/checkpoint.weights"%(log_save_folder)
713 | model_state_dict = model.state_dict()
714 | if args.save_models:
715 | save_checkpoint({
716 | 'epoch': epoch + 1,
717 | 'state_dict': model_state_dict,
718 | 'best_prec1': best_prec1,
719 | }, is_best, filename=model_save_path)
720 |
721 |
722 | if __name__ == '__main__':
723 | main()
724 |
--------------------------------------------------------------------------------
/models/densenet_imagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | # based on https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
7 |
8 | import re
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.utils.model_zoo as model_zoo
13 | from collections import OrderedDict
14 | from layers.gate_layer import GateLayer
15 |
16 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
17 |
18 |
19 | model_urls = {
20 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
21 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
22 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
23 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
24 | }
25 |
26 |
27 |
28 | def DenseNet121(pretrained=False, **kwargs):
29 | r"""Densenet-121 model from
30 | `"Densely Connected Convolutional Networks" `_
31 |
32 | Args:
33 | pretrained (bool): If True, returns a model pre-trained on ImageNet
34 | """
35 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
36 | **kwargs)
37 | if pretrained:
38 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
39 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
40 | # They are also in the checkpoints in model_urls. This pattern is used
41 | # to find such keys.
42 | pattern = re.compile(
43 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
44 | state_dict = model_zoo.load_url(model_urls['densenet121'])
45 | for key in list(state_dict.keys()):
46 | res = pattern.match(key)
47 | if res:
48 | new_key = res.group(1) + res.group(2)
49 | state_dict[new_key] = state_dict[key]
50 | del state_dict[key]
51 | model.load_state_dict(state_dict, strict=False)
52 | return model
53 |
54 |
55 | def DenseNet169(pretrained=False, **kwargs):
56 | r"""Densenet-169 model from
57 | `"Densely Connected Convolutional Networks" `_
58 |
59 | Args:
60 | pretrained (bool): If True, returns a model pre-trained on ImageNet
61 | """
62 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
63 | **kwargs)
64 | if pretrained:
65 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
66 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
67 | # They are also in the checkpoints in model_urls. This pattern is used
68 | # to find such keys.
69 | pattern = re.compile(
70 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
71 | state_dict = model_zoo.load_url(model_urls['densenet169'])
72 | for key in list(state_dict.keys()):
73 | res = pattern.match(key)
74 | if res:
75 | new_key = res.group(1) + res.group(2)
76 | state_dict[new_key] = state_dict[key]
77 | del state_dict[key]
78 | model.load_state_dict(state_dict, strict=False)
79 | return model
80 |
81 |
82 | def DenseNet201(pretrained=False, **kwargs):
83 | r"""Densenet-201 model from
84 | `"Densely Connected Convolutional Networks" `_
85 |
86 | Args:
87 | pretrained (bool): If True, returns a model pre-trained on ImageNet
88 | """
89 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
90 | **kwargs)
91 | if pretrained:
92 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
93 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
94 | # They are also in the checkpoints in model_urls. This pattern is used
95 | # to find such keys.
96 | pattern = re.compile(
97 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
98 | state_dict = model_zoo.load_url(model_urls['densenet201'], model_dir='models/pretrained')
99 | for key in list(state_dict.keys()):
100 | res = pattern.match(key)
101 | if res:
102 | new_key = res.group(1) + res.group(2)
103 | state_dict[new_key] = state_dict[key]
104 | del state_dict[key]
105 | model.load_state_dict(state_dict, strict=False)
106 | return model
107 |
108 |
109 | def densenet161(pretrained=False, **kwargs):
110 | r"""Densenet-161 model from
111 | `"Densely Connected Convolutional Networks" `_
112 |
113 | Args:
114 | pretrained (bool): If True, returns a model pre-trained on ImageNet
115 | """
116 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
117 | **kwargs)
118 | if pretrained:
119 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
120 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
121 | # They are also in the checkpoints in model_urls. This pattern is used
122 | # to find such keys.
123 | pattern = re.compile(
124 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
125 | state_dict = model_zoo.load_url(model_urls['densenet161'])
126 | for key in list(state_dict.keys()):
127 | res = pattern.match(key)
128 | if res:
129 | new_key = res.group(1) + res.group(2)
130 | state_dict[new_key] = state_dict[key]
131 | del state_dict[key]
132 | model.load_state_dict(state_dict, strict=False)
133 | return model
134 |
135 |
136 | class _DenseLayer(nn.Sequential):
137 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, gate_types):
138 | super(_DenseLayer, self).__init__()
139 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
140 | self.add_module('relu1', nn.ReLU(inplace=True)),
141 | if 'input' in gate_types:
142 | self.add_module('gate1): (input', GateLayer(num_input_features,num_input_features,[1, -1, 1, 1]))
143 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
144 | growth_rate, kernel_size=1, stride=1, bias=False)),
145 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
146 | if 'output_bn' in gate_types:
147 | self.add_module('gate2): (output_bn', GateLayer(bn_size * growth_rate,bn_size * growth_rate,[1, -1, 1, 1]))
148 | self.add_module('relu2', nn.ReLU(inplace=True)),
149 |
150 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
151 | kernel_size=3, stride=1, padding=1, bias=False)),
152 | if 'output_conv' in gate_types:
153 | self.add_module('gate3): (output_conv', GateLayer(growth_rate,growth_rate,[1, -1, 1, 1]))
154 | self.drop_rate = drop_rate
155 |
156 | def forward(self, x):
157 | new_features = super(_DenseLayer, self).forward(x)
158 | if self.drop_rate > 0:
159 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
160 | return torch.cat([x, new_features], 1)
161 |
162 |
163 | class _DenseBlock(nn.Sequential):
164 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, gate_types):
165 | super(_DenseBlock, self).__init__()
166 | for i in range(num_layers):
167 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, gate_types)
168 | self.add_module('denselayer%d' % (i + 1), layer)
169 |
170 |
171 | class _Transition(nn.Sequential):
172 | def __init__(self, num_input_features, num_output_features, gate_types):
173 | super(_Transition, self).__init__()
174 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
175 | self.add_module('relu', nn.ReLU(inplace=True))
176 | if 'input' in gate_types:
177 | self.add_module('gate): (input', GateLayer(num_input_features,num_input_features,[1, -1, 1, 1]))
178 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
179 | kernel_size=1, stride=1, bias=False))
180 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
181 | if 'output_conv' in gate_types:
182 | self.add_module('gate): (output_conv', GateLayer(num_output_features,num_output_features,[1, -1, 1, 1]))
183 |
184 |
185 | class DenseNet(nn.Module):
186 | r"""Densenet-BC model class, based on
187 | `"Densely Connected Convolutional Networks" `_
188 |
189 | Args:
190 | growth_rate (int) - how many filters to add each layer (`k` in paper)
191 | block_config (list of 4 ints) - how many layers in each pooling block
192 | num_init_features (int) - the number of filters to learn in the first convolution layer
193 | bn_size (int) - multiplicative factor for number of bottle neck layers
194 | (i.e. bn_size * k features in the bottleneck layer)
195 | drop_rate (float) - dropout rate after each dense layer
196 | num_classes (int) - number of classification classes
197 | """
198 |
199 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
200 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000,
201 | gate_types=['input','output_bn','output_conv','bottom','top']):
202 |
203 | super(DenseNet, self).__init__()
204 |
205 | # First convolution
206 | self.features = nn.Sequential(OrderedDict([
207 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
208 | ('norm0', nn.BatchNorm2d(num_init_features)),
209 | ('relu0', nn.ReLU(inplace=True)),
210 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
211 | ]))
212 | if 'bottom' in gate_types:
213 | self.features.add_module('gate0): (bottom', GateLayer(num_init_features,num_init_features,[1, -1, 1, 1]))
214 |
215 | # Each denseblock
216 | num_features = num_init_features
217 | for i, num_layers in enumerate(block_config):
218 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
219 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate,
220 | gate_types=gate_types)
221 | self.features.add_module('denseblock%d' % (i + 1), block)
222 | num_features = num_features + num_layers * growth_rate
223 | if i != len(block_config) - 1:
224 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
225 | gate_types=gate_types)
226 | self.features.add_module('transition%d' % (i + 1), trans)
227 | num_features = num_features // 2
228 |
229 | # Final batch norm
230 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
231 | if 'top' in gate_types:
232 | self.features.add_module('gate5): (top', GateLayer(num_features,num_features,[1, -1, 1, 1]))
233 |
234 | # Linear layer
235 | self.classifier = nn.Linear(num_features, num_classes)
236 |
237 | # Official init from torch repo.
238 | for m in self.modules():
239 | if isinstance(m, nn.Conv2d):
240 | nn.init.kaiming_normal_(m.weight)
241 | elif isinstance(m, nn.BatchNorm2d):
242 | nn.init.constant_(m.weight, 1)
243 | nn.init.constant_(m.bias, 0)
244 | elif isinstance(m, nn.Linear):
245 | nn.init.constant_(m.bias, 0)
246 |
247 | def forward(self, x):
248 | features = self.features(x)
249 | out = F.relu(features, inplace=True)
250 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
251 | out = self.classifier(out)
252 | return out
253 |
--------------------------------------------------------------------------------
/models/lenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | # based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py
7 |
8 | '''LeNet in PyTorch.'''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from layers.gate_layer import GateLayer
13 |
14 |
15 | class LeNet(nn.Module):
16 | def __init__(self, dataset="CIFAR10"):
17 | super(LeNet, self).__init__()
18 | if dataset=="CIFAR10":
19 | nunits_input = 3
20 | nuintis_fc = 32*5*5
21 | elif dataset=="MNIST":
22 | nunits_input = 1
23 | nuintis_fc = 32*4*4
24 | self.conv1 = nn.Conv2d(nunits_input, 16, 5)
25 | self.gate1 = GateLayer(16,16,[1, -1, 1, 1])
26 | self.conv2 = nn.Conv2d(16, 32, 5)
27 | self.gate2 = GateLayer(32,32,[1, -1, 1, 1])
28 | self.fc1 = nn.Linear(nuintis_fc, 120)
29 | self.gate3 = GateLayer(120,120,[1, -1])
30 | self.fc2 = nn.Linear(120, 84)
31 | self.gate4 = GateLayer(84,84,[1, -1])
32 | self.fc3 = nn.Linear(84, 10)
33 |
34 | def forward(self, x):
35 | out = F.relu(self.conv1(x))
36 | out = F.max_pool2d(out, 2)
37 | out = self.gate1(out)
38 | out = F.relu(self.conv2(out))
39 | out = F.max_pool2d(out, 2)
40 | out = self.gate2(out)
41 | out = out.view(out.size(0), -1)
42 | out = F.relu(self.fc1(out))
43 | out = self.gate3(out)
44 | out = F.relu(self.fc2(out))
45 | out = self.gate4(out)
46 | out = self.fc3(out)
47 | return out
--------------------------------------------------------------------------------
/models/preact_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | # based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py
7 |
8 | '''Pre-activation ResNet in PyTorch.
9 | Reference:
10 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
11 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
12 | '''
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from layers.gate_layer import GateLayer
18 |
19 |
20 | def norm2d(planes, num_groups=32):
21 | if num_groups != 0:
22 | print("num_groups:{}".format(num_groups))
23 | if num_groups > 0:
24 | return GroupNorm2D(planes, num_groups)
25 | else:
26 | return nn.BatchNorm2d(planes)
27 |
28 |
29 | class PreActBlock(nn.Module):
30 | '''Pre-activation version of the BasicBlock.'''
31 | expansion = 1
32 |
33 | def __init__(self, in_planes, planes, stride=1, group_norm=0):
34 | super(PreActBlock, self).__init__()
35 | self.bn1 = norm2d(in_planes, group_norm)
36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
37 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1])
38 | self.bn2 = norm2d(planes, group_norm)
39 |
40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
41 | self.gate_out = GateLayer(planes,planes,[1, -1, 1, 1])
42 |
43 | if stride != 1 or in_planes != self.expansion*planes:
44 | self.shortcut = nn.Sequential(
45 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
46 | )
47 | self.gate_shortcut = GateLayer(self.expansion*planes,self.expansion*planes,[1, -1, 1, 1])
48 |
49 | def forward(self, x):
50 | out = F.relu(self.bn1(x))
51 |
52 | if hasattr(self, 'shortcut'):
53 | shortcut = self.shortcut(out)
54 | shortcut = self.gate_shortcut(shortcut)
55 | else:
56 | shortcut = x
57 |
58 | out = self.conv1(out)
59 |
60 | out = self.bn2(out)
61 | out = self.gate1(out)
62 |
63 | out = F.relu(out)
64 | out = self.conv2(out)
65 | out = self.gate_out(out)
66 |
67 | out = out + shortcut
68 | ##as a block here we might benefit with gate at this stage
69 |
70 | return out
71 |
72 |
73 | class PreActBottleneck(nn.Module):
74 | '''Pre-activation version of the original Bottleneck module.'''
75 | expansion = 4
76 |
77 | def __init__(self, in_planes, planes, stride=1, group_norm=0):
78 | super(PreActBottleneck, self).__init__()
79 |
80 | self.bn1 = norm2d(in_planes, group_norm)
81 |
82 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
83 | self.bn2 = norm2d(planes, group_norm)
84 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1])
85 |
86 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
87 | self.bn3 = norm2d(planes, group_norm)
88 | self.gate2 = GateLayer(planes,planes,[1, -1, 1, 1])
89 |
90 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
91 | self.gate3 = GateLayer(self.expansion*planes,self.expansion*planes,[1, -1, 1, 1])
92 |
93 | if stride != 1 or in_planes != self.expansion*planes:
94 | self.shortcut = nn.Sequential(
95 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False))
96 | self.gate_shortcut = GateLayer(self.expansion*planes,self.expansion*planes,[1, -1, 1, 1])
97 |
98 | def forward(self, x):
99 | out = F.relu(self.bn1(x))
100 | input_out = out
101 |
102 | out = self.conv1(out)
103 | out = self.bn2(out)
104 | out = self.gate1(out)
105 | out = F.relu(out)
106 |
107 | out = self.conv2(out)
108 | out = self.bn3(out)
109 | out = self.gate2(out)
110 |
111 | out = F.relu(out)
112 |
113 | out = self.conv3(out)
114 | out = self.gate3(out)
115 |
116 | if hasattr(self, 'shortcut'):
117 | shortcut = self.shortcut(input_out)
118 | shortcut = self.gate_shortcut(shortcut)
119 | else:
120 | shortcut = x
121 |
122 | out = out + shortcut
123 | return out
124 |
125 |
126 | class PreActResNet(nn.Module):
127 | def __init__(self, block, num_blocks, num_classes=10, group_norm=0, dataset="CIFAR10"):
128 | super(PreActResNet, self).__init__()
129 |
130 | self.in_planes = 64
131 | self.dataset = dataset
132 |
133 | if dataset == "CIFAR10":
134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
135 | num_classes = 10
136 | elif dataset == "Imagenet":
137 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
138 | bias=False)
139 | self.bn1 = nn.BatchNorm2d(64)
140 | self.relu = nn.ReLU(inplace=True)
141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
142 | num_classes = 1000
143 |
144 | self.gate_in = GateLayer(64, 64, [1, -1, 1, 1])
145 |
146 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, group_norm=group_norm)
147 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, group_norm=group_norm)
148 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, group_norm=group_norm)
149 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, group_norm=group_norm)
150 |
151 | if dataset == "CIFAR10":
152 | self.avgpool = nn.AvgPool2d(4, stride=1)
153 | self.linear = nn.Linear(512*block.expansion, num_classes)
154 | elif dataset == "Imagenet":
155 | self.avgpool = nn.AvgPool2d(7, stride=1)
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 |
158 |
159 |
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163 | if isinstance(m, nn.BatchNorm2d):
164 | m.bias.data.zero_()
165 |
166 | def _make_layer(self, block, planes, num_blocks, stride, group_norm = 0):
167 | strides = [stride] + [1]*(num_blocks-1)
168 | layers = []
169 | for stride in strides:
170 | layers.append(block(self.in_planes, planes, stride, group_norm = group_norm))
171 | self.in_planes = planes * block.expansion
172 | return nn.Sequential(*layers)
173 |
174 | def forward(self, x):
175 | out = self.conv1(x)
176 | if self.dataset == "Imagenet":
177 | out = self.bn1(out)
178 | out = self.relu(out)
179 | out = self.maxpool(out)
180 |
181 | out = self.gate_in(out)
182 |
183 | out = self.layer1(out)
184 | out = self.layer2(out)
185 | out = self.layer3(out)
186 | out = self.layer4(out)
187 | out = self.avgpool(out)
188 |
189 | out = out.view(out.size(0), -1)
190 | if self.dataset == "CIFAR10":
191 | out = self.linear(out)
192 | elif self.dataset == "Imagenet":
193 | out = self.fc(out)
194 |
195 | return out
196 |
197 |
198 | def PreActResNet18(group_norm = 0):
199 | return PreActResNet(PreActBlock, [2,2,2,2], group_norm= group_norm)
200 |
201 | def PreActResNet34():
202 | return PreActResNet(PreActBlock, [3,4,6,3])
203 |
204 | def PreActResNet50(group_norm=0, dataset = "CIFAR10"):
205 | return PreActResNet(PreActBottleneck, [3,4,6,3], group_norm = group_norm, dataset = dataset)
206 |
207 | def PreActResNet101():
208 | return PreActResNet(PreActBottleneck, [3,4,23,3])
209 |
210 | def PreActResNet152():
211 | return PreActResNet(PreActBottleneck, [3,8,36,3])
212 |
213 |
214 | def test():
215 | net = PreActResNet18()
216 | y = net((torch.randn(1,3,32,32)))
217 | print(y.size())
218 |
219 | # test()
220 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | # based on https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
7 |
8 | import torch.nn as nn
9 | import torch.utils.model_zoo as model_zoo
10 | from layers.gate_layer import GateLayer
11 |
12 |
13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
14 | 'resnet152']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=1, bias=False)
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | expansion = 1
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None, gate=None):
36 | super(BasicBlock, self).__init__()
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = nn.BatchNorm2d(planes)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1])
41 | self.conv2 = conv3x3(planes, planes)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 | self.gate2 = GateLayer(planes,planes,[1, -1, 1, 1])
44 | self.downsample = downsample
45 | self.stride = stride
46 | self.gate = gate
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.gate1(out)
54 | out = self.relu(out)
55 |
56 | out = self.conv2(out)
57 | out = self.bn2(out)
58 | out = self.gate2(out)
59 |
60 | if self.downsample is not None:
61 | residual = self.downsample(x)
62 |
63 | out += residual
64 | out = self.relu(out)
65 |
66 | if self.gate is not None:
67 | out = self.gate(out)
68 |
69 | return out
70 |
71 |
72 | class Bottleneck(nn.Module):
73 | expansion = 4
74 |
75 | def __init__(self, inplanes, planes, stride=1, downsample=None, gate=None):
76 | super(Bottleneck, self).__init__()
77 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
78 | self.bn1 = nn.BatchNorm2d(planes)
79 | self.gate1 = GateLayer(planes,planes,[1, -1, 1, 1])
80 | self.relu1 = nn.ReLU(inplace=True)
81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
82 | padding=1, bias=False)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.gate2 = GateLayer(planes,planes,[1, -1, 1, 1])
85 | self.relu2 = nn.ReLU(inplace=True)
86 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
87 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
88 | self.relu3 = nn.ReLU(inplace=True)
89 | self.downsample = downsample
90 | self.stride = stride
91 | self.gate = gate
92 |
93 | def forward(self, x):
94 | residual = x
95 |
96 | out = self.conv1(x)
97 | out = self.bn1(out)
98 | out = self.gate1(out)
99 |
100 | out = self.relu1(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.gate2(out)
105 | out = self.relu2(out)
106 |
107 | out = self.conv3(out)
108 | out = self.bn3(out)
109 |
110 | if self.downsample is not None:
111 | residual = self.downsample(x)
112 |
113 | out += residual
114 | out = self.relu3(out)
115 | if self.gate is not None:
116 | out = self.gate(out)
117 |
118 | return out
119 |
120 |
121 | class ResNet(nn.Module):
122 |
123 | def __init__(self, block, layers, num_classes=1000, skip_gate = True):
124 | self.inplanes = 64
125 | super(ResNet, self).__init__()
126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
127 | bias=False)
128 | self.bn1 = nn.BatchNorm2d(64)
129 | self.relu = nn.ReLU(inplace=True)
130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131 |
132 | gate = skip_gate
133 | self.gate = gate
134 | if gate:
135 | # self.gate_skip1 = GateLayer(64,64,[1, -1, 1, 1])
136 | self.gate_skip64 = GateLayer(64*4,64*4,[1, -1, 1, 1])
137 | self.gate_skip128 = GateLayer(128*4,128*4,[1, -1, 1, 1])
138 | self.gate_skip256 = GateLayer(256*4,256*4,[1, -1, 1, 1])
139 | self.gate_skip512 = GateLayer(512*4,512*4,[1, -1, 1, 1])
140 | if block == BasicBlock:
141 | self.gate_skip64 = GateLayer(64, 64, [1, -1, 1, 1])
142 | self.gate_skip128 = GateLayer(128, 128, [1, -1, 1, 1])
143 | self.gate_skip256 = GateLayer(256, 256, [1, -1, 1, 1])
144 | self.gate_skip512 = GateLayer(512, 512, [1, -1, 1, 1])
145 | else:
146 | self.gate_skip64 = None
147 | self.gate_skip128 = None
148 | self.gate_skip256 = None
149 | self.gate_skip512 = None
150 |
151 | self.layer1 = self._make_layer(block, 64, layers[0], gate = self.gate_skip64)
152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, gate=self.gate_skip128)
153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, gate=self.gate_skip256)
154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, gate=self.gate_skip512)
155 | self.avgpool = nn.AvgPool2d(7, stride=1)
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 |
158 | for m in self.modules():
159 | if isinstance(m, nn.Conv2d):
160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
161 | elif isinstance(m, nn.BatchNorm2d):
162 | nn.init.constant_(m.weight, 1)
163 | nn.init.constant_(m.bias, 0)
164 |
165 | def _make_layer(self, block, planes, blocks, stride=1, gate = None):
166 | downsample = None
167 | if stride != 1 or self.inplanes != planes * block.expansion:
168 | downsample = nn.Sequential(
169 | nn.Conv2d(self.inplanes, planes * block.expansion,
170 | kernel_size=1, stride=stride, bias=False),
171 | nn.BatchNorm2d(planes * block.expansion),
172 | )
173 |
174 | layers = []
175 | layers.append(block(self.inplanes, planes, stride, downsample, gate = gate))
176 |
177 | self.inplanes = planes * block.expansion
178 | for i in range(1, blocks):
179 | layers.append(block(self.inplanes, planes, gate = gate))
180 |
181 | return nn.Sequential(*layers)
182 |
183 | def forward(self, x):
184 | x = self.conv1(x)
185 | x = self.bn1(x)
186 | x = self.relu(x)
187 | x = self.maxpool(x)
188 |
189 | # if self.gate:
190 | # x=self.gate_skip1(x)
191 |
192 | x = self.layer1(x)
193 | x = self.layer2(x)
194 | x = self.layer3(x)
195 | x = self.layer4(x)
196 |
197 | x = self.avgpool(x)
198 | x = x.view(x.size(0), -1)
199 | x = self.fc(x)
200 |
201 | return x
202 |
203 |
204 | def resnet18(pretrained=False, **kwargs):
205 | """Constructs a ResNet-18 model.
206 |
207 | Args:
208 | pretrained (bool): If True, returns a model pre-trained on ImageNet
209 | """
210 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
211 | if pretrained:
212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
213 | return model
214 |
215 |
216 | def resnet34(pretrained=False, **kwargs):
217 | """Constructs a ResNet-34 model.
218 |
219 | Args:
220 | pretrained (bool): If True, returns a model pre-trained on ImageNet
221 | """
222 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
223 | if pretrained:
224 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
225 | return model
226 |
227 |
228 | def resnet50(pretrained=False, **kwargs):
229 | """Constructs a ResNet-50 model.
230 |
231 | Args:
232 | pretrained (bool): If True, returns a model pre-trained on ImageNet
233 | """
234 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
235 | if pretrained:
236 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
237 | return model
238 |
239 |
240 | def resnet101(pretrained=False, **kwargs):
241 | """Constructs a ResNet-101 model.
242 |
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | """
246 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
247 | if pretrained:
248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
249 | return model
250 |
251 |
252 | def resnet152(pretrained=False, **kwargs):
253 | """Constructs a ResNet-152 model.
254 |
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | """
258 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
259 | if pretrained:
260 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
261 | return model
262 |
--------------------------------------------------------------------------------
/models/vgg_bn.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | # based on https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
7 |
8 | import torch.nn as nn
9 | import torch.utils.model_zoo as model_zoo
10 | from layers.gate_layer import GateLayer
11 |
12 | __all__ = [
13 | 'slimmingvgg',
14 | ]
15 |
16 | model_urls = {
17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
18 | }
19 |
20 | class LinView(nn.Module):
21 | def __init__(self):
22 | super(LinView, self).__init__()
23 |
24 | def forward(self, x):
25 | return x.view(x.size(0), -1)
26 |
27 |
28 | class VGG(nn.Module):
29 |
30 | def __init__(self, features, cfg, num_classes=1000, init_weights=True):
31 | super(VGG, self).__init__()
32 | self.features = features
33 | self.classifier = nn.Sequential(
34 | nn.Linear(cfg[0] * 7 * 7, cfg[1]),
35 | nn.BatchNorm1d(cfg[1]),
36 | nn.ReLU(True),
37 | nn.Linear(cfg[1],cfg[2]),
38 | nn.BatchNorm1d(cfg[2]),
39 | nn.ReLU(True),
40 | nn.Linear(cfg[2], num_classes)
41 | )
42 | if init_weights:
43 | self._initialize_weights()
44 |
45 | def forward(self, x):
46 | x = self.features(x)
47 | x = x.view(x.size(0), -1)
48 | x = self.classifier(x)
49 | return x
50 |
51 | def _initialize_weights(self):
52 | for m in self.modules():
53 | if isinstance(m, nn.Conv2d):
54 | nn.init.kaiming_normal(m.weight, mode='fan_out')#, nonlinearity='relu')
55 | if m.bias is not None:
56 | m.bias.data.zero_()
57 | elif isinstance(m, nn.BatchNorm2d):
58 | m.weight.data.fill_(0.5)
59 | m.bias.data.zero_()
60 | elif isinstance(m, nn.Linear):
61 | m.weight.data.normal_(0, 0.01)
62 | m.bias.data.zero_()
63 | elif isinstance(m, nn.BatchNorm1d):
64 | m.weight.data.fill_(0.5)
65 | m.bias.data.zero_()
66 |
67 | def make_layers(cfg, batch_norm=False):
68 | layers = []
69 | in_channels = 3
70 | for v in cfg:
71 | if v == 'M':
72 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
73 | else:
74 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
75 | if batch_norm:
76 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
77 | else:
78 | layers += [conv2d, nn.ReLU(inplace=True)]
79 | in_channels = v
80 | return nn.Sequential(*layers)
81 |
82 | cfg = {
83 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M', 4096, 4096]
84 | }
85 |
86 |
87 | def flatten_model(old_net):
88 | """Removes nested modules. Only works for VGG."""
89 | from collections import OrderedDict
90 | module_list, counter, inserted_view = [], 0, False
91 | gate_counter = 0
92 | print("printing network")
93 | print(" Hard codded network in vgg_bn.py")
94 | for m_indx, module in enumerate(old_net.modules()):
95 | if not isinstance(module, (nn.Sequential, VGG)):
96 | print(m_indx, module)
97 | if isinstance(module, nn.Linear) and not inserted_view:
98 | module_list.append(('flatten', LinView()))
99 | inserted_view = True
100 |
101 | # features.0
102 | # classifier
103 | prefix = "features"
104 |
105 | if m_indx > 30:
106 | prefix = "classifier"
107 | if m_indx == 32:
108 | counter = 0
109 |
110 | # prefix = ""
111 |
112 | module_list.append((prefix + str(counter), module))
113 |
114 | if isinstance(module, nn.BatchNorm2d):
115 | planes = module.num_features
116 | gate = GateLayer(planes, planes, [1, -1, 1, 1])
117 | module_list.append(('gate%d' % (gate_counter), gate))
118 | print("gate ", counter, planes)
119 | gate_counter += 1
120 |
121 |
122 | if isinstance(module, nn.BatchNorm1d):
123 | planes = module.num_features
124 | gate = GateLayer(planes, planes, [1, -1])
125 | module_list.append(('gate%d' % (gate_counter), gate))
126 | print("gate ", counter, planes)
127 | gate_counter += 1
128 |
129 |
130 | counter += 1
131 | new_net = nn.Sequential(OrderedDict(module_list))
132 | return new_net
133 |
134 |
135 | def slimmingvgg(pretrained=False, config=None, **kwargs):
136 | """VGG 11-layer model (configuration "A") with batch normalization
137 |
138 | Args:
139 | pretrained (bool): If True, returns a model pre-trained on ImageNet
140 | """
141 | if pretrained:
142 | kwargs['init_weights'] = False
143 | if config == None:
144 | config = cfg['A']
145 | config2 = [config[-4],config[-2],config[-1]]
146 | model = VGG(make_layers(config[:-2], batch_norm=True), config2, **kwargs)
147 | if pretrained:
148 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']),strict=False)
149 | model = flatten_model(model)
150 | return model
151 |
--------------------------------------------------------------------------------
/pruning_engine.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from __future__ import print_function
7 | import os
8 | import time
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.optim
14 | import torch.utils.data
15 |
16 | import numpy as np
17 |
18 |
19 | from copy import deepcopy
20 | import itertools
21 | import pickle
22 | import json
23 |
24 | METHOD_ENCODING = {0: "Taylor_weight", 1: "Random", 2: "Weight norm", 3: "Weight_abs",
25 | 6: "Taylor_output", 10: "OBD", 11: "Taylor_gate_SO",
26 | 22: "Taylor_gate", 23: "Taylor_gate_FG", 30: "BN_weight", 31: "BN_Taylor"}
27 |
28 |
29 | # Method is encoded as an integer that mapping is shown above.
30 | # Methods map to the paper as follows:
31 | # 0 - Taylor_weight - Conv weight/conv/linear weight with Taylor FO In Table 2 and Table 1
32 | # 1 - Random - Random
33 | # 2 - Weight norm - Weight magnitude/ weight
34 | # 3 - Weight_abs - Not used
35 | # 6 - Taylor_output - Taylor-output as is [27]
36 | # 10- OBD - OBD
37 | # 11- Taylor_gate_SO- Taylor SO
38 | # 22- Taylor_gate - Gate after BN in Table 2, Taylor FO in Table 1
39 | # 23- Taylor_gate_FG- uses gradient per example to compute Taylor FO, Taylor FO- FG in Table 1, Gate after BN - FG in Table 2
40 | # 30- BN_weight - BN scale in Table 2
41 | # 31- BN_Taylor - BN scale Taylor FO in Table 2
42 |
43 |
44 | class PruningConfigReader(object):
45 | def __init__(self):
46 | self.pruning_settings = {}
47 | self.config = None
48 |
49 | def read_config(self, filename):
50 | # reads .json file and sets values as pruning_settings for pruning
51 |
52 | with open(filename, "r") as f:
53 | config = json.load(f)
54 |
55 | self.config = config
56 |
57 | self.read_field_value("method", 0)
58 | self.read_field_value("frequency", 500)
59 | self.read_field_value("prune_per_iteration", 2)
60 | self.read_field_value("maximum_pruning_iterations", 10000)
61 | self.read_field_value("starting_neuron", 0)
62 |
63 | self.read_field_value("fixed_layer", -1)
64 | # self.read_field_value("use_momentum", False)
65 |
66 | self.read_field_value("pruning_threshold", 100)
67 | self.read_field_value("start_pruning_after_n_iterations", 0)
68 | # self.read_field_value("use_momentum", False)
69 | self.read_field_value("do_iterative_pruning", True)
70 | self.read_field_value("fixed_criteria", False)
71 | self.read_field_value("seed", 0)
72 | self.read_field_value("pruning_momentum", 0.9)
73 | self.read_field_value("flops_regularization", 0.0)
74 | self.read_field_value("prune_neurons_max", 1)
75 |
76 | self.read_field_value("group_size", 1)
77 |
78 | def read_field_value(self, key, default):
79 | param = default
80 | if key in self.config:
81 | param = self.config[key]
82 |
83 | self.pruning_settings[key] = param
84 |
85 | def get_parameters(self):
86 | return self.pruning_settings
87 |
88 |
89 | class pytorch_pruning(object):
90 | def __init__(self, parameters, pruning_settings=dict(), log_folder=None):
91 | def initialize_parameter(object_name, settings, key, def_value):
92 | '''
93 | Function check if key is in the settings and sets it, otherwise puts default momentum
94 | :param object_name: reference to the object instance
95 | :param settings: dict of settings
96 | :param def_value: def value for the parameter to be putted into the field if it doesn't work
97 | :return:
98 | void
99 | '''
100 | value = def_value
101 | if key in settings.keys():
102 | value = settings[key]
103 | setattr(object_name, key, value)
104 |
105 | # store some statistics
106 | self.min_criteria_value = 1e6
107 | self.max_criteria_value = 0.0
108 | self.median_criteria_value = 0.0
109 | self.neuron_units = 0
110 | self.all_neuron_units = 0
111 | self.pruned_neurons = 0
112 | self.gradient_norm_final = 0.0
113 | self.flops_regularization = 0.0 #not used in the paper
114 | self.pruning_iterations_done = 0
115 |
116 | # initialize_parameter(self, pruning_settings, 'use_momentum', False)
117 | initialize_parameter(self, pruning_settings, 'pruning_momentum', 0.9)
118 | initialize_parameter(self, pruning_settings, 'flops_regularization', 0.0)
119 | self.momentum_coeff = self.pruning_momentum
120 | self.use_momentum = self.pruning_momentum > 0.0
121 |
122 | initialize_parameter(self, pruning_settings, 'prune_per_iteration', 1)
123 | initialize_parameter(self, pruning_settings, 'start_pruning_after_n_iterations', 0)
124 | initialize_parameter(self, pruning_settings, 'prune_neurons_max', 0)
125 | initialize_parameter(self, pruning_settings, 'maximum_pruning_iterations', 0)
126 | initialize_parameter(self, pruning_settings, 'pruning_silent', False)
127 | initialize_parameter(self, pruning_settings, 'l2_normalization_per_layer', False)
128 | initialize_parameter(self, pruning_settings, 'fixed_criteria', False)
129 | initialize_parameter(self, pruning_settings, 'starting_neuron', 0)
130 | initialize_parameter(self, pruning_settings, 'frequency', 30)
131 | initialize_parameter(self, pruning_settings, 'pruning_threshold', 100)
132 | initialize_parameter(self, pruning_settings, 'fixed_layer', -1)
133 | initialize_parameter(self, pruning_settings, 'combination_ID', 0)
134 | initialize_parameter(self, pruning_settings, 'seed', 0)
135 | initialize_parameter(self, pruning_settings, 'group_size', 1)
136 |
137 | initialize_parameter(self, pruning_settings, 'method', 0)
138 |
139 | # Hessian related parameters
140 | self.temp_hessian = [] # list to store Hessian
141 | self.hessian_first_time = True
142 |
143 | self.parameters = list()
144 |
145 | ##get pruning parameters
146 | for parameter in parameters:
147 | parameter_value = parameter["parameter"]
148 | self.parameters.append(parameter_value)
149 |
150 | if self.fixed_layer == -1:
151 | ##prune all layers
152 | self.prune_layers = [True for parameter in self.parameters]
153 | else:
154 | ##prune only one layer
155 | self.prune_layers = [False, ]*len(self.parameters)
156 | self.prune_layers[self.fixed_layer] = True
157 |
158 | self.iterations_done = 0
159 |
160 | self.prune_network_criteria = list()
161 | self.prune_network_accomulate = {"by_layer": list(), "averaged": list(), "averaged_cpu": list()}
162 |
163 | self.pruning_gates = list()
164 | for layer in range(len(self.parameters)):
165 | self.prune_network_criteria.append(list())
166 |
167 | for key in self.prune_network_accomulate.keys():
168 | self.prune_network_accomulate[key].append(list())
169 |
170 | self.pruning_gates.append(np.ones(len(self.parameters[layer]),))
171 | layer_now_criteria = self.prune_network_criteria[-1]
172 | for unit in range(len(self.parameters[layer])):
173 | layer_now_criteria.append(0.0)
174 |
175 | # logging setup
176 | self.log_folder = log_folder
177 | self.folder_to_write_debug = self.log_folder + '/debug/'
178 | if not os.path.exists(self.folder_to_write_debug):
179 | os.makedirs(self.folder_to_write_debug)
180 |
181 | self.method_25_first_done = True
182 |
183 | if self.method == 40 or self.method == 50 or self.method == 25:
184 | self.oracle_dict = {"layer_pruning": -1, "initial_loss": 0.0, "loss_list": list(), "neuron": list(), "iterations": 0}
185 | self.method_25_first_done = False
186 |
187 | if self.method == 25:
188 | with open("./utils/study/oracle.pickle","rb") as f:
189 | oracle_list = pickle.load(f)
190 |
191 | self.oracle_dict["loss_list"] = oracle_list
192 |
193 | self.needs_hessian = False
194 | if self.method in [10, 11]:
195 | self.needs_hessian = True
196 |
197 | # useful for storing data of the experiment
198 | self.data_logger = dict()
199 | self.data_logger["pruning_neurons"] = list()
200 | self.data_logger["pruning_accuracy"] = list()
201 | self.data_logger["pruning_loss"] = list()
202 | self.data_logger["method"] = self.method
203 | self.data_logger["prune_per_iteration"] = self.prune_per_iteration
204 | self.data_logger["combination_ID"] = list()
205 | self.data_logger["fixed_layer"] = self.fixed_layer
206 | self.data_logger["frequency"] = self.frequency
207 | self.data_logger["starting_neuron"] = self.starting_neuron
208 | self.data_logger["use_momentum"] = self.use_momentum
209 |
210 | self.data_logger["time_stamp"] = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
211 |
212 | if hasattr(self, 'seed'):
213 | self.data_logger["seed"] = self.seed
214 |
215 | self.data_logger["filename"] = "%s/data_logger_seed_%d_%s.p"%(log_folder, self.data_logger["seed"], self.data_logger["time_stamp"])
216 | if self.method == 50:
217 | self.data_logger["filename"] = "%s/data_logger_seed_%d_neuron_%d_%s.p"%(log_folder, self.starting_neuron, self.data_logger["seed"], self.data_logger["time_stamp"])
218 | self.log_folder = log_folder
219 |
220 | # the rest of initializations
221 | self.pruned_neurons = self.starting_neuron
222 |
223 | self.util_loss_tracker = 0.0
224 | self.util_acc_tracker = 0.0
225 | self.util_loss_tracker_num = 0.0
226 |
227 | self.loss_tracker_exp = ExpMeter()
228 | # stores results of the pruning, 0 - unsuccessful, 1 - successful
229 | self.res_pruning = 0
230 |
231 | self.iter_step = -1
232 |
233 | self.train_writer = None
234 |
235 | self.set_moment_zero = True
236 | self.pruning_mask_from = ""
237 |
238 | def add_criteria(self):
239 | '''
240 | This method adds criteria to global list given batch stats.
241 | '''
242 |
243 | if self.fixed_criteria:
244 | if self.pruning_iterations_done > self.start_pruning_after_n_iterations :
245 | return 0
246 |
247 | for layer, if_prune in enumerate(self.prune_layers):
248 | if not if_prune:
249 | continue
250 |
251 | nunits = self.parameters[layer].size(0)
252 | eps = 1e-8
253 |
254 | if len(self.pruning_mask_from) > 0:
255 | # preload pruning mask
256 | self.method = -1
257 | criteria_for_layer = torch.from_numpy(self.loaded_mask_criteria[layer]).type(torch.FloatTensor).cuda(async=True)
258 |
259 | if self.method == 0:
260 | # First order Taylor expansion on the weight
261 | criteria_for_layer = (self.parameters[layer]*self.parameters[layer].grad ).data.pow(2).view(nunits,-1).sum(dim=1)
262 | elif self.method == 1:
263 | # random pruning
264 | criteria_for_layer = np.random.uniform(low=0, high=5, size=(nunits,))
265 | elif self.method == 2:
266 | # min weight
267 | criteria_for_layer = self.parameters[layer].pow(2).view(nunits,-1).sum(dim=1).data
268 | elif self.method == 3:
269 | # weight_abs
270 | criteria_for_layer = self.parameters[layer].abs().view(nunits,-1).sum(dim=1).data
271 | elif self.method == 6:
272 | # ICLR2017 Taylor on output of the layer
273 | if 1:
274 | criteria_for_layer = self.parameters[layer].full_grad_iclr2017
275 | criteria_for_layer = criteria_for_layer / (np.linalg.norm(criteria_for_layer) + eps)
276 | elif self.method == 10:
277 | # diagonal of Hessian
278 | criteria_for_layer = (self.parameters[layer] * torch.diag(self.temp_hessian[layer])).data.view(nunits,
279 | -1).sum(
280 | dim=1)
281 | elif self.method == 11:
282 | # second order Taylor expansion for loss change in the network
283 | criteria_for_layer = (-(self.parameters[layer] * self.parameters[layer].grad).data + 0.5 * (
284 | self.parameters[layer] * self.parameters[layer] * torch.diag(
285 | self.temp_hessian[layer])).data).pow(2)
286 |
287 | elif self.method == 22:
288 | # Taylor pruning on gate
289 | criteria_for_layer = (self.parameters[layer]*self.parameters[layer].grad).data.pow(2).view(nunits, -1).sum(dim=1)
290 | if hasattr(self, "dataset"):
291 | # fix for skip connection pruning, gradient will be accumulated instead of being averaged
292 | if self.dataset == "Imagenet":
293 | if hasattr(self, "model"):
294 | if not ("noskip" in self.model):
295 | if "resnet" in self.model:
296 | mult = 3.0
297 | if layer == 1:
298 | mult = 4.0
299 | elif layer == 2:
300 | mult = 23.0 if "resnet101" in self.model else mult
301 | mult = 6.0 if "resnet34" in self.model else mult
302 | mult = 6.0 if "resnet50" in self.model else mult
303 |
304 | criteria_for_layer /= mult
305 |
306 | elif self.method == 23:
307 | # Taylor pruning on gate with computing full gradient
308 | criteria_for_layer = (self.parameters[layer].full_grad.t()).data.pow(2).view(nunits,-1).sum(dim=1)
309 |
310 | elif self.method == 30:
311 | # batch normalization based pruning
312 | # by scale (weight) of the batchnorm
313 | criteria_for_layer = (self.parameters[layer]).data.abs().view(nunits, -1).sum(dim=1)
314 |
315 | elif self.method == 31:
316 | # Taylor FO on BN
317 | if hasattr(self.parameters[layer], "bias"):
318 | criteria_for_layer = (self.parameters[layer]*self.parameters[layer].grad +
319 | self.parameters[layer].bias*self.parameters[layer].bias.grad ).data.pow(2).view(nunits,-1).sum(dim=1)
320 | else:
321 | criteria_for_layer = (
322 | self.parameters[layer] * self.parameters[layer].grad).data.pow(2).view(nunits, -1).sum(dim=1)
323 |
324 | elif self.method == 40:
325 | # ORACLE on the fly that reevaluates itslef every pruning step
326 | criteria_for_layer = np.asarray(self.oracle_dict["loss_list"][layer]).copy()
327 | self.oracle_dict["loss_list"][layer] = list()
328 | elif self.method == 50:
329 | # combinatorial pruning - evaluates all possibilities of removing N neurons
330 | criteria_for_layer = np.asarray(self.oracle_dict["loss_list"][layer]).copy()
331 | self.oracle_dict["loss_list"][layer] = list()
332 | else:
333 | pass
334 |
335 | if self.iterations_done == 0:
336 | self.prune_network_accomulate["by_layer"][layer] = criteria_for_layer
337 | else:
338 | self.prune_network_accomulate["by_layer"][layer] += criteria_for_layer
339 |
340 | self.iterations_done += 1
341 |
342 | @staticmethod
343 | def group_criteria(list_criteria_per_layer, group_size=1):
344 | '''
345 | Function combine criteria per neuron into groups of size group_size.
346 | Output is a list of groups organized by layers. Length of output is a number of layers.
347 | The criterion for the group is computed as an average of member's criteria.
348 | Input:
349 | list_criteria_per_layer - list of criteria per neuron organized per layer
350 | group_size - number of neurons per group
351 |
352 | Output:
353 | groups - groups organized per layer. Each group element is a tuple of 2: (index of neurons, criterion)
354 | '''
355 | groups = list()
356 |
357 | for layer in list_criteria_per_layer:
358 | layer_groups = list()
359 | indeces = np.argsort(layer)
360 | for group_id in range(int(np.ceil(len(layer)/group_size))):
361 | current_group = slice(group_id*group_size, min((group_id+1)*group_size, len(layer)))
362 | values = [layer[ind] for ind in indeces[current_group]]
363 | group = [indeces[current_group], sum(values)]
364 |
365 | layer_groups.append(group)
366 | groups.append(layer_groups)
367 |
368 | return groups
369 |
370 | def compute_saliency(self):
371 | '''
372 | Method performs pruning based on precomputed criteria values. Needs to run after add_criteria()
373 | '''
374 | def write_to_debug(what_write_name, what_write_value):
375 | # Aux function to store information in the text file
376 | with open(self.log_debug, 'a') as f:
377 | f.write("{} {}\n".format(what_write_name,what_write_value))
378 |
379 | def nothing(what_write_name, what_write_value):
380 | pass
381 |
382 | if self.method == 50:
383 | write_to_debug = nothing
384 |
385 | # compute loss since the last pruning and decide if to prune:
386 | if self.util_loss_tracker_num > 0:
387 | validation_error = self.util_loss_tracker / self.util_loss_tracker_num
388 | validation_error_long = validation_error
389 | acc = self.util_acc_tracker / self.util_loss_tracker_num
390 | else:
391 | print("compute loss and run self.util_add_loss(loss.item()) before running this")
392 | validation_error = 0.0
393 | acc = 0.0
394 |
395 | self.util_training_loss = validation_error
396 | self.util_training_acc = acc
397 |
398 | # reset training loss tracker
399 | self.util_loss_tracker = 0.0
400 | self.util_acc_tracker = 0.0
401 | self.util_loss_tracker_num = 0
402 |
403 | if validation_error > self.pruning_threshold:
404 | ## if error is big then skip pruning
405 | print("skipping pruning", validation_error, "(%f)"%validation_error_long, self.pruning_threshold)
406 | if self.method != 4:
407 | self.res_pruning = -1
408 | return -1
409 |
410 | if self.maximum_pruning_iterations <= self.pruning_iterations_done:
411 | # if reached max number of pruning iterations -> exit
412 | self.res_pruning = -1
413 | return -1
414 |
415 | self.full_list_of_criteria = list()
416 |
417 | for layer, if_prune in enumerate(self.prune_layers):
418 | if not if_prune:
419 | continue
420 |
421 | if self.iterations_done > 0:
422 | # momentum turned to be useless and even reduces performance
423 | contribution = self.prune_network_accomulate["by_layer"][layer] / self.iterations_done
424 | if self.pruning_iterations_done == 0 or not self.use_momentum or (self.method in [4, 40, 50]):
425 | self.prune_network_accomulate["averaged"][layer] = contribution
426 | else:
427 | # use momentum to accumulate criteria over several pruning iterations:
428 | self.prune_network_accomulate["averaged"][layer] = self.momentum_coeff*self.prune_network_accomulate["averaged"][layer]+(1.0- self.momentum_coeff)*contribution
429 |
430 | current_layer = self.prune_network_accomulate["averaged"][layer]
431 | if not (self.method in [1, 4, 40, 15, 50]):
432 | current_layer = current_layer.cpu().numpy()
433 |
434 | if self.l2_normalization_per_layer:
435 | eps = 1e-8
436 | current_layer = current_layer / (np.linalg.norm(current_layer) + eps)
437 |
438 | self.prune_network_accomulate["averaged_cpu"][layer] = current_layer
439 | else:
440 | print("First do some add_criteria iterations")
441 | exit()
442 |
443 | for unit in range(len(self.parameters[layer])):
444 | criterion_now = current_layer[unit]
445 |
446 | # make sure that pruned neurons have 0 criteria
447 | self.prune_network_criteria[layer][unit] = criterion_now * self.pruning_gates[layer][unit]
448 |
449 | if self.method == 50:
450 | self.prune_network_criteria[layer][unit] = criterion_now
451 |
452 | # count number of neurons
453 | all_neuron_units, neuron_units = self._count_number_of_neurons()
454 | self.neuron_units = neuron_units
455 | self.all_neuron_units = all_neuron_units
456 |
457 | # store criteria_result into file
458 | if not self.pruning_silent:
459 | import pickle
460 | store_criteria = self.prune_network_accomulate["averaged_cpu"]
461 | pickle.dump(store_criteria, open(self.folder_to_write_debug + "criteria_%04d.pickle"%self.pruning_iterations_done, "wb"))
462 | if self.pruning_iterations_done == 0:
463 | pickle.dump(store_criteria, open(self.log_folder + "criteria_%d.pickle"%self.method, "wb"))
464 | pickle.dump(store_criteria, open(self.log_folder + "criteria_%d_final.pickle"%self.method, "wb"))
465 |
466 | if not self.fixed_criteria:
467 | self.iterations_done = 0
468 |
469 | # create groups per layer
470 | groups = self.group_criteria(self.prune_network_criteria, group_size=self.group_size)
471 |
472 | # apply flops regularization
473 | # if self.flops_regularization > 0.0:
474 | # self.apply_flops_regularization(groups, mu=self.flops_regularization)
475 |
476 | # get an array of all criteria from groups
477 | all_criteria = np.asarray([group[1] for layer in groups for group in layer]).reshape(-1)
478 |
479 | prune_neurons_now = (self.pruned_neurons + self.prune_per_iteration)//self.group_size - 1
480 | if self.prune_neurons_max != -1:
481 | prune_neurons_now = min(len(all_criteria)-1, min(prune_neurons_now, self.prune_neurons_max//self.group_size - 1))
482 |
483 | # adaptively estimate threshold given a number of neurons to be removed
484 | threshold_now = np.sort(all_criteria)[prune_neurons_now]
485 |
486 | if self.method == 50:
487 | # combinatorial approach
488 | threshold_now = 0.5
489 | self.pruning_iterations_done = self.combination_ID
490 | self.data_logger["combination_ID"].append(self.combination_ID-1)
491 | self.combination_ID += 1
492 | self.reset_oracle_pruning()
493 | print("full_combinatorial: combination ", self.combination_ID)
494 |
495 | self.pruning_iterations_done += 1
496 |
497 | self.log_debug = self.folder_to_write_debug + 'debugOutput_pruning_%08d' % (
498 | self.pruning_iterations_done) + '.txt'
499 | write_to_debug("method", self.method)
500 | write_to_debug("pruned_neurons", self.pruned_neurons)
501 | write_to_debug("pruning_iterations_done", self.pruning_iterations_done)
502 | write_to_debug("neuron_units", neuron_units)
503 | write_to_debug("all_neuron_units", all_neuron_units)
504 | write_to_debug("threshold_now", threshold_now)
505 | write_to_debug("groups_total", sum([len(layer) for layer in groups]))
506 |
507 | if self.pruning_iterations_done < self.start_pruning_after_n_iterations:
508 | self.res_pruning = -1
509 | return -1
510 |
511 | for layer, if_prune in enumerate(self.prune_layers):
512 | if not if_prune:
513 | continue
514 |
515 | write_to_debug("\nLayer:", layer)
516 | write_to_debug("units:", len(self.parameters[layer]))
517 |
518 | if self.prune_per_iteration == 0:
519 | continue
520 |
521 | for group in groups[layer]:
522 | if group[1] <= threshold_now:
523 | for unit in group[0]:
524 | # do actual pruning
525 | self.pruning_gates[layer][unit] *= 0.0
526 | self.parameters[layer].data[unit] *= 0.0
527 |
528 | write_to_debug("pruned_perc:", [np.nonzero(1.0-self.pruning_gates[layer])[0].size, len(self.parameters[layer])])
529 |
530 | # count number of neurons
531 | all_neuron_units, neuron_units = self._count_number_of_neurons()
532 |
533 | self.pruned_neurons = all_neuron_units-neuron_units
534 |
535 | if self.method == 25:
536 | self.method_25_first_done = True
537 |
538 | self.threshold_now = threshold_now
539 | try:
540 | self.min_criteria_value = (all_criteria[all_criteria > 0.0]).min()
541 | self.max_criteria_value = (all_criteria[all_criteria > 0.0]).max()
542 | self.median_criteria_value = np.median(all_criteria[all_criteria > 0.0])
543 | except:
544 | self.min_criteria_value = 0.0
545 | self.max_criteria_value = 0.0
546 | self.median_criteria_value = 0.0
547 |
548 | # set result to successful
549 | self.res_pruning = 1
550 |
551 | def _count_number_of_neurons(self):
552 | '''
553 | Function computes number of total neurons and number of active neurons
554 | :return:
555 | all_neuron_units - number of neurons considered for pruning
556 | neuron_units - number of not pruned neurons in the model
557 | '''
558 | all_neuron_units = 0
559 | neuron_units = 0
560 | for layer, if_prune in enumerate(self.prune_layers):
561 | if not if_prune:
562 | continue
563 |
564 | all_neuron_units += len( self.parameters[layer] )
565 | for unit in range(len( self.parameters[layer] )):
566 | if len(self.parameters[layer].data.size()) > 1:
567 | statistics = self.parameters[layer].data[unit].abs().sum()
568 | else:
569 | statistics = self.parameters[layer].data[unit]
570 |
571 | if statistics > 0.0:
572 | neuron_units += 1
573 |
574 | return all_neuron_units, neuron_units
575 |
576 | def set_weights_oracle_pruning(self):
577 | '''
578 | sets gates/weights to zero to evaluate pruning
579 | will reuse weights for pruning
580 | only for oracle pruning
581 | '''
582 |
583 | for layer,if_prune in enumerate(self.prune_layers_oracle):
584 | if not if_prune:
585 | continue
586 |
587 | if self.method == 40:
588 | self.parameters[layer].data = deepcopy(torch.from_numpy(self.stored_weights).cuda())
589 |
590 | for unit in range(len(self.parameters[layer])):
591 | if self.method == 40:
592 | self.pruning_gates[layer][unit] = 1.0
593 |
594 | if unit == self.oracle_unit:
595 | self.pruning_gates[layer][unit] *= 0.0
596 | self.parameters[layer].data[unit] *= 0.0
597 |
598 | # if 'momentum_buffer' in optimizer.state[self.parameters[layer]].keys():
599 | # optimizer.state[self.parameters[layer]]['momentum_buffer'][unit] *= 0.0
600 | return 1
601 |
602 | def reset_oracle_pruning(self):
603 | '''
604 | Method restores weights to original after masking for Oracle pruning
605 | :return:
606 | '''
607 | for layer, if_prune in enumerate(self.prune_layers_oracle):
608 | if not if_prune:
609 | continue
610 |
611 | if self.method == 40 or self.method == 50:
612 | self.parameters[layer].data = deepcopy(torch.from_numpy(self.stored_weights).cuda())
613 |
614 | for unit in range(len( self.parameters[layer])):
615 | if self.method == 40 or self.method == 50:
616 | self.pruning_gates[layer][unit] = 1.0
617 |
618 | def enforce_pruning(self):
619 | '''
620 | Method sets parameters ang gates to 0 for pruned neurons.
621 | Helpful if optimizer will change weights from being zero (due to regularization etc.)
622 | '''
623 | for layer, if_prune in enumerate(self.prune_layers):
624 | if not if_prune:
625 | continue
626 |
627 | for unit in range(len(self.parameters[layer])):
628 | if self.pruning_gates[layer][unit] == 0.0:
629 | self.parameters[layer].data[unit] *= 0.0
630 |
631 | def compute_hessian(self, loss):
632 | '''
633 | Computes Hessian per layer of the loss with respect to self.parameters, currently implemented only for gates
634 | '''
635 |
636 | if self.maximum_pruning_iterations <= self.pruning_iterations_done:
637 | # if reached max number of pruning iterations -> exit
638 | self.res_pruning = -1
639 | return -1
640 |
641 | self.temp_hessian = list()
642 | for layer_indx, parameter in enumerate(self.parameters):
643 | # print("Computing Hessian current/total layers:",layer_indx,"/",len(self.parameters))
644 | if self.prune_layers[layer_indx]:
645 | grad_params = torch.autograd.grad(loss, parameter, create_graph=True)
646 | length_grad = len(grad_params[0])
647 | hessian = torch.zeros(length_grad, length_grad)
648 |
649 | cnt = 0
650 | for parameter_loc in range(len(parameter)):
651 | if parameter[parameter_loc].data.cpu().numpy().sum() == 0.0:
652 | continue
653 |
654 | grad_params2 = torch.autograd.grad(grad_params[0][parameter_loc], parameter, create_graph=True)
655 | hessian[parameter_loc, :] = grad_params2[0].data
656 |
657 | else:
658 | length_grad = len(parameter)
659 | hessian = torch.zeros(length_grad, length_grad)
660 |
661 | self.temp_hessian.append(torch.FloatTensor(hessian.cpu().numpy()).cuda())
662 |
663 | def run_full_oracle(self, model, data, target, criterion, initial_loss):
664 | '''
665 | Runs oracle on all data by setting to 0 every neuron and running forward pass
666 | '''
667 |
668 | # stop adding data if needed
669 | if self.maximum_pruning_iterations <= self.pruning_iterations_done:
670 | # if reached max number of pruning iterations -> exit
671 | self.res_pruning = -1
672 | return -1
673 |
674 | if self.method == 40:
675 | # for oracle let's try to do the best possible oracle by evaluating all neurons for each batch
676 | self.oracle_dict["initial_loss"] += initial_loss
677 | self.oracle_dict["iterations"] += 1
678 |
679 | # import pdb; pdb.set_trace()
680 | if hasattr(self, 'stored_pruning'):
681 | if self.stored_pruning['use_now']:
682 | # load first list of criteria
683 | print("use previous computed priors")
684 | for layer_index, layer_parameters in enumerate(self.parameters):
685 |
686 | # start list of estiamtes for the layer if it is empty
687 | if len(self.oracle_dict["loss_list"]) < layer_index + 1:
688 | self.oracle_dict["loss_list"].append(list())
689 |
690 | if self.prune_layers[layer_index] == False:
691 | continue
692 |
693 | self.oracle_dict["loss_list"][layer_index] = self.stored_pruning['criteria_start'][layer_index]
694 | self.pruned_neurons = self.stored_pruning['neuron_start']
695 | return 1
696 |
697 | # do first pass with precomputed values
698 | for layer_index, layer_parameters in enumerate(self.parameters):
699 | # start list of estimates for the layer if it is empty
700 | if len(self.oracle_dict["loss_list"]) < layer_index + 1:
701 | self.oracle_dict["loss_list"].append(list())
702 |
703 | if not self.prune_layers[layer_index]:
704 | continue
705 | # copy original prune_layer variable that sets layers to be prunned
706 | self.prune_layers_oracle = [False, ]*len(self.parameters)
707 | self.prune_layers_oracle[layer_index] = True
708 | # store weights for future to recover
709 | self.stored_weights = deepcopy(self.parameters[layer_index].data.cpu().numpy())
710 |
711 | for neurion_id, neuron in enumerate(layer_parameters):
712 | # set neuron to zero
713 | self.oracle_unit = neurion_id
714 | self.set_weights_oracle_pruning()
715 |
716 | if self.stored_weights[neurion_id].sum() == 0.0:
717 | new_loss = initial_loss
718 | else:
719 | outputs = model(data)
720 | loss = criterion(outputs, target)
721 | new_loss = loss.item()
722 |
723 | # define loss
724 | oracle_value = abs(initial_loss - new_loss)
725 | # relative loss for testing:
726 | # oracle_value = initial_loss - new_loss
727 |
728 | if len(self.oracle_dict["loss_list"][layer_index]) == 0:
729 | self.oracle_dict["loss_list"][layer_index] = [oracle_value, ]
730 | elif len(self.oracle_dict["loss_list"][layer_index]) < neurion_id+1:
731 | self.oracle_dict["loss_list"][layer_index].append(oracle_value)
732 | else:
733 | self.oracle_dict["loss_list"][layer_index][neurion_id] += oracle_value
734 |
735 | self.reset_oracle_pruning()
736 |
737 | elif self.method == 50:
738 | if self.pruning_iterations_done == 0:
739 | # store weights again
740 | self.stored_weights = deepcopy(self.parameters[self.fixed_layer].data.cpu().numpy())
741 |
742 | self.set_next_combination()
743 |
744 | else:
745 | pass
746 | # print("Full oracle only works with the methods: {}".format(40))
747 |
748 | def report_loss_neuron(self, training_loss, training_acc, train_writer = None, neurons_left = 0):
749 | '''
750 | method to store stistics during pruning to the log file
751 | :param training_loss:
752 | :param training_acc:
753 | :param train_writer:
754 | :param neurons_left:
755 | :return:
756 | void
757 | '''
758 | if train_writer is not None:
759 | train_writer.add_scalar('loss_neuron', training_loss, self.all_neuron_units-self.neuron_units)
760 |
761 | self.data_logger["pruning_neurons"].append(self.all_neuron_units-self.neuron_units)
762 | self.data_logger["pruning_loss"].append(training_loss)
763 | self.data_logger["pruning_accuracy"].append(training_acc)
764 |
765 | self.write_log_file()
766 |
767 | def write_log_file(self):
768 | with open(self.data_logger["filename"], "wb") as f:
769 | pickle.dump(self.data_logger, f)
770 |
771 | def load_mask(self):
772 | '''Method loads precomputed criteria for pruning
773 | :return:
774 | '''
775 | if not len(self.pruning_mask_from)>0:
776 | print("pruning_engine.load_mask(): did not find mask file, will load nothing")
777 | else:
778 | if not os.path.isfile(self.pruning_mask_from):
779 | print("pruning_engine.load_mask(): file doesn't exist", self.pruning_mask_from)
780 | print("pruning_engine.load_mask(): check it, exit,", self.pruning_mask_from)
781 | exit()
782 |
783 | with open(self.pruning_mask_from, 'rb') as f:
784 | self.loaded_mask_criteria = pickle.load(f)
785 |
786 | print("pruning_engine.load_mask(): loaded criteria from", self.pruning_mask_from)
787 |
788 | def set_next_combination(self):
789 | '''
790 | For combinatorial pruning only
791 | '''
792 | if self.method == 50:
793 |
794 | self.oracle_dict["iterations"] += 1
795 |
796 | for layer_index, layer_parameters in enumerate(self.parameters):
797 |
798 | ##start list of estiamtes for the layer if it is empty
799 | if len(self.oracle_dict["loss_list"]) < layer_index + 1:
800 | self.oracle_dict["loss_list"].append(list())
801 |
802 | if self.prune_layers[layer_index] == False:
803 | continue
804 |
805 | nunits = len(layer_parameters)
806 |
807 | comb_num = -1
808 | found_combination = False
809 | for it in itertools.combinations(range(nunits), self.starting_neuron):
810 | comb_num += 1
811 | if comb_num == int(self.combination_ID):
812 | found_combination = True
813 | break
814 |
815 | # import pdb; pdb.set_trace()
816 | if not found_combination:
817 | print("didn't find needed combination, exit")
818 | exit()
819 |
820 | self.prune_layers_oracle = self.prune_layers.copy()
821 | self.prune_layers_oracle = [False,]*len(self.parameters)
822 | self.prune_layers_oracle[layer_index] = True
823 |
824 | criteria_for_layer = np.ones((nunits,))
825 | criteria_for_layer[list(it)] = 0.0
826 |
827 | if len(self.oracle_dict["loss_list"][layer_index]) == 0:
828 | self.oracle_dict["loss_list"][layer_index] = criteria_for_layer
829 | else:
830 | self.oracle_dict["loss_list"][layer_index] += criteria_for_layer
831 |
832 | def report_to_tensorboard(self, train_writer, processed_batches):
833 | '''
834 | Log data with tensorboard
835 | '''
836 | gradient_norm_final_before = self.gradient_norm_final
837 | train_writer.add_scalar('Neurons_left', self.neuron_units, processed_batches)
838 | train_writer.add_scalar('Criteria_min', self.min_criteria_value, self.pruning_iterations_done)
839 | train_writer.add_scalar('Criteria_max', self.max_criteria_value, self.pruning_iterations_done)
840 | train_writer.add_scalar('Criteria_median', self.median_criteria_value, self.pruning_iterations_done)
841 | train_writer.add_scalar('Gradient_norm_before', gradient_norm_final_before, self.pruning_iterations_done)
842 | train_writer.add_scalar('Pruning_threshold', self.threshold_now, self.pruning_iterations_done)
843 |
844 | def util_add_loss(self, training_loss_current, training_acc):
845 | # keeps track of current loss
846 | self.util_loss_tracker += training_loss_current
847 | self.util_acc_tracker += training_acc
848 | self.util_loss_tracker_num += 1
849 | self.loss_tracker_exp.update(training_loss_current)
850 | # self.acc_tracker_exp.update(training_acc)
851 |
852 | def do_step(self, loss=None, optimizer=None, neurons_left=0, training_acc=0.0):
853 | '''
854 | do one step of pruning,
855 | 1) Add importance estimate
856 | 2) checks if loss is above threshold
857 | 3) performs one step of pruning if needed
858 | '''
859 | self.iter_step += 1
860 | niter = self.iter_step
861 |
862 | # # sets pruned weights to zero
863 | # self.enforce_pruning()
864 |
865 | # stop if pruned maximum amount
866 | if self.maximum_pruning_iterations <= self.pruning_iterations_done:
867 | # exit if we pruned enough
868 | self.res_pruning = -1
869 | return -1
870 |
871 | # sets pruned weights to zero
872 | self.enforce_pruning()
873 |
874 | # compute criteria for given batch
875 | self.add_criteria()
876 |
877 | # small script to keep track of training loss since the last pruning
878 | self.util_add_loss(loss, training_acc)
879 |
880 | if ((niter-1) % self.frequency == 0) and (niter != 0) and (self.res_pruning==1):
881 | self.report_loss_neuron(self.util_training_loss, training_acc=self.util_training_acc, train_writer=self.train_writer, neurons_left=neurons_left)
882 |
883 | if niter % self.frequency == 0 and niter != 0:
884 | # do actual pruning, output: 1 - good, 0 - no pruning
885 |
886 | self.compute_saliency()
887 | self.set_momentum_zero_sgd(optimizer=optimizer)
888 |
889 | training_loss = self.util_training_loss
890 | if self.res_pruning == 1:
891 | print("Pruning: Units", self.neuron_units, "/", self.all_neuron_units, "loss", training_loss, "Zeroed", self.pruned_neurons, "criteria min:{}/max:{:2.7f}".format(self.min_criteria_value,self.max_criteria_value))
892 |
893 | def set_momentum_zero_sgd(self, optimizer=None):
894 | '''
895 | Method sets momentum buffer to zero for pruned neurons. Supports SGD only.
896 | :return:
897 | void
898 | '''
899 | for layer in range(len(self.pruning_gates)):
900 | if not self.prune_layers[layer]:
901 | continue
902 | for unit in range(len(self.pruning_gates[layer])):
903 | if not self.pruning_gates[layer][unit]:
904 | continue
905 | if 'momentum_buffer' in optimizer.state[self.parameters[layer]].keys():
906 | optimizer.state[self.parameters[layer]]['momentum_buffer'][unit] *= 0.0
907 |
908 | def connect_tensorboard(self, tensorboard):
909 | '''
910 | Function connects tensorboard to pruning engine
911 | '''
912 | self.tensorboard = True
913 | self.train_writer = tensorboard
914 |
915 | def update_flops(self, stats=None):
916 | '''
917 | Function updates flops for potential regularization
918 | :param stats: a list of flops per parameter
919 | :return:
920 | '''
921 | self.per_layer_flops = list()
922 | if len(stats["flops"]) < 1:
923 | return -1
924 | for pruning_param in self.gates_to_params:
925 | if isinstance(pruning_param, list):
926 | # parameter spans many blocks, will aggregate over them
927 | self.per_layer_flops.append(sum([stats['flops'][a] for a in pruning_param]))
928 | else:
929 | self.per_layer_flops.append(stats['flops'][pruning_param])
930 |
931 | def apply_flops_regularization(self, groups, mu=0.1):
932 | '''
933 | Function applieregularisation to computed importance per layer
934 | :param groups: a list of groups organized per layer
935 | :param mu: regularization coefficient
936 | :return:
937 | '''
938 | if len(self.per_layer_flops) < 1:
939 | return -1
940 |
941 | for layer_id, layer in enumerate(groups):
942 | for group in layer:
943 | # import pdb; pdb.set_trace()
944 | total_neurons = len(group[0])
945 | group[1] = group[1] - mu*(self.per_layer_flops[layer_id]*total_neurons)
946 |
947 |
948 | def prepare_pruning_list(pruning_settings, model, model_name, pruning_mask_from='', name=''):
949 | '''
950 | Function returns a list of parameters from model to be considered for pruning.
951 | Depending on the pruning method and strategy different parameters are selected (conv kernels, BN parameters etc)
952 | :param pruning_settings:
953 | :param model:
954 | :return:
955 | '''
956 | # Function creates a list of layer that will be pruned based o user selection
957 |
958 | ADD_BY_GATES = True # gates add artificially they have weight == 1 and not trained, but gradient is important. see models/lenet.py
959 | ADD_BY_WEIGHTS = ADD_BY_BN = False
960 |
961 | pruning_method = pruning_settings['method']
962 |
963 | pruning_parameters_list = list()
964 | if ADD_BY_GATES:
965 |
966 | first_step = True
967 | prev_module = None
968 | prev_module2 = None
969 | print("network structure")
970 | for module_indx, m in enumerate(model.modules()):
971 | # print(module_indx, m)
972 | if hasattr(m, "do_not_update"):
973 | m_to_add = m
974 |
975 | if (pruning_method != 23) and (pruning_method != 6):
976 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add,
977 | "compute_criteria_from": m_to_add.weight}
978 | else:
979 | def just_hook(self, grad_input, grad_output):
980 | # getting full gradient for parameters
981 | # normal backward will provide only averaged gradient per batch
982 | # requires to store output of the layer
983 | if len(grad_output[0].shape) == 4:
984 | self.weight.full_grad = (grad_output[0] * self.output).sum(-1).sum(-1)
985 | else:
986 | self.weight.full_grad = (grad_output[0] * self.output)
987 |
988 | if pruning_method == 6:
989 | # implement ICLR2017 paper
990 | def just_hook(self, grad_input, grad_output):
991 | if len(grad_output[0].shape) == 4:
992 | self.weight.full_grad_iclr2017 = (grad_output[0] * self.output).abs().mean(-1).mean(
993 | -1).mean(0)
994 | else:
995 | self.weight.full_grad_iclr2017 = (grad_output[0] * self.output).abs().mean(0)
996 |
997 | def forward_hook(self, input, output):
998 | self.output = output
999 |
1000 | if not len(pruning_mask_from) > 0:
1001 | # in case mask is precomputed we remove hooks
1002 | m_to_add.register_forward_hook(forward_hook)
1003 | m_to_add.register_backward_hook(just_hook)
1004 |
1005 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add,
1006 | "compute_criteria_from": m_to_add.weight}
1007 |
1008 | if pruning_method in [30, 31]:
1009 | # for densenets.
1010 | # add previous layer's value for batch norm pruning
1011 |
1012 | if isinstance(prev_module, nn.BatchNorm2d):
1013 | m_to_add = prev_module
1014 | print(m_to_add, "yes")
1015 | else:
1016 | print(m_to_add, "no")
1017 |
1018 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add,
1019 | "compute_criteria_from": m_to_add.weight}
1020 |
1021 | if pruning_method in [24, ]:
1022 | # add previous layer's value for batch norm pruning
1023 |
1024 | if isinstance(prev_module, nn.Conv2d):
1025 | m_to_add = prev_module
1026 |
1027 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add,
1028 | "compute_criteria_from": m_to_add.weight}
1029 |
1030 | if pruning_method in [0, 2, 3]:
1031 | # add previous layer's value for batch norm pruning
1032 |
1033 | if isinstance(prev_module2, nn.Conv2d):
1034 | print(module_indx, prev_module2, "yes")
1035 | m_to_add = prev_module2
1036 | elif isinstance(prev_module2, nn.Linear):
1037 | print(module_indx, prev_module2, "yes")
1038 | m_to_add = prev_module2
1039 | elif isinstance(prev_module, nn.Conv2d):
1040 | print(module_indx, prev_module, "yes")
1041 | m_to_add = prev_module
1042 | else:
1043 | print(module_indx, m, "no")
1044 |
1045 | for_pruning = {"parameter": m_to_add.weight, "layer": m_to_add,
1046 | "compute_criteria_from": m_to_add.weight}
1047 |
1048 | pruning_parameters_list.append(for_pruning)
1049 | prev_module2 = prev_module
1050 | prev_module = m
1051 |
1052 | if model_name == "resnet20":
1053 | # prune only even layers as in Rethinking min norm pruning
1054 | pruning_parameters_list = [d for di, d in enumerate(pruning_parameters_list) if (di % 2 == 1 and di > 0)]
1055 |
1056 | if ("prune_only_skip_connections" in name) and 1:
1057 | # will prune only skip connections (gates around them). Works with ResNets only
1058 | pruning_parameters_list = pruning_parameters_list[:4]
1059 |
1060 | return pruning_parameters_list
1061 |
1062 | class ExpMeter(object):
1063 | """Computes and stores the average and current value"""
1064 | def __init__(self, mom = 0.9):
1065 | self.reset()
1066 | self.mom = mom
1067 |
1068 | def reset(self):
1069 | self.val = 0
1070 | self.avg = 0
1071 | self.sum = 0
1072 | self.count = 0
1073 | self.exp_avg = 0
1074 |
1075 | def update(self, val, n=1):
1076 | self.val = val
1077 | self.sum += val * n
1078 | self.count += n
1079 | self.mean_avg = self.sum / self.count
1080 | self.exp_avg = self.mom*self.exp_avg + (1.0 - self.mom)*self.val
1081 | if self.count == 1:
1082 | self.exp_avg = self.val
1083 |
1084 | if __name__ == '__main__':
1085 | pass
1086 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>2.0.0
2 | numpy>=1.0.0
3 | protobuf>=3.0.0
4 | tensorboardX>=1.3
5 | torch>=1.0.1
6 | torchvision
--------------------------------------------------------------------------------
/utils/group_lasso_optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from torch.optim.optimizer import Optimizer
7 |
8 | PRINT_ALL = False
9 | USE_FULL = False
10 |
11 | class group_lasso_decay(Optimizer):
12 | r"""Implements group lasso weight decay (GLWD) that pushed entire group to 0.
13 | Normal weight decay makes weight sparse, GLWD will make sparse channel.
14 | Assumes we want to decay the group related to feature maps (channels), other groups are possible but not implemented
15 | """
16 |
17 | def __init__(self, params, group_lasso_weight=0, named_parameters=None, output_sizes=None):
18 | defaults = dict(group_lasso_weight = group_lasso_weight, total_neurons = 0)
19 |
20 | super(group_lasso_decay, self).__init__(params, defaults)
21 |
22 | self.per_layer_per_neuron_stats = {'flops': list(), 'params': list(), 'latency': list()}
23 |
24 | self.named_parameters = named_parameters
25 |
26 | self.output_sizes = None
27 | if output_sizes is not None:
28 | self.output_sizes = output_sizes
29 |
30 | def __setstate__(self, state):
31 | super(group_lasso_decay, self).__setstate__(state)
32 |
33 | def get_number_neurons(self, print_output = False):
34 | total_neurons = 0
35 | for gr_ind, group in enumerate(self.param_groups):
36 | total_neurons += group['total_neurons'].item()
37 | # if print_output:
38 | # print("Total parameters: ",total_neurons, " or ", total_neurons/1e7, " times 1e7")
39 | return total_neurons
40 |
41 | def get_number_flops(self, print_output = False):
42 | total_flops = 0
43 | for gr_ind, group in enumerate(self.param_groups):
44 | total_flops += group['total_flops'].item()
45 |
46 | total_neurons = 0
47 | for gr_ind, group in enumerate(self.param_groups):
48 | total_neurons += group['total_neurons'].item()
49 |
50 | if print_output:
51 | # print("Total flops: ",total_flops, " or ", total_flops/1e9, " times 1e9")
52 |
53 | print("Flops 1e 9/params 1e7: %3.3f & %3.3f"%(total_flops/1e9, total_neurons/1e7))
54 | return total_flops
55 |
56 | def step(self, closure=None):
57 | """Applies GLWD regularization to weights.
58 | Arguments:
59 | closure (callable, optional): A closure that reevaluates the model
60 | and returns the loss.
61 | """
62 | loss = None
63 | if closure is not None:
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | group_lasso_weight = group['group_lasso_weight']
68 | group['total_neurons'] = 0
69 | group['total_flops'] = 0
70 |
71 | for p in group['params']:
72 | param_state = self.state[p]
73 |
74 | if 1:
75 |
76 | weight_size = p.data.size()
77 |
78 | if ('group_lasso_coeff' not in param_state) or (param_state['group_lasso_coeff'].shape != p.data.shape):
79 | nunits = p.data.size(0)
80 | param_state['group_lasso_coeff'] = p.data.clone().view(nunits,-1).sum(dim=1)*0.0 + group_lasso_weight
81 |
82 | group_lasso_weight_local = param_state['group_lasso_coeff']
83 |
84 | if len(weight_size) == 4:
85 | # defined for conv layers only
86 | nunits = p.data.size(0)
87 |
88 | # let's compute denominator
89 | divider = p.data.pow(2).view(nunits,-1).sum(dim=1).pow(0.5)
90 |
91 | eps = 1e-5
92 | eps2 = 1e-13
93 | # check if divider is above threshold
94 | divider_bool = divider.gt(eps).view(-1).float()
95 |
96 | group_lasso_gradient = p.data * (group_lasso_weight_local * divider_bool /
97 | (divider + eps2)).view(nunits, 1, 1, 1).repeat(1, weight_size[1],weight_size[2],weight_size[3])
98 |
99 | # apply weight decay step:
100 | p.data.add_(-1.0, group_lasso_gradient)
101 | return loss
102 |
103 | def step_after(self, closure=None):
104 | """Computes FLOPS and number of neurons after considering zeroed out input and outputs.
105 | Channels are assumed to be pruned if their l2 norm is very small or if magnitude of gradient is very small.
106 | This function does not perform weight pruning, weights are untouched.
107 |
108 | This function also calls push_biases_down which sets corresponding biases to 0.
109 |
110 | Arguments:
111 | closure (callable, optional): A closure that reevaluates the model
112 | and returns the loss.
113 | """
114 | loss = None
115 | if closure is not None:
116 | loss = closure()
117 |
118 | param_index = -1
119 | conv_param_index = -1
120 | for group in self.param_groups:
121 | group_lasso_weight = group['group_lasso_weight']
122 | group['total_neurons'] = 0
123 | group['total_flops'] = 0
124 |
125 | for p in group['params']:
126 |
127 | if group_lasso_weight != 0 or 1:
128 | weight_size = p.data.size()
129 |
130 | if (len(weight_size) == 4) or (len(weight_size) == 2) or (len(weight_size) == 1):
131 | param_index += 1
132 | # defined for conv layers only
133 | nunits = p.data.size(0)
134 | # let's compute denominator
135 | divider = p.data.pow(2).view(nunits,-1).sum(dim=1).pow(0.5)
136 |
137 | eps = 1e-4
138 | # check if divider is above threshold
139 | divider_bool = divider.gt(eps).view(-1).float()
140 |
141 | if (len(weight_size) == 4) or (len(weight_size) == 2) or (len(weight_size) == 1):
142 | if not (p.grad is None):
143 | # consider gradients as well and if gradient is below spesific threshold than we claim parameter to be removed
144 | divider_grad = p.grad.data.pow(2).view(nunits, -1).sum(dim=1).pow(0.5)
145 | eps = 1e-8
146 | divider_bool_grad = divider_grad.gt(eps).view(-1).float()
147 | divider_bool = divider_bool_grad * divider_bool
148 |
149 | if (len(weight_size) == 4) or (len(weight_size) == 2):
150 | # get gradient for input:
151 | divider_grad_input = p.grad.data.pow(2).transpose(0,1).contiguous().view(p.data.size(1),-1).sum(dim=1).pow(0.5)
152 | divider_bool_grad_input = divider_grad_input.gt(eps).view(-1).float()
153 |
154 | divider_input = p.data.pow(2).transpose(0,1).contiguous().view(p.data.size(1), -1).sum(dim=1).pow(0.5)
155 | divider_bool_input = divider_input.gt(eps).view(-1).float()
156 | divider_bool_input = divider_bool_input * divider_bool_grad_input
157 | # if gradient is small then remove it out
158 |
159 | if USE_FULL:
160 | # reset to evaluate true number of flops and neurons
161 | # useful for full network only
162 | divider_bool = 0.0*divider_bool + 1.0
163 | divider_bool_input = 0.0*divider_bool_input + 1.0
164 |
165 | if len(weight_size) == 4:
166 | p.data.mul_(divider_bool.view(nunits,1, 1, 1).repeat(1,weight_size[1], weight_size[2], weight_size[3]))
167 | current_neurons = divider_bool.sum()*divider_bool_input.sum()*weight_size[2]* weight_size[3]
168 |
169 | if len(weight_size) == 2:
170 | current_neurons = divider_bool.sum()*divider_bool_input.sum()
171 |
172 | if len(weight_size) == 1:
173 | current_neurons = divider_bool.sum()
174 | # add mean and var over batches
175 | current_neurons = current_neurons + divider_bool.sum()
176 |
177 | group['total_neurons'] += current_neurons
178 |
179 | if len(weight_size) == 4:
180 | conv_param_index += 1
181 | input_channels = divider_bool_input.sum()
182 | output_channels = divider_bool.sum()
183 |
184 | if self.output_sizes is not None:
185 | output_height, output_width = self.output_sizes[conv_param_index][-2:]
186 | else:
187 | if hasattr(p, 'output_dims'):
188 | output_height, output_width = p.output_dims[-2:]
189 | else:
190 | output_height, output_width = 0, 0
191 |
192 | kernel_ops = weight_size[2] * weight_size[3] * input_channels
193 |
194 | params = output_channels * kernel_ops
195 | flops = params * output_height * output_height
196 |
197 | # add flops due to batch normalization
198 | flops = flops + output_height * output_width*3
199 |
200 | group['total_flops'] = group['total_flops'] + flops
201 |
202 | if len(weight_size) == 1:
203 | flops = len(weight_size)
204 | group['total_flops'] = group['total_flops'] + flops
205 | if len(weight_size) == 2:
206 | input_channels = divider_bool_input.sum()
207 | output_channels = divider_bool.sum()
208 | flops = input_channels * output_channels
209 | group['total_flops'] = group['total_flops'] + flops
210 |
211 | if len(self.per_layer_per_neuron_stats['flops']) <= param_index:
212 | self.per_layer_per_neuron_stats['flops'].append(flops / divider_bool.sum())
213 | self.per_layer_per_neuron_stats['params'].append(current_neurons / divider_bool.sum())
214 | # self.per_layer_per_neuron_stats['latency'].append(flops / output_channels)
215 | else:
216 | self.per_layer_per_neuron_stats['flops'][param_index] = flops / divider_bool.sum()
217 | self.per_layer_per_neuron_stats['params'][param_index] = current_neurons / divider_bool.sum()
218 | # self.per_layer_per_neuron_stats['latency'][param_index] = flops / output_channels
219 |
220 | self.push_biases_down(eps=1e-3)
221 | return loss
222 |
223 | def push_biases_down(self, eps=1e-3):
224 | '''
225 | This function goes over parameters and sets according biases to zero,
226 | without this function biases will not be zero
227 | '''
228 | # first pass
229 | list_of_names = []
230 | for name, param in self.named_parameters:
231 | if "weight" in name:
232 | weight_size = param.data.shape
233 | if (len(weight_size) == 4) or (len(weight_size) == 2):
234 | # defined for conv layers only
235 | nunits = weight_size[0]
236 | # let's compute denominator
237 | divider = param.data.pow(2).view(nunits, -1).sum(dim=1).pow(0.5)
238 | divider_bool = divider.gt(eps).view(-1).float()
239 | list_of_names.append((name.replace("weight", "bias"), divider_bool))
240 |
241 | # second pass
242 | for name, param in self.named_parameters:
243 | if "bias" in name:
244 | for ind in range(len(list_of_names)):
245 | if list_of_names[ind][0] == name:
246 | param.data.mul_(list_of_names[ind][1])
247 |
248 |
249 |
250 |
251 |
252 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 |
8 |
9 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
10 | """Saves checkpoint to disk"""
11 | try:
12 | torch.save(state, filename)
13 | if is_best:
14 | torch.save(state, filename.replace("checkpoint", "best_model"))
15 | except:
16 | print("didn't save checkpoint file")
17 |
18 |
19 | def adjust_learning_rate(args, optimizer, epoch, zero_lr_for_epochs, train_writer):
20 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
21 | lr = args.lr * (args.lr_decay_scalar ** (epoch // args.lr_decay_every)) # * (0.1 ** (epoch // (2*args.lr_decay_every)))
22 | if zero_lr_for_epochs > -1:
23 | if epoch <= zero_lr_for_epochs:
24 | lr = 0.0
25 | # log to TensorBoard
26 | if args.tensorboard:
27 | train_writer.add_scalar('learning_rate', lr, epoch)
28 | print("learning rate adjusted:", lr, epoch)
29 | for param_group in optimizer.param_groups:
30 | param_group['lr'] = lr
31 |
32 |
33 | def adjust_learning_rate_fixed(args, optimizer, epoch, zero_lr_for_epochs, train_writer):
34 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
35 | lr = args.lr * (args.lr_decay_scalar ** (epoch // args.lr_decay_every))
36 | if zero_lr_for_epochs > -1:
37 | if epoch <= zero_lr_for_epochs:
38 | lr = 0.0
39 | # log to TensorBoard
40 | if epoch == args.lr_decay_every:
41 | lr_scale = args.lr_decay_scalar
42 | lr = args.lr * (args.lr_decay_scalar ** (epoch // args.lr_decay_every))
43 | if args.tensorboard:
44 | train_writer.add_scalar('learning_rate', lr, epoch)
45 | print("learning rate adjusted:", lr, epoch)
46 | for param_group in optimizer.param_groups:
47 | param_group['lr'] = param_group['lr']*lr_scale
48 |
49 |
50 | class AverageMeter(object):
51 | """Computes and stores the average and current value"""
52 | def __init__(self):
53 | self.reset()
54 |
55 | def reset(self):
56 | self.val = 0
57 | self.avg = 0
58 | self.sum = 0
59 | self.count = 0
60 |
61 | def update(self, val, n=1):
62 | self.val = val
63 | self.sum += val * n
64 | self.count += n
65 | self.avg = self.sum / self.count
66 |
67 |
68 | def accuracy(output, target, topk=(1,)):
69 | """Computes the precision@k for the specified values of k"""
70 | maxk = max(topk)
71 | batch_size = target.size(0)
72 |
73 | _, pred = output.topk(maxk, 1, True, True)
74 | pred = pred.t()
75 | correct = pred.eq(target.view(1, -1).expand_as(pred))
76 |
77 | res = []
78 | for k in topk:
79 | correct_k = correct[:k].view(-1).float().sum(0)
80 | res.append(correct_k.mul_(100.0 / batch_size))
81 | return res
82 |
83 |
84 | def load_model_pytorch(model, load_model, model_name):
85 | print("=> loading checkpoint '{}'".format(load_model))
86 | checkpoint = torch.load(load_model)
87 |
88 | if 'state_dict' in checkpoint.keys():
89 | load_from = checkpoint['state_dict']
90 | else:
91 | load_from = checkpoint
92 |
93 | # match_dictionaries, useful if loading model without gate:
94 | if 'module.' in list(model.state_dict().keys())[0]:
95 | if 'module.' not in list(load_from.keys())[0]:
96 | from collections import OrderedDict
97 |
98 | load_from = OrderedDict([("module.{}".format(k), v) for k, v in load_from.items()])
99 |
100 | if 'module.' not in list(model.state_dict().keys())[0]:
101 | if 'module.' in list(load_from.keys())[0]:
102 | from collections import OrderedDict
103 |
104 | load_from = OrderedDict([(k.replace("module.", ""), v) for k, v in load_from.items()])
105 |
106 | # just for vgg
107 | if model_name == "vgg":
108 | from collections import OrderedDict
109 |
110 | load_from = OrderedDict([(k.replace("features.", "features"), v) for k, v in load_from.items()])
111 | load_from = OrderedDict([(k.replace("classifier.", "classifier"), v) for k, v in load_from.items()])
112 |
113 | if 1:
114 | for ind, (key, item) in enumerate(model.state_dict().items()):
115 | if ind > 10:
116 | continue
117 | print(key, model.state_dict()[key].shape)
118 |
119 | print("*********")
120 |
121 | for ind, (key, item) in enumerate(load_from.items()):
122 | if ind > 10:
123 | continue
124 | print(key, load_from[key].shape)
125 |
126 | for key, item in model.state_dict().items():
127 | # if we add gate that is not in the saved file
128 | if key not in load_from:
129 | load_from[key] = item
130 | # if load pretrined model
131 | if load_from[key].shape != item.shape:
132 | load_from[key] = item
133 |
134 | model.load_state_dict(load_from, strict=True)
135 |
136 |
137 | epoch_from = -1
138 | if 'epoch' in checkpoint.keys():
139 | epoch_from = checkpoint['epoch']
140 | print("=> loaded checkpoint '{}' (epoch {})"
141 | .format(load_model, epoch_from))
142 |
143 |
144 | def dynamic_network_change_local(model):
145 | '''
146 | Methods attempts to modify network in place by removing pruned filters.
147 | Works with ResNet101 for now only
148 | :param model: reference to torch model to be modified
149 | :return:
150 | '''
151 | # change network dynamically given a pruning mask
152 |
153 | # step 1: model adjustment
154 | # lets go layer by layer and get the mask if we have parameter in pruning settings:
155 |
156 | pruning_maks_input = None
157 | prev_model1 = None
158 | prev_model2 = None
159 | prev_model3 = None
160 |
161 | pruning_mask_indexes = None
162 | gate_track = -1
163 |
164 | skip_connections = list()
165 |
166 | current_skip = 0
167 | DO_SKIP = True
168 |
169 | gate_size = -1
170 | current_skip_mask_size = -1
171 |
172 | for module_indx, m in enumerate(model.modules()):
173 | pruning_mask_indexes = None
174 | if not hasattr(m, "do_not_update"):
175 | if isinstance(m, torch.nn.Conv2d):
176 | print("interm layer", gate_track, module_indx)
177 | print(m)
178 | if 1:
179 | if pruning_maks_input is not None:
180 | print("fixing interm layer", gate_track, module_indx)
181 |
182 | m.weight.data = m.weight.data[:, pruning_maks_input]
183 |
184 | pruning_maks_input = None
185 | print("weight size now", m.weight.data.shape)
186 | m.in_channels = m.weight.data.shape[1]
187 | m.out_channels = m.weight.data.shape[0]
188 |
189 | if DO_SKIP:
190 | print("doing skip connection")
191 | if m.weight.data.shape[0] == current_skip_mask_size:
192 | m.weight.data = m.weight.data[current_skip_mask]
193 | print("weight size after skip", m.weight.data.shape)
194 |
195 | if module_indx in [23, 63, 115, 395]:
196 | if DO_SKIP:
197 | print("fixing interm skip")
198 | m.weight.data = m.weight.data[:, prev_skip_mask]
199 |
200 | print("weight size now", m.weight.data.shape)
201 | m.in_channels = m.weight.data.shape[1]
202 | m.out_channels = m.weight.data.shape[0]
203 |
204 | if DO_SKIP:
205 | print("doing skip connection")
206 | if m.weight.data.shape[0] == current_skip_mask_size:
207 | m.weight.data = m.weight.data[current_skip_mask]
208 | print("weight size after skip", m.weight.data.shape)
209 |
210 | if isinstance(m, torch.nn.BatchNorm2d):
211 | print("interm layer BN: ", gate_track, module_indx)
212 | print(m)
213 | if DO_SKIP:
214 | print("doing skip connection")
215 | if m.weight.data.shape[0] == current_skip_mask_size:
216 | m.weight.data = m.weight.data[current_skip_mask]
217 | print("weight size after skip", m.weight.data.shape)
218 |
219 | # m.weight.data = m.weight.data[current_skip_mask]
220 | m.bias.data = m.bias.data[current_skip_mask]
221 | m.running_mean.data = m.running_mean.data[current_skip_mask]
222 | m.running_var.data = m.running_var.data[current_skip_mask]
223 | else:
224 | # keeping track of gates:
225 | gate_track += 1
226 |
227 | then_pass = False
228 | if gate_track < 4:
229 | # skipping skip connections
230 | then_pass = True
231 | skip_connections.append(m.weight)
232 | current_skip = -1
233 |
234 | if not then_pass:
235 | pruning_mask = m.weight
236 | if gate_size!=m.weight.shape[0]:
237 | current_skip += 1
238 | current_skip_mask_size = skip_connections[current_skip].data.shape[0]
239 |
240 | if skip_connections[current_skip].data.shape[0] != 2048:
241 | current_skip_mask = skip_connections[current_skip].data.nonzero().view(-1)
242 | else:
243 | current_skip_mask = (skip_connections[current_skip].data + 1.0).nonzero().view(-1)
244 | prev_skip_mask_size = 64
245 | prev_skip_mask = range(64)
246 | if current_skip > 0:
247 | prev_skip_mask_size = skip_connections[current_skip - 1].data.shape[0]
248 | prev_skip_mask = skip_connections[current_skip - 1].data.nonzero().view(-1)
249 |
250 | gate_size = m.weight.shape[0]
251 |
252 | if 1:
253 | print("fixing layer", gate_track, module_indx)
254 | print(m)
255 | print(pruning_mask)
256 |
257 | if 1.0 in pruning_mask:
258 | pruning_mask_indexes = pruning_mask.nonzero().view(-1)
259 | else:
260 | pruning_mask_indexes = []
261 | m.weight.data = m.weight.data[pruning_mask_indexes]
262 | for prev_model in [prev_model1, prev_model2, prev_model3]:
263 | if isinstance(prev_model, torch.nn.Conv2d):
264 | print("prev fixing layer", prev_model, gate_track, module_indx)
265 | prev_model.weight.data = prev_model.weight.data[pruning_mask_indexes]
266 | print("weight size", prev_model.weight.data.shape)
267 |
268 | if DO_SKIP:
269 | print("doing skip connection")
270 |
271 | if prev_model.weight.data.shape[1] == current_skip_mask_size:
272 | prev_model.weight.data = prev_model.weight.data[:, current_skip_mask]
273 | print("weight size", prev_model.weight.data.shape)
274 |
275 | if module_indx in [53, 105, 385]: # add one more layer for this transition
276 | print("doing skip connection")
277 |
278 | if prev_model.weight.data.shape[1] == prev_skip_mask_size:
279 | prev_model.weight.data = prev_model.weight.data[:, prev_skip_mask]
280 | print("weight size", prev_model.weight.data.shape)
281 |
282 | if isinstance(prev_model, torch.nn.BatchNorm2d):
283 | print("prev fixing layer", prev_model, gate_track, module_indx)
284 | prev_model.weight.data = prev_model.weight.data[pruning_mask_indexes]
285 | prev_model.bias.data = prev_model.bias.data[pruning_mask_indexes]
286 | prev_model.running_mean.data = prev_model.running_mean.data[pruning_mask_indexes]
287 | prev_model.running_var.data = prev_model.running_var.data[pruning_mask_indexes]
288 |
289 | pruning_maks_input = pruning_mask_indexes
290 |
291 | prev_model3 = prev_model2
292 | prev_model2 = prev_model1
293 | prev_model1 = m
294 |
295 | if DO_SKIP:
296 | # fix gate layers
297 | gate_track = 0
298 |
299 | for module_indx, m in enumerate(model.modules()):
300 | if hasattr(m, "do_not_update"):
301 | gate_track += 1
302 | if gate_track < 4:
303 | if m.weight.shape[0] < 2048:
304 | m.weight.data = m.weight.data[m.weight.nonzero().view(-1)]
305 |
306 | print("printing conv layers")
307 | for module_indx, m in enumerate(model.modules()):
308 | if isinstance(m, torch.nn.Conv2d):
309 | print(module_indx, "->", m.weight.data.shape)
310 |
311 | print("printing bn layers")
312 | for module_indx, m in enumerate(model.modules()):
313 | if isinstance(m, torch.nn.BatchNorm2d):
314 | print(module_indx, "->", m.weight.data.shape)
315 |
316 | print("printing gate layers")
317 | for module_indx, m in enumerate(model.modules()):
318 | if hasattr(m, "do_not_update"):
319 | print(module_indx, "->", m.weight.data.shape, m.size_mask)
320 |
321 |
322 | def add_hook_for_flops(args, model):
323 | # add output dims for FLOPs computation
324 | if 1:
325 | for module_indx, m in enumerate(model.modules()):
326 | if isinstance(m, torch.nn.Conv2d):
327 | def forward_hook(self, input, output):
328 | self.weight.output_dims = output.shape
329 |
330 | m.register_forward_hook(forward_hook)
331 |
332 |
333 | def get_conv_sizes(args, model):
334 | output_sizes = None
335 | if args.compute_flops:
336 | # add hooks to compute dimensions of the output tensors for conv layers
337 | add_hook_for_flops(args, model)
338 | if 1:
339 | if args.dataset=="CIFAR10":
340 | dummy_input = torch.rand(1, 3, 32, 32)
341 | elif args.dataset=="Imagenet":
342 | dummy_input = torch.rand(1, 3, 224, 224)
343 | # run inference
344 | with torch.no_grad():
345 | model(dummy_input)
346 | # store flops
347 | output_sizes = list()
348 | for param in model.parameters():
349 | if hasattr(param, 'output_dims'):
350 | output_dims = param.output_dims
351 | output_sizes.append(output_dims)
352 |
353 | return output_sizes
354 |
355 |
356 | def connect_gates_with_parameters_for_flops(model_name, named_parameters):
357 | '''
358 | Function creates a mapping between gates and parameter index to map flops
359 | :return:
360 | returns a list with mapping, each element is a gate id, entries are corresponding parameters
361 | '''
362 | if "resnet" not in model_name:
363 | print("connect_gates_with_parameters_for_flops only supports resnet for now")
364 | return -1
365 |
366 | # for skip connections, first 4 for them:
367 | gate_to_param_map = [list() for _ in range(4)]
368 |
369 | for param_id, (name, param) in enumerate(named_parameters):
370 | if "layer" not in name:
371 | # skip because we don't prune the first layer
372 | continue
373 |
374 | if ("conv1" in name) or ("conv2" in name):
375 | gate_to_param_map.append(param_id)
376 |
377 | if "conv3" in name:
378 | # third convolution contributes to skip connection only
379 | skip_block_id = int(name[name.find("layer")+len("layer")])
380 | gate_to_param_map[skip_block_id-1].append(param_id)
381 |
382 | return gate_to_param_map
--------------------------------------------------------------------------------