├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── cifar │ ├── __init__.py │ ├── resnet110.py │ └── resnet20.py ├── dgc │ ├── __init__.py │ ├── fp16.py │ ├── int32.py │ ├── mm.py │ ├── nm.py │ ├── wm0.py │ ├── wm5.py │ └── wm5o.py └── imagenet │ ├── __init__.py │ ├── cosine.py │ ├── resnet18.py │ ├── resnet50.py │ └── vgg16_bn.py ├── data ├── .gitignore └── docs │ ├── cifar-10.png │ ├── resnet.png │ ├── speedup.png │ └── teaser.png ├── dgc ├── __init__.py ├── clip_grad.py ├── compression.py ├── horovod │ ├── README.md │ ├── __init__.py │ ├── compression.py │ ├── horovod.june.6.6b77884.patch │ └── optimizer.py ├── memory.py └── optim │ ├── __init__.py │ └── sgd.py ├── requirements.txt ├── script ├── cifar.resnet110.sh ├── cifar.resnet20.sh ├── imagenet.resnet50.sh └── imagenet.vgg16.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/ 132 | logs/ 133 | runs/ 134 | *.swp 135 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "torchpack"] 2 | path = torchpack 3 | url = https://github.com/synxlin/mini-torchpack.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2024] Yujun Lin 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Gradient Compression [[arXiv]](https://arxiv.org/pdf/1712.01887.pdf) 2 | 3 | ``` 4 | @inproceedings{lin2018dgc, 5 | title={{Deep Gradient Compression: Reducing the communication bandwidth for distributed training}}, 6 | author={Lin, Yujun and Han, Song and Mao, Huizi and Wang, Yu and Dally, William J}, 7 | booktitle={The International Conference on Learning Representations}, 8 | year={2018} 9 | } 10 | ``` 11 | 12 | ## Overview 13 | 14 | We release the PyTorch code of the [Deep Gradient Compression](https://arxiv.org/pdf/1712.01887.pdf). 15 | 16 |

17 |
18 | Figure 1. Deep Gradient Compression (DGC) can reduce the communication bandwidth (transmit less gradients by pruning away small gradients), improve the scalability, and speed up distributed training.

19 | 20 | 21 |
22 | Figure 2. : DGC maintains accuracy: Learning curves of ResNet (the gradient sparsity is 99.9%).

23 | 24 |
25 | Figure 3. DGC improves the scalability: speedup measured on NVIDIA TITAN RTX 2080Ti GPU cluster with 25 Gbps Ethernet.

26 |

27 | 28 | ## Content 29 | - [Prerequisites](#prerequisites) 30 | - [Code](#code) 31 | - [Training](#training) 32 | - [Known Issues and TODOs](#known-issues-and-todos) 33 | 34 | ## Prerequisites 35 | 36 | The code is built with following libraries (see [requirements.txt](requirements.txt)): 37 | - Python >= 3.7 38 | - [PyTorch](https://github.com/pytorch/pytorch) >= 1.5 39 | - [Horovod](https://github.com/horovod/horovod) >= 0.19.4 40 | - [numpy](https://github.com/numpy/numpy) 41 | - [tensorboardX](https://github.com/lanpa/tensorboardX) >= 1.2 42 | - [tqdm](https://github.com/tqdm/tqdm) 43 | - [openmpi](https://www.open-mpi.org/software/ompi/) >= 4.0 44 | 45 | ## Code 46 | 47 | The core code to implement DGC is in [dgc/compression.py](dgc/compression.py) and [dgc/memory.py](dgc/memory.py). 48 | 49 | - Gradient Accumulation and Momentum Correction 50 | ```python 51 | mmt = self.momentums[name] 52 | vec = self.velocities[name] 53 | if self.nesterov: 54 | mmt.add_(grad).mul_(self.momentum) 55 | vec.add_(mmt).add_(grad) 56 | else: 57 | mmt.mul_(self.momentum).add_(grad) 58 | vec.add_(mmt) 59 | return vec 60 | ``` 61 | 62 | - Sparsification 63 | ```python 64 | importance = tensor.abs() 65 | # sampling 66 | sample_start = random.randint(0, sample_stride - 1) 67 | samples = importance[sample_start::sample_stride] 68 | # thresholding 69 | threshold = torch.min(torch.topk(samples, top_k_samples, 0, largest=True, sorted=False)[0]) 70 | mask = torch.ge(importance, threshold) 71 | indices = mask.nonzero().view(-1) 72 | ``` 73 | 74 | ## Training 75 | We use [Horovod](https://github.com/horovod/horovod) to run distributed training: 76 | - run on a machine with *N* GPUs, 77 | ```bash 78 | horovodrun -np N python train.py --configs [config files] 79 | ``` 80 | e.g., resnet-20 on cifar-10 dataset with 8 GPUs: 81 | ```bash 82 | # fp16 values, int32 indices 83 | # warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001 84 | horovodrun -np 8 python train.py --configs configs/cifar/resnet20.py \ 85 | configs/dgc/wm5.py configs/dgc/fp16.py configs/dgc/int32.py 86 | ``` 87 | - run on *K* machines with *N* GPUs each, 88 | ```bash 89 | mpirun -np [K*N] -H server0:N,server1:N,...,serverK:N \ 90 | -bind-to none -map-by slot -x NCCL_DEBUG=INFO \ 91 | -x LD_LIBRARY_PATH -x PATH -mca pml ob1 \ 92 | -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo \ 93 | python train.py --configs [config files] 94 | ``` 95 | e.g., resnet-50 on ImageNet dataset with 4 machines with 8 GPUs each, 96 | ```bash 97 | # fp32 values, int64 indices, no warmup 98 | mpirun -np 32 -H server0:8,server1:8,server2:8,server3:8 \ 99 | -bind-to none -map-by slot -x NCCL_DEBUG=INFO \ 100 | -x LD_LIBRARY_PATH -x PATH -mca pml ob1 \ 101 | -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo \ 102 | python train.py --configs configs/imagenet/resnet50.py \ 103 | configs/dgc/wm0.py 104 | ``` 105 | For more information on horovodrun, please read horovod documentations. 106 | 107 | You can modify/add new config files under [configs](configs) to change training settings. You can also modify some trivial configs in the command: 108 | ```bash 109 | python train.py --configs [config files] --[config name] [config value] --suffix [suffix of experiment directory] 110 | ``` 111 | e.g., 112 | ```bash 113 | horovodrun -np 8 python train.py --configs configs/cifar/resnet20.py \ 114 | configs/dgc/wm5.py --configs.train.num_epochs 500 --suffix .e500 115 | ``` 116 | 117 | Here are some reproduce results using **0.1%** compression ratio (*i.e.*, `configs.train.compression.compress_ratio = 0.001`): 118 | | #GPUs | Batch Size | #Sparsified Nodes | ResNet-50 | VGG-16-BN | LR Scheduler | 119 | |:-----:|:----------:|:-----------------:|:---------:|:---------:|:------------:| 120 | | - | - | - | [76.2](https://pytorch.org/docs/stable/torchvision/models.html) | [73.4](https://pytorch.org/docs/stable/torchvision/models.html) | - | 121 | | 8 | 256 | 8 | 76.6 | 74.1 | MultiStep | 122 | | 16 | 512 | 16 | 76.5 | 73.8 | MultiStep | 123 | | 32 | 1024 | 32 | 76.3 | 73.3 | MultiStep | 124 | | 32 | 1024 | 32 | 76.7 | 74.4 | Cosine | 125 | | 64 | 2048 | 64 | 76.8 | 74.2 | Cosine | 126 | | 64 | 2048 | 8 | 76.6 | 73.8 | Cosine | 127 | | 128 | 4096 | 16 | 76.4 | 73.1 | Cosine | 128 | | 256 | 8192 | 32 | 75.9 | 71.7 | Cosine | 129 | 130 | ## Known Issues and TODOs 131 | 132 | - **Backend**: We currently only support OpenMPI backend. We encountered some errors when calling `allgather` using NCCL2 backend: `allgather`ed data are random data once in a while; if we set `CUDA_LAUNCH_BLOCKING=1` for debugging, everything works well. 133 | - **#Sparsified Nodes**: We currently treat each GPU as an independent node. However, communication is rarely a bottleneck within one machine. A better strategy should be performing `allreduce` dense gradients intra-machine and `allgather` sparse gradients inter-machines. 134 | - For accuracy/convergence verification, we can simulate this by setting `configs.train.num_batches_per_step` to desired #GPUs per machine (see accuracy table for batch size = 4096/8192). 135 | - **Sparsification Granularity**: We naively perform fine-grained (*i.e.*, element-wise) top-k to select gradients, and thus the communication will suffer from increased `allgather` data volume as #nodes increases. 136 | - [Sun *et.al.*](https://arxiv.org/pdf/1902.06855.pdf) modified the process with coarse-grained sparsification: gradients are partioned into chunks, `allreduce` the gradient chunks selected based on `allreduce`d L1-norm of each chunk, which gets rid of the `allgather` and solves the problem. 137 | - **Data Encoding**: We did not perform any data quantization/encoding before transmission. Data encoding can further reduce data volume. 138 | - **Overhead**: Performing sparsification (esp. adapting thresholding) in C/C++ may further reduce the DGC overhead. 139 | 140 | ## License 141 | 142 | This repository is released under the Apache license. See [LICENSE](LICENSE) for additional details. 143 | 144 | 145 | ## Acknowledgement 146 | - Our implementation is modified from [grace](https://github.com/sands-lab/grace) which is an unified framework for all sorts of compressed distributed training algorithms. 147 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | from torchpack.mtpack.meters import TopKClassMeter 5 | 6 | from dgc.horovod.compression import Compression 7 | 8 | 9 | configs.seed = 42 10 | configs.data = Config() 11 | configs.data.num_threads_per_worker = 4 12 | 13 | # criterion 14 | configs.train = Config() 15 | configs.train.dgc = False 16 | configs.train.compression = Compression.none 17 | configs.train.criterion = Config(torch.nn.CrossEntropyLoss) 18 | 19 | # optimizer 20 | configs.train.optimizer = Config(torch.optim.SGD) 21 | configs.train.optimizer.momentum = 0.9 22 | 23 | # scheduler 24 | configs.train.schedule_lr_per_epoch = True 25 | configs.train.warmup_lr_epochs = 5 26 | 27 | # metrics 28 | configs.train.metric = 'acc/test_top1' 29 | configs.train.meters = Config() 30 | configs.train.meters['acc/{}_top1'] = Config(TopKClassMeter) 31 | configs.train.meters['acc/{}_top1'].k = 1 32 | configs.train.meters['acc/{}_top5'] = Config(TopKClassMeter) 33 | configs.train.meters['acc/{}_top5'].k = 5 34 | -------------------------------------------------------------------------------- /configs/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchpack.mtpack.datasets.vision import CIFAR 4 | from torchpack.mtpack.utils.config import Config, configs 5 | 6 | # dataset 7 | configs.dataset = Config(CIFAR) 8 | configs.dataset.root = './data/cifar10' 9 | configs.dataset.num_classes = 10 10 | configs.dataset.image_size = 32 11 | 12 | # training 13 | configs.train.num_epochs = 200 14 | configs.train.batch_size = 128 15 | 16 | # optimizer 17 | configs.train.optimizer.lr = 0.1 18 | configs.train.optimizer.weight_decay = 1e-4 19 | 20 | # scheduler 21 | configs.train.scheduler = Config(torch.optim.lr_scheduler.CosineAnnealingLR) 22 | configs.train.scheduler.T_max = configs.train.num_epochs - configs.train.warmup_lr_epochs 23 | -------------------------------------------------------------------------------- /configs/cifar/resnet110.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.models.vision.resnet import resnet110 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | 5 | # model 6 | configs.model = Config(resnet110) 7 | configs.model.num_classes = configs.dataset.num_classes 8 | -------------------------------------------------------------------------------- /configs/cifar/resnet20.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.models.vision.resnet import resnet20 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | 5 | # model 6 | configs.model = Config(resnet20) 7 | configs.model.num_classes = configs.dataset.num_classes 8 | -------------------------------------------------------------------------------- /configs/dgc/__init__.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | from dgc.compression import DGCCompressor 4 | from dgc.memory import DGCSGDMemory 5 | from dgc.optim import DGCSGD 6 | 7 | 8 | configs.train.dgc = True 9 | configs.train.compression = Config(DGCCompressor) 10 | configs.train.compression.compress_ratio = 0.001 11 | configs.train.compression.sample_ratio = 0.01 12 | configs.train.compression.strided_sample = True 13 | configs.train.compression.compress_upper_bound = 1.3 14 | configs.train.compression.compress_lower_bound = 0.8 15 | configs.train.compression.max_adaptation_iters = 10 16 | configs.train.compression.resample = True 17 | 18 | old_optimizer = configs.train.optimizer 19 | configs.train.optimizer = Config(DGCSGD) 20 | for k, v in old_optimizer.items(): 21 | configs.train.optimizer[k] = v 22 | 23 | configs.train.compression.memory = Config(DGCSGDMemory) 24 | configs.train.compression.memory.momentum = configs.train.optimizer.momentum 25 | -------------------------------------------------------------------------------- /configs/dgc/fp16.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.fp16_values = True 4 | -------------------------------------------------------------------------------- /configs/dgc/int32.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.int32_indices = True 4 | -------------------------------------------------------------------------------- /configs/dgc/mm.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.memory.momentum_masking = True 4 | -------------------------------------------------------------------------------- /configs/dgc/nm.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.memory.momentum_masking = False 4 | -------------------------------------------------------------------------------- /configs/dgc/wm0.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.warmup_epochs = 0 4 | -------------------------------------------------------------------------------- /configs/dgc/wm5.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.warmup_epochs = 5 4 | -------------------------------------------------------------------------------- /configs/dgc/wm5o.py: -------------------------------------------------------------------------------- 1 | from torchpack.mtpack.utils.config import Config, configs 2 | 3 | configs.train.compression.warmup_epochs = 5 4 | configs.train.compression.warmup_coeff = [1, 1, 1, 1, 1] 5 | -------------------------------------------------------------------------------- /configs/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchpack.mtpack.datasets.vision import ImageNet 4 | from torchpack.mtpack.utils.config import Config, configs 5 | 6 | # dataset 7 | configs.dataset = Config(ImageNet) 8 | configs.dataset.root = '/dataset/imagenet' 9 | configs.dataset.num_classes = 1000 10 | configs.dataset.image_size = 224 11 | 12 | # training 13 | configs.train.num_epochs = 90 14 | configs.train.batch_size = 32 15 | 16 | # optimizer 17 | configs.train.optimize_bn_separately = False 18 | configs.train.optimizer.lr = 0.0125 19 | configs.train.optimizer.weight_decay = 5e-5 20 | 21 | # scheduler 22 | configs.train.scheduler = Config(torch.optim.lr_scheduler.MultiStepLR) 23 | configs.train.scheduler.milestones = [e - configs.train.warmup_lr_epochs 24 | for e in [30, 60, 80]] 25 | configs.train.scheduler.gamma = 0.1 26 | -------------------------------------------------------------------------------- /configs/imagenet/cosine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | 5 | # scheduler 6 | configs.train.scheduler = Config(torch.optim.lr_scheduler.CosineAnnealingLR) 7 | configs.train.scheduler.T_max = configs.train.num_epochs - configs.train.warmup_lr_epochs 8 | -------------------------------------------------------------------------------- /configs/imagenet/resnet18.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet18 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | 5 | configs.train.batch_size = 64 6 | configs.train.optimizer.lr = 0.025 7 | 8 | # model 9 | configs.model = Config(resnet18) 10 | configs.model.num_classes = configs.dataset.num_classes 11 | configs.model.zero_init_residual = True 12 | -------------------------------------------------------------------------------- /configs/imagenet/resnet50.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet50 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | 5 | configs.train.optimizer.weight_decay = 1e-4 6 | configs.train.optimizer.nesterov = True 7 | configs.train.optimize_bn_separately = True 8 | 9 | # model 10 | configs.model = Config(resnet50) 11 | configs.model.num_classes = configs.dataset.num_classes 12 | configs.model.zero_init_residual = True 13 | -------------------------------------------------------------------------------- /configs/imagenet/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import vgg16_bn 2 | 3 | from torchpack.mtpack.utils.config import Config, configs 4 | 5 | # model 6 | configs.model = Config(vgg16_bn) 7 | configs.model.num_classes = configs.dataset.num_classes 8 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !docs/ 3 | !docs/* 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /data/docs/cifar-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/cifar-10.png -------------------------------------------------------------------------------- /data/docs/resnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/resnet.png -------------------------------------------------------------------------------- /data/docs/speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/speedup.png -------------------------------------------------------------------------------- /data/docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/teaser.png -------------------------------------------------------------------------------- /dgc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/dgc/__init__.py -------------------------------------------------------------------------------- /dgc/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import inf 3 | 4 | from horovod.torch import allreduce_ 5 | 6 | __all__ = ['clip_grad_norm_', 'clip_grad_value_', 'clip_grad_value_by_global_norm_', 'clip_grad_norm_2_by_global_'] 7 | 8 | 9 | # code modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py 10 | def clip_grad_norm_(grad, max_norm, norm_type=2): 11 | max_norm = float(max_norm) 12 | norm_type = float(norm_type) 13 | if norm_type == inf: 14 | total_norm = grad.data.abs().max() 15 | else: 16 | total_norm = grad.data.norm(norm_type) 17 | clip_coef = max_norm / (total_norm + 1e-6) 18 | if clip_coef < 1: 19 | grad.data.mul_(clip_coef) 20 | return grad 21 | 22 | 23 | def clip_grad_value_(grad, clip_value): 24 | clip_value = float(clip_value) 25 | grad.data.clamp_(min=-clip_value, max=clip_value) 26 | 27 | 28 | # code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/memory/dgc.py 29 | def clip_grad_value_by_global_norm_(grad, name=None): 30 | grad_square_sum = torch.sum(grad.square()) 31 | clip_value = torch.sqrt(allreduce_(grad_square_sum, average=True, name=name)) 32 | grad.data.clamp_(min=-clip_value, max=clip_value) 33 | 34 | 35 | def clip_grad_norm_2_by_global_(grad, max_norm, name=None): 36 | max_norm = float(max_norm) 37 | grad_square_sum = torch.sum(grad.square()) 38 | total_norm = torch.sqrt(allreduce_(grad_square_sum, average=True, name=name)) 39 | clip_coef = max_norm / (total_norm + 1e-6) 40 | if clip_coef < 1: 41 | grad.data.mul_(clip_coef) 42 | return grad 43 | -------------------------------------------------------------------------------- /dgc/compression.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | 6 | import horovod.torch as hvd 7 | from horovod.torch.mpi_ops import Average 8 | from horovod.torch.mpi_ops import allreduce_async_ 9 | from horovod.torch.mpi_ops import allgather_async as allgather_async_ 10 | from horovod.torch.mpi_ops import synchronize as synchronize_ 11 | 12 | from dgc.memory import Memory 13 | 14 | __all__ = ['DGCCompressor'] 15 | 16 | 17 | class DGCCompressor: 18 | def __init__(self, compress_ratio, memory=None, 19 | sample_ratio=0.01, strided_sample=True, 20 | compress_upper_bound=1.3, compress_lower_bound=0.8, max_adaptation_iters=10, resample=True, 21 | fp16_values=False, int32_indices=False, 22 | warmup_epochs=-1, warmup_coeff=None): 23 | self.world_size = hvd.size() 24 | self.op = Average 25 | self.fp16_values = fp16_values 26 | self.int32_indices = int32_indices 27 | 28 | self.base_compress_ratio = self.compress_ratio = \ 29 | compress_ratio if compress_ratio <= 1.0 else 1.0 / compress_ratio 30 | self.memory = Memory if memory is None else memory 31 | self.warmup_epochs = warmup_epochs 32 | if self.warmup_epochs > 0: 33 | if warmup_coeff is None: 34 | self.warmup_coeff = self.base_compress_ratio \ 35 | ** (1. / (self.warmup_epochs + 1)) 36 | else: 37 | if isinstance(warmup_coeff, (tuple, list)): 38 | assert len(warmup_coeff) >= self.warmup_epochs 39 | for wc in warmup_coeff: 40 | assert 0 < wc <= 1 41 | else: 42 | assert 0 < warmup_coeff <= 1 43 | self.warmup_coeff = warmup_coeff 44 | else: 45 | self.warmup_coeff = 1 46 | 47 | self.sample_ratio = min(max(sample_ratio, 0.01), 1.0) 48 | self.strided_sample = strided_sample 49 | self.compress_upper_bound = compress_upper_bound 50 | self.compress_lower_bound = compress_lower_bound 51 | self.max_adaptation_iters = max_adaptation_iters 52 | self.resample = resample 53 | 54 | self.attributes = {} 55 | 56 | def initialize(self, named_parameters): 57 | if hvd.rank() == 0: 58 | print("=> initializing dgc compressor") 59 | for name, param in named_parameters: 60 | if torch.is_tensor(param): 61 | numel = param.numel() 62 | shape = list(param.size()) 63 | else: 64 | assert isinstance(param, (list, tuple)) 65 | numel, shape = param[0], param[1] 66 | if self.sample_ratio < 1.0: 67 | pct_numel = int(math.ceil(numel * self.sample_ratio)) 68 | cpr_numel = int(math.ceil(2 / self.compress_ratio)) 69 | if numel <= cpr_numel: 70 | if hvd.rank() == 0: 71 | print(f'Warning: {name} with {numel} elements transmits 1 gradient element') 72 | sample_stride = 1 73 | num_samples = numel 74 | else: 75 | sample_stride = int(math.ceil(numel / max(pct_numel, cpr_numel) / 32)) * 32 + 1 76 | num_samples = numel // sample_stride 77 | while num_samples < max(pct_numel, cpr_numel): 78 | sample_stride = sample_stride - 8 79 | num_samples = numel // sample_stride 80 | else: 81 | sample_stride = 1 82 | num_samples = numel 83 | top_k_samples = int(math.ceil(num_samples * self.compress_ratio)) 84 | num_selects = int(math.ceil(numel * self.compress_ratio)) 85 | self.attributes[name] = (numel, shape, num_selects, num_samples, top_k_samples, sample_stride) 86 | if hvd.rank() == 0: 87 | print(f' {name:<25}: transmit {num_selects} / {numel} elements of shape {shape}\n' 88 | f' {" " * 25} threshold {top_k_samples} / {num_samples} samples' 89 | f' {f"at stride {sample_stride}" if self.strided_sample else "uniformly"}') 90 | 91 | def warmup_compress_ratio(self, epoch): 92 | if self.warmup_epochs > 0: 93 | if epoch < self.warmup_epochs: 94 | if isinstance(self.warmup_coeff, (tuple, list)): 95 | compress_ratio = self.warmup_coeff[epoch] 96 | else: 97 | compress_ratio = max(self.warmup_coeff ** (epoch + 1), 98 | self.base_compress_ratio) 99 | else: 100 | compress_ratio = self.base_compress_ratio 101 | else: 102 | compress_ratio = self.base_compress_ratio 103 | if compress_ratio != self.compress_ratio: 104 | if hvd.rank() == 0: 105 | print(f'update compress ratio: {compress_ratio}') 106 | self.compress_ratio = compress_ratio 107 | self.initialize(self.attributes.items()) 108 | 109 | def _sparsify(self, tensor, name): 110 | tensor = tensor.view(-1) 111 | numel, shape, num_selects, num_samples, top_k_samples, sample_stride = self.attributes[name] 112 | 113 | importance = tensor.abs() 114 | if numel == num_samples: 115 | samples = importance 116 | else: 117 | if self.strided_sample: 118 | sample_start = random.randint(0, sample_stride - 1) 119 | samples = importance[sample_start::sample_stride] 120 | else: 121 | samples = importance[torch.randint(0, numel, (num_samples, ), device=tensor.device)] 122 | 123 | threshold = torch.min(torch.topk(samples, top_k_samples, 0, largest=True, sorted=False)[0]) 124 | mask = torch.ge(importance, threshold) 125 | indices = mask.nonzero().view(-1) 126 | num_indices = indices.numel() 127 | 128 | if numel > num_samples: 129 | # code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/compressor/dgc.py 130 | for _ in range(self.max_adaptation_iters): 131 | if num_indices > num_selects: 132 | if num_indices > num_selects * self.compress_upper_bound: 133 | if self.resample: 134 | indices = indices[ 135 | torch.topk(importance[indices], num_selects, 136 | 0, largest=True, sorted=False)[1] 137 | ] 138 | break 139 | else: 140 | threshold = threshold * self.compress_upper_bound 141 | else: 142 | break 143 | elif num_indices < self.compress_lower_bound * num_selects: 144 | threshold = threshold * self.compress_lower_bound 145 | else: 146 | break 147 | mask = torch.ge(importance, threshold) 148 | indices = mask.nonzero().view(-1) 149 | num_indices = indices.numel() 150 | 151 | indices = indices[:num_selects] 152 | values = tensor[indices] 153 | return values, indices, numel, shape, num_selects 154 | 155 | def compress(self, tensor, name): 156 | if self.compress_ratio < 1.0 and name in self.attributes: 157 | # compress 158 | tensor_compensated = self.memory.compensate( 159 | tensor, name, accumulate=True) 160 | values, indices, numel, shape, num_selects = \ 161 | self._sparsify(tensor_compensated, name) 162 | self.memory.update(name, (indices, )) 163 | indices = indices.view(-1, 1) 164 | values = values.view(-1, 1) 165 | 166 | ctx = (name, numel, shape, values.dtype, indices.dtype, 167 | tensor.data.view(numel)) 168 | if self.fp16_values and values.dtype.is_floating_point: 169 | values = values.type(torch.float16) 170 | if self.int32_indices and not indices.dtype.is_floating_point: 171 | indices = indices.type(torch.int32) 172 | return (values, indices), ctx 173 | else: 174 | ctx = (name, None, None, tensor.dtype, None, None) 175 | if self.fp16_values and tensor.dtype.is_floating_point: 176 | tensor = tensor.type(torch.float16) 177 | return tensor, ctx 178 | 179 | def decompress(self, tensor, ctx): 180 | name, numel, shape, vdtype, idtype, grad = ctx 181 | if self.compress_ratio < 1.0 and name in self.attributes: 182 | # decompress 183 | assert isinstance(tensor, (list, tuple)) 184 | values, indices = tensor 185 | values = values.view(-1) 186 | indices = indices.view(-1) 187 | if self.fp16_values and vdtype.is_floating_point: 188 | values = values.type(vdtype) 189 | if self.int32_indices and not idtype.is_floating_point: 190 | indices = indices.type(idtype) 191 | grad.zero_().index_put_([indices], values, accumulate=True) 192 | if self.op == Average: 193 | grad.mul_(1. / self.world_size) 194 | return grad.view(shape) 195 | else: 196 | if self.fp16_values and vdtype.is_floating_point: 197 | tensor = tensor.type(vdtype) 198 | return self.memory.compensate(tensor, name, accumulate=False) 199 | 200 | def communicate(self, tensor_compressed, name, op): 201 | self.op = op 202 | if self.compress_ratio < 1.0 and name in self.attributes: 203 | return [allgather_async_(t, name=f'{name}.t{e}') 204 | for e, t in enumerate(tensor_compressed)] 205 | else: 206 | return allreduce_async_(tensor_compressed, name=name, op=op) 207 | 208 | def synchronize(self, handle): 209 | if isinstance(handle, (tuple, list)): 210 | return [synchronize_(h) for h in handle] 211 | else: 212 | return synchronize_(handle) 213 | -------------------------------------------------------------------------------- /dgc/horovod/README.md: -------------------------------------------------------------------------------- 1 | # Horovod Patch 2 | 3 | We applied [patch](horovod.june.6.6b77884.patch) to Horovod at [this commit](https://github.com/horovod/horovod/tree/6b77884daf92649ecf031fcc8ff29697bbea0132). 4 | Nonetheless, we copied the modified files in this directory so that you don't have to patch Horovod source code. 5 | 6 | The modification is very subtile: 7 | 8 | - class `Compressor` will take `name` as another argument when compressing a tensor. 9 | 10 | - class `DistributedOptimizer` will perform `communicate()` of its compression member if possible, instead of always using `allreduce_async()`. 11 | -------------------------------------------------------------------------------- /dgc/horovod/__init__.py: -------------------------------------------------------------------------------- 1 | from dgc.horovod.compression import Compressor, Compression 2 | from dgc.horovod.optimizer import DistributedOptimizer 3 | -------------------------------------------------------------------------------- /dgc/horovod/compression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Uber Technologies, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Gradient compression algorithms.""" 16 | 17 | import torch 18 | 19 | __all__ = ['Compressor', 'Compression'] 20 | 21 | 22 | class Compressor(object): 23 | """Interface for compressing and decompressing a given tensor.""" 24 | @staticmethod 25 | def compress(tensor, name=None): 26 | """Compresses a tensor and returns it with the context needed to decompress it.""" 27 | pass 28 | 29 | @staticmethod 30 | def decompress(tensor, ctx): 31 | """Decompress the tensor with the given context.""" 32 | pass 33 | 34 | 35 | class NoneCompressor(Compressor): 36 | """Default no-op compression.""" 37 | @staticmethod 38 | def compress(tensor, name=None): 39 | """Returns the tensor unmodified.""" 40 | return tensor, None 41 | 42 | @staticmethod 43 | def decompress(tensor, ctx): 44 | """Returns the tensor unmodified.""" 45 | return tensor 46 | 47 | 48 | class FP16Compressor(Compressor): 49 | """Compress all floating point gradients to 16-bit.""" 50 | @staticmethod 51 | def compress(tensor, name=None): 52 | """Downcasts the tensor to 16-bit.""" 53 | tensor_compressed = tensor 54 | if tensor.dtype.is_floating_point: 55 | # Only allow compression from other floating point types 56 | tensor_compressed = tensor.type(torch.float16) 57 | return tensor_compressed, tensor.dtype 58 | 59 | @staticmethod 60 | def decompress(tensor, ctx): 61 | """Upcasts the tensor to the initialization dtype.""" 62 | tensor_decompressed = tensor 63 | dtype = ctx 64 | if dtype.is_floating_point: 65 | tensor_decompressed = tensor.type(dtype) 66 | return tensor_decompressed 67 | 68 | 69 | class Compression(object): 70 | """Optional gradient compression algorithm used during allreduce.""" 71 | 72 | """Do not compress the gradients. This is the default.""" 73 | none = NoneCompressor 74 | 75 | """Compress all floating point gradients to 16-bit.""" 76 | fp16 = FP16Compressor 77 | -------------------------------------------------------------------------------- /dgc/horovod/horovod.june.6.6b77884.patch: -------------------------------------------------------------------------------- 1 | diff --git a/horovod/torch/compression.py b/horovod/torch/compression.py 2 | index 75ce91e..ca160ce 100644 3 | --- a/horovod/torch/compression.py 4 | +++ b/horovod/torch/compression.py 5 | @@ -20,7 +20,7 @@ import torch 6 | class Compressor(object): 7 | """Interface for compressing and decompressing a given tensor.""" 8 | @staticmethod 9 | - def compress(tensor): 10 | + def compress(tensor, name=None): 11 | """Compresses a tensor and returns it with the context needed to decompress it.""" 12 | pass 13 | 14 | @@ -33,7 +33,7 @@ class Compressor(object): 15 | class NoneCompressor(Compressor): 16 | """Default no-op compression.""" 17 | @staticmethod 18 | - def compress(tensor): 19 | + def compress(tensor, name=None): 20 | """Returns the tensor unmodified.""" 21 | return tensor, None 22 | 23 | @@ -46,7 +46,7 @@ class NoneCompressor(Compressor): 24 | class FP16Compressor(Compressor): 25 | """Compress all floating point gradients to 16-bit.""" 26 | @staticmethod 27 | - def compress(tensor): 28 | + def compress(tensor, name=None): 29 | """Downcasts the tensor to 16-bit.""" 30 | tensor_compressed = tensor 31 | if tensor.dtype.is_floating_point: 32 | diff --git a/horovod/torch/optimizer.py b/horovod/torch/optimizer.py 33 | index c8def4b..d954146 100644 34 | --- a/horovod/torch/optimizer.py 35 | +++ b/horovod/torch/optimizer.py 36 | @@ -23,7 +23,7 @@ import torch 37 | 38 | from horovod.torch.compression import Compression 39 | from horovod.torch.mpi_ops import allreduce_async_ 40 | -from horovod.torch.mpi_ops import synchronize 41 | +from horovod.torch.mpi_ops import synchronize as synchronize_ 42 | from horovod.torch.mpi_ops import size 43 | from horovod.torch.mpi_ops import Average, Adasum 44 | 45 | @@ -33,6 +33,8 @@ class _DistributedOptimizer(torch.optim.Optimizer): 46 | backward_passes_per_step=1, op=Average): 47 | super(self.__class__, self).__init__(params) 48 | self._compression = compression 49 | + self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_) 50 | + self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_) 51 | 52 | if named_parameters is not None: 53 | named_parameters = list(named_parameters) 54 | @@ -111,9 +113,9 @@ class _DistributedOptimizer(torch.optim.Optimizer): 55 | def _allreduce_grad_async(self, p): 56 | name = self._parameter_names.get(p) 57 | tensor = p.grad 58 | - tensor_compressed, ctx = self._compression.compress(tensor) 59 | + tensor_compressed, ctx = self._compression.compress(tensor, name) 60 | 61 | - handle = allreduce_async_(tensor_compressed, name=name, op=self.op) 62 | + handle = self._communicate_(tensor_compressed, name=name, op=self.op) 63 | return handle, ctx 64 | 65 | def _make_hook(self, p): 66 | @@ -140,13 +142,12 @@ class _DistributedOptimizer(torch.optim.Optimizer): 67 | handle, ctx = self._allreduce_grad_async(p) 68 | self._handles[p] = (handle, ctx) 69 | 70 | - for p, value in self._handles.items(): 71 | - handle, ctx = value 72 | + for p, (handle, ctx) in self._handles.items(): 73 | if handle is None: 74 | handle, ctx = self._allreduce_grad_async(p) 75 | self._handles[p] = (handle, ctx) 76 | - for p, (handle, _) in self._handles.items(): 77 | - output = synchronize(handle) 78 | + for p, (handle, ctx) in self._handles.items(): 79 | + output = self._synchronize_(handle) 80 | self._allreduce_delay[p] = self.backward_passes_per_step 81 | p.grad.set_(self._compression.decompress(output, ctx)) 82 | self._handles.clear() 83 | @@ -200,6 +201,8 @@ class _DistributedAdasumOptimizer(torch.optim.Optimizer): 84 | super(self.__class__, self).__init__(params) 85 | 86 | self._compression = compression 87 | + self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_) 88 | + self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_) 89 | 90 | if named_parameters is not None: 91 | named_parameters = list(named_parameters) 92 | @@ -298,8 +301,8 @@ class _DistributedAdasumOptimizer(torch.optim.Optimizer): 93 | p.data.sub_(start) 94 | 95 | # allreduce as before 96 | - tensor_compressed, ctx = self._compression.compress(p) 97 | - handle = allreduce_async_(tensor_compressed.data, name=name, op=Adasum) 98 | + tensor_compressed, ctx = self._compression.compress(p, name) 99 | + handle = self._communicate_(tensor_compressed.data, name=name, op=Adasum) 100 | 101 | # reset stashed parameters 102 | for stashed, group in zip(stashed_params, self.param_groups): 103 | @@ -348,7 +351,7 @@ class _DistributedAdasumOptimizer(torch.optim.Optimizer): 104 | if not handle: 105 | handle, ctx = self._allreduce_grad_async(p) 106 | self._handles[p] = (handle, ctx) 107 | - delta = synchronize(handle) 108 | + delta = self._synchronize_(handle) 109 | delta = self._compression.decompress(delta, ctx) 110 | start = self._starting_models[p] 111 | start.data.add_(delta.data) 112 | -------------------------------------------------------------------------------- /dgc/horovod/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Uber Technologies, Inc. All Rights Reserved. 2 | # Modifications copyright Microsoft 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import os 18 | import warnings 19 | 20 | from contextlib import contextmanager 21 | 22 | import torch 23 | 24 | from horovod.torch.mpi_ops import allreduce_async_ 25 | from horovod.torch.mpi_ops import synchronize as synchronize_ 26 | from horovod.torch.mpi_ops import size 27 | from horovod.torch.mpi_ops import Average, Adasum 28 | 29 | from .compression import Compression 30 | 31 | __all__ = ['DistributedOptimizer'] 32 | 33 | 34 | class _DistributedOptimizer(torch.optim.Optimizer): 35 | def __init__(self, params, named_parameters, compression, 36 | backward_passes_per_step=1, op=Average): 37 | super(self.__class__, self).__init__(params) 38 | self._compression = compression 39 | self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_) 40 | self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_) 41 | 42 | if named_parameters is not None: 43 | named_parameters = list(named_parameters) 44 | else: 45 | named_parameters = [('allreduce.noname.%s' % i, v) 46 | for param_group in self.param_groups 47 | for i, v in enumerate(param_group['params'])] 48 | # make sure that named_parameters are tuples 49 | if any([not isinstance(p, tuple) for p in named_parameters]): 50 | raise ValueError('named_parameters should be a sequence of ' 51 | 'tuples (name, parameter), usually produced by ' 52 | 'model.named_parameters().') 53 | 54 | dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters]) 55 | if len(dups) > 0: 56 | raise ValueError('Parameter names in named_parameters must be unique. ' 57 | 'Found duplicates: %s' % ', '.join(dups)) 58 | 59 | all_param_ids = {id(v) 60 | for param_group in self.param_groups 61 | for v in param_group['params']} 62 | named_param_ids = {id(v) for k, v in named_parameters} 63 | unnamed_param_ids = all_param_ids - named_param_ids 64 | if len(unnamed_param_ids): 65 | raise ValueError('named_parameters was specified, but one or more model ' 66 | 'parameters were not named. Python object ids: ' 67 | '%s' % ', '.join(str(id) for id in unnamed_param_ids)) 68 | 69 | self._parameter_names = {v: k for k, v in sorted(named_parameters)} 70 | self.backward_passes_per_step = backward_passes_per_step 71 | self._allreduce_delay = {v: self.backward_passes_per_step 72 | for _, v in sorted(named_parameters)} 73 | self.op = op 74 | self._handles = {} 75 | self._grad_accs = [] 76 | self._requires_update = set() 77 | self._synchronized = False 78 | self._should_synchronize = True 79 | if size() > 1 or os.environ.get('HOROVOD_ELASTIC') == '1': 80 | self._register_hooks() 81 | 82 | def load_state_dict(self, *args, **kwargs): 83 | self._handles = {} 84 | self._synchronized = False 85 | self._should_synchronize = True 86 | for p in self._allreduce_delay: 87 | self._allreduce_delay[p] = self.backward_passes_per_step 88 | super(self.__class__, self).load_state_dict(*args, **kwargs) 89 | 90 | @staticmethod 91 | def find_duplicates(lst): 92 | seen = set() 93 | dups = set() 94 | for el in lst: 95 | if el in seen: 96 | dups.add(el) 97 | seen.add(el) 98 | return dups 99 | 100 | def set_backward_passes_per_step(self, passes): 101 | self.backward_passes_per_step = passes 102 | for p in self._allreduce_delay: 103 | self._allreduce_delay[p] = self.backward_passes_per_step 104 | 105 | def _register_hooks(self): 106 | for param_group in self.param_groups: 107 | for p in param_group['params']: 108 | if p.requires_grad: 109 | p.grad = p.data.new(p.size()).zero_() 110 | self._requires_update.add(p) 111 | p_tmp = p.expand_as(p) 112 | grad_acc = p_tmp.grad_fn.next_functions[0][0] 113 | grad_acc.register_hook(self._make_hook(p)) 114 | self._grad_accs.append(grad_acc) 115 | 116 | def _allreduce_grad_async(self, p): 117 | name = self._parameter_names.get(p) 118 | tensor_compressed, ctx = self._compression.compress(p.grad, name) 119 | 120 | handle = self._communicate_(tensor_compressed, name=name, op=self.op) 121 | return handle, ctx 122 | 123 | def _make_hook(self, p): 124 | def hook(*ignore): 125 | if p in self._handles and self._handles[p][0] is not None: 126 | if self._allreduce_delay[p] <= 0: 127 | raise AssertionError( 128 | "Gradients were computed more than " 129 | "backward_passes_per_step times before call " 130 | "to step(). Increase backward_passes_per_step to " 131 | "accumulate gradients locally.") 132 | assert not p.grad.requires_grad 133 | assert self._allreduce_delay[p] > 0 134 | handle, ctx = None, None 135 | self._allreduce_delay[p] -= 1 136 | if self._allreduce_delay[p] == 0: 137 | handle, ctx = self._allreduce_grad_async(p) 138 | self._handles[p] = (handle, ctx) 139 | return hook 140 | 141 | def synchronize(self): 142 | missing_p = self._requires_update - set(self._handles.keys()) 143 | for p in missing_p: 144 | handle, ctx = self._allreduce_grad_async(p) 145 | self._handles[p] = (handle, ctx) 146 | 147 | for p, (handle, ctx) in self._handles.items(): 148 | if handle is None: 149 | handle, ctx = self._allreduce_grad_async(p) 150 | self._handles[p] = (handle, ctx) 151 | for p, (handle, ctx) in self._handles.items(): 152 | output = self._synchronize_(handle) 153 | self._allreduce_delay[p] = self.backward_passes_per_step 154 | p.grad.set_(self._compression.decompress(output, ctx)) 155 | self._handles.clear() 156 | 157 | self._synchronized = True 158 | 159 | @contextmanager 160 | def skip_synchronize(self): 161 | """ 162 | A context manager used to specify that optimizer.step() should 163 | not perform synchronization. 164 | 165 | It's typically used in a following pattern: 166 | 167 | .. code-block:: python 168 | 169 | optimizer.synchronize() 170 | with optimizer.skip_synchronize(): 171 | optimizer.step() 172 | """ 173 | self._should_synchronize = False 174 | try: 175 | yield 176 | finally: 177 | self._should_synchronize = True 178 | 179 | def step(self, closure=None): 180 | if self._should_synchronize: 181 | if self._synchronized: 182 | warnings.warn("optimizer.step() called without " 183 | "optimizer.skip_synchronize() context after " 184 | "optimizer.synchronize(). This can cause training " 185 | "slowdown. You may want to consider using " 186 | "optimizer.skip_synchronize() context if you use " 187 | "optimizer.synchronize() in your code.") 188 | self.synchronize() 189 | self._synchronized = False 190 | return super(self.__class__, self).step(closure) 191 | 192 | def zero_grad(self): 193 | if self._handles: 194 | raise AssertionError("optimizer.zero_grad() was called after loss.backward() " 195 | "but before optimizer.step() or optimizer.synchronize(). " 196 | "This is prohibited as it can cause a race condition.") 197 | return super(self.__class__, self).zero_grad() 198 | 199 | 200 | class _DistributedAdasumOptimizer(torch.optim.Optimizer): 201 | def __init__(self, params, named_parameters, compression, 202 | backward_passes_per_step=1): 203 | super(self.__class__, self).__init__(params) 204 | 205 | self._compression = compression 206 | self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_) 207 | self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_) 208 | 209 | if named_parameters is not None: 210 | named_parameters = list(named_parameters) 211 | else: 212 | named_parameters = [('allreduce.noname.%s' % i, v) 213 | for param_group in self.param_groups 214 | for i, v in enumerate(param_group['params'])] 215 | 216 | # make sure that named_parameters are tuples 217 | if any([not isinstance(p, tuple) for p in named_parameters]): 218 | raise ValueError('named_parameters should be a sequence of ' 219 | 'tuples (name, parameter), usually produced by ' 220 | 'model.named_parameters().') 221 | 222 | dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters]) 223 | if len(dups) > 0: 224 | raise ValueError('Parameter names in named_parameters must be unique. ' 225 | 'Found duplicates: %s' % ', '.join(dups)) 226 | 227 | all_param_ids = {id(v) 228 | for param_group in self.param_groups 229 | for v in param_group['params']} 230 | named_param_ids = {id(v) for k, v in named_parameters} 231 | unnamed_param_ids = all_param_ids - named_param_ids 232 | if len(unnamed_param_ids): 233 | raise ValueError('named_parameters was specified, but one or more model ' 234 | 'parameters were not named. Python object ids: ' 235 | '%s' % ', '.join(str(id) for id in unnamed_param_ids)) 236 | 237 | self._parameter_names = {v: k for k, v in sorted(named_parameters)} 238 | self.backward_passes_per_step = backward_passes_per_step 239 | self._allreduce_delay = {v: self.backward_passes_per_step 240 | for _, v in sorted(named_parameters)} 241 | self._handles = {} 242 | self._grad_accs = [] 243 | self._requires_update = set() 244 | self._synchronized = False 245 | self._should_synchronize = True 246 | 247 | self._starting_models = { 248 | p : torch.zeros_like(p, requires_grad=False) 249 | for _, p in named_parameters 250 | } 251 | 252 | self._register_hooks() 253 | 254 | def set_backward_passes_per_step(self, passes): 255 | self.backward_passes_per_step = passes 256 | for p in self._allreduce_delay: 257 | self._allreduce_delay[p] = self.backward_passes_per_step 258 | 259 | def _register_hooks(self): 260 | for param_group in self.param_groups: 261 | for p in param_group['params']: 262 | if p.requires_grad: 263 | p.grad = p.data.new(p.size()).zero_() 264 | self._requires_update.add(p) 265 | p_tmp = p.expand_as(p) 266 | grad_acc = p_tmp.grad_fn.next_functions[0][0] 267 | grad_acc.register_hook(self._make_hook(p)) 268 | self._grad_accs.append(grad_acc) 269 | 270 | def _allreduce_grad_async(self, p): 271 | # Delta optimizer implements this logic: 272 | # start = current.copy() 273 | # step() -> computes 'current - \alpha.f(g)' where f is 274 | # optimizer logic and g is the gradient 275 | # delta = current-start 276 | # allreduce_(delta) 277 | # start += delta 278 | # current = start 279 | # In order to suppport this logic using function hook to improve performance, 280 | # we do: 281 | # delta = (start - \alpha.f(g)) - start 282 | # = -\alpha.f(g) 283 | # set start to zero and step computes -\alpha.f(g) 284 | # where f is the underlying optimizer logic 285 | 286 | name = self._parameter_names.get(p) 287 | start = self._starting_models[p] 288 | 289 | stashed_params = [] 290 | for group in self.param_groups: 291 | stashed_params.append(group['params']) 292 | # only want to step on p 293 | if any([p is v for v in group['params']]): 294 | group['params'] = [p] 295 | else: 296 | group['params'] = [] 297 | 298 | start.data.copy_(p) 299 | 300 | super(self.__class__, self).step() 301 | 302 | # compute delta = curr - start 303 | p.data.sub_(start) 304 | 305 | # allreduce as before 306 | tensor_compressed, ctx = self._compression.compress(p, name) 307 | handle = self._communicate_(tensor_compressed.data, name=name, op=Adasum) 308 | 309 | # reset stashed parameters 310 | for stashed, group in zip(stashed_params, self.param_groups): 311 | group['params'] = stashed 312 | 313 | return handle, ctx 314 | 315 | def _make_hook(self, p): 316 | def hook(*ignore): 317 | if p in self._handles and self._handles[p][0] is not None: 318 | if self._allreduce_delay[p] <= 0: 319 | raise AssertionError( 320 | "Gradients were computed more than " 321 | "backward_passes_per_step times before call " 322 | "to step(). Increase backward_passes_per_step to " 323 | "accumulate gradients locally.") 324 | assert not p.grad.requires_grad 325 | assert self._allreduce_delay[p] > 0 326 | handle, ctx = None, None 327 | self._allreduce_delay[p] -= 1 328 | if self._allreduce_delay[p] == 0: 329 | handle, ctx = self._allreduce_grad_async(p) 330 | self._handles[p] = (handle, ctx) 331 | return hook 332 | 333 | def synchronize(self): 334 | pass 335 | 336 | @contextmanager 337 | def skip_synchronize(self): 338 | raise AssertionError("Skipping synchronization is not supported when using Adasum optimizer.") 339 | 340 | def step(self, closure=None): 341 | loss = None 342 | if closure is not None: 343 | loss = closure() 344 | 345 | missing_p = self._requires_update - set(self._handles.keys()) 346 | for p in missing_p: 347 | handle, ctx = self._allreduce_grad_async(p) 348 | self._handles[p] = (handle, ctx) 349 | 350 | for p, (handle, ctx) in self._handles.items(): 351 | # This means step() is called before backward_passes_per_steps finished. 352 | # We do a synchoronous allreduce here. 353 | if not handle: 354 | handle, ctx = self._allreduce_grad_async(p) 355 | self._handles[p] = (handle, ctx) 356 | delta = self._synchronize_(handle) 357 | delta = self._compression.decompress(delta, ctx) 358 | start = self._starting_models[p] 359 | start.data.add_(delta.data) 360 | p.data.copy_(start) 361 | self._allreduce_delay[p] = self.backward_passes_per_step 362 | self._handles.clear() 363 | return loss 364 | 365 | def zero_grad(self): 366 | if self._handles: 367 | raise AssertionError("optimizer.zero_grad() was called after loss.backward() " 368 | "but before optimizer.step() or optimizer.synchronize(). " 369 | "This is prohibited as it can cause a race condition.") 370 | return super(self.__class__, self).zero_grad() 371 | 372 | 373 | def DistributedOptimizer(optimizer, named_parameters=None, 374 | compression=Compression.none, 375 | backward_passes_per_step=1, 376 | op=Average): 377 | """ 378 | An optimizer that wraps another torch.optim.Optimizer, using an allreduce to 379 | combine gradient values before applying gradients to model weights. 380 | 381 | Allreduce operations are executed after each gradient is computed by ``loss.backward()`` 382 | in parallel with each other. The ``step()`` method ensures that all allreduce operations are 383 | finished before applying gradients to the model. 384 | 385 | DistributedOptimizer exposes the ``synchronize()`` method, which forces allreduce operations 386 | to finish before continuing the execution. It's useful in conjunction with gradient 387 | clipping, or other operations that modify gradients in place before ``step()`` is executed. 388 | Make sure to use ``optimizer.skip_synchronize()`` if you're calling ``synchronize()`` 389 | in your code. 390 | 391 | Example of gradient clipping: 392 | 393 | .. code-block:: python 394 | 395 | output = model(data) 396 | loss = F.nll_loss(output, target) 397 | loss.backward() 398 | optimizer.synchronize() 399 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 400 | with optimizer.skip_synchronize(): 401 | optimizer.step() 402 | 403 | Arguments: 404 | optimizer: Optimizer to use for computing gradients and applying updates. 405 | named_parameters: A mapping between parameter names and values. Used for naming of 406 | allreduce operations. Typically just ``model.named_parameters()``. 407 | compression: Compression algorithm used during allreduce to reduce the amount 408 | of data sent during the each parameter update step. Defaults to 409 | not using compression. 410 | backward_passes_per_step: Number of expected backward passes to perform 411 | before calling step()/synchronize(). This 412 | allows accumulating gradients over multiple 413 | mini-batches before reducing and applying them. 414 | op: The reduction operation to use when combining gradients across different ranks. 415 | """ 416 | # We dynamically create a new class that inherits from the optimizer that was passed in. 417 | # The goal is to override the `step()` method with an allreduce implementation. 418 | 419 | if op != Adasum or size() == 1: 420 | cls = type(optimizer.__class__.__name__, (optimizer.__class__,), 421 | dict(_DistributedOptimizer.__dict__)) 422 | return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step, op) 423 | else: 424 | cls = type(optimizer.__class__.__name__, (optimizer.__class__,), 425 | dict(_DistributedAdasumOptimizer.__dict__)) 426 | return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step) 427 | -------------------------------------------------------------------------------- /dgc/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import horovod.torch as hvd 4 | 5 | __all__ = ['Memory', 'DGCSGDMemory'] 6 | 7 | 8 | # code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/memory/dgc.py 9 | class Memory: 10 | @staticmethod 11 | def initialize(*args, **kwargs): 12 | pass 13 | 14 | @staticmethod 15 | def compensate(tensor, *args, **kwargs): 16 | return tensor 17 | 18 | @staticmethod 19 | def update(*args, **kwargs): 20 | pass 21 | 22 | @staticmethod 23 | def state_dict(): 24 | return None 25 | 26 | @staticmethod 27 | def load_state_dict(state_dict): 28 | pass 29 | 30 | 31 | class DGCSGDMemory(Memory): 32 | """ Memory for momentum correction in DGC for momentum SGD optimizer""" 33 | def __init__(self, momentum=0.9, nesterov=False, 34 | gradient_clipping=None, momentum_masking=True): 35 | self.gradient_clipping = gradient_clipping 36 | self.momentum_masking = momentum_masking 37 | 38 | self.momentum = momentum 39 | self.nesterov = nesterov 40 | self.momentums = {} 41 | self.velocities = {} 42 | 43 | def initialize(self, named_parameters): 44 | if hvd.rank() == 0: 45 | print("=> initializing dgc sgd memory") 46 | for name, param in named_parameters: 47 | self.momentums[name] = torch.zeros_like(param.data) 48 | self.velocities[name] = torch.zeros_like(param.data) 49 | 50 | def compensate(self, grad, name, accumulate=True): 51 | """Update the velocities with the momentums.""" 52 | if self.gradient_clipping is not None: 53 | grad = self.gradient_clipping(grad) 54 | mmt = self.momentums[name] 55 | if accumulate: 56 | vec = self.velocities[name] 57 | if self.nesterov: 58 | mmt.add_(grad).mul_(self.momentum) 59 | vec.add_(mmt).add_(grad) 60 | else: 61 | mmt.mul_(self.momentum).add_(grad) 62 | vec.add_(mmt) 63 | return vec 64 | else: 65 | if self.nesterov: 66 | mmt.add_(grad).mul_(self.momentum) 67 | return mmt.add(grad) 68 | else: 69 | mmt.mul_(self.momentum).add_(grad) 70 | return mmt.clone() # TODO: save this clone 71 | 72 | def update(self, name, ctx): 73 | """Update the momentums.""" 74 | indices = ctx[0] 75 | if self.momentum_masking: 76 | self.momentums[name].view(-1).index_fill_(0, indices, 0) 77 | self.velocities[name].view(-1).index_fill_(0, indices, 0) 78 | 79 | def state_dict(self): 80 | return dict(momentums=self.momentums, velocities=self.velocities) 81 | 82 | def load_state_dict(self, state_dict): 83 | momentums = state_dict['momentums'] 84 | velocities = state_dict['velocities'] 85 | for name in self.momentums.keys(): 86 | if name in momentums: 87 | self.momentums[name] = momentums[name] 88 | self.velocities[name] = velocities[name] 89 | -------------------------------------------------------------------------------- /dgc/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from dgc.optim.sgd import DGCSGD 2 | -------------------------------------------------------------------------------- /dgc/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | __all__ = ['DGCSGD'] 7 | 8 | 9 | class DGCSGD(Optimizer): 10 | def __init__(self, params, lr=required, momentum=0, dampening=0, 11 | weight_decay=0, nesterov=False): 12 | if lr is not required and lr < 0.0: 13 | raise ValueError("Invalid learning rate: {}".format(lr)) 14 | if momentum < 0.0: 15 | raise ValueError("Invalid momentum value: {}".format(momentum)) 16 | if weight_decay < 0.0: 17 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 18 | 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 20 | weight_decay=weight_decay, nesterov=nesterov) 21 | if nesterov and (momentum <= 0 or dampening != 0): 22 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 23 | super(DGCSGD, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(DGCSGD, self).__setstate__(state) 27 | for group in self.param_groups: 28 | group.setdefault('nesterov', False) 29 | 30 | @torch.no_grad() 31 | def step(self, closure=None): 32 | """Performs a single optimization step. 33 | Arguments: 34 | closure (callable, optional): A closure that reevaluates the model 35 | and returns the loss. 36 | """ 37 | loss = None 38 | if closure is not None: 39 | with torch.enable_grad(): 40 | loss = closure() 41 | 42 | for group in self.param_groups: 43 | weight_decay = group['weight_decay'] 44 | momentum = group['momentum'] 45 | dampening = group['dampening'] 46 | nesterov = group['nesterov'] 47 | 48 | for p in group['params']: 49 | if p.grad is None: 50 | continue 51 | if weight_decay != 0: 52 | d_p = weight_decay * p.data 53 | if momentum != 0: 54 | param_state = self.state[p] 55 | if 'momentum_buffer' not in param_state: 56 | buf = param_state['momentum_buffer'] = d_p 57 | else: 58 | buf = param_state['momentum_buffer'] 59 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 60 | if nesterov: 61 | d_p = d_p.add(buf, alpha=momentum) 62 | else: 63 | d_p = buf 64 | d_p = d_p.add(p.grad) 65 | else: 66 | d_p = p.grad 67 | 68 | p.add_(d_p, alpha=-group['lr']) 69 | 70 | return loss 71 | 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | horovod 3 | tensorboardX 4 | numpy 5 | tqdm -------------------------------------------------------------------------------- /script/cifar.resnet110.sh: -------------------------------------------------------------------------------- 1 | # fp32 values, int64 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001, no momentum masking 2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5.py 3 | 4 | # fp32 values, int64 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001, no momentum masking 5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5o.py 6 | 7 | # fp16 values, int32 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001, no momentum masking 8 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5.py configs/dgc/fp16.py configs/dgc/int32.py 9 | 10 | # fp16 values, int32 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001, no momentum masking 11 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5o.py configs/dgc/fp16.py configs/dgc/int32.py 12 | -------------------------------------------------------------------------------- /script/cifar.resnet20.sh: -------------------------------------------------------------------------------- 1 | # fp32 values, int64 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001 2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5.py 3 | 4 | # fp32 values, int64 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001 5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5m.py 6 | 7 | # fp16 values, int32 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001 8 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5.py configs/dgc/fp16.py configs/dgc/int32.py 9 | 10 | # fp16 values, int32 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001 11 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5m.py configs/dgc/fp16.py configs/dgc/int32.py 12 | -------------------------------------------------------------------------------- /script/imagenet.resnet50.sh: -------------------------------------------------------------------------------- 1 | # fp32 values, int64 indices, no warmup 2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/dgc/wm0.py 3 | 4 | # fp32 values, int64 indices, cosine, no warmup 5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/imagenet/cosine.py configs/dgc/wm0.py 6 | 7 | # fp16 values, int32 indices, no warmup 8 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/dgc/wm0.py configs/dgc/fp16.py configs/dgc/int32. 9 | 10 | # fp16 values, int32 indices, cosine, no warmup 11 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/imagenet/cosine.py configs/dgc/wm0.py configs/dgc/fp16.py configs/dgc/int32.py -------------------------------------------------------------------------------- /script/imagenet.vgg16.sh: -------------------------------------------------------------------------------- 1 | # fp32 values, int64 indices, no warmup 2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/vgg16_bn.py configs/dgc/wm0.py 3 | 4 | # fp16 values, int32 indices, no warmup 5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/vgg16_bn.py configs/dgc/wm0.py configs/dgc/fp16.py configs/dgc/int32.py 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import random 5 | import shutil 6 | 7 | import numpy as np 8 | import horovod.torch as hvd 9 | import torch 10 | import torch.nn as nn 11 | import torch.backends.cudnn as cudnn 12 | import torch.multiprocessing as mp 13 | from tqdm import tqdm 14 | 15 | from torchpack.mtpack.utils.config import Config, configs 16 | 17 | from dgc.horovod.optimizer import DistributedOptimizer 18 | from dgc.compression import DGCCompressor 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--configs', nargs='+') 24 | parser.add_argument('--devices', default='gpu') 25 | parser.add_argument('--evaluate', action='store_true') 26 | parser.add_argument('--suffix', default='') 27 | args, opts = parser.parse_known_args() 28 | 29 | ################## 30 | # Update configs # 31 | ################## 32 | 33 | printr(f'==> loading configs from {args.configs}') 34 | Config.update_from_modules(*args.configs) 35 | Config.update_from_arguments(*opts) 36 | 37 | if args.devices is not None and args.devices != 'cpu': 38 | configs.device = 'cuda' 39 | # Horovod: pin GPU to local rank. 40 | torch.cuda.set_device(hvd.local_rank()) 41 | cudnn.benchmark = True 42 | else: 43 | configs.device = 'cpu' 44 | 45 | if 'seed' in configs and configs.seed is not None: 46 | random.seed(configs.seed) 47 | np.random.seed(configs.seed) 48 | torch.manual_seed(configs.seed) 49 | if configs.device == 'cuda' and configs.get('deterministic', True): 50 | cudnn.deterministic = True 51 | cudnn.benchmark = False 52 | 53 | configs.train.num_batches_per_step = \ 54 | configs.train.get('num_batches_per_step', 1) 55 | 56 | configs.train.save_path = get_save_path(*args.configs) \ 57 | + f'{args.suffix}.np{hvd.size()}' 58 | printr(f'[train.save_path] = {configs.train.save_path}') 59 | checkpoint_path = os.path.join(configs.train.save_path, 'checkpoints') 60 | configs.train.checkpoint_path = os.path.join( 61 | checkpoint_path, f'e{"{epoch}"}-r{hvd.rank()}.pth' 62 | ) 63 | configs.train.latest_pth_path = os.path.join( 64 | checkpoint_path, f'latest-r{hvd.rank()}.pth' 65 | ) 66 | configs.train.best_pth_path = os.path.join( 67 | checkpoint_path, f'best-r{hvd.rank()}.pth' 68 | ) 69 | os.makedirs(checkpoint_path, exist_ok=True) 70 | 71 | if args.evaluate: 72 | configs.train.latest_pth_path = configs.train.best_pth_path 73 | 74 | printr(configs) 75 | 76 | ##################################################################### 77 | # Initialize DataLoaders, Model, Criterion, LRScheduler & Optimizer # 78 | ##################################################################### 79 | 80 | printr(f'\n==> creating dataset "{configs.dataset}"') 81 | dataset = configs.dataset() 82 | # Horovod: limit # of CPU threads to be used per worker. 83 | torch.set_num_threads(configs.data.num_threads_per_worker) 84 | loader_kwargs = {'num_workers': configs.data.num_threads_per_worker, 85 | 'pin_memory': True} if configs.device == 'cuda' else {} 86 | # When supported, use 'forkserver' to spawn dataloader workers 87 | # instead of 'fork' to prevent issues with Infiniband implementations 88 | # that are not fork-safe 89 | if (loader_kwargs.get('num_workers', 0) > 0 and 90 | hasattr(mp, '_supports_context') and 91 | mp._supports_context and 92 | 'forkserver' in mp.get_all_start_methods()): 93 | loader_kwargs['multiprocessing_context'] = 'forkserver' 94 | printr(f'\n==> loading dataset "{loader_kwargs}""') 95 | samplers, loaders = {}, {} 96 | for split in dataset: 97 | # Horovod: use DistributedSampler to partition data among workers. 98 | # Manually specify `num_replicas=hvd.size()` and `rank=hvd.rank()`. 99 | samplers[split] = torch.utils.data.distributed.DistributedSampler( 100 | dataset[split], num_replicas=hvd.size(), rank=hvd.rank()) 101 | loaders[split] = torch.utils.data.DataLoader( 102 | dataset[split], batch_size=configs.train.batch_size * ( 103 | configs.train.num_batches_per_step if split == 'train' else 1), 104 | sampler=samplers[split], 105 | drop_last=(configs.train.num_batches_per_step > 1 106 | and split == 'train'), 107 | **loader_kwargs 108 | ) 109 | 110 | printr(f'\n==> creating model "{configs.model}"') 111 | model = configs.model() 112 | model = model.cuda() 113 | 114 | criterion = configs.train.criterion().to(configs.device) 115 | # Horovod: scale learning rate by the number of GPUs. 116 | configs.train.base_lr = configs.train.optimizer.lr 117 | configs.train.optimizer.lr *= (configs.train.num_batches_per_step 118 | * hvd.size()) 119 | printr(f'\n==> creating optimizer "{configs.train.optimizer}"') 120 | 121 | if configs.train.optimize_bn_separately: 122 | optimizer = configs.train.optimizer([ 123 | dict(params=get_common_parameters(model)), 124 | dict(params=get_bn_parameters(model), weight_decay=0) 125 | ]) 126 | else: 127 | optimizer = configs.train.optimizer(model.parameters()) 128 | 129 | # Horovod: (optional) compression algorithm. 130 | printr(f'\n==> creating compression "{configs.train.compression}"') 131 | if configs.train.dgc: 132 | printr(f'\n==> initializing dgc compression') 133 | configs.train.compression.memory = configs.train.compression.memory() 134 | compression = configs.train.compression() 135 | compression.memory.initialize(model.named_parameters()) 136 | cpr_parameters = {} 137 | for name, param in model.named_parameters(): 138 | if param.dim() > 1: 139 | cpr_parameters[name] = param 140 | compression.initialize(cpr_parameters.items()) 141 | else: 142 | compression = configs.train.compression() 143 | 144 | # Horovod: wrap optimizer with DistributedOptimizer. 145 | optimizer = DistributedOptimizer( 146 | optimizer, named_parameters=model.named_parameters(), 147 | compression=compression, 148 | backward_passes_per_step=configs.train.num_batches_per_step, 149 | op=hvd.Average 150 | ) 151 | 152 | # resume from checkpoint 153 | last_epoch, best_metric = -1, None 154 | if os.path.exists(configs.train.latest_pth_path): 155 | printr(f'\n[resume_path] = {configs.train.latest_pth_path}') 156 | checkpoint = torch.load(configs.train.latest_pth_path) 157 | if 'model' in checkpoint: 158 | model.load_state_dict(checkpoint.pop('model')) 159 | if 'optimizer' in checkpoint: 160 | optimizer.load_state_dict(checkpoint.pop('optimizer')) 161 | if configs.train.dgc and 'compression' in checkpoint: 162 | compression.memory.load_state_dict(checkpoint.pop('compression')) 163 | last_epoch = checkpoint.get('epoch', last_epoch) 164 | best_metric = checkpoint.get('meters', {}).get( 165 | f'{configs.train.metric}_best', best_metric) 166 | # Horovod: broadcast parameters. 167 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 168 | else: 169 | printr('\n==> train from scratch') 170 | # Horovod: broadcast parameters & optimizer state. 171 | printr('\n==> broadcasting paramters and optimizer state') 172 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 173 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 174 | 175 | num_steps_per_epoch = len(loaders['train']) 176 | if 'scheduler' in configs.train and configs.train.scheduler is not None: 177 | if configs.train.schedule_lr_per_epoch: 178 | last = max(last_epoch - configs.train.warmup_lr_epochs - 1, -1) 179 | else: 180 | last = max((last_epoch - configs.train.warmup_lr_epochs + 1) 181 | * num_steps_per_epoch - 2, -1) 182 | scheduler = configs.train.scheduler(optimizer, last_epoch=last) 183 | else: 184 | scheduler = None 185 | 186 | ############ 187 | # Training # 188 | ############ 189 | 190 | meters = evaluate(model, device=configs.device, meters=configs.train.meters, 191 | loader=loaders['test'], split='test') 192 | for k, meter in meters.items(): 193 | printr(f'[{k}] = {meter:2f}') 194 | if args.evaluate or last_epoch >= configs.train.num_epochs: 195 | return 196 | 197 | if hvd.rank() == 0: 198 | from tensorboardX import SummaryWriter 199 | writer = SummaryWriter(configs.train.save_path) 200 | else: 201 | writer = None 202 | 203 | for current_epoch in range(last_epoch + 1, configs.train.num_epochs): 204 | printr(f'\n==> training epoch {current_epoch}' 205 | f'/{configs.train.num_epochs}') 206 | 207 | if configs.train.dgc: 208 | compression.warmup_compress_ratio(current_epoch) 209 | 210 | train(model=model, loader=loaders['train'], 211 | device=configs.device, epoch=current_epoch, 212 | sampler=samplers['train'], criterion=criterion, 213 | optimizer=optimizer, scheduler=scheduler, 214 | batch_size=configs.train.batch_size, 215 | num_batches_per_step=configs.train.num_batches_per_step, 216 | num_steps_per_epoch=num_steps_per_epoch, 217 | warmup_lr_epochs=configs.train.warmup_lr_epochs, 218 | schedule_lr_per_epoch=configs.train.schedule_lr_per_epoch, 219 | writer=writer, quiet=hvd.rank() != 0) 220 | 221 | meters = dict() 222 | for split, loader in loaders.items(): 223 | if split != 'train': 224 | meters.update(evaluate(model, loader=loader, 225 | device=configs.device, 226 | meters=configs.train.meters, 227 | split=split, quiet=hvd.rank() != 0)) 228 | 229 | best = False 230 | if 'metric' in configs.train and configs.train.metric is not None: 231 | if best_metric is None or best_metric < meters[configs.train.metric]: 232 | best_metric, best = meters[configs.train.metric], True 233 | meters[configs.train.metric + '_best'] = best_metric 234 | 235 | if writer is not None: 236 | num_inputs = ((current_epoch + 1) * num_steps_per_epoch 237 | * configs.train.num_batches_per_step 238 | * configs.train.batch_size * hvd.size()) 239 | print('') 240 | for k, meter in meters.items(): 241 | print(f'[{k}] = {meter:2f}') 242 | writer.add_scalar(k, meter, num_inputs) 243 | 244 | checkpoint = { 245 | 'epoch': current_epoch, 246 | 'model': model.state_dict(), 247 | 'optimizer': optimizer.state_dict(), 248 | 'meters': meters, 249 | 'compression': compression.memory.state_dict() \ 250 | if configs.train.dgc else None 251 | } 252 | 253 | # save checkpoint 254 | checkpoint_path = \ 255 | configs.train.checkpoint_path.format(epoch=current_epoch) 256 | torch.save(checkpoint, checkpoint_path) 257 | shutil.copyfile(checkpoint_path, configs.train.latest_pth_path) 258 | if best: 259 | shutil.copyfile(checkpoint_path, configs.train.best_pth_path) 260 | if current_epoch >= 3: 261 | os.remove( 262 | configs.train.checkpoint_path.format(epoch=current_epoch - 3) 263 | ) 264 | printr(f'[save_path] = {checkpoint_path}') 265 | 266 | 267 | def train(model, loader, device, epoch, sampler, criterion, optimizer, 268 | scheduler, batch_size, num_batches_per_step, num_steps_per_epoch, warmup_lr_epochs, schedule_lr_per_epoch, writer=None, quiet=True): 269 | step_size = num_batches_per_step * batch_size 270 | num_inputs = epoch * num_steps_per_epoch * step_size * hvd.size() 271 | _r_num_batches_per_step = 1.0 / num_batches_per_step 272 | 273 | sampler.set_epoch(epoch) 274 | model.train() 275 | for step, (inputs, targets) in enumerate(tqdm( 276 | loader, desc='train', ncols=0, disable=quiet)): 277 | adjust_learning_rate(scheduler, epoch=epoch, step=step, 278 | num_steps_per_epoch=num_steps_per_epoch, 279 | warmup_lr_epochs=warmup_lr_epochs, 280 | schedule_lr_per_epoch=schedule_lr_per_epoch) 281 | 282 | inputs = inputs.to(device, non_blocking=True) 283 | targets = targets.to(device, non_blocking=True) 284 | optimizer.zero_grad() 285 | 286 | loss = torch.tensor([0.0]) 287 | for b in range(0, step_size, batch_size): 288 | _inputs = inputs[b:b+batch_size] 289 | _targets = targets[b:b+batch_size] 290 | _outputs = model(_inputs) 291 | _loss = criterion(_outputs, _targets) 292 | _loss.mul_(_r_num_batches_per_step) 293 | _loss.backward() 294 | loss += _loss.item() 295 | optimizer.step() 296 | 297 | # write train loss log 298 | loss = hvd.allreduce(loss, name='loss').item() 299 | if writer is not None: 300 | num_inputs += step_size * hvd.size() 301 | writer.add_scalar('loss/train', loss, num_inputs) 302 | 303 | 304 | def evaluate(model, loader, device, meters, split='test', quiet=True): 305 | _meters = {} 306 | for k, meter in meters.items(): 307 | _meters[k.format(split)] = meter() 308 | meters = _meters 309 | 310 | model.eval() 311 | 312 | with torch.no_grad(): 313 | for inputs, targets in tqdm(loader, desc=split, ncols=0, disable=quiet): 314 | inputs = inputs.to(device, non_blocking=True) 315 | targets = targets.to(device, non_blocking=True) 316 | 317 | outputs = model(inputs) 318 | for meter in meters.values(): 319 | meter.update(outputs, targets) 320 | 321 | for k, meter in meters.items(): 322 | data = meter.data() 323 | for dk, d in data.items(): 324 | data[dk] = \ 325 | hvd.allreduce(torch.tensor([d]), name=dk, op=hvd.Sum).item() 326 | meter.set(data) 327 | meters[k] = meter.compute() 328 | return meters 329 | 330 | 331 | # Horovod: using `lr = base_lr * hvd.size()` from the very beginning 332 | # leads to worse final accuracy. 333 | # Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` 334 | # during the first five epochs. See https://arxiv.org/abs/1706.02677. 335 | def adjust_learning_rate(scheduler, epoch, step, num_steps_per_epoch, 336 | warmup_lr_epochs=0, schedule_lr_per_epoch=False): 337 | if epoch < warmup_lr_epochs: 338 | size = hvd.size() 339 | epoch += step / num_steps_per_epoch 340 | factor = (epoch * (size - 1) / warmup_lr_epochs + 1) / size 341 | for param_group, base_lr in zip(scheduler.optimizer.param_groups, 342 | scheduler.base_lrs): 343 | param_group['lr'] = base_lr * factor 344 | elif schedule_lr_per_epoch and (step > 0 or epoch == 0): 345 | return 346 | elif epoch == warmup_lr_epochs and step == 0: 347 | for param_group, base_lr in zip(scheduler.optimizer.param_groups, 348 | scheduler.base_lrs): 349 | param_group['lr'] = base_lr 350 | return 351 | else: 352 | scheduler.step() 353 | 354 | def get_bn_parameters(module): 355 | def get_members_fn(m): 356 | if isinstance(m, nn.BatchNorm2d): 357 | return m._parameters.items() 358 | else: 359 | return dict() 360 | gen = module._named_members(get_members_fn=get_members_fn) 361 | for _, elem in gen: 362 | yield elem 363 | 364 | 365 | def get_common_parameters(module): 366 | def get_members_fn(m): 367 | if isinstance(m, nn.BatchNorm2d): 368 | return dict() 369 | else: 370 | for n, p in m._parameters.items(): 371 | yield n, p 372 | 373 | gen = module._named_members(get_members_fn=get_members_fn) 374 | for _, elem in gen: 375 | yield elem 376 | 377 | 378 | def get_save_path(*configs, prefix='runs'): 379 | memo = dict() 380 | for c in configs: 381 | cmemo = memo 382 | c = c.replace('configs/', '').replace('.py', '').split('/') 383 | for m in c: 384 | if m not in cmemo: 385 | cmemo[m] = dict() 386 | cmemo = cmemo[m] 387 | 388 | def get_str(m, p): 389 | n = len(m) 390 | if n > 1: 391 | p += '[' 392 | for i, (k, v) in enumerate(m.items()): 393 | p += k 394 | if len(v) > 0: 395 | p += '.' 396 | p = get_str(v, p) 397 | if n > 1 and i < n - 1: 398 | p += '+' 399 | if n > 1: 400 | p += ']' 401 | return p 402 | 403 | return os.path.join(prefix, get_str(memo, '')) 404 | 405 | 406 | def printr(*args, **kwargs): 407 | if hvd.rank() == 0: 408 | print(*args, **kwargs) 409 | 410 | 411 | if __name__ == '__main__': 412 | hvd.init() 413 | main() 414 | --------------------------------------------------------------------------------