├── .gitignore ├── LICENSE ├── README.md ├── caffe └── README.md ├── chainer ├── README.md ├── mnist_dataset.py ├── mnist_helper.py ├── train_mnist.py ├── train_mnist_gpu.py └── train_mnist_multi.py ├── imgs └── pytorch │ ├── data_parallel_gpu.PNG │ ├── data_parallel_time.PNG │ ├── multi_node_distribute_gpu.PNG │ ├── multi_node_distribute_gpu_node2.PNG │ ├── multi_node_distribute_time.PNG │ ├── sg_gpu.PNG │ ├── sg_time.PNG │ ├── single_node_distribute.PNG │ ├── single_node_distribute_rank0_time.PNG │ └── single_node_distribute_rank1_time.PNG ├── mxnet └── README.md ├── pytorch ├── README.md ├── data_parallel.py ├── distributed_data_parallel.py ├── model.py └── single_gpu.py ├── tensorflow └── README.md └── tensorflow2 ├── README.md ├── mnist_mirror_strategy.py ├── mnist_multi_worker_strategy.py └── mnist_single.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Pychram 7 | .idea/ 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed-Training-DL 2 | 各种深度学习(DL)框架分布式训练,包括:Tensorflow、Tensorflow2、Pytorch、Chainer、Caffe、Mxnet ...,欢迎大家来共同维护! 3 | 4 | ## Tensorflow 5 | Tensorflow 分布式训练示例参见[README](./tensorflow/README.md) 6 | 7 | ## Tensorflow2 8 | Tensorflow2 分布式训练示例参见[README](./tensorflow2/README.md) 9 | 10 | ## PyTorch 11 | Pytorch 分布式训练示例参见[README](./pytorch/README.md)。其中包括[`nn.DataParallel`](./pytorch/data_parallel.py)实现方式,[`nn.parallel.DistributedDataParallel`](./pytorch/distributed_data_parallel.py)实现方式以及基于 `SLURM` HPC调度实现。 12 | 13 | ## Chainer 14 | Chainer 分布式训练示例参见[README](./chainer/README.md) 15 | 16 | ## Mxnet 17 | Mxnet 分布式训练示例参见[README](./mxnet/README.md) 18 | 19 | ## Caffe 20 | Caffe 分布式训练示例参见[README](./caffe/README.md) -------------------------------------------------------------------------------- /caffe/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/caffe/README.md -------------------------------------------------------------------------------- /chainer/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/chainer/README.md -------------------------------------------------------------------------------- /chainer/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import chainer 4 | from chainer.dataset import download 5 | from mnist_helper import make_npz 6 | from mnist_helper import preprocess_mnist 7 | 8 | def get_mnist(withlabel=True, ndim=1, scale=1, dtype=None, 9 | label_dtype=numpy.int32, rgb_format=False): 10 | 11 | dtype = chainer.get_dtype(dtype) 12 | train_raw = _retrieve_mnist_training() 13 | train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, 14 | label_dtype, rgb_format) 15 | 16 | test_raw = _retrieve_mnist_test() 17 | test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, 18 | label_dtype, rgb_format) 19 | return train, test 20 | 21 | def _retrieve_mnist_training(): 22 | train_path1 = os.path.dirname(os.path.realpath(__file__)) + '/../../datasets/mnist/train-images-idx3-ubyte.gz' 23 | train_path2 = os.path.dirname(os.path.realpath(__file__)) + '/../../datasets/mnist/train-labels-idx1-ubyte.gz' 24 | train_path = [train_path1, train_path2] 25 | return _retrieve_mnist('train.npz', train_path) 26 | 27 | def _retrieve_mnist_test(): 28 | test_path1 = os.path.dirname(os.path.realpath(__file__)) + '/../../datasets/mnist/t10k-images-idx3-ubyte.gz' 29 | test_path2 = os.path.dirname(os.path.realpath(__file__)) + '/../../datasets/mnist/t10k-labels-idx1-ubyte.gz' 30 | test_path = [test_path1, test_path2] 31 | return _retrieve_mnist('test.npz', test_path) 32 | 33 | def _retrieve_mnist(name, data_paths): 34 | 35 | root = download.get_dataset_directory('./temp_dir') 36 | path = os.path.join(root, name) 37 | return download.cache_or_load_file( 38 | path, lambda path: make_npz(path, data_paths), numpy.load) 39 | -------------------------------------------------------------------------------- /chainer/mnist_helper.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import struct 3 | 4 | import numpy 5 | import six 6 | 7 | from chainer.datasets import tuple_dataset 8 | 9 | def make_npz(path, paths): 10 | 11 | x_path, y_path = paths 12 | 13 | with gzip.open(x_path, 'rb') as fx, gzip.open(y_path, 'rb') as fy: 14 | fx.read(4) 15 | fy.read(4) 16 | N, = struct.unpack('>i', fx.read(4)) 17 | if N != struct.unpack('>i', fy.read(4))[0]: 18 | raise RuntimeError('wrong pair of MNIST images and labels') 19 | fx.read(8) 20 | 21 | x = numpy.empty((N, 784), dtype=numpy.uint8) 22 | y = numpy.empty(N, dtype=numpy.uint8) 23 | 24 | for i in six.moves.range(N): 25 | y[i] = ord(fy.read(1)) 26 | for j in six.moves.range(784): 27 | x[i, j] = ord(fx.read(1)) 28 | 29 | numpy.savez_compressed(path, x=x, y=y) 30 | return {'x': x, 'y': y} 31 | 32 | def preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, 33 | rgb_format): 34 | images = raw['x'] 35 | 36 | if ndim == 2: 37 | images = images.reshape(-1, 28, 28) 38 | elif ndim == 3: 39 | images = images.reshape(-1, 1, 28, 28) 40 | if rgb_format: 41 | images = numpy.broadcast_to(images, 42 | (len(images), 3) + images.shape[2:]) 43 | elif ndim != 1: 44 | raise ValueError('invalid ndim for MNIST dataset') 45 | 46 | images = images.astype(image_dtype) 47 | images *= scale / 255. 48 | 49 | if withlabel: 50 | labels = raw['y'].astype(label_dtype) 51 | return tuple_dataset.TupleDataset(images, labels) 52 | else: 53 | return images 54 | -------------------------------------------------------------------------------- /chainer/train_mnist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import chainer 5 | import chainer.functions as F 6 | import chainer.links as L 7 | import sklearn 8 | from chainer import training 9 | from chainer.training import extensions 10 | import mnist_dataset 11 | 12 | # Network definition 13 | class MLP(chainer.Chain): 14 | 15 | def __init__(self, n_units, n_out): 16 | super(MLP, self).__init__() 17 | with self.init_scope(): 18 | # the size of the inputs to each layer will be inferred 19 | self.l1 = L.Linear(None, n_units) # n_in -> n_units 20 | self.l2 = L.Linear(None, n_units) # n_units -> n_units 21 | self.l3 = L.Linear(None, n_out) # n_units -> n_out 22 | 23 | def forward(self, x): 24 | h1 = F.relu(self.l1(x)) 25 | h2 = F.relu(self.l2(h1)) 26 | return self.l3(h2) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description='Chainer example: MNIST') 31 | parser.add_argument('--batchsize', '-b', type=int, default=100, 32 | help='Number of images in each mini-batch') 33 | parser.add_argument('--epoch', '-e', type=int, default=20, 34 | help='Number of sweeps over the dataset to train') 35 | parser.add_argument('--frequency', '-f', type=int, default=-1, 36 | help='Frequency of taking a snapshot') 37 | parser.add_argument('--gpu', '-g', type=int, default=-1, 38 | help='GPU ID (negative value indicates CPU)') 39 | parser.add_argument('--out', '-o', default='result', 40 | help='Directory to output the result') 41 | parser.add_argument('--resume', '-r', default='', 42 | help='Resume the training from snapshot') 43 | parser.add_argument('--unit', '-u', type=int, default=1000, 44 | help='Number of units') 45 | parser.add_argument('--noplot', dest='plot', action='store_false', 46 | help='Disable PlotReport extension') 47 | args = parser.parse_args() 48 | 49 | print('=============================================') 50 | if args.gpu < 0: 51 | print('# gpu = {}, Program selected cpu execution!'.format(args.gpu)) 52 | else: 53 | print('# gpu = {}, Program selected gpu execution!'.format(args.gpu)) 54 | print('# number of units: {}'.format(args.unit)) 55 | print('# minibatch-size: {}'.format(args.batchsize)) 56 | print('# epoch: {}'.format(args.epoch)) 57 | print('=============================================') 58 | 59 | # Set up a neural network to train 60 | # Classifier reports softmax cross entropy loss and accuracy at every 61 | # iteration, which will be used by the PrintReport extension below. 62 | model = L.Classifier(MLP(args.unit, 10)) 63 | if args.gpu >= 0: 64 | # Make a specified GPU current 65 | chainer.backends.cuda.get_device_from_id(args.gpu).use() 66 | model.to_gpu() # Copy the model to the GPU 67 | 68 | # Setup an optimizer 69 | optimizer = chainer.optimizers.Adam() 70 | optimizer.setup(model) 71 | 72 | # Load the MNIST dataset 73 | train, test = mnist_dataset.get_mnist() 74 | 75 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize) 76 | test_iter = chainer.iterators.SerialIterator(test, args.batchsize, 77 | repeat=False, shuffle=False) 78 | 79 | # Set up a trainer 80 | updater = training.updaters.StandardUpdater( 81 | train_iter, optimizer, device=args.gpu) 82 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 83 | 84 | # Evaluate the model with the test dataset for each epoch 85 | trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) 86 | 87 | # Dump a computational graph from 'loss' variable at the first iteration 88 | # The "main" refers to the target link of the "main" optimizer. 89 | trainer.extend(extensions.dump_graph('main/loss')) 90 | 91 | # Take a snapshot for each specified epoch 92 | frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) 93 | trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) 94 | 95 | # Write a log of evaluation statistics for each epoch 96 | trainer.extend(extensions.LogReport()) 97 | 98 | # Save two plot images to the result dir 99 | #if args.plot and extensions.PlotReport.available(): 100 | # trainer.extend( 101 | # extensions.PlotReport(['main/loss', 'validation/main/loss'], 102 | # 'epoch', file_name='loss.png')) 103 | # trainer.extend( 104 | # extensions.PlotReport( 105 | # ['main/accuracy', 'validation/main/accuracy'], 106 | # 'epoch', file_name='accuracy.png')) 107 | 108 | # Print selected entries of the log to stdout 109 | # Here "main" refers to the target link of the "main" optimizer again, and 110 | # "validation" refers to the default name of the Evaluator extension. 111 | # Entries other than 'epoch' are reported by the Classifier link, called by 112 | # either the updater or the evaluator. 113 | trainer.extend(extensions.PrintReport( 114 | ['epoch', 'main/loss', 'validation/main/loss', 115 | 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) 116 | 117 | # Print a progress bar to stdout 118 | # trainer.extend(extensions.ProgressBar()) 119 | 120 | if args.resume: 121 | # Resume from a snapshot 122 | chainer.serializers.load_npz(args.resume, trainer) 123 | 124 | # Run the training 125 | trainer.run() 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /chainer/train_mnist_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | import chainer 5 | import chainer.links as L 6 | import chainer.functions as F 7 | from chainer import training 8 | from chainer.training import extensions 9 | import mnist_dataset 10 | 11 | # Network definition 12 | class MLP(chainer.Chain): 13 | 14 | def __init__(self, n_units, n_out): 15 | super(MLP, self).__init__() 16 | with self.init_scope(): 17 | # the size of the inputs to each layer will be inferred 18 | self.l1 = L.Linear(None, n_units) # n_in -> n_units 19 | self.l2 = L.Linear(None, n_units) # n_units -> n_units 20 | self.l3 = L.Linear(None, n_out) # n_units -> n_out 21 | 22 | def forward(self, x): 23 | h1 = F.relu(self.l1(x)) 24 | h2 = F.relu(self.l2(h1)) 25 | return self.l3(h2) 26 | 27 | 28 | def main(): 29 | # This script is almost identical to train_mnist.py. The only difference is 30 | # that this script uses data-parallel computation on two GPUs. 31 | # See train_mnist.py for more details. 32 | parser = argparse.ArgumentParser(description='Chainer example: MNIST') 33 | parser.add_argument('--batchsize', '-b', type=int, default=400, 34 | help='Number of images in each mini-batch') 35 | parser.add_argument('--epoch', '-e', type=int, default=20, 36 | help='Number of sweeps over the dataset to train') 37 | parser.add_argument('--gpu', '-g', action='store_true', 38 | help='Use gpu') 39 | parser.add_argument('--gpu_number', '-n', type=int, default=1, 40 | help='Number of gpus') 41 | parser.add_argument('--out', '-o', default='result', 42 | help='Directory to output the result') 43 | parser.add_argument('--resume', '-r', default='', 44 | help='Resume the training from snapshot') 45 | parser.add_argument('--unit', '-u', type=int, default=1000, 46 | help='Number of units') 47 | args = parser.parse_args() 48 | 49 | if args.gpu: 50 | print("Use the gpu environment to perform work, please \ 51 | set the number of gpus you need to use.") 52 | if args.gpu_number == 2: 53 | print('===================================') 54 | use_gpu = {'main': 0, 'second': 1} 55 | print('# use GPU:2') 56 | print('# unit: {}'.format(args.unit)) 57 | print('# minibatch-size: {}'.format(args.batchsize)) 58 | print('# epoch: {}'.format(args.epoch)) 59 | print('===================================') 60 | elif args.gpu_number == 1: 61 | print('===================================') 62 | use_gpu = {'main': 0} 63 | print('# use GPU: 1') 64 | print('# unit: {}'.format(args.unit)) 65 | print('# minibatch-size: {}'.format(args.batchsize)) 66 | print('# epoch: {}'.format(args.epoch)) 67 | print('===================================') 68 | else: 69 | raise ValueError('please set the correct number of gpus you need to use!') 70 | else: 71 | raise ValueError('gpu env set error!') 72 | 73 | chainer.backends.cuda.get_device_from_id(0).use() 74 | 75 | model = L.Classifier(MLP(args.unit, 10)) 76 | optimizer = chainer.optimizers.Adam() 77 | optimizer.setup(model) 78 | 79 | train, test = mnist_dataset.get_mnist() 80 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize) 81 | test_iter = chainer.iterators.SerialIterator(test, args.batchsize, 82 | repeat=False, shuffle=False) 83 | 84 | # ParallelUpdater implements the data-parallel gradient computation on 85 | # multiple GPUs. It accepts "devices" argument that specifies which GPU to 86 | # use. 87 | updater = training.updaters.ParallelUpdater( 88 | train_iter, 89 | optimizer, 90 | # The device of the name 'main' is used as a "master", while others are 91 | # used as slaves. Names other than 'main' are arbitrary. 92 | devices=use_gpu, 93 | ) 94 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 95 | 96 | trainer.extend(extensions.Evaluator(test_iter, model, device=0)) 97 | trainer.extend(extensions.dump_graph('main/loss')) 98 | trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch')) 99 | trainer.extend(extensions.LogReport()) 100 | trainer.extend(extensions.PrintReport( 101 | ['epoch', 'main/loss', 'validation/main/loss', 102 | 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) 103 | # trainer.extend(extensions.ProgressBar()) 104 | 105 | if args.resume: 106 | chainer.serializers.load_npz(args.resume, trainer) 107 | 108 | trainer.run() 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /chainer/train_mnist_multi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | 4 | import argparse 5 | 6 | import chainer 7 | import chainer.functions as F 8 | import chainer.links as L 9 | from chainer import training 10 | from chainer.training import extensions 11 | 12 | import chainermn 13 | import mnist_dataset 14 | 15 | class MLP(chainer.Chain): 16 | 17 | def __init__(self, n_units, n_out): 18 | super(MLP, self).__init__( 19 | # the size of the inputs to each layer will be inferred 20 | l1=L.Linear(784, n_units), # n_in -> n_units 21 | l2=L.Linear(n_units, n_units), # n_units -> n_units 22 | l3=L.Linear(n_units, n_out), # n_units -> n_out 23 | ) 24 | 25 | def __call__(self, x): 26 | h1 = F.relu(self.l1(x)) 27 | h2 = F.relu(self.l2(h1)) 28 | return self.l3(h2) 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser(description='ChainerMN example: MNIST') 33 | parser.add_argument('--batchsize', '-b', type=int, default=100, 34 | help='Number of images in each mini-batch') 35 | parser.add_argument('--communicator', type=str, 36 | default='pure_nccl', help='Type of communicator') 37 | parser.add_argument('--epoch', '-e', type=int, default=20, 38 | help='Number of sweeps over the dataset to train') 39 | parser.add_argument('--gpu', '-g', action='store_true', 40 | help='Use GPU') 41 | parser.add_argument('--out', '-o', default='result', 42 | help='Directory to output the result') 43 | parser.add_argument('--resume', '-r', default='', 44 | help='Resume the training from snapshot') 45 | parser.add_argument('--unit', '-u', type=int, default=1000, 46 | help='Number of units') 47 | args = parser.parse_args() 48 | 49 | # Prepare ChainerMN communicator. 50 | 51 | if args.gpu: 52 | if args.communicator == 'naive': 53 | print("Error: 'naive' communicator does not support GPU.\n") 54 | exit(-1) 55 | comm = chainermn.create_communicator(args.communicator) 56 | device = comm.intra_rank 57 | else: 58 | if args.communicator != 'naive': 59 | print('Warning: using naive communicator ' 60 | 'because only naive supports CPU-only execution') 61 | comm = chainermn.create_communicator('naive') 62 | device = -1 63 | 64 | if comm.rank == 0: 65 | print('==========================================') 66 | print('Num process (COMM_WORLD): {}'.format(comm.size)) 67 | if args.gpu: 68 | print('Using GPUs') 69 | print('Using {} communicator'.format(args.communicator)) 70 | print('Num unit: {}'.format(args.unit)) 71 | print('Num Minibatch-size: {}'.format(args.batchsize)) 72 | print('Num epoch: {}'.format(args.epoch)) 73 | print('==========================================') 74 | 75 | model = L.Classifier(MLP(args.unit, 10)) 76 | if device >= 0: 77 | chainer.cuda.get_device_from_id(device).use() 78 | model.to_gpu() 79 | 80 | # Create a multi node optimizer from a standard Chainer optimizer. 81 | optimizer = chainermn.create_multi_node_optimizer( 82 | chainer.optimizers.Adam(), comm) 83 | optimizer.setup(model) 84 | 85 | # Split and distribute the dataset. Only worker 0 loads the whole dataset. 86 | # Datasets of worker 0 are evenly split and distributed to all workers. 87 | if comm.rank == 0: 88 | train, test = mnist_dataset.get_mnist() 89 | else: 90 | train, test = None, None 91 | train = chainermn.scatter_dataset(train, comm, shuffle=True) 92 | test = chainermn.scatter_dataset(test, comm, shuffle=True) 93 | 94 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize) 95 | test_iter = chainer.iterators.SerialIterator(test, args.batchsize, 96 | repeat=False, shuffle=False) 97 | 98 | updater = training.StandardUpdater(train_iter, optimizer, device=device) 99 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 100 | 101 | # Create a multi node evaluator from a standard Chainer evaluator. 102 | evaluator = extensions.Evaluator(test_iter, model, device=device) 103 | evaluator = chainermn.create_multi_node_evaluator(evaluator, comm) 104 | trainer.extend(evaluator) 105 | 106 | # Some display and output extensions are necessary only for one worker. 107 | # (Otherwise, there would just be repeated outputs.) 108 | if comm.rank == 0: 109 | trainer.extend(extensions.dump_graph('main/loss')) 110 | trainer.extend(extensions.LogReport()) 111 | trainer.extend(extensions.PrintReport( 112 | ['epoch', 'main/loss', 'validation/main/loss', 113 | 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) 114 | # trainer.extend(extensions.ProgressBar()) 115 | 116 | if args.resume: 117 | chainer.serializers.load_npz(args.resume, trainer) 118 | 119 | trainer.run() 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /imgs/pytorch/data_parallel_gpu.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/data_parallel_gpu.PNG -------------------------------------------------------------------------------- /imgs/pytorch/data_parallel_time.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/data_parallel_time.PNG -------------------------------------------------------------------------------- /imgs/pytorch/multi_node_distribute_gpu.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/multi_node_distribute_gpu.PNG -------------------------------------------------------------------------------- /imgs/pytorch/multi_node_distribute_gpu_node2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/multi_node_distribute_gpu_node2.PNG -------------------------------------------------------------------------------- /imgs/pytorch/multi_node_distribute_time.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/multi_node_distribute_time.PNG -------------------------------------------------------------------------------- /imgs/pytorch/sg_gpu.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/sg_gpu.PNG -------------------------------------------------------------------------------- /imgs/pytorch/sg_time.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/sg_time.PNG -------------------------------------------------------------------------------- /imgs/pytorch/single_node_distribute.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/single_node_distribute.PNG -------------------------------------------------------------------------------- /imgs/pytorch/single_node_distribute_rank0_time.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/single_node_distribute_rank0_time.PNG -------------------------------------------------------------------------------- /imgs/pytorch/single_node_distribute_rank1_time.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/imgs/pytorch/single_node_distribute_rank1_time.PNG -------------------------------------------------------------------------------- /mxnet/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/mxnet/README.md -------------------------------------------------------------------------------- /pytorch/README.md: -------------------------------------------------------------------------------- 1 | # SingleNode and Distributed Implementation with Pytorch 2 | 3 | ## 介绍 4 | 这个目录主要是`Pytorch`框架的单机以及分布式训练示例。通过在`cifar10`数据集上训练`PyramidNet`模型进行对比分析`Pytorch`框架单机以及分布式训练性能, 5 | 以帮助读者能够快速上手`PyTorch`框架。 6 | 7 | * [`model.py`](./model.py):这个文件主要用以实现模型封装。 8 | * [`single_gpu.py`](./single_gpu.py):这个文件是单机单卡训练脚本。 9 | * [`data_parallel.py`](./data_parallel.py):这个文件是使用`data`进行模型并行的单机多卡训练脚本。 10 | * [`distibuted_data_parallel.py`](./distributed_data_parallel.py):这个文件是实现分布式训练,可通过给定不同参数实现单机多卡、多机单卡以及多机多卡分布式训练。 11 | 12 | ## Requirements 13 | * python3.+ 14 | * torch==1.5.0 15 | * torchvision==0.6.0 16 | 17 | ## 训练 18 | 19 | ### 单机单卡 20 | 关于单机单卡,就是使用单个结点单个gpu进行训练,这应该也是大家最常用的训练方式。 21 | 22 | * 执行命令: 23 | ``` 24 | python single_gpu.py --gpu-nums 1 --epochs 2 --batch-size 64 --train-dir /home/crise/single_gpu --dataset-dir /home/crise/cifar10 --log-interval 20 --save-model 25 | ``` 26 | 上面命令也可简便执行如下: 27 | ``` 28 | python single_gpu.py -g 1 -e 2 -b 64 -td /home/crise/single_gpu -dd /home/crise -li 20 -sm 29 | ``` 30 | * 参数介绍: 31 | * --gpu-nums: 使用gpu的数量,其实只能等于1(因为是单卡训练),不然会报`ValueError`,默认值为0。 32 | * --epochs: 最大`epoch`数量,默认值为3。 33 | * --batch-size: batch size 大小,默认值为64。 34 | * --train-dir: 模型参数及结果存放路径,默认值为`./train_dir`。 35 | * --dataset-dir: 数据集存放路径,默认值为`./data`。 36 | * --log-interval: 日志打印频率,默认值为迭代20步打印一次。 37 | * --save-model: 是否需要存储模型,带上这个参数则存,否则不存。 38 | 39 | * 训练时间: 40 | 本来是想贴图片的,但发现贴上来很难看。可以点击[训练时长](../imgs/pytorch/sg_time.PNG)以及[GPU利用率](../imgs/pytorch/sg_gpu.PNG)查看。 41 | * batch time: 0.255s 42 | * epoch time: 03:20min 43 | * gpu util: 98% 44 | 45 | ### 单机多卡 46 | 单机多卡有两种实现方式,一种是使用`DataParallel`接口实现数据并行单机多卡分布式训练,另外一个是使用`DistributedDataParallel`接口实现 47 | 48 | #### `DataParallel`实现 49 | 这个方式主要是通过单个进程,关于该接口实现详细介绍请参考博客[分布式训练之PyTorch](https://crisescode.github.io/blog/2020/07/31/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83%E4%B9%8BPyTorch/) 50 | 51 | * 执行命令: 52 | ``` 53 | python data_parallel.py --gpu-nums 2 --epochs 2 --batch-size 64 --train-dir /home/crise/data_parallel --dataset-dir /home/crise/cifar10 --log-interval 20 --save-model 54 | ``` 55 | 56 | > 注:简化执行命令可参照单机单卡训练。 57 | 58 | * 参数介绍:参照单机单卡。 59 | 60 | * 训练时间: 61 | [训练时长](../imgs/pytorch/data_parallel_time.PNG) 与 [训练时间](../imgs/pytorch/data_parallel_gpu.PNG),图中能看出在同一个进程中使用了两个gpu进行训练。 62 | * batch time: 0.170s 63 | * epoch time: 02:18min 64 | * gpu util: 80% 65 | 66 | #### `DistributedDataParallel` 实现 67 | 这个方式会通过`torch.multiprocessing`来启动多个进程,进行 68 | 69 | * 执行命令: 70 | 71 | * Shell 1: 72 | ``` 73 | CUDA_VISIBLE_DEVICES='0' python distributed_data_parallel.py --epochs 2 --batch-size 64 --train-dir /home/crise/single_node_distribute --dataset-dir /home/crise/cifar10 --log-interval 20 --save-model --world-size 2 --rank 0 74 | ``` 75 | * 同一个节点 Shell 2 执行: 76 | ``` 77 | CUDA_VISIBLE_DEVICES='1' python distributed_data_parallel.py --epochs 2 --batch-size 64 --train-dir /home/crise/single_node_distribute --dataset-dir /home/crise/cifar10 --log-interval 20 --save-model --world-size 2 --rank 1 78 | ``` 79 | 80 | > 注:简化执行命令可参照单机单卡训练。 81 | 82 | * 参数介绍:基本参数参考单机单卡,新增参数如下。 83 | * --world-size: 启动的进程总数,默认值为1。 84 | * --rank: 当前进程序号,默认值为0。 85 | 86 | * 训练时间:[训练时长](../imgs/pytorch/single_node_distribute_rank0_time.PNG) 与 [训练时间](../imgs/pytorch/single_node_distribute.PNG),图中能看出是有两个进程中分别在不同的gpu上进行训练。 87 | * batch time: 0.274s 88 | * epoch time: 01:51min 89 | * gpu0 util: 98% 90 | * gpu1 util: 99% 91 | 92 | ### 多机多卡分布式 93 | 多机多卡分布式训练还是主要通过`DistributedDataParallel`接口来实现, 94 | * 执行命令: 95 | 96 | * Node 1 & Shell 1 执行: 97 | ``` 98 | CUDA_VISIBLE_DEVICES='0' python distributed_data_parallel.py --epochs 2 --batch-size 64 --train-dir /home/crise/multi_node_distribute --dataset-dir /home/crise --log-interval 20 --save-model --init-method tcp://c1:20201 --world-size 4 --rank 0 99 | ``` 100 | * Node 1 & Shell 2 执行: 101 | ``` 102 | CUDA_VISIBLE_DEVICES='1' python distributed_data_parallel.py --epochs 2 --batch-size 64 --train-dir /home/crise/multi_node_distribute --dataset-dir /home/crise --log-interval 20 --save-model --init-method tcp://c1:20201 --world-size 4 --rank 1 103 | ``` 104 | 105 | * Node 2 & Shell 1 执行: 106 | ``` 107 | CUDA_VISIBLE_DEVICES='0' python distributed_data_parallel.py --epochs 2 --batch-size 64 --train-dir /home/crise/multi_node_distribute --dataset-dir /home/crise --log-interval 20 --save-model --init-method tcp://c1:20201 --world-size 4 --rank 2 108 | ``` 109 | 110 | * Node 2 & Shell 2 执行: 111 | ``` 112 | CUDA_VISIBLE_DEVICES='1' python distributed_data_parallel.py --epochs 2 --batch-size 64 --train-dir /home/crise/multi_node_distribute --dataset-dir /home/crise --log-interval 20 --save-model --init-method tcp://c1:20201 --world-size 4 --rank 3 113 | ``` 114 | > 注:简化执行命令可参照单机单卡训练。 115 | 116 | * 参数介绍:基本参数参考单机单卡,新增参数如下。 117 | * --world-size: 启动的进程总数,默认值为1。 118 | * --rank: 当前进程序号,默认值为0。 119 | * --init-method:初始化方式 120 | 121 | * 训练时间:[训练时长](../imgs/pytorch/multi_node_distribute_time.PNG) 与 [训练时间](../imgs/pytorch/multi_node_distribute_gpu.PNG),这只是一个节点GPU截图,另一个节点大差不差。 122 | * batch time: 0.301s 123 | * epoch time: 01:07min 124 | * gpu0 util: 99% 125 | * gpu1 util: 99% 126 | 127 | ## 性能对比 128 | 上面几种训练都是在`Tesla P100-PCIE`,显存为`16Gb`,batch_size 为64,训练5个epoch,总时长对比如下: 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /pytorch/data_parallel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import time 4 | from os import environ, mkdir, path 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.transforms as transforms 10 | from model import pyramidnet 11 | from torch.utils.data import DataLoader 12 | from torchvision.datasets import CIFAR10 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch Cifar10 Data Parallel Training') 15 | 16 | parser.add_argument('--train-dir', '-td', type=str, default="./train_dir", 17 | help='the path that the model saved (default: "./train_dir")') 18 | parser.add_argument('--dataset-dir', '-dd', type=str, default="./data", 19 | help='the path of dataset (default: "./data")') 20 | parser.add_argument('--batch-size', '-b', type=int, default=64, 21 | help='input batch size for training (default: 64)') 22 | parser.add_argument('--num-workers', type=int, default=4, help='') 23 | parser.add_argument('--test-batchsize', '-tb', type=int, default=1000, 24 | help='input batch size for testing (default: 1000)') 25 | parser.add_argument('--epochs', '-e', type=int, default=10, 26 | help='number of epochs to train (default: 10)') 27 | parser.add_argument('--gpu-nums', '-g', type=int, default=0, 28 | help='Number of GPU in each mini-batch') 29 | parser.add_argument('--learning-rate', '--lr', type=float, default=0.1, metavar='LR', 30 | help='learning rate (default: 0.1)') 31 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 32 | help='SGD momentum (default: 0.9)') 33 | parser.add_argument('--seed', type=int, default=1, metavar='S', 34 | help='random seed (default: 1)') 35 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 36 | help='how many batches to wait before logging training status') 37 | parser.add_argument('--save-model', '-sm', action='store_true', default=False, 38 | help='For Saving the current Model') 39 | parser.add_argument('--weight-decay', '--wd', type=float, default=1e-4, metavar='W', 40 | help='weight decay(default: 1e-4)') 41 | 42 | args = parser.parse_args() 43 | 44 | 45 | def main(): 46 | # set run env 47 | if args.gpu_nums > 1: 48 | device = 'cuda' if torch.cuda.is_available() else "cpu" 49 | gpu_ids = ','.join([str(id) for id in range(args.gpu_nums)]) 50 | environ["CUDA_VISIBLE_DEVICES"] = gpu_ids 51 | else: 52 | raise ValueError("gpu-nums must be greater than 1.") 53 | 54 | print('==> Preparing data..') 55 | transforms_train = transforms.Compose([ 56 | transforms.RandomCrop(32, padding=4), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 60 | 61 | dataset_train = CIFAR10(root='/home/zhaopp5', train=True, download=True, 62 | transform=transforms_train) 63 | 64 | train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 65 | shuffle=True, num_workers=args.num_workers) 66 | 67 | print('==> Making model..') 68 | 69 | model = pyramidnet() 70 | if args.gpu_nums > 1: 71 | model = nn.DataParallel(model) 72 | model = model.to(device) 73 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 74 | print('The number of parameters of model is', num_params) 75 | 76 | criterion = nn.CrossEntropyLoss() 77 | optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, 78 | momentum=args.momentum, weight_decay=args.weight_decay) 79 | 80 | for epoch in range(1, args.epochs + 1): 81 | train(epoch, model, criterion, optimizer, train_loader, device) 82 | 83 | if args.save_model: 84 | if not path.exists(args.train_dir): 85 | mkdir(args.train_dir) 86 | 87 | torch.save( 88 | model.state_dict(), 89 | path.join(args.train_dir, "data_parallel_model.pth") 90 | ) 91 | print("data parallel model has been saved.") 92 | 93 | 94 | def train(epoch, model, criterion, optimizer, train_loader, device): 95 | model.train() 96 | 97 | train_loss, correct, total = 0, 0, 0 98 | epoch_start = time.time() 99 | for batch_idx, (inputs, targets) in enumerate(train_loader): 100 | start = time.time() 101 | 102 | inputs = inputs.to(device) 103 | targets = targets.to(device) 104 | outputs = model(inputs) 105 | loss = criterion(outputs, targets) 106 | 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | 111 | train_loss += loss.item() 112 | _, predicted = outputs.max(1) 113 | total += targets.size(0) 114 | correct += predicted.eq(targets).sum().item() 115 | 116 | acc = 100 * correct / total 117 | 118 | batch_time = time.time() - start 119 | 120 | if batch_idx % args.log_interval == 0: 121 | print('Epoch[{}]: [{}/{}]| loss: {:.3f} | acc: {:.3f} | batch time: {:.3f}s '.format( 122 | epoch, batch_idx, len(train_loader), train_loss / (batch_idx + 1), acc, batch_time)) 123 | 124 | elapse_time = time.time() - epoch_start 125 | elapse_time = datetime.timedelta(seconds=elapse_time) 126 | print("Training time {}".format(elapse_time)) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /pytorch/distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import time 4 | from os import mkdir, path 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.utils.data.distributed 12 | import torchvision.transforms as transforms 13 | from model import pyramidnet 14 | from torch.optim import lr_scheduler 15 | from torch.utils.data import DataLoader 16 | from torchvision.datasets import CIFAR10 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch Cifar10 Distributed Training') 19 | 20 | parser.add_argument('--train-dir', '-td', type=str, default="./train_dir", 21 | help='the path that the model saved (default: "./train_dir")') 22 | parser.add_argument('--dataset-dir', '-dd', type=str, default="./data", 23 | help='the path of dataset (default: "./data")') 24 | parser.add_argument('--batch-size', '-b', type=int, default=64, 25 | help='input batch size for training (default: 64)') 26 | parser.add_argument('--num-workers', type=int, default=4, help='') 27 | parser.add_argument('--test-batch-size', '-tb', type=int, default=1000, 28 | help='input batch size for testing (default: 1000)') 29 | parser.add_argument('--epochs', '-e', type=int, default=10, 30 | help='number of epochs to train (default: 10)') 31 | parser.add_argument('--gpu-nums', '-g', type=int, default=0, 32 | help='Number of GPU in each mini-batch') 33 | parser.add_argument('--learning-rate', '--lr', type=float, default=0.1, metavar='LR', 34 | help='learning rate (default: 0.1)') 35 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 36 | help='SGD momentum (default: 0.9)') 37 | parser.add_argument('--seed', type=int, default=1, metavar='S', 38 | help='random seed (default: 1)') 39 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 40 | help='how many batches to wait before logging training status') 41 | parser.add_argument('--save-model', '-sm', action='store_true', default=False, 42 | help='For Saving the current Model') 43 | parser.add_argument('--weight-decay', '--wd', type=float, default=1e-4, metavar='W', 44 | help='weight decay(default: 1e-4)') 45 | parser.add_argument('--init-method', default='tcp://127.0.0.1:13456', type=str, help='') 46 | parser.add_argument('--dist-backend', default='nccl', type=str, help='') 47 | parser.add_argument('--rank', default=0, type=int, help='') 48 | parser.add_argument('--world-size', default=1, type=int, help='') 49 | 50 | args = parser.parse_args() 51 | 52 | 53 | def main(): 54 | ngpus_per_node = torch.cuda.device_count() 55 | print("ngpus_per_node: ", ngpus_per_node) 56 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 57 | 58 | 59 | def main_worker(gpu, ngpus_per_node, args): 60 | # init the process group 61 | dist.init_process_group(backend=args.dist_backend, init_method=args.init_method, 62 | world_size=args.world_size, rank=args.rank) 63 | 64 | torch.cuda.set_device(gpu) 65 | 66 | print("From Rank: {}, Use GPU: {} for training".format(args.rank, gpu)) 67 | 68 | print('From Rank: {}, ==> Making model..'.format(args.rank)) 69 | net = pyramidnet() 70 | net.cuda(gpu) 71 | args.batch_size = int(args.batch_size / ngpus_per_node) 72 | print("batch_size: ", args.batch_size) 73 | 74 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpu], output_device=gpu) 75 | num_params = sum(p.numel() for p in net.parameters() if p.requires_grad) 76 | print('From Rank: {}, The number of parameters of model is'.format(args.rank), num_params) 77 | 78 | print('From Rank: {}, ==> Preparing data..'.format(args.rank)) 79 | transforms_train = transforms.Compose([ 80 | transforms.RandomCrop(32, padding=4), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 84 | 85 | dataset_train = CIFAR10(root=args.dataset_dir, train=True, download=True, 86 | transform=transforms_train) 87 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) 88 | train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 89 | shuffle=(train_sampler is None), 90 | num_workers=args.num_workers, 91 | sampler=train_sampler) 92 | 93 | criterion = nn.CrossEntropyLoss().cuda(gpu) 94 | optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, 95 | momentum=args.momentum, weight_decay=args.weight_decay) 96 | 97 | scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) 98 | 99 | for epoch in range(1, args.epochs + 1): 100 | train(epoch, net, criterion, optimizer, train_loader, args.rank) 101 | scheduler.step() 102 | 103 | if args.save_model: 104 | if not path.exists(args.train_dir): 105 | mkdir(args.train_dir) 106 | 107 | # if args.rank == 0: 108 | torch.save( 109 | net.module.state_dict(), 110 | path.join( 111 | args.train_dir, 112 | "distributed_data_parallel_{}.pth".format(args.rank) 113 | ) 114 | ) 115 | print("From Rank: {}, model saved.".format(args.rank)) 116 | 117 | 118 | def train(epoch, net, criterion, optimizer, train_loader, rank): 119 | net.train() 120 | 121 | train_loss, correct, total = 0, 0, 0 122 | epoch_start = time.time() 123 | for batch_idx, (inputs, targets) in enumerate(train_loader): 124 | start = time.time() 125 | 126 | inputs = inputs.cuda() 127 | targets = targets.cuda() 128 | outputs = net(inputs) 129 | loss = criterion(outputs, targets) 130 | 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | train_loss += loss.item() 136 | _, predicted = outputs.max(1) 137 | total += targets.size(0) 138 | correct += predicted.eq(targets).sum().item() 139 | 140 | acc = 100 * correct / total 141 | 142 | batch_time = time.time() - start 143 | 144 | if batch_idx % args.log_interval == 0: 145 | print('From Rank: {}, Epoch:[{}][{}/{}]| loss: {:.3f} | ' 146 | 'acc: {:.3f} | batch time: {:.3f}s '.format( 147 | rank, epoch, batch_idx, len(train_loader), 148 | train_loss / (batch_idx + 1), acc, batch_time), flush=True) 149 | 150 | elapse_time = time.time() - epoch_start 151 | elapse_time = datetime.timedelta(seconds=elapse_time) 152 | print("From Rank: {}, Training time {}".format(rank, elapse_time)) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | # code from https://github.com/KellerJordan/ResNet-PyTorch-CIFAR10/blob/master/model.py 6 | class IdentityPadding(nn.Module): 7 | def __init__(self, in_channels, out_channels, stride=1): 8 | super(IdentityPadding, self).__init__() 9 | 10 | if stride == 2: 11 | self.pooling = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) 12 | else: 13 | self.pooling = None 14 | 15 | self.add_channels = out_channels - in_channels 16 | 17 | def forward(self, x): 18 | out = F.pad(x, (0, 0, 0, 0, 0, self.add_channels)) 19 | if self.pooling is not None: 20 | out = self.pooling(out) 21 | return out 22 | 23 | 24 | class ResidualBlock(nn.Module): 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super(ResidualBlock, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(in_channels) 28 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 29 | stride=stride, padding=1, bias=False) 30 | self.bn2 = nn.BatchNorm2d(out_channels) 31 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 32 | stride=1, padding=1, bias=False) 33 | self.bn3 = nn.BatchNorm2d(out_channels) 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | self.down_sample = IdentityPadding(in_channels, out_channels, stride) 37 | 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | shortcut = self.down_sample(x) 42 | out = self.bn1(x) 43 | out = self.conv1(out) 44 | out = self.bn2(out) 45 | out = self.relu(out) 46 | out = self.conv2(out) 47 | out = self.bn3(out) 48 | 49 | out += shortcut 50 | return out 51 | 52 | 53 | class PyramidNet(nn.Module): 54 | def __init__(self, num_layers, alpha, block, num_classes=10): 55 | super(PyramidNet, self).__init__() 56 | self.in_channels = 16 57 | 58 | # num_layers = (110 - 2)/6 = 18 59 | self.num_layers = num_layers 60 | self.addrate = alpha / (3 * self.num_layers * 1.0) 61 | 62 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, 63 | stride=1, padding=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(16) 65 | 66 | # feature map size = 32x32 67 | self.layer1 = self.get_layers(block, stride=1) 68 | # feature map size = 16x16 69 | self.layer2 = self.get_layers(block, stride=2) 70 | # feature map size = 8x8 71 | self.layer3 = self.get_layers(block, stride=2) 72 | 73 | self.out_channels = int(round(self.out_channels)) 74 | self.bn_out = nn.BatchNorm2d(self.out_channels) 75 | self.relu_out = nn.ReLU(inplace=True) 76 | self.avgpool = nn.AvgPool2d(8, stride=1) 77 | self.fc_out = nn.Linear(self.out_channels, num_classes) 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 82 | nonlinearity='relu') 83 | elif isinstance(m, nn.BatchNorm2d): 84 | nn.init.constant_(m.weight, 1) 85 | nn.init.constant_(m.bias, 0) 86 | 87 | def get_layers(self, block, stride): 88 | layers_list = [] 89 | for _ in range(self.num_layers - 1): 90 | self.out_channels = self.in_channels + self.addrate 91 | layers_list.append(block(int(round(self.in_channels)), 92 | int(round(self.out_channels)), 93 | stride)) 94 | self.in_channels = self.out_channels 95 | stride = 1 96 | 97 | return nn.Sequential(*layers_list) 98 | 99 | def forward(self, x): 100 | x = self.conv1(x) 101 | x = self.bn1(x) 102 | 103 | x = self.layer1(x) 104 | x = self.layer2(x) 105 | x = self.layer3(x) 106 | 107 | x = self.bn_out(x) 108 | x = self.relu_out(x) 109 | x = self.avgpool(x) 110 | x = x.view(x.size(0), -1) 111 | x = self.fc_out(x) 112 | return x 113 | 114 | 115 | def pyramidnet(): 116 | block = ResidualBlock 117 | model = PyramidNet(num_layers=18, alpha=270, block=block) 118 | return model 119 | -------------------------------------------------------------------------------- /pytorch/single_gpu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import time 4 | from os import mkdir, path 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.transforms as transforms 10 | from model import pyramidnet 11 | from torch.utils.data import DataLoader 12 | from torchvision.datasets import CIFAR10 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch Cifar10 Single Gpu Training') 15 | parser.add_argument('--train-dir', '-td', type=str, default="./train_dir", 16 | help='the path that the model saved (default: "./train_dir")') 17 | parser.add_argument('--dataset-dir', '-dd', type=str, default="./data", 18 | help='the path of dataset (default: "./data")') 19 | parser.add_argument('--batch-size', '-b', type=int, default=64, 20 | help='input batch size for training (default: 64)') 21 | parser.add_argument('--num-workers', type=int, default=4, help='') 22 | parser.add_argument('--test-batchsize', '-tb', type=int, default=1000, 23 | help='input batch size for testing (default: 1000)') 24 | parser.add_argument('--epochs', '-e', type=int, default=10, 25 | help='number of epochs to train (default: 10)') 26 | parser.add_argument('--gpu-nums', '-g', type=int, default=0, 27 | help='Number of GPU in each mini-batch') 28 | parser.add_argument('--learning-rate', '-lr', type=float, default=0.1, metavar='LR', 29 | help='learning rate (default: 0.1)') 30 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 31 | help='SGD momentum (default: 0.9)') 32 | parser.add_argument('--seed', type=int, default=1, metavar='S', 33 | help='random seed (default: 1)') 34 | parser.add_argument('--log-interval', '-li', type=int, default=20, metavar='N', 35 | help='how many batches to wait before logging training status') 36 | parser.add_argument('--save-model', '-sm', action='store_true', default=False, 37 | help='For Saving the current Model') 38 | parser.add_argument('--weight-decay', '-wd', type=float, default=1e-4, metavar='W', 39 | help='weight decay(default: 1e-4)') 40 | args = parser.parse_args() 41 | 42 | 43 | def main(): 44 | if args.gpu_nums > 1: 45 | raise ValueError("gpu nums must be equal to 1.") 46 | 47 | # set run env 48 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 49 | 50 | print('==> Preparing data..') 51 | transforms_train = transforms.Compose([ 52 | transforms.RandomCrop(32, padding=4), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) 56 | 57 | dataset_train = CIFAR10(root=args.dataset_dir, train=True, download=True, 58 | transform=transforms_train) 59 | 60 | train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 61 | shuffle=True, num_workers=args.num_workers) 62 | 63 | print('==> Making model..') 64 | 65 | model = pyramidnet() 66 | model = model.to(device) 67 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 68 | print('The number of parameters of model is', num_params) 69 | 70 | criterion = nn.CrossEntropyLoss() 71 | optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, 72 | momentum=args.momentum, weight_decay=args.weight_decay) 73 | 74 | for epoch in range(1, args.epochs + 1): 75 | train(epoch, model, criterion, optimizer, train_loader, device) 76 | 77 | if args.save_model: 78 | if not path.exists(args.train_dir): 79 | mkdir(args.train_dir) 80 | 81 | torch.save( 82 | model.state_dict(), 83 | path.join(args.train_dir, "single_gpu_model.pth") 84 | ) 85 | print("single gpu model has been saved.") 86 | 87 | 88 | def train(epoch, model, criterion, optimizer, train_loader, device): 89 | model.train() 90 | 91 | train_loss, correct, total = 0, 0, 0 92 | epoch_start = time.time() 93 | for batch_idx, (inputs, targets) in enumerate(train_loader): 94 | start = time.time() 95 | 96 | inputs = inputs.to(device) 97 | targets = targets.to(device) 98 | outputs = model(inputs) 99 | loss = criterion(outputs, targets) 100 | 101 | optimizer.zero_grad() 102 | loss.backward() 103 | optimizer.step() 104 | 105 | train_loss += loss.item() 106 | _, predicted = outputs.max(1) 107 | total += targets.size(0) 108 | correct += predicted.eq(targets).sum().item() 109 | 110 | acc = 100 * correct / total 111 | 112 | batch_time = time.time() - start 113 | 114 | if batch_idx % args.log_interval == 0: 115 | print('Epoch[{}]: [{}/{}]| loss: {:.3f} | acc: {:.3f} | batch time: {:.3f}s '.format( 116 | epoch, batch_idx, len(train_loader), train_loss / (batch_idx + 1), acc, batch_time)) 117 | 118 | elapse_time = time.time() - epoch_start 119 | elapse_time = datetime.timedelta(seconds=elapse_time) 120 | print("Training time {}".format(elapse_time)) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/tensorflow/README.md -------------------------------------------------------------------------------- /tensorflow2/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Crisescode/distributed-training-dl/a7ff2b4a6c07a126c30eaa886cc6e8cd02a83949/tensorflow2/README.md -------------------------------------------------------------------------------- /tensorflow2/mnist_mirror_strategy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*-coding:utf-8 -*- 3 | 4 | import os 5 | import argparse 6 | 7 | import tensorflow as tf 8 | from tensorflow.keras import datasets 9 | from tensorflow.keras import layers, models 10 | from tensorflow.keras import optimizers 11 | 12 | strategy = tf.distribute.MirroredStrategy() 13 | 14 | 15 | # create cnn model 16 | class Net(object): 17 | def __init__(self): 18 | model = models.Sequential() 19 | model.add(layers.Conv2D( 20 | 32, (3, 3), activation='relu', input_shape=(28, 28, 1))) 21 | model.add(layers.MaxPooling2D((2, 2))) 22 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 23 | model.add(layers.MaxPooling2D((2, 2))) 24 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 25 | 26 | model.add(layers.Flatten()) 27 | model.add(layers.Dense(64, activation='relu')) 28 | model.add(layers.Dense(10, activation='softmax')) 29 | 30 | model.summary() 31 | 32 | self.model = model 33 | 34 | 35 | # inital dateset 36 | class DataSet(object): 37 | def __init__(self): 38 | data_path = os.path.dirname(os.path.realpath(__file__)) \ 39 | + '/../../datasets/mnist/mnist.npz' 40 | (train_images, train_labels), (test_images, test_labels) = \ 41 | datasets.mnist.load_data(path=data_path) 42 | train_images = train_images.reshape((60000, 28, 28, 1)) 43 | test_images = test_images.reshape((10000, 28, 28, 1)) 44 | 45 | train_images, test_images = train_images / 255.0, test_images / 255.0 46 | 47 | self.train_images, self.train_labels = train_images, train_labels 48 | self.test_images, self.test_labels = test_images, test_labels 49 | 50 | 51 | # train and val 52 | class Train: 53 | def __init__(self): 54 | self.data = DataSet() 55 | 56 | def train(self, args): 57 | # Define the checkpoint directory to store the checkpoints 58 | checkpoint_dir = args.train_dir 59 | # Name of the checkpoint files 60 | checkpoint_path = os.path.join(checkpoint_dir, "ckpt_{epoch}") 61 | 62 | callbacks = [ 63 | tf.keras.callbacks.TensorBoard(log_dir=args.train_dir, histogram_freq=1), 64 | tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, 65 | save_weights_only=True), 66 | ] 67 | 68 | with strategy.scope(): 69 | model = Net().model 70 | 71 | model.compile(optimizer=optimizers.Adam(), 72 | loss='sparse_categorical_crossentropy', 73 | metrics=['accuracy']) 74 | 75 | model.fit(self.data.train_images, self.data.train_labels, 76 | batch_size=args.batch_size, 77 | epochs=args.epochs, 78 | callbacks=callbacks, 79 | validation_data=(self.data.test_images, self.data.test_labels)) 80 | 81 | # EVAL 82 | model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) 83 | eval_loss, eval_acc = model.evaluate( 84 | self.data.test_images, self.data.test_labels, verbose=2) 85 | print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc)) 86 | 87 | 88 | def main(): 89 | # training params settings 90 | parser = argparse.ArgumentParser(description='Tensorflow 2.0 MNIST Example,' 91 | ' use Mirrorstrategy') 92 | parser.add_argument('--train_dir', '-td', type=str, default='./train_dir', 93 | help='the folder of svaing model') 94 | parser.add_argument('--batch_size', '-b', type=int, default=64, 95 | help='input batch size for training (default: 64)') 96 | parser.add_argument('--test_batchsize', '-tb', type=int, default=1000, 97 | help='input batch size for testing (default: 1000)') 98 | parser.add_argument('--epochs', '-e', type=int, default=10, 99 | help='number of epochs to train (default: 10)') 100 | parser.add_argument('--gpu_nums', '-g', type=int, default=0, 101 | help='number of gpus') 102 | parser.add_argument('--cpu_nums', '-c', type=int, default=0, 103 | help='number of cpus') 104 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.01, 105 | help='learning rate (default: 0.01)') 106 | parser.add_argument('--momentum', type=float, default=0.5, 107 | help='SGD momentum (default: 0.5)') 108 | parser.add_argument('--log_interval', type=int, default=10, 109 | help='how many batches to wait before logging training status') 110 | parser.add_argument('--save_model', '-sm', action='store_true', default=False, 111 | help='For Saving the current Model') 112 | 113 | args = parser.parse_args() 114 | 115 | app = Train() 116 | app.train(args) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /tensorflow2/mnist_multi_worker_strategy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*-coding:utf-8 -*- 3 | 4 | import os 5 | import json 6 | import argparse 7 | 8 | import tensorflow as tf 9 | from tensorflow.keras import datasets 10 | from tensorflow.keras import layers, models 11 | from tensorflow.keras import optimizers 12 | 13 | 14 | def set_strategy(args): 15 | if args.job_name != 'worker': 16 | raise ValueError("Multi strategy only support worker mode, please check job name") 17 | 18 | tf_config = args.worker_hosts.split(',') 19 | os.environ["TF_CONFIG"] = json.dumps({ 20 | 'cluster': { 21 | 'worker': tf_config 22 | }, 23 | 'task': {'type': args.job_name, 'index': args.task_index} 24 | }) 25 | print(os.environ["TF_CONFIG"]) 26 | 27 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 28 | 29 | return strategy 30 | 31 | 32 | # create cnn model 33 | class Net(object): 34 | def __init__(self): 35 | model = models.Sequential() 36 | model.add(layers.Conv2D( 37 | 32, (3, 3), activation='relu', input_shape=(28, 28, 1))) 38 | model.add(layers.MaxPooling2D((2, 2))) 39 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 40 | model.add(layers.MaxPooling2D((2, 2))) 41 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 42 | 43 | model.add(layers.Flatten()) 44 | model.add(layers.Dense(64, activation='relu')) 45 | model.add(layers.Dense(10, activation='softmax')) 46 | 47 | model.summary() 48 | 49 | self.model = model 50 | 51 | 52 | # inital dateset 53 | class DataSet(object): 54 | def __init__(self): 55 | data_path = os.path.dirname(os.path.realpath(__file__)) \ 56 | + '/../../datasets/mnist/mnist.npz' 57 | (train_images, train_labels), (test_images, test_labels) = \ 58 | datasets.mnist.load_data(path=data_path) 59 | train_images = train_images.reshape((60000, 28, 28, 1)) 60 | test_images = test_images.reshape((10000, 28, 28, 1)) 61 | 62 | train_images, test_images = train_images / 255.0, test_images / 255.0 63 | 64 | self.train_images, self.train_labels = train_images, train_labels 65 | self.test_images, self.test_labels = test_images, test_labels 66 | 67 | 68 | # train and val 69 | class Train: 70 | def __init__(self): 71 | self.data = DataSet() 72 | 73 | def train(self, args, strategy): 74 | # Define the checkpoint directory to store the checkpoints 75 | checkpoint_dir = args.train_dir 76 | # Name of the checkpoint files 77 | checkpoint_path = os.path.join(checkpoint_dir, "ckpt_{epoch}") 78 | 79 | callbacks = [ 80 | tf.keras.callbacks.TensorBoard(log_dir=args.train_dir, histogram_freq=1), 81 | tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, 82 | save_weights_only=True), 83 | ] 84 | 85 | with strategy.scope(): 86 | model = Net().model 87 | 88 | model.compile(optimizer=optimizers.Adam(), 89 | loss='sparse_categorical_crossentropy', 90 | metrics=['accuracy']) 91 | 92 | model.fit(self.data.train_images, self.data.train_labels, 93 | batch_size=args.batch_size, 94 | epochs=args.epochs, 95 | callbacks=callbacks, 96 | validation_data=(self.data.test_images, self.data.test_labels)) 97 | 98 | # EVAL 99 | model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) 100 | eval_loss, eval_acc = model.evaluate( 101 | self.data.test_images, self.data.test_labels, verbose=2) 102 | print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc)) 103 | 104 | 105 | def main(): 106 | # training params settings 107 | parser = argparse.ArgumentParser(description='Tensorflow 2.0 MNIST Example,' 108 | ' use Mirrorstrategy') 109 | parser.add_argument('--train_dir', '-td', type=str, default='./train_dir', 110 | help='the folder of svaing model') 111 | parser.add_argument('--batch_size', '-b', type=int, default=64, 112 | help='input batch size for training (default: 64)') 113 | parser.add_argument('--test_batchsize', '-tb', type=int, default=1000, 114 | help='input batch size for testing (default: 1000)') 115 | parser.add_argument('--epochs', '-e', type=int, default=10, 116 | help='number of epochs to train (default: 10)') 117 | parser.add_argument('--gpu_nums', '-g', type=int, default=0, 118 | help='number of gpus') 119 | parser.add_argument('--cpu_nums', '-c', type=int, default=0, 120 | help='number of cpus') 121 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.01, 122 | help='learning rate (default: 0.01)') 123 | parser.add_argument('--momentum', type=float, default=0.5, 124 | help='SGD momentum (default: 0.5)') 125 | parser.add_argument('--log_interval', type=int, default=10, 126 | help='how many batches to wait before logging training status') 127 | parser.add_argument('--save_model', '-sm', action='store_true', default=False, 128 | help='For Saving the current Model') 129 | parser.add_argument('--worker_hosts', '-wh', type=str, required=True, 130 | help='Comma-separated list of hostname:port pairs') 131 | parser.add_argument('--job_name', '-j', type=str, default='worker', 132 | help='Ps or worker') 133 | parser.add_argument('--task_index', '-i', type=int, required=True, 134 | help='Index of task within the job') 135 | 136 | args = parser.parse_args() 137 | 138 | strategy = set_strategy(args) 139 | 140 | app = Train() 141 | app.train(args, strategy) 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /tensorflow2/mnist_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*-coding:utf-8 -*- 3 | 4 | import os 5 | import argparse 6 | 7 | import tensorflow as tf 8 | from tensorflow.keras import datasets 9 | from tensorflow.keras import layers, models 10 | from tensorflow.keras import optimizers 11 | 12 | 13 | # create cnn model 14 | class Net(object): 15 | def __init__(self): 16 | model = models.Sequential() 17 | model.add(layers.Conv2D( 18 | 32, (3, 3), activation='relu', input_shape=(28, 28, 1))) 19 | model.add(layers.MaxPooling2D((2, 2))) 20 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 21 | model.add(layers.MaxPooling2D((2, 2))) 22 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 23 | 24 | model.add(layers.Flatten()) 25 | model.add(layers.Dense(64, activation='relu')) 26 | model.add(layers.Dense(10, activation='softmax')) 27 | 28 | model.summary() 29 | 30 | self.model = model 31 | 32 | 33 | # inital dateset 34 | class DataSet(object): 35 | def __init__(self): 36 | data_path = os.path.dirname(os.path.realpath(__file__)) \ 37 | + '/../../datasets/mnist/mnist.npz' 38 | (train_images, train_labels), (test_images, test_labels) = \ 39 | datasets.mnist.load_data(path=data_path) 40 | 41 | train_images = train_images.reshape((60000, 28, 28, 1)) 42 | test_images = test_images.reshape((10000, 28, 28, 1)) 43 | 44 | train_images, test_images = train_images / 255.0, test_images / 255.0 45 | 46 | self.train_images, self.train_labels = train_images, train_labels 47 | self.test_images, self.test_labels = test_images, test_labels 48 | 49 | 50 | class PrintLR(tf.keras.callbacks.Callback): 51 | def __init__(self, lr): 52 | super(PrintLR, self).__init__() 53 | self.lr = lr 54 | 55 | def on_epoch_end(self, epoch, logs=None): 56 | print('\nLearning rate for epoch {} is {}'.format(epoch + 1, self.lr)) 57 | 58 | 59 | # train and val 60 | class Train: 61 | def __init__(self): 62 | self.model = Net().model 63 | self.data = DataSet() 64 | 65 | def train(self, args): 66 | # Define the checkpoint directory to store the checkpoints 67 | checkpoint_dir = args.train_dir 68 | # Name of the checkpoint files 69 | checkpoint_path = os.path.join(checkpoint_dir, "ckpt_{epoch}") 70 | 71 | 72 | callbacks = [ 73 | tf.keras.callbacks.TensorBoard(log_dir=args.train_dir, histogram_freq=1), 74 | tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, 75 | save_weights_only=True), 76 | ] 77 | 78 | self.model.compile(optimizer=optimizers.Adam(), 79 | loss='sparse_categorical_crossentropy', 80 | metrics=['accuracy']) 81 | 82 | self.model.fit(self.data.train_images, self.data.train_labels, 83 | batch_size=args.batch_size, 84 | epochs=args.epochs, 85 | callbacks=callbacks, 86 | validation_data=(self.data.test_images, self.data.test_labels)) 87 | 88 | # EVAL 89 | self.model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) 90 | eval_loss, eval_acc = self.model.evaluate( 91 | self.data.test_images, self.data.test_labels, verbose=2) 92 | print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc)) 93 | 94 | 95 | def main(): 96 | # training params settings 97 | parser = argparse.ArgumentParser(description='Tensorflow 2.0 MNIST Example') 98 | parser.add_argument('--train_dir', '-td', type=str, default='./train_dir', 99 | help='the folder of svaing model') 100 | parser.add_argument('--batch_size', '-b', type=int, default=64, 101 | help='input batch size for training (default: 64)') 102 | parser.add_argument('--test_batchsize', '-tb', type=int, default=1000, 103 | help='input batch size for testing (default: 1000)') 104 | parser.add_argument('--epochs', '-e', type=int, default=10, 105 | help='number of epochs to train (default: 10)') 106 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.01, 107 | help='learning rate (default: 0.01)') 108 | parser.add_argument('--momentum', type=float, default=0.5, 109 | help='SGD momentum (default: 0.5)') 110 | parser.add_argument('--log_interval', type=int, default=10, 111 | help='how many batches to wait before logging training status') 112 | parser.add_argument('--save_model', '-sm', action='store_true', default=False, 113 | help='For Saving the current Model') 114 | 115 | args = parser.parse_args() 116 | 117 | app = Train() 118 | app.train(args) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | --------------------------------------------------------------------------------