├── LICENSE ├── NOTICE ├── README.md ├── conda-env.txt ├── experiments ├── datasets.py ├── experiments.py └── models.py ├── pruning ├── mask_networks.py └── pruning_algos.py ├── train_cifar_tiny_imagenet.py └── train_imagenet.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | FORCE 2 | Copyright (c) 2020-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------------- 23 | 24 | This project contains subcomponents with separate copyright notices and license terms. 25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 26 | 27 | ===== 28 | 29 | kuangliu/pytorch-cifar 30 | https://github.com/kuangliu/pytorch-cifar 31 | 32 | 33 | MIT License 34 | 35 | Copyright (c) 2017 liukuang 36 | 37 | Permission is hereby granted, free of charge, to any person obtaining a copy 38 | of this software and associated documentation files (the "Software"), to deal 39 | in the Software without restriction, including without limitation the rights 40 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 41 | copies of the Software, and to permit persons to whom the Software is 42 | furnished to do so, subject to the following conditions: 43 | 44 | The above copyright notice and this permission notice shall be included in all 45 | copies or substantial portions of the Software. 46 | 47 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 48 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 49 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 50 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 51 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 52 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 53 | SOFTWARE. 54 | 55 | ===== 56 | 57 | pytorch/examples 58 | https://github.com/pytorch/examples 59 | 60 | 61 | BSD 3-Clause License 62 | 63 | Copyright (c) 2017, 64 | All rights reserved. 65 | 66 | Redistribution and use in source and binary forms, with or without 67 | modification, are permitted provided that the following conditions are met: 68 | 69 | * Redistributions of source code must retain the above copyright notice, this 70 | list of conditions and the following disclaimer. 71 | 72 | * Redistributions in binary form must reproduce the above copyright notice, 73 | this list of conditions and the following disclaimer in the documentation 74 | and/or other materials provided with the distribution. 75 | 76 | * Neither the name of the copyright holder nor the names of its 77 | contributors may be used to endorse or promote products derived from 78 | this software without specific prior written permission. 79 | 80 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 81 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 82 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 83 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 84 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 85 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 86 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 87 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 88 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 89 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 90 | 91 | ===== 92 | 93 | alecwangcq/GraSP 94 | https://github.com/alecwangcq/GraSP 95 | 96 | 97 | MIT License 98 | 99 | Copyright (c) 2019 100 | 101 | Permission is hereby granted, free of charge, to any person obtaining a copy 102 | of this software and associated documentation files (the "Software"), to deal 103 | in the Software without restriction, including without limitation the rights 104 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 105 | copies of the Software, and to permit persons to whom the Software is 106 | furnished to do so, subject to the following conditions: 107 | 108 | The above copyright notice and this permission notice shall be included in all 109 | copies or substantial portions of the Software. 110 | 111 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 112 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 113 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 114 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 115 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 116 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 117 | SOFTWARE. 118 | 119 | ===== 120 | 121 | mi-lad/snip 122 | https://github.com/mi-lad/snip 123 | 124 | 125 | 126 | MIT License 127 | 128 | Copyright (c) 2019 Milad Alizadeh 129 | 130 | Permission is hereby granted, free of charge, to any person obtaining a copy 131 | of this software and associated documentation files (the "Software"), to deal 132 | in the Software without restriction, including without limitation the rights 133 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 134 | copies of the Software, and to permit persons to whom the Software is 135 | furnished to do so, subject to the following conditions: 136 | 137 | The above copyright notice and this permission notice shall be included in all 138 | copies or substantial portions of the Software. 139 | 140 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 141 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 142 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 143 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 144 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 145 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 146 | SOFTWARE. 147 | 148 | ===== 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FORCE 2 | Official implementation of FORCE algorithm from *Progressive Skeletonization: Trimming more fat from a network at initialization* (https://arxiv.org/abs/2006.09081) 3 | 4 | ## Requirements 5 | You can create a conda environment with all the necessary libraries with the command: 6 | 7 | conda create --name myenv --file conda-env.txt 8 | 9 | You will need to download the Tiny-Imagenet and Imagenet datasets. For Tiny-Imagenet modify the path specified in `experiments/datasets.py` For Imagenet you will need to specify it when runing the code. 10 | 11 | ## Examples 12 | 13 | ### CIFAR and Tiny Imagenet datasets 14 | To run an experiment with CIFAR10/100 or Tiny Imagenet datasets, run: 15 | 16 | python train_cifar_tiny_imagenet.py --network_name vgg19 --pruning_factor 0.01 --prune_method 1 --dataset CIFAR10 --num_steps 60 --mode exp --num_batches 1 17 | 18 | Alternatively, you can change the `--dataset` option with `CIFAR100` or `tiny_imagenet`. Bear in mind that for `CIFAR100` you will need to set `--num_batches` to 10 and for `tiny_imagenet` to 20. You may also change the architecture and use `resnet50` for instance. 19 | 20 | ### Imagenet 21 | To run an experiment with Imagenet run: 22 | 23 | python train_imagenet.py /path_to_dataset/ --network-name resnet50 --pruning_factor 0.05 --prune_method 1 --num_steps 60 --mode exp --num_batches 40 --epochs 90 24 | 25 | ## License 26 | 27 | ``` 28 | Copyright (c) 2020-present NAVER Corp. 29 | 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in 39 | all copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 47 | THE SOFTWARE. 48 | ``` 49 | -------------------------------------------------------------------------------- /conda-env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-main.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/_tflow_select-2.3.0-mkl.conda 7 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2020.1.1-0.conda 8 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2019.1-144.conda 9 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.conda 10 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-8.2.0-hdf63c60_1.conda 11 | https://conda.anaconda.org/conda-forge/linux-64/pandoc-2.9.1.1-0.tar.bz2 12 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-8.2.0-hdf63c60_1.conda 13 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2019.1-144.conda 14 | https://conda.anaconda.org/conda-forge/linux-64/blas-1.0-mkl.tar.bz2 15 | https://repo.anaconda.com/pkgs/main/linux-64/c-ares-1.15.0-h7b6447c_1001.conda 16 | https://repo.anaconda.com/pkgs/main/linux-64/expat-2.2.6-he6710b0_0.conda 17 | https://repo.anaconda.com/pkgs/main/linux-64/icu-58.2-h9c2bf20_1.conda 18 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.conda 19 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.conda 20 | https://conda.anaconda.org/conda-forge/linux-64/libsodium-1.0.17-h516909a_0.tar.bz2 21 | https://repo.anaconda.com/pkgs/main/linux-64/libuuid-1.0.3-h1bed415_2.conda 22 | https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.13-h1bed415_1.conda 23 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.1-he6710b0_1.conda 24 | https://repo.anaconda.com/pkgs/main/linux-64/openssl-1.1.1g-h7b6447c_0.conda 25 | https://repo.anaconda.com/pkgs/main/linux-64/pcre-8.43-he6710b0_0.conda 26 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.4-h14c3975_4.conda 27 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.conda 28 | https://repo.anaconda.com/pkgs/main/linux-64/glib-2.63.1-h5a9c865_0.conda 29 | https://repo.anaconda.com/pkgs/main/linux-64/hdf5-1.10.4-hb1b8bf9_0.conda 30 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20170329-h6b74fdf_2.conda 31 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.35-hbc83047_0.conda 32 | https://repo.anaconda.com/pkgs/main/linux-64/libprotobuf-3.11.2-hd408876_0.conda 33 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.0.9-he85c1e1_2.conda 34 | https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.9.9-hea5a465_1.conda 35 | https://repo.anaconda.com/pkgs/main/linux-64/readline-7.0-h7b6447c_5.conda 36 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.8-hbc83047_0.conda 37 | https://conda.anaconda.org/conda-forge/linux-64/zeromq-4.3.2-he1b5a44_2.tar.bz2 38 | https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.12-h746ee38_0.conda 39 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.9.1-h8a8886c_1.conda 40 | https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.0-hb453b48_1.conda 41 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.26.0-h7b6447c_0.conda 42 | https://repo.anaconda.com/pkgs/main/linux-64/fontconfig-2.13.0-h9420a91_0.conda 43 | https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.conda 44 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.7.1-h0371630_7.conda 45 | https://repo.anaconda.com/pkgs/main/linux-64/astor-0.8.0-py37_0.conda 46 | https://conda.anaconda.org/conda-forge/noarch/attrs-19.3.0-py_0.tar.bz2 47 | https://repo.anaconda.com/pkgs/main/linux-64/backcall-0.1.0-py37_0.conda 48 | https://repo.anaconda.com/pkgs/main/linux-64/certifi-2020.4.5.1-py37_0.conda 49 | https://conda.anaconda.org/conda-forge/noarch/decorator-4.4.1-py_0.tar.bz2 50 | https://conda.anaconda.org/conda-forge/noarch/defusedxml-0.6.0-py_0.tar.bz2 51 | https://conda.anaconda.org/conda-forge/noarch/easydict-1.9-py_0.tar.bz2 52 | https://conda.anaconda.org/conda-forge/linux-64/entrypoints-0.3-py37_1000.tar.bz2 53 | https://conda.anaconda.org/conda-forge/noarch/gast-0.3.2-py_0.tar.bz2 54 | https://repo.anaconda.com/pkgs/main/linux-64/ipython_genutils-0.2.0-py37_0.conda 55 | https://conda.anaconda.org/conda-forge/noarch/json5-0.8.5-py_0.tar.bz2 56 | https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.1.0-py37he6710b0_0.conda 57 | https://conda.anaconda.org/conda-forge/linux-64/markupsafe-1.1.1-py37h516909a_0.tar.bz2 58 | https://conda.anaconda.org/conda-forge/linux-64/mistune-0.8.4-py37h516909a_1000.tar.bz2 59 | https://conda.anaconda.org/conda-forge/noarch/more-itertools-8.1.0-py_0.tar.bz2 60 | https://repo.anaconda.com/pkgs/main/linux-64/ninja-1.8.2-py37h6bb024c_1.conda 61 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.15.4-py37hde5b4d6_0.conda 62 | https://repo.anaconda.com/pkgs/main/linux-64/olefile-0.46-py37_0.conda 63 | https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.4.2-py_1.tar.bz2 64 | https://conda.anaconda.org/conda-forge/noarch/parso-0.5.2-py_0.tar.bz2 65 | https://repo.anaconda.com/pkgs/main/linux-64/pickleshare-0.7.5-py37_0.conda 66 | https://conda.anaconda.org/conda-forge/noarch/prometheus_client-0.7.1-py_0.tar.bz2 67 | https://conda.anaconda.org/conda-forge/linux-64/ptyprocess-0.6.0-py37_0.tar.bz2 68 | https://repo.anaconda.com/pkgs/main/linux-64/pycparser-2.19-py37_0.conda 69 | https://repo.anaconda.com/pkgs/main/noarch/pyparsing-2.4.6-py_0.conda 70 | https://repo.anaconda.com/pkgs/main/noarch/pytz-2019.3-py_0.tar.bz2 71 | https://conda.anaconda.org/conda-forge/linux-64/pyzmq-18.1.1-py37h1768529_0.tar.bz2 72 | https://repo.anaconda.com/pkgs/main/linux-64/qt-5.9.7-h5867ecd_1.conda 73 | https://conda.anaconda.org/conda-forge/noarch/send2trash-1.5.0-py_0.tar.bz2 74 | https://repo.anaconda.com/pkgs/main/linux-64/sip-4.19.8-py37hf484d3e_0.conda 75 | https://repo.anaconda.com/pkgs/main/linux-64/six-1.12.0-py37_0.conda 76 | https://repo.anaconda.com/pkgs/main/linux-64/termcolor-1.1.0-py37_1.conda 77 | https://conda.anaconda.org/conda-forge/noarch/testpath-0.4.4-py_0.tar.bz2 78 | https://conda.anaconda.org/conda-forge/linux-64/tornado-6.0.3-py37h516909a_0.tar.bz2 79 | https://repo.anaconda.com/pkgs/main/linux-64/tqdm-4.28.1-py37h28b3542_0.conda 80 | https://repo.anaconda.com/pkgs/main/linux-64/wcwidth-0.1.7-py37_0.conda 81 | https://conda.anaconda.org/conda-forge/noarch/webencodings-0.5.1-py_1.tar.bz2 82 | https://conda.anaconda.org/conda-forge/noarch/werkzeug-1.0.0-py_0.tar.bz2 83 | https://repo.anaconda.com/pkgs/main/linux-64/wrapt-1.11.2-py37h7b6447c_0.conda 84 | https://conda.anaconda.org/conda-forge/linux-64/absl-py-0.8.1-py37_0.tar.bz2 85 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.11.5-py37he75722e_1.conda 86 | https://repo.anaconda.com/pkgs/main/linux-64/cycler-0.10.0-py37_0.conda 87 | https://conda.anaconda.org/conda-forge/noarch/google-pasta-0.1.8-py_0.tar.bz2 88 | https://conda.anaconda.org/conda-forge/linux-64/jedi-0.15.2-py37_0.tar.bz2 89 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.0.6-py37hd81dba3_0.conda 90 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.0.2-py37hd81dba3_0.conda 91 | https://conda.anaconda.org/conda-forge/linux-64/pexpect-4.7.0-py37_0.tar.bz2 92 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-5.3.0-py37h34e0f95_0.conda 93 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.9.2-py37h05f1152_2.conda 94 | https://conda.anaconda.org/conda-forge/linux-64/pyrsistent-0.15.7-py37h516909a_0.tar.bz2 95 | https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.8.1-py_0.tar.bz2 96 | https://conda.anaconda.org/conda-forge/linux-64/setuptools-40.6.2-py37_0.tar.bz2 97 | https://conda.anaconda.org/conda-forge/linux-64/terminado-0.8.3-py37_0.tar.bz2 98 | https://conda.anaconda.org/conda-forge/linux-64/traitlets-4.3.3-py37_0.tar.bz2 99 | https://conda.anaconda.org/conda-forge/noarch/zipp-0.6.0-py_0.tar.bz2 100 | https://conda.anaconda.org/conda-forge/noarch/bleach-3.1.0-py_0.tar.bz2 101 | https://conda.anaconda.org/conda-forge/linux-64/grpcio-1.23.0-py37he9ae1f9_0.tar.bz2 102 | https://conda.anaconda.org/conda-forge/linux-64/importlib_metadata-1.4.0-py37_0.tar.bz2 103 | https://conda.anaconda.org/conda-forge/noarch/jinja2-2.10.3-py_0.tar.bz2 104 | https://conda.anaconda.org/conda-forge/linux-64/jupyter_core-4.6.1-py37_0.tar.bz2 105 | https://conda.anaconda.org/conda-forge/noarch/markdown-3.2.1-py_0.tar.bz2 106 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.15.4-py37h7e9f1db_0.conda 107 | https://conda.anaconda.org/conda-forge/linux-64/protobuf-3.11.2-py37he1b5a44_0.tar.bz2 108 | https://conda.anaconda.org/conda-forge/noarch/pygments-2.5.2-py_0.tar.bz2 109 | https://conda.anaconda.org/conda-forge/linux-64/wheel-0.32.3-py37_0.tar.bz2 110 | https://repo.anaconda.com/pkgs/main/linux-64/h5py-2.9.0-py37h7918eee_0.conda 111 | https://conda.anaconda.org/conda-forge/linux-64/jsonschema-3.2.0-py37_0.tar.bz2 112 | https://conda.anaconda.org/conda-forge/linux-64/jupyter_client-5.3.4-py37_0.tar.bz2 113 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.0.2-py37h5429711_0.conda 114 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-1.0.3-py37h0573a6f_0.conda 115 | https://repo.anaconda.com/pkgs/main/linux-64/pip-18.1-py37_0.conda 116 | https://conda.anaconda.org/conda-forge/noarch/prompt_toolkit-3.0.2-py_0.tar.bz2 117 | https://conda.anaconda.org/pytorch/linux-64/pytorch-1.0.0-py3.7_cuda9.0.176_cudnn7.4.1_1.tar.bz2 118 | https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.2.1-py37h7c811a0_0.conda 119 | https://conda.anaconda.org/conda-forge/linux-64/tensorboard-1.14.0-py37_0.tar.bz2 120 | https://conda.anaconda.org/pytorch/linux-64/ignite-0.1.2-py37_0.tar.bz2 121 | https://repo.anaconda.com/pkgs/main/linux-64/ipython-7.11.1-py37h39e3cac_0.conda 122 | https://repo.anaconda.com/pkgs/main/noarch/keras-applications-1.0.8-py_0.tar.bz2 123 | https://repo.anaconda.com/pkgs/main/noarch/keras-preprocessing-1.1.0-py_1.tar.bz2 124 | https://conda.anaconda.org/conda-forge/noarch/nbformat-5.0.3-py_0.tar.bz2 125 | https://repo.anaconda.com/pkgs/main/linux-64/scikit-learn-0.20.3-py37hd81dba3_0.conda 126 | https://repo.anaconda.com/pkgs/main/noarch/seaborn-0.10.1-py_0.conda 127 | https://conda.anaconda.org/pytorch/noarch/torchvision-0.2.1-py_2.tar.bz2 128 | https://conda.anaconda.org/conda-forge/linux-64/ipykernel-5.1.3-py37h5ca1d4c_0.tar.bz2 129 | https://conda.anaconda.org/conda-forge/linux-64/nbconvert-5.6.1-py37_0.tar.bz2 130 | https://repo.anaconda.com/pkgs/main/linux-64/tensorflow-base-1.14.0-mkl_py37h7ce6ba3_0.conda 131 | https://conda.anaconda.org/conda-forge/linux-64/notebook-6.0.1-py37_0.tar.bz2 132 | https://repo.anaconda.com/pkgs/main/noarch/tensorflow-estimator-1.14.0-py_0.tar.bz2 133 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-1.0.6-py_0.tar.bz2 134 | https://repo.anaconda.com/pkgs/main/linux-64/tensorflow-1.14.0-mkl_py37h45c423b_0.conda 135 | https://conda.anaconda.org/conda-forge/noarch/jupyterlab-1.2.4-py_0.tar.bz2 136 | -------------------------------------------------------------------------------- /experiments/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import DataLoader 6 | from torchvision.datasets import CIFAR10 7 | from torchvision.transforms import Compose, ToTensor, Normalize 8 | 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | 11 | 12 | #################################################################### 13 | ###################### CIFAR ####################### 14 | #################################################################### 15 | 16 | 17 | def get_cifar_train_valid_loader(batch_size, 18 | augment, 19 | random_seed, 20 | valid_size=0.1, 21 | shuffle=True, 22 | num_workers=1, 23 | pin_memory=False, 24 | dataset_name='CIFAR10'): 25 | """ 26 | Utility function for loading and returning train and valid 27 | multi-process iterators over the CIFAR-10 dataset. A sample 28 | 9x9 grid of the images can be optionally displayed. 29 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 30 | Params 31 | ------ 32 | - data_dir: path directory to the dataset. 33 | - batch_size: how many samples per batch to load. 34 | - augment: whether to apply the data augmentation scheme 35 | mentioned in the paper. Only applied on the train split. 36 | - random_seed: fix seed for reproducibility. 37 | - valid_size: percentage split of the training set used for 38 | the validation set. Should be a float in the range [0, 1]. 39 | - shuffle: whether to shuffle the train/validation indices. 40 | - num_workers: number of subprocesses to use when loading the dataset. 41 | - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to 42 | True if using GPU. 43 | Returns 44 | ------- 45 | - train_loader: training set iterator. 46 | - valid_loader: validation set iterator. 47 | """ 48 | error_msg = "[!] valid_size should be in the range [0, 1]." 49 | assert ((valid_size >= 0) and (valid_size <= 1)), error_msg 50 | 51 | if dataset_name=='CIFAR10': 52 | normalize = transforms.Normalize( 53 | mean=[0.4914, 0.4822, 0.4465], 54 | std=[0.2023, 0.1994, 0.2010], 55 | ) 56 | elif dataset_name=='CIFAR100': 57 | normalize = transforms.Normalize( 58 | mean=[0.5071, 0.4867, 0.4408], 59 | std=[0.2675, 0.2565, 0.2761], 60 | ) 61 | 62 | # define transforms 63 | valid_transform = transforms.Compose([ 64 | transforms.ToTensor(), 65 | normalize, 66 | ]) 67 | if augment: 68 | train_transform = transforms.Compose([ 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | normalize, 73 | ]) 74 | else: 75 | train_transform = transforms.Compose([ 76 | transforms.ToTensor(), 77 | normalize, 78 | ]) 79 | 80 | # load the dataset 81 | data_dir = '~/data' 82 | 83 | if dataset_name == 'CIFAR10': 84 | train_dataset = torchvision.datasets.CIFAR10( 85 | root=data_dir, train=True, 86 | download=True, transform=train_transform, 87 | ) 88 | 89 | valid_dataset = torchvision.datasets.CIFAR10( 90 | root=data_dir, train=True, 91 | download=True, transform=valid_transform, 92 | ) 93 | elif dataset_name == 'CIFAR100': 94 | train_dataset = torchvision.datasets.CIFAR100( 95 | root=data_dir, train=True, 96 | download=True, transform=train_transform, 97 | ) 98 | 99 | valid_dataset = torchvision.datasets.CIFAR100( 100 | root=data_dir, train=True, 101 | download=True, transform=valid_transform, 102 | ) 103 | 104 | num_train = len(train_dataset) 105 | indices = list(range(num_train)) 106 | split = int(np.floor(valid_size * num_train)) 107 | 108 | if shuffle: 109 | np.random.seed(random_seed) 110 | np.random.shuffle(indices) 111 | 112 | train_idx, valid_idx = indices[split:], indices[:split] 113 | train_sampler = SubsetRandomSampler(train_idx) 114 | valid_sampler = SubsetRandomSampler(valid_idx) 115 | 116 | train_loader = DataLoader( 117 | train_dataset, batch_size=batch_size, sampler=train_sampler, 118 | num_workers=num_workers, pin_memory=pin_memory, 119 | ) 120 | valid_loader = DataLoader( 121 | valid_dataset, batch_size=batch_size, sampler=valid_sampler, 122 | num_workers=num_workers, pin_memory=pin_memory, 123 | ) 124 | 125 | return (train_loader, valid_loader) 126 | 127 | 128 | def get_cifar_test_loader(batch_size, 129 | shuffle=False, 130 | num_workers=1, 131 | pin_memory=False, 132 | dataset_name='CIFAR10'): 133 | """ 134 | Utility function for loading and returning a multi-process 135 | test iterator over the CIFAR-10 dataset. 136 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 137 | Params 138 | ------ 139 | - data_dir: path directory to the dataset. 140 | - batch_size: how many samples per batch to load. 141 | - shuffle: whether to shuffle the dataset after every epoch. 142 | - num_workers: number of subprocesses to use when loading the dataset. 143 | - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to 144 | True if using GPU. 145 | Returns 146 | ------- 147 | - data_loader: test set iterator. 148 | """ 149 | if dataset_name=='CIFAR10': 150 | normalize = transforms.Normalize( 151 | mean=[0.4914, 0.4822, 0.4465], 152 | std=[0.2023, 0.1994, 0.2010], 153 | ) 154 | elif dataset_name=='CIFAR100': 155 | normalize = transforms.Normalize( 156 | mean=[0.5071, 0.4867, 0.4408], 157 | std=[0.2675, 0.2565, 0.2761], 158 | ) 159 | 160 | # define transform 161 | transform = transforms.Compose([ 162 | transforms.ToTensor(), 163 | normalize, 164 | ]) 165 | 166 | # load the dataset 167 | data_dir = '~/data' 168 | if dataset_name == 'CIFAR10': 169 | dataset = torchvision.datasets.CIFAR10( 170 | root=data_dir, train=False, 171 | download=True, transform=transform, 172 | ) 173 | elif dataset_name == 'CIFAR100': 174 | dataset = torchvision.datasets.CIFAR100( 175 | root=data_dir, train=False, 176 | download=True, transform=transform, 177 | ) 178 | 179 | 180 | data_loader = DataLoader( 181 | dataset, batch_size=batch_size, shuffle=shuffle, 182 | num_workers=num_workers, pin_memory=pin_memory, 183 | ) 184 | 185 | return data_loader 186 | 187 | #################################################################### 188 | ########################## Tiny Imagenet ########################### 189 | #################################################################### 190 | 191 | def get_tiny_imagenet_train_valid_loader(batch_size, 192 | augment, 193 | shuffle=True, 194 | num_workers=1): 195 | 196 | root = '/path_to_tiny-imagenet_dataset/' 197 | tiny_mean = [0.48024578664982126, 0.44807218089384643, 0.3975477478649648] 198 | tiny_std = [0.2769864069088257, 0.26906448510256, 0.282081906210584] 199 | transform_train = transforms.Compose([ 200 | transforms.RandomCrop(64, padding=4), 201 | transforms.RandomHorizontalFlip(), 202 | transforms.ToTensor(), 203 | transforms.Normalize(tiny_mean, tiny_std)]) 204 | 205 | transform_test = transforms.Compose([ 206 | transforms.ToTensor(), 207 | transforms.Normalize(tiny_mean, tiny_std)]) 208 | trainset = torchvision.datasets.ImageFolder(root + '/tiny-imagenet-200/train', 209 | transform=transform_train) 210 | testset = torchvision.datasets.ImageFolder(root + '/tiny-imagenet-200/val', 211 | transform=transform_test) 212 | trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=shuffle, 213 | num_workers=num_workers) 214 | testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, 215 | num_workers=num_workers) 216 | 217 | return trainloader, testloader -------------------------------------------------------------------------------- /experiments/experiments.py: -------------------------------------------------------------------------------- 1 | ''' 2 | FORCE 3 | Copyright (c) 2020-present NAVER Corp. 4 | MIT license 5 | ''' 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torchvision 10 | import torch.nn.functional as F 11 | 12 | from .models import * 13 | from .datasets import * 14 | 15 | network_name_module = { 16 | 'resnet34': resnet34, 17 | 'resnet50': resnet50, 18 | 'resnet110': resnet110, 19 | } 20 | 21 | dataset_num_classes = { 22 | 'CIFAR10': 10, 23 | 'CIFAR100': 100, 24 | 'tiny_imagenet': 200 25 | } 26 | 27 | def vgg_cifar_experiment(device, network_name, dataset, frac_data_for_train=0.9): 28 | """ 29 | Util function to generate necessary components to train VGG network 30 | on CIFAR10/100 datasets. 31 | """ 32 | 33 | INIT_LR = 0.1 34 | BATCH_SIZE = 128 35 | milestones=[150, 250] 36 | EPOCHS = 350 37 | WEIGHT_DECAY_RATE = 0.0005 38 | 39 | if network_name == 'vgg19': 40 | depth = 19 41 | elif network_name == 'vgg16': 42 | depth = 16 43 | else: 44 | raise NotImplementedError 45 | net = VGG(dataset=dataset, depth=depth).to(device) 46 | 47 | optimiser = optim.SGD(net.parameters(), 48 | lr=INIT_LR, 49 | momentum=0.9, 50 | weight_decay=WEIGHT_DECAY_RATE) 51 | 52 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimiser, milestones=milestones, gamma=0.1) 53 | 54 | train_loader, val_loader = get_cifar_train_valid_loader( 55 | batch_size=BATCH_SIZE, 56 | augment=True, 57 | random_seed=1, 58 | valid_size=1-frac_data_for_train, 59 | pin_memory=False, 60 | dataset_name=dataset 61 | ) 62 | 63 | test_loader = get_cifar_test_loader( 64 | batch_size=BATCH_SIZE, 65 | pin_memory=False, 66 | dataset_name=dataset 67 | ) 68 | 69 | loss = F.cross_entropy 70 | return net, optimiser, lr_scheduler, train_loader, val_loader, test_loader, loss, EPOCHS 71 | 72 | 73 | def vgg_tiny_imagenet_experiment(device, network_name, dataset): 74 | """ 75 | Util function to generate necessary components to train VGG network 76 | on Tiny Imagenet dataset. 77 | """ 78 | 79 | INIT_LR = 0.1 80 | BATCH_SIZE = 128 81 | milestones=[150, 225] 82 | EPOCHS = 300 83 | WEIGHT_DECAY_RATE = 0.0005 84 | 85 | if network_name == 'vgg19': 86 | depth = 19 87 | elif network_name == 'vgg16': 88 | depth = 16 89 | else: 90 | raise NotImplementedError 91 | net = VGG(dataset=dataset, depth=depth).to(device) 92 | 93 | optimiser = optim.SGD(net.parameters(), 94 | lr=INIT_LR, 95 | momentum=0.9, 96 | weight_decay=WEIGHT_DECAY_RATE) 97 | 98 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimiser, milestones=milestones, gamma=0.1) 99 | 100 | train_loader, test_loader = get_tiny_imagenet_train_valid_loader(BATCH_SIZE, 101 | augment=True, 102 | shuffle=True, 103 | num_workers=8) 104 | val_loader = None 105 | 106 | loss = F.cross_entropy 107 | return net, optimiser, lr_scheduler, train_loader, val_loader, test_loader, loss, EPOCHS 108 | 109 | 110 | def resnet_tiny_imagenet_experiment(device, network_name, dataset, in_planes): 111 | """ 112 | Util function to generate necessary components to train resnet network 113 | on Tiny Imagenet dataset. 114 | """ 115 | 116 | INIT_LR = 0.1 117 | BATCH_SIZE = 128 118 | milestones=[150, 225] 119 | EPOCHS = 300 120 | WEIGHT_DECAY_RATE = 0.0005 121 | 122 | print(network_name) 123 | num_classes = dataset_num_classes[dataset] 124 | network_name = network_name.split('stable')[-1] 125 | net = network_name_module[network_name](num_classes=num_classes, stable_resnet=False, 126 | in_planes=in_planes).to(device) 127 | 128 | optimiser = optim.SGD(net.parameters(), 129 | lr=INIT_LR, 130 | momentum=0.9, 131 | weight_decay=WEIGHT_DECAY_RATE) 132 | 133 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimiser, milestones=milestones, gamma=0.1) 134 | 135 | train_loader, test_loader = get_tiny_imagenet_train_valid_loader(BATCH_SIZE, 136 | augment=True, 137 | shuffle=True, 138 | num_workers=8) 139 | val_loader = None 140 | 141 | loss = F.cross_entropy 142 | return net, optimiser, lr_scheduler, train_loader, val_loader, test_loader, loss, EPOCHS 143 | 144 | 145 | def resnet_cifar_experiment(device, network_name, dataset_name, optimiser_name="sgd", 146 | frac_data_for_train=0.9, stable_resnet=False, in_planes=64): 147 | """ 148 | Util function to generate necessary components to train resnet network 149 | on CIFAR10/100 datasets. 150 | """ 151 | 152 | INIT_LR = 0.1 153 | BATCH_SIZE = 128 154 | milestones = [150, 250] 155 | EPOCHS = 350 156 | WEIGHT_DECAY_RATE = 0.0005 157 | 158 | print(network_name) 159 | num_classes = dataset_num_classes[dataset_name] 160 | network_name = network_name.split('stable')[-1] 161 | net = network_name_module[network_name](num_classes=num_classes, stable_resnet=stable_resnet, 162 | in_planes=in_planes).to(device) 163 | torch.backends.cudnn.benchmark = True 164 | 165 | if optimiser_name == "sgd": 166 | optimiser = optim.SGD(net.parameters(), 167 | lr=INIT_LR, 168 | momentum=0.9, 169 | weight_decay=WEIGHT_DECAY_RATE) 170 | elif optimiser_name == "adam": 171 | optimiser = optim.Adam(net.parameters(), 172 | lr=INIT_LR, 173 | weight_decay=WEIGHT_DECAY_RATE) 174 | 175 | scheduler = optim.lr_scheduler.MultiStepLR(optimiser, milestones=milestones, gamma=0.1) 176 | 177 | train_loader, val_loader = get_cifar_train_valid_loader( 178 | batch_size=BATCH_SIZE, 179 | augment=True, 180 | random_seed=1, 181 | valid_size=1-frac_data_for_train, 182 | pin_memory=False, 183 | dataset_name=dataset_name 184 | ) 185 | 186 | test_loader = get_cifar_test_loader( 187 | batch_size=BATCH_SIZE, 188 | pin_memory=False, 189 | dataset_name=dataset_name 190 | ) 191 | 192 | loss = F.cross_entropy 193 | return net, optimiser, scheduler, train_loader, val_loader, test_loader, loss, EPOCHS 194 | 195 | 196 | def train_cross_entropy(epoch, model, train_loader, optimizer, device, writer, LOG_INTERVAL=20): 197 | ''' 198 | Util method for training with cross entropy loss. 199 | ''' 200 | # Signalling the model that it is in training mode 201 | model.train() 202 | train_loss = 0 203 | for batch_idx, (data, labels) in enumerate(train_loader): 204 | # Loading the data onto the GPU 205 | data = data.to(device) 206 | labels = labels.to(device) 207 | 208 | optimizer.zero_grad() 209 | 210 | logits = model(data) 211 | loss = F.cross_entropy(logits, labels) 212 | 213 | loss.backward() 214 | # torch.nn.utils.clip_grad_norm(model.parameters(), 2) 215 | train_loss += loss.item() 216 | optimizer.step() 217 | 218 | if batch_idx % LOG_INTERVAL == 0: 219 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 220 | epoch, batch_idx * len(data), len(train_loader) * len(data), 221 | 100. * batch_idx / len(train_loader), loss.item())) 222 | writer.add_scalar("training/loss", loss.item(), 223 | epoch*len(train_loader)+batch_idx) 224 | 225 | print('====> Epoch: {} Average loss: {:.4f}'.format( 226 | epoch, train_loss / len(train_loader))) 227 | return train_loss / len(train_loader) 228 | -------------------------------------------------------------------------------- /experiments/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn.init as init 5 | 6 | 7 | #################################################################### 8 | ###################### Resnet ####################### 9 | #################################################################### 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1, L=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | # Normalising factor derived in Stable Resnet paper 22 | # https://arxiv.org/pdf/2002.08797.pdf 23 | self.factor = L**(-0.5) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out = out*self.factor + self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1, L=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 51 | 52 | # Normalising factor derived in Stable Resnet paper 53 | # https://arxiv.org/pdf/2002.08797.pdf 54 | self.factor = L**(-0.5) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 60 | nn.BatchNorm2d(self.expansion*planes) 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = F.relu(self.bn2(self.conv2(out))) 66 | out = self.bn3(self.conv3(out)) 67 | out = out*self.factor + self.shortcut(x) 68 | out = F.relu(out) 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, num_classes=10, temp=1.0, in_planes=64, stable_resnet=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = in_planes 76 | if stable_resnet: 77 | # Total number of blocks for Stable ResNet 78 | # https://arxiv.org/pdf/2002.08797.pdf 79 | L = 0 80 | for x in num_blocks: 81 | L+=x 82 | self.L = L 83 | else: 84 | self.L = 1 85 | 86 | self.masks = None 87 | 88 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 89 | self.bn1 = nn.BatchNorm2d(self.in_planes) 90 | self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1) 91 | self.layer2 = self._make_layer(block, in_planes*2, num_blocks[1], stride=2) 92 | self.layer3 = self._make_layer(block, in_planes*4, num_blocks[2], stride=2) 93 | self.layer4 = self._make_layer(block, in_planes*8, num_blocks[3], stride=2) 94 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 95 | self.fc = nn.Linear(in_planes*8*block.expansion, num_classes) 96 | self.temp = temp 97 | 98 | def _make_layer(self, block, planes, num_blocks, stride): 99 | strides = [stride] + [1]*(num_blocks-1) 100 | layers = [] 101 | for stride in strides: 102 | layers.append(block(self.in_planes, planes, stride, self.L)) 103 | self.in_planes = planes * block.expansion 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | out = F.relu(self.bn1(self.conv1(x))) 108 | out = self.layer1(out) 109 | out = self.layer2(out) 110 | out = self.layer3(out) 111 | out = self.layer4(out) 112 | out = self.avgpool(out) 113 | out = torch.flatten(out, 1) 114 | out = self.fc(out) / self.temp 115 | 116 | return out 117 | 118 | 119 | def resnet18(temp=1.0, **kwargs): 120 | model = ResNet(BasicBlock, [2, 2, 2, 2], temp=temp, **kwargs) 121 | return model 122 | 123 | def resnet34(temp=1.0, **kwargs): 124 | model = ResNet(BasicBlock, [3, 4, 6, 3], temp=temp, **kwargs) 125 | return model 126 | 127 | def resnet50(temp=1.0, **kwargs): 128 | model = ResNet(Bottleneck, [3, 4, 6, 3], temp=temp, **kwargs) 129 | return model 130 | 131 | def resnet101(temp=1.0, **kwargs): 132 | model = ResNet(Bottleneck, [3, 4, 23, 3], temp=temp, **kwargs) 133 | return model 134 | 135 | def resnet110(temp=1.0, **kwargs): 136 | model = ResNet(Bottleneck, [3, 4, 26, 3], temp=temp, **kwargs) 137 | return model 138 | 139 | def resnet152(temp=1.0, **kwargs): 140 | model = ResNet(Bottleneck, [3, 8, 36, 3], temp=temp, **kwargs) 141 | return model 142 | 143 | 144 | #################################################################### 145 | ####################### VGG ################################### 146 | #################################################################### 147 | 148 | defaultcfg = { 149 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 150 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 151 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 152 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 153 | } 154 | 155 | 156 | class VGG(nn.Module): 157 | def __init__(self, dataset='CIFAR10', depth=19, cfg=None, affine=True, batchnorm=True, 158 | init_weights=True): 159 | super(VGG, self).__init__() 160 | if cfg is None: 161 | cfg = defaultcfg[depth] 162 | self._AFFINE = affine 163 | self.feature = self.make_layers(cfg, batchnorm) 164 | self.dataset = dataset 165 | if dataset == 'CIFAR10': 166 | num_classes = 10 167 | elif dataset == 'CIFAR100': 168 | num_classes = 100 169 | elif dataset == 'tiny_imagenet': 170 | num_classes = 200 171 | else: 172 | raise NotImplementedError("Unsupported dataset " + dataset) 173 | self.classifier = nn.Linear(cfg[-1], num_classes) 174 | if init_weights: 175 | self.apply(weights_init) 176 | 177 | 178 | def make_layers(self, cfg, batch_norm=False): 179 | layers = [] 180 | in_channels = 3 181 | for v in cfg: 182 | if v == 'M': 183 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 184 | else: 185 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 186 | if batch_norm: 187 | layers += [conv2d, nn.BatchNorm2d(v, affine=self._AFFINE), nn.ReLU(inplace=True)] 188 | else: 189 | layers += [conv2d, nn.ReLU(inplace=True)] 190 | in_channels = v 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | x = self.feature(x) 195 | if self.dataset == 'tiny_imagenet': 196 | x = nn.AvgPool2d(4)(x) 197 | else: 198 | x = nn.AvgPool2d(2)(x) 199 | x = x.view(x.size(0), -1) 200 | y = self.classifier(x) 201 | return y 202 | 203 | 204 | def weights_init(m): 205 | # print('=> weights init') 206 | if isinstance(m, nn.Conv2d): 207 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 208 | # nn.init.normal_(m.weight, 0, 0.1) 209 | if m.bias is not None: 210 | m.bias.data.zero_() 211 | elif isinstance(m, nn.Linear): 212 | # nn.init.xavier_normal(m.weight) 213 | nn.init.normal_(m.weight, 0, 0.01) 214 | nn.init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.BatchNorm2d): 216 | # Note that BN's running_var/mean are 217 | # already initialized to 1 and 0 respectively. 218 | if m.weight is not None: 219 | m.weight.data.fill_(1.0) 220 | if m.bias is not None: 221 | m.bias.data.zero_() -------------------------------------------------------------------------------- /pruning/mask_networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | # from IPython import embed 4 | 5 | def apply_prune_mask(net, keep_masks, apply_hooks=True): 6 | """ 7 | Function that takes a network and a list of masks and applies it to the relevant layers. 8 | mask[i] == 0 --> Prune parameter 9 | mask[i] == 1 --> Keep parameter 10 | 11 | If apply_hooks == False, then set weight to 0 but do not block the gradient. 12 | This is used for FORCE algorithm that sparsifies the net instead of pruning. 13 | """ 14 | 15 | # Before I can zip() layers and pruning masks I need to make sure they match 16 | # one-to-one by removing all non-prunable modules: 17 | prunable_layers = filter( 18 | lambda layer: isinstance(layer, nn.Conv2d) or isinstance( 19 | layer, nn.Linear), net.modules()) 20 | 21 | # List of hooks to be applied on the gradients. It's useful to save them in order to remove 22 | # them later 23 | hook_handlers = [] 24 | for layer, keep_mask in zip(prunable_layers, keep_masks): 25 | assert (layer.weight.shape == keep_mask.shape) 26 | 27 | def hook_factory(keep_mask): 28 | """ 29 | The hook function can't be defined directly here because of Python's 30 | late binding which would result in all hooks getting the very last 31 | mask! Getting it through another function forces early binding. 32 | """ 33 | def hook(grads): 34 | return grads * keep_mask 35 | 36 | return hook 37 | 38 | # Step 1: Set the masked weights to zero (Biases are ignored) 39 | layer.weight.data[keep_mask == 0.] = 0. 40 | 41 | # Step 2: Make sure their gradients remain zero (not with FORCE) 42 | if apply_hooks: 43 | hook_handlers.append(layer.weight.register_hook(hook_factory(keep_mask))) 44 | 45 | return hook_handlers 46 | -------------------------------------------------------------------------------- /pruning/pruning_algos.py: -------------------------------------------------------------------------------- 1 | ''' 2 | FORCE 3 | Copyright (c) 2020-present NAVER Corp. 4 | MIT license 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import copy 11 | import numpy as np 12 | # from IPython import embed 13 | 14 | import warnings 15 | 16 | from .mask_networks import apply_prune_mask 17 | 18 | 19 | #################################################### 20 | ############### Get saliencies ################## 21 | #################################################### 22 | 23 | def get_average_gradients(net, train_dataloader, device, num_batches=-1): 24 | """ 25 | Function to compute gradients and average them over several batches. 26 | 27 | num_batches: Number of batches to be used to approximate the gradients. 28 | When set to -1, uses the whole training set. 29 | 30 | Returns a list of tensors, with gradients for each prunable layer. 31 | """ 32 | 33 | # Prepare list to store gradients 34 | gradients = [] 35 | for layer in net.modules(): 36 | # Select only prunable layers 37 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 38 | gradients.append(0) 39 | 40 | # Take a whole epoch 41 | count_batch = 0 42 | for batch_idx in range(len(train_dataloader)): 43 | inputs, targets = next(iter(train_dataloader)) 44 | inputs = inputs.to(device) 45 | targets = targets.to(device) 46 | 47 | # Compute gradients (but don't apply them) 48 | net.zero_grad() 49 | outputs = net.forward(inputs) 50 | loss = F.nll_loss(outputs, targets) 51 | loss.backward() 52 | 53 | # Store gradients 54 | counter = 0 55 | for layer in net.modules(): 56 | # Select only prunable layers 57 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 58 | gradients[counter] += layer.weight.grad 59 | counter += 1 60 | count_batch += 1 61 | if batch_idx == num_batches - 1: 62 | break 63 | avg_gradients = [x / count_batch for x in gradients] 64 | 65 | return avg_gradients 66 | 67 | 68 | def get_average_saliencies(net, train_dataloader, device, prune_method=3, num_batches=-1, 69 | original_weights=None): 70 | """ 71 | Get saliencies with averaged gradients. 72 | 73 | num_batches: Number of batches to be used to approximate the gradients. 74 | When set to -1, uses the whole training set. 75 | 76 | prune_method: Which method to use to prune the layers, refer to https://arxiv.org/abs/2006.09081. 77 | 1: Use Iter SNIP. 78 | 2: Use GRASP-It. 79 | 3: Use FORCE (default). 80 | 4: Random (random pruning baseline). 81 | 82 | Returns a list of tensors with saliencies for each weight. 83 | """ 84 | 85 | def pruning_criteria(method): 86 | if method == 2: 87 | # GRASP-It method 88 | result = layer_weight_grad**2 # Custom gradient norm approximation 89 | elif method == 4: 90 | result = torch.rand_like(layer_weight) # Randomly pruning weights 91 | else: 92 | # FORCE / Iter SNIP method 93 | result = torch.abs(layer_weight * layer_weight_grad) 94 | return result 95 | 96 | if prune_method != 4: # No need to compute gradients for random pruning 97 | gradients = get_average_gradients(net, train_dataloader, device, num_batches) 98 | saliency = [] 99 | idx = 0 100 | for layer in net.modules(): 101 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 102 | if prune_method == 3: 103 | layer_weight = original_weights[idx] 104 | else: 105 | layer_weight = layer.weight 106 | if prune_method != 4: # No need to compute gradients for random pruning 107 | layer_weight_grad = gradients[idx] 108 | idx += 1 109 | saliency.append(pruning_criteria(prune_method)) 110 | 111 | return saliency 112 | 113 | ################################################### 114 | ############# Iterative pruning ################### 115 | ################################################### 116 | 117 | def get_mask(saliency, pruning_factor): 118 | """ 119 | Given a list of saliencies and a pruning factor (sparsity), 120 | returns a list with binary tensors which correspond to pruning masks. 121 | """ 122 | # Gather all scores in a single vector and normalise 123 | all_scores = torch.cat([torch.flatten(x) for x in saliency]) 124 | 125 | num_params_to_keep = int(len(all_scores) * pruning_factor) 126 | threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True) 127 | acceptable_score = threshold[-1] 128 | 129 | keep_masks = [] 130 | for m in saliency: 131 | keep_masks.append((m >= acceptable_score).float()) 132 | return keep_masks 133 | 134 | def iterative_pruning(ori_net, train_dataloader, device, pruning_factor=0.1, 135 | prune_method=3, num_steps=10, 136 | mode='exp', num_batches=1): 137 | """ 138 | Function to gradually remove weights from a network, recomputing the saliency at each step. 139 | 140 | pruning_factor: Fraction of remaining weights (globally) after pruning. 141 | 142 | prune_method: Which method to use to prune the layers, refer to https://arxiv.org/abs/2006.09081. 143 | 1: Use Iter SNIP. 144 | 2: Use GRASP-It. 145 | 3: Use FORCE (default). 146 | 4: Random (random pruning baseline). 147 | 148 | num_steps: Number of iterations to do when pruning progressively (should be >= 1). 149 | 150 | mode: Mode of choosing the sparsity decay schedule. One of 'exp', 'linear' 151 | 152 | num_batches: Number of batches to be used to approximate the gradients (should be -1 or >= 1). 153 | When set to -1, uses the whole training set. 154 | 155 | Returns a list of binary tensors which correspond to the final pruning mask. 156 | """ 157 | # Let's create a copy of the network to make sure we don't affect the training later 158 | net = copy.deepcopy(ori_net) 159 | 160 | if prune_method == 4 and num_steps > 1: 161 | message = 'The selected pruning variant (Random) is not meant to perform iterative pruning' 162 | warnings.warn(message, UserWarning, stacklevel=2) 163 | 164 | if prune_method == 3: 165 | # If we want to apply FORCE we need to save the original (dense) weights 166 | # to compute the saliency of sparsified connections. 167 | original_weights = [] 168 | for layer in net.modules(): 169 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 170 | original_weights.append(layer.weight.detach()) 171 | else: 172 | original_weights = None 173 | 174 | # Choose a decay mode for sparsity (exponential should be used unless you know what you 175 | # are doing) 176 | if mode == 'linear': 177 | pruning_steps = [1 - ((x + 1) * (1 - pruning_factor) / num_steps) for x in range(num_steps)] 178 | 179 | elif mode == 'exp': 180 | pruning_steps = [np.exp(0 - ((x + 1) * (0 - np.log(pruning_factor)) / num_steps)) for x in range(num_steps)] 181 | 182 | mask = None 183 | hook_handlers = None 184 | 185 | for perc in pruning_steps: 186 | saliency = [] 187 | saliency = get_average_saliencies(net, train_dataloader, device, 188 | prune_method=prune_method, 189 | num_batches=num_batches, 190 | original_weights=original_weights) 191 | torch.cuda.empty_cache() 192 | 193 | # Make sure all saliencies of previously deleted weights is minimum so they do not 194 | # get picked again. 195 | if mask is not None and prune_method < 3: 196 | min_saliency = get_minimum_saliency(saliency) 197 | for ii in range(len(saliency)): 198 | saliency[ii][mask[ii] == 0.] = min_saliency 199 | 200 | if hook_handlers is not None: 201 | for h in hook_handlers: 202 | h.remove() 203 | 204 | mask = [] 205 | mask = get_mask(saliency, perc) 206 | 207 | # To use FORCE, go back to the dense network so unmasked 208 | # weights can recover 209 | if prune_method == 3: 210 | net = copy.deepcopy(ori_net) 211 | apply_prune_mask(net, mask, apply_hooks=False) 212 | else: 213 | hook_handlers = apply_prune_mask(net, mask, apply_hooks=True) 214 | 215 | p = check_global_pruning(mask) 216 | print(f'Global pruning {round(float(p),5)}') 217 | 218 | return mask 219 | 220 | def check_global_pruning(mask): 221 | "Compute fraction of unpruned weights in a mask" 222 | flattened_mask = torch.cat([torch.flatten(x) for x in mask]) 223 | return flattened_mask.mean() 224 | 225 | def get_minimum_saliency(saliency): 226 | "Compute minimum value of saliency globally" 227 | flattened_saliency = torch.cat([torch.flatten(x) for x in saliency]) 228 | return flattened_saliency.min() 229 | 230 | def get_maximum_saliency(saliency): 231 | "Compute maximum value of saliency globally" 232 | flattened_saliency = torch.cat([torch.flatten(x) for x in saliency]) 233 | return flattened_saliency.max() 234 | 235 | 236 | #################################################################### 237 | ###################### UTILS ################################# 238 | #################################################################### 239 | 240 | def get_force_saliency(net, mask, train_dataloader, device, num_batches): 241 | """ 242 | Given a dense network and a pruning mask, compute the FORCE saliency. 243 | """ 244 | net = copy.deepcopy(net) 245 | apply_prune_mask(net, mask, 0, apply_hooks=True) 246 | saliencies = get_average_saliencies(net, train_dataloader, device, 247 | 1, num_batches=num_batches) 248 | torch.cuda.empty_cache() 249 | s = sum_unmasked_saliency(saliencies, mask) 250 | torch.cuda.empty_cache() 251 | return s 252 | 253 | def sum_unmasked_saliency(variable, mask): 254 | "Util to sum all unmasked (mask==1) components" 255 | V = 0 256 | for v, m in zip(variable, mask): 257 | V += v[m > 0].sum() 258 | return V.detach().cpu() 259 | 260 | def get_gradient_norm(net, mask, train_dataloader, device, num_batches): 261 | "Given a dense network, compute the gradient norm after applying the pruning mask." 262 | net = copy.deepcopy(net) 263 | apply_prune_mask(net, mask) 264 | gradients = get_average_gradients(net, train_dataloader, device, num_batches) 265 | torch.cuda.empty_cache() 266 | norm = 0 267 | for g in gradients: 268 | norm += (g**2).sum().detach().cpu().numpy() 269 | return norm 270 | -------------------------------------------------------------------------------- /train_cifar_tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | FORCE 3 | Copyright (c) 2020-present NAVER Corp. 4 | MIT license 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from tensorboardX import SummaryWriter 12 | from ignite.engine import create_supervised_evaluator 13 | from ignite.metrics import Accuracy, Loss 14 | 15 | from pruning.pruning_algos import iterative_pruning 16 | from experiments.experiments import * 17 | from pruning.mask_networks import apply_prune_mask 18 | 19 | import os 20 | import argparse 21 | import random 22 | # from IPython import embed 23 | 24 | def parseArgs(): 25 | 26 | parser = argparse.ArgumentParser( 27 | description="Training CIFAR / Tiny-Imagenet.", 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | 30 | parser.add_argument("--pruning_factor", type=float, default=0.01, dest="pruning_factor", 31 | help='Fraction of connections after pruning') 32 | 33 | parser.add_argument("--prune_method", type=int, default=3, dest="prune_method", 34 | help="""Which pruning method to use: 35 | 1->Iter SNIP 36 | 2->GRASP-It 37 | 3->FORCE (default). """) 38 | 39 | parser.add_argument("--dataset", type=str, default='CIFAR10', 40 | dest="dataset_name", help='Dataset to train on') 41 | 42 | parser.add_argument("--network_name", type=str, default='resnet50', dest="network_name", 43 | help='Model to train') 44 | 45 | parser.add_argument("--num_steps", type=int, default=10, 46 | help='Number of steps to use with iterative pruning') 47 | 48 | parser.add_argument("--mode", type=str, default='exp', 49 | help='Mode of creating the iterative pruning steps one of "linear" or "exp".') 50 | 51 | parser.add_argument("--num_batches", type=int, default=1, 52 | help='''Number of batches to be used when computing the gradient. 53 | If set to -1 they will be averaged over the whole dataset.''') 54 | 55 | parser.add_argument("--save_interval", type=int, default=50, 56 | dest="save_interval", help="Number of epochs between model checkpoints.") 57 | 58 | parser.add_argument("--save_loc", type=str, default='saved_models/', 59 | dest="save_loc", help='Path where to save the model') 60 | 61 | parser.add_argument("--opt", type=str, default='sgd', 62 | dest="optimiser", 63 | help='Choice of optimisation algorithm') 64 | 65 | parser.add_argument("--saved_model_name", type=str, default="cnn.model", 66 | dest="saved_model_name", help="Filename of the pre-trained model") 67 | 68 | parser.add_argument("--frac-train-data", type=float, default=0.9, dest="frac_data_for_train", 69 | help='Fraction of data used for training (only applied in CIFAR)') 70 | 71 | parser.add_argument("--init", type=str, default='normal_kaiming', 72 | help='Which initialization method to use') 73 | 74 | parser.add_argument("--in_planes", type=int, default=64, 75 | help='''Number of input planes in Resnet. Afterwards they duplicate after 76 | each conv with stride 2 as usual.''') 77 | 78 | return parser.parse_args() 79 | 80 | 81 | LOG_INTERVAL = 20 82 | REPEAT_WITH_DIFFERENT_SEED = 3 # Number of initialize-prune-train trials (minimum of 1) 83 | torch.backends.cudnn.deterministic = True 84 | torch.backends.cudnn.benchmark = False 85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 86 | 87 | # New additions 88 | args = parseArgs() 89 | 90 | 91 | def train(seed): 92 | 93 | # Set manual seed 94 | torch.manual_seed(seed) 95 | 96 | if 'resnet' in args.network_name: 97 | stable_resnet = False 98 | if 'stable' in args.network_name: 99 | stable_resnet = True 100 | if 'CIFAR' in args.dataset_name: 101 | [net, optimiser, lr_scheduler, 102 | train_loader, val_loader, 103 | test_loader, loss, EPOCHS] = resnet_cifar_experiment(device, args.network_name, 104 | args.dataset_name, args.optimiser, 105 | args.frac_data_for_train, 106 | stable_resnet, args.in_planes) 107 | elif 'tiny_imagenet' in args.dataset_name: 108 | [net, optimiser, lr_scheduler, 109 | train_loader, val_loader, 110 | test_loader, loss, EPOCHS] = resnet_tiny_imagenet_experiment(device, args.network_name, 111 | args.dataset_name, args.in_planes) 112 | 113 | 114 | elif 'vgg' in args.network_name or 'VGG' in args.network_name: 115 | if 'tiny_imagenet' in args.dataset_name: 116 | [net, optimiser, lr_scheduler, 117 | train_loader, val_loader, 118 | test_loader, loss, EPOCHS] = vgg_tiny_imagenet_experiment(device, args.network_name, 119 | args.dataset_name) 120 | else: 121 | [net, optimiser, lr_scheduler, 122 | train_loader, val_loader, 123 | test_loader, loss, EPOCHS] = vgg_cifar_experiment(device, args.network_name, 124 | args.dataset_name, args.frac_data_for_train) 125 | 126 | # Initialize network 127 | for layer in net.modules(): 128 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 129 | if args.init == 'normal_kaiming': 130 | nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') 131 | elif args.init == 'normal_kaiming_fout': 132 | nn.init.kaiming_normal_(layer.weight, nonlinearity='relu', mode='fan_out') 133 | elif args.init == 'normal_xavier': 134 | nn.init.xavier_normal_(layer.weight) 135 | elif args.init == 'orthogonal': 136 | nn.init.orthogonal_(layer.weight) 137 | else: 138 | raise ValueError(f"Unrecognised initialisation parameter {args.init}") 139 | 140 | ############################################################################ 141 | #################### Pruning at init ######################## 142 | ############################################################################ 143 | pruning_factor = args.pruning_factor 144 | keep_masks=[] 145 | if pruning_factor != 1: 146 | print(f'Pruning network iteratively for {args.num_steps} steps') 147 | keep_masks = iterative_pruning(net, train_loader, device, pruning_factor, 148 | prune_method=args.prune_method, 149 | num_steps=args.num_steps, 150 | mode=args.mode, num_batches=args.num_batches) 151 | apply_prune_mask(net, keep_masks) 152 | filename = f'iter_prun_{args.num_steps}' 153 | 154 | 155 | ############################################################################ 156 | #################### Training ######################## 157 | ############################################################################ 158 | evaluator = create_supervised_evaluator(net, { 159 | 'accuracy': Accuracy(), 160 | 'cross_entropy': Loss(loss) 161 | }, device) 162 | 163 | run_name = (args.network_name + '_' + args.dataset_name + '_spars' + 164 | str(1 - pruning_factor) + '_variant' + str(args.prune_method) + 165 | '_train-frac' + str(args.frac_data_for_train) + 166 | f'_steps{args.num_steps}_{args.mode}' + f'_{args.init}' + 167 | f'_batch{args.num_batches}' + f'_rseed_{seed}') 168 | 169 | writer_name= 'runs/' + run_name 170 | writer = SummaryWriter(writer_name) 171 | 172 | iterations = 0 173 | for epoch in range(0, EPOCHS): 174 | lr_scheduler.step() 175 | train_loss = train_cross_entropy(epoch, net, train_loader, optimiser, device, 176 | writer, LOG_INTERVAL=20) 177 | iterations +=len(train_loader) 178 | # Evaluate 179 | evaluator.run(test_loader) 180 | metrics = evaluator.state.metrics 181 | # Save history 182 | avg_accuracy = metrics['accuracy'] 183 | avg_cross_entropy = metrics['cross_entropy'] 184 | writer.add_scalar("test/loss", avg_cross_entropy, iterations) 185 | writer.add_scalar("test/accuracy", avg_accuracy, iterations) 186 | 187 | # Save model checkpoints 188 | if (epoch + 1) % args.save_interval == 0: 189 | if not os.path.exists(args.save_loc): 190 | os.makedirs(args.save_loc) 191 | save_name = args.save_loc + run_name + '_cross_entropy_' + str(epoch + 1) + '.model' 192 | torch.save(net.state_dict(), save_name) 193 | elif (epoch + 1) == EPOCHS: 194 | if not os.path.exists(args.save_loc): 195 | os.makedirs(args.save_loc) 196 | save_name = args.save_loc + run_name + '_cross_entropy_' + str(epoch + 1) + '.model' 197 | torch.save(net.state_dict(), save_name) 198 | 199 | 200 | if __name__ == '__main__': 201 | 202 | # Randomly pick a random seed for the experiment 203 | # Multiply the number of seeds to be sampled by 300 so there is wide range of seeds 204 | seeds = list(range(300 * REPEAT_WITH_DIFFERENT_SEED)) 205 | random.shuffle(seeds) 206 | 207 | for seed in seeds[:REPEAT_WITH_DIFFERENT_SEED]: 208 | train(seed) 209 | -------------------------------------------------------------------------------- /train_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | 21 | from tensorboardX import SummaryWriter 22 | from pruning.pruning_algos import iterative_pruning 23 | from pruning.mask_networks import apply_prune_mask 24 | from torchvision.models.vgg import vgg16_bn, vgg19_bn 25 | 26 | model_names = sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name])) 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | 32 | ##################################################################### 33 | ################ Arguments from pytorch code ##################### 34 | ##################################################################### 35 | parser.add_argument('data', metavar='DIR', 36 | help='path to imagenet dataset') 37 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 38 | choices=model_names, 39 | help='model architecture: ' + 40 | ' | '.join(model_names) + 41 | ' (default: resnet18)') 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 47 | help='manual epoch number (useful on restarts)') 48 | parser.add_argument('-b', '--batch-size', default=256, type=int, 49 | metavar='N', 50 | help='mini-batch size (default: 256), this is the total ' 51 | 'batch size of all GPUs on the current node when ' 52 | 'using Data Parallel or Distributed Data Parallel') 53 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 54 | metavar='LR', help='initial learning rate', dest='lr') 55 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 56 | help='momentum') 57 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 58 | metavar='W', help='weight decay (default: 1e-4)', 59 | dest='weight_decay') 60 | parser.add_argument('-p', '--print-freq', default=50, type=int, 61 | metavar='N', help='print frequency (default: 50)') 62 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 63 | help='path to latest checkpoint (default: none)') 64 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 65 | help='evaluate model on validation set') 66 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 67 | help='use pre-trained model') 68 | parser.add_argument('--world-size', default=-1, type=int, 69 | help='number of nodes for distributed training') 70 | parser.add_argument('--rank', default=-1, type=int, 71 | help='node rank for distributed training') 72 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 73 | help='url used to set up distributed training') 74 | parser.add_argument('--dist-backend', default='nccl', type=str, 75 | help='distributed backend') 76 | parser.add_argument('--seed', default=None, type=int, 77 | help='seed for initializing training. ') 78 | parser.add_argument('--gpu', default=None, type=int, 79 | help='GPU id to use.') 80 | parser.add_argument('--multiprocessing-distributed', action='store_true', 81 | help='Use multi-processing distributed training to launch ' 82 | 'N processes per node, which has N GPUs. This is the ' 83 | 'fastest way to use PyTorch for either single node or ' 84 | 'multi node data parallel training') 85 | 86 | ##################################################################### 87 | ################ Arguments for pruning ##################### 88 | ##################################################################### 89 | parser.add_argument("--network_name", type=str, default='resnet50', dest="network_name", 90 | help='Model to train') 91 | 92 | parser.add_argument("--pruning_factor", type=float, default=0.1, dest="pruning_factor", 93 | help='Percentage of connections retained') 94 | 95 | parser.add_argument("--prune_method", type=int, default=3, dest="prune_method", 96 | help="""Which pruning method to use: 97 | 1->Iter SNIP 98 | 2->GRASP-It 99 | 3->FORCE (default). 100 | 4: Random (random pruning baseline). 101 | """) 102 | 103 | parser.add_argument("--num_steps", type=int, default=10, 104 | help='Number of steps to use with iterative pruning') 105 | 106 | parser.add_argument("--mode", type=str, default='exp', 107 | help='Mode of creating the iterative pruning steps one of "linear" or "exp".') 108 | 109 | parser.add_argument("--num_batches", type=int, default=1, 110 | help='''Number of batches to be used when computing the gradient. 111 | If set to -1 they will be averaged over the whole dataset.''') 112 | 113 | parser.add_argument("--save_interval", type=int, default=50, 114 | dest="save_interval", help="Number of epochs between model checkpoints.") 115 | 116 | parser.add_argument("--save_loc", type=str, default='saved_models/', 117 | dest="save_loc", help='Path where to save the model') 118 | 119 | parser.add_argument("--opt", type=str, default='sgd', 120 | dest="optimiser", 121 | help='Choice of optimisation algorithm') 122 | 123 | parser.add_argument("--saved_model_name", type=str, default="cnn.model", 124 | dest="saved_model_name", help="Filename of the pre-trained model") 125 | 126 | parser.add_argument("--frac-train-data", type=float, default=0.9, dest="frac_data_for_train", 127 | help='Fraction of data used for training (only applied in CIFAR)') 128 | 129 | parser.add_argument("--init", type=str, default='normal_kaiming', 130 | help='Which initialization method to use') 131 | 132 | seed = 1 # We fix the random seed 133 | device = torch.device("cuda") 134 | 135 | ############################################################################# 136 | 137 | def main(): 138 | args = parser.parse_args() 139 | args.arch = args.network_name 140 | if args.seed is not None: 141 | random.seed(args.seed) 142 | torch.manual_seed(args.seed) 143 | cudnn.deterministic = True 144 | warnings.warn('You have chosen to seed training. ' 145 | 'This will turn on the CUDNN deterministic setting, ' 146 | 'which can slow down your training considerably! ' 147 | 'You may see unexpected behavior when restarting ' 148 | 'from checkpoints.') 149 | 150 | if args.gpu is not None: 151 | warnings.warn('You have chosen a specific GPU. This will completely ' 152 | 'disable data parallelism.') 153 | 154 | if args.dist_url == "env://" and args.world_size == -1: 155 | args.world_size = int(os.environ["WORLD_SIZE"]) 156 | 157 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 158 | 159 | ngpus_per_node = torch.cuda.device_count() 160 | if args.multiprocessing_distributed: 161 | # Since we have ngpus_per_node processes per node, the total world_size 162 | # needs to be adjusted accordingly 163 | args.world_size = ngpus_per_node * args.world_size 164 | # Use torch.multiprocessing.spawn to launch distributed processes: the 165 | # main_worker process function 166 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 167 | else: 168 | # Simply call main_worker function 169 | main_worker(args.gpu, ngpus_per_node, args) 170 | 171 | 172 | def main_worker(gpu, ngpus_per_node, args): 173 | args.gpu = gpu 174 | 175 | if args.gpu is not None: 176 | print("Use GPU: {} for training".format(args.gpu)) 177 | 178 | if args.distributed: 179 | if args.dist_url == "env://" and args.rank == -1: 180 | args.rank = int(os.environ["RANK"]) 181 | if args.multiprocessing_distributed: 182 | # For multiprocessing distributed training, rank needs to be the 183 | # global rank among all the processes 184 | args.rank = args.rank * ngpus_per_node + gpu 185 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 186 | world_size=args.world_size, rank=args.rank) 187 | 188 | # create model 189 | if args.network_name == 'vgg19': 190 | model = vgg19_bn(pretrained=False) 191 | elif args.network_name == 'vgg16': 192 | model = vgg16_bn(pretrained=False) 193 | elif 'resnet' in args.network_name: 194 | model = models.__dict__[args.arch](pretrained=False) 195 | else: 196 | raise NotImplementedError 197 | 198 | # Initialize network 199 | for layer in model.modules(): 200 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 201 | if args.init == 'normal_kaiming': 202 | nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') 203 | elif args.init == 'normal_kaiming_fout': 204 | nn.init.kaiming_normal_(layer.weight, nonlinearity='relu', mode='fan_out') 205 | elif args.init == 'normal_xavier': 206 | nn.init.xavier_normal_(layer.weight) 207 | elif args.init == 'orthogonal': 208 | nn.init.orthogonal_(layer.weight) 209 | else: 210 | raise ValueError(f"Unrecognised initialisation parameter {args.init}") 211 | 212 | model = torch.nn.DataParallel(model).cuda() 213 | 214 | # define loss function (criterion) and optimizer 215 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 216 | 217 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 218 | momentum=args.momentum, 219 | weight_decay=args.weight_decay) 220 | 221 | # optionally resume from a checkpoint 222 | if args.resume: 223 | if os.path.isfile(args.resume): 224 | print("=> loading checkpoint '{}'".format(args.resume)) 225 | if args.gpu is None: 226 | checkpoint = torch.load(args.resume) 227 | else: 228 | # Map model to be loaded to specified single gpu. 229 | loc = 'cuda:{}'.format(args.gpu) 230 | checkpoint = torch.load(args.resume, map_location=loc) 231 | args.start_epoch = checkpoint['epoch'] 232 | model.load_state_dict(checkpoint['state_dict']) 233 | optimizer.load_state_dict(checkpoint['optimizer']) 234 | print("=> loaded checkpoint '{}' (epoch {})" 235 | .format(args.resume, checkpoint['epoch'])) 236 | else: 237 | print("=> no checkpoint found at '{}'".format(args.resume)) 238 | 239 | cudnn.benchmark = True 240 | 241 | # Data loading code 242 | traindir = os.path.join(args.data, 'train') 243 | valdir = os.path.join(args.data, 'val') 244 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 245 | std=[0.229, 0.224, 0.225]) 246 | 247 | train_dataset = datasets.ImageFolder( 248 | traindir, 249 | transforms.Compose([ 250 | transforms.RandomResizedCrop(224), 251 | transforms.RandomHorizontalFlip(), 252 | transforms.ToTensor(), 253 | normalize, 254 | ])) 255 | 256 | if args.distributed: 257 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 258 | else: 259 | train_sampler = None 260 | 261 | train_loader = torch.utils.data.DataLoader( 262 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 263 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 264 | 265 | val_loader = torch.utils.data.DataLoader( 266 | datasets.ImageFolder(valdir, transforms.Compose([ 267 | transforms.Resize(256), 268 | transforms.CenterCrop(224), 269 | transforms.ToTensor(), 270 | normalize, 271 | ])), 272 | batch_size=args.batch_size, shuffle=False, 273 | num_workers=args.workers, pin_memory=True) 274 | 275 | ############################# 276 | #### Pruning code #### 277 | ############################# 278 | 279 | pruning_factor = args.pruning_factor 280 | keep_masks=[] 281 | filename = '' 282 | if pruning_factor != 1: 283 | print(f'Pruning network iteratively for {args.num_steps} steps') 284 | keep_masks = iterative_pruning(model, train_loader, device, pruning_factor, 285 | prune_method=args.prune_method, 286 | num_steps=args.num_steps, 287 | mode=args.mode, num_batches=args.num_batches) 288 | 289 | apply_prune_mask(model, keep_masks) 290 | 291 | # File where to save training history 292 | run_name = (args.network_name + '_IMAGENET' + '_spars' + 293 | str(1 - pruning_factor) + '_variant' + str(args.prune_method) + 294 | '_train-frac' + str(args.frac_data_for_train) + 295 | f'_steps{args.num_steps}_{args.mode}' + f'_{args.init}' + 296 | f'_batch{args.num_batches}' + f'_rseed_{seed}') 297 | writer_name= 'runs/' + run_name 298 | writer = SummaryWriter(writer_name) 299 | 300 | if args.evaluate: 301 | validate(val_loader, model, criterion, args) 302 | return 303 | 304 | iterations = 0 305 | for epoch in range(args.start_epoch, args.epochs): 306 | if args.distributed: 307 | train_sampler.set_epoch(epoch) 308 | adjust_learning_rate(optimizer, epoch, args) 309 | 310 | # Train for one epoch 311 | train(train_loader, model, criterion, optimizer, epoch, args, writer) 312 | 313 | # Evaluate on validation set 314 | iterations = epoch * len(train_loader) 315 | acc1 = validate(val_loader, model, criterion, args, writer, iterations) 316 | 317 | # Save checkpoint 318 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 319 | and args.rank % ngpus_per_node == 0): 320 | if (epoch + 1) % 5 == 0: 321 | if not os.path.exists('saved_models/'): 322 | os.makedirs('saved_models/') 323 | save_name = 'saved_models/' + run_name + '_cross_entropy_' + str(epoch + 1) + '.model' 324 | torch.save(model.state_dict(), save_name) 325 | elif (epoch + 1) == args.epochs: 326 | if not os.path.exists('saved_models/'): 327 | os.makedirs('saved_models/') 328 | save_name = 'saved_models/' + run_name + '_cross_entropy_' + str(epoch + 1) + '.model' 329 | torch.save(model.state_dict(), save_name) 330 | 331 | 332 | def train(train_loader, model, criterion, optimizer, epoch, args, writer): 333 | batch_time = AverageMeter('Time', ':6.3f') 334 | data_time = AverageMeter('Data', ':6.3f') 335 | losses = AverageMeter('Loss', ':.4e') 336 | top1 = AverageMeter('Acc@1', ':6.2f') 337 | top5 = AverageMeter('Acc@5', ':6.2f') 338 | progress = ProgressMeter( 339 | len(train_loader), 340 | [batch_time, data_time, losses, top1, top5], 341 | prefix="Epoch: [{}]".format(epoch)) 342 | 343 | # switch to train mode 344 | model.train() 345 | 346 | end = time.time() 347 | total_batches = len(train_loader) 348 | for i, (images, target) in enumerate(train_loader): 349 | # measure data loading time 350 | data_time.update(time.time() - end) 351 | 352 | if args.gpu is not None: 353 | images = images.cuda(args.gpu, non_blocking=True) 354 | target = target.cuda(args.gpu, non_blocking=True) 355 | 356 | # compute output 357 | output = model(images) 358 | loss = criterion(output, target) 359 | 360 | # measure accuracy and record loss 361 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 362 | losses.update(loss.item(), images.size(0)) 363 | top1.update(acc1[0], images.size(0)) 364 | top5.update(acc5[0], images.size(0)) 365 | 366 | # compute gradient and do SGD step 367 | optimizer.zero_grad() 368 | loss.backward() 369 | optimizer.step() 370 | 371 | # measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | 375 | if i % args.print_freq == 0: 376 | progress.display(i) 377 | writer.add_scalar("training/loss", losses.avg, epoch * total_batches + i) 378 | 379 | 380 | def validate(val_loader, model, criterion, args, writer, iterations): 381 | batch_time = AverageMeter('Time', ':6.3f') 382 | losses = AverageMeter('Loss', ':.4e') 383 | top1 = AverageMeter('Acc@1', ':6.2f') 384 | top5 = AverageMeter('Acc@5', ':6.2f') 385 | progress = ProgressMeter( 386 | len(val_loader), 387 | [batch_time, losses, top1, top5], 388 | prefix='Test: ') 389 | 390 | # switch to evaluate mode 391 | model.eval() 392 | 393 | with torch.no_grad(): 394 | end = time.time() 395 | for i, (images, target) in enumerate(val_loader): 396 | if args.gpu is not None: 397 | images = images.cuda(args.gpu, non_blocking=True) 398 | target = target.cuda(args.gpu, non_blocking=True) 399 | 400 | # compute output 401 | output = model(images) 402 | loss = criterion(output, target) 403 | 404 | # measure accuracy and record loss 405 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 406 | losses.update(loss.item(), images.size(0)) 407 | top1.update(acc1[0], images.size(0)) 408 | top5.update(acc5[0], images.size(0)) 409 | 410 | # measure elapsed time 411 | batch_time.update(time.time() - end) 412 | end = time.time() 413 | 414 | if i % args.print_freq == 0: 415 | progress.display(i) 416 | 417 | # TODO: this should also be done with the ProgressMeter 418 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 419 | .format(top1=top1, top5=top5)) 420 | writer.add_scalar("test/loss", losses.avg, iterations) 421 | writer.add_scalar("test/accuracy", top1.avg, iterations) 422 | writer.add_scalar("test/top5", top5.avg, iterations) 423 | 424 | return top1.avg 425 | 426 | class AverageMeter(object): 427 | """Computes and stores the average and current value""" 428 | def __init__(self, name, fmt=':f'): 429 | self.name = name 430 | self.fmt = fmt 431 | self.reset() 432 | 433 | def reset(self): 434 | self.val = 0 435 | self.avg = 0 436 | self.sum = 0 437 | self.count = 0 438 | 439 | def update(self, val, n=1): 440 | self.val = val 441 | self.sum += val * n 442 | self.count += n 443 | self.avg = self.sum / self.count 444 | 445 | def __str__(self): 446 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 447 | return fmtstr.format(**self.__dict__) 448 | 449 | 450 | class ProgressMeter(object): 451 | def __init__(self, num_batches, meters, prefix=""): 452 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 453 | self.meters = meters 454 | self.prefix = prefix 455 | 456 | def display(self, batch): 457 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 458 | entries += [str(meter) for meter in self.meters] 459 | print('\t'.join(entries)) 460 | 461 | def _get_batch_fmtstr(self, num_batches): 462 | num_digits = len(str(num_batches // 1)) 463 | fmt = '{:' + str(num_digits) + 'd}' 464 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 465 | 466 | 467 | def adjust_learning_rate(optimizer, epoch, args): 468 | """Sets the learning rate to the initial LR decayed by 10 at 1/3 and 2/3 of training""" 469 | interval = args.epochs // 3 470 | lr = args.lr * (0.1 ** (epoch // interval)) 471 | for param_group in optimizer.param_groups: 472 | param_group['lr'] = lr 473 | 474 | 475 | def accuracy(output, target, topk=(1,)): 476 | """Computes the accuracy over the k top predictions for the specified values of k""" 477 | with torch.no_grad(): 478 | maxk = max(topk) 479 | batch_size = target.size(0) 480 | 481 | _, pred = output.topk(maxk, 1, True, True) 482 | pred = pred.t() 483 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 484 | 485 | res = [] 486 | for k in topk: 487 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 488 | res.append(correct_k.mul_(100.0 / batch_size)) 489 | return res 490 | 491 | 492 | if __name__ == '__main__': 493 | main() 494 | --------------------------------------------------------------------------------