├── .gitignore ├── CHANGELOG.md ├── How_to_add_your_own_algorithms.md ├── LICENSE ├── README.md ├── imagenet ├── README.md ├── baseline │ ├── LICENSE │ ├── README.md │ ├── experiments │ │ ├── resnet_experiments.yaml │ │ └── wrnet_experiments.yaml │ ├── learning_schedules │ │ ├── resnet_prune_schedule.yaml │ │ ├── resnet_schedule.yaml │ │ ├── resnet_schedule_long.yaml │ │ ├── wrnet_prune_schedule.yaml │ │ ├── wrnet_schedule.yaml │ │ └── wrnet_schedule_long.yaml │ ├── main.py │ ├── model.py │ ├── parameterized_tensors.py │ ├── reparameterized_layers.py │ ├── requirements.txt │ └── run_imagenet.sh └── tuned_resnet │ ├── LICENSE │ ├── README.md │ ├── examples │ ├── RN50_FP16_1GPU.sh │ ├── RN50_FP16_4GPU.sh │ ├── RN50_FP16_8GPU.sh │ ├── RN50_FP16_EVAL.sh │ ├── RN50_FP16_INFERENCE_BENCHMARK.sh │ ├── RN50_FP32_1GPU.sh │ ├── RN50_FP32_4GPU.sh │ ├── RN50_FP32_8GPU.sh │ ├── RN50_FP32_EVAL.sh │ └── RN50_FP32_INFERENCE_BENCHMARK.sh │ ├── img │ ├── training_accuracy.png │ ├── training_loss.png │ └── validation_accuracy.png │ ├── main.py │ ├── multiproc.py │ ├── resnet.py │ ├── scripts │ └── extract_summary.py │ ├── smoothing.py │ └── sparse_momentum_logs.tar.gz ├── mnist_cifar ├── extensions.py ├── get_results_from_logs.py ├── main.py └── plot_feature_histograms.py ├── plot_graphs.py ├── requirements.txt ├── results ├── MNIST_compression_comparison_lenet300-100.csv ├── MNIST_compression_comparison_lenet5.csv ├── MNIST_sparse_summary.csv ├── WRN-28-2_results_summary.csv ├── dynamic_sparse │ ├── calc_confidence_intervals.py │ └── results.zip ├── feature_data.tar.gz ├── graphs.py ├── imagenet_0.15626087.log ├── imagenet_0.25626087.log ├── sensitivity_momentum_alexnet-s.csv ├── sensitivity_momentum_vgg-d.csv ├── sensitivity_prune_rate_alexnet-s.csv └── sensitivity_prune_rate_vgg-d.csv ├── setup.py └── sparselearning ├── __init__.py ├── core.py ├── funcs.py ├── models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | transformer-xl/data 107 | transformer-xl/pytorch/LM-TFM* 108 | 109 | # vim 110 | *.swp 111 | 112 | # pickle 113 | *.pt 114 | 115 | # data 116 | *.data 117 | 118 | *.tar 119 | conv_transformer/data 120 | dynamic_sparse/data 121 | dynamic_sparse/runs/* 122 | mnist_cifar/_dataset* 123 | data 124 | imagenet/partially_dense/runs* 125 | 126 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | ## Release v0.1: Ease of use, bug fixes, and documentation. 3 | ### Bug fixes: 4 | - Fixed a but where magnitude pruning pruned too many parameters when the weight was dense (>95% density) and the pruning rate was small (<5%). 5 | First experiments on LeNet-5 Caffe indicate that this change did not affect performance for networks that learn to have dense weights. 6 | I will replicate this across architectures to make sure this bugfix does not change performance. 7 | - Fixed instabilities in SET (sparse evolutionary training) pruning which could cause nan values in specific circumstances. 8 | 9 | ### Documentation: 10 | - Added basic docstring documentation 11 | 12 | ### Features: 13 | - MNIST/CIFAR: Separate log files are not created for different models/densities/names. 14 | - MNIST/CIFAR: Aggregate mean test accuracy with standard errors can now be automatically extracted from logs with `python get_results_from_logs.py`. 15 | 16 | ### API: 17 | - Changed names from "death" to "prune" to be more consistent with the terminology in the paper. 18 | - Added --verbose argument to print the parameter distribution before/after pruning at the end of each epoch. By default, the pruning distribution will no longer be printed. 19 | - Removed --sparse flag and added --dense flag. The default is args.dense==False and thus sparse mode is enabled by default. To run a dense model just pass the --dense argument. 20 | 21 | 22 | ## Release v0.2: FP16 support, modularity of prune/growth/redistribution algorithms. 23 | 24 | ### Bug fixes: 25 | - Fixed a bug where global pruning would throw an error if a layer was fully dense and had a low prune rate. 26 | 27 | ### Features: 28 | - Added FP16 support. Any model can now be run in 16-bit by passing the [apex](https://github.com/NVIDIA/apex) `FP16_Optimizer` into the `Masking` class and replacing `loss.backward()` with `optimizer.backward(loss)`. 29 | - Added adapted [Dynamic Sparse Reparameterization](https://arxiv.org/abs/1902.05967) [codebase](https://github.com/IntelAI/dynamic-reparameterization) that works with sparse momentum. 30 | - Added modular architecture for growth/prune/redistribution algorithms which is decoupled from the main library. This enables you to write your own prune/growth/redistribution algorithms without touched the library internals. A tutorial on how to add your own functions was also added: [How to Add Your Own Algorithms](How_to_add_your_own_algorithms.md). 31 | 32 | 33 | 34 | ## Release v1.0: Bug fixes, New ImageNet Baselines 35 | 36 | ### Bug fixes: 37 | - Changed to boolean indexing for PyTorch 1.2 compatibility. 38 | - Fixed an error where an error can occur for global pruning algorithms if very few weights were removed for a layer. 39 | - Removed momentum reset. This feature did not have any effect on performance and made the algorithm more complex. 40 | - Fixed an error where two layers of VGG16 were removed by use of the `remove_weight_partial_name()` function. Results were slightly degraded, but weights needed for dense performance and relative ordering compared to other methods remained the same. 41 | 42 | ### Features: 43 | - Evaluation script can now aggregate log files organized in a folder hierarchy. For each folder results will be aggregated. 44 | - Added decay schedule argument. One can choose between Linear and Cosine prune rate decay schedules. 45 | - Added new ImageNet baseline which is based on the codebase of [Mostafa & Wang, 2019](https://arxiv.org/abs/1902.05967). 46 | - Added a max-thread argument which can be used to set the total maximum data loader threads for training, validation and test set data loaders. 47 | -------------------------------------------------------------------------------- /How_to_add_your_own_algorithms.md: -------------------------------------------------------------------------------- 1 | # How to Add Your Own Redistribution/Pruning/Growth Algorithms 2 | 3 | This is a tutorial on how to add your own redistribution, pruning, and growth algorithms. The sparselearning library is built to be easily extendable in this regard. The basic steps are (1) implement your own function, pass it as an argument into the `Masking` class. 4 | 5 | ## General Structure of Functions 6 | 7 | Here the general structure of the three functions: 8 | ```python 9 | def your_redistribution(masking, name, weight, mask): return layer_importance 10 | def your_growth(masking, name, new_mask, total_regrowth, weight): return new_mask 11 | def your_pruning(masking, mask, weight, name): return pruned_mask 12 | ``` 13 | 14 | The variable `masking` is the general `Masking` class which enables access to global and local statistics of layers which can be useful to construct your own algorithms. `name` is the name of the current layer that is being processed. `weight` and `mask` are the weight of that layer and the binary mask that indicates the sparsity pattern. In the sparselearning library, all `0` elements in `mask` correspond to `0.0` values in `weight`. 15 | 16 | ## Accessible Variables 17 | 18 | When you write the redistribution, growth, and pruning algorithms you will have access to the `Masking` class and the `name` of the current layer. This section gives you more details on what you can access and how. 19 | 20 | ### Access to the Optimizer 21 | 22 | You can access the optimizer using the `masking.optimizer` variable. We can use this to for example get access to the momentum variables of the optimizer. This is for example how you can implement momentum redistribution used in the paper: 23 | ```python 24 | def your_redistribution(masking, name, weight, mask): 25 | momentum = masking.optimizer.state[weight]['momentum_buffer'] 26 | return momentum[mask.byte()].sum().item() 27 | 28 | ``` 29 | 30 | Other useful terms: 31 | ```python 32 | # running adam sum of square (equivalent to RMSProp). 33 | # Can be used to calculate the variance of gradient updates for each weight. 34 | adam_sumsq = masking.optimizer.state[weight]['exp_avg_sq'] 35 | ``` 36 | 37 | ### Access to Global and Layer Statistics. 38 | 39 | You can access statistics such as the number of non-zero weights of the current layer via the `masking` variable and the `name` of the layer. You have access to these statistics: 40 | Accessable global statistics: 41 | ```python 42 | 43 | # Layer statistics: 44 | non_zero_count = masking.name2nonzeros[name] 45 | zero_count = masking.name2zeros[name] 46 | normalized_layer_importance = masking.name2variance[name] 47 | number_of_pruned_weights = masking.name2removed[name] 48 | # Global Network statistics: 49 | masking.total_nonzero 50 | masking.total_zero 51 | masking.total_removed 52 | ``` 53 | 54 | ## Example: Variance-based Redistribution and Pruning 55 | 56 | ### Intuition 57 | 58 | Here I added two example extensions for redistribution and pruning. These two examples look at the variance of the gradient. If we look at weights with high and low variance in their gradients over time, then we can have the following interpretations. 59 | 60 | For high variance weights, we can have two perspectives. The first one would assume that weights with high variance are unable to model the interactions in the inputs to classify the outputs due to a lack of capacity. For example a weight might have a problem to be useful for both the digit 0 and digit 7 when classifying MNIST and thus has high variance between these examples. If we add capacity to high variance layers, then we should reduce the variance over time meaning the new weights can now fully model the different classes (one weight for 7 one weight for 0). According to this perspective we want to add more parameters to layers with high average variance. In other words, we want to redistribute pruned parameters to layers with high gradient variance. 61 | 62 | The second perspective is a "potential of be useful" perspective. Here we see weights with high variance as having "potential to do the right classification, but they might just not have found the right decision boundary between classes yet". For example, a weight might have problems being useful for both the digit 7 and 0 but overtime it can find a feature which is useful for both classes. Thus gradient variance should reduce over time as features become more stable. If we take this perspective then it is important to keep some medium-to-high variance weights. Low variance weights have "settled in" and follow the gradient for a specific set of classes. These weights will not change much anymore while high variance weights might change a lot. So high variance weights might have "potential" while the potential of low variance weights is easily assessed by looking at the magnitude of that weights. Thus we might improve pruning if we look at both the variance of the gradient _and_ the magnitude of weights. You can find these examples in ['mnist_cifar/extensions.py']('sparse_learning/mnist_cifar/extensions.py'). 63 | 64 | ### Implementation 65 | 66 | ```python 67 | def variance_redistribution(masking, name, weight, mask): 68 | '''Return the mean variance of existing weights. 69 | 70 | Intuition: Higher gradient variance means a layer does not have enough 71 | capacity to model the inputs with the current number of weights. 72 | Thus we want to add more weights if we have higher variance. 73 | If variance of the gradient stabilizes this means 74 | that some weights might be useless/not needed. 75 | ''' 76 | # Adam calculates the running average of the sum of square for us 77 | # This is similar to RMSProp. 78 | if 'exp_avg_sq' not in masking.optimizer.state[weight]: 79 | print('Variance redistribution requires the adam optimizer to be run!') 80 | raise Exception('Variance redistribution requires the adam optimizer to be run!') 81 | iv_adam_sumsq = torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq']) 82 | 83 | layer_importance = iv_adam_sumsq[mask.byte()].mean().item() 84 | return layer_importance 85 | 86 | 87 | def magnitude_variance_pruning(masking, mask, weight, name): 88 | ''' Prunes weights which have high gradient variance and low magnitude. 89 | 90 | Intuition: Weights that are large are important but there is also a dimension 91 | of reliability. If a large weight makes a large correct prediction 8/10 times 92 | is it better than a medium weight which makes a correct prediction 10/10 times? 93 | To test this, we combine magnitude (importance) with reliability (variance of 94 | gradient). 95 | 96 | Good: 97 | Weights with large magnitude and low gradient variance are the most important. 98 | Weights with medium variance/magnitude are promising for improving network performance. 99 | Bad: 100 | Weights with large magnitude but high gradient variance hurt performance. 101 | Weights with small magnitude and low gradient variance are useless. 102 | Weights with small magnitude and high gradient variance cannot learn anything usefull. 103 | 104 | We here take the geometric mean of those both normalized distribution to find weights to prune. 105 | ''' 106 | # Adam calculates the running average of the sum of square for us 107 | # This is similar to RMSProp. We take the inverse of this to rank 108 | # low variance gradients higher. 109 | if 'exp_avg_sq' not in masking.optimizer.state[weight]: 110 | print('Magnitude variance pruning requires the adam optimizer to be run!') 111 | raise Exception('Magnitude variance pruning requires the adam optimizer to be run!') 112 | iv_adam_sumsq = 1./torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq']) 113 | 114 | num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name]) 115 | 116 | num_zeros = masking.name2zeros[name] 117 | k = math.ceil(num_zeros + num_remove) 118 | if num_remove == 0.0: return weight.data != 0.0 119 | 120 | max_var = iv_adam_sumsq[mask.byte()].max().item() 121 | max_magnitude = torch.abs(weight.data[mask.byte()]).max().item() 122 | product = ((iv_adam_sumsq/max_var)*torch.abs(weight.data)/max_magnitude)*mask 123 | product[mask==0] = 0.0 124 | 125 | x, idx = torch.sort(product.view(-1)) 126 | mask.data.view(-1)[idx[:k]] = 0.0 127 | return mask 128 | ``` 129 | 130 | ### Adding our extension to MNIST 131 | To add our new methods to the MNIST script, we can simply import our newly created functions and define strings which enable our redistribution/pruning methods by passing a specific argument to the script: 132 | ```python 133 | from extensions import magnitude_variance_pruning, variance_redistribution 134 | if args.prune == 'magnitude_variance': args.prune = magnitude_variance_pruning 135 | if args.redistribution == 'variance': args.redistribution = variance_redistribution 136 | ``` 137 | 138 | With this we can now run our new pruning and redistribution method by calling the script. However, our pruning method also requires the adam optimizer and thus we need to change the optimizer and the learning rate as well: 139 | ```bash 140 | python main.py --model lenet5 --optimizer adam --prune magnitude_variance --redistribution variance --verbose --lr 0.001 141 | ``` 142 | 143 | Running 10 additional iterations (add `--iters 10`) of our new method with 5% weights on MNIST with Caffe LeNet-5 we can quickly calculate the performance using the evaluation script. 144 | ```bash 145 | python get_results_from_logs.py 146 | 147 | Accuracy. Median: 0.99300, Mean: 0.99300, Standard Error: 0.00019, Sample size: 11, 95% CI: (0.99262,0.99338) 148 | Error. Median: 0.00700, Mean: 0.00700, Standard Error: 0.00019, Sample size: 11, 95% CI: (0.00662,0.00738) 149 | Loss. Median: 0.02200, Mean: 0.02175, Standard Error: 0.00027, Sample size: 11, 95% CI: (0.02122,0.02228) 150 | 151 | ``` 152 | 153 | Sparse momentum achieves an error of 0.0069 for this setting and the upper 95% confidence interval is 0.00739. Thus for this setting our results overlap with the confidence intervals of sparse momentum. Thus our new variance method is _as good_ as sparse momentum for this particular problem (Caffe LeNet-5 with 5% weights on MNIST). 154 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tim Dettmers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Learning Library and Sparse Momentum Resources 2 | 3 | This repo contains a sparse learning library which allows you to wrap any PyTorch neural network with a sparse mask to emulate the training of sparse neural networks. It also contains the code to replicate our work [Sparse Networks from Scratch: Faster Training without Losing Performance](https://arxiv.org/abs/1907.04840). 4 | 5 | ## Requirements 6 | 7 | The library requires PyTorch v1.2. You can download it via anaconda or pip, see [PyTorch/get-started](https://pytorch.org/get-started/locally/) for further information. For CUDA version < 9.2 you need to either compile from source, or install a new CUDA version along with a compatible video driver. 8 | 9 | ## Installation 10 | 11 | 1. Install [PyTorch](https://pytorch.org/get-started/locally/). 12 | 2. Install other dependencies: `pip install -r requirements.txt` 13 | 3. Install the sparse learning library: `python setup.py install` 14 | 15 | ## Basic Usage 16 | 17 | ### MNIST & CIFAR-10 models 18 | 19 | MNIST and CIFAR-10 code can be found in the `mnist_cifar` subfolder. You can run `python main.py --data DATASET_NAME --model MODEL_NAME` to run a model on MNIST (`--data mnist`) or CIFAR-10 (`--data cifar`). 20 | 21 | The following models can be specified with the `--model` command out-of-the-box: 22 | ``` 23 | MNIST: 24 | 25 | lenet5 26 | lenet300-100 27 | 28 | CIFAR-10: 29 | 30 | alexnet-s 31 | alexnet-b 32 | vgg-c 33 | vgg-d 34 | vgg-like 35 | wrn-28-2 36 | wrn-22-8 37 | wrn-16-8 38 | wrn-16-10 39 | ``` 40 | 41 | Beyond standard parameters like batch-size and learning rate which usage can be seen by `python main.py --help` the following sparse learning specific parameter are available: 42 | ``` 43 | --save-features Resumes a saved model and saves its feature data to 44 | disk for plotting. 45 | --bench Enables the benchmarking of layers and estimates 46 | sparse speedups 47 | --growth GROWTH Growth mode. Choose from: momentum, random, and 48 | momentum_neuron. 49 | --death DEATH Death mode / pruning mode. Choose from: magnitude, 50 | SET, threshold. 51 | --redistribution REDISTRIBUTION 52 | Redistribution mode. Choose from: momentum, magnitude, 53 | nonzeros, or none. 54 | --death-rate DEATH_RATE 55 | The pruning rate / death rate. 56 | --density DENSITY The density of the overall sparse network. 57 | --sparse Enable sparse mode. Default: True. 58 | 59 | ``` 60 | 61 | ### Running an ImageNet Model 62 | 63 | To run ImageNet with 16-bit you need to install [Apex](https://github.com/NVIDIA/apex). For me it currently does not work to install apex from pip, but installing it from the repo works just fine. 64 | 65 | The ImageNet code for sparse momentum can be found in the sub-folder `imagenet` which contains two different ResNet-50 ImageNet models: A baseline that is used by Mostafa & Wang (2019) which reaches 74.9% accuravy with 100% weights and a tuned ResNet-50 version which is identical to the baseline but uses a warmup learning rate and label smoothing and reaches 77.0% accuracy with 100% weights. The tuned version builds on [NVIDIA Deep Learning Examples: RN50v1.5](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5) while the baseline builds on [Intel/dynamic-reparameterization](https://github.com/IntelAI/dynamic-reparameterization). 66 | 67 | ### Running Your Own Model 68 | 69 | With the sparse learning library it is easy to run sparse momentum on your own model. All that you need to do is follow the following code template: 70 | 71 | ![alt text][template] 72 | 73 | 74 | ## Extending the Library 75 | 76 | It is easy to extend the library with your own functions for growth, redistribution and pruning. See [The Extension Tutorial](https://github.com/TimDettmers/sparse_learning/blob/master/How_to_add_your_own_algorithms.md) for more information about how you can add your own functions. 77 | 78 | [template]: https://timdettmers.com/wp-content/uploads/2019/07/code.png "Generic example usage of sparse learning library." 79 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | ## Models for ImageNet 2 | 3 | For ImageNet there are two different ResNet-50 models. The "baseline" model replicated results by [Mostafa & Wang, 2019](https://arxiv.org/abs/1902.05967) and is an adapted version of [IntelAI/dynamic-reparameterization](https://github.com/IntelAI/dynamic-reparameterization). This ImageNet model attains a 74.9% accuracy for a dense baseline. This model can be found in the `baseline` folder and is the main model used throughout the paper for comparison against other models on ImageNet. 4 | 5 | The second version is adapted from [NVIDIA/DeepLearningExamples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5). It is the same model but better tuned than the baseline above. This model uses warmup learning rates and label smoothing and has a baseline 77.0% which is in line with recent tuned ResNet-50 baselines (see [Saining et al., 2019](https://arxiv.org/abs/1904.01569). This codebase can be found in the `tuned` folder. 6 | 7 | -------------------------------------------------------------------------------- /imagenet/baseline/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /imagenet/baseline/README.md: -------------------------------------------------------------------------------- 1 | # Dynamic parameter reallocation in deep CNNs 2 | 3 | The code implements the experiments in the ICML 2019 submission: Parameter efficient training of deep convolutional neural networks by dynamic sparse reparameterization 4 | 5 | 6 | ## Instructions 7 | This code implements the dynamic parameterization scheme in the ICML 2019 submission: Parameter efficient training of deep convolutional neural networks by dynamic sparse reparameterization. It also implements previous dynamic parameterization schemes such as the DeepR algorithm by [Bellec at al. 2018](https://arxiv.org/abs/1711.05136) and the SET algorithm by [Mocanu et al. 2018](https://www.nature.com/articles/s41467-018-04316-3) as well as static parameterizations based on tied parameters similar to [the HashedNet paper](https://arxiv.org/abs/1504.04788). It also implements iterative pruning where it can take a dense model and prune it down to the required sparsity. 8 | 9 | The main python executable is `main.py`. Results are saved under a `./runs/` directory created at the invocation directory. An invocation of `main.py` will save various accuracy metrics as well as the model parameters in the file `./runs/{model name}_{job idx}`. Accuracy figures as well as several diagnostics are also printed out. 10 | 11 | ### General usage 12 | ```shell 13 | main.py [-h] [--epochs EPOCHS] [--start-epoch START_EPOCH] 14 | [--model {mnist_mlp,cifar10_WideResNet,imagenet_resnet50}] 15 | [-b BATCH_SIZE] [--lr LR] [--momentum MOMENTUM] 16 | [--nesterov NESTEROV] [--weight-decay WEIGHT_DECAY] 17 | [--L1-loss-coeff L1_LOSS_COEFF] [--print-freq PRINT_FREQ] 18 | [--layers LAYERS] 19 | [--start-pruning-after-epoch START_PRUNING_AFTER_EPOCH] 20 | [--prune-epoch-frequency PRUNE_EPOCH_FREQUENCY] 21 | [--prune-target-sparsity-fc PRUNE_TARGET_SPARSITY_FC] 22 | [--prune-target-sparsity-conv PRUNE_TARGET_SPARSITY_CONV] 23 | [--prune-iterations PRUNE_ITERATIONS] 24 | [--post-prune-epochs POST_PRUNE_EPOCHS] 25 | [--n-prune-params N_PRUNE_PARAMS] [--threshold-prune] [--prune] 26 | [--validate-set] [--rewire-scaling] [--tied] 27 | [--rescale-tied-gradient] [--rewire] [--no-validate-train] 28 | [--DeepR] [--DeepR_eta DEEPR_ETA] 29 | [--stop-rewire-epoch STOP_REWIRE_EPOCH] [--no-batch-norm] 30 | [--rewire-fraction REWIRE_FRACTION] 31 | [--sub-kernel-granularity SUB_KERNEL_GRANULARITY] 32 | [--cubic-prune-schedule] [--sparse-resnet-downsample] 33 | [--conv-group-lasso] [--big-new-weights] 34 | [--widen-factor WIDEN_FACTOR] 35 | [--initial-sparsity-conv INITIAL_SPARSITY_CONV] 36 | [--initial-sparsity-fc INITIAL_SPARSITY_FC] [--job-idx JOB_IDX] 37 | [--no-augment] [--data DIR] [-j N] 38 | [--copy-mask-from COPY_MASK_FROM] [--resume RESUME] 39 | [--schedule-file SCHEDULE_FILE] [--name NAME] 40 | 41 | ``` 42 | Optional arguments: 43 | ``` 44 | -h, --help show this help message and exit 45 | --epochs EPOCHS number of total epochs to run 46 | --start-epoch START_EPOCH 47 | manual epoch number (useful on restarts) 48 | --model {mnist_mlp,cifar10_WideResNet,imagenet_resnet50} 49 | network name (default: mnist_mlp) 50 | -b BATCH_SIZE, --batch-size BATCH_SIZE 51 | mini-batch size (default: 100) 52 | --lr LR, --learning-rate LR 53 | initial learning rate 54 | --momentum MOMENTUM momentum 55 | --nesterov NESTEROV nesterov momentum 56 | --weight-decay WEIGHT_DECAY, --wd WEIGHT_DECAY 57 | weight decay (default: 1e-4) 58 | --L1-loss-coeff L1_LOSS_COEFF 59 | Lasso coefficient (default: 0.0) 60 | --print-freq PRINT_FREQ, -p PRINT_FREQ 61 | print frequency (default: 10) 62 | --layers LAYERS total number of layers for wide resnet (default: 28) 63 | --start-pruning-after-epoch START_PRUNING_AFTER_EPOCH 64 | Epoch after which to start pruning (default: 20) 65 | --prune-epoch-frequency PRUNE_EPOCH_FREQUENCY 66 | Intervals between prunes (default: 2) 67 | --prune-target-sparsity-fc PRUNE_TARGET_SPARSITY_FC 68 | Target sparsity when pruning fully connected layers 69 | (default: 0.98) 70 | --prune-target-sparsity-conv PRUNE_TARGET_SPARSITY_CONV 71 | Target sparsity when pruning conv layers (default: 72 | 0.5) 73 | --prune-iterations PRUNE_ITERATIONS 74 | Number of prunes. Set to 1 for single prune, larger 75 | than 1 for gradual pruning (default: 1) 76 | --post-prune-epochs POST_PRUNE_EPOCHS 77 | Epochs to train after pruning is done (default: 10) 78 | --n-prune-params N_PRUNE_PARAMS 79 | Number of parameters to re-allocate per re-allocation 80 | iteration (default: 600) 81 | --threshold-prune Prune based on a global threshold and not a fraction 82 | (default: False) 83 | --prune whether to use pruning or not (default: False) 84 | --validate-set whether to use a validation set to select epoch with 85 | best accuracy or not (default: False) 86 | --rewire-scaling Move weights between layers during parameter re- 87 | allocation (default: False) 88 | --tied whether to use tied weights instead of sparse ones 89 | (default: False) 90 | --rescale-tied-gradient 91 | whether to divide the gradient of tied weights by the 92 | number of their repetitions (default: False) 93 | --rewire whether to run parameter re-allocation (default: 94 | False) 95 | --no-validate-train whether to run validation on training set (default: 96 | False) 97 | --DeepR Train using deepR. prune and re-allocated weights that 98 | cross zero every iteration (default: False) 99 | --DeepR_eta DEEPR_ETA 100 | eta coefficient for DeepR (default: 0.1) 101 | --stop-rewire-epoch STOP_REWIRE_EPOCH 102 | Epoch after which to stop rewiring (default: 1000) 103 | --no-batch-norm no batch normalization in the mnist_mlp 104 | network(default: False) 105 | --rewire-fraction REWIRE_FRACTION 106 | Fraction of weight to rewire (default: 0.1) 107 | --sub-kernel-granularity SUB_KERNEL_GRANULARITY 108 | prune granularity (default: 2) 109 | --cubic-prune-schedule 110 | Use sparsity schedule following a cubic function as in 111 | Zhu et al. 2018 (instead of an exponential function). 112 | (default: False) 113 | --sparse-resnet-downsample 114 | Use sub-kernel granularity while rewiring(default: 115 | False) 116 | --conv-group-lasso Use group lasso to penalize an entire kernel 117 | patch(default: False) 118 | --big-new-weights Use weights initialized from the initial distribution 119 | for the new connections instead of zeros(default: 120 | False) 121 | --widen-factor WIDEN_FACTOR 122 | widen factor for wide resnet (default: 10) 123 | --initial-sparsity-conv INITIAL_SPARSITY_CONV 124 | Initial sparsity of conv layers(default: 0.5) 125 | --initial-sparsity-fc INITIAL_SPARSITY_FC 126 | Initial sparsity for fully connected layers(default: 127 | 0.98) 128 | --job-idx JOB_IDX job index provided by the job manager 129 | --no-augment whether to use standard data augmentation (default: 130 | use data augmentation) 131 | --data DIR path to imagenet dataset 132 | -j N, --workers N number of data loading workers (default: 8) 133 | --copy-mask-from COPY_MASK_FROM 134 | checkpoint from which to copy mask data(default: none) 135 | --resume RESUME path to latest checkpoint (default: none) 136 | --schedule-file SCHEDULE_FILE 137 | yaml file containing learning rate schedule and rewire 138 | period schedule 139 | --name NAME name of experiment 140 | 141 | ``` 142 | 143 | ### Specific experiments 144 | 145 | The two yaml files : `wrnet_experiments.yaml` and `resnet_experiments.yaml` contain YAML lists of all the invocations of the python executable needed to run the imagenet and the CIFAR10 experiments in the paper's main text and supplementary materials. 146 | 147 | ### Important notes 148 | 149 | - Code development and all experiments were done with Python 3.6 and pytorch 0.4.1. 150 | - All experiments were conducted on NVidia TitanXP GPUs. 151 | - Imagenet experiments require multi-GPU data parallelism, which is done by default using all available GPUs specified by environment variable `CUDA_VISIBLE_DEVICES`. -------------------------------------------------------------------------------- /imagenet/baseline/experiments/resnet_experiments.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | jobs: 18 | ##Baseline dense model training 19 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.0 --initial-sparsity-conv 0.0 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --job-idx 0 20 | 21 | 22 | ####sparsity 0.8 23 | #Thin dense model 24 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.0 --initial-sparsity-conv 0.0 --batch-size 256 --widen-factor 0.54 --epochs 200 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --job-idx 1 25 | 26 | #dynamic sparse 27 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --rewire --threshold-prune --n-prune-params 400000 --rewire-scaling --stop-rewire-epoch 95 --job-idx 2 28 | 29 | 30 | #Hash nets 31 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --tied --job-idx 3 32 | 33 | #static sparse 34 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --batch-size 256 --widen-factor 1 --epochs 200 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --job-idx 4 35 | 36 | #Prune from dense model 37 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_prune_schedule.yaml --resume ./runs/imagenet_resnet50_0 --initial-sparsity-fc 0.0 --initial-sparsity-conv 0.0 --batch-size 256 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --prune-target-sparsity-fc 0.8 --prune-target-sparsity-conv 0.8 --prune-iterations 20 --post-prune-epochs 10 --prune --start-pruning-after-epoch 0 --widen-factor 1 --job-idx 5 38 | 39 | #DeepR 40 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --batch-size 256 --widen-factor 1 --no-validate-train --sub-kernel-granularity 4 --stop-rewire-epoch 95 --DeepR --DeepR_eta 0.01 --L1-loss-coeff 1.0e-5 --job-idx 6 41 | 42 | #SET 43 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --rewire --rewire-fraction 0.089 --stop-rewire-epoch 95 --job-idx 7 44 | 45 | ##Dynamic sparse with 3x3 kernel slice re-allocation granularity 46 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 2 --rewire --threshold-prune --n-prune-params 40000 --rewire-scaling --stop-rewire-epoch 95 --job-idx 8 47 | 48 | ##Static sparse but with structure copied from trained dynamic sparse and random initialization. 49 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --widen-factor 1 --sub-kernel-granularity 4 --epochs 200 --weight-decay 1.0e-4 --copy-mask-from ./runs/imagenet_resnet50_2 --no-validate-train --job-idx 9 50 | 51 | ##Static sparse but with structure copied from trained dynamic sparse and initialization copied from the initialization of the same dynamic sparse network 52 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.8 --initial-sparsity-conv 0.8 --widen-factor 1 --sub-kernel-granularity 4 --epochs 200 --weight-decay 1.0e-4 --copy-mask-from ./runs/imagenet_resnet50_2 --resume ./runs/imagenet_resnet50_2_initial --no-validate-train --job-idx 10 53 | 54 | ###sparsity 0.9 55 | #Thin dense model 56 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.0 --initial-sparsity-conv 0.0 --batch-size 256 --widen-factor 0.46 --epochs 200 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --job-idx 11 57 | 58 | #dynamic sparse 59 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --rewire --threshold-prune --n-prune-params 400000 --rewire-scaling --stop-rewire-epoch 95 --job-idx 12 60 | 61 | #Hash nets 62 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --tied --job-idx 13 63 | 64 | #static sparse 65 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --batch-size 256 --widen-factor 1 --epochs 200 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --job-idx 14 66 | 67 | #Prune from dense model 68 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_prune_schedule.yaml --resume ./runs/imagenet_resnet50_0 --initial-sparsity-fc 0.0 --initial-sparsity-conv 0.0 --batch-size 256 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --prune-target-sparsity-fc 0.9 --prune-target-sparsity-conv 0.9 --prune-iterations 20 --post-prune-epochs 10 --prune --start-pruning-after-epoch 0 --widen-factor 1 --job-idx 15 69 | 70 | #DeepR 71 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --batch-size 256 --widen-factor 1 --no-validate-train --sub-kernel-granularity 4 --stop-rewire-epoch 95 --DeepR --DeepR_eta 0.01 --L1-loss-coeff 1.0e-5 --job-idx 16 72 | 73 | #SET 74 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --rewire --rewire-fraction 0.14 --stop-rewire-epoch 95 --job-idx 17 75 | 76 | ##Dynamic sparse with 3x3 kernel slice re-allocation granularity 77 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 2 --rewire --threshold-prune --n-prune-params 40000 --rewire-scaling --stop-rewire-epoch 95 --job-idx 18 78 | 79 | ##Static sparse but with structure copied from trained dynamic sparse and random initialization. 80 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --widen-factor 1 --sub-kernel-granularity 4 --epochs 200 --weight-decay 1.0e-4 --copy-mask-from ./runs/imagenet_resnet50_12 --no-validate-train --job-idx 19 81 | 82 | ##Static sparse but with structure copied from trained dynamic sparse and initialization copied from the initialization of the same dynamic sparse network 83 | - python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule_long.yaml --initial-sparsity-fc 0.9 --initial-sparsity-conv 0.9 --widen-factor 1 --sub-kernel-granularity 4 --epochs 200 --weight-decay 1.0e-4 --copy-mask-from ./runs/imagenet_resnet50_12 --resume ./runs/imagenet_resnet50_12_initial --no-validate-train --job-idx 20 84 | -------------------------------------------------------------------------------- /imagenet/baseline/learning_schedules/resnet_prune_schedule.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | lr_schedule: 18 | - 0 : 0.0 19 | - 1 : 0.01 20 | - 15 : 0.01 21 | - 25 : 0.001 22 | - 35 : 0.0001 23 | 24 | rewire_period_schedule: 25 | - 0 : 100 26 | - 25 : 200 27 | - 50 : 400 28 | - 75 : 800 29 | 30 | DeepR_temperature_schedule: 31 | - 0 : 1.0e-3 32 | - 25 : 1.0e-4 33 | - 50 : 1.0e-5 34 | - 75 : 1.0e-6 35 | -------------------------------------------------------------------------------- /imagenet/baseline/learning_schedules/resnet_schedule.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | lr_schedule: 18 | - 0 : 0.1 19 | - 30 : 0.01 20 | - 60 : 0.001 21 | - 90 : 0.0001 22 | 23 | rewire_period_schedule: 24 | - 0 : 1000 25 | - 25 : 2000 26 | - 50 : 4000 27 | - 75 : 8000 28 | 29 | #The temperature schedule is only used when the --DeepR option is specified 30 | DeepR_temperature_schedule: 31 | - 0 : 1.0e-5 32 | - 25 : 1.0e-8 33 | - 50 : 1.0e-12 34 | - 75 : 1.0e-15 35 | -------------------------------------------------------------------------------- /imagenet/baseline/learning_schedules/resnet_schedule_long.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | lr_schedule: 18 | - 0 : 0.1 19 | - 60 : 0.01 20 | - 120 : 0.001 21 | - 180 : 0.0001 22 | 23 | rewire_period_schedule: 24 | - 0 : 1000 25 | - 60 : 2000 26 | - 120 : 4000 27 | - 180 : 8000 28 | 29 | #The temperature schedule is only used when the --DeepR option is specified 30 | DeepR_temperature_schedule: 31 | - 0 : 1.0e-5 32 | - 50 : 1.0e-8 33 | - 100 : 1.0e-12 34 | - 150 : 1.0e-15 35 | -------------------------------------------------------------------------------- /imagenet/baseline/learning_schedules/wrnet_prune_schedule.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | lr_schedule: 18 | - 0 : 0.0 19 | - 1 : 0.02 20 | - 15 : 0.02 21 | - 25 : 0.004 22 | - 35 : 0.0008 23 | 24 | rewire_period_schedule: 25 | - 0 : 100 26 | - 25 : 200 27 | - 50 : 400 28 | - 75 : 800 29 | 30 | DeepR_temperature_schedule: 31 | - 0 : 1.0e-3 32 | - 25 : 1.0e-4 33 | - 50 : 1.0e-5 34 | - 75 : 1.0e-6 35 | -------------------------------------------------------------------------------- /imagenet/baseline/learning_schedules/wrnet_schedule.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | lr_schedule: 18 | - 0 : 0.1 19 | - 60 : 0.02 20 | - 120 : 0.004 21 | - 160 : 0.0008 22 | 23 | rewire_period_schedule: 24 | - 0 : 100 25 | - 25 : 200 26 | - 80 : 400 27 | - 140 : 800 28 | 29 | DeepR_temperature_schedule: 30 | - 0 : 1.0e-5 31 | - 25 : 1.0e-8 32 | - 80 : 1.0e-12 33 | - 140 : 1.0e-15 34 | -------------------------------------------------------------------------------- /imagenet/baseline/learning_schedules/wrnet_schedule_long.yaml: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | lr_schedule: 18 | - 0 : 0.1 19 | - 120 : 0.02 20 | - 220 : 0.004 21 | - 350 : 0.0008 22 | 23 | rewire_period_schedule: 24 | - 0 : 100 25 | - 25 : 200 26 | - 80 : 400 27 | - 140 : 800 28 | 29 | DeepR_temperature_schedule: 30 | - 0 : 1.0e-5 31 | - 25 : 1.0e-8 32 | - 80 : 1.0e-12 33 | - 140 : 1.0e-15 34 | -------------------------------------------------------------------------------- /imagenet/baseline/model.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | # Modifications copyright (C) 2019 Tim Dettmers 17 | # CHANGES: 18 | # - Replaced DynamicLinear layer with Linear layer in ResNet-50 to be compatible with sparselearning library. 19 | 20 | import numpy as np 21 | import math 22 | import torch 23 | import torch.nn as nn 24 | from reparameterized_layers import DynamicLinear,DynamicConv2d 25 | from parameterized_tensors import SparseTensor,TiedTensor 26 | import torch.utils.model_zoo as model_zoo 27 | import torch.nn.functional as F 28 | 29 | class DynamicNetworkBase(nn.Module): 30 | def __init__(self): 31 | super(DynamicNetworkBase, self).__init__() 32 | self.split_state = False 33 | 34 | def prune(self,prune_fraction_fc,prune_fraction_conv,prune_fraction_fc_special = None): 35 | for x in [x for x in self.modules() if isinstance(x,SparseTensor)]: 36 | if x.conv_tensor: 37 | x.prune_small_connections(prune_fraction_conv) 38 | else: 39 | if x.s_tensor.size(0) == 10 and x.s_tensor.size(1) == 100: 40 | x.prune_small_connections(prune_fraction_fc_special) 41 | else: 42 | x.prune_small_connections(prune_fraction_fc) 43 | 44 | 45 | 46 | def get_model_size(self): 47 | def get_tensors_and_test(tensor_type): 48 | relevant_tensors = [x for x in self.modules() if isinstance(x,tensor_type)] 49 | relevant_params = [p for x in relevant_tensors for p in x.parameters()] 50 | is_relevant_param = lambda x : [y for y in relevant_params if x is y] 51 | 52 | return relevant_tensors,is_relevant_param 53 | 54 | sparse_tensors,is_sparse_param = get_tensors_and_test(SparseTensor) 55 | tied_tensors,is_tied_param = get_tensors_and_test(TiedTensor) 56 | 57 | 58 | sparse_params = [p for x in sparse_tensors for p in x.parameters()] 59 | is_sparse_param = lambda x : [y for y in sparse_params if x is y] 60 | 61 | 62 | sparse_size = sum([x.get_sparsity()[0].item() for x in sparse_tensors]) 63 | 64 | tied_size = 0 65 | for k in tied_tensors: 66 | unique_reps = k.weight_alloc.cpu().unique() 67 | subtensor_size = np.prod(list(k.bank.size())[1:]) 68 | tied_size += unique_reps.size(0) * subtensor_size 69 | 70 | 71 | fixed_size = sum([p.data.nelement() for p in self.parameters() if (not is_sparse_param(p) and not is_tied_param(p))]) 72 | model_size = {'sparse': sparse_size,'tied' : tied_size, 'fixed':fixed_size,'learnable':fixed_size + sparse_size + tied_size} 73 | return model_size 74 | 75 | 76 | 77 | class mnist_mlp(DynamicNetworkBase): 78 | 79 | def __init__(self, initial_sparsity = 0.98,sparse = True,no_batch_norm = False): 80 | super(mnist_mlp, self).__init__() 81 | 82 | self.fc1 = DynamicLinear(784, 300, initial_sparsity,bias = no_batch_norm,sparse = sparse) 83 | self.fc_int = DynamicLinear(300, 100, initial_sparsity,bias = no_batch_norm,sparse = sparse) 84 | #self.fc2 = DynamicLinear(100, 10, min(0.5,initial_sparsity),bias = False,sparse = sparse) 85 | self.fc2 = DynamicLinear(100, 10, initial_sparsity,bias = no_batch_norm,sparse = sparse) 86 | 87 | if no_batch_norm: 88 | self.bn1 = lambda x : x 89 | self.bn2 = lambda x : x 90 | self.bn3 = lambda x : x 91 | else: 92 | self.bn1 = nn.BatchNorm1d(300) 93 | self.bn2 = nn.BatchNorm1d(100) 94 | self.bn3 = nn.BatchNorm1d(10) 95 | 96 | 97 | def forward(self, x): 98 | x = F.relu(self.bn1(self.fc1(x.view(-1, 784)))) 99 | x = F.relu(self.bn2(self.fc_int(x))) 100 | y = self.bn3(self.fc2(x)) 101 | 102 | return y 103 | 104 | 105 | 106 | #########Definition of wide resnets 107 | 108 | class BasicBlock(nn.Module): 109 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0,widen_factor = 10,initial_sparsity = 0.5,sub_kernel_granularity = False,sparse = True, sparse_momentum=False): 110 | super(BasicBlock, self).__init__() 111 | self.bn1 = nn.BatchNorm2d(in_planes) 112 | self.relu1 = nn.ReLU(inplace=True) 113 | 114 | self.conv1 = DynamicConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 115 | padding=1, bias=False,initial_sparsity = initial_sparsity,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 116 | 117 | 118 | self.bn2 = nn.BatchNorm2d(out_planes) 119 | self.relu2 = nn.ReLU(inplace=True) 120 | 121 | self.conv2 = DynamicConv2d(out_planes, out_planes, kernel_size=3, stride=1, 122 | padding=1, bias=False,initial_sparsity = initial_sparsity,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 123 | 124 | self.droprate = dropRate 125 | self.equalInOut = (in_planes == out_planes) 126 | #self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 127 | #padding=0, bias=False) or None 128 | self.convShortcut = (not self.equalInOut) and DynamicConv2d(in_planes, out_planes, kernel_size=1, stride=stride, 129 | padding=0, bias=False, initial_sparsity=initial_sparsity, sub_kernel_granularity=sub_kernel_granularity, sparse=sparse) or None 130 | def forward(self, x): 131 | if not self.equalInOut: 132 | x = self.relu1(self.bn1(x)) 133 | else: 134 | out = self.relu1(self.bn1(x)) 135 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 136 | if self.droprate > 0: 137 | out = F.dropout(out, p=self.droprate, training=self.training) 138 | out = self.conv2(out) 139 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 140 | 141 | 142 | class NetworkBlock(nn.Module): 143 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0,widen_factor = 10,initial_sparsity = 0.5,sub_kernel_granularity = False,sparse = True): 144 | super(NetworkBlock, self).__init__() 145 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate,widen_factor,initial_sparsity = initial_sparsity, 146 | sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 147 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate,widen_factor,initial_sparsity = 0.5,sub_kernel_granularity = False,sparse = True): 148 | layers = [] 149 | for i in range(int(nb_layers)): 150 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate,widen_factor = widen_factor, 151 | initial_sparsity = initial_sparsity,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse)) 152 | return nn.Sequential(*layers) 153 | def forward(self, x): 154 | return self.layer(x) 155 | 156 | class cifar10_WideResNet(DynamicNetworkBase): 157 | def __init__(self, depth, num_classes=10, widen_factor=1, dropRate=0.0,initial_sparsity_conv = 0.5,initial_sparsity_fc = 0.95,sub_kernel_granularity = 4,sparse = True): 158 | super(cifar10_WideResNet, self).__init__() 159 | nChannels = np.round(np.array([16, 16*widen_factor, 32*widen_factor, 64*widen_factor])).astype('int32') 160 | assert((depth - 4) % 6 == 0) 161 | n = (depth - 4) / 6 162 | block = BasicBlock 163 | # 1st conv before any network block 164 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 165 | padding=1, bias=False) 166 | # 1st block 167 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate,widen_factor = widen_factor, 168 | initial_sparsity = initial_sparsity_conv,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 169 | # 2nd block 170 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate,widen_factor = widen_factor, 171 | initial_sparsity = initial_sparsity_conv,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 172 | # 3rd block 173 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate,widen_factor = widen_factor, 174 | initial_sparsity = initial_sparsity_conv,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 175 | 176 | # global average pooling and classifier 177 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.fc = nn.Linear(nChannels[3],num_classes) #DynamicLinear(nChannels[3], num_classes,initial_sparsity = initial_sparsity_fc,sparse = sparse) 180 | self.nChannels = nChannels[3] 181 | self.split_state = False 182 | self.reset_parameters() 183 | 184 | def reset_parameters(self): 185 | for m in self.modules(): 186 | if isinstance(m, nn.Conv2d): 187 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 188 | m.weight.data.normal_(0, math.sqrt(2. / n)) 189 | elif isinstance(m, DynamicConv2d): 190 | n = m.kernel_size * m.kernel_size * m.n_output_maps 191 | if m.sparse: 192 | m.d_tensor.s_tensor.data.normal_(0, math.sqrt(2. / n)) 193 | else: 194 | m.d_tensor.bank.data.normal_(0, math.sqrt(2. / n)) 195 | if isinstance(m, nn.BatchNorm2d): 196 | m.weight.data.fill_(1) 197 | m.bias.data.zero_() 198 | elif isinstance(m, nn.Linear): 199 | m.bias.data.zero_() 200 | 201 | def forward(self, x): 202 | out = self.conv1(x) 203 | out = self.block1(out) 204 | out = self.block2(out) 205 | out = self.block3(out) 206 | out = self.relu(self.bn1(out)) 207 | out = F.avg_pool2d(out, 8) 208 | out = out.view(-1, self.nChannels) 209 | return self.fc(out) 210 | 211 | 212 | ###Resnet Definition 213 | class Bottleneck(nn.Module): 214 | expansion = 4 215 | 216 | def __init__(self, inplanes, planes, stride=1, downsample=None,widen_factor = 1,vanilla_conv1 = True,vanilla_conv3 = True,initial_sparsity = 0.5, 217 | sub_kernel_granularity = 4,sparse = True): 218 | super(Bottleneck, self).__init__() 219 | adjusted_planes = planes#np.round(widen_factor * planes).astype('int32') 220 | 221 | #if vanilla_conv1: 222 | if not sparse: 223 | self.conv1 = nn.Conv2d(inplanes, adjusted_planes, kernel_size=1, bias=False) 224 | self.conv3 = nn.Conv2d(adjusted_planes, planes * 4, kernel_size=1, bias=False) 225 | else: 226 | self.conv1 = DynamicConv2d(inplanes, adjusted_planes, kernel_size=1, bias=False , initial_sparsity = initial_sparsity, 227 | sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 228 | self.conv3 = DynamicConv2d(adjusted_planes, planes * 4, kernel_size=1, bias=False , initial_sparsity = initial_sparsity, 229 | sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 230 | if not sparse: 231 | self.conv2 = nn.Conv2d(adjusted_planes, adjusted_planes, kernel_size=3, stride=stride,padding=1, bias=False) 232 | else: 233 | self.conv2 = DynamicConv2d(adjusted_planes, adjusted_planes, kernel_size=3, stride=stride, 234 | padding=1, bias=False,initial_sparsity = initial_sparsity, sub_kernel_granularity = sub_kernel_granularity,sparse = sparse) 235 | 236 | 237 | self.bn1 = nn.BatchNorm2d(adjusted_planes) 238 | self.bn2 = nn.BatchNorm2d(adjusted_planes) 239 | self.bn3 = nn.BatchNorm2d(planes * 4) 240 | self.relu = nn.ReLU(inplace=True) 241 | self.downsample = downsample 242 | self.stride = stride 243 | 244 | def forward(self, x): 245 | residual = x 246 | 247 | out = self.conv1(x) 248 | out = self.bn1(out) 249 | out = self.relu(out) 250 | 251 | out = self.conv2(out) 252 | out = self.bn2(out) 253 | out = self.relu(out) 254 | 255 | out = self.conv3(out) 256 | out = self.bn3(out) 257 | 258 | if self.downsample is not None: 259 | residual = self.downsample(x) 260 | 261 | out += residual 262 | out = self.relu(out) 263 | 264 | return out 265 | 266 | 267 | class ResNet(DynamicNetworkBase): 268 | 269 | def __init__(self, block, layers, num_classes=1000,widen_factor = 1,vanilla_downsample = False,vanilla_conv1 = False,vanilla_conv3 = False, 270 | initial_sparsity_conv = 0.5,initial_sparsity_fc = 0.95,sub_kernel_granularity = 4,sparse = True): 271 | self.inplanes = np.round(64 * widen_factor).astype('int32') 272 | super(ResNet, self).__init__() 273 | self.widen_factor = widen_factor 274 | self.vanilla_conv1 = vanilla_conv1 275 | self.vanilla_conv3 = vanilla_conv3 276 | self.vanilla_downsample = vanilla_downsample 277 | self.initial_sparsity_conv = initial_sparsity_conv 278 | self.initial_sparsity_fc = initial_sparsity_fc 279 | self.sub_kernel_granularity = sub_kernel_granularity 280 | self.sparse = sparse 281 | 282 | if not sparse: 283 | self.conv1 = nn.Conv2d(3, np.round(64 * widen_factor).astype('int32'), kernel_size=7, stride=2, padding=3, 284 | bias=False) 285 | else: 286 | self.conv1 = DynamicConv2d(3, np.round(64 * widen_factor).astype('int32'), kernel_size=7, stride=2, padding=3, 287 | bias=False, initial_sparsity=initial_sparsity_conv, sub_kernel_granularity=sub_kernel_granularity, sparse=sparse) 288 | self.bn1 = nn.BatchNorm2d(np.round(64 * widen_factor).astype('int32')) 289 | self.relu = nn.ReLU(inplace=True) 290 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 291 | self.layer1 = self._make_layer(block, np.round(64 * widen_factor).astype('int32'), layers[0]) 292 | self.layer2 = self._make_layer(block, np.round(64 * widen_factor).astype('int32')*2, layers[1], stride=2) 293 | self.layer3 = self._make_layer(block, np.round(64 * widen_factor).astype('int32')*4, layers[2], stride=2) 294 | self.layer4 = self._make_layer(block, np.round(64 * widen_factor).astype('int32')*8, layers[3], stride=2) 295 | self.avgpool = nn.AvgPool2d(7, stride=1) 296 | if not sparse: 297 | self.fc = nn.Linear(np.round(64 * widen_factor).astype('int32') * block.expansion * 8, num_classes,bias=True) 298 | else: 299 | self.fc = DynamicLinear(np.round(64 * widen_factor).astype('int32') * block.expansion * 8, num_classes,initial_sparsity = self.initial_sparsity_fc,sparse = sparse) 300 | 301 | for m in self.modules(): 302 | if isinstance(m, nn.Conv2d): 303 | if sparse: 304 | raise Exception('Used sparse=True, but some layers are still dense.') 305 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 306 | m.weight.data.normal_(0, math.sqrt(2. / n)) 307 | elif isinstance(m, DynamicConv2d): 308 | if not sparse: 309 | raise Exception('Used sparse=False, but some layers are still sparse.') 310 | n = m.kernel_size * m.kernel_size * m.n_output_maps 311 | if m.sparse: 312 | m.d_tensor.s_tensor.data.normal_(0, math.sqrt(2. / n)) 313 | else: 314 | m.d_tensor.bank.data.normal_(0, math.sqrt(2. / n)) 315 | elif isinstance(m, nn.BatchNorm2d): 316 | m.weight.data.fill_(1) 317 | m.bias.data.zero_() 318 | 319 | def _make_layer(self, block, planes, blocks, stride=1): 320 | downsample = None 321 | if stride != 1 or self.inplanes != planes * block.expansion: 322 | if not self.sparse: 323 | conv = nn.Conv2d(self.inplanes, planes * block.expansion, 324 | kernel_size=1, stride=stride, bias=False) 325 | else: 326 | DynamicConv2d(self.inplanes, planes * block.expansion,kernel_size=1,stride=stride, bias=False, 327 | initial_sparsity = self.initial_sparsity_conv,sub_kernel_granularity = self.sub_kernel_granularity,sparse = self.sparse) 328 | downsample = nn.Sequential(conv, nn.BatchNorm2d(planes * block.expansion)) 329 | 330 | layers = [] 331 | layers.append(block(self.inplanes, planes, stride, downsample,widen_factor = self.widen_factor, 332 | vanilla_conv1 = self.vanilla_conv1,vanilla_conv3 = self.vanilla_conv3,initial_sparsity = self.initial_sparsity_conv, 333 | sub_kernel_granularity = self.sub_kernel_granularity,sparse = self.sparse)) 334 | self.inplanes = planes * block.expansion 335 | for i in range(1, blocks): 336 | layers.append(block(self.inplanes, planes,widen_factor = self.widen_factor, 337 | vanilla_conv1 = self.vanilla_conv1,vanilla_conv3 = self.vanilla_conv3,initial_sparsity = self.initial_sparsity_conv, 338 | sub_kernel_granularity = self.sub_kernel_granularity,sparse = self.sparse)) 339 | return nn.Sequential(*layers) 340 | 341 | def forward(self, x): 342 | x = self.conv1(x) 343 | x = self.bn1(x) 344 | x = self.relu(x) 345 | x = self.maxpool(x) 346 | 347 | x = self.layer1(x) 348 | x = self.layer2(x) 349 | x = self.layer3(x) 350 | x = self.layer4(x) 351 | 352 | x = self.avgpool(x) 353 | x = x.view(x.size(0), -1) 354 | x = self.fc(x) 355 | 356 | return x 357 | 358 | 359 | def imagenet_resnet50(widen_factor = 1,vanilla_conv1 = False,vanilla_conv3 = False,vanilla_downsample = True,decimation_factor = 8, 360 | initial_sparsity_conv = 0.5,initial_sparsity_fc = 0.95,sub_kernel_granularity = 4,sparse = True, **kwargs): 361 | """Constructs a ResNet-50 model. 362 | 363 | """ 364 | model = ResNet(Bottleneck, [3, 4, 6, 3],widen_factor = widen_factor, 365 | vanilla_conv1 = vanilla_conv1,vanilla_conv3 = vanilla_conv3,vanilla_downsample = vanilla_downsample, initial_sparsity_conv = initial_sparsity_conv, 366 | initial_sparsity_fc = initial_sparsity_fc,sub_kernel_granularity = sub_kernel_granularity,sparse = sparse,**kwargs) 367 | return model 368 | -------------------------------------------------------------------------------- /imagenet/baseline/parameterized_tensors.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | import math 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import numpy as np 22 | from torch.autograd import Variable 23 | from torch.nn.parameter import Parameter 24 | 25 | 26 | 27 | def uniform_coverage(rank,n_features): 28 | reps = torch.zeros(n_features) 29 | place_element = torch.arange(rank) 30 | for i in np.arange(0,n_features,rank): 31 | reps[i:i+rank] = place_element[0:min(rank,n_features - i)] 32 | return reps.long() 33 | 34 | 35 | class TiedTensor(nn.Module): 36 | 37 | def __init__(self, full_tensor_size,initial_sparsity, sub_kernel_granularity = False): 38 | 39 | super(TiedTensor, self).__init__() 40 | 41 | ndim = len(full_tensor_size) 42 | assert ndim == 2 or ndim == 4, 'only 2D or 4D tensors supported' 43 | 44 | self.full_tensor_size = torch.Size(full_tensor_size) 45 | self.sub_kernel_granularity = sub_kernel_granularity 46 | 47 | n_alloc_elements = np.prod(self.full_tensor_size).item() if sub_kernel_granularity else np.prod(self.full_tensor_size[:2]).item() 48 | 49 | self.num_weights = round((1 - initial_sparsity)*n_alloc_elements) 50 | 51 | 52 | self.register_buffer('weight_alloc',torch.zeros(n_alloc_elements).long()) 53 | indices = np.arange(n_alloc_elements) 54 | np.random.shuffle(indices) 55 | self.weight_alloc[indices] = uniform_coverage(self.num_weights,n_alloc_elements) 56 | 57 | self.conv_tensor = False if ndim ==2 else True 58 | 59 | 60 | trailing_dimensions = [] if sub_kernel_granularity else self.full_tensor_size[2:] 61 | self.bank = Parameter(torch.Tensor(self.num_weights,*trailing_dimensions)) 62 | 63 | self.init_parameters() 64 | 65 | 66 | def init_parameters(self): 67 | stdv = 1 / math.sqrt(np.prod(self.full_tensor_size[1:])) 68 | 69 | self.bank.data.uniform_(-stdv, stdv) 70 | self.bank.data[0] = 0.0 71 | def extra_repr(self): 72 | 73 | return 'full tensor size={} , unique_active_weights={}, fraction_of_total_weights = {}, sub_kernel_granularity = {}'.format( 74 | self.full_tensor_size, self.num_weights,self.num_weights * 1.0 / self.weight_alloc.size(0),self.sub_kernel_granularity) 75 | 76 | def forward(self): 77 | return self.bank[self.weight_alloc].view(self.full_tensor_size) 78 | 79 | 80 | class SparseTensor(nn.Module): 81 | def __init__(self,tensor_size,initial_sparsity,sub_kernel_granularity = 4): 82 | super(SparseTensor,self).__init__() 83 | self.s_tensor = Parameter(torch.Tensor(torch.Size(tensor_size))) 84 | self.initial_sparsity = initial_sparsity 85 | self.sub_kernel_granularity = sub_kernel_granularity 86 | 87 | 88 | assert self.s_tensor.dim() == 2 or self.s_tensor.dim() == 4, "can only do 2D or 4D sparse tensors" 89 | 90 | 91 | trailing_dimensions = [1]*(4 - sub_kernel_granularity) 92 | self.register_buffer('mask',torch.Tensor(*(tensor_size[:sub_kernel_granularity] ))) 93 | 94 | self.normalize_coeff = np.prod(tensor_size[sub_kernel_granularity:]).item() 95 | 96 | 97 | self.conv_tensor = False if self.s_tensor.dim() ==2 else True 98 | 99 | self.mask.zero_() 100 | flat_mask = self.mask.view(-1) 101 | indices = np.arange(flat_mask.size(0)) 102 | np.random.shuffle(indices) 103 | flat_mask[indices[:int((1-initial_sparsity) * flat_mask.size(0) + 0.1)]] = 1 104 | 105 | self.grown_indices = None 106 | self.init_parameters() 107 | self.reinitialize_unused() 108 | 109 | self.tensor_sign = torch.sign(self.s_tensor.data.view(-1)) 110 | 111 | 112 | def reinitialize_unused(self,reinitialize_unused_to_zero = True): 113 | unused_positions = (self.mask < 0.5) 114 | if reinitialize_unused_to_zero: 115 | self.s_tensor.data[unused_positions] = torch.zeros(self.s_tensor.data[unused_positions].size()).to(self.s_tensor.device) 116 | else: 117 | if self.conv_tensor: 118 | n = self.s_tensor.size(0) * self.s_tensor.size(2) * self.s_tensor.size(3) 119 | self.s_tensor.data[unused_positions] = torch.zeros(self.s_tensor.data[unused_positions].size()).normal_(0, math.sqrt(2. / n)).to(self.s_tensor.device) 120 | else: 121 | stdv = 1. / math.sqrt(self.s_tensor.size(1)) 122 | self.s_tensor.data[unused_positions] = torch.zeros(self.s_tensor.data[unused_positions].size()).normal_(0, stdv).to(self.s_tensor.device) 123 | 124 | def init_parameters(self): 125 | stdv = 1 / math.sqrt(np.prod(self.s_tensor.size()[1:])) 126 | 127 | self.s_tensor.data.uniform_(-stdv, stdv) 128 | 129 | def prune_sign_change(self,reinitialize_unused_to_zero = True,enable_print = False): 130 | W_flat = self.s_tensor.data.view(-1) 131 | 132 | new_tensor_sign = torch.sign(W_flat) 133 | mask_flat = self.mask.view(-1) 134 | 135 | mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) 136 | 137 | sign_change_indices = mask_indices[((new_tensor_sign[mask_indices] * self.tensor_sign[mask_indices].to(new_tensor_sign.device)) < -0.5).nonzero().view(-1)] 138 | 139 | mask_flat[sign_change_indices] = 0 140 | self.reinitialize_unused(reinitialize_unused_to_zero) 141 | 142 | cutoff = sign_change_indices.numel() 143 | 144 | if enable_print: 145 | print('pruned {} connections'.format(cutoff)) 146 | if self.grown_indices is not None and enable_print: 147 | overlap = np.intersect1d(sign_change_indices.cpu().numpy(),self.grown_indices.cpu().numpy()) 148 | print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0) if self.grown_indices.size(0) > 0 else 0.0)) 149 | 150 | self.tensor_sign = new_tensor_sign 151 | return sign_change_indices 152 | 153 | 154 | def prune_small_connections(self,prune_fraction,reinitialize_unused_to_zero = True): 155 | if self.conv_tensor and self.sub_kernel_granularity < 4: 156 | W_flat = self.s_tensor.abs().sum(list(np.arange(self.sub_kernel_granularity,4))).view(-1) / self.normalize_coeff 157 | else: 158 | W_flat = self.s_tensor.data.view(-1) 159 | 160 | mask_flat = self.mask.view(-1) 161 | 162 | mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) 163 | 164 | 165 | W_masked = W_flat[mask_indices] 166 | 167 | sorted_W_indices = torch.sort(torch.abs(W_masked))[1] 168 | 169 | 170 | cutoff = int(prune_fraction * W_masked.numel()) + 1 171 | 172 | mask_flat[mask_indices[sorted_W_indices[:cutoff]]] = 0 173 | self.reinitialize_unused(reinitialize_unused_to_zero) 174 | 175 | # print('pruned {} connections'.format(cutoff)) 176 | # if self.grown_indices is not None: 177 | # overlap = np.intersect1d(mask_indices[sorted_W_indices[:cutoff]].cpu().numpy(),self.grown_indices.cpu().numpy()) 178 | #print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0))) 179 | 180 | return mask_indices[sorted_W_indices[:cutoff]] 181 | 182 | def prune_threshold(self,threshold,reinitialize_unused_to_zero = True): 183 | if self.conv_tensor and self.sub_kernel_granularity < 4: 184 | W_flat = self.s_tensor.abs().sum(list(np.arange(self.sub_kernel_granularity,4))).view(-1) / self.normalize_coeff 185 | else: 186 | W_flat = self.s_tensor.data.view(-1) 187 | 188 | mask_flat = self.mask.view(-1) 189 | 190 | mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) 191 | 192 | W_masked = W_flat[mask_indices] 193 | 194 | prune_indices = (W_masked.abs() < threshold).nonzero().view(-1) 195 | 196 | 197 | if mask_indices.size(0) == prune_indices.size(0): 198 | print('removing all. keeping one') 199 | prune_indices = prune_indices[1:] 200 | 201 | 202 | mask_flat[mask_indices[prune_indices]] = 0 203 | 204 | # if mask_indices.numel() > 0 : 205 | # print('pruned {}/{}({:.2f}) connections'.format(prune_indices.numel(),mask_indices.numel(),prune_indices.numel()/mask_indices.numel())) 206 | 207 | # if self.grown_indices is not None and self.grown_indices.size(0) != 0 : 208 | # overlap = np.intersect1d(mask_indices[prune_indices].cpu().numpy(),self.grown_indices.cpu().numpy()) 209 | # print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0))) 210 | 211 | self.reinitialize_unused(reinitialize_unused_to_zero) 212 | 213 | 214 | return mask_indices[prune_indices] 215 | 216 | def grow_random(self,grow_fraction,pruned_indices = None,enable_print = False,n_to_add = None): 217 | mask_flat = self.mask.view(-1) 218 | mask_zero_indices = torch.nonzero(mask_flat < 0.5).view(-1) 219 | if pruned_indices is not None: 220 | cutoff = pruned_indices.size(0) 221 | mask_zero_indices = torch.Tensor(np.setdiff1d(mask_zero_indices.cpu().numpy(),pruned_indices.cpu().numpy())).long().to(mask_zero_indices.device) 222 | else: 223 | cutoff = int(grow_fraction * mask_zero_indices.size(0)) 224 | 225 | if n_to_add is not None: 226 | cutoff = n_to_add 227 | 228 | 229 | if mask_zero_indices.numel() < cutoff: 230 | print('******no place to grow {} connections, growing {} instead'.format(cutoff,mask_zero_indices.numel())) 231 | cutoff = mask_zero_indices.numel() 232 | 233 | if enable_print: 234 | print('grown {} connections'.format(cutoff)) 235 | 236 | self.grown_indices = mask_zero_indices[torch.randperm(mask_zero_indices.numel())][:cutoff] 237 | mask_flat[self.grown_indices] = 1 238 | 239 | return cutoff 240 | 241 | def get_sparsity(self): 242 | active_elements = self.mask.sum() * np.prod(self.s_tensor.size()[self.sub_kernel_granularity:]).item() 243 | return (active_elements,1 - active_elements / self.s_tensor.numel()) 244 | 245 | 246 | def forward(self): 247 | if self.conv_tensor: 248 | return self.mask.view(*(self.mask.size() + (1,)*(4 - self.sub_kernel_granularity))) * self.s_tensor 249 | else: 250 | return self.mask * self.s_tensor 251 | 252 | 253 | def extra_repr(self): 254 | return 'full tensor size : {} , sparsity mask : {} , sub kernel granularity : {}'.format( 255 | self.s_tensor.size(), self.get_sparsity(),self.sub_kernel_granularity) 256 | 257 | 258 | -------------------------------------------------------------------------------- /imagenet/baseline/reparameterized_layers.py: -------------------------------------------------------------------------------- 1 | # ****************************************************************************** 2 | # Copyright 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ****************************************************************************** 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from torch.nn.parameter import Parameter 21 | from parameterized_tensors import SparseTensor,TiedTensor 22 | 23 | class DynamicLinear(nn.Module): 24 | 25 | def __init__(self, in_features, out_features, initial_sparsity, bias = True , sparse = True ): 26 | super(DynamicLinear, self).__init__() 27 | self.in_features = in_features 28 | self.out_features = out_features 29 | self.initial_sparsity = initial_sparsity 30 | self.sparse = sparse 31 | 32 | if sparse: 33 | self.d_tensor = SparseTensor([out_features,in_features],initial_sparsity = initial_sparsity) 34 | else: 35 | self.d_tensor = TiedTensor([out_features,in_features],initial_sparsity = initial_sparsity) 36 | 37 | if bias: 38 | self.bias = Parameter(torch.Tensor(out_features)) 39 | else: 40 | self.bias = None 41 | 42 | self.init_parameters() 43 | # self.weight = self.d_tensor.s_tensor 44 | 45 | def init_parameters(self): 46 | if self.bias is not None: 47 | self.bias.data.zero_() 48 | self.d_tensor.init_parameters() 49 | 50 | def forward(self, input): 51 | return F.linear(input, self.d_tensor() , self.bias) 52 | 53 | def extra_repr(self): 54 | return 'in_features={}, out_features={}, initial_sparsity = {}, bias={}'.format( 55 | self.in_features, self.out_features, self.initial_sparsity,self.bias is not None) 56 | 57 | 58 | 59 | class DynamicConv2d(nn.Module): 60 | 61 | def __init__(self, 62 | n_input_maps, 63 | n_output_maps, 64 | kernel_size, 65 | stride=1, 66 | padding=0, 67 | dilation=1, 68 | groups=1, 69 | bias = True,initial_sparsity = 0.5,sub_kernel_granularity = 4,sparse=True): 70 | 71 | super(DynamicConv2d, self).__init__() 72 | 73 | if n_input_maps % groups != 0: 74 | raise ValueError('n_input_maps must be divisible by groups') 75 | 76 | self.sparse = sparse 77 | self.n_input_maps = n_input_maps 78 | self.n_output_maps = n_output_maps 79 | self.kernel_size = kernel_size 80 | 81 | 82 | if sparse: 83 | self.d_tensor = SparseTensor([n_output_maps,n_input_maps // groups, kernel_size, kernel_size],initial_sparsity = initial_sparsity,sub_kernel_granularity = sub_kernel_granularity) 84 | else: 85 | self.d_tensor = TiedTensor([n_output_maps,n_input_maps // groups, kernel_size, kernel_size],initial_sparsity = initial_sparsity,sub_kernel_granularity = sub_kernel_granularity) 86 | 87 | if bias: 88 | self.bias = Parameter(torch.Tensor(n_output_maps)) 89 | else: 90 | self.bias = None 91 | 92 | self.groups = groups 93 | self.stride = (stride,) * 2 94 | self.padding = (padding,) * 2 95 | self.dilation = (dilation,) * 2 96 | 97 | self.init_parameters() 98 | 99 | # self.weight = self.d_tensor.s_tensor 100 | 101 | def init_parameters(self): 102 | if self.bias is not None: 103 | self.bias.data.zero_() 104 | self.d_tensor.init_parameters() 105 | 106 | 107 | def forward(self, input): 108 | return F.conv2d(input, self.d_tensor(), self.bias, self.stride, self.padding, self.dilation, 109 | self.groups) 110 | 111 | def extra_repr(self): 112 | s = ('{name}({n_input_maps}, {n_output_maps}, kernel_size={kernel_size}, bias = {bias_exists}' 113 | ', stride={stride}') 114 | if self.padding != (0,) * len(self.padding): 115 | s += ', padding={padding}' 116 | if self.dilation != (1,) * len(self.dilation): 117 | s += ', dilation={dilation}' 118 | if self.groups != 1: 119 | s += ', groups={groups}' 120 | s += ')' 121 | return s.format( 122 | name=self.__class__.__name__,bias_exists = self.bias is not None,**self.__dict__) 123 | 124 | -------------------------------------------------------------------------------- /imagenet/baseline/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /imagenet/baseline/run_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | IMAGENET_PATH="PATH TO IMAGENET FOLDER" 3 | 4 | # Pick a density 5 | DENSITY=0.25626087 6 | #DENSITY=0.15626087 7 | 8 | CUDA_VISIBLE_DEVICES=0,1,2,3 OMP_NUM_THREADS=1 python main.py --model imagenet_resnet50 --schedule-file ./learning_schedules/resnet_schedule.yaml --initial-sparsity-fc 0.0 --initial-sparsity-conv 0.0 --batch-size 256 --widen-factor 1 --weight-decay 1.0e-4 --no-validate-train --sub-kernel-granularity 4 --job-idx 13 --sparse-momentum --data $IMAGENET_PATH --verbose --density $DENSITY --prune-rate 0.2 --workers 32 --fp16 -p 300 9 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/LICENSE: -------------------------------------------------------------------------------- 1 | All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/README.md: -------------------------------------------------------------------------------- 1 | # ResNet50 v1.5 2 | 3 | ## Orginal source 4 | 5 | The source code is a minimal adaptation from the [NVIDIA ResNet50 v1.5 repo](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5). I just added after my sparse learning wrapper and modified the execution scripts in the example folder. 6 | 7 | ## The model 8 | The ResNet50 v1.5 model is a modified version of the [original ResNet50 v1 model](https://arxiv.org/abs/1512.03385). 9 | 10 | The difference between v1 and v1.5 is that, in the bottleneck blocks which requires 11 | downsampling, v1 has stride = 2 in the first 1x1 convolution, whereas v1.5 has stride = 2 in the 3x3 convolution. 12 | 13 | This difference makes ResNet50 v1.5 slightly more accurate (~0.5% top1) than v1, but comes with a smallperformance drawback (~5% imgs/sec). 14 | 15 | The model is initialized as described in [Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification](https://arxiv.org/pdf/1502.01852.pdf) 16 | 17 | ## Training procedure 18 | 19 | ### Optimizer 20 | 21 | This model trains for 90 epochs, with standard ResNet v1.5 setup: 22 | 23 | * SGD with momentum (0.9) 24 | 25 | * Learning rate = 0.1 for 256 batch size, for other batch sizes we lineary 26 | scale the learning rate. 27 | 28 | * Learning rate decay - multiply by 0.1 after 30, 60, and 80 epochs 29 | 30 | * For bigger batch sizes (512 and up) we use linear warmup of the learning rate 31 | during first 5 epochs 32 | according to [Training ImageNet in 1 hour](https://arxiv.org/abs/1706.02677). 33 | 34 | * Weight decay: 1e-4 35 | 36 | * We do not apply WD on Batch Norm trainable parameters (gamma/bias) 37 | 38 | * Label Smoothing: 0.1 39 | 40 | 41 | ### Data Augmentation 42 | 43 | This model uses the following data augmentation: 44 | 45 | * For training: 46 | * Normalization 47 | * Random resized crop to 224x224 48 | * Scale from 8% to 100% 49 | * Aspect ratio from 3/4 to 4/3 50 | * Random horizontal flip 51 | 52 | * For inference: 53 | * Normalization 54 | * Scale to 256x256 55 | * Center crop to 224x224 56 | 57 | ### Other training recipes 58 | 59 | This script does not targeting any specific benchmark. 60 | There are changes that others have made which can speed up convergence and/or increase accuracy. 61 | 62 | One of the more popular training recipes is provided by [fast.ai](https://github.com/fastai/imagenet-fast). 63 | 64 | The fast.ai recipe introduces many changes to the training procedure, one of which is progressive resizing of the training images. 65 | 66 | The first part of training uses 128px images, the middle part uses 224px images, and the last part uses 288px images. 67 | The final validation is performed on 288px images. 68 | 69 | Training script in this repository performs validation on 224px images, just like the original paper described. 70 | 71 | These two approaches can't be directly compared, since the fast.ai recipe requires validation on 288px images, 72 | and this recipe keeps the original assumption that validation is done on 224px images. 73 | 74 | Using 288px images means that a lot more FLOPs are needed during inference to reach the same accuracy. 75 | 76 | 77 | # Setup 78 | ## Requirements 79 | 80 | Ensure you meet the following requirements: 81 | 82 | * [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) 83 | * [PyTorch 18.09-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch) or newer 84 | * (optional) NVIDIA Volta GPU (see section below) - for best training performance using mixed precision 85 | 86 | For more information about how to get started with NGC containers, see the 87 | following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning 88 | DGX Documentation: 89 | * [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html) 90 | * [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/dgx/user-guide/index.html#accessing_registry) 91 | * [Running PyTorch](https://docs.nvidia.com/deeplearning/dgx/pytorch-release-notes/running.html#running) 92 | 93 | ## Training using mixed precision with Tensor Cores 94 | 95 | ### Hardware requirements 96 | Training with mixed precision on NVIDIA Tensor Cores, requires an 97 | [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)-based GPU. 98 | 99 | ### Software changes 100 | 101 | For information about how to train using mixed precision, see the 102 | [Mixed Precision Training paper](https://arxiv.org/abs/1710.03740) 103 | and 104 | [Training With Mixed Precision documentation](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). 105 | 106 | For PyTorch, easily adding mixed-precision support is available from NVIDIA’s 107 | [APEX](https://github.com/NVIDIA/apex), a PyTorch extension, that contains 108 | utility libraries, such as AMP, which require minimal network code changes to 109 | leverage Tensor Core performance. 110 | 111 | # Quick start guide 112 | 113 | ## Geting the data 114 | 115 | The ResNet50 v1.5 script operates on ImageNet 1k, a widely popular image classification dataset from ILSVRC challenge. 116 | 117 | PyTorch can work directly on JPEGs, therefore, preprocessing/augmentation is not needed. 118 | 119 | 1. Download the images from http://image-net.org/download-images 120 | 121 | 2. Extract the training data: 122 | ```bash 123 | mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train 124 | tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar 125 | find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done 126 | cd .. 127 | ``` 128 | 129 | 3. Extract the validation data and move the images to subfolders: 130 | ```bash 131 | mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar 132 | wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash 133 | ``` 134 | 135 | The directory in which the `train/` and `val/` directories are placed, is referred to as `` in this document. 136 | 137 | ## Running training 138 | 139 | To run training for a standard configuration (1/4/8 GPUs, FP16/FP32), 140 | run one of the scripts in the `./examples` directory 141 | called `./examples/RN50_{FP16, FP32}_{1, 4, 8}GPU.sh`. 142 | 143 | Ensure imagenet is mounted in the `/data/imagenet` directory. 144 | 145 | To run a non standard configuration use: 146 | 147 | * For 1 GPU 148 | * FP32 149 | `python ./main.py --arch resnet50 -c fanin --label-smoothing 0.1 ` 150 | * FP16 151 | `python ./main.py --arch resnet50 -c fanin --label-smoothing 0.1 --fp16 --static-loss-scale 256 ` 152 | 153 | * For multiple GPUs 154 | * FP32 155 | `python ./multiproc.py --nproc_per_node 8 ./main.py --arch resnet50 -c fanin --label-smoothing 0.1 ` 156 | * FP16 157 | `python ./multiproc.py --nproc_per_node 8 ./main.py --arch resnet50 -c fanin --label-smoothing 0.1 --fp16 --static-loss-scale 256 ` 158 | 159 | Use `python ./main.py -h` to obtain the list of available options in the `main.py` script. 160 | 161 | ## Running inference 162 | 163 | To run inference on a checkpointed model run: 164 | 165 | `python ./main.py --arch resnet50 --evaluate --resume -b ` 166 | 167 | ## Benchmarking 168 | 169 | ### Training performance 170 | 171 | To benchmark training, run: 172 | 173 | * For 1 GPU 174 | * FP32 175 | `python ./main.py --arch resnet50 --benchmark-training ` 176 | * FP16 177 | `python ./main.py --arch resnet50 --benchmark-training --fp16 --static-loss-scale 256 ` 178 | * For multiple GPUs 179 | * FP32 180 | `python ./multiproc.py --nproc_per_node 8 ./main.py --arch resnet50 --benchmark-training ` 181 | * FP16 182 | `python ./multiproc.py --nproc_per_node 8 ./main.py --arch resnet50 --benchmark-training --fp16 --static-loss-scale 256 ` 183 | 184 | Each of this scripts will run 1 warmup iteration and measure the next 10 iterations. 185 | 186 | To control warmup and benchmark length, use the `--bench-warmup` and `--bench-iterations` flags. 187 | 188 | ### Inference performance 189 | 190 | To benchmark inference, run: 191 | 192 | * FP32 193 | 194 | `python ./main.py --arch resnet50 --benchmark-inference ` 195 | 196 | * FP16 197 | 198 | `python ./main.py --arch resnet50 --benchmark-inference --fp16 ` 199 | 200 | Each of this scripts will run 1 warmup iteration and measure the next 10 iterations. 201 | 202 | To control warmup and benchmark length, use `--bench-warmup` and `--bench-iterations` flags. 203 | 204 | ## Training Accuracy Results 205 | 206 | The following results were obtained by running the `./examples/RN50_{FP16, FP32}_{1, 4, 8}GPU.sh` scripts in 207 | the pytorch-18.09-py3 Docker container on NVIDIA DGX-1 with 8 V100 16G GPUs. 208 | 209 | | **mixed precision top1** | **FP32 top1** | 210 | |:------------------------:|:---------------:| 211 | | 76.71 +/- 0.11 | 76.83 +/- 0.11 | 212 | 213 | | **number of GPUs** | **mixed precision training time** | **FP32 training time** | 214 | |:------------------:|:---------------------------------:|:----------------------:| 215 | | 1 | 45.4h | 89.2h | 216 | | 4 | 13.5h | 25.6h | 217 | | 8 | 8.1h | 13.9h | 218 | 219 | Here are example graphs of FP32 and FP16 training on 8 GPU configuration: 220 | 221 | ![TrainingLoss](./img/training_loss.png) 222 | 223 | ![TrainingAccuracy](./img/training_accuracy.png) 224 | 225 | ![ValidationAccuracy](./img/validation_accuracy.png) 226 | 227 | 228 | ## Training Performance Results 229 | 230 | | **number of GPUs** | **mixed precision img/s** | **FP32 img/s** | **mixed precision speedup** | **mixed precision weak scaling** | **FP32 weak scaling** | 231 | |:------------------:|:-------------------------:|:--------------:|:---------------------------:|:--------------------------------:|:---------------------:| 232 | | 1 | 747.3 | 363.1 | 2.06 | 1.00 | 1.00 | 233 | | 4 | 2886.9 | 1375.5 | 2.1 | 3.86 | 3.79 | 234 | | 8 | 5815.8 | 2857.9 | 2.03 | 7.78 | 7.87 | 235 | 236 | 237 | ## Inference Performance Results 238 | 239 | | **batch size** | **mixed precision img/s** | **FP32 img/s** | 240 | |:--------------:|:-------------------------:|:--------------:| 241 | | 1 | 131.8 | 134.9 | │ 242 | | 2 | 248.7 | 260.6 | │ 243 | | 4 | 486.4 | 425.5 | │ 244 | | 8 | 908.5 | 783.6 | │ 245 | | 16 | 1370.6 | 998.9 | │ 246 | | 32 | 2287.5 | 1092.3 | │ 247 | | 64 | 2476.2 | 1166.6 | │ 248 | | 128 | 2615.6 | 1215.6 | │ 249 | | 256 | 2696.7 | N/A | 250 | 251 | # Changelog 252 | 253 | 1. September 2018 254 | * Initial release 255 | 2. January 2019 256 | * Added options Label Smoothing, fan-in initialization, skipping weight decay on batch norm gamma and bias. 257 | 258 | # Known issues 259 | 260 | There are no known issues with this model. 261 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP16_1GPU.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 training in FP16 on 1 GPUs using 256 batch size (256 per GPU) 2 | # Usage ./RN50_FP16_1GPU.sh 3 | 4 | python $1/main.py -j5 -p 500 --arch resnet50 -c fanin --label-smoothing 0.1 -b 96 --lr 0.1 --epochs 90 --fp16 --static-loss-scale 256 --sparse $2 /home/tim/data/imagenet/ 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP16_4GPU.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 training in FP16 on 4 GPUs using 1024 batch size (256 per GPU) 2 | # Usage ./RN50_FP16_4GPU.sh 3 | 4 | python $1/multiproc.py --nproc_per_node 4 $1/main.py -j5 -p 500 --arch resnet50 -c fanin --label-smoothing 0.1 -b 256 --lr 0.4 --warmup 5 --epochs 90 --fp16 --static-loss-scale 256 $2 /home/tim/data/imagenet 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP16_8GPU.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 training in FP16 on 8 GPUs using 2048 batch size (256 per GPU) 2 | # Usage ./RN50_FP16_8GPU.sh 3 | 4 | python $1/multiproc.py --nproc_per_node 8 $1/main.py -j5 -p 500 --arch resnet50 -c fanin --label-smoothing 0.1 -b 256 --lr 0.8 --warmup 5 --epochs 90 --fp16 --static-loss-scale 256 $2 /data/imagenet 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP16_EVAL.sh: -------------------------------------------------------------------------------- 1 | # This script evaluates ResNet50 model in FP16 using 64 batch size on 1 GPU 2 | # Usage: ./RN50_FP16_EVAL.sh 3 | 4 | python $1/main.py -j5 p 100 --arch resnet50 -b 256 --resume $2 --evaluate --fp16 /data/imagenet 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP16_INFERENCE_BENCHMARK.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 inference benchmark in FP16 on 1 GPU with 256 batch size 2 | 3 | python ./main.py -j5 --arch resnet50 -b 256 --fp16 --benchmark-inference /data/imagenet 4 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP32_1GPU.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 training in FP32 on 1 GPUs using 128 batch size (128 per GPU) 2 | # Usage ./RN50_FP32_1GPU.sh 3 | 4 | python $1/main.py -j5 -p 300 --arch resnet50 -c fanin --label-smoothing 0.1 -b 96 --lr 0.05 --epochs 90 --sparse $2 /home/tim/data/imagenet 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP32_4GPU.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 training in FP32 on 4 GPUs using 512 batch size (128 per GPU) 2 | # Usage ./RN50_FP32_4GPU.sh 3 | 4 | python $1/multiproc.py --nproc_per_node 4 $1/main.py -j5 -p 300 --arch resnet50 -c fanin --label-smoothing 0.1 -b 64 --lr 0.2 --warmup 5 --epochs 100 --sparse --density 0.2 --gather-checkpoints $2 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP32_8GPU.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 training in FP32 on 8 GPUs using 1024 batch size (128 per GPU) 2 | # Usage ./RN50_FP32_8GPU.sh 3 | 4 | python $1/multiproc.py --nproc_per_node 8 $1/main.py -j5 -p 500 --arch resnet50 -c fanin --label-smoothing 0.1 -b 128 --lr 0.4 --warmup 5 --epochs 90 $2 /data/imagenet 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP32_EVAL.sh: -------------------------------------------------------------------------------- 1 | # This script evaluates ResNet50 model in FP32 using 64 batch size on 1 GPU 2 | # Usage: ./RN50_FP32_EVAL.sh 3 | 4 | python $1/main.py -j5 p 100 --arch resnet50 -b 128 --resume $2 --evaluate /data/imagenet 5 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/examples/RN50_FP32_INFERENCE_BENCHMARK.sh: -------------------------------------------------------------------------------- 1 | # This script launches ResNet50 inference benchmark in FP32 on 1 GPU with 128 batch size 2 | 3 | python ./main.py -j5 --arch resnet50 -b 128 --benchmark-inference /data/imagenet 4 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/img/training_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/sparse_learning/e19ff16c475462d50ad247ec2927113617662a4e/imagenet/tuned_resnet/img/training_accuracy.png -------------------------------------------------------------------------------- /imagenet/tuned_resnet/img/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/sparse_learning/e19ff16c475462d50ad247ec2927113617662a4e/imagenet/tuned_resnet/img/training_loss.png -------------------------------------------------------------------------------- /imagenet/tuned_resnet/img/validation_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/sparse_learning/e19ff16c475462d50ad247ec2927113617662a4e/imagenet/tuned_resnet/img/validation_accuracy.png -------------------------------------------------------------------------------- /imagenet/tuned_resnet/multiproc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | import os 4 | import socket 5 | import time 6 | from argparse import ArgumentParser, REMAINDER 7 | 8 | import torch 9 | 10 | def parse_args(): 11 | """ 12 | Helper function parsing the command line options 13 | @retval ArgumentParser 14 | """ 15 | parser = ArgumentParser(description="PyTorch distributed training launch " 16 | "helper utilty that will spawn up " 17 | "multiple distributed processes") 18 | 19 | # Optional arguments for the launch helper 20 | parser.add_argument("--nnodes", type=int, default=1, 21 | help="The number of nodes to use for distributed " 22 | "training") 23 | parser.add_argument("--node_rank", type=int, default=0, 24 | help="The rank of the node for multi-node distributed " 25 | "training") 26 | parser.add_argument("--nproc_per_node", type=int, default=1, 27 | help="The number of processes to launch on each node, " 28 | "for GPU training, this is recommended to be set " 29 | "to the number of GPUs in your system so that " 30 | "each process can be bound to a single GPU.") 31 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 32 | help="Master node (rank 0)'s address, should be either " 33 | "the IP address or the hostname of node 0, for " 34 | "single node multi-proc training, the " 35 | "--master_addr can simply be 127.0.0.1") 36 | parser.add_argument("--master_port", default=29500, type=int, 37 | help="Master node (rank 0)'s free port that needs to " 38 | "be used for communciation during distributed " 39 | "training") 40 | 41 | # positional 42 | parser.add_argument("training_script", type=str, 43 | help="The full path to the single GPU training " 44 | "program/script to be launched in parallel, " 45 | "followed by all the arguments for the " 46 | "training script") 47 | 48 | # rest from the training program 49 | parser.add_argument('training_script_args', nargs=REMAINDER) 50 | return parser.parse_args() 51 | 52 | 53 | def main(): 54 | args = parse_args() 55 | 56 | # world size in terms of number of processes 57 | dist_world_size = args.nproc_per_node * args.nnodes 58 | 59 | # set PyTorch distributed related environmental variables 60 | current_env = os.environ.copy() 61 | current_env["MASTER_ADDR"] = args.master_addr 62 | current_env["MASTER_PORT"] = str(args.master_port) 63 | current_env["WORLD_SIZE"] = str(dist_world_size) 64 | 65 | processes = [] 66 | 67 | for local_rank in range(0, args.nproc_per_node): 68 | # each process's rank 69 | dist_rank = args.nproc_per_node * args.node_rank + local_rank 70 | current_env["RANK"] = str(dist_rank) 71 | 72 | # spawn the processes 73 | cmd = [sys.executable, 74 | "-u", 75 | args.training_script, 76 | "--local_rank={}".format(local_rank)] + args.training_script_args 77 | 78 | print(cmd) 79 | 80 | stdout = None if local_rank == 0 else open("GPU_"+str(local_rank)+".log", "w") 81 | 82 | process = subprocess.Popen(cmd, env=current_env, stdout=stdout) 83 | processes.append(process) 84 | 85 | try: 86 | up = True 87 | error = False 88 | while up and not error: 89 | up = False 90 | for p in processes: 91 | ret = p.poll() 92 | if ret is None: 93 | up = True 94 | elif ret != 0: 95 | error = True 96 | time.sleep(1) 97 | 98 | if error: 99 | for p in processes: 100 | if p.poll() is None: 101 | p.terminate() 102 | exit(1) 103 | 104 | except KeyboardInterrupt: 105 | for p in processes: 106 | p.terminate() 107 | raise 108 | except SystemExit: 109 | for p in processes: 110 | p.terminate() 111 | raise 112 | except: 113 | for p in processes: 114 | p.terminate() 115 | raise 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ['ResNet', 'build_resnet', 'resnet_versions', 'resnet_configs'] 6 | 7 | # ResNetBuilder {{{ 8 | 9 | class ResNetBuilder(object): 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | def conv(self, kernel_size, in_planes, out_planes, stride=1): 14 | if kernel_size == 3: 15 | conv = self.config['conv']( 16 | in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | elif kernel_size == 1: 19 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 20 | bias=False) 21 | elif kernel_size == 7: 22 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, 23 | padding=3, bias=False) 24 | else: 25 | return None 26 | 27 | nn.init.kaiming_normal_(conv.weight, 28 | mode=self.config['conv_init'], 29 | nonlinearity='relu') 30 | 31 | return conv 32 | 33 | def conv3x3(self, in_planes, out_planes, stride=1): 34 | """3x3 convolution with padding""" 35 | c = self.conv(3, in_planes, out_planes, stride=stride) 36 | return c 37 | 38 | def conv1x1(self, in_planes, out_planes, stride=1): 39 | """1x1 convolution with padding""" 40 | c = self.conv(1, in_planes, out_planes, stride=stride) 41 | return c 42 | 43 | def conv7x7(self, in_planes, out_planes, stride=1): 44 | """7x7 convolution with padding""" 45 | c = self.conv(7, in_planes, out_planes, stride=stride) 46 | return c 47 | 48 | def batchnorm(self, planes): 49 | bn = nn.BatchNorm2d(planes) 50 | nn.init.constant_(bn.weight, 1) 51 | nn.init.constant_(bn.bias, 0) 52 | 53 | return bn 54 | 55 | def activation(self): 56 | return nn.ReLU(inplace=True) 57 | 58 | # ResNetBuilder }}} 59 | 60 | # BasicBlock {{{ 61 | class BasicBlock(nn.Module): 62 | expansion = 1 63 | 64 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None): 65 | super(BasicBlock, self).__init__() 66 | self.conv1 = builder.conv3x3(inplanes, planes, stride) 67 | self.bn1 = builder.batchnorm(planes) 68 | self.relu = builder.activation() 69 | self.conv2 = builder.conv3x3(planes, planes) 70 | self.bn2 = builder.batchnorm(planes) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | if self.bn1 is not None: 79 | out = self.bn1(out) 80 | 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | if self.bn2 is not None: 85 | out = self.bn2(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | # BasicBlock }}} 95 | 96 | # Bottleneck {{{ 97 | class Bottleneck(nn.Module): 98 | expansion = 4 99 | 100 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None): 101 | super(Bottleneck, self).__init__() 102 | self.conv1 = builder.conv1x1(inplanes, planes) 103 | self.bn1 = builder.batchnorm(planes) 104 | self.conv2 = builder.conv3x3(planes, planes, stride=stride) 105 | self.bn2 = builder.batchnorm(planes) 106 | self.conv3 = builder.conv1x1(planes, planes * self.expansion) 107 | self.bn3 = builder.batchnorm(planes * self.expansion) 108 | self.relu = builder.activation() 109 | self.downsample = downsample 110 | self.stride = stride 111 | 112 | def forward(self, x): 113 | residual = x 114 | 115 | out = self.conv1(x) 116 | if self.bn1 is not None: 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | if self.bn2 is not None: 122 | out = self.bn2(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv3(out) 126 | if self.bn3 is not None: 127 | out = self.bn3(out) 128 | 129 | if self.downsample is not None: 130 | residual = self.downsample(x) 131 | 132 | out += residual 133 | 134 | out = self.relu(out) 135 | 136 | return out 137 | # Bottleneck }}} 138 | 139 | # ResNet {{{ 140 | class ResNet(nn.Module): 141 | def __init__(self, builder, block, layers, num_classes=1000): 142 | self.inplanes = 64 143 | super(ResNet, self).__init__() 144 | self.conv1 = builder.conv7x7(3, 64, stride=2) 145 | self.bn1 = builder.batchnorm(64) 146 | self.relu = builder.activation() 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(builder, block, 64, layers[0]) 149 | self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2) 151 | self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2) 152 | self.avgpool = nn.AdaptiveAvgPool2d(1) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | def _make_layer(self, builder, block, planes, blocks, stride=1): 156 | downsample = None 157 | if stride != 1 or self.inplanes != planes * block.expansion: 158 | dconv = builder.conv1x1(self.inplanes, planes * block.expansion, 159 | stride=stride) 160 | dbn = builder.batchnorm(planes * block.expansion) 161 | if dbn is not None: 162 | downsample = nn.Sequential(dconv, dbn) 163 | else: 164 | downsample = dconv 165 | 166 | layers = [] 167 | layers.append(block(builder, self.inplanes, planes, stride, downsample)) 168 | self.inplanes = planes * block.expansion 169 | for i in range(1, blocks): 170 | layers.append(block(builder, self.inplanes, planes)) 171 | 172 | return nn.Sequential(*layers) 173 | 174 | def forward(self, x): 175 | x = self.conv1(x) 176 | if self.bn1 is not None: 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | x = self.maxpool(x) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | 186 | x = self.avgpool(x) 187 | x = x.view(x.size(0), -1) 188 | x = self.fc(x) 189 | 190 | return x 191 | # ResNet }}} 192 | 193 | 194 | resnet_configs = { 195 | 'classic' : { 196 | 'conv' : nn.Conv2d, 197 | 'conv_init' : 'fan_out', 198 | }, 199 | 'fanin' : { 200 | 'conv' : nn.Conv2d, 201 | 'conv_init' : 'fan_in', 202 | }, 203 | } 204 | 205 | resnet_versions = { 206 | 'resnet18' : { 207 | 'block' : BasicBlock, 208 | 'layers' : [2, 2, 2, 2], 209 | 'num_classes' : 1000, 210 | }, 211 | 'resnet34' : { 212 | 'block' : BasicBlock, 213 | 'layers' : [3, 4, 6, 3], 214 | 'num_classes' : 1000, 215 | }, 216 | 'resnet50' : { 217 | 'block' : Bottleneck, 218 | 'layers' : [3, 4, 6, 3], 219 | 'num_classes' : 1000, 220 | }, 221 | 'resnet101' : { 222 | 'block' : Bottleneck, 223 | 'layers' : [3, 4, 23, 3], 224 | 'num_classes' : 1000, 225 | }, 226 | 'resnet152' : { 227 | 'block' : Bottleneck, 228 | 'layers' : [3, 8, 36, 3], 229 | 'num_classes' : 1000, 230 | }, 231 | } 232 | 233 | 234 | 235 | def build_resnet(version, config, model_state=None): 236 | version = resnet_versions[version] 237 | config = resnet_configs[config] 238 | 239 | builder = ResNetBuilder(config) 240 | print("Version: {}".format(version)) 241 | print("Config: {}".format(config)) 242 | model = ResNet(builder, 243 | version['block'], 244 | version['layers'], 245 | version['num_classes']) 246 | 247 | return model 248 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/scripts/extract_summary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from pprint import pprint 5 | 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser(description='summary extractor') 9 | parser.add_argument('filename', type=Path, 10 | help='path to logfile') 11 | parser.add_argument('-H', '--human-readable', action='store_const', const=True, default=False, 12 | help='human readable') 13 | parser.add_argument('--csv', action='store_const', const=True, default=False, 14 | help='print in csv format') 15 | return parser.parse_args() 16 | 17 | def extract_summary(content): 18 | train_summary = [] 19 | eval_summary = [] 20 | 21 | current_epoch = -1 22 | for line in content.splitlines(): 23 | words = line.split() 24 | if line.startswith('Train '): 25 | epoch = int(words[2]) 26 | loss = float(words[-5]) 27 | top1 = float(words[-3]) 28 | top5 = float(words[-1]) 29 | 30 | current_epoch += 1 31 | assert epoch == current_epoch 32 | 33 | train_summary.append({'loss': loss, 'top1': top1, 'top5': top5}) 34 | 35 | if line.startswith('Eval '): 36 | print(words) 37 | loss = float(words[-5]) 38 | top1 = float(words[-3]) 39 | top5 = float(words[-1]) 40 | 41 | eval_summary.append({'loss': loss, 'top1': top1, 'top5': top5}) 42 | 43 | return train_summary, eval_summary 44 | 45 | def main(args): 46 | with open(str(args.filename)) as file: 47 | content = file.read() 48 | 49 | train_summary, eval_summary = extract_summary(content) 50 | 51 | 52 | if args.human_readable: 53 | print('Train summary:') 54 | pprint(train_summary) 55 | print('Eval summary:') 56 | pprint(eval_summary) 57 | elif args.csv: 58 | print('train_loss', 'train_top1', 'train_top5', 59 | 'eval_loss', 'eval_top1', 'eval_top5', 60 | sep=',', end=',\n') 61 | for summaries in zip(train_summary, eval_summary): 62 | for summary in summaries: 63 | print(summary['loss'], summary['top1'], summary['top5'], sep=',', end=',') 64 | print() 65 | else: 66 | result = {'train': train_summary, 'eval': eval_summary} 67 | print(json.dumps(result)) 68 | 69 | if __name__ == '__main__': 70 | args = parse_arguments() 71 | main(args) 72 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LabelSmoothing(nn.Module): 5 | """ 6 | NLL loss with label smoothing. 7 | """ 8 | def __init__(self, smoothing=0.0): 9 | """ 10 | Constructor for the LabelSmoothing module. 11 | 12 | :param smoothing: label smoothing factor 13 | """ 14 | super(LabelSmoothing, self).__init__() 15 | self.confidence = 1.0 - smoothing 16 | self.smoothing = smoothing 17 | 18 | def forward(self, x, target): 19 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 20 | 21 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 22 | nll_loss = nll_loss.squeeze(1) 23 | smooth_loss = -logprobs.mean(dim=-1) 24 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 25 | return loss.mean() 26 | 27 | -------------------------------------------------------------------------------- /imagenet/tuned_resnet/sparse_momentum_logs.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/sparse_learning/e19ff16c475462d50ad247ec2927113617662a4e/imagenet/tuned_resnet/sparse_momentum_logs.tar.gz -------------------------------------------------------------------------------- /mnist_cifar/extensions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | # Through the masking variable we have access to the following variables/statistics. 5 | ''' 6 | Access to optimizer: 7 | masking.optimizer 8 | 9 | Access to momentum/Adam update: 10 | masking.get_momentum_for_weight(weight) 11 | 12 | Accessable global statistics: 13 | 14 | Layer statistics: 15 | Non-zero count of layer: 16 | masking.name2nonzeros[name] 17 | Zero count of layer: 18 | masking.name2zeros[name] 19 | Redistribution proportion: 20 | masking.name2variance[name] 21 | Number of items removed through pruning: 22 | masking.name2removed[name] 23 | 24 | Network statistics: 25 | Total number of nonzero parameter in the network: 26 | masking.total_nonzero = 0 27 | Total number of zero-valued parameter in the network: 28 | masking.total_zero = 0 29 | Total number of parameters removed in pruning: 30 | masking.total_removed = 0 31 | ''' 32 | 33 | def your_redistribution(masking, name, weight, mask): 34 | ''' 35 | Returns: 36 | Layer importance The unnormalized layer importance statistic 37 | for the layer "name". A higher value indicates 38 | that more pruned parameters are redistributed 39 | to this layer compared to layers with lower value. 40 | The values will be automatically sum-normalized 41 | after this step. 42 | ''' 43 | return layer_importance 44 | 45 | #===========================================================# 46 | # EXAMPLE # 47 | #===========================================================# 48 | def variance_redistribution(masking, name, weight, mask): 49 | '''Return the mean variance of existing weights. 50 | 51 | Higher gradient variance means a layer does not have enough 52 | capacity to model the inputs with the current number of weights. 53 | Thus we want to add more weights if we have higher variance. 54 | If variance of the gradient stabilizes this means 55 | that some weights might be useless/not needed. 56 | ''' 57 | # Adam calculates the running average of the sum of square for us 58 | # This is similar to RMSProp. 59 | if 'exp_avg_sq' not in masking.optimizer.state[weight]: 60 | print('Variance redistribution requires the adam optimizer to be run!') 61 | raise Exception('Variance redistribution requires the adam optimizer to be run!') 62 | iv_adam_sumsq = torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq']) 63 | 64 | layer_importance = iv_adam_sumsq[mask.bool()].mean().item() 65 | return layer_importance 66 | 67 | 68 | def your_pruning(masking, mask, weight, name): 69 | """Returns: 70 | mask Pruned Binary mask where 1s indicated active 71 | weights. Can be modified in-place or newly 72 | constructed 73 | """ 74 | return mask 75 | 76 | #===========================================================# 77 | # EXAMPLE # 78 | #===========================================================# 79 | def magnitude_variance_pruning(masking, mask, weight, name): 80 | ''' Prunes weights which have high gradient variance and low magnitude. 81 | 82 | Intuition: Weights that are large are important but there is also a dimension 83 | of reliability. If a large weight makes a large correct prediction 8/10 times 84 | is it better than a medium weight which makes a correct prediction 10/10 times? 85 | To test this, we combine magnitude (importance) with reliability (variance of 86 | gradient). 87 | 88 | Good: 89 | Weights with large magnitude and low gradient variance are the most important. 90 | Weights with medium variance/magnitude are promising for improving network performance. 91 | Bad: 92 | Weights with large magnitude but high gradient variance hurt performance. 93 | Weights with small magnitude and low gradient variance are useless. 94 | Weights with small magnitude and high gradient variance cannot learn anything usefull. 95 | 96 | We here take the geometric mean of those both normalized distribution to find weights to prune. 97 | ''' 98 | # Adam calculates the running average of the sum of square for us 99 | # This is similar to RMSProp. We take the inverse of this to rank 100 | # low variance gradients higher. 101 | if 'exp_avg_sq' not in masking.optimizer.state[weight]: 102 | print('Magnitude variance pruning requires the adam optimizer to be run!') 103 | raise Exception('Magnitude variance pruning requires the adam optimizer to be run!') 104 | iv_adam_sumsq = 1./torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq']) 105 | 106 | num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name]) 107 | 108 | num_zeros = masking.name2zeros[name] 109 | k = math.ceil(num_zeros + num_remove) 110 | if num_remove == 0.0: return weight.data != 0.0 111 | 112 | max_var = iv_adam_sumsq[mask.bool()].max().item() 113 | max_magnitude = torch.abs(weight.data[mask.bool()]).max().item() 114 | product = ((iv_adam_sumsq/max_var)*torch.abs(weight.data)/max_magnitude)*mask 115 | product[mask==0] = 0.0 116 | 117 | x, idx = torch.sort(product.view(-1)) 118 | mask.data.view(-1)[idx[:k]] = 0.0 119 | return mask 120 | 121 | 122 | def your_growth(masking, name, new_mask, total_regrowth, weight): 123 | ''' 124 | Returns: 125 | mask Binary mask with newly grown weights. 126 | 1s indicated active weights in the binary mask. 127 | ''' 128 | return new_mask 129 | 130 | 131 | -------------------------------------------------------------------------------- /mnist_cifar/get_results_from_logs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import argparse 5 | import hashlib 6 | import copy 7 | import shlex 8 | import sparselearning 9 | from sparselearning.core import Masking, CosineDecay, LinearDecay 10 | 11 | 12 | from os.path import join 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 15 | parser.add_argument('--all', action='store_true', help='Displays individual final results.') 16 | parser.add_argument('--folder-path', type=str, default=None, help='The folder to evaluate if running in folder mode.') 17 | parser.add_argument('--recursive', action='store_true', help='Apply folder-path mode to all sub-directories') 18 | parser.add_argument('--agg-config', action='store_true', help='Aggregate same configs within folders') 19 | parser.add_argument('--filter', type=str, default='', help='Filters by argument.') 20 | 21 | parser_cmd = argparse.ArgumentParser(description='PyTorch MNIST Example') 22 | parser_cmd.add_argument('--batch-size', type=int, default=100, metavar='N',help='input batch size for training (default: 100)') 23 | parser_cmd.add_argument('--test-batch-size', type=int, default=100, metavar='N',help='input batch size for testing (default: 100)') 24 | parser_cmd.add_argument('--epochs', type=int, default=100, metavar='N',help='number of epochs to train (default: 100)') 25 | parser_cmd.add_argument('--lr', type=float, default=0.1, metavar='LR',help='learning rate (default: 0.1)') 26 | parser_cmd.add_argument('--momentum', type=float, default=0.9, metavar='M',help='SGD momentum (default: 0.9)') 27 | parser_cmd.add_argument('--no-cuda', action='store', default=False,help='disables CUDA training') 28 | parser_cmd.add_argument('--seed', type=int, default=17, metavar='S', help='random seed (default: 17)') 29 | parser_cmd.add_argument('--log-interval', type=int, default=100, metavar='N',help='how many batches to wait before logging training status') 30 | parser_cmd.add_argument('--optimizer', type=str, default='sgd', help='The optimizer to use. Default: sgd. Options: sgd, adam.') 31 | parser_cmd.add_argument('--save-model', type=str, default='./models/model.pt', help='For Saving the current Model') 32 | parser_cmd.add_argument('--data', type=str, default='mnist') 33 | parser_cmd.add_argument('--decay_frequency', type=int, default=25000) 34 | parser_cmd.add_argument('--l1', type=float, default=0.0) 35 | parser_cmd.add_argument('--fp16', action='store', help='Run in fp16 mode.') 36 | parser_cmd.add_argument('--valid_split', type=float, default=0.1) 37 | parser_cmd.add_argument('--resume', type=str) 38 | parser_cmd.add_argument('--start-epoch', type=int, default=1) 39 | parser_cmd.add_argument('--model', type=str, default='') 40 | parser_cmd.add_argument('--l2', type=float, default=5.0e-4) 41 | parser_cmd.add_argument('--iters', type=int, default=1, help='How many times the model should be run after each other. Default=1') 42 | parser_cmd.add_argument('--save-features', action='store', help='Resumes a saved model and saves its feature data to disk for plotting.') 43 | parser_cmd.add_argument('--bench', action='store', help='Enables the benchmarking of layers and estimates sparse speedups') 44 | parser_cmd.add_argument('--max-threads', type=int, default=10, help='How many threads to use for data loading.') 45 | parser_cmd.add_argument('--decay-schedule', type=str, default='cosine', help='The decay schedule for the pruning rate. Default: cosine. Choose from: cosine, linear.') 46 | sparselearning.core.add_sparse_args(parser_cmd) 47 | 48 | 49 | 50 | args = parser.parse_args() 51 | 52 | if args.recursive: 53 | folders = [x[0] for x in os.walk(args.folder_path)] 54 | else: 55 | folders = [args.folder_path if args.folder_path else './logs'] 56 | 57 | def calc_and_print_data(args, accs, losses, arg): 58 | acc_std = np.std(accs, ddof=1) 59 | acc_se = acc_std/np.sqrt(len(accs)) 60 | 61 | loss_std = np.std(losses, ddof=1) 62 | loss_se = loss_std/np.sqrt(len(losses)) 63 | 64 | print('='*85) 65 | print('Test set results logs in folder: {0}'.format(folder)) 66 | print('Arguments:\n{0}\n'.format(arg)) 67 | print('Accuracy. Median: {5:.5f}, Mean: {0:.5f}, Standard Error: {1:.5f}, Sample size: {2}, 95% CI: ({3:.5f},{4:.5f})'.format(np.mean(accs), acc_se, len(accs), 68 | np.mean(accs)-(1.96*acc_se), np.mean(accs)+(1.96*acc_se), np.median(accs))) 69 | print('Error. Median: {5:.5f}, Mean: {0:.5f}, Standard Error: {1:.5f}, Sample size: {2}, 95% CI: ({3:.5f},{4:.5f})'.format(1.0-np.mean(accs), acc_se, len(accs), 70 | (1.0-np.mean(accs))-(1.96*acc_se), (1.0-np.mean(accs))+(1.96*acc_se), 1.0-np.median(accs))) 71 | print('Loss. Median: {5:.5f}, Mean: {0:.5f}, Standard Error: {1:.5f}, Sample size: {2}, 95% CI: ({3:.5f},{4:.5f})'.format(np.mean(losses), loss_se, len(losses), 72 | np.mean(losses)-(1.96*loss_se), np.mean(losses)+(1.96*loss_se), np.median(losses))) 73 | print('='*85) 74 | 75 | if args.all: 76 | print('Individual results:') 77 | for loss, acc in zip(losses, accs): 78 | err = 1.0-acc 79 | print('Loss: {0:.5f}, Accuracy: {1:.5f}, Error: {2:.5f}'.format(loss, acc, err)) 80 | 81 | 82 | losses = [] 83 | accs = [] 84 | hash2accs = {} 85 | hash2losses = {} 86 | hash2config = {} 87 | for folder in folders: 88 | losses = [] 89 | accs = [] 90 | for log_name in glob.iglob(join(folder, '*.log')): 91 | if not args.folder_path: 92 | losses = [] 93 | accs = [] 94 | arg = None 95 | skip = False 96 | with open(log_name) as f: 97 | for line in f: 98 | if 'Namespace' in line: 99 | arg = line[10:-2] 100 | if args.agg_config: 101 | if len(args.filter) > 0: 102 | filters = args.filter.split(' ') 103 | for f in filters: 104 | if not f in arg: skip = True 105 | arg = ('--' + arg.replace(', ', ' --')) 106 | arg = arg.replace('dense=False', 'dense') 107 | arg = arg.replace('verbose=False', 'verbose') 108 | arg = arg.replace('verbose=True', 'verbose') 109 | arg = arg.replace('_', '-') 110 | arg = arg.replace('decay-', 'decay_') 111 | arg = arg.replace('valid-', 'valid_') 112 | arg = arg.replace('decay_schedule', 'decay-schedule') 113 | 114 | arg_str = shlex.split(arg) 115 | cmd_args = parser_cmd.parse_args(arg_str) 116 | args_copy = copy.deepcopy(cmd_args) 117 | args_copy.iters = 1 118 | args_copy.verbose = False 119 | args_copy.log_interval = 1 120 | args_copy.seed = 0 121 | args_copy.fp16 = False 122 | args_copy.max_threads = 1 123 | 124 | hsval = hashlib.md5(str(args_copy).encode('utf-8')).hexdigest() 125 | if hsval not in hash2accs: 126 | hash2accs[hsval] = [] 127 | hash2losses[hsval] = [] 128 | hash2config[hsval] = str(args_copy) 129 | 130 | if skip: continue 131 | if not line.startswith('Test evaluation'): continue 132 | try: 133 | loss = float(line[31:37]) 134 | acc = float(line[61:-3])/100 135 | except: 136 | print('Could not convert number: {0}'.format(line[31:37])) 137 | 138 | if args.agg_config: 139 | hash2accs[hsval].append(acc) 140 | hash2losses[hsval].append(loss) 141 | 142 | losses.append(loss) 143 | accs.append(acc) 144 | if len(accs) == 0: continue 145 | 146 | if not args.folder_path: 147 | calc_and_print_data(args, accs, losses, arg) 148 | 149 | if args.folder_path and not args.agg_config: 150 | if len(accs) == 0: 151 | print('Test set results logs in folder {0} empty!'.format(folder)) 152 | continue 153 | 154 | calc_and_print_data(args, accs, losses, arg) 155 | 156 | if args.agg_config: 157 | for hsval in hash2accs: 158 | accs = hash2accs[hsval] 159 | losses = hash2losses[hsval] 160 | arg = hash2config[hsval] 161 | 162 | if len(accs) == 0: 163 | continue 164 | 165 | calc_and_print_data(args, accs, losses, arg) 166 | 167 | -------------------------------------------------------------------------------- /mnist_cifar/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import shutil 5 | import time 6 | import argparse 7 | import logging 8 | import hashlib 9 | import copy 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torch.backends.cudnn as cudnn 15 | import numpy as np 16 | 17 | import sparselearning 18 | from sparselearning.core import Masking, CosineDecay, LinearDecay 19 | from sparselearning.models import AlexNet, VGG16, LeNet_300_100, LeNet_5_Caffe, WideResNet 20 | from sparselearning.utils import get_mnist_dataloaders, get_cifar10_dataloaders, plot_class_feature_histograms 21 | 22 | from extensions import magnitude_variance_pruning, variance_redistribution 23 | 24 | cudnn.benchmark = True 25 | cudnn.deterministic = True 26 | 27 | if not os.path.exists('./models'): os.mkdir('./models') 28 | if not os.path.exists('./logs'): os.mkdir('./logs') 29 | logger = None 30 | 31 | models = {} 32 | models['lenet5'] = (LeNet_5_Caffe,[]) 33 | models['lenet300-100'] = (LeNet_300_100,[]) 34 | models['alexnet-s'] = (AlexNet, ['s', 10]) 35 | models['alexnet-b'] = (AlexNet, ['b', 10]) 36 | models['vgg-c'] = (VGG16, ['C', 10]) 37 | models['vgg-d'] = (VGG16, ['D', 10]) 38 | models['vgg-like'] = (VGG16, ['like', 10]) 39 | models['wrn-28-2'] = (WideResNet, [28, 2, 10, 0.3]) 40 | models['wrn-22-8'] = (WideResNet, [22, 8, 10, 0.3]) 41 | models['wrn-16-8'] = (WideResNet, [16, 8, 10, 0.3]) 42 | models['wrn-16-10'] = (WideResNet, [16, 10, 10, 0.3]) 43 | 44 | def setup_logger(args): 45 | global logger 46 | if logger == None: 47 | logger = logging.getLogger() 48 | else: # wish there was a logger.close() 49 | for handler in logger.handlers[:]: # make a copy of the list 50 | logger.removeHandler(handler) 51 | 52 | args_copy = copy.deepcopy(args) 53 | # copy to get a clean hash 54 | # use the same log file hash if iterations or verbose are different 55 | # these flags do not change the results 56 | args_copy.iters = 1 57 | args_copy.verbose = False 58 | args_copy.log_interval = 1 59 | args_copy.seed = 0 60 | 61 | log_path = './logs/{0}_{1}_{2}.log'.format(args.model, args.density, hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8]) 62 | 63 | logger.setLevel(logging.INFO) 64 | formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S') 65 | 66 | fh = logging.FileHandler(log_path) 67 | fh.setFormatter(formatter) 68 | logger.addHandler(fh) 69 | 70 | def print_and_log(msg): 71 | global logger 72 | print(msg) 73 | logger.info(msg) 74 | 75 | def train(args, model, device, train_loader, optimizer, epoch, lr_scheduler, mask=None): 76 | model.train() 77 | for batch_idx, (data, target) in enumerate(train_loader): 78 | if lr_scheduler is not None: lr_scheduler.step() 79 | data, target = data.to(device), target.to(device) 80 | if args.fp16: data = data.half() 81 | optimizer.zero_grad() 82 | output = model(data) 83 | 84 | loss = F.nll_loss(output, target) 85 | 86 | if args.fp16: 87 | optimizer.backward(loss) 88 | else: 89 | loss.backward() 90 | 91 | if mask is not None: mask.step() 92 | else: optimizer.step() 93 | 94 | if batch_idx % args.log_interval == 0: 95 | print_and_log('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 96 | epoch, batch_idx * len(data), len(train_loader)*args.batch_size, 97 | 100. * batch_idx / len(train_loader), loss.item())) 98 | 99 | def evaluate(args, model, device, test_loader, is_test_set=False): 100 | model.eval() 101 | test_loss = 0 102 | correct = 0 103 | n = 0 104 | with torch.no_grad(): 105 | for data, target in test_loader: 106 | data, target = data.to(device), target.to(device) 107 | if args.fp16: data = data.half() 108 | model.t = target 109 | output = model(data) 110 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 111 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 112 | correct += pred.eq(target.view_as(pred)).sum().item() 113 | n += target.shape[0] 114 | 115 | test_loss /= float(n) 116 | 117 | print_and_log('\n{}: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format( 118 | 'Test evaluation' if is_test_set else 'Evaluation', 119 | test_loss, correct, n, 100. * correct / float(n))) 120 | return correct / float(n) 121 | 122 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 123 | torch.save(state, filename) 124 | if is_best: 125 | shutil.copyfile(filename, 'model_best.pth.tar') 126 | 127 | def main(): 128 | # Training settings 129 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 130 | parser.add_argument('--batch-size', type=int, default=100, metavar='N', 131 | help='input batch size for training (default: 100)') 132 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 133 | help='input batch size for testing (default: 100)') 134 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 135 | help='number of epochs to train (default: 100)') 136 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 137 | help='learning rate (default: 0.1)') 138 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 139 | help='SGD momentum (default: 0.9)') 140 | parser.add_argument('--no-cuda', action='store_true', default=False, 141 | help='disables CUDA training') 142 | parser.add_argument('--seed', type=int, default=17, metavar='S', help='random seed (default: 17)') 143 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 144 | help='how many batches to wait before logging training status') 145 | parser.add_argument('--optimizer', type=str, default='sgd', help='The optimizer to use. Default: sgd. Options: sgd, adam.') 146 | parser.add_argument('--save-model', type=str, default='./models/model.pt', help='For Saving the current Model') 147 | parser.add_argument('--data', type=str, default='mnist') 148 | parser.add_argument('--decay_frequency', type=int, default=25000) 149 | parser.add_argument('--l1', type=float, default=0.0) 150 | parser.add_argument('--fp16', action='store_true', help='Run in fp16 mode.') 151 | parser.add_argument('--valid_split', type=float, default=0.1) 152 | parser.add_argument('--resume', type=str) 153 | parser.add_argument('--start-epoch', type=int, default=1) 154 | parser.add_argument('--model', type=str, default='') 155 | parser.add_argument('--l2', type=float, default=5.0e-4) 156 | parser.add_argument('--iters', type=int, default=1, help='How many times the model should be run after each other. Default=1') 157 | parser.add_argument('--save-features', action='store_true', help='Resumes a saved model and saves its feature data to disk for plotting.') 158 | parser.add_argument('--bench', action='store_true', help='Enables the benchmarking of layers and estimates sparse speedups') 159 | parser.add_argument('--max-threads', type=int, default=10, help='How many threads to use for data loading.') 160 | parser.add_argument('--decay-schedule', type=str, default='cosine', help='The decay schedule for the pruning rate. Default: cosine. Choose from: cosine, linear.') 161 | sparselearning.core.add_sparse_args(parser) 162 | 163 | args = parser.parse_args() 164 | setup_logger(args) 165 | print_and_log(args) 166 | 167 | if args.fp16: 168 | try: 169 | from apex.fp16_utils import FP16_Optimizer 170 | except: 171 | print('WARNING: apex not installed, ignoring --fp16 option') 172 | args.fp16 = False 173 | 174 | use_cuda = not args.no_cuda and torch.cuda.is_available() 175 | device = torch.device("cuda" if use_cuda else "cpu") 176 | 177 | print_and_log('\n\n') 178 | print_and_log('='*80) 179 | torch.manual_seed(args.seed) 180 | for i in range(args.iters): 181 | print_and_log("\nIteration start: {0}/{1}\n".format(i+1, args.iters)) 182 | 183 | if args.data == 'mnist': 184 | train_loader, valid_loader, test_loader = get_mnist_dataloaders(args, validation_split=args.valid_split) 185 | else: 186 | train_loader, valid_loader, test_loader = get_cifar10_dataloaders(args, args.valid_split, max_threads=args.max_threads) 187 | 188 | if args.model not in models: 189 | print('You need to select an existing model via the --model argument. Available models include: ') 190 | for key in models: 191 | print('\t{0}'.format(key)) 192 | raise Exception('You need to select a model') 193 | else: 194 | cls, cls_args = models[args.model] 195 | model = cls(*(cls_args + [args.save_features, args.bench])).to(device) 196 | print_and_log(model) 197 | print_and_log('='*60) 198 | print_and_log(args.model) 199 | print_and_log('='*60) 200 | 201 | print_and_log('='*60) 202 | print_and_log('Prune mode: {0}'.format(args.prune)) 203 | print_and_log('Growth mode: {0}'.format(args.growth)) 204 | print_and_log('Redistribution mode: {0}'.format(args.redistribution)) 205 | print_and_log('='*60) 206 | 207 | # add custom prune/growth/redisribution here 208 | if args.prune == 'magnitude_variance': 209 | print('Using magnitude-variance pruning. Switching to Adam optimizer...') 210 | args.prune = magnitude_variance_pruning 211 | args.optimizer = 'adam' 212 | if args.redistribution == 'variance': 213 | print('Using variance redistribution. Switching to Adam optimizer...') 214 | args.redistribution = variance_redistribution 215 | args.optimizer = 'adam' 216 | 217 | optimizer = None 218 | if args.optimizer == 'sgd': 219 | optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.l2, nesterov=True) 220 | elif args.optimizer == 'adam': 221 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.l2) 222 | else: 223 | print('Unknown optimizer: {0}'.format(args.optimizer)) 224 | raise Exception('Unknown optimizer.') 225 | 226 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, args.decay_frequency, gamma=0.1) 227 | 228 | if args.resume: 229 | if os.path.isfile(args.resume): 230 | print_and_log("=> loading checkpoint '{}'".format(args.resume)) 231 | checkpoint = torch.load(args.resume) 232 | args.start_epoch = checkpoint['epoch'] 233 | model.load_state_dict(checkpoint['state_dict']) 234 | optimizer.load_state_dict(checkpoint['optimizer']) 235 | print_and_log("=> loaded checkpoint '{}' (epoch {})" 236 | .format(args.resume, checkpoint['epoch'])) 237 | print_and_log('Testing...') 238 | evaluate(args, model, device, test_loader) 239 | model.feats = [] 240 | model.densities = [] 241 | plot_class_feature_histograms(args, model, device, train_loader, optimizer) 242 | else: 243 | print_and_log("=> no checkpoint found at '{}'".format(args.resume)) 244 | 245 | 246 | if args.fp16: 247 | print('FP16') 248 | optimizer = FP16_Optimizer(optimizer, 249 | static_loss_scale = None, 250 | dynamic_loss_scale = True, 251 | dynamic_loss_args = {'init_scale': 2 ** 16}) 252 | model = model.half() 253 | 254 | mask = None 255 | if not args.dense: 256 | if args.decay_schedule == 'cosine': 257 | decay = CosineDecay(args.prune_rate, len(train_loader)*(args.epochs)) 258 | elif args.decay_schedule == 'linear': 259 | decay = LinearDecay(args.prune_rate, len(train_loader)*(args.epochs)) 260 | mask = Masking(optimizer, decay, prune_rate=args.prune_rate, prune_mode=args.prune, growth_mode=args.growth, redistribution_mode=args.redistribution, 261 | verbose=args.verbose, fp16=args.fp16) 262 | mask.add_module(model, density=args.density) 263 | 264 | for epoch in range(1, args.epochs + 1): 265 | t0 = time.time() 266 | train(args, model, device, train_loader, optimizer, epoch, lr_scheduler, mask) 267 | 268 | if args.valid_split > 0.0: 269 | val_acc = evaluate(args, model, device, valid_loader) 270 | 271 | save_checkpoint({'epoch': epoch + 1, 272 | 'state_dict': model.state_dict(), 273 | 'optimizer' : optimizer.state_dict()}, 274 | is_best=False, filename=args.save_model) 275 | 276 | if not args.dense and epoch < args.epochs: 277 | mask.at_end_of_epoch() 278 | 279 | print_and_log('Current learning rate: {0}. Time taken for epoch: {1:.2f} seconds.\n'.format(optimizer.param_groups[0]['lr'], time.time() - t0)) 280 | 281 | evaluate(args, model, device, test_loader, is_test_set=True) 282 | print_and_log("\nIteration end: {0}/{1}\n".format(i+1, args.iters)) 283 | 284 | if __name__ == '__main__': 285 | main() 286 | -------------------------------------------------------------------------------- /mnist_cifar/plot_feature_histograms.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns; sns.set() 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | from matplotlib import pyplot as plt 6 | from statsmodels.formula.api import ols 7 | from statsmodels.stats.anova import anova_lm 8 | 9 | import scipy.stats as stats 10 | basepath = './results/' 11 | 12 | names = [] 13 | names.append(['AlexNet', '{0}alexnet_dense_feat_data_layer_0.npy'.format(basepath), '{0}alexnet_sparse_feat_data_layer_0.npy'.format(basepath)]) 14 | #names.append(['VGG', 'VGG_dense_feat_data_layer_0.npy', 'VGG_sparse_feat_data_layer_0.npy']) 15 | #names.append(['WRN-16-10', 'WRN-28-2_dense_feat_data_layer_0.npy', 'WRN-28-2_sparse_feat_data_layer_0.npy']) 16 | 17 | # taken from: https://www.marsja.se/three-ways-to-carry-out-2-way-anova-with-python/ 18 | def eta_squared(aov): 19 | aov['eta_sq'] = 'NaN' 20 | aov['eta_sq'] = aov[:-1]['sum_sq']/sum(aov['sum_sq']) 21 | return aov 22 | 23 | def omega_squared(aov): 24 | mse = aov['sum_sq'][-1]/aov['df'][-1] 25 | aov['omega_sq'] = 'NaN' 26 | aov['omega_sq'] = (aov[:-1]['sum_sq']-(aov[:-1]['df']*mse))/(sum(aov['sum_sq'])+mse) 27 | return aov 28 | 29 | layers = [5, 13, 24] 30 | 31 | 32 | for (name, dense_path, sparse_path), max_layers in zip(names, layers): 33 | print(name, dense_path, sparse_path) 34 | densities = np.load(sparse_path.replace('feat_data_layer_0', 'density_data')) 35 | 36 | anova_all = [] 37 | anova_all.append(['', 'y', 'layer_id', 'is_sparse']) 38 | data_id = 1 39 | sparse = [] 40 | dense = [] 41 | for layer_id in range(max_layers): 42 | 43 | dense_data = np.load(dense_path.replace('0', str(layer_id))) 44 | sparse_data = np.load(sparse_path.replace('0', str(layer_id))) 45 | density = densities[layer_id] 46 | 47 | for value in sparse_data: 48 | anova_all.append([data_id, value , layer_id, 1]) 49 | data_id += 1 50 | sparse.append(value) 51 | 52 | for value in dense_data: 53 | anova_all.append([data_id, value , layer_id, 0]) 54 | data_id += 1 55 | dense.append(value) 56 | 57 | 58 | 59 | 60 | 61 | hist, bins = np.histogram(dense_data, bins=np.linspace(0.09, 0.51, 50)) 62 | hist2, bins2 = np.histogram(sparse_data, bins=np.linspace(0.09, 0.51, 50)) 63 | 64 | xlim = np.max([np.max(hist), np.max(hist2)]) 65 | 66 | width = 0.7 * (bins[1] - bins[0]) 67 | center = (bins[:-1] + bins[1:]) / 2 68 | 69 | width2 = 0.7 * (bins2[1] - bins2[0]) 70 | center2 = (bins2[:-1] + bins2[1:]) / 2 71 | 72 | 73 | fig, axes = plt.subplots(ncols=2, sharey=True) 74 | 75 | axes[0].barh(center, hist, align='center', height=width) 76 | axes[0].set(title='Dense') 77 | axes[1].barh(center2, hist2, align='center', height=width2) 78 | axes[1].set(title='Sparse') 79 | axes[0].set_xlim(0, xlim+5) 80 | axes[1].set_xlim(0, xlim+5) 81 | 82 | #axes[0].set_xlabel('Channel Count') 83 | axes[1].set_xlabel('Channel Count', x=0.0) 84 | axes[0].set_ylabel('Class-Specialization') 85 | 86 | axes[0].invert_xaxis() 87 | #axes[0].set(yticks=y, yticklabels=states) 88 | #axes[0].yaxis.tick_right() 89 | 90 | for ax in axes.flat: 91 | ax.margins(0.00) 92 | ax.grid(True) 93 | 94 | 95 | fig.tight_layout(rect=[0, 0.01, 1, 0.97]) 96 | fig.subplots_adjust(wspace=0.0) 97 | title = '{0} Conv2D Layer {1}'.format(name, layer_id+1) 98 | 99 | plt.suptitle(title, x=0.55, y=1.0) 100 | if not os.path.exists('./feat_plots'): 101 | os.mkdir('feat_plots') 102 | plt.savefig('./feat_plots/{0}.png'.format(title)) 103 | #fig.savefig("foo.pdf", bbox_inches='tight') 104 | plt.clf() 105 | 106 | anova_all = np.array(anova_all) 107 | df = pd.DataFrame(data=anova_all[1:,1:],index=anova_all[1:,0].tolist(),columns=anova_all[0,1:].tolist()) 108 | df.colums = ['id', 'y', 'layer_id', 'is_sparse'] 109 | df = df.astype({'y' : 'float32', 'layer_id' : 'int32', 'is_sparse' : 'int32'}) 110 | formula = 'y ~ C(layer_id) + C(is_sparse) + C(layer_id)*C(is_sparse)' 111 | model = ols(formula, df).fit() 112 | aov_table = anova_lm(model, typ=1) 113 | 114 | eta_squared(aov_table) 115 | omega_squared(aov_table) 116 | print(aov_table) 117 | -------------------------------------------------------------------------------- /plot_graphs.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns; sns.set() 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import numpy as np 5 | import matplotlib.cm as cm 6 | import hashlib 7 | 8 | from scipy import stats 9 | 10 | # diverging color-blind colors taken from: https://github.com/drammock/colorblind/blob/master/colorblind.py 11 | # @author: drmccloy 12 | # Created on Thu Sep 1 17:07:57 2016 13 | # License: MIT License 14 | def diverging_colors(n): 15 | if n < 3: 16 | raise ValueError('Minimum number of diverging colors is 3.') 17 | elif n > 11: 18 | raise ValueError('Maximum number of diverging colors is 11.') 19 | cols = ['#3D52A1', '#3A89C9', '#008BCE', '#77B7E5', '#99C7EC', '#B4DDF7', 20 | '#E6F5FE', '#FFFAD2', '#FFE3AA', '#F9BD7E', '#F5A275', '#ED875E', 21 | '#D03232', '#D24D3E', '#AE1C3E'] 22 | indices = [[4, 7, 10], 23 | [2, 5, 9, 12], 24 | [2, 5, 7, 9, 12], 25 | [1, 4, 6, 8, 10, 13], 26 | [1, 4, 6, 7, 8, 10, 13], 27 | [1, 3, 5, 6, 8, 9, 11, 13], 28 | [1, 3, 5, 6, 7, 8, 9, 11, 13], 29 | [0, 1, 3, 5, 6, 8, 9, 11, 13, 14], 30 | [0, 1, 3, 5, 6, 7, 8, 9, 11, 13, 14]] 31 | return [cols[ix] for ix in indices[n - 3]] 32 | 33 | def sequential_colors(n): 34 | if n < 3: 35 | raise ValueError('Minimum number of sequential colors is 3.') 36 | elif n > 9: 37 | raise ValueError('Maximum number of sequential colors is 9.') 38 | cols = ['#FFFFE5', '#FFFBD5', '#FFF7BC', '#FEE391', '#FED98E', '#FEC44F', 39 | '#FB9A29', '#EC7014', '#D95F0E', '#CC4C02', '#993404', '#8C2D04', 40 | '#662506'] 41 | indices = [[2, 5, 8], 42 | [1, 3, 6, 9], 43 | [1, 3, 6, 8, 10], 44 | [1, 3, 5, 6, 8, 10], 45 | [1, 3, 5, 6, 7, 9, 10], 46 | [0, 2, 3, 5, 6, 7, 9, 10], 47 | [0, 2, 3, 5, 6, 7, 9, 10, 12]] 48 | return [cols[ix] for ix in indices[n - 3]] 49 | 50 | def rainbow_colors(n): 51 | if n < 4: 52 | raise ValueError('Minimum number of rainbow colors is 4.') 53 | elif n > 12: 54 | raise ValueError('Maximum number of rainbow colors is 12.') 55 | c = ['#781C81', '#404096', '#57A3AD', '#529DB7', '#63AD99', '#6DB388', 56 | '#E39C37', '#D92120'] 57 | cols = [[c[1], c[2], '#DEA73A', c[7]], 58 | [c[1], c[3], '#7DB874', c[6], c[7]], 59 | [c[1], '#498CC2', c[4], '#BEBC48', '#E68B33', c[7]], 60 | [c[0], '#3F60AE', '#539EB6', c[5], '#CAB843', '#E78532', c[7]], 61 | [c[0], '#3F56A7', '#4B91C0', '#5FAA9F', '#91BD61', '#D8AF3D', 62 | '#E77C30', c[7]], 63 | [c[0], '#3F4EA1', '#4683C1', c[2], c[5], '#B1BE4E', '#DFA53A', 64 | '#E7742F', c[7]], 65 | [c[0], '#3F479B', '#4277BD', c[3], '#62AC9B', '#86BB6A', '#C7B944', 66 | c[6], '#E76D2E', c[7]], 67 | [c[0], c[1], '#416CB7', '#4D95BE', '#5BA7A7', '#6EB387', '#A1BE56', 68 | '#D3B33F', '#E59435', '#E6682D', c[7]], 69 | [c[0], '#413B93', '#4065B1', '#488BC2', '#55A1B1', c[4], '#7FB972', 70 | '#B5BD4C', '#D9AD3C', '#E68E34', '#E6642C', c[7]] 71 | ] 72 | return cols[n - 4] 73 | 74 | colors = rainbow_colors(9) 75 | 76 | def get_name2color(names, n, seed=0): 77 | name2color = {} 78 | names = np.array(names) 79 | rdm = np.random.RandomState(seed) 80 | rdm.shuffle(names) 81 | for i, name in enumerate(names): 82 | name2color[name] = colors[i] 83 | return name2color 84 | 85 | factor=100 86 | 87 | mnist = pd.read_csv('./results/MNIST_sparse_summary.csv') 88 | print(mnist['Sparsity']) 89 | mnist['Sparsity'] *= factor 90 | mnist['Full Dense'] *= factor 91 | mnist['Sparse Momentum'] *= factor 92 | mnist['Dynamic Sparse'] *= factor 93 | mnist['SET'] *= factor 94 | mnist['DEEP-R'] *= factor 95 | mnist['error1'] *= factor 96 | mnist['error2'] *= factor 97 | mnist['error3'] *= factor 98 | mnist['error4'] *= factor 99 | mnist['error5'] *= factor 100 | mnist['Sparsity'] = 100-mnist['Sparsity'] 101 | 102 | #ax = sns.lineplot(x='Sparsity', y='Full Dense',data=mnist, label='Full Dense', palette=sns.color_palette("Paired", n_colors=3)) 103 | #ax = sns.lineplot(x='Sparsity', y='Dynamic Sparse',data=mnist, label='Dynamic Sparse', palette=sns.color_palette("Paired", n_colors=3)) 104 | #ax = sns.lineplot(x='Sparsity', y='Sparse Momentum',data=mnist, label='Sparse Momentum', palette=sns.color_palette("Paired", n_colors=3)) 105 | #ax.invert_xaxis() 106 | #ax.xaxis.set_major_locator(plt.FixedLocator(mnist['Sparsity'])) 107 | 108 | percentile95 = 1.96 109 | # color blind colors; optimized for deuteranopia/protanopia; work less well for tritanopia 110 | orange = np.array([230, 159, 0, 255])/255. 111 | blue = np.array([86, 180, 233, 255])/255. 112 | purple = np.array([73, 0, 146, 255])/255. 113 | yellow = np.array([204, 121, 167, 255])/255. 114 | plt.plot(mnist['Sparsity'], mnist['Full Dense'], color='black') 115 | plt.plot(mnist['Sparsity'], mnist['Dynamic Sparse'], color=blue) 116 | plt.plot(mnist['Sparsity'], mnist['Sparse Momentum'], color=orange) 117 | plt.plot(mnist['Sparsity'], mnist['SET'], color=purple) 118 | plt.plot(mnist['Sparsity'], mnist['DEEP-R'], color=yellow) 119 | plt.legend() 120 | plt.errorbar(mnist['Sparsity'], mnist['Full Dense'], yerr=mnist['error1']*percentile95, fmt='.k', capsize=5, elinewidth=1) 121 | plt.errorbar(mnist['Sparsity'], mnist['Dynamic Sparse'], yerr=mnist['error2']*percentile95, fmt='.k', ecolor=blue, capsize=5, elinewidth=1) 122 | plt.errorbar(mnist['Sparsity'], mnist['Sparse Momentum'], yerr=mnist['error3']*percentile95, fmt='.k', ecolor=orange, capsize=5, elinewidth=1) 123 | plt.errorbar(mnist['Sparsity'], mnist['SET'], yerr=mnist['error4']*percentile95, fmt='.k', ecolor=purple, capsize=5, elinewidth=1) 124 | plt.errorbar(mnist['Sparsity'], mnist['DEEP-R'], yerr=mnist['error5']*percentile95, fmt='.k', ecolor=yellow, capsize=5) 125 | 126 | #plt.yscale('log') 127 | plt.ylim(0.975*factor, 0.990*factor) 128 | plt.xlim(0.00*factor, 0.21*factor) 129 | plt.xticks([1, 2, 3, 4, 5, 10]) 130 | plt.ylabel("Test Accuracy") 131 | plt.xlabel('Weights (%)') 132 | plt.title("LeNet 300-100 on MNIST") 133 | 134 | #plt.show() 135 | plt.clf() 136 | 137 | 138 | 139 | data = pd.read_csv('./results/WRN-28-2_results_summary.csv') 140 | print(data['Sparsity']) 141 | data['Sparsity'] *= factor 142 | data['Full Dense'] /= factor 143 | data['Sparse Momentum'] /= factor 144 | data['Dynamic Sparse'] /= factor 145 | data['SET'] /= factor 146 | data['DEEP-R'] /= factor 147 | data['error1'] /= factor 148 | data['error2'] /= factor 149 | data['error3'] /= factor 150 | data['error4'] /= factor 151 | data['error5'] /= factor 152 | data['Sparsity'] = 100-data['Sparsity'] 153 | 154 | percentile95 = 1.96 155 | plt.plot(data['Sparsity'], data['Full Dense'], color='black') 156 | plt.plot(data['Sparsity'], data['Dynamic Sparse'], color=blue) 157 | plt.plot(data['Sparsity'], data['Sparse Momentum'], color=orange) 158 | plt.plot(data['Sparsity'], data['SET'], color=purple) 159 | plt.plot(data['Sparsity'], data['DEEP-R'], color=yellow) 160 | #plt.legend() 161 | plt.errorbar(data['Sparsity'], data['Full Dense'], yerr=data['error1']*percentile95, fmt='.k', capsize=5) 162 | plt.errorbar(data['Sparsity'], data['Dynamic Sparse'], yerr=data['error2']*percentile95, fmt='.k', ecolor=blue, capsize=5) 163 | plt.errorbar(data['Sparsity'], data['Sparse Momentum'], yerr=data['error3']*percentile95, fmt='.k', ecolor=orange, capsize=5) 164 | plt.errorbar(data['Sparsity'], data['SET'], yerr=data['error4']*percentile95, fmt='.k', ecolor=purple, capsize=5) 165 | plt.errorbar(data['Sparsity'], data['DEEP-R'], yerr=data['error5']*percentile95, fmt='.k', ecolor=yellow, capsize=5) 166 | 167 | plt.ylim(0.927*factor, 0.95*factor) 168 | plt.xlim(0.08*factor, 0.52*factor) 169 | plt.xticks([10, 20, 30, 40, 50]) 170 | plt.ylabel("Test Accuracy") 171 | plt.xlabel('Weights (%)') 172 | plt.title("WRN 28-2 on CIFAR-10") 173 | 174 | #plt.show() 175 | plt.clf() 176 | 177 | data_vgg = pd.read_csv('./results/sensitivity_momentum_vgg-d.csv') 178 | data_alexnet = pd.read_csv('./results/sensitivity_momentum_alexnet-s.csv') 179 | 180 | data_vgg = data_vgg.iloc[1:, :] 181 | data_alexnet = data_alexnet.iloc[1:, :] 182 | 183 | data_vgg.iloc[0:, 1:] *= 100.0 184 | data_alexnet.iloc[0:, 1:] *= 100.0 185 | 186 | data_alexnet.loc[0:, 'sparse SE'] *= 1.96 # 95% confidence intervals 187 | data_alexnet.loc[0:, 'dense SE'] *= 1.96 188 | 189 | data_vgg.loc[0:, 'sparse SE'] *= 1.96 # 95% confidence intervals 190 | data_vgg.loc[0:, 'dense SE'] *= 1.96 191 | 192 | print(data_vgg) 193 | print(data_alexnet) 194 | 195 | plt.plot(data_vgg['momentum'], data_vgg['sparse mean'], color='black', label='VGG Sparse momentum') 196 | plt.plot(data_vgg['momentum'], data_vgg['dense mean'], color=orange, label='VGG Dense control') 197 | #plt.plot(data_alexnet['momentum'], data_alexnet['sparse mean'], color=purple, label='AlexNet Sparse momentum') 198 | #plt.plot(data_alexnet['momentum'], data_alexnet['dense mean'], color=yellow, label='AlexNet Dense control') 199 | #plt.legend() 200 | #plt.legend(bbox_to_anchor=(0, 1), loc='center right', ncol=1) 201 | #plt.legend(bbox_to_anchor=(1.04,1), mode='expand', loc="upper left") 202 | #l1 = plt.legend(bbox_to_anchor=(1.04,1), borderaxespad=0) 203 | #l2 = plt.legend(bbox_to_anchor=(1.04,0), loc="lower left", borderaxespad=0) 204 | #l3 = plt.legend(bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0) 205 | #l4 = plt.legend(bbox_to_anchor=(0,-0.40,1,-0.2), loc="lower left", 206 | #mode="expand", borderaxespad=0, ncol=2) 207 | plt.legend() 208 | #l5 = plt.legend(bbox_to_anchor=(1,0), loc="lower right", 209 | #bbox_transform=fig.transFigure, ncol=3) 210 | #l6 = plt.legend(bbox_to_anchor=(0.4,0.8), loc="upper right") 211 | 212 | plt.errorbar(data_vgg['momentum'], data_vgg['sparse mean'], yerr=data_vgg['sparse SE'], fmt='.k', capsize=5) 213 | plt.errorbar(data_vgg['momentum'], data_vgg['dense mean'], yerr=data_vgg['dense SE'], fmt='.k', ecolor=orange, capsize=5) 214 | #plt.errorbar(data_alexnet['momentum'], data_alexnet['sparse mean'], yerr=data_alexnet['sparse SE'], fmt='.k', ecolor=purple, capsize=5) 215 | #plt.errorbar(data_alexnet['momentum'], data_alexnet['dense mean'], yerr=data_alexnet['dense SE'], fmt='.k', ecolor=yellow, capsize=5) 216 | #plt.fill_between(data_vgg['momentum'], data_vgg['sparse mean'] - data_vgg['sparse SE'], data_vgg['sparse mean']+data_vgg['sparse SE'])#, fmt='.k', ecolor=orange, capsize=5) 217 | #plt.fill_between(data_vgg['momentum'], data_vgg['dense mean'] - data_vgg['dense SE'], data_vgg['dense mean']+data_vgg['dense SE'])#, fmt='.k', ecolor=orange, capsize=5) 218 | 219 | #plt.ylim(0.927*factor, 0.95*factor) 220 | plt.xlim(0.49, 0.99) 221 | plt.xticks([0.95, 0.9, 0.8, 0.7, 0.6, 0.5]) 222 | plt.ylabel("Test Error") 223 | plt.xlabel('Momentum') 224 | plt.title("Momentum Parameter Sensitivity") 225 | #plt.subplots_adjust(bottom=-0.7) 226 | plt.tight_layout()#rect=[0,0.0,1.0,1]) 227 | 228 | #plt.show() 229 | plt.clf() 230 | 231 | 232 | 233 | data_alexnet.loc[0:, 'sparse mean'] -= data_alexnet.loc[2, 'sparse mean'] 234 | data_alexnet.loc[0:, 'dense mean'] -= data_alexnet.loc[2, 'dense mean'] 235 | data_vgg.loc[0:, 'sparse mean'] -= data_vgg.loc[2, 'sparse mean'] 236 | data_vgg.loc[0:, 'dense mean'] -= data_vgg.loc[2, 'dense mean'] 237 | 238 | sparse_data = [] 239 | sparse_data += data_vgg.loc[:, 'sparse mean'].tolist() 240 | sparse_data += data_alexnet.loc[:, 'sparse mean'].tolist() 241 | 242 | dense_data = [] 243 | dense_data += data_vgg.loc[:, 'dense mean'].tolist() 244 | dense_data += data_alexnet.loc[:, 'dense mean'].tolist() 245 | dense_data = np.array(dense_data) 246 | 247 | print(stats.levene(sparse_data, dense_data)) 248 | print(stats.normaltest(sparse_data)) 249 | print(stats.normaltest(dense_data)) 250 | print(stats.normaltest(np.log10(dense_data+1-dense_data.min()))) 251 | print(stats.wilcoxon(sparse_data, dense_data)) 252 | 253 | data_vgg = pd.read_csv('./results/sensitivity_prune_rate_vgg-d.csv') 254 | data_alexnet = pd.read_csv('./results/sensitivity_prune_rate_alexnet-s.csv') 255 | 256 | data_vgg.iloc[0:, 1:] *= 100.0 257 | data_alexnet.iloc[0:, 1:] *= 100.0 258 | 259 | data_alexnet.loc[0:, 'cosine SE'] *= 1.96 # 95% confidence intervals 260 | data_alexnet.loc[0:, 'linear SE'] *= 1.96 261 | 262 | data_vgg.loc[0:, 'cosine SE'] *= 1.96 # 95% confidence intervals 263 | data_vgg.loc[0:, 'linear SE'] *= 1.96 264 | 265 | plt.plot(data_vgg['prune_rate'], data_vgg['cosine mean'], color='black', label='Cosine annealing') 266 | plt.plot(data_vgg['prune_rate'], data_vgg['linear mean'], color=orange, label='Linear annealing') 267 | plt.legend() 268 | plt.plot(data_alexnet['prune_rate'], data_alexnet['cosine mean'], color='black')#, label='Cosine annealing') 269 | plt.plot(data_alexnet['prune_rate'], data_alexnet['linear mean'], color=orange)#, label='Linear annealing') 270 | plt.annotate('AlexNet-s', xy=(0.25, 13.7), xytext=(0.2, 10), 271 | arrowprops=dict(facecolor='black', shrink=0.05)) 272 | plt.annotate('VGG16-D', xy=(0.45, 7.0), xytext=(0.40, 10), 273 | arrowprops=dict(facecolor='black', shrink=0.05)) 274 | plt.errorbar(data_vgg['prune_rate'], data_vgg['cosine mean'], yerr=data_vgg['cosine SE'], fmt='.k', capsize=5) 275 | plt.errorbar(data_vgg['prune_rate'], data_vgg['linear mean'], yerr=data_vgg['linear SE'], fmt='.k', ecolor=orange, capsize=5) 276 | plt.errorbar(data_alexnet['prune_rate'], data_alexnet['cosine mean'], yerr=data_alexnet['cosine SE'], fmt='.k', capsize=5) 277 | plt.errorbar(data_alexnet['prune_rate'], data_alexnet['linear mean'], yerr=data_alexnet['linear SE'], fmt='.k', ecolor=orange, capsize=5) 278 | 279 | #plt.ylim(0.927*factor, 0.95*factor) 280 | #plt.xlim(0.49, 0.99) 281 | plt.xticks([0.7, 0.8, 0.6, 0.5, 0.4, 0.3, 0.2]) 282 | plt.ylabel("Test Error") 283 | plt.xlabel('Prune Rate') 284 | plt.title("Prune Rate Parameter Sensitivity") 285 | #plt.subplots_adjust(bottom=-0.7) 286 | plt.tight_layout()#rect=[0,0.0,1.0,1]) 287 | 288 | #plt.show() 289 | plt.clf() 290 | 291 | d = pd.read_csv('./results/MNIST_compression_comparison_lenet300-100.csv') 292 | print(d) 293 | 294 | labels = d.loc[:, 'name'].tolist()[1:] 295 | unique = [] 296 | # necessary to get same colors for the same seed 297 | for lbl in labels: 298 | if lbl not in unique: 299 | unique.append(lbl) 300 | labels = unique 301 | 302 | fig, ax = plt.subplots() 303 | 304 | #ax.set_facecolor('white') 305 | x, y = d['density'], d['error'] 306 | i = 0 307 | name2color = get_name2color(labels, len(labels), seed=4) 308 | for lbl in labels: 309 | color = name2color[lbl] 310 | if lbl == 'Sparse Momentum': continue 311 | cond = d['name'] == lbl 312 | plt.plot(x[cond], y[cond], linestyle='none', marker='o', label=lbl, color=color) 313 | 314 | cond = d['name'] == 'Sparse Momentum' 315 | plt.plot(x[cond], y[cond], color=orange, label='Sparse Momentum') 316 | plt.plot([0,9.0], [1.34, 1.34], label='Dense (100% Weights)', color='black') 317 | plt.legend() 318 | plt.errorbar(x[cond], y[cond], yerr=d['sm SE'][cond]*1.96, fmt='.k', capsize=5, ecolor=orange) 319 | plt.errorbar([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1.34]*10, yerr=[0.011*1.96]*10, fmt='.k', capsize=5, ecolor='black') 320 | 321 | names = [\ 322 | 'LeCun 1989', 323 | 'Dong 2017', 324 | 'Carreira-Perpinan 2018', 325 | 'Lee 2019', 326 | 'Ullrich 2017', 327 | 'Guo 2016', 328 | 'Han 2015', 329 | 'Lee 2019', 330 | 'Molchanov 2017', 331 | 'Gomez 2018', 332 | 'Gomez 2018'] 333 | 334 | diff_pos = [\ 335 | (-0.7, 0.001), 336 | (-0.65, 0.07), 337 | (0, 0), 338 | (0, 0), 339 | (0, 0), 340 | (0.1, -0.13), 341 | (-0.7, 0), 342 | (0, 0), 343 | (-0.65, -0.15), 344 | (-1.0, -0.1), 345 | (-0.0, -0.0)] 346 | 347 | print(len(diff_pos), len(names)) 348 | 349 | print(d) 350 | 351 | i = 0 352 | for name, x, y, diff in zip(d.loc[:, 'author'], d.loc[:, 'density'], d.loc[:, 'error'], diff_pos): 353 | print(name) 354 | if name == 'LeCun 1989': continue 355 | if name == 'Dettmers 2019': continue 356 | #if name == 'Dong 2017': 357 | # ax.annotate(name, xy=(x, y), xytext=(0.5, 1.6), 358 | # arrowprops=dict(color='black', facecolor='black',arrowstyle="-", \ 359 | # connectionstyle="angle3", lw=1), size=10) 360 | # #arrowprops=dict(facecolor='black', shrink=0.01)) 361 | else: 362 | #color = cm.get_cmap(name=name, lut=10) 363 | ax.annotate(name, (x+diff[0]-0.01, y+diff[1]), size=10) 364 | i += 1 365 | plt.ylabel("Test Error") 366 | plt.xlabel('Weights (%)') 367 | plt.title("LeNet 300-100 on MNIST") 368 | #plt.subplots_adjust(bottom=-0.7) 369 | plt.xlim(0.8, 10.5) 370 | plt.tight_layout()#rect=[0,0.0,1.0,1]) 371 | 372 | plt.show() 373 | plt.clf() 374 | 375 | 376 | d = pd.read_csv('./results/MNIST_compression_comparison_lenet5.csv') 377 | print(d) 378 | 379 | d = d.iloc[1:, :] 380 | 381 | labels = set(d.loc[:, 'name'].tolist()) 382 | fig, ax = plt.subplots() 383 | x, y = d['density'], d['error'] 384 | for lbl in labels: 385 | if lbl == 'Sparse Momentum': continue 386 | color = name2color[lbl] 387 | cond = d['name'] == lbl 388 | plt.plot(x[cond], y[cond], linestyle='none', marker='o', label=lbl, color=color) 389 | 390 | cond = d['name'] == 'Sparse Momentum' 391 | plt.plot(x[cond], y[cond], color=orange, label='Sparse Momentum') 392 | plt.plot([0,10.0], [0.58, 0.58], label='Dense (100% Weights)', color='black') 393 | #plt.legend() 394 | plt.errorbar(x[cond], y[cond], yerr=d['sm SE'][cond]*1.96, fmt='.k', capsize=5, ecolor=orange) 395 | plt.errorbar([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [0.58]*10, yerr=[0.01*1.96]*10, fmt='.k', capsize=5, ecolor='black') 396 | 397 | names = [\ 398 | #'LeCun 1989', 399 | 'Dong 2017', 400 | 'Lee 2019', 401 | 'Ullrich 2017', 402 | 'Guo 2016', 403 | 'Han 2015', 404 | 'Lee 2019', 405 | 'Carreira-Perpinan 2018', 406 | 'Molchanov 2017', 407 | 'Gomez 2018', 408 | 'Gomez 2018'] 409 | 410 | diff_pos = [\ 411 | #(-0.7, 0.001), 412 | (-0.5, 0.0), 413 | (0, 0.03), 414 | (0, 0), 415 | (0.1, 0.00), 416 | (-0.7, -0.05), 417 | (0, 0.02), 418 | (0.2, -0.05), 419 | (-0.35, -0.09), 420 | (-1.2, 0.00), 421 | (-1.0, 0.00)] 422 | 423 | print(len(diff_pos), len(names)) 424 | 425 | for name, x, y, diff in zip(d.loc[:, 'author'], d.loc[:, 'density'], d.loc[:, 'error'], diff_pos): 426 | print(name, x, y) 427 | if name == 'Dettmers 2019': continue 428 | 429 | #if name == 'Lee 2018': 430 | # ax.annotate(name, xy=(x, y), xytext=(0.6, 1.2), 431 | # arrowprops=dict(color='black', facecolor='black',arrowstyle="-", \ 432 | # connectionstyle="arc3", lw=1), size=10) 433 | #else: 434 | # ax.annotate(name, (x+diff[0]-0.01, y+diff[1]), size=10) 435 | ax.annotate(name, (x+diff[0]-0.01, y+diff[1]), size=10) 436 | plt.ylabel("Test Error") 437 | plt.xlabel('Weights (%)') 438 | plt.xlim(0.0, 10.5) 439 | plt.title("LeNet-5 Caffe on MNIST") 440 | #plt.subplots_adjust(bottom=-0.7) 441 | plt.tight_layout()#rect=[0,0.0,1.0,1]) 442 | 443 | plt.show() 444 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | -------------------------------------------------------------------------------- /results/MNIST_compression_comparison_lenet300-100.csv: -------------------------------------------------------------------------------- 1 | name,author,density,error,sm SE 2 | Opt. Brain Damage,LeCun 1989,8,2, 3 | Layer-wise Brain Damage,Dong 2017,1.5,2, 4 | Compression via optimization,Carreira-Perpinan 2018,1,3.2, 5 | Single-shot Net. Pruning,Lee 2019,2,2.4, 6 | Soft weight-sharing,Ullrich 2017,4.4,1.9, 7 | Dyn. Network Surgery,Guo 2016,1.8,2, 8 | Learn weights&connections,Han 2015,8.3,1.6, 9 | Single-shot Net. Pruning,Lee 2019,5,1.6, 10 | Variational Dropout,Molchanov 2017,1.5,1.9, 11 | Targeted Dropout,Gomez 2018,5.0,3.13, 12 | Targeted Dropout,Gomez 2018,2.0,3.77, 13 | Sparse Momentum,Dettmers 2019,1,2.36,0.044 14 | Sparse Momentum,Dettmers 2019,1.5,2.053,0.032 15 | Sparse Momentum,Dettmers 2019,1.8,1.984,0.053 16 | Sparse Momentum,Dettmers 2019,2,1.99,0.019 17 | Sparse Momentum,Dettmers 2019,4.4,1.552,0.027 18 | Sparse Momentum,Dettmers 2019,5,1.53,0.02 19 | Sparse Momentum,Dettmers 2019,8,1.447,0.014 20 | Sparse Momentum,Dettmers 2019,8.3,1.426,0.02 21 | -------------------------------------------------------------------------------- /results/MNIST_compression_comparison_lenet5.csv: -------------------------------------------------------------------------------- 1 | name,author,density,error,sm SE 2 | Opt. Brain Damage,LeCun 1989,8,2.7, 3 | Layer-wise Brain Damage,Dong 2017,1,2.1, 4 | Single-shot Net. Pruning,Lee 2019,1,1.1, 5 | Soft weight-sharing,Ullrich 2017,0.5,1, 6 | Dyn. Network Surgery,Guo 2016,0.9,0.9, 7 | Learn weights&connections,Han 2015,9.3,0.8, 8 | Single-shot Net. Pruning,Lee 2019,2,0.8, 9 | Compression via optimization,Carreira-Perpinan 2018,1,1.1, 10 | Variational Dropout,Molchanov 2017,0.4,0.8, 11 | Targeted Dropout,Gomez 2018,5.0,1.96, 12 | Targeted Dropout,Gomez 2018,10.0,1.95, 13 | Sparse Momentum,Dettmers 2019,0.4,1.403,0.088 14 | Sparse Momentum,Dettmers 2019,0.5,1.144,0.024 15 | Sparse Momentum,Dettmers 2019,0.9,0.894,0.033 16 | Sparse Momentum,Dettmers 2019,1,0.83,0.04 17 | Sparse Momentum,Dettmers 2019,2,0.76,0.022 18 | Sparse Momentum,Dettmers 2019,5,0.69,0.021 19 | Sparse Momentum,Dettmers 2019,8,0.645,0.019 20 | Sparse Momentum,Dettmers 2019,9.3,0.642,0.03 21 | -------------------------------------------------------------------------------- /results/MNIST_sparse_summary.csv: -------------------------------------------------------------------------------- 1 | Sparsity,Full Dense,error1,Dynamic Sparse,error2,Sparse Momentum,error3,SET,error4,DEEP-R,error5 2 | 0.99,0.9869,0.000106094,0.9754,0.000157,0.9753,0.000469515,0.958,0.000451, 0.94, 0.0001 3 | 0.98,0.9869,0.000106094,0.9814,0.00064,0.982,0.00017966,0.9774096,0.0007068877551,0.9821,0.00135 4 | 0.97,0.9869,0.000106094,0.9820,0.00023,0.983,0.000546148,0.9806024,0.0003380612245,0.9822,0.000451 5 | 0.96,0.9869,0.000106094,0.9823,0.00023,0.9842,0.00023202,0.9809639,0.0003688265306,0.983,0.000574 6 | 0.95,0.9869,0.000106094,0.9832,0.00038,0.9853,0.000329309,0.9814458,,0.9806,0.000451 7 | 0.9,0.9869,0.000106094,0.9844,0.00045,0.9867,0.0002625101,0.9839157,0.0003688265306,0.9819,0.000451 8 | -------------------------------------------------------------------------------- /results/WRN-28-2_results_summary.csv: -------------------------------------------------------------------------------- 1 | Sparsity,Full Dense,error1,Dynamic Sparse,error2,Sparse Momentum,error3,SET,error4,DEEP-R,error5 2 | 0.9,9470,5.078176719,9359,7.8,9361.4,4.68,9328,5.5,9141,5.7 3 | 0.8,9470,5.078176719,9440,5.1,9432.4,5.49,9425,6.6,9236,8.3 4 | 0.7,9470,5.078176719,9457,5.3,9457.6,4.78,9440,8.9,9250,10.0 5 | 0.6,9470,5.078176719,9466,5.2,9467.75,5.95,9455,6.4,9280,3.8 6 | 0.5,9470,5.078176719,9473,4.6,9474.87,4.03,9476,4.0,9292,8.2 7 | -------------------------------------------------------------------------------- /results/dynamic_sparse/calc_confidence_intervals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | 4 | for path in glob.glob("./results/*.txt"): 5 | print(path) 6 | data = [] 7 | with open(path) as f: 8 | for i, line in enumerate(f): 9 | if i == 0: 10 | sparsity = line.split(',') 11 | sparsity[0] = sparsity[0][sparsity[0].index('[')+1:] 12 | sparsity[-1] = sparsity[-1][:-2] 13 | else: 14 | runs = line.split(' ') 15 | m = np.mean(np.float32(runs)) 16 | std = np.std(np.float32(runs),ddof=1) 17 | se = std/np.sqrt(len(runs)) 18 | data.append((m, se)) 19 | 20 | for s, (m, se) in zip(sparsity, data): 21 | print(s, m, se) 22 | -------------------------------------------------------------------------------- /results/dynamic_sparse/results.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/sparse_learning/e19ff16c475462d50ad247ec2927113617662a4e/results/dynamic_sparse/results.zip -------------------------------------------------------------------------------- /results/feature_data.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/sparse_learning/e19ff16c475462d50ad247ec2927113617662a4e/results/feature_data.tar.gz -------------------------------------------------------------------------------- /results/graphs.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns; sns.set() 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import numpy as np 5 | 6 | mnist_path = './MNIST_sparse_summary.csv' 7 | 8 | factor=100 9 | 10 | mnist = pd.read_csv(mnist_path) 11 | print(mnist['Sparsity']) 12 | mnist['Sparsity'] *= factor 13 | mnist['Full Dense'] *= factor 14 | mnist['Sparse Momentum'] *= factor 15 | mnist['Dynamic Sparse'] *= factor 16 | mnist['SET'] *= factor 17 | mnist['DEEP-R'] *= factor 18 | mnist['error1'] *= factor 19 | mnist['error2'] *= factor 20 | mnist['error3'] *= factor 21 | mnist['error4'] *= factor 22 | mnist['error5'] *= factor 23 | mnist['Sparsity'] = 100-mnist['Sparsity'] 24 | 25 | #ax = sns.lineplot(x='Sparsity', y='Full Dense',data=mnist, label='Full Dense', palette=sns.color_palette("Paired", n_colors=3)) 26 | #ax = sns.lineplot(x='Sparsity', y='Dynamic Sparse',data=mnist, label='Dynamic Sparse', palette=sns.color_palette("Paired", n_colors=3)) 27 | #ax = sns.lineplot(x='Sparsity', y='Sparse Momentum',data=mnist, label='Sparse Momentum', palette=sns.color_palette("Paired", n_colors=3)) 28 | #ax.invert_xaxis() 29 | #ax.xaxis.set_major_locator(plt.FixedLocator(mnist['Sparsity'])) 30 | 31 | percentile95 = 1.96 32 | # color blind colors; optimized for deuteranopia/protanopia; work less well for tritanopia 33 | orange = np.array([230, 159, 0, 255])/255. 34 | blue = np.array([86, 180, 233, 255])/255. 35 | purple = np.array([73, 0, 146, 255])/255. 36 | yellow = np.array([204, 121, 167, 255])/255. 37 | plt.plot(mnist['Sparsity'], mnist['Full Dense'], color='black') 38 | plt.plot(mnist['Sparsity'], mnist['Dynamic Sparse'], color=blue) 39 | plt.plot(mnist['Sparsity'], mnist['Sparse Momentum'], color=orange) 40 | plt.plot(mnist['Sparsity'], mnist['SET'], color=purple) 41 | plt.plot(mnist['Sparsity'], mnist['DEEP-R'], color=yellow) 42 | plt.legend() 43 | plt.errorbar(mnist['Sparsity'], mnist['Full Dense'], yerr=mnist['error1']*percentile95, fmt='.k', capsize=5) 44 | plt.errorbar(mnist['Sparsity'], mnist['Dynamic Sparse'], yerr=mnist['error2']*percentile95, fmt='.k', ecolor=blue, capsize=5) 45 | plt.errorbar(mnist['Sparsity'], mnist['Sparse Momentum'], yerr=mnist['error3']*percentile95, fmt='.k', ecolor=orange, capsize=5) 46 | plt.errorbar(mnist['Sparsity'], mnist['SET'], yerr=mnist['error4']*percentile95, fmt='.k', ecolor=purple, capsize=5) 47 | plt.errorbar(mnist['Sparsity'], mnist['DEEP-R'], yerr=mnist['error5']*percentile95, fmt='.k', ecolor=yellow, capsize=5) 48 | 49 | plt.ylim(0.975*factor, 0.990*factor) 50 | plt.xlim(0.00*factor, 0.11*factor) 51 | plt.xticks([1, 2, 3, 4, 5, 10]) 52 | plt.ylabel("Test Accuracy") 53 | plt.xlabel('Weights (%)') 54 | plt.title("LeNet 300-100 on MNIST") 55 | 56 | plt.show() 57 | plt.clf() 58 | 59 | 60 | 61 | data = pd.read_csv('./WRN-28-2_results_summary.csv') 62 | print(data['Sparsity']) 63 | data['Sparsity'] *= factor 64 | data['Full Dense'] /= factor 65 | data['Sparse Momentum'] /= factor 66 | data['Dynamic Sparse'] /= factor 67 | data['SET'] /= factor 68 | data['DEEP-R'] /= factor 69 | data['error1'] /= factor 70 | data['error2'] /= factor 71 | data['error3'] /= factor 72 | data['error4'] /= factor 73 | data['error5'] /= factor 74 | data['Sparsity'] = 100-data['Sparsity'] 75 | 76 | percentile95 = 1.96 77 | plt.plot(data['Sparsity'], data['Full Dense'], color='black') 78 | plt.plot(data['Sparsity'], data['Dynamic Sparse'], color=blue) 79 | plt.plot(data['Sparsity'], data['Sparse Momentum'], color=orange) 80 | plt.plot(data['Sparsity'], data['SET'], color=purple) 81 | plt.plot(data['Sparsity'], data['DEEP-R'], color=yellow) 82 | #plt.legend() 83 | plt.errorbar(data['Sparsity'], data['Full Dense'], yerr=data['error1']*percentile95, fmt='.k', capsize=5) 84 | plt.errorbar(data['Sparsity'], data['Dynamic Sparse'], yerr=data['error2']*percentile95, fmt='.k', ecolor=blue, capsize=5) 85 | plt.errorbar(data['Sparsity'], data['Sparse Momentum'], yerr=data['error3']*percentile95, fmt='.k', ecolor=orange, capsize=5) 86 | plt.errorbar(data['Sparsity'], data['SET'], yerr=data['error4']*percentile95, fmt='.k', ecolor=purple, capsize=5) 87 | plt.errorbar(data['Sparsity'], data['DEEP-R'], yerr=data['error5']*percentile95, fmt='.k', ecolor=yellow, capsize=5) 88 | 89 | plt.ylim(0.927*factor, 0.95*factor) 90 | plt.xlim(0.08*factor, 0.52*factor) 91 | plt.xticks([10, 20, 30, 40, 50]) 92 | plt.ylabel("Test Accuracy") 93 | plt.xlabel('Weights (%)') 94 | plt.title("WRN 28-2 on CIFAR-10") 95 | 96 | plt.show() 97 | -------------------------------------------------------------------------------- /results/sensitivity_momentum_alexnet-s.csv: -------------------------------------------------------------------------------- 1 | momentum,sparse median,sparse mean,sparse SE,dense median ,dense mean,dense SE 2 | 0.99,0.23465,0.23465,0.00227,0.2298,0.22732,0.00253 3 | 0.95,0.15355,0.15292,0.00115,0.14165,0.14137,0.00138 4 | 0.9,0.145,0.14498,0.00044,0.1277,0.12683,0.00156 5 | 0.8,0.14085,0.14248,0.00154,0.12625,0.12655,0.00057 6 | 0.7,0.1466,0.1469,0.00064,0.12805,0.12793,0.00056 7 | 0.6,0.1495,0.14998,0.00109,0.1289,0.12938,0.00075 8 | 0.5,0.1528,0.15285,0.00083,0.1308,0.13102,0.00072 9 | -------------------------------------------------------------------------------- /results/sensitivity_momentum_vgg-d.csv: -------------------------------------------------------------------------------- 1 | momentum,sparse median,sparse mean,sparse SE,dense median ,dense mean,dense SE 2 | 0.99,0.3908,0.38664,0.08052,0.3101,0.36837,0.06096 3 | 0.95,0.07215,0.07172,0.00094,0.0761,0.0765,0.00082 4 | 0.9,0.0664,0.06637,0.00093,0.0652,0.06557,0.00058 5 | 0.8,0.06525,0.06525,0.0011,0.0642,0.0641,0.00057 6 | 0.7,0.06475,0.0649,0.00049,0.0651,0.06518,0.00061 7 | 0.6,0.0691,0.06903,0.0004,0.06545,0.06522,0.00042 8 | 0.5,0.06805,0.06787,0.00057,0.0658,0.06523,0.00082 9 | -------------------------------------------------------------------------------- /results/sensitivity_prune_rate_alexnet-s.csv: -------------------------------------------------------------------------------- 1 | prune_rate,cosine median,cosine mean,cosine SE,linear median,linear mean,linear SE 2 | 0.7,0.15185,0.14943,0.00188,0.14725,0.1479,0.0019 3 | 0.6,0.1477,0.14758,0.00058,0.1449,0.14615,0.00148 4 | 0.5,0.14475,0.14523,0.00052,0.146,0.1456,0.00089 5 | 0.4,0.14295,0.1428,0.00095,0.1419,0.1424,0.00102 6 | 0.3,0.1439,0.14295,0.00122,0.14325,0.14248,0.00092 7 | 0.2,0.14265,0.14217,0.00133,0.14175,0.14123,0.00118 8 | -------------------------------------------------------------------------------- /results/sensitivity_prune_rate_vgg-d.csv: -------------------------------------------------------------------------------- 1 | prune_rate,cosine median,cosine mean,cosine SE,linear median,linear mean,linear SE 2 | 0.7,0.0665,0.06712,0.00056,0.06625,0.0664,0.00027 3 | 0.6,0.06515,0.06483,0.00043,0.0654,0.06493,0.0005 4 | 0.5,0.0665,0.06673,0.00075,0.066,0.06619,0.00043 5 | 0.4,0.0651,0.06507,0.00079,0.06625,0.06563,0.0012 6 | 0.3,0.0637,0.06412,0.00053,0.06445,0.06443,0.00068 7 | 0.2,0.0639,0.06387,0.0004,0.0653,0.06567,0.00078 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | # Utility function to read the README file. 5 | # Used for the long_description. It's nice, because now 1) we have a top level 6 | # README file and 2) it's easier to type in the README file than to put a raw 7 | # string in below ... 8 | def read(fname): 9 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 10 | 11 | setup( 12 | name = "sparselearning", 13 | version = "1.0.0", 14 | author = "Tim Dettmers", 15 | author_email = "dettmers@cs.washington.edu", 16 | description = ("Sparse learning library including sparse momentum algorithm."), 17 | license = "GNU", 18 | keywords = "deep learning, sparse learning", 19 | url = "http://packages.python.org/sparselearning", 20 | packages=['sparselearning'], 21 | long_description=read('README.md'), 22 | classifiers=[ 23 | "Development Status :: 1 - Alpha", 24 | "Topic :: Machine Learning", 25 | ], 26 | ) 27 | 28 | -------------------------------------------------------------------------------- /sparselearning/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.getLogger(__name__).addHandler(logging.NullHandler()) 3 | -------------------------------------------------------------------------------- /sparselearning/funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | ''' 4 | REDISTRIBUTION 5 | ''' 6 | 7 | def momentum_redistribution(masking, name, weight, mask): 8 | """Calculates momentum redistribution statistics. 9 | 10 | Args: 11 | masking Masking class with state about current 12 | layers and the entire sparse network. 13 | 14 | name The name of the layer. This can be used to 15 | access layer-specific statistics in the 16 | masking class. 17 | 18 | weight The weight of the respective sparse layer. 19 | This is a torch parameter. 20 | 21 | mask The binary mask. 1s indicated active weights. 22 | 23 | Returns: 24 | Layer Statistic The unnormalized layer statistics 25 | for the layer "name". A higher value indicates 26 | that more pruned parameters are redistributed 27 | to this layer compared to layers with lower value. 28 | The values will be automatically sum-normalized 29 | after this step. 30 | 31 | 32 | The calculation of redistribution statistics is the first 33 | step in this sparse learning library. 34 | """ 35 | grad = masking.get_momentum_for_weight(weight) 36 | mean_magnitude = torch.abs(grad[mask.bool()]).mean().item() 37 | return mean_magnitude 38 | 39 | def magnitude_redistribution(masking, name, weight, mask): 40 | mean_magnitude = torch.abs(weight)[mask.bool()].mean().item() 41 | return mean_magnitude 42 | 43 | def nonzero_redistribution(masking, name, weight, mask): 44 | nonzero = (weight !=0.0).sum().item() 45 | return nonzero 46 | 47 | def no_redistribution(masking, name, weight, mask): 48 | num_params = masking.baseline_nonzero 49 | n = weight.numel() 50 | return n/float(num_params) 51 | 52 | 53 | ''' 54 | PRUNE 55 | ''' 56 | def magnitude_prune(masking, mask, weight, name): 57 | """Prunes the weights with smallest magnitude. 58 | 59 | The pruning functions in this sparse learning library 60 | work by constructing a binary mask variable "mask" 61 | which prevents gradient flow to weights and also 62 | sets the weights to zero where the binary mask is 0. 63 | Thus 1s in the "mask" variable indicate where the sparse 64 | network has active weights. In this function name 65 | and masking can be used to access global statistics 66 | about the specific layer (name) and the sparse network 67 | as a whole. 68 | 69 | Args: 70 | masking Masking class with state about current 71 | layers and the entire sparse network. 72 | 73 | mask The binary mask. 1s indicated active weights. 74 | 75 | weight The weight of the respective sparse layer. 76 | This is a torch parameter. 77 | 78 | name The name of the layer. This can be used to 79 | access layer-specific statistics in the 80 | masking class. 81 | 82 | Returns: 83 | mask Pruned Binary mask where 1s indicated active 84 | weights. Can be modified in-place or newly 85 | constructed 86 | 87 | Accessable global statistics: 88 | 89 | Layer statistics: 90 | Non-zero count of layer: 91 | masking.name2nonzeros[name] 92 | Zero count of layer: 93 | masking.name2zeros[name] 94 | Redistribution proportion: 95 | masking.name2variance[name] 96 | Number of items removed through pruning: 97 | masking.name2removed[name] 98 | 99 | Network statistics: 100 | Total number of nonzero parameter in the network: 101 | masking.total_nonzero = 0 102 | Total number of zero-valued parameter in the network: 103 | masking.total_zero = 0 104 | Total number of parameters removed in pruning: 105 | masking.total_removed = 0 106 | """ 107 | num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name]) 108 | num_zeros = masking.name2zeros[name] 109 | k = math.ceil(num_zeros + num_remove) 110 | if num_remove == 0.0: return weight.data != 0.0 111 | 112 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 113 | mask.data.view(-1)[idx[:k]] = 0.0 114 | return mask 115 | 116 | def global_magnitude_prune(masking): 117 | prune_rate = 0.0 118 | for name in masking.name2prune_rate: 119 | if name in masking.masks: 120 | prune_rate = masking.name2prune_rate[name] 121 | tokill = math.ceil(prune_rate*masking.baseline_nonzero) 122 | total_removed = 0 123 | prev_removed = 0 124 | while total_removed < tokill*(1.0-masking.tolerance) or (total_removed > tokill*(1.0+masking.tolerance)): 125 | total_removed = 0 126 | for module in masking.modules: 127 | for name, weight in module.named_parameters(): 128 | if name not in masking.masks: continue 129 | remain = (torch.abs(weight.data) > masking.prune_threshold).sum().item() 130 | total_removed += masking.name2nonzeros[name] - remain 131 | 132 | if prev_removed == total_removed: break 133 | prev_removed = total_removed 134 | if total_removed > tokill*(1.0+masking.tolerance): 135 | masking.prune_threshold *= 1.0-masking.increment 136 | masking.increment *= 0.99 137 | elif total_removed < tokill*(1.0-masking.tolerance): 138 | masking.prune_threshold *= 1.0+masking.increment 139 | masking.increment *= 0.99 140 | 141 | for module in masking.modules: 142 | for name, weight in module.named_parameters(): 143 | if name not in masking.masks: continue 144 | masking.masks[name][:] = torch.abs(weight.data) > masking.prune_threshold 145 | 146 | return int(total_removed) 147 | 148 | 149 | def magnitude_and_negativity_prune(masking, mask, weight, name): 150 | num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name]) 151 | if num_remove == 0.0: return weight.data != 0.0 152 | 153 | num_zeros = masking.name2zeros[name] 154 | k = math.ceil(num_zeros + (num_remove/2.0)) 155 | 156 | # remove all weights which absolute value is smaller than threshold 157 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 158 | mask.data.view(-1)[idx[:k]] = 0.0 159 | 160 | # remove the most negative weights 161 | x, idx = torch.sort(weight.data.view(-1)) 162 | mask.data.view(-1)[idx[math.ceil(num_remove/2.0):]] = 0.0 163 | 164 | return mask 165 | 166 | ''' 167 | GROWTH 168 | ''' 169 | 170 | def random_growth(masking, name, new_mask, total_regrowth, weight): 171 | n = (new_mask==0).sum().item() 172 | if n == 0: return new_mask 173 | expeced_growth_probability = (total_regrowth/n) 174 | new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability 175 | return new_mask.bool() | new_weights 176 | 177 | def momentum_growth(masking, name, new_mask, total_regrowth, weight): 178 | """Grows weights in places where the momentum is largest. 179 | 180 | Growth function in the sparse learning library work by 181 | changing 0s to 1s in a binary mask which will enable 182 | gradient flow. Weights default value are 0 and it can 183 | be changed in this function. The number of parameters 184 | to be regrown is determined by the total_regrowth 185 | parameter. The masking object in conjunction with the name 186 | of the layer enables the access to further statistics 187 | and objects that allow more flexibility to implement 188 | custom growth functions. 189 | 190 | Args: 191 | masking Masking class with state about current 192 | layers and the entire sparse network. 193 | 194 | name The name of the layer. This can be used to 195 | access layer-specific statistics in the 196 | masking class. 197 | 198 | new_mask The binary mask. 1s indicated active weights. 199 | This binary mask has already been pruned in the 200 | pruning step that preceeds the growth step. 201 | 202 | total_regrowth This variable determines the number of 203 | parameters to regrowtn in this function. 204 | It is automatically determined by the 205 | redistribution function and algorithms 206 | internal to the sparselearning library. 207 | 208 | weight The weight of the respective sparse layer. 209 | This is a torch parameter. 210 | 211 | Returns: 212 | mask Binary mask with newly grown weights. 213 | 1s indicated active weights in the binary mask. 214 | 215 | Access to optimizer: 216 | masking.optimizer 217 | 218 | Access to momentum/Adam update: 219 | masking.get_momentum_for_weight(weight) 220 | 221 | Accessable global statistics: 222 | 223 | Layer statistics: 224 | Non-zero count of layer: 225 | masking.name2nonzeros[name] 226 | Zero count of layer: 227 | masking.name2zeros[name] 228 | Redistribution proportion: 229 | masking.name2variance[name] 230 | Number of items removed through pruning: 231 | masking.name2removed[name] 232 | 233 | Network statistics: 234 | Total number of nonzero parameter in the network: 235 | masking.total_nonzero = 0 236 | Total number of zero-valued parameter in the network: 237 | masking.total_zero = 0 238 | Total number of parameters removed in pruning: 239 | masking.total_removed = 0 240 | """ 241 | grad = masking.get_momentum_for_weight(weight) 242 | if grad.dtype == torch.float16: 243 | grad = grad*(new_mask==0).half() 244 | else: 245 | grad = grad*(new_mask==0).float() 246 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 247 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0 248 | 249 | return new_mask 250 | 251 | def momentum_neuron_growth(masking, name, new_mask, total_regrowth, weight): 252 | grad = masking.get_momentum_for_weight(weight) 253 | 254 | M = torch.abs(grad) 255 | if len(M.shape) == 2: sum_dim = [1] 256 | elif len(M.shape) == 4: sum_dim = [1, 2, 3] 257 | 258 | v = M.mean(sum_dim).data 259 | v /= v.sum() 260 | 261 | slots_per_neuron = (new_mask==0).sum(sum_dim) 262 | 263 | M = M*(new_mask==0).float() 264 | for i, fraction in enumerate(v): 265 | neuron_regrowth = math.floor(fraction.item()*total_regrowth) 266 | available = slots_per_neuron[i].item() 267 | 268 | y, idx = torch.sort(M[i].flatten()) 269 | if neuron_regrowth > available: 270 | neuron_regrowth = available 271 | # TODO: Work into more stable growth method 272 | threshold = y[-(neuron_regrowth)].item() 273 | if threshold == 0.0: continue 274 | if neuron_regrowth < 10: continue 275 | new_mask[i] = new_mask[i] | (M[i] > threshold) 276 | 277 | return new_mask 278 | 279 | 280 | def global_momentum_growth(masking, total_regrowth): 281 | togrow = total_regrowth 282 | total_grown = 0 283 | last_grown = 0 284 | while total_grown < togrow*(1.0-masking.tolerance) or (total_grown > togrow*(1.0+masking.tolerance)): 285 | total_grown = 0 286 | total_possible = 0 287 | for module in masking.modules: 288 | for name, weight in module.named_parameters(): 289 | if name not in masking.masks: continue 290 | 291 | new_mask = masking.masks[name] 292 | grad = masking.get_momentum_for_weight(weight) 293 | grad = grad*(new_mask==0).float() 294 | possible = (grad !=0.0).sum().item() 295 | total_possible += possible 296 | grown = (torch.abs(grad.data) > masking.growth_threshold).sum().item() 297 | total_grown += grown 298 | if total_grown == last_grown: break 299 | last_grown = total_grown 300 | 301 | 302 | if total_grown > togrow*(1.0+masking.tolerance): 303 | masking.growth_threshold *= 1.02 304 | #masking.growth_increment *= 0.95 305 | elif total_grown < togrow*(1.0-masking.tolerance): 306 | masking.growth_threshold *= 0.98 307 | #masking.growth_increment *= 0.95 308 | 309 | total_new_nonzeros = 0 310 | for module in masking.modules: 311 | for name, weight in module.named_parameters(): 312 | if name not in masking.masks: continue 313 | 314 | new_mask = masking.masks[name] 315 | grad = masking.get_momentum_for_weight(weight) 316 | grad = grad*(new_mask==0).float() 317 | masking.masks[name][:] = (new_mask.bool() | (torch.abs(grad.data) > masking.growth_threshold)).float() 318 | total_new_nonzeros += new_mask.sum().item() 319 | return total_new_nonzeros 320 | 321 | 322 | 323 | 324 | prune_funcs = {} 325 | prune_funcs['magnitude'] = magnitude_prune 326 | prune_funcs['SET'] = magnitude_and_negativity_prune 327 | prune_funcs['global_magnitude'] = global_magnitude_prune 328 | 329 | growth_funcs = {} 330 | growth_funcs['random'] = random_growth 331 | growth_funcs['momentum'] = momentum_growth 332 | growth_funcs['momentum_neuron'] = momentum_neuron_growth 333 | growth_funcs['global_momentum_growth'] = global_momentum_growth 334 | 335 | redistribution_funcs = {} 336 | redistribution_funcs['momentum'] = momentum_redistribution 337 | redistribution_funcs['nonzero'] = nonzero_redistribution 338 | redistribution_funcs['magnitude'] = magnitude_redistribution 339 | redistribution_funcs['none'] = no_redistribution 340 | -------------------------------------------------------------------------------- /sparselearning/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torchvision import datasets, transforms 7 | 8 | class DatasetSplitter(torch.utils.data.Dataset): 9 | """This splitter makes sure that we always use the same training/validation split""" 10 | def __init__(self,parent_dataset,split_start=-1,split_end= -1): 11 | split_start = split_start if split_start != -1 else 0 12 | split_end = split_end if split_end != -1 else len(parent_dataset) 13 | assert split_start <= len(parent_dataset) - 1 and split_end <= len(parent_dataset) and split_start < split_end , "invalid dataset split" 14 | 15 | self.parent_dataset = parent_dataset 16 | self.split_start = split_start 17 | self.split_end = split_end 18 | 19 | def __len__(self): 20 | return self.split_end - self.split_start 21 | 22 | 23 | def __getitem__(self,index): 24 | assert index < len(self),"index out of bounds in split_datset" 25 | return self.parent_dataset[index + self.split_start] 26 | 27 | def get_cifar10_dataloaders(args, validation_split=0.0, max_threads=10): 28 | """Creates augmented train, validation, and test data loaders.""" 29 | 30 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), 31 | (0.2023, 0.1994, 0.2010)) 32 | 33 | train_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 36 | (4,4,4,4),mode='reflect').squeeze()), 37 | transforms.ToPILImage(), 38 | transforms.RandomCrop(32), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | normalize, 42 | ]) 43 | 44 | test_transform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | full_dataset = datasets.CIFAR10('_dataset', True, train_transform, download=True) 50 | test_dataset = datasets.CIFAR10('_dataset', False, test_transform, download=False) 51 | 52 | 53 | # we need at least two threads 54 | max_threads = 2 if max_threads < 2 else max_threads 55 | if max_threads >= 6: 56 | val_threads = 2 57 | train_threads = max_threads - val_threads 58 | else: 59 | val_threads = 1 60 | train_threads = max_threads - 1 61 | 62 | 63 | valid_loader = None 64 | if validation_split > 0.0: 65 | split = int(np.floor((1.0-validation_split) * len(full_dataset))) 66 | train_dataset = DatasetSplitter(full_dataset,split_end=split) 67 | val_dataset = DatasetSplitter(full_dataset,split_start=split) 68 | train_loader = torch.utils.data.DataLoader( 69 | train_dataset, 70 | args.batch_size, 71 | num_workers=train_threads, 72 | pin_memory=True, shuffle=True) 73 | valid_loader = torch.utils.data.DataLoader( 74 | val_dataset, 75 | args.test_batch_size, 76 | num_workers=val_threads, 77 | pin_memory=True) 78 | else: 79 | train_loader = torch.utils.data.DataLoader( 80 | full_dataset, 81 | args.batch_size, 82 | num_workers=8, 83 | pin_memory=True, shuffle=True) 84 | 85 | print('Train loader length', len(train_loader)) 86 | 87 | test_loader = torch.utils.data.DataLoader( 88 | test_dataset, 89 | args.test_batch_size, 90 | shuffle=False, 91 | num_workers=1, 92 | pin_memory=True) 93 | 94 | return train_loader, valid_loader, test_loader 95 | 96 | 97 | def get_mnist_dataloaders(args, validation_split=0.0): 98 | """Creates augmented train, validation, and test data loaders.""" 99 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 100 | transform = transform=transforms.Compose([transforms.ToTensor(),normalize]) 101 | 102 | full_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform) 103 | test_dataset = datasets.MNIST('../data', train=False, transform=transform) 104 | 105 | dataset_size = len(full_dataset) 106 | indices = list(range(dataset_size)) 107 | split = int(np.floor(validation_split * dataset_size)) 108 | 109 | valid_loader = None 110 | if validation_split > 0.0: 111 | split = int(np.floor((1.0-validation_split) * len(full_dataset))) 112 | train_dataset = DatasetSplitter(full_dataset,split_end=split) 113 | val_dataset = DatasetSplitter(full_dataset,split_start=split) 114 | train_loader = torch.utils.data.DataLoader( 115 | train_dataset, 116 | args.batch_size, 117 | num_workers=8, 118 | pin_memory=True, shuffle=True) 119 | valid_loader = torch.utils.data.DataLoader( 120 | val_dataset, 121 | args.test_batch_size, 122 | num_workers=2, 123 | pin_memory=True) 124 | else: 125 | train_loader = torch.utils.data.DataLoader( 126 | full_dataset, 127 | args.batch_size, 128 | num_workers=8, 129 | pin_memory=True, shuffle=True) 130 | 131 | print('Train loader length', len(train_loader)) 132 | 133 | test_loader = torch.utils.data.DataLoader( 134 | test_dataset, 135 | args.test_batch_size, 136 | shuffle=False, 137 | num_workers=1, 138 | pin_memory=True) 139 | 140 | return train_loader, valid_loader, test_loader 141 | 142 | 143 | def plot_class_feature_histograms(args, model, device, test_loader, optimizer): 144 | if not os.path.exists('./results'): os.mkdir('./results') 145 | model.eval() 146 | agg = {} 147 | num_classes = 10 148 | feat_id = 0 149 | sparse = not args.dense 150 | model_name = 'alexnet' 151 | #model_name = 'vgg' 152 | #model_name = 'wrn' 153 | 154 | 155 | densities = None 156 | for batch_idx, (data, target) in enumerate(test_loader): 157 | if batch_idx % 100 == 0: print(batch_idx,'/', len(test_loader)) 158 | with torch.no_grad(): 159 | #if batch_idx == 10: break 160 | data, target = data.to(device), target.to(device) 161 | for cls in range(num_classes): 162 | #print('=='*50) 163 | #print('CLASS {0}'.format(cls)) 164 | model.t = target 165 | sub_data = data[target == cls] 166 | 167 | output = model(sub_data) 168 | 169 | feats = model.feats 170 | if densities is None: 171 | densities = [] 172 | densities += model.densities 173 | 174 | if len(agg) == 0: 175 | for feat_id, feat in enumerate(feats): 176 | agg[feat_id] = [] 177 | #print(feat.shape) 178 | for i in range(feat.shape[1]): 179 | agg[feat_id].append(np.zeros((num_classes,))) 180 | 181 | for feat_id, feat in enumerate(feats): 182 | map_contributions = torch.abs(feat).sum([0, 2, 3]) 183 | for map_id in range(map_contributions.shape[0]): 184 | #print(feat_id, map_id, cls) 185 | #print(len(agg), len(agg[feat_id]), len(agg[feat_id][map_id]), len(feats)) 186 | agg[feat_id][map_id][cls] += map_contributions[map_id].item() 187 | 188 | del model.feats[:] 189 | del model.densities[:] 190 | model.feats = [] 191 | model.densities = [] 192 | 193 | if sparse: 194 | np.save('./results/{0}_sparse_density_data'.format(model_name), densities) 195 | 196 | for feat_id, map_data in agg.items(): 197 | data = np.array(map_data) 198 | #print(feat_id, data) 199 | full_contribution = data.sum() 200 | #print(full_contribution, data) 201 | contribution_per_channel = ((1.0/full_contribution)*data.sum(1)) 202 | #print('pre', data.shape[0]) 203 | channels = data.shape[0] 204 | #data = data[contribution_per_channel > 0.001] 205 | 206 | channel_density = np.cumsum(np.sort(contribution_per_channel)) 207 | print(channel_density) 208 | idx = np.argsort(contribution_per_channel) 209 | 210 | threshold_idx = np.searchsorted(channel_density, 0.05) 211 | print(data.shape, 'pre') 212 | data = data[idx[threshold_idx:]] 213 | print(data.shape, 'post') 214 | 215 | #perc = np.percentile(contribution_per_channel[contribution_per_channel > 0.0], 10) 216 | #print(contribution_per_channel, perc, feat_id) 217 | #data = data[contribution_per_channel > perc] 218 | #print(contribution_per_channel[contribution_per_channel < perc].sum()) 219 | #print('post', data.shape[0]) 220 | normed_data = np.max(data/np.sum(data,1).reshape(-1, 1), 1) 221 | #normed_data = (data/np.sum(data,1).reshape(-1, 1) > 0.2).sum(1) 222 | #counts, bins = np.histogram(normed_data, bins=4, range=(0, 4)) 223 | np.save('./results/{2}_{1}_feat_data_layer_{0}'.format(feat_id, 'sparse' if sparse else 'dense', model_name), normed_data) 224 | #plt.ylim(0, channels/2.0) 225 | ##plt.hist(normed_data, bins=range(0, 5)) 226 | #plt.hist(normed_data, bins=[(i+20)/float(200) for i in range(180)]) 227 | #plt.xlim(0.1, 0.5) 228 | #if sparse: 229 | # plt.title("Sparse: Conv2D layer {0}".format(feat_id)) 230 | # plt.savefig('./output/feat_histo/layer_{0}_sp.png'.format(feat_id)) 231 | #else: 232 | # plt.title("Dense: Conv2D layer {0}".format(feat_id)) 233 | # plt.savefig('./output/feat_histo/layer_{0}_d.png'.format(feat_id)) 234 | #plt.clf() 235 | --------------------------------------------------------------------------------