├── README.md
├── images
└── system_overview.jpg
├── src
├── README.md
├── data
│ └── data_prepare.py
├── data_prepare.sh
├── distributed_evaluator.py
├── distributed_nn.py
├── distributed_worker.py
├── evaluate_pytorch.sh
├── launch.sh
├── model_ops
│ ├── __init__.py
│ ├── lenet.py
│ ├── resnet.py
│ └── vgg.py
├── nn_ops.py
├── optim
│ ├── __init__.py
│ ├── adam.py
│ └── sgd.py
├── run_pytorch_dist.sh
├── run_pytorch_single.sh
├── sync_replicas_master_nn.py
└── util.py
└── tools
├── README.md
├── conda_install.sh
├── config
├── hosts
├── hosts_address
├── hosts_alias
├── install.sh
├── killall.sh
├── local_script.sh
├── openmpi_install.sh
├── pre_run.sh
├── pytorch_ec2.py
├── remote_script.sh
└── update_git_dir.sh
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-parameter-server
2 | Implementation of synchronous distributed machine learning in [Parameter Server](https://www.cs.cmu.edu/~muli/file/parameter_server_nips14.pdf) setup using [PyTorch's distributed communication library](https://pytorch.org/docs/stable/distributed.html) i.e. `torch.distributed`.
3 |
4 | All functionality in this repository is basically a repliaction of [ps_pytorch](https://github.com/hwang595/ps_pytorch). However, instead of using `Mpi4py`, all communications and model trainings are handled by PyTorch itself.
5 |
6 | ## Contents
7 |
8 | 1. [Motivations](#motivations)
9 | 2. [System design](#system-design)
10 | 3. [Basic usages](#basic-usages)
11 | 4. [How to prepare datasets](#prepare-datasets)
12 | 5. [How to launch a distributed task](#job-launching)
13 | 6. [Future work](#future-work)
14 |
15 | ## Motivations:
16 | 1. PyTorch provides easy-to-use APIs with dynamic computational graph
17 | 2. Altough [mpi4py](https://github.com/mpi4py/mpi4py) provides a good Python binding for any distributions of MPI and flexible communication operations, transforming data back and force (e.g. `torch.Tensor` <--> `numpy.array`) incurs heavy overheads during the entire training process.
18 | 3. PyTorch supports [NCCL](https://developer.nvidia.com/nccl) as its communication backend, which makes distributed training on GPU cluster becomes efficient and scalable.
19 |
20 | ## System Design:
21 | 1. Parameter Server: This node synchronizes all workers to enter next iteration by broadcast global step to workers and stores the global model, which will be pulled by workers at beginning of one iteration (we implement this stage using `torch.distributed.broadcast`). For a user defined frequency, Parameter Server will save the current model as checkpoint to shared file system (NFS in our system) for model evaluation.
22 | 2. workers mainly aim at sample data points (or mini-batch) from local dataset (we don't pass data among nodes to maintain data locality), computing gradients, and ship them back to Parameter Server (this stage is implemented using `torch.distributed.scatter`).
23 | 3. evaluator read the checkpoints from the shared directory, and do model evaluation. Note that: there is only testset data saved on evaluator nodes.
24 |
25 |

26 |
27 | ## Basic Usages
28 | ### Dependencies:
29 | Anaconda is highly recommended for installing depdencies for this project. Assume a conda setup machine is used, you can run
30 | ```
31 | bash ./tools/pre_run.sh
32 | ```
33 | to install all depdencies needed.
34 | ### Single Machine:
35 | The code base provided in this repository can be run on a single machine, in which multiple CPU processes will be launched and each process will be assigned a role as Parameter Server (usually process with id at 0) or worker. To do this, one can just follow the "Single-Node multi-process distributed training" part in [this tutorial](https://pytorch.org/docs/stable/distributed.html#launch-utility). We provide a script (`run_pytorch_single.sh`) to do the job for you. One can simply run
36 | ```
37 | bash ./src/run_pytorch_single.sh
38 | ```
39 |
40 | ### Cluster Setup:
41 | For running on distributed cluster, the first thing you need do is to launch AWS EC2 instances.
42 | #### Launching Instances:
43 | [This script](https://github.com/hwang595/PyTorch-parameter-server/blob/master/tools/pytorch_ec2.py) helps you to launch EC2 instances automatically, but before running this script, you should follow [the instruction](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html) to setup AWS CLI on your local machine.
44 | After that, please edit this part in `./tools/pytorch_ec2.py`
45 | ``` python
46 | cfg = Cfg({
47 | "name" : "PS_PYTORCH", # Unique name for this specific configuration
48 | "key_name": "NameOfKeyFile", # Necessary to ssh into created instances
49 | # Cluster topology
50 | "n_masters" : 1, # Should always be 1
51 | "n_workers" : 8,
52 | "num_replicas_to_aggregate" : "8", # deprecated, not necessary
53 | "method" : "spot",
54 | # Region speficiation
55 | "region" : "us-west-2",
56 | "availability_zone" : "us-west-2b",
57 | # Machine type - instance type configuration.
58 | "master_type" : "m4.2xlarge",
59 | "worker_type" : "m4.2xlarge",
60 | # please only use this AMI for pytorch
61 | "image_id": "ami-xxxxxxxx", # id of AMI
62 | # Launch specifications
63 | "spot_price" : "0.15", # Has to be a string
64 | # SSH configuration
65 | "ssh_username" : "ubuntu", # For sshing. E.G: ssh ssh_username@hostname
66 | "path_to_keyfile" : "/dir/to/NameOfKeyFile.pem",
67 |
68 | # NFS configuration
69 | # To set up these values, go to Services > ElasticFileSystem > Create new filesystem, and follow the directions.
70 | #"nfs_ip_address" : "172.31.3.173", # us-west-2c
71 | #"nfs_ip_address" : "172.31.35.0", # us-west-2a
72 | "nfs_ip_address" : "172.31.14.225", # us-west-2b
73 | "nfs_mount_point" : "/home/ubuntu/shared", # NFS base dir
74 | ```
75 | For setting everything up on EC2 cluster, the easiest way is to setup one machine and create an AMI. Then use the AMI id for `image_id` in `pytorch_ec2.py`. Then, launch EC2 instances by running
76 | ```
77 | python ./tools/pytorch_ec2.py launch
78 | ```
79 | After all launched instances are ready (this may take a while), getting private ips of instances by
80 | ```
81 | python ./tools/pytorch_ec2.py get_hosts
82 | ```
83 | this will write ips into a file named `hosts_address`, which looks like
84 | ```
85 | 172.31.16.226 (${PS_IP})
86 | 172.31.27.245
87 | 172.31.29.131
88 | 172.31.18.108
89 | ...
90 | ```
91 | After generating the `hosts_address` of all EC2 instances, running the following command will copy your keyfile to the parameter server (PS) instance whose address is always the first one in `hosts_address`. `local_script.sh` will also do some basic configurations e.g. clone this git repo
92 | ```
93 | bash ./tool/local_script.sh ${PS_IP}
94 | ```
95 | #### SSH related:
96 | At this stage, you should ssh to the PS instance and all operation should happen on PS. In PS setting, PS should be able to ssh to any compute node, [this part](https://github.com/hwang595/PyTorch-parameter-server/blob/master/tools/remote_script.sh#L8-L22) dose the job for you by running (after ssh to the PS)
97 | ```
98 | bash ./tools/remote_script.sh
99 | ```
100 |
101 | ## Prepare Datasets
102 | To download, split, and transform datasets by (and `./tools/remote_script.sh` dose this for you)
103 | ```
104 | bash ./src/data_prepare.sh
105 | ```
106 | One can simply extend script `./src/data/data_prepare.py` to support any datasets provided by [torchvision](https://github.com/pytorch/vision).
107 |
108 | ## Job Launching
109 | Since this project is built on MPI, tasks are required to be launched by PS (or master) instance. `launch.sh` (which will call `./src/run_pytorch_dist.sh`) wraps job-launching process up. Commonly used options (arguments) are listed as following:
110 |
111 | | Argument | Comments |
112 | | ----------------------------- | ---------------------------------------- |
113 | | `n` | Number of processes (size of cluster) e.g. if we have P compute node and 1 PS, n=P+1. |
114 | | `lr` | Inital learning rate that will be use. |
115 | | `momentum` | Value of momentum that will be use. |
116 | | `max-steps` | The maximum number of iterations to train. |
117 | | `epochs` | The maximal number of epochs to train (somehow redundant). |
118 | | `network` | Types of deep neural nets, currently `LeNet`, `ResNet-18/32/50/110/152`, and `VGGs` are supported. |
119 | | `dataset` | Datasets use for training. |
120 | | `batch-size` | Batch size for optimization algorithms. |
121 | | `eval-freq` | Frequency of iterations to evaluation the model. |
122 | | `enable-gpu`|Training on CPU/GPU, if CPU please leave this argument empty. |
123 | |`train-dir`|Directory to save model checkpoints for evaluation. |
124 |
125 | ## Model Evaluation
126 | [Distributed evaluator](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/distributed_evaluator.py) will fetch model checkpoints from the shared directory and evaluate model on validation set.
127 | To evaluate model, you can run
128 | ```
129 | bash ./src/evaluate_pytorch.sh
130 | ```
131 | with specified arguments.
132 |
133 | Evaluation arguments are listed as following:
134 |
135 | | Argument | Comments |
136 | | ----------------------------- | ---------------------------------------- |
137 | | `eval-batch-size` | Batch size (on validation set) used during model evaluation. |
138 | | `eval-freq` | Frequency of iterations to evaluation the model, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). |
139 | | `network` | Types of deep neural nets, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/run_pytorch_dist.sh). |
140 | | `dataset` | Datasets use for training, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/run_pytorch_dist.sh). |
141 | | `model-dir` | Directory to save model checkpoints for evaluation, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/run_pytorch_dist.sh). |
142 |
143 | ## Future work:
144 | (Please note that this project is still in early alpha version)
145 | 1. Overlapping computation (forward prop and backprop) with communication to gain better speedup.
146 | 2. Support async communication mode i.e. [Backup Worker](https://arxiv.org/pdf/1604.00981.pdf)
147 |
148 | ## Contact:
149 | Any contribution to this repo is highly appreciated.
150 | If you encountered any issue in using the code base provided here, please feel free to start an issue or email [Hongyi Wang](https://hwang595.github.io/) at (hongyiwang@cs.wisc.edu) directly.
151 |
--------------------------------------------------------------------------------
/images/system_overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang595/PyTorch-parameter-server/08726c9ed718fe0ee65c032801f632decf79ec79/images/system_overview.jpg
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | do remeber to transfer the `numpy.ndarray` to `numpy.float64` before push them to MPI send/receive buffer, otherwise there will be some data type transfer issues between numpy and MPI.
2 |
3 | # To use it on single machine:
4 | ```
5 | python single_machine.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE}
6 | ```
7 |
8 | # To use it on distributed cluster:
9 | ```
10 | mpirun -n ${NUM_WORKERS} --hostfile=${HOST_DIR} python distributed_nn.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE}
11 | ```
12 |
13 | # Run the whole thing automatically
14 | The first thing you need do is to launch AWS EC2 instances, you can do that using `tools/pytorch_ec2.py` by running the following command:
15 | ```
16 | python pytorch_ec2.py launch
17 | ```
18 | After the launch command are executed and all instances are initialized (this may cost several minutes), you need to fetch the host addresses:
19 | ```
20 | python pytorch_ec2.py get_hosts
21 | ```
22 | Then, copying essential configuration files and hosts files using the public address of master node (the first address in the `host` file):
23 | ```
24 | sh local_script.sh ${MASTER_PUB_ADDR}
25 | ```
26 | After that, launch to master node manually, running the remote script under `$HOME` dir:
27 | ```
28 | sh remote_script.sh
29 | ```
30 | This script will do the cluster setup and data preparation works for you.
31 |
--------------------------------------------------------------------------------
/src/data/data_prepare.py:
--------------------------------------------------------------------------------
1 | """
2 | Since we need to initialize the dataset in parallel, pre-download data is necessary
3 | This script will do the job for you
4 | """
5 | import torch
6 | from torchvision import datasets, transforms
7 |
8 |
9 | if __name__ == "__main__":
10 | training_set_mnist = datasets.MNIST('./mnist_data', train=True, download=True,
11 | transform=transforms.Compose([
12 | transforms.ToTensor(),
13 | transforms.Normalize((0.1307,), (0.3081,))]))
14 | train_loader_mnist = torch.utils.data.DataLoader(training_set_mnist, batch_size=128, shuffle=True)
15 | test_loader_mnist = torch.utils.data.DataLoader(
16 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([
17 | transforms.ToTensor(),
18 | transforms.Normalize((0.1307,), (0.3081,))
19 | ])), batch_size=100, shuffle=True)
20 | trainset_cifar10 = datasets.CIFAR10(root='./cifar10_data', train=True,
21 | download=True, transform=transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24 | ]))
25 | train_loader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size=128,
26 | shuffle=True)
27 | test_loader_cifar10 = torch.utils.data.DataLoader(
28 | datasets.CIFAR10('./cifar10_data', train=False, transform=transforms.Compose([
29 | transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
30 | ])), batch_size=100, shuffle=True)
31 | # load training and test set here:
32 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True,
33 | download=True, transform=transforms.Compose([
34 | transforms.ToTensor(),
35 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
36 | ]))
37 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128,
38 | shuffle=True)
39 | testset = datasets.CIFAR100(root='./cifar100_data', train=False,
40 | download=True, transform=transforms.Compose([
41 | transforms.ToTensor(),
42 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
43 | ]))
44 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000,
45 | shuffle=False)
46 |
47 | training_set = datasets.SVHN('./svhn_data', split='train', download=True, transform=transforms.Compose([
48 | transforms.RandomCrop(32, padding=4),
49 | transforms.RandomHorizontalFlip(),
50 | transforms.ToTensor(),
51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
52 | ]))
53 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128,
54 | shuffle=True)
55 | transform_test = transforms.Compose([
56 | transforms.ToTensor(),
57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
58 | ])
59 | testset = datasets.SVHN(root='./svhn_data', split='test',
60 | download=True, transform=transform_test)
61 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000,
62 | shuffle=False)
--------------------------------------------------------------------------------
/src/data_prepare.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | python ./data/data_prepare.py
--------------------------------------------------------------------------------
/src/distributed_evaluator.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os.path
3 | import time
4 | import argparse
5 | from datetime import datetime
6 | import copy
7 |
8 | import numpy as np
9 |
10 | from nn_ops import NN_Trainer
11 |
12 | import torch
13 | from torch.autograd import Variable
14 | import torch.nn.functional as F
15 | from torchvision import datasets, transforms
16 | from torch.utils.data import DataLoader
17 |
18 | from model_ops.lenet import LeNet, LeNetSplit
19 | from model_ops.resnet import *
20 | from model_ops.resnet_split import *
21 | from util import build_model
22 |
23 |
24 | def accuracy(output, target, topk=(1,)):
25 | """Computes the precision@k for the specified values of k"""
26 | maxk = max(topk)
27 | batch_size = target.size(0)
28 |
29 | _, pred = output.topk(maxk, 1, True, True)
30 | pred = pred.t()
31 | correct = pred.eq(target.view(1, -1).expand_as(pred))
32 | res = []
33 | for k in topk:
34 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
35 | res.append(correct_k.mul_(100.0 / batch_size))
36 | return res
37 |
38 | def add_fit_args(parser):
39 | """
40 | parser : argparse.ArgumentParser
41 | return a parser added with args required by fit
42 | """
43 | # Validation settings
44 | parser.add_argument('--eval-batch-size', type=int, default=10000, metavar='N',
45 | help='the batch size when doing model validation, complete at once on default')
46 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N',
47 | help='it determines per how many step the model should be evaluated')
48 | parser.add_argument('--model-dir', type=str, default='output/models/', metavar='N',
49 | help='directory to save the temp model during the training process for evaluation')
50 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N',
51 | help='which dataset used in training, MNIST and Cifar10 supported currently')
52 | parser.add_argument('--network', type=str, default='LeNet', metavar='N',
53 | help='which kind of network we are going to use, support LeNet and ResNet currently')
54 | args = parser.parse_args()
55 | return args
56 |
57 | class DistributedEvaluator(NN_Trainer):
58 | '''
59 | The DistributedEvaluator aims at providing a seperate node in the distributed cluster to evaluate
60 | the model on validation/test set and return the results
61 | In this version, the DistributedEvaluator will only load the model from the dir where the master
62 | save the model and do the evaluation task based on a user defined frequency
63 | '''
64 | def __init__(self, **kwargs):
65 | self._cur_step = 0
66 | self._model_dir = kwargs['model_dir']
67 | self._eval_freq = int(kwargs['eval_freq'])
68 | self._eval_batch_size = kwargs['eval_batch_size']
69 | self.network_config = kwargs['network']
70 | # this one is going to be used to avoid fetch the weights for multiple times
71 | self._layer_cur_step = []
72 |
73 | def evaluate(self, validation_loader):
74 | # init objective to fetch at the begining
75 | self._next_step_to_fetch = self._cur_step + self._eval_freq
76 | self._num_batch_per_epoch = len(validation_loader) / self._eval_batch_size
77 | # check if next temp model exsits, if not we wait here else we continue to do the model evaluation
78 | while True:
79 | model_dir_=self._model_dir_generator(self._next_step_to_fetch)
80 | if os.path.isfile(model_dir_):
81 | self._load_model(model_dir_)
82 | print("Evaluator evaluating results on step {}".format(self._next_step_to_fetch))
83 | self._evaluate_model(validation_loader)
84 | self._next_step_to_fetch += self._eval_freq
85 | else:
86 | # TODO(hwang): sleep appropriate period of time make sure to tune this parameter
87 | time.sleep(10)
88 |
89 | def _evaluate_model(self, test_loader):
90 | self.network.eval()
91 | test_loss = 0
92 | correct = 0
93 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
94 | for data, y_batch in test_loader:
95 | data, target = Variable(data), Variable(y_batch)
96 | output = self.network(data)
97 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).item()
98 | prec1_tmp, prec5_tmp = accuracy(output.detach(), y_batch, topk=(1, 5))
99 | prec1_counter_ += prec1_tmp.numpy()[0]
100 | prec5_counter_ += prec5_tmp.numpy()[0]
101 | batch_counter_ += 1
102 | prec1 = prec1_counter_ / batch_counter_
103 | prec5 = prec5_counter_ / batch_counter_
104 | test_loss /= len(test_loader.dataset)
105 | print('Test set: Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(test_loss, prec1, prec5))
106 |
107 | def _load_model(self, file_path):
108 | self.network = build_model(self.network_config, num_classes=10)
109 | with open(file_path, "rb") as f_:
110 | self.network.load_state_dict(torch.load(f_))
111 |
112 | def _model_dir_generator(self, next_step_to_fetch):
113 | return self._model_dir+"model_step_"+str(next_step_to_fetch)
114 |
115 | if __name__ == "__main__":
116 | # this is only a simple test case
117 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch Distributed Evaluator'))
118 |
119 | # load training and test set here:
120 | if args.dataset == "MNIST":
121 | test_loader = torch.utils.data.DataLoader(
122 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
123 | transforms.ToTensor(),
124 | transforms.Normalize((0.1307,), (0.3081,))
125 | ])), batch_size=args.eval_batch_size, shuffle=True)
126 | elif args.dataset == "Cifar10":
127 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
128 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
129 | transform_test = transforms.Compose([
130 | transforms.ToTensor(),
131 | normalize])
132 | testset = datasets.CIFAR10(root='./cifar10_data', train=False,
133 | download=True, transform=transform_test)
134 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size,
135 | shuffle=False)
136 |
137 | kwargs_evaluator={
138 | 'network':args.network,
139 | 'model_dir':args.model_dir,
140 | 'eval_freq':args.eval_freq,
141 | 'eval_batch_size':args.eval_batch_size}
142 | evaluator_nn = DistributedEvaluator(**kwargs_evaluator)
143 | evaluator_nn.evaluate(validation_loader=test_loader)
144 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size))
--------------------------------------------------------------------------------
/src/distributed_nn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | import math
5 | import threading
6 | import argparse
7 | import time
8 | import os
9 |
10 | import numpy as np
11 |
12 | import torch
13 | import torch.distributed as dist
14 | from torch.autograd import Variable
15 | from torch import nn
16 | import torch.nn.functional as F
17 |
18 | from nn_ops import NN_Trainer, accuracy
19 | from data_loader_ops.my_data_loader import DataLoader
20 |
21 | from distributed_worker import *
22 | from sync_replicas_master_nn import *
23 |
24 |
25 | def add_fit_args(parser):
26 | """
27 | parser : argparse.ArgumentParser
28 | return a parser added with args required by fit
29 | """
30 | # Training settings
31 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
32 | help='input batch size for training (default: 64)')
33 | parser.add_argument('--test-batch-size', type=int, default=500, metavar='N',
34 | help='input batch size for testing (default: 1000)')
35 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
36 | help='number of epochs to train (default: 10)')
37 | parser.add_argument('--max-steps', type=int, default=10000, metavar='N',
38 | help='the maximum number of iterations')
39 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
40 | help='learning rate (default: 0.01)')
41 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
42 | help='SGD momentum (default: 0.5)')
43 | parser.add_argument('--no-cuda', action='store_true', default=False,
44 | help='disables CUDA training')
45 | parser.add_argument('--seed', type=int, default=1, metavar='S',
46 | help='random seed (default: 1)')
47 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
48 | help='how many batches to wait before logging training status')
49 | parser.add_argument('--network', type=str, default='LeNet', metavar='N',
50 | help='which kind of network we are going to use, support LeNet and ResNet currently')
51 | parser.add_argument('--mode', type=str, default='normal', metavar='N',
52 | help='determine if we kill the stragglers or just implement normal training')
53 | parser.add_argument('--kill-threshold', type=float, default=7.0, metavar='KT',
54 | help='timeout threshold which triggers the killing process (default: 7s)')
55 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N',
56 | help='which dataset used in training, MNIST and Cifar10 supported currently')
57 | parser.add_argument('--comm-type', type=str, default='Bcast', metavar='N',
58 | help='which kind of method we use during the mode fetching stage')
59 | parser.add_argument('--num-aggregate', type=int, default=5, metavar='N',
60 | help='how many number of gradients we wish to gather at each iteration')
61 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N',
62 | help='it determines per how many step the model should be evaluated')
63 | parser.add_argument('--train-dir', type=str, default='output/models/', metavar='N',
64 | help='directory to save the temp model during the training process for evaluation')
65 | parser.add_argument('--compress-grad', type=str, default='compress', metavar='N',
66 | help='compress/none indicate if we compress the gradient matrix before communication')
67 | parser.add_argument('--gather-type', type=str, default='gather', metavar='N',
68 | help='gather/non to specify the type of comm used (MPI.Gather or point-to-point comm)')
69 | parser.add_argument('--enable-gpu', type=bool, default=False, help='whether to use gradient approx method')
70 | # TODO(hwang), check what's this
71 | parser.add_argument("--local_rank", type=int)
72 | args = parser.parse_args()
73 | return args
74 |
75 | if __name__ == "__main__":
76 | rank = int(os.environ['RANK'])
77 | world_size = int(os.environ['WORLD_SIZE'])
78 | master_addr = os.environ['MASTER_ADDR']
79 | master_port = os.environ['MASTER_PORT']
80 |
81 | print(rank, world_size)
82 | dist.init_process_group(backend='gloo', world_size=world_size, rank=rank)
83 |
84 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch MNIST Single Machine Test'))
85 |
86 | train_loader, test_loader = prepare_data(args)
87 |
88 | device = torch.device("cuda" if args.enable_gpu else "cpu")
89 |
90 | kwargs_master = {
91 | 'world_size':world_size,
92 | 'batch_size':args.batch_size,
93 | 'learning_rate':args.lr,
94 | 'max_epochs':args.epochs,
95 | 'momentum':args.momentum,
96 | 'network':args.network,
97 | 'comm_method':args.comm_type,
98 | 'kill_threshold': args.num_aggregate,
99 | 'timeout_threshold':args.kill_threshold,
100 | 'eval_freq':args.eval_freq,
101 | 'train_dir':args.train_dir,
102 | 'max_steps':args.max_steps,
103 | 'compress_grad':args.compress_grad,
104 | 'gather_type':args.gather_type,
105 | 'device':device}
106 |
107 | kwargs_worker = {
108 | 'rank':rank,
109 | 'batch_size':args.batch_size,
110 | 'learning_rate':args.lr,
111 | 'max_epochs':args.epochs,
112 | 'momentum':args.momentum,
113 | 'network':args.network,
114 | 'comm_method':args.comm_type,
115 | 'kill_threshold':args.kill_threshold,
116 | 'eval_freq':args.eval_freq,
117 | 'train_dir':args.train_dir,
118 | 'max_steps':args.max_steps,
119 | 'compress_grad':args.compress_grad,
120 | 'gather_type':args.gather_type,
121 | 'device':device}
122 |
123 | if rank == 0:
124 | master_fc_nn = SyncReplicasMaster_NN(**kwargs_master)
125 | if args.dataset == 'Cifar100':
126 | master_fc_nn.build_model(num_classes=100)
127 | else:
128 | master_fc_nn.build_model(num_classes=10)
129 | print("I am the master: the world size is {}, cur step: {}".format(master_fc_nn.world_size, master_fc_nn.cur_step))
130 | master_fc_nn.start()
131 | print("Done sending messages to workers!")
132 | else:
133 | worker_fc_nn = DistributedWorker(**kwargs_worker)
134 | if args.dataset == 'Cifar100':
135 | worker_fc_nn.build_model(num_classes=100)
136 | else:
137 | worker_fc_nn.build_model(num_classes=10)
138 | print("I am worker: {} in all {} workers, next step: {}".format(worker_fc_nn.rank, worker_fc_nn.world_size-1, worker_fc_nn.next_step))
139 | worker_fc_nn.train(train_loader=train_loader, test_loader=test_loader)
140 | print("Now the next step is: {}".format(worker_fc_nn.next_step))
--------------------------------------------------------------------------------
/src/distributed_worker.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 |
4 | from nn_ops import NN_Trainer
5 | from util import *
6 |
7 | import torch
8 | from torch.autograd import Variable
9 | import torch.distributed as dist
10 |
11 | import time
12 | from datetime import datetime
13 | import copy
14 | import logging
15 | from sys import getsizeof
16 |
17 | STEP_START_ = 1
18 | TAG_LIST_ = [i*30 for i in range(50000)]
19 |
20 | logging.basicConfig()
21 | logger = logging.getLogger()
22 | logger.setLevel(logging.INFO)
23 |
24 |
25 | def accuracy(output, target, topk=(1,)):
26 | """Computes the precision@k for the specified values of k"""
27 | maxk = max(topk)
28 | batch_size = target.size(0)
29 |
30 | _, pred = output.topk(maxk, 1, True, True)
31 | pred = pred.t()
32 | correct = pred.eq(target.view(1, -1).expand_as(pred))
33 | res = []
34 | for k in topk:
35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
36 | res.append(correct_k.mul_(100.0 / batch_size))
37 | return res
38 |
39 | class ModelBuffer(object):
40 | def __init__(self, network):
41 | """
42 | this class is used to save model weights received from parameter server
43 | current step for each layer of model will also be updated here to make sure
44 | the model is always up-to-date
45 | """
46 | super(ModelBuffer, self).__init__()
47 | self.recv_buf = []
48 | self.layer_cur_step = []
49 | self.layer_shape = []
50 | '''
51 | initialize space to receive model from parameter server
52 | '''
53 | # consider we don't want to update the param of `BatchNorm` layer right now
54 | # we temporirially deprecate the foregoing version and only update the model
55 | # parameters
56 | for param_idx, param in enumerate(network.parameters()):
57 | self.recv_buf.append(torch.zeros(param.size()))
58 |
59 |
60 | class DistributedWorker(NN_Trainer):
61 | def __init__(self, **kwargs):
62 | super(NN_Trainer, self).__init__()
63 |
64 | self.cur_step = 0
65 | self.next_step = 0 # we will fetch this one from parameter server
66 |
67 | self.rank = kwargs['rank']
68 | self.batch_size = kwargs['batch_size']
69 | self.max_epochs = kwargs['max_epochs']
70 | self.momentum = kwargs['momentum']
71 | self.lr = kwargs['learning_rate']
72 | self._max_steps = kwargs['max_steps']
73 | self.network_config = kwargs['network']
74 | self.comm_type = kwargs['comm_method']
75 | self.kill_threshold = kwargs['kill_threshold']
76 | self._eval_batch_size = 100
77 | self._eval_freq = kwargs['eval_freq']
78 | self._train_dir = kwargs['train_dir']
79 | self._compress_grad = kwargs['compress_grad']
80 | self._gather_type = kwargs['gather_type']
81 | self._device = kwargs['device']
82 |
83 | # this one is going to be used to avoid fetch the weights for multiple times
84 | self._layer_cur_step = []
85 |
86 | def build_model(self, num_classes=10):
87 | self.network = build_model(self.network_config, num_classes)
88 | # set up optimizer
89 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
90 | self.criterion = nn.CrossEntropyLoss()
91 | # assign a buffer for receiving models from parameter server
92 | self.init_recv_buf()
93 |
94 | self.network.to(self._device)
95 |
96 | def train(self, train_loader, test_loader):
97 | # the first step we need to do here is to sync fetch the inital worl_step from the parameter server
98 | # we still need to make sure the value we fetched from parameter server is 1
99 |
100 | # number of batches in one epoch
101 | iteration_last_step=0
102 | iter_start_time=0
103 | first = True
104 |
105 | logger.info("Worker {}: starting training".format(self.rank))
106 | # start the training process
107 | for num_epoch in range(self.max_epochs):
108 | for batch_idx, (train_image_batch, train_label_batch) in enumerate(train_loader):
109 |
110 | iter_start_time = time.time()
111 | # worker exit task
112 | if self.cur_step == self._max_steps:
113 | break
114 | X_batch, y_batch = train_image_batch.to(self._device), train_label_batch.to(self._device)
115 |
116 | # bcast communication stage
117 | fetch_weight_start = time.time()
118 | self._fetch_weight()
119 | fetch_weight_dur = time.time() - fetch_weight_start
120 |
121 | comp_start = time.time()
122 | self._train_init()
123 | loss, logits = self._forward(X_batch, y_batch)
124 | loss.backward()
125 | comp_dur = time.time() - comp_start
126 |
127 | prec1, prec5 = accuracy(logits.detach(), train_label_batch.long(), topk=(1, 5))
128 | gather_start = time.time()
129 | self._send_grads()
130 | gather_dur = time.time() - gather_start
131 |
132 |
133 | log_format = 'Worker: {}, Step: {}, Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.4f}, Time Cost: {:.4f}, FetchWeight: {:.4f}, Computation: {:.4f}, GatherTime: {:.4f}, Acc: {:.4f}'
134 | logger.info(log_format.format(self.rank,
135 | self.cur_step, num_epoch, batch_idx * self.batch_size, len(train_loader.dataset),
136 | (100. * (batch_idx * self.batch_size) / len(train_loader.dataset)), loss.item(),
137 | time.time()-iter_start_time, fetch_weight_dur, comp_dur, gather_dur, prec1.numpy()[0]))
138 |
139 | if self.cur_step%self._eval_freq == 0:
140 | self._save_model(file_path=self._generate_model_path())
141 |
142 | def init_recv_buf(self):
143 | self.model_recv_buf = ModelBuffer(self.network)
144 |
145 | def _train_init(self):
146 | self.network.train()
147 | self.optimizer.zero_grad()
148 |
149 | def _forward(self, X_batch, y_batch):
150 | logits = self.network(X_batch)
151 | return self.criterion(logits, y_batch), logits
152 |
153 | def _fetch_weight(self):
154 | for layer_idx, layer in enumerate(self.model_recv_buf.recv_buf):
155 | dist.broadcast(self.model_recv_buf.recv_buf[layer_idx], src=0)
156 | self.model_update(self.model_recv_buf.recv_buf)
157 | # Note that at here we update the global step
158 | self.cur_step += 1
159 |
160 | def update_step(self):
161 | '''update local (global) step on worker'''
162 | changed = (self.cur_step != self.next_step)
163 | self.cur_step = self.next_step
164 | return changed
165 |
166 | def model_update(self, weights_to_update):
167 | """write model fetched from parameter server to local model"""
168 | new_state_dict = {}
169 | model_counter_ = 0
170 | for param_idx,(key_name, param) in enumerate(self.network.state_dict().items()):
171 | # handle the case that `running_mean` and `running_var` contained in `BatchNorm` layer
172 | if "running_mean" in key_name or "running_var" in key_name or "num_batches_tracked" in key_name:
173 | tmp_dict={key_name: param}
174 | else:
175 | assert param.size() == weights_to_update[model_counter_].size()
176 | tmp_dict = {key_name: weights_to_update[model_counter_].to(self._device)}
177 | model_counter_ += 1
178 | new_state_dict.update(tmp_dict)
179 | self.network.load_state_dict(new_state_dict)
180 |
181 | def _send_grads(self):
182 | for p_index, p in enumerate(self.network.parameters()):
183 | # fetch the grad we need
184 | if self._device.type == "cuda":
185 | grad = p.grad.to(torch.device("cpu")).detach()
186 | else:
187 | grad = p.grad.detach()
188 |
189 | dist.gather(grad, [], dst=0)
190 |
191 | def _evaluate_model(self, test_loader):
192 | self.network.eval()
193 | test_loss = 0
194 | correct = 0
195 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
196 | for data, y_batch in test_loader:
197 | data, target = data.to(self._device), y_batch.to(self._device)
198 |
199 | output = self.network(data)
200 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).item() # sum up batch loss
201 |
202 | prec1_tmp, prec5_tmp = accuracy(output.detach(), y_batch, topk=(1, 5))
203 |
204 | if self._device.type == 'cuda':
205 | prec1_counter_ += prec1_tmp.to(torch.device("cpu")).numpy()[0]
206 | prec5_counter_ += prec5_tmp.to(torch.device("cpu")).numpy()[0]
207 | else:
208 | prec1_counter_ += prec1_tmp.numpy()[0]
209 | prec5_counter_ += prec5_tmp.numpy()[0]
210 |
211 | batch_counter_ += 1
212 | prec1 = prec1_counter_ / batch_counter_
213 | prec5 = prec5_counter_ / batch_counter_
214 | test_loss /= len(test_loader.dataset)
215 | print('Test set: Step: {}, Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(self.cur_step,
216 | test_loss, prec1, prec5))
217 |
218 | def _generate_model_path(self):
219 | return self._train_dir+"model_step_"+str(self.cur_step)
220 |
221 | def _save_model(self, file_path):
222 | with open(file_path, "wb") as f_:
223 | torch.save(self.network.state_dict(), f_)
224 | return
225 |
226 | if __name__ == "__main__":
227 | # this is only a simple test case
228 | comm = MPI.COMM_WORLD
229 | rank = comm.Get_rank()
230 | world_size = comm.Get_size()
231 | worker_fc_nn = WorkerFC_NN(comm=comm, world_size=world_size, rank=rank)
232 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size))
--------------------------------------------------------------------------------
/src/evaluate_pytorch.sh:
--------------------------------------------------------------------------------
1 | python distributed_evaluator.py \
2 | --eval-batch-size=10000 \
3 | --eval-freq=50 \
4 | --network=ResNet18 \
5 | --dataset=Cifar10 \
6 | --model-dir=/home/ubuntu/MPI_shared/
--------------------------------------------------------------------------------
/src/launch.sh:
--------------------------------------------------------------------------------
1 | KEY_PEM_NAME=HongyiScript.pem
2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
3 | MASTER_PUB_IP="$1"
4 | WORKING_DIR=${HOME}/ps_real_pytorch/src
5 |
6 | for i in $(seq 1 $DEEPLEARNING_WORKERS_COUNT);
7 | do
8 | ssh -i ${HOME}/.ssh/${KEY_PEM_NAME} deeplearning-worker${i} "cd ${WORKING_DIR}; nohup bash ${WORKING_DIR}/run_pytorch_dist.sh \"$((${i}-1))\" \"${DEEPLEARNING_WORKERS_COUNT}\" \"${MASTER_PUB_IP}\" &>/dev/null &"
9 | done
--------------------------------------------------------------------------------
/src/model_ops/__init__.py:
--------------------------------------------------------------------------------
1 | from . import lenet, resnet, resnet_split, vgg
2 |
3 | __all__ = ['lenet', 'resnet', 'resnet_split', 'vgg']
--------------------------------------------------------------------------------
/src/model_ops/lenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | import pandas as pd
6 | import numpy as np
7 | from torch.autograd import Variable
8 |
9 | from mpi4py import MPI
10 |
11 | import sys
12 | sys.path.insert(0, '../compression')
13 | from compression import g_compress
14 |
15 | # we use LeNet here for our simple case
16 | class LeNet(nn.Module):
17 | def __init__(self):
18 | super(LeNet, self).__init__()
19 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
20 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
21 | self.fc1 = nn.Linear(4*4*50, 500)
22 | self.fc2 = nn.Linear(500, 10)
23 | self.ceriation = nn.CrossEntropyLoss()
24 | def forward(self, x):
25 | x = self.conv1(x)
26 | x = F.max_pool2d(x, 2, 2)
27 | x = F.relu(x)
28 | x = self.conv2(x)
29 | x = F.max_pool2d(x, 2, 2)
30 | x = F.relu(x)
31 | x = x.view(-1, 4*4*50)
32 | x = self.fc1(x)
33 | x = self.fc2(x)
34 | #loss = self.ceriation(x, target)
35 | return x
36 | def name(self):
37 | return 'lenet'
38 |
39 | class LeNetSplit(nn.Module):
40 | '''
41 | this is a module that we split the module and do backward process layer by layer
42 | please don't call this module for normal uses, this is a hack and run slower than
43 | the automatic chain rule version
44 | '''
45 | def __init__(self):
46 | super(LeNetSplit, self).__init__()
47 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
48 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
49 | self.fc1 = nn.Linear(4*4*50, 500)
50 | self.fc2 = nn.Linear(500, 10)
51 |
52 | self.maxpool2d = nn.MaxPool2d(2, stride=2)
53 | self.relu = nn.ReLU()
54 |
55 | self.full_modules = [self.conv1, self.conv2, self.fc1, self.fc2]
56 | self._init_channel_index = len(self.full_modules)*2
57 |
58 | self.criterion = nn.CrossEntropyLoss()
59 |
60 | def forward(self, x):
61 | self.output = []
62 | self.input = []
63 | x = Variable(x.data, requires_grad=True)
64 | self.input.append(x)
65 | x = self.conv1(x)
66 | self.output.append(x)
67 |
68 | x = Variable(x.data, requires_grad=True)
69 | self.input.append(x)
70 | x = self.maxpool2d(x)
71 | self.output.append(x)
72 |
73 | x = Variable(x.data, requires_grad=True)
74 | self.input.append(x)
75 | x = self.relu(x)
76 | self.output.append(x)
77 |
78 | x = Variable(x.data, requires_grad=True)
79 | self.input.append(x)
80 | x = self.conv2(x)
81 | self.output.append(x)
82 |
83 | x = Variable(x.data, requires_grad=True)
84 | self.input.append(x)
85 | x = self.maxpool2d(x)
86 | self.output.append(x)
87 |
88 | x = Variable(x.data, requires_grad=True)
89 | self.input.append(x)
90 | x = self.relu(x)
91 | self.output.append(x)
92 |
93 | x = x.view(-1, 4*4*50)
94 |
95 | x = Variable(x.data, requires_grad=True)
96 | self.input.append(x)
97 | x = self.fc1(x)
98 | self.output.append(x)
99 |
100 | x = Variable(x.data, requires_grad=True)
101 | self.input.append(x)
102 | x = self.fc2(x)
103 | self.output.append(x)
104 | return x
105 |
106 | @property
107 | def fetch_init_channel_index(self):
108 | return self._init_channel_index
109 |
110 | def backward_normal(self, g, communicator, req_send_check, cur_step, compress_grad):
111 | mod_avail_index = len(self.full_modules)-1
112 | #channel_index = len(self.full_modules)*2-2
113 | channel_index = self._init_channel_index - 2
114 | mod_counters_ = [0]*len(self.full_modules)
115 | for i, output in reversed(list(enumerate(self.output))):
116 | req_send_check[-1].wait()
117 | if i == (len(self.output) - 1):
118 | # for last node, use g
119 | output.backward(g)
120 | # get gradient here after some sanity checks:
121 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
122 | if not pd.isnull(tmp_grad):
123 | grads = tmp_grad.data.numpy().astype(np.float64)
124 | ###############################################################################################
125 | if compress_grad == 'compress':
126 | _compressed_grad = g_compress(grads)
127 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index)
128 | else:
129 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
130 | ################################################################################################
131 | req_send_check.append(req_isend)
132 | # update counters
133 | mod_avail_index-=1
134 | channel_index-=1
135 | else:
136 | continue
137 | else:
138 | output.backward(self.input[i+1].grad.data)
139 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
140 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
141 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
142 | # we always send bias first
143 | if mod_counters_[mod_avail_index] == 0:
144 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
145 | ###############################################################################################
146 | if compress_grad == 'compress':
147 | _compressed_grad = g_compress(grads)
148 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index)
149 | else:
150 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
151 | ################################################################################################
152 | req_send_check.append(req_isend)
153 | channel_index-=1
154 | mod_counters_[mod_avail_index]+=1
155 | elif mod_counters_[mod_avail_index] == 1:
156 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
157 | ###############################################################################################
158 | if compress_grad == 'compress':
159 | _compressed_grad = g_compress(grads)
160 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index)
161 | else:
162 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
163 | ################################################################################################
164 | req_send_check.append(req_isend)
165 | channel_index-=1
166 | mod_counters_[mod_avail_index]+=1
167 | # update counters
168 | mod_avail_index-=1
169 | else:
170 | continue
171 | if mod_counters_[0] == 1:
172 | req_send_check[-1].wait()
173 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
174 | ###############################################################################################
175 | if compress_grad == 'compress':
176 | _compressed_grad = g_compress(grads)
177 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index)
178 | else:
179 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
180 | ################################################################################################
181 | req_send_check.append(req_isend)
182 | # for debugging here:
183 | for req in req_send_check:
184 | req.wait()
185 | return req_send_check
186 |
187 | def backward_signal_kill(self, g, communicator, req_send_check, cur_step):
188 | '''
189 | This killer is triggered by signals bcasting from master, channel of
190 | signal is kept checking by each worker to determine if they're the
191 | straggler
192 | '''
193 | mod_avail_index = len(self.full_modules)-1
194 | channel_index = self._init_channel_index - 2
195 | mod_counters_ = [0]*len(self.full_modules)
196 |
197 | # should kill flag
198 | should_kill = False
199 |
200 | for i, output in reversed(list(enumerate(self.output))):
201 | ############################ killing process on workers #####################################
202 | for _ in range(10000):
203 | status = MPI.Status()
204 | communicator.Iprobe(0, 77, status)
205 | if status.Get_source() == 0:
206 | print("Worker {}, Cur Step: {} I'm the straggler, killing myself!".format(communicator.Get_rank(), cur_step))
207 | tmp = communicator.recv(source=0, tag=77)
208 | should_kill = True
209 | break
210 | if should_kill:
211 | break
212 | ############################################################################################
213 |
214 | if i == (len(self.output) - 1):
215 | # for last node, use g
216 | output.backward(g)
217 | # get gradient here after some sanity checks:
218 | tmp_grad = self.full_modules[mod_avail_index].weight.grad
219 | if not pd.isnull(tmp_grad):
220 | grads = tmp_grad.data.numpy().astype(np.float64)
221 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
222 | req_send_check.append(req_isend)
223 | # update counters
224 | mod_avail_index-=1
225 | channel_index-=1
226 | else:
227 | continue
228 | else:
229 | output.backward(self.input[i+1].grad.data)
230 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad
231 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad
232 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias):
233 | # we always send bias first
234 | if mod_counters_[mod_avail_index] == 0:
235 | grads = tmp_grad_bias.data.numpy().astype(np.float64)
236 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
237 | req_send_check.append(req_isend)
238 | channel_index-=1
239 | mod_counters_[mod_avail_index]+=1
240 | elif mod_counters_[mod_avail_index] == 1:
241 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
242 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
243 | req_send_check.append(req_isend)
244 | channel_index-=1
245 | mod_counters_[mod_avail_index]+=1
246 | # update counters
247 | mod_avail_index-=1
248 | else:
249 | continue
250 | if mod_counters_[0] == 1:
251 | grads = tmp_grad_weight.data.numpy().astype(np.float64)
252 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index)
253 | req_send_check.append(req_isend)
254 | return req_send_check
255 |
256 | def backward_timeout_kill(self, g, communicator, req_send_check):
257 | """do we even need this?"""
258 | pass
--------------------------------------------------------------------------------
/src/model_ops/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from torch.autograd import Variable
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion*planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1):
43 | super(Bottleneck, self).__init__()
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
50 |
51 | self.shortcut = nn.Sequential()
52 | if stride != 1 or in_planes != self.expansion*planes:
53 | self.shortcut = nn.Sequential(
54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
55 | nn.BatchNorm2d(self.expansion*planes)
56 | )
57 |
58 | def forward(self, x):
59 | out = F.relu(self.bn1(self.conv1(x)))
60 | out = F.relu(self.bn2(self.conv2(out)))
61 | out = self.bn3(self.conv3(out))
62 | out += self.shortcut(x)
63 | out = F.relu(out)
64 | return out
65 |
66 |
67 | class ResNet(nn.Module):
68 | def __init__(self, block, num_blocks, num_classes=10):
69 | super(ResNet, self).__init__()
70 | self.in_planes = 64
71 |
72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
73 | self.bn1 = nn.BatchNorm2d(64)
74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
78 | self.linear = nn.Linear(512*block.expansion, num_classes)
79 |
80 | def _make_layer(self, block, planes, num_blocks, stride):
81 | strides = [stride] + [1]*(num_blocks-1)
82 | layers = []
83 | for stride in strides:
84 | layers.append(block(self.in_planes, planes, stride))
85 | self.in_planes = planes * block.expansion
86 | return nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | out = F.relu(self.bn1(self.conv1(x)))
90 | out = self.layer1(out)
91 | out = self.layer2(out)
92 | out = self.layer3(out)
93 | out = self.layer4(out)
94 | out = F.avg_pool2d(out, 4)
95 | out = out.view(out.size(0), -1)
96 | out = self.linear(out)
97 | return out
98 |
99 |
100 | def ResNet18(num_classes):
101 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)
102 |
103 | def ResNet34(num_classes):
104 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
105 |
106 | def ResNet50():
107 | return ResNet(Bottleneck, [3,4,6,3])
108 |
109 | def ResNet101():
110 | return ResNet(Bottleneck, [3,4,23,3])
111 |
112 | def ResNet152():
113 | return ResNet(Bottleneck, [3,8,36,3])
--------------------------------------------------------------------------------
/src/model_ops/vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://github.com/pytorch/vision.git
3 | '''
4 | import math
5 |
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | class VGG(nn.Module):
16 | '''
17 | VGG model
18 | '''
19 | def __init__(self, features, num_classes=10):
20 | super(VGG, self).__init__()
21 | self.features = features
22 | self.classifier = nn.Sequential(
23 | nn.Dropout(),
24 | nn.Linear(512, 512),
25 | nn.ReLU(True),
26 | nn.Dropout(),
27 | nn.Linear(512, 512),
28 | nn.ReLU(True),
29 | nn.Linear(512, num_classes),
30 | )
31 | # Initialize weights
32 | for m in self.modules():
33 | if isinstance(m, nn.Conv2d):
34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
35 | m.weight.data.normal_(0, math.sqrt(2. / n))
36 | m.bias.data.zero_()
37 |
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | x = x.view(x.size(0), -1)
42 | x = self.classifier(x)
43 | return x
44 |
45 |
46 | def make_layers(cfg, batch_norm=False):
47 | layers = []
48 | in_channels = 3
49 | for v in cfg:
50 | if v == 'M':
51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
52 | else:
53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
54 | if batch_norm:
55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
56 | else:
57 | layers += [conv2d, nn.ReLU(inplace=True)]
58 | in_channels = v
59 | return nn.Sequential(*layers)
60 |
61 |
62 | cfg = {
63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
67 | 512, 512, 512, 512, 'M'],
68 | }
69 |
70 |
71 | def vgg11():
72 | """VGG 11-layer model (configuration "A")"""
73 | return VGG(make_layers(cfg['A']))
74 |
75 |
76 | def vgg11_bn(num_classes=10):
77 | """VGG 11-layer model (configuration "A") with batch normalization"""
78 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes=num_classes)
79 |
80 |
81 | def vgg13():
82 | """VGG 13-layer model (configuration "B")"""
83 | return VGG(make_layers(cfg['B']))
84 |
85 |
86 | def vgg13_bn():
87 | """VGG 13-layer model (configuration "B") with batch normalization"""
88 | return VGG(make_layers(cfg['B'], batch_norm=True))
89 |
90 |
91 | def vgg16():
92 | """VGG 16-layer model (configuration "D")"""
93 | return VGG(make_layers(cfg['D']))
94 |
95 |
96 | def vgg16_bn():
97 | """VGG 16-layer model (configuration "D") with batch normalization"""
98 | return VGG(make_layers(cfg['D'], batch_norm=True))
99 |
100 |
101 | def vgg19():
102 | """VGG 19-layer model (configuration "E")"""
103 | return VGG(make_layers(cfg['E']))
104 |
105 |
106 | def vgg19_bn():
107 | """VGG 19-layer model (configuration 'E') with batch normalization"""
108 | return VGG(make_layers(cfg['E'], batch_norm=True))
--------------------------------------------------------------------------------
/src/nn_ops.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from torch.autograd import Variable
8 |
9 | from model_ops.lenet import LeNet
10 | from model_ops.resnet import *
11 |
12 | '''this is a trial example, we use MNIST on LeNet for simple test here'''
13 | def accuracy(output, target, topk=(1,)):
14 | """Computes the precision@k for the specified values of k"""
15 | maxk = max(topk)
16 | batch_size = target.size(0)
17 |
18 | _, pred = output.topk(maxk, 1, True, True)
19 | pred = pred.t()
20 | correct = pred.eq(target.view(1, -1).expand_as(pred))
21 |
22 | res = []
23 | for k in topk:
24 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
25 | res.append(correct_k.mul_(100.0 / batch_size))
26 | return res
27 |
28 | class NN_Trainer(object):
29 | def __init__(self, **kwargs):
30 | self.batch_size = kwargs['batch_size']
31 | self.lr = kwargs['learning_rate']
32 | self.max_epochs = kwargs['max_epochs']
33 | self.momentum = kwargs['momentum']
34 | self.network_config = kwargs['network']
35 |
36 | def build_model(self):
37 | # build network
38 | if self.network_config == "LeNet":
39 | self.network=LeNet()
40 | elif self.network_config == "ResNet":
41 | #self.network=ResNet18()
42 | self.network=ResNetSplit18(1)
43 | # set up optimizer
44 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
45 | self.criterion = torch.nn.CrossEntropyLoss()
46 |
47 | def train_and_validate(self, train_loader, test_loader):
48 | # iterate of epochs
49 | for i in range(self.max_epochs):
50 | # change back to training mode
51 | self.network.train()
52 | for batch_idx, (data, y_batch) in enumerate(train_loader):
53 | iter_start_time = time.time()
54 | data, target = Variable(data), Variable(y_batch)
55 | self.optimizer.zero_grad()
56 | ################# backward on normal model ############################
57 |
58 | logits = self.network(data)
59 | loss = self.criterion(logits, target)
60 | loss.backward()
61 | #######################################################################
62 |
63 | ################ backward on splitted model ###########################
64 | #logits = self.network(data)
65 | #logits_1 = Variable(logits.data, requires_grad=True)
66 | #loss = self.criterion(logits_1, target)
67 | #loss.backward()
68 | #self.network.backward_single(logits_1.grad)
69 | #######################################################################
70 | tmp_time_0 = time.time()
71 |
72 | duration_backward = time.time()-tmp_time_0
73 |
74 | tmp_time_1 = time.time()
75 | self.optimizer.step()
76 | duration_update = time.time()-tmp_time_1
77 |
78 | # calculate training accuracy
79 | prec1, prec5 = accuracy(logits.data, y_batch, topk=(1, 5))
80 | # load the training info
81 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Prec@1: {} Prec@5: {} Time Cost: {}'.format(
82 | i, batch_idx * len(data), len(train_loader.dataset),
83 | 100. * batch_idx / len(train_loader), loss.data[0],
84 | prec1.numpy()[0],
85 | prec5.numpy()[0], time.time()-iter_start_time))
86 | # we evaluate the model performance on end of each epoch
87 | self.validate(test_loader)
88 |
89 | def validate(self, test_loader):
90 | self.network.eval()
91 | test_loss = 0
92 | correct = 0
93 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
94 | for data, y_batch in test_loader:
95 | data, target = Variable(data, volatile=True), Variable(y_batch)
96 | output = self.network(data)
97 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
98 | prec1_tmp, prec5_tmp = accuracy(output.data, y_batch, topk=(1, 5))
99 | prec1_counter_ += prec1_tmp.numpy()[0]
100 | prec5_counter_ += prec5_tmp.numpy()[0]
101 | batch_counter_ += 1
102 | prec1 = prec1_counter_ / batch_counter_
103 | prec5 = prec5_counter_ / batch_counter_
104 | test_loss /= len(test_loader.dataset)
105 | print('Test set: Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(test_loss, prec1, prec5))
106 |
--------------------------------------------------------------------------------
/src/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from . import adam, sgd
2 |
3 | __all__ = ['adam', 'sgd']
--------------------------------------------------------------------------------
/src/optim/adam.py:
--------------------------------------------------------------------------------
1 | '''
2 | modified version of Adam optimizer
3 | by Hongyi Wang
4 | '''
5 | import sys
6 |
7 | import math
8 | import torch
9 | from torch.optim import Optimizer
10 |
11 |
12 | class Adam(Optimizer):
13 | """Implements Adam algorithm.
14 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
15 | Arguments:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float, optional): learning rate (default: 1e-3)
19 | betas (Tuple[float, float], optional): coefficients used for computing
20 | running averages of gradient and its square (default: (0.9, 0.999))
21 | eps (float, optional): term added to the denominator to improve
22 | numerical stability (default: 1e-8)
23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
25 | algorithm from the paper `On the Convergence of Adam and Beyond`_
26 | .. _Adam\: A Method for Stochastic Optimization:
27 | https://arxiv.org/abs/1412.6980
28 | .. _On the Convergence of Adam and Beyond:
29 | https://openreview.net/forum?id=ryQu7f-RZ
30 | """
31 |
32 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
33 | weight_decay=0, amsgrad=False):
34 | defaults = dict(lr=lr, betas=betas, eps=eps,
35 | weight_decay=weight_decay, amsgrad=amsgrad)
36 | super(Adam, self).__init__(params, defaults)
37 |
38 | def step(self, grads, closure=None):
39 | """Performs a single optimization step.
40 | Arguments:
41 | closure (callable, optional): A closure that reevaluates the model
42 | and returns the loss.
43 | """
44 | loss = None
45 | if closure is not None:
46 | loss = closure()
47 |
48 | for group in self.param_groups:
49 | for i,p in enumerate(group['params']):
50 | grad = torch.from_numpy(grads[i]).float()
51 | if grad.is_sparse:
52 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
53 | amsgrad = group['amsgrad']
54 |
55 | state = self.state[p]
56 |
57 | # State initialization
58 | if len(state) == 0:
59 | state['step'] = 0
60 | # Exponential moving average of gradient values
61 | state['exp_avg'] = torch.zeros_like(p.data)
62 | # Exponential moving average of squared gradient values
63 | state['exp_avg_sq'] = torch.zeros_like(p.data)
64 | if amsgrad:
65 | # Maintains max of all exp. moving avg. of sq. grad. values
66 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
67 |
68 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
69 | if amsgrad:
70 | max_exp_avg_sq = state['max_exp_avg_sq']
71 | beta1, beta2 = group['betas']
72 |
73 | state['step'] += 1
74 |
75 | if group['weight_decay'] != 0:
76 | grad = grad.add(group['weight_decay'], p.data)
77 |
78 | # Decay the first and second moment running average coefficient
79 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
80 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
81 | if amsgrad:
82 | # Maintains the maximum of all 2nd moment running avg. till now
83 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
84 | # Use the max. for normalizing running avg. of gradient
85 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
86 | else:
87 | denom = exp_avg_sq.sqrt().add_(group['eps'])
88 |
89 | bias_correction1 = 1 - beta1 ** state['step']
90 | bias_correction2 = 1 - beta2 ** state['step']
91 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
92 |
93 | p.data.addcdiv_(-step_size, exp_avg, denom)
94 |
95 | return loss
--------------------------------------------------------------------------------
/src/optim/sgd.py:
--------------------------------------------------------------------------------
1 | '''
2 | modified version of SGD optimizer
3 | by Hongyi Wang
4 | '''
5 | import sys
6 |
7 | import torch
8 | from torch.optim import Optimizer
9 |
10 |
11 | class SGD(Optimizer):
12 | r"""Implements stochastic gradient descent (optionally with momentum).
13 | Nesterov momentum is based on the formula from
14 | `On the importance of initialization and momentum in deep learning`__.
15 | Args:
16 | params (iterable): iterable of parameters to optimize or dicts defining
17 | parameter groups
18 | lr (float): learning rate
19 | momentum (float, optional): momentum factor (default: 0)
20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
21 | dampening (float, optional): dampening for momentum (default: 0)
22 | nesterov (bool, optional): enables Nesterov momentum (default: False)
23 | Example:
24 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
25 | >>> optimizer.zero_grad()
26 | >>> loss_fn(model(input), target).backward()
27 | >>> optimizer.step()
28 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
29 | .. note::
30 | The implementation of SGD with Momentum/Nesterov subtly differs from
31 | Sutskever et. al. and implementations in some other frameworks.
32 | Considering the specific case of Momentum, the update can be written as
33 | .. math::
34 | v = \rho * v + g \\
35 | p = p - lr * v
36 | where p, g, v and :math:`\rho` denote the parameters, gradient,
37 | velocity, and momentum respectively.
38 | This is in contrast to Sutskever et. al. and
39 | other frameworks which employ an update of the form
40 | .. math::
41 | v = \rho * v + lr * g \\
42 | p = p - v
43 | The Nesterov version is analogously modified.
44 | """
45 |
46 | def __init__(self, params, lr=0.1, momentum=0, dampening=0,
47 | weight_decay=0, nesterov=False):
48 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
49 | weight_decay=weight_decay, nesterov=nesterov)
50 | if nesterov and (momentum <= 0 or dampening != 0):
51 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
52 | super(SGD, self).__init__(params, defaults)
53 |
54 | def __setstate__(self, state):
55 | super(SGD, self).__setstate__(state)
56 | for group in self.param_groups:
57 | group.setdefault('nesterov', False)
58 |
59 | def step(self, grads, closure=None):
60 | """Performs a single optimization step.
61 | Arguments:
62 | closure (callable, optional): A closure that reevaluates the model
63 | and returns the loss.
64 | """
65 | loss = None
66 | if closure is not None:
67 | loss = closure()
68 |
69 | for group in self.param_groups:
70 | weight_decay = group['weight_decay']
71 | momentum = group['momentum']
72 | dampening = group['dampening']
73 | nesterov = group['nesterov']
74 |
75 | for i,p in enumerate(group['params']):
76 | d_p = grads[i]
77 | if weight_decay != 0:
78 | d_p.add_(weight_decay, p.data)
79 | if momentum != 0:
80 | param_state = self.state[p]
81 | if 'momentum_buffer' not in param_state:
82 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
83 | buf.mul_(momentum).add_(d_p)
84 | else:
85 | buf = param_state['momentum_buffer']
86 | buf.mul_(momentum).add_(1 - dampening, d_p)
87 | if nesterov:
88 | d_p = d_p.add(momentum, buf)
89 | else:
90 | d_p = buf
91 | p.data.add_(-group['lr'], d_p)
92 | return loss
--------------------------------------------------------------------------------
/src/run_pytorch_dist.sh:
--------------------------------------------------------------------------------
1 | NODE_RANK="$1"
2 | NNODE="$2"
3 | MASTER_IP="$3"
4 | SRC_DIR=${HOME}/ps_real_pytorch/src
5 |
6 | echo ${MASTER_IP}
7 | sudo /home/ubuntu/anaconda3/bin/python -m torch.distributed.launch \
8 | --nproc_per_node=1 \
9 | --nnodes=${NNODE} --node_rank=${NODE_RANK} --master_addr="${MASTER_IP}" --master_port=1234 \
10 | ${SRC_DIR}/distributed_nn.py \
11 | --lr=0.1 \
12 | --momentum=0.9 \
13 | --max-steps=100000 \
14 | --epochs=100 \
15 | --network=ResNet18 \
16 | --dataset=Cifar10 \
17 | --batch-size=64 \
18 | --comm-type=Bcast \
19 | --num-aggregate=2 \
20 | --mode=normal \
21 | --eval-freq=2000 \
22 | --gather-type=gather \
23 | --compress-grad=compress \
24 | --enable-gpu= \
25 | --train-dir=/home/ubuntu > out_node_${NODE_RANK} 2>&1
--------------------------------------------------------------------------------
/src/run_pytorch_single.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch \
2 | --nproc_per_node=3 \
3 | distributed_nn.py \
4 | --lr=0.01 \
5 | --momentum=0.9 \
6 | --max-steps=100000 \
7 | --epochs=100 \
8 | --network=LeNet \
9 | --dataset=MNIST \
10 | --batch-size=128 \
11 | --comm-type=Bcast \
12 | --num-aggregate=5 \
13 | --mode=normal \
14 | --eval-freq=20 \
15 | --gather-type=gather \
16 | --compress-grad=compress \
17 | --enable-gpu= \
18 | --train-dir=/home/ubuntu
--------------------------------------------------------------------------------
/src/sync_replicas_master_nn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import time
3 | import copy
4 | from sys import getsizeof
5 | import logging
6 | from functools import reduce
7 |
8 | import numpy as np
9 |
10 | from nn_ops import NN_Trainer
11 | from util import *
12 | from optim.adam import Adam
13 | from optim.sgd import SGD
14 |
15 | import torch
16 | import torch.distributed as dist
17 |
18 | STEP_START_ = 1
19 |
20 | logging.basicConfig()
21 | logger = logging.getLogger()
22 | logger.setLevel(logging.INFO)
23 |
24 | def update_params_dist_version(param, avg_grad, learning_rate):
25 | '''
26 | update the network layer by layer
27 | '''
28 | assert param.shape == avg_grad.shape
29 | param -= learning_rate * avg_grad
30 | return param
31 |
32 |
33 | def accuracy(output, target, topk=(1,)):
34 | """Computes the precision@k for the specified values of k"""
35 | maxk = max(topk)
36 | batch_size = target.size(0)
37 |
38 | _, pred = output.topk(maxk, 1, True, True)
39 | pred = pred.t()
40 | correct = pred.eq(target.view(1, -1).expand_as(pred))
41 |
42 | res = []
43 | for k in topk:
44 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
45 | res.append(correct_k.mul_(100.0 / batch_size))
46 | return res
47 |
48 |
49 | class GradientAccumulator(object):
50 | '''a simple class to implement gradient aggregator like the `Conditional Accumulators` in tensorflow'''
51 | def __init__(self, module, world_size, mode='None'):
52 | super(GradientAccumulator, self).__init__()
53 | # we will update this counter dynamically during the training process
54 | # the length of this counter should be number of fc layers in the network
55 | # we used list to contain gradients of layers
56 | self.gradient_aggregate_counter = []
57 | self.model_index_range = []
58 | self.gradient_aggregator = []
59 | self._mode = mode
60 |
61 | for param_idx, param in enumerate(module.parameters()):
62 | tmp_aggregator = []
63 | for worker_idx in range(world_size):
64 | tmp_aggregator.append(torch.zeros(param.size()))
65 | # initialize the gradient aggragator
66 | self.gradient_aggregator.append(tmp_aggregator)
67 | self.gradient_aggregate_counter.append(0)
68 | self.model_index_range.append(param_idx)
69 |
70 | def meset_everything(self):
71 | self._meset_grad_counter()
72 | self._meset_grad_aggregator()
73 |
74 | def _meset_grad_counter(self):
75 | self.gradient_aggregate_counter = [0 for _ in self.gradient_aggregate_counter]
76 |
77 | def _meset_grad_aggregator(self):
78 | '''
79 | reset the buffers in grad accumulator, not sure if this is necessary
80 | '''
81 | if self._mode == 'compress':
82 | pass
83 | else:
84 | for i, tmp_aggregator in enumerate(self.gradient_aggregator):
85 | for j, buf in enumerate(tmp_aggregator):
86 | self.gradient_aggregator[i][j] = np.zeros(self.gradient_aggregator[i][j].shape)
87 |
88 |
89 | class SyncReplicasMaster_NN(NN_Trainer):
90 | def __init__(self, **kwargs):
91 | super(NN_Trainer, self).__init__()
92 | '''master node here, no rank needed since the rank will always be 0 for master node'''
93 | self.world_size = kwargs['world_size']
94 | self.cur_step = STEP_START_
95 | self.lr = kwargs['learning_rate']
96 | self.momentum = kwargs['momentum']
97 | self.network_config = kwargs['network']
98 | self.comm_type = kwargs['comm_method']
99 | self._timeout_threshold = kwargs['timeout_threshold']
100 |
101 | self._num_workers = self.world_size - 1
102 | # used to aggregate tmp gradients, the length is the same as # of fc layer
103 | self._grad_aggregate_buffer = []
104 | self._model_shapes = []
105 | self._first_grad_received = False
106 | self._eval_freq = kwargs['eval_freq']
107 | self._train_dir = kwargs['train_dir']
108 | self._expected_grad_to_recv = kwargs['kill_threshold']
109 | self._max_steps = kwargs['max_steps']
110 | self._compress_grad = kwargs['compress_grad']
111 | self._gather_type = kwargs['gather_type']
112 | self._device = kwargs['device']
113 |
114 | ############ will be deprecated soon #############################
115 | self._eval_batch_size = 1000
116 |
117 | def build_model(self, num_classes=10):
118 | self.network = build_model(self.network_config, num_classes)
119 | self.optimizer = SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
120 | # assign a gradient accumulator to collect gradients from workers
121 | self.grad_accumulator = GradientAccumulator(self.network, self.world_size, self._compress_grad)
122 | self.init_model_shapes()
123 | #self.network.to(self._device)
124 | self.network.to(torch.device("cpu"))
125 |
126 | def start(self):
127 | for i in range(1, self._max_steps+1):
128 | # switch back to training mode
129 | self.network.train()
130 | self._first_grad_received = False
131 | enough_gradients_received = False
132 |
133 | logger.info("Master node is entering step: {}".format(i))
134 |
135 | self._bcast_weight()
136 |
137 | self._recv_grads()
138 |
139 | self._model_update()
140 |
141 | self.cur_step += 1
142 |
143 | def init_model_shapes(self):
144 | for param_idx, param in enumerate(self.network.parameters()):
145 | self._model_shapes.append(param.size())
146 | self._grad_aggregate_buffer.append(np.zeros(param.size()))
147 |
148 | def _model_update(self):
149 | # gradient shipped from workers are averaged and update the model
150 | self._grad_aggregate_buffer = [x / self._num_workers for x in self._grad_aggregate_buffer]
151 | self.optimizer.step(grads=self._grad_aggregate_buffer)
152 |
153 | def _bcast_weight(self):
154 | for layer_idx, layer in enumerate(self.network.parameters()):
155 | layer_weight = layer.detach()
156 | dist.broadcast(layer_weight, src=0)
157 |
158 | def aggregate_gradient(self, layer_idx, gradient):
159 | self._grad_aggregate_buffer[layer_idx] = reduce((lambda x, y: x + y), gradient[1:])
160 |
161 | def _recv_grads(self):
162 | for layer_idx, layer in enumerate(self.network.parameters()):
163 | dummpy_grad = self.grad_accumulator.gradient_aggregator[layer_idx][0]
164 | dist.gather(dummpy_grad, self.grad_accumulator.gradient_aggregator[layer_idx], dst=0)
165 | self.aggregate_gradient(layer_idx=layer_idx, gradient=self.grad_accumulator.gradient_aggregator[layer_idx])
166 |
167 | def _generate_model_path(self):
168 | return self._train_dir+"model_step_"+str(self.cur_step)
169 |
170 | def _save_model(self, file_path):
171 | with open(file_path, "wb") as f_:
172 | torch.save(self.network.state_dict(), f_)
173 | return
174 |
175 | def _evaluate_model(self, validation_loader):
176 | self.network.eval()
177 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0
178 | # which indicate an epoch based validation is done
179 | while validation_loader.dataset.epochs_completed <= self._epoch_counter:
180 | eval_image_batch, eval_label_batch = validation_loader.next_batch(batch_size=self._eval_batch_size)
181 | X_batch, y_batch = Variable(eval_image_batch.float()), Variable(eval_label_batch.long())
182 | output = self.network(X_batch)
183 | prec1_tmp, prec5_tmp = accuracy(output.detach(), eval_label_batch.long(), topk=(1, 5))
184 | prec1_counter_ += prec1_tmp
185 | prec5_counter_ += prec5_tmp
186 | batch_counter_ += 1
187 | prec1 = prec1_counter_ / batch_counter_
188 | prec5 = prec5_counter_ / batch_counter_
189 | self._epoch_counter = validation_loader.dataset.epochs_completed
190 | logger.info('Testset Performance: Cur Step:{} Prec@1: {} Prec@5: {}'.format(self.cur_step, prec1.numpy()[0], prec5.numpy()[0]))
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms
2 |
3 | from model_ops.lenet import LeNet
4 | from model_ops.resnet import *
5 | from model_ops.vgg import *
6 |
7 | def build_model(model_name, num_classes):
8 | # build network
9 | if model_name == "LeNet":
10 | return LeNet()
11 | elif model_name == "ResNet18":
12 | return ResNet18(num_classes)
13 | elif model_name == "ResNet34":
14 | return ResNet34()
15 | elif model_name == "ResNet50":
16 | return ResNet50()
17 | elif model_name == "VGG11":
18 | return vgg11_bn(num_classes)
19 |
20 | def prepare_data(args):
21 | # load training and test set here:
22 | if args.dataset == "MNIST":
23 | training_set = datasets.MNIST('./mnist_data', train=True, download=True,
24 | transform=transforms.Compose([
25 | transforms.ToTensor(),
26 | transforms.Normalize((0.1307,), (0.3081,))]))
27 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True)
28 | test_loader = torch.utils.data.DataLoader(
29 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.1307,), (0.3081,))
32 | ])), batch_size=args.test_batch_size, shuffle=True)
33 | elif args.dataset == "Cifar10":
34 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
35 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
36 | transform_train = transforms.Compose([
37 | transforms.ToTensor(),
38 | transforms.Lambda(lambda x: F.pad(
39 | Variable(x.unsqueeze(0), requires_grad=False),
40 | (4,4,4,4),mode='reflect').data.squeeze()),
41 | transforms.ToPILImage(),
42 | transforms.RandomCrop(32),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | normalize,
46 | ])
47 | # data prep for test set
48 | transform_test = transforms.Compose([
49 | transforms.ToTensor(),
50 | normalize])
51 | # load training and test set here:
52 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True,
53 | download=True, transform=transform_train)
54 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size,
55 | shuffle=True)
56 | testset = datasets.CIFAR10(root='./cifar10_data', train=False,
57 | download=True, transform=transform_test)
58 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
59 | shuffle=False)
60 | elif args.dataset == 'Cifar100':
61 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
62 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
63 | transform_train = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Lambda(lambda x: F.pad(
66 | Variable(x.unsqueeze(0), requires_grad=False),
67 | (4,4,4,4),mode='reflect').data.squeeze()),
68 | transforms.ToPILImage(),
69 | transforms.RandomCrop(32),
70 | transforms.RandomHorizontalFlip(),
71 | transforms.ToTensor(),
72 | normalize,
73 | ])
74 | # data prep for test set
75 | transform_test = transforms.Compose([
76 | transforms.ToTensor(),
77 | normalize])
78 | # load training and test set here:
79 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True,
80 | download=True, transform=transform_train)
81 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size,
82 | shuffle=True)
83 | testset = datasets.CIFAR100(root='./cifar100_data', train=False,
84 | download=True, transform=transform_test)
85 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
86 | shuffle=False)
87 | # SVHN dataset
88 | elif args.dataset == 'SVHN':
89 | training_set = datasets.SVHN('./svhn_data', split='train', transform=transforms.Compose([
90 | transforms.RandomCrop(32, padding=4),
91 | transforms.RandomHorizontalFlip(),
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
94 | ]))
95 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128,
96 | shuffle=True)
97 | transform_test = transforms.Compose([
98 | transforms.ToTensor(),
99 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
100 | ])
101 | testset = datasets.SVHN(root='./svhn_data', split='test',
102 | download=True, transform=transform_test)
103 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
104 | shuffle=False)
105 | return train_loader, test_loader
--------------------------------------------------------------------------------
/tools/README.md:
--------------------------------------------------------------------------------
1 | # AWS EC2 AMI for distributed Pythorch:
2 | `ami-ebc83c93` (with pytorch/pytorch vision installed and configured with CUDA 8.0 and cuDNN 7)
3 |
4 | # Enable `MPI` backend in pytorch distributed
5 | The frist thing you need to to is to remove your current version of `pytorch`, and build it from source (we may not need this later, but for now they're not enable `MPI` automatically in the binary source).
6 |
7 | To build pytorch from source, you can follow guidence here (https://github.com/pytorch/pytorch#from-source). But there are a few things you should be careful about.
8 |
9 | 1. make sure you're in your `conda env` when you run `python setup.py install` by
10 | ```
11 | source /home/user_name/anaconda[2 or 3]/bin/activate ~/anaconda[2 or 3]
12 | ```
13 | otherwise, pytorch will be built in your system lib directory rather than conda lib directory.
14 |
15 | 2. make sure you use CUDA (version >= 7.5), and have cuDNN (version >= 7.0) installed. I have a quick and easy way to do this in this github repo (https://github.com/hwang595/distributed-MXNet). I'm not sure why, even the version specified here is `CUDA 7.5`, but this method will still install `CUDA 8.0` for you. But this dosen't matter.
16 |
17 | To make sure if your `MPI` is enabled, just run:
18 | ```
19 | import torch
20 | torch.distributed.init_process_group(backend='mpi')
21 | ```
22 |
23 | After this, we also need to have `TorchVison` built from source by running following commands:
24 | ```
25 | git clone https://github.com/pytorch/vision.git
26 | source anaconda[2 or 3]/bin/activate ~/anaconda[2 or 3]
27 | python vision/setup.py install
28 | ```
29 |
30 | Note:
31 | To add your `.pem` file using `ssh-add`, run following command first:
32 | ```
33 | eval `ssh-agent -s`
34 | ```
35 |
--------------------------------------------------------------------------------
/tools/conda_install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # setup Anaconda env
3 | wget https://repo.continuum.io/archive/Anaconda2-5.1.0-Linux-x86_64.sh
4 | bash Anaconda2-5.1.0-Linux-x86_64.sh -b -p ~/anaconda
5 | rm Anaconda2-5.1.0-Linux-x86_64.sh
6 | echo 'export PATH="~/anaconda/bin:$PATH"' >> ~/.bashrc
7 |
8 | # Refresh basically
9 | source .bashrc
--------------------------------------------------------------------------------
/tools/config:
--------------------------------------------------------------------------------
1 | Host *
2 | StrictHostKeyChecking no
3 |
--------------------------------------------------------------------------------
/tools/hosts:
--------------------------------------------------------------------------------
1 | 172.31.19.212 deeplearning-worker1
2 | 172.31.25.216 deeplearning-worker2
3 | 172.31.16.210 deeplearning-worker3
4 | 172.31.18.242 deeplearning-worker4
5 | 172.31.30.165 deeplearning-worker5
6 | 172.31.19.35 deeplearning-worker6
7 | 172.31.21.31 deeplearning-worker7
8 | 172.31.21.232 deeplearning-worker8
9 | 172.31.26.254 deeplearning-worker9
10 | 172.31.29.227 deeplearning-worker10
11 | 172.31.21.65 deeplearning-worker11
12 | 172.31.19.22 deeplearning-worker12
13 | 172.31.26.181 deeplearning-worker13
14 | 172.31.23.5 deeplearning-worker14
15 | 172.31.23.177 deeplearning-worker15
16 | 172.31.29.128 deeplearning-worker16
17 | 172.31.21.63 deeplearning-worker17
18 | 172.31.20.152 deeplearning-worker18
19 | 172.31.24.123 deeplearning-worker19
20 | 172.31.28.206 deeplearning-worker20
21 | 172.31.25.60 deeplearning-worker21
22 | 172.31.21.191 deeplearning-worker22
23 | 172.31.18.144 deeplearning-worker23
24 | 172.31.21.28 deeplearning-worker24
25 | 172.31.22.235 deeplearning-worker25
26 | 172.31.16.15 deeplearning-worker26
27 | 172.31.18.223 deeplearning-worker27
28 | 172.31.17.46 deeplearning-worker28
29 | 172.31.31.45 deeplearning-worker29
30 | 172.31.27.44 deeplearning-worker30
31 | 172.31.17.249 deeplearning-worker31
32 | 172.31.22.20 deeplearning-worker32
33 | 172.31.22.226 deeplearning-worker33
34 | 172.31.19.15 deeplearning-worker34
35 | 172.31.21.118 deeplearning-worker35
36 | 172.31.19.49 deeplearning-worker36
37 | 172.31.21.136 deeplearning-worker37
38 | 172.31.25.85 deeplearning-worker38
39 | 172.31.18.69 deeplearning-worker39
40 | 172.31.16.134 deeplearning-worker40
41 | 172.31.25.83 deeplearning-worker41
42 | 172.31.16.223 deeplearning-worker42
43 | 172.31.20.120 deeplearning-worker43
44 | 172.31.17.89 deeplearning-worker44
45 | 172.31.18.20 deeplearning-worker45
46 | 172.31.31.170 deeplearning-worker46
47 | 172.31.19.76 deeplearning-worker47
48 | 172.31.23.122 deeplearning-worker48
49 | 172.31.24.201 deeplearning-worker49
50 | 172.31.18.63 deeplearning-worker50
51 | 172.31.26.191 deeplearning-worker51
52 | 172.31.24.249 deeplearning-worker52
53 | 172.31.22.91 deeplearning-worker53
54 | 172.31.16.114 deeplearning-worker54
55 | 172.31.26.5 deeplearning-worker55
56 | 172.31.20.194 deeplearning-worker56
57 | 172.31.30.60 deeplearning-worker57
58 | 172.31.23.145 deeplearning-worker58
59 | 172.31.18.13 deeplearning-worker59
60 | 172.31.26.9 deeplearning-worker60
61 | 172.31.21.173 deeplearning-worker61
62 | 172.31.20.40 deeplearning-worker62
63 | 172.31.18.128 deeplearning-worker63
64 | 172.31.28.200 deeplearning-worker64
65 | 172.31.27.141 deeplearning-worker65
66 |
--------------------------------------------------------------------------------
/tools/hosts_address:
--------------------------------------------------------------------------------
1 | 172.31.19.212
2 | 172.31.25.216
3 | 172.31.16.210
4 | 172.31.18.242
5 | 172.31.30.165
6 | 172.31.19.35
7 | 172.31.21.31
8 | 172.31.21.232
9 | 172.31.26.254
10 | 172.31.29.227
11 | 172.31.21.65
12 | 172.31.19.22
13 | 172.31.26.181
14 | 172.31.23.5
15 | 172.31.23.177
16 | 172.31.29.128
17 | 172.31.21.63
18 | 172.31.20.152
19 | 172.31.24.123
20 | 172.31.28.206
21 | 172.31.25.60
22 | 172.31.21.191
23 | 172.31.18.144
24 | 172.31.21.28
25 | 172.31.22.235
26 | 172.31.16.15
27 | 172.31.18.223
28 | 172.31.17.46
29 | 172.31.31.45
30 | 172.31.27.44
31 | 172.31.17.249
32 | 172.31.22.20
33 | 172.31.22.226
34 | 172.31.19.15
35 | 172.31.21.118
36 | 172.31.19.49
37 | 172.31.21.136
38 | 172.31.25.85
39 | 172.31.18.69
40 | 172.31.16.134
41 | 172.31.25.83
42 | 172.31.16.223
43 | 172.31.20.120
44 | 172.31.17.89
45 | 172.31.18.20
46 | 172.31.31.170
47 | 172.31.19.76
48 | 172.31.23.122
49 | 172.31.24.201
50 | 172.31.18.63
51 | 172.31.26.191
52 | 172.31.24.249
53 | 172.31.22.91
54 | 172.31.16.114
55 | 172.31.26.5
56 | 172.31.20.194
57 | 172.31.30.60
58 | 172.31.23.145
59 | 172.31.18.13
60 | 172.31.26.9
61 | 172.31.21.173
62 | 172.31.20.40
63 | 172.31.18.128
64 | 172.31.28.200
65 | 172.31.27.141
66 |
--------------------------------------------------------------------------------
/tools/hosts_alias:
--------------------------------------------------------------------------------
1 | deeplearning-worker1
2 | deeplearning-worker2
3 | deeplearning-worker3
4 | deeplearning-worker4
5 | deeplearning-worker5
6 | deeplearning-worker6
7 | deeplearning-worker7
8 | deeplearning-worker8
9 | deeplearning-worker9
10 | deeplearning-worker10
11 | deeplearning-worker11
12 | deeplearning-worker12
13 | deeplearning-worker13
14 | deeplearning-worker14
15 | deeplearning-worker15
16 | deeplearning-worker16
17 | deeplearning-worker17
18 | deeplearning-worker18
19 | deeplearning-worker19
20 | deeplearning-worker20
21 | deeplearning-worker21
22 | deeplearning-worker22
23 | deeplearning-worker23
24 | deeplearning-worker24
25 | deeplearning-worker25
26 | deeplearning-worker26
27 | deeplearning-worker27
28 | deeplearning-worker28
29 | deeplearning-worker29
30 | deeplearning-worker30
31 | deeplearning-worker31
32 | deeplearning-worker32
33 | deeplearning-worker33
34 | deeplearning-worker34
35 | deeplearning-worker35
36 | deeplearning-worker36
37 | deeplearning-worker37
38 | deeplearning-worker38
39 | deeplearning-worker39
40 | deeplearning-worker40
41 | deeplearning-worker41
42 | deeplearning-worker42
43 | deeplearning-worker43
44 | deeplearning-worker44
45 | deeplearning-worker45
46 | deeplearning-worker46
47 | deeplearning-worker47
48 | deeplearning-worker48
49 | deeplearning-worker49
50 | deeplearning-worker50
51 | deeplearning-worker51
52 | deeplearning-worker52
53 | deeplearning-worker53
54 | deeplearning-worker54
55 | deeplearning-worker55
56 | deeplearning-worker56
57 | deeplearning-worker57
58 | deeplearning-worker58
59 | deeplearning-worker59
60 | deeplearning-worker60
61 | deeplearning-worker61
62 | deeplearning-worker62
63 | deeplearning-worker63
64 | deeplearning-worker64
65 | deeplearning-worker65
66 |
--------------------------------------------------------------------------------
/tools/install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # append hosts
4 | sudo bash -c "cat hosts >> /etc/hosts"
5 | cp config ~/.ssh/
6 |
7 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
8 |
9 | git config --global user.name hwang595
10 | git config --global user.email hongyiwang@cs.wisc.edu
11 |
12 | sudo apt-get update
13 | sudo apt-get install pdsh -y
14 | pdsh -R ssh -w deeplearning-worker[1-$DEEPLEARNING_WORKERS_COUNT] "sudo apt-get update; sudo apt-get install pdsh -y"
15 |
--------------------------------------------------------------------------------
/tools/killall.sh:
--------------------------------------------------------------------------------
1 | KEY_PEM_NAME=~/.ssh/HongyiScript2.pem
2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
3 |
4 | sudo bash -c "cat hosts >> /etc/hosts"
5 | cp config ~/.ssh/
6 |
7 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT);
8 | do
9 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'killall python'
10 | done
--------------------------------------------------------------------------------
/tools/local_script.sh:
--------------------------------------------------------------------------------
1 | KEY_PEM_DIR=/home/hwang/My_Code/AWS/HongyiScript.pem
2 | KEY_PEM_NAME=HongyiScript.pem
3 | PUB_IP_ADDR="$1"
4 | echo "Public address of master node: ${PUB_IP_ADDR}"
5 |
6 | ssh -o "StrictHostKeyChecking no" ubuntu@${PUB_IP_ADDR}
7 | scp -i ${KEY_PEM_DIR} ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR}:~/.ssh
8 | scp -i ${KEY_PEM_DIR} hosts hosts_address config ubuntu@${PUB_IP_ADDR}:~/
9 | scp -i ${KEY_PEM_DIR} -r /home/hwang/My_Code/ps_real_pytorch ubuntu@${PUB_IP_ADDR}:~/
10 | ssh -i ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR} 'cp ps_real_pytorch/tools/remote_script.sh ~/'
11 | ssg -i ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR} 'cp hosts ps_real_pytorch/src/'
12 |
--------------------------------------------------------------------------------
/tools/openmpi_install.sh:
--------------------------------------------------------------------------------
1 | # configure, download, and install OpenMPI
2 | sudo apt-get update
3 | sudo apt-get -y install gcc g++ make
4 | sudo apt-get update
5 | wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.1.tar.gz
6 | tar -xvf openmpi-*
7 | rm -f openmpi-3.0.1.tar.gz
8 | cd ~/openmpi-3.0.1
9 | ./configure --prefix="/home/$USER/.openmpi"
10 | make
11 | sudo make install
12 | export PATH="$PATH:/home/$USER/.openmpi/bin"
13 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/$USER/.openmpi/lib/"
--------------------------------------------------------------------------------
/tools/pre_run.sh:
--------------------------------------------------------------------------------
1 | conda update -y -n base conda
2 | conda install pytorch torchvision -y -c pytorch
3 | conda install -y -c anaconda python-blosc
4 | conda install -y -c anaconda mpi4py
--------------------------------------------------------------------------------
/tools/pytorch_ec2.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import threading
4 | import Queue
5 | import paramiko as pm
6 | import boto3
7 | import time
8 | import json
9 | import os
10 |
11 |
12 | class Cfg(dict):
13 |
14 | def __getitem__(self, item):
15 | item = dict.__getitem__(self, item)
16 | if type(item) == type([]):
17 | return [x % self if type(x) == type("") else x for x in item]
18 | if type(item) == type(""):
19 | return item % self
20 | return item
21 |
22 | cfg = Cfg({
23 | "name" : "Timeout", # Unique name for this specific configuration
24 | "key_name": "HongyiScript", # Necessary to ssh into created instances
25 | # Cluster topology
26 | "n_masters" : 1, # Should always be 1
27 | "n_workers" : 64,
28 | "num_replicas_to_aggregate" : "8",
29 | "method" : "spot",
30 | # Region speficiation
31 | "region" : "us-west-2",
32 | "availability_zone" : "us-west-2b",
33 | # Machine type - instance type configuration.
34 | "master_type" : "m5.2xlarge",
35 | "worker_type" : "m5.2xlarge",
36 | # please only use this AMI for pytorch
37 | "image_id": "ami-0d3c2f72d1499d30b",
38 | # Launch specifications
39 | "spot_price" : "0.15", # Has to be a string
40 | # SSH configuration
41 | "ssh_username" : "ubuntu", # For sshing. E.G: ssh ssh_username@hostname
42 | "path_to_keyfile" : "/home/hwang/My_Code/AWS/HongyiScript.pem",
43 |
44 | # NFS configuration
45 | # To set up these values, go to Services > ElasticFileSystem > Create new filesystem, and follow the directions.
46 | #"nfs_ip_address" : "172.31.3.173", # us-west-2c
47 | #"nfs_ip_address" : "172.31.35.0", # us-west-2a
48 | "nfs_ip_address" : "172.31.14.225", # us-west-2b
49 | "nfs_mount_point" : "/home/ubuntu/shared", # NFS base dir
50 | "base_out_dir" : "%(nfs_mount_point)s/%(name)s", # Master writes checkpoints to this directory. Outfiles are written to this directory.
51 | "setup_commands" :
52 | [
53 | # "sudo rm -rf %(base_out_dir)s",
54 | "mkdir %(base_out_dir)s",
55 | ],
56 | # Command specification
57 | # Master pre commands are run only by the master
58 | "master_pre_commands" :
59 | [
60 | "cd my_mxnet",
61 | "git fetch && git reset --hard origin/master",
62 | "cd cifar10",
63 | "ls",
64 | # "cd distributed_tensorflow/DistributedResNet",
65 | # "git fetch && git reset --hard origin/master",
66 | ],
67 | # Pre commands are run on every machine before the actual training.
68 | "pre_commands" :
69 | [
70 | "cd my_mxnet",
71 | "git fetch && git reset --hard origin/master",
72 | "cd cifar10",
73 | ],
74 | # Model configuration
75 | "batch_size" : "32",
76 | "max_steps" : "2000",
77 | "initial_learning_rate" : ".001",
78 | "learning_rate_decay_factor" : ".95",
79 | "num_epochs_per_decay" : "1.0",
80 | # Train command specifies how the ps/workers execute tensorflow.
81 | # PS_HOSTS - special string replaced with actual list of ps hosts.
82 | # TASK_ID - special string replaced with actual task index.
83 | # JOB_NAME - special string replaced with actual job name.
84 | # WORKER_HOSTS - special string replaced with actual list of worker hosts
85 | # ROLE_ID - special string replaced with machine's identity (E.G: master, worker0, worker1, ps, etc)
86 | # %(...)s - Inserts self referential string value.
87 | "train_commands" :
88 | [
89 | "echo ========= Start ==========="
90 | ],
91 | })
92 |
93 | def mxnet_ec2_run(argv, configuration):
94 | client = boto3.client("ec2", region_name=configuration["region"])
95 | ec2 = boto3.resource("ec2", region_name=configuration["region"])
96 |
97 | def sleep_a_bit():
98 | time.sleep(5)
99 |
100 | def summarize_instances(instances):
101 | instance_type_to_instance_map = {}
102 | for instance in sorted(instances, key=lambda x:x.id):
103 | typ = instance.instance_type
104 | if typ not in instance_type_to_instance_map:
105 | instance_type_to_instance_map[typ] = []
106 | instance_type_to_instance_map[typ].append(instance)
107 |
108 | for type in instance_type_to_instance_map:
109 | print("Type\t", type)
110 | for instance in instance_type_to_instance_map[type]:
111 | print("instance\t", instance, "\t", instance.public_ip_address)
112 | print
113 |
114 | for k,v in instance_type_to_instance_map.items():
115 | print("%s - %d running" % (k, len(v)))
116 |
117 | return instance_type_to_instance_map
118 |
119 | def summarize_idle_instances(argv):
120 | print("Idle instances: (Idle = not running tensorflow)")
121 | summarize_instances(get_idle_instances())
122 |
123 | def summarize_running_instances(argv):
124 | print("Running instances: ")
125 | summarize_instances(ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}]))
126 |
127 | # Terminate all request.
128 | def terminate_all_requests():
129 | spot_requests = client.describe_spot_instance_requests()
130 | spot_request_ids = []
131 | for spot_request in spot_requests["SpotInstanceRequests"]:
132 | if spot_request["State"] != "cancelled" and spot_request["LaunchSpecification"]["KeyName"] == configuration["key_name"]:
133 | spot_request_id = spot_request["SpotInstanceRequestId"]
134 | spot_request_ids.append(spot_request_id)
135 |
136 | if len(spot_request_ids) != 0:
137 | print("Terminating spot requests: %s" % " ".join([str(x) for x in spot_request_ids]))
138 | client.cancel_spot_instance_requests(SpotInstanceRequestIds=spot_request_ids)
139 |
140 | # Wait until all are cancelled.
141 | # TODO: Use waiter class
142 | done = False
143 | while not done:
144 | print("Waiting for all spot requests to be terminated...")
145 | done = True
146 | spot_requests = client.describe_spot_instance_requests()
147 | states = [x["State"] for x in spot_requests["SpotInstanceRequests"] if x["LaunchSpecification"]["KeyName"] == configuration["key_name"]]
148 | for state in states:
149 | if state != "cancelled":
150 | done = False
151 | sleep_a_bit()
152 |
153 | # Terminate all instances in the configuration
154 | # Note: all_instances = ec2.instances.all() to get all intances
155 | def terminate_all_instances():
156 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}])
157 | all_instance_ids = [x.id for x in live_instances]
158 | print([x.id for x in live_instances])
159 | if len(all_instance_ids) != 0:
160 | print("Terminating instances: %s" % (" ".join([str(x) for x in all_instance_ids])))
161 | client.terminate_instances(InstanceIds=all_instance_ids)
162 |
163 | # Wait until all are terminated
164 | # TODO: Use waiter class
165 | done = False
166 | while not done:
167 | print("Waiting for all instances to be terminated...")
168 | done = True
169 | instances = ec2.instances.all()
170 | for instance in instances:
171 | if instance.state == "active":
172 | done = False
173 | sleep_a_bit()
174 |
175 | # Launch instances as specified in the configuration.
176 | def launch_instances():
177 | method = "spot"
178 | if "method" in configuration.keys():
179 | method = configuration["method"]
180 | worker_instance_type, worker_count = configuration["worker_type"], configuration["n_workers"]
181 | master_instance_type, master_count = configuration["master_type"], configuration["n_masters"]
182 | specs = [(worker_instance_type, worker_count),
183 | (master_instance_type, master_count)]
184 | for (instance_type, count) in specs:
185 | launch_specs = {"KeyName" : configuration["key_name"],
186 | "ImageId" : configuration["image_id"],
187 | "InstanceType" : instance_type,
188 | "Placement" : {"AvailabilityZone":configuration["availability_zone"]},
189 | "SecurityGroups": ["default"]}
190 | if method == "spot":
191 | # TODO: EBS optimized? (Will incur extra hourly cost)
192 | client.request_spot_instances(InstanceCount=count,
193 | LaunchSpecification=launch_specs,
194 | SpotPrice=configuration["spot_price"])
195 | elif method == "reserved":
196 | client.run_instances(ImageId=launch_specs["ImageId"],
197 | MinCount=count,
198 | MaxCount=count,
199 | KeyName=launch_specs["KeyName"],
200 | InstanceType=launch_specs["InstanceType"],
201 | Placement=launch_specs["Placement"],
202 | SecurityGroups=launch_specs["SecurityGroups"])
203 | else:
204 | print("Unknown method: %s" % method)
205 | sys.exit(-1)
206 |
207 |
208 | # TODO: use waiter class?
209 | def wait_until_running_instances_initialized():
210 | done = False
211 | while not done:
212 | print("Waiting for instances to be initialized...")
213 | done = True
214 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}])
215 | ids = [x.id for x in live_instances]
216 | resps_list = [client.describe_instance_status(InstanceIds=ids[i:i+50]) for i in range(0, len(ids), 50)]
217 | statuses = []
218 | for resp in resps_list:
219 | statuses += [x["InstanceStatus"]["Status"] for x in resp["InstanceStatuses"]]
220 | #resps = client.describe_instance_status(InstanceIds=ids)
221 | #for resp in resps["InstanceStatuses"]:
222 | # if resp["InstanceStatus"]["Status"] != "ok":
223 | # done = False
224 | print(statuses)
225 | done = statuses.count("ok") == len(statuses)
226 | if len(ids) <= 0:
227 | done = False
228 | sleep_a_bit()
229 |
230 | # Waits until status requests are all fulfilled.
231 | # Prints out status of request in between time waits.
232 | # TODO: Use waiter class
233 | def wait_until_instance_request_status_fulfilled():
234 | requests_fulfilled = False
235 | n_active_or_open = 0
236 | while not requests_fulfilled or n_active_or_open == 0:
237 | requests_fulfilled = True
238 | statuses = client.describe_spot_instance_requests()
239 | print("InstanceRequestId, InstanceType, SpotPrice, State - Status : StatusMessage")
240 | print("-------------------------------------------")
241 | n_active_or_open = 0
242 | for instance_request in statuses["SpotInstanceRequests"]:
243 | if instance_request["LaunchSpecification"]["KeyName"] != configuration["key_name"]:
244 | continue
245 | sid = instance_request["SpotInstanceRequestId"]
246 | machine_type = instance_request["LaunchSpecification"]["InstanceType"]
247 | price = instance_request["SpotPrice"]
248 | state = instance_request["State"]
249 | status, status_string = instance_request["Status"]["Code"], instance_request["Status"]["Message"]
250 | if state == "active" or state == "open":
251 | n_active_or_open += 1
252 | print("%s, %s, %s, %s - %s : %s" % (sid, machine_type, price, state, status, status_string))
253 | if state != "active":
254 | requests_fulfilled = False
255 | print("-------------------------------------------")
256 | sleep_a_bit()
257 |
258 | # Create a client to the instance
259 | def connect_client(instance):
260 | client = pm.SSHClient()
261 | host = instance.public_ip_address
262 | client.set_missing_host_key_policy(pm.AutoAddPolicy())
263 | client.connect(host, username=configuration["ssh_username"], key_filename=configuration["path_to_keyfile"])
264 | return client
265 |
266 | # Takes a list of commands (E.G: ["ls", "cd models"]
267 | # and executes command on instance, returning the stdout.
268 | # Executes everything in one session, and returns all output from all the commands.
269 | def run_ssh_commands(instance, commands):
270 | done = False
271 | while not done:
272 | try:
273 | print("Instance %s, Running ssh commands:\n%s\n" % (instance.public_ip_address, "\n".join(commands)))
274 |
275 | # Always need to exit
276 | commands.append("exit")
277 |
278 | # Set up ssh client
279 | client = connect_client(instance)
280 |
281 | # Clear the stdout from ssh'ing in
282 | # For each command perform command and read stdout
283 | commandstring = "\n".join(commands)
284 | stdin, stdout, stderr = client.exec_command(commandstring)
285 | output = stdout.read()
286 |
287 | # Close down
288 | stdout.close()
289 | stdin.close()
290 | client.close()
291 | done = True
292 | except Exception as e:
293 | done = False
294 | print(e.message)
295 | return output
296 |
297 | def run_ssh_commands_parallel(instance, commands, q):
298 | output = run_ssh_commands(instance, commands)
299 | q.put((instance, output))
300 |
301 | # Checks whether instance is idle. Assumed that instance is up and running.
302 | # An instance is idle if it is not running tensorflow...
303 | # Returns a tuple of (instance, is_instance_idle). We return a tuple for multithreading ease.
304 | def is_instance_idle(q, instance):
305 | python_processes = run_ssh_commands(instance, ["ps aux | grep python"])
306 | q.put((instance, not "ps_hosts" in python_processes and not "ps_workers" in python_processes))
307 |
308 | # Idle instances are running instances that are not running the inception model.
309 | # We check whether an instance is running the inception model by ssh'ing into a running machine,
310 | # and checking whether python is running.
311 | def get_idle_instances():
312 | live_instances = ec2.instances.filter(
313 | Filters=[{'Name': 'instance-state-name', 'Values': ['running']},
314 | {'Name': 'key-name', 'Values': [configuration["key_name"]]}])
315 | threads = []
316 | q = Queue.Queue()
317 |
318 | # Run commands in parallel, writing to the queue
319 | for instance in live_instances:
320 | t = threading.Thread(target=is_instance_idle, args=(q, instance))
321 | t.daemon = True
322 | t.start()
323 | threads.append(t)
324 |
325 | # Wait for threads to finish
326 | for thread in threads:
327 | thread.join()
328 |
329 | # Collect idle instances
330 | idle_instances = []
331 | while not q.empty():
332 | instance, is_idle = q.get()
333 | if is_idle:
334 | idle_instances.append(instance)
335 |
336 | return idle_instances
337 |
338 | def get_instance_requirements():
339 | # Get the requirements given the specification of worker/master/etc machine types
340 | worker_instance_type, worker_count = configuration["worker_type"], configuration["n_workers"]
341 | master_instance_type, master_count = configuration["master_type"], configuration["n_masters"]
342 | specs = [(worker_instance_type, worker_count),
343 | (master_instance_type, master_count)]
344 | reqs = {}
345 | for (type_needed, count_needed) in specs:
346 | if type_needed not in reqs:
347 | reqs[type_needed] = 0
348 | reqs[type_needed] += count_needed
349 | return reqs
350 |
351 | # Returns whether the idle instances satisfy the specs of the configuration.
352 | def check_idle_instances_satisfy_configuration():
353 | # Create a map of instance types to instances of that type
354 | idle_instances = get_idle_instances()
355 | instance_type_to_instance_map = summarize_instances(idle_instances)
356 |
357 | # Get instance requirements
358 | reqs = get_instance_requirements()
359 |
360 | # Check the requirements are satisfied.
361 | print("Checking whether # of running instances satisfies the configuration...")
362 | for k,v in instance_type_to_instance_map.items():
363 | n_required = 0 if k not in reqs else reqs[k]
364 | print("%s - %d running vs %d required" % (k,len(v),n_required))
365 | if len(v) < n_required:
366 | print("Error, running instances failed to satisfy configuration requirements")
367 | sys.exit(0)
368 | print("Success, running instances satisfy configuration requirement")
369 |
370 | def shut_everything_down(argv):
371 | terminate_all_requests()
372 | terminate_all_instances()
373 |
374 | def run_mxnet_grid_search(argv, port=1334):
375 | # Check idle instances satisfy configs
376 | check_idle_instances_satisfy_configuration()
377 |
378 | # Get idle instances
379 | idle_instances = get_idle_instances()
380 |
381 | # Assign instances for worker/ps/etc
382 | instance_type_to_instance_map = summarize_instances(idle_instances)
383 | specs = {
384 | "master" : {"instance_type" : configuration["master_type"],
385 | "n_required" : configuration["n_masters"]},
386 | "worker" : {"instance_type" : configuration["worker_type"],
387 | "n_required" : configuration["n_workers"]}
388 | }
389 | machine_assignments = {
390 | "master" : [],
391 | "worker" : []
392 | }
393 | for role, requirement in sorted(specs.items(), key=lambda x:x[0]):
394 | instance_type_for_role = requirement["instance_type"]
395 | n_instances_needed = requirement["n_required"]
396 | instances_to_assign, rest = instance_type_to_instance_map[instance_type_for_role][:n_instances_needed], instance_type_to_instance_map[instance_type_for_role][n_instances_needed:]
397 | instance_type_to_instance_map[instance_type_for_role] = rest
398 | machine_assignments[role] = instances_to_assign
399 |
400 | # Construct the host strings necessary for running the inception command.
401 | # Note we use private ip addresses to avoid EC2 transfer costs.
402 | worker_host_string = ",".join([x.private_ip_address+":"+str(port) for x in machine_assignments["master"] + machine_assignments["worker"]])
403 |
404 | # Create a map of command&machine assignments
405 | command_machine_assignments = {}
406 | setup_machine_assignments = {}
407 |
408 | # Construct the master command
409 | command_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["master_pre_commands"])}
410 | # setup_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["setup_commands"])}
411 | for command_string in configuration["train_commands"]:
412 | command_machine_assignments["master"]["commands"].append(command_string.replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", "master"))
413 | print(command_machine_assignments)
414 |
415 | # Construct the worker commands
416 | for worker_id, instance in enumerate(machine_assignments["worker"]):
417 | name = "worker_%d" % worker_id
418 | command_machine_assignments[name] = {"instance" : instance,
419 | "commands" : list(configuration["pre_commands"])}
420 | for command_string in configuration["train_commands"]:
421 | command_machine_assignments[name]["commands"].append(command_string.replace("TASK_ID", "%d" % (worker_id+1)).replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", name))
422 |
423 | print(command_machine_assignments)
424 |
425 | # Run the commands via ssh in parallel
426 | threads = []
427 | q = Queue.Queue()
428 |
429 | for name, command_and_machine in setup_machine_assignments.items():
430 | instance = command_and_machine["instance"]
431 | commands = command_and_machine["commands"]
432 | print("-----------------------")
433 | print("Pre Command: %s\n" % " ".join(commands))
434 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
435 | t.start()
436 | threads.append(t)
437 |
438 | # Wait until commands are all finished
439 | for t in threads:
440 | t.join()
441 |
442 | threads = []
443 | q = Queue.Queue()
444 |
445 | running_process = 0
446 | for name, command_and_machine in command_machine_assignments.items():
447 | instance = command_and_machine["instance"]
448 | neo_commands = "python train_cifar10.py --running_mode=grid_search --gpus=0 "\
449 | "--running_process={} "\
450 | "--batch-size={} "\
451 | "--dir={}/grid_search> {}/grid_search/batch_size_{}/running_{}_process.out 2>&1 &".format(
452 | running_process,
453 | configuration['batch_size'],
454 | configuration['nfs_mount_point'],
455 | configuration['nfs_mount_point'],
456 | configuration['batch_size'],
457 | running_process)
458 |
459 | commands = command_and_machine["commands"]
460 | commands.append('mkdir {}/grid_search'.format(configuration['nfs_mount_point']))
461 | commands.append('mkdir {}/grid_search/batch_size_{}'.format(
462 | configuration['nfs_mount_point'],
463 | configuration['batch_size']))
464 | commands.append(neo_commands)
465 |
466 | print("-----------------------")
467 | print("Command: %s\n" % " ".join(commands))
468 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
469 | t.start()
470 | threads.append(t)
471 | running_process += 1
472 |
473 | # Wait until commands are all finished
474 | for t in threads:
475 | t.join()
476 |
477 | # Print the output
478 | while not q.empty():
479 | instance, output = q.get()
480 | print(instance.public_ip_address)
481 | print(output)
482 |
483 | # Debug print
484 | instances = []
485 | print("\n--------------------------------------------------\n")
486 | print("Machine assignments:")
487 | print("------------------------")
488 | for name, command_and_machine in command_machine_assignments.items():
489 | instance = command_and_machine["instance"]
490 | instances.append(instance)
491 | commands = command_and_machine["commands"]
492 | ssh_command = "ssh -i %s %s@%s" % (configuration["path_to_keyfile"], configuration["ssh_username"], instance.public_ip_address)
493 | print("%s - %s" % (name, instance.instance_id))
494 | print("To ssh: %s" % ssh_command)
495 | print("------------------------")
496 |
497 | # Print out list of instance ids (which will be useful in selctively stopping inception
498 | # for given instances.
499 | instance_cluster_string = ",".join([x.instance_id for x in instances])
500 | print("\nInstances cluster string: %s" % instance_cluster_string)
501 |
502 | # Print out the id of the configuration file
503 | cluster_save = {
504 | "configuration" : configuration,
505 | "name" : configuration["name"],
506 | "command_machine_assignments" : command_machine_assignments,
507 | "cluster_string" : instance_cluster_string
508 | }
509 |
510 | return cluster_save
511 |
512 |
513 | def run_mxnet_loss_curve(argv, port=1334):
514 | # Check idle instances satisfy configs
515 | check_idle_instances_satisfy_configuration()
516 |
517 | # Get idle instances
518 | idle_instances = get_idle_instances()
519 |
520 | # Assign instances for worker/ps/etc
521 | instance_type_to_instance_map = summarize_instances(idle_instances)
522 | specs = {
523 | "master" : {"instance_type" : configuration["master_type"],
524 | "n_required" : configuration["n_masters"]},
525 | "worker" : {"instance_type" : configuration["worker_type"],
526 | "n_required" : configuration["n_workers"]}
527 | }
528 | machine_assignments = {
529 | "master" : [],
530 | "worker" : []
531 | }
532 | for role, requirement in sorted(specs.items(), key=lambda x:x[0]):
533 | instance_type_for_role = requirement["instance_type"]
534 | n_instances_needed = requirement["n_required"]
535 | instances_to_assign, rest = instance_type_to_instance_map[instance_type_for_role][:n_instances_needed], instance_type_to_instance_map[instance_type_for_role][n_instances_needed:]
536 | instance_type_to_instance_map[instance_type_for_role] = rest
537 | machine_assignments[role] = instances_to_assign
538 |
539 | # Construct the host strings necessary for running the inception command.
540 | # Note we use private ip addresses to avoid EC2 transfer costs.
541 | worker_host_string = ",".join([x.private_ip_address+":"+str(port) for x in machine_assignments["master"] + machine_assignments["worker"]])
542 |
543 | # Create a map of command&machine assignments
544 | command_machine_assignments = {}
545 | setup_machine_assignments = {}
546 |
547 | # Construct the master command
548 | command_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["master_pre_commands"])}
549 | # setup_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["setup_commands"])}
550 | for command_string in configuration["train_commands"]:
551 | command_machine_assignments["master"]["commands"].append(command_string.replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", "master"))
552 | print(command_machine_assignments)
553 |
554 | # Construct the worker commands
555 | for worker_id, instance in enumerate(machine_assignments["worker"]):
556 | name = "worker_%d" % worker_id
557 | command_machine_assignments[name] = {"instance" : instance,
558 | "commands" : list(configuration["pre_commands"])}
559 | for command_string in configuration["train_commands"]:
560 | command_machine_assignments[name]["commands"].append(command_string.replace("TASK_ID", "%d" % (worker_id+1)).replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", name))
561 |
562 | print(command_machine_assignments)
563 |
564 | # Run the commands via ssh in parallel
565 | threads = []
566 | q = Queue.Queue()
567 |
568 | for name, command_and_machine in setup_machine_assignments.items():
569 | instance = command_and_machine["instance"]
570 | commands = command_and_machine["commands"]
571 | print("-----------------------")
572 | print("Pre Command: %s\n" % " ".join(commands))
573 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
574 | t.start()
575 | threads.append(t)
576 |
577 | # Wait until commands are all finished
578 | for t in threads:
579 | t.join()
580 |
581 | threads = []
582 | q = Queue.Queue()
583 |
584 | batch_size_list = [4, 32, 50, 100, 500, 1000]
585 | learning_rate_list = [0.046, 0.05, 0.068, 0.068, 0.048, 0.086]
586 | running_process = 0
587 | for name, command_and_machine in command_machine_assignments.items():
588 | instance = command_and_machine["instance"]
589 | neo_commands = "python train_cifar10.py --running_mode=training --gpus=0 "\
590 | "--batch-size={} "\
591 | "--lr={} "\
592 | "--model-prefix={}/model_checkpoints/batch_size_{} "\
593 | "--dir={}/loss_curve > "\
594 | "{}/loss_curve/running_batch_size_{}.out 2>&1 &".format(
595 | batch_size_list[running_process],
596 | learning_rate_list[running_process],
597 | configuration['nfs_mount_point'],
598 | batch_size_list[running_process],
599 | configuration['nfs_mount_point'],
600 | configuration['nfs_mount_point'],
601 | batch_size_list[running_process])
602 |
603 | commands = command_and_machine["commands"]
604 | commands.append('mkdir {}/model_checkpoints/'.format(configuration['nfs_mount_point']))
605 | commands.append('mkdir {}/loss_curve'.format(configuration['nfs_mount_point']))
606 | commands.append(neo_commands)
607 |
608 | print("-----------------------")
609 | print("Command: %s\n" % " ".join(commands))
610 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
611 | t.start()
612 | threads.append(t)
613 | running_process += 1
614 |
615 | # Wait until commands are all finished
616 | for t in threads:
617 | t.join()
618 |
619 | # Print the output
620 | while not q.empty():
621 | instance, output = q.get()
622 | print(instance.public_ip_address)
623 | print(output)
624 |
625 | # Debug print
626 | instances = []
627 | print("\n--------------------------------------------------\n")
628 | print("Machine assignments:")
629 | print("------------------------")
630 | for name, command_and_machine in command_machine_assignments.items():
631 | instance = command_and_machine["instance"]
632 | instances.append(instance)
633 | commands = command_and_machine["commands"]
634 | ssh_command = "ssh -i %s %s@%s" % (configuration["path_to_keyfile"], configuration["ssh_username"], instance.public_ip_address)
635 | print("%s - %s" % (name, instance.instance_id))
636 | print("To ssh: %s" % ssh_command)
637 | print("------------------------")
638 |
639 | # Print out list of instance ids (which will be useful in selctively stopping inception
640 | # for given instances.
641 | instance_cluster_string = ",".join([x.instance_id for x in instances])
642 | print("\nInstances cluster string: %s" % instance_cluster_string)
643 |
644 | # Print out the id of the configuration file
645 | cluster_save = {
646 | "configuration" : configuration,
647 | "name" : configuration["name"],
648 | "command_machine_assignments" : command_machine_assignments,
649 | "cluster_string" : instance_cluster_string
650 | }
651 |
652 | return cluster_save
653 |
654 |
655 |
656 | def get_hosts(argv, port=22):
657 | # Check idle instances satisfy configs
658 | check_idle_instances_satisfy_configuration()
659 |
660 | # Get idle instances
661 | idle_instances = get_idle_instances()
662 |
663 | # Assign instances for worker/ps/etc
664 | instance_type_to_instance_map = summarize_instances(idle_instances)
665 | specs = {
666 | "master" : {"instance_type" : configuration["master_type"],
667 | "n_required" : configuration["n_masters"]},
668 | "worker" : {"instance_type" : configuration["worker_type"],
669 | "n_required" : configuration["n_workers"]}
670 | }
671 | machine_assignments = {
672 | "master" : [],
673 | "worker" : []
674 | }
675 | for role, requirement in sorted(specs.items(), key=lambda x:x[0]):
676 | instance_type_for_role = requirement["instance_type"]
677 | n_instances_needed = requirement["n_required"]
678 | instances_to_assign, rest = instance_type_to_instance_map[instance_type_for_role][:n_instances_needed], instance_type_to_instance_map[instance_type_for_role][n_instances_needed:]
679 | instance_type_to_instance_map[instance_type_for_role] = rest
680 | machine_assignments[role] = instances_to_assign
681 |
682 | # Construct the host strings necessary for running the inception command.
683 | # Note we use private ip addresses to avoid EC2 transfer costs.
684 | worker_host_string = ",".join([x.private_ip_address+":"+str(port) for x in machine_assignments["master"] + machine_assignments["worker"]])
685 | hosts_out = open('hosts', 'w')
686 | print('master ip ', machine_assignments['master'][0].public_ip_address)
687 | count = 0
688 | for instance in machine_assignments["master"] + machine_assignments["worker"]:
689 | count += 1
690 | print('{}\tdeeplearning-worker{}'.format(instance.private_ip_address, count), end='\n', file=hosts_out)
691 | hosts_out.flush()
692 | hosts_out.close()
693 |
694 | hosts_alias_out = open('hosts_alias', 'w')
695 | count = 0
696 | for _ in machine_assignments["master"] + machine_assignments["worker"]:
697 | count += 1
698 | print('deeplearning-worker{}'.format(count), end='\n', file=hosts_alias_out)
699 | hosts_alias_out.flush()
700 | hosts_alias_out.close()
701 |
702 | hosts_alias_out = open('hosts_address', 'w')
703 | count = 0
704 | for instance in machine_assignments["master"] + machine_assignments["worker"]:
705 | count += 1
706 | print('{}'.format(instance.private_ip_address), end='\n', file=hosts_alias_out)
707 | hosts_alias_out.flush()
708 | hosts_alias_out.close()
709 |
710 | # # Create a map of command&machine assignments
711 | # command_machine_assignments = {}
712 | # setup_machine_assignments = {}
713 | #
714 | # # Construct the master command
715 | # command_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["master_pre_commands"])}
716 | # for command_string in configuration["train_commands"]:
717 | # command_machine_assignments["master"]["commands"].append(command_string.replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", "master"))
718 | # print(command_machine_assignments)
719 | #
720 | # # Construct the worker commands
721 | # for worker_id, instance in enumerate(machine_assignments["worker"]):
722 | # name = "worker_%d" % worker_id
723 | # command_machine_assignments[name] = {"instance" : instance,
724 | # "commands" : list(configuration["pre_commands"])}
725 | # for command_string in configuration["train_commands"]:
726 | # command_machine_assignments[name]["commands"].append(command_string.replace("TASK_ID", "%d" % (worker_id+1)).replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", name))
727 | #
728 | # print(command_machine_assignments)
729 | #
730 | # # Run the commands via ssh in parallel
731 | # threads = []
732 | # q = Queue.Queue()
733 | #
734 | # for name, command_and_machine in setup_machine_assignments.items():
735 | # instance = command_and_machine["instance"]
736 | # commands = command_and_machine["commands"]
737 | # print("-----------------------")
738 | # print("Pre Command: %s\n" % " ".join(commands))
739 | # t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
740 | # t.start()
741 | # threads.append(t)
742 | #
743 | # # Wait until commands are all finished
744 | # for t in threads:
745 | # t.join()
746 | #
747 | # threads = []
748 | # q = Queue.Queue()
749 | #
750 | # batch_size_list = [4, 32, 50, 100, 500, 1000]
751 | # learning_rate_list = [0.046, 0.05, 0.068, 0.068, 0.048, 0.086]
752 | # running_process = 0
753 | # for name, command_and_machine in command_machine_assignments.items():
754 | # instance = command_and_machine["instance"]
755 | # neo_commands = "python train_cifar10.py --running_mode=training --gpus=0 "\
756 | # "--batch-size={} "\
757 | # "--lr={} "\
758 | # "--model-prefix={}/model_checkpoints/batch_size_{} "\
759 | # "--dir={}/loss_curve > "\
760 | # "{}/loss_curve/running_batch_size_{}.out 2>&1 &".format(
761 | # batch_size_list[running_process],
762 | # learning_rate_list[running_process],
763 | # configuration['nfs_mount_point'],
764 | # batch_size_list[running_process],
765 | # configuration['nfs_mount_point'],
766 | # configuration['nfs_mount_point'],
767 | # batch_size_list[running_process])
768 | #
769 | # commands = command_and_machine["commands"]
770 | # commands.append('mkdir {}/model_checkpoints/'.format(configuration['nfs_mount_point']))
771 | # commands.append('mkdir {}/loss_curve'.format(configuration['nfs_mount_point']))
772 | # commands.append(neo_commands)
773 | #
774 | # print("-----------------------")
775 | # print("Command: %s\n" % " ".join(commands))
776 | # t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
777 | # t.start()
778 | # threads.append(t)
779 | # running_process += 1
780 | #
781 | # # Wait until commands are all finished
782 | # for t in threads:
783 | # t.join()
784 | #
785 | # # Print the output
786 | # while not q.empty():
787 | # instance, output = q.get()
788 | # print(instance.public_ip_address)
789 | # print(output)
790 | #
791 | # # Debug print
792 | # instances = []
793 | # print("\n--------------------------------------------------\n")
794 | # print("Machine assignments:")
795 | # print("------------------------")
796 | # for name, command_and_machine in command_machine_assignments.items():
797 | # instance = command_and_machine["instance"]
798 | # instances.append(instance)
799 | # commands = command_and_machine["commands"]
800 | # ssh_command = "ssh -i %s %s@%s" % (configuration["path_to_keyfile"], configuration["ssh_username"], instance.public_ip_address)
801 | # print("%s - %s" % (name, instance.instance_id))
802 | # print("To ssh: %s" % ssh_command)
803 | # print("------------------------")
804 | #
805 | # # Print out list of instance ids (which will be useful in selctively stopping inception
806 | # # for given instances.
807 | # instance_cluster_string = ",".join([x.instance_id for x in instances])
808 | # print("\nInstances cluster string: %s" % instance_cluster_string)
809 | #
810 | # # Print out the id of the configuration file
811 | # cluster_save = {
812 | # "configuration" : configuration,
813 | # "name" : configuration["name"],
814 | # "command_machine_assignments" : command_machine_assignments,
815 | # "cluster_string" : instance_cluster_string
816 | # }
817 | #
818 | # return cluster_save
819 | return
820 |
821 | def kill_python(argv):
822 | if len(argv) != 3:
823 | print("Usage: python inception_ec2.py kill_python instance_id1,instance_id2,id3...")
824 | sys.exit(0)
825 | cluster_instance_string = argv[2]
826 | instance_ids_to_shutdown = cluster_instance_string.split(",")
827 |
828 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}])
829 | threads = []
830 | q = Queue.Queue()
831 | for instance in live_instances:
832 | if instance.instance_id in instance_ids_to_shutdown:
833 | commands = ["sudo pkill -9 python"]
834 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
835 | t.start()
836 | threads.append(t)
837 | for thread in threads:
838 | thread.join()
839 | summarize_idle_instances(None)
840 |
841 | def kill_all_python(argv):
842 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}])
843 | threads = []
844 | q = Queue.Queue()
845 | for instance in live_instances:
846 | commands = ["sudo pkill -9 python"]
847 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
848 | t.start()
849 | threads.append(t)
850 | for thread in threads:
851 | thread.join()
852 | summarize_idle_instances(None)
853 |
854 | def run_command(argv, quiet=False):
855 | if len(argv) != 4:
856 | print("Usage: python inception_ec2.py run_command instance_id1,instance_id2,id3... command")
857 | sys.exit(0)
858 | cluster_instance_string = argv[2]
859 | command = argv[3]
860 | instance_ids_to_run_command = cluster_instance_string.split(",")
861 |
862 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}])
863 | threads = []
864 | q = Queue.Queue()
865 | for instance in live_instances:
866 | if instance.instance_id in instance_ids_to_run_command:
867 | commands = [command]
868 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q))
869 | t.start()
870 | threads.append(t)
871 | for thread in threads:
872 | thread.join()
873 |
874 | while not q.empty():
875 | instance, output = q.get()
876 | if not quiet:
877 | print(instance, output)
878 |
879 | # Setup nfs on all instances
880 | def setup_nfs():
881 | print("Clearing previous nfs file system...")
882 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}])
883 | live_instances_string = ",".join([x.instance_id for x in live_instances])
884 | rm_command = "sudo rm -rf %s" % configuration["nfs_mount_point"]
885 | argv = ["python", "inception_ec2.py", live_instances_string, rm_command]
886 | # argv = ["python", "inception_ec2.py", live_instances_string]
887 | run_command(argv, quiet=True)
888 |
889 | print("Installing nfs on all running instances...")
890 | update_command = "sudo apt-get -y update"
891 | install_nfs_command = "sudo apt-get -y install nfs-common"
892 | create_mount_command = "mkdir %s" % configuration["nfs_mount_point"]
893 | setup_nfs_command = "sudo mount -t nfs4 -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 %s:/ %s" % (configuration["nfs_ip_address"], configuration["nfs_mount_point"])
894 | reduce_permissions_command = "sudo chmod 777 %s " % configuration["nfs_mount_point"]
895 | command = update_command + " && " + install_nfs_command + " && " + create_mount_command + " && " + setup_nfs_command + " && " + reduce_permissions_command
896 |
897 | # pretty hackish
898 | argv = ["python", "inception_ec2.py", live_instances_string, command]
899 | run_command(argv, quiet=True)
900 | return
901 |
902 | # Launch instances as specified by the configuration.
903 | # We also want a shared filesystem to write model checkpoints.
904 | # For simplicity we will have the user specify the filesystem via the config.
905 | def launch(argv):
906 | method = "spot"
907 | if "method" in configuration:
908 | method = configuration["method"]
909 | launch_instances()
910 | if method == "spot":
911 | wait_until_instance_request_status_fulfilled()
912 | wait_until_running_instances_initialized()
913 | print('setup nfs')
914 | setup_nfs()
915 |
916 | def clean_launch_and_run(argv):
917 | # 1. Kills all instances in region
918 | # 2. Kills all requests in region
919 | # 3. Launches requests
920 | # 5. Waits until launch requests have all been satisfied,
921 | # printing status outputs in the meanwhile
922 | # 4. Checks that configuration has been satisfied
923 | # 5. Runs inception
924 | shut_everything_down(None)
925 | launch(None)
926 | return run_mxnet_grid_search(None)
927 |
928 | def help(hmap):
929 | print("Usage: python inception_ec2.py [command]")
930 | print("Commands:")
931 | for k,v in hmap.items():
932 | print("%s - %s" % (k,v))
933 |
934 | ##############################
935 | # tf_ec2 main starting point #
936 | ##############################
937 |
938 | command_map = {
939 | "launch" : launch,
940 | "clean_launch_and_run" : clean_launch_and_run,
941 | "shutdown" : shut_everything_down,
942 | "run_mxnet_grid_search": run_mxnet_grid_search,
943 | "run_mxnet_loss_curve": run_mxnet_loss_curve,
944 | "get_hosts": get_hosts,
945 | "kill_all_python" : kill_all_python,
946 | "list_idle_instances" : summarize_idle_instances,
947 | "list_running_instances" : summarize_running_instances,
948 | "kill_python" : kill_python,
949 | "run_command" : run_command,
950 | "setup_nfs": setup_nfs,
951 | }
952 | help_map = {
953 | "launch" : "Launch instances",
954 | "clean_launch_and_run" : "Shut everything down, launch instances, wait until requests fulfilled, check that configuration is fulfilled, and launch and run inception.",
955 | "shutdown" : "Shut everything down by cancelling all instance requests, and terminating all instances.",
956 | "list_idle_instances" : "Lists all idle instances. Idle instances are running instances not running tensorflow.",
957 | "list_running_instances" : "Lists all running instances.",
958 | "run_mxnet_grid_search": "",
959 | "run_mxnet_loss_curve": "",
960 | "setup_nfs": "",
961 | "kill_all_python" : "Kills python running inception training on ALL instances.",
962 | "kill_python" : "Kills python running inception on instances indicated by instance id string separated by ',' (no spaces).",
963 | "run_command" : "Runs given command on instances selcted by instance id string, separated by ','.",
964 | }
965 |
966 | if len(argv) < 2:
967 | help(help_map)
968 | sys.exit(0)
969 |
970 | command = argv[1]
971 | return command_map[command](argv)
972 |
973 | if __name__ == "__main__":
974 | print(cfg)
975 | mxnet_ec2_run(sys.argv, cfg)
976 |
--------------------------------------------------------------------------------
/tools/remote_script.sh:
--------------------------------------------------------------------------------
1 | KEY_PEM_NAME=HongyiScript.pem
2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
3 | #cd ~/ps_pytorch/src/
4 | #bash ../tools/pre_run.sh
5 | #bash data_prepare.sh
6 | #cd ~
7 |
8 | sudo bash -c "cat hosts >> /etc/hosts"
9 | cp config ~/.ssh/
10 |
11 | cd ~/.ssh
12 | eval `ssh-agent -s`
13 | ssh-add ${KEY_PEM_NAME}
14 | ssh-keygen -t rsa -b 4096 -C "hongyiwang.hdu@gmail.com"
15 |
16 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT);
17 | do
18 | scp -i ${KEY_PEM_NAME} id_rsa.pub deeplearning-worker${i}:~/.ssh
19 | #ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'git clone https://github.com/hwang595/ps_pytorch.git; cd ~/.ssh; cat id_rsa.pub >> authorized_keys; bash ~/ps_pytorch/tools/pre_run.sh'
20 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'cd ~/.ssh; cat id_rsa.pub >> authorized_keys'
21 | scp -i ${KEY_PEM_NAME} -r /home/ubuntu/ps_real_pytorch deeplearning-worker${i}:~/
22 | #scp -i ${KEY_PEM_NAME} -r ~/ps_pytorch deeplearning-worker${i}:~
23 | echo "Done writing public key to worker: deeplearning-worker${i}"
24 | done
25 |
--------------------------------------------------------------------------------
/tools/update_git_dir.sh:
--------------------------------------------------------------------------------
1 | cd ~
2 | KEY_PEM_NAME=HongyiScript.pem
3 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts`
4 |
5 | sudo bash -c "cat hosts >> /etc/hosts"
6 |
7 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT);
8 | do
9 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'git config --global user.name hwang595; git config --global user.email hongyiwang@cs.wisc.edu; cd ~/ps_pytorch; git pull'
10 | echo "Done pull git repo on worker: deeplearning-worker${i}"
11 | done
--------------------------------------------------------------------------------