├── ps-algo.png
├── architecture.jpg
├── install-dependencies.sh
├── README.md
└── public-asgd.py
/ps-algo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xbfu/PyTorch-ParameterServer/HEAD/ps-algo.png
--------------------------------------------------------------------------------
/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xbfu/PyTorch-ParameterServer/HEAD/architecture.jpg
--------------------------------------------------------------------------------
/install-dependencies.sh:
--------------------------------------------------------------------------------
1 | sudo apt-get update
2 | sudo apt install -y python3-pip
3 | pip3 install torch==1.9.0
4 | pip3 install Pillow==8.2.0
5 | pip3 install torchvision==0.10.0
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-ParameterServer
2 | An implementation of parameter server (PS) framework [1] based on Remote Procedure Call (RPC) in PyTorch [2].
3 |
4 | ## Table of Contents
5 |
6 | - [PS-based Architecture](#PS-based-architecture)
7 | - [Implementation](#implementation)
8 | - [Environments](#environments)
9 | - [Quick Start](#quick-start)
10 | - [Download the code](#download-the-code)
11 | - [Install dependencies](#install-dependencies)
12 | - [Prepare datasets](#prepare-datasets)
13 | - [Train](#train)
14 | - [Performance](#performance)
15 | - [Usage](#Usage)
16 | - [References](#References)
17 |
18 | ## PS-based Architecture
19 |

20 | The figure [3] below shows the PS-based architecture. The architecture consists of two logical entities: one (or multiple) PS(s) and multiple workers. The whole dataset is partitioned among workers and the PS maintains model parameters. During training, each worker pulls model parameters from the PS, computes gradients on a mini-batch from its data partition, and pushes the gradients to the PS. The PS updates model parameters with gradients from the workers according to a synchronization strategy and sends the updated parameters back to the workers. The pseudocode [1] of this architecture is shown as follows.
21 | 
22 |
23 | ## Implementation
24 | This code is based on torch.distributed.rpc [4]. It is used to train ResNet50 [5] on Imagenette dataset [6] - a subset of ImageNet [7] with one PS (rank=0) and 4 workers (rank=1,2,3,4).
25 | ## Environments
26 | The code is developed under the following configurations.
27 | Server: a g3.16xlarge instance with 4 NVIDIA Tesla M60 GPUs on AWS EC2
28 | System: Ubuntu 18.04
29 | Software: python==3.6.9, torch==1.9.0, torchvision==0.10.0
30 | ## Quick Start
31 | ### Download the code
32 | ```bash
33 | git clone https://github.com/xbfu/PyTorch-ParameterServer.git
34 | ```
35 | ### Install dependencies
36 | ```bash
37 | cd PyTorch-ParameterServer
38 | sudo sh install-dependencies.sh
39 | ```
40 | ### Prepare datasets
41 | ```bash
42 | cd PyTorch-ParameterServer
43 | wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
44 | tar -zxf imagenette2.tgz
45 | ```
46 | ### Train
47 | For PS
48 | ```python
49 | python public-asgd.py --rank=0
50 | ```
51 | For workers
52 | ```python
53 | python public-asgd.py --rank=r
54 | ```
55 | `r=1,2,3,4` is the rank of each worker.
56 |
57 | ### Performance
58 | Sync Mode | Training Time (seconds)
59 | :-: | :-:
60 | Single | 858
61 | Syn | 533
62 | Asyn | 268
63 | ## Usage
64 | On one machine with multiple GPUs
65 | For PS
66 | ```python
67 | python public-asgd.py --rank=0
68 | ```
69 | For workers
70 | ```python
71 | python public-asgd.py --rank=r
72 | ```
73 | `r=1,2,3,4` is the rank of each worker.
74 |
75 | On multiple machines
76 | For PS
77 | ```python
78 | python public-asgd.py --rank=0 --master_addr=12.34.56.78
79 | ```
80 | For workers
81 | ```python
82 | python public-asgd.py --rank=r --master_addr=12.34.56.78
83 | ```
84 | `r=1,2,3,4` is the rank of each worker. `12.34.56.78` is the IP address of the PS.
85 |
86 | ## References
87 | [1]. Li M, Andersen D G, Park J W, et al. [Scaling distributed machine learning with the parameter server](https://www.usenix.org/system/files/conference/osdi14/osdi14-paper-li_mu.pdf )//11th {USENIX} Symposium on Operating Systems Design and Implementation ({OSDI} 14). 2014: 583-598.
88 | [2]. Pytorch. https://pytorch.org/.
89 | [3]. Sergeev A, Del Balso M. [Horovod: fast and easy distributed deep learning in TensorFlow](https://arxiv.org/abs/1802.05799). arXiv preprint arXiv:1802.05799, 2018.
90 | [4]. Distributed RPC Framework. https://pytorch.org/docs/1.9.0/rpc.html.
91 | [5]. He, K., Zhang, X., Ren, S., & Sun, J. (2016). [Deep residual learning for image recognition](https://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf). In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).
92 | [6]. Imagenette. https://github.com/fastai/imagenette.
93 | [7]. Imagenet. https://image-net.org/.
94 |
--------------------------------------------------------------------------------
/public-asgd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import threading
4 | import time
5 |
6 | import torch
7 | from torch import optim
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 | import torch.distributed.rpc as rpc
11 | from torchvision import transforms, datasets, models
12 |
13 | model_dict = {'resnet18': models.resnet18, 'resnet50': models.resnet50, 'vgg16': models.vgg16, 'alexnet': models.alexnet,
14 | 'googlenet': models.googlenet, 'inception': models.inception_v3,
15 | 'densenet121': models.densenet121, 'mobilenet': models.mobilenet_v2}
16 |
17 |
18 | class ParameterServer(object):
19 | """"
20 | The parameter server (PS) updates model parameters with gradients from the workers
21 | and sends the updated parameters back to the workers.
22 | """
23 | def __init__(self, model, num_workers, lr):
24 | self.lock = threading.Lock()
25 | self.future_model = torch.futures.Future()
26 | self.num_workers = num_workers
27 | # initialize model parameters
28 | assert model in model_dict.keys(), \
29 | f'model {model} is not in the model list: {list(model_dict.keys())}'
30 | self.model = model_dict[model](num_classes=10)
31 | # zero gradients
32 | for p in self.model.parameters():
33 | p.grad = torch.zeros_like(p)
34 | self.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)
35 |
36 | def get_model(self):
37 | return self.model
38 |
39 | @staticmethod
40 | @rpc.functions.async_execution
41 | def update_and_fetch_model(ps_rref, grads, worker_rank):
42 | self = ps_rref.local_value()
43 | with self.lock:
44 | print(f'PS updates parameters based on gradients from worker{worker_rank}')
45 | # update model parameters
46 | for p, g in zip(self.model.parameters(), grads):
47 | p.grad = g
48 | self.optimizer.step()
49 | self.optimizer.zero_grad()
50 |
51 | fut = self.future_model
52 |
53 | fut.set_result(self.model)
54 | self.future_model = torch.futures.Future()
55 |
56 | return fut
57 |
58 |
59 | def run_worker(ps_rref, rank, data_dir, batch_size, num_epochs):
60 | """
61 | A worker pulls model parameters from the PS, computes gradients on a mini-batch
62 | from its data partition, and pushes the gradients to the PS.
63 | """
64 |
65 | # prepare dataset
66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
67 | transform = transforms.Compose(
68 | [transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
69 | train_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
70 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
71 |
72 | # set device
73 | device_id = rank - 1
74 | # device_id = 0
75 | device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")
76 | criterion = nn.CrossEntropyLoss()
77 |
78 | # get initial model from the PS
79 | m = ps_rref.rpc_sync().get_model().to(device)
80 |
81 | print(f'worker{rank} starts training')
82 | tt0 = time.time()
83 |
84 | for i in range(num_epochs):
85 | for batch_idx, (data, target) in enumerate(train_loader):
86 | data, target = data.to(device), target.to(device)
87 | output = m(data)
88 | loss = criterion(output, target)
89 | loss.backward()
90 |
91 | print("worker{:d} | Epoch:{:3d} | Batch: {:3d} | Loss: {:6.2f}"
92 | .format(rank, (i + 1), (batch_idx + 1), loss.item()))
93 |
94 | # send gradients to the PS and fetch updated model parameters
95 | m = rpc.rpc_sync(to=ps_rref.owner(),
96 | func=ParameterServer.update_and_fetch_model,
97 | args=(ps_rref, [p.grad for p in m.cpu().parameters()], rank)
98 | ).to(device)
99 |
100 | tt1 = time.time()
101 |
102 | print("Time: {:.2f} seconds".format((tt1 - tt0)))
103 |
104 |
105 | def main():
106 | parser = argparse.ArgumentParser(description="Train models on Imagenette under ASGD")
107 | parser.add_argument("--model", type=str, default="resnet18", help="The job's name.")
108 | parser.add_argument("--rank", type=int, default=1, help="Global rank of this process.")
109 | parser.add_argument("--world_size", type=int, default=3, help="Total number of workers.")
110 | parser.add_argument("--data_dir", type=str, default="./imagenette2/val", help="The location of dataset.")
111 | parser.add_argument("--master_addr", type=str, default="localhost", help="Address of master.")
112 | parser.add_argument("--master_port", type=str, default="29600", help="Port that master is listening on.")
113 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size of each worker during training.")
114 | parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
115 | parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
116 |
117 | args = parser.parse_args()
118 |
119 | os.environ['MASTER_ADDR'] = args.master_addr
120 | os.environ['MASTER_PORT'] = args.master_port
121 |
122 | options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout=0)
123 |
124 | if args.rank == 0:
125 | """
126 | initialize PS and run workers
127 | """
128 | print(f"PS{args.rank} initializing")
129 | rpc.init_rpc(f"PS{args.rank}", rank=args.rank, world_size=args.world_size, rpc_backend_options=options)
130 | print(f"PS{args.rank} initialized")
131 |
132 | ps_rref = rpc.RRef(ParameterServer(args.model, args.world_size, args.lr))
133 |
134 | futs = []
135 | for r in range(1, args.world_size):
136 | worker = f'worker{r}'
137 | futs.append(rpc.rpc_async(to=worker,
138 | func=run_worker,
139 | args=(ps_rref, r, args.data_dir, args.batch_size, args.num_epochs)))
140 |
141 | torch.futures.wait_all(futs)
142 | print(f"Finish training")
143 |
144 | else:
145 | """
146 | initialize workers
147 | """
148 | print(f"worker{args.rank} initializing")
149 | rpc.init_rpc(f"worker{args.rank}", rank=args.rank, world_size=args.world_size, rpc_backend_options=options)
150 | print(f"worker{args.rank} initialized")
151 |
152 | rpc.shutdown()
153 |
154 |
155 | if __name__ == "__main__":
156 | main()
157 |
--------------------------------------------------------------------------------