├── graphics └── processes-gpus.png ├── src ├── mnist.py ├── mnist-mixed.py └── mnist-distributed.py └── ddp_tutorial.md /graphics/processes-gpus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangkky/distributed_tutorial/HEAD/graphics/processes-gpus.png -------------------------------------------------------------------------------- /src/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import argparse 4 | import torch.multiprocessing as mp 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | from apex.parallel import DistributedDataParallel as DDP 11 | from apex import amp 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', 17 | help='number of data loading workers (default: 4)') 18 | parser.add_argument('-g', '--gpus', default=1, type=int, 19 | help='number of gpus per node') 20 | parser.add_argument('-nr', '--nr', default=0, type =int, 21 | help='ranking within the nodes') 22 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | args = parser.parse_args() 25 | train(0, args) 26 | 27 | 28 | class ConvNet(nn.Module): 29 | def __init__(self, num_classes=10): 30 | super(ConvNet, self).__init__() 31 | self.layer1 = nn.Sequential( 32 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 33 | nn.BatchNorm2d(16), 34 | nn.ReLU(), 35 | nn.MaxPool2d(kernel_size=2, stride=2)) 36 | self.layer2 = nn.Sequential( 37 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 38 | nn.BatchNorm2d(32), 39 | nn.ReLU(), 40 | nn.MaxPool2d(kernel_size=2, stride=2)) 41 | self.fc = nn.Linear(7*7*32, num_classes) 42 | 43 | def forward(self, x): 44 | out = self.layer1(x) 45 | out = self.layer2(out) 46 | out = out.reshape(out.size(0), -1) 47 | out = self.fc(out) 48 | return out 49 | 50 | 51 | def train(gpu, args): 52 | model = ConvNet() 53 | torch.cuda.set_device(gpu) 54 | model.cuda(gpu) 55 | batch_size = 100 56 | # define loss function (criterion) and optimizer 57 | criterion = nn.CrossEntropyLoss().cuda(gpu) 58 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 59 | # Data loading code 60 | train_dataset = torchvision.datasets.MNIST(root='./data', 61 | train=True, 62 | transform=transforms.ToTensor(), 63 | download=True) 64 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 65 | batch_size=batch_size, 66 | shuffle=True, 67 | num_workers=0, 68 | pin_memory=True) 69 | 70 | start = datetime.now() 71 | total_step = len(train_loader) 72 | for epoch in range(args.epochs): 73 | for i, (images, labels) in enumerate(train_loader): 74 | images = images.cuda(non_blocking=True) 75 | labels = labels.cuda(non_blocking=True) 76 | # Forward pass 77 | outputs = model(images) 78 | loss = criterion(outputs, labels) 79 | 80 | # Backward and optimize 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | if (i + 1) % 100 == 0 and gpu == 0: 85 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step, 86 | loss.item())) 87 | if gpu == 0: 88 | print("Training complete in: " + str(datetime.now() - start)) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() -------------------------------------------------------------------------------- /src/mnist-mixed.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import argparse 4 | import torch.multiprocessing as mp 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | from apex.parallel import DistributedDataParallel as DDP 11 | from apex import amp 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', 17 | help='number of data loading workers (default: 4)') 18 | parser.add_argument('-g', '--gpus', default=1, type=int, 19 | help='number of gpus per node') 20 | parser.add_argument('-nr', '--nr', default=0, type=int, 21 | help='ranking within the nodes') 22 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | args = parser.parse_args() 25 | args.world_size = args.gpus * args.nodes 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '8888' 28 | mp.spawn(train, nprocs=args.gpus, args=(args,)) 29 | 30 | 31 | class ConvNet(nn.Module): 32 | def __init__(self, num_classes=10): 33 | super(ConvNet, self).__init__() 34 | self.layer1 = nn.Sequential( 35 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 36 | nn.BatchNorm2d(16), 37 | nn.ReLU(), 38 | nn.MaxPool2d(kernel_size=2, stride=2)) 39 | self.layer2 = nn.Sequential( 40 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 41 | nn.BatchNorm2d(32), 42 | nn.ReLU(), 43 | nn.MaxPool2d(kernel_size=2, stride=2)) 44 | self.fc = nn.Linear(7*7*32, num_classes) 45 | 46 | def forward(self, x): 47 | out = self.layer1(x) 48 | out = self.layer2(out) 49 | out = out.reshape(out.size(0), -1) 50 | out = self.fc(out) 51 | return out 52 | 53 | 54 | def train(gpu, args): 55 | rank = args.nr * args.gpus + gpu 56 | dist.init_process_group( 57 | backend='nccl', 58 | init_method='env://', 59 | world_size=args.world_size, 60 | rank=rank) 61 | torch.manual_seed(0) 62 | model = ConvNet() 63 | torch.cuda.set_device(gpu) 64 | model.cuda(gpu) 65 | batch_size = 100 66 | # define loss function (criterion) and optimizer 67 | criterion = nn.CrossEntropyLoss().cuda(gpu) 68 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 69 | # Wrap the model 70 | model, optimizer = amp.initialize(model, optimizer, opt_level='O2') 71 | model = DDP(model) 72 | # Data loading code 73 | train_dataset = torchvision.datasets.MNIST( 74 | root='./data', 75 | train=True, 76 | transform=transforms.ToTensor(), 77 | download=True 78 | ) 79 | train_sampler = torch.utils.data.distributed.DistributedSampler( 80 | train_dataset, 81 | num_replicas=args.world_size, 82 | rank=rank) 83 | train_loader = torch.utils.data.DataLoader( 84 | dataset=train_dataset, 85 | batch_size=batch_size, 86 | shuffle=False, 87 | num_workers=0, 88 | pin_memory=True, 89 | sampler=train_sampler 90 | ) 91 | 92 | start = datetime.now() 93 | total_step = len(train_loader) 94 | for epoch in range(args.epochs): 95 | for i, (images, labels) in enumerate(train_loader): 96 | images = images.cuda(non_blocking=True) 97 | labels = labels.cuda(non_blocking=True) 98 | # Forward pass 99 | outputs = model(images) 100 | loss = criterion(outputs, labels) 101 | 102 | # Backward and optimize 103 | optimizer.zero_grad() 104 | with amp.scale_loss(loss, optimizer) as scaled_loss: 105 | scaled_loss.backward() 106 | optimizer.step() 107 | if (i + 1) % 100 == 0 and gpu == 0: 108 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( 109 | epoch + 1, 110 | args.epochs, 111 | i + 1, 112 | total_step, 113 | loss.item()) 114 | ) 115 | if gpu == 0: 116 | print("Training complete in: " + str(datetime.now() - start)) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /src/mnist-distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import argparse 4 | import torch.multiprocessing as mp 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | from apex.parallel import DistributedDataParallel as DDP 11 | from apex import amp 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', 17 | help='number of data loading workers (default: 4)') 18 | parser.add_argument('-g', '--gpus', default=1, type=int, 19 | help='number of gpus per node') 20 | parser.add_argument('-nr', '--nr', default=0, type=int, 21 | help='ranking within the nodes') 22 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | args = parser.parse_args() 25 | args.world_size = args.gpus * args.nodes 26 | os.environ['MASTER_ADDR'] = '10.57.23.164' 27 | os.environ['MASTER_PORT'] = '8888' 28 | mp.spawn(train, nprocs=args.gpus, args=(args,)) 29 | 30 | 31 | class ConvNet(nn.Module): 32 | def __init__(self, num_classes=10): 33 | super(ConvNet, self).__init__() 34 | self.layer1 = nn.Sequential( 35 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 36 | nn.BatchNorm2d(16), 37 | nn.ReLU(), 38 | nn.MaxPool2d(kernel_size=2, stride=2)) 39 | self.layer2 = nn.Sequential( 40 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 41 | nn.BatchNorm2d(32), 42 | nn.ReLU(), 43 | nn.MaxPool2d(kernel_size=2, stride=2)) 44 | self.fc = nn.Linear(7*7*32, num_classes) 45 | 46 | def forward(self, x): 47 | out = self.layer1(x) 48 | out = self.layer2(out) 49 | out = out.reshape(out.size(0), -1) 50 | out = self.fc(out) 51 | return out 52 | 53 | 54 | def train(gpu, args): 55 | rank = args.nr * args.gpus + gpu 56 | dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) 57 | torch.manual_seed(0) 58 | model = ConvNet() 59 | torch.cuda.set_device(gpu) 60 | model.cuda(gpu) 61 | batch_size = 100 62 | # define loss function (criterion) and optimizer 63 | criterion = nn.CrossEntropyLoss().cuda(gpu) 64 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 65 | # Wrap the model 66 | model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) 67 | # Data loading code 68 | train_dataset = torchvision.datasets.MNIST(root='./data', 69 | train=True, 70 | transform=transforms.ToTensor(), 71 | download=True) 72 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, 73 | num_replicas=args.world_size, 74 | rank=rank) 75 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 76 | batch_size=batch_size, 77 | shuffle=False, 78 | num_workers=0, 79 | pin_memory=True, 80 | sampler=train_sampler) 81 | 82 | start = datetime.now() 83 | total_step = len(train_loader) 84 | for epoch in range(args.epochs): 85 | for i, (images, labels) in enumerate(train_loader): 86 | images = images.cuda(non_blocking=True) 87 | labels = labels.cuda(non_blocking=True) 88 | # Forward pass 89 | outputs = model(images) 90 | loss = criterion(outputs, labels) 91 | 92 | # Backward and optimize 93 | optimizer.zero_grad() 94 | loss.backward() 95 | optimizer.step() 96 | if (i + 1) % 100 == 0 and gpu == 0: 97 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step, 98 | loss.item())) 99 | if gpu == 0: 100 | print("Training complete in: " + str(datetime.now() - start)) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /ddp_tutorial.md: -------------------------------------------------------------------------------- 1 | # Distributed data parallel training in Pytorch 2 | 3 | ## Motivation 4 | 5 | The easiest way to speed up neural network training is to use a GPU, which provides large speedups over CPUs on the types of calculations (matrix multiplies and additions) that are common in neural networks. As the model or dataset gets bigger, one GPU quickly becomes insufficient. For example, big language models such as [BERT](https://arxiv.org/abs/1810.04805) and [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) are trained on hundreds of GPUs. To multi-GPU training, we must have a way to split the model and data between different GPUs and to coordinate the training. 6 | 7 | 8 | ### Why distributed data parallel? 9 | 10 | I like to implement my models in Pytorch because I find it has the best balance between control and ease of use of the major neural-net frameworks. Pytorch has two ways to split models and data across multiple GPUs: [`nn.DataParallel`](https://pytorch.org/docs/stable/nn.html#dataparallel) and [`nn.DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel). `nn.DataParallel` is easier to use (just wrap the model and run your training script). However, because it uses one process to compute the model weights and then distribute them to each GPU during each batch, networking quickly becomes a bottle-neck and GPU utilization is often very low. Furthermore, `nn.DataParallel` requires that all the GPUs be on the same node and doesn't work with [Apex](https://nvidia.github.io/apex/amp.html) for [mixed-precision](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) training. 11 | 12 | ### The existing documentation is insufficient 13 | 14 | In general, the Pytorch documentation is thorough and clear, especially in version 1.0.x. I taught myself Pytorch almost entirely from the documentation and tutorials: this is definitely much more a reflection on Pytorch's ease of use and excellent documentation than it is any special ability on my part. So I was very surprised when I spent some time trying to figure out how to use `DistributedDataParallel` and found all of the examples and tutorials to be some combination of inaccessible, incomplete, or overloaded with irrelevant features. 15 | 16 | Pytorch provides a [tutorial](https://pytorch.org/tutorials/beginner/aws_distributed_training_tutorial.html) on distributed training using AWS, which does a pretty good job of showing you how to set things up on the AWS side. However, the rest of it is a bit messy, as it spends a lot of time showing how to calculate metrics for some reason before going back to showing how to wrap your model and launch the processes. It also doesn't describe what `nn.DistributedDataParallel` does, which makes the relevant code blocks difficult to follow. 17 | 18 | The [tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html) on writing distributed applications in Pytorch has much more detail than necessary for a first pass and is not accessible to somebody without a strong background on multiprocessing in Python. It spends a lot of time replicating the functionality in `nn.DistributedDataParallel`. However, it doesn't give a high-level overview of what it does and provides no insight on how to *use* it. 19 | (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) 20 | 21 | There's also a Pytorch [tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) on getting started with distributed data parallel. This one shows how to do some setup, but doesn't explain what the setup is for, and then shows some code to split a model across GPUs and do one optimization step. Unfortunately, I'm pretty sure the code as written won't run (the function names don't match up) and furthermore it doesn't tell you *how* to run the code. Like the previous tutorial, it also doesn't give a high-level overview of how distributed training works. 22 | 23 | The closest to a MWE example Pytorch provides is the [Imagenet](https://github.com/pytorch/examples/tree/master/imagenet) training example. Unfortunately, that example also demonstrates pretty much every other feature Pytorch has, so it's difficult to pick out what pertains to distributed, multi-GPU training. 24 | 25 | Apex provides their own [version](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) of the Pytorch Imagenet example. The documentation there tells you that their version of `nn.DistributedDataParallel` is a drop-in replacement for Pytorch's, which is only helpful after learning how to use Pytorch's. 26 | 27 | This [tutorial](http://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/) has a good description of what's going on under the hood and how it's different from `nn.DataParallel`. However, it doesn't have code examples of how to use `nn.DataParallel`. 28 | 29 | ## Outline 30 | 31 | This tutorial is really directed at people who are already familiar with training neural network models in Pytorch, and I won't go over any of those parts of the code. I'll begin by summarizing the big picture. I then show a minimum working example of training on MNIST using on GPU. I modify this example to train on multiple GPUs, possibly across multiple nodes, and explain the changes line by line. Importantly, I also explain how to run the code. As a bonus, I also demonstrate how to use Apex to do easy mixed-precision distribued training. 32 | 33 | ## The big picture 34 | 35 | Multiprocessing with `DistributedDataParallel` duplicates the model across multiple GPUs, each of which is controlled by one process. (If you want, you can have each process control multiple GPUs, but that should be obviously slower than having one GPU per process. It's also possible to have multiple worker processes that fetch data for each GPU, but I'm going to leave that out for the sake of simplicity.) The GPUs can all be on the same node or spread across multiple nodes. Every process does identical tasks, and each process communicates with all the others. Only gradients are passed between the processes/GPUs so that network communication is less of a bottleneck. 36 | 37 | ![figure](graphics/processes-gpus.png) 38 | 39 | During training, each process loads its own minibatches from disk and passes them to its GPU. Each GPU does its own forward pass, and then the gradients are all-reduced across the GPUs. Gradients for each layer do not depend on previous layers, so the gradient all-reduce is calculated concurrently with the backwards pass to futher alleviate the networking bottleneck. At the end of the backwards pass, every node has the averaged gradients, ensuring that the model weights stay synchronized. 40 | 41 | All this requires that the multiple processes, possibly on multiple nodes, are synchronized and communicate. Pytorch does this through its [`distributed.init_process_group`](https://pytorch.org/docs/stable/distributed.html#initialization) function. This function needs to know where to find process 0 so that all the processes can sync up and the total number of processes to expect. Each individual process also needs to know the total number of processes as well as its rank within the processes and which GPU to use. It's common to call the total number of processes the *world size*. Finally, each process needs to know which slice of the data to work on so that the batches are non-overlapping. Pytorch provides [`nn.utils.data.DistributedSampler`](https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html) to accomplish this. 42 | 43 | ## Minimum working examples with explanations 44 | 45 | To demonstrate how to do this, I'll create an example that [trains on MNIST](https://github.com/yangkky/distributed_tutorial/blob/master/src/mnist.py), and then modify it to run on [multiple GPUs across multiple nodes](https://github.com/yangkky/distributed_tutorial/blob/master/src/mnist-distributed.py), and finally to also allow [mixed-precision training](https://github.com/yangkky/distributed_tutorial/blob/master/src/mnist-mixed.py). 46 | 47 | ### Without multiprocessing 48 | 49 | First, we import everything we need. 50 | 51 | ```python {.line-numbers} 52 | import os 53 | from datetime import datetime 54 | import argparse 55 | import torch.multiprocessing as mp 56 | import torchvision 57 | import torchvision.transforms as transforms 58 | import torch 59 | import torch.nn as nn 60 | import torch.distributed as dist 61 | from apex.parallel import DistributedDataParallel as DDP 62 | from apex import amp 63 | ``` 64 | 65 | We define a very simple convolutional model for predicting MNIST. 66 | 67 | ```python 68 | class ConvNet(nn.Module): 69 | def __init__(self, num_classes=10): 70 | super(ConvNet, self).__init__() 71 | self.layer1 = nn.Sequential( 72 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 73 | nn.BatchNorm2d(16), 74 | nn.ReLU(), 75 | nn.MaxPool2d(kernel_size=2, stride=2)) 76 | self.layer2 = nn.Sequential( 77 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 78 | nn.BatchNorm2d(32), 79 | nn.ReLU(), 80 | nn.MaxPool2d(kernel_size=2, stride=2)) 81 | self.fc = nn.Linear(7*7*32, num_classes) 82 | 83 | def forward(self, x): 84 | out = self.layer1(x) 85 | out = self.layer2(out) 86 | out = out.reshape(out.size(0), -1) 87 | out = self.fc(out) 88 | return out 89 | ``` 90 | 91 | The `main()` function will take in some arguments and run the training function. 92 | 93 | ```python 94 | def main(): 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') 97 | parser.add_argument('-g', '--gpus', default=1, type=int, 98 | help='number of gpus per node') 99 | parser.add_argument('-nr', '--nr', default=0, type=int, 100 | help='ranking within the nodes') 101 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 102 | help='number of total epochs to run') 103 | args = parser.parse_args() 104 | train(0, args) 105 | ``` 106 | 107 | And here's the train function. 108 | 109 | ```python 110 | def train(gpu, args): 111 | model = ConvNet() 112 | torch.cuda.set_device(gpu) 113 | model.cuda(gpu) 114 | batch_size = 100 115 | # define loss function (criterion) and optimizer 116 | criterion = nn.CrossEntropyLoss().cuda(gpu) 117 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 118 | # Data loading code 119 | train_dataset = torchvision.datasets.MNIST(root='./data', 120 | train=True, 121 | transform=transforms.ToTensor(), 122 | download=True) 123 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 124 | batch_size=batch_size, 125 | shuffle=True, 126 | num_workers=0, 127 | pin_memory=True) 128 | 129 | start = datetime.now() 130 | total_step = len(train_loader) 131 | for epoch in range(args.epochs): 132 | for i, (images, labels) in enumerate(train_loader): 133 | images = images.cuda(non_blocking=True) 134 | labels = labels.cuda(non_blocking=True) 135 | # Forward pass 136 | outputs = model(images) 137 | loss = criterion(outputs, labels) 138 | 139 | # Backward and optimize 140 | optimizer.zero_grad() 141 | loss.backward() 142 | optimizer.step() 143 | if (i + 1) % 100 == 0 and gpu == 0: 144 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step, 145 | loss.item())) 146 | if gpu == 0: 147 | print("Training complete in: " + str(datetime.now() - start)) 148 | ``` 149 | 150 | Finally, we want to make sure the `main()` function gets called. 151 | 152 | ```python 153 | if __name__ == '__main__': 154 | main() 155 | ``` 156 | 157 | There's definitely some extra stuff in here (the number of gpus and nodes, for example) that we don't need yet, but it's helpful to put the whole skeleton in place. 158 | 159 | We can run this code by opening a terminal and typing `python src/mnist.py -n 1 -g 1 -nr 0`, which will train on a single gpu on a single node. 160 | 161 | ### With multiprocessing 162 | 163 | To do this with multiprocessing, we need a script that will launch a process for every GPU. Each process needs to know which GPU to use, and where it ranks amongst all the processes that are running. We'll need to run the script on each node. 164 | 165 | Let's take a look at the changes to each function. I've fenced off the new code to make it easy to find. 166 | 167 | ```python 168 | def main(): 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') 171 | parser.add_argument('-g', '--gpus', default=1, type=int, 172 | help='number of gpus per node') 173 | parser.add_argument('-nr', '--nr', default=0, type=int, 174 | help='ranking within the nodes') 175 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 176 | help='number of total epochs to run') 177 | args = parser.parse_args() 178 | ######################################################### 179 | args.world_size = args.gpus * args.nodes # 180 | os.environ['MASTER_ADDR'] = '10.57.23.164' # 181 | os.environ['MASTER_PORT'] = '8888' # 182 | mp.spawn(train, nprocs=args.gpus, args=(args,)) # 183 | ######################################################### 184 | ``` 185 | 186 | I hand-waved over the arguments in the last section, but now we actually need them. 187 | 188 | - `args.nodes` is the total number of nodes we're going to use. 189 | - `args.gpus` is the number of gpus on each node. 190 | - `args.nr` is the rank of the current node within all the nodes, and goes from 0 to `args.nodes` - 1. 191 | 192 | Now, let's go through the new changes line by line: 193 | 194 | Line 12: Based on the number of nodes and gpus per node, we can calculate the `world_size`, or the total number of processes to run, which is equal to the total number of gpus because we're assigning one gpu to every process. 195 | 196 | Line 13: This tells the multiprocessing module what IP address to look at for process 0. It needs this so that all the processes can sync up initially. 197 | 198 | Line 14: Likewise, this is the port to use when looking for process 0. 199 | 200 | Line 15: Now, instead of running the train function once, we will spawn `args.gpus` processes, each of which runs `train(i, args)`, where `i` goes from 0 to `args.gpus` - 1. Remember, we run the `main()` function on each node, so that in total there will be `args.nodes` * `args.gpus` = `args.world_size` processes. 201 | 202 | Instead of lines 13 and 14, I could have run `export MASTER_ADDR=10.57.23.164` and `export MASTER_PORT=8888` in the terminal. 203 | 204 | Next, let's look at the modifications to `train`. I'll fence the new lines again. 205 | 206 | ```python 207 | def train(gpu, args): 208 | ###################################################################### 209 | rank = args.nr * args.gpus + gpu 210 | dist.init_process_group( 211 | backend='nccl', 212 | init_method='env://', 213 | world_size=args.world_size, 214 | rank=rank 215 | ) 216 | ###################################################################### 217 | 218 | model = ConvNet() 219 | torch.cuda.set_device(gpu) 220 | model.cuda(gpu) 221 | batch_size = 100 222 | # define loss function (criterion) and optimizer 223 | criterion = nn.CrossEntropyLoss().cuda(gpu) 224 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 225 | 226 | ###################################################################### 227 | # Wrap the model 228 | model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) 229 | ###################################################################### 230 | 231 | # Data loading code 232 | train_dataset = torchvision.datasets.MNIST(root='./data', 233 | train=True, 234 | transform=transforms.ToTensor(), 235 | download=True) 236 | 237 | ###################################################################### 238 | train_sampler = torch.utils.data.distributed.DistributedSampler( 239 | train_dataset, 240 | num_replicas=args.world_size, 241 | rank=rank 242 | ) 243 | ###################################################################### 244 | 245 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 246 | batch_size=batch_size, 247 | ###################################################################### 248 | shuffle=False, # 249 | ###################################################################### 250 | num_workers=0, 251 | pin_memory=True, 252 | ###################################################################### 253 | sampler=train_sampler) # 254 | ###################################################################### 255 | 256 | start = datetime.now() 257 | total_step = len(train_loader) 258 | for epoch in range(args.epochs): 259 | for i, (images, labels) in enumerate(train_loader): 260 | images = images.cuda(non_blocking=True) 261 | labels = labels.cuda(non_blocking=True) 262 | # Forward pass 263 | outputs = model(images) 264 | loss = criterion(outputs, labels) 265 | 266 | # Backward and optimize 267 | optimizer.zero_grad() 268 | loss.backward() 269 | optimizer.step() 270 | if (i + 1) % 100 == 0 and gpu == 0: 271 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step, 272 | loss.item())) 273 | if gpu == 0: 274 | print("Training complete in: " + str(datetime.now() - start)) 275 | ``` 276 | 277 | Line 3: This is the global rank of the process within all of the processes (one process per GPU). We'll use this for line 6. 278 | 279 | Lines 4 - 6: Initialize the process and join up with the other processes. This is "blocking," meaning that no process will continue until all processes have joined. I'm using the `nccl` backend here because the [pytorch docs](https://pytorch.org/docs/stable/distributed.html) say it's the fastest of the available ones. The `init_method` tells the process group where to look for some settings. In this case, it's looking at environment variables for the `MASTER_ADDR` and `MASTER_PORT`, which we set within `main`. I could have set the `world_size` there as well as `WORLD_SIZE`, but I'm choosing to set it here as a keyword argument, along with the global rank of the current process. 280 | 281 | Line 23: Wrap the model as a [`DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) model. This reproduces the model onto the GPU for the process. 282 | 283 | Lines 32-36: The [`nn.utils.data.DistributedSampler`](https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html) makes sure that each process gets a different slice of the training data. 284 | 285 | Lines 42 and 47: Use the `nn.utils.data.DistributedSampler` instead of shuffling the usual way. 286 | 287 | To run this on, say, 4 nodes with 8 GPUs each, we need 4 terminals (one on each node). On node 0 (as set by line 13 in `main`): 288 | 289 | ```python src/mnist-distributed.py -n 4 -g 8 -nr 0``` 290 | 291 | Then, on the other nodes: 292 | 293 | ```python src/mnist-distributed.py -n 4 -g 8 -nr i``` 294 | 295 | for $i \in \{1, 2, 3\}$. In other words, we run this script on each node, telling it to launch `args.gpus` processes that sync with each other before training begins. 296 | 297 | Note that the effective batchsize is now the per/GPU batchsize (the value in the script) * the total number of GPUs (the worldsize). 298 | 299 | 300 | ### With Apex for mixed precision 301 | 302 | Mixed precision training (training in a combination of float (FP32) and half (FP16) precision) allows us to use larger batch sizes and take advantage of NVIDIA [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensorcore/) for faster computation. AWS [p3](https://aws.amazon.com/ec2/instance-types/p3/) instances use NVIDIA Tesla V100 GPUs with Tensor Cores. We only need to change the `train` function. For the sake of concision, I've taken out the data loading code and the code after the backwards pass from the example here, replacing them with `...`, but they are still in the [full script](https://github.com/yangkky/distributed_tutorial/blob/master/src/mnist-mixed.py). 303 | 304 | ```python 305 | def train(gpu, args): 306 | rank = args.nr * args.gpus + gpu 307 | dist.init_process_group( 308 | backend='nccl', 309 | init_method='env://', 310 | world_size=args.world_size, 311 | rank=rank) 312 | 313 | model = ConvNet() 314 | torch.cuda.set_device(gpu) 315 | model.cuda(gpu) 316 | batch_size = 100 317 | # define loss function (criterion) and optimizer 318 | criterion = nn.CrossEntropyLoss().cuda(gpu) 319 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 320 | # Wrap the model 321 | ###################################################################### 322 | model, optimizer = amp.initialize(model, optimizer, opt_level='O2') 323 | model = DDP(model) 324 | ###################################################################### 325 | # Data loading code 326 | ... 327 | start = datetime.now() 328 | total_step = len(train_loader) 329 | for epoch in range(args.epochs): 330 | for i, (images, labels) in enumerate(train_loader): 331 | images = images.cuda(non_blocking=True) 332 | labels = labels.cuda(non_blocking=True) 333 | # Forward pass 334 | outputs = model(images) 335 | loss = criterion(outputs, labels) 336 | 337 | # Backward and optimize 338 | optimizer.zero_grad() 339 | ###################################################################### 340 | with amp.scale_loss(loss, optimizer) as scaled_loss: 341 | scaled_loss.backward() 342 | ###################################################################### 343 | optimizer.step() 344 | ... 345 | ``` 346 | 347 | Line 17: [`amp.initialize`](https://nvidia.github.io/apex/amp.html#unified-api) wraps the model and optimizer for mixed precision training. Note that that the model must already be on the correct GPU before calling `amp.initialize`. The `opt_level` goes from `O0`, which uses all floats, through `O3`, which uses half-precision throughout. `O1` and `O2` are different degrees of mixed-precision, the details of which can be found in the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties). Yes, the first character in all those codes is a capital letter 'O', while the second character is a number. Yes, if you use a zero instead, you will get a baffling error message. 348 | 349 | Line 18: [`apex.parallel.DistributedDataParallel`](https://nvidia.github.io/apex/parallel.html) is a drop-in replacement for `nn.DistributedDataParallel`. We no longer have to specify the GPUs because Apex only allows one GPU per process. It also assumes that the script calls `torch.cuda.set_device(local_rank)`(line 10) before moving the model to GPU. 350 | 351 | Lines 36-37: Mixed-precision training requires that the loss is [scaled](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) in order to prevent the gradients from underflowing. Apex does this automatically. 352 | 353 | This script is run the same way as the distributed training script. 354 | 355 | ## Acknowledgments 356 | 357 | Many thanks to the computational team at VL56 for all your work on various parts of this. I'd like to especially thank Stephen Kottman, who got a MWE up while I was still trying to figure out how multiprocessing in Python works, and then explained it to me, and Andy Beam, who greatly improved the first draft of this tutorial. --------------------------------------------------------------------------------