├── data ├── image.png └── distributed.py ├── run.sh ├── from_mingfei ├── run.sh ├── bs128_n1.log ├── bs128_n2.log └── mnist_dist.py ├── bs128_n1.log ├── bs128_n2_allreduce_sum.log ├── bs128_n2_allreduce_average.log ├── README.md └── mnist_dist.py /data/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhzhao/PyTorch-MPI-DDP-example/HEAD/data/image.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #salloc -p skx-6148-debug -N 2 4 | 5 | #n1 6 | #python mnist_dist.py 7 | 8 | #n2 9 | mpirun -n 2 python -u mnist_dist.py #this work ok 10 | -------------------------------------------------------------------------------- /from_mingfei/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #salloc -p skx-6148-debug -N 2 4 | 5 | #n1 6 | #python mnist_dist.py 7 | 8 | #n2 9 | mpirun -n 2 python -u mnist_dist.py #this work ok 10 | -------------------------------------------------------------------------------- /bs128_n1.log: -------------------------------------------------------------------------------- 1 | [0/1] first broadcast start 2 | [0/1] first broadcast done 3 | [0/1] Epoch 0 Loss 1.303139 Global batch size 128 on 1 ranks 4 | [0/1] Epoch 1 Loss 0.550393 Global batch size 128 on 1 ranks 5 | [0/1] Epoch 2 Loss 0.427802 Global batch size 128 on 1 ranks 6 | [0/1] Epoch 3 Loss 0.362163 Global batch size 128 on 1 ranks 7 | [0/1] Epoch 4 Loss 0.319644 Global batch size 128 on 1 ranks 8 | [0/1] Epoch 5 Loss 0.291402 Global batch size 128 on 1 ranks 9 | [0/1] Epoch 6 Loss 0.268326 Global batch size 128 on 1 ranks 10 | [0/1] Epoch 7 Loss 0.257067 Global batch size 128 on 1 ranks 11 | [0/1] Epoch 8 Loss 0.239572 Global batch size 128 on 1 ranks 12 | [0/1] Epoch 9 Loss 0.221816 Global batch size 128 on 1 ranks 13 | 14 | -------------------------------------------------------------------------------- /from_mingfei/bs128_n1.log: -------------------------------------------------------------------------------- 1 | xhzhao@xhzhao-ub:~/tools/PyTorch-MPI-DDP-example/from_mingfei$ python mnist_dist.py 2 | [0/1] Epoch 0 Loss 1.316212 Global batch size 128 on 1 ranks 3 | [0/1] Epoch 1 Loss 0.546703 Global batch size 128 on 1 ranks 4 | [0/1] Epoch 2 Loss 0.425036 Global batch size 128 on 1 ranks 5 | [0/1] Epoch 3 Loss 0.359292 Global batch size 128 on 1 ranks 6 | [0/1] Epoch 4 Loss 0.315650 Global batch size 128 on 1 ranks 7 | [0/1] Epoch 5 Loss 0.290439 Global batch size 128 on 1 ranks 8 | [0/1] Epoch 6 Loss 0.263221 Global batch size 128 on 1 ranks 9 | [0/1] Epoch 7 Loss 0.251407 Global batch size 128 on 1 ranks 10 | [0/1] Epoch 8 Loss 0.238773 Global batch size 128 on 1 ranks 11 | [0/1] Epoch 9 Loss 0.225228 Global batch size 128 on 1 ranks 12 | 13 | -------------------------------------------------------------------------------- /from_mingfei/bs128_n2.log: -------------------------------------------------------------------------------- 1 | xhzhao@xhzhao-ub:~/tools/PyTorch-MPI-DDP-example/from_mingfei$ ./run.sh 2 | [0/2] Epoch 0 Loss 0.970655 Global batch size 128 on 2 ranks 3 | [1/2] Epoch 0 Loss 0.957377 Global batch size 128 on 2 ranks 4 | [0/2] Epoch 1 Loss 0.409037 Global batch size 128 on 2 ranks 5 | [1/2] Epoch 1 Loss 0.407893 Global batch size 128 on 2 ranks 6 | [0/2] Epoch 2 Loss 0.315330 Global batch size 128 on 2 ranks 7 | [1/2] Epoch 2 Loss 0.310589 Global batch size 128 on 2 ranks 8 | [0/2] Epoch 3 Loss 0.269583 Global batch size 128 on 2 ranks 9 | [1/2] Epoch 3 Loss 0.268614 Global batch size 128 on 2 ranks 10 | [0/2] Epoch 4 Loss 0.245311 Global batch size 128 on 2 ranks 11 | [1/2] Epoch 4 Loss 0.238192 Global batch size 128 on 2 ranks 12 | [0/2] Epoch 5 Loss 0.223034 Global batch size 128 on 2 ranks 13 | [1/2] Epoch 5 Loss 0.221292 Global batch size 128 on 2 ranks 14 | [0/2] Epoch 6 Loss 0.207834 Global batch size 128 on 2 ranks 15 | [1/2] Epoch 6 Loss 0.204415 Global batch size 128 on 2 ranks 16 | [0/2] Epoch 7 Loss 0.199458 Global batch size 128 on 2 ranks 17 | [1/2] Epoch 7 Loss 0.192848 Global batch size 128 on 2 ranks 18 | [0/2] Epoch 8 Loss 0.184656 Global batch size 128 on 2 ranks 19 | [1/2] Epoch 8 Loss 0.179757 Global batch size 128 on 2 ranks 20 | [0/2] Epoch 9 Loss 0.177797 Global batch size 128 on 2 ranks 21 | [1/2] Epoch 9 Loss 0.174681 Global batch size 128 on 2 ranks 22 | 23 | -------------------------------------------------------------------------------- /bs128_n2_allreduce_sum.log: -------------------------------------------------------------------------------- 1 | xhzhao@xhzhao-ub:~/tools/PyTorch-MPI-DDP-example$ ./run.sh 2 | [1/2] first broadcast start 3 | [0/2] first broadcast start 4 | [0/2] first broadcast done 5 | [1/2] first broadcast done 6 | [0/2] Epoch 0 Loss 0.956670 Global batch size 128 on 2 ranks 7 | [1/2] Epoch 0 Loss 0.961719 Global batch size 128 on 2 ranks 8 | [0/2] Epoch 1 Loss 0.405841 Global batch size 128 on 2 ranks 9 | [1/2] Epoch 1 Loss 0.397616 Global batch size 128 on 2 ranks 10 | [1/2] Epoch 2 Loss 0.318296 Global batch size 128 on 2 ranks 11 | [0/2] Epoch 2 Loss 0.317754 Global batch size 128 on 2 ranks 12 | [0/2] Epoch 3 Loss 0.272213 Global batch size 128 on 2 ranks 13 | [1/2] Epoch 3 Loss 0.271408 Global batch size 128 on 2 ranks 14 | [0/2] Epoch 4 Loss 0.242495 Global batch size 128 on 2 ranks 15 | [1/2] Epoch 4 Loss 0.242542 Global batch size 128 on 2 ranks 16 | [0/2] Epoch 5 Loss 0.219803 Global batch size 128 on 2 ranks 17 | [1/2] Epoch 5 Loss 0.216617 Global batch size 128 on 2 ranks 18 | [0/2] Epoch 6 Loss 0.206899 Global batch size 128 on 2 ranks 19 | [1/2] Epoch 6 Loss 0.201472 Global batch size 128 on 2 ranks 20 | [1/2] Epoch 7 Loss 0.197499 Global batch size 128 on 2 ranks 21 | [0/2] Epoch 7 Loss 0.194447 Global batch size 128 on 2 ranks 22 | [0/2] Epoch 8 Loss 0.185426 Global batch size 128 on 2 ranks 23 | [1/2] Epoch 8 Loss 0.181102 Global batch size 128 on 2 ranks 24 | [0/2] Epoch 9 Loss 0.177293 Global batch size 128 on 2 ranks 25 | [1/2] Epoch 9 Loss 0.171482 Global batch size 128 on 2 ranks 26 | -------------------------------------------------------------------------------- /bs128_n2_allreduce_average.log: -------------------------------------------------------------------------------- 1 | xhzhao@xhzhao-ub:~/tools/PyTorch-MPI-DDP-example$ ./run.sh 2 | [1/2] first broadcast start 3 | [0/2] first broadcast start 4 | [0/2] first broadcast done 5 | [1/2] first broadcast done 6 | [0/2] Epoch 0 Loss 1.313255 Global batch size 128 on 2 ranks 7 | [1/2] Epoch 0 Loss 1.315885 Global batch size 128 on 2 ranks 8 | [0/2] Epoch 1 Loss 0.542634 Global batch size 128 on 2 ranks 9 | [1/2] Epoch 1 Loss 0.536465 Global batch size 128 on 2 ranks 10 | [0/2] Epoch 2 Loss 0.427217 Global batch size 128 on 2 ranks 11 | [1/2] Epoch 2 Loss 0.426838 Global batch size 128 on 2 ranks 12 | [0/2] Epoch 3 Loss 0.360737 Global batch size 128 on 2 ranks 13 | [1/2] Epoch 3 Loss 0.358833 Global batch size 128 on 2 ranks 14 | [0/2] Epoch 4 Loss 0.319203 Global batch size 128 on 2 ranks 15 | [1/2] Epoch 4 Loss 0.316800 Global batch size 128 on 2 ranks 16 | [0/2] Epoch 5 Loss 0.286114 Global batch size 128 on 2 ranks 17 | [1/2] Epoch 5 Loss 0.283350 Global batch size 128 on 2 ranks 18 | [0/2] Epoch 6 Loss 0.268937 Global batch size 128 on 2 ranks 19 | [1/2] Epoch 6 Loss 0.264468 Global batch size 128 on 2 ranks 20 | [0/2] Epoch 7 Loss 0.250410 Global batch size 128 on 2 ranks 21 | [1/2] Epoch 7 Loss 0.256080 Global batch size 128 on 2 ranks 22 | [0/2] Epoch 8 Loss 0.237510 Global batch size 128 on 2 ranks 23 | [1/2] Epoch 8 Loss 0.233641 Global batch size 128 on 2 ranks 24 | [0/2] Epoch 9 Loss 0.232977 Global batch size 128 on 2 ranks 25 | [1/2] Epoch 9 Loss 0.224614 Global batch size 128 on 2 ranks 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-MPI-DDP-example 2 | This github's target is to enable MPI-DDP in PyTorch. As you know, PyTorch [DDP](http://pytorch.org/docs/master/nn.html#torch.nn.parallel.DistributedDataParallel) only support nccl and gloo backends. 3 | 4 | You will be able to enable the distributed MPI-backend PyTorch Training with only 2 lines: 5 | 1. add DistributedSampler in your DataLoader 6 | 2. pass your model to DistributedDataParallel 7 | 8 | This usage is exactly the same as the torch.nn.parallel.DistributedDataParallel() 9 | See imagenet example here: https://github.com/pytorch/examples/blob/master/imagenet/main.py#L88 10 | 11 | ### Requirements 12 | * Pytorch : build from source (v0.3.1 is recommended) 13 | 14 | 15 | ### Usage 16 | bash run.sh 17 | 18 | ### Strong vs Weak Scaling 19 | This github implemented a strong scaling for mnist, which means the global batchsize is fixed no matter how many node we use. See more info about Strong vs Weak Scaling at [wiki](https://en.wikipedia.org/wiki/Scalability#Weak_versus_strong_scaling). 20 | Since this is a strong scaling example, we should perform an average after the all_reduce, which is the same as [torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L338). 21 | 22 | Our experient result: 23 | ![result](https://github.com/xhzhao/PyTorch-MPI-DDP-example/blob/master/data/image.png) 24 | 25 | ### More examples for [PyTorch example](https://github.com/pytorch/examples/) 26 | * mnist: https://github.com/xhzhao/examples/tree/master/mnist 27 | * imagenet: https://github.com/xhzhao/examples/tree/master/imagenet 28 | -------------------------------------------------------------------------------- /data/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 3 | import torch.distributed as dist 4 | from torch.nn.modules import Module 5 | 6 | ''' 7 | This version of DistributedDataParallel is designed to be used in conjunction with the DistributedSampler 8 | You will be able to enable the distributed MPI-backend PyTorch Training with only 2 lines: 9 | 1. add DistributedSampler in your DataLoader 10 | 2. pass your model to DistributedDataParallel 11 | This usage is exactly the same as the torch.nn.parallel.DistributedDataParallel() 12 | See imagenet example here: https://github.com/pytorch/examples/blob/master/imagenet/main.py#L88 13 | 14 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 15 | and will be allreduced at the finish of the backward pass. 16 | ''' 17 | 18 | 19 | class DistributedDataParallel(Module): 20 | def __init__(self, module): 21 | super(DistributedDataParallel, self).__init__() 22 | self.module = module 23 | self.first_call = True 24 | 25 | def allreduce_params(): 26 | if (self.needs_reduction): 27 | self.needs_reduction = False 28 | buckets = {} 29 | for param in self.module.parameters(): 30 | if param.requires_grad and param.grad is not None: 31 | tp = type(param.data) 32 | if tp not in buckets: 33 | buckets[tp] = [] 34 | buckets[tp].append(param) 35 | 36 | for tp in buckets: 37 | bucket = buckets[tp] 38 | grads = [param.grad.data for param in bucket] 39 | coalesced = _flatten_dense_tensors(grads) 40 | dist.all_reduce(coalesced) 41 | coalesced /= dist.get_world_size() 42 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 43 | buf.copy_(synced) 44 | 45 | for param in list(self.module.parameters()): 46 | def allreduce_hook(*unused): 47 | param._execution_engine.queue_callback(allreduce_params) 48 | 49 | if param.requires_grad: 50 | param.register_hook(allreduce_hook) 51 | def weight_broadcast(self): 52 | for param in self.module.parameters(): 53 | dist.broadcast(param.data, 0) 54 | """ 55 | for p in self.module.state_dict().values(): 56 | if not torch.is_tensor(p): 57 | continue 58 | dist.broadcast(p, 0) 59 | """ 60 | def forward(self, *inputs, **kwargs): 61 | if self.first_call: 62 | print("first broadcast start") 63 | self.weight_broadcast() 64 | self.first_call = False 65 | print("first broadcast done") 66 | self.needs_reduction = True 67 | return self.module(*inputs, **kwargs) 68 | -------------------------------------------------------------------------------- /mnist_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Synchronous SGD training on MNIST 3 | Use distributed MPI backend 4 | 5 | PyTorch distributed tutorial: 6 | http://pytorch.org/tutorials/intermediate/dist_tuto.html 7 | 8 | This example make following updates upon the tutorial 9 | 1. Add params sync at beginning of each epoch 10 | 2. Allreduce gradients across ranks, not averaging 11 | 3. Sync the shuffled index during data partition 12 | 4. Remove torch.multiprocessing in __main__ 13 | """ 14 | import os 15 | import sys 16 | import torch 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torch.distributed as dist 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | 24 | from math import ceil 25 | from random import Random 26 | from torch.multiprocessing import Process 27 | from torch.autograd import Variable 28 | from torchvision import datasets, transforms 29 | from data.distributed import DistributedDataParallel 30 | 31 | gbatch_size = 128 32 | 33 | class Partition(object): 34 | """ Dataset-like object, but only access a subset of it. """ 35 | 36 | def __init__(self, data, index): 37 | self.data = data 38 | self.index = index 39 | 40 | def __len__(self): 41 | return len(self.index) 42 | 43 | def __getitem__(self, index): 44 | data_idx = self.index[index] 45 | return self.data[data_idx] 46 | 47 | 48 | class DataPartitioner(object): 49 | """ Partitions a dataset into different chuncks. """ 50 | 51 | def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): 52 | self.data = data 53 | self.partitions = [] 54 | rng = Random() 55 | rng.seed(seed) 56 | data_len = len(data) 57 | indexes = [x for x in range(0, data_len)] 58 | rng.shuffle(indexes) 59 | """ 60 | Be cautious about index shuffle, this is performed on each rank 61 | The shuffled index must be unique across all ranks 62 | Theoretically with the same seed Random() generates the same sequence 63 | This might not be true in rare cases 64 | You can add an additional synchronization for 'indexes', just for safety 65 | Anyway, this won't take too much time 66 | e.g. 67 | dist.broadcast(indexes, 0) 68 | """ 69 | for frac in sizes: 70 | part_len = int(frac * data_len) 71 | self.partitions.append(indexes[0:part_len]) 72 | indexes = indexes[part_len:] 73 | 74 | def use(self, partition): 75 | return Partition(self.data, self.partitions[partition]) 76 | 77 | 78 | class Net(nn.Module): 79 | """ Network architecture. """ 80 | 81 | def __init__(self): 82 | super(Net, self).__init__() 83 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 84 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 85 | self.conv2_drop = nn.Dropout2d() 86 | self.fc1 = nn.Linear(320, 50) 87 | self.fc2 = nn.Linear(50, 10) 88 | 89 | def forward(self, x): 90 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 91 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 92 | x = x.view(-1, 320) 93 | x = F.relu(self.fc1(x)) 94 | x = F.dropout(x, training=self.training) 95 | x = self.fc2(x) 96 | return F.log_softmax(x, dim=1) 97 | 98 | 99 | def partition_dataset(): 100 | """ Partitioning MNIST """ 101 | dataset = datasets.MNIST( 102 | './data', 103 | train=True, 104 | download=True, 105 | transform=transforms.Compose([ 106 | transforms.ToTensor(), 107 | transforms.Normalize((0.1307, ), (0.3081, )) 108 | ])) 109 | size = dist.get_world_size() 110 | bsz = gbatch_size / float(size) 111 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 112 | train_set = torch.utils.data.DataLoader( 113 | dataset, batch_size=bsz, shuffle=(train_sampler is None), sampler=train_sampler) 114 | return train_set, bsz 115 | 116 | def run(rank, size): 117 | """ Distributed Synchronous SGD Example """ 118 | torch.manual_seed(1234) 119 | train_set, bsz = partition_dataset() 120 | model = Net() 121 | model = DistributedDataParallel(model) 122 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 123 | 124 | num_batches = ceil(len(train_set.dataset) / (float(bsz) * dist.get_world_size())) 125 | #print("num_batches = ", num_batches) 126 | for epoch in range(10): 127 | epoch_loss = 0.0 128 | for data, target in train_set: 129 | data, target = Variable(data), Variable(target) 130 | optimizer.zero_grad() 131 | output = model(data) 132 | loss = F.nll_loss(output, target) 133 | epoch_loss += loss.data[0] 134 | loss.backward() 135 | optimizer.step() 136 | print('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format( 137 | epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size())) 138 | 139 | def init_print(rank, size, debug_print=True): 140 | if not debug_print: 141 | """ In case run on hundreds of nodes, you may want to mute all the nodes except master """ 142 | if rank > 0: 143 | sys.stdout = open(os.devnull, 'w') 144 | sys.stderr = open(os.devnull, 'w') 145 | else: 146 | # labelled print with info of [rank/size] 147 | old_out = sys.stdout 148 | class LabeledStdout: 149 | def __init__(self, rank, size): 150 | self._r = rank 151 | self._s = size 152 | self.flush = sys.stdout.flush 153 | 154 | def write(self, x): 155 | if x == '\n': 156 | old_out.write(x) 157 | else: 158 | old_out.write('[%d/%d] %s' % (self._r, self._s, x)) 159 | 160 | sys.stdout = LabeledStdout(rank, size) 161 | 162 | if __name__ == "__main__": 163 | dist.init_process_group(backend='mpi') 164 | size = dist.get_world_size() 165 | rank = dist.get_rank() 166 | init_print(rank, size) 167 | 168 | run(rank, size) 169 | -------------------------------------------------------------------------------- /from_mingfei/mnist_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Synchronous SGD training on MNIST 3 | Use distributed MPI backend 4 | 5 | PyTorch distributed tutorial: 6 | http://pytorch.org/tutorials/intermediate/dist_tuto.html 7 | 8 | This example make following updates upon the tutorial 9 | 1. Add params sync at beginning of each epoch 10 | 2. Allreduce gradients across ranks, not averaging 11 | 3. Sync the shuffled index during data partition 12 | 4. Remove torch.multiprocessing in __main__ 13 | """ 14 | import os 15 | import sys 16 | import torch 17 | import torch.distributed as dist 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.optim as optim 21 | 22 | from math import ceil 23 | from random import Random 24 | from torch.multiprocessing import Process 25 | from torch.autograd import Variable 26 | from torchvision import datasets, transforms 27 | 28 | gbatch_size = 128 29 | 30 | class Partition(object): 31 | """ Dataset-like object, but only access a subset of it. """ 32 | 33 | def __init__(self, data, index): 34 | self.data = data 35 | self.index = index 36 | 37 | def __len__(self): 38 | return len(self.index) 39 | 40 | def __getitem__(self, index): 41 | data_idx = self.index[index] 42 | return self.data[data_idx] 43 | 44 | 45 | class DataPartitioner(object): 46 | """ Partitions a dataset into different chuncks. """ 47 | 48 | def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): 49 | self.data = data 50 | self.partitions = [] 51 | rng = Random() 52 | rng.seed(seed) 53 | data_len = len(data) 54 | indexes = [x for x in range(0, data_len)] 55 | rng.shuffle(indexes) 56 | """ 57 | Be cautious about index shuffle, this is performed on each rank 58 | The shuffled index must be unique across all ranks 59 | Theoretically with the same seed Random() generates the same sequence 60 | This might not be true in rare cases 61 | You can add an additional synchronization for 'indexes', just for safety 62 | Anyway, this won't take too much time 63 | e.g. 64 | dist.broadcast(indexes, 0) 65 | """ 66 | for frac in sizes: 67 | part_len = int(frac * data_len) 68 | self.partitions.append(indexes[0:part_len]) 69 | indexes = indexes[part_len:] 70 | 71 | def use(self, partition): 72 | return Partition(self.data, self.partitions[partition]) 73 | 74 | 75 | class Net(nn.Module): 76 | """ Network architecture. """ 77 | 78 | def __init__(self): 79 | super(Net, self).__init__() 80 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 81 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 82 | self.conv2_drop = nn.Dropout2d() 83 | self.fc1 = nn.Linear(320, 50) 84 | self.fc2 = nn.Linear(50, 10) 85 | 86 | def forward(self, x): 87 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 88 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 89 | x = x.view(-1, 320) 90 | x = F.relu(self.fc1(x)) 91 | x = F.dropout(x, training=self.training) 92 | x = self.fc2(x) 93 | return F.log_softmax(x, dim=1) 94 | 95 | 96 | def partition_dataset(): 97 | """ Partitioning MNIST """ 98 | dataset = datasets.MNIST( 99 | './data', 100 | train=True, 101 | download=True, 102 | transform=transforms.Compose([ 103 | transforms.ToTensor(), 104 | transforms.Normalize((0.1307, ), (0.3081, )) 105 | ])) 106 | size = dist.get_world_size() 107 | bsz = gbatch_size / float(size) 108 | partition_sizes = [1.0 / size for _ in range(size)] 109 | partition = DataPartitioner(dataset, partition_sizes) 110 | partition = partition.use(dist.get_rank()) 111 | train_set = torch.utils.data.DataLoader( 112 | partition, batch_size=bsz, shuffle=True) 113 | return train_set, bsz 114 | 115 | def sync_params(model): 116 | """ broadcast rank 0 parameter to all ranks """ 117 | for param in model.parameters(): 118 | dist.broadcast(param.data, 0) 119 | 120 | def sync_grads(model): 121 | """ all_reduce grads from all ranks """ 122 | for param in model.parameters(): 123 | dist.all_reduce(param.grad.data) 124 | 125 | def run(rank, size): 126 | """ Distributed Synchronous SGD Example """ 127 | torch.manual_seed(1234) 128 | train_set, bsz = partition_dataset() 129 | model = Net() 130 | model = model 131 | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 132 | 133 | num_batches = ceil(len(train_set.dataset) / float(bsz)) 134 | #print("num_batches = ", num_batches) 135 | for epoch in range(10): 136 | epoch_loss = 0.0 137 | # make sure we have the same parameters for all ranks 138 | sync_params(model) 139 | for data, target in train_set: 140 | data, target = Variable(data), Variable(target) 141 | optimizer.zero_grad() 142 | output = model(data) 143 | loss = F.nll_loss(output, target) 144 | epoch_loss += loss.data[0] 145 | loss.backward() 146 | # all_reduce grads 147 | sync_grads(model) 148 | optimizer.step() 149 | print('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format( 150 | epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size())) 151 | 152 | def init_print(rank, size, debug_print=True): 153 | if not debug_print: 154 | """ In case run on hundreds of nodes, you may want to mute all the nodes except master """ 155 | if rank > 0: 156 | sys.stdout = open(os.devnull, 'w') 157 | sys.stderr = open(os.devnull, 'w') 158 | else: 159 | # labelled print with info of [rank/size] 160 | old_out = sys.stdout 161 | class LabeledStdout: 162 | def __init__(self, rank, size): 163 | self._r = rank 164 | self._s = size 165 | self.flush = sys.stdout.flush 166 | 167 | def write(self, x): 168 | if x == '\n': 169 | old_out.write(x) 170 | else: 171 | old_out.write('[%d/%d] %s' % (self._r, self._s, x)) 172 | 173 | sys.stdout = LabeledStdout(rank, size) 174 | 175 | if __name__ == "__main__": 176 | dist.init_process_group(backend='mpi') 177 | size = dist.get_world_size() 178 | rank = dist.get_rank() 179 | init_print(rank, size) 180 | 181 | run(rank, size) 182 | --------------------------------------------------------------------------------