├── .gitignore ├── LICENSE ├── README.md ├── compression ├── __init__.py ├── huffman_encoding.py ├── pruning.py └── quantization.py ├── configs ├── AlexNet │ ├── alexnet-compression.json │ └── alexnet-pretrain.json ├── LeNet300 │ ├── lenet-300-compression.json │ └── lenet-300-pretrain.json └── LeNet5 │ ├── lenet-5-cifar100.json │ ├── lenet-5-compression.json │ └── lenet-5-pretrain.json ├── data.py ├── environment.yml ├── logger ├── __init__.py ├── logger.py └── logger_config.json ├── models ├── __init__.py ├── alexnet.py ├── lenet.py └── vgg.py ├── notebooks ├── mnist-lenet300.ipynb └── sensitivity.ipynb ├── parse_config.py ├── sensitivity.py ├── test.py ├── train.py ├── trainer ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── pruning.py │ └── quantization.py ├── lit_model.py ├── metrics.py └── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | # project runs output and checkpoints 3 | /runs/ 4 | /assets/ 5 | notebooks/*.ipynb_checkpoints 6 | 7 | config.json 8 | 9 | # editor, os cache directory 10 | .vscode/ 11 | /.idea/ 12 | __MACOSX/ 13 | __pycache__ 14 | 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dario Cioni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyTorch Lightning](https://img.shields.io/badge/PyTorch-Lightning-blueviolet)](#) 2 | 3 | # Deep Compression 4 | 5 | [Original Paper](https://arxiv.org/abs/1510.00149), [My slides](https://drive.google.com/file/d/1PgbEiVICWGyU-r3I94mqTNf4wtCbIM79/view?usp=sharing) 6 | 7 | This repository is an unofficial [Pytorch Lightning](https://lightning.ai/pages/open-source/) 8 | implementation of the paper "**Deep Compression**: Compressing Deep Neural Networks with pruning,trained quantization and Huffman coding" by Song Han Huizi Mao and William J. Dally, 2015. 9 | It provides an implementation of the three core methods described in the paper: 10 | 11 | - Pruning 12 | - Quantization 13 | - Huffman Encoding 14 | 15 | These are the main results on the MNIST and [Imagenenette](https://github.com/fastai/imagenette) datasets 16 | 17 | | Network | Top-1 Error (Ours) | Top-1 Error (Han et al.) | Compression Rate (Ours) | Compression Rate (Han et al.) | 18 | |--------------------------|--------------------|--------------------------|-------------------------|-------------------------------| 19 | | LeNet-300-100 Ref | 2.0% | 1.64% | - | - | 20 | | LeNet-300-100 Compressed | 1.8% | 1.58% | **48X** | 40X | 21 | | LeNet-5 Ref | 0.83% | 0.8% | - | - | 22 | | LeNet-5 Compressed | 0.74% | 0.74% | **46X** | 39X | 23 | | AlexNet Ref | 9.11% | - | - | - | 24 | | AlexNet Compressed | 3.87% | - | **41X** | 35X | 25 | 26 | This project was implemented for **Deep Learning** exam at University of Florence. 27 | 28 | ## Table of contents 29 | 30 | 31 | * [Deep Compression](#deep-compression) 32 | * [Table of contents](#table-of-contents) 33 | * [Requirements](#requirements) 34 | * [Project Structure](#project-structure) 35 | * [Usage](#usage) 36 | * [Models](#models) 37 | * [Configuration file](#configuration-file) 38 | * [Training](#training) 39 | * [Testing](#testing) 40 | * [Sensitivity analysis](#sensitivity-analysis) 41 | * [Pruning](#pruning) 42 | * [Quantization](#quantization) 43 | * [Huffman encoding](#huffman-encoding) 44 | * [Results](#results) 45 | * [MNIST](#mnist) 46 | * [Imagenette](#imagenette) 47 | * [References](#references) 48 | * [Acknowledgments](#acknowledgments) 49 | 50 | 51 | ## Requirements 52 | - pytorch 53 | - pytorch-lightning 54 | - torchmetrics 55 | - torchvision 56 | - ipykernel 57 | - jupyter 58 | - matplotlib 59 | - numpy 60 | - scipy 61 | - scikit-learn 62 | - tqdm 63 | - tensorboard 64 | 65 | ## Project Structure 66 | ``` 67 | model-compression/ 68 | │ 69 | ├── train.py - main script to start training 70 | ├── test.py - evaluation of trained model 71 | │ 72 | ├── compression/ - directory containing all the Deep Compression logic 73 | │ ├── pruning.py - implements ThresholdPruning and utilities for sparsity calculation 74 | │ ├── quantization.py - implements all the weight sharing logic, utilities for compression calculation 75 | │ └── huffman_encoding.py - implements huffman encoding 76 | │ 77 | ├── configs/ - directory of saved model configurations for various datasets and models 78 | ├── config.json - a configuration file for your current experiment. 79 | ├── parse_config.py - handles config file and cli options 80 | │ 81 | ├── data.py - anything about data loading goes here 82 | │ ├── BaseDataLoader - Abstract Base Class for Dataloader 83 | │ ├── MnistDataLoader - DataLoader for MNIST 84 | │ ├── CIFAR100DataLoader - DataLoader for CIFAR 100 85 | │ └── ImagenetDataLoader - DataLoader for Imagenet-like datasets 86 | │ 87 | ├── data/ - directory for storing input data 88 | │ 89 | ├── models/ - directory of developed models 90 | │ ├── lenet.py - Implementation of LeNet300-100 and LeNet-5 91 | │ ├── alexnet.py - Implementation of AlexNet which follows Caffe implementation 92 | │ │ https://github.com/songhan/Deep-Compression-AlexNet/blob/master/bvlc_alexnet_deploy.prototxt 93 | │ └── vgg.py - Implementation of VGG-16 94 | │ 95 | ├── notebooks/ - directory containing example notebooks 96 | │ ├── mnist-lenet300.ipynb - Deep Compression pipeline example on MNIST with LeNet-300-100 FC model 97 | │ ├── mnist-lenet5.ipynb - Deep Compression pipeline example on MNIST with LeNet-5 model 98 | │ └── ... 99 | │ 100 | ├── runs/ - trained models and logs are saved here 101 | │ 102 | ├── trainer/ - module containing code for training and evaluating models 103 | │ ├── callbacks/ - module containing custom callbacks for Lightning Trainer 104 | │ │ ├── IterativePruning - Custom callback extending ModulePruning allowing finegraned control on the pruning process 105 | │ │ └── Quantization - Custom callback defining quantization process. Also handles huffman encoding calculation. 106 | │ │ 107 | │ ├── lit_model.py - Lightning wrapper for model training 108 | │ ├── metrics.py - code to define metrics 109 | │ └── trainer.py - code to configure a Lightning Trainer from json configuration 110 | │ 111 | ├── logger/ - module for additional console logging (Tensorboard is handled by Lightning) 112 | │ ├── logger.py 113 | │ └── logger_config.json 114 | │ 115 | └── utils.py - utility functions 116 | ``` 117 | 118 | ## Usage 119 | 120 | ### Data 121 | The data is stored by default in the `data/` directory. The data loading logic is implemented in the `data.py` file and can be easily extended to new datasets. 122 | 123 | MNIST dataset is available in the `torchvision` package and can be downloaded automatically by Pytorch. 124 | The ImageNette dataset is available at [Imagenenette](https://github.com/fastai/imagenette) and needs to be downloaded and extracted in the `data/` directory. 125 | 126 | 127 | ### Models 128 | [models](models) folder contains the implementation of the following models: 129 | 130 | - LeNet-300 from the original [LeNet paper](http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf) 131 | - LeNet-5, in a modified, larger version which follows the one in the [Deep Compression paper](https://arxiv.org/abs/1510.00149) 132 | - AlexNet, which follows the Caffe implementation available the author's [repository](https://github.com/songhan/Deep-Compression-AlexNet) 133 | - VGG-16 134 | 135 | ### Configuration file 136 | All the experiments are handled by a configuration file in `.json` format: 137 | 138 | ````json 139 | { 140 | "name": "Mnist_LeNet300", 141 | "n_gpu": 1, 142 | "arch": { 143 | "type": "LeNet300", 144 | "args": { 145 | "num_classes": 10, 146 | "grayscale": true, 147 | "dropout_rate": 0 148 | } 149 | }, 150 | "data_loader": { 151 | "type": "MnistDataLoader", 152 | "args": { 153 | "data_dir": "data/", 154 | "batch_size": 128, 155 | "shuffle": true, 156 | "validation_split": 0.1, 157 | "num_workers": 6, 158 | "resize": false 159 | } 160 | }, 161 | "optimizer": { 162 | "type": "SGD", 163 | "args": { 164 | "lr": 1e-2, 165 | "momentum": 0.9, 166 | "weight_decay": 1e-3, 167 | "nesterov": true 168 | } 169 | }, 170 | "loss": "cross_entropy", 171 | "metrics": [ 172 | "accuracy", 173 | "topk_accuracy" 174 | ], 175 | "trainer": { 176 | "min_epochs": 10, 177 | "max_epochs": 20, 178 | "save_dir": "runs/", 179 | "verbosity": 1, 180 | "monitor": "max val_accuracy", 181 | "loggers": ["TensorBoard"], 182 | "callbacks": { 183 | "ModelCheckpoint": { 184 | "save_top_k": 1, 185 | "monitor": "val_accuracy", 186 | "every_n_epochs":5, 187 | "mode": "max" 188 | }, 189 | "IterativePruning": { 190 | "pruning_schedule": { 191 | "target_sparsity": 0.9, 192 | "start_epoch": 0, 193 | "prune_every": 2 194 | }, 195 | "pruning_fn": "l1_threshold", 196 | "parameter_names": ["weight"], 197 | "amount": 0.6, 198 | "use_global_unstructured": true, 199 | "make_pruning_permanent": false, 200 | "verbose": 2 201 | }, 202 | "Quantization": { 203 | "epoch": 10, 204 | "quantization_fn": "linear_quantization", 205 | "parameter_names": ["weight"], 206 | "filter_layers": ["Linear"], 207 | "bits": 6, 208 | "verbose": 2, 209 | "huffman_encode": true 210 | } 211 | } 212 | } 213 | } 214 | 215 | 216 | ```` 217 | ### Training 218 | To train a new model from scratch, use the command -c or --config followed by the path to a JSON configuration file 219 | ```sh 220 | $ python train.py -c config.json 221 | ``` 222 | 223 | To resume a training, use the command -r followed by the path to a Pytorch Lightning checkpoint. 224 | 225 | In the same directory it should also be placed the JSON configuration file of the trained model. 226 | This is useful if you want to perform compression step by step, changing the callbacks every time. 227 | ```sh 228 | $ python train.py -r path-to-checkpoint/checkpoint.ckpt 229 | ``` 230 | 231 | ### Testing 232 | ```sh 233 | $ python test.py -r path-to-checkpoint/checkpoint.ckpt 234 | ``` 235 | 236 | ### Sensitivity analysis 237 | 238 | ```sh 239 | $ python sensitivity.py -r path-to-checkpoint/checkpoint.ckpt 240 | ``` 241 | 242 | ## Pruning 243 | Pruning is implemented as a callback, called during training by Pytorch Lightning's [Trainer](https://lightning.ai/docs/pytorch/latest/common/trainer.html). 244 | The `IterativePruning` callback extends [ModelPruning](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelPruning.html#lightning.pytorch.callbacks.ModelPruning) callback with further control on the pruning schedule. 245 | It allows to set a target sparsity level, prune each layer with a different amount/threshold and perform Iterative pruning. 246 | 247 | - `pruning_fn`: Function from torch.nn.utils.prune module or a PyTorch BasePruningMethod subclass. Can also be string e.g. “l1_unstructured” 248 | - `parameter_names`: List of parameter names to be pruned from the nn.Module. Can either be "weight" or "bias". 249 | - `parameters_to_prune`: List of tuples (nn.Module, "parameter_name_string"). If unspecified, retrieves all module in model having `parameter_names`. 250 | - `use_global_unstructured`: Whether to apply pruning globally on the model. If `parameters_to_prune` is provided, global unstructured will be restricted on them. 251 | - `amount`: Quantity of parameters to prune. Can either be 252 | - int specifying the exact amount of parameter to prune 253 | - float specifying the percentage of parameters to prune 254 | - List of int or float speciying the amount to prune in each module. The length of This is allowed only if `use_global_unstructured=False` 255 | - `filter_layers`: List of strings, filters pruning only on layers of a specific class ("Linear","Conv2d" or both.) 256 | 257 | The `pruning_schedule` is provided as a dictionary in trainer's JSON configuration and allows the following arguments: 258 | - `epochs`: list specifying the exact epochs in which pruning is performed. If specified, overrides `start_epoch` and `prune_every` 259 | - `start_epoch`: first epoch in which pruning is performed. 260 | - `prune_every`: performs pruning every `prune_every` epochs. Default=1. 261 | - `target_sparsity`: prevents from applying pruning if the model's sparsity is greater than `target_sparsity` 262 | 263 | Performance of pruned models was evaluated on different datasets in different settings 264 | - One-shot pruning with retraining: prune a trained model, then retrain the weights to compensate the accuracy loss occurred during pruning 265 | - Iterative pruning: iteratively prune and retrain the model multiple times 266 | 267 | | Network | Top-1 Error | Top-5 Error | Parameters | Compression Rate | 268 | |-------------------------------------------|-------------|-------------|------------|------------------| 269 | | LeNet-300-100 Ref | 2.0% | - | 267K | - | 270 | | LeNet-300-100 one-shot pruning w/ retrain | 1.83% | - | **22K** | **12X** | 271 | | LeNet-5 Ref | 0.83% | - | 429K | - | 272 | | LeNet-5 one-shot pruning w/ retrain | 0.69% | - | **36K** | **12X** | 273 | | AlexNet Ref | 9.11% | - | 57M | - | 274 | | AlexNet one-shot pruning w/ retrain | 2.627% | - | 6M | **10X** | 275 | | VGG16 Ref | - | - | 61M | - | 276 | | VGG16 Pruned | | - | | | 277 | 278 | ## Quantization 279 | Quantization is implemented with `Quantizer`, a custom callback called by Pytorch Lightning's [Trainer](https://lightning.ai/docs/pytorch/latest/common/trainer.html). 280 | 281 | The `Quantizer` callback implements abstract [Callback](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback) class 282 | and allows to run vector quantization on Linear and Conv2d modules. 283 | 284 | Vector quantization is implemented in `BaseQuantizationMethod` class, a novel module inspired on existing pruning pipeline in torch.nn.utils.prune. 285 | This module takes care of performing a clustering of parameter's weights and store it as a tensor of cluster centers and an index matrix. 286 | 287 | The initialization of the cluster centroids can be done in three different ways 288 | 289 | - **Linear**: choose linearly-spaced values between [ _min_ , _max_ ] of the original weights 290 | - **Density-based**: chooses the weights based on the density distribution. It linearly spaces the CDF of the weights in the y-axis, then finds the horizontal intersection with the CDF, and finally the vertical intersection on the x-axis, which becomes the centroid. 291 | - **Random/Forgy**: randomly chooses _k_ weights from the weight matrix. 292 | 293 | The callback calls the quantization function for each layer and accepts the following parameters: 294 | - `epoch`: an int indicating the epoch on which quantization is performed 295 | - `quantization_fn`: Function from [compression.quantization](compression/quantization.py) module, passed as a string. Available functions are: "density_quantization","linear_quantization","forgy_quantization" 296 | - `parameter_names`: List of parameter names to be quantized from the nn.Module. Can either be "weight" or "bias". 297 | - `filter_layers`: List of strings, filters pruning only on layers of a specific class ("Linear","Conv2d" or both.) 298 | - `bits`: an int indicating the number of bits used for quantization. The number of codebook weights will be 2**bits. 299 | 300 | | #CONV bits / #FC bits | Top-1 Error | Top-1 Error increase | Compression rate | 301 | |-----------------------|-------------|----------------------|------------------| 302 | | 32 bits / 32 bits | 2.627% | - | - | 303 | | 8 bits / 5 bits | 3.87% | 1.2% | 41X | 304 | | 8 bits / 4 bits | 4.066% | 1.4% | 44X | 305 | | 4 bits / 2 bits | 3.45% | 0.8% | 61X | 306 | 307 | ## Huffman encoding 308 | Huffman Encoding is implemented in [compression.huffman_encoding](compression/huffman_encoding.py) model. 309 | 310 | This module computes the huffman tree for the passed vector and calculates the memory saving obtained by that encoding and the average number of bits used to encode every element of the vector. 311 | The encoding is not actually applied to the vector. 312 | 313 | Huffman Encoding is enabled by setting the parameter `huffman_encode` to True in `Quantization` callback. 314 | 315 | ## Results 316 | 317 | Here's a summary of the reached compression of each model, after pruning, quantization and Huffman Encoding. 318 | The experiments are available on Tensorboard.dev. 319 | 320 | ### MNIST 321 | 322 | - [LeNet-300-100](https://tensorboard.dev/experiment/Z7XtG6YXRdOlBaX9Ramt3g/) 323 | 324 | | Layer | # Weights | Weights % (P) | Weight bits (P+Q) | Weight bits (P+Q+H) | Index bits (P+Q) | Index bits (P+Q+H) | Compress rate (P+Q) | Compress rate (P+Q+H) | 325 | |-----------|-----------|---------------|-------------------|---------------------|------------------|--------------------|---------------------|-----------------------| 326 | | fc1 | 235K | 8% | 6 | 5.1 | 5 | 2.5 | 2.53% | 1.92% | 327 | | fc2 | 30K | 9% | 6 | 5.4 | 5 | 3.6 | 3.03% | 2.71% | 328 | | fc3 | 1K | 26% | 6 | 5.8 | 5 | 3.1 | 14.52% | 13.59% | 329 | | **Total** | 266K | 8% (12X) | 6 | | 5 | | 2.63% (38.0X) | 2.05% (48.7X) | 330 | 331 | - [LeNet-5](https://tensorboard.dev/experiment/RMyp5qxRRSyZP0zn4Oe7wA/) 332 | 333 | | Layer | # Weights | Weights % (P) | Weight bits (P+Q) | Weight bits (P+Q+H) | Index bits (P+Q) | Index bits (P+Q+H) | Compress rate (P+Q) | Compress rate (P+Q+H) | 334 | |-----------|-----------|---------------|-------------------|---------------------|------------------|--------------------|---------------------|-----------------------| 335 | | conv1 | 0.5K | 82% | 8 | 7.9 | 5 | 1.2 | 92.47% | 74.54% | 336 | | conv2 | 25K | 19% | 8 | 7.5 | 5 | 3.0 | 21.10% | 7.09% | 337 | | fc1 | 400K | 7% | 5 | 4.2 | 5 | 3.6 | 1.97% | 1.66% | 338 | | fc2 | 3K | 73% | 5 | 4.4 | 5 | 1.4 | 21.58% | 14.08% | 339 | | **Total** | 429K | 8% (12X) | | | 5 | | 3.34% (39X) | 2.15% (46X) | 340 | 341 | 342 | ### Imagenette 343 | 344 | - [AlexNet](https://tensorboard.dev/experiment/2xJrx1AYRAK4WiobyudX2Q/) 345 | 346 | | Layer | # Weights | Weights % (P) | Weight bits (P+Q) | Weight bits (P+Q+H) | Index bits (P+Q) | Index bits (P+Q+H) | Compress rate (P+Q) | Compress rate (P+Q+H) | 347 | |-----------|-----------|---------------|-------------------|---------------------|------------------|--------------------|---------------------|-----------------------| 348 | | conv1 | 35K | 84% | 8 | 7.2 | 5 | 1.2 | 32.6% | 23.61% | 349 | | conv2 | 307K | 38% | 8 | 6.8 | 5 | 2.6 | 14.33% | 11.20% | 350 | | conv3 | 885K | 35% | 8 | 6.5 | 5 | 2.7 | 13.16% | 10.13% | 351 | | conv4 | 663K | 37% | 8 | 6.6 | 5 | 2.7 | 13.9% | 10.96% | 352 | | conv5 | 442K | 37% | 8 | 6.7 | 5 | 2.7 | 13.92% | 11.06% | 353 | | fc1 | 38M | 9% | 5 | 4.0 | 5 | 4.5 | 2.53% | 2.07% | 354 | | fc2 | 17M | 9% | 5 | 4.1 | 5 | 4.6 | 2.53% | 1.99% | 355 | | fc3 | 4M | 25% | 5 | 4.4 | 5 | 3.3 | 7.11% | 5.95% | 356 | | **Total** | 58M | 11% (10X) | 5.4 | | 5 | | 3.03% (32X) | 2.43% (41X) | 357 | 358 | ## TODOs 359 | - [ ] Switch to PyTorch Lightning console commands 360 | - [ ] Switch to YAML configuration 361 | - [ ] Better log integration w/ Tensorboard & WandB 362 | - [ ] Compressed checkpoints saving 363 | 364 | ## References 365 | [[1]](https://arxiv.org/pdf/1510.00149v5.pdf) Han, Song, Huizi Mao, and William J. Dally. "Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding." arXiv preprint arXiv:1510.00149 (2015) 366 | 367 | [[2]](https://arxiv.org/pdf/1506.02626v3.pdf) Han, Song, et al. "Learning both weights and connections for efficient neural network." Advances in neural information processing systems 28 (2015) 368 | 369 | ## Acknowledgments 370 | - [Pytorch](https://pytorch.org/docs/stable/nn.html#module-torch.nn.utils) for pruning library 371 | - [pytorch-template](https://github.com/victoresque/pytorch-template) for project structure and experiment logging 372 | -------------------------------------------------------------------------------- /compression/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils.prune import l1_unstructured, ln_structured, global_unstructured, identity 2 | from compression.pruning import l1_threshold, get_pruned, ThresholdPruning 3 | from compression.quantization import is_quantized, density_quantization, forgy_quantization, linear_quantization 4 | -------------------------------------------------------------------------------- /compression/huffman_encoding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | 5 | log = logging.getLogger(__name__) 6 | 7 | 8 | # Node of a Huffman Tree 9 | class Node: 10 | def __init__(self, probability, index, left=None, right=None): 11 | # probability of the symbol 12 | self.probability = probability 13 | # the symbol 14 | self.index = index 15 | # the left node 16 | self.left = left 17 | # the right node 18 | self.right = right 19 | # the tree direction (0 or 1) 20 | self.code = '' 21 | 22 | 23 | class HuffmanEncode: 24 | def __init__(self, bits=5): 25 | self.symbols, self.codes = {}, {} 26 | self.initial_bits = bits 27 | 28 | """ Calculates frequency of every index in data""" 29 | 30 | def frequency(self, data): 31 | indices, frequencies = np.unique(data, return_counts=True) 32 | return indices, frequencies 33 | 34 | """ Encodes the symbols by visiting the Huffman Tree """ 35 | 36 | def codify(self, node, value=''): 37 | # a huffman code for current node 38 | newValue = value + str(node.code) 39 | 40 | if node.left: 41 | self.codify(node.left, newValue) 42 | if node.right: 43 | self.codify(node.right, newValue) 44 | 45 | if not node.left and not node.right: 46 | self.codes[node.index] = newValue 47 | return self.codes 48 | 49 | def get_encoded(self, data, coding): 50 | out = [coding[e] for e in data] 51 | return ''.join([str(item) for item in out]) 52 | 53 | """ A supporting function in order to calculate the space difference between compressed and non compressed data""" 54 | 55 | def get_gain(self, data, coding): 56 | # total bit space to store the data before compression 57 | n_data = len(data) 58 | before = n_data * self.initial_bits 59 | after = 0 60 | symbols = coding.keys() 61 | for symbol in symbols: 62 | count = np.count_nonzero(data == symbol) 63 | # calculating how many bit is required for that symbol in total 64 | after += count * len(coding[symbol]) 65 | #log.debug(f" Symbol: {symbol} | count: {count:.0f} | coding length: {len(coding[symbol])}") 66 | log.debug(" Space usage before huffman encoding for {:.0f} values (in bits): {:.0f}".format(n_data, before)) 67 | log.debug(" Space usage after huffman encoding for {:.0f} values (in bits): {:.0f}".format(n_data, after)) 68 | log.info(" Average bits: {:.1f}".format(after / n_data)) 69 | return after, after / n_data 70 | 71 | @classmethod 72 | def encode(cls, data, bits=5): 73 | huffman = cls(bits=bits) 74 | symbols, frequencies = huffman.frequency(data) 75 | # print("symbols: ", symbols) 76 | # print("frequencies: ", the_probabilities) 77 | 78 | nodes = [] 79 | 80 | # converting symbols and probabilities into huffman tree nodes 81 | for s, f in zip(symbols, frequencies): 82 | nodes.append(Node(f, s)) 83 | 84 | while len(nodes) > 1: 85 | # sorting all the nodes in ascending order based on their probability 86 | nodes = sorted(nodes, key=lambda x: x.probability) 87 | # for node in nodes: 88 | # print(node.index, node.prob) 89 | 90 | # picking two smallest nodes 91 | right = nodes[0] 92 | left = nodes[1] 93 | 94 | left.code = 0 95 | right.code = 1 96 | 97 | # combining the 2 smallest nodes to create new node 98 | new = Node(left.probability + right.probability, left.index + right.index, left, right) 99 | 100 | nodes.remove(left) 101 | nodes.remove(right) 102 | nodes.append(new) 103 | 104 | huffmanEncoding = huffman.codify(nodes[0]) 105 | # print("symbols with codes", huffmanEncoding) 106 | tot_size, avg_bits = huffman.get_gain(data, huffmanEncoding) 107 | # encoded = huffman.get_encoded(data, huffmanEncoding) 108 | return tot_size, avg_bits 109 | 110 | def decode(self, encoded, tree): 111 | treeHead = tree 112 | decoded = [] 113 | for x in encoded: 114 | if x == '1': 115 | huffmanTree = huffmanTree.right 116 | elif x == '0': 117 | huffmanTree = huffmanTree.left 118 | try: 119 | if huffmanTree.left.index == None and huffmanTree.right.index == None: 120 | pass 121 | except AttributeError: 122 | decoded.append(huffmanTree.index) 123 | huffmanTree = treeHead 124 | 125 | string = ''.join([str(item) for item in decoded]) 126 | return string 127 | -------------------------------------------------------------------------------- /compression/pruning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn.utils.prune import BasePruningMethod 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | # custom pruning 11 | class ThresholdPruning(BasePruningMethod): 12 | PRUNING_TYPE = "unstructured" 13 | 14 | def __init__(self, amount): 15 | self.s = amount 16 | 17 | def compute_mask(self, tensor, default_mask): 18 | threshold = torch.std(tensor).item() * self.s 19 | return torch.abs(tensor) > threshold 20 | 21 | 22 | # TODO: alternatively prune linear and conv2d 23 | # TODO: iteratively prune and train 24 | 25 | def l1_threshold(module, name, amount): 26 | ThresholdPruning.apply(module, name, amount=amount) 27 | return module 28 | 29 | 30 | def get_pruned(module: nn.Module, name: str) -> Tuple[int, int]: 31 | attr = f"{name}_mask" 32 | if not hasattr(module, attr): 33 | return 0, 1 34 | mask = getattr(module, attr) 35 | return (mask == 0).sum().item(), mask.numel() 36 | 37 | 38 | def sparsity_stats(model, name="weight"): 39 | diff_bits = 5 40 | 41 | sparsity_dict = {n: get_pruned(m, name) for n, m in model.named_modules() if getattr(m, name, None) is not None} 42 | log.info(f"Sparsity stats of `{model.__class__.__name__}` - `{name}`:") 43 | for name, (z, p) in sparsity_dict.items(): 44 | log.info(f" Layer {name}: retained {p - z}/{p} ({(p - z) / p:.2%}) ") 45 | zeros, params = zip(*sparsity_dict.values()) 46 | total_params = sum(params) 47 | total_zeros = sum(zeros) 48 | total_retained = total_params - total_zeros 49 | log.info( 50 | "Total:" 51 | f" Pruned: {total_zeros}/{total_params} ({total_zeros / total_params:.2%})" 52 | f" Retained: {total_retained}/{total_params} ({total_retained / total_params:.2%})" 53 | f" Compression: {total_params / total_retained:.1f} X" 54 | ) 55 | return total_retained, total_zeros 56 | -------------------------------------------------------------------------------- /compression/quantization.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | from abc import ABC, abstractmethod 5 | from typing import Tuple 6 | 7 | from compression.huffman_encoding import HuffmanEncode 8 | 9 | # suppress Kmeans warning of memory leak in Windows 10 | os.environ['OMP_NUM_THREADS'] = "1" 11 | 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.nn.utils.prune as prune 16 | 17 | import numpy as np 18 | from scipy.sparse import csr_matrix, csr_array 19 | from sklearn.cluster import KMeans 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | # Quantization base class inspired on torch.nn.utils.BasePruningMethod 25 | class BaseQuantizationMethod(ABC): 26 | _tensor_name: str 27 | _shape: Tuple 28 | 29 | def __init__(self): 30 | pass 31 | 32 | def __call__(self, module, inputs): 33 | r"""Looks up the weights (stored in ``module[name + '_indices']``) 34 | from indices (stored in ``module[name + '_centers']``) 35 | and stores the result into ``module[name]`` by using 36 | :meth:`lookup_weights`. 37 | 38 | Args: 39 | module (nn.Module): module containing the tensor to prune 40 | inputs: not used. 41 | """ 42 | setattr(module, self._tensor_name, self.lookup_weights(module)) 43 | 44 | def lookup_weights(self, module): 45 | assert self._tensor_name is not None, "Module {} has to be quantized".format( 46 | module 47 | ) # this gets set in apply() 48 | indices = getattr(module, self._tensor_name + '_indices') 49 | centers = getattr(module, self._tensor_name + '_centers') 50 | weights = F.embedding(indices, centers).squeeze() 51 | ## debugging 52 | # weights.register_hook(print) 53 | if prune.is_pruned(module): 54 | mask = getattr(module, self._tensor_name + '_mask') 55 | mat = mask.detach().flatten() 56 | mat[torch.argwhere(mat)] = weights.view(-1, 1) 57 | else: 58 | mat = weights 59 | return mat.view(self._shape) 60 | 61 | @abstractmethod 62 | def initialize_clusters(self, mat, n_points): 63 | pass 64 | 65 | @classmethod 66 | def apply(cls, module, name, bits, *args, **kwargs): 67 | param = getattr(module, name).detach() 68 | # get device on which the parameter is, then move to cpu 69 | device = param.device 70 | shape = param.shape 71 | # flatten weights to accommodate conv and fc layers 72 | mat = param.cpu().view(-1, 1) 73 | 74 | # assume it is a sparse matrix, avoid to encode zeros since they are handled by pruning reparameterization 75 | mat = csr_matrix(mat) 76 | mat = mat.data 77 | if mat.shape[0] < 2 ** bits: 78 | bits = int(np.log2(mat.shape[0])) 79 | log.warning("Number of elements in weight matrix ({}) is less than number of clusters ({:d}). \ 80 | using {:d} bits for quantization." 81 | .format(mat.shape[0], 2 ** bits, bits)) 82 | space = cls(*args, **kwargs).initialize_clusters(mat, 2 ** bits) 83 | 84 | # could do more than one initialization for better results 85 | kmeans = KMeans(n_clusters=len(space), init=space.reshape(-1, 1), n_init=1, 86 | algorithm="lloyd") 87 | kmeans.fit(mat.reshape(-1, 1)) 88 | 89 | method = cls(*args, **kwargs) 90 | # Have the quantization method remember what tensor it's been applied to and weights shape 91 | method._tensor_name = name 92 | method._shape = shape 93 | 94 | centers, indices = kmeans.cluster_centers_, kmeans.labels_ 95 | centers = torch.nn.Parameter(torch.from_numpy(centers).float().to(device)) 96 | indices = torch.from_numpy(indices).to(device) 97 | # If no reparameterization was done before (pruning), delete parameter 98 | if name in module._parameters: 99 | del module._parameters[name] 100 | # reparametrize by saving centroids and indices to `module[name + '_centers']` 101 | # and `module[name + '_indices']`... 102 | module.register_parameter(name + "_centers", centers) 103 | module.register_buffer(name + "_indices", indices) 104 | # ... and the new quantized tensor to `module[name]` 105 | setattr(module, name, method.lookup_weights(module)) 106 | # associate the quantization method to the module via a hook to 107 | # compute the function before every forward() (compile by run) 108 | module.register_forward_pre_hook(method) 109 | # print("Compression rate for layer %s: %.1f" % compression_rate(module,name,bits)) 110 | 111 | def remove(self, module): 112 | r"""Removes the quantization reparameterization from a module. The pruned 113 | parameter named ``name`` remains permanently quantized, and the parameter 114 | named and ``name+'_centers'`` is removed from the parameter list. Similarly, 115 | the buffer named ``name+'_indices'`` is removed from the buffers. 116 | """ 117 | # before removing quantization from a tensor, it has to have been applied 118 | assert ( 119 | self._tensor_name is not None 120 | ), "Module {} has to be quantized\ 121 | before quantization can be removed".format( 122 | module 123 | ) # this gets set in apply() 124 | 125 | # to update module[name] to latest trained weights 126 | weight = self.lookup_weights(module) # masked weights 127 | 128 | # delete and reset 129 | if hasattr(module, self._tensor_name): 130 | delattr(module, self._tensor_name) 131 | del module._parameters[self._tensor_name + "_centers"] 132 | del module._buffers[self._tensor_name + "_indices"] 133 | module.register_parameter(self._tensor_name, weight.data) 134 | 135 | 136 | class LinearQuantizationMethod(BaseQuantizationMethod): 137 | def initialize_clusters(self, mat, n_points): 138 | min_ = mat.min() 139 | max_ = mat.max() 140 | space = np.linspace(min_, max_, num=n_points) 141 | return space 142 | 143 | @classmethod 144 | def apply(cls, module, name, bits, *args, **kwargs): 145 | return super(LinearQuantizationMethod, cls).apply(module, name, bits) 146 | 147 | 148 | class ForgyQuantizationMethod(BaseQuantizationMethod): 149 | def initialize_clusters(self, mat, n_points): 150 | samples = np.random.choice(mat, size=n_points, replace=False) 151 | return samples 152 | 153 | @classmethod 154 | def apply(cls, module, name, bits, *args, **kwargs): 155 | return super(ForgyQuantizationMethod, cls).apply(module, name, bits) 156 | 157 | 158 | class DensityQuantizationMethod(BaseQuantizationMethod): 159 | def initialize_clusters(self, mat, n_points): 160 | x, cdf_counts = np.unique(mat, return_counts=True) 161 | y = np.cumsum(cdf_counts) / np.sum(cdf_counts) 162 | 163 | eps = 1e-2 164 | 165 | space_y = np.linspace(y.min() + eps, y.max() - eps, n_points) 166 | 167 | idxs = [] 168 | # TODO find numpy operator to eliminate for 169 | for i in space_y: 170 | idx = np.argwhere(np.diff(np.sign(y - i)))[0] 171 | idxs.append(idx) 172 | idxs = np.stack(idxs) 173 | return x[idxs] 174 | 175 | @classmethod 176 | def apply(cls, module, name, bits, *args, **kwargs): 177 | return super(DensityQuantizationMethod, cls).apply(module, name, bits) 178 | 179 | 180 | def linear_quantization(module, name, bits): 181 | LinearQuantizationMethod.apply(module, name, bits) 182 | return module 183 | 184 | 185 | def forgy_quantization(module, name, bits): 186 | ForgyQuantizationMethod.apply(module, name, bits) 187 | return module 188 | 189 | 190 | def density_quantization(module, name, bits): 191 | DensityQuantizationMethod.apply(module, name, bits) 192 | return module 193 | 194 | 195 | def is_quantized(module): 196 | for _, submodule in module.named_modules(): 197 | for _, hook in submodule._forward_pre_hooks.items(): 198 | if isinstance(hook, BaseQuantizationMethod): 199 | return True 200 | return False 201 | 202 | 203 | def get_compression(module, name, idx_bits, huffman_encoding=False): 204 | # bits encoding weights 205 | float32_bits = 32 206 | 207 | all_weights = getattr(module, name).numel() 208 | n_weights = all_weights 209 | p_idx, q_idx, idx_size = 0, 0, 0 210 | 211 | if prune.is_pruned(module): 212 | attr = f"{name}_mask" 213 | mask = csr_array(getattr(module, attr).cpu().view(-1)) 214 | n_weights = mask.getnnz() 215 | if huffman_encoding: 216 | # use index difference of csr matrix 217 | idx_diff = np.diff(mask.indices, prepend=mask.indices[0].astype(np.int8)) 218 | # store overhead of adding placeholder zeros, then consider only indices below 2**idx_bits 219 | overhead = sum(map(lambda x: x // 2 ** idx_bits, idx_diff[idx_diff > 2 ** idx_bits])) 220 | idx_diff = idx_diff[idx_diff < 2 ** idx_bits] 221 | p_idx, avg_bits = HuffmanEncode.encode(idx_diff, bits=idx_bits) 222 | p_idx += overhead 223 | log.info(f" before Huffman coding: {n_weights*idx_bits:.0f} | after: {p_idx + overhead} | overhead: {overhead:.0f} | average bits: {avg_bits:.0f}") 224 | else: 225 | p_idx = n_weights * idx_bits 226 | if is_quantized(module): 227 | attr = f"{name}_centers" 228 | n_weights = getattr(module, attr).numel() 229 | attr = f"{name}_indices" 230 | idx = getattr(module, attr).view(-1) 231 | weight_bits = math.log2(n_weights) 232 | q_idx = idx.numel() * weight_bits 233 | if huffman_encoding: 234 | # use index difference of csr matrix 235 | q_idx, _ = HuffmanEncode.encode(idx.detach().cpu().numpy(), bits=weight_bits) 236 | # Note: compression formula in paper does not include the mask 237 | return all_weights * float32_bits, n_weights * float32_bits + p_idx + q_idx 238 | 239 | 240 | def compression_stats(model, name="weight", idx_bits=5, huffman_encoding=False): 241 | log.info(f"Compression stats of `{model.__class__.__name__}` - `{name}`:") 242 | compression_dict = { 243 | n: get_compression(m, name, idx_bits=idx_bits, huffman_encoding=huffman_encoding) for 244 | n, m in model.named_modules() if 245 | getattr(m, name, None) is not None} 246 | 247 | for name, (n, d) in compression_dict.items(): 248 | cr = n / d 249 | log.info(f" Layer {name}: compression rate {1 / cr:.2%} ({cr:.1f}X) ") 250 | n, d = zip(*compression_dict.values()) 251 | total_params = sum(n) 252 | total_d = sum(d) 253 | cr = total_params / total_d 254 | log.info(f"Total compression rate: {1 / cr:.2%} ({cr:.1f}X) ") 255 | return cr 256 | -------------------------------------------------------------------------------- /configs/AlexNet/alexnet-compression.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Imagenette_AlexNet", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "AlexNet", 6 | "args": { 7 | "num_classes": 10, 8 | "dropout_rate": 0.5 9 | } 10 | }, 11 | "data_loader": { 12 | "type": "ImagenetDataLoader", 13 | "args": { 14 | "data_dir": "data/imagenette2", 15 | "batch_size": 128, 16 | "shuffle": true, 17 | "validation_split": 0.1, 18 | "num_workers": 6 19 | } 20 | }, 21 | "optimizer": { 22 | "type": "SGD", 23 | "args": { 24 | "lr": 1e-2, 25 | "momentum": 0.9, 26 | "weight_decay": 5e-4, 27 | "nesterov": true 28 | } 29 | }, 30 | "loss": "cross_entropy", 31 | "metrics": [ 32 | "accuracy", 33 | "topk_accuracy" 34 | ], 35 | "lr_scheduler": { 36 | "type": "ReduceLROnPlateau", 37 | "args": { 38 | "mode": "max", 39 | "factor": 0.1 40 | } 41 | }, 42 | "trainer": { 43 | "epochs": -1, 44 | "save_dir": "runs/", 45 | "verbosity": 1, 46 | "monitor": "max val_accuracy", 47 | "loggers": ["TensorBoard"], 48 | "callbacks": { 49 | "ModelCheckpoint": 50 | { 51 | "save_top_k": 1, 52 | "monitor": "val_accuracy", 53 | "mode": "max" 54 | }, 55 | "EarlyStopping": { 56 | "monitor": "val_accuracy", 57 | "mode": "max", 58 | "patience": 5, 59 | "min_delta": 0.00 60 | }, 61 | "IterativePruning": [ 62 | { 63 | "pruning_schedule": { 64 | "epochs": [ 65 | 0 66 | ] 67 | }, 68 | "pruning_fn": "l1_unstructured", 69 | "parameter_names": [ 70 | "weight" 71 | ], 72 | "filter_layers": [ 73 | "Linear", 74 | "Conv2d" 75 | ], 76 | "amount": [ 77 | 0.16, 78 | 0.62, 79 | 0.65, 80 | 0.63, 81 | 0.63, 82 | 0.91, 83 | 0.91, 84 | 0.75 85 | ], 86 | "use_global_unstructured": false, 87 | "verbose": 2 88 | } 89 | ], 90 | "Quantization": [ 91 | { 92 | "epoch": 8, 93 | "quantization_fn": "linear_quantization", 94 | "parameter_names": [ 95 | "weight" 96 | ], 97 | "filter_layers": [ 98 | "Linear" 99 | ], 100 | "bits": 5, 101 | "verbose": 2, 102 | "huffman_encode": true 103 | }, 104 | { 105 | "epoch": 8, 106 | "quantization_fn": "linear_quantization", 107 | "parameter_names": [ 108 | "weight" 109 | ], 110 | "filter_layers": [ 111 | "Conv2d" 112 | ], 113 | "bits": 8, 114 | "verbose": 2, 115 | "huffman_encode": true 116 | } 117 | ] 118 | } 119 | } 120 | } -------------------------------------------------------------------------------- /configs/AlexNet/alexnet-pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Imagenette_AlexNet", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "AlexNet", 6 | "args": { 7 | "num_classes": 10, 8 | "dropout_rate": 0.5 9 | } 10 | }, 11 | "data_loader": { 12 | "type": "ImagenetDataLoader", 13 | "args": { 14 | "data_dir": "data/imagenette2", 15 | "batch_size": 32, 16 | "shuffle": true, 17 | "validation_split": 0.1, 18 | "num_workers": 6 19 | } 20 | }, 21 | "optimizer": { 22 | "type": "SGD", 23 | "args": { 24 | "lr": 1e-2, 25 | "momentum": 0.9, 26 | "weight_decay": 5e-4, 27 | "nesterov": false 28 | } 29 | }, 30 | "loss": "cross_entropy", 31 | "metrics": [ 32 | "accuracy", 33 | "topk_accuracy" 34 | ], 35 | "lr_scheduler": { 36 | "type": "ReduceLROnPlateau", 37 | "args": { 38 | "mode": "max", 39 | "factor": 0.1 40 | } 41 | }, 42 | "trainer": { 43 | "min_epochs": 1, 44 | "save_dir": "runs/", 45 | "verbosity": 1, 46 | "monitor": "max val_accuracy", 47 | "loggers": ["TensorBoard"], 48 | "callbacks": { 49 | "ModelCheckpoint": 50 | { 51 | "save_top_k": 1, 52 | "every_n_epochs": 1, 53 | "monitor": "val_accuracy", 54 | "mode": "max" 55 | }, 56 | "EarlyStopping": { 57 | "monitor": "val_accuracy", 58 | "mode": "max", 59 | "patience": 5, 60 | "min_delta": 0.00 61 | } 62 | } 63 | } 64 | } -------------------------------------------------------------------------------- /configs/LeNet300/lenet-300-compression.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Mnist_LeNet300", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "LeNet300", 6 | "args": { 7 | "num_classes": 10, 8 | "grayscale": true, 9 | "dropout_rate": 0 10 | } 11 | }, 12 | "data_loader": { 13 | "type": "MnistDataLoader", 14 | "args": { 15 | "data_dir": "data/", 16 | "batch_size": 128, 17 | "shuffle": true, 18 | "validation_split": 0.1, 19 | "num_workers": 6, 20 | "resize": false 21 | } 22 | }, 23 | "optimizer": { 24 | "type": "SGD", 25 | "args": { 26 | "lr": 1e-2, 27 | "momentum": 0.9, 28 | "weight_decay": 1e-3, 29 | "nesterov": true 30 | } 31 | }, 32 | "loss": "cross_entropy", 33 | "metrics": [ 34 | "accuracy", 35 | "topk_accuracy" 36 | ], 37 | "trainer": { 38 | "min_epochs": 20, 39 | "save_dir": "runs/", 40 | "verbosity": 1, 41 | "monitor": "max val_accuracy", 42 | "loggers": ["TensorBoard"], 43 | "callbacks": { 44 | "ModelCheckpoint": { 45 | "save_top_k": 1, 46 | "monitor": "val_accuracy", 47 | "every_n_epochs":5, 48 | "mode": "max" 49 | }, 50 | "IterativePruning": { 51 | "pruning_schedule": { 52 | "target_sparsity": 0.9, 53 | "start_epoch": 0 54 | }, 55 | "pruning_fn": "l1_unstructured", 56 | "parameter_names": ["weight"], 57 | "amount": [0.92,0.91,0.74], 58 | "use_global_unstructured": false, 59 | "verbose": 2 60 | }, 61 | "Quantization": { 62 | "epoch": 15, 63 | "quantization_fn": "density_quantization", 64 | "parameter_names": ["weight"], 65 | "filter_layers": ["Linear"], 66 | "bits": 6, 67 | "verbose": 2, 68 | "huffman_encode": true 69 | } 70 | } 71 | } 72 | } -------------------------------------------------------------------------------- /configs/LeNet300/lenet-300-pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Mnist_LeNet300", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "LeNet300", 6 | "args": { 7 | "num_classes": 10, 8 | "grayscale": true, 9 | "dropout_rate": 0 10 | } 11 | }, 12 | "data_loader": { 13 | "type": "MnistDataLoader", 14 | "args": { 15 | "data_dir": "data/", 16 | "batch_size": 128, 17 | "shuffle": true, 18 | "validation_split": 0.1, 19 | "num_workers": 6, 20 | "resize": false 21 | } 22 | }, 23 | "optimizer": { 24 | "type": "SGD", 25 | "args": { 26 | "lr": 1e-2, 27 | "momentum": 0.9, 28 | "weight_decay": 1e-3, 29 | "nesterov": true 30 | } 31 | }, 32 | "loss": "cross_entropy", 33 | "metrics": [ 34 | "accuracy", 35 | "topk_accuracy" 36 | ], 37 | "trainer": { 38 | "epochs": -1, 39 | "save_dir": "runs/", 40 | "verbosity": 2, 41 | "monitor": "max val_accuracy", 42 | "loggers": ["TensorBoard"], 43 | "callbacks": { 44 | "ModelCheckpoint": { 45 | "save_top_k": 1, 46 | "monitor": "val_accuracy", 47 | "every_n_epochs":1, 48 | "mode": "max" 49 | }, 50 | "EarlyStopping": { 51 | "monitor": "val_accuracy", 52 | "mode": "max", 53 | "patience": 5, 54 | "min_delta": 0.00 55 | } 56 | } 57 | } 58 | } -------------------------------------------------------------------------------- /configs/LeNet5/lenet-5-cifar100.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Cifar100_LeNet5", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "LeNet5L", 6 | "args": { 7 | "num_classes": 10, 8 | "grayscale": true, 9 | "dropout_rate": 0 10 | } 11 | }, 12 | "data_loader": { 13 | "type": "Cifar100DataLoader", 14 | "args": { 15 | "data_dir": "data/", 16 | "batch_size": 128, 17 | "shuffle": true, 18 | "validation_split": 0.1, 19 | "num_workers": 6, 20 | "resize": true 21 | } 22 | }, 23 | "optimizer": { 24 | "type": "SGD", 25 | "args": { 26 | "lr": 1e-2, 27 | "momentum": 0.9, 28 | "weight_decay": 1e-3, 29 | "nesterov": true 30 | } 31 | }, 32 | "loss": "cross_entropy", 33 | "metrics": [ 34 | "accuracy", 35 | "topk_accuracy" 36 | ], 37 | "trainer": { 38 | "min_epochs": 20, 39 | "save_dir": "runs/", 40 | "verbosity": 1, 41 | "monitor": "max val_accuracy", 42 | "loggers": ["TensorBoard"], 43 | "callbacks": { 44 | "ModelCheckpoint": { 45 | "save_top_k": 1, 46 | "monitor": "val_accuracy", 47 | "every_n_epochs":5, 48 | "mode": "max" 49 | }, 50 | "EarlyStopping": { 51 | "monitor": "val_accuracy", 52 | "mode": "max", 53 | "patience": 5, 54 | "min_delta": 0.00 55 | }, 56 | "IterativePruning": { 57 | "pruning_schedule": { 58 | "epoch": 0 59 | }, 60 | "pruning_fn": "l1_unstructured", 61 | "parameter_names": [ 62 | "weight" 63 | ], 64 | "amount": 0.5, 65 | "use_global_unstructured": true, 66 | "verbose": 2 67 | }, 68 | "Quantization": [{ 69 | "epoch": 20, 70 | "quantization_fn": "density_quantization", 71 | "parameter_names": ["weight"], 72 | "filter_layers": ["Linear"], 73 | "bits": 5, 74 | "verbose": 1, 75 | "huffman_encode": true 76 | },{ 77 | "epoch": 20, 78 | "quantization_fn": "density_quantization", 79 | "parameter_names": ["weight"], 80 | "filter_layers": ["Conv2d"], 81 | "bits": 8, 82 | "verbose": 1, 83 | "huffman_encode": true 84 | }] 85 | } 86 | } 87 | } -------------------------------------------------------------------------------- /configs/LeNet5/lenet-5-compression.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Mnist_LeNet5", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "LeNet5L", 6 | "args": { 7 | "num_classes": 10, 8 | "grayscale": true, 9 | "dropout_rate": 0 10 | } 11 | }, 12 | "data_loader": { 13 | "type": "MnistDataLoader", 14 | "args": { 15 | "data_dir": "data/", 16 | "batch_size": 128, 17 | "shuffle": true, 18 | "validation_split": 0.1, 19 | "num_workers": 6, 20 | "resize": true 21 | } 22 | }, 23 | "optimizer": { 24 | "type": "SGD", 25 | "args": { 26 | "lr": 5e-3, 27 | "momentum": 0.9, 28 | "weight_decay": 5e-4, 29 | "nesterov": false 30 | } 31 | }, 32 | "loss": "cross_entropy", 33 | "metrics": [ 34 | "accuracy", 35 | "topk_accuracy" 36 | ], 37 | "trainer": { 38 | "min_epochs": 5, 39 | "max_epochs": 20, 40 | "save_dir": "runs/", 41 | "verbosity": 1, 42 | "monitor": "max val_accuracy", 43 | "loggers": [ 44 | "TensorBoard" 45 | ], 46 | "callbacks": { 47 | "EarlyStopping": { 48 | "monitor": "val_accuracy", 49 | "mode": "max", 50 | "patience": 5, 51 | "min_delta": 0.00 52 | }, 53 | "IterativePruning": [{ 54 | "pruning_schedule": { 55 | "target_sparsity": 0, 56 | "epochs": [ 57 | 0 58 | ] 59 | }, 60 | "pruning_fn": "l1_unstructured", 61 | "make_pruning_permanent": false, 62 | "filter_layers": ["Conv2d"], 63 | "parameter_names": [ 64 | "weight" 65 | ], 66 | "amount": 0.8, 67 | "use_global_unstructured": true, 68 | "prune_on_fit_start": true, 69 | "verbose": 2 70 | }, 71 | { 72 | "pruning_schedule": { 73 | "target_sparsity": 0, 74 | "epochs": [ 75 | 0 76 | ] 77 | }, 78 | "pruning_fn": "l1_unstructured", 79 | "make_pruning_permanent": false, 80 | "filter_layers": ["Linear"], 81 | "parameter_names": [ 82 | "weight" 83 | ], 84 | "amount": 0.925, 85 | "use_global_unstructured": true, 86 | "prune_on_fit_start": true, 87 | "verbose": 2 88 | }], 89 | "Quantization": [{ 90 | "epoch": 5, 91 | "quantization_fn": "linear_quantization", 92 | "parameter_names": ["weight"], 93 | "filter_layers": ["Linear"], 94 | "bits": 5, 95 | "verbose": 1, 96 | "huffman_encode": false 97 | },{ 98 | "epoch": 5, 99 | "quantization_fn": "linear_quantization", 100 | "parameter_names": ["weight"], 101 | "filter_layers": ["Conv2d"], 102 | "bits": 8, 103 | "verbose": 1, 104 | "huffman_encode": false 105 | }] 106 | } 107 | } 108 | } -------------------------------------------------------------------------------- /configs/LeNet5/lenet-5-pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Mnist_LeNet5", 3 | "n_gpu": 1, 4 | "arch": { 5 | "type": "LeNet5L", 6 | "args": { 7 | "num_classes": 10, 8 | "grayscale": true, 9 | "dropout_rate": 0 10 | } 11 | }, 12 | "data_loader": { 13 | "type": "MnistDataLoader", 14 | "args": { 15 | "data_dir": "data/", 16 | "batch_size": 128, 17 | "shuffle": true, 18 | "validation_split": 0.1, 19 | "num_workers": 6, 20 | "resize": true 21 | } 22 | }, 23 | "optimizer": { 24 | "type": "SGD", 25 | "args": { 26 | "lr": 1e-2, 27 | "momentum": 0.9, 28 | "weight_decay": 5e-4, 29 | "nesterov": false 30 | } 31 | }, 32 | "loss": "cross_entropy", 33 | "metrics": [ 34 | "accuracy", 35 | "topk_accuracy" 36 | ], 37 | "trainer": { 38 | "epochs": -1, 39 | "save_dir": "runs/", 40 | "verbosity": 1, 41 | "monitor": "max val_accuracy", 42 | "loggers": ["TensorBoard"], 43 | "callbacks": { 44 | "ModelCheckpoint": { 45 | "save_top_k": 1, 46 | "monitor": "val_accuracy", 47 | "every_n_epochs":1, 48 | "mode": "max" 49 | }, 50 | "EarlyStopping": { 51 | "monitor": "val_accuracy", 52 | "mode": "max", 53 | "patience": 5, 54 | "min_delta": 0.00 55 | } 56 | } 57 | } 58 | } -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import SubsetRandomSampler, DataLoader, Subset, Dataset 7 | 8 | DATA_DIR = os.path.dirname(os.path.abspath(__file__)) + '/data' 9 | 10 | 11 | class MyDataset(Dataset): 12 | def __init__(self, subset, transform=None): 13 | self.subset = subset 14 | self.transform = transform 15 | 16 | def __getitem__(self, index): 17 | x, y = self.subset[index] 18 | if self.transform: 19 | x = self.transform(x) 20 | return x, y 21 | 22 | def __len__(self): 23 | return len(self.subset) 24 | 25 | 26 | class BaseDataLoader(DataLoader): 27 | """ 28 | Base class for all data loaders 29 | """ 30 | 31 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers): 32 | self.validation_split = validation_split 33 | self.shuffle = shuffle 34 | 35 | self.batch_idx = 0 36 | self.n_samples = len(dataset) 37 | 38 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 39 | 40 | self.init_kwargs = { 41 | 'dataset': dataset, 42 | 'batch_size': batch_size, 43 | 'shuffle': self.shuffle, 44 | 'num_workers': num_workers 45 | } 46 | super().__init__(sampler=self.sampler, **self.init_kwargs) 47 | 48 | def _split_sampler(self, split): 49 | if split == 0.0: 50 | return None, None 51 | 52 | idx_full = np.arange(self.n_samples) 53 | 54 | np.random.shuffle(idx_full) 55 | 56 | if isinstance(split, int): 57 | assert split > 0 58 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 59 | len_valid = split 60 | else: 61 | len_valid = int(self.n_samples * split) 62 | 63 | valid_idx = idx_full[0:len_valid] 64 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 65 | 66 | train_sampler = SubsetRandomSampler(train_idx) 67 | valid_sampler = SubsetRandomSampler(valid_idx) 68 | 69 | # turn off shuffle option which is mutually exclusive with sampler 70 | self.shuffle = False 71 | self.n_samples = len(train_idx) 72 | 73 | return train_sampler, valid_sampler 74 | 75 | def split_validation(self): 76 | if self.valid_sampler is None: 77 | return None 78 | else: 79 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 80 | 81 | 82 | class MnistDataLoader(BaseDataLoader): 83 | def __init__(self, batch_size, data_dir=DATA_DIR, shuffle=True, validation_split=0.0, num_workers=1, training=True, 84 | resize=False): 85 | # calculated dataset mean and variance for standardization 86 | # train_set = torchvision.datasets.MNIST(root=DATA_DIR, train=True, 87 | # download=True) 88 | # 89 | # mean = train_set.data.float().mean() / 255 #0.1307 90 | # std = train_set.data.float().std() / 255 #0.3081 91 | 92 | self.resize = resize 93 | ts = [ 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.1307,), (0.3081,)) 96 | ] 97 | if resize: 98 | ts.append(transforms.Resize(32, antialias=True)) 99 | transform = transforms.Compose(ts) 100 | self.data_dir = data_dir 101 | self.dataset = datasets.MNIST(root=self.data_dir, train=training, download=True, transform=transform) 102 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 103 | 104 | 105 | class Cifar100DataLoader(BaseDataLoader): 106 | # calculated dataset mean and variance for standardization 107 | def __init__(self, batch_size, data_dir=DATA_DIR, shuffle=True, validation_split=0.0, num_workers=1, training=True): 108 | train_set = datasets.CIFAR100(root=DATA_DIR, train=True, 109 | download=True) 110 | 111 | mean = train_set.data.mean() / 255 112 | std = train_set.data.std() / 255 113 | 114 | transform = transforms.Compose([ 115 | transforms.ToTensor(), 116 | transforms.Normalize(mean, std) 117 | ]) 118 | self.data_dir = data_dir 119 | self.dataset = datasets.CIFAR100(root=self.data_dir, train=training, download=True, transform=transform) 120 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 121 | 122 | 123 | class ImagenetDataLoader(BaseDataLoader): 124 | def __init__(self, batch_size, data_dir, shuffle=True, validation_split=0.0, 125 | num_workers=1, training=True): 126 | if training: 127 | self.data_dir = os.path.join(data_dir, 'train') 128 | else: 129 | self.data_dir = os.path.join(data_dir, 'val') 130 | 131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 132 | std=[0.229, 0.224, 0.225]) 133 | 134 | # Todo: define augmentation techniques, different for train and validation. 135 | self.transform = { 136 | 'train': transforms.Compose([ 137 | transforms.Resize(256), 138 | transforms.RandomCrop(224), 139 | transforms.RandomHorizontalFlip(), 140 | transforms.ToTensor(), 141 | normalize 142 | ]), 143 | 'val': transforms.Compose([ 144 | transforms.Resize(256), 145 | transforms.CenterCrop(224), 146 | transforms.ToTensor(), 147 | normalize 148 | ]) 149 | } 150 | 151 | dataset = datasets.ImageFolder( 152 | data_dir, 153 | transform=None 154 | ) 155 | self.n_samples = len(dataset) 156 | self.train_dataset, self.valid_dataset = self._split_dataset(dataset, validation_split) 157 | validation_split = 0.0 158 | 159 | super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers) 160 | 161 | def _split_dataset(self, dataset, split): 162 | if split == 0.0: 163 | return MyDataset(dataset, self.transform['val']), None 164 | 165 | idx_full = np.arange(self.n_samples) 166 | 167 | np.random.shuffle(idx_full) 168 | 169 | if isinstance(split, int): 170 | assert split > 0 171 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 172 | len_valid = split 173 | else: 174 | len_valid = int(self.n_samples * split) 175 | 176 | valid_idx = idx_full[0:len_valid] 177 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 178 | 179 | train_dataset = MyDataset(Subset(dataset, train_idx), self.transform['train']) 180 | valid_dataset = MyDataset(Subset(dataset, valid_idx), self.transform['val']) 181 | 182 | self.n_samples = len(train_idx) 183 | 184 | return train_dataset, valid_dataset 185 | 186 | def split_validation(self): 187 | if self.valid_dataset is None: 188 | return None 189 | else: 190 | return DataLoader(self.valid_dataset, batch_size=self.init_kwargs['batch_size'], 191 | shuffle=False, num_workers=self.init_kwargs['num_workers']) 192 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch-lightning 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - pytorch==2.0.0 9 | - pytorch-cuda==11.7 10 | - pytorch-lightning==2.0.0 11 | - torchmetrics=0.11.4 12 | - torchvision==0.15.0 13 | - ipykernel==6.21.3 14 | - jupyter==1.0.0 15 | - matplotlib==3.7.1 16 | - numpy==1.24.2 17 | - scikit-learn==1.2.2 18 | - scipy==1.10.1 19 | - tqdm==4.65.0 20 | - tensorboard==2.12.0 21 | - pip: 22 | - python-graphviz==0.20.1 23 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from pytorch_lightning.loggers import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | logging.config.dictConfig(config) 19 | # removes logging from PIL plugin https://discuss.pytorch.org/t/weird-debug-messages-in-log-when-using-resnet50/166921 20 | logging.getLogger("PIL.TiffImagePlugin").setLevel(51) 21 | else: 22 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 23 | logging.basicConfig(level=default_level) 24 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.lenet import LeNet300, LeNet5, LeNet5L 2 | from models.alexnet import AlexNet 3 | from models.vgg import VGG16 4 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from compression.pruning import get_pruned 9 | 10 | 11 | class LinearWithAdjustableDropout(nn.Linear): 12 | def __init__(self, in_features, out_features, dropout_rate=0.5): 13 | super().__init__(in_features, out_features) 14 | self.dropout_rate = dropout_rate 15 | 16 | def forward(self, input: Tensor) -> Tensor: 17 | out = F.dropout(input, self.dropout_rate, self.training) 18 | out = super().forward(input) 19 | return out 20 | 21 | def adjust_dropout_rate(self, name="weight"): 22 | c_ir, c_i0 = get_pruned(self, name) 23 | self.dropout_rate = self.dropout_rate * math.sqrt(c_ir / c_i0) 24 | 25 | 26 | class AlexNet(nn.Module): 27 | def __init__(self, num_classes: int = 1000, dropout_rate: float = 0.5) -> None: 28 | super().__init__() 29 | self.features = nn.Sequential( 30 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 31 | nn.ReLU(inplace=True), 32 | nn.MaxPool2d(kernel_size=3, stride=2), 33 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(kernel_size=3, stride=2), 36 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 41 | nn.ReLU(inplace=True), 42 | nn.MaxPool2d(kernel_size=3, stride=2), 43 | ) 44 | self.classifier = nn.Sequential( 45 | LinearWithAdjustableDropout(256 * 6 * 6, 4096), 46 | nn.ReLU(inplace=True), 47 | LinearWithAdjustableDropout(4096, 4096), 48 | nn.ReLU(inplace=True), 49 | nn.Linear(4096, num_classes), 50 | ) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.features(x) 54 | x = x.view(x.size(0), 256 * 6 * 6) 55 | logits = self.classifier(x) 56 | # probs = F.softmax(logits, dim=1) 57 | return logits 58 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LeNet300(nn.Module): 7 | def __init__(self, num_classes, grayscale=True, dropout_rate=0): 8 | super().__init__() 9 | 10 | self.fc1 = nn.Linear(28 * 28, 300) 11 | self.fc2 = nn.Linear(300, 100) 12 | self.fc3 = nn.Linear(100, num_classes) 13 | self.dropout = nn.Dropout(dropout_rate) 14 | # activation function for hidden layers 15 | self.activation = F.relu 16 | 17 | def forward(self, x): 18 | out = torch.flatten(x, 1) 19 | out = self.activation(self.fc1(out)) 20 | out = self.dropout(out) 21 | out = self.activation(self.fc2(out)) 22 | out = self.dropout(out) 23 | out = self.fc3(out) 24 | return out 25 | 26 | 27 | 28 | class LeNet5(nn.Module): 29 | def __init__(self, num_classes, grayscale=True, dropout_rate=0): 30 | super().__init__() 31 | 32 | if grayscale: 33 | in_channels = 1 34 | else: 35 | in_channels = 3 36 | 37 | self.conv1 = nn.Conv2d(in_channels, 6 * in_channels, kernel_size=5, stride=1, padding=0) 38 | self.conv2 = nn.Conv2d(6 * in_channels, 16 * in_channels, kernel_size=5, stride=1, padding=0) 39 | self.fc = nn.Linear(16 * 5 * 5 * in_channels, 120 * in_channels) 40 | self.fc1 = nn.Linear(120 * in_channels, 84 * in_channels) 41 | self.fc2 = nn.Linear(84 * in_channels, num_classes) 42 | self.dropout = nn.Dropout(dropout_rate) 43 | # activation function for hidden layers 44 | self.activation = F.relu 45 | 46 | def forward(self, x): 47 | out = self.activation(self.conv1(x)) 48 | out = F.max_pool2d(out, kernel_size=2, stride=2) 49 | out = self.activation(self.conv2(out)) 50 | out = F.max_pool2d(out, kernel_size=2, stride=2) 51 | out = out.reshape(out.size(0), -1) 52 | out = self.dropout(out) 53 | out = self.activation(self.fc(out)) 54 | out = self.activation(self.fc1(out)) 55 | out = self.fc2(out) 56 | return out 57 | 58 | class LeNet5L(nn.Module): 59 | def __init__(self, num_classes, grayscale=True, dropout_rate=0): 60 | super().__init__() 61 | 62 | if grayscale: 63 | in_channels = 1 64 | else: 65 | in_channels = 3 66 | 67 | self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1, padding=0) 68 | self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1, padding=0) 69 | self.fc1 = nn.Linear(50 * 5 * 5, 320) 70 | self.fc2 = nn.Linear(320, num_classes) 71 | self.dropout = nn.Dropout(dropout_rate) 72 | # activation function for hidden layers 73 | self.activation = F.relu 74 | 75 | def forward(self, x): 76 | out = self.activation(self.conv1(x)) 77 | out = F.max_pool2d(out, kernel_size=2, stride=2) 78 | out = self.activation(self.conv2(out)) 79 | out = F.max_pool2d(out, kernel_size=2, stride=2) 80 | out = out.reshape(out.size(0), -1) 81 | out = self.dropout(out) 82 | out = self.activation(self.fc1(out)) 83 | out = self.dropout(out) 84 | out = self.fc2(out) 85 | return out 86 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | ########################## 7 | ### MODEL 8 | ########################## 9 | 10 | 11 | class VGG16(torch.nn.Module): 12 | 13 | def __init__(self, num_features, num_classes): 14 | super(VGG16, self).__init__() 15 | 16 | # calculate same padding: 17 | # (w - k + 2*p)/s + 1 = o 18 | # => p = (s(o-1) - w + k)/2 19 | 20 | self.block_1 = nn.Sequential( 21 | nn.Conv2d(in_channels=3, 22 | out_channels=64, 23 | kernel_size=(3, 3), 24 | stride=(1, 1), 25 | # (1(32-1)- 32 + 3)/2 = 1 26 | padding=1), 27 | nn.ReLU(), 28 | nn.Conv2d(in_channels=64, 29 | out_channels=64, 30 | kernel_size=(3, 3), 31 | stride=(1, 1), 32 | padding=1), 33 | nn.ReLU(), 34 | nn.MaxPool2d(kernel_size=(2, 2), 35 | stride=(2, 2)) 36 | ) 37 | 38 | self.block_2 = nn.Sequential( 39 | nn.Conv2d(in_channels=64, 40 | out_channels=128, 41 | kernel_size=(3, 3), 42 | stride=(1, 1), 43 | padding=1), 44 | nn.ReLU(), 45 | nn.Conv2d(in_channels=128, 46 | out_channels=128, 47 | kernel_size=(3, 3), 48 | stride=(1, 1), 49 | padding=1), 50 | nn.ReLU(), 51 | nn.MaxPool2d(kernel_size=(2, 2), 52 | stride=(2, 2)) 53 | ) 54 | 55 | self.block_3 = nn.Sequential( 56 | nn.Conv2d(in_channels=128, 57 | out_channels=256, 58 | kernel_size=(3, 3), 59 | stride=(1, 1), 60 | padding=1), 61 | nn.ReLU(), 62 | nn.Conv2d(in_channels=256, 63 | out_channels=256, 64 | kernel_size=(3, 3), 65 | stride=(1, 1), 66 | padding=1), 67 | nn.ReLU(), 68 | nn.Conv2d(in_channels=256, 69 | out_channels=256, 70 | kernel_size=(3, 3), 71 | stride=(1, 1), 72 | padding=1), 73 | nn.ReLU(), 74 | nn.MaxPool2d(kernel_size=(2, 2), 75 | stride=(2, 2)) 76 | ) 77 | 78 | self.block_4 = nn.Sequential( 79 | nn.Conv2d(in_channels=256, 80 | out_channels=512, 81 | kernel_size=(3, 3), 82 | stride=(1, 1), 83 | padding=1), 84 | nn.ReLU(), 85 | nn.Conv2d(in_channels=512, 86 | out_channels=512, 87 | kernel_size=(3, 3), 88 | stride=(1, 1), 89 | padding=1), 90 | nn.ReLU(), 91 | nn.Conv2d(in_channels=512, 92 | out_channels=512, 93 | kernel_size=(3, 3), 94 | stride=(1, 1), 95 | padding=1), 96 | nn.ReLU(), 97 | nn.MaxPool2d(kernel_size=(2, 2), 98 | stride=(2, 2)) 99 | ) 100 | 101 | self.block_5 = nn.Sequential( 102 | nn.Conv2d(in_channels=512, 103 | out_channels=512, 104 | kernel_size=(3, 3), 105 | stride=(1, 1), 106 | padding=1), 107 | nn.ReLU(), 108 | nn.Conv2d(in_channels=512, 109 | out_channels=512, 110 | kernel_size=(3, 3), 111 | stride=(1, 1), 112 | padding=1), 113 | nn.ReLU(), 114 | nn.Conv2d(in_channels=512, 115 | out_channels=512, 116 | kernel_size=(3, 3), 117 | stride=(1, 1), 118 | padding=1), 119 | nn.ReLU(), 120 | nn.MaxPool2d(kernel_size=(2, 2), 121 | stride=(2, 2)) 122 | ) 123 | 124 | self.classifier = nn.Sequential( 125 | nn.Linear(512, 4096), 126 | nn.ReLU(True), 127 | # nn.Dropout(p=0.5), 128 | nn.Linear(4096, 4096), 129 | nn.ReLU(True), 130 | # nn.Dropout(p=0.5), 131 | nn.Linear(4096, num_classes), 132 | ) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): 136 | nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu') 137 | if m.bias is not None: 138 | m.bias.detach().zero_() 139 | 140 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 141 | 142 | def forward(self, x): 143 | 144 | x = self.block_1(x) 145 | x = self.block_2(x) 146 | x = self.block_3(x) 147 | x = self.block_4(x) 148 | x = self.block_5(x) 149 | # x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | logits = self.classifier(x) 152 | probs = F.softmax(logits, dim=1) 153 | 154 | return logits, probs -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, config, resume=None, modification=None, run_id=None): 13 | """ 14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 15 | and logging module. 16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 17 | :param resume: String, path to the checkpoint being loaded. 18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 20 | """ 21 | # load config file and apply modification 22 | self._config = _update_config(config, modification) 23 | self.resume = resume 24 | 25 | # set save_dir where trained model and log will be saved. 26 | save_dir = Path(self.config['trainer']['save_dir']) 27 | 28 | exper_name = self.config['name'] 29 | if run_id is None: # use timestamp as default run-id 30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 31 | self._save_dir = save_dir / exper_name / run_id 32 | 33 | # make directory for saving checkpoints and log. 34 | exist_ok = run_id == '' 35 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 36 | 37 | # save updated config file to the checkpoint dir 38 | write_json(self.config, self.save_dir / 'config.json') 39 | 40 | # configure logging module 41 | setup_logging(self.save_dir) 42 | self.log_levels = { 43 | 0: logging.WARNING, 44 | 1: logging.INFO, 45 | 2: logging.DEBUG 46 | } 47 | 48 | @classmethod 49 | def from_args(cls, args, options=''): 50 | """ 51 | Initialize this class from some cli arguments. Used in train, test. 52 | """ 53 | for opt in options: 54 | args.add_argument(*opt.flags, default=None, type=opt.type) 55 | if not isinstance(args, tuple): 56 | args = args.parse_args() 57 | 58 | if args.device is not None: 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 60 | if args.resume is not None: 61 | resume = Path(args.resume) 62 | cfg_fname = resume.parent / 'config.json' 63 | else: 64 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 65 | assert args.config is not None, msg_no_cfg 66 | resume = None 67 | cfg_fname = Path(args.config) 68 | 69 | config = read_json(cfg_fname) 70 | if args.config and resume: 71 | # update new config for fine-tuning 72 | config.update(read_json(args.config)) 73 | 74 | # parse custom cli options into dictionary 75 | modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options} 76 | return cls(config, resume, modification) 77 | 78 | @classmethod 79 | def from_dict(cls, dict, options=''): 80 | """ 81 | Initialize this class from some cli arguments. Used in train, test. 82 | """ 83 | 84 | if 'device' in dict: 85 | os.environ["CUDA_VISIBLE_DEVICES"] = dict['device'] 86 | if 'resume' in dict: 87 | resume = Path(dict['resume']) 88 | cfg_fname = resume.parent / 'config.json' 89 | else: 90 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 91 | assert 'config' in dict , msg_no_cfg 92 | resume = None 93 | cfg_fname = Path(dict['config']) 94 | 95 | config = read_json(cfg_fname) 96 | if 'config' in dict and resume: 97 | # update new config for fine-tuning 98 | config.update(read_json(dict['config'])) 99 | 100 | # parse custom cli options into dictionary 101 | modification = {} 102 | return cls(config, resume, modification) 103 | 104 | def init_obj(self, name, module, *args, **kwargs): 105 | """ 106 | Finds a function handle with the name given as 'type' in config, and returns the 107 | instance initialized with corresponding arguments given. 108 | 109 | `object = config.init_obj('name', module, a, b=1)` 110 | is equivalent to 111 | `object = module.name(a, b=1)` 112 | """ 113 | module_name = self[name]['type'] 114 | module_args = dict(self[name]['args']) 115 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 116 | module_args.update(kwargs) 117 | return getattr(module, module_name)(*args, **module_args) 118 | 119 | def init_ftn(self, name, module, *args, **kwargs): 120 | """ 121 | Finds a function handle with the name given as 'type' in config, and returns the 122 | function with given arguments fixed with functools.partial. 123 | 124 | `function = config.init_ftn('name', module, a, b=1)` 125 | is equivalent to 126 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 127 | """ 128 | module_name = self[name]['type'] 129 | module_args = dict(self[name]['args']) 130 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 131 | module_args.update(kwargs) 132 | return partial(getattr(module, module_name), *args, **module_args) 133 | 134 | def __getitem__(self, name): 135 | """Access items like ordinary dict.""" 136 | return self.config[name] 137 | 138 | def get_logger(self, name=None, verbosity=2): 139 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 140 | self.log_levels.keys()) 141 | assert verbosity in self.log_levels, msg_verbosity 142 | logger = logging.getLogger(name) 143 | logger.setLevel(self.log_levels[verbosity]) 144 | return logger 145 | 146 | # setting read-only attributes 147 | @property 148 | def config(self): 149 | return self._config 150 | 151 | @property 152 | def save_dir(self): 153 | return self._save_dir 154 | 155 | @property 156 | def log_dir(self): 157 | return self._log_dir 158 | 159 | 160 | # helper functions to update config dict with custom cli options 161 | def _update_config(config, modification): 162 | if modification is None: 163 | return config 164 | 165 | for k, v in modification.items(): 166 | if v is not None: 167 | _set_by_path(config, k, v) 168 | return config 169 | 170 | 171 | def _get_opt_name(flags): 172 | for flg in flags: 173 | if flg.startswith('--'): 174 | return flg.replace('--', '') 175 | return flags[0].replace('--', '') 176 | 177 | 178 | def _set_by_path(tree, keys, value): 179 | """Set a value in a nested object in tree by sequence of keys.""" 180 | keys = keys.split(';') 181 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 182 | 183 | 184 | def _get_by_path(tree, keys): 185 | """Access a nested object in tree by sequence of keys.""" 186 | return reduce(getitem, keys, tree) 187 | -------------------------------------------------------------------------------- /sensitivity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import csv 4 | import logging 5 | import os 6 | from collections import OrderedDict 7 | from copy import deepcopy 8 | 9 | import numpy as np 10 | import pytorch_lightning 11 | import torch 12 | from pytorch_lightning import LightningModule, Trainer 13 | from torch import nn 14 | from torch.nn.utils import prune 15 | 16 | from compression import get_pruned 17 | from parse_config import ConfigParser 18 | from trainer.lit_model import LitModel 19 | from trainer.trainer import get_trainer 20 | from utils import set_all_seeds, set_deterministic, load_compressed_checkpoint 21 | 22 | import data as module_data 23 | import models as module_arch 24 | import compression as module_compression 25 | 26 | CHECKPOINT_DIR = os.path.dirname(os.path.abspath(__file__)) + '/checkpoints' 27 | RUNS_DIR = os.path.dirname(os.path.abspath(__file__)) + '/runs' 28 | SEED = 42 29 | set_all_seeds(42) 30 | # set_deterministic() 31 | 32 | _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def sensitivity_analysis(config, fn, amounts, name="weight"): 38 | """Perform a sensitivity test for a model's weights parameters. 39 | The model should be trained to maximum accuracy, because we aim to understand 40 | the behavior of the model's performance in relation to pruning of a specific 41 | weights tensor. 42 | By default this function will test all of the model's parameters. 43 | The return value is a sensitivities dictionary: the dictionary's 44 | key is the name (string) of the weights tensor. The value is another dictionary, 45 | where the tested sparsity-level is the key, and a (loss, top1, top5) tuple 46 | is the value. 47 | Below is an example of such a dictionary: 48 | .. code-block:: python 49 | {'model.fc1.weight': {0.0: (56.518, 79.07, 1.9159), 50 | 0.05: (56.492, 79.1, 1.9161), 51 | 0.10: (56.212, 78.854, 1.9315), 52 | 0.15: (35.424, 60.3, 3.0866)}, 53 | 'model.fc2.weight': {0.0: (56.518, 79.07, 1.9159), 54 | 0.05: (56.514, 79.07, 1.9159), 55 | 0.10: (56.434, 79.074, 1.9138), 56 | 0.15: (54.454, 77.854, 2.3127)} } 57 | """ 58 | 59 | sensitivities = OrderedDict() 60 | 61 | data_loader = config.init_obj('data_loader', module_data) 62 | valid_data_loader = data_loader.split_validation() 63 | 64 | model = LitModel(config, config.init_obj('arch', module_arch)) 65 | if config.resume: 66 | checkpoint = torch.load(config.resume) 67 | model.load_state_dict(checkpoint['state_dict']) 68 | 69 | current_modules = [m_name for m_name, m in model.model.named_modules() if 70 | not isinstance(m, _MODULE_CONTAINERS) and hasattr(m, name)] 71 | 72 | trainer = Trainer(default_root_dir=config.save_dir, accelerator="gpu", deterministic=True) 73 | 74 | for m_name in current_modules: 75 | 76 | sensitivity = OrderedDict() 77 | 78 | for amount in amounts: 79 | model_cpy = deepcopy(model) 80 | 81 | module = getattr(model_cpy.model, m_name) 82 | 83 | # Create the pruner (a level pruner), the pruning policy and the 84 | # pruning schedule. 85 | 86 | fn.apply(module, name, amount) 87 | 88 | log = trainer.test(model_cpy, valid_data_loader) 89 | sensitivity[amount] = log[0] 90 | sensitivities[m_name] = sensitivity 91 | return sensitivities 92 | 93 | 94 | def pruning_sensitivity_analysis(config, amounts, train=True, params=["weight"]): 95 | sensitivity = OrderedDict() 96 | 97 | data_loader = config.init_obj('data_loader', module_data) 98 | valid_data_loader = data_loader.split_validation() 99 | 100 | model = LitModel(config, config.init_obj('arch', module_arch)) 101 | if config.resume: 102 | checkpoint = torch.load(config.resume) 103 | model.load_state_dict(checkpoint['state_dict']) 104 | 105 | for amount in amounts: 106 | trainer = get_trainer(config) 107 | model_cpy = deepcopy(model) 108 | current_modules = [m for m in model_cpy.model.modules() if not isinstance(m, _MODULE_CONTAINERS)] 109 | parameters_to_prune = [(m, p) for p in params for m in current_modules if hasattr(m, p)] 110 | 111 | prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=amount) 112 | if train: 113 | trainer.fit(model_cpy, data_loader, valid_data_loader) 114 | 115 | log = trainer.test(model_cpy, valid_data_loader) 116 | sensitivity[amount] = log[0] 117 | return sensitivity 118 | 119 | 120 | def quantization_sensitivity_analysis(config, amounts, train=True, params=["weight"]): 121 | sensitivities = OrderedDict() 122 | 123 | fns = ['linear_quantization', 'forgy_quantization', 'density_quantization'] 124 | 125 | data_loader = config.init_obj('data_loader', module_data) 126 | valid_data_loader = data_loader.split_validation() 127 | 128 | model = LitModel(config, config.init_obj('arch', module_arch)) 129 | if config.resume: 130 | checkpoint = torch.load(config.resume) 131 | load_compressed_checkpoint(model, checkpoint) 132 | trainer = Trainer(accelerator="gpu", deterministic=True) 133 | for fn in fns: 134 | sensitivity = OrderedDict() 135 | quantization_fn = getattr(module_compression.quantization, fn) 136 | for amount in amounts: 137 | 138 | model_cpy = deepcopy(model) 139 | current_modules = [m for m in model_cpy.model.modules() if not isinstance(m, _MODULE_CONTAINERS) and isinstance(m, nn.Conv2d)] 140 | parameters_to_quantize = [(m, p) for p in params for m in current_modules if hasattr(m, p)] 141 | 142 | for module, name in parameters_to_quantize: 143 | print("Quantizing {} into {:d} bits...".format(module, amount)) 144 | quantization_fn(module, name=name, bits=amount) 145 | 146 | if train: 147 | trainer.fit(model_cpy, data_loader, valid_data_loader) 148 | 149 | log = trainer.test(model_cpy, valid_data_loader) 150 | sensitivity[amount] = log[0] 151 | sensitivities[fn] = sensitivity 152 | return sensitivities 153 | 154 | 155 | def main(config): 156 | # sparsities = 1 - np.logspace(-2, 0, 10) 157 | # experiment = 'global pruning w/o retrain' 158 | # sensitivity = pruning_sensitivity_analysis(config, train=True, amounts=sparsities) 159 | 160 | sensitivities = quantization_sensitivity_analysis(config,train=False,amounts=list(range(1,9))) 161 | 162 | # fig.savefig(config.save_dir / f"{config['name'].lower()}_sensitivity_analysis.png") 163 | # sensitivities = {experiment: sensitivity} 164 | fname = config.save_dir / f"{config['name'].lower()}_sensitivity_analysis.csv" 165 | 166 | with open(fname, 'w') as csv_file: 167 | writer = csv.writer(csv_file) 168 | # write the header 169 | writer.writerow(['experiment', 'sparsity', 'loss', 'top1', 'top5']) 170 | for experiment, sensitivity in sensitivities.items(): 171 | for sparsity, values in sensitivity.items(): 172 | writer.writerow([experiment] + [sparsity] + list(values.values())) 173 | 174 | # sensitivities_to_csv(sensitivities, config.save_dir / f"{config['name'].lower()}_sensitivity_analysis.csv") 175 | 176 | 177 | def plot_sensitivities(sensitivities, metric='val_accuracy'): 178 | """Create a mulitplot of the sensitivities. 179 | The 'sensitivities' argument is expected to have the dict-of-dict structure 180 | described in the documentation of perform_sensitivity_test. 181 | """ 182 | try: 183 | # sudo apt-get install python3-tk 184 | import matplotlib 185 | matplotlib.use('Agg') 186 | import matplotlib.pyplot as plt 187 | except ImportError: 188 | print("WARNING: Function plot_sensitivity requires package matplotlib which" 189 | "is not installed in your execution environment.\n" 190 | "Skipping the PNG file generation") 191 | return 192 | fig = plt.figure() 193 | for param_name, sensitivity in sorted(sensitivities.items()): 194 | sense = [values[metric] for sparsity, values in sensitivity.items()] 195 | sparsities = [sparsity for sparsity, values in sensitivity.items()] 196 | 197 | plt.plot(sparsities, sense, label=param_name) 198 | plt.ylabel(metric) 199 | plt.xlabel('sparsity') 200 | plt.title('Pruning Sensitivity') 201 | plt.grid() 202 | plt.legend(loc='lower center', 203 | ncol=2, mode="expand", borderaxespad=0.) 204 | return fig 205 | 206 | 207 | def sensitivities_to_csv(sensitivities, fname): 208 | """Create a CSV file listing from the sensitivities dictionary. 209 | The 'sensitivities' argument is expected to have the dict-of-dict structure 210 | described in the documentation of perform_sensitivity_test. 211 | """ 212 | with open(fname, 'w') as csv_file: 213 | writer = csv.writer(csv_file) 214 | # write the header 215 | writer.writerow(['parameter', 'sparsity', 'loss', 'top1', 'top5']) 216 | for param_name, sensitivity in sensitivities.items(): 217 | for sparsity, values in sensitivity.items(): 218 | writer.writerow([param_name] + [sparsity] + list(values.values())) 219 | 220 | 221 | if __name__ == "__main__": 222 | args = argparse.ArgumentParser(description=__doc__, 223 | formatter_class=lambda prog: 224 | argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=52, width=90)) 225 | args.add_argument('-c', '--config', default=None, type=str, 226 | help='config file path (default: None)') 227 | args.add_argument('-r', '--resume', default=None, type=str, 228 | help='path to latest checkpoint (default: None)') 229 | args.add_argument('-d', '--device', default=None, type=str, 230 | help='indices of GPUs to enable (default: all)') 231 | 232 | # custom cli options to modify configuration from default values given in json file. 233 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 234 | options = [ 235 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 236 | CustomArgs(['--wd', '--weight_decay'], type=float, target='optimizer;args;weight_decay'), 237 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 238 | ] 239 | config = ConfigParser.from_args(args, options) 240 | main(config) 241 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import data as module_data 4 | import torch 5 | from pytorch_lightning import LightningModule, Trainer 6 | from torch import nn 7 | 8 | import models as module_arch 9 | from parse_config import ConfigParser 10 | from trainer.lit_model import LitModel 11 | from trainer.trainer import get_trainer 12 | import torch.nn.utils.prune as prune 13 | 14 | from utils import load_compressed_checkpoint 15 | 16 | _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) 17 | 18 | 19 | def main(config): 20 | logger = config.get_logger() 21 | 22 | # Override settings 23 | config['data_loader']['args']['training'] = False 24 | config['data_loader']['args']['validation_split'] = 0.0 25 | config['data_loader']['args']['shuffle'] = False 26 | 27 | test_data_loader = config.init_obj('data_loader', module_data) 28 | print(len(test_data_loader)) 29 | 30 | model = LitModel(config, config.init_obj('arch', module_arch)) 31 | 32 | if config.resume: 33 | checkpoint = torch.load(config.resume) 34 | model = load_compressed_checkpoint(model, checkpoint) 35 | logger.info(model) 36 | 37 | trainer = Trainer(logger=None, accelerator="gpu", deterministic=True, enable_progress_bar=False, 38 | enable_model_summary=False, enable_checkpointing=False, devices=1, num_nodes=1) 39 | log = trainer.test(model, test_data_loader) 40 | print(log) 41 | 42 | 43 | if __name__ == '__main__': 44 | args = argparse.ArgumentParser(description='PyTorch Template') 45 | args.add_argument('-c', '--config', default=None, type=str, 46 | help='config file path (default: None)') 47 | args.add_argument('-r', '--resume', default=None, type=str, 48 | help='path to latest checkpoint (default: None)') 49 | args.add_argument('-d', '--device', default=None, type=str, 50 | help='indices of GPUs to enable (default: all)') 51 | 52 | config = ConfigParser.from_args(args) 53 | main(config) 54 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a model on a chosen dataset. 3 | The model and dataset are defined in a JSON file as 4 | "arch": { 5 | "type": "ModelClass", 6 | "args": { 7 | "model_args": value 8 | } 9 | } 10 | "data_loader": { 11 | "type": "DataLoaderClass", 12 | "args": { 13 | "data_dir": "data/", 14 | "batch_size": 128, 15 | "shuffle": true, 16 | "validation_split": 0.1, 17 | "num_workers": 6 18 | } 19 | } 20 | 21 | a configuration file example is config.json 22 | 23 | Usage: 24 | $ python train.py -c config.json --args 25 | Example: 26 | $ python train.py -c config.json --lr=1e-3 --wd=1e-2 --batch_size=128 27 | 28 | Models: 29 | https://codeberg.org/ciodar/model-compression/src/branch/master/models 30 | """ 31 | import argparse 32 | import collections 33 | 34 | import torch 35 | 36 | 37 | import data as module_data 38 | import models as module_arch 39 | from parse_config import ConfigParser 40 | from trainer.lit_model import LitModel 41 | from trainer.trainer import get_trainer 42 | from utils import set_all_seeds, load_compressed_checkpoint 43 | 44 | SEED = 42 45 | set_all_seeds(SEED) 46 | 47 | 48 | def main(config): 49 | logger = config.get_logger() 50 | 51 | data_loader = config.init_obj('data_loader', module_data) 52 | valid_data_loader = data_loader.split_validation() 53 | 54 | model = LitModel(config, config.init_obj('arch', module_arch)) 55 | 56 | logger.info("Start training") 57 | if config.resume: 58 | checkpoint = torch.load(config.resume) 59 | model = load_compressed_checkpoint(model,checkpoint) 60 | logger.info(model) 61 | 62 | trainer = get_trainer(config) 63 | trainer.fit(model, data_loader, valid_data_loader) 64 | 65 | 66 | if __name__ == "__main__": 67 | args = argparse.ArgumentParser(description=__doc__, 68 | formatter_class=lambda prog: 69 | argparse.ArgumentDefaultsHelpFormatter(prog, max_help_position=52, width=90)) 70 | args.add_argument('-c', '--config', default=None, type=str, 71 | help='config file path (default: None)') 72 | args.add_argument('-r', '--resume', default=None, type=str, 73 | help='path to latest checkpoint (default: None)') 74 | args.add_argument('-d', '--device', default=None, type=str, 75 | help='indices of GPUs to enable (default: all)') 76 | 77 | # custom cli options to modify configuration from default values given in json file. 78 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 79 | options = [ 80 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 81 | CustomArgs(['--wd', '--weight_decay'], type=float, target='optimizer;args;weight_decay'), 82 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 83 | ] 84 | parser = ConfigParser.from_args(args, options) 85 | main(parser) 86 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciodar/deep-compression/674eb9fbf7eb028f0dd42b4c6914eb2b5cbc7df2/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import * 2 | from trainer.callbacks.pruning import IterativePruning 3 | from trainer.callbacks.quantization import Quantization 4 | -------------------------------------------------------------------------------- /trainer/callbacks/pruning.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional, Callable, Dict 2 | 3 | import torch 4 | from pytorch_lightning import LightningModule 5 | from pytorch_lightning.callbacks import ModelPruning 6 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 7 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug 8 | 9 | import compression 10 | from compression.pruning import sparsity_stats 11 | from models.alexnet import LinearWithAdjustableDropout 12 | 13 | 14 | class IterativePruning(ModelPruning): 15 | LAYER_TYPES = ("Linear", "Conv2d") 16 | 17 | def __init__(self, pruning_fn: Union[Callable, str], pruning_schedule: Dict, 18 | amount: Union[int, float, List[int]] = None, 19 | filter_layers: Optional[List[str]] = None, use_global_unstructured: bool = True, 20 | huffman_encode: bool = False, prune_on_fit_start: bool = False, 21 | **kwargs): 22 | self._use_global_unstructured = use_global_unstructured 23 | # custom pruning function 24 | if isinstance(pruning_fn, str) and pruning_fn.lower() == "l1_threshold": 25 | pruning_fn = compression.ThresholdPruning 26 | 27 | super().__init__(amount=self._compute_amount, apply_pruning=self._check_epoch, pruning_fn=pruning_fn, 28 | use_global_unstructured=use_global_unstructured, 29 | **kwargs) 30 | 31 | self._pruning_schedule = pruning_schedule 32 | self._filter_layers = filter_layers or self.LAYER_TYPES 33 | self._filter_layers = tuple(getattr(torch.nn, c) for c in self._filter_layers) 34 | self._amount = amount 35 | self._huffman_encode = huffman_encode 36 | self._prune_on_fit_start = prune_on_fit_start 37 | 38 | if use_global_unstructured and isinstance(amount, list): 39 | raise MisconfigurationException( 40 | "`amount` should be either an int or a float when `use_global_unstructured`=True" 41 | ) 42 | 43 | def _check_epoch(self, epoch): 44 | if 'target_sparsity' in self._pruning_schedule: 45 | total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) 46 | stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] 47 | zeros = sum(zeros for zeros, _ in stats) 48 | if zeros / total_params > self._pruning_schedule['target_sparsity']: 49 | return False 50 | if 'epochs' in self._pruning_schedule: 51 | return epoch in self._pruning_schedule['epochs'] 52 | if 'start_epoch' in self._pruning_schedule and epoch >= self._pruning_schedule['start_epoch']: 53 | prune_every = self._pruning_schedule.get('prune_every', 1) 54 | return (epoch - self._pruning_schedule['start_epoch']) % prune_every == 0 55 | 56 | def _compute_amount(self, epoch): 57 | return self._amount 58 | 59 | def filter_parameters_to_prune(self, parameters_to_prune=()): 60 | # filter modules based on type (Linear or Conv2d) 61 | return list(filter(lambda p: isinstance(p[0], self._filter_layers), parameters_to_prune)) 62 | 63 | def _apply_local_pruning(self, amount: Union[int, float, List[float]]): 64 | for i, (module, name) in enumerate(self._parameters_to_prune): 65 | self.pruning_fn(module, name=name, amount=self._amount[i]) 66 | if isinstance(module, LinearWithAdjustableDropout): 67 | module.adjust_dropout_rate(name) 68 | 69 | def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 70 | if self._prune_on_fit_start: 71 | rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning") 72 | self._run_pruning(pl_module.current_epoch) 73 | 74 | if self._check_epoch(pl_module.current_epoch): 75 | tot_retained, tot_pruned = sparsity_stats(pl_module) 76 | if pl_module.logger: 77 | tensorboard = pl_module.logger.experiment 78 | tensorboard.add_scalar("sparsity", (tot_retained / (tot_pruned + tot_retained))) 79 | tensorboard.add_scalar("compression", ((tot_pruned + tot_retained) / tot_retained)) 80 | 81 | def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 82 | tot_retained, tot_pruned = sparsity_stats(pl_module) 83 | if pl_module.logger: 84 | tensorboard = pl_module.logger.experiment 85 | tensorboard.add_scalar("sparsity", (tot_retained / (tot_pruned + tot_retained))) 86 | tensorboard.add_scalar("compression", ((tot_pruned + tot_retained) / tot_retained)) 87 | 88 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: 89 | if self._prune_on_train_epoch_end: 90 | rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning") 91 | self._run_pruning(pl_module.current_epoch) 92 | 93 | if self._check_epoch(pl_module.current_epoch): 94 | tot_retained, tot_pruned = sparsity_stats(pl_module) 95 | pl_module.log("sparsity", (tot_retained / (tot_pruned + tot_retained))) 96 | pl_module.log("compression", ((tot_pruned + tot_retained) / tot_retained)) 97 | -------------------------------------------------------------------------------- /trainer/callbacks/quantization.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | from typing import Union, Callable, List, Optional, Tuple, Sequence 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from lightning_utilities.core.rank_zero import rank_zero_debug 8 | from pytorch_lightning import LightningModule 9 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 10 | from pytorch_lightning.callbacks import Callback 11 | 12 | import torch.nn as nn 13 | 14 | import compression.quantization 15 | from compression import quantization 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | _QUANTIZATION_FUNCTIONS = { 20 | "density_quantization": quantization.density_quantization, 21 | "forgy_quantization": quantization.forgy_quantization, 22 | "linear_quantization": quantization.linear_quantization 23 | } 24 | 25 | _QUANTIZATION_METHODS = { 26 | "density_quantization": quantization.DensityQuantizationMethod, 27 | "forgy_quantization": quantization.ForgyQuantizationMethod, 28 | "linear_quantization": quantization.LinearQuantizationMethod 29 | } 30 | 31 | _PARAM_TUPLE = Tuple[nn.Module, str] 32 | _PARAM_LIST = Sequence[_PARAM_TUPLE] 33 | _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) 34 | 35 | 36 | class Quantization(Callback): 37 | PARAMETER_NAMES = ("weight", "bias") 38 | LAYER_TYPES = ("Linear", "Conv2d") 39 | 40 | def __init__(self, epoch, quantization_fn, parameters_to_quantize=None, parameter_names=None, 41 | bits: Union[int, List[int]] = None, filter_layers: Optional[List[str]] = None, 42 | huffman_encode: bool = False, apply_quantization: Union[bool, Callable[[int], bool]] = True 43 | , verbose: int = 0): 44 | super().__init__() 45 | 46 | self._parameters_to_quantize = parameters_to_quantize 47 | self._parameter_names = parameter_names or self.PARAMETER_NAMES 48 | self._global_kwargs = {} 49 | self._original_layers = None 50 | self._pruning_fn_name = None 51 | 52 | quantization_fn = self._create_quantization_fn(quantization_fn) 53 | 54 | self.quantization_fn = quantization_fn 55 | self._apply_quantization = apply_quantization 56 | self.bits = bits 57 | self._quantization_epoch = epoch 58 | self._quantize_on_train_epoch_end = False 59 | self._filter_layers = filter_layers or self.LAYER_TYPES 60 | self._filter_layers = tuple(getattr(torch.nn, c) for c in self._filter_layers) 61 | self._huffman_encode = huffman_encode 62 | 63 | if verbose not in (0, 1, 2): 64 | raise MisconfigurationException("`verbose` must be any of (0, 1, 2)") 65 | 66 | self._verbose = verbose 67 | 68 | def filter_parameters_to_quantize(self, parameters_to_quantize=()): 69 | return list(filter(lambda p: isinstance(p[0], self._filter_layers), parameters_to_quantize)) 70 | 71 | def _create_quantization_fn(self, quantization_fn: str, **kwargs) -> Union[ 72 | Callable, quantization.BaseQuantizationMethod]: 73 | 74 | quantization_fn = _QUANTIZATION_FUNCTIONS[quantization_fn] 75 | # save the function __name__ now because partial does not include it 76 | # and there are issues setting the attribute manually in ddp. 77 | self._quantization_fn_name = quantization_fn.__name__ 78 | return Quantization._wrap_quantization_fn(quantization_fn, **kwargs) 79 | 80 | @staticmethod 81 | def _wrap_quantization_fn(pruning_fn, **kwargs): 82 | return partial(pruning_fn, **kwargs) 83 | 84 | def apply_quantization(self, bits: Union[int, float]) -> None: 85 | """Applies quantization to ``parameters_to_quantize``.""" 86 | for module, name in self._parameters_to_quantize: 87 | log.debug("Quantizing {} into {:d} bits...".format(module, bits)) 88 | self.quantization_fn(module, name=name, bits=bits) 89 | 90 | def setup(self, trainer: "pl.Trainer", pl_module: LightningModule, stage: str) -> None: 91 | parameters_to_quantize = self.sanitize_parameters_to_quantize( 92 | pl_module, self._parameters_to_quantize, parameter_names=self._parameter_names 93 | ) 94 | 95 | self._parameters_to_quantize = self.filter_parameters_to_quantize(parameters_to_quantize) 96 | 97 | def _run_quantization(self, current_epoch: int) -> None: 98 | self._apply_quantization = current_epoch == self._quantization_epoch 99 | if self._apply_quantization: 100 | self.apply_quantization(self.bits) 101 | 102 | def make_quantization_permanent(self, module: nn.Module) -> None: 103 | for _, module in module.named_modules(): 104 | for k in list(module._forward_pre_hooks): 105 | hook = module._forward_pre_hooks[k] 106 | if isinstance(hook, compression.quantization.BaseQuantizationMethod): 107 | hook.remove(module) 108 | del module._forward_pre_hooks[k] 109 | 110 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: 111 | if not trainer.sanity_checking and not self._quantize_on_train_epoch_end: 112 | rank_zero_debug("`Quantization.on_validation_epoch_end`. Applying quantization") 113 | self._run_quantization(pl_module.current_epoch) 114 | 115 | if self._apply_quantization: 116 | # TODO: move idx_bits to configuration 117 | compression = quantization.compression_stats(pl_module, idx_bits=4, 118 | huffman_encoding=self._huffman_encode) 119 | pl_module.log("compression", compression) 120 | 121 | # def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 122 | 123 | @staticmethod 124 | def sanitize_parameters_to_quantize( 125 | pl_module: LightningModule, 126 | parameters_to_quantize: Optional[_PARAM_LIST] = None, 127 | parameter_names: Optional[List[str]] = None, 128 | ) -> _PARAM_LIST: 129 | """ 130 | This function is responsible of sanitizing ``parameters_to_quantize`` and ``parameter_names``. 131 | If ``parameters_to_quantize is None``, it will be generated with all parameters of the model. 132 | Raises: 133 | MisconfigurationException: 134 | If ``parameters_to_quantize`` doesn't exist in the model, or 135 | if ``parameters_to_quantize`` is neither a list of tuple nor ``None``. 136 | """ 137 | parameters = parameter_names or Quantization.PARAMETER_NAMES 138 | 139 | current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] 140 | 141 | if parameters_to_quantize is None: 142 | parameters_to_quantize = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)] 143 | elif ( 144 | isinstance(parameters_to_quantize, (list, tuple)) 145 | and len(parameters_to_quantize) > 0 146 | and all(len(p) == 2 for p in parameters_to_quantize) 147 | and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_quantize) 148 | ): 149 | missing_modules, missing_parameters = [], [] 150 | for module, name in parameters_to_quantize: 151 | if module not in current_modules: 152 | missing_modules.append(module) 153 | continue 154 | if not hasattr(module, name): 155 | missing_parameters.append(name) 156 | 157 | if missing_modules or missing_parameters: 158 | raise MisconfigurationException( 159 | "Some provided `parameters_to_tune` don't exist in the model." 160 | f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}" 161 | ) 162 | else: 163 | raise MisconfigurationException( 164 | "The provided `parameters_to_quantize` should either be list of tuple" 165 | " with 2 elements: (nn.Module, parameter_name_to_quantize) or None" 166 | ) 167 | return parameters_to_quantize 168 | -------------------------------------------------------------------------------- /trainer/lit_model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as lit 2 | import trainer.metrics as module_metric 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | class LitModel(lit.LightningModule): 9 | def __init__(self, config, model): 10 | super().__init__() 11 | self.config = config 12 | self.model = model 13 | self.criterion = getattr(F, config['loss']) 14 | self.metric_ftns = [getattr(module_metric, met) for met in config['metrics']] 15 | 16 | 17 | def training_step(self, batch, batch_idx): 18 | data, target = batch 19 | output = self.model(data) 20 | loss = self.criterion(output, target) 21 | self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) 22 | for met in self.metric_ftns: 23 | self.log(met.__name__, met(output, target), on_step=False, on_epoch=True, prog_bar=True) 24 | return loss 25 | 26 | def validation_step(self, batch, batch_idx): 27 | x, y = batch 28 | logits = self.model(x) 29 | loss = self.criterion(logits, y) 30 | self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True) 31 | # compute metrics 32 | for met in self.metric_ftns: 33 | self.log('val_' + met.__name__, met(logits, y), on_step=False, on_epoch=True, prog_bar=True) 34 | return loss 35 | 36 | def test_step(self, batch, batch_idx): 37 | # this is the test loop 38 | x, y = batch 39 | logits = self.model(x) 40 | loss = self.criterion(logits, y) 41 | self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True) 42 | # compute metrics 43 | for met in self.metric_ftns: 44 | self.log('test_' + met.__name__, met(logits, y), on_step=False, on_epoch=True, prog_bar=True) 45 | return loss 46 | 47 | def on_validation_epoch_end(self): 48 | super().on_validation_epoch_end() 49 | # log model parameters 50 | if self.logger: 51 | tensorboard = self.logger.experiment 52 | for name, p in self.model.state_dict().items(): 53 | tensorboard.add_histogram(name, p, self.global_step) 54 | 55 | def configure_optimizers(self): 56 | trainable_params = filter(lambda p: p.requires_grad, self.parameters()) 57 | optimizer = self.config.init_obj('optimizer', torch.optim, trainable_params) 58 | if 'lr_scheduler' in self.config.config: 59 | lr_scheduler = self.config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) 60 | scheduler_dict = {"scheduler": lr_scheduler, "interval": "epoch"} 61 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 62 | mnt = self.config['trainer']['monitor'].split()[1] 63 | scheduler_dict['monitor'] = mnt 64 | else: 65 | return optimizer 66 | return [optimizer], [scheduler_dict] 67 | -------------------------------------------------------------------------------- /trainer/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def accuracy(output, target): 4 | with torch.no_grad(): 5 | pred = torch.argmax(output, dim=1) 6 | assert pred.shape[0] == len(target) 7 | correct = 0 8 | correct += torch.sum(pred == target).item() 9 | return correct / len(target) 10 | 11 | def topk_accuracy(output, target, k=5): 12 | with torch.no_grad(): 13 | pred = torch.topk(output, k, dim=1)[1] 14 | assert pred.shape[0] == len(target) 15 | correct = 0 16 | for i in range(k): 17 | correct += torch.sum(pred[:, i] == target).item() 18 | return correct / len(target) -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | 5 | import pytorch_lightning as lit 6 | 7 | import trainer.callbacks as module_callback 8 | 9 | 10 | def get_trainer(config): 11 | cfg_trainer = config['trainer'] 12 | 13 | if torch.cuda.is_available(): 14 | accelerator, devices = "auto", config['n_gpu'] 15 | else: 16 | accelerator, devices = "auto", None 17 | 18 | min_epochs = cfg_trainer.get('min_epochs', 0) 19 | max_epochs = cfg_trainer.get('max_epochs', -1) 20 | if cfg_trainer.get('enable_checkpointing', True): 21 | default_root_dir = config.save_dir 22 | else: 23 | default_root_dir = None 24 | 25 | callbacks = [] 26 | if 'callbacks' in cfg_trainer: 27 | for cb, values in cfg_trainer['callbacks'].items(): 28 | if isinstance(values, list): 29 | for args in values: 30 | callback = getattr(module_callback, cb)(**args) 31 | callbacks.append(callback) 32 | else: 33 | callback = getattr(module_callback, cb)(**values) 34 | callbacks.append(callback) 35 | return lit.Trainer(min_epochs=min_epochs, max_epochs=max_epochs, callbacks=callbacks, accelerator="auto", 36 | devices=devices 37 | , default_root_dir=default_root_dir) 38 | 39 | 40 | class CompressionTrainer(lit.Trainer): 41 | def __init__(self, config, **kwargs): 42 | self.config = config 43 | super().__init__(**kwargs) 44 | 45 | # def compute_pruning_amount(self, epoch): 46 | # if epoch in config['trainer'] 47 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import operator 4 | import os 5 | import random 6 | from collections import OrderedDict 7 | from os import chdir 8 | from pathlib import Path 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from pytorch_lightning import LightningModule 15 | from torch.nn.utils import prune 16 | 17 | import compression 18 | 19 | 20 | def set_all_seeds(seed): 21 | os.environ["PL_GLOBAL_SEED"] = str(seed) 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | 27 | 28 | def set_deterministic(): 29 | if torch.cuda.is_available(): 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | torch.use_deterministic_algorithms(True) 33 | 34 | 35 | # functions to show an image 36 | def imshow(img, one_channel=False): 37 | if one_channel: 38 | img = img.mean(dim=0) 39 | img = img / 2 + 0.5 # unnormalize 40 | npimg = img.numpy() 41 | if one_channel: 42 | plt.imshow(npimg, cmap="Greys") 43 | else: 44 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 45 | 46 | 47 | def plot_sparsity_matrix(model): 48 | # fig = plt.figure() 49 | for name, module in model.named_modules(): 50 | if isinstance(module, torch.nn.Linear): 51 | weights = module.weight.detach().cpu() 52 | plt.spy(weights, color='blue', markersize=1) 53 | plt.title(name) 54 | plt.show() 55 | # if isinstance(param, torch.nn.Conv2d): 56 | elif isinstance(module, torch.nn.Conv2d): 57 | weights = module.weight.detach().cpu() 58 | num_kernels = weights.shape[0] 59 | for k in range(num_kernels): 60 | kernel_weights = weights[k].sum(dim=0) 61 | tag = f"{name}/kernel_{k}" 62 | plt.spy(kernel_weights, color='blue', markersize=1) 63 | plt.title(tag) 64 | plt.show() 65 | 66 | # ax = fig.add_subplot(1, num_kernels, k + 1, xticks=[], yticks=[]) 67 | # ax.set_title("layer {0}/kernel_{1}".format(name, k)) 68 | # return fig 69 | 70 | def weight_histograms_conv2d(writer, step, weights, name): 71 | weights_shape = weights.shape 72 | num_kernels = weights_shape[0] 73 | for k in range(num_kernels): 74 | flattened_weights = weights[k].flatten() 75 | tag = f"{name}/kernel_{k}" 76 | if (flattened_weights != 0).any().item(): 77 | writer.add_histogram(tag, flattened_weights[flattened_weights != 0], global_step=step, bins='tensorflow') 78 | 79 | 80 | def weight_histograms_linear(writer, step, weights, name): 81 | flattened_weights = weights.flatten() 82 | tag = name 83 | writer.add_histogram(tag, flattened_weights[flattened_weights != 0], global_step=step, bins='tensorflow') 84 | # print('layer %s | std: %.3f | sparsity: %.3f%%' % ( 85 | # name, torch.std(flattened_weights), (flattened_weights == 0.).sum() / len(flattened_weights) * 100)) 86 | 87 | 88 | def weight_histograms(writer, step, model): 89 | # print("Visualizing model weights...") 90 | # Iterate over all model layers 91 | for name, module in model.named_modules(): 92 | # Compute weight histograms for appropriate layer 93 | if isinstance(module, nn.Conv2d): 94 | weights = module.weight 95 | weight_histograms_conv2d(writer, step, weights, name) 96 | elif isinstance(module, nn.Linear): 97 | weights = module.weight 98 | weight_histograms_linear(writer, step, weights, name) 99 | 100 | 101 | def plot_weight_histograms(model): 102 | for name, module in model.named_modules(): 103 | if isinstance(module, nn.Linear): 104 | weight = module.weight.data.cpu() 105 | plt.hist(weight[weight != 0], bins=30, density=True) 106 | plt.title('layer: %s' % name) 107 | plt.show() 108 | elif isinstance(module, nn.Conv2d): 109 | weight = module.weight.data.cpu() 110 | for k in range(weight.shape[0]): 111 | flattened_weights = weight[k].flatten() 112 | tag = "layer: %s/kernel_%d" % (name, k) 113 | plt.hist(flattened_weights[flattened_weights != 0], bins=30, density=True) 114 | plt.title(tag) 115 | plt.show() 116 | 117 | 118 | # def save_compressed_weights(model, save_path): 119 | # weight_dict = OrderedDict() 120 | # for name,module in model.named_modules(): 121 | # if prune.is_pruned(module) and not isinstance(module, type(model)): 122 | # weight_mask = getattr(module,'weight_mask') 123 | # if quantize.is_quantized(module): 124 | # indices = getattr(module,'weight_indices') 125 | # weight_mask[weight_mask==1] = indices 126 | # else: 127 | # 128 | # sparse_weight = sparse.csr_matrix(weight) if weight.shape[0] < weight.shape[1] else sparse.csc_matrix( 129 | # weight) 130 | # tensor = model.state_dict()[param_tensor] 131 | # if prune.is_pruned(tensor): 132 | # 133 | # bias = module.bias.data.cpu().numpy() 134 | # 135 | # weight_dict['%s.weight' % name] = sparse_weight 136 | # weight_dict['%s.bias' % name] = bias 137 | # torch.save(weight_dict, save_path) 138 | 139 | def read_json(fname): 140 | fname = Path(fname) 141 | with fname.open('rt') as handle: 142 | return json.load(handle, object_hook=OrderedDict) 143 | 144 | 145 | def write_json(content, fname): 146 | fname = Path(fname) 147 | with fname.open('wt') as handle: 148 | json.dump(content, handle, indent=4, sort_keys=False) 149 | 150 | 151 | def make_paths_relative_to_root(): 152 | """Always use the same, absolute (relative to root) paths 153 | which makes moving the notebooks around easier. 154 | """ 155 | top_level = Path(__file__).parent 156 | chdir(top_level) 157 | 158 | def load_compressed_checkpoint(pl_model: LightningModule, checkpoint): 159 | state_dict = checkpoint['state_dict'] 160 | pruned_parameters = [(m, mask) for m, mask in state_dict.items() if '_mask' in m] 161 | quantized_parameters = [(m, int(math.log2(t.numel()))) for m, t in state_dict.items() if ('_centers' in m)] 162 | 163 | parameters_to_prune, parameters_to_quantize = [], [] 164 | 165 | for pp, mask in pruned_parameters: 166 | module_name, param = '.'.join(pp.split('.')[0:-1]), pp.split('.')[-1] 167 | name = param.split('_mask')[0] 168 | module = operator.attrgetter(module_name)(pl_model) 169 | parameters_to_prune.append((module, param)) 170 | # Lightning cant handle this device mismatch 171 | prune.custom_from_mask(module, name, mask.to(pl_model.device)) 172 | 173 | for qp, bits in quantized_parameters: 174 | module_name, param = '.'.join(qp.split('.')[0:-1]), qp.split('.')[-1] 175 | param = param.split('_centers')[0] 176 | module = operator.attrgetter(module_name)(pl_model) 177 | parameters_to_quantize.append((module, param)) 178 | # pick any quantization type. 179 | # TODO: define identity quantization(e.g all zeros) 180 | compression.linear_quantization(module, param, bits) 181 | 182 | # finally try to load checkpoint 183 | pl_model.load_state_dict(checkpoint['state_dict']) 184 | return pl_model 185 | --------------------------------------------------------------------------------