├── ACKNOWLEDGMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── args.py
├── configs
├── quantized
│ ├── lcs_l.yaml
│ ├── lcs_p.yaml
│ └── target_bit_width.yaml
├── structured_sparsity
│ ├── lcs_l.yaml
│ ├── lcs_p.yaml
│ ├── lec.yaml
│ ├── ns.yaml
│ └── us.yaml
└── unstructured_sparsity
│ ├── lcs_l.yaml
│ ├── lcs_p.yaml
│ └── target_topk.yaml
├── curve_utils.py
├── get_training_params.py
├── main.py
├── model_logging.py
├── models
├── builder.py
├── init.py
├── modules.py
├── networks
│ ├── __init__.py
│ ├── channel_selection.py
│ ├── cpreresnet.py
│ ├── model_profiling.py
│ ├── resnet.py
│ ├── resprune.py
│ ├── utils.py
│ ├── vgg.py
│ └── vggprune.py
├── quantize_affine.py
├── quantized_modules.py
├── sparse_modules.py
└── special_tensors.py
├── requirements.txt
├── schedulers.py
├── setup.sh
├── tools
├── format-python
└── remove-whitespace.sh
├── train_curve.py
├── train_indep.py
├── train_quantized.py
├── train_structured.py
├── train_unstructured.py
├── training_params.py
└── utils.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2021 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Compressible Subspaces
2 |
3 | This is the official code release for our publication, [LCS: Learning Compressible Subspaces for Adaptive Network Compression at Inference Time](https://arxiv.org/abs/2110.04252). Our code is used to train and evaluate models that can be compressed in real-time after deployment, allowing for a fine-grained efficiency-accuracy trade-off.
4 |
5 | This repository hosts code to train compressible subspaces for structured sparsity, unstructured sparsity, and quantization. We support three architectures: cPreResNet20, ResNet18, and VGG19. Training is performed on the CIFAR-10 and ImageNet datasets.
6 |
7 | Default training configurations are provided in the `configs` folder. Note that they are automatically altered when different models and datasets are chosen through flags. See `training_params.py`. The following training parameter flags are available to all training regimes:
8 |
9 | - `--model`: Specifies the model to use. One of cpreresnet20, resnet18, or vgg19.
10 | - `--dataset`: Specifies the dataset to train on. One of cifar10 or imagenet.
11 | - `--imagenet_dir`: When using imagenet dataset, the directory to the dataset must be specified.
12 | - `--method`: Specifies the training method. For unstructured sparsity, one of target_topk, lcs_l, lcs_p. For structured sparsity, one of lec, ns, us, lcs_l, lcs_p. For quantized models, one of target_bit_width, lcs_l, lcs_p.
13 | - `--norm`: The normalization layers to use. One of IN (instance normalization), BN (batch normalization), or GN (group normalization).
14 | - `--epochs`: The number of epochs to train for.
15 | - `--learning_rate`: The optimizer learning rate.
16 | - `--batch_size`: Training and test batch sizes.
17 | - `--momentum`: The optimizer momentum.
18 | - `--weight_decay`: The L2 regularization weight.
19 | - `--warmup_budget`: The percentage of epochs to use for the training method warmup phase.
20 | - `--test_freq`: The number of training epochs to wait between test evaluation. Will also save models at this frequency.
21 |
22 | The "lcs_l" training method refers to the "LCS+L" method in the paper. In this setting, we train a linear subspace where one end is optimized for efficiency, while the other end prioritizes accuracy. The "lcs_p" training method refers to the "LCS+P" in the paper and trains a degenerate subspace conditioned to perform at arbitrary sparsity rates in the unstructured and structured sparsity settings, or bit widths in the quantized setting.
23 |
24 | ## Structured Sparsity
25 |
26 | In the structured sparsity setting, we support five training methods:
27 |
28 | 1. "lcs_l" -- This refers to the LCS+L method where one end of the linear subspace performs at high sparsity rates while the other performs at zero sparsity.
29 | 2. "lcs_p" -- This refers to the LCS+P method where we train a degenerate subspace conditioned to perform at arbitrary sparsity rates.
30 | 3. "lec" -- This refers to the method introduced in "Learning Efficient Convolutional Networks through Network Slimming" by Liu et al. (2017). We do not perform fine-tuning, as described in our paper.
31 | 4. "ns" -- This refers to the method introduced in "Slimmable Neural Networks" by Yu et al. (2018). We use a single BatchNorm to allow for evaluation at arbitrary width factors, as decribed in our paper.
32 | 5. "us" -- This refers to the method introduced in "Universally Slimmable Networks and Improved Training Techniques" by Yu & Huang (2019). We do not recalibrate BatchNorms (to facilitate on-device compression to arbitrary widths), as described in our paper.
33 |
34 | Training a model in the structured sparsity setting can be accomplished by running the following command:
35 |
36 | > python train_structured.py
37 |
38 | By default, the command above will train the cPreResNet20 architecture on CIFAR-10 using instance normalization layers with the LCS+L method. To specify the model, dataset, normalization, and training method, the flags `--model`, `--dataset`, `--norm`, `--method` can be used. The following command
39 |
40 | > python train_structured.py --model resnet18 --dataset imagenet --norm IN --method lcs_p --imagenet_dir
41 |
42 | will train a ResNet18 point subspace (LCS+P) on ImageNet using instance normalization layers and the parameters from our paper.
43 |
44 | In addition to the global flags above, the structured setting also has the following:
45 |
46 | - `--width_factors_list`: When training using the "ns" method, this sets the width factors at which the model will be trained.
47 | - `--width_factor_limits`: When training using the "us", "lcs_l", or "lcs_p" methods, sets the lower and upper width factor limits.
48 | - `--width_factor_samples`: When training using the "us", "lcs_l", or "lcs_p" methods, sets the number of samples to use for the sandwich rule. Two of these will be the samples from the width factor limits.
49 | - `--eval_width_factors`: Sets the width factors to evaluate the model for all training methods.
50 |
51 | The command
52 |
53 | > python train_structured.py --model cpreresnet20 --dataset cifar10 --norm BN --method ns --width_factors_list 0.25,0.5,0.75,1.0
54 |
55 | will train a cPreResNet20 architecture on CIFAR-10 via the NS method.
56 |
57 | ## Unstructured Sparsity
58 |
59 | In the unstructured sparsity setting, we support three training methods:
60 |
61 | 1. "lcs_l" -- This refers to the LCS+L method where one end of the linear subspace performs at high sparsity rates while the other performs at zero sparsity.
62 | 2. "lcs_p" -- This refers to the LCS+P method where we train a degenerate subspace conditioned to perform at arbitrary sparsity rates.
63 | 3. "target_topk" -- this will train a network optimized to perform well at a specified TopK target.
64 |
65 | Training a model in the unstructured sparsity setting can be accomplished by running the following command:
66 |
67 | > python train_unstructured.py
68 |
69 | By default, the command above will train the cPreResNet20 architecture on CIFAR-10 using group normalization layers with the LCS+L method and the parameters used described in our paper. To specify the model, dataset, normalization, and training method, the flags `--model`, `--dataset`, `--norm`, `--method` can be used. The following command
70 |
71 | > python train_unstructured.py --model resnet18 --dataset imagenet --norm GN --method lcs_p --imagenet_dir
72 |
73 | will train a ResNet18 point subspace (LCS+P) on ImageNet using group normalization layers again using the parameters from our paper.
74 |
75 | The command
76 |
77 | > python train_unstructured.py --model resnet18 --dataset imagenet --method target_topk --topk 0.5 --imagenet_dir
78 |
79 | will train a VGG19 architecture optimized to perform at a TopK value of 0.5.
80 |
81 | In addition to the global flags above, the unstructured setting also has the following:
82 |
83 | - `--topk`: When training using the "target_topk" method, this sets the target TopK value.
84 | - `--eval_topk_grid`: Will evaluate the model at these TopK values.
85 | - '--topk_lower_bound': The lower bound TopK value (1-sparsity) to be used for training. For linear subspaces, one end of the line will be optimized for sparsity 1-topk_lower_bound which corresponds to the high accuracy endpoint. Note: If specified, eval_topk_grid must be specified as well.
86 | - '--topk_upper_bound': The upper bound TopK value (1-sparsity) to be used for training. For linear subspaces, one end of the line will be optimized for sparsity 1-topk_upper_bound which corresponds to the high efficiency endpoint. Note: If specified, eval_topk_grid must be specified as well.
87 |
88 | The following command
89 |
90 | > python train_unstructured.py --model cpreresnet20 --dataset cifar10 --norm GN --method lcs_p --topk_lower_bound 0.005 --topk_upper_bound 0.05 --eval_topk_grid 0.005,0.01,0.015,0.02,0.025,0.03,0.035,0.04,0.045,0.05
91 |
92 | will train a point subspace with high sparsity.
93 |
94 | ## Quantization
95 |
96 | In the quantized setting, we support three training methods:
97 |
98 | 1. "lcs_l" -- This refers to the LCS+L method where one end of the linear subspace performs at a low bit width while the other performs at a hight bit width.
99 | 2. "lcs_p" -- This refers to the LCS+P method where we train a degenerate subspace conditioned to perform at arbitrary bit widths in a range.
100 | 3. "target_bit_width" -- This trains a network optimized to perform at a specified bit width.
101 |
102 | Training a model in the structured sparsity setting can be accomplished by running the following command:
103 |
104 | > python train_quantized.py
105 |
106 | By default, the command above will train the cPreResNet20 architecture on CIFAR-10 using group normalization layers with the LCS+L method with a bit range [3,8]. To specify the model, dataset, normalization, and training method, the flags `--model`, `--dataset`, `--norm`, `--method` can be used. The following command
107 |
108 | > python train_quantized.py --model vgg19 --dataset imagenet --norm GN --method lcs_p --imagenet_dir
109 |
110 | will train a ResNet18 point subspace (LCS+P) on ImageNet using group normalization layers.
111 |
112 | In addition to the global flags above, the quantized setting also has the following:
113 |
114 | - `--bit_width`: When training using the "target_bit_width" method, this sets the target bit width.
115 | - `--eval_bit_widths`: Will evaluate models at these bit widths.
116 | - `--bit_width_limits`: This sets the upper and lower bit width bounds to use for training.
117 |
118 | The following command
119 |
120 | > python train_quantized.py --model cpreresnet20 --dataset cifar10 --norm GN --method lcs_l --bit_width_limits 3,8 --eval_bit_widths 3,4,5,6,7,8
121 |
122 | will train a linear subspace cPreResNet20 model with GN layers on the ImageNet dataset and will be optimized so that one end of the line performs at 3 bits, and the other at 8.
123 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import argparse
6 |
7 |
8 | def gen_args(desc):
9 | parser = argparse.ArgumentParser(description=desc)
10 |
11 | parser.add_argument(
12 | "--model",
13 | type=str,
14 | default="cpreresnet20",
15 | help="Which model architecture to use. One of crepresnet20, resnet18, vgg19",
16 | )
17 |
18 | parser.add_argument(
19 | "--dataset",
20 | type=str,
21 | default="cifar10",
22 | help="Which dataset to use. One of cifar10, imagenet",
23 | )
24 |
25 | parser.add_argument(
26 | "--imagenet_dir",
27 | type=str,
28 | help="The root directory to the ImageNet dataset",
29 | )
30 |
31 | parser.add_argument("--save_dir", type=str, help="Directory to save model")
32 |
33 | parser.add_argument(
34 | "--norm",
35 | type=str,
36 | help="Layer normalization type. One of BN, IN, GN",
37 | )
38 |
39 | parser.add_argument("--epochs", type=int, help="Number of training epochs")
40 |
41 | parser.add_argument(
42 | "--test_freq", type=int, help="Number of epochs between testing"
43 | )
44 |
45 | parser.add_argument(
46 | "--learning_rate", type=float, help="Optimizer learning rate"
47 | )
48 |
49 | parser.add_argument(
50 | "--batch_size", type=int, help="Training/test batch size"
51 | )
52 |
53 | parser.add_argument("--momentum", type=float, help="Optimizer momentum")
54 |
55 | parser.add_argument(
56 | "--weight_decay", type=float, help="L2 regularization parameter"
57 | )
58 |
59 | return parser
60 |
61 |
62 | def parse_unstructured_arguments():
63 | parser = gen_args("Unstructured Sparsity Training")
64 |
65 | parser.add_argument(
66 | "--method",
67 | type=str,
68 | default="lcs_l",
69 | help="Training method. One of topk_target, lcs_p, lcs_l",
70 | )
71 |
72 | parser.add_argument(
73 | "--topk",
74 | type=float,
75 | help="Target topk (0,1) for BN training",
76 | )
77 |
78 | parser.add_argument(
79 | "--topk_lower_bound",
80 | type=float,
81 | help="Lower bound (high accuracy) endpoint of the learned subspace",
82 | )
83 |
84 | parser.add_argument(
85 | "--topk_upper_bound",
86 | type=float,
87 | help="Upper bound (high efficiency) endpoint of the learned subspace",
88 | )
89 |
90 | parser.add_argument(
91 | "--warmup_budget",
92 | type=float,
93 | help="Value in range (0,100] denoting the percentage of epochs for warmup phase",
94 | )
95 |
96 | parser.add_argument(
97 | "--eval_topk_grid",
98 | type=lambda x: [float(w) for w in x.split(",")],
99 | help="Will evaluate at these topk values",
100 | )
101 |
102 | return parser.parse_args()
103 |
104 |
105 | def parse_structured_arguments():
106 | parser = gen_args("Structured Sparsity Training")
107 |
108 | parser.add_argument(
109 | "--method",
110 | type=str,
111 | default="lcs_l",
112 | help="Training method. One of lec, ns, us, lcs_l, lcs_p",
113 | )
114 |
115 | parser.add_argument(
116 | "--width_factors_list",
117 | type=lambda x: [float(w) for w in x.split(",")],
118 | help="Desired width factors for NS. Ex: --width_factors_list 0.25,0.5,0.75,1.0",
119 | )
120 |
121 | parser.add_argument(
122 | "--width_factor_limits",
123 | type=lambda x: [float(w) for w in x.split(",")],
124 | help="US width factor lower and upper bounds. Ex: --width_factor_limits 0.25,1.0",
125 | )
126 |
127 | parser.add_argument(
128 | "--width_factor_samples",
129 | type=int,
130 | help="Number of width factor samples for US sandwich rule",
131 | )
132 |
133 | parser.add_argument(
134 | "--eval_width_factors",
135 | "--list",
136 | type=lambda x: [float(w) for w in x.split(",")],
137 | help="Width factors at which to evaluate model. Ex: --eval_width_factors 0.25,0.5,0.75,1.0",
138 | )
139 |
140 | return parser.parse_args()
141 |
142 |
143 | def parse_quantized_arguments():
144 | parser = gen_args("Quantized Training")
145 |
146 | parser.add_argument(
147 | "--method",
148 | type=str,
149 | default="lcs_l",
150 | help="Training method. One of target_bit_width, lcs_p, lcs_l",
151 | )
152 |
153 | parser.add_argument(
154 | "--bit_width",
155 | type=int,
156 | help="Target bit width after warmup phase. Used for target_bit_width training",
157 | )
158 |
159 | parser.add_argument(
160 | "--eval_bit_widths",
161 | type=lambda x: [float(w) for w in x.split(",")],
162 | help="Number of bits at which to evaluate. Ex: --eval_num_bits 3,4,5",
163 | )
164 |
165 | parser.add_argument(
166 | "--bit_width_limits",
167 | type=lambda x: [float(w) for w in x.split(",")],
168 | help="Min/max number of bits to train line. Ex: --bit_width_limits 3,8",
169 | )
170 |
171 | return parser.parse_args()
172 |
173 |
174 | def validate_model_data(args):
175 | """
176 | Ensures that the specified model and dataset are compatible
177 | """
178 | implemented_models = ["cpreresnet20", "resnet18", "vgg19"]
179 | implemented_datasets = ["cifar10", "imagenet"]
180 |
181 | if args.model not in implemented_models:
182 | raise ValueError(f"{args.model} not implemented.")
183 |
184 | if args.dataset not in implemented_datasets:
185 | raise ValueError(f"{args.dataset} not implemented.")
186 |
187 | if args.dataset == "imagenet":
188 | if args.model not in ("resnet18", "vgg19"):
189 | raise ValueError(
190 | f"{args.model} does not support ImageNet. Supported models: resnet18, vgg19"
191 | )
192 | elif args.dataset == "cifar10":
193 | if args.model not in ("cpreresnet20"):
194 | raise ValueError(
195 | f"{args.model} does not support CIFAR. Supported models: cpreresnet20"
196 | )
197 |
198 | return args
199 |
200 |
201 | def validate_unstructured_params(args):
202 | """
203 | Esnures that the specified unstructured sparsity parameters are valid
204 | """
205 | lb = args.topk_lower_bound
206 | ub = args.topk_upper_bound
207 |
208 | if (lb is not None and ub is None) or (lb is None and ub is not None):
209 | raise ValueError("Both upper and lower TopK bounds must be specified")
210 |
211 | if lb is not None:
212 | if lb < 0:
213 | raise ValueError("TopK lower bound must be >= 0")
214 | if ub > 1:
215 | raise ValueError("TopK upper bound must be <= 1")
216 | if lb >= ub:
217 | raise ValueError("TopK lower bound must be < upper bound")
218 | if args.eval_topk_grid is None:
219 | raise ValueError(
220 | "eval_topk_grid must be specified when TopK bounds are"
221 | )
222 |
223 | return args
224 |
225 |
226 | def validate_structured_params(args):
227 | """
228 | Ensures that the specified structured sparsity parameters are valid
229 | """
230 | if args.method in ("lec", "ns"):
231 | if args.norm is not None and args.norm != "BN":
232 | raise ValueError("LEC and NS only implemented for BN layers")
233 |
234 | if args.method == "lec" and args.model not in ("cpreresnet20", "vgg19"):
235 | raise ValueError("LEC only implemented for cpreresnet20, vgg19")
236 |
237 | return args
238 |
239 |
240 | def validate_quant_params(args):
241 | """
242 | Ensures that the specified quantized parameters are valid
243 | """
244 | bit_range = [3, 4, 5, 6, 7, 8]
245 |
246 | if args.method == "target_bit_width":
247 | if args.bit_width is not None and args.bit_width not in bit_range:
248 | raise ValueError("Bit width must be one of 3, 4, 5, 6, 7, 8.")
249 |
250 | if args.bit_width_limits is not None:
251 | l_bit, u_bit = args.bit_width_limits
252 | if l_bit < 3:
253 | raise ValueError("Smallest bit width must be >= 3")
254 |
255 | if u_bit > 8:
256 | raise ValueError("Largest bit width must be <= 8")
257 |
258 | return args
259 |
260 |
261 | def unstructured_args():
262 | args = parse_unstructured_arguments()
263 | args = validate_unstructured_params(args)
264 |
265 | return validate_model_data(args)
266 |
267 |
268 | def structured_args():
269 | args = parse_structured_arguments()
270 | args = validate_structured_params(args)
271 |
272 | return validate_model_data(args)
273 |
274 |
275 | def quantized_args():
276 | args = parse_quantized_arguments()
277 | args = validate_quant_params(args)
278 |
279 | return validate_model_data(args)
280 |
--------------------------------------------------------------------------------
/configs/quantized/lcs_l.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "quantized"
7 | dataset: "cifar10"
8 | script: "train_curve.py"
9 | save_dir: "saved_models/quantized/"
10 | # Line:
11 | num_points: 2
12 | conv_type: "LinesConv" # First convolution layer
13 | bn_type: "LinesGN" # Batchnorm for first convolution layer
14 | block_conv_type: "LinesConvBn2d" # Convolution layer for blocks
15 | block_bn_type: "LinesGN" # Convolution batchnorm layer for blocks
16 | epochs: 200 # The number of epochs to train for
17 | batch_size: 128 # Train and test batch sizes
18 | warmup_budget: 80 # percentage of epochs for "warmup" phase
19 | test_freq: 20 # Will run test after this many epochs
20 | alpha_grid: [0.166, 0.333, 0.499, 0.666, 0.833, 1] # When bit widths {3, ..., 8}
21 | learning_rate: 0.025 # Optimizer learning rate
22 | momentum: 0.9 # Optimizer momentum
23 | weight_decay: 0.0005 # L2 regularization parameter
24 | min_bits: 3
25 | max_bits: 8
26 | discrete_alpha_map: True
27 | model_config:
28 | model_class: cpreresnet20
29 | model_kwargs:
30 | channel_selection_active: False
31 |
--------------------------------------------------------------------------------
/configs/quantized/lcs_p.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "quantized"
7 | dataset: "cifar10"
8 | script: "train_indep.py"
9 | save_dir: "saved_models/quantized/"
10 | epochs: 200 # The number of epochs to train for
11 | batch_size: 128 # Train and test batch sizes
12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs
13 | test_freq: 20 # Will run test after this many epochs
14 | learning_rate: 0.025 # Optimizer learning rate
15 | momentum: 0.9 # Optimizer momentum
16 | weight_decay: 0.0005 # L2 regularization parameter
17 | conv_type: "StandardConv" # convolution layer for first layer of network
18 | bn_type: "StandardGN" # batch norm layer for first layer of network
19 | block_conv_type: "ConvBn2d" # convolution layer to use for resnet blocks
20 | block_bn_type: "StandardGN" # batch norm layer to use for resnet blocks
21 | random_param: True # If true, will train with topk=1 during warmup and then use random topk values
22 | eval_param_grid: [3, 4, 5, 6, 7, 8]
23 | min_bits: 3
24 | max_bits: 8
25 | model_config:
26 | model_class: cpreresnet20
27 | model_kwargs:
28 | channel_selection_active: False
29 |
--------------------------------------------------------------------------------
/configs/quantized/target_bit_width.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "quantized"
7 | dataset: "cifar10"
8 | save_dir: "saved_models/quantized/"
9 | # Multiple trials parameters
10 | num_runs: 3
11 | script: "train_indep.py"
12 | epochs: 200 # The number of epochs to train for
13 | batch_size: 128 # Train and test batch sizes
14 | warmup_budget: 80 # decays topk from 1 to alpha over this percentage of epochs
15 | test_freq: 20 # Will run test after this many epochs
16 | learning_rate: 0.025 # Optimizer learning rate
17 | momentum: 0.9 # Optimizer momentum
18 | weight_decay: 0.0005 # L2 regularization parameter
19 | conv_type: "StandardConv"
20 | bn_type: "StandardBN"
21 | block_conv_type: "ConvBn2d"
22 | block_bn_type: "StandardBN"
23 | num_bits: 6 # Target number of bits
24 | random_bits: False # If true, will train with num_bits=8 during warmup and then use random bits in range [min_bits, max_bits]
25 | eval_param_grid: [3, 4, 5, 6, 7, 8]
26 | model_config:
27 | model_class: cpreresnet20
28 | model_kwargs:
29 | channel_selection_active: False
30 |
--------------------------------------------------------------------------------
/configs/structured_sparsity/lcs_l.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "us"
7 | dataset: "cifar10"
8 | script: "train_curve.py"
9 | save_dir: "saved_models/structured_sparsity/"
10 | # Line:
11 | num_points: 2
12 | conv_type: "AdaptiveConv2d" # First convolution layer
13 | bn_type: "LinesAdaptiveIN" # Batchnorm for first convolution layer
14 | block_conv_type: "AdaptiveConv2d" # Convolution layer for ResNet blocks
15 | block_bn_type: "LinesAdaptiveIN" # Convolution batchnorm layer for ResNet blocks
16 | epochs: 200 # The number of epochs to train for
17 | batch_size: 128 # Train and test batch sizes
18 | warmup_budget: 80 # percentage of epochs for "warmup" phase
19 | test_freq: 20 # Will run test after this many epochs
20 | # To evaluate at width factors [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25], we need the alpha grid to look like this.
21 | # (and we must appropriately set the width_factor_limits inside regime_params).
22 | alpha_grid: [1.0, 0.83333, 0.66667, 0.5, 0.33333, 0.166667, 0]
23 | learning_rate: 0.1 # Optimizer learning rate
24 | momentum: 0.9 # Optimizer momentum
25 | weight_decay: 0.0005 # L2 regularization parameter
26 | model_config:
27 | model_class: cpreresnet20
28 | model_kwargs:
29 | channel_selection_active: False
30 | regime_params:
31 | # eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25]
32 | width_factor_limits: [0.25, 1.0]
33 | width_factor_samples: 4
34 | width_factor_sampling_method: sandwich
35 | apply_beta_to_norm: True
36 |
--------------------------------------------------------------------------------
/configs/structured_sparsity/lcs_p.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "us"
7 | dataset: "cifar10"
8 | script: "train_indep.py"
9 | save_dir: "saved_models/structured_sparsity/"
10 | epochs: 200 # The number of epochs to train for
11 | batch_size: 128 # Train and test batch sizes
12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs
13 | test_freq: 20 # Will run test after this many epochs
14 | learning_rate: 0.1 # Optimizer learning rate
15 | momentum: 0.9 # Optimizer momentum
16 | weight_decay: 0.0005 # L2 regularization parameter
17 | conv_type: "AdaptiveConv2d" # convolution layer for first layer of network
18 | bn_type: "AdaptiveIN" # batch norm layer for first layer of network
19 | block_conv_type: "AdaptiveConv2d" # convolution layer to use for resnet blocks
20 | block_bn_type: "AdaptiveIN" # batch norm layer to use for resnet blocks
21 | topk: 1.0 # the target topk to reach after warmup phase
22 | random_topk: False # If true, will train with topk=1 during warmup and then use random topk values
23 | norm_kwargs:
24 | width_factors_list: null
25 | model_config:
26 | model_class: cpreresnet20
27 | model_kwargs:
28 | channel_selection_active: False
29 | regime_params:
30 | eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25]
31 | width_factor_limits: [0.25, 1.0]
32 | width_factor_samples: 4
33 |
34 | eval_param_grid: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] # Same as eval_width_factors
35 |
--------------------------------------------------------------------------------
/configs/structured_sparsity/lec.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "lec"
7 | bn_update_factor: 0.00001
8 | dataset: "cifar10"
9 | script: "train_indep.py"
10 | save_dir: "saved_models/structured_sparsity/"
11 | epochs: 200 # The number of epochs to train for
12 | batch_size: 128 # Train and test batch sizes
13 | warmup_budget: 80 # decays topk from 1 to alpha over this percentage of epochs
14 | test_freq: 20 # Will run test after this many epochs
15 | learning_rate: 0.1 # Optimizer learning rate
16 | momentum: 0.9 # Optimizer momentum
17 | weight_decay: 0.0005 # L2 regularization parameter
18 | conv_type: "StandardConv"
19 | bn_type: "StandardBN"
20 | block_conv_type: "StandardConv"
21 | block_bn_type: "StandardBN"
22 | topk: 1.0
23 | model_config:
24 | model_class: cpreresnet20
25 | model_kwargs:
26 | channel_selection_active: True
27 | eval_param_grid: [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0]
28 |
--------------------------------------------------------------------------------
/configs/structured_sparsity/ns.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "ns"
7 | dataset: "cifar10"
8 | script: "train_indep.py"
9 | save_dir: "saved_models/structured_sparsity/"
10 | epochs: 200 # The number of epochs to train for
11 | batch_size: 128 # Train and test batch sizes
12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs
13 | test_freq: 20 # Will run test after this many epochs
14 | learning_rate: 0.1 # Optimizer learning rate
15 | momentum: 0.9 # Optimizer momentum
16 | weight_decay: 0.0005 # L2 regularization parameter
17 | conv_type: "AdaptiveConv2d" # convolution layer for first layer of network
18 | bn_type: "AdaptiveBN" # [sic] this norm is evaluatable at any point, and is otherwise the same as the AdaptiveIN
19 | block_conv_type: "AdaptiveConv2d" # convolution layer to use for resnet blocks
20 | block_bn_type: "AdaptiveBN" # batch norm layer to use for resnet blocks
21 | topk: 1.0 # the target topk to reach after warmup phase
22 | random_topk: False # If true, will train with topk=1 during warmup and then use random topk values
23 | norm_kwargs:
24 | track_running_stats: True
25 | width_factors_list: null
26 | builder_kwargs:
27 | width_factors_list: [1.0, 0.75, 0.50, 0.25]
28 | model_config:
29 | model_class: cpreresnet20
30 | model_kwargs:
31 | channel_selection_active: False
32 | regime_params:
33 | eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25]
34 |
35 | eval_param_grid: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] # Same as eval_width_factors
36 |
--------------------------------------------------------------------------------
/configs/structured_sparsity/us.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "us"
7 | dataset: "cifar10"
8 | script: "train_indep.py"
9 | save_dir: "saved_models/structured_sparsity/"
10 | epochs: 200 # The number of epochs to train for
11 | batch_size: 128 # Train and test batch sizes
12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs
13 | test_freq: 20 # Will run test after this many epochs
14 | learning_rate: 0.1 # Optimizer learning rate
15 | momentum: 0.9 # Optimizer momentum
16 | weight_decay: 0.0005 # L2 regularization parameter
17 | conv_type: "AdaptiveConv2d" # convolution layer for first layer of network
18 | bn_type: "AdaptiveBN" # batch norm layer for first layer of network
19 | block_conv_type: "AdaptiveConv2d" # convolution layer to use for resnet blocks
20 | block_bn_type: "AdaptiveBN" # batch norm layer to use for resnet blocks
21 | topk: 1.0 # the target topk to reach after warmup phase
22 | random_topk: False # If true, will train with topk=1 during warmup and then use random topk values
23 | norm_kwargs:
24 | width_factors_list: null
25 | model_config:
26 | model_class: cpreresnet20
27 | model_kwargs:
28 | channel_selection_active: False
29 | regime_params:
30 | eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25]
31 | width_factor_limits: [0.25, 1.0]
32 | width_factor_samples: 4
33 |
34 | eval_param_grid: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] # Same as eval_width_factors
35 |
--------------------------------------------------------------------------------
/configs/unstructured_sparsity/lcs_l.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "sparse"
7 | dataset: "cifar10"
8 | script: "train_curve.py"
9 | save_dir: "saved_models/unstructured_sparsity/"
10 | # Line:
11 | num_points: 2
12 | conv_type: "LinesConv" # First convolution layer
13 | bn_type: "LinesGN" # Batchnorm for first convolution layer
14 | block_conv_type: "SparseLinesConv" # Convolution layer for ResNet blocks
15 | block_bn_type: "LinesGN" # Convolution batchnorm layer for ResNet blocks
16 | epochs: 200 # The number of epochs to train for
17 | batch_size: 128 # Train and test batch sizes
18 | warmup_budget: 80 # percentage of epochs for "warmup" phase
19 | test_freq: 20 # Will run test after this many epochs
20 | alpha_grid: [0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0] # eval at these alphas
21 | learning_rate: 0.1 # Optimizer learning rate
22 | momentum: 0.9 # Optimizer momentum
23 | weight_decay: 0.0005 # L2 regularization parameter
24 | alpha_sampling: [0.025, 1, 0.50] # Biased endpoint sampling
25 | model_config:
26 | model_class: cpreresnet20
27 | model_kwargs:
28 | channel_selection_active: False
29 |
--------------------------------------------------------------------------------
/configs/unstructured_sparsity/lcs_p.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "sparse"
7 | dataset: "cifar10"
8 | script: "train_indep.py"
9 | save_dir: "saved_models/unstructured_sparsity/"
10 | epochs: 200 # The number of epochs to train for
11 | batch_size: 128 # Train and test batch sizes
12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs
13 | test_freq: 20 # Will run test after this many epochs
14 | learning_rate: 0.1 # Optimizer learning rate
15 | momentum: 0.9 # Optimizer momentum
16 | weight_decay: 0.0005 # L2 regularization parameter
17 | conv_type: "StandardConv" # convolution layer for first layer of network
18 | bn_type: "StandardGN" # batch norm layer for first layer of network
19 | block_conv_type: "SparseConv2d" # convolution layer to use for resnet blocks
20 | block_bn_type: "StandardGN" # batch norm layer to use for resnet blocks
21 | topk: 0.05 # the target topk to reach after warmup phase. If random_param true, will test at this value
22 | random_param: True # If true, will train with topk=1 during warmup and then use random topk values
23 | eval_param_grid: [0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0]
24 | model_config:
25 | model_class: cpreresnet20
26 | model_kwargs:
27 | channel_selection_active: False
28 |
--------------------------------------------------------------------------------
/configs/unstructured_sparsity/target_topk.yaml:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | parameters:
6 | regime: "sparse"
7 | dataset: "cifar10"
8 | script: "train_indep.py"
9 | save_dir: "saved_models/unstructured_sparsity/"
10 | epochs: 200 # The number of epochs to train for
11 | batch_size: 128 # Train and test batch sizes
12 | warmup_budget: 80 # decays topk from 1 to the target over this percentage of epochs
13 | test_freq: 20 # Will run test after this many epochs
14 | learning_rate: 0.1 # Optimizer learning rate
15 | momentum: 0.9 # Optimizer momentum
16 | weight_decay: 0.0005 # L2 regularization parameter
17 | conv_type: "StandardConv"
18 | bn_type: "StandardBN"
19 | block_conv_type: "SparseConv2d"
20 | block_bn_type: "StandardBN"
21 | topk: 0.9 # the target topk to reach after warmup phase
22 | eval_param_grid: [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0]
23 | model_config:
24 | model_class: cpreresnet20
25 | model_kwargs:
26 | channel_selection_active: False
27 |
--------------------------------------------------------------------------------
/curve_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import numpy as np
6 | import torch.nn as nn
7 |
8 |
9 | def get_stats(model):
10 | norms = {}
11 | numerators = {}
12 | difs = {}
13 | cossim = 0
14 | l2 = 0
15 | num_points = 2
16 |
17 | for i in range(num_points):
18 | norms[f"{i}"] = 0.0
19 | for j in range(i + 1, num_points):
20 | numerators[f"{i}-{j}"] = 0.0
21 | difs[f"{i}-{j}"] = 0.0
22 |
23 | for m in model.modules():
24 | if isinstance(m, nn.Conv2d):
25 | for i in range(num_points):
26 | vi = get_weight(m, i)
27 | norms[f"{i}"] += vi.pow(2).sum()
28 | for j in range(i + 1, num_points):
29 | vj = get_weight(m, j)
30 | numerators[f"{i}-{j}"] += (vi * vj).sum()
31 | difs[f"{i}-{j}"] += (vi - vj).pow(2).sum()
32 |
33 | for i in range(num_points):
34 | for j in range(i + 1, num_points):
35 | cossim += numerators[f"{i}-{j}"].pow(2) / (
36 | norms[f"{i}"] * norms[f"{j}"]
37 | )
38 | l2 += difs[f"{i}-{j}"]
39 |
40 | l2 = l2.pow(0.5).item()
41 | cossim = cossim.item()
42 | return cossim, l2
43 |
44 |
45 | def get_weight(m, i):
46 | if i == 0:
47 | return m.weight
48 | return getattr(m, f"weight{i}")
49 |
50 |
51 | def alpha_bit_map(alpha, **regime_params):
52 | """
53 | Maps a continuous alpha value in [0,1] to a bit value. E.g.
54 | alpha \in [0, 1/8) => 1
55 | alpha \in [1/8, 2/8) => 2
56 | alpha \in [2/8, 3/8) => 3
57 | alpha \in [3/8, 4/8) => 4
58 | alpha \in [4/8, 5/8) => 5
59 | alpha \in [5/8, 6/8) => 6
60 | alpha \in [6/8, 7/8) => 7
61 | alpha \in [7/8, 8/8) => 8
62 | """
63 | min_bits = regime_params["min_bits"]
64 | max_bits = regime_params["max_bits"]
65 | distinct_bits = max_bits - min_bits + 1
66 | for i in range(distinct_bits):
67 | if i / distinct_bits <= alpha <= (i + 1) / distinct_bits:
68 | return np.arange(min_bits, max_bits + 1)[i]
69 |
70 |
71 | def sample_alpha_num_bits(**regime_params):
72 | discrete = regime_params["discrete"]
73 | if discrete:
74 | min_bits = regime_params["min_bits"]
75 | max_bits = regime_params["max_bits"]
76 | distinct_bits = max_bits - min_bits + 1
77 | alpha = (
78 | np.random.choice(np.arange(1, distinct_bits + 1)) / distinct_bits
79 | )
80 | else:
81 | alpha = np.random.uniform(0, 1)
82 |
83 | num_bits = alpha_bit_map(alpha, **regime_params)
84 |
85 | return alpha, num_bits
86 |
87 |
88 | def alpha_sampling(**regime_params):
89 | # biased endpoint sampling
90 | if regime_params.get(f"alpha_sampling") is not None:
91 | low, high, endpoint_prob = regime_params.get(f"alpha_sampling")
92 | if np.random.rand() < endpoint_prob:
93 | # Pick an endpoint at random.
94 | if np.random.rand() < 0.5:
95 | alpha = low
96 | else:
97 | alpha = high
98 | else:
99 | alpha = np.random.uniform(low, high)
100 | else:
101 | alpha = np.random.uniform(0, 1)
102 |
103 | return alpha
104 |
--------------------------------------------------------------------------------
/get_training_params.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import numpy as np
6 |
7 | import training_params
8 | import utils
9 |
10 |
11 | def gen_args_dict(args):
12 | a_dict = {}
13 |
14 | if args.epochs is not None:
15 | a_dict["epochs"] = args.epochs
16 |
17 | if args.test_freq is not None:
18 | a_dict["test_freq"] = args.epochs
19 |
20 | if args.learning_rate is not None:
21 | a_dict["learning_rate"] = args.learning_rate
22 |
23 | if args.batch_size is not None:
24 | a_dict["batch_size"] = args.batch_size
25 |
26 | if args.momentum is not None:
27 | a_dict["momentum"] = args.momentum
28 |
29 | if args.weight_decay is not None:
30 | a_dict["weight_decay"] = args.weight_decay
31 |
32 | if args.save_dir is not None:
33 | a_dict["save_dir"] = args.save_dir
34 |
35 | if args.dataset == "imagenet":
36 | if args.imagenet_dir is not None:
37 | a_dict["dataset_dir"] = args.imagenet_dir
38 | else:
39 | raise ValueError(
40 | f"ImageNet data directory must be specified as --imagenet_dir "
41 | )
42 |
43 | return a_dict
44 |
45 |
46 | def unstructured_args_dict(args):
47 | ua_dict = gen_args_dict(args)
48 |
49 | if args.topk is not None:
50 | ua_dict["topk"] = args.topk
51 |
52 | if args.warmup_budget is not None:
53 | ua_dict["warmup_budget"] = args.warmup_budget
54 |
55 | if args.eval_topk_grid is not None:
56 | if args.method == "lcs_l":
57 | ua_dict["alpha_grid"] = args.eval_topk_grid
58 | else:
59 | ua_dict["eval_param_grid"] = args.eval_topk_grid
60 |
61 | if (
62 | args.topk_lower_bound is not None
63 | and args.topk_upper_bound is not None
64 | and args.eval_topk_grid is not None
65 | ):
66 | if args.method == "lcs_l":
67 | ua_dict["alpha_grid"] = args.eval_topk_grid
68 | elif args.method == "lcs_p":
69 | ua_dict["eval_param_grid"] = args.eval_topk_grid
70 |
71 | ua_dict["alpha_sampling"] = [
72 | args.topk_lower_bound,
73 | args.topk_upper_bound,
74 | 0.5,
75 | ]
76 |
77 | return ua_dict
78 |
79 |
80 | def structured_args_dict(args, base_config):
81 | sa_dict = gen_args_dict(args)
82 |
83 | if args.width_factors_list is not None and args.method == "ns":
84 | builder_kwargs = base_config["parameters"]["builder_kwargs"]
85 | builder_kwargs["width_factors_list"] = args.width_factors_list
86 | sa_dict["builder_kwargs"] = builder_kwargs
87 |
88 | if args.width_factor_limits is not None and args.method != "ns":
89 | regime_params = base_config["parameters"]["regime_params"]
90 | regime_params["width_factor_limits"] = args.width_factor_limits
91 | sa_dict["regime_params"] = regime_params
92 |
93 | if args.width_factor_samples is not None and args.method != "ns":
94 | regime_params = sa_dict.get(
95 | "regime_params", base_config["parameters"]["regime_params"]
96 | )
97 | regime_params["width_factor_samples"] = args.width_factor_samples
98 | sa_dict["regime_params"] = regime_params
99 |
100 | if args.eval_width_factors is not None:
101 | regime_params = sa_dict.get(
102 | "regime_params", base_config["parameters"]["regime_params"]
103 | )
104 | if args.method in ("us", "ns", "lcs_p"):
105 | regime_params["eval_width_factors"] = args.eval_width_factors
106 | sa_dict["regime_params"] = regime_params
107 | sa_dict["eval_param_grid"] = args.eval_width_factors
108 | elif args.method == "lcs_l":
109 | w_l, w_u = regime_params["width_factor_limits"]
110 | alpha_grid = [
111 | (w_f - w_l) / (w_u - w_l) for w_f in args.eval_width_factors
112 | ]
113 | sa_dict["alpha_grid"] = alpha_grid
114 |
115 | return sa_dict
116 |
117 |
118 | def quantized_args_dict(args):
119 | q_dict = gen_args_dict(args)
120 |
121 | if args.bit_width is not None and args.method == "target_bit_width":
122 | q_dict["num_bits"] = args.bit_width
123 |
124 | if args.eval_bit_widths is not None and args.method != "lcs_l":
125 | q_dict["eval_param_grid"] = args.eval_bit_widths
126 |
127 | if args.bit_width_limits is not None and args.method in ("lcs_p", "lcs_l"):
128 | min_bits, max_bits = [int(x) for x in args.bit_width_limits]
129 | q_dict["min_bits"] = min_bits
130 | q_dict["max_bits"] = max_bits
131 | if args.method == "lcs_l":
132 | range_len = max_bits - min_bits + 1
133 | alpha_grid = [
134 | np.floor(x * 1000) / 1000
135 | for x in np.linspace(0, 1, range_len + 1)
136 | ][1:]
137 | q_dict["alpha_grid"] = alpha_grid
138 | elif args.method == "lcs_p":
139 | if args.eval_bit_widths is None:
140 | eval_param_grid = np.arange(min_bits, max_bits + 1)
141 | q_dict["eval_param_grid"] = eval_param_grid
142 |
143 | return q_dict
144 |
145 |
146 | def get_config_norm(params):
147 | norm_types = ["IN", "BN", "GN"]
148 | base_bn_type = params["bn_type"]
149 | base_block_bn_type = params["block_bn_type"]
150 | for n in norm_types:
151 | if n in base_bn_type:
152 | base_bn = n
153 | if n in base_block_bn_type:
154 | base_block_bn = n
155 |
156 | for norm in norm_types:
157 | if norm in base_bn and norm in base_block_bn:
158 | return norm
159 |
160 |
161 | def get_method_config(args, setting):
162 | method = args.method
163 |
164 | base_config_dir = f"configs/{setting}/{method}.yaml"
165 |
166 | try:
167 | base_config = utils.get_yaml_config(base_config_dir)
168 | except:
169 | raise ValueError(f"{setting}/{method} not valid training configuration")
170 |
171 | params = base_config["parameters"]
172 |
173 | # Update base config with specific parameters
174 |
175 | # Set normalization layers
176 | config_norm = get_config_norm(params)
177 | norm_to_use = config_norm if args.norm is None else args.norm
178 | norm_types = ["IN", "BN", "GN"]
179 | if norm_to_use not in norm_types:
180 | raise ValueError(
181 | f"Norm {norm_to_use} not valid. Supported normalization types: IN, BN, GN."
182 | )
183 |
184 | base_bn_type = base_config["parameters"]["bn_type"]
185 | base_block_bn_type = base_config["parameters"]["block_bn_type"]
186 | for n in norm_types:
187 | if n in base_bn_type:
188 | base_bn = n
189 | if n in base_block_bn_type:
190 | base_block_bn = n
191 |
192 | if setting == "quantized" and norm_to_use == "BN":
193 | base_config["parameters"]["bn_type"] = "QuantStandardBN"
194 | base_config["parameters"]["block_bn_type"] = "QuantStandardBN"
195 | else:
196 | base_config["parameters"]["bn_type"] = base_bn_type.replace(
197 | base_bn, norm_to_use
198 | )
199 | base_config["parameters"]["block_bn_type"] = base_block_bn_type.replace(
200 | base_block_bn, norm_to_use
201 | )
202 | # Set track_running_stats to True if using BN
203 | if norm_to_use == "BN":
204 | base_norm_kwargs = base_config["parameters"].get(
205 | "norm_kwargs", None
206 | )
207 | if base_norm_kwargs is not None:
208 | base_norm_kwargs["track_running_stats"] = True
209 | else:
210 | base_config["parameters"]["norm_kwargs"] = {
211 | "track_running_stats": True
212 | }
213 |
214 | # Update dataset
215 | dataset = args.dataset
216 | base_config["parameters"]["dataset"] = dataset
217 |
218 | # Update model
219 | model_name = args.model.lower()
220 | base_config["parameters"]["model_config"]["model_class"] = model_name
221 | model_kwargs = params.get("model_kwargs", None)
222 | if model_kwargs is not None:
223 | base_model_kwargs = base_config["parameters"]["model_config"].get(
224 | "model_kwargs", None
225 | )
226 | if base_model_kwargs is not None:
227 | base_model_kwargs.update(model_kwargs)
228 | else:
229 | base_config["parameters"]["model_config"]["model_kwargs"] = {}
230 |
231 | # Remove channel_selection_active for models without this parameter
232 | if model_name not in ("cpreresnet20"):
233 | base_model_kwargs = base_config["parameters"]["model_config"][
234 | "model_kwargs"
235 | ]
236 | base_model_kwargs.pop("channel_selection_active", None)
237 |
238 | # For unstructured sparsity, if using GN, set num_groups to 32
239 | # We cannot use GN with structured sparsity (number of channels isn't always
240 | # divisible by 32), we use IN instead.
241 | if norm_to_use == "GN":
242 | if setting in ("unstructured_sparsity", "quantized"):
243 | num_groups = 32
244 | else:
245 | raise NotImplementedError(
246 | f"GroupNorm disabled for setting={setting}."
247 | )
248 | base_norm_kwargs = base_config["parameters"].get("norm_kwargs", None)
249 | if base_norm_kwargs is not None:
250 | base_norm_kwargs["num_groups"] = num_groups
251 | else:
252 | base_config["parameters"]["norm_kwargs"] = {
253 | "num_groups": num_groups
254 | }
255 |
256 | # Set default model training parameters
257 | model_training_params = training_params.model_data_params(args)
258 | base_config["parameters"].update(model_training_params)
259 |
260 | # Update training parameters with user-specified ones
261 | if setting == "unstructured_sparsity":
262 | args_dict = unstructured_args_dict(args)
263 | elif setting == "structured_sparsity":
264 | args_dict = structured_args_dict(args, base_config)
265 | elif setting == "quantized":
266 | args_dict = quantized_args_dict(args)
267 | else:
268 | args_dict = {}
269 |
270 | base_config["parameters"].update(args_dict)
271 |
272 | # Get/make save directory
273 | args_save_dir = args.save_dir
274 | if args_save_dir is None:
275 | config_save_dir = params["save_dir"]
276 | save_dir = utils.create_save_dir(
277 | config_save_dir, method, model_name, dataset, norm_to_use
278 | )
279 | else:
280 | save_dir = utils.create_save_dir(
281 | args_save_dir, method, model_name, dataset, norm_to_use, False
282 | )
283 |
284 | base_config["parameters"]["save_dir"] = save_dir
285 |
286 | # Print training details
287 | utils.print_train_params(
288 | base_config, setting, method, norm_to_use, save_dir
289 | )
290 |
291 | return base_config
292 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import train_curve
6 | import train_indep
7 | from get_training_params import get_method_config
8 |
9 |
10 | def train(args, setting):
11 | config = get_method_config(args, setting)
12 | if config["parameters"]["script"] == "train_curve.py":
13 | train_curve.train_model(config)
14 | else:
15 | train_indep.train_model(config)
16 |
--------------------------------------------------------------------------------
/model_logging.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import os
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | import utils
12 | from curve_utils import alpha_bit_map
13 | from models.networks import model_profiling
14 |
15 |
16 | def save_model_at_epoch(model, epoch, save_dir, save_per=10):
17 | """
18 | Saves the given model as 'model_{epoch}'
19 | """
20 | if epoch % save_per == 0:
21 | save_path = os.path.join(save_dir, f"model_{epoch}")
22 | torch.save(model.state_dict(), save_path)
23 | else:
24 | print(
25 | f"Skipping saving at epoch {epoch} (saving every {save_per} epochs)."
26 | )
27 |
28 |
29 | def save(save_name, log, save_dir):
30 | save_path = os.path.join(save_dir, save_name)
31 | np.save(save_path, log)
32 |
33 |
34 | def _handle_bn_stats(metric_dict, extra_metrics, regime_params):
35 | if metric_dict is None:
36 | return None
37 |
38 | if "bn_metrics" not in extra_metrics:
39 | print(f"No BN metrics found, skipping BN stats.")
40 | return None
41 |
42 | bn_metrics = extra_metrics["bn_metrics"]
43 |
44 | for name, value in bn_metrics.items():
45 | if name not in metric_dict:
46 | metric_dict[name] = []
47 | metric_dict[name].append(value)
48 | return metric_dict
49 |
50 |
51 | def sparse_logging(
52 | model,
53 | loss,
54 | acc,
55 | model_type,
56 | param=None,
57 | metric_dict=None,
58 | extra_metrics=None,
59 | **regime_params,
60 | ):
61 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params)
62 |
63 | sparsity = utils.get_sparsity_rate(model)
64 |
65 | param_name = "alpha" if model_type == "curve" else "topk"
66 | param_print = "" if param is None else f" ({param_name} = {param})"
67 | print(
68 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f} | Sparsity: {sparsity:.4f}"
69 | )
70 |
71 | if metric_dict is not None:
72 | metric_dict["acc"].append(acc)
73 | metric_dict["sparsity"].append(sparsity)
74 | if param_name not in metric_dict:
75 | metric_dict[param_name] = []
76 | metric_dict[param_name].append(param)
77 |
78 | return metric_dict
79 |
80 |
81 | def quantized_logging(
82 | model,
83 | loss,
84 | acc,
85 | model_type,
86 | param=None,
87 | metric_dict=None,
88 | extra_metrics=None,
89 | **regime_params,
90 | ):
91 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params)
92 |
93 | if model_type == "curve":
94 | param_name = "alpha"
95 | inv_alpha = regime_params.get("inv_alpha", False)
96 | if inv_alpha:
97 | num_bits = alpha_bit_map(1 - param, **regime_params)
98 | else:
99 | num_bits = alpha_bit_map(param, **regime_params)
100 | param_print = f" (alpha/num_bits = {param}/{num_bits})"
101 | else:
102 | param_name = "num_bits"
103 | num_bits = param
104 | param_print = "" if param is None else f" ({param_name} = {num_bits})"
105 |
106 | print(
107 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f}"
108 | )
109 |
110 | if metric_dict is not None:
111 | metric_dict["acc"].append(acc)
112 | metric_dict["num_bits"].append(num_bits)
113 | if param_name != "num_bits": # Append alpha for curves
114 | metric_dict[param_name].append(param)
115 |
116 | return metric_dict
117 |
118 |
119 | def lec_logging(
120 | model,
121 | loss,
122 | acc,
123 | model_type,
124 | param=None,
125 | metric_dict=None,
126 | extra_metrics=None,
127 | **regime_params,
128 | ):
129 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params)
130 |
131 | sparsity = regime_params["sparsity"]
132 |
133 | param_name = "alpha" if model_type == "curve" else "topk"
134 | param_print = "" if param is None else f" ({param_name} = {param})"
135 | print(
136 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f} | Sparsity: {sparsity:.4f}"
137 | )
138 |
139 | if metric_dict is not None:
140 | metric_dict["acc"].append(acc)
141 | metric_dict["sparsity"].append(sparsity)
142 | if param_name not in metric_dict:
143 | metric_dict[param_name] = []
144 | metric_dict[param_name].append(param)
145 |
146 | return metric_dict
147 |
148 |
149 | def ns_logging(
150 | model,
151 | loss,
152 | acc,
153 | model_type,
154 | param=None,
155 | metric_dict=None,
156 | extra_metrics=None,
157 | **regime_params,
158 | ):
159 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params)
160 |
161 | input_size = regime_params["input_size"]
162 |
163 | is_cuda = next(model.parameters()).is_cuda
164 | n_macs, n_params = model_profiling.model_profiling(
165 | model, input_size, input_size, use_cuda=is_cuda
166 | )
167 |
168 | total_params = 0
169 | for module in model.modules():
170 | if isinstance(
171 | module, (nn.Linear, nn.Conv2d)
172 | ): # Make sure we also count contributions from non-sparse elements.
173 | total_params += module.weight.numel()
174 |
175 | sparsity = (total_params - n_params) / total_params
176 |
177 | param_name = "width_factor"
178 | param_print = f" ({param_name} = {regime_params[param_name]})"
179 | print(
180 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f} | Sparsity: {sparsity:.4f}"
181 | )
182 |
183 | if metric_dict is not None:
184 | metric_dict["acc"].append(acc)
185 | metric_dict["sparsity"].append(sparsity)
186 | metric_dict["alpha"].append(param)
187 | if param_name not in metric_dict:
188 | metric_dict[param_name] = []
189 | metric_dict[param_name].append(regime_params[param_name])
190 |
191 | return metric_dict
192 |
193 |
194 | def us_logging(*args, **kwargs):
195 | return ns_logging(*args, **kwargs)
196 |
--------------------------------------------------------------------------------
/models/builder.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from . import init
7 | from . import modules
8 | from . import quantized_modules
9 | from . import sparse_modules
10 |
11 |
12 | class Builder:
13 | def __init__(
14 | self,
15 | conv_type="LinesConv",
16 | bn_type="LinesBN",
17 | conv_init="kaiming_normal",
18 | norm_kwargs=None,
19 | pass_first_last=False,
20 | **conv_kwargs
21 | ):
22 | if norm_kwargs is None:
23 | norm_kwargs = {}
24 | self.pass_first_last = pass_first_last
25 |
26 | self.conv_kwargs = conv_kwargs
27 | self.norm_kwargs = norm_kwargs
28 |
29 | if hasattr(modules, bn_type):
30 | self.bn_layer = getattr(modules, bn_type)
31 | elif hasattr(sparse_modules, bn_type):
32 | self.bn_layer = getattr(sparse_modules, bn_type)
33 | elif hasattr(quantized_modules, bn_type):
34 | self.bn_layer = getattr(quantized_modules, bn_type)
35 | else:
36 | raise ValueError("Normalization layer not found")
37 |
38 | if hasattr(modules, conv_type):
39 | self.conv_layer = getattr(modules, conv_type)
40 | elif hasattr(sparse_modules, conv_type):
41 | self.conv_layer = getattr(sparse_modules, conv_type)
42 | elif hasattr(quantized_modules, conv_type):
43 | self.conv_layer = getattr(quantized_modules, conv_type)
44 | self.conv_kwargs[
45 | "bn_module"
46 | ] = self.bn_layer # self.bn_layer chosen above
47 | self.conv_kwargs["norm_kwargs"] = norm_kwargs
48 | # Overwrite self.bn_layer to no-op batchnorm since handled by ConvBn
49 | self.bn_layer = getattr(quantized_modules, "NoOpBN")
50 | else:
51 | raise ValueError("Convolution layer not found")
52 |
53 | self.conv_init = getattr(init, conv_init)
54 |
55 | def conv(
56 | self,
57 | kernel_size,
58 | in_planes,
59 | out_planes,
60 | stride=1,
61 | groups=1,
62 | first_layer=False,
63 | last_layer=False,
64 | is_conv=False,
65 | ):
66 | conv_kwargs = self.conv_kwargs.copy()
67 | if self.pass_first_last:
68 | conv_kwargs["first_layer"] = first_layer
69 | conv_kwargs["last_layer"] = last_layer
70 |
71 | if kernel_size == 1:
72 | conv = self.conv_layer(
73 | in_planes,
74 | out_planes,
75 | kernel_size=1,
76 | stride=stride,
77 | bias=False,
78 | **conv_kwargs
79 | )
80 | elif kernel_size == 3:
81 | conv = self.conv_layer(
82 | in_channels=in_planes,
83 | out_channels=out_planes,
84 | kernel_size=3,
85 | stride=stride,
86 | padding=1,
87 | groups=groups,
88 | bias=False,
89 | **conv_kwargs
90 | )
91 | elif kernel_size == 5:
92 | conv = self.conv_layer(
93 | in_planes,
94 | out_planes,
95 | kernel_size=5,
96 | stride=stride,
97 | padding=2,
98 | groups=groups,
99 | bias=False,
100 | **conv_kwargs
101 | )
102 | elif kernel_size == 7:
103 | conv = self.conv_layer(
104 | in_planes,
105 | out_planes,
106 | kernel_size=7,
107 | stride=stride,
108 | padding=3,
109 | groups=groups,
110 | bias=False,
111 | **conv_kwargs
112 | )
113 | else:
114 | return None
115 |
116 | conv.first_layer = first_layer
117 | conv.last_layer = last_layer
118 | conv.is_conv = is_conv
119 | self.conv_init(conv.weight)
120 | if hasattr(conv, "initialize"):
121 | conv.initialize(self.conv_init)
122 | return conv
123 |
124 | def conv1x1(
125 | self,
126 | in_planes,
127 | out_planes,
128 | stride=1,
129 | groups=1,
130 | first_layer=False,
131 | last_layer=False,
132 | is_conv=False,
133 | ):
134 | """1x1 convolution with padding"""
135 | c = self.conv(
136 | 1,
137 | in_planes,
138 | out_planes,
139 | stride=stride,
140 | groups=groups,
141 | first_layer=first_layer,
142 | last_layer=last_layer,
143 | is_conv=is_conv,
144 | )
145 |
146 | return c
147 |
148 | def conv3x3(
149 | self,
150 | in_planes,
151 | out_planes,
152 | stride=1,
153 | groups=1,
154 | first_layer=False,
155 | last_layer=False,
156 | is_conv=False,
157 | ):
158 | """3x3 convolution with padding"""
159 | c = self.conv(
160 | 3,
161 | in_planes,
162 | out_planes,
163 | stride=stride,
164 | groups=groups,
165 | first_layer=first_layer,
166 | last_layer=last_layer,
167 | is_conv=is_conv,
168 | )
169 | return c
170 |
171 | def conv5x5(
172 | self,
173 | in_planes,
174 | out_planes,
175 | stride=1,
176 | groups=1,
177 | first_layer=False,
178 | last_layer=False,
179 | is_conv=False,
180 | ):
181 | """5x5 convolution with padding"""
182 | c = self.conv(
183 | 5,
184 | in_planes,
185 | out_planes,
186 | stride=stride,
187 | groups=groups,
188 | first_layer=first_layer,
189 | last_layer=last_layer,
190 | is_conv=is_conv,
191 | )
192 | return c
193 |
194 | def conv7x7(
195 | self,
196 | in_planes,
197 | out_planes,
198 | stride=1,
199 | groups=1,
200 | first_layer=False,
201 | last_layer=False,
202 | is_conv=False,
203 | ):
204 | """7x7 convolution with padding"""
205 | c = self.conv(
206 | 7,
207 | in_planes,
208 | out_planes,
209 | stride=stride,
210 | groups=groups,
211 | first_layer=first_layer,
212 | last_layer=last_layer,
213 | is_conv=is_conv,
214 | )
215 | return c
216 |
217 | def batchnorm(self, planes):
218 | return self.bn_layer(planes, **self.norm_kwargs)
219 |
--------------------------------------------------------------------------------
/models/init.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch.nn as nn
6 |
7 |
8 | def kaiming_normal(weight):
9 | nn.init.kaiming_normal_(
10 | weight,
11 | )
12 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | from typing import Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | # Convolutions
12 | StandardConv = nn.Conv2d
13 |
14 |
15 | class SubspaceConv(nn.Conv2d):
16 | def forward(self, x):
17 | # call get_weight, which samples from the subspace, then use the
18 | # corresponding weight.
19 | w = self.get_weight()
20 | x = F.conv2d(
21 | x,
22 | w,
23 | self.bias,
24 | self.stride,
25 | self.padding,
26 | self.dilation,
27 | self.groups,
28 | )
29 | return x
30 |
31 |
32 | class TwoParamConv(SubspaceConv):
33 | def __init__(self, *args, **kwargs):
34 | super().__init__(*args, **kwargs)
35 | self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
36 |
37 | def initialize(self, initialize_fn):
38 | initialize_fn(self.weight)
39 | initialize_fn(self.weight1)
40 |
41 |
42 | class LinesConv(TwoParamConv):
43 | def get_weight(self):
44 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
45 | return w
46 |
47 |
48 | # BatchNorms
49 | StandardBN = nn.BatchNorm2d
50 |
51 |
52 | class SubspaceBN(nn.BatchNorm2d):
53 | def forward(self, input):
54 | # call get_weight, which samples from the subspace, then use the
55 | # corresponding weight.
56 | w, b = self.get_weight()
57 |
58 | # The rest is code in the PyTorch source forward pass for batchnorm.
59 | if self.momentum is None:
60 | exponential_average_factor = 0.0
61 | else:
62 | exponential_average_factor = self.momentum
63 |
64 | if self.training and self.track_running_stats:
65 | if self.num_batches_tracked is not None:
66 | self.num_batches_tracked = self.num_batches_tracked + 1
67 | if self.momentum is None: # use cumulative moving average
68 | exponential_average_factor = 1.0 / float(
69 | self.num_batches_tracked
70 | )
71 | else: # use exponential moving average
72 | exponential_average_factor = self.momentum
73 | if self.training:
74 | bn_training = True
75 | else:
76 | bn_training = (self.running_mean is None) and (
77 | self.running_var is None
78 | )
79 | return F.batch_norm(
80 | input,
81 | # If buffers are not to be tracked, ensure that they won't be
82 | # updated
83 | self.running_mean
84 | if not self.training or self.track_running_stats
85 | else None,
86 | self.running_var
87 | if not self.training or self.track_running_stats
88 | else None,
89 | w,
90 | b,
91 | bn_training,
92 | exponential_average_factor,
93 | self.eps,
94 | )
95 |
96 |
97 | class TwoParamBN(SubspaceBN):
98 | def __init__(self, *args, **kwargs):
99 | super().__init__(*args, **kwargs)
100 | self.weight1 = nn.Parameter(torch.empty([self.num_features]))
101 | self.bias1 = nn.Parameter(torch.empty([self.num_features]))
102 | torch.nn.init.ones_(self.weight1)
103 | torch.nn.init.zeros_(self.bias1)
104 |
105 |
106 | class LinesBN(TwoParamBN):
107 | def get_weight(self):
108 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
109 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1
110 | return w, b
111 |
112 |
113 | # InstanceNorm
114 | def StandardIN(*args, affine=True, **kwargs):
115 | return nn.InstanceNorm2d(*args, affine=affine, **kwargs)
116 |
117 |
118 | def _process_num_groups(num_groups: Union[str, int], num_channels: int) -> int:
119 | if num_groups == "full":
120 | num_groups = num_channels # Set it equal to num_features.
121 | else:
122 | num_groups = int(num_groups)
123 |
124 | # If num_groups is greater than num_features, we reduce it.
125 | num_groups = min(num_channels, num_groups)
126 | return num_groups
127 |
128 |
129 | class SubspaceIN(nn.InstanceNorm2d):
130 | def __init__(
131 | self,
132 | num_features: int,
133 | eps: float = 1e-5,
134 | momentum: float = 0.1,
135 | affine: bool = False,
136 | track_running_stats: bool = False,
137 | ) -> None:
138 | # Override @affine to be true.
139 | super().__init__(
140 | num_features,
141 | eps=eps,
142 | momentum=momentum,
143 | affine=True,
144 | track_running_stats=track_running_stats,
145 | )
146 |
147 | def forward(self, input):
148 | # call get_weight, which samples from the subspace, then use the
149 | # corresponding weight.
150 | w, b = self.get_weight()
151 |
152 | # The rest is code in the PyTorch source forward pass for instancenorm.
153 | assert self.running_mean is None or isinstance(
154 | self.running_mean, torch.Tensor
155 | )
156 | assert self.running_var is None or isinstance(
157 | self.running_var, torch.Tensor
158 | )
159 | return F.instance_norm(
160 | input,
161 | self.running_mean,
162 | self.running_var,
163 | w,
164 | b,
165 | self.training or not self.track_running_stats,
166 | self.momentum,
167 | self.eps,
168 | )
169 |
170 |
171 | class TwoParamIN(SubspaceIN):
172 | def __init__(self, *args, **kwargs):
173 | super().__init__(*args, **kwargs)
174 | self.weight1 = nn.Parameter(torch.empty([self.num_features]))
175 | self.bias1 = nn.Parameter(torch.empty([self.num_features]))
176 | torch.nn.init.ones_(self.weight1)
177 | torch.nn.init.zeros_(self.bias1)
178 |
179 |
180 | class LinesIN(TwoParamIN):
181 | def get_weight(self):
182 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
183 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1
184 | return w, b
185 |
186 |
187 | # GroupNorm
188 | def StandardGN(*args, affine=True, **kwargs):
189 | num_groups = kwargs.pop("num_groups", "full")
190 | num_groups = _process_num_groups(num_groups, args[0])
191 | return nn.GroupNorm(num_groups, *args, affine=affine, **kwargs)
192 |
193 |
194 | class SubspaceGN(nn.GroupNorm):
195 | def __init__(
196 | self,
197 | num_features: int,
198 | eps: float = 1e-5,
199 | *,
200 | num_groups: Union[str, int],
201 | ) -> None:
202 |
203 | num_groups = _process_num_groups(num_groups, num_features)
204 |
205 | # Override @affine to be true.
206 | super().__init__(
207 | num_groups,
208 | num_features,
209 | eps=eps,
210 | affine=True,
211 | )
212 | self.num_features = num_features
213 |
214 | def forward(self, input):
215 | # call get_weight, which samples from the subspace, then use the
216 | # corresponding weight.
217 | w, b = self.get_weight()
218 | return F.group_norm(input, self.num_groups, w, b, self.eps)
219 |
220 |
221 | class TwoParamGN(SubspaceGN):
222 | def __init__(self, *args, **kwargs):
223 | super().__init__(*args, **kwargs)
224 | self.weight1 = nn.Parameter(torch.empty([self.num_features]))
225 | self.bias1 = nn.Parameter(torch.empty([self.num_features]))
226 | torch.nn.init.ones_(self.weight1)
227 | torch.nn.init.zeros_(self.bias1)
228 |
229 |
230 | class LinesGN(TwoParamGN):
231 | def get_weight(self):
232 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
233 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1
234 | return w, b
235 |
236 |
237 | def _get_num_parameters(conv):
238 | in_channels = conv.in_channels
239 | out_channels = conv.out_channels
240 |
241 | if hasattr(conv, "in_channels_list"):
242 | in_channels_ratio = in_channels / max(conv.in_channels_list)
243 | out_channels_ratio = out_channels / max(conv.out_channels_list)
244 | else:
245 | in_channels_ratio = in_channels / conv.in_channels_max
246 | out_channels_ratio = out_channels / conv.out_channels_max
247 |
248 | ret = conv.weight.numel()
249 | ret = max(1, round(ret * in_channels_ratio * out_channels_ratio))
250 | return ret
251 |
252 |
253 | # Adaptive modules (which adjust their number of channels at inference time).
254 |
255 | # This code contains the norm implementation used in our unstructured sparsity
256 | # experiments and baselines. Note that normally, we disallow storing or
257 | # recomputing BatchNorm statistics. However, we retain the ability to store
258 | # individual BatchNorm statistics purely for sanity-checking purposes (to ensure
259 | # our implementation produces similar results to the Universal Slimming paper,
260 | # when BatchNorms are stored). But, we don't use these results in any analysis.
261 | class AdaptiveNorm(nn.modules.batchnorm._NormBase):
262 | def __init__(
263 | self,
264 | bn_class,
265 | bn_func,
266 | mode,
267 | *args,
268 | ratio=1,
269 | width_factors_list=None,
270 | **kwargs,
271 | ):
272 | assert mode in ("BatchNorm", "InstanceNorm", "GroupNorm")
273 |
274 | kwargs_cpy = kwargs.copy()
275 | try:
276 | track_running_stats = kwargs_cpy.pop("track_running_stats")
277 | except KeyError:
278 | track_running_stats = False
279 |
280 | try:
281 | self.num_groups = kwargs_cpy.pop("num_groups")
282 | except KeyError:
283 | self.num_groups = None
284 |
285 | super().__init__(
286 | *args,
287 | affine=True,
288 | track_running_stats=track_running_stats,
289 | **kwargs_cpy,
290 | )
291 |
292 | num_features = args[0]
293 | self.width_factors_list = width_factors_list
294 | self.num_features_max = num_features
295 | if mode == "BatchNorm" and self.width_factors_list is not None:
296 | print(
297 | f"Storing extra BatchNorm layers. This should only be used"
298 | f"for sanity checking, since it violates our goal of"
299 | f"arbitrarily fine-grained compression levels at inference"
300 | f"time."
301 | )
302 | self.bn = nn.ModuleList(
303 | [
304 | bn_class(i, affine=False)
305 | for i in [
306 | max(1, round(self.num_features_max * width_factor))
307 | for width_factor in self.width_factors_list
308 | ]
309 | ]
310 | )
311 | if mode == "GroupNorm":
312 | if self.num_groups is None:
313 | raise ValueError("num_groups is required")
314 | if self.num_groups not in ("full", 1):
315 | # This must be "full" or 1, or the tensor might not be divisible
316 | # by @self.num_groups.
317 | raise ValueError(f"Invalid num_groups={self.num_groups}")
318 |
319 | self.ratio = ratio
320 | self.width_factor = None
321 | self.ignore_model_profiling = True
322 | self.bn_func = bn_func
323 | self.mode = mode
324 |
325 | def get_weight(self):
326 | return self.weight, self.bias
327 |
328 | def forward(self, input):
329 | weight, bias = self.get_weight()
330 | c = input.shape[1]
331 | if (
332 | self.mode == "BatchNorm"
333 | and self.width_factors_list is not None
334 | and self.width_factor in self.width_factors_list
335 | ):
336 | # Normally, we expect width_factors_list to be empty, because we
337 | # only want to use it if we are running sanity checks (e.g.
338 | # recreating the original performance or something).
339 | idx = self.width_factors_list.index(self.width_factor)
340 | kwargs = {
341 | "input": input,
342 | "running_mean": self.bn[idx].running_mean[:c],
343 | "running_var": self.bn[idx].running_var[:c],
344 | "weight": weight[:c],
345 | "bias": bias[:c],
346 | "training": self.training,
347 | "momentum": self.momentum,
348 | "eps": self.eps,
349 | }
350 | elif self.mode in ("InstanceNorm", "BatchNorm"):
351 | # Sanity check, since we're not tracking running stats.
352 | running_mean = self.running_mean
353 | if self.running_mean is not None:
354 | running_mean = running_mean[:c]
355 |
356 | running_var = self.running_var
357 | if self.running_var is not None:
358 | running_var = running_var[:c]
359 |
360 | kwargs = {
361 | "input": input,
362 | "running_mean": running_mean,
363 | "running_var": running_var,
364 | "weight": weight[:c],
365 | "bias": bias[:c],
366 | "momentum": self.momentum,
367 | "eps": self.eps,
368 | }
369 |
370 | if self.mode == "BatchNorm":
371 | kwargs["training"] = self.training
372 |
373 | elif self.mode == "GroupNorm":
374 | num_groups = self.num_groups
375 | if num_groups == "full":
376 | num_groups = c
377 | kwargs = {
378 | "input": input,
379 | "num_groups": num_groups,
380 | "weight": weight[:c],
381 | "bias": bias[:c],
382 | "eps": self.eps,
383 | }
384 | else:
385 | raise NotImplementedError(f"Invalid mode {self.mode}.")
386 |
387 | return self.bn_func(**kwargs)
388 |
389 |
390 | class AdaptiveBN(AdaptiveNorm):
391 | def __init__(self, *args, **kwargs):
392 | norm_class = nn.BatchNorm2d
393 | norm_func = F.batch_norm
394 | super().__init__(norm_class, norm_func, "BatchNorm", *args, **kwargs)
395 |
396 |
397 | class AdaptiveIN(AdaptiveNorm):
398 | def __init__(self, *args, **kwargs):
399 | norm_class = nn.InstanceNorm2d
400 | norm_func = F.instance_norm
401 | super().__init__(norm_class, norm_func, "InstanceNorm", *args, **kwargs)
402 |
403 |
404 | class LinesAdaptiveIN(AdaptiveIN):
405 | def __init__(self, *args, **kwargs):
406 | super().__init__(*args, **kwargs)
407 | self.weight1 = nn.Parameter(torch.Tensor(self.num_features))
408 | self.bias1 = nn.Parameter(torch.Tensor(self.num_features))
409 | torch.nn.init.ones_(self.weight1)
410 | torch.nn.init.zeros_(self.bias1)
411 |
412 | def get_weight(self):
413 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
414 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1
415 | return w, b
416 |
417 |
418 | class AdaptiveConv2d(nn.Conv2d):
419 | def __init__(
420 | self,
421 | in_channels,
422 | out_channels,
423 | kernel_size,
424 | stride=1,
425 | padding=0,
426 | dilation=1,
427 | groups=1,
428 | bias=False,
429 | first_layer=False,
430 | last_layer=False,
431 | ratio=None,
432 | ):
433 | self.first_layer = first_layer
434 | self.last_layer = last_layer
435 |
436 | if ratio is None:
437 | ratio = [1, 1]
438 |
439 | super(AdaptiveConv2d, self).__init__(
440 | in_channels,
441 | out_channels,
442 | kernel_size,
443 | stride=stride,
444 | padding=padding,
445 | dilation=dilation,
446 | groups=groups,
447 | bias=bias,
448 | )
449 |
450 | if groups == in_channels:
451 | assert in_channels == out_channels
452 | self.depthwise = True
453 | else:
454 | self.depthwise = False
455 |
456 | self.in_channels_max = in_channels
457 | self.out_channels_max = out_channels
458 | self.width_factor = None
459 | self.ratio = ratio
460 |
461 | def get_weight(self):
462 | return self.weight
463 |
464 | def forward(self, input):
465 | if not self.first_layer:
466 | self.in_channels = input.shape[1]
467 | if not self.last_layer:
468 | self.out_channels = max(
469 | 1, round(self.out_channels_max * self.width_factor)
470 | )
471 | self.groups = self.in_channels if self.depthwise else 1
472 | weight = self.get_weight()
473 | weight = weight[: self.out_channels, : self.in_channels, :, :]
474 | assert self.bias is None
475 | bias = None
476 | y = nn.functional.conv2d(
477 | input,
478 | weight,
479 | bias,
480 | self.stride,
481 | self.padding,
482 | self.dilation,
483 | self.groups,
484 | )
485 | return y
486 |
487 | def get_num_parameters(self):
488 | return _get_num_parameters(self)
489 |
490 |
491 | class LinesAdaptiveConv2d(AdaptiveConv2d):
492 | def __init__(self, *args, **kwargs):
493 | super().__init__(*args, **kwargs)
494 | self.weight1 = nn.Parameter(torch.empty_like(self.weight))
495 | assert self.bias is None
496 | torch.nn.init.ones_(self.weight1)
497 |
498 | def get_weight(self):
499 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
500 | return w
501 |
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | from __future__ import absolute_import
6 |
7 | from .channel_selection import * # pylint: disable=unused-import
8 | # CIFAR-specific
9 | from .cpreresnet import * # pylint: disable=unused-import
10 | from .resnet import * # pylint: disable=unused-import
11 | from .vgg import * # pylint: disable=unused-import
12 |
--------------------------------------------------------------------------------
/models/networks/channel_selection.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | # Only used for Learning Efficient Convolutions (LEC) experiments.
6 | # Deactivated in other cases.
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class channel_selection(nn.Module):
14 | """
15 | Select channels from the output of BatchNorm2d layer. It should be put
16 | directly after BatchNorm2d layer.
17 |
18 | The output shape of this layer is determined by the number of 1s in
19 | `self.indexes`.
20 | """
21 |
22 | def __init__(self, num_channels, active=True):
23 | """
24 | Initialize the `indexes` with all one vector with the length same as the number of channels.
25 | During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
26 | """
27 | super(channel_selection, self).__init__()
28 | self.indexes = nn.Parameter(torch.ones(num_channels))
29 | self.active = active
30 |
31 | def forward(self, input_tensor):
32 | """
33 | Parameter
34 | ---------
35 | input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
36 | """
37 | if not self.active:
38 | return input_tensor
39 | selected_index = np.squeeze(
40 | np.argwhere(self.indexes.data.cpu().numpy())
41 | )
42 | if selected_index.size == 1:
43 | selected_index = np.resize(selected_index, (1,))
44 | output = input_tensor[:, selected_index, :, :]
45 | return output
46 |
47 | def __repr__(self):
48 | return f"{self.__class__.__name__}(active={self.active})"
49 |
--------------------------------------------------------------------------------
/models/networks/cpreresnet.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | from __future__ import absolute_import
6 |
7 | import math
8 |
9 | import torch.nn as nn
10 |
11 | from .channel_selection import channel_selection
12 |
13 | __all__ = ["cpreresnet"]
14 |
15 | """
16 | Preactivation resnet with bottleneck design.
17 | """
18 |
19 |
20 | class Bottleneck(nn.Module):
21 | expansion = 4
22 |
23 | def __init__(
24 | self,
25 | inplanes,
26 | planes,
27 | cfg,
28 | stride=1,
29 | downsample=None,
30 | builder=None,
31 | channel_selection_active=True,
32 | ):
33 | super(Bottleneck, self).__init__()
34 | if builder is None:
35 | raise ValueError(f"Builder required, got None")
36 |
37 | self.bn1 = builder.batchnorm(inplanes)
38 | self.select = channel_selection(
39 | inplanes, active=channel_selection_active
40 | )
41 | self.conv1 = builder.conv1x1(cfg[0], cfg[1])
42 | self.bn2 = builder.batchnorm(cfg[1])
43 | self.conv2 = builder.conv3x3(cfg[1], cfg[2], stride=stride)
44 | self.bn3 = builder.batchnorm(cfg[2])
45 | self.conv3 = builder.conv1x1(cfg[2], planes * 4)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.downsample = downsample
48 | self.stride = stride
49 |
50 | def forward(self, x):
51 | residual = x
52 |
53 | out = self.bn1(x)
54 | out = self.select(out)
55 | out = self.relu(out)
56 | out = self.conv1(out)
57 |
58 | out = self.bn2(out)
59 | out = self.relu(out)
60 | out = self.conv2(out)
61 |
62 | out = self.bn3(out)
63 | out = self.relu(out)
64 | out = self.conv3(out)
65 |
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 |
69 | out += residual
70 |
71 | return out
72 |
73 |
74 | class cpreresnet(nn.Module):
75 | def __init__(
76 | self,
77 | depth=20,
78 | dataset="cifar10",
79 | cfg=None,
80 | builder=None,
81 | block_builder=None,
82 | channel_selection_active=True,
83 | ):
84 | super(cpreresnet, self).__init__()
85 | # assert block_builder is None, "Should not provide a block_builder."
86 | if builder is None:
87 | raise ValueError(f"Expected builder, got None.")
88 | assert (depth - 2) % 9 == 0, "depth should be 9n+2"
89 |
90 | n = (depth - 2) // 9
91 | block = Bottleneck
92 |
93 | if cfg is None:
94 | # Construct config variable.
95 | cfg = [
96 | [16, 16, 16],
97 | [64, 16, 16] * (n - 1),
98 | [64, 32, 32],
99 | [128, 32, 32] * (n - 1),
100 | [128, 64, 64],
101 | [256, 64, 64] * (n - 1),
102 | [256],
103 | ]
104 | cfg = [item for sub_list in cfg for item in sub_list]
105 |
106 | self.inplanes = 16
107 |
108 | self.conv1 = builder.conv3x3(3, 16, first_layer=True)
109 | self.layer1 = self._make_layer(
110 | block,
111 | 16,
112 | n,
113 | cfg=cfg[0 : 3 * n],
114 | builder=block_builder,
115 | channel_selection_active=channel_selection_active,
116 | )
117 | self.layer2 = self._make_layer(
118 | block,
119 | 32,
120 | n,
121 | cfg=cfg[3 * n : 6 * n],
122 | stride=2,
123 | builder=block_builder,
124 | channel_selection_active=channel_selection_active,
125 | )
126 | self.layer3 = self._make_layer(
127 | block,
128 | 64,
129 | n,
130 | cfg=cfg[6 * n : 9 * n],
131 | stride=2,
132 | builder=block_builder,
133 | channel_selection_active=channel_selection_active,
134 | )
135 | self.bn = builder.batchnorm(64 * block.expansion)
136 | self.select = channel_selection(
137 | 64 * block.expansion, active=channel_selection_active
138 | )
139 | self.relu = nn.ReLU(inplace=True)
140 | self.avgpool = nn.AvgPool2d(8)
141 |
142 | # We work only with CIFAR-10
143 | num_categories = 10
144 |
145 | self.fc = builder.conv1x1(cfg[-1], num_categories, last_layer=True)
146 |
147 | # Weight initialization
148 | for m in self.modules():
149 | if isinstance(m, nn.Conv2d) and m != self.fc:
150 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
151 | i = 1
152 | while hasattr(m, f"weight{i}"):
153 | weight = getattr(m, f"weight{i}")
154 | weight.data.normal_(0, math.sqrt(2.0 / n))
155 | i += 1
156 | elif isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
157 | i = 1
158 | while hasattr(m, f"weight{i}"):
159 | weight = getattr(m, f"weight{i}")
160 | weight.data.fill_(0.5)
161 | bias = getattr(m, f"bias{i}")
162 | bias.data.zero_()
163 | i += 1
164 |
165 | def _make_layer(
166 | self,
167 | block,
168 | planes,
169 | blocks,
170 | cfg,
171 | stride=1,
172 | builder=None,
173 | channel_selection_active=True,
174 | ):
175 | if builder is None:
176 | raise ValueError(f"Expected builder, got None.")
177 | downsample = None
178 | if stride != 1 or self.inplanes != planes * block.expansion:
179 | downsample = nn.Sequential(
180 | builder.conv1x1(
181 | self.inplanes, planes * block.expansion, stride=stride
182 | ),
183 | )
184 |
185 | layers = []
186 | layers.append(
187 | block(
188 | self.inplanes,
189 | planes,
190 | cfg[0:3],
191 | stride,
192 | downsample,
193 | builder=builder,
194 | channel_selection_active=channel_selection_active,
195 | )
196 | )
197 | self.inplanes = planes * block.expansion
198 | for i in range(1, blocks):
199 | layers.append(
200 | block(
201 | self.inplanes,
202 | planes,
203 | cfg[3 * i : 3 * (i + 1)],
204 | builder=builder,
205 | channel_selection_active=channel_selection_active,
206 | )
207 | )
208 |
209 | return nn.Sequential(*layers)
210 |
211 | def forward(self, x):
212 | x = self.conv1(x)
213 |
214 | x = self.layer1(x) # 32x32
215 | x = self.layer2(x) # 16x16
216 | x = self.layer3(x) # 8x8
217 | x = self.bn(x)
218 | x = self.select(x)
219 | x = self.relu(x)
220 |
221 | x = self.avgpool(x)
222 | assert x.shape[2:] == (1, 1)
223 | x = self.fc(x)
224 | x = x.view(x.size(0), -1)
225 |
226 | return x
227 |
--------------------------------------------------------------------------------
/models/networks/model_profiling.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | model_profiling_hooks = []
12 | model_profiling_speed_hooks = []
13 |
14 | name_space = 95
15 | params_space = 15
16 | macs_space = 15
17 | seconds_space = 15
18 |
19 | num_forwards = 10
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, verbose=False):
24 | self.verbose = verbose
25 | self.start = None
26 | self.end = None
27 |
28 | def __enter__(self):
29 | self.start = time.time()
30 | return self
31 |
32 | def __exit__(self, *args):
33 | self.end = time.time()
34 | self.time = self.end - self.start
35 | if self.verbose:
36 | print("Elapsed time: %f ms." % self.time)
37 |
38 |
39 | def get_params(module):
40 | """get number of params in module"""
41 | if hasattr(module, "width_factor"):
42 | # We assume it's a conv layer.
43 | ret = module.get_num_parameters()
44 | else:
45 | ret = np.sum([np.prod(list(w.size())) for w in module.parameters()])
46 | return ret
47 |
48 |
49 | def run_forward(module, input):
50 | with Timer() as t:
51 | for _ in range(num_forwards):
52 | module.forward(*input)
53 | if input[0].is_cuda:
54 | torch.cuda.synchronize()
55 | return int(t.time * 1e9 / num_forwards)
56 |
57 |
58 | def conv_module_name_filter(name):
59 | """filter module name to have a short view"""
60 | filters = {
61 | "kernel_size": "k",
62 | "stride": "s",
63 | "padding": "pad",
64 | "bias": "b",
65 | "groups": "g",
66 | }
67 | for k in filters:
68 | name = name.replace(k, filters[k])
69 | return name
70 |
71 |
72 | def module_profiling(module, input, output, verbose):
73 | if not isinstance(input[0], list):
74 | # Some modules return a list of outputs. We usually ignore them.
75 | ins = input[0].size()
76 | outs = output.size()
77 | # NOTE: There are some difference between type and isinstance, thus please
78 | # be careful.
79 | t = type(module)
80 | if isinstance(module, nn.Conv2d):
81 | module.n_macs = (
82 | ins[1]
83 | * outs[1]
84 | * module.kernel_size[0]
85 | * module.kernel_size[1]
86 | * outs[2]
87 | * outs[3]
88 | // module.groups
89 | ) * outs[0]
90 | module.n_params = get_params(module)
91 | module.n_seconds = run_forward(module, input)
92 | module.name = conv_module_name_filter(module.__repr__())
93 | elif isinstance(module, nn.ConvTranspose2d):
94 | module.n_macs = (
95 | ins[1]
96 | * outs[1]
97 | * module.kernel_size[0]
98 | * module.kernel_size[1]
99 | * outs[2]
100 | * outs[3]
101 | // module.groups
102 | ) * outs[0]
103 | module.n_params = get_params(module)
104 | module.n_seconds = run_forward(module, input)
105 | module.name = conv_module_name_filter(module.__repr__())
106 | elif isinstance(module, nn.Linear):
107 | module.n_macs = ins[1] * outs[1] * outs[0]
108 | module.n_params = get_params(module)
109 | module.n_seconds = run_forward(module, input)
110 | module.name = module.__repr__()
111 | elif isinstance(module, nn.AvgPool2d):
112 | # NOTE: this function is correct only when stride == kernel size
113 | module.n_macs = ins[1] * ins[2] * ins[3] * ins[0]
114 | module.n_params = 0
115 | module.n_seconds = run_forward(module, input)
116 | module.name = module.__repr__()
117 | elif isinstance(module, nn.AdaptiveAvgPool2d):
118 | # NOTE: this function is correct only when stride == kernel size
119 | module.n_macs = ins[1] * ins[2] * ins[3] * ins[0]
120 | module.n_params = 0
121 | module.n_seconds = run_forward(module, input)
122 | module.name = module.__repr__()
123 | else:
124 | # This works only in depth-first travel of modules.
125 | module.n_macs = 0
126 | module.n_params = 0
127 | module.n_seconds = 0
128 | num_children = 0
129 | for m in module.children():
130 | module.n_macs += getattr(m, "n_macs", 0)
131 | module.n_params += getattr(m, "n_params", 0)
132 | module.n_seconds += getattr(m, "n_seconds", 0)
133 | num_children += 1
134 | ignore_zeros_t = [
135 | nn.BatchNorm2d,
136 | nn.InstanceNorm2d,
137 | nn.Dropout2d,
138 | nn.Dropout,
139 | nn.Sequential,
140 | nn.ReLU6,
141 | nn.ReLU,
142 | nn.MaxPool2d,
143 | nn.modules.padding.ZeroPad2d,
144 | nn.modules.activation.Sigmoid,
145 | ]
146 | if (
147 | not getattr(module, "ignore_model_profiling", False)
148 | and module.n_macs == 0
149 | and t not in ignore_zeros_t
150 | ):
151 | print(
152 | "WARNING: leaf module {} has zero n_macs.".format(type(module))
153 | )
154 | return
155 | if verbose:
156 | print(
157 | module.name.ljust(name_space, " ")
158 | + "{:,}".format(module.n_params).rjust(params_space, " ")
159 | + "{:,}".format(module.n_macs).rjust(macs_space, " ")
160 | + "{:,}".format(module.n_seconds).rjust(seconds_space, " ")
161 | )
162 | return
163 |
164 |
165 | def add_profiling_hooks(m, verbose):
166 | global model_profiling_hooks
167 | model_profiling_hooks.append(
168 | m.register_forward_hook(
169 | lambda m, input, output: module_profiling(
170 | m, input, output, verbose=verbose
171 | )
172 | )
173 | )
174 |
175 |
176 | def remove_profiling_hooks():
177 | global model_profiling_hooks
178 | for h in model_profiling_hooks:
179 | h.remove()
180 | model_profiling_hooks = []
181 |
182 |
183 | def model_profiling(
184 | model, height, width, batch=1, channel=3, use_cuda=True, verbose=True
185 | ):
186 | """Pytorch model profiling with input image size
187 | (batch, channel, height, width).
188 | The function exams the number of multiply-accumulates (n_macs).
189 |
190 | Args:
191 | model: pytorch model
192 | height: int
193 | width: int
194 | batch: int
195 | channel: int
196 | use_cuda: bool
197 |
198 | Returns:
199 | macs: int
200 | params: int
201 |
202 | """
203 | if isinstance(model, nn.DataParallel):
204 | model = model.module
205 | model.eval()
206 | data = torch.rand(batch, channel, height, width)
207 | device = torch.device("cuda:0" if use_cuda else "cpu")
208 | model = model.to(device)
209 | data = data.to(device)
210 | model.apply(lambda m: add_profiling_hooks(m, verbose=verbose))
211 | print(
212 | "Item".ljust(name_space, " ")
213 | + "params".rjust(macs_space, " ")
214 | + "macs".rjust(macs_space, " ")
215 | + "nanosecs".rjust(seconds_space, " ")
216 | )
217 | if verbose:
218 | print(
219 | "".center(
220 | name_space + params_space + macs_space + seconds_space, "-"
221 | )
222 | )
223 | model(data)
224 | if verbose:
225 | print(
226 | "".center(
227 | name_space + params_space + macs_space + seconds_space, "-"
228 | )
229 | )
230 | print(
231 | "Total".ljust(name_space, " ")
232 | + "{:,}".format(model.n_params).rjust(params_space, " ")
233 | + "{:,}".format(model.n_macs).rjust(macs_space, " ")
234 | + "{:,}".format(model.n_seconds).rjust(seconds_space, " ")
235 | )
236 | remove_profiling_hooks()
237 | return model.n_macs, model.n_params
238 |
--------------------------------------------------------------------------------
/models/networks/resnet.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch.nn as nn
6 |
7 |
8 | # BasicBlock {{{
9 | class BasicBlock(nn.Module):
10 | M = 2
11 | expansion = 1
12 |
13 | def __init__(
14 | self,
15 | builder,
16 | inplanes,
17 | planes,
18 | stride=1,
19 | downsample=None,
20 | base_width=64,
21 | ):
22 | super(BasicBlock, self).__init__()
23 | if base_width / 64 > 1:
24 | raise ValueError("Base width >64 does not work for BasicBlock")
25 |
26 | self.conv1 = builder.conv3x3(inplanes, planes, stride)
27 | self.bn1 = builder.batchnorm(planes)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.conv2 = builder.conv3x3(planes, planes)
30 | self.bn2 = builder.batchnorm(planes)
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | if self.bn1 is not None:
39 | out = self.bn1(out)
40 |
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 |
45 | if self.bn2 is not None:
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | # BasicBlock }}}
58 |
59 | # Bottleneck {{{
60 | class Bottleneck(nn.Module):
61 | M = 3
62 | expansion = 4
63 |
64 | def __init__(
65 | self,
66 | builder,
67 | inplanes,
68 | planes,
69 | stride=1,
70 | downsample=None,
71 | base_width=64,
72 | ):
73 | super(Bottleneck, self).__init__()
74 | width = int(planes * base_width / 64)
75 | self.conv1 = builder.conv1x1(inplanes, width)
76 | self.bn1 = builder.batchnorm(width)
77 | self.conv2 = builder.conv3x3(width, width, stride=stride)
78 | self.bn2 = builder.batchnorm(width)
79 | self.conv3 = builder.conv1x1(width, planes * self.expansion)
80 | self.bn3 = builder.batchnorm(planes * self.expansion)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.downsample = downsample
83 | self.stride = stride
84 |
85 | def forward(self, x):
86 | residual = x
87 |
88 | out = self.conv1(x)
89 | out = self.bn1(out)
90 | out = self.relu(out)
91 |
92 | out = self.conv2(out)
93 | out = self.bn2(out)
94 | out = self.relu(out)
95 |
96 | out = self.conv3(out)
97 | out = self.bn3(out)
98 |
99 | if self.downsample is not None:
100 | residual = self.downsample(x)
101 |
102 | out += residual
103 |
104 | out = self.relu(out)
105 |
106 | return out
107 |
108 |
109 | # Bottleneck }}}
110 |
111 | # ResNet {{{
112 | class ResNet(nn.Module):
113 | def __init__(
114 | self,
115 | builder,
116 | block_builder,
117 | block,
118 | layers,
119 | num_classes=1000,
120 | base_width=64,
121 | ):
122 | self.inplanes = 64
123 | super(ResNet, self).__init__()
124 |
125 | self.base_width = base_width
126 | if self.base_width // 64 > 1:
127 | print(f"==> Using {self.base_width // 64}x wide model")
128 |
129 | self.conv1 = builder.conv7x7(3, 64, stride=2, first_layer=True)
130 |
131 | self.bn1 = builder.batchnorm(64)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
134 | self.layer1 = self._make_layer(block_builder, block, 64, layers[0])
135 | self.layer2 = self._make_layer(
136 | block_builder, block, 128, layers[1], stride=2
137 | )
138 | self.layer3 = self._make_layer(
139 | block_builder, block, 256, layers[2], stride=2
140 | )
141 | self.layer4 = self._make_layer(
142 | block_builder, block, 512, layers[3], stride=2
143 | )
144 | self.avgpool = nn.AdaptiveAvgPool2d(1)
145 | self.return_feats = False
146 |
147 | self.fc = builder.conv1x1(
148 | 512 * block.expansion, num_classes, last_layer=True
149 | )
150 |
151 | def _make_layer(self, builder, block, planes, blocks, stride=1):
152 | downsample = None
153 | if stride != 1 or self.inplanes != planes * block.expansion:
154 | dconv = builder.conv1x1(
155 | self.inplanes, planes * block.expansion, stride=stride
156 | )
157 | dbn = builder.batchnorm(planes * block.expansion)
158 | if dbn is not None:
159 | downsample = nn.Sequential(dconv, dbn)
160 | else:
161 | downsample = dconv
162 |
163 | layers = []
164 | layers.append(
165 | block(
166 | builder,
167 | self.inplanes,
168 | planes,
169 | stride,
170 | downsample,
171 | base_width=self.base_width,
172 | )
173 | )
174 | self.inplanes = planes * block.expansion
175 | for i in range(1, blocks):
176 | layers.append(
177 | block(
178 | builder, self.inplanes, planes, base_width=self.base_width
179 | )
180 | )
181 |
182 | return nn.Sequential(*layers)
183 |
184 | def forward(self, x):
185 | x = self.conv1(x)
186 |
187 | if self.bn1 is not None:
188 | x = self.bn1(x)
189 | x = self.relu(x)
190 | x = self.maxpool(x)
191 |
192 | x = self.layer1(x)
193 | x = self.layer2(x)
194 | x = self.layer3(x)
195 | x = self.layer4(x)
196 |
197 | feats = self.avgpool(x)
198 | x = self.fc(feats)
199 | x = x.view(x.size(0), -1)
200 |
201 | if self.return_feats:
202 | return x, feats.view(feats.size(0), -1)
203 | return x
204 |
205 |
206 | def _get_output_size(dataset: str):
207 | return {
208 | "cifar10": 10,
209 | "imagenet": 1000,
210 | }[dataset]
211 |
212 |
213 | # ResNet }}}
214 | def ResNet18(builder, block_builder, dataset):
215 | return ResNet(
216 | builder,
217 | block_builder,
218 | BasicBlock,
219 | [2, 2, 2, 2],
220 | _get_output_size(dataset),
221 | )
222 |
--------------------------------------------------------------------------------
/models/networks/resprune.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | # Note: this only works for cPreResNet20 (implementation provided by LEC paper).
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | from models import networks
12 |
13 | from .cpreresnet import *
14 |
15 |
16 | def get_slimmed_network(model, model_kwargs, cfg, cfg_mask):
17 | assert not isinstance(model, nn.DataParallel), f"Must unwrap DataParallel"
18 |
19 | is_cuda = next(model.parameters()).is_cuda
20 | print("Cfg:")
21 | print(cfg)
22 |
23 | newmodel = cpreresnet(cfg=cfg, **model_kwargs)
24 |
25 | if is_cuda:
26 | newmodel.cuda()
27 |
28 | old_modules = list(model.modules())
29 | new_modules = list(newmodel.modules())
30 | layer_id_in_cfg = 0
31 | start_mask = torch.ones(3)
32 | end_mask = cfg_mask[layer_id_in_cfg]
33 | conv_count = 0
34 |
35 | for layer_id in range(len(old_modules)):
36 | m0 = old_modules[layer_id]
37 | m1 = new_modules[layer_id]
38 |
39 | if isinstance(m0, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
40 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
41 | if idx1.size == 1:
42 | idx1 = np.resize(idx1, (1,))
43 |
44 | if isinstance(
45 | old_modules[layer_id + 1], networks.channel_selection
46 | ):
47 | # If the next layer is the channel selection layer, then the
48 | # current batchnorm 2d layer won't be pruned.
49 | m1.weight.data = m0.weight.data.clone()
50 | m1.bias.data = m0.bias.data.clone()
51 | if m0.running_mean is None:
52 | m1.running_mean = m0.running_mean
53 | m1.running_var = m0.running_var
54 | else:
55 | m1.running_mean.data = m0.running_mean.clone()
56 | m1.running_var.data = m0.running_var.clone()
57 |
58 | # We need to set the channel selection layer.
59 | m2 = new_modules[layer_id + 1]
60 | m2.indexes.data.zero_()
61 | m2.indexes.data[idx1.tolist()] = 1.0
62 |
63 | layer_id_in_cfg += 1
64 | start_mask = end_mask.clone()
65 | if layer_id_in_cfg < len(cfg_mask):
66 | end_mask = cfg_mask[layer_id_in_cfg]
67 | else:
68 | m1.weight.data = m0.weight.data[idx1.tolist()].clone()
69 | m1.bias.data = m0.bias.data[idx1.tolist()].clone()
70 | if m0.running_mean is None:
71 | m1.running_mean = m0.running_mean
72 | m1.running_var = m0.running_var
73 | else:
74 | m1.running_mean = m0.running_mean[idx1.tolist()].clone()
75 | m1.running_var = m0.running_var[idx1.tolist()].clone()
76 | layer_id_in_cfg += 1
77 | start_mask = end_mask.clone()
78 | if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
79 | end_mask = cfg_mask[layer_id_in_cfg]
80 | elif isinstance(m0, nn.Conv2d) and m0 != model.fc:
81 | if conv_count == 0:
82 | m1.weight.data = m0.weight.data.clone()
83 | conv_count += 1
84 | continue
85 | if isinstance(
86 | old_modules[layer_id - 1], networks.channel_selection
87 | ) or isinstance(
88 | old_modules[layer_id - 1],
89 | (nn.modules.batchnorm._NormBase, nn.GroupNorm),
90 | ):
91 | # This convers the convolutions in the residual block.
92 | # The convolutions are either after the channel selection layer
93 | # or after the batch normalization layer.
94 | conv_count += 1
95 | idx0 = np.squeeze(
96 | np.argwhere(np.asarray(start_mask.cpu().numpy()))
97 | )
98 | idx1 = np.squeeze(
99 | np.argwhere(np.asarray(end_mask.cpu().numpy()))
100 | )
101 | if idx0.size == 1:
102 | idx0 = np.resize(idx0, (1,))
103 | if idx1.size == 1:
104 | idx1 = np.resize(idx1, (1,))
105 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
106 |
107 | # If the current convolution is not the last convolution in the
108 | # residual block, then we can change the number of output
109 | # channels. Currently we use `conv_count` to detect whether it
110 | # is such convolution.
111 | if conv_count % 3 != 1:
112 | w1 = w1[idx1.tolist(), :, :, :].clone()
113 | m1.weight.data = w1.clone()
114 | continue
115 |
116 | # We need to consider the case where there are downsampling
117 | # convolutions. For these convolutions, we just copy the weights.
118 | m1.weight.data = m0.weight.data.clone()
119 | elif m0 == model.fc:
120 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
121 | if idx0.size == 1:
122 | idx0 = np.resize(idx0, (1,))
123 |
124 | m1.weight.data = m0.weight.data[:, idx0].clone()
125 |
126 | assert m1.bias is None == m0.bias is None
127 | if m1.bias is not None:
128 | m1.bias.data = m0.bias.data.clone()
129 |
130 | num_parameters = sum([param.nelement() for param in newmodel.parameters()])
131 |
132 | return num_parameters, newmodel
133 |
--------------------------------------------------------------------------------
/models/networks/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def get_slim_configs(model):
10 | # Note: the model should *already* be pruned on the BN layers. This module
11 | # will not apply the pruning part.
12 | # We expect the user to create a (pruned) copy of the model first before
13 | # calling this.
14 | is_cuda = next(model.parameters()).is_cuda
15 |
16 | cfg = []
17 | cfg_mask = []
18 | for k, m in enumerate(model.modules()):
19 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
20 | mask = m.weight.abs() > 0
21 | if is_cuda:
22 | mask = mask.cuda()
23 | cfg.append(int(torch.sum(mask)))
24 | cfg_mask.append(mask.clone())
25 | print(
26 | "layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}".format(
27 | k, mask.shape[0], int(torch.sum(mask))
28 | )
29 | )
30 | elif isinstance(m, nn.MaxPool2d):
31 | cfg.append("M")
32 |
33 | return cfg, cfg_mask
34 |
--------------------------------------------------------------------------------
/models/networks/vgg.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import math
6 |
7 | import torch.nn as nn
8 |
9 | __all__ = ["vgg", "vgg11", "vgg13", "vgg16", "vgg19"]
10 |
11 | defaultcfg = {
12 | 11: [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
13 | 13: [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512],
14 | 16: [
15 | 64,
16 | 64,
17 | "M",
18 | 128,
19 | 128,
20 | "M",
21 | 256,
22 | 256,
23 | 256,
24 | "M",
25 | 512,
26 | 512,
27 | 512,
28 | "M",
29 | 512,
30 | 512,
31 | 512,
32 | ],
33 | 19: [
34 | 64,
35 | 64,
36 | "M",
37 | 128,
38 | 128,
39 | "M",
40 | 256,
41 | 256,
42 | 256,
43 | 256,
44 | "M",
45 | 512,
46 | 512,
47 | 512,
48 | 512,
49 | "M",
50 | 512,
51 | 512,
52 | 512,
53 | 512,
54 | ],
55 | }
56 |
57 |
58 | class vgg(nn.Module):
59 | def __init__(
60 | self,
61 | dataset="imagenet",
62 | depth=19,
63 | init_weights=True,
64 | cfg=None,
65 | builder=None,
66 | block_builder=None,
67 | ):
68 | super(vgg, self).__init__()
69 | if cfg is None:
70 | cfg = defaultcfg[depth]
71 |
72 | self.feature = self.make_layers(
73 | cfg, True, builder=builder, block_builder=block_builder
74 | )
75 |
76 | if dataset == "imagenet":
77 | num_classes = 1000
78 | else:
79 | raise NotImplementedError(f"Not implemented for dataset {dataset}")
80 | self.classifier = builder.conv1x1(cfg[-1], num_classes, last_layer=True)
81 |
82 | if init_weights:
83 | self._initialize_weights()
84 |
85 | def make_layers(
86 | self, cfg, batch_norm=False, builder=None, block_builder=None
87 | ):
88 | layers = []
89 | in_channels = 3
90 | first_conv_layer = True
91 | for v in cfg:
92 | if v == "M":
93 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
94 | else:
95 | if first_conv_layer:
96 | my_builder = builder
97 | else:
98 | my_builder = block_builder
99 | conv2d = my_builder.conv3x3(
100 | in_channels, v, first_layer=first_conv_layer
101 | )
102 | first_conv_layer = False
103 | if batch_norm:
104 | bn = my_builder.batchnorm(v)
105 | layers += [conv2d, bn, nn.ReLU(inplace=True)]
106 | else:
107 | layers += [conv2d, nn.ReLU(inplace=True)]
108 | in_channels = v
109 | return nn.Sequential(*layers)
110 |
111 | def forward(self, x):
112 | x = self.feature(x)
113 | x = nn.AvgPool2d(2)(x)
114 | y = self.classifier(x)
115 | y = y.view(y.size(0), -1)
116 | return y
117 |
118 | def _initialize_weights(self):
119 | for m in self.modules():
120 | if isinstance(m, nn.Conv2d):
121 | if m.last_layer:
122 | assert m.kernel_size == (1, 1)
123 | # Fully Connected Layer
124 | i = 1
125 | while hasattr(m, f"weight{i}"):
126 | weight = getattr(m, f"weight{i}")
127 | weight.data.normal_(0, 0.01)
128 | bias = getattr(m, f"bias{i}", None)
129 | if bias is not None:
130 | bias.data.zero_()
131 | i += 1
132 | else:
133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
134 | i = 1
135 | while hasattr(m, f"weight{i}"):
136 | weight = getattr(m, f"weight{i}")
137 | weight.data.normal_(0, math.sqrt(2.0 / n))
138 | bias = getattr(m, f"bias{i}", None)
139 | if bias is not None:
140 | bias.data.zero_()
141 | i += 1
142 |
143 | elif isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
144 | i = 1
145 | while hasattr(m, f"weight{i}"):
146 | weight = getattr(m, f"weight{i}")
147 | weight.data.fill_(0.5)
148 | bias = getattr(m, f"bias{i}", None)
149 | if bias is not None:
150 | bias.data.zero_()
151 | i += 1
152 | elif isinstance(m, nn.Linear):
153 | i = 1
154 | while hasattr(m, f"weight{i}"):
155 | weight = getattr(m, f"weight{i}")
156 | weight.data.normal_(0, 0.01)
157 | bias = getattr(m, f"bias{i}", None)
158 | if bias is not None:
159 | bias.data.zero_()
160 | i += 1
161 |
162 |
163 | def vgg11(**kwargs):
164 | return vgg(depth=11, **kwargs)
165 |
166 |
167 | def vgg13(**kwargs):
168 | return vgg(depth=13, **kwargs)
169 |
170 |
171 | def vgg16(**kwargs):
172 | return vgg(depth=16, **kwargs)
173 |
174 |
175 | def vgg19(**kwargs):
176 | return vgg(depth=19, **kwargs)
177 |
--------------------------------------------------------------------------------
/models/networks/vggprune.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .vgg import *
10 |
11 |
12 | def get_slimmed_network(model, model_kwargs, cfg, cfg_mask):
13 | assert not isinstance(model, nn.DataParallel), f"Must unwrap DataParallel"
14 |
15 | is_cuda = next(model.parameters()).is_cuda
16 | print("Cfg:")
17 | print(cfg)
18 |
19 | newmodel = vgg(cfg=cfg, **model_kwargs)
20 |
21 | if is_cuda:
22 | newmodel.cuda()
23 |
24 | old_modules = list(model.modules())
25 | new_modules = list(newmodel.modules())
26 | layer_id_in_cfg = 0
27 | start_mask = torch.ones(3)
28 | end_mask = cfg_mask[layer_id_in_cfg]
29 |
30 | for layer_id in range(len(old_modules)):
31 | m0 = old_modules[layer_id]
32 | m1 = new_modules[layer_id]
33 |
34 | if isinstance(m0, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
35 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
36 | if idx1.size == 1:
37 | idx1 = np.resize(idx1, (1,))
38 | m1.weight.data = m0.weight.data[idx1.tolist()].clone()
39 | m1.bias.data = m0.bias.data[idx1.tolist()].clone()
40 | m1.running_mean = m0.running_mean[idx1.tolist()].clone()
41 | m1.running_var = m0.running_var[idx1.tolist()].clone()
42 | layer_id_in_cfg += 1
43 | start_mask = end_mask.clone()
44 | if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
45 | end_mask = cfg_mask[layer_id_in_cfg]
46 | elif isinstance(m0, nn.Conv2d) and m0 != model.classifier:
47 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
48 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
49 | print(
50 | "In shape: {:d}, Out shape {:d}.".format(idx0.size, idx1.size)
51 | )
52 | if idx0.size == 1:
53 | idx0 = np.resize(idx0, (1,))
54 | if idx1.size == 1:
55 | idx1 = np.resize(idx1, (1,))
56 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
57 | w1 = w1[idx1.tolist(), :, :, :].clone()
58 | m1.weight.data = w1.clone()
59 | elif m0 == model.classifier:
60 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
61 | if idx0.size == 1:
62 | idx0 = np.resize(idx0, (1,))
63 |
64 | m1.weight.data = m0.weight.data[:, idx0].clone()
65 |
66 | assert m1.bias is None == m0.bias is None
67 | if m1.bias is not None:
68 | m1.bias.data = m0.bias.data.clone()
69 |
70 | num_parameters = sum([param.nelement() for param in newmodel.parameters()])
71 |
72 | return num_parameters, newmodel
73 |
--------------------------------------------------------------------------------
/models/quantize_affine.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import collections
6 | import numbers
7 | from typing import Any
8 | from typing import Optional
9 |
10 | import numpy as np
11 | import torch
12 | from torch import autograd
13 | from torch import nn
14 |
15 | from .special_tensors import RepresentibleByQuantizeAffine
16 | from .special_tensors import tag_with_metadata
17 |
18 | QuantizeAffineParams2 = collections.namedtuple(
19 | "QuantizeAffineParams", ["scale", "zero_point", "num_bits"]
20 | )
21 |
22 | INFINITY = 1e10
23 |
24 |
25 | def _validate_tensor(tensor: torch.Tensor) -> None:
26 | if torch.isnan(tensor).any():
27 | raise ValueError("Found NaN in the tensor.")
28 | if tensor.abs().max() > INFINITY:
29 | raise ValueError(
30 | "Tensor seems to be diverging. Found a value > {}".format(INFINITY)
31 | )
32 |
33 |
34 | def get_quantized_representation(
35 | tensor: torch.Tensor,
36 | quantize_params: QuantizeAffineParams2,
37 | ) -> torch.Tensor:
38 | """Gets the quantize representation of a float @tensor.
39 |
40 | The resulting tensor will contain the quantized values and the quantization
41 | parameters will be tagged with the tensor as a special tensor.
42 |
43 | A ValueError will be raised if the given tensor contains NaN or divergent
44 | values.
45 |
46 | Arguments:
47 | tensor (torch.Tensor): The float torch tensor to quantize.
48 | quantize_params (QuantizeAffineParams): The quantization params to
49 | quantize the tensor by.
50 | """
51 | _validate_tensor(tensor)
52 |
53 | scale = quantize_params.scale
54 | zero_point = quantize_params.zero_point
55 | num_bits = quantize_params.num_bits
56 | if scale == 0:
57 | # Special case, all elements are zeros.
58 | if zero_point != 0:
59 | raise ValueError(
60 | "The given QuantizeAffineParams (={}) has a non-zero zero point"
61 | " with a scale of 0.".format(quantize_params)
62 | )
63 | quantized_tensor = torch.zeros_like(tensor)
64 | tag_with_metadata(quantized_tensor, quantize_params)
65 | return quantized_tensor
66 |
67 | qmin, qmax = get_qmin_qmax(num_bits)
68 | reciprocal = 1 / scale
69 | quantized_tensor = ((tensor * reciprocal).round_() + zero_point).clamp_(
70 | qmin, qmax
71 | )
72 |
73 | tag_with_metadata(quantized_tensor, quantize_params)
74 | return quantized_tensor
75 |
76 |
77 | def mark_quantize_affine(
78 | tensor: torch.Tensor,
79 | scale: float,
80 | zero_point: int,
81 | dtype: np.dtype = np.uint8,
82 | ) -> None:
83 | """Mark a tensor as quantized with affine.
84 |
85 | Arguments:
86 | tensor (torch.Tensor): The tensor to be marked as affine-quantizable
87 | Tensor.
88 | scale (float): the scale (from quantization parameters).
89 | zero_point (int): The zero_point (from quantization parameters).
90 | dtype (numpy.dtype): Type of tensor when quantized (this is usually
91 | numpy.uint8, which is used for Q8). A ValueError will be thrown if
92 | the input dtype is not one of the following:
93 | {numpy.uint8, numpy.int32}.
94 | """
95 | quant_params = QuantizeAffineParams2(scale, zero_point, dtype)
96 | tag_with_metadata(tensor, RepresentibleByQuantizeAffine(quant_params))
97 |
98 |
99 | class QuantizeAffineFunction(autograd.Function):
100 | """Simulates affect of affine quantization during forward pass.
101 |
102 | This function simulates the affect of quantization and subsequent
103 | dequantization (in the forward pass only). Although the affine
104 | transformation results in a different basis (e.g. uint8), the output of this
105 | function will be a float Tensor representing that transformation (the
106 | dequantized Tensor).
107 |
108 | A ValueError will be raised if the input or resulting tensor contains NaN or
109 | divergent values.
110 |
111 | Arguments:
112 | input (Tensor): The input float Tensor to be quantized.
113 | quantize_params (quantize_affine_util.QuantizeAffineParams): The
114 | quantization parameter to quantize the input tensor by.
115 | """
116 |
117 | @staticmethod
118 | def forward(
119 | ctx: Any,
120 | input: torch.Tensor,
121 | quantize_params: QuantizeAffineParams2,
122 | ) -> torch.Tensor:
123 | quantized_tensor = get_quantized_representation(input, quantize_params)
124 | dequantized_tensor = dequantize(quantized_tensor, quantize_params)
125 |
126 | mark_quantize_affine(
127 | dequantized_tensor,
128 | quantize_params.scale,
129 | quantize_params.zero_point,
130 | quantize_params.num_bits,
131 | )
132 | return dequantized_tensor
133 |
134 | @staticmethod
135 | def backward(ctx: Any, grad_output: Any) -> Any:
136 | """We will approximate the gradient as the identity"""
137 | return grad_output, None
138 |
139 |
140 | def quantize_affine_function_continuous(
141 | input: torch.Tensor,
142 | quantize_params: QuantizeAffineParams2,
143 | ) -> torch.Tensor:
144 | quantized_tensor = get_quantized_representation(input, quantize_params)
145 | dequantized_tensor = dequantize(quantized_tensor, quantize_params)
146 |
147 | mark_quantize_affine(
148 | dequantized_tensor,
149 | quantize_params.scale,
150 | quantize_params.zero_point,
151 | quantize_params.num_bits,
152 | )
153 | return dequantized_tensor
154 |
155 |
156 | def get_qmin_qmax(num_bits):
157 | return -(2 ** (num_bits - 1)), 2 ** (num_bits - 1) - 1
158 |
159 |
160 | def get_quantization_params(
161 | rmin: float,
162 | rmax: float,
163 | num_bits: int = 8,
164 | ) -> QuantizeAffineParams2:
165 | """Returns QuantizeAffineParams for a data range [rmin, rmax].
166 |
167 | The range must include 0 otherwise that's a failure. The scale and
168 | zero_point are picked such that the error is quantization error is
169 | minimized.
170 |
171 | Arguments:
172 | rmin (float): The data minimum point. Numbers smaller than rmin would
173 | not be representible by the quantized schema.
174 | rmax (float): The data maximum point. Numbers bigger than rmax would
175 | not be representible by the quantized schema.
176 | dtype (optional, np.dtype): The dtype that should be used to represent
177 | the individual numbers after quantization. Only np.uint8 is
178 | supported.
179 | """
180 | if rmin > rmax:
181 | raise ValueError("Got rmin (={}) > rmax (={}).".format(rmin, rmax))
182 | if rmin > 0 or rmax < 0:
183 | raise ValueError(
184 | "The data range ([{}, {}]) must always include "
185 | "0.".format(rmin, rmax)
186 | )
187 |
188 | if rmin == rmax == 0.0:
189 | # Special case: all values are zero.
190 | return QuantizeAffineParams2(0, 0, num_bits)
191 |
192 | # Scale is floating point and is (rmax - rmin) / (qmax - qmin) to map the
193 | # length of the ranges. For zero_point, we solve the following equation:
194 | # rmin = (qmin - zero_point) * scale
195 | qmin, qmax = get_qmin_qmax(num_bits)
196 | scale = (rmax - rmin) / (qmax - qmin)
197 | zero_point = qmin - (rmin / scale)
198 | zero_point = np.clip(round(zero_point), qmin, qmax).astype(np.int32)
199 |
200 | quantize_params = QuantizeAffineParams2(scale, zero_point, num_bits)
201 | # We must ensure that zero is exactly representable with these quantization
202 | # parameters. This is easy enough to add a self-check for.
203 | quantized_zero = quantize(np.array([0.0]), quantize_params)
204 | dequantized_zero = dequantize(quantized_zero, quantize_params)
205 | if dequantized_zero.item() != 0.0:
206 | raise ValueError(
207 | f"Quantization parameters are invalid: scale={scale}, zero={zero_point}. "
208 | f"Can't exactly represent zero."
209 | )
210 |
211 | return quantize_params
212 |
213 |
214 | def quantize_affine_given_quant_params(
215 | input: torch.Tensor,
216 | quantize_params: QuantizeAffineParams2,
217 | ) -> torch.Tensor:
218 | """Get a quantizable approximation of a float tensor given quantize param.
219 |
220 | This function does not quantize the float tensor @input, but only adjusts it
221 | such that the returned float tensor has an exact quantized representation.
222 | This is a function that we want to use at training time to quantize biases
223 | and other parameters whose quantization schema is enforced by other
224 | parameteres.
225 |
226 | In forward pass, this function is equivalent to
227 |
228 | dequantize(get_quantized_representation(input, quantize_param))
229 |
230 | However, in backward pass, this function operates as identity, making it
231 | ideal to be a part of the training forward pass.
232 | """
233 | return QuantizeAffineFunction.apply(input, quantize_params)
234 |
235 |
236 | def quantize(
237 | arr: np.ndarray, quantize_params: QuantizeAffineParams2
238 | ) -> np.ndarray:
239 | """Quantize a floating point array with respect to the quantization params.
240 |
241 | Arguments:
242 | arr (np.ndarray): The floating point data to quantize.
243 | quantize_params (QuantizeAffineParams): The quantization parameters
244 | under which the data should be quantized.
245 | """
246 | scale = quantize_params.scale
247 | zero_point = quantize_params.zero_point
248 | num_bits = quantize_params.num_bits
249 | if scale == 0:
250 | # Special case, all elements are zeros.
251 | if zero_point != 0:
252 | raise ValueError(
253 | "The given QuantizeAffineParams (={}) has a non-zero zero point"
254 | " with a scale of 0.".format(quantize_params)
255 | )
256 | return np.zeros_like(arr, dtype=np.int32)
257 |
258 | qmin, qmax = get_qmin_qmax(num_bits)
259 | reciprocal = 1 / scale
260 | quantized_values = (arr * reciprocal).round() + zero_point
261 | quantized_values = quantized_values.clip(qmin, qmax)
262 | return quantized_values
263 |
264 |
265 | def dequantize(
266 | q_arr: np.ndarray,
267 | quantize_params: QuantizeAffineParams2,
268 | ) -> np.ndarray:
269 | """Dequantize a fixed point array with respect to the quantization params.
270 |
271 | Arguments:
272 | q_arr (np.ndarray): The quantized array to dequantize. It's dtype must
273 | match quantize_params.
274 | quantize_params (QuantizeAffineParams): The quantization parameters
275 | under which the data should be dequantized.
276 | """
277 | zero_point = quantize_params.zero_point
278 | scale = quantize_params.scale
279 | return (q_arr - zero_point) * scale
280 |
281 |
282 | def quantize_affine(
283 | input: torch.Tensor,
284 | min_value: Optional[numbers.Real] = None,
285 | max_value: Optional[numbers.Real] = None,
286 | num_bits: int = None,
287 | ) -> torch.Tensor:
288 | """Return a quantizable approximation of a float tensor @input.
289 |
290 | This function does not quantize the float tensor @input, but only adjusts it
291 | such that the returned float tensor has an exact quantized representation.
292 | This is a function that we want to use at training time to quantize weights
293 | and activations.
294 |
295 | Arguments:
296 | input (Tensor): The input float Tensor to be quantized.
297 | min_value (scalar): The running min value (possibly averaged).
298 | max_value (scalar): The running max value (possibly averaged).
299 | num_bits (numpy.dtype): The number of bits.
300 | """
301 | if num_bits is None:
302 | raise ValueError("num_bits must be supplied")
303 |
304 | if min_value is None:
305 | # Force include 0 in our calculation of min_value.
306 | min_value = min(input.min().item(), 0.0)
307 | if max_value is None:
308 | # Force include 0 in our calculation of max_value.
309 | max_value = max(input.max().item(), 0.0)
310 |
311 | quantize_params = get_quantization_params(min_value, max_value, num_bits)
312 | return QuantizeAffineFunction.apply(input, quantize_params)
313 |
314 |
315 | class QuantizeAffine(nn.Module):
316 | """Pytorch quantize_affine layer for quantizing layer outputs.
317 |
318 | This layer will keep a running max and min, which is used to compute a scale
319 | and zero_point for the quantization. Note that it is not always desirable
320 | to start the quantization immediately while training.
321 |
322 | Arguments:
323 | momentum (scalar): The amount of averaging of min and max bounds.
324 | This value should be in the range [0.0, 1.0].
325 | iteration_delay (scalar): The number of batches to wait before starting
326 | to quantize.
327 | """
328 |
329 | def __init__(
330 | self,
331 | momentum=0.1,
332 | iteration_delay=0,
333 | num_bits=8,
334 | quantizer_freeze_min_max=False,
335 | ):
336 | super().__init__()
337 | self.momentum = momentum
338 | self.iteration_delay = iteration_delay
339 | self.increment_counter = False
340 | self.num_bits = num_bits
341 | self.register_buffer("running_min_value", torch.tensor(0.0))
342 | self.register_buffer("running_max_value", torch.tensor(0.0))
343 | self.register_buffer(
344 | "iteration_count", torch.zeros([1], dtype=torch.int32).squeeze()
345 | )
346 | self.quantizer_freeze_min_max = quantizer_freeze_min_max
347 |
348 | def __repr__(self):
349 | return (
350 | f"{self.__class__.__name__}(running_min="
351 | f"{self.running_min_value}, running_max="
352 | f"{self.running_max_value}, freeze_min_max="
353 | f"{self.quantizer_freeze_min_max}, num_bits={self.num_bits})"
354 | )
355 |
356 | def update_num_bits(self, num_bits):
357 | self.num_bits = num_bits
358 |
359 | def forward(self, input, recomp_bn_stats=False, override_alpha=False):
360 | if (
361 | self.training
362 | and self.is_active()
363 | and not self.quantizer_freeze_min_max
364 | ):
365 | # Force include 0 in min_value and max_value calculation.
366 | min_value = min(input.min().item(), 0)
367 | max_value = max(input.max().item(), 0)
368 |
369 | if self.iteration_count == self.iteration_delay:
370 | new_running_min_value = min_value
371 | new_running_max_value = max_value
372 | else:
373 | new_running_min_value = (
374 | 1.0 - self.momentum
375 | ) * self.running_min_value.item() + self.momentum * min_value
376 | new_running_max_value = (
377 | 1.0 - self.momentum
378 | ) * self.running_max_value.item() + self.momentum * max_value
379 |
380 | self.running_min_value.fill_(new_running_min_value)
381 | self.running_max_value.fill_(new_running_max_value)
382 |
383 | if self.is_active():
384 | output = quantize_affine(
385 | input,
386 | self.running_min_value.item(),
387 | self.running_max_value.item(),
388 | self.num_bits,
389 | )
390 | else:
391 | output = input
392 |
393 | if self.training and self.increment_counter:
394 | self.iteration_count.fill_(self.iteration_count.item() + 1)
395 |
396 | return output
397 |
398 | def is_active(self):
399 | if self.training:
400 | return self.iteration_count >= self.iteration_delay
401 | # If evaluating, always run quantization:
402 | return True
403 |
--------------------------------------------------------------------------------
/models/quantized_modules.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | # Modified version of https://github.com/pytorch/pytorch/blob/master/torch/nn/intrinsic/qat/modules/conv_fused.py
7 |
8 | import math
9 | from typing import TypeVar
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.intrinsic as nni
14 | from torch.nn import init
15 | from torch.nn.modules.utils import _pair
16 | from torch.nn.parameter import Parameter
17 |
18 | from .quantize_affine import QuantizeAffine
19 | from .quantize_affine import quantize_affine
20 |
21 | MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
22 |
23 |
24 | class QuantStandardBN(nn.BatchNorm2d):
25 | def get_weight(self):
26 | return self.weight, self.bias
27 |
28 |
29 | class NoOpBN(nn.Module):
30 | def __init__(self, *args, **kwargs):
31 | super().__init__()
32 |
33 | def forward(self, input):
34 | return input
35 |
36 |
37 | class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
38 |
39 | _version = 2
40 | _FLOAT_MODULE = MOD
41 |
42 | def __init__(
43 | self,
44 | # ConvNd args
45 | in_channels,
46 | out_channels,
47 | kernel_size,
48 | stride,
49 | padding,
50 | dilation,
51 | transposed,
52 | output_padding,
53 | groups,
54 | bias,
55 | padding_mode,
56 | # BatchNormNd args
57 | # num_features: out_channels
58 | eps=1e-05,
59 | momentum=0.1,
60 | # affine: True
61 | # track_running_stats: True
62 | # Args for this module
63 | freeze_bn=False,
64 | dim=2,
65 | num_bits=8,
66 | iteration_delay=0,
67 | bn_module=None,
68 | norm_kwargs=None,
69 | ):
70 | nn.modules.conv._ConvNd.__init__(
71 | self,
72 | in_channels,
73 | out_channels,
74 | kernel_size,
75 | stride,
76 | padding,
77 | dilation,
78 | transposed,
79 | output_padding,
80 | groups,
81 | False,
82 | padding_mode,
83 | )
84 | self.freeze_bn = freeze_bn if self.training else True
85 | if bn_module is not None:
86 | if "GN" in bn_module.__name__:
87 | num_groups = norm_kwargs["num_groups"]
88 | if "Line" in bn_module.__name__:
89 | self.bn = bn_module(
90 | out_channels, eps=eps, num_groups=num_groups
91 | )
92 | else:
93 | self.bn = bn_module(
94 | out_channels,
95 | eps=eps,
96 | num_groups=num_groups,
97 | affine=True,
98 | )
99 | else:
100 | self.bn = bn_module(
101 | out_channels, eps=eps, momentum=momentum, affine=True
102 | )
103 |
104 | else:
105 | raise ValueError("BN module must be supplied")
106 |
107 | if bias:
108 | self.bias = Parameter(torch.empty(out_channels))
109 | else:
110 | self.register_parameter("bias", None)
111 |
112 | # this needs to be called after reset_bn_parameters,
113 | # as they modify the same state
114 | if self.training:
115 | if freeze_bn:
116 | self.freeze_bn_stats()
117 | else:
118 | self.update_bn_stats()
119 | else:
120 | self.freeze_bn_stats()
121 |
122 | # Quantization functions/parameters
123 | self.num_bits = num_bits
124 | self.iteration_delay = iteration_delay
125 | self.quantize_input = QuantizeAffine(
126 | iteration_delay=iteration_delay, num_bits=self.num_bits
127 | )
128 |
129 | def reset_running_stats(self):
130 | self.bn.reset_running_stats()
131 |
132 | def reset_bn_parameters(self):
133 | self.bn.reset_running_stats()
134 | init.uniform_(self.bn.weight)
135 | init.zeros_(self.bn.bias)
136 | # note: below is actully for conv, not BN
137 | if self.bias is not None:
138 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
139 | bound = 1 / math.sqrt(fan_in)
140 | init.uniform_(self.bias, -bound, bound)
141 |
142 | def reset_parameters(self):
143 | super(_ConvBnNd, self).reset_parameters()
144 |
145 | def update_bn_stats(self):
146 | self.freeze_bn = False
147 | self.bn.training = True
148 | return self
149 |
150 | def freeze_bn_stats(self):
151 | self.freeze_bn = True
152 | self.bn.training = False
153 | return self
154 |
155 | def update_num_bits(self, num_bits):
156 | self.num_bits = num_bits
157 | self.quantize_input.num_bits = num_bits
158 |
159 | def start_counter(self):
160 | self.quantize_input.increment_counter = True
161 |
162 | def stop_counter(self):
163 | self.quantize_input.increment_counter = False
164 |
165 | def is_active(self):
166 | return self.quantize_input.is_active()
167 |
168 | def _forward(self, input, weight):
169 | if isinstance(self.bn, nn.BatchNorm2d):
170 | assert self.bn.running_var is not None
171 | running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
172 | scale_factor = self.bn.get_weight()[0] / running_std
173 | elif isinstance(self.bn, (nn.InstanceNorm2d, nn.GroupNorm)):
174 | # We can't really apply the scale factor easily because each batch
175 | # element is weighted differently. So, we don't fuse the
176 | # InstanceNorm into the convolution.
177 | scale_factor = torch.ones(self.out_channels, device=weight.device)
178 |
179 | weight_shape = [1] * len(weight.shape)
180 | weight_shape[0] = -1
181 | bias_shape = [1] * len(weight.shape)
182 | bias_shape[1] = -1
183 | # Quantize weights
184 | # Note: weights are quantized from the beginning of training, i.e., no delay here
185 | scaled_weight = quantize_affine(
186 | weight * scale_factor.reshape(weight_shape), num_bits=self.num_bits
187 | )
188 | # Quantize input
189 | # Note: inputs are quantized after self.iteration_delay training iterations
190 | input = self.quantize_input(input)
191 | # using zero bias here since the bias for original conv
192 | # will be added later
193 | if self.bias is not None:
194 | zero_bias = torch.zeros_like(self.bias)
195 | else:
196 | zero_bias = torch.zeros(
197 | self.out_channels, device=scaled_weight.device
198 | )
199 | conv = self._conv_forward(input, scaled_weight, zero_bias)
200 | conv_orig = conv / scale_factor.reshape(bias_shape)
201 | if self.bias is not None:
202 | conv_orig = conv_orig + self.bias.reshape(bias_shape)
203 | conv = self.bn(conv_orig)
204 | return conv
205 |
206 | def extra_repr(self):
207 | return super(_ConvBnNd, self).extra_repr()
208 |
209 | # def forward(self, input):
210 | # return self._forward(input)
211 |
212 | def train(self, mode=True):
213 | """
214 | Batchnorm's training behavior is using the self.training flag. Prevent
215 | changing it if BN is frozen. This makes sure that calling `model.train()`
216 | on a model with a frozen BN will behave properly.
217 | """
218 | self.training = mode
219 | if not self.freeze_bn:
220 | for module in self.children():
221 | module.train(mode)
222 | return self
223 |
224 | # ===== Serialization version history =====
225 | #
226 | # Version 1/None
227 | # self
228 | # |--- weight : Tensor
229 | # |--- bias : Tensor
230 | # |--- gamma : Tensor
231 | # |--- beta : Tensor
232 | # |--- running_mean : Tensor
233 | # |--- running_var : Tensor
234 | # |--- num_batches_tracked : Tensor
235 | #
236 | # Version 2
237 | # self
238 | # |--- weight : Tensor
239 | # |--- bias : Tensor
240 | # |--- bn : Module
241 | # |--- weight : Tensor (moved from v1.self.gamma)
242 | # |--- bias : Tensor (moved from v1.self.beta)
243 | # |--- running_mean : Tensor (moved from v1.self.running_mean)
244 | # |--- running_var : Tensor (moved from v1.self.running_var)
245 | # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
246 | def _load_from_state_dict(
247 | self,
248 | state_dict,
249 | prefix,
250 | local_metadata,
251 | strict,
252 | missing_keys,
253 | unexpected_keys,
254 | error_msgs,
255 | ):
256 | version = local_metadata.get("version", None)
257 | if version is None or version == 1:
258 | # BN related parameters and buffers were moved into the BN module for v2
259 | v2_to_v1_names = {
260 | "bn.weight": "gamma",
261 | "bn.bias": "beta",
262 | "bn.running_mean": "running_mean",
263 | "bn.running_var": "running_var",
264 | "bn.num_batches_tracked": "num_batches_tracked",
265 | }
266 |
267 | for v2_name, v1_name in v2_to_v1_names.items():
268 | if prefix + v1_name in state_dict:
269 | state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
270 | state_dict.pop(prefix + v1_name)
271 | elif prefix + v2_name in state_dict:
272 | # there was a brief period where forward compatibility
273 | # for this module was broken (between
274 | # https://github.com/pytorch/pytorch/pull/38478
275 | # and https://github.com/pytorch/pytorch/pull/38820)
276 | # and modules emitted the v2 state_dict format while
277 | # specifying that version == 1. This patches the forward
278 | # compatibility issue by allowing the v2 style entries to
279 | # be used.
280 | pass
281 | elif strict:
282 | missing_keys.append(prefix + v2_name)
283 |
284 | super(_ConvBnNd, self)._load_from_state_dict(
285 | state_dict,
286 | prefix,
287 | local_metadata,
288 | strict,
289 | missing_keys,
290 | unexpected_keys,
291 | error_msgs,
292 | )
293 |
294 |
295 | class ConvBn2d(_ConvBnNd, nn.Conv2d):
296 | r"""
297 | A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
298 | attached with FakeQuantize modules for weight,
299 | used in quantization aware training.
300 | We combined the interface of :class:`torch.nn.Conv2d` and
301 | :class:`torch.nn.BatchNorm2d`.
302 | Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
303 | to default.
304 | Attributes:
305 | freeze_bn:
306 | weight_fake_quant: fake quant module for weight
307 | """
308 |
309 | def __init__(
310 | self,
311 | # ConvNd args
312 | in_channels,
313 | out_channels,
314 | kernel_size,
315 | stride=1,
316 | padding=0,
317 | dilation=1,
318 | groups=1,
319 | bias=None,
320 | padding_mode="zeros",
321 | # BatchNorm2d args
322 | # num_features: out_channels
323 | eps=1e-05,
324 | momentum=0.1,
325 | # affine: True
326 | # track_running_stats: True
327 | # Args for this module
328 | freeze_bn=False,
329 | num_bits=8,
330 | iteration_delay=0,
331 | bn_module=None,
332 | norm_kwargs=None,
333 | ):
334 | kernel_size = _pair(kernel_size)
335 | stride = _pair(stride)
336 | padding = _pair(padding)
337 | dilation = _pair(dilation)
338 | _ConvBnNd.__init__(
339 | self,
340 | in_channels,
341 | out_channels,
342 | kernel_size,
343 | stride,
344 | padding,
345 | dilation,
346 | False,
347 | _pair(0),
348 | groups,
349 | bias,
350 | padding_mode,
351 | eps,
352 | momentum,
353 | freeze_bn,
354 | dim=2,
355 | num_bits=num_bits,
356 | iteration_delay=iteration_delay,
357 | bn_module=bn_module,
358 | norm_kwargs=norm_kwargs,
359 | )
360 |
361 | def forward(self, input, weight=None):
362 | if weight is None:
363 | return _ConvBnNd._forward(self, input, self.weight)
364 | else:
365 | return _ConvBnNd._forward(self, input, weight)
366 |
367 |
368 | class QuantSubspaceConvBn2d(ConvBn2d):
369 | def forward(self, x):
370 | # call get_weight, which samples from the subspace, then use the
371 | # corresponding weight.
372 | w = self.get_weight()
373 | return super().forward(x, w)
374 |
375 |
376 | class TwoParamConvBnd2d(QuantSubspaceConvBn2d):
377 | def __init__(self, *args, **kwargs):
378 | super().__init__(*args, **kwargs)
379 | self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
380 |
381 | def initialize(self, initialize_fn):
382 | initialize_fn(self.weight)
383 | initialize_fn(self.weight1)
384 |
385 |
386 | class LinesConvBn2d(TwoParamConvBnd2d):
387 | def get_weight(self):
388 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
389 | return w
390 |
--------------------------------------------------------------------------------
/models/sparse_modules.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 |
8 | import utils
9 |
10 |
11 | class SparseConv2d(nn.Conv2d):
12 | """
13 | A class for sparse 2d convolution layers
14 | """
15 |
16 | def __init__(self, *args, **kwargs):
17 | """
18 | Inputs:
19 | - topk: A float between 0 and 1 indicating the percentage of top values
20 | to keep. E.g. if topk = 0.3, will keep the top 30% of weights (by
21 | magnitude) and drop the rest
22 | """
23 | self.method = kwargs.pop("method", "topk")
24 | self.topk = 1.0
25 | super(SparseConv2d, self).__init__(*args, **kwargs)
26 |
27 | def apply_sparsity(self, weight: torch.Tensor) -> torch.Tensor:
28 | return utils.apply_sparsity(self.topk, weight, method=self.method)
29 |
30 | def forward(self, x, weight=None):
31 | """
32 | Performs forward pass after passing weight tensor through apply_sparsity. Iterations are incremented only on train
33 | """
34 | if weight is None:
35 | return self._conv_forward(
36 | x, self.apply_sparsity(self.weight), self.bias
37 | )
38 | else:
39 | return self._conv_forward(x, self.apply_sparsity(weight), self.bias)
40 |
41 | def __repr__(self) -> str:
42 | ret = super().__repr__()
43 | # Remove last paren.
44 | ret = ret[:-1]
45 | ret += f", method={self.method})"
46 | return ret
47 |
48 |
49 | class SparseSubspaceConv(SparseConv2d):
50 | def forward(self, x):
51 | # call get_weight, which samples from the subspace, then use the corresponding weight.
52 | w = self.get_weight()
53 | return super().forward(x, w)
54 |
55 |
56 | class SparseTwoParamConv(SparseSubspaceConv):
57 | def __init__(self, *args, **kwargs):
58 | super().__init__(*args, **kwargs)
59 | self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
60 |
61 | def initialize(self, initialize_fn):
62 | initialize_fn(self.weight)
63 | initialize_fn(self.weight1)
64 |
65 |
66 | class SparseLinesConv(SparseTwoParamConv):
67 | def get_weight(self):
68 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
69 | return w
70 |
--------------------------------------------------------------------------------
/models/special_tensors.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | """Utility functions to tag tensors with metadata.
6 |
7 | The metadata remains with the tensor under torch operations that don't change
8 | the values, e.g. .clone(), .contiguous(), .permute(), etc.
9 | """
10 |
11 | import collections
12 | import copy
13 | from typing import Any
14 | from typing import Optional
15 |
16 | import numpy as np
17 | import torch
18 |
19 | QuantizeAffineParams2 = collections.namedtuple(
20 | "QuantizeAffineParams", ["scale", "zero_point", "num_bits"]
21 | )
22 |
23 |
24 | class _SpecialTensor(torch.Tensor):
25 | """This class denotes special tensors.
26 |
27 | It isn't intended to be used directly, but serves as a helper for tagging
28 | tensors with metadata.
29 |
30 | It subclasses torch.Tensor so that isinstance(t, torch.Tensor) returns True
31 | for special tensors. It forbids some of the methods of torch.Tensor, and
32 | overrides a few methods used to create other tensors, to ensure the result
33 | is still special.
34 | """
35 |
36 | _metadata = None
37 |
38 | def __getattribute__(self, attr: str) -> Any:
39 | # Disallow new_zeros, new_ones, new_full, etc.
40 | if "new_" in attr:
41 | raise AttributeError(
42 | "Invalid attr {!r} for special tensors".format(attr)
43 | )
44 | return super().__getattribute__(attr)
45 |
46 | def detach(self) -> "_SpecialTensor":
47 | ret = super().detach()
48 | ret.__class__ = _SpecialTensor
49 | ret._metadata = self._metadata
50 | return ret
51 |
52 | @property
53 | def data(self) -> "_SpecialTensor":
54 | ret = super().data
55 | ret.__class__ = _SpecialTensor
56 | ret._metadata = self._metadata
57 | return ret
58 |
59 | def clone(self) -> "_SpecialTensor":
60 | ret = super().clone()
61 | ret.__class__ = _SpecialTensor
62 | ret._metadata = self._metadata
63 | return ret
64 |
65 | def cuda(
66 | self, device: Optional[torch.device] = None, non_blocking: bool = False
67 | ) -> "_SpecialTensor":
68 | ret = super().cuda()
69 | ret.__class__ = _SpecialTensor
70 | ret._metadata = self._metadata
71 | return ret
72 |
73 | def contiguous(self) -> "_SpecialTensor":
74 | ret = super().contiguous()
75 | ret.__class__ = _SpecialTensor
76 | ret._metadata = self._metadata
77 | return ret
78 |
79 | def view(self, *args, **kwargs) -> "_SpecialTensor":
80 | ret = super().view(*args, **kwargs)
81 | ret.__class__ = _SpecialTensor
82 | ret._metadata = self._metadata
83 | return ret
84 |
85 | def permute(self, *args, **kwargs) -> "_SpecialTensor":
86 | ret = super().permute(*args, **kwargs)
87 | ret.__class__ = _SpecialTensor
88 | ret._metadata = self._metadata
89 | return ret
90 |
91 | def __getitem__(self, *args, **kwargs) -> "_SpecialTensor":
92 | ret = super().__getitem__(*args, **kwargs)
93 | ret.__class__ = _SpecialTensor
94 | ret._metadata = self._metadata
95 | return ret
96 |
97 | def __copy__(self) -> "_SpecialTensor":
98 | ret = copy.copy(super())
99 | ret.__class__ = _SpecialTensor
100 | ret._metadata = self._metadata
101 | return ret
102 |
103 |
104 | def _check_type(tensor: torch.Tensor) -> None:
105 | given_type = type(tensor)
106 | if not issubclass(given_type, torch.Tensor):
107 | raise TypeError("invalid type {!r}".format(given_type))
108 |
109 |
110 | def tag_with_metadata(tensor: torch.Tensor, metadata: Any) -> None:
111 | """Tag a metadata to a tensor."""
112 | _check_type(tensor)
113 | tensor.__class__ = _SpecialTensor
114 | tensor._metadata = metadata
115 |
116 |
117 | RepresentibleByQuantizeAffine = collections.namedtuple(
118 | "RepresentibleByQuantizeAffine", ["quant_params"]
119 | )
120 |
121 |
122 | def mark_quantize_affine(
123 | tensor: torch.Tensor,
124 | scale: float,
125 | zero_point: int,
126 | dtype: np.dtype = np.uint8,
127 | ) -> None:
128 | """Mark a tensor as quantized with affine.
129 |
130 | See //xnorai/training/pytorch/extensions/functions:quantize_affine for more
131 | info on this method of quantization.
132 |
133 | The tensor itself can be a floating point Tensor. However, its values must
134 | be representible with @scale and @zero_point. This function, for performance
135 | reasons, does not validiate if the tensor is really quantizable as it
136 | claims to be.
137 |
138 | Arguments:
139 | tensor (torch.Tensor): The tensor to be marked as affine-quantizable
140 | Tensor.
141 | scale (float): the scale (from quantization parameters).
142 | zero_point (int): The zero_point (from quantization parameters).
143 | dtype (numpy.dtype): Type of tensor when quantized (this is usually
144 | numpy.uint8, which is used for Q8). A ValueError will be thrown if
145 | the input dtype is not one of the following:
146 | {numpy.uint8, numpy.int32}.
147 | """
148 | allowed_dtypes = [np.uint8, np.int32]
149 | if dtype not in allowed_dtypes:
150 | raise ValueError(
151 | "Provided dtype ({}) is not supported. Please use: {}".format(
152 | dtype, allowed_dtypes
153 | )
154 | )
155 | quant_params = QuantizeAffineParams2(scale, zero_point, dtype)
156 | tag_with_metadata(tensor, RepresentibleByQuantizeAffine(quant_params))
157 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | autoflake==1.4
2 | black==20.8b1
3 | isort==5.8.0
4 | numpy==1.19.5
5 | scipy==1.5.4
6 | torch==1.8.1+cu111
7 | torchvision==0.9.1+cu111
8 |
--------------------------------------------------------------------------------
/schedulers.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import numpy as np
6 |
7 | __all__ = ["cosine_lr"]
8 |
9 |
10 | def assign_learning_rate(optimizer, new_lr):
11 | for param_group in optimizer.param_groups:
12 | param_group["lr"] = new_lr
13 |
14 |
15 | def cosine_lr(optimizer, learning_rate, *, warmup_length, epochs):
16 | def _lr_adjuster(epoch, iteration):
17 | if epoch < warmup_length:
18 | lr = _warmup_lr(learning_rate, warmup_length, epoch)
19 | else:
20 | e = epoch - warmup_length
21 | es = epochs - warmup_length
22 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * learning_rate
23 |
24 | assign_learning_rate(optimizer, lr)
25 | print(f"Assigned lr={lr}")
26 | return lr
27 |
28 | return _lr_adjuster
29 |
30 |
31 | def _warmup_lr(base_lr, warmup_length, epoch):
32 | return base_lr * (epoch + 1) / warmup_length
33 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | pip install -r requirements.txt
2 |
--------------------------------------------------------------------------------
/tools/format-python:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | autoflake --in-place --remove-unused-variables --remove-all-unused-imports $1
3 | black -l 80 $1
4 | isort -sl $1
5 | tools/remove-whitespace.sh $1
6 |
--------------------------------------------------------------------------------
/tools/remove-whitespace.sh:
--------------------------------------------------------------------------------
1 | sed -i '' -E 's/[ '$'\t'']+$//' "$1"
2 |
--------------------------------------------------------------------------------
/train_curve.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | # Trains a standard/sparse/quantized line/poly chain/bezier curve.
7 |
8 | import random
9 |
10 | import numpy as np
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | import torch.nn as nn
14 |
15 | import curve_utils
16 | import model_logging
17 | import models.networks as models
18 | import utils
19 | from models import modules
20 | from models.builder import Builder
21 | from models.networks import resprune
22 | from models.networks import utils as network_utils
23 | from models.networks import vggprune
24 | from models.quantized_modules import ConvBn2d
25 |
26 |
27 | def sparse_module_updates(model, training=False, alpha=None, **regime_params):
28 | current_iteration = regime_params["current_iteration"]
29 | warmup_iterations = regime_params["warmup_iterations"]
30 |
31 | if alpha is None:
32 | alpha = curve_utils.alpha_sampling(**regime_params)
33 |
34 | df = np.max([1 - current_iteration / warmup_iterations, 0])
35 | topk = alpha + (1 - alpha) * df
36 |
37 | is_standard = regime_params["is_standard"]
38 |
39 | for m in model.modules():
40 | if isinstance(m, nn.Conv2d) or isinstance(
41 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)
42 | ):
43 | setattr(m, f"alpha", alpha)
44 |
45 | if (not is_standard) and hasattr(m, "topk"):
46 | setattr(m, "topk", topk)
47 |
48 | if training:
49 | regime_params["current_iteration"] += 1
50 |
51 | return model, regime_params
52 |
53 |
54 | def quantized_module_updates(
55 | model, training=False, alpha=None, **regime_params
56 | ):
57 | if alpha is None:
58 | alpha, num_bits = curve_utils.sample_alpha_num_bits(**regime_params)
59 | else:
60 | num_bits = curve_utils.alpha_bit_map(alpha, **regime_params)
61 |
62 | is_standard = regime_params["is_standard"]
63 |
64 | for m in model.modules():
65 | if isinstance(m, nn.Conv2d) or isinstance(
66 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)
67 | ):
68 | setattr(m, f"alpha", alpha)
69 |
70 | if isinstance(m, ConvBn2d):
71 | if not is_standard:
72 | m.update_num_bits(num_bits)
73 |
74 | if training:
75 | m.start_counter()
76 | else:
77 | m.stop_counter()
78 |
79 | return model, regime_params
80 |
81 |
82 | def _test_time_lec_update(model, **regime_params):
83 | # This requires that the topk values are already set on the model.
84 |
85 | # We create a whole new copy of the model which is pruned.
86 | model_kwargs = regime_params["model_kwargs"]
87 | fresh_copy = utils.make_fresh_copy_of_pruned_network(model, model_kwargs)
88 | cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy)
89 |
90 | builder = Builder(conv_type="StandardConv", bn_type="StandardIN")
91 |
92 | try:
93 | if isinstance(model, models.cpreresnet):
94 | model_class = resprune
95 | elif isinstance(model, models.vgg.vgg):
96 | model_class = vggprune
97 | else:
98 | raise ValueError(
99 | "Model {} is not surpported for LEC.".format(model)
100 | )
101 |
102 | _, slimmed_network = model_class.get_slimmed_network(
103 | fresh_copy.module,
104 | {"builder": builder, "block_builder": builder, **model_kwargs},
105 | cfg,
106 | cfg_mask,
107 | )
108 | except:
109 | print(
110 | f"Something went wrong during LEC. Most likely, an entire "
111 | f"layer was deleted. Using @fresh_copy."
112 | )
113 | slimmed_network = fresh_copy
114 | num_parameters = sum(
115 | [param.nelement() for param in slimmed_network.parameters()]
116 | )
117 |
118 | # NOTE: DO NOT use @model here, since it has too many extra buffers in the
119 | # case of training a line.
120 | total_params = sum([param.nelement() for param in fresh_copy.parameters()])
121 | regime_params["sparsity"] = (total_params - num_parameters) / total_params
122 |
123 | print(f"Got sparsity level of {regime_params['sparsity']}")
124 |
125 | return slimmed_network, regime_params
126 |
127 |
128 | def lec_update(model, training=False, alpha=None, **regime_params):
129 | # Note: this file is for training a *curve*, so we do the right thing for
130 | # curves here. Namely, we apply the same update to alpha/topk as when
131 | # training a standard curve, but we gather the sparsity in a different way
132 | # since we're doing LEC rather than doing the unstructured sparsity.
133 | model, regime_params = sparse_module_updates(
134 | model, training=training, alpha=alpha, **regime_params
135 | )
136 |
137 | if training:
138 | return model, regime_params
139 | else:
140 | return _test_time_lec_update(model, **regime_params)
141 |
142 |
143 | def us_update(model, training=False, alpha=None, **regime_params):
144 | assert alpha is not None, f"Alpha value is required."
145 |
146 | # Use alpha to compute the width factor.
147 | low_factor, high_factor = regime_params["width_factor_limits"]
148 | width_factor = low_factor + (high_factor - low_factor) * alpha
149 | regime_params["width_factor"] = width_factor
150 |
151 | for module in model.modules():
152 | if hasattr(module, "width_factor"):
153 | module.width_factor = width_factor
154 |
155 | for m in model.modules():
156 | if isinstance(m, nn.Conv2d) or isinstance(
157 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)
158 | ):
159 | setattr(m, f"alpha", alpha)
160 |
161 | return model, regime_params
162 |
163 |
164 | def train(
165 | model, train_loader, optimizer, criterion, epoch, device, **regime_params
166 | ):
167 |
168 | model.zero_grad()
169 | model.train()
170 | avg_loss = 0.0
171 |
172 | regime_update_dict = {
173 | "sparse": sparse_module_updates,
174 | "quantized": quantized_module_updates,
175 | "lec": lec_update,
176 | "us": us_update,
177 | }
178 | module_update = regime_update_dict[regime_params["regime"]]
179 |
180 | for batch_idx, (data, target) in enumerate(train_loader):
181 |
182 | data, target = data.to(device, non_blocking=True), target.to(
183 | device, non_blocking=True
184 | )
185 |
186 | optimizer.zero_grad()
187 |
188 | if regime_params["regime"] == "us":
189 | loss = 0
190 |
191 | if regime_params["width_factor_sampling_method"] == "sandwich":
192 | n_samples = regime_params["width_factor_samples"]
193 |
194 | assert n_samples >= 2, f"Require n_samples>=2, got {n_samples}"
195 |
196 | alphas = [0, 1]
197 | for i in range(n_samples - 2):
198 | alphas.append(random.uniform(0.0, 1.0))
199 | elif regime_params["width_factor_sampling_method"] == "point":
200 | alphas = [random.uniform(0.0, 1.0)]
201 | else:
202 | raise NotImplementedError
203 |
204 | for alpha in alphas:
205 | model, regime_params = module_update(
206 | model, True, alpha=alpha, **regime_params
207 | )
208 | output = model(data)
209 | loss += criterion(output, target)
210 |
211 | else:
212 | model, regime_params = module_update(
213 | model, training=True, **regime_params
214 | )
215 | output = model(data)
216 | loss = criterion(output, target)
217 |
218 | # Application of the regularization term, equation 3.
219 | num_points = regime_params["num_points"]
220 | beta = regime_params.get("beta", 1)
221 | if beta > 0 and num_points > 1:
222 | out = random.sample([i for i in range(num_points)], 2)
223 |
224 | i, j = out[0], out[1]
225 | num = 0.0
226 | normi = 0.0
227 | normj = 0.0
228 | for m in model.modules():
229 | # Apply beta term if we have a conv, and (optionally) if we have
230 | # a norm layer. Only apply beta term if alpha exists (e.g. it's
231 | # a line).
232 | # (We forbid an exact type match because "plain-old" Conv2d
233 | # layers [as in LEC] should not trigger this logic).
234 | should_apply_beta = isinstance(m, nn.Conv2d) and not type(
235 | m
236 | ) in (nn.Conv2d, modules.AdaptiveConv2d)
237 | should_apply_beta = should_apply_beta or (
238 | isinstance(
239 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)
240 | )
241 | and regime_params.get("apply_beta_to_norm", False)
242 | )
243 | should_apply_beta = should_apply_beta and hasattr(m, "alpha")
244 | if should_apply_beta:
245 | vi = curve_utils.get_weight(m, i)
246 | vj = curve_utils.get_weight(m, j)
247 | num += (vi * vj).sum()
248 | normi += vi.pow(2).sum()
249 | normj += vj.pow(2).sum()
250 | loss += beta * (num.pow(2) / (normi * normj))
251 |
252 | loss.backward()
253 | if regime_params.get("bn_update_factor") is not None:
254 | loss_reg_term = (
255 | utils.apply_in_topk_reg(model, apply_to_bias=False)
256 | * regime_params["bn_update_factor"]
257 | )
258 | else:
259 | loss_reg_term = 0
260 | optimizer.step()
261 |
262 | avg_loss += loss.item() + loss_reg_term
263 |
264 | log_interval = 10
265 |
266 | if batch_idx % log_interval == 0:
267 | num_samples = batch_idx * len(data)
268 | num_epochs = len(train_loader.dataset)
269 | percent_complete = 100.0 * batch_idx / len(train_loader)
270 |
271 | predicted_labels = output.argmax(dim=1)
272 | corrects = (
273 | predicted_labels == target
274 | ).float().sum() / target.numel()
275 |
276 | print(
277 | f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t"
278 | f"Loss: {loss.item():.6f} Correct: {corrects.item():.4f}"
279 | )
280 |
281 | model.apply(lambda m: setattr(m, "return_feats", False))
282 |
283 | avg_loss = avg_loss / len(train_loader)
284 |
285 | return avg_loss, regime_params
286 |
287 |
288 | def test(
289 | model,
290 | alpha,
291 | val_loader,
292 | criterion,
293 | epoch,
294 | device,
295 | metric_dict=None,
296 | **regime_params,
297 | ):
298 |
299 | model.zero_grad()
300 | model.eval()
301 | test_loss = 0
302 | correct = 0
303 |
304 | regime_update_dict = {
305 | "sparse": sparse_module_updates,
306 | "quantized": quantized_module_updates,
307 | "lec": lec_update,
308 | "us": us_update,
309 | }
310 | logging_dict = {
311 | "sparse": model_logging.sparse_logging,
312 | "quantized": model_logging.quantized_logging,
313 | "lec": model_logging.lec_logging,
314 | "us": model_logging.us_logging,
315 | }
316 | module_update = regime_update_dict[regime_params["regime"]]
317 | logging = logging_dict[regime_params["regime"]]
318 |
319 | model, regime_params = module_update(model, alpha=alpha, **regime_params)
320 |
321 | # optionally add the hooks needed for tracking BN accuracy stats.
322 | if regime_params.get("bn_accuracy_stats", True):
323 | hooks, mean_dict, var_dict = utils.register_bn_tracking_hooks(model)
324 |
325 | with torch.no_grad():
326 |
327 | for batch_idx, (data, target) in enumerate(val_loader):
328 | data, target = data.to(device, non_blocking=True), target.to(
329 | device, non_blocking=True
330 | )
331 |
332 | output = model(data)
333 | test_loss += criterion(output, target).item()
334 |
335 | # get the index of the max log-probability
336 | pred = output.argmax(dim=1, keepdim=True)
337 |
338 | correct += pred.eq(target.view_as(pred)).sum().item()
339 |
340 | if regime_params.get("bn_accuracy_stats", True):
341 | utils.unregister_bn_tracking_hooks(hooks)
342 | extra_metrics = utils.get_bn_accuracy_metrics(
343 | model, mean_dict, var_dict
344 | )
345 |
346 | # Mean_dict and var_dict contain a mapping from modules to their
347 | else:
348 | extra_metrics = {}
349 |
350 | test_loss /= len(val_loader)
351 | test_acc = float(correct) / len(val_loader.dataset)
352 | return logging(
353 | model,
354 | test_loss,
355 | test_acc,
356 | model_type="curve",
357 | param=alpha,
358 | metric_dict=metric_dict,
359 | extra_metrics=extra_metrics,
360 | **regime_params,
361 | )
362 |
363 |
364 | def train_model(config):
365 | # Get network, data, and training parameters
366 | (
367 | net,
368 | opt_params,
369 | train_params,
370 | regime_params,
371 | data,
372 | ) = utils.network_and_params(config)
373 |
374 | # Unpack training parameters
375 | epochs, test_freq, alpha_grid = train_params
376 |
377 | # Unpack dataset
378 | trainset, trainloader, testset, testloader = data
379 |
380 | # Unpack optimization parameters
381 | criterion, optimizer, scheduler = opt_params
382 |
383 | device = "cuda" if torch.cuda.is_available() else "cpu"
384 | net = net.to(device)
385 |
386 | if device == "cuda":
387 | cudnn.benchmark = True
388 | net = torch.nn.DataParallel(net)
389 |
390 | start_epoch = 0
391 |
392 | metric_dict = {"acc": [], "alpha": []}
393 | if regime_params["regime"] == "quantized":
394 | metric_dict["num_bits"] = []
395 | if regime_params["regime"] in (
396 | "sparse",
397 | "lec",
398 | "us",
399 | "ns",
400 | ):
401 | metric_dict["sparsity"] = []
402 |
403 | save_dir = regime_params["save_dir"]
404 |
405 | for epoch in range(start_epoch, start_epoch + epochs):
406 | scheduler(epoch, None)
407 | if epoch % test_freq == 0:
408 | for alpha in alpha_grid:
409 | test(
410 | net,
411 | alpha,
412 | testloader,
413 | criterion,
414 | epoch,
415 | device,
416 | metric_dict=None,
417 | **regime_params,
418 | )
419 |
420 | model_logging.save_model_at_epoch(net, epoch, save_dir)
421 |
422 | _, regime_params = train(
423 | net,
424 | trainloader,
425 | optimizer,
426 | criterion,
427 | epoch,
428 | device,
429 | **regime_params,
430 | )
431 |
432 | # Save final model
433 | for alpha in alpha_grid:
434 | metric_dict = test(
435 | net,
436 | alpha,
437 | testloader,
438 | criterion,
439 | epoch,
440 | device,
441 | metric_dict=metric_dict,
442 | **regime_params,
443 | )
444 |
445 | model_logging.save_model_at_epoch(net, epoch + 1, save_dir)
446 | model_logging.save("test_metrics.npy", metric_dict, save_dir)
447 |
--------------------------------------------------------------------------------
/train_indep.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | from torch import nn
11 |
12 | import model_logging
13 | import models.networks as networks
14 | import utils
15 | from models.builder import Builder
16 | from models.networks import resprune
17 | from models.networks import utils as network_utils
18 | from models.networks import vggprune
19 | from models.quantized_modules import ConvBn2d
20 | from models.sparse_modules import SparseConv2d
21 |
22 |
23 | def sparse_module_updates(model, training=False, **regime_params):
24 | topk = regime_params["topk"]
25 | current_iteration = regime_params["current_iteration"]
26 | warmup_iterations = regime_params["warmup_iterations"]
27 | topk_lower, topk_upper = regime_params["alpha_sampling"][0:2]
28 | if regime_params["random_param"]:
29 | if training:
30 | if current_iteration < warmup_iterations:
31 | current_topk = topk_upper
32 | else:
33 | # Unbiased sampling
34 | current_topk = np.random.uniform(topk_lower, topk_upper)
35 | else:
36 | current_topk = regime_params["topk"]
37 | else:
38 | df = np.max([1 - current_iteration / warmup_iterations, 0])
39 | current_topk = topk + (1 - topk) * df
40 |
41 | for m in model.modules():
42 | if isinstance(m, SparseConv2d):
43 | setattr(m, f"topk", current_topk)
44 |
45 | if training:
46 | regime_params["current_iteration"] += 1
47 |
48 | return model, regime_params
49 |
50 |
51 | def quantized_module_updates(model, training=False, **regime_params):
52 | set_bits = regime_params["num_bits"]
53 |
54 | for m in model.modules():
55 | if isinstance(m, ConvBn2d):
56 | if regime_params["random_param"]:
57 | if training:
58 | m.start_counter()
59 | if m.is_active():
60 | bits = np.arange(
61 | regime_params["min_bits"],
62 | regime_params["max_bits"] + 1,
63 | )
64 | rand_bits = np.random.choice(bits)
65 | m.update_num_bits(rand_bits)
66 | else:
67 | m.update_num_bits(regime_params["max_bits"])
68 | else:
69 | m.stop_counter()
70 | m.update_num_bits(set_bits)
71 | else:
72 | m.update_num_bits(set_bits)
73 | if training:
74 | m.start_counter()
75 | else:
76 | m.stop_counter()
77 |
78 | return model, regime_params
79 |
80 |
81 | def lec_update(model, training=False, **regime_params):
82 | # The original LEC paper does the update using a global threshold, so we
83 | # adopt that strategy here.
84 | model, regime_params = sparse_module_updates(
85 | model, training=training, **regime_params
86 | )
87 |
88 | if training:
89 | return model, regime_params
90 | else:
91 | # We create a pruned copy of the model.
92 | model_kwargs = regime_params["model_kwargs"]
93 | fresh_copy = utils.make_fresh_copy_of_pruned_network(
94 | model, model_kwargs
95 | )
96 |
97 | # The @fresh_copy needs to have its smallest InstanceNorm parameters
98 | # deleted.
99 | topk = regime_params["topk"]
100 | all_weights = []
101 | for m in fresh_copy.modules():
102 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
103 | all_weights.append(m.weight.abs())
104 |
105 | all_weights = torch.cat(all_weights, dim=0)
106 | y, i = torch.sort(all_weights)
107 | threshold = y[int(all_weights.shape[0] * (1.0 - topk))]
108 |
109 | for m in fresh_copy.modules():
110 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
111 | mask = m.weight.data.clone().abs().gt(threshold).float()
112 | m.weight.data.mul_(mask)
113 | m.bias.data.mul_(mask)
114 |
115 | # Now that we have the sparse copy, we slim it down.
116 | cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy)
117 |
118 | builder = Builder(
119 | conv_type="StandardConv", bn_type=regime_params["bn_type"]
120 | )
121 |
122 | try:
123 | if isinstance(model, nn.DataParallel):
124 | model = model.module
125 |
126 | if isinstance(model, networks.cpreresnet):
127 | model_class = resprune
128 | elif isinstance(model, networks.vgg.vgg):
129 | model_class = vggprune
130 | else:
131 | raise ValueError(
132 | "Model {} is not supported for LEC.".format(model)
133 | )
134 |
135 | _, slimmed_network = model_class.get_slimmed_network(
136 | fresh_copy.module,
137 | {"builder": builder, "block_builder": builder, **model_kwargs},
138 | cfg,
139 | cfg_mask,
140 | )
141 | except IndexError:
142 | # This is the error if we eliminate a whole layer.
143 | print(
144 | f"Something went wrong during LEC - most likely, an entire "
145 | f"layer was deleted. Using @fresh_copy."
146 | )
147 | slimmed_network = fresh_copy
148 | num_parameters = sum(
149 | [param.nelement() for param in slimmed_network.parameters()]
150 | )
151 |
152 | # NOTE: DO NOT use @model here, since it has too many extra buffers.
153 | total_params = sum(
154 | [param.nelement() for param in fresh_copy.parameters()]
155 | )
156 | regime_params["sparsity"] = (
157 | total_params - num_parameters
158 | ) / total_params
159 |
160 | print(f"Got sparsity level of {regime_params['sparsity']}")
161 |
162 | return slimmed_network, regime_params
163 |
164 |
165 | def ns_update(model, training=False, **regime_params):
166 | current_width = regime_params["width_factor"]
167 |
168 | for module in model.modules():
169 | if hasattr(module, "width_factor"):
170 | module.width_factor = current_width
171 |
172 | return model, regime_params
173 |
174 |
175 | def us_update(*args, **kwargs):
176 | return ns_update(*args, **kwargs)
177 |
178 |
179 | def train(
180 | model, train_loader, optimizer, criterion, epoch, device, **regime_params
181 | ):
182 | model.zero_grad()
183 | model.train()
184 | avg_loss = 0.0
185 |
186 | regime_update_dict = {
187 | "sparse": sparse_module_updates,
188 | "quantized": quantized_module_updates,
189 | "lec": lec_update,
190 | "ns": ns_update,
191 | "us": us_update, # [sic]
192 | }
193 | module_update = regime_update_dict[regime_params["regime"]]
194 |
195 | for batch_idx, (data, target) in enumerate(train_loader):
196 | data, target = data.to(device, non_blocking=True), target.to(
197 | device, non_blocking=True
198 | )
199 |
200 | # Special case: you want to do multiple iterations of training.
201 | if regime_params["regime"] == "ns":
202 | loss = 0
203 | optimizer.zero_grad()
204 | for width_factor in regime_params["width_factors_list"]:
205 | regime_params["width_factor"] = width_factor
206 |
207 | model, regime_params = module_update(
208 | model, True, **regime_params
209 | )
210 | output = model(data)
211 | loss += criterion(output, target)
212 | elif regime_params["regime"] == "us":
213 | loss = 0
214 | optimizer.zero_grad()
215 |
216 | low_factor, high_factor = regime_params["width_factor_limits"]
217 | n_samples = regime_params["width_factor_samples"]
218 |
219 | assert n_samples >= 2, f"Require n_samples>=2, got {n_samples}"
220 |
221 | width_factors = [low_factor, high_factor]
222 | for i in range(n_samples - 2):
223 | width_factors.append(random.uniform(low_factor, high_factor))
224 |
225 | for width_factor in width_factors:
226 | regime_params["width_factor"] = width_factor
227 |
228 | model, regime_params = module_update(
229 | model, True, **regime_params
230 | )
231 | output = model(data)
232 | loss += criterion(output, target)
233 | else:
234 | model, regime_params = module_update(model, True, **regime_params)
235 |
236 | optimizer.zero_grad()
237 | output = model(data)
238 | loss = criterion(output, target)
239 |
240 | loss.backward()
241 | if regime_params.get("bn_update_factor") is not None:
242 | apply_norm_regularization(model, regime_params["bn_update_factor"])
243 | optimizer.step()
244 |
245 | avg_loss += loss.item()
246 |
247 | log_interval = 10
248 |
249 | if batch_idx % log_interval == 0:
250 | num_samples = batch_idx * len(data)
251 | num_epochs = len(train_loader.dataset)
252 | percent_complete = 100.0 * batch_idx / len(train_loader)
253 |
254 | predicted_labels = output.argmax(dim=1)
255 | corrects = (
256 | predicted_labels == target
257 | ).float().sum() / target.numel()
258 |
259 | print(
260 | f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t"
261 | f"Loss: {loss.item():.6f} Correct: {corrects.item():.4f}"
262 | )
263 |
264 | model.apply(lambda m: setattr(m, "return_feats", False))
265 | avg_loss = avg_loss / len(train_loader)
266 |
267 | return avg_loss, regime_params
268 |
269 |
270 | def apply_norm_regularization(model, s_factor):
271 | count = 0
272 | for m in model.modules():
273 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
274 | count += 1
275 | m.weight.grad.data.add_(s_factor * torch.sign(m.weight.data)) # L1
276 | if count == 0:
277 | raise ValueError(f"Didn't adjust any Norms")
278 |
279 |
280 | def test(
281 | model,
282 | param,
283 | val_loader,
284 | criterion,
285 | epoch,
286 | device,
287 | metric_dict=None,
288 | **regime_params,
289 | ):
290 | # Set eval param
291 | if regime_params["regime"] in ("sparse", "lec"):
292 | regime_params["topk"] = param
293 | elif regime_params["regime"] in ("ns", "us"):
294 | regime_params["width_factor"] = param
295 | else:
296 | regime_params["num_bits"] = param
297 |
298 | metric_dict = _test(
299 | model,
300 | param,
301 | val_loader,
302 | criterion,
303 | epoch,
304 | device,
305 | metric_dict=metric_dict,
306 | **regime_params,
307 | )
308 | return metric_dict
309 |
310 |
311 | def _test(
312 | model,
313 | param,
314 | val_loader,
315 | criterion,
316 | epoch,
317 | device,
318 | metric_dict=None,
319 | **regime_params,
320 | ):
321 |
322 | model.zero_grad()
323 | model.eval()
324 | test_loss = 0
325 | correct = 0
326 |
327 | regime_update_dict = {
328 | "sparse": sparse_module_updates,
329 | "quantized": quantized_module_updates,
330 | "lec": lec_update,
331 | "ns": ns_update,
332 | "us": us_update,
333 | }
334 | logging_dict = {
335 | "sparse": model_logging.sparse_logging,
336 | "quantized": model_logging.quantized_logging,
337 | "lec": model_logging.lec_logging,
338 | "ns": model_logging.ns_logging,
339 | "us": model_logging.us_logging,
340 | }
341 | module_update = regime_update_dict[regime_params["regime"]]
342 | logging = logging_dict[regime_params["regime"]]
343 |
344 | model, regime_params = module_update(model, **regime_params)
345 |
346 | # optionally add the hooks needed for tracking BN accuracy stats.
347 | if regime_params.get("bn_accuracy_stats", True):
348 | hooks, mean_dict, var_dict = utils.register_bn_tracking_hooks(model)
349 |
350 | with torch.no_grad():
351 | for batch_idx, (data, target) in enumerate(val_loader):
352 |
353 | data, target = data.to(device, non_blocking=True), target.to(
354 | device, non_blocking=True
355 | )
356 |
357 | output = model(data)
358 | test_loss += criterion(output, target).item()
359 |
360 | # get the index of the max log-probability
361 | pred = output.argmax(dim=1, keepdim=True)
362 |
363 | correct += pred.eq(target.view_as(pred)).sum().item()
364 |
365 | if regime_params.get("bn_accuracy_stats", True):
366 | utils.unregister_bn_tracking_hooks(hooks)
367 | extra_metrics = utils.get_bn_accuracy_metrics(
368 | model, mean_dict, var_dict
369 | )
370 |
371 | # Mean_dict and var_dict contain a mapping from modules to their
372 | else:
373 | extra_metrics = {}
374 |
375 | test_loss /= len(val_loader)
376 | test_acc = float(correct) / len(val_loader.dataset)
377 | return logging(
378 | model,
379 | test_loss,
380 | test_acc,
381 | model_type="indep",
382 | param=param,
383 | metric_dict=metric_dict,
384 | extra_metrics=extra_metrics,
385 | **regime_params,
386 | )
387 |
388 |
389 | def train_model(config):
390 | # Get network, data, and training parameters
391 | (
392 | net,
393 | opt_params,
394 | train_params,
395 | regime_params,
396 | data,
397 | ) = utils.network_and_params(config)
398 |
399 | # Unpack training parameters
400 | epochs, test_freq, _ = train_params
401 |
402 | # Unpack dataset
403 | trainset, trainloader, testset, testloader = data
404 |
405 | # Unpack optimizer parameters
406 | criterion, optimizer, scheduler = opt_params
407 |
408 | device = "cuda" if torch.cuda.is_available() else "cpu"
409 | net = net.to(device)
410 |
411 | if device == "cuda":
412 | cudnn.benchmark = True
413 | net = torch.nn.DataParallel(net)
414 |
415 | start_epoch = 0
416 |
417 | metric_dict = {"acc": [], "alpha": []}
418 | if regime_params["regime"] == "quantized":
419 | metric_dict["num_bits"] = []
420 | if regime_params["regime"] in (
421 | "sparse",
422 | "lec",
423 | "us",
424 | "ns",
425 | ):
426 | metric_dict["sparsity"] = []
427 |
428 | # Get evaluation parameter grid
429 | eval_param_grid = regime_params["eval_param_grid"]
430 |
431 | save_dir = regime_params["save_dir"]
432 |
433 | # Training loop
434 | for epoch in range(start_epoch, start_epoch + epochs):
435 | scheduler(epoch, None)
436 | if epoch % test_freq == 0:
437 | for param in eval_param_grid:
438 | test(
439 | net,
440 | param,
441 | testloader,
442 | criterion,
443 | epoch,
444 | device,
445 | metric_dict=None,
446 | **regime_params,
447 | )
448 |
449 | model_logging.save_model_at_epoch(net, epoch, save_dir)
450 |
451 | _, regime_params = train(
452 | net,
453 | trainloader,
454 | optimizer,
455 | criterion,
456 | epoch,
457 | device,
458 | **regime_params,
459 | )
460 |
461 | # Save final model
462 | for param in eval_param_grid:
463 | metric_dict = test(
464 | net,
465 | param,
466 | testloader,
467 | criterion,
468 | epoch,
469 | device,
470 | metric_dict=metric_dict,
471 | **regime_params,
472 | )
473 |
474 | model_logging.save_model_at_epoch(net, epoch + 1, save_dir)
475 | model_logging.save("test_metrics.npy", metric_dict, save_dir)
476 |
--------------------------------------------------------------------------------
/train_quantized.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import args
6 | import main
7 |
8 | if __name__ == "__main__":
9 | args = args.quantized_args()
10 | main.train(args, "quantized")
11 |
--------------------------------------------------------------------------------
/train_structured.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import args
6 | import main
7 |
8 | if __name__ == "__main__":
9 | args = args.structured_args()
10 | main.train(args, "structured_sparsity")
11 |
--------------------------------------------------------------------------------
/train_unstructured.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import args
6 | import main
7 |
8 | if __name__ == "__main__":
9 | args = args.unstructured_args()
10 | main.train(args, "unstructured_sparsity")
11 |
--------------------------------------------------------------------------------
/training_params.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | # Architecture parameters
6 |
7 | DEFAULT_PARAMS = {
8 | "epochs": 200,
9 | "test_freq": 20,
10 | "batch_size": 128,
11 | "momentum": 0.9,
12 | "weight_decay": 0.0005,
13 | }
14 |
15 | DEFAULT_IMAGENET_PARAMS = {
16 | **DEFAULT_PARAMS,
17 | "epochs": 90,
18 | "test_freq": 1,
19 | "batch_size": 128,
20 | "weight_decay": 0.00005,
21 | }
22 |
23 | VGG_IMAGENET_PARAMS = {
24 | **DEFAULT_IMAGENET_PARAMS,
25 | "batch_size": 256,
26 | }
27 |
28 |
29 | def model_data_params(args):
30 | model = args.model
31 | dataset = args.dataset
32 |
33 | if dataset == "imagenet":
34 | if model == "vgg19":
35 | return VGG_IMAGENET_PARAMS
36 | elif model == "resnet18":
37 | return DEFAULT_IMAGENET_PARAMS
38 | else:
39 | raise NotImplementedError(
40 | f"No training parameters for {model}/{dataset}"
41 | )
42 | elif "cifar" in dataset:
43 | return DEFAULT_PARAMS
44 | else:
45 | raise NotImplementedError(
46 | f"No training parameters for {model}/{dataset}"
47 | )
48 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import collections
6 | import os
7 | import time
8 | from typing import Callable
9 | from typing import Dict
10 | from typing import List
11 | from typing import Tuple
12 | from typing import Union
13 |
14 | import numpy as np
15 | import torch
16 | import torch.backends.cudnn as cudnn
17 | import torch.nn as nn
18 | import torch.optim as optim
19 | import torchvision
20 | import torchvision.transforms as transforms
21 | import yaml
22 |
23 | import schedulers
24 | from models import modules
25 | from models import networks
26 | from models.builder import Builder
27 | from models.sparse_modules import SparseConv2d
28 |
29 |
30 | def get_cifar10_data(batch_size):
31 | print("==> Preparing CIFAR-10 data...")
32 | transform_train = transforms.Compose(
33 | [
34 | transforms.RandomCrop(32, padding=4),
35 | transforms.RandomHorizontalFlip(),
36 | transforms.ToTensor(),
37 | transforms.Normalize(
38 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
39 | ),
40 | ]
41 | )
42 |
43 | transform_test = transforms.Compose(
44 | [
45 | transforms.ToTensor(),
46 | transforms.Normalize(
47 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
48 | ),
49 | ]
50 | )
51 |
52 | trainset = torchvision.datasets.CIFAR10(
53 | root="./data", train=True, download=True, transform=transform_train
54 | )
55 | trainloader = torch.utils.data.DataLoader(
56 | trainset,
57 | batch_size=batch_size,
58 | shuffle=True,
59 | num_workers=2,
60 | pin_memory=True,
61 | )
62 |
63 | testset = torchvision.datasets.CIFAR10(
64 | root="./data", train=False, download=True, transform=transform_test
65 | )
66 | # Note that, we perform an analysis of BatchNorm statistics when validating,
67 | # so we *must* shuffle the validation set.
68 | testloader = torch.utils.data.DataLoader(
69 | testset,
70 | batch_size=batch_size,
71 | shuffle=True,
72 | num_workers=2,
73 | pin_memory=True,
74 | )
75 |
76 | return trainset, trainloader, testset, testloader
77 |
78 |
79 | def get_imagenet_data(data_dir, batch_size):
80 | print("==> Preparing ImageNet data...")
81 |
82 | traindir = os.path.join(data_dir, "training")
83 | valdir = os.path.join(data_dir, "validation")
84 | normalize = transforms.Normalize(
85 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
86 | )
87 |
88 | trainset = torchvision.datasets.ImageFolder(
89 | traindir,
90 | transforms.Compose(
91 | [
92 | transforms.RandomResizedCrop(224),
93 | transforms.RandomHorizontalFlip(),
94 | transforms.ToTensor(),
95 | normalize,
96 | ]
97 | ),
98 | )
99 |
100 | trainloader = torch.utils.data.DataLoader(
101 | trainset,
102 | batch_size=batch_size,
103 | shuffle=True,
104 | num_workers=16,
105 | pin_memory=True,
106 | )
107 |
108 | testset = torchvision.datasets.ImageFolder(
109 | valdir,
110 | transforms.Compose(
111 | [
112 | transforms.Resize(256),
113 | transforms.CenterCrop(224),
114 | transforms.ToTensor(),
115 | normalize,
116 | ]
117 | ),
118 | )
119 |
120 | # Note that, we perform an analysis of BatchNorm statistics when validating,
121 | # so we *must* shuffle the validation set.
122 | testloader = torch.utils.data.DataLoader(
123 | testset,
124 | batch_size=batch_size,
125 | shuffle=True,
126 | num_workers=16,
127 | pin_memory=True,
128 | )
129 |
130 | return trainset, trainloader, testset, testloader
131 |
132 |
133 | def get_yaml_config(config_file):
134 | print(f"Reading config {config_file}")
135 | with open(config_file) as f:
136 | config = yaml.load(f, Loader=yaml.FullLoader)
137 |
138 | return config
139 |
140 |
141 | def print_train_params(config, setting, method, norm, save_dir):
142 | params = config["parameters"]
143 |
144 | model = params["model_config"]["model_class"]
145 | dataset = params["dataset"]
146 |
147 | model_name_map = {
148 | "cpreresnet20": "cPreResNet20",
149 | "resnet18": "ResNet18",
150 | "vgg19": "VGG19",
151 | }
152 |
153 | method_name_map = {
154 | "lcs_l": "LCS+L",
155 | "lcs_p": "LCS+P",
156 | "ns": "NS",
157 | "us": "US",
158 | "lec": "LEC",
159 | "target_topk": "TopK Target",
160 | "target_bit_width": "Bit Width Target",
161 | }
162 |
163 | dataset_name_map = {"cifar10": "CIFAR-10", "imagenet": "ImageNet"}
164 |
165 | setting_name_map = {
166 | "unstructured_sparsity": "Unstructured sparsity",
167 | "structured_sparsity": "Structured sparsity",
168 | "quantized": "Quantized",
169 | }
170 |
171 | msg = f"{setting_name_map[setting]} ({method_name_map[method]}) training:"
172 | msg += f" {model_name_map[model]} on {dataset_name_map[dataset]} w/ {norm}."
173 |
174 | if method == "target_topk":
175 | topk_target = params["topk"]
176 | msg += f" TopK target: {topk_target}."
177 | elif method == "target_bit_width":
178 | bit_width_target = params["num_bits"]
179 | msg += f" Bit width target: {bit_width_target}."
180 |
181 | print()
182 | print(msg)
183 | print(f"Saving to {save_dir}.")
184 | print()
185 | time.sleep(5)
186 |
187 |
188 | def create_save_dir(base_save, method, model, dataset, norm, use_default=True):
189 | if use_default:
190 | save_dir = f"{base_save}/{method}/{model}/{dataset}/{norm}"
191 | else:
192 | save_dir = base_save
193 |
194 | if not os.path.isdir(save_dir):
195 | os.makedirs(save_dir, exist_ok=True)
196 |
197 | return save_dir
198 |
199 |
200 | def network_and_params(config=None):
201 | """
202 | Returns the network and training parameters for the specified model type.
203 | """
204 |
205 | model_config = config["parameters"].get("model_config")
206 | model_class = model_config["model_class"]
207 | model_class_dict = {
208 | "cpreresnet20": networks.cpreresnet,
209 | "resnet18": networks.ResNet18,
210 | "vgg19": networks.vgg19,
211 | }
212 | if model_config is not None:
213 | # Get the class.
214 | if model_class in model_class_dict:
215 | model_class = model_class_dict[model_class]
216 | else:
217 | raise NotImplementedError(
218 | f"Invalid model_class={model_config['model_class']}"
219 | )
220 | if "model_kwargs" in model_config:
221 | extra_model_kwargs = model_config["model_kwargs"]
222 | else:
223 | extra_model_kwargs = {}
224 | else:
225 | extra_model_kwargs = {}
226 |
227 | # General params
228 | epochs = config["parameters"].get("epochs", 200)
229 |
230 | test_freq = config["parameters"].get("test_freq", 20)
231 | batch_size = config["parameters"].get("batch_size", 128)
232 | learning_rate = config["parameters"].get("learning_rate", 0.01)
233 | momentum = config["parameters"].get("momentum", 0.9)
234 | weight_decay = config["parameters"].get("weight_decay", 0.0005)
235 | warmup_budget = config["parameters"].get("warmup_budget", 80) / 100.0
236 | dataset = config["parameters"].get("dataset", "cifar10")
237 | alpha_grid = config["parameters"].get("alpha_grid", None)
238 |
239 | regime = config["parameters"]["regime"]
240 |
241 | if dataset == "cifar10":
242 | data = get_cifar10_data(batch_size)
243 | elif dataset == "imagenet":
244 | imagenet_dir = config["parameters"]["dataset_dir"]
245 | data = get_imagenet_data(imagenet_dir, batch_size)
246 | else:
247 | raise ValueError(f"Dataset {dataset} not supported")
248 |
249 | train_size = len(data[0])
250 | warmup_iterations = np.ceil(
251 | warmup_budget * epochs * train_size / batch_size
252 | )
253 |
254 | # Get model layers
255 | conv_type = config["parameters"]["conv_type"]
256 | bn_type = config["parameters"]["bn_type"]
257 | block_conv_type = config["parameters"]["block_conv_type"]
258 | block_bn_type = config["parameters"]["block_bn_type"]
259 |
260 | # Get regime-specific parameters
261 | regime_params = config["parameters"].get("regime_params", {})
262 | regime_params["regime"] = regime
263 | builder_parms = {}
264 | block_builder_params = {}
265 |
266 | # Append dataset to extra_model_kwargs args
267 | extra_model_kwargs["dataset"] = dataset
268 |
269 | if regime == "sparse":
270 | if "Sparse" not in block_conv_type:
271 | raise ValueError(
272 | "Regime set to sparse but non-sparse convolution layer received..."
273 | )
274 | regime_params["topk"] = config["parameters"].get("topk", 0.0)
275 | regime_params["current_iteration"] = 0
276 | regime_params["warmup_iterations"] = warmup_iterations
277 | regime_params["alpha_sampling"] = config["parameters"].get(
278 | "alpha_sampling", [0.025, 1, 0]
279 | )
280 |
281 | method = config["parameters"].get("method", "topk")
282 | block_builder_params["method"] = method
283 | if "Sparse" in conv_type:
284 | builder_parms["method"] = method
285 |
286 | elif regime == "lec":
287 | regime_params["topk"] = config["parameters"].get("topk", 0.0)
288 | regime_params["current_iteration"] = 0
289 | regime_params["warmup_iterations"] = warmup_iterations
290 | regime_params["alpha_sampling"] = config["parameters"].get(
291 | "alpha_sampling", [0, 1, 0]
292 | )
293 | regime_params["model_kwargs"] = {
294 | "dataset": dataset,
295 | **extra_model_kwargs,
296 | }
297 | regime_params["bn_update_factor"] = config["parameters"].get(
298 | "bn_update_factor", 0
299 | )
300 |
301 | regime_params["bn_type"] = config["parameters"]["bn_type"]
302 |
303 | elif regime == "ns":
304 | width_factors_list = config["parameters"]["builder_kwargs"][
305 | "width_factors_list"
306 | ]
307 | regime_params["width_factors_list"] = width_factors_list
308 |
309 | builder_parms["pass_first_last"] = True
310 |
311 | block_builder_params["pass_first_last"] = True
312 |
313 | if config["parameters"]["block_conv_type"] != "AdaptiveConv2d":
314 | block_builder_params["width_factors_list"] = width_factors_list
315 | builder_parms["width_factors_list"] = width_factors_list
316 |
317 | if "BN" in config["parameters"]["bn_type"]:
318 | norm_kwargs = config["parameters"].get("norm_kwargs", {})
319 |
320 | builder_parms["norm_kwargs"] = {
321 | "width_factors_list": width_factors_list,
322 | **norm_kwargs,
323 | }
324 | block_builder_params["norm_kwargs"] = {
325 | "width_factors_list": width_factors_list,
326 | **norm_kwargs,
327 | }
328 |
329 | regime_params["bn_type"] = config["parameters"]["bn_type"]
330 |
331 | elif regime == "us":
332 | builder_parms["pass_first_last"] = True
333 | block_builder_params["pass_first_last"] = True
334 |
335 | if "BN" in config["parameters"]["bn_type"]:
336 | norm_kwargs = config["parameters"]["norm_kwargs"]
337 | assert "width_factors_list" in norm_kwargs
338 |
339 | block_builder_params["norm_kwargs"] = norm_kwargs
340 | builder_parms["norm_kwargs"] = norm_kwargs
341 |
342 | elif regime == "quantized":
343 | if "ConvBn2d" not in block_conv_type:
344 | raise ValueError(
345 | "Regime set to quanitzed but non-quantized convolution layer received..."
346 | )
347 | block_builder_params["num_bits"] = config["parameters"].get(
348 | "num_bits", 8
349 | )
350 | block_builder_params["iteration_delay"] = warmup_iterations
351 |
352 | if conv_type == "ConvBn2d":
353 | builder_parms["num_bits"] = config["parameters"].get("num_bits", 8)
354 | builder_parms["iteration_delay"] = warmup_iterations
355 |
356 | regime_params["min_bits"] = config["parameters"].get("min_bits", 2)
357 | regime_params["max_bits"] = config["parameters"].get("max_bits", 8)
358 | regime_params["num_bits"] = config["parameters"].get("num_bits", 8)
359 | regime_params["discrete"] = config["parameters"].get(
360 | "discrete_alpha_map", False
361 | )
362 |
363 | regime_params["is_standard"] = config["parameters"].get(
364 | "is_standard", False
365 | )
366 | regime_params["random_param"] = config["parameters"].get(
367 | "random_param", False
368 | )
369 | regime_params["num_points"] = config["parameters"].get("num_points", 0)
370 |
371 | # Evaluation parameters for independent models
372 | regime_params["eval_param_grid"] = config["parameters"].get(
373 | "eval_param_grid", None
374 | )
375 |
376 | # If norm_kwargs haven't been set, and they are present, add them to
377 | # builder_params and block_builder_params.
378 | if "norm_kwargs" not in builder_parms:
379 | builder_parms["norm_kwargs"] = config["parameters"].get(
380 | "norm_kwargs", {}
381 | )
382 | if "norm_kwargs" not in block_builder_params:
383 | block_builder_params["norm_kwargs"] = config["parameters"].get(
384 | "norm_kwargs", {}
385 | )
386 |
387 | # Construct network
388 | builder = Builder(conv_type=conv_type, bn_type=bn_type, **builder_parms)
389 | block_builder = Builder(
390 | block_conv_type, block_bn_type, **block_builder_params
391 | )
392 |
393 | net = model_class(
394 | builder=builder, block_builder=block_builder, **extra_model_kwargs
395 | )
396 |
397 | # Input size
398 | regime_params["input_size"] = get_input_size(dataset)
399 |
400 | # Save directory
401 | regime_params["save_dir"] = config["parameters"]["save_dir"]
402 |
403 | criterion = nn.CrossEntropyLoss()
404 | optimizer = optim.SGD(
405 | net.parameters(),
406 | lr=learning_rate,
407 | momentum=momentum,
408 | weight_decay=weight_decay,
409 | )
410 |
411 | scheduler = schedulers.cosine_lr(
412 | optimizer, learning_rate, warmup_length=5, epochs=epochs
413 | )
414 |
415 | train_params = [epochs, test_freq, alpha_grid]
416 | opt_params = [criterion, optimizer, scheduler]
417 |
418 | print(f"Got network:\n{net}")
419 |
420 | return net, opt_params, train_params, regime_params, data
421 |
422 |
423 | def get_net(model_dir, net, num_gpus=1):
424 |
425 | device = "cuda" if torch.cuda.is_available() else "cpu"
426 |
427 | if device == "cuda":
428 | cudnn.benchmark = True
429 | state_dict = torch.load(model_dir)
430 | else:
431 | state_dict = torch.load(model_dir, map_location=torch.device("cpu"))
432 |
433 | # Set device ids for DataParallel.
434 | if device == "cuda" and num_gpus > 0:
435 | device_ids = list(range(num_gpus))
436 | else:
437 | device_ids = None
438 | network = torch.nn.DataParallel(net, device_ids=device_ids)
439 | network.load_state_dict(state_dict)
440 |
441 | return network
442 |
443 |
444 | def get_sparsity_rate(model) -> float:
445 | total_params = 0
446 | sparse_params = 0
447 | for module in model.modules():
448 | if isinstance(module, SparseConv2d):
449 | sparse_weight = module.apply_sparsity(module.weight)
450 | total_params += sparse_weight.numel()
451 | sparse_params += (
452 | (sparse_weight == 0).float().sum().item()
453 | ) # changed fuse weight to sparse weight
454 | elif isinstance(
455 | module, (nn.Linear, nn.Conv2d)
456 | ): # Make sure we also count contributions from non-sparse elements.
457 | total_params += module.weight.numel()
458 | # Note: bias parameters are a drop in the bucket.
459 | sparsity_rate = sparse_params / np.max([total_params, 1])
460 | return sparsity_rate
461 |
462 |
463 | def apply_sparsity(
464 | topk, weight: torch.Tensor, return_scale_factors=False, *, method
465 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
466 | if method == "topk":
467 | return apply_topk(
468 | topk, weight, return_scale_factors=return_scale_factors
469 | )
470 | else:
471 | raise NotImplementedError(f"Invalid sparsity method {method}")
472 |
473 |
474 | def apply_topk(topk: float, weight: torch.Tensor, return_scale_factors=False):
475 | """
476 | Given a weight tensor, retains the top self.current_topk of the weights and
477 | multiplies the rest by 0
478 | Inputs:
479 | - weight: A weight tensor, e.g., self.weight
480 | """
481 | # Retain only the topk weights, multiplying the rest by 0.
482 | frac_to_zero = 1 - topk
483 | with torch.no_grad():
484 | flat_weight = weight.flatten()
485 | # Want to convert it away from a special tensor, hence the float() call.
486 | _, idx = flat_weight.float().abs().sort()
487 | # @idx is a @special_tensors._SpecialTensor, but we need to convert it
488 | # to a normal tensor for indexing to work properly.
489 | idx = torch.tensor(idx, requires_grad=False)
490 | f = int(frac_to_zero * weight.numel())
491 | scale_factors = torch.ones_like(flat_weight, requires_grad=False)
492 | scale_factors[idx[:f]] = 0
493 | scale_factors = scale_factors.view_as(weight)
494 |
495 | ret = weight * scale_factors
496 |
497 | if return_scale_factors:
498 | return ret, scale_factors
499 |
500 | return ret
501 |
502 |
503 | def apply_in_topk_reg(model, apply_to_bias=False):
504 | loss = None
505 | count = 0
506 | for m in model.modules():
507 | if isinstance(m, modules.SubspaceIN):
508 | count += 1
509 | w, b = m.get_weight()
510 |
511 | if loss is None:
512 | loss = torch.tensor(0.0, device=w.device)
513 | loss += w.abs().sum()
514 |
515 | if apply_to_bias:
516 | loss += b.abs().sum()
517 |
518 | if count == 0:
519 | raise ValueError(f"Didn't adjust any Norms")
520 |
521 | return loss
522 |
523 |
524 | def get_norm_type_string(model):
525 | bn_count = 0
526 | in_count = 0
527 | for m in model.modules():
528 | if isinstance(m, nn.BatchNorm2d):
529 | bn_count += 1
530 | elif isinstance(m, nn.InstanceNorm2d):
531 | in_count += 1
532 |
533 | if bn_count > 0:
534 | assert in_count == 0, "Got both BN and IN"
535 | return "StandardBN"
536 | elif in_count > 0:
537 | assert bn_count == 0, "Got both BN and IN"
538 | return "StandardIN"
539 | else:
540 | raise ValueError(f"No norm layers detected.")
541 |
542 |
543 | def make_fresh_copy_of_pruned_network(model: nn.Module, model_kwargs: Dict):
544 | norm_type_string = get_norm_type_string(model)
545 | builder = Builder(conv_type="StandardConv", bn_type=norm_type_string)
546 |
547 | copy = type(model.module)(
548 | builder=builder, block_builder=builder, **model_kwargs
549 | ) # type: nn.Module
550 | # Need to move @copy to GPU before moving to DataParallel.
551 | if next(model.parameters()).is_cuda:
552 | copy = copy.cuda()
553 | copy = nn.DataParallel(copy)
554 |
555 | state_dict = model.state_dict()
556 | del_me = []
557 | for k, v in state_dict.items():
558 | if k.endswith(f"1"):
559 | del_me.append(k)
560 |
561 | for elem in del_me:
562 | del state_dict[elem]
563 |
564 | copy.load_state_dict(state_dict)
565 |
566 | # The only part we should need to fix are modules with a get_weight()
567 | # function.
568 | name_to_copy = {name: module for name, module in copy.named_modules()}
569 |
570 | for name, module in model.named_modules():
571 | if hasattr(module, "get_weight"):
572 | print(f"Adjusting weight at module {name}")
573 |
574 | pieces = module.get_weight()
575 |
576 | if len(pieces) == 1:
577 | name_to_copy[name].weight.data = pieces
578 | else:
579 | assert len(pieces) == 2, f"Invalid len(pieces)={len(pieces)}"
580 | name_to_copy[name].weight.data = pieces[0]
581 | name_to_copy[name].bias.data = pieces[1]
582 |
583 | return copy
584 |
585 |
586 | def get_input_size(dataset):
587 | return {
588 | "cifar10": 32,
589 | "imagenet": 224,
590 | }[dataset]
591 |
592 |
593 | def register_bn_tracking_hook(module, mean_dict, var_dict, name):
594 | def forward_hook(_, x, __):
595 | x = x[0]
596 | reshaped_x = x.permute(1, 0, 2, 3).reshape(x.shape[1], -1)
597 |
598 | mean_dict[name].append(reshaped_x.mean(dim=1).detach().cpu())
599 | var_dict[name].append(reshaped_x.var(dim=1).detach().cpu())
600 |
601 | return module.register_forward_hook(forward_hook)
602 |
603 |
604 | def register_bn_tracking_hooks(model: nn.Module):
605 | mean_dict = collections.defaultdict(list)
606 | var_dict = collections.defaultdict(list)
607 |
608 | hooks = []
609 | for name, module in model.named_modules():
610 | # NOTE: we omit GroupNorm from this check, because it doesn't track
611 | # running stats (there's no option to do so in PyTorch).
612 | if (
613 | isinstance(module, nn.modules.batchnorm._NormBase)
614 | and module.track_running_stats
615 | ):
616 | hooks.append(
617 | register_bn_tracking_hook(module, mean_dict, var_dict, name)
618 | )
619 | return hooks, mean_dict, var_dict
620 |
621 |
622 | def unregister_bn_tracking_hooks(hooks: List[Callable]):
623 | for hook in hooks:
624 | hook.remove()
625 |
626 |
627 | def get_bn_accuracy_metrics(model: nn.Module, mean_dict: Dict, var_dict: Dict):
628 | """Determine how accurate the running_mean and running_var of the BatchNorms
629 | are.
630 |
631 | Note that, the test set should be shuffled during training, or these
632 | statistics won't be valid.
633 |
634 | Arguments:
635 | model: the network.
636 | mean_dict: A dictionary that looks like:
637 | {'layer_name': [batch1_mean, batch2_mean, ...]}
638 | var_dict: Similar to mean_dict, but with variances.
639 | """
640 | mean_results = collections.defaultdict(list)
641 | var_results = collections.defaultdict(list)
642 |
643 | name_to_module = {name: module for name, module in model.named_modules()}
644 |
645 | for name in mean_dict.keys():
646 | module = name_to_module[name]
647 |
648 | running_mean = module.running_mean.detach().cpu()
649 | assert isinstance(mean_dict[name], list)
650 | for batch_result in mean_dict[name]:
651 | num_channels = batch_result.shape[0]
652 | mean_abs_diff = (
653 | (batch_result - running_mean[:num_channels]).abs().mean().item()
654 | )
655 | mean_results[name].append(mean_abs_diff)
656 |
657 | running_var = module.running_var.detach().cpu()
658 | assert isinstance(var_dict[name], list)
659 | for batch_result in var_dict[name]:
660 | num_channels = batch_result.shape[0]
661 | mean_abs_diff = (
662 | (batch_result - running_var[:num_channels]).abs().mean().item()
663 | )
664 | var_results[name].append(mean_abs_diff)
665 |
666 | # For each layer, record the mean and std of the average deviations, for
667 | # both running_mean and running_var.
668 | ret = {}
669 | for name, stats in mean_results.items():
670 | ret[f"{name}_running_mean_MAD_mean"] = torch.tensor(stats).mean().item()
671 | ret[f"{name}_running_mean_MAD_std"] = torch.tensor(stats).std().item()
672 | for name, stats in var_results.items():
673 | ret[f"{name}_running_var_MAD_mean"] = torch.tensor(stats).mean().item()
674 | ret[f"{name}_running_var_MAD_std"] = torch.tensor(stats).std().item()
675 | return {"bn_metrics": ret}
676 |
--------------------------------------------------------------------------------