├── README.md ├── environment.yml ├── test_dataparallel.py └── train_model.py /README.md: -------------------------------------------------------------------------------- 1 | # Tutorial to implement different versions of all-reduce 2 | 3 | ## Create and activate conda environment 4 | 1. Install Anaconda python virtual environment manager (miniconda is recommended) 5 | 2. Create conda environment and install required packages `conda env create -n allreduce_env -f environment.yml` 6 | 3. Activate the environment `conda activate allreduce_env` 7 | 8 | ## Install PyCharm 9 | https://www.jetbrains.com/pycharm/download/#section=windows 10 | 11 | ## Modify allreduce implementation 12 | Search for `# Modify gradient allreduce here` and update code there. Replace star-reduce code with ring-allreduce as per: 13 | 1. https://towardsdatascience.com/distributed-deep-learning-with-horovod-2d1eea004cb2 14 | 2. https://towardsdatascience.com/visual-intuition-on-ring-allreduce-for-distributed-deep-learning-d1f34b4911da 15 | 16 | ## Run tests to make sure that weight all-reduce gives correct results 17 | ```bash 18 | python test_dataparallel.py 19 | ``` 20 | 21 | ## Run neural network training on MNIST datastet 22 | Training is run for the reference and dataparallel models simultaneously 23 | ```bash 24 | python train_model.py 25 | ``` -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mnist 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - atomicwrites=1.4.0=pyh9f0ad1d_0 7 | - attrs=20.3.0=pyhd3deb0d_0 8 | - ca-certificates=2020.12.5=h5b45459_0 9 | - certifi=2020.12.5=py38haa244fe_1 10 | - colorama=0.4.4=pyh9f0ad1d_0 11 | - iniconfig=1.1.1=pyh9f0ad1d_0 12 | - intel-openmp=2021.2.0=h57928b3_616 13 | - libblas=3.9.0=8_mkl 14 | - libcblas=3.9.0=8_mkl 15 | - liblapack=3.9.0=8_mkl 16 | - mkl=2020.4=hb70f87d_311 17 | - more-itertools=8.7.0=pyhd8ed1ab_0 18 | - numpy=1.20.2=py38h09042cb_0 19 | - openssl=1.1.1k=h8ffe710_0 20 | - packaging=20.9=pyh44b312d_0 21 | - pip=21.0.1=pyhd8ed1ab_0 22 | - pluggy=0.13.1=py38haa244fe_4 23 | - py=1.10.0=pyhd3deb0d_0 24 | - pyparsing=2.4.7=pyh9f0ad1d_0 25 | - pytest=6.2.3=py38haa244fe_0 26 | - python=3.8.8=h7840368_0_cpython 27 | - python_abi=3.8=1_cp38 28 | - setuptools=49.6.0=py38haa244fe_3 29 | - sqlite=3.35.4=h8ffe710_0 30 | - toml=0.10.2=pyhd8ed1ab_0 31 | - vc=14.2=hb210afc_4 32 | - vs2015_runtime=14.28.29325=h5e1d092_4 33 | - wheel=0.36.2=pyhd3deb0d_0 34 | - wincertstore=0.2=py38haa244fe_1006 35 | - pip: 36 | - pillow==8.2.0 37 | - torch==1.8.1 38 | - torch-testing==0.0.2 39 | - torchvision==0.9.1 40 | - typing-extensions==3.7.4.3 41 | -------------------------------------------------------------------------------- /test_dataparallel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch_testing as tt 4 | 5 | from train_model import Trainer, parse_args 6 | 7 | 8 | class TestDataParallel(unittest.TestCase): 9 | 10 | def test_dataparallel(self): 11 | args = parse_args(external_args=[]) 12 | trainer = Trainer(args) 13 | 14 | trainer.reference_model.train(False) 15 | trainer.dataparallel_model.train(False) 16 | 17 | def _compare_models(): 18 | for i_layer, (ref_np, dp_np) in enumerate(zip( 19 | trainer.reference_model.named_parameters(), 20 | trainer.dataparallel_model.named_parameters())): 21 | 22 | if i_layer == 0: 23 | print(ref_np[0], dp_np[0]) 24 | print("Weights:") 25 | print(ref_np[1].data[0, 0, ...]) 26 | print(dp_np[1].data[0, 0, ...]) 27 | print("Grads:") 28 | if ref_np[1].grad is not None: 29 | print(ref_np[1].grad[0, 0, ...]) 30 | else: 31 | print("None") 32 | if dp_np[1].grad is not None: 33 | print(dp_np[1].grad[0, 0, ...]) 34 | else: 35 | print("None") 36 | print("") 37 | 38 | rtol = 2e-2 39 | atol = 1e-7 40 | tt.assert_allclose(ref_np[1].data, dp_np[1].data, rtol=rtol, atol=atol) 41 | if ref_np[1].grad is not None and dp_np[1].grad is not None: 42 | tt.assert_allclose(ref_np[1].grad, dp_np[1].grad, rtol=rtol) 43 | 44 | def _check_dp_models_equal(): 45 | dp_model = trainer.dataparallel_model 46 | for i_model, model in enumerate(dp_model.models): 47 | if i_model == dp_model.master_model_idx: 48 | continue 49 | master_model_params = dp_model.models[dp_model.master_model_idx].parameters() 50 | model_params = model.parameters() 51 | for i_layer, (master_param, secondary_param) in \ 52 | enumerate(zip(master_model_params, model_params)): 53 | if i_layer == 0: 54 | print(f"Master model and model {i_model}") 55 | print(master_param[0, 0, ...]) 56 | print(secondary_param[0, 0, ...]) 57 | # Important that after all-reduced gradients are applied, 58 | # all replica weights are bit-exactly equal even as float32 values! 59 | tt.assert_equal(master_param, secondary_param) 60 | 61 | print("Before step") 62 | _compare_models() 63 | _check_dp_models_equal() 64 | 65 | for batch_idx, (data, target) in enumerate(trainer.train_loader): 66 | data, target = data.to(trainer.device), target.to(trainer.device) 67 | 68 | step_info_ref = trainer.reference_model.step(data, target) 69 | ref_loss = step_info_ref["loss"] 70 | 71 | step_info_dp = trainer.dataparallel_model.step(data, target) 72 | dp_loss = step_info_dp["loss"] 73 | 74 | print("After step") 75 | print(f"Loss, reference={ref_loss} dp={dp_loss}") 76 | _compare_models() 77 | _check_dp_models_equal() 78 | 79 | break 80 | 81 | return 82 | 83 | 84 | if __name__ == "__main__": 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | import contextlib 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torch.utils.data 11 | from torchvision import datasets, transforms 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | 15 | class Net(nn.Module): 16 | def __init__(self): 17 | super(Net, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 32, 3, stride=1) 19 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1) 20 | self.dropout1 = nn.Dropout(0.25) 21 | self.dropout2 = nn.Dropout(0.5) 22 | self.fc1 = nn.Linear(9216, 128) 23 | self.fc2 = nn.Linear(128, 10) 24 | 25 | def forward(self, x): 26 | x = self.conv1(x) 27 | x = F.relu(x) 28 | x = self.conv2(x) 29 | x = F.relu(x) 30 | x = F.max_pool2d(x, 2) 31 | x = self.dropout1(x) 32 | x = torch.flatten(x, 1) 33 | x = self.fc1(x) 34 | x = F.relu(x) 35 | x = self.dropout2(x) 36 | x = self.fc2(x) 37 | output = F.log_softmax(x, dim=1) 38 | return output 39 | 40 | 41 | class GenericModel: 42 | def __init__(self): 43 | self.is_train = True 44 | 45 | def train(self, is_train: bool = True): 46 | self.is_train = is_train 47 | 48 | 49 | class ReferenceModel(GenericModel): 50 | def __init__(self, net_factory, optimizer_factory, lr_scheduler_factory, loss_functor): 51 | super().__init__() 52 | 53 | self.net_factory = net_factory 54 | self.optimizer_factory = optimizer_factory 55 | self.lr_scheduler_factory = lr_scheduler_factory 56 | self.loss_functor = loss_functor 57 | 58 | self.model: nn.Module = net_factory() 59 | self.optimizer = optimizer_factory(self.model.parameters()) 60 | self.scheduler = self.lr_scheduler_factory(self.optimizer) 61 | 62 | def step(self, data, target, no_grad=False, dry_run=False): 63 | if not no_grad: 64 | self.optimizer.zero_grad() 65 | with torch.no_grad() if no_grad else contextlib.nullcontext(): 66 | output = self.model(data) 67 | loss = self.loss_functor(output, target) 68 | if not no_grad: 69 | loss.backward() 70 | if not dry_run and not no_grad: 71 | self.optimizer.step() 72 | return dict(loss=loss.item(), output=output.detach()) 73 | 74 | def get_model(self): 75 | return self.model 76 | 77 | def named_parameters(self): 78 | return self.model.named_parameters() 79 | 80 | def train(self, is_train: bool = True): 81 | super().train(is_train) 82 | self.model.train(is_train) 83 | 84 | def lr_scheduler_step(self): 85 | self.scheduler.step() 86 | 87 | 88 | class DataparallelModel(GenericModel): 89 | def __init__(self, net_factory, optimizer_factory, lr_scheduler_factory, 90 | loss_functor, num_replicas, reference_model=None): 91 | super().__init__() 92 | 93 | self.net_factory = net_factory 94 | self.optimizer_factory = optimizer_factory 95 | self.lr_scheduler_factory = lr_scheduler_factory 96 | self.loss_functor = loss_functor 97 | self.num_replicas = num_replicas 98 | 99 | models = [] 100 | optimizers = [] 101 | schedulers = [] 102 | for i_replica in range(num_replicas): 103 | model = net_factory() 104 | models.append(model) 105 | optimizer = self.optimizer_factory(model.parameters()) 106 | optimizers.append(optimizer) 107 | scheduler = self.lr_scheduler_factory(optimizer) 108 | schedulers.append(scheduler) 109 | 110 | self.models = models 111 | self.optimizers = optimizers 112 | self.schedulers = schedulers 113 | 114 | self.master_model_idx = 0 115 | 116 | if reference_model is not None: 117 | # If a reference model is given, broadcast it 118 | for param_group, ref_param in \ 119 | zip(self.param_group_gen(), reference_model.named_parameters()): 120 | for param in param_group: 121 | param.data[...] = ref_param[1].data[...] 122 | else: 123 | # If there is no reference model, broadcast weights of the master model 124 | for param_group in self.param_group_gen(): 125 | assert self.master_model_idx == 0 126 | for param in param_group[1:]: 127 | param.data[...] = param_group[self.master_model_idx].data[...] 128 | 129 | def param_group_gen(self): 130 | param_groups = [m.parameters() for m in self.models] 131 | for group in zip(*param_groups): 132 | yield group 133 | 134 | def step(self, data, target, no_grad=False, dry_run=False): 135 | 136 | assert len(data) % self.num_replicas == 0 137 | offset = len(data) // self.num_replicas 138 | 139 | losses = [] 140 | outputs = [] 141 | for i_replica, (model, optimizer) in enumerate(zip(self.models, self.optimizers)): 142 | data_rep = data[i_replica*offset:(i_replica+1)*offset] 143 | target_rep = target[i_replica*offset:(i_replica+1)*offset] 144 | if not no_grad: 145 | optimizer.zero_grad() 146 | with torch.no_grad() if no_grad else contextlib.nullcontext(): 147 | output = model(data_rep) 148 | loss = self.loss_functor(output, target_rep) 149 | if not no_grad: 150 | loss.backward() 151 | losses.append(loss.item()) 152 | outputs.append(output.detach()) 153 | 154 | total_loss = np.mean(np.array(losses)) 155 | 156 | outputs = torch.cat(outputs, dim=0) 157 | 158 | if not no_grad: 159 | for param_group in self.param_group_gen(): 160 | param_group_data = tuple(p.grad for p in param_group) 161 | 162 | # Modify gradient allreduce here 163 | # Below is a star-allreduce implementation. Replace it with your own. 164 | reduced_tensor = torch.mean(torch.stack(param_group_data, dim=0), dim=0) 165 | for grad in param_group_data: 166 | grad[...] = reduced_tensor[...] 167 | 168 | if not dry_run and not no_grad: 169 | for i_replica, (model, optimizer) in enumerate(zip(self.models, self.optimizers)): 170 | optimizer.step() 171 | 172 | # check all replica weights are identical 173 | 174 | return dict(loss=total_loss, pred=outputs) 175 | 176 | def named_parameters(self): 177 | assert len(self.models) > 0 178 | return self.models[self.master_model_idx].named_parameters() 179 | 180 | def get_model(self): 181 | assert len(self.models) > 0 182 | return self.models[self.master_model_idx] 183 | 184 | def train(self, is_train: bool = True): 185 | super().train(is_train) 186 | for model in self.models: 187 | model.train(is_train) 188 | 189 | def lr_scheduler_step(self): 190 | for scheduler in self.schedulers: 191 | scheduler.step() 192 | 193 | 194 | class Trainer: 195 | def __init__(self, args): 196 | self.args = args 197 | 198 | use_cuda = not args.no_cuda and torch.cuda.is_available() 199 | 200 | torch.manual_seed(args.seed) 201 | 202 | self.device = torch.device("cuda" if use_cuda else "cpu") 203 | 204 | train_kwargs = {'batch_size': args.batch_size} 205 | test_kwargs = {'batch_size': args.test_batch_size} 206 | if use_cuda: 207 | cuda_kwargs = {'num_workers': 1, 208 | 'pin_memory': True, 209 | 'shuffle': True} 210 | train_kwargs.update(cuda_kwargs) 211 | test_kwargs.update(cuda_kwargs) 212 | 213 | transform = transforms.Compose([ 214 | transforms.ToTensor(), 215 | transforms.Normalize((0.1307,), (0.3081,)) 216 | ]) 217 | self.dataset_train = datasets.MNIST('../data', train=False, download=True, transform=transform) 218 | self.dataset_val = datasets.MNIST('../data', train=False, transform=transform) 219 | 220 | shrink_dataset = True 221 | if shrink_dataset: 222 | train_data_size = 2000 223 | val_data_size = 1000 224 | def _shrink_dataset(dataset, size): 225 | dataset.data = dataset.data[:size] 226 | dataset.targets = dataset.targets[:size] 227 | _shrink_dataset(self.dataset_train, train_data_size) 228 | _shrink_dataset(self.dataset_val, val_data_size) 229 | 230 | self.train_loader = torch.utils.data.DataLoader(self.dataset_train, **train_kwargs) 231 | self.test_loader = torch.utils.data.DataLoader(self.dataset_val, **test_kwargs) 232 | 233 | def net_factory(): 234 | return Net() 235 | 236 | def optimizer_factory(params): 237 | return optim.Adadelta(params, lr=args.lr) 238 | 239 | self.loss_func = partial(F.nll_loss, reduction="mean") 240 | 241 | def lr_scheduler_factory(optimizer): 242 | return StepLR(optimizer, step_size=1, gamma=args.gamma) 243 | 244 | self.reference_model = ReferenceModel(net_factory, optimizer_factory, 245 | lr_scheduler_factory, self.loss_func) 246 | 247 | num_replicas = args.num_replicas 248 | self.dataparallel_model = DataparallelModel(net_factory, optimizer_factory, lr_scheduler_factory, 249 | self.loss_func, num_replicas, 250 | reference_model=self.reference_model.get_model()) 251 | 252 | def train(self): 253 | for epoch in range(1, self.args.epochs + 1): 254 | self.train_epoch(epoch) 255 | self.test() 256 | self.reference_model.lr_scheduler_step() 257 | self.dataparallel_model.lr_scheduler_step() 258 | 259 | if self.args.save_model: 260 | torch.save(self.reference_model.get_model().state_dict(), "mnist_cnn_ref.pt") 261 | torch.save(self.dataparallel_model.get_model().state_dict(), "mnist_cnn_ref.pt") 262 | 263 | def train_epoch(self, epoch): 264 | self.reference_model.train(True) 265 | self.dataparallel_model.train(True) 266 | 267 | for batch_idx, (data, target) in enumerate(self.train_loader): 268 | data, target = data.to(self.device), target.to(self.device) 269 | 270 | step_info_ref = self.reference_model.step(data, target) 271 | ref_loss = step_info_ref["loss"] 272 | 273 | step_info_dp = self.dataparallel_model.step(data, target) 274 | dp_loss = step_info_dp["loss"] 275 | 276 | if batch_idx % self.args.log_interval == 0: 277 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tRef loss: {:.6f}\tDP loss: {:.6f}'.format( 278 | epoch, batch_idx * len(data), len(self.train_loader.dataset), 279 | 100. * batch_idx / len(self.train_loader), ref_loss, dp_loss)) 280 | 281 | def test(self): 282 | self.dataparallel_model.train(False) 283 | 284 | test_loss = 0 285 | correct = 0 286 | for data, target in self.test_loader: 287 | data, target = data.to(self.device), target.to(self.device) 288 | step_info_dp = self.dataparallel_model.step(data, target, no_grad=True) 289 | test_loss += step_info_dp["loss"] # sum up batch loss 290 | pred = step_info_dp["pred"].argmax(dim=1, keepdim=True) # get the index of the max log-probability 291 | correct += pred.eq(target.view_as(pred)).sum().item() 292 | 293 | test_loss /= len(self.test_loader.dataset) 294 | 295 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 296 | test_loss, correct, len(self.test_loader.dataset), 297 | 100. * correct / len(self.test_loader.dataset))) 298 | 299 | 300 | def parse_args(external_args=None): 301 | # Training settings 302 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 303 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 304 | help='input batch size for training (default: 64)') 305 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 306 | help='input batch size for testing (default: 1000)') 307 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 308 | help='number of epochs to train (default: 14)') 309 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 310 | help='learning rate (default: 1.0)') 311 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 312 | help='Learning rate step gamma (default: 0.7)') 313 | parser.add_argument('--no-cuda', action='store_true', default=False, 314 | help='disables CUDA training') 315 | parser.add_argument('--dry-run', action='store_true', default=False, 316 | help='quickly check a single pass') 317 | parser.add_argument('--seed', type=int, default=1, metavar='S', 318 | help='random seed (default: 1)') 319 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 320 | help='how many batches to wait before logging training status') 321 | parser.add_argument('--save-model', action='store_true', default=False, 322 | help='For Saving the current Model') 323 | parser.add_argument('--num-replicas', type=int, default=4, metavar='N', 324 | help='number of dataparallel replicas (default: 4)') 325 | if external_args is not None: 326 | args = parser.parse_args(external_args) 327 | else: 328 | args = parser.parse_args() 329 | return args 330 | 331 | 332 | def main(): 333 | args = parse_args() 334 | trainer = Trainer(args) 335 | trainer.train() 336 | 337 | 338 | if __name__ == '__main__': 339 | main() 340 | --------------------------------------------------------------------------------