├── .gitignore ├── LICENSE ├── README.md ├── ds_config.json ├── mnist_accelerate.py ├── mnist_ddp.py ├── mnist_ddp_mlflow.py ├── mnist_ddp_profiler.py ├── mnist_deepspeed.py ├── mnist_lightning_ddp.py ├── mnist_mp.py ├── requirements.txt ├── run-accelerate-gpu4.sh ├── run-accelerate-gpu8.sh ├── run-ddp-gpu1-mlflow.sh ├── run-ddp-gpu1-profiler.sh ├── run-ddp-gpu4-mlflow.sh ├── run-ddp-gpu4-smi-logging.sh ├── run-ddp-gpu4.sh ├── run-ddp-gpu8.sh ├── run-deepspeed-gpu4.sh ├── run-deepspeed-gpu8.sh ├── run-lightning-gpu4.sh └── run-lightning-gpu8.sh /.gitignore: -------------------------------------------------------------------------------- 1 | slurm*out 2 | *~ 3 | /data/ 4 | /MNIST/ 5 | /lightning_logs/ 6 | /*.ckpt 7 | /logs 8 | /mlruns 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 CSC - IT Center for Science 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch multi-GPU and multi-node examples for CSC's supercomputers 2 | 3 | [PyTorch distributed][pytorch_dist] and in particular 4 | `DistributedDataParallel` (DDP), offers a nice way of running 5 | multi-GPU and multi-node PyTorch jobs. Unfortunately, the PyTorch 6 | documentation has been a bit lacking in this area, and examples found 7 | online can often be out-of-date. 8 | 9 | To make usage of DDP on CSC's Supercomputers easier, we have created a 10 | set of examples on how to run simple DDP jobs on the cluster. Included 11 | are also examples with other frameworks, such as [PyTorch 12 | Lightning][lightning] and [DeepSpeed][deepspeed]. 13 | 14 | All examples train a simple CNN on MNIST. Scripts have been provided 15 | for the Puhti supercomputer, but can be used on other systems with 16 | minor modifications. 17 | 18 | For larger examples, see also our [Machine learning benchmarks 19 | repository](https://github.com/mvsjober/ml-benchmarks). 20 | 21 | Finally, you might also be interested in [CSC's machine learning 22 | guide](https://docs.csc.fi/support/tutorials/ml-guide/) and in 23 | particular the section on [Multi-GPU and multi-node 24 | machinelearing](https://docs.csc.fi/support/tutorials/ml-multi/). 25 | 26 | 27 | ## Multi-GPU, single-node 28 | 29 | The simplest case is using all four GPUs on a single node on Puhti. 30 | 31 | ```bash 32 | sbatch run-ddp-gpu4.sh 33 | ``` 34 | 35 | ## Multi-GPU, multi-node 36 | 37 | Example using two nodes, four GPUs on each giving a total of 8 GPUs (again, on Puhti): 38 | 39 | ```bash 40 | sbatch run-ddp-gpu8.sh 41 | ``` 42 | 43 | 44 | ## PyTorch Lightning examples 45 | 46 | Multi-GPU and multi-node jobs are even easier with [PyTorch 47 | Lightning][lightning]. The [official PyTorch Lightning now has a 48 | relatively good Slurm 49 | documentation](https://lightning.ai/docs/pytorch/stable/clouds/cluster_advanced.html?highlight=slurm), 50 | although it has to be modified a bit for Puhti. 51 | 52 | Four GPUs on single node on Puhti: 53 | 54 | ```bash 55 | sbatch run-lightning-gpu4.sh 56 | ``` 57 | 58 | Two nodes, 8 GPUs in total on Puhti: 59 | 60 | ```bash 61 | sbatch run-lightning-gpu8.sh 62 | ``` 63 | 64 | ## DeepSpeed examples 65 | 66 | [DeepSpeed][deepspeed] should work on Puhti and Mahti with the 67 | [PyTorch module](https://docs.csc.fi/apps/pytorch/) (from version 1.10 68 | onwards). 69 | 70 | Single-node with four GPUs (Puhti): 71 | 72 | ```bash 73 | sbatch run-deepspeed-gpu4.sh 74 | ``` 75 | 76 | Here we are using Slurm to launch a single process which uses DeepSpeed's 77 | launcher to launch four processes (one for each GPU). 78 | 79 | Two nodes, 8 GPUs in total (Puhti): 80 | 81 | ```bash 82 | sbatch run-deepspeed-gpu8.sh 83 | ``` 84 | 85 | Note that we are using Slurm's `srun` to launch four processess on each node 86 | (one per GPU), and instead of DeepSpeed's launcher we are relying on MPI to 87 | provide it the information it needs to communicate between all the processes. 88 | 89 | 90 | [pytorch_dist]: https://pytorch.org/tutorials/beginner/dist_overview.html 91 | [ddp]: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html 92 | [lightning]: https://www.pytorchlightning.ai/ 93 | [deepspeed]: https://www.deepspeed.ai/ 94 | -------------------------------------------------------------------------------- /ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 400, 3 | "steps_per_print": 2000, 4 | "optimizer": { 5 | "type": "Adam", 6 | "params": { 7 | "lr": 0.001, 8 | "betas": [ 9 | 0.8, 10 | 0.999 11 | ], 12 | "eps": 1e-8, 13 | "weight_decay": 3e-7 14 | } 15 | }, 16 | "scheduler": { 17 | "type": "WarmupLR", 18 | "params": { 19 | "warmup_min_lr": 0, 20 | "warmup_max_lr": 0.001, 21 | "warmup_num_steps": 1000 22 | } 23 | }, 24 | "gradient_clipping": 1.0, 25 | "prescale_gradients": false, 26 | "wall_clock_breakdown": false 27 | } 28 | -------------------------------------------------------------------------------- /mnist_accelerate.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import argparse 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | # import torch.distributed as dist 7 | import torchvision.transforms as transforms 8 | from torchvision.datasets import MNIST 9 | # from torch.utils.data.distributed import DistributedSampler 10 | # from torch.nn.parallel import DistributedDataParallel 11 | from torch.utils.data import DataLoader 12 | from accelerate import Accelerator 13 | 14 | 15 | class ConvNet(nn.Module): 16 | def __init__(self, num_classes=10): 17 | super(ConvNet, self).__init__() 18 | self.layer1 = nn.Sequential( 19 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 20 | nn.BatchNorm2d(16), 21 | nn.ReLU(), 22 | nn.MaxPool2d(kernel_size=2, stride=2)) 23 | self.layer2 = nn.Sequential( 24 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 25 | nn.BatchNorm2d(32), 26 | nn.ReLU(), 27 | nn.MaxPool2d(kernel_size=2, stride=2)) 28 | self.fc = nn.Linear(7*7*32, num_classes) 29 | 30 | def forward(self, x): 31 | out = self.layer1(x) 32 | out = self.layer2(out) 33 | out = out.reshape(out.size(0), -1) 34 | out = self.fc(out) 35 | return out 36 | 37 | 38 | def train(num_epochs): 39 | accelerator = Accelerator() 40 | # dist.init_process_group(backend='nccl') 41 | 42 | torch.manual_seed(0) 43 | # local_rank = int(os.environ['LOCAL_RANK']) 44 | # torch.cuda.set_device(local_rank) 45 | 46 | #verbose = dist.get_rank() == 0 # print only on global_rank==0 47 | verbose = accelerator.is_main_process # print only in main process 48 | 49 | model = ConvNet().cuda() 50 | batch_size = 100 51 | 52 | criterion = nn.CrossEntropyLoss().cuda() 53 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 54 | 55 | # model = DistributedDataParallel(model, device_ids=[local_rank]) 56 | 57 | train_dataset = MNIST(root='./data', train=True, 58 | transform=transforms.ToTensor(), download=True) 59 | # train_sampler = DistributedSampler(train_dataset) 60 | # train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 61 | # shuffle=False, num_workers=0, pin_memory=True, 62 | # sampler=train_sampler) 63 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 64 | shuffle=False, num_workers=0, pin_memory=True) 65 | 66 | train_loader, model, optimizer = accelerator.prepare(train_loader, model, 67 | optimizer) 68 | 69 | start = datetime.now() 70 | for epoch in range(num_epochs): 71 | tot_loss = 0 72 | for i, (images, labels) in enumerate(train_loader): 73 | # images = images.cuda(non_blocking=True) 74 | # labels = labels.cuda(non_blocking=True) 75 | 76 | outputs = model(images) 77 | loss = criterion(outputs, labels) 78 | 79 | accelerator.backward(loss) 80 | optimizer.step() 81 | # loss.backward() 82 | optimizer.zero_grad() 83 | 84 | tot_loss += loss.item() 85 | 86 | if verbose: 87 | print('Epoch [{}/{}], average loss: {:.4f}'.format( 88 | epoch + 1, 89 | num_epochs, 90 | tot_loss / (i+1))) 91 | if verbose: 92 | print("Training completed in: " + str(datetime.now() - start)) 93 | 94 | 95 | def main(): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 98 | help='number of total epochs to run') 99 | args = parser.parse_args() 100 | 101 | train(args.epochs) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /mnist_ddp.py: -------------------------------------------------------------------------------- 1 | # Based on multiprocessing example from 2 | # https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html 3 | 4 | from datetime import datetime 5 | import argparse 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | import torchvision.transforms as transforms 11 | from torchvision.datasets import MNIST 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch.utils.data import DataLoader 15 | 16 | 17 | class ConvNet(nn.Module): 18 | def __init__(self, num_classes=10): 19 | super(ConvNet, self).__init__() 20 | self.layer1 = nn.Sequential( 21 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 22 | nn.BatchNorm2d(16), 23 | nn.ReLU(), 24 | nn.MaxPool2d(kernel_size=2, stride=2)) 25 | self.layer2 = nn.Sequential( 26 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 27 | nn.BatchNorm2d(32), 28 | nn.ReLU(), 29 | nn.MaxPool2d(kernel_size=2, stride=2)) 30 | self.fc = nn.Linear(7*7*32, num_classes) 31 | 32 | def forward(self, x): 33 | out = self.layer1(x) 34 | out = self.layer2(out) 35 | out = out.reshape(out.size(0), -1) 36 | out = self.fc(out) 37 | return out 38 | 39 | 40 | def train(num_epochs): 41 | dist.init_process_group(backend='nccl') 42 | 43 | torch.manual_seed(0) 44 | local_rank = int(os.environ['LOCAL_RANK']) 45 | torch.cuda.set_device(local_rank) 46 | 47 | verbose = dist.get_rank() == 0 # print only on global_rank==0 48 | 49 | model = ConvNet().cuda() 50 | batch_size = 100 51 | 52 | criterion = nn.CrossEntropyLoss().cuda() 53 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 54 | 55 | model = DistributedDataParallel(model, device_ids=[local_rank]) 56 | 57 | train_dataset = MNIST(root='./data', train=True, 58 | transform=transforms.ToTensor(), download=True) 59 | train_sampler = DistributedSampler(train_dataset) 60 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 61 | shuffle=False, num_workers=0, pin_memory=True, 62 | sampler=train_sampler) 63 | 64 | start = datetime.now() 65 | for epoch in range(num_epochs): 66 | tot_loss = 0 67 | for i, (images, labels) in enumerate(train_loader): 68 | images = images.cuda(non_blocking=True) 69 | labels = labels.cuda(non_blocking=True) 70 | 71 | outputs = model(images) 72 | loss = criterion(outputs, labels) 73 | 74 | optimizer.zero_grad() 75 | loss.backward() 76 | optimizer.step() 77 | 78 | tot_loss += loss.item() 79 | 80 | if verbose: 81 | print('Epoch [{}/{}], average loss: {:.4f}'.format( 82 | epoch + 1, 83 | num_epochs, 84 | tot_loss / (i+1))) 85 | if verbose: 86 | print("Training completed in: " + str(datetime.now() - start)) 87 | 88 | 89 | def main(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 92 | help='number of total epochs to run') 93 | args = parser.parse_args() 94 | 95 | train(args.epochs) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /mnist_ddp_mlflow.py: -------------------------------------------------------------------------------- 1 | # Based on multiprocessing example from 2 | # https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html 3 | 4 | from datetime import datetime 5 | import argparse 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | import torchvision.transforms as transforms 11 | from torchvision.datasets import MNIST 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch.utils.data import DataLoader 15 | import mlflow 16 | 17 | class ConvNet(nn.Module): 18 | def __init__(self, num_classes=10): 19 | super(ConvNet, self).__init__() 20 | self.layer1 = nn.Sequential( 21 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 22 | nn.BatchNorm2d(16), 23 | nn.ReLU(), 24 | nn.MaxPool2d(kernel_size=2, stride=2)) 25 | self.layer2 = nn.Sequential( 26 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 27 | nn.BatchNorm2d(32), 28 | nn.ReLU(), 29 | nn.MaxPool2d(kernel_size=2, stride=2)) 30 | self.fc = nn.Linear(7*7*32, num_classes) 31 | 32 | def forward(self, x): 33 | out = self.layer1(x) 34 | out = self.layer2(out) 35 | out = out.reshape(out.size(0), -1) 36 | out = self.fc(out) 37 | return out 38 | 39 | 40 | def train(num_epochs): 41 | dist.init_process_group(backend='nccl') 42 | 43 | torch.manual_seed(0) 44 | local_rank = int(os.environ['LOCAL_RANK']) 45 | torch.cuda.set_device(local_rank) 46 | 47 | verbose = dist.get_rank() == 0 # print only on global_rank==0 48 | if verbose: 49 | mlflow.set_tracking_uri("/scratch/project_2001659/mvsjober/mlruns") 50 | #mlflow.set_tracking_uri("sqlite:////scratch/project_2001659/mvsjober/mlruns.db") 51 | #mlflow.set_tracking_uri('https://mats-mlflow2.rahtiapp.fi/') 52 | 53 | mlflow.start_run(run_name=os.getenv("SLURM_JOB_ID")) 54 | 55 | model = ConvNet().cuda() 56 | batch_size = 100 57 | 58 | criterion = nn.CrossEntropyLoss().cuda() 59 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 60 | 61 | model = DistributedDataParallel(model, device_ids=[local_rank]) 62 | 63 | train_dataset = MNIST(root='./data', train=True, 64 | transform=transforms.ToTensor(), download=True) 65 | train_sampler = DistributedSampler(train_dataset) 66 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 67 | shuffle=False, num_workers=0, pin_memory=True, 68 | sampler=train_sampler) 69 | 70 | start = datetime.now() 71 | for epoch in range(num_epochs): 72 | tot_loss = 0 73 | for i, (images, labels) in enumerate(train_loader): 74 | images = images.cuda(non_blocking=True) 75 | labels = labels.cuda(non_blocking=True) 76 | 77 | outputs = model(images) 78 | loss = criterion(outputs, labels) 79 | 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | tot_loss += loss.item() 85 | 86 | if verbose: 87 | tot_loss = tot_loss / (i+1) 88 | mlflow.log_metric("loss", tot_loss) 89 | 90 | print('Epoch [{}/{}], average loss: {:.4f}'.format( 91 | epoch + 1, 92 | num_epochs, 93 | tot_loss)) 94 | if verbose: 95 | print("Training completed in: " + str(datetime.now() - start)) 96 | 97 | 98 | def main(): 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 101 | help='number of total epochs to run') 102 | args = parser.parse_args() 103 | 104 | train(args.epochs) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /mnist_ddp_profiler.py: -------------------------------------------------------------------------------- 1 | # Based on multiprocessing example from 2 | # https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html 3 | 4 | from datetime import datetime 5 | import argparse 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed as dist 10 | import torchvision.transforms as transforms 11 | from torchvision.datasets import MNIST 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch.utils.data import DataLoader 15 | from torch.profiler import profile, record_function, ProfilerActivity 16 | 17 | 18 | class ConvNet(nn.Module): 19 | def __init__(self, num_classes=10): 20 | super(ConvNet, self).__init__() 21 | self.layer1 = nn.Sequential( 22 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 23 | nn.BatchNorm2d(16), 24 | nn.ReLU(), 25 | nn.MaxPool2d(kernel_size=2, stride=2)) 26 | self.layer2 = nn.Sequential( 27 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 28 | nn.BatchNorm2d(32), 29 | nn.ReLU(), 30 | nn.MaxPool2d(kernel_size=2, stride=2)) 31 | self.fc = nn.Linear(7*7*32, num_classes) 32 | 33 | def forward(self, x): 34 | out = self.layer1(x) 35 | out = self.layer2(out) 36 | out = out.reshape(out.size(0), -1) 37 | out = self.fc(out) 38 | return out 39 | 40 | 41 | def train(num_epochs): 42 | dist.init_process_group(backend='nccl') 43 | 44 | torch.manual_seed(0) 45 | local_rank = int(os.environ['LOCAL_RANK']) 46 | torch.cuda.set_device(local_rank) 47 | 48 | verbose = dist.get_rank() == 0 # print only on global_rank==0 49 | 50 | prof = profile( 51 | schedule=torch.profiler.schedule( 52 | skip_first=10, 53 | wait=5, 54 | warmup=1, 55 | active=3, 56 | repeat=1) 57 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/profiler'), 58 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 59 | record_shapes=True, # record shapes of operator inputs 60 | profile_memory=True, # track tensor memory allocation/deallocation 61 | with_stack=True, # record source code information 62 | with_flops=True, # estimate FLOPS of operators 63 | ) 64 | 65 | model = ConvNet().cuda() 66 | batch_size = 100 67 | 68 | criterion = nn.CrossEntropyLoss().cuda() 69 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 70 | 71 | model = DistributedDataParallel(model, device_ids=[local_rank]) 72 | 73 | train_dataset = MNIST(root='./data', train=True, 74 | transform=transforms.ToTensor(), download=True) 75 | train_sampler = DistributedSampler(train_dataset) 76 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 77 | shuffle=False, num_workers=0, pin_memory=True, 78 | sampler=train_sampler) 79 | 80 | start = datetime.now() 81 | prof.start() 82 | for epoch in range(num_epochs): 83 | tot_loss = 0 84 | for i, (images, labels) in enumerate(train_loader): 85 | images = images.cuda(non_blocking=True) 86 | labels = labels.cuda(non_blocking=True) 87 | 88 | outputs = model(images) 89 | loss = criterion(outputs, labels) 90 | 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | prof.step() 96 | 97 | tot_loss += loss.item() 98 | 99 | if verbose: 100 | print('Epoch [{}/{}], average loss: {:.4f}'.format( 101 | epoch + 1, 102 | num_epochs, 103 | tot_loss / (i+1))) 104 | prof.stop() 105 | 106 | if verbose: 107 | print("Training completed in: " + str(datetime.now() - start)) 108 | 109 | 110 | def main(): 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 113 | help='number of total epochs to run') 114 | args = parser.parse_args() 115 | 116 | train(args.epochs) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /mnist_deepspeed.py: -------------------------------------------------------------------------------- 1 | # Based on multiprocessing example from 2 | # https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html 3 | 4 | from datetime import datetime 5 | import argparse 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | from torchvision.datasets import MNIST 11 | import deepspeed 12 | from datetime import timedelta 13 | 14 | 15 | class ConvNet(nn.Module): 16 | def __init__(self, num_classes=10): 17 | super(ConvNet, self).__init__() 18 | self.layer1 = nn.Sequential( 19 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 20 | nn.BatchNorm2d(16), 21 | nn.ReLU(), 22 | nn.MaxPool2d(kernel_size=2, stride=2)) 23 | self.layer2 = nn.Sequential( 24 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 25 | nn.BatchNorm2d(32), 26 | nn.ReLU(), 27 | nn.MaxPool2d(kernel_size=2, stride=2)) 28 | self.fc = nn.Linear(7*7*32, num_classes) 29 | 30 | def forward(self, x): 31 | out = self.layer1(x) 32 | out = self.layer2(out) 33 | out = out.reshape(out.size(0), -1) 34 | out = self.fc(out) 35 | return out 36 | 37 | 38 | def train(args): 39 | num_epochs = args.epochs 40 | local_rank = args.local_rank 41 | if local_rank == -1: 42 | local_rank = int(os.environ.get('PMIX_RANK', -1)) 43 | 44 | deepspeed.init_distributed(timeout=timedelta(minutes=5)) 45 | 46 | torch.manual_seed(0) 47 | model = ConvNet() 48 | 49 | criterion = nn.CrossEntropyLoss().cuda() 50 | 51 | train_dataset = MNIST(root='./data', train=True, 52 | transform=transforms.ToTensor(), download=True) 53 | 54 | model_engine, optimizer, train_loader, __ = deepspeed.initialize( 55 | args=args, model=model, model_parameters=model.parameters(), 56 | training_data=train_dataset) 57 | 58 | start = datetime.now() 59 | for epoch in range(num_epochs): 60 | tot_loss = 0 61 | for i, data in enumerate(train_loader): 62 | images = data[0].to(model_engine.local_rank) 63 | labels = data[1].to(model_engine.local_rank) 64 | 65 | outputs = model_engine(images) 66 | loss = criterion(outputs, labels) 67 | 68 | model_engine.backward(loss) 69 | model_engine.step() 70 | 71 | tot_loss += loss.item() 72 | 73 | if local_rank == 0: 74 | print('Epoch [{}/{}], average loss: {:.4f}'.format( 75 | epoch + 1, 76 | num_epochs, 77 | tot_loss / (i+1))) 78 | 79 | if local_rank == 0: 80 | print("Training completed in: " + str(datetime.now() - start)) 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 86 | help='number of total epochs to run') 87 | parser.add_argument('--local_rank', type=int, default=-1, 88 | help='local rank passed from distributed launcher') 89 | 90 | parser = deepspeed.add_config_arguments(parser) 91 | 92 | args = parser.parse_args() 93 | 94 | train(args) 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /mnist_lightning_ddp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | from torchvision.datasets import MNIST 8 | from torch.utils.data import DataLoader 9 | import lightning as L 10 | #import pytorch_lightning as L 11 | import mlflow 12 | 13 | class LitConvNet(L.LightningModule): 14 | def __init__(self, num_classes=10): 15 | super().__init__() 16 | self.layer1 = nn.Sequential( 17 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 18 | nn.BatchNorm2d(16), 19 | nn.ReLU(), 20 | nn.MaxPool2d(kernel_size=2, stride=2)) 21 | self.layer2 = nn.Sequential( 22 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 23 | nn.BatchNorm2d(32), 24 | nn.ReLU(), 25 | nn.MaxPool2d(kernel_size=2, stride=2)) 26 | self.fc = nn.Linear(7*7*32, num_classes) 27 | 28 | def forward(self, x): 29 | out = self.layer1(x) 30 | out = self.layer2(out) 31 | out = out.reshape(out.size(0), -1) 32 | out = self.fc(out) 33 | return out 34 | 35 | def training_step(self, batch, batch_idx): 36 | images, labels = batch 37 | outputs = self(images) 38 | loss = F.cross_entropy(outputs, labels) 39 | self.log("train_loss", loss) 40 | return loss 41 | 42 | def configure_optimizers(self): 43 | optimizer = torch.optim.SGD(self.parameters(), 1e-4) 44 | return optimizer 45 | 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--gpus', default=1, type=int, metavar='N', 50 | help='number of GPUs per node') 51 | parser.add_argument('--nodes', default=1, type=int, metavar='N', 52 | help='number of nodes') 53 | parser.add_argument('--epochs', default=2, type=int, metavar='N', 54 | help='maximum number of epochs to run') 55 | args = parser.parse_args() 56 | 57 | mlflow.autolog() 58 | 59 | print("Using PyTorch {} and Lightning {}".format(torch.__version__, L.__version__)) 60 | 61 | batch_size = 100 62 | dataset = MNIST(os.getcwd(), download=True, 63 | transform=transforms.ToTensor()) 64 | train_loader = DataLoader(dataset, batch_size=batch_size, 65 | num_workers=10, pin_memory=True) 66 | 67 | convnet = LitConvNet() 68 | 69 | trainer = L.Trainer(devices=args.gpus, 70 | num_nodes=args.nodes, 71 | max_epochs=args.epochs, 72 | accelerator='gpu', 73 | strategy='ddp') 74 | 75 | from datetime import datetime 76 | t0 = datetime.now() 77 | trainer.fit(convnet, train_loader) 78 | dt = datetime.now() - t0 79 | print('Training took {}'.format(dt)) 80 | 81 | trainer.save_checkpoint("lightning_model.ckpt") 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /mnist_mp.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets import MNIST 4 | import multiprocessing as mp 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ConvNet(nn.Module): 11 | def __init__(self, num_classes=10): 12 | super(ConvNet, self).__init__() 13 | self.layer1 = nn.Sequential( 14 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 15 | nn.BatchNorm2d(16), 16 | nn.ReLU(), 17 | nn.MaxPool2d(kernel_size=2, stride=2)) 18 | self.layer2 = nn.Sequential( 19 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 20 | nn.BatchNorm2d(32), 21 | nn.ReLU(), 22 | nn.MaxPool2d(kernel_size=2, stride=2)) 23 | self.fc = nn.Linear(7*7*32, num_classes) 24 | 25 | def forward(self, x): 26 | out = self.layer1(x) 27 | out = self.layer2(out) 28 | out = out.reshape(out.size(0), -1) 29 | out = self.fc(out) 30 | return out 31 | 32 | 33 | def train(batch_size): 34 | num_epochs = 100 35 | 36 | torch.manual_seed(0) 37 | verbose = True 38 | 39 | model = ConvNet().cuda() 40 | 41 | criterion = nn.CrossEntropyLoss().cuda() 42 | optimizer = torch.optim.SGD(model.parameters(), 1e-4) 43 | 44 | train_dataset = MNIST(root='./data', train=True, 45 | transform=transforms.ToTensor(), download=True) 46 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 47 | shuffle=False, num_workers=0, pin_memory=True) 48 | 49 | start = datetime.now() 50 | for epoch in range(num_epochs): 51 | tot_loss = 0 52 | for i, (images, labels) in enumerate(train_loader): 53 | images = images.cuda(non_blocking=True) 54 | labels = labels.cuda(non_blocking=True) 55 | 56 | outputs = model(images) 57 | loss = criterion(outputs, labels) 58 | 59 | optimizer.zero_grad() 60 | loss.backward() 61 | optimizer.step() 62 | 63 | tot_loss += loss.item() 64 | 65 | if verbose: 66 | print('Epoch [{}/{}], batch_size={} average loss: {:.4f}'.format( 67 | epoch + 1, 68 | num_epochs, 69 | batch_size, 70 | tot_loss / (i+1))) 71 | if verbose: 72 | print("Training completed in: " + str(datetime.now() - start)) 73 | 74 | 75 | if __name__ == '__main__': 76 | bs_list = [16, 32, 64, 128] 77 | num_processes = 4 78 | 79 | with mp.Pool(processes=num_processes) as pool: 80 | pool.map(train, bs_list) 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.2 2 | torchvision>=0.17.1 3 | deepspeed>=0.13.3 4 | mpi4py>=3.1.5 5 | mlflow>=2.10.2 6 | accelerate>=0.27.2 7 | lightning>=2.2 8 | -------------------------------------------------------------------------------- /run-accelerate-gpu4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=40 6 | #SBATCH --mem=320G 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:4 9 | 10 | module purge 11 | module load pytorch 12 | 13 | #pip install --user accelerate 14 | 15 | srun apptainer_wrapper exec accelerate launch --multi_gpu --num_processes=4 --num_machines=1 \ 16 | --mixed_precision=bf16 --dynamo_backend=no \ 17 | mnist_accelerate.py --epochs=100 18 | -------------------------------------------------------------------------------- /run-accelerate-gpu8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --nodes=2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=40 7 | #SBATCH --mem=320G 8 | #SBATCH --time=15 9 | #SBATCH --gres=gpu:v100:4 10 | 11 | module purge 12 | module load pytorch 13 | 14 | GPUS_PER_NODE=4 15 | MASTER_ADDR=$(hostname -i) 16 | MASTER_PORT=12802 17 | 18 | # Note: --machine_rank must be evaluated on each node, hence the LAUNCH_CMD setup 19 | export LAUNCH_CMD=" 20 | accelerate launch \ 21 | --multi_gpu --mixed_precision no \ 22 | --num_machines=${SLURM_NNODES} \ 23 | --num_processes=$(expr ${SLURM_NNODES} \* ${GPUS_PER_NODE}) \ 24 | --machine_rank=\${SLURM_NODEID} \ 25 | --main_process_ip=${MASTER_ADDR} \ 26 | --main_process_port=${MASTER_PORT} \ 27 | mnist_accelerate.py --epochs=100 \ 28 | " 29 | echo ${LAUNCH_CMD} 30 | srun singularity_wrapper exec bash -c "${LAUNCH_CMD}" 31 | -------------------------------------------------------------------------------- /run-ddp-gpu1-mlflow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=10 6 | #SBATCH --mem=64G 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:1 9 | 10 | module purge 11 | module load pytorch 12 | 13 | # Old way with torch.distributed.run 14 | # srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100 15 | 16 | # New way with torchrun 17 | srun torchrun --standalone --nnodes=1 --nproc_per_node=1 mnist_ddp_mlflow.py --epochs=100 18 | -------------------------------------------------------------------------------- /run-ddp-gpu1-profiler.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=10 6 | #SBATCH --mem=64G 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:1 9 | 10 | module purge 11 | module load pytorch 12 | 13 | # Old way with torch.distributed.run 14 | # srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100 15 | 16 | # New way with torchrun 17 | srun torchrun --standalone --nnodes=1 --nproc_per_node=1 mnist_ddp_profiler.py --epochs=100 18 | -------------------------------------------------------------------------------- /run-ddp-gpu4-mlflow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=40 6 | #SBATCH --mem=0 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:4 9 | 10 | module purge 11 | module load pytorch 12 | 13 | # Old way with torch.distributed.run 14 | # srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100 15 | 16 | # New way with torchrun 17 | srun torchrun --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp_mlflow.py --epochs=100 18 | -------------------------------------------------------------------------------- /run-ddp-gpu4-smi-logging.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=40 6 | #SBATCH --mem=0 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:4 9 | 10 | module purge 11 | module load pytorch 12 | 13 | # Old way with torch.distributed.run 14 | # srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100 15 | 16 | nvidia-smi --query-gpu=index,temperature.gpu,utilization.gpu,utilization.memory --format=csv -l > gpu.log & 17 | BG_PID=$! 18 | 19 | # New way with torchrun 20 | srun torchrun --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=2 21 | 22 | kill $BG_PID 23 | -------------------------------------------------------------------------------- /run-ddp-gpu4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=40 6 | #SBATCH --mem=0 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:4 9 | 10 | module purge 11 | module load pytorch 12 | 13 | # Old way with torch.distributed.run 14 | # srun python3 -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100 15 | 16 | # New way with torchrun 17 | srun torchrun --standalone --nnodes=1 --nproc_per_node=4 mnist_ddp.py --epochs=100 18 | -------------------------------------------------------------------------------- /run-ddp-gpu8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --nodes=2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=40 7 | #SBATCH --mem=0 8 | #SBATCH --time=15 9 | #SBATCH --gres=gpu:v100:4 10 | 11 | module purge 12 | module load pytorch 13 | 14 | export RDZV_HOST=$(hostname) 15 | export RDZV_PORT=29400 16 | 17 | # Old way with torch.distributed.run 18 | # srun python3 -m torch.distributed.run \ 19 | # --nnodes=$SLURM_JOB_NUM_NODES \ 20 | # --nproc_per_node=4 \ 21 | # --rdzv_id=$SLURM_JOB_ID \ 22 | # --rdzv_backend=c10d \ 23 | # --rdzv_endpoint="$RDZV_HOST:$RDZV_PORT" \ 24 | # mnist_ddp.py --epochs=100 25 | 26 | # New way with torchrun 27 | srun torchrun \ 28 | --nnodes=$SLURM_JOB_NUM_NODES \ 29 | --nproc_per_node=4 \ 30 | --rdzv_id=$SLURM_JOB_ID \ 31 | --rdzv_backend=c10d \ 32 | --rdzv_endpoint="$RDZV_HOST:$RDZV_PORT" \ 33 | mnist_ddp.py --epochs=100 34 | -------------------------------------------------------------------------------- /run-deepspeed-gpu4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --ntasks=1 5 | #SBATCH --cpus-per-task=40 6 | #SBATCH --mem=0 7 | #SBATCH --time=15 8 | #SBATCH --gres=gpu:v100:4 9 | 10 | module purge 11 | module load pytorch 12 | 13 | srun singularity_wrapper exec deepspeed mnist_deepspeed.py --epochs=100 \ 14 | --deepspeed --deepspeed_config ds_config.json 15 | -------------------------------------------------------------------------------- /run-deepspeed-gpu8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --nodes=2 5 | #SBATCH --ntasks-per-node=4 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --mem=0 8 | #SBATCH --time=15 9 | #SBATCH --gres=gpu:v100:4 10 | 11 | module purge 12 | module load pytorch 13 | 14 | srun python3 mnist_deepspeed.py --epochs=100 \ 15 | --deepspeed --deepspeed_config ds_config.json 16 | -------------------------------------------------------------------------------- /run-lightning-gpu4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=4 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --mem=0 8 | #SBATCH --time=15 9 | #SBATCH --gres=gpu:v100:4 10 | 11 | module purge 12 | module load pytorch 13 | 14 | srun python3 mnist_lightning_ddp.py --gpus=4 --epochs=100 15 | -------------------------------------------------------------------------------- /run-lightning-gpu8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_2001659 3 | #SBATCH --partition=gputest 4 | #SBATCH --nodes=2 5 | #SBATCH --ntasks-per-node=4 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --mem=0 8 | #SBATCH --time=15 9 | #SBATCH --gres=gpu:v100:4 10 | 11 | module purge 12 | module load pytorch 13 | 14 | srun python3 mnist_lightning_ddp.py --gpus=4 --nodes=2 --epochs=100 15 | --------------------------------------------------------------------------------