├── ACKNOWLEDGMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── args.py ├── configs ├── quantized │ ├── lcs_l.yaml │ ├── lcs_p.yaml │ └── target_bit_width.yaml ├── structured_sparsity │ ├── lcs_l.yaml │ ├── lcs_p.yaml │ ├── lec.yaml │ ├── ns.yaml │ └── us.yaml └── unstructured_sparsity │ ├── lcs_l.yaml │ ├── lcs_p.yaml │ └── target_topk.yaml ├── curve_utils.py ├── get_training_params.py ├── main.py ├── model_logging.py ├── models ├── builder.py ├── init.py ├── modules.py ├── networks │ ├── __init__.py │ ├── channel_selection.py │ ├── cpreresnet.py │ ├── model_profiling.py │ ├── resnet.py │ ├── resprune.py │ ├── utils.py │ ├── vgg.py │ └── vggprune.py ├── quantize_affine.py ├── quantized_modules.py ├── sparse_modules.py └── special_tensors.py ├── requirements.txt ├── schedulers.py ├── setup.sh ├── tools ├── format-python └── remove-whitespace.sh ├── train_curve.py ├── train_indep.py ├── train_quantized.py ├── train_structured.py ├── train_unstructured.py ├── training_params.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2021 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Compressible Subspaces 2 | 3 | This is the official code release for our publication, [LCS: Learning Compressible Subspaces for Adaptive Network Compression at Inference Time](https://arxiv.org/abs/2110.04252). Our code is used to train and evaluate models that can be compressed in real-time after deployment, allowing for a fine-grained efficiency-accuracy trade-off. 4 | 5 | This repository hosts code to train compressible subspaces for structured sparsity, unstructured sparsity, and quantization. We support three architectures: cPreResNet20, ResNet18, and VGG19. Training is performed on the CIFAR-10 and ImageNet datasets. 6 | 7 | Default training configurations are provided in the `configs` folder. Note that they are automatically altered when different models and datasets are chosen through flags. See `training_params.py`. The following training parameter flags are available to all training regimes: 8 | 9 | - `--model`: Specifies the model to use. One of cpreresnet20, resnet18, or vgg19. 10 | - `--dataset`: Specifies the dataset to train on. One of cifar10 or imagenet. 11 | - `--imagenet_dir`: When using imagenet dataset, the directory to the dataset must be specified. 12 | - `--method`: Specifies the training method. For unstructured sparsity, one of target_topk, lcs_l, lcs_p. For structured sparsity, one of lec, ns, us, lcs_l, lcs_p. For quantized models, one of target_bit_width, lcs_l, lcs_p. 13 | - `--norm`: The normalization layers to use. One of IN (instance normalization), BN (batch normalization), or GN (group normalization). 14 | - `--epochs`: The number of epochs to train for. 15 | - `--learning_rate`: The optimizer learning rate. 16 | - `--batch_size`: Training and test batch sizes. 17 | - `--momentum`: The optimizer momentum. 18 | - `--weight_decay`: The L2 regularization weight. 19 | - `--warmup_budget`: The percentage of epochs to use for the training method warmup phase. 20 | - `--test_freq`: The number of training epochs to wait between test evaluation. Will also save models at this frequency. 21 | 22 | The "lcs_l" training method refers to the "LCS+L" method in the paper. In this setting, we train a linear subspace where one end is optimized for efficiency, while the other end prioritizes accuracy. The "lcs_p" training method refers to the "LCS+P" in the paper and trains a degenerate subspace conditioned to perform at arbitrary sparsity rates in the unstructured and structured sparsity settings, or bit widths in the quantized setting. 23 | 24 | ## Structured Sparsity 25 | 26 | In the structured sparsity setting, we support five training methods: 27 | 28 | 1. "lcs_l" -- This refers to the LCS+L method where one end of the linear subspace performs at high sparsity rates while the other performs at zero sparsity. 29 | 2. "lcs_p" -- This refers to the LCS+P method where we train a degenerate subspace conditioned to perform at arbitrary sparsity rates. 30 | 3. "lec" -- This refers to the method introduced in "Learning Efficient Convolutional Networks through Network Slimming" by Liu et al. (2017). We do not perform fine-tuning, as described in our paper. 31 | 4. "ns" -- This refers to the method introduced in "Slimmable Neural Networks" by Yu et al. (2018). We use a single BatchNorm to allow for evaluation at arbitrary width factors, as decribed in our paper. 32 | 5. "us" -- This refers to the method introduced in "Universally Slimmable Networks and Improved Training Techniques" by Yu & Huang (2019). We do not recalibrate BatchNorms (to facilitate on-device compression to arbitrary widths), as described in our paper. 33 | 34 | Training a model in the structured sparsity setting can be accomplished by running the following command: 35 | 36 | > python train_structured.py 37 | 38 | By default, the command above will train the cPreResNet20 architecture on CIFAR-10 using instance normalization layers with the LCS+L method. To specify the model, dataset, normalization, and training method, the flags `--model`, `--dataset`, `--norm`, `--method` can be used. The following command 39 | 40 | > python train_structured.py --model resnet18 --dataset imagenet --norm IN --method lcs_p --imagenet_dir 41 | 42 | will train a ResNet18 point subspace (LCS+P) on ImageNet using instance normalization layers and the parameters from our paper. 43 | 44 | In addition to the global flags above, the structured setting also has the following: 45 | 46 | - `--width_factors_list`: When training using the "ns" method, this sets the width factors at which the model will be trained. 47 | - `--width_factor_limits`: When training using the "us", "lcs_l", or "lcs_p" methods, sets the lower and upper width factor limits. 48 | - `--width_factor_samples`: When training using the "us", "lcs_l", or "lcs_p" methods, sets the number of samples to use for the sandwich rule. Two of these will be the samples from the width factor limits. 49 | - `--eval_width_factors`: Sets the width factors to evaluate the model for all training methods. 50 | 51 | The command 52 | 53 | > python train_structured.py --model cpreresnet20 --dataset cifar10 --norm BN --method ns --width_factors_list 0.25,0.5,0.75,1.0 54 | 55 | will train a cPreResNet20 architecture on CIFAR-10 via the NS method. 56 | 57 | ## Unstructured Sparsity 58 | 59 | In the unstructured sparsity setting, we support three training methods: 60 | 61 | 1. "lcs_l" -- This refers to the LCS+L method where one end of the linear subspace performs at high sparsity rates while the other performs at zero sparsity. 62 | 2. "lcs_p" -- This refers to the LCS+P method where we train a degenerate subspace conditioned to perform at arbitrary sparsity rates. 63 | 3. "target_topk" -- this will train a network optimized to perform well at a specified TopK target. 64 | 65 | Training a model in the unstructured sparsity setting can be accomplished by running the following command: 66 | 67 | > python train_unstructured.py 68 | 69 | By default, the command above will train the cPreResNet20 architecture on CIFAR-10 using group normalization layers with the LCS+L method and the parameters used described in our paper. To specify the model, dataset, normalization, and training method, the flags `--model`, `--dataset`, `--norm`, `--method` can be used. The following command 70 | 71 | > python train_unstructured.py --model resnet18 --dataset imagenet --norm GN --method lcs_p --imagenet_dir 72 | 73 | will train a ResNet18 point subspace (LCS+P) on ImageNet using group normalization layers again using the parameters from our paper. 74 | 75 | The command 76 | 77 | > python train_unstructured.py --model resnet18 --dataset imagenet --method target_topk --topk 0.5 --imagenet_dir 78 | 79 | will train a VGG19 architecture optimized to perform at a TopK value of 0.5. 80 | 81 | In addition to the global flags above, the unstructured setting also has the following: 82 | 83 | - `--topk`: When training using the "target_topk" method, this sets the target TopK value. 84 | - `--eval_topk_grid`: Will evaluate the model at these TopK values. 85 | - '--topk_lower_bound': The lower bound TopK value (1-sparsity) to be used for training. For linear subspaces, one end of the line will be optimized for sparsity 1-topk_lower_bound which corresponds to the high accuracy endpoint. Note: If specified, eval_topk_grid must be specified as well. 86 | - '--topk_upper_bound': The upper bound TopK value (1-sparsity) to be used for training. For linear subspaces, one end of the line will be optimized for sparsity 1-topk_upper_bound which corresponds to the high efficiency endpoint. Note: If specified, eval_topk_grid must be specified as well. 87 | 88 | The following command 89 | 90 | > python train_unstructured.py --model cpreresnet20 --dataset cifar10 --norm GN --method lcs_p --topk_lower_bound 0.005 --topk_upper_bound 0.05 --eval_topk_grid 0.005,0.01,0.015,0.02,0.025,0.03,0.035,0.04,0.045,0.05 91 | 92 | will train a point subspace with high sparsity. 93 | 94 | ## Quantization 95 | 96 | In the quantized setting, we support three training methods: 97 | 98 | 1. "lcs_l" -- This refers to the LCS+L method where one end of the linear subspace performs at a low bit width while the other performs at a hight bit width. 99 | 2. "lcs_p" -- This refers to the LCS+P method where we train a degenerate subspace conditioned to perform at arbitrary bit widths in a range. 100 | 3. "target_bit_width" -- This trains a network optimized to perform at a specified bit width. 101 | 102 | Training a model in the structured sparsity setting can be accomplished by running the following command: 103 | 104 | > python train_quantized.py 105 | 106 | By default, the command above will train the cPreResNet20 architecture on CIFAR-10 using group normalization layers with the LCS+L method with a bit range [3,8]. To specify the model, dataset, normalization, and training method, the flags `--model`, `--dataset`, `--norm`, `--method` can be used. The following command 107 | 108 | > python train_quantized.py --model vgg19 --dataset imagenet --norm GN --method lcs_p --imagenet_dir 109 | 110 | will train a ResNet18 point subspace (LCS+P) on ImageNet using group normalization layers. 111 | 112 | In addition to the global flags above, the quantized setting also has the following: 113 | 114 | - `--bit_width`: When training using the "target_bit_width" method, this sets the target bit width. 115 | - `--eval_bit_widths`: Will evaluate models at these bit widths. 116 | - `--bit_width_limits`: This sets the upper and lower bit width bounds to use for training. 117 | 118 | The following command 119 | 120 | > python train_quantized.py --model cpreresnet20 --dataset cifar10 --norm GN --method lcs_l --bit_width_limits 3,8 --eval_bit_widths 3,4,5,6,7,8 121 | 122 | will train a linear subspace cPreResNet20 model with GN layers on the ImageNet dataset and will be optimized so that one end of the line performs at 3 bits, and the other at 8. 123 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | 7 | 8 | def gen_args(desc): 9 | parser = argparse.ArgumentParser(description=desc) 10 | 11 | parser.add_argument( 12 | "--model", 13 | type=str, 14 | default="cpreresnet20", 15 | help="Which model architecture to use. One of crepresnet20, resnet18, vgg19", 16 | ) 17 | 18 | parser.add_argument( 19 | "--dataset", 20 | type=str, 21 | default="cifar10", 22 | help="Which dataset to use. One of cifar10, imagenet", 23 | ) 24 | 25 | parser.add_argument( 26 | "--imagenet_dir", 27 | type=str, 28 | help="The root directory to the ImageNet dataset", 29 | ) 30 | 31 | parser.add_argument("--save_dir", type=str, help="Directory to save model") 32 | 33 | parser.add_argument( 34 | "--norm", 35 | type=str, 36 | help="Layer normalization type. One of BN, IN, GN", 37 | ) 38 | 39 | parser.add_argument("--epochs", type=int, help="Number of training epochs") 40 | 41 | parser.add_argument( 42 | "--test_freq", type=int, help="Number of epochs between testing" 43 | ) 44 | 45 | parser.add_argument( 46 | "--learning_rate", type=float, help="Optimizer learning rate" 47 | ) 48 | 49 | parser.add_argument( 50 | "--batch_size", type=int, help="Training/test batch size" 51 | ) 52 | 53 | parser.add_argument("--momentum", type=float, help="Optimizer momentum") 54 | 55 | parser.add_argument( 56 | "--weight_decay", type=float, help="L2 regularization parameter" 57 | ) 58 | 59 | return parser 60 | 61 | 62 | def parse_unstructured_arguments(): 63 | parser = gen_args("Unstructured Sparsity Training") 64 | 65 | parser.add_argument( 66 | "--method", 67 | type=str, 68 | default="lcs_l", 69 | help="Training method. One of topk_target, lcs_p, lcs_l", 70 | ) 71 | 72 | parser.add_argument( 73 | "--topk", 74 | type=float, 75 | help="Target topk (0,1) for BN training", 76 | ) 77 | 78 | parser.add_argument( 79 | "--topk_lower_bound", 80 | type=float, 81 | help="Lower bound (high accuracy) endpoint of the learned subspace", 82 | ) 83 | 84 | parser.add_argument( 85 | "--topk_upper_bound", 86 | type=float, 87 | help="Upper bound (high efficiency) endpoint of the learned subspace", 88 | ) 89 | 90 | parser.add_argument( 91 | "--warmup_budget", 92 | type=float, 93 | help="Value in range (0,100] denoting the percentage of epochs for warmup phase", 94 | ) 95 | 96 | parser.add_argument( 97 | "--eval_topk_grid", 98 | type=lambda x: [float(w) for w in x.split(",")], 99 | help="Will evaluate at these topk values", 100 | ) 101 | 102 | return parser.parse_args() 103 | 104 | 105 | def parse_structured_arguments(): 106 | parser = gen_args("Structured Sparsity Training") 107 | 108 | parser.add_argument( 109 | "--method", 110 | type=str, 111 | default="lcs_l", 112 | help="Training method. One of lec, ns, us, lcs_l, lcs_p", 113 | ) 114 | 115 | parser.add_argument( 116 | "--width_factors_list", 117 | type=lambda x: [float(w) for w in x.split(",")], 118 | help="Desired width factors for NS. Ex: --width_factors_list 0.25,0.5,0.75,1.0", 119 | ) 120 | 121 | parser.add_argument( 122 | "--width_factor_limits", 123 | type=lambda x: [float(w) for w in x.split(",")], 124 | help="US width factor lower and upper bounds. Ex: --width_factor_limits 0.25,1.0", 125 | ) 126 | 127 | parser.add_argument( 128 | "--width_factor_samples", 129 | type=int, 130 | help="Number of width factor samples for US sandwich rule", 131 | ) 132 | 133 | parser.add_argument( 134 | "--eval_width_factors", 135 | "--list", 136 | type=lambda x: [float(w) for w in x.split(",")], 137 | help="Width factors at which to evaluate model. Ex: --eval_width_factors 0.25,0.5,0.75,1.0", 138 | ) 139 | 140 | return parser.parse_args() 141 | 142 | 143 | def parse_quantized_arguments(): 144 | parser = gen_args("Quantized Training") 145 | 146 | parser.add_argument( 147 | "--method", 148 | type=str, 149 | default="lcs_l", 150 | help="Training method. One of target_bit_width, lcs_p, lcs_l", 151 | ) 152 | 153 | parser.add_argument( 154 | "--bit_width", 155 | type=int, 156 | help="Target bit width after warmup phase. Used for target_bit_width training", 157 | ) 158 | 159 | parser.add_argument( 160 | "--eval_bit_widths", 161 | type=lambda x: [float(w) for w in x.split(",")], 162 | help="Number of bits at which to evaluate. Ex: --eval_num_bits 3,4,5", 163 | ) 164 | 165 | parser.add_argument( 166 | "--bit_width_limits", 167 | type=lambda x: [float(w) for w in x.split(",")], 168 | help="Min/max number of bits to train line. Ex: --bit_width_limits 3,8", 169 | ) 170 | 171 | return parser.parse_args() 172 | 173 | 174 | def validate_model_data(args): 175 | """ 176 | Ensures that the specified model and dataset are compatible 177 | """ 178 | implemented_models = ["cpreresnet20", "resnet18", "vgg19"] 179 | implemented_datasets = ["cifar10", "imagenet"] 180 | 181 | if args.model not in implemented_models: 182 | raise ValueError(f"{args.model} not implemented.") 183 | 184 | if args.dataset not in implemented_datasets: 185 | raise ValueError(f"{args.dataset} not implemented.") 186 | 187 | if args.dataset == "imagenet": 188 | if args.model not in ("resnet18", "vgg19"): 189 | raise ValueError( 190 | f"{args.model} does not support ImageNet. Supported models: resnet18, vgg19" 191 | ) 192 | elif args.dataset == "cifar10": 193 | if args.model not in ("cpreresnet20"): 194 | raise ValueError( 195 | f"{args.model} does not support CIFAR. Supported models: cpreresnet20" 196 | ) 197 | 198 | return args 199 | 200 | 201 | def validate_unstructured_params(args): 202 | """ 203 | Esnures that the specified unstructured sparsity parameters are valid 204 | """ 205 | lb = args.topk_lower_bound 206 | ub = args.topk_upper_bound 207 | 208 | if (lb is not None and ub is None) or (lb is None and ub is not None): 209 | raise ValueError("Both upper and lower TopK bounds must be specified") 210 | 211 | if lb is not None: 212 | if lb < 0: 213 | raise ValueError("TopK lower bound must be >= 0") 214 | if ub > 1: 215 | raise ValueError("TopK upper bound must be <= 1") 216 | if lb >= ub: 217 | raise ValueError("TopK lower bound must be < upper bound") 218 | if args.eval_topk_grid is None: 219 | raise ValueError( 220 | "eval_topk_grid must be specified when TopK bounds are" 221 | ) 222 | 223 | return args 224 | 225 | 226 | def validate_structured_params(args): 227 | """ 228 | Ensures that the specified structured sparsity parameters are valid 229 | """ 230 | if args.method in ("lec", "ns"): 231 | if args.norm is not None and args.norm != "BN": 232 | raise ValueError("LEC and NS only implemented for BN layers") 233 | 234 | if args.method == "lec" and args.model not in ("cpreresnet20", "vgg19"): 235 | raise ValueError("LEC only implemented for cpreresnet20, vgg19") 236 | 237 | return args 238 | 239 | 240 | def validate_quant_params(args): 241 | """ 242 | Ensures that the specified quantized parameters are valid 243 | """ 244 | bit_range = [3, 4, 5, 6, 7, 8] 245 | 246 | if args.method == "target_bit_width": 247 | if args.bit_width is not None and args.bit_width not in bit_range: 248 | raise ValueError("Bit width must be one of 3, 4, 5, 6, 7, 8.") 249 | 250 | if args.bit_width_limits is not None: 251 | l_bit, u_bit = args.bit_width_limits 252 | if l_bit < 3: 253 | raise ValueError("Smallest bit width must be >= 3") 254 | 255 | if u_bit > 8: 256 | raise ValueError("Largest bit width must be <= 8") 257 | 258 | return args 259 | 260 | 261 | def unstructured_args(): 262 | args = parse_unstructured_arguments() 263 | args = validate_unstructured_params(args) 264 | 265 | return validate_model_data(args) 266 | 267 | 268 | def structured_args(): 269 | args = parse_structured_arguments() 270 | args = validate_structured_params(args) 271 | 272 | return validate_model_data(args) 273 | 274 | 275 | def quantized_args(): 276 | args = parse_quantized_arguments() 277 | args = validate_quant_params(args) 278 | 279 | return validate_model_data(args) 280 | -------------------------------------------------------------------------------- /configs/quantized/lcs_l.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "quantized" 7 | dataset: "cifar10" 8 | script: "train_curve.py" 9 | save_dir: "saved_models/quantized/" 10 | # Line: 11 | num_points: 2 12 | conv_type: "LinesConv" # First convolution layer 13 | bn_type: "LinesGN" # Batchnorm for first convolution layer 14 | block_conv_type: "LinesConvBn2d" # Convolution layer for blocks 15 | block_bn_type: "LinesGN" # Convolution batchnorm layer for blocks 16 | epochs: 200 # The number of epochs to train for 17 | batch_size: 128 # Train and test batch sizes 18 | warmup_budget: 80 # percentage of epochs for "warmup" phase 19 | test_freq: 20 # Will run test after this many epochs 20 | alpha_grid: [0.166, 0.333, 0.499, 0.666, 0.833, 1] # When bit widths {3, ..., 8} 21 | learning_rate: 0.025 # Optimizer learning rate 22 | momentum: 0.9 # Optimizer momentum 23 | weight_decay: 0.0005 # L2 regularization parameter 24 | min_bits: 3 25 | max_bits: 8 26 | discrete_alpha_map: True 27 | model_config: 28 | model_class: cpreresnet20 29 | model_kwargs: 30 | channel_selection_active: False 31 | -------------------------------------------------------------------------------- /configs/quantized/lcs_p.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "quantized" 7 | dataset: "cifar10" 8 | script: "train_indep.py" 9 | save_dir: "saved_models/quantized/" 10 | epochs: 200 # The number of epochs to train for 11 | batch_size: 128 # Train and test batch sizes 12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs 13 | test_freq: 20 # Will run test after this many epochs 14 | learning_rate: 0.025 # Optimizer learning rate 15 | momentum: 0.9 # Optimizer momentum 16 | weight_decay: 0.0005 # L2 regularization parameter 17 | conv_type: "StandardConv" # convolution layer for first layer of network 18 | bn_type: "StandardGN" # batch norm layer for first layer of network 19 | block_conv_type: "ConvBn2d" # convolution layer to use for resnet blocks 20 | block_bn_type: "StandardGN" # batch norm layer to use for resnet blocks 21 | random_param: True # If true, will train with topk=1 during warmup and then use random topk values 22 | eval_param_grid: [3, 4, 5, 6, 7, 8] 23 | min_bits: 3 24 | max_bits: 8 25 | model_config: 26 | model_class: cpreresnet20 27 | model_kwargs: 28 | channel_selection_active: False 29 | -------------------------------------------------------------------------------- /configs/quantized/target_bit_width.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "quantized" 7 | dataset: "cifar10" 8 | save_dir: "saved_models/quantized/" 9 | # Multiple trials parameters 10 | num_runs: 3 11 | script: "train_indep.py" 12 | epochs: 200 # The number of epochs to train for 13 | batch_size: 128 # Train and test batch sizes 14 | warmup_budget: 80 # decays topk from 1 to alpha over this percentage of epochs 15 | test_freq: 20 # Will run test after this many epochs 16 | learning_rate: 0.025 # Optimizer learning rate 17 | momentum: 0.9 # Optimizer momentum 18 | weight_decay: 0.0005 # L2 regularization parameter 19 | conv_type: "StandardConv" 20 | bn_type: "StandardBN" 21 | block_conv_type: "ConvBn2d" 22 | block_bn_type: "StandardBN" 23 | num_bits: 6 # Target number of bits 24 | random_bits: False # If true, will train with num_bits=8 during warmup and then use random bits in range [min_bits, max_bits] 25 | eval_param_grid: [3, 4, 5, 6, 7, 8] 26 | model_config: 27 | model_class: cpreresnet20 28 | model_kwargs: 29 | channel_selection_active: False 30 | -------------------------------------------------------------------------------- /configs/structured_sparsity/lcs_l.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "us" 7 | dataset: "cifar10" 8 | script: "train_curve.py" 9 | save_dir: "saved_models/structured_sparsity/" 10 | # Line: 11 | num_points: 2 12 | conv_type: "AdaptiveConv2d" # First convolution layer 13 | bn_type: "LinesAdaptiveIN" # Batchnorm for first convolution layer 14 | block_conv_type: "AdaptiveConv2d" # Convolution layer for ResNet blocks 15 | block_bn_type: "LinesAdaptiveIN" # Convolution batchnorm layer for ResNet blocks 16 | epochs: 200 # The number of epochs to train for 17 | batch_size: 128 # Train and test batch sizes 18 | warmup_budget: 80 # percentage of epochs for "warmup" phase 19 | test_freq: 20 # Will run test after this many epochs 20 | # To evaluate at width factors [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25], we need the alpha grid to look like this. 21 | # (and we must appropriately set the width_factor_limits inside regime_params). 22 | alpha_grid: [1.0, 0.83333, 0.66667, 0.5, 0.33333, 0.166667, 0] 23 | learning_rate: 0.1 # Optimizer learning rate 24 | momentum: 0.9 # Optimizer momentum 25 | weight_decay: 0.0005 # L2 regularization parameter 26 | model_config: 27 | model_class: cpreresnet20 28 | model_kwargs: 29 | channel_selection_active: False 30 | regime_params: 31 | # eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.5, 0.375, 0.25] 32 | width_factor_limits: [0.25, 1.0] 33 | width_factor_samples: 4 34 | width_factor_sampling_method: sandwich 35 | apply_beta_to_norm: True 36 | -------------------------------------------------------------------------------- /configs/structured_sparsity/lcs_p.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "us" 7 | dataset: "cifar10" 8 | script: "train_indep.py" 9 | save_dir: "saved_models/structured_sparsity/" 10 | epochs: 200 # The number of epochs to train for 11 | batch_size: 128 # Train and test batch sizes 12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs 13 | test_freq: 20 # Will run test after this many epochs 14 | learning_rate: 0.1 # Optimizer learning rate 15 | momentum: 0.9 # Optimizer momentum 16 | weight_decay: 0.0005 # L2 regularization parameter 17 | conv_type: "AdaptiveConv2d" # convolution layer for first layer of network 18 | bn_type: "AdaptiveIN" # batch norm layer for first layer of network 19 | block_conv_type: "AdaptiveConv2d" # convolution layer to use for resnet blocks 20 | block_bn_type: "AdaptiveIN" # batch norm layer to use for resnet blocks 21 | topk: 1.0 # the target topk to reach after warmup phase 22 | random_topk: False # If true, will train with topk=1 during warmup and then use random topk values 23 | norm_kwargs: 24 | width_factors_list: null 25 | model_config: 26 | model_class: cpreresnet20 27 | model_kwargs: 28 | channel_selection_active: False 29 | regime_params: 30 | eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] 31 | width_factor_limits: [0.25, 1.0] 32 | width_factor_samples: 4 33 | 34 | eval_param_grid: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] # Same as eval_width_factors 35 | -------------------------------------------------------------------------------- /configs/structured_sparsity/lec.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "lec" 7 | bn_update_factor: 0.00001 8 | dataset: "cifar10" 9 | script: "train_indep.py" 10 | save_dir: "saved_models/structured_sparsity/" 11 | epochs: 200 # The number of epochs to train for 12 | batch_size: 128 # Train and test batch sizes 13 | warmup_budget: 80 # decays topk from 1 to alpha over this percentage of epochs 14 | test_freq: 20 # Will run test after this many epochs 15 | learning_rate: 0.1 # Optimizer learning rate 16 | momentum: 0.9 # Optimizer momentum 17 | weight_decay: 0.0005 # L2 regularization parameter 18 | conv_type: "StandardConv" 19 | bn_type: "StandardBN" 20 | block_conv_type: "StandardConv" 21 | block_bn_type: "StandardBN" 22 | topk: 1.0 23 | model_config: 24 | model_class: cpreresnet20 25 | model_kwargs: 26 | channel_selection_active: True 27 | eval_param_grid: [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0] 28 | -------------------------------------------------------------------------------- /configs/structured_sparsity/ns.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "ns" 7 | dataset: "cifar10" 8 | script: "train_indep.py" 9 | save_dir: "saved_models/structured_sparsity/" 10 | epochs: 200 # The number of epochs to train for 11 | batch_size: 128 # Train and test batch sizes 12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs 13 | test_freq: 20 # Will run test after this many epochs 14 | learning_rate: 0.1 # Optimizer learning rate 15 | momentum: 0.9 # Optimizer momentum 16 | weight_decay: 0.0005 # L2 regularization parameter 17 | conv_type: "AdaptiveConv2d" # convolution layer for first layer of network 18 | bn_type: "AdaptiveBN" # [sic] this norm is evaluatable at any point, and is otherwise the same as the AdaptiveIN 19 | block_conv_type: "AdaptiveConv2d" # convolution layer to use for resnet blocks 20 | block_bn_type: "AdaptiveBN" # batch norm layer to use for resnet blocks 21 | topk: 1.0 # the target topk to reach after warmup phase 22 | random_topk: False # If true, will train with topk=1 during warmup and then use random topk values 23 | norm_kwargs: 24 | track_running_stats: True 25 | width_factors_list: null 26 | builder_kwargs: 27 | width_factors_list: [1.0, 0.75, 0.50, 0.25] 28 | model_config: 29 | model_class: cpreresnet20 30 | model_kwargs: 31 | channel_selection_active: False 32 | regime_params: 33 | eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] 34 | 35 | eval_param_grid: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] # Same as eval_width_factors 36 | -------------------------------------------------------------------------------- /configs/structured_sparsity/us.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "us" 7 | dataset: "cifar10" 8 | script: "train_indep.py" 9 | save_dir: "saved_models/structured_sparsity/" 10 | epochs: 200 # The number of epochs to train for 11 | batch_size: 128 # Train and test batch sizes 12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs 13 | test_freq: 20 # Will run test after this many epochs 14 | learning_rate: 0.1 # Optimizer learning rate 15 | momentum: 0.9 # Optimizer momentum 16 | weight_decay: 0.0005 # L2 regularization parameter 17 | conv_type: "AdaptiveConv2d" # convolution layer for first layer of network 18 | bn_type: "AdaptiveBN" # batch norm layer for first layer of network 19 | block_conv_type: "AdaptiveConv2d" # convolution layer to use for resnet blocks 20 | block_bn_type: "AdaptiveBN" # batch norm layer to use for resnet blocks 21 | topk: 1.0 # the target topk to reach after warmup phase 22 | random_topk: False # If true, will train with topk=1 during warmup and then use random topk values 23 | norm_kwargs: 24 | width_factors_list: null 25 | model_config: 26 | model_class: cpreresnet20 27 | model_kwargs: 28 | channel_selection_active: False 29 | regime_params: 30 | eval_width_factors: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] 31 | width_factor_limits: [0.25, 1.0] 32 | width_factor_samples: 4 33 | 34 | eval_param_grid: [1.0, 0.875, 0.75, 0.625, 0.50, 0.375, 0.25] # Same as eval_width_factors 35 | -------------------------------------------------------------------------------- /configs/unstructured_sparsity/lcs_l.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "sparse" 7 | dataset: "cifar10" 8 | script: "train_curve.py" 9 | save_dir: "saved_models/unstructured_sparsity/" 10 | # Line: 11 | num_points: 2 12 | conv_type: "LinesConv" # First convolution layer 13 | bn_type: "LinesGN" # Batchnorm for first convolution layer 14 | block_conv_type: "SparseLinesConv" # Convolution layer for ResNet blocks 15 | block_bn_type: "LinesGN" # Convolution batchnorm layer for ResNet blocks 16 | epochs: 200 # The number of epochs to train for 17 | batch_size: 128 # Train and test batch sizes 18 | warmup_budget: 80 # percentage of epochs for "warmup" phase 19 | test_freq: 20 # Will run test after this many epochs 20 | alpha_grid: [0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0] # eval at these alphas 21 | learning_rate: 0.1 # Optimizer learning rate 22 | momentum: 0.9 # Optimizer momentum 23 | weight_decay: 0.0005 # L2 regularization parameter 24 | alpha_sampling: [0.025, 1, 0.50] # Biased endpoint sampling 25 | model_config: 26 | model_class: cpreresnet20 27 | model_kwargs: 28 | channel_selection_active: False 29 | -------------------------------------------------------------------------------- /configs/unstructured_sparsity/lcs_p.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "sparse" 7 | dataset: "cifar10" 8 | script: "train_indep.py" 9 | save_dir: "saved_models/unstructured_sparsity/" 10 | epochs: 200 # The number of epochs to train for 11 | batch_size: 128 # Train and test batch sizes 12 | warmup_budget: 80 # decays topk from 1 to target topk over this percentage of epochs 13 | test_freq: 20 # Will run test after this many epochs 14 | learning_rate: 0.1 # Optimizer learning rate 15 | momentum: 0.9 # Optimizer momentum 16 | weight_decay: 0.0005 # L2 regularization parameter 17 | conv_type: "StandardConv" # convolution layer for first layer of network 18 | bn_type: "StandardGN" # batch norm layer for first layer of network 19 | block_conv_type: "SparseConv2d" # convolution layer to use for resnet blocks 20 | block_bn_type: "StandardGN" # batch norm layer to use for resnet blocks 21 | topk: 0.05 # the target topk to reach after warmup phase. If random_param true, will test at this value 22 | random_param: True # If true, will train with topk=1 during warmup and then use random topk values 23 | eval_param_grid: [0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0] 24 | model_config: 25 | model_class: cpreresnet20 26 | model_kwargs: 27 | channel_selection_active: False 28 | -------------------------------------------------------------------------------- /configs/unstructured_sparsity/target_topk.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | parameters: 6 | regime: "sparse" 7 | dataset: "cifar10" 8 | script: "train_indep.py" 9 | save_dir: "saved_models/unstructured_sparsity/" 10 | epochs: 200 # The number of epochs to train for 11 | batch_size: 128 # Train and test batch sizes 12 | warmup_budget: 80 # decays topk from 1 to the target over this percentage of epochs 13 | test_freq: 20 # Will run test after this many epochs 14 | learning_rate: 0.1 # Optimizer learning rate 15 | momentum: 0.9 # Optimizer momentum 16 | weight_decay: 0.0005 # L2 regularization parameter 17 | conv_type: "StandardConv" 18 | bn_type: "StandardBN" 19 | block_conv_type: "SparseConv2d" 20 | block_bn_type: "StandardBN" 21 | topk: 0.9 # the target topk to reach after warmup phase 22 | eval_param_grid: [0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 1.0] 23 | model_config: 24 | model_class: cpreresnet20 25 | model_kwargs: 26 | channel_selection_active: False 27 | -------------------------------------------------------------------------------- /curve_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import numpy as np 6 | import torch.nn as nn 7 | 8 | 9 | def get_stats(model): 10 | norms = {} 11 | numerators = {} 12 | difs = {} 13 | cossim = 0 14 | l2 = 0 15 | num_points = 2 16 | 17 | for i in range(num_points): 18 | norms[f"{i}"] = 0.0 19 | for j in range(i + 1, num_points): 20 | numerators[f"{i}-{j}"] = 0.0 21 | difs[f"{i}-{j}"] = 0.0 22 | 23 | for m in model.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | for i in range(num_points): 26 | vi = get_weight(m, i) 27 | norms[f"{i}"] += vi.pow(2).sum() 28 | for j in range(i + 1, num_points): 29 | vj = get_weight(m, j) 30 | numerators[f"{i}-{j}"] += (vi * vj).sum() 31 | difs[f"{i}-{j}"] += (vi - vj).pow(2).sum() 32 | 33 | for i in range(num_points): 34 | for j in range(i + 1, num_points): 35 | cossim += numerators[f"{i}-{j}"].pow(2) / ( 36 | norms[f"{i}"] * norms[f"{j}"] 37 | ) 38 | l2 += difs[f"{i}-{j}"] 39 | 40 | l2 = l2.pow(0.5).item() 41 | cossim = cossim.item() 42 | return cossim, l2 43 | 44 | 45 | def get_weight(m, i): 46 | if i == 0: 47 | return m.weight 48 | return getattr(m, f"weight{i}") 49 | 50 | 51 | def alpha_bit_map(alpha, **regime_params): 52 | """ 53 | Maps a continuous alpha value in [0,1] to a bit value. E.g. 54 | alpha \in [0, 1/8) => 1 55 | alpha \in [1/8, 2/8) => 2 56 | alpha \in [2/8, 3/8) => 3 57 | alpha \in [3/8, 4/8) => 4 58 | alpha \in [4/8, 5/8) => 5 59 | alpha \in [5/8, 6/8) => 6 60 | alpha \in [6/8, 7/8) => 7 61 | alpha \in [7/8, 8/8) => 8 62 | """ 63 | min_bits = regime_params["min_bits"] 64 | max_bits = regime_params["max_bits"] 65 | distinct_bits = max_bits - min_bits + 1 66 | for i in range(distinct_bits): 67 | if i / distinct_bits <= alpha <= (i + 1) / distinct_bits: 68 | return np.arange(min_bits, max_bits + 1)[i] 69 | 70 | 71 | def sample_alpha_num_bits(**regime_params): 72 | discrete = regime_params["discrete"] 73 | if discrete: 74 | min_bits = regime_params["min_bits"] 75 | max_bits = regime_params["max_bits"] 76 | distinct_bits = max_bits - min_bits + 1 77 | alpha = ( 78 | np.random.choice(np.arange(1, distinct_bits + 1)) / distinct_bits 79 | ) 80 | else: 81 | alpha = np.random.uniform(0, 1) 82 | 83 | num_bits = alpha_bit_map(alpha, **regime_params) 84 | 85 | return alpha, num_bits 86 | 87 | 88 | def alpha_sampling(**regime_params): 89 | # biased endpoint sampling 90 | if regime_params.get(f"alpha_sampling") is not None: 91 | low, high, endpoint_prob = regime_params.get(f"alpha_sampling") 92 | if np.random.rand() < endpoint_prob: 93 | # Pick an endpoint at random. 94 | if np.random.rand() < 0.5: 95 | alpha = low 96 | else: 97 | alpha = high 98 | else: 99 | alpha = np.random.uniform(low, high) 100 | else: 101 | alpha = np.random.uniform(0, 1) 102 | 103 | return alpha 104 | -------------------------------------------------------------------------------- /get_training_params.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import numpy as np 6 | 7 | import training_params 8 | import utils 9 | 10 | 11 | def gen_args_dict(args): 12 | a_dict = {} 13 | 14 | if args.epochs is not None: 15 | a_dict["epochs"] = args.epochs 16 | 17 | if args.test_freq is not None: 18 | a_dict["test_freq"] = args.epochs 19 | 20 | if args.learning_rate is not None: 21 | a_dict["learning_rate"] = args.learning_rate 22 | 23 | if args.batch_size is not None: 24 | a_dict["batch_size"] = args.batch_size 25 | 26 | if args.momentum is not None: 27 | a_dict["momentum"] = args.momentum 28 | 29 | if args.weight_decay is not None: 30 | a_dict["weight_decay"] = args.weight_decay 31 | 32 | if args.save_dir is not None: 33 | a_dict["save_dir"] = args.save_dir 34 | 35 | if args.dataset == "imagenet": 36 | if args.imagenet_dir is not None: 37 | a_dict["dataset_dir"] = args.imagenet_dir 38 | else: 39 | raise ValueError( 40 | f"ImageNet data directory must be specified as --imagenet_dir " 41 | ) 42 | 43 | return a_dict 44 | 45 | 46 | def unstructured_args_dict(args): 47 | ua_dict = gen_args_dict(args) 48 | 49 | if args.topk is not None: 50 | ua_dict["topk"] = args.topk 51 | 52 | if args.warmup_budget is not None: 53 | ua_dict["warmup_budget"] = args.warmup_budget 54 | 55 | if args.eval_topk_grid is not None: 56 | if args.method == "lcs_l": 57 | ua_dict["alpha_grid"] = args.eval_topk_grid 58 | else: 59 | ua_dict["eval_param_grid"] = args.eval_topk_grid 60 | 61 | if ( 62 | args.topk_lower_bound is not None 63 | and args.topk_upper_bound is not None 64 | and args.eval_topk_grid is not None 65 | ): 66 | if args.method == "lcs_l": 67 | ua_dict["alpha_grid"] = args.eval_topk_grid 68 | elif args.method == "lcs_p": 69 | ua_dict["eval_param_grid"] = args.eval_topk_grid 70 | 71 | ua_dict["alpha_sampling"] = [ 72 | args.topk_lower_bound, 73 | args.topk_upper_bound, 74 | 0.5, 75 | ] 76 | 77 | return ua_dict 78 | 79 | 80 | def structured_args_dict(args, base_config): 81 | sa_dict = gen_args_dict(args) 82 | 83 | if args.width_factors_list is not None and args.method == "ns": 84 | builder_kwargs = base_config["parameters"]["builder_kwargs"] 85 | builder_kwargs["width_factors_list"] = args.width_factors_list 86 | sa_dict["builder_kwargs"] = builder_kwargs 87 | 88 | if args.width_factor_limits is not None and args.method != "ns": 89 | regime_params = base_config["parameters"]["regime_params"] 90 | regime_params["width_factor_limits"] = args.width_factor_limits 91 | sa_dict["regime_params"] = regime_params 92 | 93 | if args.width_factor_samples is not None and args.method != "ns": 94 | regime_params = sa_dict.get( 95 | "regime_params", base_config["parameters"]["regime_params"] 96 | ) 97 | regime_params["width_factor_samples"] = args.width_factor_samples 98 | sa_dict["regime_params"] = regime_params 99 | 100 | if args.eval_width_factors is not None: 101 | regime_params = sa_dict.get( 102 | "regime_params", base_config["parameters"]["regime_params"] 103 | ) 104 | if args.method in ("us", "ns", "lcs_p"): 105 | regime_params["eval_width_factors"] = args.eval_width_factors 106 | sa_dict["regime_params"] = regime_params 107 | sa_dict["eval_param_grid"] = args.eval_width_factors 108 | elif args.method == "lcs_l": 109 | w_l, w_u = regime_params["width_factor_limits"] 110 | alpha_grid = [ 111 | (w_f - w_l) / (w_u - w_l) for w_f in args.eval_width_factors 112 | ] 113 | sa_dict["alpha_grid"] = alpha_grid 114 | 115 | return sa_dict 116 | 117 | 118 | def quantized_args_dict(args): 119 | q_dict = gen_args_dict(args) 120 | 121 | if args.bit_width is not None and args.method == "target_bit_width": 122 | q_dict["num_bits"] = args.bit_width 123 | 124 | if args.eval_bit_widths is not None and args.method != "lcs_l": 125 | q_dict["eval_param_grid"] = args.eval_bit_widths 126 | 127 | if args.bit_width_limits is not None and args.method in ("lcs_p", "lcs_l"): 128 | min_bits, max_bits = [int(x) for x in args.bit_width_limits] 129 | q_dict["min_bits"] = min_bits 130 | q_dict["max_bits"] = max_bits 131 | if args.method == "lcs_l": 132 | range_len = max_bits - min_bits + 1 133 | alpha_grid = [ 134 | np.floor(x * 1000) / 1000 135 | for x in np.linspace(0, 1, range_len + 1) 136 | ][1:] 137 | q_dict["alpha_grid"] = alpha_grid 138 | elif args.method == "lcs_p": 139 | if args.eval_bit_widths is None: 140 | eval_param_grid = np.arange(min_bits, max_bits + 1) 141 | q_dict["eval_param_grid"] = eval_param_grid 142 | 143 | return q_dict 144 | 145 | 146 | def get_config_norm(params): 147 | norm_types = ["IN", "BN", "GN"] 148 | base_bn_type = params["bn_type"] 149 | base_block_bn_type = params["block_bn_type"] 150 | for n in norm_types: 151 | if n in base_bn_type: 152 | base_bn = n 153 | if n in base_block_bn_type: 154 | base_block_bn = n 155 | 156 | for norm in norm_types: 157 | if norm in base_bn and norm in base_block_bn: 158 | return norm 159 | 160 | 161 | def get_method_config(args, setting): 162 | method = args.method 163 | 164 | base_config_dir = f"configs/{setting}/{method}.yaml" 165 | 166 | try: 167 | base_config = utils.get_yaml_config(base_config_dir) 168 | except: 169 | raise ValueError(f"{setting}/{method} not valid training configuration") 170 | 171 | params = base_config["parameters"] 172 | 173 | # Update base config with specific parameters 174 | 175 | # Set normalization layers 176 | config_norm = get_config_norm(params) 177 | norm_to_use = config_norm if args.norm is None else args.norm 178 | norm_types = ["IN", "BN", "GN"] 179 | if norm_to_use not in norm_types: 180 | raise ValueError( 181 | f"Norm {norm_to_use} not valid. Supported normalization types: IN, BN, GN." 182 | ) 183 | 184 | base_bn_type = base_config["parameters"]["bn_type"] 185 | base_block_bn_type = base_config["parameters"]["block_bn_type"] 186 | for n in norm_types: 187 | if n in base_bn_type: 188 | base_bn = n 189 | if n in base_block_bn_type: 190 | base_block_bn = n 191 | 192 | if setting == "quantized" and norm_to_use == "BN": 193 | base_config["parameters"]["bn_type"] = "QuantStandardBN" 194 | base_config["parameters"]["block_bn_type"] = "QuantStandardBN" 195 | else: 196 | base_config["parameters"]["bn_type"] = base_bn_type.replace( 197 | base_bn, norm_to_use 198 | ) 199 | base_config["parameters"]["block_bn_type"] = base_block_bn_type.replace( 200 | base_block_bn, norm_to_use 201 | ) 202 | # Set track_running_stats to True if using BN 203 | if norm_to_use == "BN": 204 | base_norm_kwargs = base_config["parameters"].get( 205 | "norm_kwargs", None 206 | ) 207 | if base_norm_kwargs is not None: 208 | base_norm_kwargs["track_running_stats"] = True 209 | else: 210 | base_config["parameters"]["norm_kwargs"] = { 211 | "track_running_stats": True 212 | } 213 | 214 | # Update dataset 215 | dataset = args.dataset 216 | base_config["parameters"]["dataset"] = dataset 217 | 218 | # Update model 219 | model_name = args.model.lower() 220 | base_config["parameters"]["model_config"]["model_class"] = model_name 221 | model_kwargs = params.get("model_kwargs", None) 222 | if model_kwargs is not None: 223 | base_model_kwargs = base_config["parameters"]["model_config"].get( 224 | "model_kwargs", None 225 | ) 226 | if base_model_kwargs is not None: 227 | base_model_kwargs.update(model_kwargs) 228 | else: 229 | base_config["parameters"]["model_config"]["model_kwargs"] = {} 230 | 231 | # Remove channel_selection_active for models without this parameter 232 | if model_name not in ("cpreresnet20"): 233 | base_model_kwargs = base_config["parameters"]["model_config"][ 234 | "model_kwargs" 235 | ] 236 | base_model_kwargs.pop("channel_selection_active", None) 237 | 238 | # For unstructured sparsity, if using GN, set num_groups to 32 239 | # We cannot use GN with structured sparsity (number of channels isn't always 240 | # divisible by 32), we use IN instead. 241 | if norm_to_use == "GN": 242 | if setting in ("unstructured_sparsity", "quantized"): 243 | num_groups = 32 244 | else: 245 | raise NotImplementedError( 246 | f"GroupNorm disabled for setting={setting}." 247 | ) 248 | base_norm_kwargs = base_config["parameters"].get("norm_kwargs", None) 249 | if base_norm_kwargs is not None: 250 | base_norm_kwargs["num_groups"] = num_groups 251 | else: 252 | base_config["parameters"]["norm_kwargs"] = { 253 | "num_groups": num_groups 254 | } 255 | 256 | # Set default model training parameters 257 | model_training_params = training_params.model_data_params(args) 258 | base_config["parameters"].update(model_training_params) 259 | 260 | # Update training parameters with user-specified ones 261 | if setting == "unstructured_sparsity": 262 | args_dict = unstructured_args_dict(args) 263 | elif setting == "structured_sparsity": 264 | args_dict = structured_args_dict(args, base_config) 265 | elif setting == "quantized": 266 | args_dict = quantized_args_dict(args) 267 | else: 268 | args_dict = {} 269 | 270 | base_config["parameters"].update(args_dict) 271 | 272 | # Get/make save directory 273 | args_save_dir = args.save_dir 274 | if args_save_dir is None: 275 | config_save_dir = params["save_dir"] 276 | save_dir = utils.create_save_dir( 277 | config_save_dir, method, model_name, dataset, norm_to_use 278 | ) 279 | else: 280 | save_dir = utils.create_save_dir( 281 | args_save_dir, method, model_name, dataset, norm_to_use, False 282 | ) 283 | 284 | base_config["parameters"]["save_dir"] = save_dir 285 | 286 | # Print training details 287 | utils.print_train_params( 288 | base_config, setting, method, norm_to_use, save_dir 289 | ) 290 | 291 | return base_config 292 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import train_curve 6 | import train_indep 7 | from get_training_params import get_method_config 8 | 9 | 10 | def train(args, setting): 11 | config = get_method_config(args, setting) 12 | if config["parameters"]["script"] == "train_curve.py": 13 | train_curve.train_model(config) 14 | else: 15 | train_indep.train_model(config) 16 | -------------------------------------------------------------------------------- /model_logging.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | import utils 12 | from curve_utils import alpha_bit_map 13 | from models.networks import model_profiling 14 | 15 | 16 | def save_model_at_epoch(model, epoch, save_dir, save_per=10): 17 | """ 18 | Saves the given model as 'model_{epoch}' 19 | """ 20 | if epoch % save_per == 0: 21 | save_path = os.path.join(save_dir, f"model_{epoch}") 22 | torch.save(model.state_dict(), save_path) 23 | else: 24 | print( 25 | f"Skipping saving at epoch {epoch} (saving every {save_per} epochs)." 26 | ) 27 | 28 | 29 | def save(save_name, log, save_dir): 30 | save_path = os.path.join(save_dir, save_name) 31 | np.save(save_path, log) 32 | 33 | 34 | def _handle_bn_stats(metric_dict, extra_metrics, regime_params): 35 | if metric_dict is None: 36 | return None 37 | 38 | if "bn_metrics" not in extra_metrics: 39 | print(f"No BN metrics found, skipping BN stats.") 40 | return None 41 | 42 | bn_metrics = extra_metrics["bn_metrics"] 43 | 44 | for name, value in bn_metrics.items(): 45 | if name not in metric_dict: 46 | metric_dict[name] = [] 47 | metric_dict[name].append(value) 48 | return metric_dict 49 | 50 | 51 | def sparse_logging( 52 | model, 53 | loss, 54 | acc, 55 | model_type, 56 | param=None, 57 | metric_dict=None, 58 | extra_metrics=None, 59 | **regime_params, 60 | ): 61 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params) 62 | 63 | sparsity = utils.get_sparsity_rate(model) 64 | 65 | param_name = "alpha" if model_type == "curve" else "topk" 66 | param_print = "" if param is None else f" ({param_name} = {param})" 67 | print( 68 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f} | Sparsity: {sparsity:.4f}" 69 | ) 70 | 71 | if metric_dict is not None: 72 | metric_dict["acc"].append(acc) 73 | metric_dict["sparsity"].append(sparsity) 74 | if param_name not in metric_dict: 75 | metric_dict[param_name] = [] 76 | metric_dict[param_name].append(param) 77 | 78 | return metric_dict 79 | 80 | 81 | def quantized_logging( 82 | model, 83 | loss, 84 | acc, 85 | model_type, 86 | param=None, 87 | metric_dict=None, 88 | extra_metrics=None, 89 | **regime_params, 90 | ): 91 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params) 92 | 93 | if model_type == "curve": 94 | param_name = "alpha" 95 | inv_alpha = regime_params.get("inv_alpha", False) 96 | if inv_alpha: 97 | num_bits = alpha_bit_map(1 - param, **regime_params) 98 | else: 99 | num_bits = alpha_bit_map(param, **regime_params) 100 | param_print = f" (alpha/num_bits = {param}/{num_bits})" 101 | else: 102 | param_name = "num_bits" 103 | num_bits = param 104 | param_print = "" if param is None else f" ({param_name} = {num_bits})" 105 | 106 | print( 107 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f}" 108 | ) 109 | 110 | if metric_dict is not None: 111 | metric_dict["acc"].append(acc) 112 | metric_dict["num_bits"].append(num_bits) 113 | if param_name != "num_bits": # Append alpha for curves 114 | metric_dict[param_name].append(param) 115 | 116 | return metric_dict 117 | 118 | 119 | def lec_logging( 120 | model, 121 | loss, 122 | acc, 123 | model_type, 124 | param=None, 125 | metric_dict=None, 126 | extra_metrics=None, 127 | **regime_params, 128 | ): 129 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params) 130 | 131 | sparsity = regime_params["sparsity"] 132 | 133 | param_name = "alpha" if model_type == "curve" else "topk" 134 | param_print = "" if param is None else f" ({param_name} = {param})" 135 | print( 136 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f} | Sparsity: {sparsity:.4f}" 137 | ) 138 | 139 | if metric_dict is not None: 140 | metric_dict["acc"].append(acc) 141 | metric_dict["sparsity"].append(sparsity) 142 | if param_name not in metric_dict: 143 | metric_dict[param_name] = [] 144 | metric_dict[param_name].append(param) 145 | 146 | return metric_dict 147 | 148 | 149 | def ns_logging( 150 | model, 151 | loss, 152 | acc, 153 | model_type, 154 | param=None, 155 | metric_dict=None, 156 | extra_metrics=None, 157 | **regime_params, 158 | ): 159 | metric_dict = _handle_bn_stats(metric_dict, extra_metrics, regime_params) 160 | 161 | input_size = regime_params["input_size"] 162 | 163 | is_cuda = next(model.parameters()).is_cuda 164 | n_macs, n_params = model_profiling.model_profiling( 165 | model, input_size, input_size, use_cuda=is_cuda 166 | ) 167 | 168 | total_params = 0 169 | for module in model.modules(): 170 | if isinstance( 171 | module, (nn.Linear, nn.Conv2d) 172 | ): # Make sure we also count contributions from non-sparse elements. 173 | total_params += module.weight.numel() 174 | 175 | sparsity = (total_params - n_params) / total_params 176 | 177 | param_name = "width_factor" 178 | param_print = f" ({param_name} = {regime_params[param_name]})" 179 | print( 180 | f"Test set{param_print}: Average loss: {loss:.4f} | Accuracy: {acc:.4f} | Sparsity: {sparsity:.4f}" 181 | ) 182 | 183 | if metric_dict is not None: 184 | metric_dict["acc"].append(acc) 185 | metric_dict["sparsity"].append(sparsity) 186 | metric_dict["alpha"].append(param) 187 | if param_name not in metric_dict: 188 | metric_dict[param_name] = [] 189 | metric_dict[param_name].append(regime_params[param_name]) 190 | 191 | return metric_dict 192 | 193 | 194 | def us_logging(*args, **kwargs): 195 | return ns_logging(*args, **kwargs) 196 | -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from . import init 7 | from . import modules 8 | from . import quantized_modules 9 | from . import sparse_modules 10 | 11 | 12 | class Builder: 13 | def __init__( 14 | self, 15 | conv_type="LinesConv", 16 | bn_type="LinesBN", 17 | conv_init="kaiming_normal", 18 | norm_kwargs=None, 19 | pass_first_last=False, 20 | **conv_kwargs 21 | ): 22 | if norm_kwargs is None: 23 | norm_kwargs = {} 24 | self.pass_first_last = pass_first_last 25 | 26 | self.conv_kwargs = conv_kwargs 27 | self.norm_kwargs = norm_kwargs 28 | 29 | if hasattr(modules, bn_type): 30 | self.bn_layer = getattr(modules, bn_type) 31 | elif hasattr(sparse_modules, bn_type): 32 | self.bn_layer = getattr(sparse_modules, bn_type) 33 | elif hasattr(quantized_modules, bn_type): 34 | self.bn_layer = getattr(quantized_modules, bn_type) 35 | else: 36 | raise ValueError("Normalization layer not found") 37 | 38 | if hasattr(modules, conv_type): 39 | self.conv_layer = getattr(modules, conv_type) 40 | elif hasattr(sparse_modules, conv_type): 41 | self.conv_layer = getattr(sparse_modules, conv_type) 42 | elif hasattr(quantized_modules, conv_type): 43 | self.conv_layer = getattr(quantized_modules, conv_type) 44 | self.conv_kwargs[ 45 | "bn_module" 46 | ] = self.bn_layer # self.bn_layer chosen above 47 | self.conv_kwargs["norm_kwargs"] = norm_kwargs 48 | # Overwrite self.bn_layer to no-op batchnorm since handled by ConvBn 49 | self.bn_layer = getattr(quantized_modules, "NoOpBN") 50 | else: 51 | raise ValueError("Convolution layer not found") 52 | 53 | self.conv_init = getattr(init, conv_init) 54 | 55 | def conv( 56 | self, 57 | kernel_size, 58 | in_planes, 59 | out_planes, 60 | stride=1, 61 | groups=1, 62 | first_layer=False, 63 | last_layer=False, 64 | is_conv=False, 65 | ): 66 | conv_kwargs = self.conv_kwargs.copy() 67 | if self.pass_first_last: 68 | conv_kwargs["first_layer"] = first_layer 69 | conv_kwargs["last_layer"] = last_layer 70 | 71 | if kernel_size == 1: 72 | conv = self.conv_layer( 73 | in_planes, 74 | out_planes, 75 | kernel_size=1, 76 | stride=stride, 77 | bias=False, 78 | **conv_kwargs 79 | ) 80 | elif kernel_size == 3: 81 | conv = self.conv_layer( 82 | in_channels=in_planes, 83 | out_channels=out_planes, 84 | kernel_size=3, 85 | stride=stride, 86 | padding=1, 87 | groups=groups, 88 | bias=False, 89 | **conv_kwargs 90 | ) 91 | elif kernel_size == 5: 92 | conv = self.conv_layer( 93 | in_planes, 94 | out_planes, 95 | kernel_size=5, 96 | stride=stride, 97 | padding=2, 98 | groups=groups, 99 | bias=False, 100 | **conv_kwargs 101 | ) 102 | elif kernel_size == 7: 103 | conv = self.conv_layer( 104 | in_planes, 105 | out_planes, 106 | kernel_size=7, 107 | stride=stride, 108 | padding=3, 109 | groups=groups, 110 | bias=False, 111 | **conv_kwargs 112 | ) 113 | else: 114 | return None 115 | 116 | conv.first_layer = first_layer 117 | conv.last_layer = last_layer 118 | conv.is_conv = is_conv 119 | self.conv_init(conv.weight) 120 | if hasattr(conv, "initialize"): 121 | conv.initialize(self.conv_init) 122 | return conv 123 | 124 | def conv1x1( 125 | self, 126 | in_planes, 127 | out_planes, 128 | stride=1, 129 | groups=1, 130 | first_layer=False, 131 | last_layer=False, 132 | is_conv=False, 133 | ): 134 | """1x1 convolution with padding""" 135 | c = self.conv( 136 | 1, 137 | in_planes, 138 | out_planes, 139 | stride=stride, 140 | groups=groups, 141 | first_layer=first_layer, 142 | last_layer=last_layer, 143 | is_conv=is_conv, 144 | ) 145 | 146 | return c 147 | 148 | def conv3x3( 149 | self, 150 | in_planes, 151 | out_planes, 152 | stride=1, 153 | groups=1, 154 | first_layer=False, 155 | last_layer=False, 156 | is_conv=False, 157 | ): 158 | """3x3 convolution with padding""" 159 | c = self.conv( 160 | 3, 161 | in_planes, 162 | out_planes, 163 | stride=stride, 164 | groups=groups, 165 | first_layer=first_layer, 166 | last_layer=last_layer, 167 | is_conv=is_conv, 168 | ) 169 | return c 170 | 171 | def conv5x5( 172 | self, 173 | in_planes, 174 | out_planes, 175 | stride=1, 176 | groups=1, 177 | first_layer=False, 178 | last_layer=False, 179 | is_conv=False, 180 | ): 181 | """5x5 convolution with padding""" 182 | c = self.conv( 183 | 5, 184 | in_planes, 185 | out_planes, 186 | stride=stride, 187 | groups=groups, 188 | first_layer=first_layer, 189 | last_layer=last_layer, 190 | is_conv=is_conv, 191 | ) 192 | return c 193 | 194 | def conv7x7( 195 | self, 196 | in_planes, 197 | out_planes, 198 | stride=1, 199 | groups=1, 200 | first_layer=False, 201 | last_layer=False, 202 | is_conv=False, 203 | ): 204 | """7x7 convolution with padding""" 205 | c = self.conv( 206 | 7, 207 | in_planes, 208 | out_planes, 209 | stride=stride, 210 | groups=groups, 211 | first_layer=first_layer, 212 | last_layer=last_layer, 213 | is_conv=is_conv, 214 | ) 215 | return c 216 | 217 | def batchnorm(self, planes): 218 | return self.bn_layer(planes, **self.norm_kwargs) 219 | -------------------------------------------------------------------------------- /models/init.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch.nn as nn 6 | 7 | 8 | def kaiming_normal(weight): 9 | nn.init.kaiming_normal_( 10 | weight, 11 | ) 12 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | # Convolutions 12 | StandardConv = nn.Conv2d 13 | 14 | 15 | class SubspaceConv(nn.Conv2d): 16 | def forward(self, x): 17 | # call get_weight, which samples from the subspace, then use the 18 | # corresponding weight. 19 | w = self.get_weight() 20 | x = F.conv2d( 21 | x, 22 | w, 23 | self.bias, 24 | self.stride, 25 | self.padding, 26 | self.dilation, 27 | self.groups, 28 | ) 29 | return x 30 | 31 | 32 | class TwoParamConv(SubspaceConv): 33 | def __init__(self, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | self.weight1 = nn.Parameter(torch.zeros_like(self.weight)) 36 | 37 | def initialize(self, initialize_fn): 38 | initialize_fn(self.weight) 39 | initialize_fn(self.weight1) 40 | 41 | 42 | class LinesConv(TwoParamConv): 43 | def get_weight(self): 44 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 45 | return w 46 | 47 | 48 | # BatchNorms 49 | StandardBN = nn.BatchNorm2d 50 | 51 | 52 | class SubspaceBN(nn.BatchNorm2d): 53 | def forward(self, input): 54 | # call get_weight, which samples from the subspace, then use the 55 | # corresponding weight. 56 | w, b = self.get_weight() 57 | 58 | # The rest is code in the PyTorch source forward pass for batchnorm. 59 | if self.momentum is None: 60 | exponential_average_factor = 0.0 61 | else: 62 | exponential_average_factor = self.momentum 63 | 64 | if self.training and self.track_running_stats: 65 | if self.num_batches_tracked is not None: 66 | self.num_batches_tracked = self.num_batches_tracked + 1 67 | if self.momentum is None: # use cumulative moving average 68 | exponential_average_factor = 1.0 / float( 69 | self.num_batches_tracked 70 | ) 71 | else: # use exponential moving average 72 | exponential_average_factor = self.momentum 73 | if self.training: 74 | bn_training = True 75 | else: 76 | bn_training = (self.running_mean is None) and ( 77 | self.running_var is None 78 | ) 79 | return F.batch_norm( 80 | input, 81 | # If buffers are not to be tracked, ensure that they won't be 82 | # updated 83 | self.running_mean 84 | if not self.training or self.track_running_stats 85 | else None, 86 | self.running_var 87 | if not self.training or self.track_running_stats 88 | else None, 89 | w, 90 | b, 91 | bn_training, 92 | exponential_average_factor, 93 | self.eps, 94 | ) 95 | 96 | 97 | class TwoParamBN(SubspaceBN): 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | self.weight1 = nn.Parameter(torch.empty([self.num_features])) 101 | self.bias1 = nn.Parameter(torch.empty([self.num_features])) 102 | torch.nn.init.ones_(self.weight1) 103 | torch.nn.init.zeros_(self.bias1) 104 | 105 | 106 | class LinesBN(TwoParamBN): 107 | def get_weight(self): 108 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 109 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1 110 | return w, b 111 | 112 | 113 | # InstanceNorm 114 | def StandardIN(*args, affine=True, **kwargs): 115 | return nn.InstanceNorm2d(*args, affine=affine, **kwargs) 116 | 117 | 118 | def _process_num_groups(num_groups: Union[str, int], num_channels: int) -> int: 119 | if num_groups == "full": 120 | num_groups = num_channels # Set it equal to num_features. 121 | else: 122 | num_groups = int(num_groups) 123 | 124 | # If num_groups is greater than num_features, we reduce it. 125 | num_groups = min(num_channels, num_groups) 126 | return num_groups 127 | 128 | 129 | class SubspaceIN(nn.InstanceNorm2d): 130 | def __init__( 131 | self, 132 | num_features: int, 133 | eps: float = 1e-5, 134 | momentum: float = 0.1, 135 | affine: bool = False, 136 | track_running_stats: bool = False, 137 | ) -> None: 138 | # Override @affine to be true. 139 | super().__init__( 140 | num_features, 141 | eps=eps, 142 | momentum=momentum, 143 | affine=True, 144 | track_running_stats=track_running_stats, 145 | ) 146 | 147 | def forward(self, input): 148 | # call get_weight, which samples from the subspace, then use the 149 | # corresponding weight. 150 | w, b = self.get_weight() 151 | 152 | # The rest is code in the PyTorch source forward pass for instancenorm. 153 | assert self.running_mean is None or isinstance( 154 | self.running_mean, torch.Tensor 155 | ) 156 | assert self.running_var is None or isinstance( 157 | self.running_var, torch.Tensor 158 | ) 159 | return F.instance_norm( 160 | input, 161 | self.running_mean, 162 | self.running_var, 163 | w, 164 | b, 165 | self.training or not self.track_running_stats, 166 | self.momentum, 167 | self.eps, 168 | ) 169 | 170 | 171 | class TwoParamIN(SubspaceIN): 172 | def __init__(self, *args, **kwargs): 173 | super().__init__(*args, **kwargs) 174 | self.weight1 = nn.Parameter(torch.empty([self.num_features])) 175 | self.bias1 = nn.Parameter(torch.empty([self.num_features])) 176 | torch.nn.init.ones_(self.weight1) 177 | torch.nn.init.zeros_(self.bias1) 178 | 179 | 180 | class LinesIN(TwoParamIN): 181 | def get_weight(self): 182 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 183 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1 184 | return w, b 185 | 186 | 187 | # GroupNorm 188 | def StandardGN(*args, affine=True, **kwargs): 189 | num_groups = kwargs.pop("num_groups", "full") 190 | num_groups = _process_num_groups(num_groups, args[0]) 191 | return nn.GroupNorm(num_groups, *args, affine=affine, **kwargs) 192 | 193 | 194 | class SubspaceGN(nn.GroupNorm): 195 | def __init__( 196 | self, 197 | num_features: int, 198 | eps: float = 1e-5, 199 | *, 200 | num_groups: Union[str, int], 201 | ) -> None: 202 | 203 | num_groups = _process_num_groups(num_groups, num_features) 204 | 205 | # Override @affine to be true. 206 | super().__init__( 207 | num_groups, 208 | num_features, 209 | eps=eps, 210 | affine=True, 211 | ) 212 | self.num_features = num_features 213 | 214 | def forward(self, input): 215 | # call get_weight, which samples from the subspace, then use the 216 | # corresponding weight. 217 | w, b = self.get_weight() 218 | return F.group_norm(input, self.num_groups, w, b, self.eps) 219 | 220 | 221 | class TwoParamGN(SubspaceGN): 222 | def __init__(self, *args, **kwargs): 223 | super().__init__(*args, **kwargs) 224 | self.weight1 = nn.Parameter(torch.empty([self.num_features])) 225 | self.bias1 = nn.Parameter(torch.empty([self.num_features])) 226 | torch.nn.init.ones_(self.weight1) 227 | torch.nn.init.zeros_(self.bias1) 228 | 229 | 230 | class LinesGN(TwoParamGN): 231 | def get_weight(self): 232 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 233 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1 234 | return w, b 235 | 236 | 237 | def _get_num_parameters(conv): 238 | in_channels = conv.in_channels 239 | out_channels = conv.out_channels 240 | 241 | if hasattr(conv, "in_channels_list"): 242 | in_channels_ratio = in_channels / max(conv.in_channels_list) 243 | out_channels_ratio = out_channels / max(conv.out_channels_list) 244 | else: 245 | in_channels_ratio = in_channels / conv.in_channels_max 246 | out_channels_ratio = out_channels / conv.out_channels_max 247 | 248 | ret = conv.weight.numel() 249 | ret = max(1, round(ret * in_channels_ratio * out_channels_ratio)) 250 | return ret 251 | 252 | 253 | # Adaptive modules (which adjust their number of channels at inference time). 254 | 255 | # This code contains the norm implementation used in our unstructured sparsity 256 | # experiments and baselines. Note that normally, we disallow storing or 257 | # recomputing BatchNorm statistics. However, we retain the ability to store 258 | # individual BatchNorm statistics purely for sanity-checking purposes (to ensure 259 | # our implementation produces similar results to the Universal Slimming paper, 260 | # when BatchNorms are stored). But, we don't use these results in any analysis. 261 | class AdaptiveNorm(nn.modules.batchnorm._NormBase): 262 | def __init__( 263 | self, 264 | bn_class, 265 | bn_func, 266 | mode, 267 | *args, 268 | ratio=1, 269 | width_factors_list=None, 270 | **kwargs, 271 | ): 272 | assert mode in ("BatchNorm", "InstanceNorm", "GroupNorm") 273 | 274 | kwargs_cpy = kwargs.copy() 275 | try: 276 | track_running_stats = kwargs_cpy.pop("track_running_stats") 277 | except KeyError: 278 | track_running_stats = False 279 | 280 | try: 281 | self.num_groups = kwargs_cpy.pop("num_groups") 282 | except KeyError: 283 | self.num_groups = None 284 | 285 | super().__init__( 286 | *args, 287 | affine=True, 288 | track_running_stats=track_running_stats, 289 | **kwargs_cpy, 290 | ) 291 | 292 | num_features = args[0] 293 | self.width_factors_list = width_factors_list 294 | self.num_features_max = num_features 295 | if mode == "BatchNorm" and self.width_factors_list is not None: 296 | print( 297 | f"Storing extra BatchNorm layers. This should only be used" 298 | f"for sanity checking, since it violates our goal of" 299 | f"arbitrarily fine-grained compression levels at inference" 300 | f"time." 301 | ) 302 | self.bn = nn.ModuleList( 303 | [ 304 | bn_class(i, affine=False) 305 | for i in [ 306 | max(1, round(self.num_features_max * width_factor)) 307 | for width_factor in self.width_factors_list 308 | ] 309 | ] 310 | ) 311 | if mode == "GroupNorm": 312 | if self.num_groups is None: 313 | raise ValueError("num_groups is required") 314 | if self.num_groups not in ("full", 1): 315 | # This must be "full" or 1, or the tensor might not be divisible 316 | # by @self.num_groups. 317 | raise ValueError(f"Invalid num_groups={self.num_groups}") 318 | 319 | self.ratio = ratio 320 | self.width_factor = None 321 | self.ignore_model_profiling = True 322 | self.bn_func = bn_func 323 | self.mode = mode 324 | 325 | def get_weight(self): 326 | return self.weight, self.bias 327 | 328 | def forward(self, input): 329 | weight, bias = self.get_weight() 330 | c = input.shape[1] 331 | if ( 332 | self.mode == "BatchNorm" 333 | and self.width_factors_list is not None 334 | and self.width_factor in self.width_factors_list 335 | ): 336 | # Normally, we expect width_factors_list to be empty, because we 337 | # only want to use it if we are running sanity checks (e.g. 338 | # recreating the original performance or something). 339 | idx = self.width_factors_list.index(self.width_factor) 340 | kwargs = { 341 | "input": input, 342 | "running_mean": self.bn[idx].running_mean[:c], 343 | "running_var": self.bn[idx].running_var[:c], 344 | "weight": weight[:c], 345 | "bias": bias[:c], 346 | "training": self.training, 347 | "momentum": self.momentum, 348 | "eps": self.eps, 349 | } 350 | elif self.mode in ("InstanceNorm", "BatchNorm"): 351 | # Sanity check, since we're not tracking running stats. 352 | running_mean = self.running_mean 353 | if self.running_mean is not None: 354 | running_mean = running_mean[:c] 355 | 356 | running_var = self.running_var 357 | if self.running_var is not None: 358 | running_var = running_var[:c] 359 | 360 | kwargs = { 361 | "input": input, 362 | "running_mean": running_mean, 363 | "running_var": running_var, 364 | "weight": weight[:c], 365 | "bias": bias[:c], 366 | "momentum": self.momentum, 367 | "eps": self.eps, 368 | } 369 | 370 | if self.mode == "BatchNorm": 371 | kwargs["training"] = self.training 372 | 373 | elif self.mode == "GroupNorm": 374 | num_groups = self.num_groups 375 | if num_groups == "full": 376 | num_groups = c 377 | kwargs = { 378 | "input": input, 379 | "num_groups": num_groups, 380 | "weight": weight[:c], 381 | "bias": bias[:c], 382 | "eps": self.eps, 383 | } 384 | else: 385 | raise NotImplementedError(f"Invalid mode {self.mode}.") 386 | 387 | return self.bn_func(**kwargs) 388 | 389 | 390 | class AdaptiveBN(AdaptiveNorm): 391 | def __init__(self, *args, **kwargs): 392 | norm_class = nn.BatchNorm2d 393 | norm_func = F.batch_norm 394 | super().__init__(norm_class, norm_func, "BatchNorm", *args, **kwargs) 395 | 396 | 397 | class AdaptiveIN(AdaptiveNorm): 398 | def __init__(self, *args, **kwargs): 399 | norm_class = nn.InstanceNorm2d 400 | norm_func = F.instance_norm 401 | super().__init__(norm_class, norm_func, "InstanceNorm", *args, **kwargs) 402 | 403 | 404 | class LinesAdaptiveIN(AdaptiveIN): 405 | def __init__(self, *args, **kwargs): 406 | super().__init__(*args, **kwargs) 407 | self.weight1 = nn.Parameter(torch.Tensor(self.num_features)) 408 | self.bias1 = nn.Parameter(torch.Tensor(self.num_features)) 409 | torch.nn.init.ones_(self.weight1) 410 | torch.nn.init.zeros_(self.bias1) 411 | 412 | def get_weight(self): 413 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 414 | b = (1 - self.alpha) * self.bias + self.alpha * self.bias1 415 | return w, b 416 | 417 | 418 | class AdaptiveConv2d(nn.Conv2d): 419 | def __init__( 420 | self, 421 | in_channels, 422 | out_channels, 423 | kernel_size, 424 | stride=1, 425 | padding=0, 426 | dilation=1, 427 | groups=1, 428 | bias=False, 429 | first_layer=False, 430 | last_layer=False, 431 | ratio=None, 432 | ): 433 | self.first_layer = first_layer 434 | self.last_layer = last_layer 435 | 436 | if ratio is None: 437 | ratio = [1, 1] 438 | 439 | super(AdaptiveConv2d, self).__init__( 440 | in_channels, 441 | out_channels, 442 | kernel_size, 443 | stride=stride, 444 | padding=padding, 445 | dilation=dilation, 446 | groups=groups, 447 | bias=bias, 448 | ) 449 | 450 | if groups == in_channels: 451 | assert in_channels == out_channels 452 | self.depthwise = True 453 | else: 454 | self.depthwise = False 455 | 456 | self.in_channels_max = in_channels 457 | self.out_channels_max = out_channels 458 | self.width_factor = None 459 | self.ratio = ratio 460 | 461 | def get_weight(self): 462 | return self.weight 463 | 464 | def forward(self, input): 465 | if not self.first_layer: 466 | self.in_channels = input.shape[1] 467 | if not self.last_layer: 468 | self.out_channels = max( 469 | 1, round(self.out_channels_max * self.width_factor) 470 | ) 471 | self.groups = self.in_channels if self.depthwise else 1 472 | weight = self.get_weight() 473 | weight = weight[: self.out_channels, : self.in_channels, :, :] 474 | assert self.bias is None 475 | bias = None 476 | y = nn.functional.conv2d( 477 | input, 478 | weight, 479 | bias, 480 | self.stride, 481 | self.padding, 482 | self.dilation, 483 | self.groups, 484 | ) 485 | return y 486 | 487 | def get_num_parameters(self): 488 | return _get_num_parameters(self) 489 | 490 | 491 | class LinesAdaptiveConv2d(AdaptiveConv2d): 492 | def __init__(self, *args, **kwargs): 493 | super().__init__(*args, **kwargs) 494 | self.weight1 = nn.Parameter(torch.empty_like(self.weight)) 495 | assert self.bias is None 496 | torch.nn.init.ones_(self.weight1) 497 | 498 | def get_weight(self): 499 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 500 | return w 501 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | from __future__ import absolute_import 6 | 7 | from .channel_selection import * # pylint: disable=unused-import 8 | # CIFAR-specific 9 | from .cpreresnet import * # pylint: disable=unused-import 10 | from .resnet import * # pylint: disable=unused-import 11 | from .vgg import * # pylint: disable=unused-import 12 | -------------------------------------------------------------------------------- /models/networks/channel_selection.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | # Only used for Learning Efficient Convolutions (LEC) experiments. 6 | # Deactivated in other cases. 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class channel_selection(nn.Module): 14 | """ 15 | Select channels from the output of BatchNorm2d layer. It should be put 16 | directly after BatchNorm2d layer. 17 | 18 | The output shape of this layer is determined by the number of 1s in 19 | `self.indexes`. 20 | """ 21 | 22 | def __init__(self, num_channels, active=True): 23 | """ 24 | Initialize the `indexes` with all one vector with the length same as the number of channels. 25 | During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0. 26 | """ 27 | super(channel_selection, self).__init__() 28 | self.indexes = nn.Parameter(torch.ones(num_channels)) 29 | self.active = active 30 | 31 | def forward(self, input_tensor): 32 | """ 33 | Parameter 34 | --------- 35 | input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer. 36 | """ 37 | if not self.active: 38 | return input_tensor 39 | selected_index = np.squeeze( 40 | np.argwhere(self.indexes.data.cpu().numpy()) 41 | ) 42 | if selected_index.size == 1: 43 | selected_index = np.resize(selected_index, (1,)) 44 | output = input_tensor[:, selected_index, :, :] 45 | return output 46 | 47 | def __repr__(self): 48 | return f"{self.__class__.__name__}(active={self.active})" 49 | -------------------------------------------------------------------------------- /models/networks/cpreresnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | from __future__ import absolute_import 6 | 7 | import math 8 | 9 | import torch.nn as nn 10 | 11 | from .channel_selection import channel_selection 12 | 13 | __all__ = ["cpreresnet"] 14 | 15 | """ 16 | Preactivation resnet with bottleneck design. 17 | """ 18 | 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__( 24 | self, 25 | inplanes, 26 | planes, 27 | cfg, 28 | stride=1, 29 | downsample=None, 30 | builder=None, 31 | channel_selection_active=True, 32 | ): 33 | super(Bottleneck, self).__init__() 34 | if builder is None: 35 | raise ValueError(f"Builder required, got None") 36 | 37 | self.bn1 = builder.batchnorm(inplanes) 38 | self.select = channel_selection( 39 | inplanes, active=channel_selection_active 40 | ) 41 | self.conv1 = builder.conv1x1(cfg[0], cfg[1]) 42 | self.bn2 = builder.batchnorm(cfg[1]) 43 | self.conv2 = builder.conv3x3(cfg[1], cfg[2], stride=stride) 44 | self.bn3 = builder.batchnorm(cfg[2]) 45 | self.conv3 = builder.conv1x1(cfg[2], planes * 4) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.bn1(x) 54 | out = self.select(out) 55 | out = self.relu(out) 56 | out = self.conv1(out) 57 | 58 | out = self.bn2(out) 59 | out = self.relu(out) 60 | out = self.conv2(out) 61 | 62 | out = self.bn3(out) 63 | out = self.relu(out) 64 | out = self.conv3(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | 71 | return out 72 | 73 | 74 | class cpreresnet(nn.Module): 75 | def __init__( 76 | self, 77 | depth=20, 78 | dataset="cifar10", 79 | cfg=None, 80 | builder=None, 81 | block_builder=None, 82 | channel_selection_active=True, 83 | ): 84 | super(cpreresnet, self).__init__() 85 | # assert block_builder is None, "Should not provide a block_builder." 86 | if builder is None: 87 | raise ValueError(f"Expected builder, got None.") 88 | assert (depth - 2) % 9 == 0, "depth should be 9n+2" 89 | 90 | n = (depth - 2) // 9 91 | block = Bottleneck 92 | 93 | if cfg is None: 94 | # Construct config variable. 95 | cfg = [ 96 | [16, 16, 16], 97 | [64, 16, 16] * (n - 1), 98 | [64, 32, 32], 99 | [128, 32, 32] * (n - 1), 100 | [128, 64, 64], 101 | [256, 64, 64] * (n - 1), 102 | [256], 103 | ] 104 | cfg = [item for sub_list in cfg for item in sub_list] 105 | 106 | self.inplanes = 16 107 | 108 | self.conv1 = builder.conv3x3(3, 16, first_layer=True) 109 | self.layer1 = self._make_layer( 110 | block, 111 | 16, 112 | n, 113 | cfg=cfg[0 : 3 * n], 114 | builder=block_builder, 115 | channel_selection_active=channel_selection_active, 116 | ) 117 | self.layer2 = self._make_layer( 118 | block, 119 | 32, 120 | n, 121 | cfg=cfg[3 * n : 6 * n], 122 | stride=2, 123 | builder=block_builder, 124 | channel_selection_active=channel_selection_active, 125 | ) 126 | self.layer3 = self._make_layer( 127 | block, 128 | 64, 129 | n, 130 | cfg=cfg[6 * n : 9 * n], 131 | stride=2, 132 | builder=block_builder, 133 | channel_selection_active=channel_selection_active, 134 | ) 135 | self.bn = builder.batchnorm(64 * block.expansion) 136 | self.select = channel_selection( 137 | 64 * block.expansion, active=channel_selection_active 138 | ) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.avgpool = nn.AvgPool2d(8) 141 | 142 | # We work only with CIFAR-10 143 | num_categories = 10 144 | 145 | self.fc = builder.conv1x1(cfg[-1], num_categories, last_layer=True) 146 | 147 | # Weight initialization 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d) and m != self.fc: 150 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | i = 1 152 | while hasattr(m, f"weight{i}"): 153 | weight = getattr(m, f"weight{i}") 154 | weight.data.normal_(0, math.sqrt(2.0 / n)) 155 | i += 1 156 | elif isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 157 | i = 1 158 | while hasattr(m, f"weight{i}"): 159 | weight = getattr(m, f"weight{i}") 160 | weight.data.fill_(0.5) 161 | bias = getattr(m, f"bias{i}") 162 | bias.data.zero_() 163 | i += 1 164 | 165 | def _make_layer( 166 | self, 167 | block, 168 | planes, 169 | blocks, 170 | cfg, 171 | stride=1, 172 | builder=None, 173 | channel_selection_active=True, 174 | ): 175 | if builder is None: 176 | raise ValueError(f"Expected builder, got None.") 177 | downsample = None 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | builder.conv1x1( 181 | self.inplanes, planes * block.expansion, stride=stride 182 | ), 183 | ) 184 | 185 | layers = [] 186 | layers.append( 187 | block( 188 | self.inplanes, 189 | planes, 190 | cfg[0:3], 191 | stride, 192 | downsample, 193 | builder=builder, 194 | channel_selection_active=channel_selection_active, 195 | ) 196 | ) 197 | self.inplanes = planes * block.expansion 198 | for i in range(1, blocks): 199 | layers.append( 200 | block( 201 | self.inplanes, 202 | planes, 203 | cfg[3 * i : 3 * (i + 1)], 204 | builder=builder, 205 | channel_selection_active=channel_selection_active, 206 | ) 207 | ) 208 | 209 | return nn.Sequential(*layers) 210 | 211 | def forward(self, x): 212 | x = self.conv1(x) 213 | 214 | x = self.layer1(x) # 32x32 215 | x = self.layer2(x) # 16x16 216 | x = self.layer3(x) # 8x8 217 | x = self.bn(x) 218 | x = self.select(x) 219 | x = self.relu(x) 220 | 221 | x = self.avgpool(x) 222 | assert x.shape[2:] == (1, 1) 223 | x = self.fc(x) 224 | x = x.view(x.size(0), -1) 225 | 226 | return x 227 | -------------------------------------------------------------------------------- /models/networks/model_profiling.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | model_profiling_hooks = [] 12 | model_profiling_speed_hooks = [] 13 | 14 | name_space = 95 15 | params_space = 15 16 | macs_space = 15 17 | seconds_space = 15 18 | 19 | num_forwards = 10 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, verbose=False): 24 | self.verbose = verbose 25 | self.start = None 26 | self.end = None 27 | 28 | def __enter__(self): 29 | self.start = time.time() 30 | return self 31 | 32 | def __exit__(self, *args): 33 | self.end = time.time() 34 | self.time = self.end - self.start 35 | if self.verbose: 36 | print("Elapsed time: %f ms." % self.time) 37 | 38 | 39 | def get_params(module): 40 | """get number of params in module""" 41 | if hasattr(module, "width_factor"): 42 | # We assume it's a conv layer. 43 | ret = module.get_num_parameters() 44 | else: 45 | ret = np.sum([np.prod(list(w.size())) for w in module.parameters()]) 46 | return ret 47 | 48 | 49 | def run_forward(module, input): 50 | with Timer() as t: 51 | for _ in range(num_forwards): 52 | module.forward(*input) 53 | if input[0].is_cuda: 54 | torch.cuda.synchronize() 55 | return int(t.time * 1e9 / num_forwards) 56 | 57 | 58 | def conv_module_name_filter(name): 59 | """filter module name to have a short view""" 60 | filters = { 61 | "kernel_size": "k", 62 | "stride": "s", 63 | "padding": "pad", 64 | "bias": "b", 65 | "groups": "g", 66 | } 67 | for k in filters: 68 | name = name.replace(k, filters[k]) 69 | return name 70 | 71 | 72 | def module_profiling(module, input, output, verbose): 73 | if not isinstance(input[0], list): 74 | # Some modules return a list of outputs. We usually ignore them. 75 | ins = input[0].size() 76 | outs = output.size() 77 | # NOTE: There are some difference between type and isinstance, thus please 78 | # be careful. 79 | t = type(module) 80 | if isinstance(module, nn.Conv2d): 81 | module.n_macs = ( 82 | ins[1] 83 | * outs[1] 84 | * module.kernel_size[0] 85 | * module.kernel_size[1] 86 | * outs[2] 87 | * outs[3] 88 | // module.groups 89 | ) * outs[0] 90 | module.n_params = get_params(module) 91 | module.n_seconds = run_forward(module, input) 92 | module.name = conv_module_name_filter(module.__repr__()) 93 | elif isinstance(module, nn.ConvTranspose2d): 94 | module.n_macs = ( 95 | ins[1] 96 | * outs[1] 97 | * module.kernel_size[0] 98 | * module.kernel_size[1] 99 | * outs[2] 100 | * outs[3] 101 | // module.groups 102 | ) * outs[0] 103 | module.n_params = get_params(module) 104 | module.n_seconds = run_forward(module, input) 105 | module.name = conv_module_name_filter(module.__repr__()) 106 | elif isinstance(module, nn.Linear): 107 | module.n_macs = ins[1] * outs[1] * outs[0] 108 | module.n_params = get_params(module) 109 | module.n_seconds = run_forward(module, input) 110 | module.name = module.__repr__() 111 | elif isinstance(module, nn.AvgPool2d): 112 | # NOTE: this function is correct only when stride == kernel size 113 | module.n_macs = ins[1] * ins[2] * ins[3] * ins[0] 114 | module.n_params = 0 115 | module.n_seconds = run_forward(module, input) 116 | module.name = module.__repr__() 117 | elif isinstance(module, nn.AdaptiveAvgPool2d): 118 | # NOTE: this function is correct only when stride == kernel size 119 | module.n_macs = ins[1] * ins[2] * ins[3] * ins[0] 120 | module.n_params = 0 121 | module.n_seconds = run_forward(module, input) 122 | module.name = module.__repr__() 123 | else: 124 | # This works only in depth-first travel of modules. 125 | module.n_macs = 0 126 | module.n_params = 0 127 | module.n_seconds = 0 128 | num_children = 0 129 | for m in module.children(): 130 | module.n_macs += getattr(m, "n_macs", 0) 131 | module.n_params += getattr(m, "n_params", 0) 132 | module.n_seconds += getattr(m, "n_seconds", 0) 133 | num_children += 1 134 | ignore_zeros_t = [ 135 | nn.BatchNorm2d, 136 | nn.InstanceNorm2d, 137 | nn.Dropout2d, 138 | nn.Dropout, 139 | nn.Sequential, 140 | nn.ReLU6, 141 | nn.ReLU, 142 | nn.MaxPool2d, 143 | nn.modules.padding.ZeroPad2d, 144 | nn.modules.activation.Sigmoid, 145 | ] 146 | if ( 147 | not getattr(module, "ignore_model_profiling", False) 148 | and module.n_macs == 0 149 | and t not in ignore_zeros_t 150 | ): 151 | print( 152 | "WARNING: leaf module {} has zero n_macs.".format(type(module)) 153 | ) 154 | return 155 | if verbose: 156 | print( 157 | module.name.ljust(name_space, " ") 158 | + "{:,}".format(module.n_params).rjust(params_space, " ") 159 | + "{:,}".format(module.n_macs).rjust(macs_space, " ") 160 | + "{:,}".format(module.n_seconds).rjust(seconds_space, " ") 161 | ) 162 | return 163 | 164 | 165 | def add_profiling_hooks(m, verbose): 166 | global model_profiling_hooks 167 | model_profiling_hooks.append( 168 | m.register_forward_hook( 169 | lambda m, input, output: module_profiling( 170 | m, input, output, verbose=verbose 171 | ) 172 | ) 173 | ) 174 | 175 | 176 | def remove_profiling_hooks(): 177 | global model_profiling_hooks 178 | for h in model_profiling_hooks: 179 | h.remove() 180 | model_profiling_hooks = [] 181 | 182 | 183 | def model_profiling( 184 | model, height, width, batch=1, channel=3, use_cuda=True, verbose=True 185 | ): 186 | """Pytorch model profiling with input image size 187 | (batch, channel, height, width). 188 | The function exams the number of multiply-accumulates (n_macs). 189 | 190 | Args: 191 | model: pytorch model 192 | height: int 193 | width: int 194 | batch: int 195 | channel: int 196 | use_cuda: bool 197 | 198 | Returns: 199 | macs: int 200 | params: int 201 | 202 | """ 203 | if isinstance(model, nn.DataParallel): 204 | model = model.module 205 | model.eval() 206 | data = torch.rand(batch, channel, height, width) 207 | device = torch.device("cuda:0" if use_cuda else "cpu") 208 | model = model.to(device) 209 | data = data.to(device) 210 | model.apply(lambda m: add_profiling_hooks(m, verbose=verbose)) 211 | print( 212 | "Item".ljust(name_space, " ") 213 | + "params".rjust(macs_space, " ") 214 | + "macs".rjust(macs_space, " ") 215 | + "nanosecs".rjust(seconds_space, " ") 216 | ) 217 | if verbose: 218 | print( 219 | "".center( 220 | name_space + params_space + macs_space + seconds_space, "-" 221 | ) 222 | ) 223 | model(data) 224 | if verbose: 225 | print( 226 | "".center( 227 | name_space + params_space + macs_space + seconds_space, "-" 228 | ) 229 | ) 230 | print( 231 | "Total".ljust(name_space, " ") 232 | + "{:,}".format(model.n_params).rjust(params_space, " ") 233 | + "{:,}".format(model.n_macs).rjust(macs_space, " ") 234 | + "{:,}".format(model.n_seconds).rjust(seconds_space, " ") 235 | ) 236 | remove_profiling_hooks() 237 | return model.n_macs, model.n_params 238 | -------------------------------------------------------------------------------- /models/networks/resnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch.nn as nn 6 | 7 | 8 | # BasicBlock {{{ 9 | class BasicBlock(nn.Module): 10 | M = 2 11 | expansion = 1 12 | 13 | def __init__( 14 | self, 15 | builder, 16 | inplanes, 17 | planes, 18 | stride=1, 19 | downsample=None, 20 | base_width=64, 21 | ): 22 | super(BasicBlock, self).__init__() 23 | if base_width / 64 > 1: 24 | raise ValueError("Base width >64 does not work for BasicBlock") 25 | 26 | self.conv1 = builder.conv3x3(inplanes, planes, stride) 27 | self.bn1 = builder.batchnorm(planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = builder.conv3x3(planes, planes) 30 | self.bn2 = builder.batchnorm(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | if self.bn1 is not None: 39 | out = self.bn1(out) 40 | 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | 45 | if self.bn2 is not None: 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | # BasicBlock }}} 58 | 59 | # Bottleneck {{{ 60 | class Bottleneck(nn.Module): 61 | M = 3 62 | expansion = 4 63 | 64 | def __init__( 65 | self, 66 | builder, 67 | inplanes, 68 | planes, 69 | stride=1, 70 | downsample=None, 71 | base_width=64, 72 | ): 73 | super(Bottleneck, self).__init__() 74 | width = int(planes * base_width / 64) 75 | self.conv1 = builder.conv1x1(inplanes, width) 76 | self.bn1 = builder.batchnorm(width) 77 | self.conv2 = builder.conv3x3(width, width, stride=stride) 78 | self.bn2 = builder.batchnorm(width) 79 | self.conv3 = builder.conv1x1(width, planes * self.expansion) 80 | self.bn3 = builder.batchnorm(planes * self.expansion) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | residual = self.downsample(x) 101 | 102 | out += residual 103 | 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | # Bottleneck }}} 110 | 111 | # ResNet {{{ 112 | class ResNet(nn.Module): 113 | def __init__( 114 | self, 115 | builder, 116 | block_builder, 117 | block, 118 | layers, 119 | num_classes=1000, 120 | base_width=64, 121 | ): 122 | self.inplanes = 64 123 | super(ResNet, self).__init__() 124 | 125 | self.base_width = base_width 126 | if self.base_width // 64 > 1: 127 | print(f"==> Using {self.base_width // 64}x wide model") 128 | 129 | self.conv1 = builder.conv7x7(3, 64, stride=2, first_layer=True) 130 | 131 | self.bn1 = builder.batchnorm(64) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | self.layer1 = self._make_layer(block_builder, block, 64, layers[0]) 135 | self.layer2 = self._make_layer( 136 | block_builder, block, 128, layers[1], stride=2 137 | ) 138 | self.layer3 = self._make_layer( 139 | block_builder, block, 256, layers[2], stride=2 140 | ) 141 | self.layer4 = self._make_layer( 142 | block_builder, block, 512, layers[3], stride=2 143 | ) 144 | self.avgpool = nn.AdaptiveAvgPool2d(1) 145 | self.return_feats = False 146 | 147 | self.fc = builder.conv1x1( 148 | 512 * block.expansion, num_classes, last_layer=True 149 | ) 150 | 151 | def _make_layer(self, builder, block, planes, blocks, stride=1): 152 | downsample = None 153 | if stride != 1 or self.inplanes != planes * block.expansion: 154 | dconv = builder.conv1x1( 155 | self.inplanes, planes * block.expansion, stride=stride 156 | ) 157 | dbn = builder.batchnorm(planes * block.expansion) 158 | if dbn is not None: 159 | downsample = nn.Sequential(dconv, dbn) 160 | else: 161 | downsample = dconv 162 | 163 | layers = [] 164 | layers.append( 165 | block( 166 | builder, 167 | self.inplanes, 168 | planes, 169 | stride, 170 | downsample, 171 | base_width=self.base_width, 172 | ) 173 | ) 174 | self.inplanes = planes * block.expansion 175 | for i in range(1, blocks): 176 | layers.append( 177 | block( 178 | builder, self.inplanes, planes, base_width=self.base_width 179 | ) 180 | ) 181 | 182 | return nn.Sequential(*layers) 183 | 184 | def forward(self, x): 185 | x = self.conv1(x) 186 | 187 | if self.bn1 is not None: 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | 192 | x = self.layer1(x) 193 | x = self.layer2(x) 194 | x = self.layer3(x) 195 | x = self.layer4(x) 196 | 197 | feats = self.avgpool(x) 198 | x = self.fc(feats) 199 | x = x.view(x.size(0), -1) 200 | 201 | if self.return_feats: 202 | return x, feats.view(feats.size(0), -1) 203 | return x 204 | 205 | 206 | def _get_output_size(dataset: str): 207 | return { 208 | "cifar10": 10, 209 | "imagenet": 1000, 210 | }[dataset] 211 | 212 | 213 | # ResNet }}} 214 | def ResNet18(builder, block_builder, dataset): 215 | return ResNet( 216 | builder, 217 | block_builder, 218 | BasicBlock, 219 | [2, 2, 2, 2], 220 | _get_output_size(dataset), 221 | ) 222 | -------------------------------------------------------------------------------- /models/networks/resprune.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | # Note: this only works for cPreResNet20 (implementation provided by LEC paper). 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | from models import networks 12 | 13 | from .cpreresnet import * 14 | 15 | 16 | def get_slimmed_network(model, model_kwargs, cfg, cfg_mask): 17 | assert not isinstance(model, nn.DataParallel), f"Must unwrap DataParallel" 18 | 19 | is_cuda = next(model.parameters()).is_cuda 20 | print("Cfg:") 21 | print(cfg) 22 | 23 | newmodel = cpreresnet(cfg=cfg, **model_kwargs) 24 | 25 | if is_cuda: 26 | newmodel.cuda() 27 | 28 | old_modules = list(model.modules()) 29 | new_modules = list(newmodel.modules()) 30 | layer_id_in_cfg = 0 31 | start_mask = torch.ones(3) 32 | end_mask = cfg_mask[layer_id_in_cfg] 33 | conv_count = 0 34 | 35 | for layer_id in range(len(old_modules)): 36 | m0 = old_modules[layer_id] 37 | m1 = new_modules[layer_id] 38 | 39 | if isinstance(m0, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 40 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 41 | if idx1.size == 1: 42 | idx1 = np.resize(idx1, (1,)) 43 | 44 | if isinstance( 45 | old_modules[layer_id + 1], networks.channel_selection 46 | ): 47 | # If the next layer is the channel selection layer, then the 48 | # current batchnorm 2d layer won't be pruned. 49 | m1.weight.data = m0.weight.data.clone() 50 | m1.bias.data = m0.bias.data.clone() 51 | if m0.running_mean is None: 52 | m1.running_mean = m0.running_mean 53 | m1.running_var = m0.running_var 54 | else: 55 | m1.running_mean.data = m0.running_mean.clone() 56 | m1.running_var.data = m0.running_var.clone() 57 | 58 | # We need to set the channel selection layer. 59 | m2 = new_modules[layer_id + 1] 60 | m2.indexes.data.zero_() 61 | m2.indexes.data[idx1.tolist()] = 1.0 62 | 63 | layer_id_in_cfg += 1 64 | start_mask = end_mask.clone() 65 | if layer_id_in_cfg < len(cfg_mask): 66 | end_mask = cfg_mask[layer_id_in_cfg] 67 | else: 68 | m1.weight.data = m0.weight.data[idx1.tolist()].clone() 69 | m1.bias.data = m0.bias.data[idx1.tolist()].clone() 70 | if m0.running_mean is None: 71 | m1.running_mean = m0.running_mean 72 | m1.running_var = m0.running_var 73 | else: 74 | m1.running_mean = m0.running_mean[idx1.tolist()].clone() 75 | m1.running_var = m0.running_var[idx1.tolist()].clone() 76 | layer_id_in_cfg += 1 77 | start_mask = end_mask.clone() 78 | if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC 79 | end_mask = cfg_mask[layer_id_in_cfg] 80 | elif isinstance(m0, nn.Conv2d) and m0 != model.fc: 81 | if conv_count == 0: 82 | m1.weight.data = m0.weight.data.clone() 83 | conv_count += 1 84 | continue 85 | if isinstance( 86 | old_modules[layer_id - 1], networks.channel_selection 87 | ) or isinstance( 88 | old_modules[layer_id - 1], 89 | (nn.modules.batchnorm._NormBase, nn.GroupNorm), 90 | ): 91 | # This convers the convolutions in the residual block. 92 | # The convolutions are either after the channel selection layer 93 | # or after the batch normalization layer. 94 | conv_count += 1 95 | idx0 = np.squeeze( 96 | np.argwhere(np.asarray(start_mask.cpu().numpy())) 97 | ) 98 | idx1 = np.squeeze( 99 | np.argwhere(np.asarray(end_mask.cpu().numpy())) 100 | ) 101 | if idx0.size == 1: 102 | idx0 = np.resize(idx0, (1,)) 103 | if idx1.size == 1: 104 | idx1 = np.resize(idx1, (1,)) 105 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() 106 | 107 | # If the current convolution is not the last convolution in the 108 | # residual block, then we can change the number of output 109 | # channels. Currently we use `conv_count` to detect whether it 110 | # is such convolution. 111 | if conv_count % 3 != 1: 112 | w1 = w1[idx1.tolist(), :, :, :].clone() 113 | m1.weight.data = w1.clone() 114 | continue 115 | 116 | # We need to consider the case where there are downsampling 117 | # convolutions. For these convolutions, we just copy the weights. 118 | m1.weight.data = m0.weight.data.clone() 119 | elif m0 == model.fc: 120 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 121 | if idx0.size == 1: 122 | idx0 = np.resize(idx0, (1,)) 123 | 124 | m1.weight.data = m0.weight.data[:, idx0].clone() 125 | 126 | assert m1.bias is None == m0.bias is None 127 | if m1.bias is not None: 128 | m1.bias.data = m0.bias.data.clone() 129 | 130 | num_parameters = sum([param.nelement() for param in newmodel.parameters()]) 131 | 132 | return num_parameters, newmodel 133 | -------------------------------------------------------------------------------- /models/networks/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def get_slim_configs(model): 10 | # Note: the model should *already* be pruned on the BN layers. This module 11 | # will not apply the pruning part. 12 | # We expect the user to create a (pruned) copy of the model first before 13 | # calling this. 14 | is_cuda = next(model.parameters()).is_cuda 15 | 16 | cfg = [] 17 | cfg_mask = [] 18 | for k, m in enumerate(model.modules()): 19 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 20 | mask = m.weight.abs() > 0 21 | if is_cuda: 22 | mask = mask.cuda() 23 | cfg.append(int(torch.sum(mask))) 24 | cfg_mask.append(mask.clone()) 25 | print( 26 | "layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}".format( 27 | k, mask.shape[0], int(torch.sum(mask)) 28 | ) 29 | ) 30 | elif isinstance(m, nn.MaxPool2d): 31 | cfg.append("M") 32 | 33 | return cfg, cfg_mask 34 | -------------------------------------------------------------------------------- /models/networks/vgg.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import math 6 | 7 | import torch.nn as nn 8 | 9 | __all__ = ["vgg", "vgg11", "vgg13", "vgg16", "vgg19"] 10 | 11 | defaultcfg = { 12 | 11: [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512], 13 | 13: [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512], 14 | 16: [ 15 | 64, 16 | 64, 17 | "M", 18 | 128, 19 | 128, 20 | "M", 21 | 256, 22 | 256, 23 | 256, 24 | "M", 25 | 512, 26 | 512, 27 | 512, 28 | "M", 29 | 512, 30 | 512, 31 | 512, 32 | ], 33 | 19: [ 34 | 64, 35 | 64, 36 | "M", 37 | 128, 38 | 128, 39 | "M", 40 | 256, 41 | 256, 42 | 256, 43 | 256, 44 | "M", 45 | 512, 46 | 512, 47 | 512, 48 | 512, 49 | "M", 50 | 512, 51 | 512, 52 | 512, 53 | 512, 54 | ], 55 | } 56 | 57 | 58 | class vgg(nn.Module): 59 | def __init__( 60 | self, 61 | dataset="imagenet", 62 | depth=19, 63 | init_weights=True, 64 | cfg=None, 65 | builder=None, 66 | block_builder=None, 67 | ): 68 | super(vgg, self).__init__() 69 | if cfg is None: 70 | cfg = defaultcfg[depth] 71 | 72 | self.feature = self.make_layers( 73 | cfg, True, builder=builder, block_builder=block_builder 74 | ) 75 | 76 | if dataset == "imagenet": 77 | num_classes = 1000 78 | else: 79 | raise NotImplementedError(f"Not implemented for dataset {dataset}") 80 | self.classifier = builder.conv1x1(cfg[-1], num_classes, last_layer=True) 81 | 82 | if init_weights: 83 | self._initialize_weights() 84 | 85 | def make_layers( 86 | self, cfg, batch_norm=False, builder=None, block_builder=None 87 | ): 88 | layers = [] 89 | in_channels = 3 90 | first_conv_layer = True 91 | for v in cfg: 92 | if v == "M": 93 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 94 | else: 95 | if first_conv_layer: 96 | my_builder = builder 97 | else: 98 | my_builder = block_builder 99 | conv2d = my_builder.conv3x3( 100 | in_channels, v, first_layer=first_conv_layer 101 | ) 102 | first_conv_layer = False 103 | if batch_norm: 104 | bn = my_builder.batchnorm(v) 105 | layers += [conv2d, bn, nn.ReLU(inplace=True)] 106 | else: 107 | layers += [conv2d, nn.ReLU(inplace=True)] 108 | in_channels = v 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | x = self.feature(x) 113 | x = nn.AvgPool2d(2)(x) 114 | y = self.classifier(x) 115 | y = y.view(y.size(0), -1) 116 | return y 117 | 118 | def _initialize_weights(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | if m.last_layer: 122 | assert m.kernel_size == (1, 1) 123 | # Fully Connected Layer 124 | i = 1 125 | while hasattr(m, f"weight{i}"): 126 | weight = getattr(m, f"weight{i}") 127 | weight.data.normal_(0, 0.01) 128 | bias = getattr(m, f"bias{i}", None) 129 | if bias is not None: 130 | bias.data.zero_() 131 | i += 1 132 | else: 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | i = 1 135 | while hasattr(m, f"weight{i}"): 136 | weight = getattr(m, f"weight{i}") 137 | weight.data.normal_(0, math.sqrt(2.0 / n)) 138 | bias = getattr(m, f"bias{i}", None) 139 | if bias is not None: 140 | bias.data.zero_() 141 | i += 1 142 | 143 | elif isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 144 | i = 1 145 | while hasattr(m, f"weight{i}"): 146 | weight = getattr(m, f"weight{i}") 147 | weight.data.fill_(0.5) 148 | bias = getattr(m, f"bias{i}", None) 149 | if bias is not None: 150 | bias.data.zero_() 151 | i += 1 152 | elif isinstance(m, nn.Linear): 153 | i = 1 154 | while hasattr(m, f"weight{i}"): 155 | weight = getattr(m, f"weight{i}") 156 | weight.data.normal_(0, 0.01) 157 | bias = getattr(m, f"bias{i}", None) 158 | if bias is not None: 159 | bias.data.zero_() 160 | i += 1 161 | 162 | 163 | def vgg11(**kwargs): 164 | return vgg(depth=11, **kwargs) 165 | 166 | 167 | def vgg13(**kwargs): 168 | return vgg(depth=13, **kwargs) 169 | 170 | 171 | def vgg16(**kwargs): 172 | return vgg(depth=16, **kwargs) 173 | 174 | 175 | def vgg19(**kwargs): 176 | return vgg(depth=19, **kwargs) 177 | -------------------------------------------------------------------------------- /models/networks/vggprune.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .vgg import * 10 | 11 | 12 | def get_slimmed_network(model, model_kwargs, cfg, cfg_mask): 13 | assert not isinstance(model, nn.DataParallel), f"Must unwrap DataParallel" 14 | 15 | is_cuda = next(model.parameters()).is_cuda 16 | print("Cfg:") 17 | print(cfg) 18 | 19 | newmodel = vgg(cfg=cfg, **model_kwargs) 20 | 21 | if is_cuda: 22 | newmodel.cuda() 23 | 24 | old_modules = list(model.modules()) 25 | new_modules = list(newmodel.modules()) 26 | layer_id_in_cfg = 0 27 | start_mask = torch.ones(3) 28 | end_mask = cfg_mask[layer_id_in_cfg] 29 | 30 | for layer_id in range(len(old_modules)): 31 | m0 = old_modules[layer_id] 32 | m1 = new_modules[layer_id] 33 | 34 | if isinstance(m0, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 35 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 36 | if idx1.size == 1: 37 | idx1 = np.resize(idx1, (1,)) 38 | m1.weight.data = m0.weight.data[idx1.tolist()].clone() 39 | m1.bias.data = m0.bias.data[idx1.tolist()].clone() 40 | m1.running_mean = m0.running_mean[idx1.tolist()].clone() 41 | m1.running_var = m0.running_var[idx1.tolist()].clone() 42 | layer_id_in_cfg += 1 43 | start_mask = end_mask.clone() 44 | if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC 45 | end_mask = cfg_mask[layer_id_in_cfg] 46 | elif isinstance(m0, nn.Conv2d) and m0 != model.classifier: 47 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 48 | idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) 49 | print( 50 | "In shape: {:d}, Out shape {:d}.".format(idx0.size, idx1.size) 51 | ) 52 | if idx0.size == 1: 53 | idx0 = np.resize(idx0, (1,)) 54 | if idx1.size == 1: 55 | idx1 = np.resize(idx1, (1,)) 56 | w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() 57 | w1 = w1[idx1.tolist(), :, :, :].clone() 58 | m1.weight.data = w1.clone() 59 | elif m0 == model.classifier: 60 | idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) 61 | if idx0.size == 1: 62 | idx0 = np.resize(idx0, (1,)) 63 | 64 | m1.weight.data = m0.weight.data[:, idx0].clone() 65 | 66 | assert m1.bias is None == m0.bias is None 67 | if m1.bias is not None: 68 | m1.bias.data = m0.bias.data.clone() 69 | 70 | num_parameters = sum([param.nelement() for param in newmodel.parameters()]) 71 | 72 | return num_parameters, newmodel 73 | -------------------------------------------------------------------------------- /models/quantize_affine.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import collections 6 | import numbers 7 | from typing import Any 8 | from typing import Optional 9 | 10 | import numpy as np 11 | import torch 12 | from torch import autograd 13 | from torch import nn 14 | 15 | from .special_tensors import RepresentibleByQuantizeAffine 16 | from .special_tensors import tag_with_metadata 17 | 18 | QuantizeAffineParams2 = collections.namedtuple( 19 | "QuantizeAffineParams", ["scale", "zero_point", "num_bits"] 20 | ) 21 | 22 | INFINITY = 1e10 23 | 24 | 25 | def _validate_tensor(tensor: torch.Tensor) -> None: 26 | if torch.isnan(tensor).any(): 27 | raise ValueError("Found NaN in the tensor.") 28 | if tensor.abs().max() > INFINITY: 29 | raise ValueError( 30 | "Tensor seems to be diverging. Found a value > {}".format(INFINITY) 31 | ) 32 | 33 | 34 | def get_quantized_representation( 35 | tensor: torch.Tensor, 36 | quantize_params: QuantizeAffineParams2, 37 | ) -> torch.Tensor: 38 | """Gets the quantize representation of a float @tensor. 39 | 40 | The resulting tensor will contain the quantized values and the quantization 41 | parameters will be tagged with the tensor as a special tensor. 42 | 43 | A ValueError will be raised if the given tensor contains NaN or divergent 44 | values. 45 | 46 | Arguments: 47 | tensor (torch.Tensor): The float torch tensor to quantize. 48 | quantize_params (QuantizeAffineParams): The quantization params to 49 | quantize the tensor by. 50 | """ 51 | _validate_tensor(tensor) 52 | 53 | scale = quantize_params.scale 54 | zero_point = quantize_params.zero_point 55 | num_bits = quantize_params.num_bits 56 | if scale == 0: 57 | # Special case, all elements are zeros. 58 | if zero_point != 0: 59 | raise ValueError( 60 | "The given QuantizeAffineParams (={}) has a non-zero zero point" 61 | " with a scale of 0.".format(quantize_params) 62 | ) 63 | quantized_tensor = torch.zeros_like(tensor) 64 | tag_with_metadata(quantized_tensor, quantize_params) 65 | return quantized_tensor 66 | 67 | qmin, qmax = get_qmin_qmax(num_bits) 68 | reciprocal = 1 / scale 69 | quantized_tensor = ((tensor * reciprocal).round_() + zero_point).clamp_( 70 | qmin, qmax 71 | ) 72 | 73 | tag_with_metadata(quantized_tensor, quantize_params) 74 | return quantized_tensor 75 | 76 | 77 | def mark_quantize_affine( 78 | tensor: torch.Tensor, 79 | scale: float, 80 | zero_point: int, 81 | dtype: np.dtype = np.uint8, 82 | ) -> None: 83 | """Mark a tensor as quantized with affine. 84 | 85 | Arguments: 86 | tensor (torch.Tensor): The tensor to be marked as affine-quantizable 87 | Tensor. 88 | scale (float): the scale (from quantization parameters). 89 | zero_point (int): The zero_point (from quantization parameters). 90 | dtype (numpy.dtype): Type of tensor when quantized (this is usually 91 | numpy.uint8, which is used for Q8). A ValueError will be thrown if 92 | the input dtype is not one of the following: 93 | {numpy.uint8, numpy.int32}. 94 | """ 95 | quant_params = QuantizeAffineParams2(scale, zero_point, dtype) 96 | tag_with_metadata(tensor, RepresentibleByQuantizeAffine(quant_params)) 97 | 98 | 99 | class QuantizeAffineFunction(autograd.Function): 100 | """Simulates affect of affine quantization during forward pass. 101 | 102 | This function simulates the affect of quantization and subsequent 103 | dequantization (in the forward pass only). Although the affine 104 | transformation results in a different basis (e.g. uint8), the output of this 105 | function will be a float Tensor representing that transformation (the 106 | dequantized Tensor). 107 | 108 | A ValueError will be raised if the input or resulting tensor contains NaN or 109 | divergent values. 110 | 111 | Arguments: 112 | input (Tensor): The input float Tensor to be quantized. 113 | quantize_params (quantize_affine_util.QuantizeAffineParams): The 114 | quantization parameter to quantize the input tensor by. 115 | """ 116 | 117 | @staticmethod 118 | def forward( 119 | ctx: Any, 120 | input: torch.Tensor, 121 | quantize_params: QuantizeAffineParams2, 122 | ) -> torch.Tensor: 123 | quantized_tensor = get_quantized_representation(input, quantize_params) 124 | dequantized_tensor = dequantize(quantized_tensor, quantize_params) 125 | 126 | mark_quantize_affine( 127 | dequantized_tensor, 128 | quantize_params.scale, 129 | quantize_params.zero_point, 130 | quantize_params.num_bits, 131 | ) 132 | return dequantized_tensor 133 | 134 | @staticmethod 135 | def backward(ctx: Any, grad_output: Any) -> Any: 136 | """We will approximate the gradient as the identity""" 137 | return grad_output, None 138 | 139 | 140 | def quantize_affine_function_continuous( 141 | input: torch.Tensor, 142 | quantize_params: QuantizeAffineParams2, 143 | ) -> torch.Tensor: 144 | quantized_tensor = get_quantized_representation(input, quantize_params) 145 | dequantized_tensor = dequantize(quantized_tensor, quantize_params) 146 | 147 | mark_quantize_affine( 148 | dequantized_tensor, 149 | quantize_params.scale, 150 | quantize_params.zero_point, 151 | quantize_params.num_bits, 152 | ) 153 | return dequantized_tensor 154 | 155 | 156 | def get_qmin_qmax(num_bits): 157 | return -(2 ** (num_bits - 1)), 2 ** (num_bits - 1) - 1 158 | 159 | 160 | def get_quantization_params( 161 | rmin: float, 162 | rmax: float, 163 | num_bits: int = 8, 164 | ) -> QuantizeAffineParams2: 165 | """Returns QuantizeAffineParams for a data range [rmin, rmax]. 166 | 167 | The range must include 0 otherwise that's a failure. The scale and 168 | zero_point are picked such that the error is quantization error is 169 | minimized. 170 | 171 | Arguments: 172 | rmin (float): The data minimum point. Numbers smaller than rmin would 173 | not be representible by the quantized schema. 174 | rmax (float): The data maximum point. Numbers bigger than rmax would 175 | not be representible by the quantized schema. 176 | dtype (optional, np.dtype): The dtype that should be used to represent 177 | the individual numbers after quantization. Only np.uint8 is 178 | supported. 179 | """ 180 | if rmin > rmax: 181 | raise ValueError("Got rmin (={}) > rmax (={}).".format(rmin, rmax)) 182 | if rmin > 0 or rmax < 0: 183 | raise ValueError( 184 | "The data range ([{}, {}]) must always include " 185 | "0.".format(rmin, rmax) 186 | ) 187 | 188 | if rmin == rmax == 0.0: 189 | # Special case: all values are zero. 190 | return QuantizeAffineParams2(0, 0, num_bits) 191 | 192 | # Scale is floating point and is (rmax - rmin) / (qmax - qmin) to map the 193 | # length of the ranges. For zero_point, we solve the following equation: 194 | # rmin = (qmin - zero_point) * scale 195 | qmin, qmax = get_qmin_qmax(num_bits) 196 | scale = (rmax - rmin) / (qmax - qmin) 197 | zero_point = qmin - (rmin / scale) 198 | zero_point = np.clip(round(zero_point), qmin, qmax).astype(np.int32) 199 | 200 | quantize_params = QuantizeAffineParams2(scale, zero_point, num_bits) 201 | # We must ensure that zero is exactly representable with these quantization 202 | # parameters. This is easy enough to add a self-check for. 203 | quantized_zero = quantize(np.array([0.0]), quantize_params) 204 | dequantized_zero = dequantize(quantized_zero, quantize_params) 205 | if dequantized_zero.item() != 0.0: 206 | raise ValueError( 207 | f"Quantization parameters are invalid: scale={scale}, zero={zero_point}. " 208 | f"Can't exactly represent zero." 209 | ) 210 | 211 | return quantize_params 212 | 213 | 214 | def quantize_affine_given_quant_params( 215 | input: torch.Tensor, 216 | quantize_params: QuantizeAffineParams2, 217 | ) -> torch.Tensor: 218 | """Get a quantizable approximation of a float tensor given quantize param. 219 | 220 | This function does not quantize the float tensor @input, but only adjusts it 221 | such that the returned float tensor has an exact quantized representation. 222 | This is a function that we want to use at training time to quantize biases 223 | and other parameters whose quantization schema is enforced by other 224 | parameteres. 225 | 226 | In forward pass, this function is equivalent to 227 | 228 | dequantize(get_quantized_representation(input, quantize_param)) 229 | 230 | However, in backward pass, this function operates as identity, making it 231 | ideal to be a part of the training forward pass. 232 | """ 233 | return QuantizeAffineFunction.apply(input, quantize_params) 234 | 235 | 236 | def quantize( 237 | arr: np.ndarray, quantize_params: QuantizeAffineParams2 238 | ) -> np.ndarray: 239 | """Quantize a floating point array with respect to the quantization params. 240 | 241 | Arguments: 242 | arr (np.ndarray): The floating point data to quantize. 243 | quantize_params (QuantizeAffineParams): The quantization parameters 244 | under which the data should be quantized. 245 | """ 246 | scale = quantize_params.scale 247 | zero_point = quantize_params.zero_point 248 | num_bits = quantize_params.num_bits 249 | if scale == 0: 250 | # Special case, all elements are zeros. 251 | if zero_point != 0: 252 | raise ValueError( 253 | "The given QuantizeAffineParams (={}) has a non-zero zero point" 254 | " with a scale of 0.".format(quantize_params) 255 | ) 256 | return np.zeros_like(arr, dtype=np.int32) 257 | 258 | qmin, qmax = get_qmin_qmax(num_bits) 259 | reciprocal = 1 / scale 260 | quantized_values = (arr * reciprocal).round() + zero_point 261 | quantized_values = quantized_values.clip(qmin, qmax) 262 | return quantized_values 263 | 264 | 265 | def dequantize( 266 | q_arr: np.ndarray, 267 | quantize_params: QuantizeAffineParams2, 268 | ) -> np.ndarray: 269 | """Dequantize a fixed point array with respect to the quantization params. 270 | 271 | Arguments: 272 | q_arr (np.ndarray): The quantized array to dequantize. It's dtype must 273 | match quantize_params. 274 | quantize_params (QuantizeAffineParams): The quantization parameters 275 | under which the data should be dequantized. 276 | """ 277 | zero_point = quantize_params.zero_point 278 | scale = quantize_params.scale 279 | return (q_arr - zero_point) * scale 280 | 281 | 282 | def quantize_affine( 283 | input: torch.Tensor, 284 | min_value: Optional[numbers.Real] = None, 285 | max_value: Optional[numbers.Real] = None, 286 | num_bits: int = None, 287 | ) -> torch.Tensor: 288 | """Return a quantizable approximation of a float tensor @input. 289 | 290 | This function does not quantize the float tensor @input, but only adjusts it 291 | such that the returned float tensor has an exact quantized representation. 292 | This is a function that we want to use at training time to quantize weights 293 | and activations. 294 | 295 | Arguments: 296 | input (Tensor): The input float Tensor to be quantized. 297 | min_value (scalar): The running min value (possibly averaged). 298 | max_value (scalar): The running max value (possibly averaged). 299 | num_bits (numpy.dtype): The number of bits. 300 | """ 301 | if num_bits is None: 302 | raise ValueError("num_bits must be supplied") 303 | 304 | if min_value is None: 305 | # Force include 0 in our calculation of min_value. 306 | min_value = min(input.min().item(), 0.0) 307 | if max_value is None: 308 | # Force include 0 in our calculation of max_value. 309 | max_value = max(input.max().item(), 0.0) 310 | 311 | quantize_params = get_quantization_params(min_value, max_value, num_bits) 312 | return QuantizeAffineFunction.apply(input, quantize_params) 313 | 314 | 315 | class QuantizeAffine(nn.Module): 316 | """Pytorch quantize_affine layer for quantizing layer outputs. 317 | 318 | This layer will keep a running max and min, which is used to compute a scale 319 | and zero_point for the quantization. Note that it is not always desirable 320 | to start the quantization immediately while training. 321 | 322 | Arguments: 323 | momentum (scalar): The amount of averaging of min and max bounds. 324 | This value should be in the range [0.0, 1.0]. 325 | iteration_delay (scalar): The number of batches to wait before starting 326 | to quantize. 327 | """ 328 | 329 | def __init__( 330 | self, 331 | momentum=0.1, 332 | iteration_delay=0, 333 | num_bits=8, 334 | quantizer_freeze_min_max=False, 335 | ): 336 | super().__init__() 337 | self.momentum = momentum 338 | self.iteration_delay = iteration_delay 339 | self.increment_counter = False 340 | self.num_bits = num_bits 341 | self.register_buffer("running_min_value", torch.tensor(0.0)) 342 | self.register_buffer("running_max_value", torch.tensor(0.0)) 343 | self.register_buffer( 344 | "iteration_count", torch.zeros([1], dtype=torch.int32).squeeze() 345 | ) 346 | self.quantizer_freeze_min_max = quantizer_freeze_min_max 347 | 348 | def __repr__(self): 349 | return ( 350 | f"{self.__class__.__name__}(running_min=" 351 | f"{self.running_min_value}, running_max=" 352 | f"{self.running_max_value}, freeze_min_max=" 353 | f"{self.quantizer_freeze_min_max}, num_bits={self.num_bits})" 354 | ) 355 | 356 | def update_num_bits(self, num_bits): 357 | self.num_bits = num_bits 358 | 359 | def forward(self, input, recomp_bn_stats=False, override_alpha=False): 360 | if ( 361 | self.training 362 | and self.is_active() 363 | and not self.quantizer_freeze_min_max 364 | ): 365 | # Force include 0 in min_value and max_value calculation. 366 | min_value = min(input.min().item(), 0) 367 | max_value = max(input.max().item(), 0) 368 | 369 | if self.iteration_count == self.iteration_delay: 370 | new_running_min_value = min_value 371 | new_running_max_value = max_value 372 | else: 373 | new_running_min_value = ( 374 | 1.0 - self.momentum 375 | ) * self.running_min_value.item() + self.momentum * min_value 376 | new_running_max_value = ( 377 | 1.0 - self.momentum 378 | ) * self.running_max_value.item() + self.momentum * max_value 379 | 380 | self.running_min_value.fill_(new_running_min_value) 381 | self.running_max_value.fill_(new_running_max_value) 382 | 383 | if self.is_active(): 384 | output = quantize_affine( 385 | input, 386 | self.running_min_value.item(), 387 | self.running_max_value.item(), 388 | self.num_bits, 389 | ) 390 | else: 391 | output = input 392 | 393 | if self.training and self.increment_counter: 394 | self.iteration_count.fill_(self.iteration_count.item() + 1) 395 | 396 | return output 397 | 398 | def is_active(self): 399 | if self.training: 400 | return self.iteration_count >= self.iteration_delay 401 | # If evaluating, always run quantization: 402 | return True 403 | -------------------------------------------------------------------------------- /models/quantized_modules.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | # Modified version of https://github.com/pytorch/pytorch/blob/master/torch/nn/intrinsic/qat/modules/conv_fused.py 7 | 8 | import math 9 | from typing import TypeVar 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.intrinsic as nni 14 | from torch.nn import init 15 | from torch.nn.modules.utils import _pair 16 | from torch.nn.parameter import Parameter 17 | 18 | from .quantize_affine import QuantizeAffine 19 | from .quantize_affine import quantize_affine 20 | 21 | MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) 22 | 23 | 24 | class QuantStandardBN(nn.BatchNorm2d): 25 | def get_weight(self): 26 | return self.weight, self.bias 27 | 28 | 29 | class NoOpBN(nn.Module): 30 | def __init__(self, *args, **kwargs): 31 | super().__init__() 32 | 33 | def forward(self, input): 34 | return input 35 | 36 | 37 | class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): 38 | 39 | _version = 2 40 | _FLOAT_MODULE = MOD 41 | 42 | def __init__( 43 | self, 44 | # ConvNd args 45 | in_channels, 46 | out_channels, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | transposed, 52 | output_padding, 53 | groups, 54 | bias, 55 | padding_mode, 56 | # BatchNormNd args 57 | # num_features: out_channels 58 | eps=1e-05, 59 | momentum=0.1, 60 | # affine: True 61 | # track_running_stats: True 62 | # Args for this module 63 | freeze_bn=False, 64 | dim=2, 65 | num_bits=8, 66 | iteration_delay=0, 67 | bn_module=None, 68 | norm_kwargs=None, 69 | ): 70 | nn.modules.conv._ConvNd.__init__( 71 | self, 72 | in_channels, 73 | out_channels, 74 | kernel_size, 75 | stride, 76 | padding, 77 | dilation, 78 | transposed, 79 | output_padding, 80 | groups, 81 | False, 82 | padding_mode, 83 | ) 84 | self.freeze_bn = freeze_bn if self.training else True 85 | if bn_module is not None: 86 | if "GN" in bn_module.__name__: 87 | num_groups = norm_kwargs["num_groups"] 88 | if "Line" in bn_module.__name__: 89 | self.bn = bn_module( 90 | out_channels, eps=eps, num_groups=num_groups 91 | ) 92 | else: 93 | self.bn = bn_module( 94 | out_channels, 95 | eps=eps, 96 | num_groups=num_groups, 97 | affine=True, 98 | ) 99 | else: 100 | self.bn = bn_module( 101 | out_channels, eps=eps, momentum=momentum, affine=True 102 | ) 103 | 104 | else: 105 | raise ValueError("BN module must be supplied") 106 | 107 | if bias: 108 | self.bias = Parameter(torch.empty(out_channels)) 109 | else: 110 | self.register_parameter("bias", None) 111 | 112 | # this needs to be called after reset_bn_parameters, 113 | # as they modify the same state 114 | if self.training: 115 | if freeze_bn: 116 | self.freeze_bn_stats() 117 | else: 118 | self.update_bn_stats() 119 | else: 120 | self.freeze_bn_stats() 121 | 122 | # Quantization functions/parameters 123 | self.num_bits = num_bits 124 | self.iteration_delay = iteration_delay 125 | self.quantize_input = QuantizeAffine( 126 | iteration_delay=iteration_delay, num_bits=self.num_bits 127 | ) 128 | 129 | def reset_running_stats(self): 130 | self.bn.reset_running_stats() 131 | 132 | def reset_bn_parameters(self): 133 | self.bn.reset_running_stats() 134 | init.uniform_(self.bn.weight) 135 | init.zeros_(self.bn.bias) 136 | # note: below is actully for conv, not BN 137 | if self.bias is not None: 138 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 139 | bound = 1 / math.sqrt(fan_in) 140 | init.uniform_(self.bias, -bound, bound) 141 | 142 | def reset_parameters(self): 143 | super(_ConvBnNd, self).reset_parameters() 144 | 145 | def update_bn_stats(self): 146 | self.freeze_bn = False 147 | self.bn.training = True 148 | return self 149 | 150 | def freeze_bn_stats(self): 151 | self.freeze_bn = True 152 | self.bn.training = False 153 | return self 154 | 155 | def update_num_bits(self, num_bits): 156 | self.num_bits = num_bits 157 | self.quantize_input.num_bits = num_bits 158 | 159 | def start_counter(self): 160 | self.quantize_input.increment_counter = True 161 | 162 | def stop_counter(self): 163 | self.quantize_input.increment_counter = False 164 | 165 | def is_active(self): 166 | return self.quantize_input.is_active() 167 | 168 | def _forward(self, input, weight): 169 | if isinstance(self.bn, nn.BatchNorm2d): 170 | assert self.bn.running_var is not None 171 | running_std = torch.sqrt(self.bn.running_var + self.bn.eps) 172 | scale_factor = self.bn.get_weight()[0] / running_std 173 | elif isinstance(self.bn, (nn.InstanceNorm2d, nn.GroupNorm)): 174 | # We can't really apply the scale factor easily because each batch 175 | # element is weighted differently. So, we don't fuse the 176 | # InstanceNorm into the convolution. 177 | scale_factor = torch.ones(self.out_channels, device=weight.device) 178 | 179 | weight_shape = [1] * len(weight.shape) 180 | weight_shape[0] = -1 181 | bias_shape = [1] * len(weight.shape) 182 | bias_shape[1] = -1 183 | # Quantize weights 184 | # Note: weights are quantized from the beginning of training, i.e., no delay here 185 | scaled_weight = quantize_affine( 186 | weight * scale_factor.reshape(weight_shape), num_bits=self.num_bits 187 | ) 188 | # Quantize input 189 | # Note: inputs are quantized after self.iteration_delay training iterations 190 | input = self.quantize_input(input) 191 | # using zero bias here since the bias for original conv 192 | # will be added later 193 | if self.bias is not None: 194 | zero_bias = torch.zeros_like(self.bias) 195 | else: 196 | zero_bias = torch.zeros( 197 | self.out_channels, device=scaled_weight.device 198 | ) 199 | conv = self._conv_forward(input, scaled_weight, zero_bias) 200 | conv_orig = conv / scale_factor.reshape(bias_shape) 201 | if self.bias is not None: 202 | conv_orig = conv_orig + self.bias.reshape(bias_shape) 203 | conv = self.bn(conv_orig) 204 | return conv 205 | 206 | def extra_repr(self): 207 | return super(_ConvBnNd, self).extra_repr() 208 | 209 | # def forward(self, input): 210 | # return self._forward(input) 211 | 212 | def train(self, mode=True): 213 | """ 214 | Batchnorm's training behavior is using the self.training flag. Prevent 215 | changing it if BN is frozen. This makes sure that calling `model.train()` 216 | on a model with a frozen BN will behave properly. 217 | """ 218 | self.training = mode 219 | if not self.freeze_bn: 220 | for module in self.children(): 221 | module.train(mode) 222 | return self 223 | 224 | # ===== Serialization version history ===== 225 | # 226 | # Version 1/None 227 | # self 228 | # |--- weight : Tensor 229 | # |--- bias : Tensor 230 | # |--- gamma : Tensor 231 | # |--- beta : Tensor 232 | # |--- running_mean : Tensor 233 | # |--- running_var : Tensor 234 | # |--- num_batches_tracked : Tensor 235 | # 236 | # Version 2 237 | # self 238 | # |--- weight : Tensor 239 | # |--- bias : Tensor 240 | # |--- bn : Module 241 | # |--- weight : Tensor (moved from v1.self.gamma) 242 | # |--- bias : Tensor (moved from v1.self.beta) 243 | # |--- running_mean : Tensor (moved from v1.self.running_mean) 244 | # |--- running_var : Tensor (moved from v1.self.running_var) 245 | # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) 246 | def _load_from_state_dict( 247 | self, 248 | state_dict, 249 | prefix, 250 | local_metadata, 251 | strict, 252 | missing_keys, 253 | unexpected_keys, 254 | error_msgs, 255 | ): 256 | version = local_metadata.get("version", None) 257 | if version is None or version == 1: 258 | # BN related parameters and buffers were moved into the BN module for v2 259 | v2_to_v1_names = { 260 | "bn.weight": "gamma", 261 | "bn.bias": "beta", 262 | "bn.running_mean": "running_mean", 263 | "bn.running_var": "running_var", 264 | "bn.num_batches_tracked": "num_batches_tracked", 265 | } 266 | 267 | for v2_name, v1_name in v2_to_v1_names.items(): 268 | if prefix + v1_name in state_dict: 269 | state_dict[prefix + v2_name] = state_dict[prefix + v1_name] 270 | state_dict.pop(prefix + v1_name) 271 | elif prefix + v2_name in state_dict: 272 | # there was a brief period where forward compatibility 273 | # for this module was broken (between 274 | # https://github.com/pytorch/pytorch/pull/38478 275 | # and https://github.com/pytorch/pytorch/pull/38820) 276 | # and modules emitted the v2 state_dict format while 277 | # specifying that version == 1. This patches the forward 278 | # compatibility issue by allowing the v2 style entries to 279 | # be used. 280 | pass 281 | elif strict: 282 | missing_keys.append(prefix + v2_name) 283 | 284 | super(_ConvBnNd, self)._load_from_state_dict( 285 | state_dict, 286 | prefix, 287 | local_metadata, 288 | strict, 289 | missing_keys, 290 | unexpected_keys, 291 | error_msgs, 292 | ) 293 | 294 | 295 | class ConvBn2d(_ConvBnNd, nn.Conv2d): 296 | r""" 297 | A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, 298 | attached with FakeQuantize modules for weight, 299 | used in quantization aware training. 300 | We combined the interface of :class:`torch.nn.Conv2d` and 301 | :class:`torch.nn.BatchNorm2d`. 302 | Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized 303 | to default. 304 | Attributes: 305 | freeze_bn: 306 | weight_fake_quant: fake quant module for weight 307 | """ 308 | 309 | def __init__( 310 | self, 311 | # ConvNd args 312 | in_channels, 313 | out_channels, 314 | kernel_size, 315 | stride=1, 316 | padding=0, 317 | dilation=1, 318 | groups=1, 319 | bias=None, 320 | padding_mode="zeros", 321 | # BatchNorm2d args 322 | # num_features: out_channels 323 | eps=1e-05, 324 | momentum=0.1, 325 | # affine: True 326 | # track_running_stats: True 327 | # Args for this module 328 | freeze_bn=False, 329 | num_bits=8, 330 | iteration_delay=0, 331 | bn_module=None, 332 | norm_kwargs=None, 333 | ): 334 | kernel_size = _pair(kernel_size) 335 | stride = _pair(stride) 336 | padding = _pair(padding) 337 | dilation = _pair(dilation) 338 | _ConvBnNd.__init__( 339 | self, 340 | in_channels, 341 | out_channels, 342 | kernel_size, 343 | stride, 344 | padding, 345 | dilation, 346 | False, 347 | _pair(0), 348 | groups, 349 | bias, 350 | padding_mode, 351 | eps, 352 | momentum, 353 | freeze_bn, 354 | dim=2, 355 | num_bits=num_bits, 356 | iteration_delay=iteration_delay, 357 | bn_module=bn_module, 358 | norm_kwargs=norm_kwargs, 359 | ) 360 | 361 | def forward(self, input, weight=None): 362 | if weight is None: 363 | return _ConvBnNd._forward(self, input, self.weight) 364 | else: 365 | return _ConvBnNd._forward(self, input, weight) 366 | 367 | 368 | class QuantSubspaceConvBn2d(ConvBn2d): 369 | def forward(self, x): 370 | # call get_weight, which samples from the subspace, then use the 371 | # corresponding weight. 372 | w = self.get_weight() 373 | return super().forward(x, w) 374 | 375 | 376 | class TwoParamConvBnd2d(QuantSubspaceConvBn2d): 377 | def __init__(self, *args, **kwargs): 378 | super().__init__(*args, **kwargs) 379 | self.weight1 = nn.Parameter(torch.zeros_like(self.weight)) 380 | 381 | def initialize(self, initialize_fn): 382 | initialize_fn(self.weight) 383 | initialize_fn(self.weight1) 384 | 385 | 386 | class LinesConvBn2d(TwoParamConvBnd2d): 387 | def get_weight(self): 388 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 389 | return w 390 | -------------------------------------------------------------------------------- /models/sparse_modules.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | 8 | import utils 9 | 10 | 11 | class SparseConv2d(nn.Conv2d): 12 | """ 13 | A class for sparse 2d convolution layers 14 | """ 15 | 16 | def __init__(self, *args, **kwargs): 17 | """ 18 | Inputs: 19 | - topk: A float between 0 and 1 indicating the percentage of top values 20 | to keep. E.g. if topk = 0.3, will keep the top 30% of weights (by 21 | magnitude) and drop the rest 22 | """ 23 | self.method = kwargs.pop("method", "topk") 24 | self.topk = 1.0 25 | super(SparseConv2d, self).__init__(*args, **kwargs) 26 | 27 | def apply_sparsity(self, weight: torch.Tensor) -> torch.Tensor: 28 | return utils.apply_sparsity(self.topk, weight, method=self.method) 29 | 30 | def forward(self, x, weight=None): 31 | """ 32 | Performs forward pass after passing weight tensor through apply_sparsity. Iterations are incremented only on train 33 | """ 34 | if weight is None: 35 | return self._conv_forward( 36 | x, self.apply_sparsity(self.weight), self.bias 37 | ) 38 | else: 39 | return self._conv_forward(x, self.apply_sparsity(weight), self.bias) 40 | 41 | def __repr__(self) -> str: 42 | ret = super().__repr__() 43 | # Remove last paren. 44 | ret = ret[:-1] 45 | ret += f", method={self.method})" 46 | return ret 47 | 48 | 49 | class SparseSubspaceConv(SparseConv2d): 50 | def forward(self, x): 51 | # call get_weight, which samples from the subspace, then use the corresponding weight. 52 | w = self.get_weight() 53 | return super().forward(x, w) 54 | 55 | 56 | class SparseTwoParamConv(SparseSubspaceConv): 57 | def __init__(self, *args, **kwargs): 58 | super().__init__(*args, **kwargs) 59 | self.weight1 = nn.Parameter(torch.zeros_like(self.weight)) 60 | 61 | def initialize(self, initialize_fn): 62 | initialize_fn(self.weight) 63 | initialize_fn(self.weight1) 64 | 65 | 66 | class SparseLinesConv(SparseTwoParamConv): 67 | def get_weight(self): 68 | w = (1 - self.alpha) * self.weight + self.alpha * self.weight1 69 | return w 70 | -------------------------------------------------------------------------------- /models/special_tensors.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | """Utility functions to tag tensors with metadata. 6 | 7 | The metadata remains with the tensor under torch operations that don't change 8 | the values, e.g. .clone(), .contiguous(), .permute(), etc. 9 | """ 10 | 11 | import collections 12 | import copy 13 | from typing import Any 14 | from typing import Optional 15 | 16 | import numpy as np 17 | import torch 18 | 19 | QuantizeAffineParams2 = collections.namedtuple( 20 | "QuantizeAffineParams", ["scale", "zero_point", "num_bits"] 21 | ) 22 | 23 | 24 | class _SpecialTensor(torch.Tensor): 25 | """This class denotes special tensors. 26 | 27 | It isn't intended to be used directly, but serves as a helper for tagging 28 | tensors with metadata. 29 | 30 | It subclasses torch.Tensor so that isinstance(t, torch.Tensor) returns True 31 | for special tensors. It forbids some of the methods of torch.Tensor, and 32 | overrides a few methods used to create other tensors, to ensure the result 33 | is still special. 34 | """ 35 | 36 | _metadata = None 37 | 38 | def __getattribute__(self, attr: str) -> Any: 39 | # Disallow new_zeros, new_ones, new_full, etc. 40 | if "new_" in attr: 41 | raise AttributeError( 42 | "Invalid attr {!r} for special tensors".format(attr) 43 | ) 44 | return super().__getattribute__(attr) 45 | 46 | def detach(self) -> "_SpecialTensor": 47 | ret = super().detach() 48 | ret.__class__ = _SpecialTensor 49 | ret._metadata = self._metadata 50 | return ret 51 | 52 | @property 53 | def data(self) -> "_SpecialTensor": 54 | ret = super().data 55 | ret.__class__ = _SpecialTensor 56 | ret._metadata = self._metadata 57 | return ret 58 | 59 | def clone(self) -> "_SpecialTensor": 60 | ret = super().clone() 61 | ret.__class__ = _SpecialTensor 62 | ret._metadata = self._metadata 63 | return ret 64 | 65 | def cuda( 66 | self, device: Optional[torch.device] = None, non_blocking: bool = False 67 | ) -> "_SpecialTensor": 68 | ret = super().cuda() 69 | ret.__class__ = _SpecialTensor 70 | ret._metadata = self._metadata 71 | return ret 72 | 73 | def contiguous(self) -> "_SpecialTensor": 74 | ret = super().contiguous() 75 | ret.__class__ = _SpecialTensor 76 | ret._metadata = self._metadata 77 | return ret 78 | 79 | def view(self, *args, **kwargs) -> "_SpecialTensor": 80 | ret = super().view(*args, **kwargs) 81 | ret.__class__ = _SpecialTensor 82 | ret._metadata = self._metadata 83 | return ret 84 | 85 | def permute(self, *args, **kwargs) -> "_SpecialTensor": 86 | ret = super().permute(*args, **kwargs) 87 | ret.__class__ = _SpecialTensor 88 | ret._metadata = self._metadata 89 | return ret 90 | 91 | def __getitem__(self, *args, **kwargs) -> "_SpecialTensor": 92 | ret = super().__getitem__(*args, **kwargs) 93 | ret.__class__ = _SpecialTensor 94 | ret._metadata = self._metadata 95 | return ret 96 | 97 | def __copy__(self) -> "_SpecialTensor": 98 | ret = copy.copy(super()) 99 | ret.__class__ = _SpecialTensor 100 | ret._metadata = self._metadata 101 | return ret 102 | 103 | 104 | def _check_type(tensor: torch.Tensor) -> None: 105 | given_type = type(tensor) 106 | if not issubclass(given_type, torch.Tensor): 107 | raise TypeError("invalid type {!r}".format(given_type)) 108 | 109 | 110 | def tag_with_metadata(tensor: torch.Tensor, metadata: Any) -> None: 111 | """Tag a metadata to a tensor.""" 112 | _check_type(tensor) 113 | tensor.__class__ = _SpecialTensor 114 | tensor._metadata = metadata 115 | 116 | 117 | RepresentibleByQuantizeAffine = collections.namedtuple( 118 | "RepresentibleByQuantizeAffine", ["quant_params"] 119 | ) 120 | 121 | 122 | def mark_quantize_affine( 123 | tensor: torch.Tensor, 124 | scale: float, 125 | zero_point: int, 126 | dtype: np.dtype = np.uint8, 127 | ) -> None: 128 | """Mark a tensor as quantized with affine. 129 | 130 | See //xnorai/training/pytorch/extensions/functions:quantize_affine for more 131 | info on this method of quantization. 132 | 133 | The tensor itself can be a floating point Tensor. However, its values must 134 | be representible with @scale and @zero_point. This function, for performance 135 | reasons, does not validiate if the tensor is really quantizable as it 136 | claims to be. 137 | 138 | Arguments: 139 | tensor (torch.Tensor): The tensor to be marked as affine-quantizable 140 | Tensor. 141 | scale (float): the scale (from quantization parameters). 142 | zero_point (int): The zero_point (from quantization parameters). 143 | dtype (numpy.dtype): Type of tensor when quantized (this is usually 144 | numpy.uint8, which is used for Q8). A ValueError will be thrown if 145 | the input dtype is not one of the following: 146 | {numpy.uint8, numpy.int32}. 147 | """ 148 | allowed_dtypes = [np.uint8, np.int32] 149 | if dtype not in allowed_dtypes: 150 | raise ValueError( 151 | "Provided dtype ({}) is not supported. Please use: {}".format( 152 | dtype, allowed_dtypes 153 | ) 154 | ) 155 | quant_params = QuantizeAffineParams2(scale, zero_point, dtype) 156 | tag_with_metadata(tensor, RepresentibleByQuantizeAffine(quant_params)) 157 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autoflake==1.4 2 | black==20.8b1 3 | isort==5.8.0 4 | numpy==1.19.5 5 | scipy==1.5.4 6 | torch==1.8.1+cu111 7 | torchvision==0.9.1+cu111 8 | -------------------------------------------------------------------------------- /schedulers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import numpy as np 6 | 7 | __all__ = ["cosine_lr"] 8 | 9 | 10 | def assign_learning_rate(optimizer, new_lr): 11 | for param_group in optimizer.param_groups: 12 | param_group["lr"] = new_lr 13 | 14 | 15 | def cosine_lr(optimizer, learning_rate, *, warmup_length, epochs): 16 | def _lr_adjuster(epoch, iteration): 17 | if epoch < warmup_length: 18 | lr = _warmup_lr(learning_rate, warmup_length, epoch) 19 | else: 20 | e = epoch - warmup_length 21 | es = epochs - warmup_length 22 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * learning_rate 23 | 24 | assign_learning_rate(optimizer, lr) 25 | print(f"Assigned lr={lr}") 26 | return lr 27 | 28 | return _lr_adjuster 29 | 30 | 31 | def _warmup_lr(base_lr, warmup_length, epoch): 32 | return base_lr * (epoch + 1) / warmup_length 33 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | -------------------------------------------------------------------------------- /tools/format-python: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | autoflake --in-place --remove-unused-variables --remove-all-unused-imports $1 3 | black -l 80 $1 4 | isort -sl $1 5 | tools/remove-whitespace.sh $1 6 | -------------------------------------------------------------------------------- /tools/remove-whitespace.sh: -------------------------------------------------------------------------------- 1 | sed -i '' -E 's/[ '$'\t'']+$//' "$1" 2 | -------------------------------------------------------------------------------- /train_curve.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | # Trains a standard/sparse/quantized line/poly chain/bezier curve. 7 | 8 | import random 9 | 10 | import numpy as np 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn as nn 14 | 15 | import curve_utils 16 | import model_logging 17 | import models.networks as models 18 | import utils 19 | from models import modules 20 | from models.builder import Builder 21 | from models.networks import resprune 22 | from models.networks import utils as network_utils 23 | from models.networks import vggprune 24 | from models.quantized_modules import ConvBn2d 25 | 26 | 27 | def sparse_module_updates(model, training=False, alpha=None, **regime_params): 28 | current_iteration = regime_params["current_iteration"] 29 | warmup_iterations = regime_params["warmup_iterations"] 30 | 31 | if alpha is None: 32 | alpha = curve_utils.alpha_sampling(**regime_params) 33 | 34 | df = np.max([1 - current_iteration / warmup_iterations, 0]) 35 | topk = alpha + (1 - alpha) * df 36 | 37 | is_standard = regime_params["is_standard"] 38 | 39 | for m in model.modules(): 40 | if isinstance(m, nn.Conv2d) or isinstance( 41 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm) 42 | ): 43 | setattr(m, f"alpha", alpha) 44 | 45 | if (not is_standard) and hasattr(m, "topk"): 46 | setattr(m, "topk", topk) 47 | 48 | if training: 49 | regime_params["current_iteration"] += 1 50 | 51 | return model, regime_params 52 | 53 | 54 | def quantized_module_updates( 55 | model, training=False, alpha=None, **regime_params 56 | ): 57 | if alpha is None: 58 | alpha, num_bits = curve_utils.sample_alpha_num_bits(**regime_params) 59 | else: 60 | num_bits = curve_utils.alpha_bit_map(alpha, **regime_params) 61 | 62 | is_standard = regime_params["is_standard"] 63 | 64 | for m in model.modules(): 65 | if isinstance(m, nn.Conv2d) or isinstance( 66 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm) 67 | ): 68 | setattr(m, f"alpha", alpha) 69 | 70 | if isinstance(m, ConvBn2d): 71 | if not is_standard: 72 | m.update_num_bits(num_bits) 73 | 74 | if training: 75 | m.start_counter() 76 | else: 77 | m.stop_counter() 78 | 79 | return model, regime_params 80 | 81 | 82 | def _test_time_lec_update(model, **regime_params): 83 | # This requires that the topk values are already set on the model. 84 | 85 | # We create a whole new copy of the model which is pruned. 86 | model_kwargs = regime_params["model_kwargs"] 87 | fresh_copy = utils.make_fresh_copy_of_pruned_network(model, model_kwargs) 88 | cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy) 89 | 90 | builder = Builder(conv_type="StandardConv", bn_type="StandardIN") 91 | 92 | try: 93 | if isinstance(model, models.cpreresnet): 94 | model_class = resprune 95 | elif isinstance(model, models.vgg.vgg): 96 | model_class = vggprune 97 | else: 98 | raise ValueError( 99 | "Model {} is not surpported for LEC.".format(model) 100 | ) 101 | 102 | _, slimmed_network = model_class.get_slimmed_network( 103 | fresh_copy.module, 104 | {"builder": builder, "block_builder": builder, **model_kwargs}, 105 | cfg, 106 | cfg_mask, 107 | ) 108 | except: 109 | print( 110 | f"Something went wrong during LEC. Most likely, an entire " 111 | f"layer was deleted. Using @fresh_copy." 112 | ) 113 | slimmed_network = fresh_copy 114 | num_parameters = sum( 115 | [param.nelement() for param in slimmed_network.parameters()] 116 | ) 117 | 118 | # NOTE: DO NOT use @model here, since it has too many extra buffers in the 119 | # case of training a line. 120 | total_params = sum([param.nelement() for param in fresh_copy.parameters()]) 121 | regime_params["sparsity"] = (total_params - num_parameters) / total_params 122 | 123 | print(f"Got sparsity level of {regime_params['sparsity']}") 124 | 125 | return slimmed_network, regime_params 126 | 127 | 128 | def lec_update(model, training=False, alpha=None, **regime_params): 129 | # Note: this file is for training a *curve*, so we do the right thing for 130 | # curves here. Namely, we apply the same update to alpha/topk as when 131 | # training a standard curve, but we gather the sparsity in a different way 132 | # since we're doing LEC rather than doing the unstructured sparsity. 133 | model, regime_params = sparse_module_updates( 134 | model, training=training, alpha=alpha, **regime_params 135 | ) 136 | 137 | if training: 138 | return model, regime_params 139 | else: 140 | return _test_time_lec_update(model, **regime_params) 141 | 142 | 143 | def us_update(model, training=False, alpha=None, **regime_params): 144 | assert alpha is not None, f"Alpha value is required." 145 | 146 | # Use alpha to compute the width factor. 147 | low_factor, high_factor = regime_params["width_factor_limits"] 148 | width_factor = low_factor + (high_factor - low_factor) * alpha 149 | regime_params["width_factor"] = width_factor 150 | 151 | for module in model.modules(): 152 | if hasattr(module, "width_factor"): 153 | module.width_factor = width_factor 154 | 155 | for m in model.modules(): 156 | if isinstance(m, nn.Conv2d) or isinstance( 157 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm) 158 | ): 159 | setattr(m, f"alpha", alpha) 160 | 161 | return model, regime_params 162 | 163 | 164 | def train( 165 | model, train_loader, optimizer, criterion, epoch, device, **regime_params 166 | ): 167 | 168 | model.zero_grad() 169 | model.train() 170 | avg_loss = 0.0 171 | 172 | regime_update_dict = { 173 | "sparse": sparse_module_updates, 174 | "quantized": quantized_module_updates, 175 | "lec": lec_update, 176 | "us": us_update, 177 | } 178 | module_update = regime_update_dict[regime_params["regime"]] 179 | 180 | for batch_idx, (data, target) in enumerate(train_loader): 181 | 182 | data, target = data.to(device, non_blocking=True), target.to( 183 | device, non_blocking=True 184 | ) 185 | 186 | optimizer.zero_grad() 187 | 188 | if regime_params["regime"] == "us": 189 | loss = 0 190 | 191 | if regime_params["width_factor_sampling_method"] == "sandwich": 192 | n_samples = regime_params["width_factor_samples"] 193 | 194 | assert n_samples >= 2, f"Require n_samples>=2, got {n_samples}" 195 | 196 | alphas = [0, 1] 197 | for i in range(n_samples - 2): 198 | alphas.append(random.uniform(0.0, 1.0)) 199 | elif regime_params["width_factor_sampling_method"] == "point": 200 | alphas = [random.uniform(0.0, 1.0)] 201 | else: 202 | raise NotImplementedError 203 | 204 | for alpha in alphas: 205 | model, regime_params = module_update( 206 | model, True, alpha=alpha, **regime_params 207 | ) 208 | output = model(data) 209 | loss += criterion(output, target) 210 | 211 | else: 212 | model, regime_params = module_update( 213 | model, training=True, **regime_params 214 | ) 215 | output = model(data) 216 | loss = criterion(output, target) 217 | 218 | # Application of the regularization term, equation 3. 219 | num_points = regime_params["num_points"] 220 | beta = regime_params.get("beta", 1) 221 | if beta > 0 and num_points > 1: 222 | out = random.sample([i for i in range(num_points)], 2) 223 | 224 | i, j = out[0], out[1] 225 | num = 0.0 226 | normi = 0.0 227 | normj = 0.0 228 | for m in model.modules(): 229 | # Apply beta term if we have a conv, and (optionally) if we have 230 | # a norm layer. Only apply beta term if alpha exists (e.g. it's 231 | # a line). 232 | # (We forbid an exact type match because "plain-old" Conv2d 233 | # layers [as in LEC] should not trigger this logic). 234 | should_apply_beta = isinstance(m, nn.Conv2d) and not type( 235 | m 236 | ) in (nn.Conv2d, modules.AdaptiveConv2d) 237 | should_apply_beta = should_apply_beta or ( 238 | isinstance( 239 | m, (nn.modules.batchnorm._NormBase, nn.GroupNorm) 240 | ) 241 | and regime_params.get("apply_beta_to_norm", False) 242 | ) 243 | should_apply_beta = should_apply_beta and hasattr(m, "alpha") 244 | if should_apply_beta: 245 | vi = curve_utils.get_weight(m, i) 246 | vj = curve_utils.get_weight(m, j) 247 | num += (vi * vj).sum() 248 | normi += vi.pow(2).sum() 249 | normj += vj.pow(2).sum() 250 | loss += beta * (num.pow(2) / (normi * normj)) 251 | 252 | loss.backward() 253 | if regime_params.get("bn_update_factor") is not None: 254 | loss_reg_term = ( 255 | utils.apply_in_topk_reg(model, apply_to_bias=False) 256 | * regime_params["bn_update_factor"] 257 | ) 258 | else: 259 | loss_reg_term = 0 260 | optimizer.step() 261 | 262 | avg_loss += loss.item() + loss_reg_term 263 | 264 | log_interval = 10 265 | 266 | if batch_idx % log_interval == 0: 267 | num_samples = batch_idx * len(data) 268 | num_epochs = len(train_loader.dataset) 269 | percent_complete = 100.0 * batch_idx / len(train_loader) 270 | 271 | predicted_labels = output.argmax(dim=1) 272 | corrects = ( 273 | predicted_labels == target 274 | ).float().sum() / target.numel() 275 | 276 | print( 277 | f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t" 278 | f"Loss: {loss.item():.6f} Correct: {corrects.item():.4f}" 279 | ) 280 | 281 | model.apply(lambda m: setattr(m, "return_feats", False)) 282 | 283 | avg_loss = avg_loss / len(train_loader) 284 | 285 | return avg_loss, regime_params 286 | 287 | 288 | def test( 289 | model, 290 | alpha, 291 | val_loader, 292 | criterion, 293 | epoch, 294 | device, 295 | metric_dict=None, 296 | **regime_params, 297 | ): 298 | 299 | model.zero_grad() 300 | model.eval() 301 | test_loss = 0 302 | correct = 0 303 | 304 | regime_update_dict = { 305 | "sparse": sparse_module_updates, 306 | "quantized": quantized_module_updates, 307 | "lec": lec_update, 308 | "us": us_update, 309 | } 310 | logging_dict = { 311 | "sparse": model_logging.sparse_logging, 312 | "quantized": model_logging.quantized_logging, 313 | "lec": model_logging.lec_logging, 314 | "us": model_logging.us_logging, 315 | } 316 | module_update = regime_update_dict[regime_params["regime"]] 317 | logging = logging_dict[regime_params["regime"]] 318 | 319 | model, regime_params = module_update(model, alpha=alpha, **regime_params) 320 | 321 | # optionally add the hooks needed for tracking BN accuracy stats. 322 | if regime_params.get("bn_accuracy_stats", True): 323 | hooks, mean_dict, var_dict = utils.register_bn_tracking_hooks(model) 324 | 325 | with torch.no_grad(): 326 | 327 | for batch_idx, (data, target) in enumerate(val_loader): 328 | data, target = data.to(device, non_blocking=True), target.to( 329 | device, non_blocking=True 330 | ) 331 | 332 | output = model(data) 333 | test_loss += criterion(output, target).item() 334 | 335 | # get the index of the max log-probability 336 | pred = output.argmax(dim=1, keepdim=True) 337 | 338 | correct += pred.eq(target.view_as(pred)).sum().item() 339 | 340 | if regime_params.get("bn_accuracy_stats", True): 341 | utils.unregister_bn_tracking_hooks(hooks) 342 | extra_metrics = utils.get_bn_accuracy_metrics( 343 | model, mean_dict, var_dict 344 | ) 345 | 346 | # Mean_dict and var_dict contain a mapping from modules to their 347 | else: 348 | extra_metrics = {} 349 | 350 | test_loss /= len(val_loader) 351 | test_acc = float(correct) / len(val_loader.dataset) 352 | return logging( 353 | model, 354 | test_loss, 355 | test_acc, 356 | model_type="curve", 357 | param=alpha, 358 | metric_dict=metric_dict, 359 | extra_metrics=extra_metrics, 360 | **regime_params, 361 | ) 362 | 363 | 364 | def train_model(config): 365 | # Get network, data, and training parameters 366 | ( 367 | net, 368 | opt_params, 369 | train_params, 370 | regime_params, 371 | data, 372 | ) = utils.network_and_params(config) 373 | 374 | # Unpack training parameters 375 | epochs, test_freq, alpha_grid = train_params 376 | 377 | # Unpack dataset 378 | trainset, trainloader, testset, testloader = data 379 | 380 | # Unpack optimization parameters 381 | criterion, optimizer, scheduler = opt_params 382 | 383 | device = "cuda" if torch.cuda.is_available() else "cpu" 384 | net = net.to(device) 385 | 386 | if device == "cuda": 387 | cudnn.benchmark = True 388 | net = torch.nn.DataParallel(net) 389 | 390 | start_epoch = 0 391 | 392 | metric_dict = {"acc": [], "alpha": []} 393 | if regime_params["regime"] == "quantized": 394 | metric_dict["num_bits"] = [] 395 | if regime_params["regime"] in ( 396 | "sparse", 397 | "lec", 398 | "us", 399 | "ns", 400 | ): 401 | metric_dict["sparsity"] = [] 402 | 403 | save_dir = regime_params["save_dir"] 404 | 405 | for epoch in range(start_epoch, start_epoch + epochs): 406 | scheduler(epoch, None) 407 | if epoch % test_freq == 0: 408 | for alpha in alpha_grid: 409 | test( 410 | net, 411 | alpha, 412 | testloader, 413 | criterion, 414 | epoch, 415 | device, 416 | metric_dict=None, 417 | **regime_params, 418 | ) 419 | 420 | model_logging.save_model_at_epoch(net, epoch, save_dir) 421 | 422 | _, regime_params = train( 423 | net, 424 | trainloader, 425 | optimizer, 426 | criterion, 427 | epoch, 428 | device, 429 | **regime_params, 430 | ) 431 | 432 | # Save final model 433 | for alpha in alpha_grid: 434 | metric_dict = test( 435 | net, 436 | alpha, 437 | testloader, 438 | criterion, 439 | epoch, 440 | device, 441 | metric_dict=metric_dict, 442 | **regime_params, 443 | ) 444 | 445 | model_logging.save_model_at_epoch(net, epoch + 1, save_dir) 446 | model_logging.save("test_metrics.npy", metric_dict, save_dir) 447 | -------------------------------------------------------------------------------- /train_indep.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch import nn 11 | 12 | import model_logging 13 | import models.networks as networks 14 | import utils 15 | from models.builder import Builder 16 | from models.networks import resprune 17 | from models.networks import utils as network_utils 18 | from models.networks import vggprune 19 | from models.quantized_modules import ConvBn2d 20 | from models.sparse_modules import SparseConv2d 21 | 22 | 23 | def sparse_module_updates(model, training=False, **regime_params): 24 | topk = regime_params["topk"] 25 | current_iteration = regime_params["current_iteration"] 26 | warmup_iterations = regime_params["warmup_iterations"] 27 | topk_lower, topk_upper = regime_params["alpha_sampling"][0:2] 28 | if regime_params["random_param"]: 29 | if training: 30 | if current_iteration < warmup_iterations: 31 | current_topk = topk_upper 32 | else: 33 | # Unbiased sampling 34 | current_topk = np.random.uniform(topk_lower, topk_upper) 35 | else: 36 | current_topk = regime_params["topk"] 37 | else: 38 | df = np.max([1 - current_iteration / warmup_iterations, 0]) 39 | current_topk = topk + (1 - topk) * df 40 | 41 | for m in model.modules(): 42 | if isinstance(m, SparseConv2d): 43 | setattr(m, f"topk", current_topk) 44 | 45 | if training: 46 | regime_params["current_iteration"] += 1 47 | 48 | return model, regime_params 49 | 50 | 51 | def quantized_module_updates(model, training=False, **regime_params): 52 | set_bits = regime_params["num_bits"] 53 | 54 | for m in model.modules(): 55 | if isinstance(m, ConvBn2d): 56 | if regime_params["random_param"]: 57 | if training: 58 | m.start_counter() 59 | if m.is_active(): 60 | bits = np.arange( 61 | regime_params["min_bits"], 62 | regime_params["max_bits"] + 1, 63 | ) 64 | rand_bits = np.random.choice(bits) 65 | m.update_num_bits(rand_bits) 66 | else: 67 | m.update_num_bits(regime_params["max_bits"]) 68 | else: 69 | m.stop_counter() 70 | m.update_num_bits(set_bits) 71 | else: 72 | m.update_num_bits(set_bits) 73 | if training: 74 | m.start_counter() 75 | else: 76 | m.stop_counter() 77 | 78 | return model, regime_params 79 | 80 | 81 | def lec_update(model, training=False, **regime_params): 82 | # The original LEC paper does the update using a global threshold, so we 83 | # adopt that strategy here. 84 | model, regime_params = sparse_module_updates( 85 | model, training=training, **regime_params 86 | ) 87 | 88 | if training: 89 | return model, regime_params 90 | else: 91 | # We create a pruned copy of the model. 92 | model_kwargs = regime_params["model_kwargs"] 93 | fresh_copy = utils.make_fresh_copy_of_pruned_network( 94 | model, model_kwargs 95 | ) 96 | 97 | # The @fresh_copy needs to have its smallest InstanceNorm parameters 98 | # deleted. 99 | topk = regime_params["topk"] 100 | all_weights = [] 101 | for m in fresh_copy.modules(): 102 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 103 | all_weights.append(m.weight.abs()) 104 | 105 | all_weights = torch.cat(all_weights, dim=0) 106 | y, i = torch.sort(all_weights) 107 | threshold = y[int(all_weights.shape[0] * (1.0 - topk))] 108 | 109 | for m in fresh_copy.modules(): 110 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 111 | mask = m.weight.data.clone().abs().gt(threshold).float() 112 | m.weight.data.mul_(mask) 113 | m.bias.data.mul_(mask) 114 | 115 | # Now that we have the sparse copy, we slim it down. 116 | cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy) 117 | 118 | builder = Builder( 119 | conv_type="StandardConv", bn_type=regime_params["bn_type"] 120 | ) 121 | 122 | try: 123 | if isinstance(model, nn.DataParallel): 124 | model = model.module 125 | 126 | if isinstance(model, networks.cpreresnet): 127 | model_class = resprune 128 | elif isinstance(model, networks.vgg.vgg): 129 | model_class = vggprune 130 | else: 131 | raise ValueError( 132 | "Model {} is not supported for LEC.".format(model) 133 | ) 134 | 135 | _, slimmed_network = model_class.get_slimmed_network( 136 | fresh_copy.module, 137 | {"builder": builder, "block_builder": builder, **model_kwargs}, 138 | cfg, 139 | cfg_mask, 140 | ) 141 | except IndexError: 142 | # This is the error if we eliminate a whole layer. 143 | print( 144 | f"Something went wrong during LEC - most likely, an entire " 145 | f"layer was deleted. Using @fresh_copy." 146 | ) 147 | slimmed_network = fresh_copy 148 | num_parameters = sum( 149 | [param.nelement() for param in slimmed_network.parameters()] 150 | ) 151 | 152 | # NOTE: DO NOT use @model here, since it has too many extra buffers. 153 | total_params = sum( 154 | [param.nelement() for param in fresh_copy.parameters()] 155 | ) 156 | regime_params["sparsity"] = ( 157 | total_params - num_parameters 158 | ) / total_params 159 | 160 | print(f"Got sparsity level of {regime_params['sparsity']}") 161 | 162 | return slimmed_network, regime_params 163 | 164 | 165 | def ns_update(model, training=False, **regime_params): 166 | current_width = regime_params["width_factor"] 167 | 168 | for module in model.modules(): 169 | if hasattr(module, "width_factor"): 170 | module.width_factor = current_width 171 | 172 | return model, regime_params 173 | 174 | 175 | def us_update(*args, **kwargs): 176 | return ns_update(*args, **kwargs) 177 | 178 | 179 | def train( 180 | model, train_loader, optimizer, criterion, epoch, device, **regime_params 181 | ): 182 | model.zero_grad() 183 | model.train() 184 | avg_loss = 0.0 185 | 186 | regime_update_dict = { 187 | "sparse": sparse_module_updates, 188 | "quantized": quantized_module_updates, 189 | "lec": lec_update, 190 | "ns": ns_update, 191 | "us": us_update, # [sic] 192 | } 193 | module_update = regime_update_dict[regime_params["regime"]] 194 | 195 | for batch_idx, (data, target) in enumerate(train_loader): 196 | data, target = data.to(device, non_blocking=True), target.to( 197 | device, non_blocking=True 198 | ) 199 | 200 | # Special case: you want to do multiple iterations of training. 201 | if regime_params["regime"] == "ns": 202 | loss = 0 203 | optimizer.zero_grad() 204 | for width_factor in regime_params["width_factors_list"]: 205 | regime_params["width_factor"] = width_factor 206 | 207 | model, regime_params = module_update( 208 | model, True, **regime_params 209 | ) 210 | output = model(data) 211 | loss += criterion(output, target) 212 | elif regime_params["regime"] == "us": 213 | loss = 0 214 | optimizer.zero_grad() 215 | 216 | low_factor, high_factor = regime_params["width_factor_limits"] 217 | n_samples = regime_params["width_factor_samples"] 218 | 219 | assert n_samples >= 2, f"Require n_samples>=2, got {n_samples}" 220 | 221 | width_factors = [low_factor, high_factor] 222 | for i in range(n_samples - 2): 223 | width_factors.append(random.uniform(low_factor, high_factor)) 224 | 225 | for width_factor in width_factors: 226 | regime_params["width_factor"] = width_factor 227 | 228 | model, regime_params = module_update( 229 | model, True, **regime_params 230 | ) 231 | output = model(data) 232 | loss += criterion(output, target) 233 | else: 234 | model, regime_params = module_update(model, True, **regime_params) 235 | 236 | optimizer.zero_grad() 237 | output = model(data) 238 | loss = criterion(output, target) 239 | 240 | loss.backward() 241 | if regime_params.get("bn_update_factor") is not None: 242 | apply_norm_regularization(model, regime_params["bn_update_factor"]) 243 | optimizer.step() 244 | 245 | avg_loss += loss.item() 246 | 247 | log_interval = 10 248 | 249 | if batch_idx % log_interval == 0: 250 | num_samples = batch_idx * len(data) 251 | num_epochs = len(train_loader.dataset) 252 | percent_complete = 100.0 * batch_idx / len(train_loader) 253 | 254 | predicted_labels = output.argmax(dim=1) 255 | corrects = ( 256 | predicted_labels == target 257 | ).float().sum() / target.numel() 258 | 259 | print( 260 | f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t" 261 | f"Loss: {loss.item():.6f} Correct: {corrects.item():.4f}" 262 | ) 263 | 264 | model.apply(lambda m: setattr(m, "return_feats", False)) 265 | avg_loss = avg_loss / len(train_loader) 266 | 267 | return avg_loss, regime_params 268 | 269 | 270 | def apply_norm_regularization(model, s_factor): 271 | count = 0 272 | for m in model.modules(): 273 | if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)): 274 | count += 1 275 | m.weight.grad.data.add_(s_factor * torch.sign(m.weight.data)) # L1 276 | if count == 0: 277 | raise ValueError(f"Didn't adjust any Norms") 278 | 279 | 280 | def test( 281 | model, 282 | param, 283 | val_loader, 284 | criterion, 285 | epoch, 286 | device, 287 | metric_dict=None, 288 | **regime_params, 289 | ): 290 | # Set eval param 291 | if regime_params["regime"] in ("sparse", "lec"): 292 | regime_params["topk"] = param 293 | elif regime_params["regime"] in ("ns", "us"): 294 | regime_params["width_factor"] = param 295 | else: 296 | regime_params["num_bits"] = param 297 | 298 | metric_dict = _test( 299 | model, 300 | param, 301 | val_loader, 302 | criterion, 303 | epoch, 304 | device, 305 | metric_dict=metric_dict, 306 | **regime_params, 307 | ) 308 | return metric_dict 309 | 310 | 311 | def _test( 312 | model, 313 | param, 314 | val_loader, 315 | criterion, 316 | epoch, 317 | device, 318 | metric_dict=None, 319 | **regime_params, 320 | ): 321 | 322 | model.zero_grad() 323 | model.eval() 324 | test_loss = 0 325 | correct = 0 326 | 327 | regime_update_dict = { 328 | "sparse": sparse_module_updates, 329 | "quantized": quantized_module_updates, 330 | "lec": lec_update, 331 | "ns": ns_update, 332 | "us": us_update, 333 | } 334 | logging_dict = { 335 | "sparse": model_logging.sparse_logging, 336 | "quantized": model_logging.quantized_logging, 337 | "lec": model_logging.lec_logging, 338 | "ns": model_logging.ns_logging, 339 | "us": model_logging.us_logging, 340 | } 341 | module_update = regime_update_dict[regime_params["regime"]] 342 | logging = logging_dict[regime_params["regime"]] 343 | 344 | model, regime_params = module_update(model, **regime_params) 345 | 346 | # optionally add the hooks needed for tracking BN accuracy stats. 347 | if regime_params.get("bn_accuracy_stats", True): 348 | hooks, mean_dict, var_dict = utils.register_bn_tracking_hooks(model) 349 | 350 | with torch.no_grad(): 351 | for batch_idx, (data, target) in enumerate(val_loader): 352 | 353 | data, target = data.to(device, non_blocking=True), target.to( 354 | device, non_blocking=True 355 | ) 356 | 357 | output = model(data) 358 | test_loss += criterion(output, target).item() 359 | 360 | # get the index of the max log-probability 361 | pred = output.argmax(dim=1, keepdim=True) 362 | 363 | correct += pred.eq(target.view_as(pred)).sum().item() 364 | 365 | if regime_params.get("bn_accuracy_stats", True): 366 | utils.unregister_bn_tracking_hooks(hooks) 367 | extra_metrics = utils.get_bn_accuracy_metrics( 368 | model, mean_dict, var_dict 369 | ) 370 | 371 | # Mean_dict and var_dict contain a mapping from modules to their 372 | else: 373 | extra_metrics = {} 374 | 375 | test_loss /= len(val_loader) 376 | test_acc = float(correct) / len(val_loader.dataset) 377 | return logging( 378 | model, 379 | test_loss, 380 | test_acc, 381 | model_type="indep", 382 | param=param, 383 | metric_dict=metric_dict, 384 | extra_metrics=extra_metrics, 385 | **regime_params, 386 | ) 387 | 388 | 389 | def train_model(config): 390 | # Get network, data, and training parameters 391 | ( 392 | net, 393 | opt_params, 394 | train_params, 395 | regime_params, 396 | data, 397 | ) = utils.network_and_params(config) 398 | 399 | # Unpack training parameters 400 | epochs, test_freq, _ = train_params 401 | 402 | # Unpack dataset 403 | trainset, trainloader, testset, testloader = data 404 | 405 | # Unpack optimizer parameters 406 | criterion, optimizer, scheduler = opt_params 407 | 408 | device = "cuda" if torch.cuda.is_available() else "cpu" 409 | net = net.to(device) 410 | 411 | if device == "cuda": 412 | cudnn.benchmark = True 413 | net = torch.nn.DataParallel(net) 414 | 415 | start_epoch = 0 416 | 417 | metric_dict = {"acc": [], "alpha": []} 418 | if regime_params["regime"] == "quantized": 419 | metric_dict["num_bits"] = [] 420 | if regime_params["regime"] in ( 421 | "sparse", 422 | "lec", 423 | "us", 424 | "ns", 425 | ): 426 | metric_dict["sparsity"] = [] 427 | 428 | # Get evaluation parameter grid 429 | eval_param_grid = regime_params["eval_param_grid"] 430 | 431 | save_dir = regime_params["save_dir"] 432 | 433 | # Training loop 434 | for epoch in range(start_epoch, start_epoch + epochs): 435 | scheduler(epoch, None) 436 | if epoch % test_freq == 0: 437 | for param in eval_param_grid: 438 | test( 439 | net, 440 | param, 441 | testloader, 442 | criterion, 443 | epoch, 444 | device, 445 | metric_dict=None, 446 | **regime_params, 447 | ) 448 | 449 | model_logging.save_model_at_epoch(net, epoch, save_dir) 450 | 451 | _, regime_params = train( 452 | net, 453 | trainloader, 454 | optimizer, 455 | criterion, 456 | epoch, 457 | device, 458 | **regime_params, 459 | ) 460 | 461 | # Save final model 462 | for param in eval_param_grid: 463 | metric_dict = test( 464 | net, 465 | param, 466 | testloader, 467 | criterion, 468 | epoch, 469 | device, 470 | metric_dict=metric_dict, 471 | **regime_params, 472 | ) 473 | 474 | model_logging.save_model_at_epoch(net, epoch + 1, save_dir) 475 | model_logging.save("test_metrics.npy", metric_dict, save_dir) 476 | -------------------------------------------------------------------------------- /train_quantized.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import args 6 | import main 7 | 8 | if __name__ == "__main__": 9 | args = args.quantized_args() 10 | main.train(args, "quantized") 11 | -------------------------------------------------------------------------------- /train_structured.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import args 6 | import main 7 | 8 | if __name__ == "__main__": 9 | args = args.structured_args() 10 | main.train(args, "structured_sparsity") 11 | -------------------------------------------------------------------------------- /train_unstructured.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import args 6 | import main 7 | 8 | if __name__ == "__main__": 9 | args = args.unstructured_args() 10 | main.train(args, "unstructured_sparsity") 11 | -------------------------------------------------------------------------------- /training_params.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | # Architecture parameters 6 | 7 | DEFAULT_PARAMS = { 8 | "epochs": 200, 9 | "test_freq": 20, 10 | "batch_size": 128, 11 | "momentum": 0.9, 12 | "weight_decay": 0.0005, 13 | } 14 | 15 | DEFAULT_IMAGENET_PARAMS = { 16 | **DEFAULT_PARAMS, 17 | "epochs": 90, 18 | "test_freq": 1, 19 | "batch_size": 128, 20 | "weight_decay": 0.00005, 21 | } 22 | 23 | VGG_IMAGENET_PARAMS = { 24 | **DEFAULT_IMAGENET_PARAMS, 25 | "batch_size": 256, 26 | } 27 | 28 | 29 | def model_data_params(args): 30 | model = args.model 31 | dataset = args.dataset 32 | 33 | if dataset == "imagenet": 34 | if model == "vgg19": 35 | return VGG_IMAGENET_PARAMS 36 | elif model == "resnet18": 37 | return DEFAULT_IMAGENET_PARAMS 38 | else: 39 | raise NotImplementedError( 40 | f"No training parameters for {model}/{dataset}" 41 | ) 42 | elif "cifar" in dataset: 43 | return DEFAULT_PARAMS 44 | else: 45 | raise NotImplementedError( 46 | f"No training parameters for {model}/{dataset}" 47 | ) 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import collections 6 | import os 7 | import time 8 | from typing import Callable 9 | from typing import Dict 10 | from typing import List 11 | from typing import Tuple 12 | from typing import Union 13 | 14 | import numpy as np 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torchvision 20 | import torchvision.transforms as transforms 21 | import yaml 22 | 23 | import schedulers 24 | from models import modules 25 | from models import networks 26 | from models.builder import Builder 27 | from models.sparse_modules import SparseConv2d 28 | 29 | 30 | def get_cifar10_data(batch_size): 31 | print("==> Preparing CIFAR-10 data...") 32 | transform_train = transforms.Compose( 33 | [ 34 | transforms.RandomCrop(32, padding=4), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize( 38 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 39 | ), 40 | ] 41 | ) 42 | 43 | transform_test = transforms.Compose( 44 | [ 45 | transforms.ToTensor(), 46 | transforms.Normalize( 47 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 48 | ), 49 | ] 50 | ) 51 | 52 | trainset = torchvision.datasets.CIFAR10( 53 | root="./data", train=True, download=True, transform=transform_train 54 | ) 55 | trainloader = torch.utils.data.DataLoader( 56 | trainset, 57 | batch_size=batch_size, 58 | shuffle=True, 59 | num_workers=2, 60 | pin_memory=True, 61 | ) 62 | 63 | testset = torchvision.datasets.CIFAR10( 64 | root="./data", train=False, download=True, transform=transform_test 65 | ) 66 | # Note that, we perform an analysis of BatchNorm statistics when validating, 67 | # so we *must* shuffle the validation set. 68 | testloader = torch.utils.data.DataLoader( 69 | testset, 70 | batch_size=batch_size, 71 | shuffle=True, 72 | num_workers=2, 73 | pin_memory=True, 74 | ) 75 | 76 | return trainset, trainloader, testset, testloader 77 | 78 | 79 | def get_imagenet_data(data_dir, batch_size): 80 | print("==> Preparing ImageNet data...") 81 | 82 | traindir = os.path.join(data_dir, "training") 83 | valdir = os.path.join(data_dir, "validation") 84 | normalize = transforms.Normalize( 85 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 86 | ) 87 | 88 | trainset = torchvision.datasets.ImageFolder( 89 | traindir, 90 | transforms.Compose( 91 | [ 92 | transforms.RandomResizedCrop(224), 93 | transforms.RandomHorizontalFlip(), 94 | transforms.ToTensor(), 95 | normalize, 96 | ] 97 | ), 98 | ) 99 | 100 | trainloader = torch.utils.data.DataLoader( 101 | trainset, 102 | batch_size=batch_size, 103 | shuffle=True, 104 | num_workers=16, 105 | pin_memory=True, 106 | ) 107 | 108 | testset = torchvision.datasets.ImageFolder( 109 | valdir, 110 | transforms.Compose( 111 | [ 112 | transforms.Resize(256), 113 | transforms.CenterCrop(224), 114 | transforms.ToTensor(), 115 | normalize, 116 | ] 117 | ), 118 | ) 119 | 120 | # Note that, we perform an analysis of BatchNorm statistics when validating, 121 | # so we *must* shuffle the validation set. 122 | testloader = torch.utils.data.DataLoader( 123 | testset, 124 | batch_size=batch_size, 125 | shuffle=True, 126 | num_workers=16, 127 | pin_memory=True, 128 | ) 129 | 130 | return trainset, trainloader, testset, testloader 131 | 132 | 133 | def get_yaml_config(config_file): 134 | print(f"Reading config {config_file}") 135 | with open(config_file) as f: 136 | config = yaml.load(f, Loader=yaml.FullLoader) 137 | 138 | return config 139 | 140 | 141 | def print_train_params(config, setting, method, norm, save_dir): 142 | params = config["parameters"] 143 | 144 | model = params["model_config"]["model_class"] 145 | dataset = params["dataset"] 146 | 147 | model_name_map = { 148 | "cpreresnet20": "cPreResNet20", 149 | "resnet18": "ResNet18", 150 | "vgg19": "VGG19", 151 | } 152 | 153 | method_name_map = { 154 | "lcs_l": "LCS+L", 155 | "lcs_p": "LCS+P", 156 | "ns": "NS", 157 | "us": "US", 158 | "lec": "LEC", 159 | "target_topk": "TopK Target", 160 | "target_bit_width": "Bit Width Target", 161 | } 162 | 163 | dataset_name_map = {"cifar10": "CIFAR-10", "imagenet": "ImageNet"} 164 | 165 | setting_name_map = { 166 | "unstructured_sparsity": "Unstructured sparsity", 167 | "structured_sparsity": "Structured sparsity", 168 | "quantized": "Quantized", 169 | } 170 | 171 | msg = f"{setting_name_map[setting]} ({method_name_map[method]}) training:" 172 | msg += f" {model_name_map[model]} on {dataset_name_map[dataset]} w/ {norm}." 173 | 174 | if method == "target_topk": 175 | topk_target = params["topk"] 176 | msg += f" TopK target: {topk_target}." 177 | elif method == "target_bit_width": 178 | bit_width_target = params["num_bits"] 179 | msg += f" Bit width target: {bit_width_target}." 180 | 181 | print() 182 | print(msg) 183 | print(f"Saving to {save_dir}.") 184 | print() 185 | time.sleep(5) 186 | 187 | 188 | def create_save_dir(base_save, method, model, dataset, norm, use_default=True): 189 | if use_default: 190 | save_dir = f"{base_save}/{method}/{model}/{dataset}/{norm}" 191 | else: 192 | save_dir = base_save 193 | 194 | if not os.path.isdir(save_dir): 195 | os.makedirs(save_dir, exist_ok=True) 196 | 197 | return save_dir 198 | 199 | 200 | def network_and_params(config=None): 201 | """ 202 | Returns the network and training parameters for the specified model type. 203 | """ 204 | 205 | model_config = config["parameters"].get("model_config") 206 | model_class = model_config["model_class"] 207 | model_class_dict = { 208 | "cpreresnet20": networks.cpreresnet, 209 | "resnet18": networks.ResNet18, 210 | "vgg19": networks.vgg19, 211 | } 212 | if model_config is not None: 213 | # Get the class. 214 | if model_class in model_class_dict: 215 | model_class = model_class_dict[model_class] 216 | else: 217 | raise NotImplementedError( 218 | f"Invalid model_class={model_config['model_class']}" 219 | ) 220 | if "model_kwargs" in model_config: 221 | extra_model_kwargs = model_config["model_kwargs"] 222 | else: 223 | extra_model_kwargs = {} 224 | else: 225 | extra_model_kwargs = {} 226 | 227 | # General params 228 | epochs = config["parameters"].get("epochs", 200) 229 | 230 | test_freq = config["parameters"].get("test_freq", 20) 231 | batch_size = config["parameters"].get("batch_size", 128) 232 | learning_rate = config["parameters"].get("learning_rate", 0.01) 233 | momentum = config["parameters"].get("momentum", 0.9) 234 | weight_decay = config["parameters"].get("weight_decay", 0.0005) 235 | warmup_budget = config["parameters"].get("warmup_budget", 80) / 100.0 236 | dataset = config["parameters"].get("dataset", "cifar10") 237 | alpha_grid = config["parameters"].get("alpha_grid", None) 238 | 239 | regime = config["parameters"]["regime"] 240 | 241 | if dataset == "cifar10": 242 | data = get_cifar10_data(batch_size) 243 | elif dataset == "imagenet": 244 | imagenet_dir = config["parameters"]["dataset_dir"] 245 | data = get_imagenet_data(imagenet_dir, batch_size) 246 | else: 247 | raise ValueError(f"Dataset {dataset} not supported") 248 | 249 | train_size = len(data[0]) 250 | warmup_iterations = np.ceil( 251 | warmup_budget * epochs * train_size / batch_size 252 | ) 253 | 254 | # Get model layers 255 | conv_type = config["parameters"]["conv_type"] 256 | bn_type = config["parameters"]["bn_type"] 257 | block_conv_type = config["parameters"]["block_conv_type"] 258 | block_bn_type = config["parameters"]["block_bn_type"] 259 | 260 | # Get regime-specific parameters 261 | regime_params = config["parameters"].get("regime_params", {}) 262 | regime_params["regime"] = regime 263 | builder_parms = {} 264 | block_builder_params = {} 265 | 266 | # Append dataset to extra_model_kwargs args 267 | extra_model_kwargs["dataset"] = dataset 268 | 269 | if regime == "sparse": 270 | if "Sparse" not in block_conv_type: 271 | raise ValueError( 272 | "Regime set to sparse but non-sparse convolution layer received..." 273 | ) 274 | regime_params["topk"] = config["parameters"].get("topk", 0.0) 275 | regime_params["current_iteration"] = 0 276 | regime_params["warmup_iterations"] = warmup_iterations 277 | regime_params["alpha_sampling"] = config["parameters"].get( 278 | "alpha_sampling", [0.025, 1, 0] 279 | ) 280 | 281 | method = config["parameters"].get("method", "topk") 282 | block_builder_params["method"] = method 283 | if "Sparse" in conv_type: 284 | builder_parms["method"] = method 285 | 286 | elif regime == "lec": 287 | regime_params["topk"] = config["parameters"].get("topk", 0.0) 288 | regime_params["current_iteration"] = 0 289 | regime_params["warmup_iterations"] = warmup_iterations 290 | regime_params["alpha_sampling"] = config["parameters"].get( 291 | "alpha_sampling", [0, 1, 0] 292 | ) 293 | regime_params["model_kwargs"] = { 294 | "dataset": dataset, 295 | **extra_model_kwargs, 296 | } 297 | regime_params["bn_update_factor"] = config["parameters"].get( 298 | "bn_update_factor", 0 299 | ) 300 | 301 | regime_params["bn_type"] = config["parameters"]["bn_type"] 302 | 303 | elif regime == "ns": 304 | width_factors_list = config["parameters"]["builder_kwargs"][ 305 | "width_factors_list" 306 | ] 307 | regime_params["width_factors_list"] = width_factors_list 308 | 309 | builder_parms["pass_first_last"] = True 310 | 311 | block_builder_params["pass_first_last"] = True 312 | 313 | if config["parameters"]["block_conv_type"] != "AdaptiveConv2d": 314 | block_builder_params["width_factors_list"] = width_factors_list 315 | builder_parms["width_factors_list"] = width_factors_list 316 | 317 | if "BN" in config["parameters"]["bn_type"]: 318 | norm_kwargs = config["parameters"].get("norm_kwargs", {}) 319 | 320 | builder_parms["norm_kwargs"] = { 321 | "width_factors_list": width_factors_list, 322 | **norm_kwargs, 323 | } 324 | block_builder_params["norm_kwargs"] = { 325 | "width_factors_list": width_factors_list, 326 | **norm_kwargs, 327 | } 328 | 329 | regime_params["bn_type"] = config["parameters"]["bn_type"] 330 | 331 | elif regime == "us": 332 | builder_parms["pass_first_last"] = True 333 | block_builder_params["pass_first_last"] = True 334 | 335 | if "BN" in config["parameters"]["bn_type"]: 336 | norm_kwargs = config["parameters"]["norm_kwargs"] 337 | assert "width_factors_list" in norm_kwargs 338 | 339 | block_builder_params["norm_kwargs"] = norm_kwargs 340 | builder_parms["norm_kwargs"] = norm_kwargs 341 | 342 | elif regime == "quantized": 343 | if "ConvBn2d" not in block_conv_type: 344 | raise ValueError( 345 | "Regime set to quanitzed but non-quantized convolution layer received..." 346 | ) 347 | block_builder_params["num_bits"] = config["parameters"].get( 348 | "num_bits", 8 349 | ) 350 | block_builder_params["iteration_delay"] = warmup_iterations 351 | 352 | if conv_type == "ConvBn2d": 353 | builder_parms["num_bits"] = config["parameters"].get("num_bits", 8) 354 | builder_parms["iteration_delay"] = warmup_iterations 355 | 356 | regime_params["min_bits"] = config["parameters"].get("min_bits", 2) 357 | regime_params["max_bits"] = config["parameters"].get("max_bits", 8) 358 | regime_params["num_bits"] = config["parameters"].get("num_bits", 8) 359 | regime_params["discrete"] = config["parameters"].get( 360 | "discrete_alpha_map", False 361 | ) 362 | 363 | regime_params["is_standard"] = config["parameters"].get( 364 | "is_standard", False 365 | ) 366 | regime_params["random_param"] = config["parameters"].get( 367 | "random_param", False 368 | ) 369 | regime_params["num_points"] = config["parameters"].get("num_points", 0) 370 | 371 | # Evaluation parameters for independent models 372 | regime_params["eval_param_grid"] = config["parameters"].get( 373 | "eval_param_grid", None 374 | ) 375 | 376 | # If norm_kwargs haven't been set, and they are present, add them to 377 | # builder_params and block_builder_params. 378 | if "norm_kwargs" not in builder_parms: 379 | builder_parms["norm_kwargs"] = config["parameters"].get( 380 | "norm_kwargs", {} 381 | ) 382 | if "norm_kwargs" not in block_builder_params: 383 | block_builder_params["norm_kwargs"] = config["parameters"].get( 384 | "norm_kwargs", {} 385 | ) 386 | 387 | # Construct network 388 | builder = Builder(conv_type=conv_type, bn_type=bn_type, **builder_parms) 389 | block_builder = Builder( 390 | block_conv_type, block_bn_type, **block_builder_params 391 | ) 392 | 393 | net = model_class( 394 | builder=builder, block_builder=block_builder, **extra_model_kwargs 395 | ) 396 | 397 | # Input size 398 | regime_params["input_size"] = get_input_size(dataset) 399 | 400 | # Save directory 401 | regime_params["save_dir"] = config["parameters"]["save_dir"] 402 | 403 | criterion = nn.CrossEntropyLoss() 404 | optimizer = optim.SGD( 405 | net.parameters(), 406 | lr=learning_rate, 407 | momentum=momentum, 408 | weight_decay=weight_decay, 409 | ) 410 | 411 | scheduler = schedulers.cosine_lr( 412 | optimizer, learning_rate, warmup_length=5, epochs=epochs 413 | ) 414 | 415 | train_params = [epochs, test_freq, alpha_grid] 416 | opt_params = [criterion, optimizer, scheduler] 417 | 418 | print(f"Got network:\n{net}") 419 | 420 | return net, opt_params, train_params, regime_params, data 421 | 422 | 423 | def get_net(model_dir, net, num_gpus=1): 424 | 425 | device = "cuda" if torch.cuda.is_available() else "cpu" 426 | 427 | if device == "cuda": 428 | cudnn.benchmark = True 429 | state_dict = torch.load(model_dir) 430 | else: 431 | state_dict = torch.load(model_dir, map_location=torch.device("cpu")) 432 | 433 | # Set device ids for DataParallel. 434 | if device == "cuda" and num_gpus > 0: 435 | device_ids = list(range(num_gpus)) 436 | else: 437 | device_ids = None 438 | network = torch.nn.DataParallel(net, device_ids=device_ids) 439 | network.load_state_dict(state_dict) 440 | 441 | return network 442 | 443 | 444 | def get_sparsity_rate(model) -> float: 445 | total_params = 0 446 | sparse_params = 0 447 | for module in model.modules(): 448 | if isinstance(module, SparseConv2d): 449 | sparse_weight = module.apply_sparsity(module.weight) 450 | total_params += sparse_weight.numel() 451 | sparse_params += ( 452 | (sparse_weight == 0).float().sum().item() 453 | ) # changed fuse weight to sparse weight 454 | elif isinstance( 455 | module, (nn.Linear, nn.Conv2d) 456 | ): # Make sure we also count contributions from non-sparse elements. 457 | total_params += module.weight.numel() 458 | # Note: bias parameters are a drop in the bucket. 459 | sparsity_rate = sparse_params / np.max([total_params, 1]) 460 | return sparsity_rate 461 | 462 | 463 | def apply_sparsity( 464 | topk, weight: torch.Tensor, return_scale_factors=False, *, method 465 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 466 | if method == "topk": 467 | return apply_topk( 468 | topk, weight, return_scale_factors=return_scale_factors 469 | ) 470 | else: 471 | raise NotImplementedError(f"Invalid sparsity method {method}") 472 | 473 | 474 | def apply_topk(topk: float, weight: torch.Tensor, return_scale_factors=False): 475 | """ 476 | Given a weight tensor, retains the top self.current_topk of the weights and 477 | multiplies the rest by 0 478 | Inputs: 479 | - weight: A weight tensor, e.g., self.weight 480 | """ 481 | # Retain only the topk weights, multiplying the rest by 0. 482 | frac_to_zero = 1 - topk 483 | with torch.no_grad(): 484 | flat_weight = weight.flatten() 485 | # Want to convert it away from a special tensor, hence the float() call. 486 | _, idx = flat_weight.float().abs().sort() 487 | # @idx is a @special_tensors._SpecialTensor, but we need to convert it 488 | # to a normal tensor for indexing to work properly. 489 | idx = torch.tensor(idx, requires_grad=False) 490 | f = int(frac_to_zero * weight.numel()) 491 | scale_factors = torch.ones_like(flat_weight, requires_grad=False) 492 | scale_factors[idx[:f]] = 0 493 | scale_factors = scale_factors.view_as(weight) 494 | 495 | ret = weight * scale_factors 496 | 497 | if return_scale_factors: 498 | return ret, scale_factors 499 | 500 | return ret 501 | 502 | 503 | def apply_in_topk_reg(model, apply_to_bias=False): 504 | loss = None 505 | count = 0 506 | for m in model.modules(): 507 | if isinstance(m, modules.SubspaceIN): 508 | count += 1 509 | w, b = m.get_weight() 510 | 511 | if loss is None: 512 | loss = torch.tensor(0.0, device=w.device) 513 | loss += w.abs().sum() 514 | 515 | if apply_to_bias: 516 | loss += b.abs().sum() 517 | 518 | if count == 0: 519 | raise ValueError(f"Didn't adjust any Norms") 520 | 521 | return loss 522 | 523 | 524 | def get_norm_type_string(model): 525 | bn_count = 0 526 | in_count = 0 527 | for m in model.modules(): 528 | if isinstance(m, nn.BatchNorm2d): 529 | bn_count += 1 530 | elif isinstance(m, nn.InstanceNorm2d): 531 | in_count += 1 532 | 533 | if bn_count > 0: 534 | assert in_count == 0, "Got both BN and IN" 535 | return "StandardBN" 536 | elif in_count > 0: 537 | assert bn_count == 0, "Got both BN and IN" 538 | return "StandardIN" 539 | else: 540 | raise ValueError(f"No norm layers detected.") 541 | 542 | 543 | def make_fresh_copy_of_pruned_network(model: nn.Module, model_kwargs: Dict): 544 | norm_type_string = get_norm_type_string(model) 545 | builder = Builder(conv_type="StandardConv", bn_type=norm_type_string) 546 | 547 | copy = type(model.module)( 548 | builder=builder, block_builder=builder, **model_kwargs 549 | ) # type: nn.Module 550 | # Need to move @copy to GPU before moving to DataParallel. 551 | if next(model.parameters()).is_cuda: 552 | copy = copy.cuda() 553 | copy = nn.DataParallel(copy) 554 | 555 | state_dict = model.state_dict() 556 | del_me = [] 557 | for k, v in state_dict.items(): 558 | if k.endswith(f"1"): 559 | del_me.append(k) 560 | 561 | for elem in del_me: 562 | del state_dict[elem] 563 | 564 | copy.load_state_dict(state_dict) 565 | 566 | # The only part we should need to fix are modules with a get_weight() 567 | # function. 568 | name_to_copy = {name: module for name, module in copy.named_modules()} 569 | 570 | for name, module in model.named_modules(): 571 | if hasattr(module, "get_weight"): 572 | print(f"Adjusting weight at module {name}") 573 | 574 | pieces = module.get_weight() 575 | 576 | if len(pieces) == 1: 577 | name_to_copy[name].weight.data = pieces 578 | else: 579 | assert len(pieces) == 2, f"Invalid len(pieces)={len(pieces)}" 580 | name_to_copy[name].weight.data = pieces[0] 581 | name_to_copy[name].bias.data = pieces[1] 582 | 583 | return copy 584 | 585 | 586 | def get_input_size(dataset): 587 | return { 588 | "cifar10": 32, 589 | "imagenet": 224, 590 | }[dataset] 591 | 592 | 593 | def register_bn_tracking_hook(module, mean_dict, var_dict, name): 594 | def forward_hook(_, x, __): 595 | x = x[0] 596 | reshaped_x = x.permute(1, 0, 2, 3).reshape(x.shape[1], -1) 597 | 598 | mean_dict[name].append(reshaped_x.mean(dim=1).detach().cpu()) 599 | var_dict[name].append(reshaped_x.var(dim=1).detach().cpu()) 600 | 601 | return module.register_forward_hook(forward_hook) 602 | 603 | 604 | def register_bn_tracking_hooks(model: nn.Module): 605 | mean_dict = collections.defaultdict(list) 606 | var_dict = collections.defaultdict(list) 607 | 608 | hooks = [] 609 | for name, module in model.named_modules(): 610 | # NOTE: we omit GroupNorm from this check, because it doesn't track 611 | # running stats (there's no option to do so in PyTorch). 612 | if ( 613 | isinstance(module, nn.modules.batchnorm._NormBase) 614 | and module.track_running_stats 615 | ): 616 | hooks.append( 617 | register_bn_tracking_hook(module, mean_dict, var_dict, name) 618 | ) 619 | return hooks, mean_dict, var_dict 620 | 621 | 622 | def unregister_bn_tracking_hooks(hooks: List[Callable]): 623 | for hook in hooks: 624 | hook.remove() 625 | 626 | 627 | def get_bn_accuracy_metrics(model: nn.Module, mean_dict: Dict, var_dict: Dict): 628 | """Determine how accurate the running_mean and running_var of the BatchNorms 629 | are. 630 | 631 | Note that, the test set should be shuffled during training, or these 632 | statistics won't be valid. 633 | 634 | Arguments: 635 | model: the network. 636 | mean_dict: A dictionary that looks like: 637 | {'layer_name': [batch1_mean, batch2_mean, ...]} 638 | var_dict: Similar to mean_dict, but with variances. 639 | """ 640 | mean_results = collections.defaultdict(list) 641 | var_results = collections.defaultdict(list) 642 | 643 | name_to_module = {name: module for name, module in model.named_modules()} 644 | 645 | for name in mean_dict.keys(): 646 | module = name_to_module[name] 647 | 648 | running_mean = module.running_mean.detach().cpu() 649 | assert isinstance(mean_dict[name], list) 650 | for batch_result in mean_dict[name]: 651 | num_channels = batch_result.shape[0] 652 | mean_abs_diff = ( 653 | (batch_result - running_mean[:num_channels]).abs().mean().item() 654 | ) 655 | mean_results[name].append(mean_abs_diff) 656 | 657 | running_var = module.running_var.detach().cpu() 658 | assert isinstance(var_dict[name], list) 659 | for batch_result in var_dict[name]: 660 | num_channels = batch_result.shape[0] 661 | mean_abs_diff = ( 662 | (batch_result - running_var[:num_channels]).abs().mean().item() 663 | ) 664 | var_results[name].append(mean_abs_diff) 665 | 666 | # For each layer, record the mean and std of the average deviations, for 667 | # both running_mean and running_var. 668 | ret = {} 669 | for name, stats in mean_results.items(): 670 | ret[f"{name}_running_mean_MAD_mean"] = torch.tensor(stats).mean().item() 671 | ret[f"{name}_running_mean_MAD_std"] = torch.tensor(stats).std().item() 672 | for name, stats in var_results.items(): 673 | ret[f"{name}_running_var_MAD_mean"] = torch.tensor(stats).mean().item() 674 | ret[f"{name}_running_var_MAD_std"] = torch.tensor(stats).std().item() 675 | return {"bn_metrics": ret} 676 | --------------------------------------------------------------------------------