├── README.md ├── images └── system_overview.jpg ├── src ├── README.md ├── data │ └── data_prepare.py ├── data_prepare.sh ├── distributed_evaluator.py ├── distributed_nn.py ├── distributed_worker.py ├── evaluate_pytorch.sh ├── launch.sh ├── model_ops │ ├── __init__.py │ ├── lenet.py │ ├── resnet.py │ └── vgg.py ├── nn_ops.py ├── optim │ ├── __init__.py │ ├── adam.py │ └── sgd.py ├── run_pytorch_dist.sh ├── run_pytorch_single.sh ├── sync_replicas_master_nn.py └── util.py └── tools ├── README.md ├── conda_install.sh ├── config ├── hosts ├── hosts_address ├── hosts_alias ├── install.sh ├── killall.sh ├── local_script.sh ├── openmpi_install.sh ├── pre_run.sh ├── pytorch_ec2.py ├── remote_script.sh └── update_git_dir.sh /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-parameter-server 2 | Implementation of synchronous distributed machine learning in [Parameter Server](https://www.cs.cmu.edu/~muli/file/parameter_server_nips14.pdf) setup using [PyTorch's distributed communication library](https://pytorch.org/docs/stable/distributed.html) i.e. `torch.distributed`. 3 | 4 | All functionality in this repository is basically a repliaction of [ps_pytorch](https://github.com/hwang595/ps_pytorch). However, instead of using `Mpi4py`, all communications and model trainings are handled by PyTorch itself. 5 | 6 | ## Contents 7 | 8 | 1. [Motivations](#motivations) 9 | 2. [System design](#system-design) 10 | 3. [Basic usages](#basic-usages) 11 | 4. [How to prepare datasets](#prepare-datasets) 12 | 5. [How to launch a distributed task](#job-launching) 13 | 6. [Future work](#future-work) 14 | 15 | ## Motivations: 16 | 1. PyTorch provides easy-to-use APIs with dynamic computational graph 17 | 2. Altough [mpi4py](https://github.com/mpi4py/mpi4py) provides a good Python binding for any distributions of MPI and flexible communication operations, transforming data back and force (e.g. `torch.Tensor` <--> `numpy.array`) incurs heavy overheads during the entire training process. 18 | 3. PyTorch supports [NCCL](https://developer.nvidia.com/nccl) as its communication backend, which makes distributed training on GPU cluster becomes efficient and scalable. 19 | 20 | ## System Design: 21 | 1. Parameter Server: This node synchronizes all workers to enter next iteration by broadcast global step to workers and stores the global model, which will be pulled by workers at beginning of one iteration (we implement this stage using `torch.distributed.broadcast`). For a user defined frequency, Parameter Server will save the current model as checkpoint to shared file system (NFS in our system) for model evaluation. 22 | 2. workers mainly aim at sample data points (or mini-batch) from local dataset (we don't pass data among nodes to maintain data locality), computing gradients, and ship them back to Parameter Server (this stage is implemented using `torch.distributed.scatter`). 23 | 3. evaluator read the checkpoints from the shared directory, and do model evaluation. Note that: there is only testset data saved on evaluator nodes. 24 | 25 |
26 | 27 | ## Basic Usages 28 | ### Dependencies: 29 | Anaconda is highly recommended for installing depdencies for this project. Assume a conda setup machine is used, you can run 30 | ``` 31 | bash ./tools/pre_run.sh 32 | ``` 33 | to install all depdencies needed. 34 | ### Single Machine: 35 | The code base provided in this repository can be run on a single machine, in which multiple CPU processes will be launched and each process will be assigned a role as Parameter Server (usually process with id at 0) or worker. To do this, one can just follow the "Single-Node multi-process distributed training" part in [this tutorial](https://pytorch.org/docs/stable/distributed.html#launch-utility). We provide a script (`run_pytorch_single.sh`) to do the job for you. One can simply run 36 | ``` 37 | bash ./src/run_pytorch_single.sh 38 | ``` 39 | 40 | ### Cluster Setup: 41 | For running on distributed cluster, the first thing you need do is to launch AWS EC2 instances. 42 | #### Launching Instances: 43 | [This script](https://github.com/hwang595/PyTorch-parameter-server/blob/master/tools/pytorch_ec2.py) helps you to launch EC2 instances automatically, but before running this script, you should follow [the instruction](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html) to setup AWS CLI on your local machine. 44 | After that, please edit this part in `./tools/pytorch_ec2.py` 45 | ``` python 46 | cfg = Cfg({ 47 | "name" : "PS_PYTORCH", # Unique name for this specific configuration 48 | "key_name": "NameOfKeyFile", # Necessary to ssh into created instances 49 | # Cluster topology 50 | "n_masters" : 1, # Should always be 1 51 | "n_workers" : 8, 52 | "num_replicas_to_aggregate" : "8", # deprecated, not necessary 53 | "method" : "spot", 54 | # Region speficiation 55 | "region" : "us-west-2", 56 | "availability_zone" : "us-west-2b", 57 | # Machine type - instance type configuration. 58 | "master_type" : "m4.2xlarge", 59 | "worker_type" : "m4.2xlarge", 60 | # please only use this AMI for pytorch 61 | "image_id": "ami-xxxxxxxx", # id of AMI 62 | # Launch specifications 63 | "spot_price" : "0.15", # Has to be a string 64 | # SSH configuration 65 | "ssh_username" : "ubuntu", # For sshing. E.G: ssh ssh_username@hostname 66 | "path_to_keyfile" : "/dir/to/NameOfKeyFile.pem", 67 | 68 | # NFS configuration 69 | # To set up these values, go to Services > ElasticFileSystem > Create new filesystem, and follow the directions. 70 | #"nfs_ip_address" : "172.31.3.173", # us-west-2c 71 | #"nfs_ip_address" : "172.31.35.0", # us-west-2a 72 | "nfs_ip_address" : "172.31.14.225", # us-west-2b 73 | "nfs_mount_point" : "/home/ubuntu/shared", # NFS base dir 74 | ``` 75 | For setting everything up on EC2 cluster, the easiest way is to setup one machine and create an AMI. Then use the AMI id for `image_id` in `pytorch_ec2.py`. Then, launch EC2 instances by running 76 | ``` 77 | python ./tools/pytorch_ec2.py launch 78 | ``` 79 | After all launched instances are ready (this may take a while), getting private ips of instances by 80 | ``` 81 | python ./tools/pytorch_ec2.py get_hosts 82 | ``` 83 | this will write ips into a file named `hosts_address`, which looks like 84 | ``` 85 | 172.31.16.226 (${PS_IP}) 86 | 172.31.27.245 87 | 172.31.29.131 88 | 172.31.18.108 89 | ... 90 | ``` 91 | After generating the `hosts_address` of all EC2 instances, running the following command will copy your keyfile to the parameter server (PS) instance whose address is always the first one in `hosts_address`. `local_script.sh` will also do some basic configurations e.g. clone this git repo 92 | ``` 93 | bash ./tool/local_script.sh ${PS_IP} 94 | ``` 95 | #### SSH related: 96 | At this stage, you should ssh to the PS instance and all operation should happen on PS. In PS setting, PS should be able to ssh to any compute node, [this part](https://github.com/hwang595/PyTorch-parameter-server/blob/master/tools/remote_script.sh#L8-L22) dose the job for you by running (after ssh to the PS) 97 | ``` 98 | bash ./tools/remote_script.sh 99 | ``` 100 | 101 | ## Prepare Datasets 102 | To download, split, and transform datasets by (and `./tools/remote_script.sh` dose this for you) 103 | ``` 104 | bash ./src/data_prepare.sh 105 | ``` 106 | One can simply extend script `./src/data/data_prepare.py` to support any datasets provided by [torchvision](https://github.com/pytorch/vision). 107 | 108 | ## Job Launching 109 | Since this project is built on MPI, tasks are required to be launched by PS (or master) instance. `launch.sh` (which will call `./src/run_pytorch_dist.sh`) wraps job-launching process up. Commonly used options (arguments) are listed as following: 110 | 111 | | Argument | Comments | 112 | | ----------------------------- | ---------------------------------------- | 113 | | `n` | Number of processes (size of cluster) e.g. if we have P compute node and 1 PS, n=P+1. | 114 | | `lr` | Inital learning rate that will be use. | 115 | | `momentum` | Value of momentum that will be use. | 116 | | `max-steps` | The maximum number of iterations to train. | 117 | | `epochs` | The maximal number of epochs to train (somehow redundant). | 118 | | `network` | Types of deep neural nets, currently `LeNet`, `ResNet-18/32/50/110/152`, and `VGGs` are supported. | 119 | | `dataset` | Datasets use for training. | 120 | | `batch-size` | Batch size for optimization algorithms. | 121 | | `eval-freq` | Frequency of iterations to evaluation the model. | 122 | | `enable-gpu`|Training on CPU/GPU, if CPU please leave this argument empty. | 123 | |`train-dir`|Directory to save model checkpoints for evaluation. | 124 | 125 | ## Model Evaluation 126 | [Distributed evaluator](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/distributed_evaluator.py) will fetch model checkpoints from the shared directory and evaluate model on validation set. 127 | To evaluate model, you can run 128 | ``` 129 | bash ./src/evaluate_pytorch.sh 130 | ``` 131 | with specified arguments. 132 | 133 | Evaluation arguments are listed as following: 134 | 135 | | Argument | Comments | 136 | | ----------------------------- | ---------------------------------------- | 137 | | `eval-batch-size` | Batch size (on validation set) used during model evaluation. | 138 | | `eval-freq` | Frequency of iterations to evaluation the model, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/ps_pytorch/blob/master/src/run_pytorch.sh). | 139 | | `network` | Types of deep neural nets, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/run_pytorch_dist.sh). | 140 | | `dataset` | Datasets use for training, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/run_pytorch_dist.sh). | 141 | | `model-dir` | Directory to save model checkpoints for evaluation, should be set to the same value as [run_pytorch_dist.sh](https://github.com/hwang595/PyTorch-parameter-server/blob/master/src/run_pytorch_dist.sh). | 142 | 143 | ## Future work: 144 | (Please note that this project is still in early alpha version) 145 | 1. Overlapping computation (forward prop and backprop) with communication to gain better speedup. 146 | 2. Support async communication mode i.e. [Backup Worker](https://arxiv.org/pdf/1604.00981.pdf) 147 | 148 | ## Contact: 149 | Any contribution to this repo is highly appreciated. 150 | If you encountered any issue in using the code base provided here, please feel free to start an issue or email [Hongyi Wang](https://hwang595.github.io/) at (hongyiwang@cs.wisc.edu) directly. 151 | -------------------------------------------------------------------------------- /images/system_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang595/PyTorch-parameter-server/08726c9ed718fe0ee65c032801f632decf79ec79/images/system_overview.jpg -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | do remeber to transfer the `numpy.ndarray` to `numpy.float64` before push them to MPI send/receive buffer, otherwise there will be some data type transfer issues between numpy and MPI. 2 | 3 | # To use it on single machine: 4 | ``` 5 | python single_machine.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE} 6 | ``` 7 | 8 | # To use it on distributed cluster: 9 | ``` 10 | mpirun -n ${NUM_WORKERS} --hostfile=${HOST_DIR} python distributed_nn.py --dataset=MNIST/Cifar10 --network=LeNet/Resnet --batch-size=${BATCH_SIZE} 11 | ``` 12 | 13 | # Run the whole thing automatically 14 | The first thing you need do is to launch AWS EC2 instances, you can do that using `tools/pytorch_ec2.py` by running the following command: 15 | ``` 16 | python pytorch_ec2.py launch 17 | ``` 18 | After the launch command are executed and all instances are initialized (this may cost several minutes), you need to fetch the host addresses: 19 | ``` 20 | python pytorch_ec2.py get_hosts 21 | ``` 22 | Then, copying essential configuration files and hosts files using the public address of master node (the first address in the `host` file): 23 | ``` 24 | sh local_script.sh ${MASTER_PUB_ADDR} 25 | ``` 26 | After that, launch to master node manually, running the remote script under `$HOME` dir: 27 | ``` 28 | sh remote_script.sh 29 | ``` 30 | This script will do the cluster setup and data preparation works for you. 31 | -------------------------------------------------------------------------------- /src/data/data_prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Since we need to initialize the dataset in parallel, pre-download data is necessary 3 | This script will do the job for you 4 | """ 5 | import torch 6 | from torchvision import datasets, transforms 7 | 8 | 9 | if __name__ == "__main__": 10 | training_set_mnist = datasets.MNIST('./mnist_data', train=True, download=True, 11 | transform=transforms.Compose([ 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.1307,), (0.3081,))])) 14 | train_loader_mnist = torch.utils.data.DataLoader(training_set_mnist, batch_size=128, shuffle=True) 15 | test_loader_mnist = torch.utils.data.DataLoader( 16 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.1307,), (0.3081,)) 19 | ])), batch_size=100, shuffle=True) 20 | trainset_cifar10 = datasets.CIFAR10(root='./cifar10_data', train=True, 21 | download=True, transform=transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 24 | ])) 25 | train_loader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size=128, 26 | shuffle=True) 27 | test_loader_cifar10 = torch.utils.data.DataLoader( 28 | datasets.CIFAR10('./cifar10_data', train=False, transform=transforms.Compose([ 29 | transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 30 | ])), batch_size=100, shuffle=True) 31 | # load training and test set here: 32 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True, 33 | download=True, transform=transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 36 | ])) 37 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128, 38 | shuffle=True) 39 | testset = datasets.CIFAR100(root='./cifar100_data', train=False, 40 | download=True, transform=transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 43 | ])) 44 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000, 45 | shuffle=False) 46 | 47 | training_set = datasets.SVHN('./svhn_data', split='train', download=True, transform=transforms.Compose([ 48 | transforms.RandomCrop(32, padding=4), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 52 | ])) 53 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128, 54 | shuffle=True) 55 | transform_test = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 58 | ]) 59 | testset = datasets.SVHN(root='./svhn_data', split='test', 60 | download=True, transform=transform_test) 61 | test_loader = torch.utils.data.DataLoader(testset, batch_size=1000, 62 | shuffle=False) -------------------------------------------------------------------------------- /src/data_prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python ./data/data_prepare.py -------------------------------------------------------------------------------- /src/distributed_evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path 3 | import time 4 | import argparse 5 | from datetime import datetime 6 | import copy 7 | 8 | import numpy as np 9 | 10 | from nn_ops import NN_Trainer 11 | 12 | import torch 13 | from torch.autograd import Variable 14 | import torch.nn.functional as F 15 | from torchvision import datasets, transforms 16 | from torch.utils.data import DataLoader 17 | 18 | from model_ops.lenet import LeNet, LeNetSplit 19 | from model_ops.resnet import * 20 | from model_ops.resnet_split import * 21 | from util import build_model 22 | 23 | 24 | def accuracy(output, target, topk=(1,)): 25 | """Computes the precision@k for the specified values of k""" 26 | maxk = max(topk) 27 | batch_size = target.size(0) 28 | 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | res = [] 33 | for k in topk: 34 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 35 | res.append(correct_k.mul_(100.0 / batch_size)) 36 | return res 37 | 38 | def add_fit_args(parser): 39 | """ 40 | parser : argparse.ArgumentParser 41 | return a parser added with args required by fit 42 | """ 43 | # Validation settings 44 | parser.add_argument('--eval-batch-size', type=int, default=10000, metavar='N', 45 | help='the batch size when doing model validation, complete at once on default') 46 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N', 47 | help='it determines per how many step the model should be evaluated') 48 | parser.add_argument('--model-dir', type=str, default='output/models/', metavar='N', 49 | help='directory to save the temp model during the training process for evaluation') 50 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N', 51 | help='which dataset used in training, MNIST and Cifar10 supported currently') 52 | parser.add_argument('--network', type=str, default='LeNet', metavar='N', 53 | help='which kind of network we are going to use, support LeNet and ResNet currently') 54 | args = parser.parse_args() 55 | return args 56 | 57 | class DistributedEvaluator(NN_Trainer): 58 | ''' 59 | The DistributedEvaluator aims at providing a seperate node in the distributed cluster to evaluate 60 | the model on validation/test set and return the results 61 | In this version, the DistributedEvaluator will only load the model from the dir where the master 62 | save the model and do the evaluation task based on a user defined frequency 63 | ''' 64 | def __init__(self, **kwargs): 65 | self._cur_step = 0 66 | self._model_dir = kwargs['model_dir'] 67 | self._eval_freq = int(kwargs['eval_freq']) 68 | self._eval_batch_size = kwargs['eval_batch_size'] 69 | self.network_config = kwargs['network'] 70 | # this one is going to be used to avoid fetch the weights for multiple times 71 | self._layer_cur_step = [] 72 | 73 | def evaluate(self, validation_loader): 74 | # init objective to fetch at the begining 75 | self._next_step_to_fetch = self._cur_step + self._eval_freq 76 | self._num_batch_per_epoch = len(validation_loader) / self._eval_batch_size 77 | # check if next temp model exsits, if not we wait here else we continue to do the model evaluation 78 | while True: 79 | model_dir_=self._model_dir_generator(self._next_step_to_fetch) 80 | if os.path.isfile(model_dir_): 81 | self._load_model(model_dir_) 82 | print("Evaluator evaluating results on step {}".format(self._next_step_to_fetch)) 83 | self._evaluate_model(validation_loader) 84 | self._next_step_to_fetch += self._eval_freq 85 | else: 86 | # TODO(hwang): sleep appropriate period of time make sure to tune this parameter 87 | time.sleep(10) 88 | 89 | def _evaluate_model(self, test_loader): 90 | self.network.eval() 91 | test_loss = 0 92 | correct = 0 93 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 94 | for data, y_batch in test_loader: 95 | data, target = Variable(data), Variable(y_batch) 96 | output = self.network(data) 97 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).item() 98 | prec1_tmp, prec5_tmp = accuracy(output.detach(), y_batch, topk=(1, 5)) 99 | prec1_counter_ += prec1_tmp.numpy()[0] 100 | prec5_counter_ += prec5_tmp.numpy()[0] 101 | batch_counter_ += 1 102 | prec1 = prec1_counter_ / batch_counter_ 103 | prec5 = prec5_counter_ / batch_counter_ 104 | test_loss /= len(test_loader.dataset) 105 | print('Test set: Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(test_loss, prec1, prec5)) 106 | 107 | def _load_model(self, file_path): 108 | self.network = build_model(self.network_config, num_classes=10) 109 | with open(file_path, "rb") as f_: 110 | self.network.load_state_dict(torch.load(f_)) 111 | 112 | def _model_dir_generator(self, next_step_to_fetch): 113 | return self._model_dir+"model_step_"+str(next_step_to_fetch) 114 | 115 | if __name__ == "__main__": 116 | # this is only a simple test case 117 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch Distributed Evaluator')) 118 | 119 | # load training and test set here: 120 | if args.dataset == "MNIST": 121 | test_loader = torch.utils.data.DataLoader( 122 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 123 | transforms.ToTensor(), 124 | transforms.Normalize((0.1307,), (0.3081,)) 125 | ])), batch_size=args.eval_batch_size, shuffle=True) 126 | elif args.dataset == "Cifar10": 127 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 128 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 129 | transform_test = transforms.Compose([ 130 | transforms.ToTensor(), 131 | normalize]) 132 | testset = datasets.CIFAR10(root='./cifar10_data', train=False, 133 | download=True, transform=transform_test) 134 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, 135 | shuffle=False) 136 | 137 | kwargs_evaluator={ 138 | 'network':args.network, 139 | 'model_dir':args.model_dir, 140 | 'eval_freq':args.eval_freq, 141 | 'eval_batch_size':args.eval_batch_size} 142 | evaluator_nn = DistributedEvaluator(**kwargs_evaluator) 143 | evaluator_nn.evaluate(validation_loader=test_loader) 144 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size)) -------------------------------------------------------------------------------- /src/distributed_nn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import math 5 | import threading 6 | import argparse 7 | import time 8 | import os 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.distributed as dist 14 | from torch.autograd import Variable 15 | from torch import nn 16 | import torch.nn.functional as F 17 | 18 | from nn_ops import NN_Trainer, accuracy 19 | from data_loader_ops.my_data_loader import DataLoader 20 | 21 | from distributed_worker import * 22 | from sync_replicas_master_nn import * 23 | 24 | 25 | def add_fit_args(parser): 26 | """ 27 | parser : argparse.ArgumentParser 28 | return a parser added with args required by fit 29 | """ 30 | # Training settings 31 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 32 | help='input batch size for training (default: 64)') 33 | parser.add_argument('--test-batch-size', type=int, default=500, metavar='N', 34 | help='input batch size for testing (default: 1000)') 35 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 36 | help='number of epochs to train (default: 10)') 37 | parser.add_argument('--max-steps', type=int, default=10000, metavar='N', 38 | help='the maximum number of iterations') 39 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 40 | help='learning rate (default: 0.01)') 41 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 42 | help='SGD momentum (default: 0.5)') 43 | parser.add_argument('--no-cuda', action='store_true', default=False, 44 | help='disables CUDA training') 45 | parser.add_argument('--seed', type=int, default=1, metavar='S', 46 | help='random seed (default: 1)') 47 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 48 | help='how many batches to wait before logging training status') 49 | parser.add_argument('--network', type=str, default='LeNet', metavar='N', 50 | help='which kind of network we are going to use, support LeNet and ResNet currently') 51 | parser.add_argument('--mode', type=str, default='normal', metavar='N', 52 | help='determine if we kill the stragglers or just implement normal training') 53 | parser.add_argument('--kill-threshold', type=float, default=7.0, metavar='KT', 54 | help='timeout threshold which triggers the killing process (default: 7s)') 55 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N', 56 | help='which dataset used in training, MNIST and Cifar10 supported currently') 57 | parser.add_argument('--comm-type', type=str, default='Bcast', metavar='N', 58 | help='which kind of method we use during the mode fetching stage') 59 | parser.add_argument('--num-aggregate', type=int, default=5, metavar='N', 60 | help='how many number of gradients we wish to gather at each iteration') 61 | parser.add_argument('--eval-freq', type=int, default=50, metavar='N', 62 | help='it determines per how many step the model should be evaluated') 63 | parser.add_argument('--train-dir', type=str, default='output/models/', metavar='N', 64 | help='directory to save the temp model during the training process for evaluation') 65 | parser.add_argument('--compress-grad', type=str, default='compress', metavar='N', 66 | help='compress/none indicate if we compress the gradient matrix before communication') 67 | parser.add_argument('--gather-type', type=str, default='gather', metavar='N', 68 | help='gather/non to specify the type of comm used (MPI.Gather or point-to-point comm)') 69 | parser.add_argument('--enable-gpu', type=bool, default=False, help='whether to use gradient approx method') 70 | # TODO(hwang), check what's this 71 | parser.add_argument("--local_rank", type=int) 72 | args = parser.parse_args() 73 | return args 74 | 75 | if __name__ == "__main__": 76 | rank = int(os.environ['RANK']) 77 | world_size = int(os.environ['WORLD_SIZE']) 78 | master_addr = os.environ['MASTER_ADDR'] 79 | master_port = os.environ['MASTER_PORT'] 80 | 81 | print(rank, world_size) 82 | dist.init_process_group(backend='gloo', world_size=world_size, rank=rank) 83 | 84 | args = add_fit_args(argparse.ArgumentParser(description='PyTorch MNIST Single Machine Test')) 85 | 86 | train_loader, test_loader = prepare_data(args) 87 | 88 | device = torch.device("cuda" if args.enable_gpu else "cpu") 89 | 90 | kwargs_master = { 91 | 'world_size':world_size, 92 | 'batch_size':args.batch_size, 93 | 'learning_rate':args.lr, 94 | 'max_epochs':args.epochs, 95 | 'momentum':args.momentum, 96 | 'network':args.network, 97 | 'comm_method':args.comm_type, 98 | 'kill_threshold': args.num_aggregate, 99 | 'timeout_threshold':args.kill_threshold, 100 | 'eval_freq':args.eval_freq, 101 | 'train_dir':args.train_dir, 102 | 'max_steps':args.max_steps, 103 | 'compress_grad':args.compress_grad, 104 | 'gather_type':args.gather_type, 105 | 'device':device} 106 | 107 | kwargs_worker = { 108 | 'rank':rank, 109 | 'batch_size':args.batch_size, 110 | 'learning_rate':args.lr, 111 | 'max_epochs':args.epochs, 112 | 'momentum':args.momentum, 113 | 'network':args.network, 114 | 'comm_method':args.comm_type, 115 | 'kill_threshold':args.kill_threshold, 116 | 'eval_freq':args.eval_freq, 117 | 'train_dir':args.train_dir, 118 | 'max_steps':args.max_steps, 119 | 'compress_grad':args.compress_grad, 120 | 'gather_type':args.gather_type, 121 | 'device':device} 122 | 123 | if rank == 0: 124 | master_fc_nn = SyncReplicasMaster_NN(**kwargs_master) 125 | if args.dataset == 'Cifar100': 126 | master_fc_nn.build_model(num_classes=100) 127 | else: 128 | master_fc_nn.build_model(num_classes=10) 129 | print("I am the master: the world size is {}, cur step: {}".format(master_fc_nn.world_size, master_fc_nn.cur_step)) 130 | master_fc_nn.start() 131 | print("Done sending messages to workers!") 132 | else: 133 | worker_fc_nn = DistributedWorker(**kwargs_worker) 134 | if args.dataset == 'Cifar100': 135 | worker_fc_nn.build_model(num_classes=100) 136 | else: 137 | worker_fc_nn.build_model(num_classes=10) 138 | print("I am worker: {} in all {} workers, next step: {}".format(worker_fc_nn.rank, worker_fc_nn.world_size-1, worker_fc_nn.next_step)) 139 | worker_fc_nn.train(train_loader=train_loader, test_loader=test_loader) 140 | print("Now the next step is: {}".format(worker_fc_nn.next_step)) -------------------------------------------------------------------------------- /src/distributed_worker.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | from nn_ops import NN_Trainer 5 | from util import * 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | import torch.distributed as dist 10 | 11 | import time 12 | from datetime import datetime 13 | import copy 14 | import logging 15 | from sys import getsizeof 16 | 17 | STEP_START_ = 1 18 | TAG_LIST_ = [i*30 for i in range(50000)] 19 | 20 | logging.basicConfig() 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the precision@k for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | 30 | _, pred = output.topk(maxk, 1, True, True) 31 | pred = pred.t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 36 | res.append(correct_k.mul_(100.0 / batch_size)) 37 | return res 38 | 39 | class ModelBuffer(object): 40 | def __init__(self, network): 41 | """ 42 | this class is used to save model weights received from parameter server 43 | current step for each layer of model will also be updated here to make sure 44 | the model is always up-to-date 45 | """ 46 | super(ModelBuffer, self).__init__() 47 | self.recv_buf = [] 48 | self.layer_cur_step = [] 49 | self.layer_shape = [] 50 | ''' 51 | initialize space to receive model from parameter server 52 | ''' 53 | # consider we don't want to update the param of `BatchNorm` layer right now 54 | # we temporirially deprecate the foregoing version and only update the model 55 | # parameters 56 | for param_idx, param in enumerate(network.parameters()): 57 | self.recv_buf.append(torch.zeros(param.size())) 58 | 59 | 60 | class DistributedWorker(NN_Trainer): 61 | def __init__(self, **kwargs): 62 | super(NN_Trainer, self).__init__() 63 | 64 | self.cur_step = 0 65 | self.next_step = 0 # we will fetch this one from parameter server 66 | 67 | self.rank = kwargs['rank'] 68 | self.batch_size = kwargs['batch_size'] 69 | self.max_epochs = kwargs['max_epochs'] 70 | self.momentum = kwargs['momentum'] 71 | self.lr = kwargs['learning_rate'] 72 | self._max_steps = kwargs['max_steps'] 73 | self.network_config = kwargs['network'] 74 | self.comm_type = kwargs['comm_method'] 75 | self.kill_threshold = kwargs['kill_threshold'] 76 | self._eval_batch_size = 100 77 | self._eval_freq = kwargs['eval_freq'] 78 | self._train_dir = kwargs['train_dir'] 79 | self._compress_grad = kwargs['compress_grad'] 80 | self._gather_type = kwargs['gather_type'] 81 | self._device = kwargs['device'] 82 | 83 | # this one is going to be used to avoid fetch the weights for multiple times 84 | self._layer_cur_step = [] 85 | 86 | def build_model(self, num_classes=10): 87 | self.network = build_model(self.network_config, num_classes) 88 | # set up optimizer 89 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 90 | self.criterion = nn.CrossEntropyLoss() 91 | # assign a buffer for receiving models from parameter server 92 | self.init_recv_buf() 93 | 94 | self.network.to(self._device) 95 | 96 | def train(self, train_loader, test_loader): 97 | # the first step we need to do here is to sync fetch the inital worl_step from the parameter server 98 | # we still need to make sure the value we fetched from parameter server is 1 99 | 100 | # number of batches in one epoch 101 | iteration_last_step=0 102 | iter_start_time=0 103 | first = True 104 | 105 | logger.info("Worker {}: starting training".format(self.rank)) 106 | # start the training process 107 | for num_epoch in range(self.max_epochs): 108 | for batch_idx, (train_image_batch, train_label_batch) in enumerate(train_loader): 109 | 110 | iter_start_time = time.time() 111 | # worker exit task 112 | if self.cur_step == self._max_steps: 113 | break 114 | X_batch, y_batch = train_image_batch.to(self._device), train_label_batch.to(self._device) 115 | 116 | # bcast communication stage 117 | fetch_weight_start = time.time() 118 | self._fetch_weight() 119 | fetch_weight_dur = time.time() - fetch_weight_start 120 | 121 | comp_start = time.time() 122 | self._train_init() 123 | loss, logits = self._forward(X_batch, y_batch) 124 | loss.backward() 125 | comp_dur = time.time() - comp_start 126 | 127 | prec1, prec5 = accuracy(logits.detach(), train_label_batch.long(), topk=(1, 5)) 128 | gather_start = time.time() 129 | self._send_grads() 130 | gather_dur = time.time() - gather_start 131 | 132 | 133 | log_format = 'Worker: {}, Step: {}, Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.4f}, Time Cost: {:.4f}, FetchWeight: {:.4f}, Computation: {:.4f}, GatherTime: {:.4f}, Acc: {:.4f}' 134 | logger.info(log_format.format(self.rank, 135 | self.cur_step, num_epoch, batch_idx * self.batch_size, len(train_loader.dataset), 136 | (100. * (batch_idx * self.batch_size) / len(train_loader.dataset)), loss.item(), 137 | time.time()-iter_start_time, fetch_weight_dur, comp_dur, gather_dur, prec1.numpy()[0])) 138 | 139 | if self.cur_step%self._eval_freq == 0: 140 | self._save_model(file_path=self._generate_model_path()) 141 | 142 | def init_recv_buf(self): 143 | self.model_recv_buf = ModelBuffer(self.network) 144 | 145 | def _train_init(self): 146 | self.network.train() 147 | self.optimizer.zero_grad() 148 | 149 | def _forward(self, X_batch, y_batch): 150 | logits = self.network(X_batch) 151 | return self.criterion(logits, y_batch), logits 152 | 153 | def _fetch_weight(self): 154 | for layer_idx, layer in enumerate(self.model_recv_buf.recv_buf): 155 | dist.broadcast(self.model_recv_buf.recv_buf[layer_idx], src=0) 156 | self.model_update(self.model_recv_buf.recv_buf) 157 | # Note that at here we update the global step 158 | self.cur_step += 1 159 | 160 | def update_step(self): 161 | '''update local (global) step on worker''' 162 | changed = (self.cur_step != self.next_step) 163 | self.cur_step = self.next_step 164 | return changed 165 | 166 | def model_update(self, weights_to_update): 167 | """write model fetched from parameter server to local model""" 168 | new_state_dict = {} 169 | model_counter_ = 0 170 | for param_idx,(key_name, param) in enumerate(self.network.state_dict().items()): 171 | # handle the case that `running_mean` and `running_var` contained in `BatchNorm` layer 172 | if "running_mean" in key_name or "running_var" in key_name or "num_batches_tracked" in key_name: 173 | tmp_dict={key_name: param} 174 | else: 175 | assert param.size() == weights_to_update[model_counter_].size() 176 | tmp_dict = {key_name: weights_to_update[model_counter_].to(self._device)} 177 | model_counter_ += 1 178 | new_state_dict.update(tmp_dict) 179 | self.network.load_state_dict(new_state_dict) 180 | 181 | def _send_grads(self): 182 | for p_index, p in enumerate(self.network.parameters()): 183 | # fetch the grad we need 184 | if self._device.type == "cuda": 185 | grad = p.grad.to(torch.device("cpu")).detach() 186 | else: 187 | grad = p.grad.detach() 188 | 189 | dist.gather(grad, [], dst=0) 190 | 191 | def _evaluate_model(self, test_loader): 192 | self.network.eval() 193 | test_loss = 0 194 | correct = 0 195 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 196 | for data, y_batch in test_loader: 197 | data, target = data.to(self._device), y_batch.to(self._device) 198 | 199 | output = self.network(data) 200 | test_loss += F.nll_loss(F.log_softmax(output), target, size_average=False).item() # sum up batch loss 201 | 202 | prec1_tmp, prec5_tmp = accuracy(output.detach(), y_batch, topk=(1, 5)) 203 | 204 | if self._device.type == 'cuda': 205 | prec1_counter_ += prec1_tmp.to(torch.device("cpu")).numpy()[0] 206 | prec5_counter_ += prec5_tmp.to(torch.device("cpu")).numpy()[0] 207 | else: 208 | prec1_counter_ += prec1_tmp.numpy()[0] 209 | prec5_counter_ += prec5_tmp.numpy()[0] 210 | 211 | batch_counter_ += 1 212 | prec1 = prec1_counter_ / batch_counter_ 213 | prec5 = prec5_counter_ / batch_counter_ 214 | test_loss /= len(test_loader.dataset) 215 | print('Test set: Step: {}, Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(self.cur_step, 216 | test_loss, prec1, prec5)) 217 | 218 | def _generate_model_path(self): 219 | return self._train_dir+"model_step_"+str(self.cur_step) 220 | 221 | def _save_model(self, file_path): 222 | with open(file_path, "wb") as f_: 223 | torch.save(self.network.state_dict(), f_) 224 | return 225 | 226 | if __name__ == "__main__": 227 | # this is only a simple test case 228 | comm = MPI.COMM_WORLD 229 | rank = comm.Get_rank() 230 | world_size = comm.Get_size() 231 | worker_fc_nn = WorkerFC_NN(comm=comm, world_size=world_size, rank=rank) 232 | print("I am worker: {} in all {} workers".format(worker_fc_nn.rank, worker_fc_nn.world_size)) -------------------------------------------------------------------------------- /src/evaluate_pytorch.sh: -------------------------------------------------------------------------------- 1 | python distributed_evaluator.py \ 2 | --eval-batch-size=10000 \ 3 | --eval-freq=50 \ 4 | --network=ResNet18 \ 5 | --dataset=Cifar10 \ 6 | --model-dir=/home/ubuntu/MPI_shared/ -------------------------------------------------------------------------------- /src/launch.sh: -------------------------------------------------------------------------------- 1 | KEY_PEM_NAME=HongyiScript.pem 2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 3 | MASTER_PUB_IP="$1" 4 | WORKING_DIR=${HOME}/ps_real_pytorch/src 5 | 6 | for i in $(seq 1 $DEEPLEARNING_WORKERS_COUNT); 7 | do 8 | ssh -i ${HOME}/.ssh/${KEY_PEM_NAME} deeplearning-worker${i} "cd ${WORKING_DIR}; nohup bash ${WORKING_DIR}/run_pytorch_dist.sh \"$((${i}-1))\" \"${DEEPLEARNING_WORKERS_COUNT}\" \"${MASTER_PUB_IP}\" &>/dev/null &" 9 | done -------------------------------------------------------------------------------- /src/model_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from . import lenet, resnet, resnet_split, vgg 2 | 3 | __all__ = ['lenet', 'resnet', 'resnet_split', 'vgg'] -------------------------------------------------------------------------------- /src/model_ops/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from torch.autograd import Variable 8 | 9 | from mpi4py import MPI 10 | 11 | import sys 12 | sys.path.insert(0, '../compression') 13 | from compression import g_compress 14 | 15 | # we use LeNet here for our simple case 16 | class LeNet(nn.Module): 17 | def __init__(self): 18 | super(LeNet, self).__init__() 19 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 20 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 21 | self.fc1 = nn.Linear(4*4*50, 500) 22 | self.fc2 = nn.Linear(500, 10) 23 | self.ceriation = nn.CrossEntropyLoss() 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = F.max_pool2d(x, 2, 2) 27 | x = F.relu(x) 28 | x = self.conv2(x) 29 | x = F.max_pool2d(x, 2, 2) 30 | x = F.relu(x) 31 | x = x.view(-1, 4*4*50) 32 | x = self.fc1(x) 33 | x = self.fc2(x) 34 | #loss = self.ceriation(x, target) 35 | return x 36 | def name(self): 37 | return 'lenet' 38 | 39 | class LeNetSplit(nn.Module): 40 | ''' 41 | this is a module that we split the module and do backward process layer by layer 42 | please don't call this module for normal uses, this is a hack and run slower than 43 | the automatic chain rule version 44 | ''' 45 | def __init__(self): 46 | super(LeNetSplit, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 48 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 49 | self.fc1 = nn.Linear(4*4*50, 500) 50 | self.fc2 = nn.Linear(500, 10) 51 | 52 | self.maxpool2d = nn.MaxPool2d(2, stride=2) 53 | self.relu = nn.ReLU() 54 | 55 | self.full_modules = [self.conv1, self.conv2, self.fc1, self.fc2] 56 | self._init_channel_index = len(self.full_modules)*2 57 | 58 | self.criterion = nn.CrossEntropyLoss() 59 | 60 | def forward(self, x): 61 | self.output = [] 62 | self.input = [] 63 | x = Variable(x.data, requires_grad=True) 64 | self.input.append(x) 65 | x = self.conv1(x) 66 | self.output.append(x) 67 | 68 | x = Variable(x.data, requires_grad=True) 69 | self.input.append(x) 70 | x = self.maxpool2d(x) 71 | self.output.append(x) 72 | 73 | x = Variable(x.data, requires_grad=True) 74 | self.input.append(x) 75 | x = self.relu(x) 76 | self.output.append(x) 77 | 78 | x = Variable(x.data, requires_grad=True) 79 | self.input.append(x) 80 | x = self.conv2(x) 81 | self.output.append(x) 82 | 83 | x = Variable(x.data, requires_grad=True) 84 | self.input.append(x) 85 | x = self.maxpool2d(x) 86 | self.output.append(x) 87 | 88 | x = Variable(x.data, requires_grad=True) 89 | self.input.append(x) 90 | x = self.relu(x) 91 | self.output.append(x) 92 | 93 | x = x.view(-1, 4*4*50) 94 | 95 | x = Variable(x.data, requires_grad=True) 96 | self.input.append(x) 97 | x = self.fc1(x) 98 | self.output.append(x) 99 | 100 | x = Variable(x.data, requires_grad=True) 101 | self.input.append(x) 102 | x = self.fc2(x) 103 | self.output.append(x) 104 | return x 105 | 106 | @property 107 | def fetch_init_channel_index(self): 108 | return self._init_channel_index 109 | 110 | def backward_normal(self, g, communicator, req_send_check, cur_step, compress_grad): 111 | mod_avail_index = len(self.full_modules)-1 112 | #channel_index = len(self.full_modules)*2-2 113 | channel_index = self._init_channel_index - 2 114 | mod_counters_ = [0]*len(self.full_modules) 115 | for i, output in reversed(list(enumerate(self.output))): 116 | req_send_check[-1].wait() 117 | if i == (len(self.output) - 1): 118 | # for last node, use g 119 | output.backward(g) 120 | # get gradient here after some sanity checks: 121 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 122 | if not pd.isnull(tmp_grad): 123 | grads = tmp_grad.data.numpy().astype(np.float64) 124 | ############################################################################################### 125 | if compress_grad == 'compress': 126 | _compressed_grad = g_compress(grads) 127 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index) 128 | else: 129 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 130 | ################################################################################################ 131 | req_send_check.append(req_isend) 132 | # update counters 133 | mod_avail_index-=1 134 | channel_index-=1 135 | else: 136 | continue 137 | else: 138 | output.backward(self.input[i+1].grad.data) 139 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 140 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 141 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 142 | # we always send bias first 143 | if mod_counters_[mod_avail_index] == 0: 144 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 145 | ############################################################################################### 146 | if compress_grad == 'compress': 147 | _compressed_grad = g_compress(grads) 148 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index) 149 | else: 150 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 151 | ################################################################################################ 152 | req_send_check.append(req_isend) 153 | channel_index-=1 154 | mod_counters_[mod_avail_index]+=1 155 | elif mod_counters_[mod_avail_index] == 1: 156 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 157 | ############################################################################################### 158 | if compress_grad == 'compress': 159 | _compressed_grad = g_compress(grads) 160 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index) 161 | else: 162 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 163 | ################################################################################################ 164 | req_send_check.append(req_isend) 165 | channel_index-=1 166 | mod_counters_[mod_avail_index]+=1 167 | # update counters 168 | mod_avail_index-=1 169 | else: 170 | continue 171 | if mod_counters_[0] == 1: 172 | req_send_check[-1].wait() 173 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 174 | ############################################################################################### 175 | if compress_grad == 'compress': 176 | _compressed_grad = g_compress(grads) 177 | req_isend = communicator.isend(_compressed_grad, dest=0, tag=88+channel_index) 178 | else: 179 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 180 | ################################################################################################ 181 | req_send_check.append(req_isend) 182 | # for debugging here: 183 | for req in req_send_check: 184 | req.wait() 185 | return req_send_check 186 | 187 | def backward_signal_kill(self, g, communicator, req_send_check, cur_step): 188 | ''' 189 | This killer is triggered by signals bcasting from master, channel of 190 | signal is kept checking by each worker to determine if they're the 191 | straggler 192 | ''' 193 | mod_avail_index = len(self.full_modules)-1 194 | channel_index = self._init_channel_index - 2 195 | mod_counters_ = [0]*len(self.full_modules) 196 | 197 | # should kill flag 198 | should_kill = False 199 | 200 | for i, output in reversed(list(enumerate(self.output))): 201 | ############################ killing process on workers ##################################### 202 | for _ in range(10000): 203 | status = MPI.Status() 204 | communicator.Iprobe(0, 77, status) 205 | if status.Get_source() == 0: 206 | print("Worker {}, Cur Step: {} I'm the straggler, killing myself!".format(communicator.Get_rank(), cur_step)) 207 | tmp = communicator.recv(source=0, tag=77) 208 | should_kill = True 209 | break 210 | if should_kill: 211 | break 212 | ############################################################################################ 213 | 214 | if i == (len(self.output) - 1): 215 | # for last node, use g 216 | output.backward(g) 217 | # get gradient here after some sanity checks: 218 | tmp_grad = self.full_modules[mod_avail_index].weight.grad 219 | if not pd.isnull(tmp_grad): 220 | grads = tmp_grad.data.numpy().astype(np.float64) 221 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 222 | req_send_check.append(req_isend) 223 | # update counters 224 | mod_avail_index-=1 225 | channel_index-=1 226 | else: 227 | continue 228 | else: 229 | output.backward(self.input[i+1].grad.data) 230 | tmp_grad_weight = self.full_modules[mod_avail_index].weight.grad 231 | tmp_grad_bias = self.full_modules[mod_avail_index].bias.grad 232 | if not pd.isnull(tmp_grad_weight) and not pd.isnull(tmp_grad_bias): 233 | # we always send bias first 234 | if mod_counters_[mod_avail_index] == 0: 235 | grads = tmp_grad_bias.data.numpy().astype(np.float64) 236 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 237 | req_send_check.append(req_isend) 238 | channel_index-=1 239 | mod_counters_[mod_avail_index]+=1 240 | elif mod_counters_[mod_avail_index] == 1: 241 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 242 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 243 | req_send_check.append(req_isend) 244 | channel_index-=1 245 | mod_counters_[mod_avail_index]+=1 246 | # update counters 247 | mod_avail_index-=1 248 | else: 249 | continue 250 | if mod_counters_[0] == 1: 251 | grads = tmp_grad_weight.data.numpy().astype(np.float64) 252 | req_isend = communicator.Isend([grads, MPI.DOUBLE], dest=0, tag=88+channel_index) 253 | req_send_check.append(req_isend) 254 | return req_send_check 255 | 256 | def backward_timeout_kill(self, g, communicator, req_send_check): 257 | """do we even need this?""" 258 | pass -------------------------------------------------------------------------------- /src/model_ops/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(num_classes): 101 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes) 102 | 103 | def ResNet34(num_classes): 104 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes) 105 | 106 | def ResNet50(): 107 | return ResNet(Bottleneck, [3,4,6,3]) 108 | 109 | def ResNet101(): 110 | return ResNet(Bottleneck, [3,4,23,3]) 111 | 112 | def ResNet152(): 113 | return ResNet(Bottleneck, [3,8,36,3]) -------------------------------------------------------------------------------- /src/model_ops/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model 18 | ''' 19 | def __init__(self, features, num_classes=10): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | nn.Linear(512, num_classes), 30 | ) 31 | # Initialize weights 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | m.weight.data.normal_(0, math.sqrt(2. / n)) 36 | m.bias.data.zero_() 37 | 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | def make_layers(cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | 61 | 62 | cfg = { 63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 67 | 512, 512, 512, 512, 'M'], 68 | } 69 | 70 | 71 | def vgg11(): 72 | """VGG 11-layer model (configuration "A")""" 73 | return VGG(make_layers(cfg['A'])) 74 | 75 | 76 | def vgg11_bn(num_classes=10): 77 | """VGG 11-layer model (configuration "A") with batch normalization""" 78 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes=num_classes) 79 | 80 | 81 | def vgg13(): 82 | """VGG 13-layer model (configuration "B")""" 83 | return VGG(make_layers(cfg['B'])) 84 | 85 | 86 | def vgg13_bn(): 87 | """VGG 13-layer model (configuration "B") with batch normalization""" 88 | return VGG(make_layers(cfg['B'], batch_norm=True)) 89 | 90 | 91 | def vgg16(): 92 | """VGG 16-layer model (configuration "D")""" 93 | return VGG(make_layers(cfg['D'])) 94 | 95 | 96 | def vgg16_bn(): 97 | """VGG 16-layer model (configuration "D") with batch normalization""" 98 | return VGG(make_layers(cfg['D'], batch_norm=True)) 99 | 100 | 101 | def vgg19(): 102 | """VGG 19-layer model (configuration "E")""" 103 | return VGG(make_layers(cfg['E'])) 104 | 105 | 106 | def vgg19_bn(): 107 | """VGG 19-layer model (configuration 'E') with batch normalization""" 108 | return VGG(make_layers(cfg['E'], batch_norm=True)) -------------------------------------------------------------------------------- /src/nn_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | from model_ops.lenet import LeNet 10 | from model_ops.resnet import * 11 | 12 | '''this is a trial example, we use MNIST on LeNet for simple test here''' 13 | def accuracy(output, target, topk=(1,)): 14 | """Computes the precision@k for the specified values of k""" 15 | maxk = max(topk) 16 | batch_size = target.size(0) 17 | 18 | _, pred = output.topk(maxk, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 21 | 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 25 | res.append(correct_k.mul_(100.0 / batch_size)) 26 | return res 27 | 28 | class NN_Trainer(object): 29 | def __init__(self, **kwargs): 30 | self.batch_size = kwargs['batch_size'] 31 | self.lr = kwargs['learning_rate'] 32 | self.max_epochs = kwargs['max_epochs'] 33 | self.momentum = kwargs['momentum'] 34 | self.network_config = kwargs['network'] 35 | 36 | def build_model(self): 37 | # build network 38 | if self.network_config == "LeNet": 39 | self.network=LeNet() 40 | elif self.network_config == "ResNet": 41 | #self.network=ResNet18() 42 | self.network=ResNetSplit18(1) 43 | # set up optimizer 44 | self.optimizer = torch.optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 45 | self.criterion = torch.nn.CrossEntropyLoss() 46 | 47 | def train_and_validate(self, train_loader, test_loader): 48 | # iterate of epochs 49 | for i in range(self.max_epochs): 50 | # change back to training mode 51 | self.network.train() 52 | for batch_idx, (data, y_batch) in enumerate(train_loader): 53 | iter_start_time = time.time() 54 | data, target = Variable(data), Variable(y_batch) 55 | self.optimizer.zero_grad() 56 | ################# backward on normal model ############################ 57 | 58 | logits = self.network(data) 59 | loss = self.criterion(logits, target) 60 | loss.backward() 61 | ####################################################################### 62 | 63 | ################ backward on splitted model ########################### 64 | #logits = self.network(data) 65 | #logits_1 = Variable(logits.data, requires_grad=True) 66 | #loss = self.criterion(logits_1, target) 67 | #loss.backward() 68 | #self.network.backward_single(logits_1.grad) 69 | ####################################################################### 70 | tmp_time_0 = time.time() 71 | 72 | duration_backward = time.time()-tmp_time_0 73 | 74 | tmp_time_1 = time.time() 75 | self.optimizer.step() 76 | duration_update = time.time()-tmp_time_1 77 | 78 | # calculate training accuracy 79 | prec1, prec5 = accuracy(logits.data, y_batch, topk=(1, 5)) 80 | # load the training info 81 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Prec@1: {} Prec@5: {} Time Cost: {}'.format( 82 | i, batch_idx * len(data), len(train_loader.dataset), 83 | 100. * batch_idx / len(train_loader), loss.data[0], 84 | prec1.numpy()[0], 85 | prec5.numpy()[0], time.time()-iter_start_time)) 86 | # we evaluate the model performance on end of each epoch 87 | self.validate(test_loader) 88 | 89 | def validate(self, test_loader): 90 | self.network.eval() 91 | test_loss = 0 92 | correct = 0 93 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 94 | for data, y_batch in test_loader: 95 | data, target = Variable(data, volatile=True), Variable(y_batch) 96 | output = self.network(data) 97 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 98 | prec1_tmp, prec5_tmp = accuracy(output.data, y_batch, topk=(1, 5)) 99 | prec1_counter_ += prec1_tmp.numpy()[0] 100 | prec5_counter_ += prec5_tmp.numpy()[0] 101 | batch_counter_ += 1 102 | prec1 = prec1_counter_ / batch_counter_ 103 | prec5 = prec5_counter_ / batch_counter_ 104 | test_loss /= len(test_loader.dataset) 105 | print('Test set: Average loss: {:.4f}, Prec@1: {} Prec@5: {}'.format(test_loss, prec1, prec5)) 106 | -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from . import adam, sgd 2 | 3 | __all__ = ['adam', 'sgd'] -------------------------------------------------------------------------------- /src/optim/adam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | modified version of Adam optimizer 3 | by Hongyi Wang 4 | ''' 5 | import sys 6 | 7 | import math 8 | import torch 9 | from torch.optim import Optimizer 10 | 11 | 12 | class Adam(Optimizer): 13 | """Implements Adam algorithm. 14 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | .. _Adam\: A Method for Stochastic Optimization: 27 | https://arxiv.org/abs/1412.6980 28 | .. _On the Convergence of Adam and Beyond: 29 | https://openreview.net/forum?id=ryQu7f-RZ 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 33 | weight_decay=0, amsgrad=False): 34 | defaults = dict(lr=lr, betas=betas, eps=eps, 35 | weight_decay=weight_decay, amsgrad=amsgrad) 36 | super(Adam, self).__init__(params, defaults) 37 | 38 | def step(self, grads, closure=None): 39 | """Performs a single optimization step. 40 | Arguments: 41 | closure (callable, optional): A closure that reevaluates the model 42 | and returns the loss. 43 | """ 44 | loss = None 45 | if closure is not None: 46 | loss = closure() 47 | 48 | for group in self.param_groups: 49 | for i,p in enumerate(group['params']): 50 | grad = torch.from_numpy(grads[i]).float() 51 | if grad.is_sparse: 52 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 53 | amsgrad = group['amsgrad'] 54 | 55 | state = self.state[p] 56 | 57 | # State initialization 58 | if len(state) == 0: 59 | state['step'] = 0 60 | # Exponential moving average of gradient values 61 | state['exp_avg'] = torch.zeros_like(p.data) 62 | # Exponential moving average of squared gradient values 63 | state['exp_avg_sq'] = torch.zeros_like(p.data) 64 | if amsgrad: 65 | # Maintains max of all exp. moving avg. of sq. grad. values 66 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 67 | 68 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 69 | if amsgrad: 70 | max_exp_avg_sq = state['max_exp_avg_sq'] 71 | beta1, beta2 = group['betas'] 72 | 73 | state['step'] += 1 74 | 75 | if group['weight_decay'] != 0: 76 | grad = grad.add(group['weight_decay'], p.data) 77 | 78 | # Decay the first and second moment running average coefficient 79 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 80 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 81 | if amsgrad: 82 | # Maintains the maximum of all 2nd moment running avg. till now 83 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 84 | # Use the max. for normalizing running avg. of gradient 85 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 86 | else: 87 | denom = exp_avg_sq.sqrt().add_(group['eps']) 88 | 89 | bias_correction1 = 1 - beta1 ** state['step'] 90 | bias_correction2 = 1 - beta2 ** state['step'] 91 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 92 | 93 | p.data.addcdiv_(-step_size, exp_avg, denom) 94 | 95 | return loss -------------------------------------------------------------------------------- /src/optim/sgd.py: -------------------------------------------------------------------------------- 1 | ''' 2 | modified version of SGD optimizer 3 | by Hongyi Wang 4 | ''' 5 | import sys 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class SGD(Optimizer): 12 | r"""Implements stochastic gradient descent (optionally with momentum). 13 | Nesterov momentum is based on the formula from 14 | `On the importance of initialization and momentum in deep learning`__. 15 | Args: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float): learning rate 19 | momentum (float, optional): momentum factor (default: 0) 20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 21 | dampening (float, optional): dampening for momentum (default: 0) 22 | nesterov (bool, optional): enables Nesterov momentum (default: False) 23 | Example: 24 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 25 | >>> optimizer.zero_grad() 26 | >>> loss_fn(model(input), target).backward() 27 | >>> optimizer.step() 28 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 29 | .. note:: 30 | The implementation of SGD with Momentum/Nesterov subtly differs from 31 | Sutskever et. al. and implementations in some other frameworks. 32 | Considering the specific case of Momentum, the update can be written as 33 | .. math:: 34 | v = \rho * v + g \\ 35 | p = p - lr * v 36 | where p, g, v and :math:`\rho` denote the parameters, gradient, 37 | velocity, and momentum respectively. 38 | This is in contrast to Sutskever et. al. and 39 | other frameworks which employ an update of the form 40 | .. math:: 41 | v = \rho * v + lr * g \\ 42 | p = p - v 43 | The Nesterov version is analogously modified. 44 | """ 45 | 46 | def __init__(self, params, lr=0.1, momentum=0, dampening=0, 47 | weight_decay=0, nesterov=False): 48 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 49 | weight_decay=weight_decay, nesterov=nesterov) 50 | if nesterov and (momentum <= 0 or dampening != 0): 51 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 52 | super(SGD, self).__init__(params, defaults) 53 | 54 | def __setstate__(self, state): 55 | super(SGD, self).__setstate__(state) 56 | for group in self.param_groups: 57 | group.setdefault('nesterov', False) 58 | 59 | def step(self, grads, closure=None): 60 | """Performs a single optimization step. 61 | Arguments: 62 | closure (callable, optional): A closure that reevaluates the model 63 | and returns the loss. 64 | """ 65 | loss = None 66 | if closure is not None: 67 | loss = closure() 68 | 69 | for group in self.param_groups: 70 | weight_decay = group['weight_decay'] 71 | momentum = group['momentum'] 72 | dampening = group['dampening'] 73 | nesterov = group['nesterov'] 74 | 75 | for i,p in enumerate(group['params']): 76 | d_p = grads[i] 77 | if weight_decay != 0: 78 | d_p.add_(weight_decay, p.data) 79 | if momentum != 0: 80 | param_state = self.state[p] 81 | if 'momentum_buffer' not in param_state: 82 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 83 | buf.mul_(momentum).add_(d_p) 84 | else: 85 | buf = param_state['momentum_buffer'] 86 | buf.mul_(momentum).add_(1 - dampening, d_p) 87 | if nesterov: 88 | d_p = d_p.add(momentum, buf) 89 | else: 90 | d_p = buf 91 | p.data.add_(-group['lr'], d_p) 92 | return loss -------------------------------------------------------------------------------- /src/run_pytorch_dist.sh: -------------------------------------------------------------------------------- 1 | NODE_RANK="$1" 2 | NNODE="$2" 3 | MASTER_IP="$3" 4 | SRC_DIR=${HOME}/ps_real_pytorch/src 5 | 6 | echo ${MASTER_IP} 7 | sudo /home/ubuntu/anaconda3/bin/python -m torch.distributed.launch \ 8 | --nproc_per_node=1 \ 9 | --nnodes=${NNODE} --node_rank=${NODE_RANK} --master_addr="${MASTER_IP}" --master_port=1234 \ 10 | ${SRC_DIR}/distributed_nn.py \ 11 | --lr=0.1 \ 12 | --momentum=0.9 \ 13 | --max-steps=100000 \ 14 | --epochs=100 \ 15 | --network=ResNet18 \ 16 | --dataset=Cifar10 \ 17 | --batch-size=64 \ 18 | --comm-type=Bcast \ 19 | --num-aggregate=2 \ 20 | --mode=normal \ 21 | --eval-freq=2000 \ 22 | --gather-type=gather \ 23 | --compress-grad=compress \ 24 | --enable-gpu= \ 25 | --train-dir=/home/ubuntu > out_node_${NODE_RANK} 2>&1 -------------------------------------------------------------------------------- /src/run_pytorch_single.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=3 \ 3 | distributed_nn.py \ 4 | --lr=0.01 \ 5 | --momentum=0.9 \ 6 | --max-steps=100000 \ 7 | --epochs=100 \ 8 | --network=LeNet \ 9 | --dataset=MNIST \ 10 | --batch-size=128 \ 11 | --comm-type=Bcast \ 12 | --num-aggregate=5 \ 13 | --mode=normal \ 14 | --eval-freq=20 \ 15 | --gather-type=gather \ 16 | --compress-grad=compress \ 17 | --enable-gpu= \ 18 | --train-dir=/home/ubuntu -------------------------------------------------------------------------------- /src/sync_replicas_master_nn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import copy 4 | from sys import getsizeof 5 | import logging 6 | from functools import reduce 7 | 8 | import numpy as np 9 | 10 | from nn_ops import NN_Trainer 11 | from util import * 12 | from optim.adam import Adam 13 | from optim.sgd import SGD 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | STEP_START_ = 1 19 | 20 | logging.basicConfig() 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | 24 | def update_params_dist_version(param, avg_grad, learning_rate): 25 | ''' 26 | update the network layer by layer 27 | ''' 28 | assert param.shape == avg_grad.shape 29 | param -= learning_rate * avg_grad 30 | return param 31 | 32 | 33 | def accuracy(output, target, topk=(1,)): 34 | """Computes the precision@k for the specified values of k""" 35 | maxk = max(topk) 36 | batch_size = target.size(0) 37 | 38 | _, pred = output.topk(maxk, 1, True, True) 39 | pred = pred.t() 40 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 41 | 42 | res = [] 43 | for k in topk: 44 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 45 | res.append(correct_k.mul_(100.0 / batch_size)) 46 | return res 47 | 48 | 49 | class GradientAccumulator(object): 50 | '''a simple class to implement gradient aggregator like the `Conditional Accumulators` in tensorflow''' 51 | def __init__(self, module, world_size, mode='None'): 52 | super(GradientAccumulator, self).__init__() 53 | # we will update this counter dynamically during the training process 54 | # the length of this counter should be number of fc layers in the network 55 | # we used list to contain gradients of layers 56 | self.gradient_aggregate_counter = [] 57 | self.model_index_range = [] 58 | self.gradient_aggregator = [] 59 | self._mode = mode 60 | 61 | for param_idx, param in enumerate(module.parameters()): 62 | tmp_aggregator = [] 63 | for worker_idx in range(world_size): 64 | tmp_aggregator.append(torch.zeros(param.size())) 65 | # initialize the gradient aggragator 66 | self.gradient_aggregator.append(tmp_aggregator) 67 | self.gradient_aggregate_counter.append(0) 68 | self.model_index_range.append(param_idx) 69 | 70 | def meset_everything(self): 71 | self._meset_grad_counter() 72 | self._meset_grad_aggregator() 73 | 74 | def _meset_grad_counter(self): 75 | self.gradient_aggregate_counter = [0 for _ in self.gradient_aggregate_counter] 76 | 77 | def _meset_grad_aggregator(self): 78 | ''' 79 | reset the buffers in grad accumulator, not sure if this is necessary 80 | ''' 81 | if self._mode == 'compress': 82 | pass 83 | else: 84 | for i, tmp_aggregator in enumerate(self.gradient_aggregator): 85 | for j, buf in enumerate(tmp_aggregator): 86 | self.gradient_aggregator[i][j] = np.zeros(self.gradient_aggregator[i][j].shape) 87 | 88 | 89 | class SyncReplicasMaster_NN(NN_Trainer): 90 | def __init__(self, **kwargs): 91 | super(NN_Trainer, self).__init__() 92 | '''master node here, no rank needed since the rank will always be 0 for master node''' 93 | self.world_size = kwargs['world_size'] 94 | self.cur_step = STEP_START_ 95 | self.lr = kwargs['learning_rate'] 96 | self.momentum = kwargs['momentum'] 97 | self.network_config = kwargs['network'] 98 | self.comm_type = kwargs['comm_method'] 99 | self._timeout_threshold = kwargs['timeout_threshold'] 100 | 101 | self._num_workers = self.world_size - 1 102 | # used to aggregate tmp gradients, the length is the same as # of fc layer 103 | self._grad_aggregate_buffer = [] 104 | self._model_shapes = [] 105 | self._first_grad_received = False 106 | self._eval_freq = kwargs['eval_freq'] 107 | self._train_dir = kwargs['train_dir'] 108 | self._expected_grad_to_recv = kwargs['kill_threshold'] 109 | self._max_steps = kwargs['max_steps'] 110 | self._compress_grad = kwargs['compress_grad'] 111 | self._gather_type = kwargs['gather_type'] 112 | self._device = kwargs['device'] 113 | 114 | ############ will be deprecated soon ############################# 115 | self._eval_batch_size = 1000 116 | 117 | def build_model(self, num_classes=10): 118 | self.network = build_model(self.network_config, num_classes) 119 | self.optimizer = SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum) 120 | # assign a gradient accumulator to collect gradients from workers 121 | self.grad_accumulator = GradientAccumulator(self.network, self.world_size, self._compress_grad) 122 | self.init_model_shapes() 123 | #self.network.to(self._device) 124 | self.network.to(torch.device("cpu")) 125 | 126 | def start(self): 127 | for i in range(1, self._max_steps+1): 128 | # switch back to training mode 129 | self.network.train() 130 | self._first_grad_received = False 131 | enough_gradients_received = False 132 | 133 | logger.info("Master node is entering step: {}".format(i)) 134 | 135 | self._bcast_weight() 136 | 137 | self._recv_grads() 138 | 139 | self._model_update() 140 | 141 | self.cur_step += 1 142 | 143 | def init_model_shapes(self): 144 | for param_idx, param in enumerate(self.network.parameters()): 145 | self._model_shapes.append(param.size()) 146 | self._grad_aggregate_buffer.append(np.zeros(param.size())) 147 | 148 | def _model_update(self): 149 | # gradient shipped from workers are averaged and update the model 150 | self._grad_aggregate_buffer = [x / self._num_workers for x in self._grad_aggregate_buffer] 151 | self.optimizer.step(grads=self._grad_aggregate_buffer) 152 | 153 | def _bcast_weight(self): 154 | for layer_idx, layer in enumerate(self.network.parameters()): 155 | layer_weight = layer.detach() 156 | dist.broadcast(layer_weight, src=0) 157 | 158 | def aggregate_gradient(self, layer_idx, gradient): 159 | self._grad_aggregate_buffer[layer_idx] = reduce((lambda x, y: x + y), gradient[1:]) 160 | 161 | def _recv_grads(self): 162 | for layer_idx, layer in enumerate(self.network.parameters()): 163 | dummpy_grad = self.grad_accumulator.gradient_aggregator[layer_idx][0] 164 | dist.gather(dummpy_grad, self.grad_accumulator.gradient_aggregator[layer_idx], dst=0) 165 | self.aggregate_gradient(layer_idx=layer_idx, gradient=self.grad_accumulator.gradient_aggregator[layer_idx]) 166 | 167 | def _generate_model_path(self): 168 | return self._train_dir+"model_step_"+str(self.cur_step) 169 | 170 | def _save_model(self, file_path): 171 | with open(file_path, "wb") as f_: 172 | torch.save(self.network.state_dict(), f_) 173 | return 174 | 175 | def _evaluate_model(self, validation_loader): 176 | self.network.eval() 177 | prec1_counter_ = prec5_counter_ = batch_counter_ = 0 178 | # which indicate an epoch based validation is done 179 | while validation_loader.dataset.epochs_completed <= self._epoch_counter: 180 | eval_image_batch, eval_label_batch = validation_loader.next_batch(batch_size=self._eval_batch_size) 181 | X_batch, y_batch = Variable(eval_image_batch.float()), Variable(eval_label_batch.long()) 182 | output = self.network(X_batch) 183 | prec1_tmp, prec5_tmp = accuracy(output.detach(), eval_label_batch.long(), topk=(1, 5)) 184 | prec1_counter_ += prec1_tmp 185 | prec5_counter_ += prec5_tmp 186 | batch_counter_ += 1 187 | prec1 = prec1_counter_ / batch_counter_ 188 | prec5 = prec5_counter_ / batch_counter_ 189 | self._epoch_counter = validation_loader.dataset.epochs_completed 190 | logger.info('Testset Performance: Cur Step:{} Prec@1: {} Prec@5: {}'.format(self.cur_step, prec1.numpy()[0], prec5.numpy()[0])) -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | from model_ops.lenet import LeNet 4 | from model_ops.resnet import * 5 | from model_ops.vgg import * 6 | 7 | def build_model(model_name, num_classes): 8 | # build network 9 | if model_name == "LeNet": 10 | return LeNet() 11 | elif model_name == "ResNet18": 12 | return ResNet18(num_classes) 13 | elif model_name == "ResNet34": 14 | return ResNet34() 15 | elif model_name == "ResNet50": 16 | return ResNet50() 17 | elif model_name == "VGG11": 18 | return vgg11_bn(num_classes) 19 | 20 | def prepare_data(args): 21 | # load training and test set here: 22 | if args.dataset == "MNIST": 23 | training_set = datasets.MNIST('./mnist_data', train=True, download=True, 24 | transform=transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.1307,), (0.3081,))])) 27 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True) 28 | test_loader = torch.utils.data.DataLoader( 29 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.1307,), (0.3081,)) 32 | ])), batch_size=args.test_batch_size, shuffle=True) 33 | elif args.dataset == "Cifar10": 34 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 35 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 36 | transform_train = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Lambda(lambda x: F.pad( 39 | Variable(x.unsqueeze(0), requires_grad=False), 40 | (4,4,4,4),mode='reflect').data.squeeze()), 41 | transforms.ToPILImage(), 42 | transforms.RandomCrop(32), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize, 46 | ]) 47 | # data prep for test set 48 | transform_test = transforms.Compose([ 49 | transforms.ToTensor(), 50 | normalize]) 51 | # load training and test set here: 52 | training_set = datasets.CIFAR10(root='./cifar10_data', train=True, 53 | download=True, transform=transform_train) 54 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, 55 | shuffle=True) 56 | testset = datasets.CIFAR10(root='./cifar10_data', train=False, 57 | download=True, transform=transform_test) 58 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 59 | shuffle=False) 60 | elif args.dataset == 'Cifar100': 61 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 62 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 63 | transform_train = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Lambda(lambda x: F.pad( 66 | Variable(x.unsqueeze(0), requires_grad=False), 67 | (4,4,4,4),mode='reflect').data.squeeze()), 68 | transforms.ToPILImage(), 69 | transforms.RandomCrop(32), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | normalize, 73 | ]) 74 | # data prep for test set 75 | transform_test = transforms.Compose([ 76 | transforms.ToTensor(), 77 | normalize]) 78 | # load training and test set here: 79 | training_set = datasets.CIFAR100(root='./cifar100_data', train=True, 80 | download=True, transform=transform_train) 81 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=args.batch_size, 82 | shuffle=True) 83 | testset = datasets.CIFAR100(root='./cifar100_data', train=False, 84 | download=True, transform=transform_test) 85 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 86 | shuffle=False) 87 | # SVHN dataset 88 | elif args.dataset == 'SVHN': 89 | training_set = datasets.SVHN('./svhn_data', split='train', transform=transforms.Compose([ 90 | transforms.RandomCrop(32, padding=4), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 94 | ])) 95 | train_loader = torch.utils.data.DataLoader(training_set, batch_size=128, 96 | shuffle=True) 97 | transform_test = transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 100 | ]) 101 | testset = datasets.SVHN(root='./svhn_data', split='test', 102 | download=True, transform=transform_test) 103 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, 104 | shuffle=False) 105 | return train_loader, test_loader -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # AWS EC2 AMI for distributed Pythorch: 2 | `ami-ebc83c93` (with pytorch/pytorch vision installed and configured with CUDA 8.0 and cuDNN 7) 3 | 4 | # Enable `MPI` backend in pytorch distributed 5 | The frist thing you need to to is to remove your current version of `pytorch`, and build it from source (we may not need this later, but for now they're not enable `MPI` automatically in the binary source). 6 | 7 | To build pytorch from source, you can follow guidence here (https://github.com/pytorch/pytorch#from-source). But there are a few things you should be careful about. 8 | 9 | 1. make sure you're in your `conda env` when you run `python setup.py install` by 10 | ``` 11 | source /home/user_name/anaconda[2 or 3]/bin/activate ~/anaconda[2 or 3] 12 | ``` 13 | otherwise, pytorch will be built in your system lib directory rather than conda lib directory. 14 | 15 | 2. make sure you use CUDA (version >= 7.5), and have cuDNN (version >= 7.0) installed. I have a quick and easy way to do this in this github repo (https://github.com/hwang595/distributed-MXNet). I'm not sure why, even the version specified here is `CUDA 7.5`, but this method will still install `CUDA 8.0` for you. But this dosen't matter. 16 | 17 | To make sure if your `MPI` is enabled, just run: 18 | ``` 19 | import torch 20 | torch.distributed.init_process_group(backend='mpi') 21 | ``` 22 | 23 | After this, we also need to have `TorchVison` built from source by running following commands: 24 | ``` 25 | git clone https://github.com/pytorch/vision.git 26 | source anaconda[2 or 3]/bin/activate ~/anaconda[2 or 3] 27 | python vision/setup.py install 28 | ``` 29 | 30 | Note: 31 | To add your `.pem` file using `ssh-add`, run following command first: 32 | ``` 33 | eval `ssh-agent -s` 34 | ``` 35 | -------------------------------------------------------------------------------- /tools/conda_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # setup Anaconda env 3 | wget https://repo.continuum.io/archive/Anaconda2-5.1.0-Linux-x86_64.sh 4 | bash Anaconda2-5.1.0-Linux-x86_64.sh -b -p ~/anaconda 5 | rm Anaconda2-5.1.0-Linux-x86_64.sh 6 | echo 'export PATH="~/anaconda/bin:$PATH"' >> ~/.bashrc 7 | 8 | # Refresh basically 9 | source .bashrc -------------------------------------------------------------------------------- /tools/config: -------------------------------------------------------------------------------- 1 | Host * 2 | StrictHostKeyChecking no 3 | -------------------------------------------------------------------------------- /tools/hosts: -------------------------------------------------------------------------------- 1 | 172.31.19.212 deeplearning-worker1 2 | 172.31.25.216 deeplearning-worker2 3 | 172.31.16.210 deeplearning-worker3 4 | 172.31.18.242 deeplearning-worker4 5 | 172.31.30.165 deeplearning-worker5 6 | 172.31.19.35 deeplearning-worker6 7 | 172.31.21.31 deeplearning-worker7 8 | 172.31.21.232 deeplearning-worker8 9 | 172.31.26.254 deeplearning-worker9 10 | 172.31.29.227 deeplearning-worker10 11 | 172.31.21.65 deeplearning-worker11 12 | 172.31.19.22 deeplearning-worker12 13 | 172.31.26.181 deeplearning-worker13 14 | 172.31.23.5 deeplearning-worker14 15 | 172.31.23.177 deeplearning-worker15 16 | 172.31.29.128 deeplearning-worker16 17 | 172.31.21.63 deeplearning-worker17 18 | 172.31.20.152 deeplearning-worker18 19 | 172.31.24.123 deeplearning-worker19 20 | 172.31.28.206 deeplearning-worker20 21 | 172.31.25.60 deeplearning-worker21 22 | 172.31.21.191 deeplearning-worker22 23 | 172.31.18.144 deeplearning-worker23 24 | 172.31.21.28 deeplearning-worker24 25 | 172.31.22.235 deeplearning-worker25 26 | 172.31.16.15 deeplearning-worker26 27 | 172.31.18.223 deeplearning-worker27 28 | 172.31.17.46 deeplearning-worker28 29 | 172.31.31.45 deeplearning-worker29 30 | 172.31.27.44 deeplearning-worker30 31 | 172.31.17.249 deeplearning-worker31 32 | 172.31.22.20 deeplearning-worker32 33 | 172.31.22.226 deeplearning-worker33 34 | 172.31.19.15 deeplearning-worker34 35 | 172.31.21.118 deeplearning-worker35 36 | 172.31.19.49 deeplearning-worker36 37 | 172.31.21.136 deeplearning-worker37 38 | 172.31.25.85 deeplearning-worker38 39 | 172.31.18.69 deeplearning-worker39 40 | 172.31.16.134 deeplearning-worker40 41 | 172.31.25.83 deeplearning-worker41 42 | 172.31.16.223 deeplearning-worker42 43 | 172.31.20.120 deeplearning-worker43 44 | 172.31.17.89 deeplearning-worker44 45 | 172.31.18.20 deeplearning-worker45 46 | 172.31.31.170 deeplearning-worker46 47 | 172.31.19.76 deeplearning-worker47 48 | 172.31.23.122 deeplearning-worker48 49 | 172.31.24.201 deeplearning-worker49 50 | 172.31.18.63 deeplearning-worker50 51 | 172.31.26.191 deeplearning-worker51 52 | 172.31.24.249 deeplearning-worker52 53 | 172.31.22.91 deeplearning-worker53 54 | 172.31.16.114 deeplearning-worker54 55 | 172.31.26.5 deeplearning-worker55 56 | 172.31.20.194 deeplearning-worker56 57 | 172.31.30.60 deeplearning-worker57 58 | 172.31.23.145 deeplearning-worker58 59 | 172.31.18.13 deeplearning-worker59 60 | 172.31.26.9 deeplearning-worker60 61 | 172.31.21.173 deeplearning-worker61 62 | 172.31.20.40 deeplearning-worker62 63 | 172.31.18.128 deeplearning-worker63 64 | 172.31.28.200 deeplearning-worker64 65 | 172.31.27.141 deeplearning-worker65 66 | -------------------------------------------------------------------------------- /tools/hosts_address: -------------------------------------------------------------------------------- 1 | 172.31.19.212 2 | 172.31.25.216 3 | 172.31.16.210 4 | 172.31.18.242 5 | 172.31.30.165 6 | 172.31.19.35 7 | 172.31.21.31 8 | 172.31.21.232 9 | 172.31.26.254 10 | 172.31.29.227 11 | 172.31.21.65 12 | 172.31.19.22 13 | 172.31.26.181 14 | 172.31.23.5 15 | 172.31.23.177 16 | 172.31.29.128 17 | 172.31.21.63 18 | 172.31.20.152 19 | 172.31.24.123 20 | 172.31.28.206 21 | 172.31.25.60 22 | 172.31.21.191 23 | 172.31.18.144 24 | 172.31.21.28 25 | 172.31.22.235 26 | 172.31.16.15 27 | 172.31.18.223 28 | 172.31.17.46 29 | 172.31.31.45 30 | 172.31.27.44 31 | 172.31.17.249 32 | 172.31.22.20 33 | 172.31.22.226 34 | 172.31.19.15 35 | 172.31.21.118 36 | 172.31.19.49 37 | 172.31.21.136 38 | 172.31.25.85 39 | 172.31.18.69 40 | 172.31.16.134 41 | 172.31.25.83 42 | 172.31.16.223 43 | 172.31.20.120 44 | 172.31.17.89 45 | 172.31.18.20 46 | 172.31.31.170 47 | 172.31.19.76 48 | 172.31.23.122 49 | 172.31.24.201 50 | 172.31.18.63 51 | 172.31.26.191 52 | 172.31.24.249 53 | 172.31.22.91 54 | 172.31.16.114 55 | 172.31.26.5 56 | 172.31.20.194 57 | 172.31.30.60 58 | 172.31.23.145 59 | 172.31.18.13 60 | 172.31.26.9 61 | 172.31.21.173 62 | 172.31.20.40 63 | 172.31.18.128 64 | 172.31.28.200 65 | 172.31.27.141 66 | -------------------------------------------------------------------------------- /tools/hosts_alias: -------------------------------------------------------------------------------- 1 | deeplearning-worker1 2 | deeplearning-worker2 3 | deeplearning-worker3 4 | deeplearning-worker4 5 | deeplearning-worker5 6 | deeplearning-worker6 7 | deeplearning-worker7 8 | deeplearning-worker8 9 | deeplearning-worker9 10 | deeplearning-worker10 11 | deeplearning-worker11 12 | deeplearning-worker12 13 | deeplearning-worker13 14 | deeplearning-worker14 15 | deeplearning-worker15 16 | deeplearning-worker16 17 | deeplearning-worker17 18 | deeplearning-worker18 19 | deeplearning-worker19 20 | deeplearning-worker20 21 | deeplearning-worker21 22 | deeplearning-worker22 23 | deeplearning-worker23 24 | deeplearning-worker24 25 | deeplearning-worker25 26 | deeplearning-worker26 27 | deeplearning-worker27 28 | deeplearning-worker28 29 | deeplearning-worker29 30 | deeplearning-worker30 31 | deeplearning-worker31 32 | deeplearning-worker32 33 | deeplearning-worker33 34 | deeplearning-worker34 35 | deeplearning-worker35 36 | deeplearning-worker36 37 | deeplearning-worker37 38 | deeplearning-worker38 39 | deeplearning-worker39 40 | deeplearning-worker40 41 | deeplearning-worker41 42 | deeplearning-worker42 43 | deeplearning-worker43 44 | deeplearning-worker44 45 | deeplearning-worker45 46 | deeplearning-worker46 47 | deeplearning-worker47 48 | deeplearning-worker48 49 | deeplearning-worker49 50 | deeplearning-worker50 51 | deeplearning-worker51 52 | deeplearning-worker52 53 | deeplearning-worker53 54 | deeplearning-worker54 55 | deeplearning-worker55 56 | deeplearning-worker56 57 | deeplearning-worker57 58 | deeplearning-worker58 59 | deeplearning-worker59 60 | deeplearning-worker60 61 | deeplearning-worker61 62 | deeplearning-worker62 63 | deeplearning-worker63 64 | deeplearning-worker64 65 | deeplearning-worker65 66 | -------------------------------------------------------------------------------- /tools/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # append hosts 4 | sudo bash -c "cat hosts >> /etc/hosts" 5 | cp config ~/.ssh/ 6 | 7 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 8 | 9 | git config --global user.name hwang595 10 | git config --global user.email hongyiwang@cs.wisc.edu 11 | 12 | sudo apt-get update 13 | sudo apt-get install pdsh -y 14 | pdsh -R ssh -w deeplearning-worker[1-$DEEPLEARNING_WORKERS_COUNT] "sudo apt-get update; sudo apt-get install pdsh -y" 15 | -------------------------------------------------------------------------------- /tools/killall.sh: -------------------------------------------------------------------------------- 1 | KEY_PEM_NAME=~/.ssh/HongyiScript2.pem 2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 3 | 4 | sudo bash -c "cat hosts >> /etc/hosts" 5 | cp config ~/.ssh/ 6 | 7 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT); 8 | do 9 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'killall python' 10 | done -------------------------------------------------------------------------------- /tools/local_script.sh: -------------------------------------------------------------------------------- 1 | KEY_PEM_DIR=/home/hwang/My_Code/AWS/HongyiScript.pem 2 | KEY_PEM_NAME=HongyiScript.pem 3 | PUB_IP_ADDR="$1" 4 | echo "Public address of master node: ${PUB_IP_ADDR}" 5 | 6 | ssh -o "StrictHostKeyChecking no" ubuntu@${PUB_IP_ADDR} 7 | scp -i ${KEY_PEM_DIR} ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR}:~/.ssh 8 | scp -i ${KEY_PEM_DIR} hosts hosts_address config ubuntu@${PUB_IP_ADDR}:~/ 9 | scp -i ${KEY_PEM_DIR} -r /home/hwang/My_Code/ps_real_pytorch ubuntu@${PUB_IP_ADDR}:~/ 10 | ssh -i ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR} 'cp ps_real_pytorch/tools/remote_script.sh ~/' 11 | ssg -i ${KEY_PEM_DIR} ubuntu@${PUB_IP_ADDR} 'cp hosts ps_real_pytorch/src/' 12 | -------------------------------------------------------------------------------- /tools/openmpi_install.sh: -------------------------------------------------------------------------------- 1 | # configure, download, and install OpenMPI 2 | sudo apt-get update 3 | sudo apt-get -y install gcc g++ make 4 | sudo apt-get update 5 | wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.1.tar.gz 6 | tar -xvf openmpi-* 7 | rm -f openmpi-3.0.1.tar.gz 8 | cd ~/openmpi-3.0.1 9 | ./configure --prefix="/home/$USER/.openmpi" 10 | make 11 | sudo make install 12 | export PATH="$PATH:/home/$USER/.openmpi/bin" 13 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/$USER/.openmpi/lib/" -------------------------------------------------------------------------------- /tools/pre_run.sh: -------------------------------------------------------------------------------- 1 | conda update -y -n base conda 2 | conda install pytorch torchvision -y -c pytorch 3 | conda install -y -c anaconda python-blosc 4 | conda install -y -c anaconda mpi4py -------------------------------------------------------------------------------- /tools/pytorch_ec2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import threading 4 | import Queue 5 | import paramiko as pm 6 | import boto3 7 | import time 8 | import json 9 | import os 10 | 11 | 12 | class Cfg(dict): 13 | 14 | def __getitem__(self, item): 15 | item = dict.__getitem__(self, item) 16 | if type(item) == type([]): 17 | return [x % self if type(x) == type("") else x for x in item] 18 | if type(item) == type(""): 19 | return item % self 20 | return item 21 | 22 | cfg = Cfg({ 23 | "name" : "Timeout", # Unique name for this specific configuration 24 | "key_name": "HongyiScript", # Necessary to ssh into created instances 25 | # Cluster topology 26 | "n_masters" : 1, # Should always be 1 27 | "n_workers" : 64, 28 | "num_replicas_to_aggregate" : "8", 29 | "method" : "spot", 30 | # Region speficiation 31 | "region" : "us-west-2", 32 | "availability_zone" : "us-west-2b", 33 | # Machine type - instance type configuration. 34 | "master_type" : "m5.2xlarge", 35 | "worker_type" : "m5.2xlarge", 36 | # please only use this AMI for pytorch 37 | "image_id": "ami-0d3c2f72d1499d30b", 38 | # Launch specifications 39 | "spot_price" : "0.15", # Has to be a string 40 | # SSH configuration 41 | "ssh_username" : "ubuntu", # For sshing. E.G: ssh ssh_username@hostname 42 | "path_to_keyfile" : "/home/hwang/My_Code/AWS/HongyiScript.pem", 43 | 44 | # NFS configuration 45 | # To set up these values, go to Services > ElasticFileSystem > Create new filesystem, and follow the directions. 46 | #"nfs_ip_address" : "172.31.3.173", # us-west-2c 47 | #"nfs_ip_address" : "172.31.35.0", # us-west-2a 48 | "nfs_ip_address" : "172.31.14.225", # us-west-2b 49 | "nfs_mount_point" : "/home/ubuntu/shared", # NFS base dir 50 | "base_out_dir" : "%(nfs_mount_point)s/%(name)s", # Master writes checkpoints to this directory. Outfiles are written to this directory. 51 | "setup_commands" : 52 | [ 53 | # "sudo rm -rf %(base_out_dir)s", 54 | "mkdir %(base_out_dir)s", 55 | ], 56 | # Command specification 57 | # Master pre commands are run only by the master 58 | "master_pre_commands" : 59 | [ 60 | "cd my_mxnet", 61 | "git fetch && git reset --hard origin/master", 62 | "cd cifar10", 63 | "ls", 64 | # "cd distributed_tensorflow/DistributedResNet", 65 | # "git fetch && git reset --hard origin/master", 66 | ], 67 | # Pre commands are run on every machine before the actual training. 68 | "pre_commands" : 69 | [ 70 | "cd my_mxnet", 71 | "git fetch && git reset --hard origin/master", 72 | "cd cifar10", 73 | ], 74 | # Model configuration 75 | "batch_size" : "32", 76 | "max_steps" : "2000", 77 | "initial_learning_rate" : ".001", 78 | "learning_rate_decay_factor" : ".95", 79 | "num_epochs_per_decay" : "1.0", 80 | # Train command specifies how the ps/workers execute tensorflow. 81 | # PS_HOSTS - special string replaced with actual list of ps hosts. 82 | # TASK_ID - special string replaced with actual task index. 83 | # JOB_NAME - special string replaced with actual job name. 84 | # WORKER_HOSTS - special string replaced with actual list of worker hosts 85 | # ROLE_ID - special string replaced with machine's identity (E.G: master, worker0, worker1, ps, etc) 86 | # %(...)s - Inserts self referential string value. 87 | "train_commands" : 88 | [ 89 | "echo ========= Start ===========" 90 | ], 91 | }) 92 | 93 | def mxnet_ec2_run(argv, configuration): 94 | client = boto3.client("ec2", region_name=configuration["region"]) 95 | ec2 = boto3.resource("ec2", region_name=configuration["region"]) 96 | 97 | def sleep_a_bit(): 98 | time.sleep(5) 99 | 100 | def summarize_instances(instances): 101 | instance_type_to_instance_map = {} 102 | for instance in sorted(instances, key=lambda x:x.id): 103 | typ = instance.instance_type 104 | if typ not in instance_type_to_instance_map: 105 | instance_type_to_instance_map[typ] = [] 106 | instance_type_to_instance_map[typ].append(instance) 107 | 108 | for type in instance_type_to_instance_map: 109 | print("Type\t", type) 110 | for instance in instance_type_to_instance_map[type]: 111 | print("instance\t", instance, "\t", instance.public_ip_address) 112 | print 113 | 114 | for k,v in instance_type_to_instance_map.items(): 115 | print("%s - %d running" % (k, len(v))) 116 | 117 | return instance_type_to_instance_map 118 | 119 | def summarize_idle_instances(argv): 120 | print("Idle instances: (Idle = not running tensorflow)") 121 | summarize_instances(get_idle_instances()) 122 | 123 | def summarize_running_instances(argv): 124 | print("Running instances: ") 125 | summarize_instances(ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}])) 126 | 127 | # Terminate all request. 128 | def terminate_all_requests(): 129 | spot_requests = client.describe_spot_instance_requests() 130 | spot_request_ids = [] 131 | for spot_request in spot_requests["SpotInstanceRequests"]: 132 | if spot_request["State"] != "cancelled" and spot_request["LaunchSpecification"]["KeyName"] == configuration["key_name"]: 133 | spot_request_id = spot_request["SpotInstanceRequestId"] 134 | spot_request_ids.append(spot_request_id) 135 | 136 | if len(spot_request_ids) != 0: 137 | print("Terminating spot requests: %s" % " ".join([str(x) for x in spot_request_ids])) 138 | client.cancel_spot_instance_requests(SpotInstanceRequestIds=spot_request_ids) 139 | 140 | # Wait until all are cancelled. 141 | # TODO: Use waiter class 142 | done = False 143 | while not done: 144 | print("Waiting for all spot requests to be terminated...") 145 | done = True 146 | spot_requests = client.describe_spot_instance_requests() 147 | states = [x["State"] for x in spot_requests["SpotInstanceRequests"] if x["LaunchSpecification"]["KeyName"] == configuration["key_name"]] 148 | for state in states: 149 | if state != "cancelled": 150 | done = False 151 | sleep_a_bit() 152 | 153 | # Terminate all instances in the configuration 154 | # Note: all_instances = ec2.instances.all() to get all intances 155 | def terminate_all_instances(): 156 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}]) 157 | all_instance_ids = [x.id for x in live_instances] 158 | print([x.id for x in live_instances]) 159 | if len(all_instance_ids) != 0: 160 | print("Terminating instances: %s" % (" ".join([str(x) for x in all_instance_ids]))) 161 | client.terminate_instances(InstanceIds=all_instance_ids) 162 | 163 | # Wait until all are terminated 164 | # TODO: Use waiter class 165 | done = False 166 | while not done: 167 | print("Waiting for all instances to be terminated...") 168 | done = True 169 | instances = ec2.instances.all() 170 | for instance in instances: 171 | if instance.state == "active": 172 | done = False 173 | sleep_a_bit() 174 | 175 | # Launch instances as specified in the configuration. 176 | def launch_instances(): 177 | method = "spot" 178 | if "method" in configuration.keys(): 179 | method = configuration["method"] 180 | worker_instance_type, worker_count = configuration["worker_type"], configuration["n_workers"] 181 | master_instance_type, master_count = configuration["master_type"], configuration["n_masters"] 182 | specs = [(worker_instance_type, worker_count), 183 | (master_instance_type, master_count)] 184 | for (instance_type, count) in specs: 185 | launch_specs = {"KeyName" : configuration["key_name"], 186 | "ImageId" : configuration["image_id"], 187 | "InstanceType" : instance_type, 188 | "Placement" : {"AvailabilityZone":configuration["availability_zone"]}, 189 | "SecurityGroups": ["default"]} 190 | if method == "spot": 191 | # TODO: EBS optimized? (Will incur extra hourly cost) 192 | client.request_spot_instances(InstanceCount=count, 193 | LaunchSpecification=launch_specs, 194 | SpotPrice=configuration["spot_price"]) 195 | elif method == "reserved": 196 | client.run_instances(ImageId=launch_specs["ImageId"], 197 | MinCount=count, 198 | MaxCount=count, 199 | KeyName=launch_specs["KeyName"], 200 | InstanceType=launch_specs["InstanceType"], 201 | Placement=launch_specs["Placement"], 202 | SecurityGroups=launch_specs["SecurityGroups"]) 203 | else: 204 | print("Unknown method: %s" % method) 205 | sys.exit(-1) 206 | 207 | 208 | # TODO: use waiter class? 209 | def wait_until_running_instances_initialized(): 210 | done = False 211 | while not done: 212 | print("Waiting for instances to be initialized...") 213 | done = True 214 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}]) 215 | ids = [x.id for x in live_instances] 216 | resps_list = [client.describe_instance_status(InstanceIds=ids[i:i+50]) for i in range(0, len(ids), 50)] 217 | statuses = [] 218 | for resp in resps_list: 219 | statuses += [x["InstanceStatus"]["Status"] for x in resp["InstanceStatuses"]] 220 | #resps = client.describe_instance_status(InstanceIds=ids) 221 | #for resp in resps["InstanceStatuses"]: 222 | # if resp["InstanceStatus"]["Status"] != "ok": 223 | # done = False 224 | print(statuses) 225 | done = statuses.count("ok") == len(statuses) 226 | if len(ids) <= 0: 227 | done = False 228 | sleep_a_bit() 229 | 230 | # Waits until status requests are all fulfilled. 231 | # Prints out status of request in between time waits. 232 | # TODO: Use waiter class 233 | def wait_until_instance_request_status_fulfilled(): 234 | requests_fulfilled = False 235 | n_active_or_open = 0 236 | while not requests_fulfilled or n_active_or_open == 0: 237 | requests_fulfilled = True 238 | statuses = client.describe_spot_instance_requests() 239 | print("InstanceRequestId, InstanceType, SpotPrice, State - Status : StatusMessage") 240 | print("-------------------------------------------") 241 | n_active_or_open = 0 242 | for instance_request in statuses["SpotInstanceRequests"]: 243 | if instance_request["LaunchSpecification"]["KeyName"] != configuration["key_name"]: 244 | continue 245 | sid = instance_request["SpotInstanceRequestId"] 246 | machine_type = instance_request["LaunchSpecification"]["InstanceType"] 247 | price = instance_request["SpotPrice"] 248 | state = instance_request["State"] 249 | status, status_string = instance_request["Status"]["Code"], instance_request["Status"]["Message"] 250 | if state == "active" or state == "open": 251 | n_active_or_open += 1 252 | print("%s, %s, %s, %s - %s : %s" % (sid, machine_type, price, state, status, status_string)) 253 | if state != "active": 254 | requests_fulfilled = False 255 | print("-------------------------------------------") 256 | sleep_a_bit() 257 | 258 | # Create a client to the instance 259 | def connect_client(instance): 260 | client = pm.SSHClient() 261 | host = instance.public_ip_address 262 | client.set_missing_host_key_policy(pm.AutoAddPolicy()) 263 | client.connect(host, username=configuration["ssh_username"], key_filename=configuration["path_to_keyfile"]) 264 | return client 265 | 266 | # Takes a list of commands (E.G: ["ls", "cd models"] 267 | # and executes command on instance, returning the stdout. 268 | # Executes everything in one session, and returns all output from all the commands. 269 | def run_ssh_commands(instance, commands): 270 | done = False 271 | while not done: 272 | try: 273 | print("Instance %s, Running ssh commands:\n%s\n" % (instance.public_ip_address, "\n".join(commands))) 274 | 275 | # Always need to exit 276 | commands.append("exit") 277 | 278 | # Set up ssh client 279 | client = connect_client(instance) 280 | 281 | # Clear the stdout from ssh'ing in 282 | # For each command perform command and read stdout 283 | commandstring = "\n".join(commands) 284 | stdin, stdout, stderr = client.exec_command(commandstring) 285 | output = stdout.read() 286 | 287 | # Close down 288 | stdout.close() 289 | stdin.close() 290 | client.close() 291 | done = True 292 | except Exception as e: 293 | done = False 294 | print(e.message) 295 | return output 296 | 297 | def run_ssh_commands_parallel(instance, commands, q): 298 | output = run_ssh_commands(instance, commands) 299 | q.put((instance, output)) 300 | 301 | # Checks whether instance is idle. Assumed that instance is up and running. 302 | # An instance is idle if it is not running tensorflow... 303 | # Returns a tuple of (instance, is_instance_idle). We return a tuple for multithreading ease. 304 | def is_instance_idle(q, instance): 305 | python_processes = run_ssh_commands(instance, ["ps aux | grep python"]) 306 | q.put((instance, not "ps_hosts" in python_processes and not "ps_workers" in python_processes)) 307 | 308 | # Idle instances are running instances that are not running the inception model. 309 | # We check whether an instance is running the inception model by ssh'ing into a running machine, 310 | # and checking whether python is running. 311 | def get_idle_instances(): 312 | live_instances = ec2.instances.filter( 313 | Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, 314 | {'Name': 'key-name', 'Values': [configuration["key_name"]]}]) 315 | threads = [] 316 | q = Queue.Queue() 317 | 318 | # Run commands in parallel, writing to the queue 319 | for instance in live_instances: 320 | t = threading.Thread(target=is_instance_idle, args=(q, instance)) 321 | t.daemon = True 322 | t.start() 323 | threads.append(t) 324 | 325 | # Wait for threads to finish 326 | for thread in threads: 327 | thread.join() 328 | 329 | # Collect idle instances 330 | idle_instances = [] 331 | while not q.empty(): 332 | instance, is_idle = q.get() 333 | if is_idle: 334 | idle_instances.append(instance) 335 | 336 | return idle_instances 337 | 338 | def get_instance_requirements(): 339 | # Get the requirements given the specification of worker/master/etc machine types 340 | worker_instance_type, worker_count = configuration["worker_type"], configuration["n_workers"] 341 | master_instance_type, master_count = configuration["master_type"], configuration["n_masters"] 342 | specs = [(worker_instance_type, worker_count), 343 | (master_instance_type, master_count)] 344 | reqs = {} 345 | for (type_needed, count_needed) in specs: 346 | if type_needed not in reqs: 347 | reqs[type_needed] = 0 348 | reqs[type_needed] += count_needed 349 | return reqs 350 | 351 | # Returns whether the idle instances satisfy the specs of the configuration. 352 | def check_idle_instances_satisfy_configuration(): 353 | # Create a map of instance types to instances of that type 354 | idle_instances = get_idle_instances() 355 | instance_type_to_instance_map = summarize_instances(idle_instances) 356 | 357 | # Get instance requirements 358 | reqs = get_instance_requirements() 359 | 360 | # Check the requirements are satisfied. 361 | print("Checking whether # of running instances satisfies the configuration...") 362 | for k,v in instance_type_to_instance_map.items(): 363 | n_required = 0 if k not in reqs else reqs[k] 364 | print("%s - %d running vs %d required" % (k,len(v),n_required)) 365 | if len(v) < n_required: 366 | print("Error, running instances failed to satisfy configuration requirements") 367 | sys.exit(0) 368 | print("Success, running instances satisfy configuration requirement") 369 | 370 | def shut_everything_down(argv): 371 | terminate_all_requests() 372 | terminate_all_instances() 373 | 374 | def run_mxnet_grid_search(argv, port=1334): 375 | # Check idle instances satisfy configs 376 | check_idle_instances_satisfy_configuration() 377 | 378 | # Get idle instances 379 | idle_instances = get_idle_instances() 380 | 381 | # Assign instances for worker/ps/etc 382 | instance_type_to_instance_map = summarize_instances(idle_instances) 383 | specs = { 384 | "master" : {"instance_type" : configuration["master_type"], 385 | "n_required" : configuration["n_masters"]}, 386 | "worker" : {"instance_type" : configuration["worker_type"], 387 | "n_required" : configuration["n_workers"]} 388 | } 389 | machine_assignments = { 390 | "master" : [], 391 | "worker" : [] 392 | } 393 | for role, requirement in sorted(specs.items(), key=lambda x:x[0]): 394 | instance_type_for_role = requirement["instance_type"] 395 | n_instances_needed = requirement["n_required"] 396 | instances_to_assign, rest = instance_type_to_instance_map[instance_type_for_role][:n_instances_needed], instance_type_to_instance_map[instance_type_for_role][n_instances_needed:] 397 | instance_type_to_instance_map[instance_type_for_role] = rest 398 | machine_assignments[role] = instances_to_assign 399 | 400 | # Construct the host strings necessary for running the inception command. 401 | # Note we use private ip addresses to avoid EC2 transfer costs. 402 | worker_host_string = ",".join([x.private_ip_address+":"+str(port) for x in machine_assignments["master"] + machine_assignments["worker"]]) 403 | 404 | # Create a map of command&machine assignments 405 | command_machine_assignments = {} 406 | setup_machine_assignments = {} 407 | 408 | # Construct the master command 409 | command_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["master_pre_commands"])} 410 | # setup_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["setup_commands"])} 411 | for command_string in configuration["train_commands"]: 412 | command_machine_assignments["master"]["commands"].append(command_string.replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", "master")) 413 | print(command_machine_assignments) 414 | 415 | # Construct the worker commands 416 | for worker_id, instance in enumerate(machine_assignments["worker"]): 417 | name = "worker_%d" % worker_id 418 | command_machine_assignments[name] = {"instance" : instance, 419 | "commands" : list(configuration["pre_commands"])} 420 | for command_string in configuration["train_commands"]: 421 | command_machine_assignments[name]["commands"].append(command_string.replace("TASK_ID", "%d" % (worker_id+1)).replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", name)) 422 | 423 | print(command_machine_assignments) 424 | 425 | # Run the commands via ssh in parallel 426 | threads = [] 427 | q = Queue.Queue() 428 | 429 | for name, command_and_machine in setup_machine_assignments.items(): 430 | instance = command_and_machine["instance"] 431 | commands = command_and_machine["commands"] 432 | print("-----------------------") 433 | print("Pre Command: %s\n" % " ".join(commands)) 434 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 435 | t.start() 436 | threads.append(t) 437 | 438 | # Wait until commands are all finished 439 | for t in threads: 440 | t.join() 441 | 442 | threads = [] 443 | q = Queue.Queue() 444 | 445 | running_process = 0 446 | for name, command_and_machine in command_machine_assignments.items(): 447 | instance = command_and_machine["instance"] 448 | neo_commands = "python train_cifar10.py --running_mode=grid_search --gpus=0 "\ 449 | "--running_process={} "\ 450 | "--batch-size={} "\ 451 | "--dir={}/grid_search> {}/grid_search/batch_size_{}/running_{}_process.out 2>&1 &".format( 452 | running_process, 453 | configuration['batch_size'], 454 | configuration['nfs_mount_point'], 455 | configuration['nfs_mount_point'], 456 | configuration['batch_size'], 457 | running_process) 458 | 459 | commands = command_and_machine["commands"] 460 | commands.append('mkdir {}/grid_search'.format(configuration['nfs_mount_point'])) 461 | commands.append('mkdir {}/grid_search/batch_size_{}'.format( 462 | configuration['nfs_mount_point'], 463 | configuration['batch_size'])) 464 | commands.append(neo_commands) 465 | 466 | print("-----------------------") 467 | print("Command: %s\n" % " ".join(commands)) 468 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 469 | t.start() 470 | threads.append(t) 471 | running_process += 1 472 | 473 | # Wait until commands are all finished 474 | for t in threads: 475 | t.join() 476 | 477 | # Print the output 478 | while not q.empty(): 479 | instance, output = q.get() 480 | print(instance.public_ip_address) 481 | print(output) 482 | 483 | # Debug print 484 | instances = [] 485 | print("\n--------------------------------------------------\n") 486 | print("Machine assignments:") 487 | print("------------------------") 488 | for name, command_and_machine in command_machine_assignments.items(): 489 | instance = command_and_machine["instance"] 490 | instances.append(instance) 491 | commands = command_and_machine["commands"] 492 | ssh_command = "ssh -i %s %s@%s" % (configuration["path_to_keyfile"], configuration["ssh_username"], instance.public_ip_address) 493 | print("%s - %s" % (name, instance.instance_id)) 494 | print("To ssh: %s" % ssh_command) 495 | print("------------------------") 496 | 497 | # Print out list of instance ids (which will be useful in selctively stopping inception 498 | # for given instances. 499 | instance_cluster_string = ",".join([x.instance_id for x in instances]) 500 | print("\nInstances cluster string: %s" % instance_cluster_string) 501 | 502 | # Print out the id of the configuration file 503 | cluster_save = { 504 | "configuration" : configuration, 505 | "name" : configuration["name"], 506 | "command_machine_assignments" : command_machine_assignments, 507 | "cluster_string" : instance_cluster_string 508 | } 509 | 510 | return cluster_save 511 | 512 | 513 | def run_mxnet_loss_curve(argv, port=1334): 514 | # Check idle instances satisfy configs 515 | check_idle_instances_satisfy_configuration() 516 | 517 | # Get idle instances 518 | idle_instances = get_idle_instances() 519 | 520 | # Assign instances for worker/ps/etc 521 | instance_type_to_instance_map = summarize_instances(idle_instances) 522 | specs = { 523 | "master" : {"instance_type" : configuration["master_type"], 524 | "n_required" : configuration["n_masters"]}, 525 | "worker" : {"instance_type" : configuration["worker_type"], 526 | "n_required" : configuration["n_workers"]} 527 | } 528 | machine_assignments = { 529 | "master" : [], 530 | "worker" : [] 531 | } 532 | for role, requirement in sorted(specs.items(), key=lambda x:x[0]): 533 | instance_type_for_role = requirement["instance_type"] 534 | n_instances_needed = requirement["n_required"] 535 | instances_to_assign, rest = instance_type_to_instance_map[instance_type_for_role][:n_instances_needed], instance_type_to_instance_map[instance_type_for_role][n_instances_needed:] 536 | instance_type_to_instance_map[instance_type_for_role] = rest 537 | machine_assignments[role] = instances_to_assign 538 | 539 | # Construct the host strings necessary for running the inception command. 540 | # Note we use private ip addresses to avoid EC2 transfer costs. 541 | worker_host_string = ",".join([x.private_ip_address+":"+str(port) for x in machine_assignments["master"] + machine_assignments["worker"]]) 542 | 543 | # Create a map of command&machine assignments 544 | command_machine_assignments = {} 545 | setup_machine_assignments = {} 546 | 547 | # Construct the master command 548 | command_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["master_pre_commands"])} 549 | # setup_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["setup_commands"])} 550 | for command_string in configuration["train_commands"]: 551 | command_machine_assignments["master"]["commands"].append(command_string.replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", "master")) 552 | print(command_machine_assignments) 553 | 554 | # Construct the worker commands 555 | for worker_id, instance in enumerate(machine_assignments["worker"]): 556 | name = "worker_%d" % worker_id 557 | command_machine_assignments[name] = {"instance" : instance, 558 | "commands" : list(configuration["pre_commands"])} 559 | for command_string in configuration["train_commands"]: 560 | command_machine_assignments[name]["commands"].append(command_string.replace("TASK_ID", "%d" % (worker_id+1)).replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", name)) 561 | 562 | print(command_machine_assignments) 563 | 564 | # Run the commands via ssh in parallel 565 | threads = [] 566 | q = Queue.Queue() 567 | 568 | for name, command_and_machine in setup_machine_assignments.items(): 569 | instance = command_and_machine["instance"] 570 | commands = command_and_machine["commands"] 571 | print("-----------------------") 572 | print("Pre Command: %s\n" % " ".join(commands)) 573 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 574 | t.start() 575 | threads.append(t) 576 | 577 | # Wait until commands are all finished 578 | for t in threads: 579 | t.join() 580 | 581 | threads = [] 582 | q = Queue.Queue() 583 | 584 | batch_size_list = [4, 32, 50, 100, 500, 1000] 585 | learning_rate_list = [0.046, 0.05, 0.068, 0.068, 0.048, 0.086] 586 | running_process = 0 587 | for name, command_and_machine in command_machine_assignments.items(): 588 | instance = command_and_machine["instance"] 589 | neo_commands = "python train_cifar10.py --running_mode=training --gpus=0 "\ 590 | "--batch-size={} "\ 591 | "--lr={} "\ 592 | "--model-prefix={}/model_checkpoints/batch_size_{} "\ 593 | "--dir={}/loss_curve > "\ 594 | "{}/loss_curve/running_batch_size_{}.out 2>&1 &".format( 595 | batch_size_list[running_process], 596 | learning_rate_list[running_process], 597 | configuration['nfs_mount_point'], 598 | batch_size_list[running_process], 599 | configuration['nfs_mount_point'], 600 | configuration['nfs_mount_point'], 601 | batch_size_list[running_process]) 602 | 603 | commands = command_and_machine["commands"] 604 | commands.append('mkdir {}/model_checkpoints/'.format(configuration['nfs_mount_point'])) 605 | commands.append('mkdir {}/loss_curve'.format(configuration['nfs_mount_point'])) 606 | commands.append(neo_commands) 607 | 608 | print("-----------------------") 609 | print("Command: %s\n" % " ".join(commands)) 610 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 611 | t.start() 612 | threads.append(t) 613 | running_process += 1 614 | 615 | # Wait until commands are all finished 616 | for t in threads: 617 | t.join() 618 | 619 | # Print the output 620 | while not q.empty(): 621 | instance, output = q.get() 622 | print(instance.public_ip_address) 623 | print(output) 624 | 625 | # Debug print 626 | instances = [] 627 | print("\n--------------------------------------------------\n") 628 | print("Machine assignments:") 629 | print("------------------------") 630 | for name, command_and_machine in command_machine_assignments.items(): 631 | instance = command_and_machine["instance"] 632 | instances.append(instance) 633 | commands = command_and_machine["commands"] 634 | ssh_command = "ssh -i %s %s@%s" % (configuration["path_to_keyfile"], configuration["ssh_username"], instance.public_ip_address) 635 | print("%s - %s" % (name, instance.instance_id)) 636 | print("To ssh: %s" % ssh_command) 637 | print("------------------------") 638 | 639 | # Print out list of instance ids (which will be useful in selctively stopping inception 640 | # for given instances. 641 | instance_cluster_string = ",".join([x.instance_id for x in instances]) 642 | print("\nInstances cluster string: %s" % instance_cluster_string) 643 | 644 | # Print out the id of the configuration file 645 | cluster_save = { 646 | "configuration" : configuration, 647 | "name" : configuration["name"], 648 | "command_machine_assignments" : command_machine_assignments, 649 | "cluster_string" : instance_cluster_string 650 | } 651 | 652 | return cluster_save 653 | 654 | 655 | 656 | def get_hosts(argv, port=22): 657 | # Check idle instances satisfy configs 658 | check_idle_instances_satisfy_configuration() 659 | 660 | # Get idle instances 661 | idle_instances = get_idle_instances() 662 | 663 | # Assign instances for worker/ps/etc 664 | instance_type_to_instance_map = summarize_instances(idle_instances) 665 | specs = { 666 | "master" : {"instance_type" : configuration["master_type"], 667 | "n_required" : configuration["n_masters"]}, 668 | "worker" : {"instance_type" : configuration["worker_type"], 669 | "n_required" : configuration["n_workers"]} 670 | } 671 | machine_assignments = { 672 | "master" : [], 673 | "worker" : [] 674 | } 675 | for role, requirement in sorted(specs.items(), key=lambda x:x[0]): 676 | instance_type_for_role = requirement["instance_type"] 677 | n_instances_needed = requirement["n_required"] 678 | instances_to_assign, rest = instance_type_to_instance_map[instance_type_for_role][:n_instances_needed], instance_type_to_instance_map[instance_type_for_role][n_instances_needed:] 679 | instance_type_to_instance_map[instance_type_for_role] = rest 680 | machine_assignments[role] = instances_to_assign 681 | 682 | # Construct the host strings necessary for running the inception command. 683 | # Note we use private ip addresses to avoid EC2 transfer costs. 684 | worker_host_string = ",".join([x.private_ip_address+":"+str(port) for x in machine_assignments["master"] + machine_assignments["worker"]]) 685 | hosts_out = open('hosts', 'w') 686 | print('master ip ', machine_assignments['master'][0].public_ip_address) 687 | count = 0 688 | for instance in machine_assignments["master"] + machine_assignments["worker"]: 689 | count += 1 690 | print('{}\tdeeplearning-worker{}'.format(instance.private_ip_address, count), end='\n', file=hosts_out) 691 | hosts_out.flush() 692 | hosts_out.close() 693 | 694 | hosts_alias_out = open('hosts_alias', 'w') 695 | count = 0 696 | for _ in machine_assignments["master"] + machine_assignments["worker"]: 697 | count += 1 698 | print('deeplearning-worker{}'.format(count), end='\n', file=hosts_alias_out) 699 | hosts_alias_out.flush() 700 | hosts_alias_out.close() 701 | 702 | hosts_alias_out = open('hosts_address', 'w') 703 | count = 0 704 | for instance in machine_assignments["master"] + machine_assignments["worker"]: 705 | count += 1 706 | print('{}'.format(instance.private_ip_address), end='\n', file=hosts_alias_out) 707 | hosts_alias_out.flush() 708 | hosts_alias_out.close() 709 | 710 | # # Create a map of command&machine assignments 711 | # command_machine_assignments = {} 712 | # setup_machine_assignments = {} 713 | # 714 | # # Construct the master command 715 | # command_machine_assignments["master"] = {"instance" : machine_assignments["master"][0], "commands" : list(configuration["master_pre_commands"])} 716 | # for command_string in configuration["train_commands"]: 717 | # command_machine_assignments["master"]["commands"].append(command_string.replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", "master")) 718 | # print(command_machine_assignments) 719 | # 720 | # # Construct the worker commands 721 | # for worker_id, instance in enumerate(machine_assignments["worker"]): 722 | # name = "worker_%d" % worker_id 723 | # command_machine_assignments[name] = {"instance" : instance, 724 | # "commands" : list(configuration["pre_commands"])} 725 | # for command_string in configuration["train_commands"]: 726 | # command_machine_assignments[name]["commands"].append(command_string.replace("TASK_ID", "%d" % (worker_id+1)).replace("JOB_NAME", "worker").replace("WORKER_HOSTS", worker_host_string).replace("ROLE_ID", name)) 727 | # 728 | # print(command_machine_assignments) 729 | # 730 | # # Run the commands via ssh in parallel 731 | # threads = [] 732 | # q = Queue.Queue() 733 | # 734 | # for name, command_and_machine in setup_machine_assignments.items(): 735 | # instance = command_and_machine["instance"] 736 | # commands = command_and_machine["commands"] 737 | # print("-----------------------") 738 | # print("Pre Command: %s\n" % " ".join(commands)) 739 | # t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 740 | # t.start() 741 | # threads.append(t) 742 | # 743 | # # Wait until commands are all finished 744 | # for t in threads: 745 | # t.join() 746 | # 747 | # threads = [] 748 | # q = Queue.Queue() 749 | # 750 | # batch_size_list = [4, 32, 50, 100, 500, 1000] 751 | # learning_rate_list = [0.046, 0.05, 0.068, 0.068, 0.048, 0.086] 752 | # running_process = 0 753 | # for name, command_and_machine in command_machine_assignments.items(): 754 | # instance = command_and_machine["instance"] 755 | # neo_commands = "python train_cifar10.py --running_mode=training --gpus=0 "\ 756 | # "--batch-size={} "\ 757 | # "--lr={} "\ 758 | # "--model-prefix={}/model_checkpoints/batch_size_{} "\ 759 | # "--dir={}/loss_curve > "\ 760 | # "{}/loss_curve/running_batch_size_{}.out 2>&1 &".format( 761 | # batch_size_list[running_process], 762 | # learning_rate_list[running_process], 763 | # configuration['nfs_mount_point'], 764 | # batch_size_list[running_process], 765 | # configuration['nfs_mount_point'], 766 | # configuration['nfs_mount_point'], 767 | # batch_size_list[running_process]) 768 | # 769 | # commands = command_and_machine["commands"] 770 | # commands.append('mkdir {}/model_checkpoints/'.format(configuration['nfs_mount_point'])) 771 | # commands.append('mkdir {}/loss_curve'.format(configuration['nfs_mount_point'])) 772 | # commands.append(neo_commands) 773 | # 774 | # print("-----------------------") 775 | # print("Command: %s\n" % " ".join(commands)) 776 | # t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 777 | # t.start() 778 | # threads.append(t) 779 | # running_process += 1 780 | # 781 | # # Wait until commands are all finished 782 | # for t in threads: 783 | # t.join() 784 | # 785 | # # Print the output 786 | # while not q.empty(): 787 | # instance, output = q.get() 788 | # print(instance.public_ip_address) 789 | # print(output) 790 | # 791 | # # Debug print 792 | # instances = [] 793 | # print("\n--------------------------------------------------\n") 794 | # print("Machine assignments:") 795 | # print("------------------------") 796 | # for name, command_and_machine in command_machine_assignments.items(): 797 | # instance = command_and_machine["instance"] 798 | # instances.append(instance) 799 | # commands = command_and_machine["commands"] 800 | # ssh_command = "ssh -i %s %s@%s" % (configuration["path_to_keyfile"], configuration["ssh_username"], instance.public_ip_address) 801 | # print("%s - %s" % (name, instance.instance_id)) 802 | # print("To ssh: %s" % ssh_command) 803 | # print("------------------------") 804 | # 805 | # # Print out list of instance ids (which will be useful in selctively stopping inception 806 | # # for given instances. 807 | # instance_cluster_string = ",".join([x.instance_id for x in instances]) 808 | # print("\nInstances cluster string: %s" % instance_cluster_string) 809 | # 810 | # # Print out the id of the configuration file 811 | # cluster_save = { 812 | # "configuration" : configuration, 813 | # "name" : configuration["name"], 814 | # "command_machine_assignments" : command_machine_assignments, 815 | # "cluster_string" : instance_cluster_string 816 | # } 817 | # 818 | # return cluster_save 819 | return 820 | 821 | def kill_python(argv): 822 | if len(argv) != 3: 823 | print("Usage: python inception_ec2.py kill_python instance_id1,instance_id2,id3...") 824 | sys.exit(0) 825 | cluster_instance_string = argv[2] 826 | instance_ids_to_shutdown = cluster_instance_string.split(",") 827 | 828 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) 829 | threads = [] 830 | q = Queue.Queue() 831 | for instance in live_instances: 832 | if instance.instance_id in instance_ids_to_shutdown: 833 | commands = ["sudo pkill -9 python"] 834 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 835 | t.start() 836 | threads.append(t) 837 | for thread in threads: 838 | thread.join() 839 | summarize_idle_instances(None) 840 | 841 | def kill_all_python(argv): 842 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}]) 843 | threads = [] 844 | q = Queue.Queue() 845 | for instance in live_instances: 846 | commands = ["sudo pkill -9 python"] 847 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 848 | t.start() 849 | threads.append(t) 850 | for thread in threads: 851 | thread.join() 852 | summarize_idle_instances(None) 853 | 854 | def run_command(argv, quiet=False): 855 | if len(argv) != 4: 856 | print("Usage: python inception_ec2.py run_command instance_id1,instance_id2,id3... command") 857 | sys.exit(0) 858 | cluster_instance_string = argv[2] 859 | command = argv[3] 860 | instance_ids_to_run_command = cluster_instance_string.split(",") 861 | 862 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}]) 863 | threads = [] 864 | q = Queue.Queue() 865 | for instance in live_instances: 866 | if instance.instance_id in instance_ids_to_run_command: 867 | commands = [command] 868 | t = threading.Thread(target=run_ssh_commands_parallel, args=(instance, commands, q)) 869 | t.start() 870 | threads.append(t) 871 | for thread in threads: 872 | thread.join() 873 | 874 | while not q.empty(): 875 | instance, output = q.get() 876 | if not quiet: 877 | print(instance, output) 878 | 879 | # Setup nfs on all instances 880 | def setup_nfs(): 881 | print("Clearing previous nfs file system...") 882 | live_instances = ec2.instances.filter(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}, {'Name': 'key-name', 'Values': [configuration["key_name"]]}]) 883 | live_instances_string = ",".join([x.instance_id for x in live_instances]) 884 | rm_command = "sudo rm -rf %s" % configuration["nfs_mount_point"] 885 | argv = ["python", "inception_ec2.py", live_instances_string, rm_command] 886 | # argv = ["python", "inception_ec2.py", live_instances_string] 887 | run_command(argv, quiet=True) 888 | 889 | print("Installing nfs on all running instances...") 890 | update_command = "sudo apt-get -y update" 891 | install_nfs_command = "sudo apt-get -y install nfs-common" 892 | create_mount_command = "mkdir %s" % configuration["nfs_mount_point"] 893 | setup_nfs_command = "sudo mount -t nfs4 -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 %s:/ %s" % (configuration["nfs_ip_address"], configuration["nfs_mount_point"]) 894 | reduce_permissions_command = "sudo chmod 777 %s " % configuration["nfs_mount_point"] 895 | command = update_command + " && " + install_nfs_command + " && " + create_mount_command + " && " + setup_nfs_command + " && " + reduce_permissions_command 896 | 897 | # pretty hackish 898 | argv = ["python", "inception_ec2.py", live_instances_string, command] 899 | run_command(argv, quiet=True) 900 | return 901 | 902 | # Launch instances as specified by the configuration. 903 | # We also want a shared filesystem to write model checkpoints. 904 | # For simplicity we will have the user specify the filesystem via the config. 905 | def launch(argv): 906 | method = "spot" 907 | if "method" in configuration: 908 | method = configuration["method"] 909 | launch_instances() 910 | if method == "spot": 911 | wait_until_instance_request_status_fulfilled() 912 | wait_until_running_instances_initialized() 913 | print('setup nfs') 914 | setup_nfs() 915 | 916 | def clean_launch_and_run(argv): 917 | # 1. Kills all instances in region 918 | # 2. Kills all requests in region 919 | # 3. Launches requests 920 | # 5. Waits until launch requests have all been satisfied, 921 | # printing status outputs in the meanwhile 922 | # 4. Checks that configuration has been satisfied 923 | # 5. Runs inception 924 | shut_everything_down(None) 925 | launch(None) 926 | return run_mxnet_grid_search(None) 927 | 928 | def help(hmap): 929 | print("Usage: python inception_ec2.py [command]") 930 | print("Commands:") 931 | for k,v in hmap.items(): 932 | print("%s - %s" % (k,v)) 933 | 934 | ############################## 935 | # tf_ec2 main starting point # 936 | ############################## 937 | 938 | command_map = { 939 | "launch" : launch, 940 | "clean_launch_and_run" : clean_launch_and_run, 941 | "shutdown" : shut_everything_down, 942 | "run_mxnet_grid_search": run_mxnet_grid_search, 943 | "run_mxnet_loss_curve": run_mxnet_loss_curve, 944 | "get_hosts": get_hosts, 945 | "kill_all_python" : kill_all_python, 946 | "list_idle_instances" : summarize_idle_instances, 947 | "list_running_instances" : summarize_running_instances, 948 | "kill_python" : kill_python, 949 | "run_command" : run_command, 950 | "setup_nfs": setup_nfs, 951 | } 952 | help_map = { 953 | "launch" : "Launch instances", 954 | "clean_launch_and_run" : "Shut everything down, launch instances, wait until requests fulfilled, check that configuration is fulfilled, and launch and run inception.", 955 | "shutdown" : "Shut everything down by cancelling all instance requests, and terminating all instances.", 956 | "list_idle_instances" : "Lists all idle instances. Idle instances are running instances not running tensorflow.", 957 | "list_running_instances" : "Lists all running instances.", 958 | "run_mxnet_grid_search": "", 959 | "run_mxnet_loss_curve": "", 960 | "setup_nfs": "", 961 | "kill_all_python" : "Kills python running inception training on ALL instances.", 962 | "kill_python" : "Kills python running inception on instances indicated by instance id string separated by ',' (no spaces).", 963 | "run_command" : "Runs given command on instances selcted by instance id string, separated by ','.", 964 | } 965 | 966 | if len(argv) < 2: 967 | help(help_map) 968 | sys.exit(0) 969 | 970 | command = argv[1] 971 | return command_map[command](argv) 972 | 973 | if __name__ == "__main__": 974 | print(cfg) 975 | mxnet_ec2_run(sys.argv, cfg) 976 | -------------------------------------------------------------------------------- /tools/remote_script.sh: -------------------------------------------------------------------------------- 1 | KEY_PEM_NAME=HongyiScript.pem 2 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 3 | #cd ~/ps_pytorch/src/ 4 | #bash ../tools/pre_run.sh 5 | #bash data_prepare.sh 6 | #cd ~ 7 | 8 | sudo bash -c "cat hosts >> /etc/hosts" 9 | cp config ~/.ssh/ 10 | 11 | cd ~/.ssh 12 | eval `ssh-agent -s` 13 | ssh-add ${KEY_PEM_NAME} 14 | ssh-keygen -t rsa -b 4096 -C "hongyiwang.hdu@gmail.com" 15 | 16 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT); 17 | do 18 | scp -i ${KEY_PEM_NAME} id_rsa.pub deeplearning-worker${i}:~/.ssh 19 | #ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'git clone https://github.com/hwang595/ps_pytorch.git; cd ~/.ssh; cat id_rsa.pub >> authorized_keys; bash ~/ps_pytorch/tools/pre_run.sh' 20 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'cd ~/.ssh; cat id_rsa.pub >> authorized_keys' 21 | scp -i ${KEY_PEM_NAME} -r /home/ubuntu/ps_real_pytorch deeplearning-worker${i}:~/ 22 | #scp -i ${KEY_PEM_NAME} -r ~/ps_pytorch deeplearning-worker${i}:~ 23 | echo "Done writing public key to worker: deeplearning-worker${i}" 24 | done 25 | -------------------------------------------------------------------------------- /tools/update_git_dir.sh: -------------------------------------------------------------------------------- 1 | cd ~ 2 | KEY_PEM_NAME=HongyiScript.pem 3 | export DEEPLEARNING_WORKERS_COUNT=`wc -l < hosts` 4 | 5 | sudo bash -c "cat hosts >> /etc/hosts" 6 | 7 | for i in $(seq 2 $DEEPLEARNING_WORKERS_COUNT); 8 | do 9 | ssh -i ${KEY_PEM_NAME} deeplearning-worker${i} 'git config --global user.name hwang595; git config --global user.email hongyiwang@cs.wisc.edu; cd ~/ps_pytorch; git pull' 10 | echo "Done pull git repo on worker: deeplearning-worker${i}" 11 | done --------------------------------------------------------------------------------