├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── configs
├── __init__.py
├── cifar
│ ├── __init__.py
│ ├── resnet110.py
│ └── resnet20.py
├── dgc
│ ├── __init__.py
│ ├── fp16.py
│ ├── int32.py
│ ├── mm.py
│ ├── nm.py
│ ├── wm0.py
│ ├── wm5.py
│ └── wm5o.py
└── imagenet
│ ├── __init__.py
│ ├── cosine.py
│ ├── resnet18.py
│ ├── resnet50.py
│ └── vgg16_bn.py
├── data
├── .gitignore
└── docs
│ ├── cifar-10.png
│ ├── resnet.png
│ ├── speedup.png
│ └── teaser.png
├── dgc
├── __init__.py
├── clip_grad.py
├── compression.py
├── horovod
│ ├── README.md
│ ├── __init__.py
│ ├── compression.py
│ ├── horovod.june.6.6b77884.patch
│ └── optimizer.py
├── memory.py
└── optim
│ ├── __init__.py
│ └── sgd.py
├── requirements.txt
├── script
├── cifar.resnet110.sh
├── cifar.resnet20.sh
├── imagenet.resnet50.sh
└── imagenet.vgg16.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .vscode/
132 | logs/
133 | runs/
134 | *.swp
135 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "torchpack"]
2 | path = torchpack
3 | url = https://github.com/synxlin/mini-torchpack.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2024] Yujun Lin
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Gradient Compression [[arXiv]](https://arxiv.org/pdf/1712.01887.pdf)
2 |
3 | ```
4 | @inproceedings{lin2018dgc,
5 | title={{Deep Gradient Compression: Reducing the communication bandwidth for distributed training}},
6 | author={Lin, Yujun and Han, Song and Mao, Huizi and Wang, Yu and Dally, William J},
7 | booktitle={The International Conference on Learning Representations},
8 | year={2018}
9 | }
10 | ```
11 |
12 | ## Overview
13 |
14 | We release the PyTorch code of the [Deep Gradient Compression](https://arxiv.org/pdf/1712.01887.pdf).
15 |
16 |
17 | 
18 | Figure 1. Deep Gradient Compression (DGC) can reduce the communication bandwidth (transmit less gradients by pruning away small gradients), improve the scalability, and speed up distributed training.
19 |
20 |
21 | 
22 | Figure 2. : DGC maintains accuracy: Learning curves of ResNet (the gradient sparsity is 99.9%).
23 |
24 | 
25 | Figure 3. DGC improves the scalability: speedup measured on NVIDIA TITAN RTX 2080Ti GPU cluster with 25 Gbps Ethernet.
26 |
27 |
28 | ## Content
29 | - [Prerequisites](#prerequisites)
30 | - [Code](#code)
31 | - [Training](#training)
32 | - [Known Issues and TODOs](#known-issues-and-todos)
33 |
34 | ## Prerequisites
35 |
36 | The code is built with following libraries (see [requirements.txt](requirements.txt)):
37 | - Python >= 3.7
38 | - [PyTorch](https://github.com/pytorch/pytorch) >= 1.5
39 | - [Horovod](https://github.com/horovod/horovod) >= 0.19.4
40 | - [numpy](https://github.com/numpy/numpy)
41 | - [tensorboardX](https://github.com/lanpa/tensorboardX) >= 1.2
42 | - [tqdm](https://github.com/tqdm/tqdm)
43 | - [openmpi](https://www.open-mpi.org/software/ompi/) >= 4.0
44 |
45 | ## Code
46 |
47 | The core code to implement DGC is in [dgc/compression.py](dgc/compression.py) and [dgc/memory.py](dgc/memory.py).
48 |
49 | - Gradient Accumulation and Momentum Correction
50 | ```python
51 | mmt = self.momentums[name]
52 | vec = self.velocities[name]
53 | if self.nesterov:
54 | mmt.add_(grad).mul_(self.momentum)
55 | vec.add_(mmt).add_(grad)
56 | else:
57 | mmt.mul_(self.momentum).add_(grad)
58 | vec.add_(mmt)
59 | return vec
60 | ```
61 |
62 | - Sparsification
63 | ```python
64 | importance = tensor.abs()
65 | # sampling
66 | sample_start = random.randint(0, sample_stride - 1)
67 | samples = importance[sample_start::sample_stride]
68 | # thresholding
69 | threshold = torch.min(torch.topk(samples, top_k_samples, 0, largest=True, sorted=False)[0])
70 | mask = torch.ge(importance, threshold)
71 | indices = mask.nonzero().view(-1)
72 | ```
73 |
74 | ## Training
75 | We use [Horovod](https://github.com/horovod/horovod) to run distributed training:
76 | - run on a machine with *N* GPUs,
77 | ```bash
78 | horovodrun -np N python train.py --configs [config files]
79 | ```
80 | e.g., resnet-20 on cifar-10 dataset with 8 GPUs:
81 | ```bash
82 | # fp16 values, int32 indices
83 | # warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001
84 | horovodrun -np 8 python train.py --configs configs/cifar/resnet20.py \
85 | configs/dgc/wm5.py configs/dgc/fp16.py configs/dgc/int32.py
86 | ```
87 | - run on *K* machines with *N* GPUs each,
88 | ```bash
89 | mpirun -np [K*N] -H server0:N,server1:N,...,serverK:N \
90 | -bind-to none -map-by slot -x NCCL_DEBUG=INFO \
91 | -x LD_LIBRARY_PATH -x PATH -mca pml ob1 \
92 | -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo \
93 | python train.py --configs [config files]
94 | ```
95 | e.g., resnet-50 on ImageNet dataset with 4 machines with 8 GPUs each,
96 | ```bash
97 | # fp32 values, int64 indices, no warmup
98 | mpirun -np 32 -H server0:8,server1:8,server2:8,server3:8 \
99 | -bind-to none -map-by slot -x NCCL_DEBUG=INFO \
100 | -x LD_LIBRARY_PATH -x PATH -mca pml ob1 \
101 | -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo \
102 | python train.py --configs configs/imagenet/resnet50.py \
103 | configs/dgc/wm0.py
104 | ```
105 | For more information on horovodrun, please read horovod documentations.
106 |
107 | You can modify/add new config files under [configs](configs) to change training settings. You can also modify some trivial configs in the command:
108 | ```bash
109 | python train.py --configs [config files] --[config name] [config value] --suffix [suffix of experiment directory]
110 | ```
111 | e.g.,
112 | ```bash
113 | horovodrun -np 8 python train.py --configs configs/cifar/resnet20.py \
114 | configs/dgc/wm5.py --configs.train.num_epochs 500 --suffix .e500
115 | ```
116 |
117 | Here are some reproduce results using **0.1%** compression ratio (*i.e.*, `configs.train.compression.compress_ratio = 0.001`):
118 | | #GPUs | Batch Size | #Sparsified Nodes | ResNet-50 | VGG-16-BN | LR Scheduler |
119 | |:-----:|:----------:|:-----------------:|:---------:|:---------:|:------------:|
120 | | - | - | - | [76.2](https://pytorch.org/docs/stable/torchvision/models.html) | [73.4](https://pytorch.org/docs/stable/torchvision/models.html) | - |
121 | | 8 | 256 | 8 | 76.6 | 74.1 | MultiStep |
122 | | 16 | 512 | 16 | 76.5 | 73.8 | MultiStep |
123 | | 32 | 1024 | 32 | 76.3 | 73.3 | MultiStep |
124 | | 32 | 1024 | 32 | 76.7 | 74.4 | Cosine |
125 | | 64 | 2048 | 64 | 76.8 | 74.2 | Cosine |
126 | | 64 | 2048 | 8 | 76.6 | 73.8 | Cosine |
127 | | 128 | 4096 | 16 | 76.4 | 73.1 | Cosine |
128 | | 256 | 8192 | 32 | 75.9 | 71.7 | Cosine |
129 |
130 | ## Known Issues and TODOs
131 |
132 | - **Backend**: We currently only support OpenMPI backend. We encountered some errors when calling `allgather` using NCCL2 backend: `allgather`ed data are random data once in a while; if we set `CUDA_LAUNCH_BLOCKING=1` for debugging, everything works well.
133 | - **#Sparsified Nodes**: We currently treat each GPU as an independent node. However, communication is rarely a bottleneck within one machine. A better strategy should be performing `allreduce` dense gradients intra-machine and `allgather` sparse gradients inter-machines.
134 | - For accuracy/convergence verification, we can simulate this by setting `configs.train.num_batches_per_step` to desired #GPUs per machine (see accuracy table for batch size = 4096/8192).
135 | - **Sparsification Granularity**: We naively perform fine-grained (*i.e.*, element-wise) top-k to select gradients, and thus the communication will suffer from increased `allgather` data volume as #nodes increases.
136 | - [Sun *et.al.*](https://arxiv.org/pdf/1902.06855.pdf) modified the process with coarse-grained sparsification: gradients are partioned into chunks, `allreduce` the gradient chunks selected based on `allreduce`d L1-norm of each chunk, which gets rid of the `allgather` and solves the problem.
137 | - **Data Encoding**: We did not perform any data quantization/encoding before transmission. Data encoding can further reduce data volume.
138 | - **Overhead**: Performing sparsification (esp. adapting thresholding) in C/C++ may further reduce the DGC overhead.
139 |
140 | ## License
141 |
142 | This repository is released under the Apache license. See [LICENSE](LICENSE) for additional details.
143 |
144 |
145 | ## Acknowledgement
146 | - Our implementation is modified from [grace](https://github.com/sands-lab/grace) which is an unified framework for all sorts of compressed distributed training algorithms.
147 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 | from torchpack.mtpack.meters import TopKClassMeter
5 |
6 | from dgc.horovod.compression import Compression
7 |
8 |
9 | configs.seed = 42
10 | configs.data = Config()
11 | configs.data.num_threads_per_worker = 4
12 |
13 | # criterion
14 | configs.train = Config()
15 | configs.train.dgc = False
16 | configs.train.compression = Compression.none
17 | configs.train.criterion = Config(torch.nn.CrossEntropyLoss)
18 |
19 | # optimizer
20 | configs.train.optimizer = Config(torch.optim.SGD)
21 | configs.train.optimizer.momentum = 0.9
22 |
23 | # scheduler
24 | configs.train.schedule_lr_per_epoch = True
25 | configs.train.warmup_lr_epochs = 5
26 |
27 | # metrics
28 | configs.train.metric = 'acc/test_top1'
29 | configs.train.meters = Config()
30 | configs.train.meters['acc/{}_top1'] = Config(TopKClassMeter)
31 | configs.train.meters['acc/{}_top1'].k = 1
32 | configs.train.meters['acc/{}_top5'] = Config(TopKClassMeter)
33 | configs.train.meters['acc/{}_top5'].k = 5
34 |
--------------------------------------------------------------------------------
/configs/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torchpack.mtpack.datasets.vision import CIFAR
4 | from torchpack.mtpack.utils.config import Config, configs
5 |
6 | # dataset
7 | configs.dataset = Config(CIFAR)
8 | configs.dataset.root = './data/cifar10'
9 | configs.dataset.num_classes = 10
10 | configs.dataset.image_size = 32
11 |
12 | # training
13 | configs.train.num_epochs = 200
14 | configs.train.batch_size = 128
15 |
16 | # optimizer
17 | configs.train.optimizer.lr = 0.1
18 | configs.train.optimizer.weight_decay = 1e-4
19 |
20 | # scheduler
21 | configs.train.scheduler = Config(torch.optim.lr_scheduler.CosineAnnealingLR)
22 | configs.train.scheduler.T_max = configs.train.num_epochs - configs.train.warmup_lr_epochs
23 |
--------------------------------------------------------------------------------
/configs/cifar/resnet110.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.models.vision.resnet import resnet110
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 |
5 | # model
6 | configs.model = Config(resnet110)
7 | configs.model.num_classes = configs.dataset.num_classes
8 |
--------------------------------------------------------------------------------
/configs/cifar/resnet20.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.models.vision.resnet import resnet20
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 |
5 | # model
6 | configs.model = Config(resnet20)
7 | configs.model.num_classes = configs.dataset.num_classes
8 |
--------------------------------------------------------------------------------
/configs/dgc/__init__.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | from dgc.compression import DGCCompressor
4 | from dgc.memory import DGCSGDMemory
5 | from dgc.optim import DGCSGD
6 |
7 |
8 | configs.train.dgc = True
9 | configs.train.compression = Config(DGCCompressor)
10 | configs.train.compression.compress_ratio = 0.001
11 | configs.train.compression.sample_ratio = 0.01
12 | configs.train.compression.strided_sample = True
13 | configs.train.compression.compress_upper_bound = 1.3
14 | configs.train.compression.compress_lower_bound = 0.8
15 | configs.train.compression.max_adaptation_iters = 10
16 | configs.train.compression.resample = True
17 |
18 | old_optimizer = configs.train.optimizer
19 | configs.train.optimizer = Config(DGCSGD)
20 | for k, v in old_optimizer.items():
21 | configs.train.optimizer[k] = v
22 |
23 | configs.train.compression.memory = Config(DGCSGDMemory)
24 | configs.train.compression.memory.momentum = configs.train.optimizer.momentum
25 |
--------------------------------------------------------------------------------
/configs/dgc/fp16.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.fp16_values = True
4 |
--------------------------------------------------------------------------------
/configs/dgc/int32.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.int32_indices = True
4 |
--------------------------------------------------------------------------------
/configs/dgc/mm.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.memory.momentum_masking = True
4 |
--------------------------------------------------------------------------------
/configs/dgc/nm.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.memory.momentum_masking = False
4 |
--------------------------------------------------------------------------------
/configs/dgc/wm0.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.warmup_epochs = 0
4 |
--------------------------------------------------------------------------------
/configs/dgc/wm5.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.warmup_epochs = 5
4 |
--------------------------------------------------------------------------------
/configs/dgc/wm5o.py:
--------------------------------------------------------------------------------
1 | from torchpack.mtpack.utils.config import Config, configs
2 |
3 | configs.train.compression.warmup_epochs = 5
4 | configs.train.compression.warmup_coeff = [1, 1, 1, 1, 1]
5 |
--------------------------------------------------------------------------------
/configs/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torchpack.mtpack.datasets.vision import ImageNet
4 | from torchpack.mtpack.utils.config import Config, configs
5 |
6 | # dataset
7 | configs.dataset = Config(ImageNet)
8 | configs.dataset.root = '/dataset/imagenet'
9 | configs.dataset.num_classes = 1000
10 | configs.dataset.image_size = 224
11 |
12 | # training
13 | configs.train.num_epochs = 90
14 | configs.train.batch_size = 32
15 |
16 | # optimizer
17 | configs.train.optimize_bn_separately = False
18 | configs.train.optimizer.lr = 0.0125
19 | configs.train.optimizer.weight_decay = 5e-5
20 |
21 | # scheduler
22 | configs.train.scheduler = Config(torch.optim.lr_scheduler.MultiStepLR)
23 | configs.train.scheduler.milestones = [e - configs.train.warmup_lr_epochs
24 | for e in [30, 60, 80]]
25 | configs.train.scheduler.gamma = 0.1
26 |
--------------------------------------------------------------------------------
/configs/imagenet/cosine.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 |
5 | # scheduler
6 | configs.train.scheduler = Config(torch.optim.lr_scheduler.CosineAnnealingLR)
7 | configs.train.scheduler.T_max = configs.train.num_epochs - configs.train.warmup_lr_epochs
8 |
--------------------------------------------------------------------------------
/configs/imagenet/resnet18.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import resnet18
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 |
5 | configs.train.batch_size = 64
6 | configs.train.optimizer.lr = 0.025
7 |
8 | # model
9 | configs.model = Config(resnet18)
10 | configs.model.num_classes = configs.dataset.num_classes
11 | configs.model.zero_init_residual = True
12 |
--------------------------------------------------------------------------------
/configs/imagenet/resnet50.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import resnet50
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 |
5 | configs.train.optimizer.weight_decay = 1e-4
6 | configs.train.optimizer.nesterov = True
7 | configs.train.optimize_bn_separately = True
8 |
9 | # model
10 | configs.model = Config(resnet50)
11 | configs.model.num_classes = configs.dataset.num_classes
12 | configs.model.zero_init_residual = True
13 |
--------------------------------------------------------------------------------
/configs/imagenet/vgg16_bn.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import vgg16_bn
2 |
3 | from torchpack.mtpack.utils.config import Config, configs
4 |
5 | # model
6 | configs.model = Config(vgg16_bn)
7 | configs.model.num_classes = configs.dataset.num_classes
8 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !docs/
3 | !docs/*
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/data/docs/cifar-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/cifar-10.png
--------------------------------------------------------------------------------
/data/docs/resnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/resnet.png
--------------------------------------------------------------------------------
/data/docs/speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/speedup.png
--------------------------------------------------------------------------------
/data/docs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/data/docs/teaser.png
--------------------------------------------------------------------------------
/dgc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/synxlin/deep-gradient-compression/24673d45d11bba08bd4554b4585cae65e3dcf6f1/dgc/__init__.py
--------------------------------------------------------------------------------
/dgc/clip_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch._six import inf
3 |
4 | from horovod.torch import allreduce_
5 |
6 | __all__ = ['clip_grad_norm_', 'clip_grad_value_', 'clip_grad_value_by_global_norm_', 'clip_grad_norm_2_by_global_']
7 |
8 |
9 | # code modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py
10 | def clip_grad_norm_(grad, max_norm, norm_type=2):
11 | max_norm = float(max_norm)
12 | norm_type = float(norm_type)
13 | if norm_type == inf:
14 | total_norm = grad.data.abs().max()
15 | else:
16 | total_norm = grad.data.norm(norm_type)
17 | clip_coef = max_norm / (total_norm + 1e-6)
18 | if clip_coef < 1:
19 | grad.data.mul_(clip_coef)
20 | return grad
21 |
22 |
23 | def clip_grad_value_(grad, clip_value):
24 | clip_value = float(clip_value)
25 | grad.data.clamp_(min=-clip_value, max=clip_value)
26 |
27 |
28 | # code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/memory/dgc.py
29 | def clip_grad_value_by_global_norm_(grad, name=None):
30 | grad_square_sum = torch.sum(grad.square())
31 | clip_value = torch.sqrt(allreduce_(grad_square_sum, average=True, name=name))
32 | grad.data.clamp_(min=-clip_value, max=clip_value)
33 |
34 |
35 | def clip_grad_norm_2_by_global_(grad, max_norm, name=None):
36 | max_norm = float(max_norm)
37 | grad_square_sum = torch.sum(grad.square())
38 | total_norm = torch.sqrt(allreduce_(grad_square_sum, average=True, name=name))
39 | clip_coef = max_norm / (total_norm + 1e-6)
40 | if clip_coef < 1:
41 | grad.data.mul_(clip_coef)
42 | return grad
43 |
--------------------------------------------------------------------------------
/dgc/compression.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import torch
5 |
6 | import horovod.torch as hvd
7 | from horovod.torch.mpi_ops import Average
8 | from horovod.torch.mpi_ops import allreduce_async_
9 | from horovod.torch.mpi_ops import allgather_async as allgather_async_
10 | from horovod.torch.mpi_ops import synchronize as synchronize_
11 |
12 | from dgc.memory import Memory
13 |
14 | __all__ = ['DGCCompressor']
15 |
16 |
17 | class DGCCompressor:
18 | def __init__(self, compress_ratio, memory=None,
19 | sample_ratio=0.01, strided_sample=True,
20 | compress_upper_bound=1.3, compress_lower_bound=0.8, max_adaptation_iters=10, resample=True,
21 | fp16_values=False, int32_indices=False,
22 | warmup_epochs=-1, warmup_coeff=None):
23 | self.world_size = hvd.size()
24 | self.op = Average
25 | self.fp16_values = fp16_values
26 | self.int32_indices = int32_indices
27 |
28 | self.base_compress_ratio = self.compress_ratio = \
29 | compress_ratio if compress_ratio <= 1.0 else 1.0 / compress_ratio
30 | self.memory = Memory if memory is None else memory
31 | self.warmup_epochs = warmup_epochs
32 | if self.warmup_epochs > 0:
33 | if warmup_coeff is None:
34 | self.warmup_coeff = self.base_compress_ratio \
35 | ** (1. / (self.warmup_epochs + 1))
36 | else:
37 | if isinstance(warmup_coeff, (tuple, list)):
38 | assert len(warmup_coeff) >= self.warmup_epochs
39 | for wc in warmup_coeff:
40 | assert 0 < wc <= 1
41 | else:
42 | assert 0 < warmup_coeff <= 1
43 | self.warmup_coeff = warmup_coeff
44 | else:
45 | self.warmup_coeff = 1
46 |
47 | self.sample_ratio = min(max(sample_ratio, 0.01), 1.0)
48 | self.strided_sample = strided_sample
49 | self.compress_upper_bound = compress_upper_bound
50 | self.compress_lower_bound = compress_lower_bound
51 | self.max_adaptation_iters = max_adaptation_iters
52 | self.resample = resample
53 |
54 | self.attributes = {}
55 |
56 | def initialize(self, named_parameters):
57 | if hvd.rank() == 0:
58 | print("=> initializing dgc compressor")
59 | for name, param in named_parameters:
60 | if torch.is_tensor(param):
61 | numel = param.numel()
62 | shape = list(param.size())
63 | else:
64 | assert isinstance(param, (list, tuple))
65 | numel, shape = param[0], param[1]
66 | if self.sample_ratio < 1.0:
67 | pct_numel = int(math.ceil(numel * self.sample_ratio))
68 | cpr_numel = int(math.ceil(2 / self.compress_ratio))
69 | if numel <= cpr_numel:
70 | if hvd.rank() == 0:
71 | print(f'Warning: {name} with {numel} elements transmits 1 gradient element')
72 | sample_stride = 1
73 | num_samples = numel
74 | else:
75 | sample_stride = int(math.ceil(numel / max(pct_numel, cpr_numel) / 32)) * 32 + 1
76 | num_samples = numel // sample_stride
77 | while num_samples < max(pct_numel, cpr_numel):
78 | sample_stride = sample_stride - 8
79 | num_samples = numel // sample_stride
80 | else:
81 | sample_stride = 1
82 | num_samples = numel
83 | top_k_samples = int(math.ceil(num_samples * self.compress_ratio))
84 | num_selects = int(math.ceil(numel * self.compress_ratio))
85 | self.attributes[name] = (numel, shape, num_selects, num_samples, top_k_samples, sample_stride)
86 | if hvd.rank() == 0:
87 | print(f' {name:<25}: transmit {num_selects} / {numel} elements of shape {shape}\n'
88 | f' {" " * 25} threshold {top_k_samples} / {num_samples} samples'
89 | f' {f"at stride {sample_stride}" if self.strided_sample else "uniformly"}')
90 |
91 | def warmup_compress_ratio(self, epoch):
92 | if self.warmup_epochs > 0:
93 | if epoch < self.warmup_epochs:
94 | if isinstance(self.warmup_coeff, (tuple, list)):
95 | compress_ratio = self.warmup_coeff[epoch]
96 | else:
97 | compress_ratio = max(self.warmup_coeff ** (epoch + 1),
98 | self.base_compress_ratio)
99 | else:
100 | compress_ratio = self.base_compress_ratio
101 | else:
102 | compress_ratio = self.base_compress_ratio
103 | if compress_ratio != self.compress_ratio:
104 | if hvd.rank() == 0:
105 | print(f'update compress ratio: {compress_ratio}')
106 | self.compress_ratio = compress_ratio
107 | self.initialize(self.attributes.items())
108 |
109 | def _sparsify(self, tensor, name):
110 | tensor = tensor.view(-1)
111 | numel, shape, num_selects, num_samples, top_k_samples, sample_stride = self.attributes[name]
112 |
113 | importance = tensor.abs()
114 | if numel == num_samples:
115 | samples = importance
116 | else:
117 | if self.strided_sample:
118 | sample_start = random.randint(0, sample_stride - 1)
119 | samples = importance[sample_start::sample_stride]
120 | else:
121 | samples = importance[torch.randint(0, numel, (num_samples, ), device=tensor.device)]
122 |
123 | threshold = torch.min(torch.topk(samples, top_k_samples, 0, largest=True, sorted=False)[0])
124 | mask = torch.ge(importance, threshold)
125 | indices = mask.nonzero().view(-1)
126 | num_indices = indices.numel()
127 |
128 | if numel > num_samples:
129 | # code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/compressor/dgc.py
130 | for _ in range(self.max_adaptation_iters):
131 | if num_indices > num_selects:
132 | if num_indices > num_selects * self.compress_upper_bound:
133 | if self.resample:
134 | indices = indices[
135 | torch.topk(importance[indices], num_selects,
136 | 0, largest=True, sorted=False)[1]
137 | ]
138 | break
139 | else:
140 | threshold = threshold * self.compress_upper_bound
141 | else:
142 | break
143 | elif num_indices < self.compress_lower_bound * num_selects:
144 | threshold = threshold * self.compress_lower_bound
145 | else:
146 | break
147 | mask = torch.ge(importance, threshold)
148 | indices = mask.nonzero().view(-1)
149 | num_indices = indices.numel()
150 |
151 | indices = indices[:num_selects]
152 | values = tensor[indices]
153 | return values, indices, numel, shape, num_selects
154 |
155 | def compress(self, tensor, name):
156 | if self.compress_ratio < 1.0 and name in self.attributes:
157 | # compress
158 | tensor_compensated = self.memory.compensate(
159 | tensor, name, accumulate=True)
160 | values, indices, numel, shape, num_selects = \
161 | self._sparsify(tensor_compensated, name)
162 | self.memory.update(name, (indices, ))
163 | indices = indices.view(-1, 1)
164 | values = values.view(-1, 1)
165 |
166 | ctx = (name, numel, shape, values.dtype, indices.dtype,
167 | tensor.data.view(numel))
168 | if self.fp16_values and values.dtype.is_floating_point:
169 | values = values.type(torch.float16)
170 | if self.int32_indices and not indices.dtype.is_floating_point:
171 | indices = indices.type(torch.int32)
172 | return (values, indices), ctx
173 | else:
174 | ctx = (name, None, None, tensor.dtype, None, None)
175 | if self.fp16_values and tensor.dtype.is_floating_point:
176 | tensor = tensor.type(torch.float16)
177 | return tensor, ctx
178 |
179 | def decompress(self, tensor, ctx):
180 | name, numel, shape, vdtype, idtype, grad = ctx
181 | if self.compress_ratio < 1.0 and name in self.attributes:
182 | # decompress
183 | assert isinstance(tensor, (list, tuple))
184 | values, indices = tensor
185 | values = values.view(-1)
186 | indices = indices.view(-1)
187 | if self.fp16_values and vdtype.is_floating_point:
188 | values = values.type(vdtype)
189 | if self.int32_indices and not idtype.is_floating_point:
190 | indices = indices.type(idtype)
191 | grad.zero_().index_put_([indices], values, accumulate=True)
192 | if self.op == Average:
193 | grad.mul_(1. / self.world_size)
194 | return grad.view(shape)
195 | else:
196 | if self.fp16_values and vdtype.is_floating_point:
197 | tensor = tensor.type(vdtype)
198 | return self.memory.compensate(tensor, name, accumulate=False)
199 |
200 | def communicate(self, tensor_compressed, name, op):
201 | self.op = op
202 | if self.compress_ratio < 1.0 and name in self.attributes:
203 | return [allgather_async_(t, name=f'{name}.t{e}')
204 | for e, t in enumerate(tensor_compressed)]
205 | else:
206 | return allreduce_async_(tensor_compressed, name=name, op=op)
207 |
208 | def synchronize(self, handle):
209 | if isinstance(handle, (tuple, list)):
210 | return [synchronize_(h) for h in handle]
211 | else:
212 | return synchronize_(handle)
213 |
--------------------------------------------------------------------------------
/dgc/horovod/README.md:
--------------------------------------------------------------------------------
1 | # Horovod Patch
2 |
3 | We applied [patch](horovod.june.6.6b77884.patch) to Horovod at [this commit](https://github.com/horovod/horovod/tree/6b77884daf92649ecf031fcc8ff29697bbea0132).
4 | Nonetheless, we copied the modified files in this directory so that you don't have to patch Horovod source code.
5 |
6 | The modification is very subtile:
7 |
8 | - class `Compressor` will take `name` as another argument when compressing a tensor.
9 |
10 | - class `DistributedOptimizer` will perform `communicate()` of its compression member if possible, instead of always using `allreduce_async()`.
11 |
--------------------------------------------------------------------------------
/dgc/horovod/__init__.py:
--------------------------------------------------------------------------------
1 | from dgc.horovod.compression import Compressor, Compression
2 | from dgc.horovod.optimizer import DistributedOptimizer
3 |
--------------------------------------------------------------------------------
/dgc/horovod/compression.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Gradient compression algorithms."""
16 |
17 | import torch
18 |
19 | __all__ = ['Compressor', 'Compression']
20 |
21 |
22 | class Compressor(object):
23 | """Interface for compressing and decompressing a given tensor."""
24 | @staticmethod
25 | def compress(tensor, name=None):
26 | """Compresses a tensor and returns it with the context needed to decompress it."""
27 | pass
28 |
29 | @staticmethod
30 | def decompress(tensor, ctx):
31 | """Decompress the tensor with the given context."""
32 | pass
33 |
34 |
35 | class NoneCompressor(Compressor):
36 | """Default no-op compression."""
37 | @staticmethod
38 | def compress(tensor, name=None):
39 | """Returns the tensor unmodified."""
40 | return tensor, None
41 |
42 | @staticmethod
43 | def decompress(tensor, ctx):
44 | """Returns the tensor unmodified."""
45 | return tensor
46 |
47 |
48 | class FP16Compressor(Compressor):
49 | """Compress all floating point gradients to 16-bit."""
50 | @staticmethod
51 | def compress(tensor, name=None):
52 | """Downcasts the tensor to 16-bit."""
53 | tensor_compressed = tensor
54 | if tensor.dtype.is_floating_point:
55 | # Only allow compression from other floating point types
56 | tensor_compressed = tensor.type(torch.float16)
57 | return tensor_compressed, tensor.dtype
58 |
59 | @staticmethod
60 | def decompress(tensor, ctx):
61 | """Upcasts the tensor to the initialization dtype."""
62 | tensor_decompressed = tensor
63 | dtype = ctx
64 | if dtype.is_floating_point:
65 | tensor_decompressed = tensor.type(dtype)
66 | return tensor_decompressed
67 |
68 |
69 | class Compression(object):
70 | """Optional gradient compression algorithm used during allreduce."""
71 |
72 | """Do not compress the gradients. This is the default."""
73 | none = NoneCompressor
74 |
75 | """Compress all floating point gradients to 16-bit."""
76 | fp16 = FP16Compressor
77 |
--------------------------------------------------------------------------------
/dgc/horovod/horovod.june.6.6b77884.patch:
--------------------------------------------------------------------------------
1 | diff --git a/horovod/torch/compression.py b/horovod/torch/compression.py
2 | index 75ce91e..ca160ce 100644
3 | --- a/horovod/torch/compression.py
4 | +++ b/horovod/torch/compression.py
5 | @@ -20,7 +20,7 @@ import torch
6 | class Compressor(object):
7 | """Interface for compressing and decompressing a given tensor."""
8 | @staticmethod
9 | - def compress(tensor):
10 | + def compress(tensor, name=None):
11 | """Compresses a tensor and returns it with the context needed to decompress it."""
12 | pass
13 |
14 | @@ -33,7 +33,7 @@ class Compressor(object):
15 | class NoneCompressor(Compressor):
16 | """Default no-op compression."""
17 | @staticmethod
18 | - def compress(tensor):
19 | + def compress(tensor, name=None):
20 | """Returns the tensor unmodified."""
21 | return tensor, None
22 |
23 | @@ -46,7 +46,7 @@ class NoneCompressor(Compressor):
24 | class FP16Compressor(Compressor):
25 | """Compress all floating point gradients to 16-bit."""
26 | @staticmethod
27 | - def compress(tensor):
28 | + def compress(tensor, name=None):
29 | """Downcasts the tensor to 16-bit."""
30 | tensor_compressed = tensor
31 | if tensor.dtype.is_floating_point:
32 | diff --git a/horovod/torch/optimizer.py b/horovod/torch/optimizer.py
33 | index c8def4b..d954146 100644
34 | --- a/horovod/torch/optimizer.py
35 | +++ b/horovod/torch/optimizer.py
36 | @@ -23,7 +23,7 @@ import torch
37 |
38 | from horovod.torch.compression import Compression
39 | from horovod.torch.mpi_ops import allreduce_async_
40 | -from horovod.torch.mpi_ops import synchronize
41 | +from horovod.torch.mpi_ops import synchronize as synchronize_
42 | from horovod.torch.mpi_ops import size
43 | from horovod.torch.mpi_ops import Average, Adasum
44 |
45 | @@ -33,6 +33,8 @@ class _DistributedOptimizer(torch.optim.Optimizer):
46 | backward_passes_per_step=1, op=Average):
47 | super(self.__class__, self).__init__(params)
48 | self._compression = compression
49 | + self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_)
50 | + self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_)
51 |
52 | if named_parameters is not None:
53 | named_parameters = list(named_parameters)
54 | @@ -111,9 +113,9 @@ class _DistributedOptimizer(torch.optim.Optimizer):
55 | def _allreduce_grad_async(self, p):
56 | name = self._parameter_names.get(p)
57 | tensor = p.grad
58 | - tensor_compressed, ctx = self._compression.compress(tensor)
59 | + tensor_compressed, ctx = self._compression.compress(tensor, name)
60 |
61 | - handle = allreduce_async_(tensor_compressed, name=name, op=self.op)
62 | + handle = self._communicate_(tensor_compressed, name=name, op=self.op)
63 | return handle, ctx
64 |
65 | def _make_hook(self, p):
66 | @@ -140,13 +142,12 @@ class _DistributedOptimizer(torch.optim.Optimizer):
67 | handle, ctx = self._allreduce_grad_async(p)
68 | self._handles[p] = (handle, ctx)
69 |
70 | - for p, value in self._handles.items():
71 | - handle, ctx = value
72 | + for p, (handle, ctx) in self._handles.items():
73 | if handle is None:
74 | handle, ctx = self._allreduce_grad_async(p)
75 | self._handles[p] = (handle, ctx)
76 | - for p, (handle, _) in self._handles.items():
77 | - output = synchronize(handle)
78 | + for p, (handle, ctx) in self._handles.items():
79 | + output = self._synchronize_(handle)
80 | self._allreduce_delay[p] = self.backward_passes_per_step
81 | p.grad.set_(self._compression.decompress(output, ctx))
82 | self._handles.clear()
83 | @@ -200,6 +201,8 @@ class _DistributedAdasumOptimizer(torch.optim.Optimizer):
84 | super(self.__class__, self).__init__(params)
85 |
86 | self._compression = compression
87 | + self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_)
88 | + self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_)
89 |
90 | if named_parameters is not None:
91 | named_parameters = list(named_parameters)
92 | @@ -298,8 +301,8 @@ class _DistributedAdasumOptimizer(torch.optim.Optimizer):
93 | p.data.sub_(start)
94 |
95 | # allreduce as before
96 | - tensor_compressed, ctx = self._compression.compress(p)
97 | - handle = allreduce_async_(tensor_compressed.data, name=name, op=Adasum)
98 | + tensor_compressed, ctx = self._compression.compress(p, name)
99 | + handle = self._communicate_(tensor_compressed.data, name=name, op=Adasum)
100 |
101 | # reset stashed parameters
102 | for stashed, group in zip(stashed_params, self.param_groups):
103 | @@ -348,7 +351,7 @@ class _DistributedAdasumOptimizer(torch.optim.Optimizer):
104 | if not handle:
105 | handle, ctx = self._allreduce_grad_async(p)
106 | self._handles[p] = (handle, ctx)
107 | - delta = synchronize(handle)
108 | + delta = self._synchronize_(handle)
109 | delta = self._compression.decompress(delta, ctx)
110 | start = self._starting_models[p]
111 | start.data.add_(delta.data)
112 |
--------------------------------------------------------------------------------
/dgc/horovod/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
2 | # Modifications copyright Microsoft
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | import os
18 | import warnings
19 |
20 | from contextlib import contextmanager
21 |
22 | import torch
23 |
24 | from horovod.torch.mpi_ops import allreduce_async_
25 | from horovod.torch.mpi_ops import synchronize as synchronize_
26 | from horovod.torch.mpi_ops import size
27 | from horovod.torch.mpi_ops import Average, Adasum
28 |
29 | from .compression import Compression
30 |
31 | __all__ = ['DistributedOptimizer']
32 |
33 |
34 | class _DistributedOptimizer(torch.optim.Optimizer):
35 | def __init__(self, params, named_parameters, compression,
36 | backward_passes_per_step=1, op=Average):
37 | super(self.__class__, self).__init__(params)
38 | self._compression = compression
39 | self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_)
40 | self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_)
41 |
42 | if named_parameters is not None:
43 | named_parameters = list(named_parameters)
44 | else:
45 | named_parameters = [('allreduce.noname.%s' % i, v)
46 | for param_group in self.param_groups
47 | for i, v in enumerate(param_group['params'])]
48 | # make sure that named_parameters are tuples
49 | if any([not isinstance(p, tuple) for p in named_parameters]):
50 | raise ValueError('named_parameters should be a sequence of '
51 | 'tuples (name, parameter), usually produced by '
52 | 'model.named_parameters().')
53 |
54 | dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters])
55 | if len(dups) > 0:
56 | raise ValueError('Parameter names in named_parameters must be unique. '
57 | 'Found duplicates: %s' % ', '.join(dups))
58 |
59 | all_param_ids = {id(v)
60 | for param_group in self.param_groups
61 | for v in param_group['params']}
62 | named_param_ids = {id(v) for k, v in named_parameters}
63 | unnamed_param_ids = all_param_ids - named_param_ids
64 | if len(unnamed_param_ids):
65 | raise ValueError('named_parameters was specified, but one or more model '
66 | 'parameters were not named. Python object ids: '
67 | '%s' % ', '.join(str(id) for id in unnamed_param_ids))
68 |
69 | self._parameter_names = {v: k for k, v in sorted(named_parameters)}
70 | self.backward_passes_per_step = backward_passes_per_step
71 | self._allreduce_delay = {v: self.backward_passes_per_step
72 | for _, v in sorted(named_parameters)}
73 | self.op = op
74 | self._handles = {}
75 | self._grad_accs = []
76 | self._requires_update = set()
77 | self._synchronized = False
78 | self._should_synchronize = True
79 | if size() > 1 or os.environ.get('HOROVOD_ELASTIC') == '1':
80 | self._register_hooks()
81 |
82 | def load_state_dict(self, *args, **kwargs):
83 | self._handles = {}
84 | self._synchronized = False
85 | self._should_synchronize = True
86 | for p in self._allreduce_delay:
87 | self._allreduce_delay[p] = self.backward_passes_per_step
88 | super(self.__class__, self).load_state_dict(*args, **kwargs)
89 |
90 | @staticmethod
91 | def find_duplicates(lst):
92 | seen = set()
93 | dups = set()
94 | for el in lst:
95 | if el in seen:
96 | dups.add(el)
97 | seen.add(el)
98 | return dups
99 |
100 | def set_backward_passes_per_step(self, passes):
101 | self.backward_passes_per_step = passes
102 | for p in self._allreduce_delay:
103 | self._allreduce_delay[p] = self.backward_passes_per_step
104 |
105 | def _register_hooks(self):
106 | for param_group in self.param_groups:
107 | for p in param_group['params']:
108 | if p.requires_grad:
109 | p.grad = p.data.new(p.size()).zero_()
110 | self._requires_update.add(p)
111 | p_tmp = p.expand_as(p)
112 | grad_acc = p_tmp.grad_fn.next_functions[0][0]
113 | grad_acc.register_hook(self._make_hook(p))
114 | self._grad_accs.append(grad_acc)
115 |
116 | def _allreduce_grad_async(self, p):
117 | name = self._parameter_names.get(p)
118 | tensor_compressed, ctx = self._compression.compress(p.grad, name)
119 |
120 | handle = self._communicate_(tensor_compressed, name=name, op=self.op)
121 | return handle, ctx
122 |
123 | def _make_hook(self, p):
124 | def hook(*ignore):
125 | if p in self._handles and self._handles[p][0] is not None:
126 | if self._allreduce_delay[p] <= 0:
127 | raise AssertionError(
128 | "Gradients were computed more than "
129 | "backward_passes_per_step times before call "
130 | "to step(). Increase backward_passes_per_step to "
131 | "accumulate gradients locally.")
132 | assert not p.grad.requires_grad
133 | assert self._allreduce_delay[p] > 0
134 | handle, ctx = None, None
135 | self._allreduce_delay[p] -= 1
136 | if self._allreduce_delay[p] == 0:
137 | handle, ctx = self._allreduce_grad_async(p)
138 | self._handles[p] = (handle, ctx)
139 | return hook
140 |
141 | def synchronize(self):
142 | missing_p = self._requires_update - set(self._handles.keys())
143 | for p in missing_p:
144 | handle, ctx = self._allreduce_grad_async(p)
145 | self._handles[p] = (handle, ctx)
146 |
147 | for p, (handle, ctx) in self._handles.items():
148 | if handle is None:
149 | handle, ctx = self._allreduce_grad_async(p)
150 | self._handles[p] = (handle, ctx)
151 | for p, (handle, ctx) in self._handles.items():
152 | output = self._synchronize_(handle)
153 | self._allreduce_delay[p] = self.backward_passes_per_step
154 | p.grad.set_(self._compression.decompress(output, ctx))
155 | self._handles.clear()
156 |
157 | self._synchronized = True
158 |
159 | @contextmanager
160 | def skip_synchronize(self):
161 | """
162 | A context manager used to specify that optimizer.step() should
163 | not perform synchronization.
164 |
165 | It's typically used in a following pattern:
166 |
167 | .. code-block:: python
168 |
169 | optimizer.synchronize()
170 | with optimizer.skip_synchronize():
171 | optimizer.step()
172 | """
173 | self._should_synchronize = False
174 | try:
175 | yield
176 | finally:
177 | self._should_synchronize = True
178 |
179 | def step(self, closure=None):
180 | if self._should_synchronize:
181 | if self._synchronized:
182 | warnings.warn("optimizer.step() called without "
183 | "optimizer.skip_synchronize() context after "
184 | "optimizer.synchronize(). This can cause training "
185 | "slowdown. You may want to consider using "
186 | "optimizer.skip_synchronize() context if you use "
187 | "optimizer.synchronize() in your code.")
188 | self.synchronize()
189 | self._synchronized = False
190 | return super(self.__class__, self).step(closure)
191 |
192 | def zero_grad(self):
193 | if self._handles:
194 | raise AssertionError("optimizer.zero_grad() was called after loss.backward() "
195 | "but before optimizer.step() or optimizer.synchronize(). "
196 | "This is prohibited as it can cause a race condition.")
197 | return super(self.__class__, self).zero_grad()
198 |
199 |
200 | class _DistributedAdasumOptimizer(torch.optim.Optimizer):
201 | def __init__(self, params, named_parameters, compression,
202 | backward_passes_per_step=1):
203 | super(self.__class__, self).__init__(params)
204 |
205 | self._compression = compression
206 | self._communicate_ = getattr(self._compression, 'communicate', allreduce_async_)
207 | self._synchronize_ = getattr(self._compression, 'synchronize', synchronize_)
208 |
209 | if named_parameters is not None:
210 | named_parameters = list(named_parameters)
211 | else:
212 | named_parameters = [('allreduce.noname.%s' % i, v)
213 | for param_group in self.param_groups
214 | for i, v in enumerate(param_group['params'])]
215 |
216 | # make sure that named_parameters are tuples
217 | if any([not isinstance(p, tuple) for p in named_parameters]):
218 | raise ValueError('named_parameters should be a sequence of '
219 | 'tuples (name, parameter), usually produced by '
220 | 'model.named_parameters().')
221 |
222 | dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters])
223 | if len(dups) > 0:
224 | raise ValueError('Parameter names in named_parameters must be unique. '
225 | 'Found duplicates: %s' % ', '.join(dups))
226 |
227 | all_param_ids = {id(v)
228 | for param_group in self.param_groups
229 | for v in param_group['params']}
230 | named_param_ids = {id(v) for k, v in named_parameters}
231 | unnamed_param_ids = all_param_ids - named_param_ids
232 | if len(unnamed_param_ids):
233 | raise ValueError('named_parameters was specified, but one or more model '
234 | 'parameters were not named. Python object ids: '
235 | '%s' % ', '.join(str(id) for id in unnamed_param_ids))
236 |
237 | self._parameter_names = {v: k for k, v in sorted(named_parameters)}
238 | self.backward_passes_per_step = backward_passes_per_step
239 | self._allreduce_delay = {v: self.backward_passes_per_step
240 | for _, v in sorted(named_parameters)}
241 | self._handles = {}
242 | self._grad_accs = []
243 | self._requires_update = set()
244 | self._synchronized = False
245 | self._should_synchronize = True
246 |
247 | self._starting_models = {
248 | p : torch.zeros_like(p, requires_grad=False)
249 | for _, p in named_parameters
250 | }
251 |
252 | self._register_hooks()
253 |
254 | def set_backward_passes_per_step(self, passes):
255 | self.backward_passes_per_step = passes
256 | for p in self._allreduce_delay:
257 | self._allreduce_delay[p] = self.backward_passes_per_step
258 |
259 | def _register_hooks(self):
260 | for param_group in self.param_groups:
261 | for p in param_group['params']:
262 | if p.requires_grad:
263 | p.grad = p.data.new(p.size()).zero_()
264 | self._requires_update.add(p)
265 | p_tmp = p.expand_as(p)
266 | grad_acc = p_tmp.grad_fn.next_functions[0][0]
267 | grad_acc.register_hook(self._make_hook(p))
268 | self._grad_accs.append(grad_acc)
269 |
270 | def _allreduce_grad_async(self, p):
271 | # Delta optimizer implements this logic:
272 | # start = current.copy()
273 | # step() -> computes 'current - \alpha.f(g)' where f is
274 | # optimizer logic and g is the gradient
275 | # delta = current-start
276 | # allreduce_(delta)
277 | # start += delta
278 | # current = start
279 | # In order to suppport this logic using function hook to improve performance,
280 | # we do:
281 | # delta = (start - \alpha.f(g)) - start
282 | # = -\alpha.f(g)
283 | # set start to zero and step computes -\alpha.f(g)
284 | # where f is the underlying optimizer logic
285 |
286 | name = self._parameter_names.get(p)
287 | start = self._starting_models[p]
288 |
289 | stashed_params = []
290 | for group in self.param_groups:
291 | stashed_params.append(group['params'])
292 | # only want to step on p
293 | if any([p is v for v in group['params']]):
294 | group['params'] = [p]
295 | else:
296 | group['params'] = []
297 |
298 | start.data.copy_(p)
299 |
300 | super(self.__class__, self).step()
301 |
302 | # compute delta = curr - start
303 | p.data.sub_(start)
304 |
305 | # allreduce as before
306 | tensor_compressed, ctx = self._compression.compress(p, name)
307 | handle = self._communicate_(tensor_compressed.data, name=name, op=Adasum)
308 |
309 | # reset stashed parameters
310 | for stashed, group in zip(stashed_params, self.param_groups):
311 | group['params'] = stashed
312 |
313 | return handle, ctx
314 |
315 | def _make_hook(self, p):
316 | def hook(*ignore):
317 | if p in self._handles and self._handles[p][0] is not None:
318 | if self._allreduce_delay[p] <= 0:
319 | raise AssertionError(
320 | "Gradients were computed more than "
321 | "backward_passes_per_step times before call "
322 | "to step(). Increase backward_passes_per_step to "
323 | "accumulate gradients locally.")
324 | assert not p.grad.requires_grad
325 | assert self._allreduce_delay[p] > 0
326 | handle, ctx = None, None
327 | self._allreduce_delay[p] -= 1
328 | if self._allreduce_delay[p] == 0:
329 | handle, ctx = self._allreduce_grad_async(p)
330 | self._handles[p] = (handle, ctx)
331 | return hook
332 |
333 | def synchronize(self):
334 | pass
335 |
336 | @contextmanager
337 | def skip_synchronize(self):
338 | raise AssertionError("Skipping synchronization is not supported when using Adasum optimizer.")
339 |
340 | def step(self, closure=None):
341 | loss = None
342 | if closure is not None:
343 | loss = closure()
344 |
345 | missing_p = self._requires_update - set(self._handles.keys())
346 | for p in missing_p:
347 | handle, ctx = self._allreduce_grad_async(p)
348 | self._handles[p] = (handle, ctx)
349 |
350 | for p, (handle, ctx) in self._handles.items():
351 | # This means step() is called before backward_passes_per_steps finished.
352 | # We do a synchoronous allreduce here.
353 | if not handle:
354 | handle, ctx = self._allreduce_grad_async(p)
355 | self._handles[p] = (handle, ctx)
356 | delta = self._synchronize_(handle)
357 | delta = self._compression.decompress(delta, ctx)
358 | start = self._starting_models[p]
359 | start.data.add_(delta.data)
360 | p.data.copy_(start)
361 | self._allreduce_delay[p] = self.backward_passes_per_step
362 | self._handles.clear()
363 | return loss
364 |
365 | def zero_grad(self):
366 | if self._handles:
367 | raise AssertionError("optimizer.zero_grad() was called after loss.backward() "
368 | "but before optimizer.step() or optimizer.synchronize(). "
369 | "This is prohibited as it can cause a race condition.")
370 | return super(self.__class__, self).zero_grad()
371 |
372 |
373 | def DistributedOptimizer(optimizer, named_parameters=None,
374 | compression=Compression.none,
375 | backward_passes_per_step=1,
376 | op=Average):
377 | """
378 | An optimizer that wraps another torch.optim.Optimizer, using an allreduce to
379 | combine gradient values before applying gradients to model weights.
380 |
381 | Allreduce operations are executed after each gradient is computed by ``loss.backward()``
382 | in parallel with each other. The ``step()`` method ensures that all allreduce operations are
383 | finished before applying gradients to the model.
384 |
385 | DistributedOptimizer exposes the ``synchronize()`` method, which forces allreduce operations
386 | to finish before continuing the execution. It's useful in conjunction with gradient
387 | clipping, or other operations that modify gradients in place before ``step()`` is executed.
388 | Make sure to use ``optimizer.skip_synchronize()`` if you're calling ``synchronize()``
389 | in your code.
390 |
391 | Example of gradient clipping:
392 |
393 | .. code-block:: python
394 |
395 | output = model(data)
396 | loss = F.nll_loss(output, target)
397 | loss.backward()
398 | optimizer.synchronize()
399 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
400 | with optimizer.skip_synchronize():
401 | optimizer.step()
402 |
403 | Arguments:
404 | optimizer: Optimizer to use for computing gradients and applying updates.
405 | named_parameters: A mapping between parameter names and values. Used for naming of
406 | allreduce operations. Typically just ``model.named_parameters()``.
407 | compression: Compression algorithm used during allreduce to reduce the amount
408 | of data sent during the each parameter update step. Defaults to
409 | not using compression.
410 | backward_passes_per_step: Number of expected backward passes to perform
411 | before calling step()/synchronize(). This
412 | allows accumulating gradients over multiple
413 | mini-batches before reducing and applying them.
414 | op: The reduction operation to use when combining gradients across different ranks.
415 | """
416 | # We dynamically create a new class that inherits from the optimizer that was passed in.
417 | # The goal is to override the `step()` method with an allreduce implementation.
418 |
419 | if op != Adasum or size() == 1:
420 | cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
421 | dict(_DistributedOptimizer.__dict__))
422 | return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step, op)
423 | else:
424 | cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
425 | dict(_DistributedAdasumOptimizer.__dict__))
426 | return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step)
427 |
--------------------------------------------------------------------------------
/dgc/memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import horovod.torch as hvd
4 |
5 | __all__ = ['Memory', 'DGCSGDMemory']
6 |
7 |
8 | # code modified from https://github.com/sands-lab/grace/blob/master/grace_dl/torch/memory/dgc.py
9 | class Memory:
10 | @staticmethod
11 | def initialize(*args, **kwargs):
12 | pass
13 |
14 | @staticmethod
15 | def compensate(tensor, *args, **kwargs):
16 | return tensor
17 |
18 | @staticmethod
19 | def update(*args, **kwargs):
20 | pass
21 |
22 | @staticmethod
23 | def state_dict():
24 | return None
25 |
26 | @staticmethod
27 | def load_state_dict(state_dict):
28 | pass
29 |
30 |
31 | class DGCSGDMemory(Memory):
32 | """ Memory for momentum correction in DGC for momentum SGD optimizer"""
33 | def __init__(self, momentum=0.9, nesterov=False,
34 | gradient_clipping=None, momentum_masking=True):
35 | self.gradient_clipping = gradient_clipping
36 | self.momentum_masking = momentum_masking
37 |
38 | self.momentum = momentum
39 | self.nesterov = nesterov
40 | self.momentums = {}
41 | self.velocities = {}
42 |
43 | def initialize(self, named_parameters):
44 | if hvd.rank() == 0:
45 | print("=> initializing dgc sgd memory")
46 | for name, param in named_parameters:
47 | self.momentums[name] = torch.zeros_like(param.data)
48 | self.velocities[name] = torch.zeros_like(param.data)
49 |
50 | def compensate(self, grad, name, accumulate=True):
51 | """Update the velocities with the momentums."""
52 | if self.gradient_clipping is not None:
53 | grad = self.gradient_clipping(grad)
54 | mmt = self.momentums[name]
55 | if accumulate:
56 | vec = self.velocities[name]
57 | if self.nesterov:
58 | mmt.add_(grad).mul_(self.momentum)
59 | vec.add_(mmt).add_(grad)
60 | else:
61 | mmt.mul_(self.momentum).add_(grad)
62 | vec.add_(mmt)
63 | return vec
64 | else:
65 | if self.nesterov:
66 | mmt.add_(grad).mul_(self.momentum)
67 | return mmt.add(grad)
68 | else:
69 | mmt.mul_(self.momentum).add_(grad)
70 | return mmt.clone() # TODO: save this clone
71 |
72 | def update(self, name, ctx):
73 | """Update the momentums."""
74 | indices = ctx[0]
75 | if self.momentum_masking:
76 | self.momentums[name].view(-1).index_fill_(0, indices, 0)
77 | self.velocities[name].view(-1).index_fill_(0, indices, 0)
78 |
79 | def state_dict(self):
80 | return dict(momentums=self.momentums, velocities=self.velocities)
81 |
82 | def load_state_dict(self, state_dict):
83 | momentums = state_dict['momentums']
84 | velocities = state_dict['velocities']
85 | for name in self.momentums.keys():
86 | if name in momentums:
87 | self.momentums[name] = momentums[name]
88 | self.velocities[name] = velocities[name]
89 |
--------------------------------------------------------------------------------
/dgc/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from dgc.optim.sgd import DGCSGD
2 |
--------------------------------------------------------------------------------
/dgc/optim/sgd.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
2 |
3 | import torch
4 | from torch.optim.optimizer import Optimizer, required
5 |
6 | __all__ = ['DGCSGD']
7 |
8 |
9 | class DGCSGD(Optimizer):
10 | def __init__(self, params, lr=required, momentum=0, dampening=0,
11 | weight_decay=0, nesterov=False):
12 | if lr is not required and lr < 0.0:
13 | raise ValueError("Invalid learning rate: {}".format(lr))
14 | if momentum < 0.0:
15 | raise ValueError("Invalid momentum value: {}".format(momentum))
16 | if weight_decay < 0.0:
17 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
18 |
19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
20 | weight_decay=weight_decay, nesterov=nesterov)
21 | if nesterov and (momentum <= 0 or dampening != 0):
22 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
23 | super(DGCSGD, self).__init__(params, defaults)
24 |
25 | def __setstate__(self, state):
26 | super(DGCSGD, self).__setstate__(state)
27 | for group in self.param_groups:
28 | group.setdefault('nesterov', False)
29 |
30 | @torch.no_grad()
31 | def step(self, closure=None):
32 | """Performs a single optimization step.
33 | Arguments:
34 | closure (callable, optional): A closure that reevaluates the model
35 | and returns the loss.
36 | """
37 | loss = None
38 | if closure is not None:
39 | with torch.enable_grad():
40 | loss = closure()
41 |
42 | for group in self.param_groups:
43 | weight_decay = group['weight_decay']
44 | momentum = group['momentum']
45 | dampening = group['dampening']
46 | nesterov = group['nesterov']
47 |
48 | for p in group['params']:
49 | if p.grad is None:
50 | continue
51 | if weight_decay != 0:
52 | d_p = weight_decay * p.data
53 | if momentum != 0:
54 | param_state = self.state[p]
55 | if 'momentum_buffer' not in param_state:
56 | buf = param_state['momentum_buffer'] = d_p
57 | else:
58 | buf = param_state['momentum_buffer']
59 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
60 | if nesterov:
61 | d_p = d_p.add(buf, alpha=momentum)
62 | else:
63 | d_p = buf
64 | d_p = d_p.add(p.grad)
65 | else:
66 | d_p = p.grad
67 |
68 | p.add_(d_p, alpha=-group['lr'])
69 |
70 | return loss
71 |
72 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | horovod
3 | tensorboardX
4 | numpy
5 | tqdm
--------------------------------------------------------------------------------
/script/cifar.resnet110.sh:
--------------------------------------------------------------------------------
1 | # fp32 values, int64 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001, no momentum masking
2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5.py
3 |
4 | # fp32 values, int64 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001, no momentum masking
5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5o.py
6 |
7 | # fp16 values, int32 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001, no momentum masking
8 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5.py configs/dgc/fp16.py configs/dgc/int32.py
9 |
10 | # fp16 values, int32 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001, no momentum masking
11 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet110.py configs/dgc/wm5o.py configs/dgc/fp16.py configs/dgc/int32.py
12 |
--------------------------------------------------------------------------------
/script/cifar.resnet20.sh:
--------------------------------------------------------------------------------
1 | # fp32 values, int64 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001
2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5.py
3 |
4 | # fp32 values, int64 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001
5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5m.py
6 |
7 | # fp16 values, int32 indices, warmup coeff: [0.25, 0.063, 0.015, 0.004, 0.001] -> 0.001
8 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5.py configs/dgc/fp16.py configs/dgc/int32.py
9 |
10 | # fp16 values, int32 indices, wamup coeff: [1, 1, 1, 1, 1] -> 0.001
11 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/cifar/resnet20.py configs/dgc/wm5m.py configs/dgc/fp16.py configs/dgc/int32.py
12 |
--------------------------------------------------------------------------------
/script/imagenet.resnet50.sh:
--------------------------------------------------------------------------------
1 | # fp32 values, int64 indices, no warmup
2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/dgc/wm0.py
3 |
4 | # fp32 values, int64 indices, cosine, no warmup
5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/imagenet/cosine.py configs/dgc/wm0.py
6 |
7 | # fp16 values, int32 indices, no warmup
8 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/dgc/wm0.py configs/dgc/fp16.py configs/dgc/int32.
9 |
10 | # fp16 values, int32 indices, cosine, no warmup
11 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/resnet50.py configs/imagenet/cosine.py configs/dgc/wm0.py configs/dgc/fp16.py configs/dgc/int32.py
--------------------------------------------------------------------------------
/script/imagenet.vgg16.sh:
--------------------------------------------------------------------------------
1 | # fp32 values, int64 indices, no warmup
2 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/vgg16_bn.py configs/dgc/wm0.py
3 |
4 | # fp16 values, int32 indices, no warmup
5 | mpirun -np ${1} -H ${2} -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_exclude docker0,lo python train.py --configs configs/imagenet/vgg16_bn.py configs/dgc/wm0.py configs/dgc/fp16.py configs/dgc/int32.py
6 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import random
5 | import shutil
6 |
7 | import numpy as np
8 | import horovod.torch as hvd
9 | import torch
10 | import torch.nn as nn
11 | import torch.backends.cudnn as cudnn
12 | import torch.multiprocessing as mp
13 | from tqdm import tqdm
14 |
15 | from torchpack.mtpack.utils.config import Config, configs
16 |
17 | from dgc.horovod.optimizer import DistributedOptimizer
18 | from dgc.compression import DGCCompressor
19 |
20 |
21 | def main():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--configs', nargs='+')
24 | parser.add_argument('--devices', default='gpu')
25 | parser.add_argument('--evaluate', action='store_true')
26 | parser.add_argument('--suffix', default='')
27 | args, opts = parser.parse_known_args()
28 |
29 | ##################
30 | # Update configs #
31 | ##################
32 |
33 | printr(f'==> loading configs from {args.configs}')
34 | Config.update_from_modules(*args.configs)
35 | Config.update_from_arguments(*opts)
36 |
37 | if args.devices is not None and args.devices != 'cpu':
38 | configs.device = 'cuda'
39 | # Horovod: pin GPU to local rank.
40 | torch.cuda.set_device(hvd.local_rank())
41 | cudnn.benchmark = True
42 | else:
43 | configs.device = 'cpu'
44 |
45 | if 'seed' in configs and configs.seed is not None:
46 | random.seed(configs.seed)
47 | np.random.seed(configs.seed)
48 | torch.manual_seed(configs.seed)
49 | if configs.device == 'cuda' and configs.get('deterministic', True):
50 | cudnn.deterministic = True
51 | cudnn.benchmark = False
52 |
53 | configs.train.num_batches_per_step = \
54 | configs.train.get('num_batches_per_step', 1)
55 |
56 | configs.train.save_path = get_save_path(*args.configs) \
57 | + f'{args.suffix}.np{hvd.size()}'
58 | printr(f'[train.save_path] = {configs.train.save_path}')
59 | checkpoint_path = os.path.join(configs.train.save_path, 'checkpoints')
60 | configs.train.checkpoint_path = os.path.join(
61 | checkpoint_path, f'e{"{epoch}"}-r{hvd.rank()}.pth'
62 | )
63 | configs.train.latest_pth_path = os.path.join(
64 | checkpoint_path, f'latest-r{hvd.rank()}.pth'
65 | )
66 | configs.train.best_pth_path = os.path.join(
67 | checkpoint_path, f'best-r{hvd.rank()}.pth'
68 | )
69 | os.makedirs(checkpoint_path, exist_ok=True)
70 |
71 | if args.evaluate:
72 | configs.train.latest_pth_path = configs.train.best_pth_path
73 |
74 | printr(configs)
75 |
76 | #####################################################################
77 | # Initialize DataLoaders, Model, Criterion, LRScheduler & Optimizer #
78 | #####################################################################
79 |
80 | printr(f'\n==> creating dataset "{configs.dataset}"')
81 | dataset = configs.dataset()
82 | # Horovod: limit # of CPU threads to be used per worker.
83 | torch.set_num_threads(configs.data.num_threads_per_worker)
84 | loader_kwargs = {'num_workers': configs.data.num_threads_per_worker,
85 | 'pin_memory': True} if configs.device == 'cuda' else {}
86 | # When supported, use 'forkserver' to spawn dataloader workers
87 | # instead of 'fork' to prevent issues with Infiniband implementations
88 | # that are not fork-safe
89 | if (loader_kwargs.get('num_workers', 0) > 0 and
90 | hasattr(mp, '_supports_context') and
91 | mp._supports_context and
92 | 'forkserver' in mp.get_all_start_methods()):
93 | loader_kwargs['multiprocessing_context'] = 'forkserver'
94 | printr(f'\n==> loading dataset "{loader_kwargs}""')
95 | samplers, loaders = {}, {}
96 | for split in dataset:
97 | # Horovod: use DistributedSampler to partition data among workers.
98 | # Manually specify `num_replicas=hvd.size()` and `rank=hvd.rank()`.
99 | samplers[split] = torch.utils.data.distributed.DistributedSampler(
100 | dataset[split], num_replicas=hvd.size(), rank=hvd.rank())
101 | loaders[split] = torch.utils.data.DataLoader(
102 | dataset[split], batch_size=configs.train.batch_size * (
103 | configs.train.num_batches_per_step if split == 'train' else 1),
104 | sampler=samplers[split],
105 | drop_last=(configs.train.num_batches_per_step > 1
106 | and split == 'train'),
107 | **loader_kwargs
108 | )
109 |
110 | printr(f'\n==> creating model "{configs.model}"')
111 | model = configs.model()
112 | model = model.cuda()
113 |
114 | criterion = configs.train.criterion().to(configs.device)
115 | # Horovod: scale learning rate by the number of GPUs.
116 | configs.train.base_lr = configs.train.optimizer.lr
117 | configs.train.optimizer.lr *= (configs.train.num_batches_per_step
118 | * hvd.size())
119 | printr(f'\n==> creating optimizer "{configs.train.optimizer}"')
120 |
121 | if configs.train.optimize_bn_separately:
122 | optimizer = configs.train.optimizer([
123 | dict(params=get_common_parameters(model)),
124 | dict(params=get_bn_parameters(model), weight_decay=0)
125 | ])
126 | else:
127 | optimizer = configs.train.optimizer(model.parameters())
128 |
129 | # Horovod: (optional) compression algorithm.
130 | printr(f'\n==> creating compression "{configs.train.compression}"')
131 | if configs.train.dgc:
132 | printr(f'\n==> initializing dgc compression')
133 | configs.train.compression.memory = configs.train.compression.memory()
134 | compression = configs.train.compression()
135 | compression.memory.initialize(model.named_parameters())
136 | cpr_parameters = {}
137 | for name, param in model.named_parameters():
138 | if param.dim() > 1:
139 | cpr_parameters[name] = param
140 | compression.initialize(cpr_parameters.items())
141 | else:
142 | compression = configs.train.compression()
143 |
144 | # Horovod: wrap optimizer with DistributedOptimizer.
145 | optimizer = DistributedOptimizer(
146 | optimizer, named_parameters=model.named_parameters(),
147 | compression=compression,
148 | backward_passes_per_step=configs.train.num_batches_per_step,
149 | op=hvd.Average
150 | )
151 |
152 | # resume from checkpoint
153 | last_epoch, best_metric = -1, None
154 | if os.path.exists(configs.train.latest_pth_path):
155 | printr(f'\n[resume_path] = {configs.train.latest_pth_path}')
156 | checkpoint = torch.load(configs.train.latest_pth_path)
157 | if 'model' in checkpoint:
158 | model.load_state_dict(checkpoint.pop('model'))
159 | if 'optimizer' in checkpoint:
160 | optimizer.load_state_dict(checkpoint.pop('optimizer'))
161 | if configs.train.dgc and 'compression' in checkpoint:
162 | compression.memory.load_state_dict(checkpoint.pop('compression'))
163 | last_epoch = checkpoint.get('epoch', last_epoch)
164 | best_metric = checkpoint.get('meters', {}).get(
165 | f'{configs.train.metric}_best', best_metric)
166 | # Horovod: broadcast parameters.
167 | hvd.broadcast_parameters(model.state_dict(), root_rank=0)
168 | else:
169 | printr('\n==> train from scratch')
170 | # Horovod: broadcast parameters & optimizer state.
171 | printr('\n==> broadcasting paramters and optimizer state')
172 | hvd.broadcast_parameters(model.state_dict(), root_rank=0)
173 | hvd.broadcast_optimizer_state(optimizer, root_rank=0)
174 |
175 | num_steps_per_epoch = len(loaders['train'])
176 | if 'scheduler' in configs.train and configs.train.scheduler is not None:
177 | if configs.train.schedule_lr_per_epoch:
178 | last = max(last_epoch - configs.train.warmup_lr_epochs - 1, -1)
179 | else:
180 | last = max((last_epoch - configs.train.warmup_lr_epochs + 1)
181 | * num_steps_per_epoch - 2, -1)
182 | scheduler = configs.train.scheduler(optimizer, last_epoch=last)
183 | else:
184 | scheduler = None
185 |
186 | ############
187 | # Training #
188 | ############
189 |
190 | meters = evaluate(model, device=configs.device, meters=configs.train.meters,
191 | loader=loaders['test'], split='test')
192 | for k, meter in meters.items():
193 | printr(f'[{k}] = {meter:2f}')
194 | if args.evaluate or last_epoch >= configs.train.num_epochs:
195 | return
196 |
197 | if hvd.rank() == 0:
198 | from tensorboardX import SummaryWriter
199 | writer = SummaryWriter(configs.train.save_path)
200 | else:
201 | writer = None
202 |
203 | for current_epoch in range(last_epoch + 1, configs.train.num_epochs):
204 | printr(f'\n==> training epoch {current_epoch}'
205 | f'/{configs.train.num_epochs}')
206 |
207 | if configs.train.dgc:
208 | compression.warmup_compress_ratio(current_epoch)
209 |
210 | train(model=model, loader=loaders['train'],
211 | device=configs.device, epoch=current_epoch,
212 | sampler=samplers['train'], criterion=criterion,
213 | optimizer=optimizer, scheduler=scheduler,
214 | batch_size=configs.train.batch_size,
215 | num_batches_per_step=configs.train.num_batches_per_step,
216 | num_steps_per_epoch=num_steps_per_epoch,
217 | warmup_lr_epochs=configs.train.warmup_lr_epochs,
218 | schedule_lr_per_epoch=configs.train.schedule_lr_per_epoch,
219 | writer=writer, quiet=hvd.rank() != 0)
220 |
221 | meters = dict()
222 | for split, loader in loaders.items():
223 | if split != 'train':
224 | meters.update(evaluate(model, loader=loader,
225 | device=configs.device,
226 | meters=configs.train.meters,
227 | split=split, quiet=hvd.rank() != 0))
228 |
229 | best = False
230 | if 'metric' in configs.train and configs.train.metric is not None:
231 | if best_metric is None or best_metric < meters[configs.train.metric]:
232 | best_metric, best = meters[configs.train.metric], True
233 | meters[configs.train.metric + '_best'] = best_metric
234 |
235 | if writer is not None:
236 | num_inputs = ((current_epoch + 1) * num_steps_per_epoch
237 | * configs.train.num_batches_per_step
238 | * configs.train.batch_size * hvd.size())
239 | print('')
240 | for k, meter in meters.items():
241 | print(f'[{k}] = {meter:2f}')
242 | writer.add_scalar(k, meter, num_inputs)
243 |
244 | checkpoint = {
245 | 'epoch': current_epoch,
246 | 'model': model.state_dict(),
247 | 'optimizer': optimizer.state_dict(),
248 | 'meters': meters,
249 | 'compression': compression.memory.state_dict() \
250 | if configs.train.dgc else None
251 | }
252 |
253 | # save checkpoint
254 | checkpoint_path = \
255 | configs.train.checkpoint_path.format(epoch=current_epoch)
256 | torch.save(checkpoint, checkpoint_path)
257 | shutil.copyfile(checkpoint_path, configs.train.latest_pth_path)
258 | if best:
259 | shutil.copyfile(checkpoint_path, configs.train.best_pth_path)
260 | if current_epoch >= 3:
261 | os.remove(
262 | configs.train.checkpoint_path.format(epoch=current_epoch - 3)
263 | )
264 | printr(f'[save_path] = {checkpoint_path}')
265 |
266 |
267 | def train(model, loader, device, epoch, sampler, criterion, optimizer,
268 | scheduler, batch_size, num_batches_per_step, num_steps_per_epoch, warmup_lr_epochs, schedule_lr_per_epoch, writer=None, quiet=True):
269 | step_size = num_batches_per_step * batch_size
270 | num_inputs = epoch * num_steps_per_epoch * step_size * hvd.size()
271 | _r_num_batches_per_step = 1.0 / num_batches_per_step
272 |
273 | sampler.set_epoch(epoch)
274 | model.train()
275 | for step, (inputs, targets) in enumerate(tqdm(
276 | loader, desc='train', ncols=0, disable=quiet)):
277 | adjust_learning_rate(scheduler, epoch=epoch, step=step,
278 | num_steps_per_epoch=num_steps_per_epoch,
279 | warmup_lr_epochs=warmup_lr_epochs,
280 | schedule_lr_per_epoch=schedule_lr_per_epoch)
281 |
282 | inputs = inputs.to(device, non_blocking=True)
283 | targets = targets.to(device, non_blocking=True)
284 | optimizer.zero_grad()
285 |
286 | loss = torch.tensor([0.0])
287 | for b in range(0, step_size, batch_size):
288 | _inputs = inputs[b:b+batch_size]
289 | _targets = targets[b:b+batch_size]
290 | _outputs = model(_inputs)
291 | _loss = criterion(_outputs, _targets)
292 | _loss.mul_(_r_num_batches_per_step)
293 | _loss.backward()
294 | loss += _loss.item()
295 | optimizer.step()
296 |
297 | # write train loss log
298 | loss = hvd.allreduce(loss, name='loss').item()
299 | if writer is not None:
300 | num_inputs += step_size * hvd.size()
301 | writer.add_scalar('loss/train', loss, num_inputs)
302 |
303 |
304 | def evaluate(model, loader, device, meters, split='test', quiet=True):
305 | _meters = {}
306 | for k, meter in meters.items():
307 | _meters[k.format(split)] = meter()
308 | meters = _meters
309 |
310 | model.eval()
311 |
312 | with torch.no_grad():
313 | for inputs, targets in tqdm(loader, desc=split, ncols=0, disable=quiet):
314 | inputs = inputs.to(device, non_blocking=True)
315 | targets = targets.to(device, non_blocking=True)
316 |
317 | outputs = model(inputs)
318 | for meter in meters.values():
319 | meter.update(outputs, targets)
320 |
321 | for k, meter in meters.items():
322 | data = meter.data()
323 | for dk, d in data.items():
324 | data[dk] = \
325 | hvd.allreduce(torch.tensor([d]), name=dk, op=hvd.Sum).item()
326 | meter.set(data)
327 | meters[k] = meter.compute()
328 | return meters
329 |
330 |
331 | # Horovod: using `lr = base_lr * hvd.size()` from the very beginning
332 | # leads to worse final accuracy.
333 | # Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()`
334 | # during the first five epochs. See https://arxiv.org/abs/1706.02677.
335 | def adjust_learning_rate(scheduler, epoch, step, num_steps_per_epoch,
336 | warmup_lr_epochs=0, schedule_lr_per_epoch=False):
337 | if epoch < warmup_lr_epochs:
338 | size = hvd.size()
339 | epoch += step / num_steps_per_epoch
340 | factor = (epoch * (size - 1) / warmup_lr_epochs + 1) / size
341 | for param_group, base_lr in zip(scheduler.optimizer.param_groups,
342 | scheduler.base_lrs):
343 | param_group['lr'] = base_lr * factor
344 | elif schedule_lr_per_epoch and (step > 0 or epoch == 0):
345 | return
346 | elif epoch == warmup_lr_epochs and step == 0:
347 | for param_group, base_lr in zip(scheduler.optimizer.param_groups,
348 | scheduler.base_lrs):
349 | param_group['lr'] = base_lr
350 | return
351 | else:
352 | scheduler.step()
353 |
354 | def get_bn_parameters(module):
355 | def get_members_fn(m):
356 | if isinstance(m, nn.BatchNorm2d):
357 | return m._parameters.items()
358 | else:
359 | return dict()
360 | gen = module._named_members(get_members_fn=get_members_fn)
361 | for _, elem in gen:
362 | yield elem
363 |
364 |
365 | def get_common_parameters(module):
366 | def get_members_fn(m):
367 | if isinstance(m, nn.BatchNorm2d):
368 | return dict()
369 | else:
370 | for n, p in m._parameters.items():
371 | yield n, p
372 |
373 | gen = module._named_members(get_members_fn=get_members_fn)
374 | for _, elem in gen:
375 | yield elem
376 |
377 |
378 | def get_save_path(*configs, prefix='runs'):
379 | memo = dict()
380 | for c in configs:
381 | cmemo = memo
382 | c = c.replace('configs/', '').replace('.py', '').split('/')
383 | for m in c:
384 | if m not in cmemo:
385 | cmemo[m] = dict()
386 | cmemo = cmemo[m]
387 |
388 | def get_str(m, p):
389 | n = len(m)
390 | if n > 1:
391 | p += '['
392 | for i, (k, v) in enumerate(m.items()):
393 | p += k
394 | if len(v) > 0:
395 | p += '.'
396 | p = get_str(v, p)
397 | if n > 1 and i < n - 1:
398 | p += '+'
399 | if n > 1:
400 | p += ']'
401 | return p
402 |
403 | return os.path.join(prefix, get_str(memo, ''))
404 |
405 |
406 | def printr(*args, **kwargs):
407 | if hvd.rank() == 0:
408 | print(*args, **kwargs)
409 |
410 |
411 | if __name__ == '__main__':
412 | hvd.init()
413 | main()
414 |
--------------------------------------------------------------------------------